diff options
| author | Yong He <yonghe@outlook.com> | 2025-01-07 22:26:31 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-07 22:26:31 -0800 |
| commit | c43f6fa55aca23365c86c6ec1737d42be74d9d3e (patch) | |
| tree | 2c49bc1dbd12ae5f46d682a3f240465931471060 /source | |
| parent | 1a56f58fdd0c704e6dc0fad0f0ec33a25a35e60b (diff) | |
Lower varying parameters as pointers instead of SSA values. (#5919)
* Add executable test on matrix-typed vertex input.
* Fix emit logic of matrix layout qualifier.
* Pass fragment shader varying input by constref to allow EvaluateAttributeAtCentroid etc. to be implemented correctly.
Diffstat (limited to 'source')
29 files changed, 977 insertions, 214 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 625f8f608..adb7470dd 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1,3 +1,5 @@ +//public module core; + // Slang `core` library // Aliases for base types diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index 2ff71a74e..ba26b5d84 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -9080,7 +9080,7 @@ public vector<float, N> fwidthCoarse(vector<float, N> p) [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public float interpolateAtCentroid(__ref float interpolant) +public float interpolateAtCentroid(__constref float interpolant) { __target_switch { @@ -9099,7 +9099,7 @@ __generic<let N : int> [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public vector<float, N> interpolateAtCentroid(__ref vector<float, N> interpolant) +public vector<float, N> interpolateAtCentroid(__constref vector<float, N> interpolant) { __target_switch { @@ -9118,7 +9118,7 @@ public vector<float, N> interpolateAtCentroid(__ref vector<float, N> interpolant [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public float interpolateAtSample(__ref float interpolant, int sample) +public float interpolateAtSample(__constref float interpolant, int sample) { __target_switch { @@ -9137,7 +9137,7 @@ __generic<let N : int> [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public vector<float, N> interpolateAtSample(__ref vector<float, N> interpolant, int sample) +public vector<float, N> interpolateAtSample(__constref vector<float, N> interpolant, int sample) { __target_switch { @@ -9156,7 +9156,7 @@ public vector<float, N> interpolateAtSample(__ref vector<float, N> interpolant, [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public float interpolateAtOffset(__ref float interpolant, vec2 offset) +public float interpolateAtOffset(__constref float interpolant, vec2 offset) { __target_switch { @@ -9175,7 +9175,7 @@ __generic<let N : int> [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public vector<float, N> interpolateAtOffset(__ref vector<float, N> interpolant, vec2 offset) +public vector<float, N> interpolateAtOffset(__constref vector<float, N> interpolant, vec2 offset) { __target_switch { diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index d8f9845b8..d5b59427f 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -8101,89 +8101,139 @@ RasterizerOrderedStructuredBuffer<T> __getEquivalentStructuredBuffer<T>(Rasteriz // Attribute evaluation +T __EvaluateAttributeAtCentroid<T>(__constref T x) +{ + __target_switch + { + case hlsl: __intrinsic_asm "EvaluateAttributeAtCentroid"; + case glsl: __intrinsic_asm "interpolateAtCentroid"; + } +} + // TODO: The matrix cases of these functions won't actuall work // when compiled to GLSL, since they only support scalar/vector // TODO: Should these be constrains to `__BuiltinFloatingPointType`? // TODO: SPIRV-direct does not support non-floating-point types. +/// Interpolates vertex attribute at centroid position. +/// @param x The vertex attribute to interpolate. +/// @return The interpolated attribute value. +/// @remarks `x` must be a direct reference to a fragment shader varying input. +/// @category interpolation Vertex Interpolation Functions __generic<T : __BuiltinArithmeticType> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -T EvaluateAttributeAtCentroid(T x) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +T EvaluateAttributeAtCentroid(__constref T x) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtCentroid"; + case hlsl: + case glsl: + return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x)); case spirv: return spirv_asm { - OpExtInst $$T result glsl450 InterpolateAtCentroid $x + OpCapability InterpolationFunction; + OpExtInst $$T result glsl450 InterpolateAtCentroid $__ResolveVaryingInputRef(x) }; } } __generic<T : __BuiltinArithmeticType, let N : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -vector<T,N> EvaluateAttributeAtCentroid(vector<T,N> x) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +vector<T,N> EvaluateAttributeAtCentroid(__constref vector<T,N> x) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtCentroid"; + case hlsl: + case glsl: + return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x)); case spirv: return spirv_asm { - OpExtInst $$vector<T,N> result glsl450 InterpolateAtCentroid $x + OpCapability InterpolationFunction; + OpExtInst $$vector<T,N> result glsl450 InterpolateAtCentroid $__ResolveVaryingInputRef(x) }; } } __generic<T : __BuiltinArithmeticType, let N : int, let M : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -matrix<T,N,M> EvaluateAttributeAtCentroid(matrix<T,N,M> x) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +matrix<T,N,M> EvaluateAttributeAtCentroid(__constref matrix<T,N,M> x) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtCentroid"; + case hlsl: + case glsl: + return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x)); default: MATRIX_MAP_UNARY(T, N, M, EvaluateAttributeAtCentroid, x); } } +T __EvaluateAttributeAtSample<T>(__constref T x, uint sampleIndex) +{ + __target_switch + { + case hlsl: __intrinsic_asm "EvaluateAttributeAtSample"; + case glsl: __intrinsic_asm "interpolateAtSample"; + } +} + +/// Interpolates vertex attribute at the current fragment sample position. +/// @param x The vertex attribute to interpolate. +/// @return The interpolated attribute value. +/// @remarks `x` must be a direct reference to a fragment shader varying input. +/// @category interpolation Vertex Interpolation Functions __generic<T : __BuiltinArithmeticType> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -T EvaluateAttributeAtSample(T x, uint sampleindex) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +T EvaluateAttributeAtSample(__constref T x, uint sampleindex) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))"; + case hlsl: + case glsl: + return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex); case spirv: return spirv_asm { - OpExtInst $$T result glsl450 InterpolateAtSample $x $sampleindex + OpCapability InterpolationFunction; + OpExtInst $$T result glsl450 InterpolateAtSample $__ResolveVaryingInputRef(x) $sampleindex }; } } __generic<T : __BuiltinArithmeticType, let N : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -vector<T,N> EvaluateAttributeAtSample(vector<T,N> x, uint sampleindex) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +vector<T,N> EvaluateAttributeAtSample(__constref vector<T,N> x, uint sampleindex) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))"; + case hlsl: + case glsl: + return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex); case spirv: return spirv_asm { - OpExtInst $$vector<T,N> result glsl450 InterpolateAtSample $x $sampleindex + OpCapability InterpolationFunction; + OpExtInst $$vector<T,N> result glsl450 InterpolateAtSample $__ResolveVaryingInputRef(x) $sampleindex }; } } __generic<T : __BuiltinArithmeticType, let N : int, let M : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +matrix<T,N,M> EvaluateAttributeAtSample(__constref matrix<T,N,M> x, uint sampleindex) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))"; + case hlsl: + case glsl: + return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex); default: matrix<T,N,M> result; for(int i = 0; i < N; ++i) @@ -8194,21 +8244,59 @@ matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex) } } +T __EvaluateAttributeSnapped<T>(__constref T x, int2 offset) +{ + __target_switch + { + case hlsl: __intrinsic_asm "EvaluateAttributeSnapped"; + case glsl: __intrinsic_asm "EvaluateAttributeSnapped"; + } +} + +/// Interpolates vertex attribute at the specified subpixel offset. +/// @param x The vertex attribute to interpolate. +/// @param offset The subpixel offset. Each component is a 4-bit signed integer in range [-8, 7]. +/// @return The interpolated attribute value. +/// @remarks `x` must be a direct reference to a fragment shader varying input. +/// +/// The valid values of each component of `offset` are: +/// +/// - 1000 = -0.5f (-8 / 16) +/// - 1001 = -0.4375f (-7 / 16) +/// - 1010 = -0.375f (-6 / 16) +/// - 1011 = -0.3125f (-5 / 16) +/// - 1100 = -0.25f (-4 / 16) +/// - 1101 = -0.1875f (-3 / 16) +/// - 1110 = -0.125f (-2 / 16) +/// - 1111 = -0.0625f (-1 / 16) +/// - 0000 = 0.0f ( 0 / 16) +/// - 0001 = 0.0625f ( 1 / 16) +/// - 0010 = 0.125f ( 2 / 16) +/// - 0011 = 0.1875f ( 3 / 16) +/// - 0100 = 0.25f ( 4 / 16) +/// - 0101 = 0.3125f ( 5 / 16) +/// - 0110 = 0.375f ( 6 / 16) +/// - 0111 = 0.4375f ( 7 / 16) +/// @category interpolation Vertex Interpolation Functions __generic<T : __BuiltinArithmeticType> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -T EvaluateAttributeSnapped(T x, int2 offset) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +T EvaluateAttributeSnapped(__constref T x, int2 offset) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)"; + case hlsl: + case glsl: + return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset); case spirv: { const float2 tmp = float2(16.f, 16.f); return spirv_asm { + OpCapability InterpolationFunction; %foffset:$$float2 = OpConvertSToF $offset; %offsetdiv16:$$float2 = OpFDiv %foffset $tmp; - result:$$T = OpExtInst glsl450 InterpolateAtOffset $x %offsetdiv16 + result:$$T = OpExtInst glsl450 InterpolateAtOffset $__ResolveVaryingInputRef(x) %offsetdiv16 }; } } @@ -8216,19 +8304,23 @@ T EvaluateAttributeSnapped(T x, int2 offset) __generic<T : __BuiltinArithmeticType, let N : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -vector<T,N> EvaluateAttributeSnapped(vector<T,N> x, int2 offset) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +vector<T,N> EvaluateAttributeSnapped(__constref vector<T,N> x, int2 offset) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)"; + case hlsl: + case glsl: + return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset); case spirv: { const float2 tmp = float2(16.f, 16.f); return spirv_asm { + OpCapability InterpolationFunction; %foffset:$$float2 = OpConvertSToF $offset; %offsetdiv16:$$float2 = OpFDiv %foffset $tmp; - result:$$vector<T,N> = OpExtInst glsl450 InterpolateAtOffset $x %offsetdiv16 + result:$$vector<T,N> = OpExtInst glsl450 InterpolateAtOffset $__ResolveVaryingInputRef(x) %offsetdiv16 }; } } @@ -8236,12 +8328,15 @@ vector<T,N> EvaluateAttributeSnapped(vector<T,N> x, int2 offset) __generic<T : __BuiltinArithmeticType, let N : int, let M : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -matrix<T,N,M> EvaluateAttributeSnapped(matrix<T,N,M> x, int2 offset) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +matrix<T,N,M> EvaluateAttributeSnapped(__constref matrix<T,N,M> x, int2 offset) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)"; + case hlsl: + case glsl: + return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset); default: matrix<T,N,M> result; for(int i = 0; i < N; ++i) @@ -9243,8 +9338,16 @@ matrix<T, N, M> fwidth(matrix<T, N, M> x) } } +__intrinsic_op($(kIROp_ResolveVaryingInputRef)) +Ref<T> __ResolveVaryingInputRef<T>(__constref T attribute); + __intrinsic_op($(kIROp_GetPerVertexInputArray)) -Array<T, 3> __GetPerVertexInputArray<T>(T attribute); +Ref<Array<T, 3>> __GetPerVertexInputArray<T>(__constref T attribute); + +T __GetAttributeAtVertex<T>(__constref T attribute, uint vertexIndex) +{ + __intrinsic_asm "GetAttributeAtVertex"; +} /// Get the value of a vertex attribute at a specific vertex. /// @@ -9265,15 +9368,15 @@ __glsl_extension(GL_EXT_fragment_shader_barycentric) [require(glsl_hlsl_spirv, getattributeatvertex)] [KnownBuiltin("GetAttributeAtVertex")] [__unsafeForceInlineEarly] -T GetAttributeAtVertex(T attribute, uint vertexIndex) +T GetAttributeAtVertex(__constref T attribute, uint vertexIndex) { __target_switch { case hlsl: - __intrinsic_asm "GetAttributeAtVertex"; + return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex); case glsl: case spirv: - return __GetPerVertexInputArray(attribute)[vertexIndex]; + return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex]; } } @@ -9294,20 +9397,16 @@ __generic<T : __BuiltinType, let N : int> __glsl_version(450) __glsl_extension(GL_EXT_fragment_shader_barycentric) [require(glsl_hlsl_spirv, getattributeatvertex)] -vector<T,N> GetAttributeAtVertex(vector<T,N> attribute, uint vertexIndex) +[__unsafeForceInlineEarly] +vector<T,N> GetAttributeAtVertex(__constref vector<T,N> attribute, uint vertexIndex) { __target_switch { case hlsl: - __intrinsic_asm "GetAttributeAtVertex"; - case glsl: - __intrinsic_asm "$0[$1]"; + return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex); + case glsl: case spirv: - return spirv_asm { - %_ptr_Input_vectorT = OpTypePointer Input $$vector<T,N>; - %addr = OpAccessChain %_ptr_Input_vectorT $attribute $vertexIndex; - result:$$vector<T,N> = OpLoad %addr; - }; + return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex]; } } @@ -9328,20 +9427,16 @@ __generic<T : __BuiltinType, let N : int, let M : int> __glsl_version(450) __glsl_extension(GL_EXT_fragment_shader_barycentric) [require(glsl_hlsl_spirv, getattributeatvertex)] -matrix<T,N,M> GetAttributeAtVertex(matrix<T,N,M> attribute, uint vertexIndex) +[__unsafeForceInlineEarly] +matrix<T,N,M> GetAttributeAtVertex(__constref matrix<T,N,M> attribute, uint vertexIndex) { __target_switch { case hlsl: - __intrinsic_asm "GetAttributeAtVertex"; - case glsl: - __intrinsic_asm "$0[$1]"; + return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex); + case glsl: case spirv: - return spirv_asm { - %_ptr_Input_matrixT = OpTypePointer Input $$matrix<T,N,M>; - %addr = OpAccessChain %_ptr_Input_matrixT $attribute $vertexIndex; - result:$$matrix<T,N,M> = OpLoad %addr; - }; + return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex]; } } diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index e8886a59a..911455f17 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -461,6 +461,8 @@ class ModuleDecl : public NamespaceDeclBase /// `__implementing` etc. bool isInLegacyLanguage = true; + DeclVisibility defaultVisibility = DeclVisibility::Internal; + SLANG_UNREFLECTED /// Map a type to the list of extensions of that type (if any) declared in this module diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index b77003bff..85d2d0d9f 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -460,6 +460,26 @@ struct ASTDumpContext } } + void dump(DeclVisibility vis) + { + switch (vis) + { + case DeclVisibility::Private: + m_writer->emit("private"); + break; + case DeclVisibility::Internal: + m_writer->emit("internal"); + break; + case DeclVisibility::Public: + m_writer->emit("public"); + break; + default: + m_writer->emit(String((int)vis).getUnownedSlice()); + break; + } + } + + void dump(const QualType& qualType) { if (qualType.isLeftValue) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ce3f1e64c..3667a36ba 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3186,6 +3186,11 @@ void SemanticsDeclVisitorBase::checkModule(ModuleDecl* moduleDecl) } } + if (moduleDecl->findModifier<PublicModifier>()) + { + moduleDecl->defaultVisibility = DeclVisibility::Public; + } + // We need/want to visit any `import` declarations before // anything else, to make sure that scoping works. // @@ -12604,8 +12609,10 @@ DeclVisibility getDeclVisibility(Decl* decl) } auto defaultVis = DeclVisibility::Default; if (auto parentModule = getModuleDecl(decl)) - defaultVis = - parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal; + { + defaultVis = parentModule->isInLegacyLanguage ? DeclVisibility::Public + : parentModule->defaultVisibility; + } // Members of other agg type decls will have their default visibility capped to the parents'. if (as<NamespaceDecl>(decl)) @@ -12790,10 +12797,6 @@ void diagnoseCapabilityProvenance( auto moduleDecl = getModuleDecl(declToPrint); if (thisModule != moduleDecl) break; - if (moduleDecl && moduleDecl->isInLegacyLanguage) - continue; - if (getDeclVisibility(declToPrint) == DeclVisibility::Public) - break; } if (previousDecl == declToPrint) break; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 83b668a33..95d5a2a7c 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -4099,6 +4099,15 @@ Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBas elementType = QualType(ptrType->getValueType()); elementType.isLeftValue = true; } + else + { + auto newExpr = maybeOpenRef(expr); + if (newExpr != expr) + { + expr = newExpr; + continue; + } + } if (elementType.type) { auto derefExpr = m_astBuilder->create<DerefExpr>(); @@ -4108,9 +4117,10 @@ Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBas expr = derefExpr; continue; } - // Default case: just use the expression as-is - return expr; + break; } + // Default case: just use the expression as-is + return expr; } Expr* SemanticsVisitor::CheckMatrixSwizzleExpr( diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 3175f1b07..64cc9969c 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1537,6 +1537,30 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) return true; } + if (auto load = as<IRLoad>(inst)) + { + // Loads from a constref global param should always be folded. + auto ptrType = load->getPtr()->getDataType(); + if (load->getPtr()->getOp() == kIROp_GlobalParam) + { + if (ptrType->getOp() == kIROp_ConstRefType) + return true; + if (auto ptrTypeBase = as<IRPtrTypeBase>(ptrType)) + { + auto addrSpace = ptrTypeBase->getAddressSpace(); + switch (addrSpace) + { + case Slang::AddressSpace::Uniform: + case Slang::AddressSpace::Input: + case Slang::AddressSpace::BuiltinInput: + return true; + default: + break; + } + } + } + } + // Always hold if inst is a call into an [__alwaysFoldIntoUseSite] function. if (auto call = as<IRCall>(inst)) { @@ -4709,9 +4733,21 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) auto rawType = varDecl->getDataType(); auto varType = rawType; - if (auto outType = as<IROutTypeBase>(varType)) + if (auto ptrType = as<IRPtrTypeBase>(varType)) { - varType = outType->getValueType(); + switch (ptrType->getAddressSpace()) + { + case AddressSpace::Input: + case AddressSpace::Output: + case AddressSpace::BuiltinInput: + case AddressSpace::BuiltinOutput: + varType = ptrType->getValueType(); + break; + default: + if (as<IROutTypeBase>(ptrType)) + varType = ptrType->getValueType(); + break; + } } if (as<IRVoidType>(varType)) return; diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index f91c4d06e..b0d1fbb4c 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -310,6 +310,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S } case kIROp_NativePtrType: case kIROp_PtrType: + case kIROp_ConstRefType: { auto elementType = (IRType*)type->getOperand(0); SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out)); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index a863e7eb1..23fff37ac 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -3166,6 +3166,11 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType()); return; } + case kIROp_ConstRefType: + { + emitSimpleTypeImpl(as<IRConstRefType>(type)->getValueType()); + return; + } default: break; } @@ -3471,15 +3476,18 @@ void GLSLSourceEmitter::emitMatrixLayoutModifiersImpl(IRType* varType) // auto matrixType = as<IRMatrixType>(unwrapArray(varType)); - if (matrixType) { + auto layout = getIntVal(matrixType->getLayout()); + if (layout == getTargetProgram()->getOptionSet().getMatrixLayoutMode()) + return; + // Reminder: the meaning of row/column major layout // in our semantics is the *opposite* of what GLSL // calls them, because what they call "columns" // are what we call "rows." // - switch (getIntVal(matrixType->getLayout())) + switch (layout) { case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR: m_writer->emit("layout(row_major)\n"); diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 40d6f75d9..83eec17b4 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -1180,6 +1180,34 @@ void HLSLSourceEmitter::emitSimpleValueImpl(IRInst* inst) Super::emitSimpleValueImpl(inst); } +void HLSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) +{ + if (declarator) + { + // HLSL only allow matrix layout modifier when declaring a variable or struct field. + if (auto matType = as<IRMatrixType>(type)) + { + auto matrixLayout = getIntVal(matType->getLayout()); + if (getTargetProgram()->getOptionSet().getMatrixLayoutMode() != + (MatrixLayoutMode)matrixLayout) + { + switch (matrixLayout) + { + case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR: + m_writer->emit("column_major "); + break; + case SLANG_MATRIX_LAYOUT_ROW_MAJOR: + m_writer->emit("row_major "); + break; + default: + break; + } + } + } + } + Super::emitSimpleTypeAndDeclaratorImpl(type, declarator); +} + void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) { switch (type->getOp()) @@ -1313,6 +1341,11 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType()); return; } + case kIROp_ConstRefType: + { + emitSimpleTypeImpl(as<IRConstRefType>(type)->getValueType()); + return; + } default: break; } @@ -1671,28 +1704,6 @@ void HLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl) } } -void HLSLSourceEmitter::emitMatrixLayoutModifiersImpl(IRType* type) -{ - auto matType = as<IRMatrixType>(type); - if (!matType) - return; - auto matrixLayout = getIntVal(matType->getLayout()); - if (getTargetProgram()->getOptionSet().getMatrixLayoutMode() != (MatrixLayoutMode)matrixLayout) - { - switch (matrixLayout) - { - case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR: - m_writer->emit("column_major "); - break; - case SLANG_MATRIX_LAYOUT_ROW_MAJOR: - m_writer->emit("row_major "); - break; - default: - break; - } - } -} - void HLSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) { if (inst->findDecoration<IRRequiresNVAPIDecoration>()) diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index b2e2ca05a..6b99a7f50 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -55,11 +55,12 @@ protected: IRPackOffsetDecoration* decoration) SLANG_OVERRIDE; virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; + virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) + SLANG_OVERRIDE; virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; - virtual void emitMatrixLayoutModifiersImpl(IRType* varType) SLANG_OVERRIDE; virtual void emitParamTypeModifier(IRType* type) SLANG_OVERRIDE { emitMatrixLayoutModifiersImpl(type); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 7ea2fef88..b9217de41 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -42,6 +42,7 @@ #include "slang-ir-entry-point-uniforms.h" #include "slang-ir-explicit-global-context.h" #include "slang-ir-explicit-global-init.h" +#include "slang-ir-fix-entrypoint-callsite.h" #include "slang-ir-fuse-satcoop.h" #include "slang-ir-glsl-legalize.h" #include "slang-ir-glsl-liveness.h" @@ -76,6 +77,7 @@ #include "slang-ir-pytorch-cpp-binding.h" #include "slang-ir-redundancy-removal.h" #include "slang-ir-resolve-texture-format.h" +#include "slang-ir-resolve-varying-input-ref.h" #include "slang-ir-restructure-scoping.h" #include "slang-ir-restructure.h" #include "slang-ir-sccp.h" @@ -314,6 +316,7 @@ struct RequiredLoweringPassSet bool glslSSBO; bool byteAddressBuffer; bool dynamicResource; + bool resolveVaryingInputRef; }; // Scan the IR module and determine which lowering/legalization passes are needed based @@ -423,6 +426,9 @@ void calcRequiredLoweringPassSet( case kIROp_DynamicResourceType: result.dynamicResource = true; break; + case kIROp_ResolveVaryingInputRef: + result.resolveVaryingInputRef = true; + break; } if (!result.generics || !result.existentialTypeLayout) { @@ -591,6 +597,11 @@ Result linkAndOptimizeIR( if (requiredLoweringPassSet.glslGlobalVar) translateGLSLGlobalVar(codeGenContext, irModule); + if (requiredLoweringPassSet.resolveVaryingInputRef) + resolveVaryingInputRef(irModule); + + fixEntryPointCallsites(irModule); + // Replace any global constants with their values. // replaceGlobalConstants(irModule); diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.cpp b/source/slang/slang-ir-fix-entrypoint-callsite.cpp new file mode 100644 index 000000000..7390f3a7f --- /dev/null +++ b/source/slang/slang-ir-fix-entrypoint-callsite.cpp @@ -0,0 +1,101 @@ +#include "slang-ir-fix-entrypoint-callsite.h" + +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +namespace Slang +{ +// If the entrypoint is called by some other function, we need to clone the +// entrypoint and replace the callsites to call the cloned entrypoint instead. +// This is because we will be modifying the signature of the entrypoint during +// entrypoint legalization to rewrite the way system values are passed in. +// By replacing the callsites to call the cloned entrypoint that act as ordinary +// functions, we will no longer need to worry about changing the callsites when we +// legalize the entry-points. +// +void fixEntryPointCallsites(IRFunc* entryPoint) +{ + IRFunc* clonedEntryPointForCall = nullptr; + auto ensureClonedEntryPointForCall = [&]() -> IRFunc* + { + if (clonedEntryPointForCall) + return clonedEntryPointForCall; + IRCloneEnv cloneEnv; + IRBuilder builder(entryPoint); + builder.setInsertBefore(entryPoint); + clonedEntryPointForCall = (IRFunc*)cloneInst(&cloneEnv, &builder, entryPoint); + // Remove entrypoint and linkage decorations from the cloned callee. + List<IRInst*> decorsToRemove; + for (auto decor : clonedEntryPointForCall->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_ExportDecoration: + case kIROp_UserExternDecoration: + case kIROp_HLSLExportDecoration: + case kIROp_EntryPointDecoration: + case kIROp_LayoutDecoration: + case kIROp_NumThreadsDecoration: + case kIROp_ImportDecoration: + case kIROp_ExternCDecoration: + case kIROp_ExternCppDecoration: + decorsToRemove.add(decor); + break; + } + } + for (auto decor : decorsToRemove) + decor->removeAndDeallocate(); + return clonedEntryPointForCall; + }; + traverseUses( + entryPoint, + [&](IRUse* use) + { + auto user = use->getUser(); + auto call = as<IRCall>(user); + if (!call) + return; + auto callee = ensureClonedEntryPointForCall(); + call->setOperand(0, callee); + + // Fix up argument types: if the callee entrypoint is expecting a constref + // and the caller is passing a value, we need to wrap the value in a temporary var + // and pass the temporary var. + // + auto funcType = as<IRFuncType>(callee->getDataType()); + SLANG_ASSERT(funcType); + IRBuilder builder(call); + builder.setInsertBefore(call); + List<IRParam*> params; + for (auto param : callee->getParams()) + params.add(param); + if ((UInt)params.getCount() != call->getArgCount()) + return; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto paramType = params[i]->getDataType(); + auto arg = call->getArg(i); + if (auto refType = as<IRConstRefType>(paramType)) + { + if (!as<IRPtrTypeBase>(arg->getDataType())) + { + auto tempVar = builder.emitVar(refType->getValueType()); + builder.emitStore(tempVar, arg); + call->setArg(i, tempVar); + } + } + } + }); +} + +void fixEntryPointCallsites(IRModule* module) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration<IREntryPointDecoration>()) + fixEntryPointCallsites((IRFunc*)globalInst); + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.h b/source/slang/slang-ir-fix-entrypoint-callsite.h new file mode 100644 index 000000000..493d67a77 --- /dev/null +++ b/source/slang/slang-ir-fix-entrypoint-callsite.h @@ -0,0 +1,9 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +void fixEntryPointCallsites(IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 09bf245df..39f970319 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -67,11 +67,7 @@ struct ScalarizedValImpl : RefObject }; struct ScalarizedTupleValImpl; struct ScalarizedTypeAdapterValImpl; - -struct ScalarizedArrayIndexValImpl : ScalarizedValImpl -{ - Index index; -}; +struct ScalarizedArrayIndexValImpl; struct ScalarizedVal { @@ -132,15 +128,12 @@ struct ScalarizedVal result.impl = (ScalarizedValImpl*)impl; return result; } - static ScalarizedVal scalarizedArrayIndex(IRInst* irValue, Index index) + static ScalarizedVal scalarizedArrayIndex(ScalarizedArrayIndexValImpl* impl) { ScalarizedVal result; result.flavor = Flavor::arrayIndex; - auto impl = new ScalarizedArrayIndexValImpl; - impl->index = index; - - result.irValue = irValue; - result.impl = impl; + result.irValue = nullptr; + result.impl = (ScalarizedValImpl*)impl; return result; } @@ -151,8 +144,6 @@ struct ScalarizedVal RefPtr<ScalarizedValImpl> impl; }; -IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val); - // This is the case for a value that is a "tuple" of other values struct ScalarizedTupleValImpl : ScalarizedValImpl { @@ -175,6 +166,36 @@ struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl IRType* pretendType; // the type this value pretends to have }; +struct ScalarizedArrayIndexValImpl : ScalarizedValImpl +{ + ScalarizedVal arrayVal; + Index index; + IRType* elementType; +}; + +ScalarizedVal extractField( + IRBuilder* builder, + ScalarizedVal const& val, + UInt fieldIndex, // Pass ~0 in to search for the index via the key + IRStructKey* fieldKey); +ScalarizedVal adaptType(IRBuilder* builder, IRInst* val, IRType* toType, IRType* fromType); +ScalarizedVal adaptType( + IRBuilder* builder, + ScalarizedVal const& val, + IRType* toType, + IRType* fromType); +IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val); +ScalarizedVal getSubscriptVal( + IRBuilder* builder, + IRType* elementType, + ScalarizedVal val, + IRInst* indexVal); +ScalarizedVal getSubscriptVal( + IRBuilder* builder, + IRType* elementType, + ScalarizedVal val, + UInt index); + struct GlobalVaryingDeclarator { enum class Flavor @@ -1303,6 +1324,22 @@ ScalarizedVal createSimpleGLSLGlobalVarying( } } + AddressSpace addrSpace = AddressSpace::Uniform; + IROp ptrOpCode = kIROp_PtrType; + switch (kind) + { + case LayoutResourceKind::VaryingInput: + addrSpace = systemValueInfo ? AddressSpace::BuiltinInput : AddressSpace::Input; + break; + case LayoutResourceKind::VaryingOutput: + addrSpace = systemValueInfo ? AddressSpace::BuiltinOutput : AddressSpace::Output; + ptrOpCode = kIROp_OutType; + break; + default: + break; + } + + // If we have a declarator, we just use the normal logic, as that seems to work correctly // if (systemValueInfo && systemValueInfo->arrayIndex >= 0 && declarator == nullptr) @@ -1339,9 +1376,7 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // Set the array size to 0, to mean it is unsized auto arrayType = builder->getArrayType(type, 0); - IRType* paramType = kind == LayoutResourceKind::VaryingOutput - ? (IRType*)builder->getOutType(arrayType) - : arrayType; + IRType* paramType = builder->getPtrType(ptrOpCode, arrayType, addrSpace); auto globalParam = addGlobalParam(builder->getModule(), paramType); moveValueBefore(globalParam, builder->getFunc()); @@ -1371,9 +1406,12 @@ ScalarizedVal createSimpleGLSLGlobalVarying( semanticGlobal->addIndex(systemValueInfo->arrayIndex); // Make it an array index - ScalarizedVal val = ScalarizedVal::scalarizedArrayIndex( - semanticGlobal->globalParam, - systemValueInfo->arrayIndex); + ScalarizedVal val = ScalarizedVal::address(semanticGlobal->globalParam); + RefPtr<ScalarizedArrayIndexValImpl> arrayImpl = new ScalarizedArrayIndexValImpl(); + arrayImpl->arrayVal = val; + arrayImpl->index = systemValueInfo->arrayIndex; + arrayImpl->elementType = type; + val = ScalarizedVal::scalarizedArrayIndex(arrayImpl); // We need to make this access, an array access to the global if (auto fromType = systemValueInfo->requiredType) @@ -1466,14 +1504,14 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // like our IR function parameters, and need a wrapper // `Out<...>` type to represent outputs. // - bool isOutput = (kind == LayoutResourceKind::VaryingOutput); - IRType* paramType = isOutput ? builder->getOutType(type) : type; + + // Non system value varying inputs shall be passed as pointers. + IRType* paramType = builder->getPtrType(ptrOpCode, type, addrSpace); auto globalParam = addGlobalParam(builder->getModule(), paramType); moveValueBefore(globalParam, builder->getFunc()); - ScalarizedVal val = - isOutput ? ScalarizedVal::address(globalParam) : ScalarizedVal::value(globalParam); + ScalarizedVal val = ScalarizedVal::address(globalParam); if (systemValueInfo) { @@ -1958,10 +1996,10 @@ ScalarizedVal adaptType( break; case ScalarizedVal::Flavor::arrayIndex: { - auto element = builder->emitElementExtract( - val.irValue, - as<ScalarizedArrayIndexValImpl>(val.impl)->index); - return adaptType(builder, element, toType, fromType); + auto arrayImpl = as<ScalarizedArrayIndexValImpl>(val.impl); + auto elementVal = + getSubscriptVal(builder, fromType, arrayImpl->arrayVal, arrayImpl->index); + return adaptType(builder, elementVal, toType, fromType); } break; default: @@ -1970,8 +2008,6 @@ ScalarizedVal adaptType( } } -IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val); - void assign( IRBuilder* builder, ScalarizedVal const& left, @@ -1988,16 +2024,12 @@ void assign( // Determine the index auto leftArrayIndexVal = as<ScalarizedArrayIndexValImpl>(left.impl); - const auto arrayIndex = leftArrayIndexVal->index; - - auto arrayIndexInst = builder->getIntValue(builder->getIntType(), arrayIndex); - - // Store to the index - auto address = builder->emitElementAddress( - builder->getPtrType(right.irValue->getFullType()), - left.irValue, - arrayIndexInst); - builder->emitStore(address, rhs); + auto leftVal = getSubscriptVal( + builder, + leftArrayIndexVal->elementType, + leftArrayIndexVal->arrayVal, + leftArrayIndexVal->index); + builder->emitStore(leftVal.irValue, rhs); break; } @@ -2236,10 +2268,10 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val) case ScalarizedVal::Flavor::arrayIndex: { - auto element = builder->emitElementExtract( - val.irValue, - as<ScalarizedArrayIndexValImpl>(val.impl)->index); - return element; + auto impl = as<ScalarizedArrayIndexValImpl>(val.impl); + auto elementVal = + getSubscriptVal(builder, impl->elementType, impl->arrayVal, impl->index); + return materializeValue(builder, elementVal); } case ScalarizedVal::Flavor::tuple: { @@ -2735,9 +2767,9 @@ IRInst* getOrCreatePerVertexInputArray(GLSLLegalizationContext* context, IRInst* IRBuilder builder(inputVertexAttr); builder.setInsertBefore(inputVertexAttr); auto arrayType = builder.getArrayType( - inputVertexAttr->getDataType(), + tryGetPointedToType(&builder, inputVertexAttr->getDataType()), builder.getIntValue(builder.getIntType(), 3)); - arrayInst = builder.createGlobalParam(arrayType); + arrayInst = builder.createGlobalParam(builder.getPtrType(arrayType, AddressSpace::Input)); context->mapVertexInputToPerVertexArray[inputVertexAttr] = arrayInst; builder.addDecoration(arrayInst, kIROp_PerVertexDecoration); @@ -2770,21 +2802,105 @@ void tryReplaceUsesOfStageInput( [&](IRUse* use) { auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + builder.replaceOperand(use, val.irValue); + }); + } + break; + case ScalarizedVal::Flavor::address: + { + bool needMaterialize = false; + if (as<IRPtrTypeBase>(val.irValue->getDataType())) + { + if (!as<IRPtrTypeBase>(originalVal->getDataType())) + { + needMaterialize = true; + } + } + traverseUses( + originalVal, + [&](IRUse* use) + { + auto user = use->getUser(); if (user->getOp() == kIROp_GetPerVertexInputArray) { auto arrayInst = getOrCreatePerVertexInputArray(context, val.irValue); user->replaceUsesWith(arrayInst); user->removeAndDeallocate(); + return; + } + IRBuilder builder(user); + builder.setInsertBefore(user); + if (needMaterialize) + { + auto materializedVal = materializeValue(&builder, val); + builder.replaceOperand(use, materializedVal); } else { - IRBuilder builder(user); - builder.setInsertBefore(user); builder.replaceOperand(use, val.irValue); } }); } break; + case ScalarizedVal::Flavor::typeAdapter: + { + traverseUses( + originalVal, + [&](IRUse* use) + { + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + auto typeAdapter = as<ScalarizedTypeAdapterValImpl>(val.impl); + auto materializedInner = materializeValue(&builder, typeAdapter->val); + auto adapted = adaptType( + &builder, + materializedInner, + typeAdapter->pretendType, + typeAdapter->actualType); + if (user->getOp() == kIROp_Load) + { + user->replaceUsesWith(adapted.irValue); + user->removeAndDeallocate(); + } + else + { + use->set(adapted.irValue); + } + }); + } + break; + case ScalarizedVal::Flavor::arrayIndex: + { + traverseUses( + originalVal, + [&](IRUse* use) + { + auto arrayIndexImpl = as<ScalarizedArrayIndexValImpl>(val.impl); + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + auto subscriptVal = getSubscriptVal( + &builder, + arrayIndexImpl->elementType, + arrayIndexImpl->arrayVal, + arrayIndexImpl->index); + builder.setInsertBefore(user); + auto materializedInner = materializeValue(&builder, subscriptVal); + if (user->getOp() == kIROp_Load) + { + user->replaceUsesWith(materializedInner); + user->removeAndDeallocate(); + } + else + { + use->set(materializedInner); + } + }); + break; + } case ScalarizedVal::Flavor::tuple: { auto tupleVal = as<ScalarizedTupleValImpl>(val.impl); @@ -2793,22 +2909,36 @@ void tryReplaceUsesOfStageInput( [&](IRUse* use) { auto user = use->getUser(); - if (auto fieldExtract = as<IRFieldExtract>(user)) + switch (user->getOp()) { - auto fieldKey = fieldExtract->getField(); - ScalarizedVal fieldVal; - for (auto element : tupleVal->elements) + case kIROp_FieldExtract: + case kIROp_FieldAddress: { - if (element.key == fieldKey) + auto fieldKey = user->getOperand(1); + ScalarizedVal fieldVal; + for (auto element : tupleVal->elements) { - fieldVal = element.val; - break; + if (element.key == fieldKey) + { + fieldVal = element.val; + break; + } + } + if (fieldVal.flavor != ScalarizedVal::Flavor::none) + { + tryReplaceUsesOfStageInput(context, fieldVal, user); } } - if (fieldVal.flavor != ScalarizedVal::Flavor::none) + break; + case kIROp_Load: { - tryReplaceUsesOfStageInput(context, fieldVal, user); + IRBuilder builder(user); + builder.setInsertBefore(user); + auto materializedVal = materializeTupleValue(&builder, val); + user->replaceUsesWith(materializedVal); + user->removeAndDeallocate(); } + break; } }); } @@ -3066,7 +3196,7 @@ void legalizeEntryPointParameterForGLSL( // We are going to create a local variable of the appropriate // type, which will replace the parameter, along with // one or more global variables for the actual input/output. - + setInsertAfterOrdinaryInst(builder, pp); auto localVariable = builder->emitVar(valueType); auto localVal = ScalarizedVal::address(localVariable); @@ -3135,6 +3265,73 @@ void legalizeEntryPointParameterForGLSL( assign(&terminatorBuilder, globalOutputVal, localVal); } } + else if (auto ptrType = as<IRPtrTypeBase>(paramType)) + { + // This is the case where the parameter is passed by const + // reference. We simply replace existing uses of the parameter + // with the real global variable. + SLANG_ASSERT( + ptrType->getOp() == kIROp_ConstRefType || + ptrType->getAddressSpace() == AddressSpace::Input || + ptrType->getAddressSpace() == AddressSpace::BuiltinInput); + + auto globalValue = createGLSLGlobalVaryings( + context, + codeGenContext, + builder, + valueType, + paramLayout, + LayoutResourceKind::VaryingInput, + stage, + pp); + tryReplaceUsesOfStageInput(context, globalValue, pp); + for (auto dec : pp->getDecorations()) + { + if (dec->getOp() != kIROp_GlobalVariableShadowingGlobalParameterDecoration) + continue; + auto globalVar = dec->getOperand(0); + auto key = dec->getOperand(1); + IRInst* realGlobalVar = nullptr; + if (globalValue.flavor != ScalarizedVal::Flavor::tuple) + continue; + if (auto tupleVal = as<ScalarizedTupleValImpl>(globalValue.impl)) + { + for (auto elem : tupleVal->elements) + { + if (elem.key == key) + { + realGlobalVar = elem.val.irValue; + break; + } + } + } + SLANG_ASSERT(realGlobalVar); + + // Remove all stores into the global var introduced during + // the initial glsl global var translation pass since we are + // going to replace the global var with a pointer to the real + // input, and it makes no sense to store values into such real + // input locations. + traverseUses( + globalVar, + [&](IRUse* use) + { + auto user = use->getUser(); + if (auto store = as<IRStore>(user)) + { + if (store->getPtrUse() == use) + { + store->removeAndDeallocate(); + } + } + }); + // we will be replacing uses of `globalVarToReplace`. We need + // globalVarToReplaceNextUse to catch the next use before it is removed from the + // list of uses. + globalVar->replaceUsesWith(realGlobalVar); + globalVar->removeAndDeallocate(); + } + } else { // This is the "easy" case where the parameter wasn't @@ -3451,6 +3648,7 @@ ScalarizedVal legalizeEntryPointReturnValueForGLSL( return result; } + void legalizeEntryPointForGLSL( Session* session, IRModule* module, @@ -3554,12 +3752,12 @@ void legalizeEntryPointForGLSL( // and turn them into global variables. if (auto firstBlock = func->getFirstBlock()) { - // Any initialization code we insert for parameters needs - // to be at the start of the "ordinary" instructions in the block: - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - for (auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam()) { + // Any initialization code we insert for parameters needs + // to be at the start of the "ordinary" instructions in the block: + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + // We assume that the entry-point parameters will all have // layout information attached to them, which is kept up-to-date // by any transformations affecting the parameter list. @@ -3606,11 +3804,11 @@ void legalizeEntryPointForGLSL( { auto type = value.globalParam->getDataType(); - // Strip out if there is one - auto outType = as<IROutType>(type); - if (outType) + // Strip ptr if there is one. + auto ptrType = as<IRPtrTypeBase>(type); + if (ptrType) { - type = outType->getValueType(); + type = ptrType->getValueType(); } // Get the array type @@ -3627,10 +3825,13 @@ void legalizeEntryPointForGLSL( auto elementCountInst = builder.getIntValue(builder.getIntType(), value.maxIndex + 1); IRType* sizedArrayType = builder.getArrayType(elementType, elementCountInst); - // Re-add out if there was one on the input - if (outType) + // Re-add ptr if there was one on the input + if (ptrType) { - sizedArrayType = builder.getOutType(sizedArrayType); + sizedArrayType = builder.getPtrType( + ptrType->getOp(), + sizedArrayType, + ptrType->getAddressSpace()); } // Change the globals type diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 01466ed00..88a9ac5e3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -736,7 +736,8 @@ INST(GetVulkanRayTracingPayloadLocation, GetVulkanRayTracingPayloadLocation, 1, INST(GetLegalizedSPIRVGlobalParamAddr, GetLegalizedSPIRVGlobalParamAddr, 1, 0) -INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, 0) +INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, HOISTABLE) +INST(ResolveVaryingInputRef, ResolveVaryingInputRef, 1, HOISTABLE) INST(ForceVarIntoStructTemporarily, ForceVarIntoStructTemporarily, 1, 0) INST(MetalAtomicCast, MetalAtomicCast, 1, 0) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 025bcf1b8..33f3944fd 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -651,7 +651,9 @@ protected: // The materialized value can be used to completely // replace the original parameter. // - param->replaceUsesWith(materialized); + auto localVar = builder.emitVar(materialized->getDataType()); + builder.emitStore(localVar, materialized); + param->replaceUsesWith(localVar); param->removeAndDeallocate(); } @@ -1475,4 +1477,71 @@ void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* si context.processModule(module, sink); } +void depointerizeInputParams(IRFunc* entryPointFunc) +{ + List<IRParam*> workList; + List<Index> modifiedParamIndices; + Index i = 0; + for (auto param : entryPointFunc->getParams()) + { + if (auto constRefType = as<IRConstRefType>(param->getFullType())) + { + switch (constRefType->getValueType()->getOp()) + { + case kIROp_VerticesType: + case kIROp_IndicesType: + case kIROp_PrimitivesType: + continue; + default: + break; + } + workList.add(param); + modifiedParamIndices.add(i); + } + else if (auto ptrType = as<IRPtrTypeBase>(param->getFullType())) + { + switch (ptrType->getAddressSpace()) + { + case AddressSpace::Input: + case AddressSpace::BuiltinInput: + workList.add(param); + modifiedParamIndices.add(i); + break; + } + } + i++; + } + for (auto param : workList) + { + auto valueType = as<IRPtrTypeBase>(param->getDataType())->getValueType(); + IRBuilder builder(param); + setInsertBeforeOrdinaryInst(&builder, param); + auto var = builder.emitVar(valueType); + param->replaceUsesWith(var); + param->setFullType(valueType); + builder.emitStore(var, param); + } + + fixUpFuncType(entryPointFunc); + + // Fix up callsites of the entrypoint func. + for (auto use = entryPointFunc->firstUse; use; use = use->nextUse) + { + auto call = as<IRCall>(use->getUser()); + if (!call) + continue; + IRBuilder builder(call); + builder.setInsertBefore(call); + for (auto paramIndex : modifiedParamIndices) + { + auto arg = call->getArg(paramIndex); + auto ptrType = as<IRPtrTypeBase>(arg->getDataType()); + if (!ptrType) + continue; + auto val = builder.emitLoad(arg); + call->setArg(paramIndex, val); + } + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index 7604cb245..efd61e87c 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -18,6 +18,7 @@ void legalizeEntryPointVaryingParamsForCPU(IRModule* module, DiagnosticSink* sin void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* sink); +void depointerizeInputParams(IRFunc* entryPoint); // (#4375) Once `slang-ir-metal-legalize.cpp` is merged with // `slang-ir-legalize-varying-params.cpp`, move the following diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 7f67c9254..74e84f1ee 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -412,6 +412,28 @@ struct LoweredElementTypeContext return 4; } + bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) + { + // For spirv, we always want to lower all matrix types, because SPIRV does not support + // specifying matrix layout/stride if the matrix type is used in places other than + // defining a struct field. This means that if a matrix is used to define a varying + // parameter, we always want to wrap it in a struct. + // + if (target->shouldEmitSPIRVDirectly()) + { + return true; + } + + if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && + config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural) + { + // For other targets, we only lower the matrix types if they differ from the default + // matrix layout. + return false; + } + return true; + } + LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config) { IRBuilder builder(type); @@ -422,18 +444,10 @@ struct LoweredElementTypeContext if (auto matrixType = as<IRMatrixType>(type)) { - // For spirv, we always want to lower all matrix types, because matrix types - // are considered abstract types. - if (!target->shouldEmitSPIRVDirectly()) + if (!shouldLowerMatrixType(matrixType, config)) { - // For other targets, we only lower the matrix types if they differ from the default - // matrix layout. - if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && - config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural) - { - info.loweredType = type; - return info; - } + info.loweredType = type; + return info; } auto loweredType = builder.createStructType(); @@ -859,27 +873,24 @@ struct LoweredElementTypeContext { IRType* elementType = nullptr; - if (options.lowerBufferPointer) + if (auto ptrType = as<IRPtrTypeBase>(globalInst)) { - if (auto ptrType = as<IRPtrTypeBase>(globalInst)) + switch (ptrType->getAddressSpace()) { - switch (ptrType->getAddressSpace()) - { - case AddressSpace::UserPointer: - case AddressSpace::Input: - case AddressSpace::Output: - elementType = ptrType->getValueType(); - break; - } + case AddressSpace::UserPointer: + if (!options.lowerBufferPointer) + continue; + [[fallthrough]]; + case AddressSpace::Input: + case AddressSpace::Output: + elementType = ptrType->getValueType(); + break; } } - else - { - if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) - elementType = structBuffer->getElementType(); - else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) - elementType = constBuffer->getElementType(); - } + if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) + elementType = structBuffer->getElementType(); + else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) + elementType = constBuffer->getElementType(); if (as<IRTextureBufferType>(globalInst)) continue; if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 835041a59..ce5b34c3e 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1924,6 +1924,7 @@ struct LegalizeMetalEntryPointContext void legalizeEntryPointForMetal(EntryPointInfo entryPoint) { // Input Parameter Legalize + depointerizeInputParams(entryPoint.entryPointFunc); hoistEntryPointParameterFromStruct(entryPoint); packStageInParameters(entryPoint); flattenInputParameters(entryPoint); diff --git a/source/slang/slang-ir-resolve-varying-input-ref.cpp b/source/slang/slang-ir-resolve-varying-input-ref.cpp new file mode 100644 index 000000000..0707c566f --- /dev/null +++ b/source/slang/slang-ir-resolve-varying-input-ref.cpp @@ -0,0 +1,92 @@ +#include "slang-ir-resolve-varying-input-ref.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +namespace Slang +{ +void resolveVaryingInputRef(IRFunc* func) +{ + List<IRInst*> toRemove; + for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + for (auto inst : bb->getChildren()) + { + switch (inst->getOp()) + { + case kIROp_ResolveVaryingInputRef: + { + // Resolve a reference to varying input to the actual global param + // representing the varying input. + auto operand = inst->getOperand(0); + List<IRInst*> accessChain; + List<IRInst*> types; + auto rootAddr = getRootAddr(operand, accessChain, &types); + if (rootAddr->getOp() == kIROp_Param || rootAddr->getOp() == kIROp_GlobalParam) + { + // If the referred operand is already a global param, use it directly. + inst->replaceUsesWith(operand); + toRemove.add(inst); + break; + } + // If the referred operand is a local var, + // and there is a store(var, load(globalParam)), + // replace `inst` with `globalParam`. + IRInst* srcPtr = nullptr; + for (auto use = rootAddr->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (auto store = as<IRStore>(user)) + { + if (store->getPtrUse() == use) + { + if (auto load = as<IRLoad>(store->getVal())) + { + auto ptr = load->getPtr(); + if (ptr->getOp() == kIROp_Param || + ptr->getOp() == kIROp_GlobalParam) + { + if (!srcPtr) + srcPtr = ptr; + else + { + srcPtr = nullptr; + break; + } + } + } + } + } + } + if (srcPtr) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto resolvedPtr = builder.emitElementAddress( + srcPtr, + accessChain.getArrayView(), + types.getArrayView()); + inst->replaceUsesWith(resolvedPtr); + toRemove.add(inst); + } + } + break; + } + } + } + for (auto inst : toRemove) + { + inst->removeAndDeallocate(); + } +} + +void resolveVaryingInputRef(IRModule* module) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration<IREntryPointDecoration>()) + resolveVaryingInputRef((IRFunc*)globalInst); + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-resolve-varying-input-ref.h b/source/slang/slang-ir-resolve-varying-input-ref.h new file mode 100644 index 000000000..5cbff0f8c --- /dev/null +++ b/source/slang/slang-ir-resolve-varying-input-ref.h @@ -0,0 +1,10 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +void resolveVaryingInputRef(IRFunc* func); +void resolveVaryingInputRef(IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 65cb8f64f..a44e16a7c 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -122,7 +122,8 @@ struct GlobalVarTranslationContext // Add an entry point parameter for all the inputs. auto firstBlock = entryPointFunc->getFirstBlock(); builder.setInsertInto(firstBlock); - auto inputParam = builder.emitParam(inputStructType); + auto inputParam = builder.emitParam( + builder.getPtrType(kIROp_ConstRefType, inputStructType, AddressSpace::Input)); builder.addLayoutDecoration(inputParam, paramLayout); // Initialize all global variables. @@ -133,7 +134,8 @@ struct GlobalVarTranslationContext auto inputType = cast<IRPtrTypeBase>(input->getDataType())->getValueType(); builder.emitStore( input, - builder.emitFieldExtract(inputType, inputParam, inputKeys[i])); + builder + .emitFieldExtract(inputType, builder.emitLoad(inputParam), inputKeys[i])); // Relate "global variable" to a "global parameter" for use later in compilation // to resolve a "global variable" shadowing a "global parameter" relationship. builder.addGlobalVariableShadowingGlobalParameterDecoration( diff --git a/source/slang/slang-ir-vk-invert-y.cpp b/source/slang/slang-ir-vk-invert-y.cpp index e7fc81144..70f2584ac 100644 --- a/source/slang/slang-ir-vk-invert-y.cpp +++ b/source/slang/slang-ir-vk-invert-y.cpp @@ -104,10 +104,15 @@ void rcpWOfPositionInput(IRModule* module) [&](IRUse* use) { // Get the inverted vector. - builder.setInsertBefore(use->getUser()); - auto invertedVal = _invertWOfVector(builder, globalInst); - // Replace original uses with the invertex vector. - builder.replaceOperand(use, invertedVal); + auto user = use->getUser(); + if (user->getOp() == kIROp_Load) + { + builder.setInsertBefore(user); + auto val = builder.emitLoad(globalInst); + auto invertedVal = _invertWOfVector(builder, val); + user->replaceUsesWith(invertedVal); + user->removeAndDeallocate(); + } }); } } diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index 907c2b8ba..f76a0541c 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -1362,6 +1362,9 @@ struct LegalizeWGSLEntryPointContext void legalizeEntryPointForWGSL(EntryPointInfo entryPoint) { + // If the entrypoint is receiving varying inputs as a pointer, turn it into a value. + depointerizeInputParams(entryPoint.entryPointFunc); + // Input Parameter Legalize flattenInputParameters(entryPoint); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ff1cd49ea..daeaca67b 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8243,6 +8243,8 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetStringHash: case kIROp_AllocateOpaqueHandle: case kIROp_GetArrayLength: + case kIROp_ResolveVaryingInputRef: + case kIROp_GetPerVertexInputArray: return false; case kIROp_ForwardDifferentiate: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5bbe44e9b..011ea6bc7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2767,14 +2767,15 @@ ParameterDirection getParameterDirection(VarDeclBase* paramDecl) /// ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection defaultDirection) { - auto parentParent = getParentDecl(parentDecl); + auto parentParent = getParentAggTypeDecl(parentDecl); + // The `this` parameter for a `class` is always `in`. if (as<ClassDecl>(parentParent)) { return kParameterDirection_In; } - if (parentParent->findModifier<NonCopyableTypeAttribute>()) + if (parentParent && parentParent->findModifier<NonCopyableTypeAttribute>()) { if (parentDecl->hasModifier<MutatingAttribute>()) return kParameterDirection_Ref; @@ -2982,6 +2983,9 @@ struct IRLoweringParameterInfo // The direction (`in` vs `out` vs `in out`) ParameterDirection direction; + // The direction declared in user code. + ParameterDirection declaredDirection = ParameterDirection::kParameterDirection_In; + // The variable/parameter declaration for // this parameter (if any) VarDeclBase* decl = nullptr; @@ -3005,6 +3009,7 @@ IRLoweringParameterInfo getParameterInfo( info.type = getParamType(context->astBuilder, paramDecl); info.decl = paramDecl.getDecl(); info.direction = getParameterDirection(paramDecl.getDecl()); + info.declaredDirection = info.direction; info.isThisParam = false; return info; } @@ -3051,6 +3056,7 @@ void addThisParameter(ParameterDirection direction, Type* type, ParameterLists* info.type = type; info.decl = nullptr; info.direction = direction; + info.declaredDirection = direction; info.isThisParam = true; ioParameterLists->params.add(info); @@ -3064,10 +3070,22 @@ void maybeAddReturnDestinationParam(ParameterLists* ioParameterLists, Type* resu info.type = resultType; info.decl = nullptr; info.direction = kParameterDirection_Ref; + info.declaredDirection = info.direction; info.isReturnDestination = true; ioParameterLists->params.add(info); } } + +void makeVaryingInputParamConstRef(IRLoweringParameterInfo& paramInfo) +{ + if (paramInfo.direction != kParameterDirection_In) + return; + if (paramInfo.decl->findModifier<HLSLUniformModifier>()) + return; + if (as<HLSLPatchType>(paramInfo.type)) + return; + paramInfo.direction = kParameterDirection_ConstRef; +} // // And here is our function that will do the recursive walk: void collectParameterLists( @@ -3137,13 +3155,31 @@ void collectParameterLists( // if (auto callableDeclRef = declRef.as<CallableDecl>()) { + // We need a special case here when lowering the varying parameters of an entrypoint + // function. Due to the existence of `EvaluateAttributeAtSample` and friends, we need to + // always lower the varying inputs as `__constref` parameters so we can pass pointers to + // these intrinsics. + // This means that although these parameters are declared as "in" parameters in the source, + // we will actually treat them as __constref parameters when lowering to IR. A complication + // result from this is that if the original source code actually modifies the input + // parameter we still need to create a local var to hold the modified value. In the future + // when we are able to update our language spec to always assume input parameters are + // immutable, then we can remove this adhoc logic of introducing temporary variables. For + // For now we will rely on a follow up pass to remove unnecessary temporary variables if + // we can determine that they are never actually writtten to by the user. + // + bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier<EntryPointAttribute>(); + // Don't collect parameters from the outer scope if // we are in a `static` context. if (mode == kParameterListCollectMode_Default) { for (auto paramDeclRef : getParameters(context->astBuilder, callableDeclRef)) { - ioParameterLists->params.add(getParameterInfo(context, paramDeclRef)); + auto paramInfo = getParameterInfo(context, paramDeclRef); + if (lowerVaryingInputAsConstRef) + makeVaryingInputParamConstRef(paramInfo); + ioParameterLists->params.add(paramInfo); } maybeAddReturnDestinationParam( ioParameterLists, @@ -5623,9 +5659,7 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBase<RValueExprLowe LoweredValInfo visitOpenRefExpr(OpenRefExpr* expr) { auto inner = lowerLValueExpr(context, expr->innerExpr); - auto builder = getBuilder(); - auto irLoad = builder->emitLoad(inner.val); - return LoweredValInfo::simple(irLoad); + return LoweredValInfo::ptr(inner.val); } }; @@ -9980,6 +10014,22 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (paramInfo.isReturnDestination) subContext->returnDestination = paramVal; + if (paramInfo.declaredDirection == kParameterDirection_In && + paramInfo.direction == kParameterDirection_ConstRef) + { + // If the parameter is originally declared as "in", but we are + // lowering it as constref for any reason (e.g. it is a varying input), + // then we need to emit a local variable to hold the original value, so + // that we can still generate correct code when the user trys to mutate + // the variable. + // The local variable introduced here is cleaned up by the SSA pass, if + // we can determine that there are no actual writes into the local var. + auto irLocal = + subBuilder->emitVar(tryGetPointedToType(subBuilder, irParamType)); + auto localVal = LoweredValInfo::ptr(irLocal); + assign(subContext, localVal, paramVal); + paramVal = localVal; + } // TODO: We might want to copy the pointed-to value into // a temporary at the start of the function, and then copy // back out at the end, so that we don't have to worry @@ -10987,6 +11037,16 @@ static void lowerFrontEndEntryPointToIR( auto entryPointFuncDecl = entryPoint->getFuncDecl(); + if (!entryPointFuncDecl->findModifier<EntryPointAttribute>()) + { + // If the entry point doesn't have an explicit `[shader("...")]` attribute, + // then we make sure to add one here, so the lowering logic knows it is an + // entry point. + auto entryPointAttr = context->astBuilder->create<EntryPointAttribute>(); + entryPointAttr->capabilitySet = entryPoint->getProfile().getCapabilityName(); + addModifier(entryPointFuncDecl, entryPointAttr); + } + auto builder = context->irBuilder; builder->setInsertInto(builder->getModule()->getModuleInst()); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 22491c848..c275a868b 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -7531,12 +7531,6 @@ static IRFloatingPointValue _foldFloatPrefixOp(TokenType tokenType, IRFloatingPo static std::optional<SPIRVAsmOperand> parseSPIRVAsmOperand(Parser* parser) { - const auto slangIdentOperand = [&](auto flavor) - { - auto token = parser->tokenReader.peekToken(); - return SPIRVAsmOperand{flavor, token, parseAtomicExpr(parser)}; - }; - const auto slangTypeExprOperand = [&](auto flavor) { auto tok = parser->tokenReader.peekToken(); @@ -7673,12 +7667,13 @@ static std::optional<SPIRVAsmOperand> parseSPIRVAsmOperand(Parser* parser) // A &foo variable reference (for the address of foo) else if (AdvanceIf(parser, TokenType::OpBitAnd)) { - return slangIdentOperand(SPIRVAsmOperand::SlangValueAddr); + Expr* expr = parsePostfixExpr(parser); + return SPIRVAsmOperand{SPIRVAsmOperand::SlangValueAddr, Token{}, expr}; } // A $foo variable else if (AdvanceIf(parser, TokenType::Dollar)) { - Expr* expr = parseAtomicExpr(parser); + Expr* expr = parsePostfixExpr(parser); return SPIRVAsmOperand{SPIRVAsmOperand::SlangValue, Token{}, expr}; } // A $$foo type |
