diff options
| author | Yong He <yonghe@outlook.com> | 2024-03-25 09:24:04 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-03-25 09:24:04 -0700 |
| commit | cbe55261ce07c9f737b5117dd1f703950190d843 (patch) | |
| tree | 2962165562aad4fb94862a41d28c8f6d04e468ab | |
| parent | c9df734b836a503dbc09c48bfd54b35facd0f105 (diff) | |
Fix missing PerPrimitive decoration in mesh shader output. (#3828)
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-glsl-legalize.cpp | 61 | ||||
| -rw-r--r-- | source/slang/slang-ir-translate-glsl-global-var.cpp | 4 | ||||
| -rw-r--r-- | tests/spirv/mesh-primitive.slang | 64 |
4 files changed, 126 insertions, 17 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 6ecffecd1..c48f0e978 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -2141,10 +2141,18 @@ struct SPIRVEmitContext } } - if (var->findDecorationImpl(kIROp_RequireSPIRVDescriptorIndexingExtensionDecoration)) + for (auto decor : var->getDecorations()) { - ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_descriptor_indexing")); - requireSPIRVCapability(SpvCapabilityRuntimeDescriptorArray); + switch (decor->getOp()) + { + case kIROp_GLSLPrimitivesRateDecoration: + emitOpDecorate(getSection(SpvLogicalSectionID::Annotations), decor, varInst, SpvDecorationPerPrimitiveEXT); + break; + case kIROp_RequireSPIRVDescriptorIndexingExtensionDecoration: + ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_descriptor_indexing")); + requireSPIRVCapability(SpvCapabilityRuntimeDescriptorArray); + break; + } } } diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 3c6592164..c4a8fff5d 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -957,6 +957,20 @@ GLSLSystemValueInfo* getGLSLSystemValueInfo( return nullptr; } +// Hold the in-stack linked list that represents the access chain +// to the current global varying parameter being created. +// e.g. if the user code has: +// struct Params { in float member; } +// void main(in Params inParams); +// Then the `outerParamInfo` when we get to `createSimpleGLSLVarying` for `member` +// will be: {IRStructField member} -> {IRParam inParams} -> {IRFunc main}. +// +struct OuterParamInfoLink +{ + IRInst* outerParam; + OuterParamInfoLink* next; +}; + void createVarLayoutForLegalizedGlobalParam( IRInst* globalParam, IRBuilder* builder, @@ -966,7 +980,7 @@ void createVarLayoutForLegalizedGlobalParam( UInt bindingIndex, UInt bindingSpace, GlobalVaryingDeclarator* declarator, - IRInst* leafVar, + OuterParamInfoLink* outerParamInfo, GLSLSystemValueInfo* systemValueInfo) { // We need to construct a fresh layout for the variable, even @@ -982,13 +996,22 @@ void createVarLayoutForLegalizedGlobalParam( IRVarLayout* varLayout = varLayoutBuilder.build(); builder->addLayoutDecoration(globalParam, varLayout); - if (leafVar) + // Traverse the entire access chain for the current leaf var and see if + // there are interpolation mode decorations along the way. + // Make sure we respect the decoration on the inner most node. + // So that the decoration on a struct field overrides the outer decoration + // on a parameter of the struct type. + for (; outerParamInfo; outerParamInfo = outerParamInfo->next) { - if (auto interpolationModeDecor = leafVar->findDecoration<IRInterpolationModeDecoration>()) + auto paramInfo = outerParamInfo->outerParam; + auto decorParent = paramInfo; + if (auto field = as<IRStructField>(decorParent)) + decorParent = field->getKey(); + if (auto interpolationModeDecor = decorParent->findDecoration<IRInterpolationModeDecoration>()) { builder->addInterpolationModeDecoration(globalParam, interpolationModeDecor->getMode()); + break; } - } if (declarator && declarator->flavor == GlobalVaryingDeclarator::Flavor::meshOutputPrimitives) @@ -1031,7 +1054,7 @@ ScalarizedVal createSimpleGLSLGlobalVarying( UInt bindingIndex, UInt bindingSpace, GlobalVaryingDeclarator* declarator, - IRInst* leafVar, + OuterParamInfoLink* outerParamInfo, StringBuilder& nameHintSB) { // Check if we have a system value on our hands. @@ -1101,7 +1124,7 @@ ScalarizedVal createSimpleGLSLGlobalVarying( builder->addImportDecoration(globalParam, systemValueName); createVarLayoutForLegalizedGlobalParam( - globalParam, builder, inVarLayout, inTypeLayout, kind, bindingIndex, bindingSpace, declarator, leafVar, systemValueInfo); + globalParam, builder, inVarLayout, inTypeLayout, kind, bindingIndex, bindingSpace, declarator, outerParamInfo, systemValueInfo); semanticGlobalTmp.globalParam = globalParam; @@ -1239,7 +1262,7 @@ ScalarizedVal createSimpleGLSLGlobalVarying( } createVarLayoutForLegalizedGlobalParam( - globalParam, builder, inVarLayout, typeLayout, kind, bindingIndex, bindingSpace, declarator, leafVar, systemValueInfo); + globalParam, builder, inVarLayout, typeLayout, kind, bindingIndex, bindingSpace, declarator, outerParamInfo, systemValueInfo); return val; } @@ -1255,6 +1278,7 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( UInt bindingIndex, UInt bindingSpace, GlobalVaryingDeclarator* declarator, + OuterParamInfoLink* outerParamInfo, IRInst* leafVar, StringBuilder& nameHintSB) { @@ -1267,14 +1291,14 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( return createSimpleGLSLGlobalVarying( context, codeGenContext, - builder, type, varLayout, typeLayout, kind, stage, bindingIndex, bindingSpace, declarator, leafVar, nameHintSB); + builder, type, varLayout, typeLayout, kind, stage, bindingIndex, bindingSpace, declarator, outerParamInfo, nameHintSB); } else if( as<IRVectorType>(type) ) { return createSimpleGLSLGlobalVarying( context, codeGenContext, - builder, type, varLayout, typeLayout, kind, stage, bindingIndex, bindingSpace, declarator, leafVar, nameHintSB); + builder, type, varLayout, typeLayout, kind, stage, bindingIndex, bindingSpace, declarator, outerParamInfo, nameHintSB); } else if( as<IRMatrixType>(type) ) { @@ -1282,7 +1306,7 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( return createSimpleGLSLGlobalVarying( context, codeGenContext, - builder, type, varLayout, typeLayout, kind, stage, bindingIndex, bindingSpace, declarator, leafVar, nameHintSB); + builder, type, varLayout, typeLayout, kind, stage, bindingIndex, bindingSpace, declarator, outerParamInfo, nameHintSB); } else if( auto arrayType = as<IRArrayType>(type) ) { @@ -1311,6 +1335,7 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( bindingIndex, bindingSpace, &arrayDeclarator, + outerParamInfo, leafVar, nameHintSB); } @@ -1355,6 +1380,7 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( bindingIndex, bindingSpace, &arrayDeclarator, + outerParamInfo, leafVar, nameHintSB); } @@ -1377,6 +1403,7 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( bindingIndex, bindingSpace, declarator, + outerParamInfo, leafVar, nameHintSB); } @@ -1389,6 +1416,11 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( SLANG_ASSERT(structTypeLayout); RefPtr<ScalarizedTupleValImpl> tupleValImpl = new ScalarizedTupleValImpl(); + // Since we are going to recurse into struct fields, + // we need to create a new node in `outerParamInfo` to keep track of + // the access chain to get to the new leafVar. + OuterParamInfoLink fieldParentInfo; + fieldParentInfo.next = outerParamInfo; // Construct the actual type for the tuple (including any outer arrays) IRType* fullType = type; @@ -1436,6 +1468,7 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( nameHintSB << "."; nameHintSB << fieldNameHint->getName(); } + fieldParentInfo.outerParam = field; auto fieldVal = createGLSLGlobalVaryingsImpl( context, codeGenContext, @@ -1448,6 +1481,7 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( fieldBindingIndex, fieldBindingSpace, declarator, + &fieldParentInfo, field, nameHintSB); if (fieldVal.flavor != ScalarizedVal::Flavor::none) @@ -1467,7 +1501,7 @@ ScalarizedVal createGLSLGlobalVaryingsImpl( return createSimpleGLSLGlobalVarying( context, codeGenContext, - builder, type, varLayout, typeLayout, kind, stage, bindingIndex, bindingSpace, declarator, leafVar, nameHintSB); + builder, type, varLayout, typeLayout, kind, stage, bindingIndex, bindingSpace, declarator, outerParamInfo, nameHintSB); } ScalarizedVal createGLSLGlobalVaryings( @@ -1492,10 +1526,13 @@ ScalarizedVal createGLSLGlobalVaryings( { namehintSB << nameHint->getName(); } + OuterParamInfoLink outerParamInfo; + outerParamInfo.next = nullptr; + outerParamInfo.outerParam = leafVar; return createGLSLGlobalVaryingsImpl( context, codeGenContext, - builder, type, layout, layout->getTypeLayout(), kind, stage, bindingIndex, bindingSpace, nullptr, leafVar, namehintSB); + builder, type, layout, layout->getTypeLayout(), kind, stage, bindingIndex, bindingSpace, nullptr, &outerParamInfo, leafVar, namehintSB); } ScalarizedVal extractField( diff --git a/source/slang/slang-ir-translate-glsl-global-var.cpp b/source/slang/slang-ir-translate-glsl-global-var.cpp index 775c1b9d6..90bc49b32 100644 --- a/source/slang/slang-ir-translate-glsl-global-var.cpp +++ b/source/slang/slang-ir-translate-glsl-global-var.cpp @@ -79,7 +79,7 @@ namespace Slang { builder.addNameHintDecoration(key, nameHint->getName()); } - auto field = builder.createStructField(inputStructType, key, inputType); + builder.createStructField(inputStructType, key, inputType); IRTypeLayout::Builder fieldTypeLayout(&builder); IRVarLayout::Builder varLayoutBuilder(&builder, fieldTypeLayout.build()); varLayoutBuilder.setStage(entryPointDecor->getProfile().getStage()); @@ -105,7 +105,7 @@ namespace Slang inputVarIndex++; } inputStructTypeLayoutBuilder.addField(key, varLayoutBuilder.build()); - input->transferDecorationsTo(field); + input->transferDecorationsTo(key); } auto paramTypeLayout = inputStructTypeLayoutBuilder.build(); IRVarLayout::Builder paramVarLayoutBuilder(&builder, paramTypeLayout); diff --git a/tests/spirv/mesh-primitive.slang b/tests/spirv/mesh-primitive.slang new file mode 100644 index 000000000..4daf21749 --- /dev/null +++ b/tests/spirv/mesh-primitive.slang @@ -0,0 +1,64 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-directly + +// CHECK: OpDecorate %primitives_color Location 0 +// CHECK: OpDecorate %primitives_color PerPrimitive +// CHECK: OpDecorate %prim_color Location 0 +// CHECK: OpDecorate %prim_color Flat + +const static uint MAX_VERTS = 6; +const static uint MAX_PRIMS = 2; + +const static float2 positions[MAX_VERTS] = { + float2(0.0, -0.5), + float2(0.5, 0), + float2(-0.5, 0), + float2(0.0, 0.5), + float2(0.5, 0), + float2(-0.5, 0), +}; + +struct Vertex +{ + float4 pos : SV_Position; +}; + +struct Primitive +{ + [[vk::location(0)]] float3 color; +} + +[outputtopology("triangle")] +[numthreads(MAX_VERTS, 1, 1)] +[shader("mesh")] +void entry_mesh( + in uint tig : SV_GroupThreadID, + OutputVertices<Vertex, MAX_VERTS> verts, + OutputIndices<uint3, MAX_PRIMS> triangles, + OutputPrimitives<Primitive, MAX_PRIMS> primitives) +{ + const uint numVertices = MAX_VERTS; + const uint numPrimitives = MAX_PRIMS; + SetMeshOutputCounts(numVertices, numPrimitives); + + if(tig < numVertices) { + verts[tig] = {float4(positions[tig], 0, 1)}; + } + + if(tig < numPrimitives) { + triangles[tig] = uint3(0,1,2) + tig * 3; + primitives[tig] = { float3(1,0,0) }; + } +} + +struct FragmentOut +{ + [[vk::location(0)]] float3 color; +}; + +[shader("fragment")] +FragmentOut entry_fragment(in nointerpolation Primitive prim) +{ + FragmentOut frag_out; + frag_out.color = prim.color; + return frag_out; +}
\ No newline at end of file |
