diff options
| author | Yong He <yonghe@outlook.com> | 2025-01-07 22:26:31 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-07 22:26:31 -0800 |
| commit | c43f6fa55aca23365c86c6ec1737d42be74d9d3e (patch) | |
| tree | 2c49bc1dbd12ae5f46d682a3f240465931471060 | |
| parent | 1a56f58fdd0c704e6dc0fad0f0ec33a25a35e60b (diff) | |
Lower varying parameters as pointers instead of SSA values. (#5919)
* Add executable test on matrix-typed vertex input.
* Fix emit logic of matrix layout qualifier.
* Pass fragment shader varying input by constref to allow EvaluateAttributeAtCentroid etc. to be implemented correctly.
47 files changed, 1155 insertions, 303 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 625f8f608..adb7470dd 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1,3 +1,5 @@ +//public module core; + // Slang `core` library // Aliases for base types diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index 2ff71a74e..ba26b5d84 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -9080,7 +9080,7 @@ public vector<float, N> fwidthCoarse(vector<float, N> p) [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public float interpolateAtCentroid(__ref float interpolant) +public float interpolateAtCentroid(__constref float interpolant) { __target_switch { @@ -9099,7 +9099,7 @@ __generic<let N : int> [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public vector<float, N> interpolateAtCentroid(__ref vector<float, N> interpolant) +public vector<float, N> interpolateAtCentroid(__constref vector<float, N> interpolant) { __target_switch { @@ -9118,7 +9118,7 @@ public vector<float, N> interpolateAtCentroid(__ref vector<float, N> interpolant [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public float interpolateAtSample(__ref float interpolant, int sample) +public float interpolateAtSample(__constref float interpolant, int sample) { __target_switch { @@ -9137,7 +9137,7 @@ __generic<let N : int> [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public vector<float, N> interpolateAtSample(__ref vector<float, N> interpolant, int sample) +public vector<float, N> interpolateAtSample(__constref vector<float, N> interpolant, int sample) { __target_switch { @@ -9156,7 +9156,7 @@ public vector<float, N> interpolateAtSample(__ref vector<float, N> interpolant, [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public float interpolateAtOffset(__ref float interpolant, vec2 offset) +public float interpolateAtOffset(__constref float interpolant, vec2 offset) { __target_switch { @@ -9175,7 +9175,7 @@ __generic<let N : int> [__NoSideEffect] [__GLSLRequireShaderInputParameter(0)] [require(glsl_spirv, fragmentprocessing)] -public vector<float, N> interpolateAtOffset(__ref vector<float, N> interpolant, vec2 offset) +public vector<float, N> interpolateAtOffset(__constref vector<float, N> interpolant, vec2 offset) { __target_switch { diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index d8f9845b8..d5b59427f 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -8101,89 +8101,139 @@ RasterizerOrderedStructuredBuffer<T> __getEquivalentStructuredBuffer<T>(Rasteriz // Attribute evaluation +T __EvaluateAttributeAtCentroid<T>(__constref T x) +{ + __target_switch + { + case hlsl: __intrinsic_asm "EvaluateAttributeAtCentroid"; + case glsl: __intrinsic_asm "interpolateAtCentroid"; + } +} + // TODO: The matrix cases of these functions won't actuall work // when compiled to GLSL, since they only support scalar/vector // TODO: Should these be constrains to `__BuiltinFloatingPointType`? // TODO: SPIRV-direct does not support non-floating-point types. +/// Interpolates vertex attribute at centroid position. +/// @param x The vertex attribute to interpolate. +/// @return The interpolated attribute value. +/// @remarks `x` must be a direct reference to a fragment shader varying input. +/// @category interpolation Vertex Interpolation Functions __generic<T : __BuiltinArithmeticType> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -T EvaluateAttributeAtCentroid(T x) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +T EvaluateAttributeAtCentroid(__constref T x) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtCentroid"; + case hlsl: + case glsl: + return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x)); case spirv: return spirv_asm { - OpExtInst $$T result glsl450 InterpolateAtCentroid $x + OpCapability InterpolationFunction; + OpExtInst $$T result glsl450 InterpolateAtCentroid $__ResolveVaryingInputRef(x) }; } } __generic<T : __BuiltinArithmeticType, let N : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -vector<T,N> EvaluateAttributeAtCentroid(vector<T,N> x) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +vector<T,N> EvaluateAttributeAtCentroid(__constref vector<T,N> x) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtCentroid"; + case hlsl: + case glsl: + return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x)); case spirv: return spirv_asm { - OpExtInst $$vector<T,N> result glsl450 InterpolateAtCentroid $x + OpCapability InterpolationFunction; + OpExtInst $$vector<T,N> result glsl450 InterpolateAtCentroid $__ResolveVaryingInputRef(x) }; } } __generic<T : __BuiltinArithmeticType, let N : int, let M : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -matrix<T,N,M> EvaluateAttributeAtCentroid(matrix<T,N,M> x) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +matrix<T,N,M> EvaluateAttributeAtCentroid(__constref matrix<T,N,M> x) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtCentroid"; + case hlsl: + case glsl: + return __EvaluateAttributeAtCentroid(__ResolveVaryingInputRef(x)); default: MATRIX_MAP_UNARY(T, N, M, EvaluateAttributeAtCentroid, x); } } +T __EvaluateAttributeAtSample<T>(__constref T x, uint sampleIndex) +{ + __target_switch + { + case hlsl: __intrinsic_asm "EvaluateAttributeAtSample"; + case glsl: __intrinsic_asm "interpolateAtSample"; + } +} + +/// Interpolates vertex attribute at the current fragment sample position. +/// @param x The vertex attribute to interpolate. +/// @return The interpolated attribute value. +/// @remarks `x` must be a direct reference to a fragment shader varying input. +/// @category interpolation Vertex Interpolation Functions __generic<T : __BuiltinArithmeticType> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -T EvaluateAttributeAtSample(T x, uint sampleindex) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +T EvaluateAttributeAtSample(__constref T x, uint sampleindex) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))"; + case hlsl: + case glsl: + return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex); case spirv: return spirv_asm { - OpExtInst $$T result glsl450 InterpolateAtSample $x $sampleindex + OpCapability InterpolationFunction; + OpExtInst $$T result glsl450 InterpolateAtSample $__ResolveVaryingInputRef(x) $sampleindex }; } } __generic<T : __BuiltinArithmeticType, let N : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -vector<T,N> EvaluateAttributeAtSample(vector<T,N> x, uint sampleindex) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +vector<T,N> EvaluateAttributeAtSample(__constref vector<T,N> x, uint sampleindex) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))"; + case hlsl: + case glsl: + return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex); case spirv: return spirv_asm { - OpExtInst $$vector<T,N> result glsl450 InterpolateAtSample $x $sampleindex + OpCapability InterpolationFunction; + OpExtInst $$vector<T,N> result glsl450 InterpolateAtSample $__ResolveVaryingInputRef(x) $sampleindex }; } } __generic<T : __BuiltinArithmeticType, let N : int, let M : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +matrix<T,N,M> EvaluateAttributeAtSample(__constref matrix<T,N,M> x, uint sampleindex) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtSample($0, int($1))"; + case hlsl: + case glsl: + return __EvaluateAttributeAtSample(__ResolveVaryingInputRef(x), sampleindex); default: matrix<T,N,M> result; for(int i = 0; i < N; ++i) @@ -8194,21 +8244,59 @@ matrix<T,N,M> EvaluateAttributeAtSample(matrix<T,N,M> x, uint sampleindex) } } +T __EvaluateAttributeSnapped<T>(__constref T x, int2 offset) +{ + __target_switch + { + case hlsl: __intrinsic_asm "EvaluateAttributeSnapped"; + case glsl: __intrinsic_asm "EvaluateAttributeSnapped"; + } +} + +/// Interpolates vertex attribute at the specified subpixel offset. +/// @param x The vertex attribute to interpolate. +/// @param offset The subpixel offset. Each component is a 4-bit signed integer in range [-8, 7]. +/// @return The interpolated attribute value. +/// @remarks `x` must be a direct reference to a fragment shader varying input. +/// +/// The valid values of each component of `offset` are: +/// +/// - 1000 = -0.5f (-8 / 16) +/// - 1001 = -0.4375f (-7 / 16) +/// - 1010 = -0.375f (-6 / 16) +/// - 1011 = -0.3125f (-5 / 16) +/// - 1100 = -0.25f (-4 / 16) +/// - 1101 = -0.1875f (-3 / 16) +/// - 1110 = -0.125f (-2 / 16) +/// - 1111 = -0.0625f (-1 / 16) +/// - 0000 = 0.0f ( 0 / 16) +/// - 0001 = 0.0625f ( 1 / 16) +/// - 0010 = 0.125f ( 2 / 16) +/// - 0011 = 0.1875f ( 3 / 16) +/// - 0100 = 0.25f ( 4 / 16) +/// - 0101 = 0.3125f ( 5 / 16) +/// - 0110 = 0.375f ( 6 / 16) +/// - 0111 = 0.4375f ( 7 / 16) +/// @category interpolation Vertex Interpolation Functions __generic<T : __BuiltinArithmeticType> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -T EvaluateAttributeSnapped(T x, int2 offset) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +T EvaluateAttributeSnapped(__constref T x, int2 offset) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)"; + case hlsl: + case glsl: + return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset); case spirv: { const float2 tmp = float2(16.f, 16.f); return spirv_asm { + OpCapability InterpolationFunction; %foffset:$$float2 = OpConvertSToF $offset; %offsetdiv16:$$float2 = OpFDiv %foffset $tmp; - result:$$T = OpExtInst glsl450 InterpolateAtOffset $x %offsetdiv16 + result:$$T = OpExtInst glsl450 InterpolateAtOffset $__ResolveVaryingInputRef(x) %offsetdiv16 }; } } @@ -8216,19 +8304,23 @@ T EvaluateAttributeSnapped(T x, int2 offset) __generic<T : __BuiltinArithmeticType, let N : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -vector<T,N> EvaluateAttributeSnapped(vector<T,N> x, int2 offset) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +vector<T,N> EvaluateAttributeSnapped(__constref vector<T,N> x, int2 offset) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)"; + case hlsl: + case glsl: + return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset); case spirv: { const float2 tmp = float2(16.f, 16.f); return spirv_asm { + OpCapability InterpolationFunction; %foffset:$$float2 = OpConvertSToF $offset; %offsetdiv16:$$float2 = OpFDiv %foffset $tmp; - result:$$vector<T,N> = OpExtInst glsl450 InterpolateAtOffset $x %offsetdiv16 + result:$$vector<T,N> = OpExtInst glsl450 InterpolateAtOffset $__ResolveVaryingInputRef(x) %offsetdiv16 }; } } @@ -8236,12 +8328,15 @@ vector<T,N> EvaluateAttributeSnapped(vector<T,N> x, int2 offset) __generic<T : __BuiltinArithmeticType, let N : int, let M : int> [__readNone] -[require(glsl_spirv, fragmentprocessing)] -matrix<T,N,M> EvaluateAttributeSnapped(matrix<T,N,M> x, int2 offset) +[__unsafeForceInlineEarly] +[require(glsl_hlsl_spirv, fragmentprocessing)] +matrix<T,N,M> EvaluateAttributeSnapped(__constref matrix<T,N,M> x, int2 offset) { __target_switch { - case glsl: __intrinsic_asm "interpolateAtOffset($0, vec2($1) / 16.0f)"; + case hlsl: + case glsl: + return __EvaluateAttributeSnapped(__ResolveVaryingInputRef(x), offset); default: matrix<T,N,M> result; for(int i = 0; i < N; ++i) @@ -9243,8 +9338,16 @@ matrix<T, N, M> fwidth(matrix<T, N, M> x) } } +__intrinsic_op($(kIROp_ResolveVaryingInputRef)) +Ref<T> __ResolveVaryingInputRef<T>(__constref T attribute); + __intrinsic_op($(kIROp_GetPerVertexInputArray)) -Array<T, 3> __GetPerVertexInputArray<T>(T attribute); +Ref<Array<T, 3>> __GetPerVertexInputArray<T>(__constref T attribute); + +T __GetAttributeAtVertex<T>(__constref T attribute, uint vertexIndex) +{ + __intrinsic_asm "GetAttributeAtVertex"; +} /// Get the value of a vertex attribute at a specific vertex. /// @@ -9265,15 +9368,15 @@ __glsl_extension(GL_EXT_fragment_shader_barycentric) [require(glsl_hlsl_spirv, getattributeatvertex)] [KnownBuiltin("GetAttributeAtVertex")] [__unsafeForceInlineEarly] -T GetAttributeAtVertex(T attribute, uint vertexIndex) +T GetAttributeAtVertex(__constref T attribute, uint vertexIndex) { __target_switch { case hlsl: - __intrinsic_asm "GetAttributeAtVertex"; + return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex); case glsl: case spirv: - return __GetPerVertexInputArray(attribute)[vertexIndex]; + return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex]; } } @@ -9294,20 +9397,16 @@ __generic<T : __BuiltinType, let N : int> __glsl_version(450) __glsl_extension(GL_EXT_fragment_shader_barycentric) [require(glsl_hlsl_spirv, getattributeatvertex)] -vector<T,N> GetAttributeAtVertex(vector<T,N> attribute, uint vertexIndex) +[__unsafeForceInlineEarly] +vector<T,N> GetAttributeAtVertex(__constref vector<T,N> attribute, uint vertexIndex) { __target_switch { case hlsl: - __intrinsic_asm "GetAttributeAtVertex"; - case glsl: - __intrinsic_asm "$0[$1]"; + return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex); + case glsl: case spirv: - return spirv_asm { - %_ptr_Input_vectorT = OpTypePointer Input $$vector<T,N>; - %addr = OpAccessChain %_ptr_Input_vectorT $attribute $vertexIndex; - result:$$vector<T,N> = OpLoad %addr; - }; + return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex]; } } @@ -9328,20 +9427,16 @@ __generic<T : __BuiltinType, let N : int, let M : int> __glsl_version(450) __glsl_extension(GL_EXT_fragment_shader_barycentric) [require(glsl_hlsl_spirv, getattributeatvertex)] -matrix<T,N,M> GetAttributeAtVertex(matrix<T,N,M> attribute, uint vertexIndex) +[__unsafeForceInlineEarly] +matrix<T,N,M> GetAttributeAtVertex(__constref matrix<T,N,M> attribute, uint vertexIndex) { __target_switch { case hlsl: - __intrinsic_asm "GetAttributeAtVertex"; - case glsl: - __intrinsic_asm "$0[$1]"; + return __GetAttributeAtVertex(__ResolveVaryingInputRef(attribute), vertexIndex); + case glsl: case spirv: - return spirv_asm { - %_ptr_Input_matrixT = OpTypePointer Input $$matrix<T,N,M>; - %addr = OpAccessChain %_ptr_Input_matrixT $attribute $vertexIndex; - result:$$matrix<T,N,M> = OpLoad %addr; - }; + return __GetPerVertexInputArray(__ResolveVaryingInputRef(attribute))[vertexIndex]; } } diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index e8886a59a..911455f17 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -461,6 +461,8 @@ class ModuleDecl : public NamespaceDeclBase /// `__implementing` etc. bool isInLegacyLanguage = true; + DeclVisibility defaultVisibility = DeclVisibility::Internal; + SLANG_UNREFLECTED /// Map a type to the list of extensions of that type (if any) declared in this module diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index b77003bff..85d2d0d9f 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -460,6 +460,26 @@ struct ASTDumpContext } } + void dump(DeclVisibility vis) + { + switch (vis) + { + case DeclVisibility::Private: + m_writer->emit("private"); + break; + case DeclVisibility::Internal: + m_writer->emit("internal"); + break; + case DeclVisibility::Public: + m_writer->emit("public"); + break; + default: + m_writer->emit(String((int)vis).getUnownedSlice()); + break; + } + } + + void dump(const QualType& qualType) { if (qualType.isLeftValue) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index ce3f1e64c..3667a36ba 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3186,6 +3186,11 @@ void SemanticsDeclVisitorBase::checkModule(ModuleDecl* moduleDecl) } } + if (moduleDecl->findModifier<PublicModifier>()) + { + moduleDecl->defaultVisibility = DeclVisibility::Public; + } + // We need/want to visit any `import` declarations before // anything else, to make sure that scoping works. // @@ -12604,8 +12609,10 @@ DeclVisibility getDeclVisibility(Decl* decl) } auto defaultVis = DeclVisibility::Default; if (auto parentModule = getModuleDecl(decl)) - defaultVis = - parentModule->isInLegacyLanguage ? DeclVisibility::Public : DeclVisibility::Internal; + { + defaultVis = parentModule->isInLegacyLanguage ? DeclVisibility::Public + : parentModule->defaultVisibility; + } // Members of other agg type decls will have their default visibility capped to the parents'. if (as<NamespaceDecl>(decl)) @@ -12790,10 +12797,6 @@ void diagnoseCapabilityProvenance( auto moduleDecl = getModuleDecl(declToPrint); if (thisModule != moduleDecl) break; - if (moduleDecl && moduleDecl->isInLegacyLanguage) - continue; - if (getDeclVisibility(declToPrint) == DeclVisibility::Public) - break; } if (previousDecl == declToPrint) break; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 83b668a33..95d5a2a7c 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -4099,6 +4099,15 @@ Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBas elementType = QualType(ptrType->getValueType()); elementType.isLeftValue = true; } + else + { + auto newExpr = maybeOpenRef(expr); + if (newExpr != expr) + { + expr = newExpr; + continue; + } + } if (elementType.type) { auto derefExpr = m_astBuilder->create<DerefExpr>(); @@ -4108,9 +4117,10 @@ Expr* SemanticsVisitor::maybeDereference(Expr* inExpr, CheckBaseContext checkBas expr = derefExpr; continue; } - // Default case: just use the expression as-is - return expr; + break; } + // Default case: just use the expression as-is + return expr; } Expr* SemanticsVisitor::CheckMatrixSwizzleExpr( diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 3175f1b07..64cc9969c 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1537,6 +1537,30 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) return true; } + if (auto load = as<IRLoad>(inst)) + { + // Loads from a constref global param should always be folded. + auto ptrType = load->getPtr()->getDataType(); + if (load->getPtr()->getOp() == kIROp_GlobalParam) + { + if (ptrType->getOp() == kIROp_ConstRefType) + return true; + if (auto ptrTypeBase = as<IRPtrTypeBase>(ptrType)) + { + auto addrSpace = ptrTypeBase->getAddressSpace(); + switch (addrSpace) + { + case Slang::AddressSpace::Uniform: + case Slang::AddressSpace::Input: + case Slang::AddressSpace::BuiltinInput: + return true; + default: + break; + } + } + } + } + // Always hold if inst is a call into an [__alwaysFoldIntoUseSite] function. if (auto call = as<IRCall>(inst)) { @@ -4709,9 +4733,21 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) auto rawType = varDecl->getDataType(); auto varType = rawType; - if (auto outType = as<IROutTypeBase>(varType)) + if (auto ptrType = as<IRPtrTypeBase>(varType)) { - varType = outType->getValueType(); + switch (ptrType->getAddressSpace()) + { + case AddressSpace::Input: + case AddressSpace::Output: + case AddressSpace::BuiltinInput: + case AddressSpace::BuiltinOutput: + varType = ptrType->getValueType(); + break; + default: + if (as<IROutTypeBase>(ptrType)) + varType = ptrType->getValueType(); + break; + } } if (as<IRVoidType>(varType)) return; diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index f91c4d06e..b0d1fbb4c 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -310,6 +310,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S } case kIROp_NativePtrType: case kIROp_PtrType: + case kIROp_ConstRefType: { auto elementType = (IRType*)type->getOperand(0); SLANG_RETURN_ON_FAIL(calcTypeName(elementType, target, out)); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index a863e7eb1..23fff37ac 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -3166,6 +3166,11 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType()); return; } + case kIROp_ConstRefType: + { + emitSimpleTypeImpl(as<IRConstRefType>(type)->getValueType()); + return; + } default: break; } @@ -3471,15 +3476,18 @@ void GLSLSourceEmitter::emitMatrixLayoutModifiersImpl(IRType* varType) // auto matrixType = as<IRMatrixType>(unwrapArray(varType)); - if (matrixType) { + auto layout = getIntVal(matrixType->getLayout()); + if (layout == getTargetProgram()->getOptionSet().getMatrixLayoutMode()) + return; + // Reminder: the meaning of row/column major layout // in our semantics is the *opposite* of what GLSL // calls them, because what they call "columns" // are what we call "rows." // - switch (getIntVal(matrixType->getLayout())) + switch (layout) { case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR: m_writer->emit("layout(row_major)\n"); diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 40d6f75d9..83eec17b4 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -1180,6 +1180,34 @@ void HLSLSourceEmitter::emitSimpleValueImpl(IRInst* inst) Super::emitSimpleValueImpl(inst); } +void HLSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) +{ + if (declarator) + { + // HLSL only allow matrix layout modifier when declaring a variable or struct field. + if (auto matType = as<IRMatrixType>(type)) + { + auto matrixLayout = getIntVal(matType->getLayout()); + if (getTargetProgram()->getOptionSet().getMatrixLayoutMode() != + (MatrixLayoutMode)matrixLayout) + { + switch (matrixLayout) + { + case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR: + m_writer->emit("column_major "); + break; + case SLANG_MATRIX_LAYOUT_ROW_MAJOR: + m_writer->emit("row_major "); + break; + default: + break; + } + } + } + } + Super::emitSimpleTypeAndDeclaratorImpl(type, declarator); +} + void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) { switch (type->getOp()) @@ -1313,6 +1341,11 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType()); return; } + case kIROp_ConstRefType: + { + emitSimpleTypeImpl(as<IRConstRefType>(type)->getValueType()); + return; + } default: break; } @@ -1671,28 +1704,6 @@ void HLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl) } } -void HLSLSourceEmitter::emitMatrixLayoutModifiersImpl(IRType* type) -{ - auto matType = as<IRMatrixType>(type); - if (!matType) - return; - auto matrixLayout = getIntVal(matType->getLayout()); - if (getTargetProgram()->getOptionSet().getMatrixLayoutMode() != (MatrixLayoutMode)matrixLayout) - { - switch (matrixLayout) - { - case SLANG_MATRIX_LAYOUT_COLUMN_MAJOR: - m_writer->emit("column_major "); - break; - case SLANG_MATRIX_LAYOUT_ROW_MAJOR: - m_writer->emit("row_major "); - break; - default: - break; - } - } -} - void HLSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) { if (inst->findDecoration<IRRequiresNVAPIDecoration>()) diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index b2e2ca05a..6b99a7f50 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -55,11 +55,12 @@ protected: IRPackOffsetDecoration* decoration) SLANG_OVERRIDE; virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; + virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) + SLANG_OVERRIDE; virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; - virtual void emitMatrixLayoutModifiersImpl(IRType* varType) SLANG_OVERRIDE; virtual void emitParamTypeModifier(IRType* type) SLANG_OVERRIDE { emitMatrixLayoutModifiersImpl(type); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 7ea2fef88..b9217de41 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -42,6 +42,7 @@ #include "slang-ir-entry-point-uniforms.h" #include "slang-ir-explicit-global-context.h" #include "slang-ir-explicit-global-init.h" +#include "slang-ir-fix-entrypoint-callsite.h" #include "slang-ir-fuse-satcoop.h" #include "slang-ir-glsl-legalize.h" #include "slang-ir-glsl-liveness.h" @@ -76,6 +77,7 @@ #include "slang-ir-pytorch-cpp-binding.h" #include "slang-ir-redundancy-removal.h" #include "slang-ir-resolve-texture-format.h" +#include "slang-ir-resolve-varying-input-ref.h" #include "slang-ir-restructure-scoping.h" #include "slang-ir-restructure.h" #include "slang-ir-sccp.h" @@ -314,6 +316,7 @@ struct RequiredLoweringPassSet bool glslSSBO; bool byteAddressBuffer; bool dynamicResource; + bool resolveVaryingInputRef; }; // Scan the IR module and determine which lowering/legalization passes are needed based @@ -423,6 +426,9 @@ void calcRequiredLoweringPassSet( case kIROp_DynamicResourceType: result.dynamicResource = true; break; + case kIROp_ResolveVaryingInputRef: + result.resolveVaryingInputRef = true; + break; } if (!result.generics || !result.existentialTypeLayout) { @@ -591,6 +597,11 @@ Result linkAndOptimizeIR( if (requiredLoweringPassSet.glslGlobalVar) translateGLSLGlobalVar(codeGenContext, irModule); + if (requiredLoweringPassSet.resolveVaryingInputRef) + resolveVaryingInputRef(irModule); + + fixEntryPointCallsites(irModule); + // Replace any global constants with their values. // replaceGlobalConstants(irModule); diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.cpp b/source/slang/slang-ir-fix-entrypoint-callsite.cpp new file mode 100644 index 000000000..7390f3a7f --- /dev/null +++ b/source/slang/slang-ir-fix-entrypoint-callsite.cpp @@ -0,0 +1,101 @@ +#include "slang-ir-fix-entrypoint-callsite.h" + +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +namespace Slang +{ +// If the entrypoint is called by some other function, we need to clone the +// entrypoint and replace the callsites to call the cloned entrypoint instead. +// This is because we will be modifying the signature of the entrypoint during +// entrypoint legalization to rewrite the way system values are passed in. +// By replacing the callsites to call the cloned entrypoint that act as ordinary +// functions, we will no longer need to worry about changing the callsites when we +// legalize the entry-points. +// +void fixEntryPointCallsites(IRFunc* entryPoint) +{ + IRFunc* clonedEntryPointForCall = nullptr; + auto ensureClonedEntryPointForCall = [&]() -> IRFunc* + { + if (clonedEntryPointForCall) + return clonedEntryPointForCall; + IRCloneEnv cloneEnv; + IRBuilder builder(entryPoint); + builder.setInsertBefore(entryPoint); + clonedEntryPointForCall = (IRFunc*)cloneInst(&cloneEnv, &builder, entryPoint); + // Remove entrypoint and linkage decorations from the cloned callee. + List<IRInst*> decorsToRemove; + for (auto decor : clonedEntryPointForCall->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_ExportDecoration: + case kIROp_UserExternDecoration: + case kIROp_HLSLExportDecoration: + case kIROp_EntryPointDecoration: + case kIROp_LayoutDecoration: + case kIROp_NumThreadsDecoration: + case kIROp_ImportDecoration: + case kIROp_ExternCDecoration: + case kIROp_ExternCppDecoration: + decorsToRemove.add(decor); + break; + } + } + for (auto decor : decorsToRemove) + decor->removeAndDeallocate(); + return clonedEntryPointForCall; + }; + traverseUses( + entryPoint, + [&](IRUse* use) + { + auto user = use->getUser(); + auto call = as<IRCall>(user); + if (!call) + return; + auto callee = ensureClonedEntryPointForCall(); + call->setOperand(0, callee); + + // Fix up argument types: if the callee entrypoint is expecting a constref + // and the caller is passing a value, we need to wrap the value in a temporary var + // and pass the temporary var. + // + auto funcType = as<IRFuncType>(callee->getDataType()); + SLANG_ASSERT(funcType); + IRBuilder builder(call); + builder.setInsertBefore(call); + List<IRParam*> params; + for (auto param : callee->getParams()) + params.add(param); + if ((UInt)params.getCount() != call->getArgCount()) + return; + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto paramType = params[i]->getDataType(); + auto arg = call->getArg(i); + if (auto refType = as<IRConstRefType>(paramType)) + { + if (!as<IRPtrTypeBase>(arg->getDataType())) + { + auto tempVar = builder.emitVar(refType->getValueType()); + builder.emitStore(tempVar, arg); + call->setArg(i, tempVar); + } + } + } + }); +} + +void fixEntryPointCallsites(IRModule* module) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration<IREntryPointDecoration>()) + fixEntryPointCallsites((IRFunc*)globalInst); + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-fix-entrypoint-callsite.h b/source/slang/slang-ir-fix-entrypoint-callsite.h new file mode 100644 index 000000000..493d67a77 --- /dev/null +++ b/source/slang/slang-ir-fix-entrypoint-callsite.h @@ -0,0 +1,9 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +void fixEntryPointCallsites(IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 09bf245df..39f970319 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -67,11 +67,7 @@ struct ScalarizedValImpl : RefObject }; struct ScalarizedTupleValImpl; struct ScalarizedTypeAdapterValImpl; - -struct ScalarizedArrayIndexValImpl : ScalarizedValImpl -{ - Index index; -}; +struct ScalarizedArrayIndexValImpl; struct ScalarizedVal { @@ -132,15 +128,12 @@ struct ScalarizedVal result.impl = (ScalarizedValImpl*)impl; return result; } - static ScalarizedVal scalarizedArrayIndex(IRInst* irValue, Index index) + static ScalarizedVal scalarizedArrayIndex(ScalarizedArrayIndexValImpl* impl) { ScalarizedVal result; result.flavor = Flavor::arrayIndex; - auto impl = new ScalarizedArrayIndexValImpl; - impl->index = index; - - result.irValue = irValue; - result.impl = impl; + result.irValue = nullptr; + result.impl = (ScalarizedValImpl*)impl; return result; } @@ -151,8 +144,6 @@ struct ScalarizedVal RefPtr<ScalarizedValImpl> impl; }; -IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val); - // This is the case for a value that is a "tuple" of other values struct ScalarizedTupleValImpl : ScalarizedValImpl { @@ -175,6 +166,36 @@ struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl IRType* pretendType; // the type this value pretends to have }; +struct ScalarizedArrayIndexValImpl : ScalarizedValImpl +{ + ScalarizedVal arrayVal; + Index index; + IRType* elementType; +}; + +ScalarizedVal extractField( + IRBuilder* builder, + ScalarizedVal const& val, + UInt fieldIndex, // Pass ~0 in to search for the index via the key + IRStructKey* fieldKey); +ScalarizedVal adaptType(IRBuilder* builder, IRInst* val, IRType* toType, IRType* fromType); +ScalarizedVal adaptType( + IRBuilder* builder, + ScalarizedVal const& val, + IRType* toType, + IRType* fromType); +IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val); +ScalarizedVal getSubscriptVal( + IRBuilder* builder, + IRType* elementType, + ScalarizedVal val, + IRInst* indexVal); +ScalarizedVal getSubscriptVal( + IRBuilder* builder, + IRType* elementType, + ScalarizedVal val, + UInt index); + struct GlobalVaryingDeclarator { enum class Flavor @@ -1303,6 +1324,22 @@ ScalarizedVal createSimpleGLSLGlobalVarying( } } + AddressSpace addrSpace = AddressSpace::Uniform; + IROp ptrOpCode = kIROp_PtrType; + switch (kind) + { + case LayoutResourceKind::VaryingInput: + addrSpace = systemValueInfo ? AddressSpace::BuiltinInput : AddressSpace::Input; + break; + case LayoutResourceKind::VaryingOutput: + addrSpace = systemValueInfo ? AddressSpace::BuiltinOutput : AddressSpace::Output; + ptrOpCode = kIROp_OutType; + break; + default: + break; + } + + // If we have a declarator, we just use the normal logic, as that seems to work correctly // if (systemValueInfo && systemValueInfo->arrayIndex >= 0 && declarator == nullptr) @@ -1339,9 +1376,7 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // Set the array size to 0, to mean it is unsized auto arrayType = builder->getArrayType(type, 0); - IRType* paramType = kind == LayoutResourceKind::VaryingOutput - ? (IRType*)builder->getOutType(arrayType) - : arrayType; + IRType* paramType = builder->getPtrType(ptrOpCode, arrayType, addrSpace); auto globalParam = addGlobalParam(builder->getModule(), paramType); moveValueBefore(globalParam, builder->getFunc()); @@ -1371,9 +1406,12 @@ ScalarizedVal createSimpleGLSLGlobalVarying( semanticGlobal->addIndex(systemValueInfo->arrayIndex); // Make it an array index - ScalarizedVal val = ScalarizedVal::scalarizedArrayIndex( - semanticGlobal->globalParam, - systemValueInfo->arrayIndex); + ScalarizedVal val = ScalarizedVal::address(semanticGlobal->globalParam); + RefPtr<ScalarizedArrayIndexValImpl> arrayImpl = new ScalarizedArrayIndexValImpl(); + arrayImpl->arrayVal = val; + arrayImpl->index = systemValueInfo->arrayIndex; + arrayImpl->elementType = type; + val = ScalarizedVal::scalarizedArrayIndex(arrayImpl); // We need to make this access, an array access to the global if (auto fromType = systemValueInfo->requiredType) @@ -1466,14 +1504,14 @@ ScalarizedVal createSimpleGLSLGlobalVarying( // like our IR function parameters, and need a wrapper // `Out<...>` type to represent outputs. // - bool isOutput = (kind == LayoutResourceKind::VaryingOutput); - IRType* paramType = isOutput ? builder->getOutType(type) : type; + + // Non system value varying inputs shall be passed as pointers. + IRType* paramType = builder->getPtrType(ptrOpCode, type, addrSpace); auto globalParam = addGlobalParam(builder->getModule(), paramType); moveValueBefore(globalParam, builder->getFunc()); - ScalarizedVal val = - isOutput ? ScalarizedVal::address(globalParam) : ScalarizedVal::value(globalParam); + ScalarizedVal val = ScalarizedVal::address(globalParam); if (systemValueInfo) { @@ -1958,10 +1996,10 @@ ScalarizedVal adaptType( break; case ScalarizedVal::Flavor::arrayIndex: { - auto element = builder->emitElementExtract( - val.irValue, - as<ScalarizedArrayIndexValImpl>(val.impl)->index); - return adaptType(builder, element, toType, fromType); + auto arrayImpl = as<ScalarizedArrayIndexValImpl>(val.impl); + auto elementVal = + getSubscriptVal(builder, fromType, arrayImpl->arrayVal, arrayImpl->index); + return adaptType(builder, elementVal, toType, fromType); } break; default: @@ -1970,8 +2008,6 @@ ScalarizedVal adaptType( } } -IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val); - void assign( IRBuilder* builder, ScalarizedVal const& left, @@ -1988,16 +2024,12 @@ void assign( // Determine the index auto leftArrayIndexVal = as<ScalarizedArrayIndexValImpl>(left.impl); - const auto arrayIndex = leftArrayIndexVal->index; - - auto arrayIndexInst = builder->getIntValue(builder->getIntType(), arrayIndex); - - // Store to the index - auto address = builder->emitElementAddress( - builder->getPtrType(right.irValue->getFullType()), - left.irValue, - arrayIndexInst); - builder->emitStore(address, rhs); + auto leftVal = getSubscriptVal( + builder, + leftArrayIndexVal->elementType, + leftArrayIndexVal->arrayVal, + leftArrayIndexVal->index); + builder->emitStore(leftVal.irValue, rhs); break; } @@ -2236,10 +2268,10 @@ IRInst* materializeValue(IRBuilder* builder, ScalarizedVal const& val) case ScalarizedVal::Flavor::arrayIndex: { - auto element = builder->emitElementExtract( - val.irValue, - as<ScalarizedArrayIndexValImpl>(val.impl)->index); - return element; + auto impl = as<ScalarizedArrayIndexValImpl>(val.impl); + auto elementVal = + getSubscriptVal(builder, impl->elementType, impl->arrayVal, impl->index); + return materializeValue(builder, elementVal); } case ScalarizedVal::Flavor::tuple: { @@ -2735,9 +2767,9 @@ IRInst* getOrCreatePerVertexInputArray(GLSLLegalizationContext* context, IRInst* IRBuilder builder(inputVertexAttr); builder.setInsertBefore(inputVertexAttr); auto arrayType = builder.getArrayType( - inputVertexAttr->getDataType(), + tryGetPointedToType(&builder, inputVertexAttr->getDataType()), builder.getIntValue(builder.getIntType(), 3)); - arrayInst = builder.createGlobalParam(arrayType); + arrayInst = builder.createGlobalParam(builder.getPtrType(arrayType, AddressSpace::Input)); context->mapVertexInputToPerVertexArray[inputVertexAttr] = arrayInst; builder.addDecoration(arrayInst, kIROp_PerVertexDecoration); @@ -2770,21 +2802,105 @@ void tryReplaceUsesOfStageInput( [&](IRUse* use) { auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + builder.replaceOperand(use, val.irValue); + }); + } + break; + case ScalarizedVal::Flavor::address: + { + bool needMaterialize = false; + if (as<IRPtrTypeBase>(val.irValue->getDataType())) + { + if (!as<IRPtrTypeBase>(originalVal->getDataType())) + { + needMaterialize = true; + } + } + traverseUses( + originalVal, + [&](IRUse* use) + { + auto user = use->getUser(); if (user->getOp() == kIROp_GetPerVertexInputArray) { auto arrayInst = getOrCreatePerVertexInputArray(context, val.irValue); user->replaceUsesWith(arrayInst); user->removeAndDeallocate(); + return; + } + IRBuilder builder(user); + builder.setInsertBefore(user); + if (needMaterialize) + { + auto materializedVal = materializeValue(&builder, val); + builder.replaceOperand(use, materializedVal); } else { - IRBuilder builder(user); - builder.setInsertBefore(user); builder.replaceOperand(use, val.irValue); } }); } break; + case ScalarizedVal::Flavor::typeAdapter: + { + traverseUses( + originalVal, + [&](IRUse* use) + { + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + auto typeAdapter = as<ScalarizedTypeAdapterValImpl>(val.impl); + auto materializedInner = materializeValue(&builder, typeAdapter->val); + auto adapted = adaptType( + &builder, + materializedInner, + typeAdapter->pretendType, + typeAdapter->actualType); + if (user->getOp() == kIROp_Load) + { + user->replaceUsesWith(adapted.irValue); + user->removeAndDeallocate(); + } + else + { + use->set(adapted.irValue); + } + }); + } + break; + case ScalarizedVal::Flavor::arrayIndex: + { + traverseUses( + originalVal, + [&](IRUse* use) + { + auto arrayIndexImpl = as<ScalarizedArrayIndexValImpl>(val.impl); + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + auto subscriptVal = getSubscriptVal( + &builder, + arrayIndexImpl->elementType, + arrayIndexImpl->arrayVal, + arrayIndexImpl->index); + builder.setInsertBefore(user); + auto materializedInner = materializeValue(&builder, subscriptVal); + if (user->getOp() == kIROp_Load) + { + user->replaceUsesWith(materializedInner); + user->removeAndDeallocate(); + } + else + { + use->set(materializedInner); + } + }); + break; + } case ScalarizedVal::Flavor::tuple: { auto tupleVal = as<ScalarizedTupleValImpl>(val.impl); @@ -2793,22 +2909,36 @@ void tryReplaceUsesOfStageInput( [&](IRUse* use) { auto user = use->getUser(); - if (auto fieldExtract = as<IRFieldExtract>(user)) + switch (user->getOp()) { - auto fieldKey = fieldExtract->getField(); - ScalarizedVal fieldVal; - for (auto element : tupleVal->elements) + case kIROp_FieldExtract: + case kIROp_FieldAddress: { - if (element.key == fieldKey) + auto fieldKey = user->getOperand(1); + ScalarizedVal fieldVal; + for (auto element : tupleVal->elements) { - fieldVal = element.val; - break; + if (element.key == fieldKey) + { + fieldVal = element.val; + break; + } + } + if (fieldVal.flavor != ScalarizedVal::Flavor::none) + { + tryReplaceUsesOfStageInput(context, fieldVal, user); } } - if (fieldVal.flavor != ScalarizedVal::Flavor::none) + break; + case kIROp_Load: { - tryReplaceUsesOfStageInput(context, fieldVal, user); + IRBuilder builder(user); + builder.setInsertBefore(user); + auto materializedVal = materializeTupleValue(&builder, val); + user->replaceUsesWith(materializedVal); + user->removeAndDeallocate(); } + break; } }); } @@ -3066,7 +3196,7 @@ void legalizeEntryPointParameterForGLSL( // We are going to create a local variable of the appropriate // type, which will replace the parameter, along with // one or more global variables for the actual input/output. - + setInsertAfterOrdinaryInst(builder, pp); auto localVariable = builder->emitVar(valueType); auto localVal = ScalarizedVal::address(localVariable); @@ -3135,6 +3265,73 @@ void legalizeEntryPointParameterForGLSL( assign(&terminatorBuilder, globalOutputVal, localVal); } } + else if (auto ptrType = as<IRPtrTypeBase>(paramType)) + { + // This is the case where the parameter is passed by const + // reference. We simply replace existing uses of the parameter + // with the real global variable. + SLANG_ASSERT( + ptrType->getOp() == kIROp_ConstRefType || + ptrType->getAddressSpace() == AddressSpace::Input || + ptrType->getAddressSpace() == AddressSpace::BuiltinInput); + + auto globalValue = createGLSLGlobalVaryings( + context, + codeGenContext, + builder, + valueType, + paramLayout, + LayoutResourceKind::VaryingInput, + stage, + pp); + tryReplaceUsesOfStageInput(context, globalValue, pp); + for (auto dec : pp->getDecorations()) + { + if (dec->getOp() != kIROp_GlobalVariableShadowingGlobalParameterDecoration) + continue; + auto globalVar = dec->getOperand(0); + auto key = dec->getOperand(1); + IRInst* realGlobalVar = nullptr; + if (globalValue.flavor != ScalarizedVal::Flavor::tuple) + continue; + if (auto tupleVal = as<ScalarizedTupleValImpl>(globalValue.impl)) + { + for (auto elem : tupleVal->elements) + { + if (elem.key == key) + { + realGlobalVar = elem.val.irValue; + break; + } + } + } + SLANG_ASSERT(realGlobalVar); + + // Remove all stores into the global var introduced during + // the initial glsl global var translation pass since we are + // going to replace the global var with a pointer to the real + // input, and it makes no sense to store values into such real + // input locations. + traverseUses( + globalVar, + [&](IRUse* use) + { + auto user = use->getUser(); + if (auto store = as<IRStore>(user)) + { + if (store->getPtrUse() == use) + { + store->removeAndDeallocate(); + } + } + }); + // we will be replacing uses of `globalVarToReplace`. We need + // globalVarToReplaceNextUse to catch the next use before it is removed from the + // list of uses. + globalVar->replaceUsesWith(realGlobalVar); + globalVar->removeAndDeallocate(); + } + } else { // This is the "easy" case where the parameter wasn't @@ -3451,6 +3648,7 @@ ScalarizedVal legalizeEntryPointReturnValueForGLSL( return result; } + void legalizeEntryPointForGLSL( Session* session, IRModule* module, @@ -3554,12 +3752,12 @@ void legalizeEntryPointForGLSL( // and turn them into global variables. if (auto firstBlock = func->getFirstBlock()) { - // Any initialization code we insert for parameters needs - // to be at the start of the "ordinary" instructions in the block: - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - for (auto pp = firstBlock->getFirstParam(); pp; pp = pp->getNextParam()) { + // Any initialization code we insert for parameters needs + // to be at the start of the "ordinary" instructions in the block: + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + // We assume that the entry-point parameters will all have // layout information attached to them, which is kept up-to-date // by any transformations affecting the parameter list. @@ -3606,11 +3804,11 @@ void legalizeEntryPointForGLSL( { auto type = value.globalParam->getDataType(); - // Strip out if there is one - auto outType = as<IROutType>(type); - if (outType) + // Strip ptr if there is one. + auto ptrType = as<IRPtrTypeBase>(type); + if (ptrType) { - type = outType->getValueType(); + type = ptrType->getValueType(); } // Get the array type @@ -3627,10 +3825,13 @@ void legalizeEntryPointForGLSL( auto elementCountInst = builder.getIntValue(builder.getIntType(), value.maxIndex + 1); IRType* sizedArrayType = builder.getArrayType(elementType, elementCountInst); - // Re-add out if there was one on the input - if (outType) + // Re-add ptr if there was one on the input + if (ptrType) { - sizedArrayType = builder.getOutType(sizedArrayType); + sizedArrayType = builder.getPtrType( + ptrType->getOp(), + sizedArrayType, + ptrType->getAddressSpace()); } // Change the globals type diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 01466ed00..88a9ac5e3 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -736,7 +736,8 @@ INST(GetVulkanRayTracingPayloadLocation, GetVulkanRayTracingPayloadLocation, 1, INST(GetLegalizedSPIRVGlobalParamAddr, GetLegalizedSPIRVGlobalParamAddr, 1, 0) -INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, 0) +INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, HOISTABLE) +INST(ResolveVaryingInputRef, ResolveVaryingInputRef, 1, HOISTABLE) INST(ForceVarIntoStructTemporarily, ForceVarIntoStructTemporarily, 1, 0) INST(MetalAtomicCast, MetalAtomicCast, 1, 0) diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 025bcf1b8..33f3944fd 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -651,7 +651,9 @@ protected: // The materialized value can be used to completely // replace the original parameter. // - param->replaceUsesWith(materialized); + auto localVar = builder.emitVar(materialized->getDataType()); + builder.emitStore(localVar, materialized); + param->replaceUsesWith(localVar); param->removeAndDeallocate(); } @@ -1475,4 +1477,71 @@ void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* si context.processModule(module, sink); } +void depointerizeInputParams(IRFunc* entryPointFunc) +{ + List<IRParam*> workList; + List<Index> modifiedParamIndices; + Index i = 0; + for (auto param : entryPointFunc->getParams()) + { + if (auto constRefType = as<IRConstRefType>(param->getFullType())) + { + switch (constRefType->getValueType()->getOp()) + { + case kIROp_VerticesType: + case kIROp_IndicesType: + case kIROp_PrimitivesType: + continue; + default: + break; + } + workList.add(param); + modifiedParamIndices.add(i); + } + else if (auto ptrType = as<IRPtrTypeBase>(param->getFullType())) + { + switch (ptrType->getAddressSpace()) + { + case AddressSpace::Input: + case AddressSpace::BuiltinInput: + workList.add(param); + modifiedParamIndices.add(i); + break; + } + } + i++; + } + for (auto param : workList) + { + auto valueType = as<IRPtrTypeBase>(param->getDataType())->getValueType(); + IRBuilder builder(param); + setInsertBeforeOrdinaryInst(&builder, param); + auto var = builder.emitVar(valueType); + param->replaceUsesWith(var); + param->setFullType(valueType); + builder.emitStore(var, param); + } + + fixUpFuncType(entryPointFunc); + + // Fix up callsites of the entrypoint func. + for (auto use = entryPointFunc->firstUse; use; use = use->nextUse) + { + auto call = as<IRCall>(use->getUser()); + if (!call) + continue; + IRBuilder builder(call); + builder.setInsertBefore(call); + for (auto paramIndex : modifiedParamIndices) + { + auto arg = call->getArg(paramIndex); + auto ptrType = as<IRPtrTypeBase>(arg->getDataType()); + if (!ptrType) + continue; + auto val = builder.emitLoad(arg); + call->setArg(paramIndex, val); + } + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index 7604cb245..efd61e87c 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -18,6 +18,7 @@ void legalizeEntryPointVaryingParamsForCPU(IRModule* module, DiagnosticSink* sin void legalizeEntryPointVaryingParamsForCUDA(IRModule* module, DiagnosticSink* sink); +void depointerizeInputParams(IRFunc* entryPoint); // (#4375) Once `slang-ir-metal-legalize.cpp` is merged with // `slang-ir-legalize-varying-params.cpp`, move the following diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp index 7f67c9254..74e84f1ee 100644 --- a/source/slang/slang-ir-lower-buffer-element-type.cpp +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -412,6 +412,28 @@ struct LoweredElementTypeContext return 4; } + bool shouldLowerMatrixType(IRMatrixType* matrixType, TypeLoweringConfig config) + { + // For spirv, we always want to lower all matrix types, because SPIRV does not support + // specifying matrix layout/stride if the matrix type is used in places other than + // defining a struct field. This means that if a matrix is used to define a varying + // parameter, we always want to wrap it in a struct. + // + if (target->shouldEmitSPIRVDirectly()) + { + return true; + } + + if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && + config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural) + { + // For other targets, we only lower the matrix types if they differ from the default + // matrix layout. + return false; + } + return true; + } + LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type, TypeLoweringConfig config) { IRBuilder builder(type); @@ -422,18 +444,10 @@ struct LoweredElementTypeContext if (auto matrixType = as<IRMatrixType>(type)) { - // For spirv, we always want to lower all matrix types, because matrix types - // are considered abstract types. - if (!target->shouldEmitSPIRVDirectly()) + if (!shouldLowerMatrixType(matrixType, config)) { - // For other targets, we only lower the matrix types if they differ from the default - // matrix layout. - if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout && - config.layoutRule->ruleName == IRTypeLayoutRuleName::Natural) - { - info.loweredType = type; - return info; - } + info.loweredType = type; + return info; } auto loweredType = builder.createStructType(); @@ -859,27 +873,24 @@ struct LoweredElementTypeContext { IRType* elementType = nullptr; - if (options.lowerBufferPointer) + if (auto ptrType = as<IRPtrTypeBase>(globalInst)) { - if (auto ptrType = as<IRPtrTypeBase>(globalInst)) + switch (ptrType->getAddressSpace()) { - switch (ptrType->getAddressSpace()) - { - case AddressSpace::UserPointer: - case AddressSpace::Input: - case AddressSpace::Output: - elementType = ptrType->getValueType(); - break; - } + case AddressSpace::UserPointer: + if (!options.lowerBufferPointer) + continue; + [[fallthrough]]; + case AddressSpace::Input: + case AddressSpace::Output: + elementType = ptrType->getValueType(); + break; } } - else - { - if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) - elementType = structBuffer->getElementType(); - else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) - elementType = constBuffer->getElementType(); - } + if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) + elementType = structBuffer->getElementType(); + else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) + elementType = constBuffer->getElementType(); if (as<IRTextureBufferType>(globalInst)) continue; if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index 835041a59..ce5b34c3e 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1924,6 +1924,7 @@ struct LegalizeMetalEntryPointContext void legalizeEntryPointForMetal(EntryPointInfo entryPoint) { // Input Parameter Legalize + depointerizeInputParams(entryPoint.entryPointFunc); hoistEntryPointParameterFromStruct(entryPoint); packStageInParameters(entryPoint); flattenInputParameters(entryPoint); diff --git a/source/slang/slang-ir-resolve-varying-input-ref.cpp b/source/slang/slang-ir-resolve-varying-input-ref.cpp new file mode 100644 index 000000000..0707c566f --- /dev/null +++ b/source/slang/slang-ir-resolve-varying-input-ref.cpp @@ -0,0 +1,92 @@ +#include "slang-ir-resolve-varying-input-ref.h" + +#include "slang-ir-insts.h" +#include "slang-ir-util.h" + +namespace Slang +{ +void resolveVaryingInputRef(IRFunc* func) +{ + List<IRInst*> toRemove; + for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + for (auto inst : bb->getChildren()) + { + switch (inst->getOp()) + { + case kIROp_ResolveVaryingInputRef: + { + // Resolve a reference to varying input to the actual global param + // representing the varying input. + auto operand = inst->getOperand(0); + List<IRInst*> accessChain; + List<IRInst*> types; + auto rootAddr = getRootAddr(operand, accessChain, &types); + if (rootAddr->getOp() == kIROp_Param || rootAddr->getOp() == kIROp_GlobalParam) + { + // If the referred operand is already a global param, use it directly. + inst->replaceUsesWith(operand); + toRemove.add(inst); + break; + } + // If the referred operand is a local var, + // and there is a store(var, load(globalParam)), + // replace `inst` with `globalParam`. + IRInst* srcPtr = nullptr; + for (auto use = rootAddr->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (auto store = as<IRStore>(user)) + { + if (store->getPtrUse() == use) + { + if (auto load = as<IRLoad>(store->getVal())) + { + auto ptr = load->getPtr(); + if (ptr->getOp() == kIROp_Param || + ptr->getOp() == kIROp_GlobalParam) + { + if (!srcPtr) + srcPtr = ptr; + else + { + srcPtr = nullptr; + break; + } + } + } + } + } + } + if (srcPtr) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto resolvedPtr = builder.emitElementAddress( + srcPtr, + accessChain.getArrayView(), + types.getArrayView()); + inst->replaceUsesWith(resolvedPtr); + toRemove.add(inst); + } + } + break; + } + } + } + for (auto inst : toRemove) + { + inst->removeAndDeallocate(); + } +} + +void resolveVaryingInputRef(IRModule* module) +{ + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->findDecoration<IREntryPointDecoration>()) + resolveVaryingInputRef((IRFunc*)globalInst); + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-resolve-varying-input-ref.h b/source/slang/slang-ir-resolve-varying-input-ref.h new file mode 100644 index 000000000..5cbff0f8c --- /dev/null +++ b/source/slang/slang-ir-resolve-varying-input-ref.h @@ -0,0 +1,10 @@ +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +void resolveVaryingInputRef(IRFunc* func); +void resolveVaryingInputRef(IRModule* module); + +} // namespace Slang diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 65cb8f64f..a44e16a7c 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -122,7 +122,8 @@ struct GlobalVarTranslationContext // Add an entry point parameter for all the inputs. auto firstBlock = entryPointFunc->getFirstBlock(); builder.setInsertInto(firstBlock); - auto inputParam = builder.emitParam(inputStructType); + auto inputParam = builder.emitParam( + builder.getPtrType(kIROp_ConstRefType, inputStructType, AddressSpace::Input)); builder.addLayoutDecoration(inputParam, paramLayout); // Initialize all global variables. @@ -133,7 +134,8 @@ struct GlobalVarTranslationContext auto inputType = cast<IRPtrTypeBase>(input->getDataType())->getValueType(); builder.emitStore( input, - builder.emitFieldExtract(inputType, inputParam, inputKeys[i])); + builder + .emitFieldExtract(inputType, builder.emitLoad(inputParam), inputKeys[i])); // Relate "global variable" to a "global parameter" for use later in compilation // to resolve a "global variable" shadowing a "global parameter" relationship. builder.addGlobalVariableShadowingGlobalParameterDecoration( diff --git a/source/slang/slang-ir-vk-invert-y.cpp b/source/slang/slang-ir-vk-invert-y.cpp index e7fc81144..70f2584ac 100644 --- a/source/slang/slang-ir-vk-invert-y.cpp +++ b/source/slang/slang-ir-vk-invert-y.cpp @@ -104,10 +104,15 @@ void rcpWOfPositionInput(IRModule* module) [&](IRUse* use) { // Get the inverted vector. - builder.setInsertBefore(use->getUser()); - auto invertedVal = _invertWOfVector(builder, globalInst); - // Replace original uses with the invertex vector. - builder.replaceOperand(use, invertedVal); + auto user = use->getUser(); + if (user->getOp() == kIROp_Load) + { + builder.setInsertBefore(user); + auto val = builder.emitLoad(globalInst); + auto invertedVal = _invertWOfVector(builder, val); + user->replaceUsesWith(invertedVal); + user->removeAndDeallocate(); + } }); } } diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index 907c2b8ba..f76a0541c 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -1362,6 +1362,9 @@ struct LegalizeWGSLEntryPointContext void legalizeEntryPointForWGSL(EntryPointInfo entryPoint) { + // If the entrypoint is receiving varying inputs as a pointer, turn it into a value. + depointerizeInputParams(entryPoint.entryPointFunc); + // Input Parameter Legalize flattenInputParameters(entryPoint); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index ff1cd49ea..daeaca67b 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8243,6 +8243,8 @@ bool IRInst::mightHaveSideEffects(SideEffectAnalysisOptions options) case kIROp_GetStringHash: case kIROp_AllocateOpaqueHandle: case kIROp_GetArrayLength: + case kIROp_ResolveVaryingInputRef: + case kIROp_GetPerVertexInputArray: return false; case kIROp_ForwardDifferentiate: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5bbe44e9b..011ea6bc7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2767,14 +2767,15 @@ ParameterDirection getParameterDirection(VarDeclBase* paramDecl) /// ParameterDirection getThisParamDirection(Decl* parentDecl, ParameterDirection defaultDirection) { - auto parentParent = getParentDecl(parentDecl); + auto parentParent = getParentAggTypeDecl(parentDecl); + // The `this` parameter for a `class` is always `in`. if (as<ClassDecl>(parentParent)) { return kParameterDirection_In; } - if (parentParent->findModifier<NonCopyableTypeAttribute>()) + if (parentParent && parentParent->findModifier<NonCopyableTypeAttribute>()) { if (parentDecl->hasModifier<MutatingAttribute>()) return kParameterDirection_Ref; @@ -2982,6 +2983,9 @@ struct IRLoweringParameterInfo // The direction (`in` vs `out` vs `in out`) ParameterDirection direction; + // The direction declared in user code. + ParameterDirection declaredDirection = ParameterDirection::kParameterDirection_In; + // The variable/parameter declaration for // this parameter (if any) VarDeclBase* decl = nullptr; @@ -3005,6 +3009,7 @@ IRLoweringParameterInfo getParameterInfo( info.type = getParamType(context->astBuilder, paramDecl); info.decl = paramDecl.getDecl(); info.direction = getParameterDirection(paramDecl.getDecl()); + info.declaredDirection = info.direction; info.isThisParam = false; return info; } @@ -3051,6 +3056,7 @@ void addThisParameter(ParameterDirection direction, Type* type, ParameterLists* info.type = type; info.decl = nullptr; info.direction = direction; + info.declaredDirection = direction; info.isThisParam = true; ioParameterLists->params.add(info); @@ -3064,10 +3070,22 @@ void maybeAddReturnDestinationParam(ParameterLists* ioParameterLists, Type* resu info.type = resultType; info.decl = nullptr; info.direction = kParameterDirection_Ref; + info.declaredDirection = info.direction; info.isReturnDestination = true; ioParameterLists->params.add(info); } } + +void makeVaryingInputParamConstRef(IRLoweringParameterInfo& paramInfo) +{ + if (paramInfo.direction != kParameterDirection_In) + return; + if (paramInfo.decl->findModifier<HLSLUniformModifier>()) + return; + if (as<HLSLPatchType>(paramInfo.type)) + return; + paramInfo.direction = kParameterDirection_ConstRef; +} // // And here is our function that will do the recursive walk: void collectParameterLists( @@ -3137,13 +3155,31 @@ void collectParameterLists( // if (auto callableDeclRef = declRef.as<CallableDecl>()) { + // We need a special case here when lowering the varying parameters of an entrypoint + // function. Due to the existence of `EvaluateAttributeAtSample` and friends, we need to + // always lower the varying inputs as `__constref` parameters so we can pass pointers to + // these intrinsics. + // This means that although these parameters are declared as "in" parameters in the source, + // we will actually treat them as __constref parameters when lowering to IR. A complication + // result from this is that if the original source code actually modifies the input + // parameter we still need to create a local var to hold the modified value. In the future + // when we are able to update our language spec to always assume input parameters are + // immutable, then we can remove this adhoc logic of introducing temporary variables. For + // For now we will rely on a follow up pass to remove unnecessary temporary variables if + // we can determine that they are never actually writtten to by the user. + // + bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier<EntryPointAttribute>(); + // Don't collect parameters from the outer scope if // we are in a `static` context. if (mode == kParameterListCollectMode_Default) { for (auto paramDeclRef : getParameters(context->astBuilder, callableDeclRef)) { - ioParameterLists->params.add(getParameterInfo(context, paramDeclRef)); + auto paramInfo = getParameterInfo(context, paramDeclRef); + if (lowerVaryingInputAsConstRef) + makeVaryingInputParamConstRef(paramInfo); + ioParameterLists->params.add(paramInfo); } maybeAddReturnDestinationParam( ioParameterLists, @@ -5623,9 +5659,7 @@ struct RValueExprLoweringVisitor : public ExprLoweringVisitorBase<RValueExprLowe LoweredValInfo visitOpenRefExpr(OpenRefExpr* expr) { auto inner = lowerLValueExpr(context, expr->innerExpr); - auto builder = getBuilder(); - auto irLoad = builder->emitLoad(inner.val); - return LoweredValInfo::simple(irLoad); + return LoweredValInfo::ptr(inner.val); } }; @@ -9980,6 +10014,22 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (paramInfo.isReturnDestination) subContext->returnDestination = paramVal; + if (paramInfo.declaredDirection == kParameterDirection_In && + paramInfo.direction == kParameterDirection_ConstRef) + { + // If the parameter is originally declared as "in", but we are + // lowering it as constref for any reason (e.g. it is a varying input), + // then we need to emit a local variable to hold the original value, so + // that we can still generate correct code when the user trys to mutate + // the variable. + // The local variable introduced here is cleaned up by the SSA pass, if + // we can determine that there are no actual writes into the local var. + auto irLocal = + subBuilder->emitVar(tryGetPointedToType(subBuilder, irParamType)); + auto localVal = LoweredValInfo::ptr(irLocal); + assign(subContext, localVal, paramVal); + paramVal = localVal; + } // TODO: We might want to copy the pointed-to value into // a temporary at the start of the function, and then copy // back out at the end, so that we don't have to worry @@ -10987,6 +11037,16 @@ static void lowerFrontEndEntryPointToIR( auto entryPointFuncDecl = entryPoint->getFuncDecl(); + if (!entryPointFuncDecl->findModifier<EntryPointAttribute>()) + { + // If the entry point doesn't have an explicit `[shader("...")]` attribute, + // then we make sure to add one here, so the lowering logic knows it is an + // entry point. + auto entryPointAttr = context->astBuilder->create<EntryPointAttribute>(); + entryPointAttr->capabilitySet = entryPoint->getProfile().getCapabilityName(); + addModifier(entryPointFuncDecl, entryPointAttr); + } + auto builder = context->irBuilder; builder->setInsertInto(builder->getModule()->getModuleInst()); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 22491c848..c275a868b 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -7531,12 +7531,6 @@ static IRFloatingPointValue _foldFloatPrefixOp(TokenType tokenType, IRFloatingPo static std::optional<SPIRVAsmOperand> parseSPIRVAsmOperand(Parser* parser) { - const auto slangIdentOperand = [&](auto flavor) - { - auto token = parser->tokenReader.peekToken(); - return SPIRVAsmOperand{flavor, token, parseAtomicExpr(parser)}; - }; - const auto slangTypeExprOperand = [&](auto flavor) { auto tok = parser->tokenReader.peekToken(); @@ -7673,12 +7667,13 @@ static std::optional<SPIRVAsmOperand> parseSPIRVAsmOperand(Parser* parser) // A &foo variable reference (for the address of foo) else if (AdvanceIf(parser, TokenType::OpBitAnd)) { - return slangIdentOperand(SPIRVAsmOperand::SlangValueAddr); + Expr* expr = parsePostfixExpr(parser); + return SPIRVAsmOperand{SPIRVAsmOperand::SlangValueAddr, Token{}, expr}; } // A $foo variable else if (AdvanceIf(parser, TokenType::Dollar)) { - Expr* expr = parseAtomicExpr(parser); + Expr* expr = parsePostfixExpr(parser); return SPIRVAsmOperand{SPIRVAsmOperand::SlangValue, Token{}, expr}; } // A $$foo type diff --git a/tests/bugs/gh-841.slang b/tests/bugs/gh-841.slang index ba746984b..5f7e0c81f 100644 --- a/tests/bugs/gh-841.slang +++ b/tests/bugs/gh-841.slang @@ -11,8 +11,8 @@ struct RasterVertex float4 c : COLOR; // Make sure that the input value in location 1 is decorated as Flat - // SPV-DAG: [[VAL:%[_A-Za-z0-9]+]] = OpVariable {{.*}} Input - // SPV-DAG: OpDecorate [[VAL]] Location 1 + // SPV-DAG: OpDecorate [[VAL:%[_A-Za-z0-9]+]] Location 1 + // SPV-DAG: [[VAL]] = OpVariable {{.*}} Input // SPV-DAG: OpDecorate [[VAL]] Flat // // Likewise for GLSL diff --git a/tests/bugs/vk-structured-buffer-load.hlsl b/tests/bugs/vk-structured-buffer-load.hlsl index d9e54d925..ac8a86a5c 100644 --- a/tests/bugs/vk-structured-buffer-load.hlsl +++ b/tests/bugs/vk-structured-buffer-load.hlsl @@ -1,4 +1,9 @@ //TEST:CROSS_COMPILE: -profile glsl_460+GL_NV_ray_tracing -entry HitMain -stage closesthit -target spirv-assembly +//TEST:SIMPLE(filecheck=DXIL): -target dxil -entry HitMain -stage closesthit -profile sm_6_5 +//TEST:SIMPLE(filecheck=SPV): -target spirv + +// DXIL: define void @ +// SPV: OpEntryPoint #define USE_RCP 0 diff --git a/tests/cross-compile/glsl-generic-in.slang b/tests/cross-compile/glsl-generic-in.slang index a743c32cb..6bf2d28fb 100644 --- a/tests/cross-compile/glsl-generic-in.slang +++ b/tests/cross-compile/glsl-generic-in.slang @@ -1,9 +1,9 @@ //TEST:SIMPLE(filecheck=CHECK): -target spirv-assembly -entry main -profile vs_5_0 -emit-spirv-directly //TEST:SIMPLE(filecheck=CHECK): -target spirv-assembly -entry main -profile vs_5_0 -emit-spirv-via-glsl -// CHECK: vIn_field_v0{{.*}} = OpVariable %_ptr_Input_v4float Input -// CHECK: %vIn_field_v1{{.*}}= OpVariable %_ptr_Input_v2float Input -// CHECK: %vIn_p0{{.*}}= OpVariable %_ptr_Input_v3float Input +// CHECK-DAG: vIn_field_v0{{.*}} = OpVariable %_ptr_Input_v4float Input +// CHECK-DAG: %vIn_field_v1{{.*}}= OpVariable %_ptr_Input_v2float Input +// CHECK-DAG: %vIn_p0{{.*}}= OpVariable %_ptr_Input_v3float Input interface IField { diff --git a/tests/glsl-intrinsic/fragment-processing/fragment-processing.slang b/tests/glsl-intrinsic/fragment-processing/fragment-processing.slang index 909679bbe..c69752cd0 100644 --- a/tests/glsl-intrinsic/fragment-processing/fragment-processing.slang +++ b/tests/glsl-intrinsic/fragment-processing/fragment-processing.slang @@ -76,8 +76,8 @@ bool testFragmentProcessingDerivativeFunctionsVector() } bool testFragmentProcessingInterpolateFunctions() { -// CHECK_SPV: {{.*}} = OpExtInst {{.*}} {{.*}} InterpolateAtCentroid %inDataV1 -// CHECK_GLSL: interpolateAtCentroid{{.*}}inDataV1 +// CHECK_SPV-DAG: {{.*}} = OpExtInst {{.*}} {{.*}} InterpolateAtCentroid %inDataV1 +// CHECK_GLSL-DAG: interpolateAtCentroid{{.*}}inDataV1 // CHECK_SPV: {{.*}} = OpExtInst {{.*}} {{.*}} InterpolateAtSample %inDataV1 {{.*}} // CHECK_GLSL: interpolateAtSample{{.*}}inDataV1 // CHECK_SPV: {{.*}} = OpExtInst {{.*}} {{.*}} InterpolateAtOffset %inDataV1 {{.*}} diff --git a/tests/glsl/matrix-mul.slang b/tests/glsl/matrix-mul.slang index 156673b87..3bdd1cb8d 100644 --- a/tests/glsl/matrix-mul.slang +++ b/tests/glsl/matrix-mul.slang @@ -1,6 +1,6 @@ //TEST:SIMPLE(filecheck=SPIRV): -target spirv -stage vertex -entry main -allow-glsl -emit-spirv-directly //TEST:SIMPLE(filecheck=SPIRV): -target spirv -stage vertex -entry main -allow-glsl -//TEST:SIMPLE(filecheck=METAL): -target metal -stage vertex -entry main -allow-glsl +//TEST:SIMPLE(filecheck=METAL): -target metal -stage vertex -entry main -allow-glsl -matrix-layout-row-major #version 310 es layout(location = 0) in highp vec4 a_position; diff --git a/tests/hlsl-intrinsic/fragment-interpolate.slang b/tests/hlsl-intrinsic/fragment-interpolate.slang new file mode 100644 index 000000000..f64e4e13b --- /dev/null +++ b/tests/hlsl-intrinsic/fragment-interpolate.slang @@ -0,0 +1,17 @@ +//TEST:SIMPLE(filecheck=CHECK_HLSL): -target hlsl -stage fragment -entry main +//TEST:SIMPLE(filecheck=CHECK_SPV): -target spirv -emit-spirv-directly -stage fragment -entry main + +struct VertexOut +{ + float4 pos : SV_Position; + float3 color; +} + +// CHECK_SPV: %v_color = OpVariable %_ptr_Input_v3float Input +// CHECK_SPV: %{{.*}} = OpExtInst %v3float %{{.*}} InterpolateAtCentroid %v_color +// CHECK_HLSL: EvaluateAttributeAtCentroid(v_0.color_0) + +float4 main(VertexOut v) : SV_Target +{ + return float4(EvaluateAttributeAtCentroid(v.color), 1.0); +}
\ No newline at end of file diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang deleted file mode 100644 index d7bdbc69c..000000000 --- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang +++ /dev/null @@ -1,17 +0,0 @@ -// get-attribute-at-vertex.slang - -// Basic test for `GetAttributeAtVertex` function - -//TEST:CROSS_COMPILE:-target dxil -capability GL_NV_fragment_shader_barycentric -entry main -stage fragment -profile sm_6_1 -//TEST:CROSS_COMPILE:-target spirv -capability GL_NV_fragment_shader_barycentric -entry main -stage fragment -profile glsl_450 - -[shader("fragment")] -void main( - pervertex float4 color : COLOR, - float3 bary : SV_Barycentrics, - out float4 result : SV_Target) -{ - result = bary.x * GetAttributeAtVertex(color, 0) - + bary.y * GetAttributeAtVertex(color, 1) - + bary.z * GetAttributeAtVertex(color, 2); -} diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl deleted file mode 100644 index 820918d8b..000000000 --- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.glsl +++ /dev/null @@ -1,20 +0,0 @@ -// get-attribute-at-vertex.slang.glsl -//TEST_IGNORE_FILE: - -#version 450 -#extension GL_EXT_fragment_shader_barycentric : require -layout(row_major) uniform; -layout(row_major) buffer; - -pervertexEXT layout(location = 0) -in vec4 color_0[3]; - -layout(location = 0) -out vec4 result_0; - -void main() -{ - result_0 = gl_BaryCoordEXT.x * ((color_0)[(0U)]) + gl_BaryCoordEXT.y * ((color_0)[(1U)]) + gl_BaryCoordEXT.z * ((color_0)[(2U)]); - return; -} - diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl b/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl deleted file mode 100644 index a6b45eab4..000000000 --- a/tests/pipeline/rasterization/get-attribute-at-vertex-nv.slang.hlsl +++ /dev/null @@ -1,16 +0,0 @@ -// get-attribute-at-vertex.slang.hlsl - -//TEST_IGNORE_FILE: - -#pragma warning(disable: 3557) - -[shader("pixel")] -void main( - nointerpolation vector<float,4> color_0 : COLOR, - vector<float,3> bary_0 : SV_BARYCENTRICS, - out vector<float,4> result_0 : SV_TARGET) -{ - result_0 = bary_0.x * GetAttributeAtVertex(color_0, 0U) - + bary_0.y * GetAttributeAtVertex(color_0, 1U) - + bary_0.z * GetAttributeAtVertex(color_0, 2U); -} diff --git a/tests/pipeline/rasterization/get-attribute-at-vertex.slang b/tests/pipeline/rasterization/get-attribute-at-vertex.slang index 9ae347a3a..c334200fb 100644 --- a/tests/pipeline/rasterization/get-attribute-at-vertex.slang +++ b/tests/pipeline/rasterization/get-attribute-at-vertex.slang @@ -2,8 +2,6 @@ // Basic test for `GetAttributeAtVertex` function -//TEST:CROSS_COMPILE:-target dxil -entry main -stage fragment -profile sm_6_1 -//TEST:CROSS_COMPILE:-target spirv -entry main -stage fragment -profile glsl_450+GL_EXT_fragment_shader_barycentric //TEST:SIMPLE(filecheck=CHECK):-emit-spirv-directly -target spirv -entry main -stage fragment -profile glsl_450+GL_EXT_fragment_shader_barycentric // CHECK: OpCapability FragmentBarycentricKHR diff --git a/tests/pipeline/rasterization/varying-to-inout.slang b/tests/pipeline/rasterization/varying-to-inout.slang new file mode 100644 index 000000000..7a54fd82f --- /dev/null +++ b/tests/pipeline/rasterization/varying-to-inout.slang @@ -0,0 +1,22 @@ +// Test passing a varying parameter direclty to an inout parameter. + +//TEST:SIMPLE(filecheck=CHECK):-target spirv -entry main -stage fragment + +// CHECK: OpEntryPoint Fragment %main "main" +struct PS_IN +{ + float3 pos : SV_Position; + float4 color : COLOR; +} + +void test(inout PS_IN v) +{ + v.color = v.color + v.pos.x; +} + +[shader("fragment")] +float4 main(PS_IN psIn):SV_Target +{ + test(psIn); + return psIn.color; +} diff --git a/tests/spirv/array-uniform-param.slang b/tests/spirv/array-uniform-param.slang index 235e85bbd..672543b9a 100644 --- a/tests/spirv/array-uniform-param.slang +++ b/tests/spirv/array-uniform-param.slang @@ -1,8 +1,10 @@ // array-uniform-param.slang -//TESTD:SIMPLE:-target spirv -entry computeMain -stage compute -emit-spirv-directly -force-glsl-scalar-layout +//TEST:SIMPLE(filecheck=CHECK):-target spirv -entry computeMain -stage compute -emit-spirv-directly -force-glsl-scalar-layout //TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -output-using-type +// CHECK: OpEntryPoint + // Test direct SPIR-V emit on arrays in uniforms. //TEST_INPUT:set resultBuffer = out ubuffer(data=[0 0 0 0], stride=4) diff --git a/tests/spirv/matrix-vertex-input.slang b/tests/spirv/matrix-vertex-input.slang index fc4af8c61..b6277bead 100644 --- a/tests/spirv/matrix-vertex-input.slang +++ b/tests/spirv/matrix-vertex-input.slang @@ -1,23 +1,84 @@ -//TEST:SIMPLE(filecheck=CHECK): -target spirv -// CHECK: OpVectorTimesMatrix +//TEST(compute):COMPARE_RENDER_COMPUTE(filecheck-buffer=ROWMAJOR): -vk -output-using-type +//TEST(compute):COMPARE_RENDER_COMPUTE(filecheck-buffer=ROWMAJOR): -d3d11 -output-using-type -struct Vertex +//TEST(compute):COMPARE_RENDER_COMPUTE(filecheck-buffer=COLMAJOR): -vk -output-using-type -emit-spirv-directly -xslang -DCOLUMN_MAJOR +//TEST(compute):COMPARE_RENDER_COMPUTE(filecheck-buffer=COLMAJOR): -d3d11 -output-using-type -xslang -DCOLUMN_MAJOR + +// Check that row_major and column_major matrix typed vertex input are correctly handled. + +//TEST_INPUT: Texture2D(size=4, content = one):name t +//TEST_INPUT: Sampler:name s +//TEST_INPUT: ubuffer(data=[0], stride=4):out, name outputBuffer + +Texture2D t; +SamplerState s; +RWStructuredBuffer<float> outputBuffer; + +cbuffer Uniforms { - float4x4 m; - float4 pos; + float4x4 modelViewProjection; } -struct VertexOut +struct AssembledVertex +{ + float3 position; + float3 color; + float2 uv; +#ifdef COLUMN_MAJOR + column_major float4x4 m; +#else + row_major float4x4 m; +#endif +}; + +struct CoarseVertex +{ + float3 color; +}; + +struct Fragment +{ + float4 color; +}; + +// Vertex Shader + +struct VertexStageInput +{ + AssembledVertex assembledVertex : A; +}; + +struct VertexStageOutput +{ + CoarseVertex coarseVertex : CoarseVertex; + float4 sv_position : SV_Position; +}; + +VertexStageOutput vertexMain(VertexStageInput input) { - float4 pos : SV_Position; - float4 color; + VertexStageOutput output; + output.coarseVertex.color = input.assembledVertex.m[1][2]; + output.sv_position = mul(modelViewProjection, float4(input.assembledVertex.position, 1.0)); + return output; } -[shader("vertex")] -VertexOut vertMain(Vertex v) +struct FragmentStageInput { - VertexOut o; - o.pos = mul(v.m, v.pos); - o.color = v.pos; - return o; -}
\ No newline at end of file + CoarseVertex coarseVertex : CoarseVertex; +}; + +struct FragmentStageOutput +{ + Fragment fragment : SV_Target; +}; + +FragmentStageOutput fragmentMain(FragmentStageInput input) +{ + FragmentStageOutput output; + float3 color = input.coarseVertex.color; + output.fragment.color = float4(color, 1.0); + outputBuffer[0] = color.x; + // ROWMAJOR: 7.0 + // COLMAJOR: 10.0 + return output; +} diff --git a/tests/spirv/nested-entrypoint.slang b/tests/spirv/nested-entrypoint.slang new file mode 100644 index 000000000..28e9b9c4a --- /dev/null +++ b/tests/spirv/nested-entrypoint.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -fvk-use-entrypoint-name + +// CHECK: OpEntryPoint + +RWStructuredBuffer<int> output; + +[shader("compute")] +[numthreads(1,1,1)] +void innerMain(int id : SV_DispatchThreadID) +{ + output[id] = id; +} + +[shader("compute")] +[numthreads(1,1,1)] +void outerMain(int id : SV_DispatchThreadID) +{ + innerMain(id); +}
\ No newline at end of file diff --git a/tests/spirv/optional-vertex-output.slang b/tests/spirv/optional-vertex-output.slang index df15befa2..7baf02d0b 100644 --- a/tests/spirv/optional-vertex-output.slang +++ b/tests/spirv/optional-vertex-output.slang @@ -20,7 +20,10 @@ struct VSOut { VSOut vertMain(VIn i) { VSOut o; - o.a = i.inA; + if (i.inA.hasValue) + o.a = i.inA; + else + o.a = 0.0; o.outputValues = { true, false, true }; return o; }
\ No newline at end of file diff --git a/tests/vkray/anyhit.slang b/tests/vkray/anyhit.slang index 45d35b1fa..8f5a6e597 100644 --- a/tests/vkray/anyhit.slang +++ b/tests/vkray/anyhit.slang @@ -57,7 +57,6 @@ void main( // SPIRV: OpEntryPoint // SPIRV: BuiltIn HitTriangleVertexPositionsKHR // SPIRV: OpTypePointer HitAttribute{{NV|KHR}} -// SPIRV: OpTypePointer HitAttribute{{NV|KHR}} // SPIRV: OpVariable{{.*}}HitAttribute{{NV|KHR}} // SPIRV: OpIgnoreIntersectionKHR // SPIRV: OpTerminateRayKHR @@ -70,7 +69,6 @@ void main( // GL_SPIRV: OpEntryPoint // GL_SPIRV: BuiltIn HitTriangleVertexPositionsKHR // GL_SPIRV-DAG: OpTypePointer HitAttribute{{NV|KHR}} -// GL_SPIRV-DAG: OpTypePointer HitAttribute{{NV|KHR}} // GL_SPIRV: OpTerminateRayKHR // GL_SPIRV: OpIgnoreIntersectionKHR // GL_SPIRV-DAG: %{{.*}} = OpAccessChain %{{.*}} %{{.*}} %{{.*}} diff --git a/tests/vkray/anyhit.slang.glsl b/tests/vkray/anyhit.slang.glsl index 8255599b9..4d2e5a0dd 100644 --- a/tests/vkray/anyhit.slang.glsl +++ b/tests/vkray/anyhit.slang.glsl @@ -8,7 +8,7 @@ struct Params_0 }; layout(binding = 0) -layout(std140) uniform _S1 +layout(std140) uniform block_Params_0 { int mode_0; }gParams_0; @@ -23,20 +23,21 @@ struct SphereHitAttributes_0 vec3 normal_0; }; -hitAttributeEXT SphereHitAttributes_0 _S2; +hitAttributeEXT SphereHitAttributes_0 _S1; struct ShadowRay_0 { vec4 hitDistance_0; + vec3 dummyOut_0; }; -rayPayloadInEXT ShadowRay_0 _S3; +rayPayloadInEXT ShadowRay_0 _S2; void main() { if(gParams_0.mode_0 != 0) { - if((textureLod(sampler2D(gParams_alphaMap_0,gParams_sampler_0), (_S2.normal_0.xy), (0.0)).x) > 0.0) + if((textureLod(sampler2D(gParams_alphaMap_0,gParams_sampler_0), (_S1.normal_0.xy), (0.0)).x) > 0.0) { terminateRayEXT;; } @@ -45,6 +46,14 @@ void main() ignoreIntersectionEXT;; } } + + vec3 _S3 = (gl_HitTriangleVertexPositionsEXT[(0U)]); + _S2.dummyOut_0 = _S3; + vec3 _S4 = (gl_HitTriangleVertexPositionsEXT[(1U)]); + vec3 _S5 = _S3 + _S4; + _S2.dummyOut_0 = _S5; + vec3 _S6 = (gl_HitTriangleVertexPositionsEXT[(2U)]); + _S2.dummyOut_0 = _S5 + _S6; return; } diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index 5907be66d..b1f957551 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -45,12 +45,16 @@ struct Vertex float position[3]; float color[3]; float uv[2]; + float customData0[4]; + float customData1[4]; + float customData2[4]; + float customData3[4]; }; static const Vertex kVertexData[] = { - {{0, 0, 0.5}, {1, 0, 0}, {0, 0}}, - {{0, 1, 0.5}, {0, 0, 1}, {1, 0}}, - {{1, 0, 0.5}, {0, 1, 0}, {1, 1}}, + {{0, 0, 0.5}, {1, 0, 0}, {0, 0}, {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}, + {{0, 1, 0.5}, {0, 0, 1}, {1, 0}, {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}, + {{1, 0, 0.5}, {0, 1, 0}, {1, 1}, {1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}, }; static const int kVertexCount = SLANG_COUNT_OF(kVertexData); @@ -614,6 +618,10 @@ SlangResult RenderTestApp::initialize( {"A", 0, Format::R32G32B32_FLOAT, offsetof(Vertex, position)}, {"A", 1, Format::R32G32B32_FLOAT, offsetof(Vertex, color)}, {"A", 2, Format::R32G32_FLOAT, offsetof(Vertex, uv)}, + {"A", 3, Format::R32G32B32A32_FLOAT, offsetof(Vertex, customData0)}, + {"A", 4, Format::R32G32B32A32_FLOAT, offsetof(Vertex, customData1)}, + {"A", 5, Format::R32G32B32A32_FLOAT, offsetof(Vertex, customData2)}, + {"A", 6, Format::R32G32B32A32_FLOAT, offsetof(Vertex, customData3)}, }; ComPtr<IInputLayout> inputLayout; |
