summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-01-07 22:26:31 -0800
committerGitHub <noreply@github.com>2025-01-07 22:26:31 -0800
commitc43f6fa55aca23365c86c6ec1737d42be74d9d3e (patch)
tree2c49bc1dbd12ae5f46d682a3f240465931471060 /source
parent1a56f58fdd0c704e6dc0fad0f0ec33a25a35e60b (diff)
Lower varying parameters as pointers instead of SSA values. (#5919)
* Add executable test on matrix-typed vertex input. * Fix emit logic of matrix layout qualifier. * Pass fragment shader varying input by constref to allow EvaluateAttributeAtCentroid etc. to be implemented correctly.
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang2
-rw-r--r--source/slang/glsl.meta.slang12
-rw-r--r--source/slang/hlsl.meta.slang205
-rw-r--r--source/slang/slang-ast-decl.h2
-rw-r--r--source/slang/slang-ast-dump.cpp20
-rw-r--r--source/slang/slang-check-decl.cpp15
-rw-r--r--source/slang/slang-check-expr.cpp14
-rw-r--r--source/slang/slang-emit-c-like.cpp40
-rw-r--r--source/slang/slang-emit-cpp.cpp1
-rw-r--r--source/slang/slang-emit-glsl.cpp12
-rw-r--r--source/slang/slang-emit-hlsl.cpp55
-rw-r--r--source/slang/slang-emit-hlsl.h3
-rw-r--r--source/slang/slang-emit.cpp11
-rw-r--r--source/slang/slang-ir-fix-entrypoint-callsite.cpp101
-rw-r--r--source/slang/slang-ir-fix-entrypoint-callsite.h9
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp337
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp71
-rw-r--r--source/slang/slang-ir-legalize-varying-params.h1
-rw-r--r--source/slang/slang-ir-lower-buffer-element-type.cpp67
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp1
-rw-r--r--source/slang/slang-ir-resolve-varying-input-ref.cpp92
-rw-r--r--source/slang/slang-ir-resolve-varying-input-ref.h10
-rw-r--r--source/slang/slang-ir-translate-glsl-global-var.cpp6
-rw-r--r--source/slang/slang-ir-vk-invert-y.cpp13
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp3
-rw-r--r--source/slang/slang-ir.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp72
-rw-r--r--source/slang/slang-parser.cpp11
29 files changed, 977 insertions, 214 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 625f8f608..adb7470dd 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1,3 +1,5 @@
+//public module core;
+
// Slang `core` library
// Aliases for base types
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index 2ff71a74e..ba26b5d84 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -9080,7 +9080,7 @@ public vector<float, N> fwidthCoarse(vector<float, N> p)
[__NoSideEffect]
[__GLSLRequireShaderInputParameter(0)]
[require(glsl_spirv, fragmentprocessing)]
-public float interpolateAtCentroid(__ref float interpolant)
+public float interpolateAtCentroid(__constref float interpolant)
{
__target_switch
{
@@ -9099,7 +9099,7 @@ __generic<let N : int>
[__NoSideEffect]
[__GLSLRequireShaderInputParameter(0)]
[require(glsl_spirv, fragmentprocessing)]
-public vector<float, N> interpolateAtCentroid(__ref vector<float, N> interpolant)
+public vector<float, N> interpolateAtCentroid(__constref vector<float, N> interpolant)
{
__target_switch
{
@@ -9118,7 +9118,7 @@ public vector<float, N> interpolateAtCentroid(__ref vector<float, N> interpolant
[__NoSideEffect]
[__GLSLRequireShaderInputParameter(0)]
[require(glsl_spirv, fragmentprocessing)]
-public float interpolateAtSample(__ref float interpolant, int sample)
+public float interpolateAtSample(__constref float interpolant, int sample)
{
__target_switch
{
@@ -9137,7 +9137,7 @@ __generic<let N : int>
[__NoSideEffect]
[__GLSLRequireShaderInputParameter(0)]
[require(glsl_spirv, fragmentprocessing)]
-public vector<float, N> interpolateAtSample(__ref vector<float, N> interpolant, int sample)
+public vector<float, N> interpolateAtSample(__constref vector<float, N> interpolant, int sample)
{
__target_switch
{
@@ -9156,7 +9156,7 @@ public vector<float, N> interpolateAtSample(__ref vector<float, N> interpolant,
[__NoSideEffect]
[__GLSLRequireShaderInputParameter(0)]
[require(glsl_spirv, fragmentprocessing)]
-public float interpolateAtOffset(__ref float interpolant, vec2 offset)
+public float interpolateAtOffset(__constref float interpolant, vec2 offset)
{
__target_switch
{
@@ -9175,7 +9175,7 @@ __generic<let N : int>
[__NoSideEffect]
[__GLSLRequireShaderInputParameter(0)]
[require(glsl_spirv, fragmentprocessing)]
-public vector<float, N> interpolateAtOffset(__ref vector<float, N> interpolant, vec2 offset)
+public vector<float, N> interpolateAtOffset(__constref vector<float, N> interpolant, vec2 offset)
{
__target_switch
{
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index d8f9845b8..d5b59427f 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -8101,89 +8101,139 @@ RasterizerOrderedStructuredBuffer<T> __getEquivalentStructuredBuffer<T>(Rasteriz
// Attribute evaluation
+T __EvaluateAttributeAtCentroid<T>(__constref T x)
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "EvaluateAttributeAtCentroid";
+ case glsl: __intrinsic_asm "interpolateAtCentroid";
+ }
+}
+
// TODO: The matrix cases of these functions won't actuall work
// when compiled to GLSL, since they only support scalar/vector
// TODO: Should these be constrains to `__BuiltinFloatingPointType`?
// TODO: SPIRV-direct does not support non-floating-point types.
+/// Interpolates vertex attribute at centroid position.
+/// @param x The vertex attribute to interpolate.
+/// @return The interpolated attribute value.
+/// @remarks `x` must be a direct reference to a fragment shader varying input.
+/// @category interpolation Vertex Interpolation Functions
__generic<T : __BuiltinArithmeticType>
[__readNone]
-[require(glsl_spirv, fragmentprocessing)]
-T EvaluateAttributeAtCentroid(T x)
+[__unsafeForceInlineEarly]
+[require(glsl_hlsl_spirv, fragmentprocessing)]
+T EvaluateAttributeAtCentroid(__constref T x)
{
__target_switch
{
- case glsl: __intrinsic_asm "interpolateAtCentroid";
+ case hlsl:
+ case glsl:
+ return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x));
case spirv: return spirv_asm {
- OpExtInst $$T result glsl450 InterpolateAtCentroid $x
+ OpCapability InterpolationFunction;
+ OpExtInst $$T result glsl450 InterpolateAtCentroid $__ResolveVaryingInputRef(x)
};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
[__readNone]
-[require(glsl_spirv, fragmentprocessing)]
-vector<T,N> EvaluateAttributeAtCentroid(vector<T,N> x)
+[__unsafeForceInlineEarly]
+[require(glsl_hlsl_spirv, fragmentprocessing)]
+vector<T,N> EvaluateAttributeAtCentroid(__constref vector<T,N> x)
{
__target_switch
{
- case glsl: __intrinsic_asm "interpolateAtCentroid";
+ case hlsl:
+ case glsl:
+ return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x));
case spirv: return spirv_asm {
- OpExtInst $$vector<T,N> result glsl450 InterpolateAtCentroid $x
+ OpCapability InterpolationFunction;
+ OpExtInst $$vector<T,N> result glsl450 InterpolateAtCentroid $__ResolveVaryingInputRef(x)
};
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
[__readNone]
-[require(glsl_spirv, fragmentprocessing)]
-matrix<T,N,M> EvaluateAttributeAtCentroid(matrix<T,N,M> x)
+[__unsafeForceInlineEarly]
+[require(glsl_hlsl_spirv, fragmentprocessing)]
+matrix<T,N,M> EvaluateAttributeAtCentroid(__constref matrix<T,N,M> x)
{
__target_switch
{
- case glsl: __intrinsic_asm "interpolateAtCentroid";
+ case hlsl:
+ case glsl:
+ return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x));
default:
MATRIX_MAP_UNARY(T, N, M, EvaluateAttributeAtCentroid, x);
}
}
+T __EvaluateAttributeAtSample<T>(__constref T x, uint sampleIndex)
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "EvaluateAttributeAtSample";
+ case glsl: __intrinsic_asm "interpolateAtSample";
+ }
+}
+
+/// Interpolates vertex attribute at the current fragment sample position.
+/// @param x The vertex attribute to interpolate.
+/// @return The interpolated attribute value.
+/// @remarks `x` must be a direct reference to a fragment shader varying input.
+/// @category interpolation Vertex Interpolation Functions
__generic<T : __BuiltinArithmeticType>
[__readNone]
-[require(glsl_spirv, fragmentprocessing)]
-T EvaluateAttributeAtSample(T x, uint sampleindex)
+[__unsafeForceInlineEarly]
+[require(glsl_hlsl_spirv, fragmentprocessing)]
+T EvaluateAttributeAtSample(__constref T x, uint sampleindex)
{
__target_switch
{
- case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))";
+ case hlsl:
+ case glsl:
+ return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex);
case spirv: return spirv_asm {
- OpExtInst $$T result glsl450 InterpolateAtSample $x $sampleindex
+ OpCapability InterpolationFunction;
+ OpExtInst $$T result glsl450 InterpolateAtSample $__ResolveVaryingInputRef(x) $sampleindex
};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
[__readNone]
-[require(glsl_spirv, fragmentprocessing)]
-vector<T,N> EvaluateAttributeAtSample(vector<T,N> x, uint sampleindex)
+[__unsafeForceInlineEarly]
+[require(glsl_hlsl_spirv, fragmentprocessing)]
+vector<T,N> EvaluateAttributeAtSample(__constref vector<T,N> x, uint sampleindex)
{
__target_switch
{
- case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))";
+ case hlsl:
+ case glsl:
+ return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex);
case spirv: return spirv_asm {
- OpExtInst $$vector<T,N> result glsl450 InterpolateAtSample $x $sampleindex
+ OpCapability InterpolationFunction;
+ OpExtInst $$vector<T,N> result glsl450 InterpolateAtSample $__ResolveVaryingInputRef(x) $sampleindex
};
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
[__readNone]
-[require(glsl_spirv, fragmentprocessing)]
-matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex)
+[__unsafeForceInlineEarly]
+[require(glsl_hlsl_spirv, fragmentprocessing)]
+matrix<T,N,M> EvaluateAttributeAtSample(__constref matrix<T,N,M> x, uint sampleindex)
{
__target_switch
{
- case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))";
+ case hlsl:
+ case glsl:
+ return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex);
default:
matrix<T,N,M> result;
for(int i = 0; i < N; ++i)
@@ -8194,21 +8244,59 @@ matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex)
}
}
+T __EvaluateAttributeSnapped<T>(__constref T x, int2 offset)
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "EvaluateAttributeSnapped";
+ case glsl: __intrinsic_asm "EvaluateAttributeSnapped";
+ }
+}
+
+/// Interpolates vertex attribute at the specified subpixel offset.
+/// @param x The vertex attribute to interpolate.
+/// @param offset The subpixel offset. Each component is a 4-bit signed integer in range [-8, 7].
+/// @return The interpolated attribute value.
+/// @remarks `x` must be a direct reference to a fragment shader varying input.
+///
+/// The valid values of each component of `offset` are:
+///
+/// - 1000 = -0.5f (-8 / 16)
+/// - 1001 = -0.4375f (-7 / 16)
+/// - 1010 = -0.375f (-6 / 16)
+/// - 1011 = -0.3125f (-5 / 16)
+/// - 1100 = -0.25f (-4 / 16)
+/// - 1101 = -0.1875f (-3 / 16)
+/// - 1110 = -0.125f (-2 / 16)
+/// - 1111 = -0.0625f (-1 / 16)
+/// - 0000 = 0.0f ( 0 / 16)
+/// - 0001 = 0.0625f ( 1 / 16)
+/// - 0010 = 0.125f ( 2 / 16)
+/// - 0011 = 0.1875f ( 3 / 16)
+/// - 0100 = 0.25f ( 4 / 16)
+/// - 0101 = 0.3125f ( 5 / 16)
+/// - 0110 = 0.375f ( 6 / 16)
+/// - 0111 = 0.4375f ( 7 / 16)
+/// @category interpolation Vertex Interpolation Functions
__generic<T : __BuiltinArithmeticType>
[__readNone]
-[require(glsl_spirv, fragmentprocessing)]
-T EvaluateAttributeSnapped(T x, int2 offset)
+[__unsafeForceInlineEarly]
+[require(glsl_hlsl_spirv, fragmentprocessing)]
+T EvaluateAttributeSnapped(__constref T x, int2 offset)
{
__target_switch
{
- case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)";
+ case hlsl:
+ case glsl:
+ return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset);
case spirv:
{
const float2 tmp = float2(16.f, 16.f);
return spirv_asm {
+ OpCapability InterpolationFunction;
%foffset:$$float2 = OpConvertSToF $offset;
%offsetdiv16:$$float2 = OpFDiv %foffset $tmp;
- result:$$T = OpExtInst glsl450 InterpolateAtOffset $x %offsetdiv16
+ result:$$T = OpExtInst glsl450 InterpolateAtOffset $__ResolveVaryingInputRef(x) %offsetdiv16
};
}
}
@@ -8216,19 +8304,23 @@ T EvaluateAttributeSnapped(T x, int2 offset)
__generic<T : __BuiltinArithmeticType, let N : int>
[__readNone]
-[require(glsl_spirv, fragmentprocessing)]
-vector<T,N> EvaluateAttributeSnapped(vector<T,N> x, int2 offset)
+[__unsafeForceInlineEarly]
+[require(glsl_hlsl_spirv, fragmentprocessing)]
+vector<T,N> EvaluateAttributeSnapped(__constref vector<T,N> x, int2 offset)
{
__target_switch
{
- case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)";
+ case hlsl:
+ case glsl:
+ return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset);
case spirv:
{
const float2 tmp = float2(16.f, 16.f);
return spirv_asm {
+ OpCapability InterpolationFunction;
%foffset:$$float2 = OpConvertSToF $offset;
%offsetdiv16:$$float2 = OpFDiv %foffset $tmp;
- result:$$vector<T,N> = OpExtInst glsl450 InterpolateAtOffset $x %offsetdiv16
+ result:$$vector<T,N> = OpExtInst glsl450 InterpolateAtOffset $__ResolveVaryingInputRef(x) %offsetdiv16
};
}
}
@@ -8236,12 +8328,15 @@ vector<T,N> EvaluateAttributeSnapped(vector<T,N> x, int2 offset)
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
[__readNone]
-[require(glsl_spirv, fragmentprocessing)]
-matrix<T,N,M> EvaluateAttributeSnapped(matrix<T,N,M> x, int2 offset)
+[__unsafeForceInlineEarly]
+[require(glsl_hlsl_spirv, fragmentprocessing)]
+matrix<T,N,M> EvaluateAttributeSnapped(__constref matrix<T,N,M> x, int2 offset)
{
__target_switch
{
- case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)";
+ case hlsl:
+ case glsl:
+ return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset);
default:
matrix<T,N,M> result;
for(int i = 0; i < N; ++i)
@@ -9243,8 +9338,16 @@ matrix<T, N, M> fwidth(matrix<T, N, M> x)
}
}
+__intrinsic_op($(kIROp_ResolveVaryingInputRef))
+Ref<T> __ResolveVaryingInputRef<T>(__constref T attribute);
+
__intrinsic_op($(kIROp_GetPerVertexInputArray))
-Array<T, 3> __GetPerVertexInputArray<T>(T attribute);
+Ref<Array<T, 3>> __GetPerVertexInputArray<T>(__constref T attribute);
+
+T __GetAttributeAtVertex<T>(__constref T attribute, uint vertexIndex)
+{
+ __intrinsic_asm "GetAttributeAtVertex";
+}
/// Get the value of a vertex attribute at a specific vertex.
///
@@ -9265,15 +9368,15 @@ __glsl_extension(GL_EXT_fragment_shader_barycentric)
[require(glsl_hlsl_spirv, getattributeatvertex)]
[KnownBuiltin("GetAttributeAtVertex")]
[__unsafeForceInlineEarly]
-T GetAttributeAtVertex(T attribute, uint vertexIndex)
+T GetAttributeAtVertex(__constref T attribute, uint vertexIndex)
{
__target_switch
{
case hlsl:
- __intrinsic_asm "GetAttributeAtVertex";
+ return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex);
case glsl:
case spirv:
- return __GetPerVertexInputArray(attribute)[vertexIndex];
+ return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex];
}
}
@@ -9294,20 +9397,16 @@ __generic<T : __BuiltinType, let N : int>
__glsl_version(450)
__glsl_extension(GL_EXT_fragment_shader_barycentric)
[require(glsl_hlsl_spirv, getattributeatvertex)]
-vector<T,N> GetAttributeAtVertex(vector<T,N> attribute, uint vertexIndex)
+[__unsafeForceInlineEarly]
+vector<T,N> GetAttributeAtVertex(__constref vector<T,N> attribute, uint vertexIndex)
{
__target_switch
{
case hlsl:
- __intrinsic_asm "GetAttributeAtVertex";
- case glsl:
- __intrinsic_asm "$0[$1]";
+ return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex);
+ case glsl:
case spirv:
- return spirv_asm {
- %_ptr_Input_vectorT = OpTypePointer Input $$vector<T,N>;
- %addr = OpAccessChain %_ptr_Input_vectorT $attribute $vertexIndex;
- result:$$vector<T,N> = OpLoad %addr;
- };
+ return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex];
}
}
@@ -9328,20 +9427,16 @@ __generic<T : __BuiltinType, let N : int, let M : int>
__glsl_version(450)
__glsl_extension(GL_EXT_fragment_shader_barycentric)
[require(glsl_hlsl_spirv, getattributeatvertex)]
-matrix<T,N,M> GetAttributeAtVertex(matrix<T,N,M> attribute, uint vertexIndex)
+[__unsafeForceInlineEarly]
+matrix<T,N,M> GetAttributeAtVertex(__constref matrix<T,N,M> attribute, uint vertexIndex)
{
__target_switch
{
case hlsl:
- __intrinsic_asm "GetAttributeAtVertex";
- case glsl:
- __intrinsic_asm "$0[$1]";
+ return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex);
+ case glsl:
case spirv:
- return spirv_asm {
- %_ptr_Input_matrixT = OpTypePointer Input $$matrix<T,N,M>;
- %addr = OpAccessChain %_ptr_Input_matrixT $attribute $vertexIndex;
- result:$$matrix<T,N,M> = OpLoad %addr;
- };
+ return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex];
}
}
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index e8886a59a..911455f17 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -461,6 +461,8 @@ class ModuleDecl : public NamespaceDeclBase
/// `__implementing` etc.
bool isInLegacyLanguage = true;
+ DeclVisibility defaultVisibility = DeclVisibility::Internal;
+
SLANG_UNREFLECTED
/// Map a type to the list of extensions of that type (if any) declared in this module
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index b77003bff..85d2d0d9f 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -460,6 +460,26 @@ struct ASTDumpContext
}
}
+ void dump(DeclVisibility vis)
+ {
+ switch (vis)
+ {
+ case DeclVisibility::Private:
+ m_writer->emit("private");
+ break;
+ case DeclVisibility::Internal:
+ m_writer->emit("internal");
+ break;
+ case DeclVisibility::Public:
+ m_writer->emit("public");
+ break;
+ default:
+ m_writer->emit(String((int)vis).getUnownedSlice());
+ break;
+ }
+ }
+
+
void dump(const QualType& qualType)
{
if (qualType.isLeftValue)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index ce3f1e64c..3667a36ba 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -3186,6 +3186,11 @@ void SemanticsDeclVisitorBase::checkModule(ModuleDecl* moduleDecl)
}
}
+ if (moduleDecl->findModifier<PublicModifier>())
+ {
+ moduleDecl->defaultVisibility = DeclVisibility::Public;
+ }
+
// We need/want to visit any `import` declarations before
// anything else, to make sure that scoping works.
//
@@ -12604,8 +12609,10 @@ DeclVisibility getDeclVisibility(Decl* decl)
}
auto defaultVis = DeclVisibility::Default;
if (auto parentModule = getModuleDecl(decl))
- defaultVis =
- parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal;
+ {
+ defaultVis = parentModule->isInLegacyLanguage ? DeclVisibility::Public
+ : parentModule->defaultVisibility;
+ }
// Members of other agg type decls will have their default visibility capped to the parents'.
if (as<NamespaceDecl>(decl))
@@ -12790,10 +12797,6 @@ void diagnoseCapabilityProvenance(
auto moduleDecl = getModuleDecl(declToPrint);
if (thisModule != moduleDecl)
break;
- if (moduleDecl && moduleDecl->isInLegacyLanguage)
- continue;
- if (getDeclVisibility(declToPrint) == DeclVisibility::Public)
- break;
}
if (previousDecl == declToPrint)
break;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 83b668a33..95d5a2a7c 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -4099,6 +4099,15 @@ Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBas
elementType = QualType(ptrType->getValueType());
elementType.isLeftValue = true;
}
+ else
+ {
+ auto newExpr = maybeOpenRef(expr);
+ if (newExpr != expr)
+ {
+ expr = newExpr;
+ continue;
+ }
+ }
if (elementType.type)
{
auto derefExpr = m_astBuilder->create<DerefExpr>();
@@ -4108,9 +4117,10 @@ Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBas
expr = derefExpr;
continue;
}
- // Default case: just use the expression as-is
- return expr;
+ break;
}
+ // Default case: just use the expression as-is
+ return expr;
}
Expr* SemanticsVisitor::CheckMatrixSwizzleExpr(
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index 3175f1b07..64cc9969c 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1537,6 +1537,30 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
return true;
}
+ if (auto load = as<IRLoad>(inst))
+ {
+ // Loads from a constref global param should always be folded.
+ auto ptrType = load->getPtr()->getDataType();
+ if (load->getPtr()->getOp() == kIROp_GlobalParam)
+ {
+ if (ptrType->getOp() == kIROp_ConstRefType)
+ return true;
+ if (auto ptrTypeBase = as<IRPtrTypeBase>(ptrType))
+ {
+ auto addrSpace = ptrTypeBase->getAddressSpace();
+ switch (addrSpace)
+ {
+ case Slang::AddressSpace::Uniform:
+ case Slang::AddressSpace::Input:
+ case Slang::AddressSpace::BuiltinInput:
+ return true;
+ default:
+ break;
+ }
+ }
+ }
+ }
+
// Always hold if inst is a call into an [__alwaysFoldIntoUseSite] function.
if (auto call = as<IRCall>(inst))
{
@@ -4709,9 +4733,21 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl)
auto rawType = varDecl->getDataType();
auto varType = rawType;
- if (auto outType = as<IROutTypeBase>(varType))
+ if (auto ptrType = as<IRPtrTypeBase>(varType))
{
- varType = outType->getValueType();
+ switch (ptrType->getAddressSpace())
+ {
+ case AddressSpace::Input:
+ case AddressSpace::Output:
+ case AddressSpace::BuiltinInput:
+ case AddressSpace::BuiltinOutput:
+ varType = ptrType->getValueType();
+ break;
+ default:
+ if (as<IROutTypeBase>(ptrType))
+ varType = ptrType->getValueType();
+ break;
+ }
}
if (as<IRVoidType>(varType))
return;
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index f91c4d06e..b0d1fbb4c 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -310,6 +310,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S
}
case kIROp_NativePtrType:
case kIROp_PtrType:
+ case kIROp_ConstRefType:
{
auto elementType = (IRType*)type->getOperand(0);
SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out));
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index a863e7eb1..23fff37ac 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -3166,6 +3166,11 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType());
return;
}
+ case kIROp_ConstRefType:
+ {
+ emitSimpleTypeImpl(as<IRConstRefType>(type)->getValueType());
+ return;
+ }
default:
break;
}
@@ -3471,15 +3476,18 @@ void GLSLSourceEmitter::emitMatrixLayoutModifiersImpl(IRType* varType)
//
auto matrixType = as<IRMatrixType>(unwrapArray(varType));
-
if (matrixType)
{
+ auto layout = getIntVal(matrixType->getLayout());
+ if (layout == getTargetProgram()->getOptionSet().getMatrixLayoutMode())
+ return;
+
// Reminder: the meaning of row/column major layout
// in our semantics is the *opposite* of what GLSL
// calls them, because what they call "columns"
// are what we call "rows."
//
- switch (getIntVal(matrixType->getLayout()))
+ switch (layout)
{
case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR:
m_writer->emit("layout(row_major)\n");
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index 40d6f75d9..83eec17b4 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -1180,6 +1180,34 @@ void HLSLSourceEmitter::emitSimpleValueImpl(IRInst* inst)
Super::emitSimpleValueImpl(inst);
}
+void HLSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator)
+{
+ if (declarator)
+ {
+ // HLSL only allow matrix layout modifier when declaring a variable or struct field.
+ if (auto matType = as<IRMatrixType>(type))
+ {
+ auto matrixLayout = getIntVal(matType->getLayout());
+ if (getTargetProgram()->getOptionSet().getMatrixLayoutMode() !=
+ (MatrixLayoutMode)matrixLayout)
+ {
+ switch (matrixLayout)
+ {
+ case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR:
+ m_writer->emit("column_major ");
+ break;
+ case SLANG_MATRIX_LAYOUT_ROW_MAJOR:
+ m_writer->emit("row_major ");
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ }
+ Super::emitSimpleTypeAndDeclaratorImpl(type, declarator);
+}
+
void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
{
switch (type->getOp())
@@ -1313,6 +1341,11 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType());
return;
}
+ case kIROp_ConstRefType:
+ {
+ emitSimpleTypeImpl(as<IRConstRefType>(type)->getValueType());
+ return;
+ }
default:
break;
}
@@ -1671,28 +1704,6 @@ void HLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl)
}
}
-void HLSLSourceEmitter::emitMatrixLayoutModifiersImpl(IRType* type)
-{
- auto matType = as<IRMatrixType>(type);
- if (!matType)
- return;
- auto matrixLayout = getIntVal(matType->getLayout());
- if (getTargetProgram()->getOptionSet().getMatrixLayoutMode() != (MatrixLayoutMode)matrixLayout)
- {
- switch (matrixLayout)
- {
- case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR:
- m_writer->emit("column_major ");
- break;
- case SLANG_MATRIX_LAYOUT_ROW_MAJOR:
- m_writer->emit("row_major ");
- break;
- default:
- break;
- }
- }
-}
-
void HLSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst)
{
if (inst->findDecoration<IRRequiresNVAPIDecoration>())
diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h
index b2e2ca05a..6b99a7f50 100644
--- a/source/slang/slang-emit-hlsl.h
+++ b/source/slang/slang-emit-hlsl.h
@@ -55,11 +55,12 @@ protected:
IRPackOffsetDecoration* decoration) SLANG_OVERRIDE;
virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE;
+ virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator)
+ SLANG_OVERRIDE;
virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE;
virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount)
SLANG_OVERRIDE;
virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE;
- virtual void emitMatrixLayoutModifiersImpl(IRType* varType) SLANG_OVERRIDE;
virtual void emitParamTypeModifier(IRType* type) SLANG_OVERRIDE
{
emitMatrixLayoutModifiersImpl(type);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 7ea2fef88..b9217de41 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -42,6 +42,7 @@
#include "slang-ir-entry-point-uniforms.h"
#include "slang-ir-explicit-global-context.h"
#include "slang-ir-explicit-global-init.h"
+#include "slang-ir-fix-entrypoint-callsite.h"
#include "slang-ir-fuse-satcoop.h"
#include "slang-ir-glsl-legalize.h"
#include "slang-ir-glsl-liveness.h"
@@ -76,6 +77,7 @@
#include "slang-ir-pytorch-cpp-binding.h"
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-resolve-texture-format.h"
+#include "slang-ir-resolve-varying-input-ref.h"
#include "slang-ir-restructure-scoping.h"
#include "slang-ir-restructure.h"
#include "slang-ir-sccp.h"
@@ -314,6 +316,7 @@ struct RequiredLoweringPassSet
bool glslSSBO;
bool byteAddressBuffer;
bool dynamicResource;
+ bool resolveVaryingInputRef;
};
// Scan the IR module and determine which lowering/legalization passes are needed based
@@ -423,6 +426,9 @@ void calcRequiredLoweringPassSet(
case kIROp_DynamicResourceType:
result.dynamicResource = true;
break;
+ case kIROp_ResolveVaryingInputRef:
+ result.resolveVaryingInputRef = true;
+ break;
}
if (!result.generics || !result.existentialTypeLayout)
{
@@ -591,6 +597,11 @@ Result linkAndOptimizeIR(
if (requiredLoweringPassSet.glslGlobalVar)
translateGLSLGlobalVar(codeGenContext, irModule);
+ if (requiredLoweringPassSet.resolveVaryingInputRef)
+ resolveVaryingInputRef(irModule);
+
+ fixEntryPointCallsites(irModule);
+
// Replace any global constants with their values.
//
replaceGlobalConstants(irModule);
diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.cpp b/source/slang/slang-ir-fix-entrypoint-callsite.cpp
new file mode 100644
index 000000000..7390f3a7f
--- /dev/null
+++ b/source/slang/slang-ir-fix-entrypoint-callsite.cpp
@@ -0,0 +1,101 @@
+#include "slang-ir-fix-entrypoint-callsite.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+// If the entrypoint is called by some other function, we need to clone the
+// entrypoint and replace the callsites to call the cloned entrypoint instead.
+// This is because we will be modifying the signature of the entrypoint during
+// entrypoint legalization to rewrite the way system values are passed in.
+// By replacing the callsites to call the cloned entrypoint that act as ordinary
+// functions, we will no longer need to worry about changing the callsites when we
+// legalize the entry-points.
+//
+void fixEntryPointCallsites(IRFunc* entryPoint)
+{
+ IRFunc* clonedEntryPointForCall = nullptr;
+ auto ensureClonedEntryPointForCall = [&]() -> IRFunc*
+ {
+ if (clonedEntryPointForCall)
+ return clonedEntryPointForCall;
+ IRCloneEnv cloneEnv;
+ IRBuilder builder(entryPoint);
+ builder.setInsertBefore(entryPoint);
+ clonedEntryPointForCall = (IRFunc*)cloneInst(&cloneEnv, &builder, entryPoint);
+ // Remove entrypoint and linkage decorations from the cloned callee.
+ List<IRInst*> decorsToRemove;
+ for (auto decor : clonedEntryPointForCall->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_ExportDecoration:
+ case kIROp_UserExternDecoration:
+ case kIROp_HLSLExportDecoration:
+ case kIROp_EntryPointDecoration:
+ case kIROp_LayoutDecoration:
+ case kIROp_NumThreadsDecoration:
+ case kIROp_ImportDecoration:
+ case kIROp_ExternCDecoration:
+ case kIROp_ExternCppDecoration:
+ decorsToRemove.add(decor);
+ break;
+ }
+ }
+ for (auto decor : decorsToRemove)
+ decor->removeAndDeallocate();
+ return clonedEntryPointForCall;
+ };
+ traverseUses(
+ entryPoint,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ auto call = as<IRCall>(user);
+ if (!call)
+ return;
+ auto callee = ensureClonedEntryPointForCall();
+ call->setOperand(0, callee);
+
+ // Fix up argument types: if the callee entrypoint is expecting a constref
+ // and the caller is passing a value, we need to wrap the value in a temporary var
+ // and pass the temporary var.
+ //
+ auto funcType = as<IRFuncType>(callee->getDataType());
+ SLANG_ASSERT(funcType);
+ IRBuilder builder(call);
+ builder.setInsertBefore(call);
+ List<IRParam*> params;
+ for (auto param : callee->getParams())
+ params.add(param);
+ if ((UInt)params.getCount() != call->getArgCount())
+ return;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto paramType = params[i]->getDataType();
+ auto arg = call->getArg(i);
+ if (auto refType = as<IRConstRefType>(paramType))
+ {
+ if (!as<IRPtrTypeBase>(arg->getDataType()))
+ {
+ auto tempVar = builder.emitVar(refType->getValueType());
+ builder.emitStore(tempVar, arg);
+ call->setArg(i, tempVar);
+ }
+ }
+ }
+ });
+}
+
+void fixEntryPointCallsites(IRModule* module)
+{
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (globalInst->findDecoration<IREntryPointDecoration>())
+ fixEntryPointCallsites((IRFunc*)globalInst);
+ }
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.h b/source/slang/slang-ir-fix-entrypoint-callsite.h
new file mode 100644
index 000000000..493d67a77
--- /dev/null
+++ b/source/slang/slang-ir-fix-entrypoint-callsite.h
@@ -0,0 +1,9 @@
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+void fixEntryPointCallsites(IRModule* module);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 09bf245df..39f970319 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -67,11 +67,7 @@ struct ScalarizedValImpl : RefObject
};
struct ScalarizedTupleValImpl;
struct ScalarizedTypeAdapterValImpl;
-
-struct ScalarizedArrayIndexValImpl : ScalarizedValImpl
-{
- Index index;
-};
+struct ScalarizedArrayIndexValImpl;
struct ScalarizedVal
{
@@ -132,15 +128,12 @@ struct ScalarizedVal
result.impl = (ScalarizedValImpl*)impl;
return result;
}
- static ScalarizedVal scalarizedArrayIndex(IRInst* irValue, Index index)
+ static ScalarizedVal scalarizedArrayIndex(ScalarizedArrayIndexValImpl* impl)
{
ScalarizedVal result;
result.flavor = Flavor::arrayIndex;
- auto impl = new ScalarizedArrayIndexValImpl;
- impl->index = index;
-
- result.irValue = irValue;
- result.impl = impl;
+ result.irValue = nullptr;
+ result.impl = (ScalarizedValImpl*)impl;
return result;
}
@@ -151,8 +144,6 @@ struct ScalarizedVal
RefPtr<ScalarizedValImpl> impl;
};
-IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val);
-
// This is the case for a value that is a "tuple" of other values
struct ScalarizedTupleValImpl : ScalarizedValImpl
{
@@ -175,6 +166,36 @@ struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl
IRType* pretendType; // the type this value pretends to have
};
+struct ScalarizedArrayIndexValImpl : ScalarizedValImpl
+{
+ ScalarizedVal arrayVal;
+ Index index;
+ IRType* elementType;
+};
+
+ScalarizedVal extractField(
+ IRBuilder* builder,
+ ScalarizedVal const& val,
+ UInt fieldIndex, // Pass ~0 in to search for the index via the key
+ IRStructKey* fieldKey);
+ScalarizedVal adaptType(IRBuilder* builder, IRInst* val, IRType* toType, IRType* fromType);
+ScalarizedVal adaptType(
+ IRBuilder* builder,
+ ScalarizedVal const& val,
+ IRType* toType,
+ IRType* fromType);
+IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val);
+ScalarizedVal getSubscriptVal(
+ IRBuilder* builder,
+ IRType* elementType,
+ ScalarizedVal val,
+ IRInst* indexVal);
+ScalarizedVal getSubscriptVal(
+ IRBuilder* builder,
+ IRType* elementType,
+ ScalarizedVal val,
+ UInt index);
+
struct GlobalVaryingDeclarator
{
enum class Flavor
@@ -1303,6 +1324,22 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
}
}
+ AddressSpace addrSpace = AddressSpace::Uniform;
+ IROp ptrOpCode = kIROp_PtrType;
+ switch (kind)
+ {
+ case LayoutResourceKind::VaryingInput:
+ addrSpace = systemValueInfo ? AddressSpace::BuiltinInput : AddressSpace::Input;
+ break;
+ case LayoutResourceKind::VaryingOutput:
+ addrSpace = systemValueInfo ? AddressSpace::BuiltinOutput : AddressSpace::Output;
+ ptrOpCode = kIROp_OutType;
+ break;
+ default:
+ break;
+ }
+
+
// If we have a declarator, we just use the normal logic, as that seems to work correctly
//
if (systemValueInfo && systemValueInfo->arrayIndex >= 0 && declarator == nullptr)
@@ -1339,9 +1376,7 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
// Set the array size to 0, to mean it is unsized
auto arrayType = builder->getArrayType(type, 0);
- IRType* paramType = kind == LayoutResourceKind::VaryingOutput
- ? (IRType*)builder->getOutType(arrayType)
- : arrayType;
+ IRType* paramType = builder->getPtrType(ptrOpCode, arrayType, addrSpace);
auto globalParam = addGlobalParam(builder->getModule(), paramType);
moveValueBefore(globalParam, builder->getFunc());
@@ -1371,9 +1406,12 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
semanticGlobal->addIndex(systemValueInfo->arrayIndex);
// Make it an array index
- ScalarizedVal val = ScalarizedVal::scalarizedArrayIndex(
- semanticGlobal->globalParam,
- systemValueInfo->arrayIndex);
+ ScalarizedVal val = ScalarizedVal::address(semanticGlobal->globalParam);
+ RefPtr<ScalarizedArrayIndexValImpl> arrayImpl = new ScalarizedArrayIndexValImpl();
+ arrayImpl->arrayVal = val;
+ arrayImpl->index = systemValueInfo->arrayIndex;
+ arrayImpl->elementType = type;
+ val = ScalarizedVal::scalarizedArrayIndex(arrayImpl);
// We need to make this access, an array access to the global
if (auto fromType = systemValueInfo->requiredType)
@@ -1466,14 +1504,14 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
// like our IR function parameters, and need a wrapper
// `Out<...>` type to represent outputs.
//
- bool isOutput = (kind == LayoutResourceKind::VaryingOutput);
- IRType* paramType = isOutput ? builder->getOutType(type) : type;
+
+ // Non system value varying inputs shall be passed as pointers.
+ IRType* paramType = builder->getPtrType(ptrOpCode, type, addrSpace);
auto globalParam = addGlobalParam(builder->getModule(), paramType);
moveValueBefore(globalParam, builder->getFunc());
- ScalarizedVal val =
- isOutput ? ScalarizedVal::address(globalParam) : ScalarizedVal::value(globalParam);
+ ScalarizedVal val = ScalarizedVal::address(globalParam);
if (systemValueInfo)
{
@@ -1958,10 +1996,10 @@ ScalarizedVal adaptType(
break;
case ScalarizedVal::Flavor::arrayIndex:
{
- auto element = builder->emitElementExtract(
- val.irValue,
- as<ScalarizedArrayIndexValImpl>(val.impl)->index);
- return adaptType(builder, element, toType, fromType);
+ auto arrayImpl = as<ScalarizedArrayIndexValImpl>(val.impl);
+ auto elementVal =
+ getSubscriptVal(builder, fromType, arrayImpl->arrayVal, arrayImpl->index);
+ return adaptType(builder, elementVal, toType, fromType);
}
break;
default:
@@ -1970,8 +2008,6 @@ ScalarizedVal adaptType(
}
}
-IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val);
-
void assign(
IRBuilder* builder,
ScalarizedVal const& left,
@@ -1988,16 +2024,12 @@ void assign(
// Determine the index
auto leftArrayIndexVal = as<ScalarizedArrayIndexValImpl>(left.impl);
- const auto arrayIndex = leftArrayIndexVal->index;
-
- auto arrayIndexInst = builder->getIntValue(builder->getIntType(), arrayIndex);
-
- // Store to the index
- auto address = builder->emitElementAddress(
- builder->getPtrType(right.irValue->getFullType()),
- left.irValue,
- arrayIndexInst);
- builder->emitStore(address, rhs);
+ auto leftVal = getSubscriptVal(
+ builder,
+ leftArrayIndexVal->elementType,
+ leftArrayIndexVal->arrayVal,
+ leftArrayIndexVal->index);
+ builder->emitStore(leftVal.irValue, rhs);
break;
}
@@ -2236,10 +2268,10 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val)
case ScalarizedVal::Flavor::arrayIndex:
{
- auto element = builder->emitElementExtract(
- val.irValue,
- as<ScalarizedArrayIndexValImpl>(val.impl)->index);
- return element;
+ auto impl = as<ScalarizedArrayIndexValImpl>(val.impl);
+ auto elementVal =
+ getSubscriptVal(builder, impl->elementType, impl->arrayVal, impl->index);
+ return materializeValue(builder, elementVal);
}
case ScalarizedVal::Flavor::tuple:
{
@@ -2735,9 +2767,9 @@ IRInst* getOrCreatePerVertexInputArray(GLSLLegalizationContext* context, IRInst*
IRBuilder builder(inputVertexAttr);
builder.setInsertBefore(inputVertexAttr);
auto arrayType = builder.getArrayType(
- inputVertexAttr->getDataType(),
+ tryGetPointedToType(&builder, inputVertexAttr->getDataType()),
builder.getIntValue(builder.getIntType(), 3));
- arrayInst = builder.createGlobalParam(arrayType);
+ arrayInst = builder.createGlobalParam(builder.getPtrType(arrayType, AddressSpace::Input));
context->mapVertexInputToPerVertexArray[inputVertexAttr] = arrayInst;
builder.addDecoration(arrayInst, kIROp_PerVertexDecoration);
@@ -2770,21 +2802,105 @@ void tryReplaceUsesOfStageInput(
[&](IRUse* use)
{
auto user = use->getUser();
+ IRBuilder builder(user);
+ builder.setInsertBefore(user);
+ builder.replaceOperand(use, val.irValue);
+ });
+ }
+ break;
+ case ScalarizedVal::Flavor::address:
+ {
+ bool needMaterialize = false;
+ if (as<IRPtrTypeBase>(val.irValue->getDataType()))
+ {
+ if (!as<IRPtrTypeBase>(originalVal->getDataType()))
+ {
+ needMaterialize = true;
+ }
+ }
+ traverseUses(
+ originalVal,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
if (user->getOp() == kIROp_GetPerVertexInputArray)
{
auto arrayInst = getOrCreatePerVertexInputArray(context, val.irValue);
user->replaceUsesWith(arrayInst);
user->removeAndDeallocate();
+ return;
+ }
+ IRBuilder builder(user);
+ builder.setInsertBefore(user);
+ if (needMaterialize)
+ {
+ auto materializedVal = materializeValue(&builder, val);
+ builder.replaceOperand(use, materializedVal);
}
else
{
- IRBuilder builder(user);
- builder.setInsertBefore(user);
builder.replaceOperand(use, val.irValue);
}
});
}
break;
+ case ScalarizedVal::Flavor::typeAdapter:
+ {
+ traverseUses(
+ originalVal,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ IRBuilder builder(user);
+ builder.setInsertBefore(user);
+ auto typeAdapter = as<ScalarizedTypeAdapterValImpl>(val.impl);
+ auto materializedInner = materializeValue(&builder, typeAdapter->val);
+ auto adapted = adaptType(
+ &builder,
+ materializedInner,
+ typeAdapter->pretendType,
+ typeAdapter->actualType);
+ if (user->getOp() == kIROp_Load)
+ {
+ user->replaceUsesWith(adapted.irValue);
+ user->removeAndDeallocate();
+ }
+ else
+ {
+ use->set(adapted.irValue);
+ }
+ });
+ }
+ break;
+ case ScalarizedVal::Flavor::arrayIndex:
+ {
+ traverseUses(
+ originalVal,
+ [&](IRUse* use)
+ {
+ auto arrayIndexImpl = as<ScalarizedArrayIndexValImpl>(val.impl);
+ auto user = use->getUser();
+ IRBuilder builder(user);
+ builder.setInsertBefore(user);
+ auto subscriptVal = getSubscriptVal(
+ &builder,
+ arrayIndexImpl->elementType,
+ arrayIndexImpl->arrayVal,
+ arrayIndexImpl->index);
+ builder.setInsertBefore(user);
+ auto materializedInner = materializeValue(&builder, subscriptVal);
+ if (user->getOp() == kIROp_Load)
+ {
+ user->replaceUsesWith(materializedInner);
+ user->removeAndDeallocate();
+ }
+ else
+ {
+ use->set(materializedInner);
+ }
+ });
+ break;
+ }
case ScalarizedVal::Flavor::tuple:
{
auto tupleVal = as<ScalarizedTupleValImpl>(val.impl);
@@ -2793,22 +2909,36 @@ void tryReplaceUsesOfStageInput(
[&](IRUse* use)
{
auto user = use->getUser();
- if (auto fieldExtract = as<IRFieldExtract>(user))
+ switch (user->getOp())
{
- auto fieldKey = fieldExtract->getField();
- ScalarizedVal fieldVal;
- for (auto element : tupleVal->elements)
+ case kIROp_FieldExtract:
+ case kIROp_FieldAddress:
{
- if (element.key == fieldKey)
+ auto fieldKey = user->getOperand(1);
+ ScalarizedVal fieldVal;
+ for (auto element : tupleVal->elements)
{
- fieldVal = element.val;
- break;
+ if (element.key == fieldKey)
+ {
+ fieldVal = element.val;
+ break;
+ }
+ }
+ if (fieldVal.flavor != ScalarizedVal::Flavor::none)
+ {
+ tryReplaceUsesOfStageInput(context, fieldVal, user);
}
}
- if (fieldVal.flavor != ScalarizedVal::Flavor::none)
+ break;
+ case kIROp_Load:
{
- tryReplaceUsesOfStageInput(context, fieldVal, user);
+ IRBuilder builder(user);
+ builder.setInsertBefore(user);
+ auto materializedVal = materializeTupleValue(&builder, val);
+ user->replaceUsesWith(materializedVal);
+ user->removeAndDeallocate();
}
+ break;
}
});
}
@@ -3066,7 +3196,7 @@ void legalizeEntryPointParameterForGLSL(
// We are going to create a local variable of the appropriate
// type, which will replace the parameter, along with
// one or more global variables for the actual input/output.
-
+ setInsertAfterOrdinaryInst(builder, pp);
auto localVariable = builder->emitVar(valueType);
auto localVal = ScalarizedVal::address(localVariable);
@@ -3135,6 +3265,73 @@ void legalizeEntryPointParameterForGLSL(
assign(&terminatorBuilder, globalOutputVal, localVal);
}
}
+ else if (auto ptrType = as<IRPtrTypeBase>(paramType))
+ {
+ // This is the case where the parameter is passed by const
+ // reference. We simply replace existing uses of the parameter
+ // with the real global variable.
+ SLANG_ASSERT(
+ ptrType->getOp() == kIROp_ConstRefType ||
+ ptrType->getAddressSpace() == AddressSpace::Input ||
+ ptrType->getAddressSpace() == AddressSpace::BuiltinInput);
+
+ auto globalValue = createGLSLGlobalVaryings(
+ context,
+ codeGenContext,
+ builder,
+ valueType,
+ paramLayout,
+ LayoutResourceKind::VaryingInput,
+ stage,
+ pp);
+ tryReplaceUsesOfStageInput(context, globalValue, pp);
+ for (auto dec : pp->getDecorations())
+ {
+ if (dec->getOp() != kIROp_GlobalVariableShadowingGlobalParameterDecoration)
+ continue;
+ auto globalVar = dec->getOperand(0);
+ auto key = dec->getOperand(1);
+ IRInst* realGlobalVar = nullptr;
+ if (globalValue.flavor != ScalarizedVal::Flavor::tuple)
+ continue;
+ if (auto tupleVal = as<ScalarizedTupleValImpl>(globalValue.impl))
+ {
+ for (auto elem : tupleVal->elements)
+ {
+ if (elem.key == key)
+ {
+ realGlobalVar = elem.val.irValue;
+ break;
+ }
+ }
+ }
+ SLANG_ASSERT(realGlobalVar);
+
+ // Remove all stores into the global var introduced during
+ // the initial glsl global var translation pass since we are
+ // going to replace the global var with a pointer to the real
+ // input, and it makes no sense to store values into such real
+ // input locations.
+ traverseUses(
+ globalVar,
+ [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ if (auto store = as<IRStore>(user))
+ {
+ if (store->getPtrUse() == use)
+ {
+ store->removeAndDeallocate();
+ }
+ }
+ });
+ // we will be replacing uses of `globalVarToReplace`. We need
+ // globalVarToReplaceNextUse to catch the next use before it is removed from the
+ // list of uses.
+ globalVar->replaceUsesWith(realGlobalVar);
+ globalVar->removeAndDeallocate();
+ }
+ }
else
{
// This is the "easy" case where the parameter wasn't
@@ -3451,6 +3648,7 @@ ScalarizedVal legalizeEntryPointReturnValueForGLSL(
return result;
}
+
void legalizeEntryPointForGLSL(
Session* session,
IRModule* module,
@@ -3554,12 +3752,12 @@ void legalizeEntryPointForGLSL(
// and turn them into global variables.
if (auto firstBlock = func->getFirstBlock())
{
- // Any initialization code we insert for parameters needs
- // to be at the start of the "ordinary" instructions in the block:
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
-
for (auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam())
{
+ // Any initialization code we insert for parameters needs
+ // to be at the start of the "ordinary" instructions in the block:
+ builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
+
// We assume that the entry-point parameters will all have
// layout information attached to them, which is kept up-to-date
// by any transformations affecting the parameter list.
@@ -3606,11 +3804,11 @@ void legalizeEntryPointForGLSL(
{
auto type = value.globalParam->getDataType();
- // Strip out if there is one
- auto outType = as<IROutType>(type);
- if (outType)
+ // Strip ptr if there is one.
+ auto ptrType = as<IRPtrTypeBase>(type);
+ if (ptrType)
{
- type = outType->getValueType();
+ type = ptrType->getValueType();
}
// Get the array type
@@ -3627,10 +3825,13 @@ void legalizeEntryPointForGLSL(
auto elementCountInst = builder.getIntValue(builder.getIntType(), value.maxIndex + 1);
IRType* sizedArrayType = builder.getArrayType(elementType, elementCountInst);
- // Re-add out if there was one on the input
- if (outType)
+ // Re-add ptr if there was one on the input
+ if (ptrType)
{
- sizedArrayType = builder.getOutType(sizedArrayType);
+ sizedArrayType = builder.getPtrType(
+ ptrType->getOp(),
+ sizedArrayType,
+ ptrType->getAddressSpace());
}
// Change the globals type
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 01466ed00..88a9ac5e3 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -736,7 +736,8 @@ INST(GetVulkanRayTracingPayloadLocation, GetVulkanRayTracingPayloadLocation, 1,
INST(GetLegalizedSPIRVGlobalParamAddr, GetLegalizedSPIRVGlobalParamAddr, 1, 0)
-INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, 0)
+INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, HOISTABLE)
+INST(ResolveVaryingInputRef, ResolveVaryingInputRef, 1, HOISTABLE)
INST(ForceVarIntoStructTemporarily, ForceVarIntoStructTemporarily, 1, 0)
INST(MetalAtomicCast, MetalAtomicCast, 1, 0)
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index 025bcf1b8..33f3944fd 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -651,7 +651,9 @@ protected:
// The materialized value can be used to completely
// replace the original parameter.
//
- param->replaceUsesWith(materialized);
+ auto localVar = builder.emitVar(materialized->getDataType());
+ builder.emitStore(localVar, materialized);
+ param->replaceUsesWith(localVar);
param->removeAndDeallocate();
}
@@ -1475,4 +1477,71 @@ void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* si
context.processModule(module, sink);
}
+void depointerizeInputParams(IRFunc* entryPointFunc)
+{
+ List<IRParam*> workList;
+ List<Index> modifiedParamIndices;
+ Index i = 0;
+ for (auto param : entryPointFunc->getParams())
+ {
+ if (auto constRefType = as<IRConstRefType>(param->getFullType()))
+ {
+ switch (constRefType->getValueType()->getOp())
+ {
+ case kIROp_VerticesType:
+ case kIROp_IndicesType:
+ case kIROp_PrimitivesType:
+ continue;
+ default:
+ break;
+ }
+ workList.add(param);
+ modifiedParamIndices.add(i);
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(param->getFullType()))
+ {
+ switch (ptrType->getAddressSpace())
+ {
+ case AddressSpace::Input:
+ case AddressSpace::BuiltinInput:
+ workList.add(param);
+ modifiedParamIndices.add(i);
+ break;
+ }
+ }
+ i++;
+ }
+ for (auto param : workList)
+ {
+ auto valueType = as<IRPtrTypeBase>(param->getDataType())->getValueType();
+ IRBuilder builder(param);
+ setInsertBeforeOrdinaryInst(&builder, param);
+ auto var = builder.emitVar(valueType);
+ param->replaceUsesWith(var);
+ param->setFullType(valueType);
+ builder.emitStore(var, param);
+ }
+
+ fixUpFuncType(entryPointFunc);
+
+ // Fix up callsites of the entrypoint func.
+ for (auto use = entryPointFunc->firstUse; use; use = use->nextUse)
+ {
+ auto call = as<IRCall>(use->getUser());
+ if (!call)
+ continue;
+ IRBuilder builder(call);
+ builder.setInsertBefore(call);
+ for (auto paramIndex : modifiedParamIndices)
+ {
+ auto arg = call->getArg(paramIndex);
+ auto ptrType = as<IRPtrTypeBase>(arg->getDataType());
+ if (!ptrType)
+ continue;
+ auto val = builder.emitLoad(arg);
+ call->setArg(paramIndex, val);
+ }
+ }
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h
index 7604cb245..efd61e87c 100644
--- a/source/slang/slang-ir-legalize-varying-params.h
+++ b/source/slang/slang-ir-legalize-varying-params.h
@@ -18,6 +18,7 @@ void legalizeEntryPointVaryingParamsForCPU(IRModule* module, DiagnosticSink* sin
void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* sink);
+void depointerizeInputParams(IRFunc* entryPoint);
// (#4375) Once `slang-ir-metal-legalize.cpp` is merged with
// `slang-ir-legalize-varying-params.cpp`, move the following
diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp
index 7f67c9254..74e84f1ee 100644
--- a/source/slang/slang-ir-lower-buffer-element-type.cpp
+++ b/source/slang/slang-ir-lower-buffer-element-type.cpp
@@ -412,6 +412,28 @@ struct LoweredElementTypeContext
return 4;
}
+ bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config)
+ {
+ // For spirv, we always want to lower all matrix types, because SPIRV does not support
+ // specifying matrix layout/stride if the matrix type is used in places other than
+ // defining a struct field. This means that if a matrix is used to define a varying
+ // parameter, we always want to wrap it in a struct.
+ //
+ if (target->shouldEmitSPIRVDirectly())
+ {
+ return true;
+ }
+
+ if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
+ config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural)
+ {
+ // For other targets, we only lower the matrix types if they differ from the default
+ // matrix layout.
+ return false;
+ }
+ return true;
+ }
+
LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config)
{
IRBuilder builder(type);
@@ -422,18 +444,10 @@ struct LoweredElementTypeContext
if (auto matrixType = as<IRMatrixType>(type))
{
- // For spirv, we always want to lower all matrix types, because matrix types
- // are considered abstract types.
- if (!target->shouldEmitSPIRVDirectly())
+ if (!shouldLowerMatrixType(matrixType, config))
{
- // For other targets, we only lower the matrix types if they differ from the default
- // matrix layout.
- if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout &&
- config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural)
- {
- info.loweredType = type;
- return info;
- }
+ info.loweredType = type;
+ return info;
}
auto loweredType = builder.createStructType();
@@ -859,27 +873,24 @@ struct LoweredElementTypeContext
{
IRType* elementType = nullptr;
- if (options.lowerBufferPointer)
+ if (auto ptrType = as<IRPtrTypeBase>(globalInst))
{
- if (auto ptrType = as<IRPtrTypeBase>(globalInst))
+ switch (ptrType->getAddressSpace())
{
- switch (ptrType->getAddressSpace())
- {
- case AddressSpace::UserPointer:
- case AddressSpace::Input:
- case AddressSpace::Output:
- elementType = ptrType->getValueType();
- break;
- }
+ case AddressSpace::UserPointer:
+ if (!options.lowerBufferPointer)
+ continue;
+ [[fallthrough]];
+ case AddressSpace::Input:
+ case AddressSpace::Output:
+ elementType = ptrType->getValueType();
+ break;
}
}
- else
- {
- if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst))
- elementType = structBuffer->getElementType();
- else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst))
- elementType = constBuffer->getElementType();
- }
+ if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst))
+ elementType = structBuffer->getElementType();
+ else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst))
+ elementType = constBuffer->getElementType();
if (as<IRTextureBufferType>(globalInst))
continue;
if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) &&
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 835041a59..ce5b34c3e 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -1924,6 +1924,7 @@ struct LegalizeMetalEntryPointContext
void legalizeEntryPointForMetal(EntryPointInfo entryPoint)
{
// Input Parameter Legalize
+ depointerizeInputParams(entryPoint.entryPointFunc);
hoistEntryPointParameterFromStruct(entryPoint);
packStageInParameters(entryPoint);
flattenInputParameters(entryPoint);
diff --git a/source/slang/slang-ir-resolve-varying-input-ref.cpp b/source/slang/slang-ir-resolve-varying-input-ref.cpp
new file mode 100644
index 000000000..0707c566f
--- /dev/null
+++ b/source/slang/slang-ir-resolve-varying-input-ref.cpp
@@ -0,0 +1,92 @@
+#include "slang-ir-resolve-varying-input-ref.h"
+
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+void resolveVaryingInputRef(IRFunc* func)
+{
+ List<IRInst*> toRemove;
+ for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock())
+ {
+ for (auto inst : bb->getChildren())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_ResolveVaryingInputRef:
+ {
+ // Resolve a reference to varying input to the actual global param
+ // representing the varying input.
+ auto operand = inst->getOperand(0);
+ List<IRInst*> accessChain;
+ List<IRInst*> types;
+ auto rootAddr = getRootAddr(operand, accessChain, &types);
+ if (rootAddr->getOp() == kIROp_Param || rootAddr->getOp() == kIROp_GlobalParam)
+ {
+ // If the referred operand is already a global param, use it directly.
+ inst->replaceUsesWith(operand);
+ toRemove.add(inst);
+ break;
+ }
+ // If the referred operand is a local var,
+ // and there is a store(var, load(globalParam)),
+ // replace `inst` with `globalParam`.
+ IRInst* srcPtr = nullptr;
+ for (auto use = rootAddr->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (auto store = as<IRStore>(user))
+ {
+ if (store->getPtrUse() == use)
+ {
+ if (auto load = as<IRLoad>(store->getVal()))
+ {
+ auto ptr = load->getPtr();
+ if (ptr->getOp() == kIROp_Param ||
+ ptr->getOp() == kIROp_GlobalParam)
+ {
+ if (!srcPtr)
+ srcPtr = ptr;
+ else
+ {
+ srcPtr = nullptr;
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
+ if (srcPtr)
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto resolvedPtr = builder.emitElementAddress(
+ srcPtr,
+ accessChain.getArrayView(),
+ types.getArrayView());
+ inst->replaceUsesWith(resolvedPtr);
+ toRemove.add(inst);
+ }
+ }
+ break;
+ }
+ }
+ }
+ for (auto inst : toRemove)
+ {
+ inst->removeAndDeallocate();
+ }
+}
+
+void resolveVaryingInputRef(IRModule* module)
+{
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (globalInst->findDecoration<IREntryPointDecoration>())
+ resolveVaryingInputRef((IRFunc*)globalInst);
+ }
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-resolve-varying-input-ref.h b/source/slang/slang-ir-resolve-varying-input-ref.h
new file mode 100644
index 000000000..5cbff0f8c
--- /dev/null
+++ b/source/slang/slang-ir-resolve-varying-input-ref.h
@@ -0,0 +1,10 @@
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+void resolveVaryingInputRef(IRFunc* func);
+void resolveVaryingInputRef(IRModule* module);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp
index 65cb8f64f..a44e16a7c 100644
--- a/source/slang/slang-ir-translate-glsl-global-var.cpp
+++ b/source/slang/slang-ir-translate-glsl-global-var.cpp
@@ -122,7 +122,8 @@ struct GlobalVarTranslationContext
// Add an entry point parameter for all the inputs.
auto firstBlock = entryPointFunc->getFirstBlock();
builder.setInsertInto(firstBlock);
- auto inputParam = builder.emitParam(inputStructType);
+ auto inputParam = builder.emitParam(
+ builder.getPtrType(kIROp_ConstRefType, inputStructType, AddressSpace::Input));
builder.addLayoutDecoration(inputParam, paramLayout);
// Initialize all global variables.
@@ -133,7 +134,8 @@ struct GlobalVarTranslationContext
auto inputType = cast<IRPtrTypeBase>(input->getDataType())->getValueType();
builder.emitStore(
input,
- builder.emitFieldExtract(inputType, inputParam, inputKeys[i]));
+ builder
+ .emitFieldExtract(inputType, builder.emitLoad(inputParam), inputKeys[i]));
// Relate "global variable" to a "global parameter" for use later in compilation
// to resolve a "global variable" shadowing a "global parameter" relationship.
builder.addGlobalVariableShadowingGlobalParameterDecoration(
diff --git a/source/slang/slang-ir-vk-invert-y.cpp b/source/slang/slang-ir-vk-invert-y.cpp
index e7fc81144..70f2584ac 100644
--- a/source/slang/slang-ir-vk-invert-y.cpp
+++ b/source/slang/slang-ir-vk-invert-y.cpp
@@ -104,10 +104,15 @@ void rcpWOfPositionInput(IRModule* module)
[&](IRUse* use)
{
// Get the inverted vector.
- builder.setInsertBefore(use->getUser());
- auto invertedVal = _invertWOfVector(builder, globalInst);
- // Replace original uses with the invertex vector.
- builder.replaceOperand(use, invertedVal);
+ auto user = use->getUser();
+ if (user->getOp() == kIROp_Load)
+ {
+ builder.setInsertBefore(user);
+ auto val = builder.emitLoad(globalInst);
+ auto invertedVal = _invertWOfVector(builder, val);
+ user->replaceUsesWith(invertedVal);
+ user->removeAndDeallocate();
+ }
});
}
}
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
index 907c2b8ba..f76a0541c 100644
--- a/source/slang/slang-ir-wgsl-legalize.cpp
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -1362,6 +1362,9 @@ struct LegalizeWGSLEntryPointContext
void legalizeEntryPointForWGSL(EntryPointInfo entryPoint)
{
+ // If the entrypoint is receiving varying inputs as a pointer, turn it into a value.
+ depointerizeInputParams(entryPoint.entryPointFunc);
+
// Input Parameter Legalize
flattenInputParameters(entryPoint);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index ff1cd49ea..daeaca67b 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -8243,6 +8243,8 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options)
case kIROp_GetStringHash:
case kIROp_AllocateOpaqueHandle:
case kIROp_GetArrayLength:
+ case kIROp_ResolveVaryingInputRef:
+ case kIROp_GetPerVertexInputArray:
return false;
case kIROp_ForwardDifferentiate:
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 5bbe44e9b..011ea6bc7 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2767,14 +2767,15 @@ ParameterDirection getParameterDirection(VarDeclBase* paramDecl)
///
ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection defaultDirection)
{
- auto parentParent = getParentDecl(parentDecl);
+ auto parentParent = getParentAggTypeDecl(parentDecl);
+
// The `this` parameter for a `class` is always `in`.
if (as<ClassDecl>(parentParent))
{
return kParameterDirection_In;
}
- if (parentParent->findModifier<NonCopyableTypeAttribute>())
+ if (parentParent && parentParent->findModifier<NonCopyableTypeAttribute>())
{
if (parentDecl->hasModifier<MutatingAttribute>())
return kParameterDirection_Ref;
@@ -2982,6 +2983,9 @@ struct IRLoweringParameterInfo
// The direction (`in` vs `out` vs `in out`)
ParameterDirection direction;
+ // The direction declared in user code.
+ ParameterDirection declaredDirection = ParameterDirection::kParameterDirection_In;
+
// The variable/parameter declaration for
// this parameter (if any)
VarDeclBase* decl = nullptr;
@@ -3005,6 +3009,7 @@ IRLoweringParameterInfo getParameterInfo(
info.type = getParamType(context->astBuilder, paramDecl);
info.decl = paramDecl.getDecl();
info.direction = getParameterDirection(paramDecl.getDecl());
+ info.declaredDirection = info.direction;
info.isThisParam = false;
return info;
}
@@ -3051,6 +3056,7 @@ void addThisParameter(ParameterDirection direction, Type* type, ParameterLists*
info.type = type;
info.decl = nullptr;
info.direction = direction;
+ info.declaredDirection = direction;
info.isThisParam = true;
ioParameterLists->params.add(info);
@@ -3064,10 +3070,22 @@ void maybeAddReturnDestinationParam(ParameterLists* ioParameterLists, Type* resu
info.type = resultType;
info.decl = nullptr;
info.direction = kParameterDirection_Ref;
+ info.declaredDirection = info.direction;
info.isReturnDestination = true;
ioParameterLists->params.add(info);
}
}
+
+void makeVaryingInputParamConstRef(IRLoweringParameterInfo& paramInfo)
+{
+ if (paramInfo.direction != kParameterDirection_In)
+ return;
+ if (paramInfo.decl->findModifier<HLSLUniformModifier>())
+ return;
+ if (as<HLSLPatchType>(paramInfo.type))
+ return;
+ paramInfo.direction = kParameterDirection_ConstRef;
+}
//
// And here is our function that will do the recursive walk:
void collectParameterLists(
@@ -3137,13 +3155,31 @@ void collectParameterLists(
//
if (auto callableDeclRef = declRef.as<CallableDecl>())
{
+ // We need a special case here when lowering the varying parameters of an entrypoint
+ // function. Due to the existence of `EvaluateAttributeAtSample` and friends, we need to
+ // always lower the varying inputs as `__constref` parameters so we can pass pointers to
+ // these intrinsics.
+ // This means that although these parameters are declared as "in" parameters in the source,
+ // we will actually treat them as __constref parameters when lowering to IR. A complication
+ // result from this is that if the original source code actually modifies the input
+ // parameter we still need to create a local var to hold the modified value. In the future
+ // when we are able to update our language spec to always assume input parameters are
+ // immutable, then we can remove this adhoc logic of introducing temporary variables. For
+ // For now we will rely on a follow up pass to remove unnecessary temporary variables if
+ // we can determine that they are never actually writtten to by the user.
+ //
+ bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier<EntryPointAttribute>();
+
// Don't collect parameters from the outer scope if
// we are in a `static` context.
if (mode == kParameterListCollectMode_Default)
{
for (auto paramDeclRef : getParameters(context->astBuilder, callableDeclRef))
{
- ioParameterLists->params.add(getParameterInfo(context, paramDeclRef));
+ auto paramInfo = getParameterInfo(context, paramDeclRef);
+ if (lowerVaryingInputAsConstRef)
+ makeVaryingInputParamConstRef(paramInfo);
+ ioParameterLists->params.add(paramInfo);
}
maybeAddReturnDestinationParam(
ioParameterLists,
@@ -5623,9 +5659,7 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBase<RValueExprLowe
LoweredValInfo visitOpenRefExpr(OpenRefExpr* expr)
{
auto inner = lowerLValueExpr(context, expr->innerExpr);
- auto builder = getBuilder();
- auto irLoad = builder->emitLoad(inner.val);
- return LoweredValInfo::simple(irLoad);
+ return LoweredValInfo::ptr(inner.val);
}
};
@@ -9980,6 +10014,22 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (paramInfo.isReturnDestination)
subContext->returnDestination = paramVal;
+ if (paramInfo.declaredDirection == kParameterDirection_In &&
+ paramInfo.direction == kParameterDirection_ConstRef)
+ {
+ // If the parameter is originally declared as "in", but we are
+ // lowering it as constref for any reason (e.g. it is a varying input),
+ // then we need to emit a local variable to hold the original value, so
+ // that we can still generate correct code when the user trys to mutate
+ // the variable.
+ // The local variable introduced here is cleaned up by the SSA pass, if
+ // we can determine that there are no actual writes into the local var.
+ auto irLocal =
+ subBuilder->emitVar(tryGetPointedToType(subBuilder, irParamType));
+ auto localVal = LoweredValInfo::ptr(irLocal);
+ assign(subContext, localVal, paramVal);
+ paramVal = localVal;
+ }
// TODO: We might want to copy the pointed-to value into
// a temporary at the start of the function, and then copy
// back out at the end, so that we don't have to worry
@@ -10987,6 +11037,16 @@ static void lowerFrontEndEntryPointToIR(
auto entryPointFuncDecl = entryPoint->getFuncDecl();
+ if (!entryPointFuncDecl->findModifier<EntryPointAttribute>())
+ {
+ // If the entry point doesn't have an explicit `[shader("...")]` attribute,
+ // then we make sure to add one here, so the lowering logic knows it is an
+ // entry point.
+ auto entryPointAttr = context->astBuilder->create<EntryPointAttribute>();
+ entryPointAttr->capabilitySet = entryPoint->getProfile().getCapabilityName();
+ addModifier(entryPointFuncDecl, entryPointAttr);
+ }
+
auto builder = context->irBuilder;
builder->setInsertInto(builder->getModule()->getModuleInst());
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 22491c848..c275a868b 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -7531,12 +7531,6 @@ static IRFloatingPointValue _foldFloatPrefixOp(TokenType tokenType, IRFloatingPo
static std::optional<SPIRVAsmOperand> parseSPIRVAsmOperand(Parser* parser)
{
- const auto slangIdentOperand = [&](auto flavor)
- {
- auto token = parser->tokenReader.peekToken();
- return SPIRVAsmOperand{flavor, token, parseAtomicExpr(parser)};
- };
-
const auto slangTypeExprOperand = [&](auto flavor)
{
auto tok = parser->tokenReader.peekToken();
@@ -7673,12 +7667,13 @@ static std::optional<SPIRVAsmOperand> parseSPIRVAsmOperand(Parser* parser)
// A &foo variable reference (for the address of foo)
else if (AdvanceIf(parser, TokenType::OpBitAnd))
{
- return slangIdentOperand(SPIRVAsmOperand::SlangValueAddr);
+ Expr* expr = parsePostfixExpr(parser);
+ return SPIRVAsmOperand{SPIRVAsmOperand::SlangValueAddr, Token{}, expr};
}
// A $foo variable
else if (AdvanceIf(parser, TokenType::Dollar))
{
- Expr* expr = parseAtomicExpr(parser);
+ Expr* expr = parsePostfixExpr(parser);
return SPIRVAsmOperand{SPIRVAsmOperand::SlangValue, Token{}, expr};
}
// A $$foo type