diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2023-09-12 11:13:11 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-11 20:13:11 -0700 |
| commit | 09854a4596019ddb3bb315b8836b5c88e718cdc7 (patch) | |
| tree | 1556ae3e00da0fac91343f159b52cee1231a7fab /source | |
| parent | 87bb0b503544f1b8c6ec818e25c695b31cda24b7 (diff) | |
Add Mesh and Task shader support to GFX (#3190)
* Bump vulkan headers
Also just use vulkan-headers as a submodule
* Add drawMeshTasks to gfx graphics pipelines
* Add DispatchMesh overload with no payload, with GLSL intrinsic
* Require spirv 1.4 for mesh shaders
* Add vulkan mesh shader feature discovery
* Add mesh shader stage bits to vk-util
* Add mesh and task shader support to render-test
* Add mesh and task tests
* Preserve "payload" specifier in task shaders
* Add mesh shader pipeline support to gfx
* Add TODO
* Add numThreads attribute for amplification stage
* Add payload to task shader test
* Drop dependency on d3dx12
* Allow passing payloads from task to mesh shaders
* regenerate vs projects
* check DispatchMesh name correctly
* Add mesh shader tests to failing tests
* Detect wave-ops feature on vulkan
* Add fuse-product to expected failures
This fails because the global varaible `count` is not initialized
* Add required extension to WaveMaskMatch SPIR-V impl
* Remove meshShader member from pipeline desc
* Identify mesh shader support on d3d12
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 36 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-emit-cuda.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-cuda.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-glsl.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-emit-glsl.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-emit-hlsl.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-emit-hlsl.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-fuse-satcoop.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-glsl-legalize.cpp | 116 | ||||
| -rw-r--r-- | source/slang/slang-ir-glsl-legalize.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 1 |
18 files changed, 234 insertions, 41 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index c670f234e..752c99a5b 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -5948,9 +5948,8 @@ matrix<T,N,M> WaveMaskReadLaneFirst(WaveMask mask, matrix<T,N,M> expr); __generic<T : __BuiltinType> __glsl_extension(GL_NV_shader_subgroup_partitioned) -__spirv_version(1.3) +__spirv_version(1.1) __cuda_sm_version(7.0) -__spirv_capability(GroupNonUniformPartitionedNV) WaveMask WaveMaskMatch(WaveMask mask, T value) { __target_switch @@ -5959,14 +5958,18 @@ WaveMask WaveMaskMatch(WaveMask mask, T value) case cuda: __intrinsic_asm "_waveMatchScalar($0, $1).x"; case hlsl: __intrinsic_asm "WaveMatch($1).x"; case spirv: - return (spirv_asm {OpGroupNonUniformPartitionNV $$uint4 result $value}).x; + return (spirv_asm + { + OpCapability GroupNonUniformPartitionedNV; + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpGroupNonUniformPartitionNV $$uint4 result $value + }).x; } } __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_NV_shader_subgroup_partitioned) -__spirv_version(1.3) +__spirv_version(1.1) __cuda_sm_version(7.0) -__spirv_capability(GroupNonUniformPartitionedNV) WaveMask WaveMaskMatch(WaveMask mask, vector<T,N> value) { __target_switch @@ -5975,7 +5978,12 @@ WaveMask WaveMaskMatch(WaveMask mask, vector<T,N> value) case cuda: __intrinsic_asm "_waveMatchMultiple($0, $1).x"; case hlsl: __intrinsic_asm "WaveMatch($1).x"; case spirv: - return (spirv_asm {OpGroupNonUniformPartitionNV $$uint4 result $value}).x; + return (spirv_asm + { + OpCapability GroupNonUniformPartitionedNV; + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpGroupNonUniformPartitionNV $$uint4 result $value + }).x; } } @@ -7524,7 +7532,21 @@ void SetMeshOutputCounts(uint vertexCount, uint primitiveCount); // Specify the number of downstream mesh shader thread groups to invoke from an amplification shader, // and provide the values for per-mesh payload parameters. // -void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint threadGroupCountZ, P meshPayload); +// This function doesn't return. +// +[KnownBuiltin("DispatchMesh")] +void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint threadGroupCountZ, __ref P meshPayload) +{ + __target_switch + { + case hlsl: + __intrinsic_asm "DispatchMesh"; + case glsl: + // This intrinsic doesn't take into account writing meshPayload. That + // is dealt with separately by 'legalizeDispatchMeshPayloadForGLSL'. + __intrinsic_asm "EmitMeshTasksEXT($0, $1, $2)"; + } +} // // "Sampler feedback" types `FeedbackTexture2D` and `FeedbackTexture2DArray`. diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 3bd52245a..d70651636 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -985,6 +985,11 @@ class HLSLPrimitivesModifier : public HLSLMeshShaderOutputModifier SLANG_AST_CLASS(HLSLPrimitivesModifier) }; +class HLSLPayloadModifier : public Modifier +{ + SLANG_AST_CLASS(HLSLPayloadModifier) +}; + // A modifier to indicate that a constructor/initializer can be used // to perform implicit type conversion, and to specify the cost of // the conversion, if applied. diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index a8080851f..fa6fd2b43 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1638,11 +1638,14 @@ void CLikeSourceEmitter::emitArgs(IRInst* inst) m_writer->emit(")"); } -void CLikeSourceEmitter::emitRateQualifiers(IRInst* value) +void CLikeSourceEmitter::emitRateQualifiersAndAddressSpace(IRInst* value) { - if (IRRate* rate = value->getRate()) + const auto rate = value->getRate(); + const auto ptrTy = composeGetters<IRPtrTypeBase>(value, &IRInst::getDataType); + const auto addressSpace = ptrTy ? ptrTy->getAddressSpace() : -1; + if (rate || addressSpace != -1) { - emitRateQualifiersImpl(rate); + emitRateQualifiersAndAddressSpaceImpl(rate, addressSpace); } } @@ -1657,7 +1660,7 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) emitTempModifiers(inst); - emitRateQualifiers(inst); + emitRateQualifiersAndAddressSpace(inst); if(as<IRModuleInst>(inst->getParent())) { @@ -3138,7 +3141,7 @@ void CLikeSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) || layout->usesResourceKind(LayoutResourceKind::VaryingOutput)) { emitInterpolationModifiers(param, paramType, layout); - emitMeshOutputModifiers(param); + emitMeshShaderModifiers(param); } } @@ -3514,9 +3517,9 @@ void CLikeSourceEmitter::emitInterpolationModifiers(IRInst* varInst, IRType* val emitInterpolationModifiersImpl(varInst, valueType, layout); } -void CLikeSourceEmitter::emitMeshOutputModifiers(IRInst* varInst) +void CLikeSourceEmitter::emitMeshShaderModifiers(IRInst* varInst) { - emitMeshOutputModifiersImpl(varInst); + emitMeshShaderModifiersImpl(varInst); } /// Emit modifiers that should apply even for a declaration of an SSA temporary. @@ -3547,7 +3550,7 @@ void CLikeSourceEmitter::emitVarModifiers(IRVarLayout* layout, IRInst* varDecl, || layout->usesResourceKind(LayoutResourceKind::VaryingOutput)) { emitInterpolationModifiers(varDecl, varType, layout); - emitMeshOutputModifiers(varDecl); + emitMeshShaderModifiers(varDecl); } // Output target specific qualifiers @@ -3643,7 +3646,7 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl) break; } #endif - emitRateQualifiers(varDecl); + emitRateQualifiersAndAddressSpace(varDecl); emitType(varType, getName(varDecl)); @@ -3725,7 +3728,7 @@ void CLikeSourceEmitter::emitGlobalVar(IRGlobalVar* varDecl) emitVarModifiers(layout, varDecl, varType); - emitRateQualifiers(varDecl); + emitRateQualifiersAndAddressSpace(varDecl); emitType(varType, getName(varDecl)); // TODO: These shouldn't be needed for ordinary @@ -3788,7 +3791,7 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) emitVarModifiers(layout, varDecl, varType); - emitRateQualifiers(varDecl); + emitRateQualifiersAndAddressSpace(varDecl); emitType(varType, getName(varDecl)); emitSemantics(varDecl); diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 62f6d20a2..976ac8e19 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -349,7 +349,7 @@ public: void emitArgs(IRInst* inst); - void emitRateQualifiers(IRInst* value); + void emitRateQualifiersAndAddressSpace(IRInst* value); void emitInstResultDecl(IRInst* inst); @@ -430,7 +430,7 @@ public: void emitPostKeywordTypeAttributes(IRInst* inst) { emitPostKeywordTypeAttributesImpl(inst); } void emitInterpolationModifiers(IRInst* varInst, IRType* valueType, IRVarLayout* layout); - void emitMeshOutputModifiers(IRInst* varInst); + void emitMeshShaderModifiers(IRInst* varInst); virtual void emitPackOffsetModifier(IRInst* /*varInst*/, IRType* /*valueType*/, IRPackOffsetDecoration* /*decoration*/) {}; @@ -511,13 +511,13 @@ public: /// the appropriate generated declarations occur. virtual void emitPreModuleImpl() {} - virtual void emitRateQualifiersImpl(IRRate* rate) { SLANG_UNUSED(rate); } + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); } virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) { SLANG_UNUSED(inst); SLANG_UNUSED(allowOffsetLayout); } virtual void emitSimpleFuncParamImpl(IRParam* param); virtual void emitSimpleFuncParamsImpl(IRFunc* func); virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) { SLANG_UNUSED(varInst); SLANG_UNUSED(valueType); SLANG_UNUSED(layout); } - virtual void emitMeshOutputModifiersImpl(IRInst* varInst) { SLANG_UNUSED(varInst) } + virtual void emitMeshShaderModifiersImpl(IRInst* varInst) { SLANG_UNUSED(varInst) } virtual void emitSimpleTypeImpl(IRType* type) = 0; virtual void emitVarDecorationsImpl(IRInst* varDecl) { SLANG_UNUSED(varDecl); } virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) { SLANG_UNUSED(layout); } diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index fa0e3c7aa..9364f4441 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -742,7 +742,7 @@ void CUDASourceEmitter::emitSimpleTypeImpl(IRType* type) } } -void CUDASourceEmitter::emitRateQualifiersImpl(IRRate* rate) +void CUDASourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace) { if (as<IRGroupSharedRate>(rate)) { diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h index 82d0240b3..a63d087bf 100644 --- a/source/slang/slang-emit-cuda.h +++ b/source/slang/slang-emit-cuda.h @@ -69,7 +69,7 @@ protected: virtual void emitPreModuleImpl() SLANG_OVERRIDE; - virtual void emitRateQualifiersImpl(IRRate* rate) SLANG_OVERRIDE; + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE; virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) SLANG_OVERRIDE; virtual void emitSimpleFuncImpl(IRFunc* func) SLANG_OVERRIDE; virtual void emitSimpleFuncParamsImpl(IRFunc* func) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index c097de5b9..f82a76d4a 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -47,6 +47,7 @@ SlangResult GLSLSourceEmitter::init() case Stage::Amplification: { _requireGLSLExtension(UnownedStringSlice::fromLiteral("GL_EXT_mesh_shader")); + _requireSPIRVVersion(SemanticVersion(1, 4)); break; } default: break; @@ -845,7 +846,7 @@ void GLSLSourceEmitter::_maybeEmitGLSLBuiltin(IRGlobalParam* var, UnownedStringS // SLANG_ASSERT(layout && "Mesh shader builtin output has no layout"); // SLANG_ASSERT(layout->usesResourceKind(LayoutResourceKind::VaryingOutput)); // emitVarModifiers(layout, var, arrayType); - emitMeshOutputModifiers(var); + emitMeshShaderModifiers(var); m_writer->emit("out"); m_writer->emit(" "); m_writer->emit(elementTypeName); @@ -1186,6 +1187,11 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin } } break; + case Stage::Amplification: + { + emitLocalSizeLayout(); + } + break; // TODO: There are other stages that will need this kind of handling. default: break; @@ -1211,7 +1217,7 @@ void GLSLSourceEmitter::_emitGLSLPerVertexVaryingFragmentInput(IRGlobalParam* pa emitVarModifiers(layout, param, type); - emitRateQualifiers(param); + emitRateQualifiersAndAddressSpace(param); auto name = getName(param); StringSliceLoc nameAndLoc(name.getUnownedSlice()); @@ -2440,9 +2446,13 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled type"); } -void GLSLSourceEmitter::emitRateQualifiersImpl(IRRate* rate) +void GLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) { - if (as<IRConstExprRate>(rate)) + if(addressSpace == SpvStorageClassTaskPayloadWorkgroupEXT) + { + m_writer->emit("taskPayloadSharedEXT "); + } + else if (as<IRConstExprRate>(rate)) { m_writer->emit("const "); @@ -2565,7 +2575,7 @@ void GLSLSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueTyp m_writer->emit(")\n"); } -void GLSLSourceEmitter::emitMeshOutputModifiersImpl(IRInst* varInst) +void GLSLSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst) { if(varInst->findDecoration<IRGLSLPrimitivesRateDecoration>()) { diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h index 780b24453..f15ea03b4 100644 --- a/source/slang/slang-emit-glsl.h +++ b/source/slang/slang-emit-glsl.h @@ -31,11 +31,11 @@ protected: virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; - virtual void emitRateQualifiersImpl(IRRate* rate) SLANG_OVERRIDE; + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE; virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE; virtual void emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* decoration) SLANG_OVERRIDE; - virtual void emitMeshOutputModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; + virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 66902a624..9defc9adc 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -442,6 +442,11 @@ void HLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin } break; } + case Stage::Amplification: + { + emitNumThreadsAttribute(); + break; + } // TODO: There are other stages that will need this kind of handling. default: break; @@ -1032,7 +1037,7 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) } } -void HLSLSourceEmitter::emitRateQualifiersImpl(IRRate* rate) +void HLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace) { if (as<IRGroupSharedRate>(rate)) { @@ -1138,8 +1143,11 @@ void HLSLSourceEmitter::_emitPrefixTypeAttr(IRAttr* attr) void HLSLSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) { - emitRateQualifiers(param); - emitMeshOutputModifiers(param); + // A mesh shader input payload has it's own weird stuff going on, handled + // in emitMeshShaderModifiers, skip this bit which will introduce an + // invalid "groupshared" keyword. + if (!param->findDecoration<IRHLSLMeshPayloadDecoration>()) + emitRateQualifiersAndAddressSpace(param); if (auto decor = param->findDecoration<IRGeometryInputPrimitiveTypeDecoration>()) { @@ -1200,7 +1208,7 @@ void HLSLSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueTyp // We emit packoffset as a semantic in `emitSemantic`, so nothing to do here. } -void HLSLSourceEmitter::emitMeshOutputModifiersImpl(IRInst* varInst) +void HLSLSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst) { if(auto modifier = varInst->findDecoration<IRMeshOutputDecoration>()) { @@ -1212,6 +1220,11 @@ void HLSLSourceEmitter::emitMeshOutputModifiersImpl(IRInst* varInst) SLANG_ASSERT(s && "Unhandled type of mesh output decoration"); m_writer->emit(s); } + if(varInst->findDecoration<IRHLSLMeshPayloadDecoration>()) + { + // DXC requires that mesh payload parameters have "in" specified + m_writer->emit("in payload "); + } } void HLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl) diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index 08363bceb..707667e90 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -35,13 +35,13 @@ protected: virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; - virtual void emitRateQualifiersImpl(IRRate* rate) SLANG_OVERRIDE; + virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE; virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsets) SLANG_OVERRIDE; virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE; virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE; virtual void emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* decoration) SLANG_OVERRIDE; - virtual void emitMeshOutputModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; + virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE; virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 4bee37746..543bb089d 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -806,9 +806,11 @@ Result linkAndOptimizeIR( { case CodeGenTarget::GLSL: case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: { legalizeImageSubscriptForGLSL(irModule); legalizeConstantBufferLoadForGLSL(irModule); + legalizeDispatchMeshPayloadForGLSL(irModule); } break; default: @@ -879,6 +881,10 @@ Result linkAndOptimizeIR( // // If any have survived this far, change them back to regular (decorated) // arrays that the emitters can deal with. + // + // TODO: This is too early for the SPIR-V backend, which requires these + // types for when it calls legalizeEntryPointsForGLSL (later than GLSL does + // above) legalizeMeshOutputTypes(irModule); // We need to lower any types used in a buffer resource (e.g. ContantBuffer or StructuredBuffer) into diff --git a/source/slang/slang-ir-fuse-satcoop.cpp b/source/slang/slang-ir-fuse-satcoop.cpp index b672f3f7c..e6b3d7f10 100644 --- a/source/slang/slang-ir-fuse-satcoop.cpp +++ b/source/slang/slang-ir-fuse-satcoop.cpp @@ -432,7 +432,11 @@ IRCall* isKnownFunction(const char* n, IRInst* i) if(!generic) return nullptr; - auto h = generic->findDecoration<IRKnownBuiltinDecoration>(); + auto inner = findGenericReturnVal(generic); + if(!inner) + return nullptr; + + auto h = inner->findDecoration<IRKnownBuiltinDecoration>(); if(!h || h->getName() != n) return nullptr; return call; diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 23ddfc50f..866fd3a4d 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -7,8 +7,8 @@ #include "slang-ir-insts.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-specialize-function-call.h" - #include "slang-glsl-extension-tracker.h" +#include "../../external/spirv-headers/include/spirv/unified1/spirv.h" namespace Slang { @@ -1929,7 +1929,35 @@ void legalizeRayTracingEntryPointParameterForGLSL( builder->addDependsOnDecoration(func, globalParam); } -void legalizeMeshOutputParam( +static void legalizeMeshPayloadInputParam( + GLSLLegalizationContext* context, + CodeGenContext* codeGenContext, + IRParam* pp) +{ + auto builder = context->getBuilder(); + auto stage = context->getStage(); + SLANG_ASSERT(stage == Stage::Mesh && "legalizing mesh payload input, but we're not a mesh shader"); + IRBuilderInsertLocScope locScope{builder}; + builder->setInsertInto(builder->getModule()); + + const auto g = builder->emitVar(pp->getDataType(), SpvStorageClassTaskPayloadWorkgroupEXT); + g->setFullType(builder->getRateQualifiedType(builder->getGroupSharedRate(), g->getFullType())); + // moveValueBefore(g, builder->getFunc()); + builder->addNameHintDecoration(g, pp->findDecoration<IRNameHintDecoration>()->getName()); + pp->replaceUsesWith(g); + struct MeshPayloadInputSpecializationCondition : FunctionCallSpecializeCondition + { + bool doesParamWantSpecialization(IRParam*, IRInst* arg) + { + return arg == g; + } + IRInst* g; + } condition; + condition.g = g; + specializeFunctionCalls(codeGenContext, builder->getModule(), &condition); +} + +static void legalizeMeshOutputParam( GLSLLegalizationContext* context, CodeGenContext* codeGenContext, IRFunc* func, @@ -2354,6 +2382,7 @@ void legalizeEntryPointParameterForGLSL( // don't fit into the standard varying model. // - Geometry shader output streams // - Mesh shader outputs + // - Mesh shader payload input if( auto paramPtrType = as<IROutTypeBase>(paramType) ) { auto valueType = paramPtrType->getValueType(); @@ -2476,6 +2505,10 @@ void legalizeEntryPointParameterForGLSL( return legalizeMeshOutputParam(context, codeGenContext, func, pp, paramLayout, meshOutputType); } } + else if(pp->findDecoration<IRHLSLMeshPayloadDecoration>()) + { + return legalizeMeshPayloadInputParam(context, codeGenContext, pp); + } // When we have an HLSL ray tracing shader entry point, // we don't want to translate the inputs/outputs for GLSL/SPIR-V @@ -2878,4 +2911,83 @@ void legalizeConstantBufferLoadForGLSL(IRModule* module) } +void legalizeDispatchMeshPayloadForGLSL(IRModule* module) +{ + // Find out DispatchMesh function + IRGlobalValueWithCode* dispatchMeshFunc = nullptr; + for(const auto globalInst : module->getGlobalInsts()) + { + if(const auto func = as<IRGlobalValueWithCode>(globalInst)) + { + if(const auto dec = func->findDecoration<IRKnownBuiltinDecoration>()) + { + if(dec->getName() == "DispatchMesh") + { + SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found"); + dispatchMeshFunc = func; + } + } + } + } + + if(!dispatchMeshFunc) + return; + + IRBuilder builder{module}; + builder.setInsertBefore(dispatchMeshFunc); + + // We'll rewrite the calls to call EmitMeshTasksEXT + traverseUses(dispatchMeshFunc, [&](const IRUse* use){ + if(const auto call = as<IRCall>(use->getUser())) + { + SLANG_ASSERT(call->getArgCount() == 4); + const auto payload = call->getArg(3); + + const auto payloadPtrType = composeGetters<IRPtrTypeBase>( + payload, + &IRInst::getDataType + ); + SLANG_ASSERT(payloadPtrType); + const auto payloadType = payloadPtrType->getValueType(); + SLANG_ASSERT(payloadType); + + const bool isGroupsharedGlobal = + payload->getParent() == module->getModuleInst() && + composeGetters<IRGroupSharedRate>(payload, &IRInst::getRate); + if(isGroupsharedGlobal) + { + // If it's a groupshared global, then we put it in the address + // space we know to emit as taskPayloadSharedEXT instead (or + // naturally fall through correctly for SPIR-V emit) + // + // Keep it as a groupshared rate qualified type so we don't + // miss out on any further legalization requirement or + // optimization opportunities. + const auto payloadSharedPtrType = + builder.getRateQualifiedType( + builder.getGroupSharedRate(), + builder.getPtrType( + payloadPtrType->getOp(), + payloadPtrType->getValueType(), + SpvStorageClassTaskPayloadWorkgroupEXT + ) + ); + payload->setFullType(payloadSharedPtrType); + } + else + { + // ... + // If it's not a groupshared global, then create such a + // parameter and store into the value being passed to this + // call. + builder.setInsertInto(module->getModuleInst()); + const auto v = builder.emitVar(payloadType, SpvStorageClassTaskPayloadWorkgroupEXT); + v->setFullType(builder.getRateQualifiedType(builder.getGroupSharedRate(), v->getFullType())); + builder.setInsertBefore(call); + builder.emitStore(v, payload); + } + } + }); +} + } // namespace Slang diff --git a/source/slang/slang-ir-glsl-legalize.h b/source/slang/slang-ir-glsl-legalize.h index 1816df1f2..6fd0b642e 100644 --- a/source/slang/slang-ir-glsl-legalize.h +++ b/source/slang/slang-ir-glsl-legalize.h @@ -25,4 +25,6 @@ void legalizeImageSubscriptForGLSL(IRModule* module); void legalizeConstantBufferLoadForGLSL(IRModule* module); +void legalizeDispatchMeshPayloadForGLSL(IRModule* module); + } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f48801162..552a8af6d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -813,6 +813,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(IndicesDecoration, indices, 1, 0) INST(PrimitivesDecoration, primitives, 1, 0) INST_RANGE(MeshOutputDecoration, VerticesDecoration, PrimitivesDecoration) + INST(HLSLMeshPayloadDecoration, payload, 0, 0) INST(GLSLPrimitivesRateDecoration, perprimitive, 0, 0) // Marks an inst that represents the gl_Position output. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index bfcca5b02..d8788aa06 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -360,6 +360,7 @@ IR_SIMPLE_DECORATION(NoInlineDecoration) IR_SIMPLE_DECORATION(AlwaysFoldIntoUseSiteDecoration) IR_SIMPLE_DECORATION(StaticRequirementDecoration) IR_SIMPLE_DECORATION(NonCopyableTypeDecoration) +IR_SIMPLE_DECORATION(HLSLMeshPayloadDecoration) struct IRNVAPIMagicDecoration : IRDecoration { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4266b46f9..106e5b5a3 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1355,7 +1355,10 @@ static void addLinkageDecoration( } else if (as<KnownBuiltinAttribute>(modifier)) { - builder->addKnownBuiltinDecoration(inst, decl->getName()->text.getUnownedSlice()); + // We add this to the internal instruction, like other name-like + // decorations, for instance "nameHint". This prevents it becoming + // lost during specialization. + builder->addKnownBuiltinDecoration(inInst, decl->getName()->text.getUnownedSlice()); } } if (as<InterfaceDecl>(decl->parentDecl) && @@ -2110,6 +2113,10 @@ void addVarDecorations( { builder->addFormatDecoration(inst, formatAttr->format); } + else if(auto payloadMod = as<HLSLPayloadModifier>(mod)) + { + builder->addSimpleDecoration<IRHLSLMeshPayloadDecoration>(inst); + } // TODO: what are other modifiers we need to propagate through? } @@ -2977,6 +2984,12 @@ void _lowerFuncDeclBaseTypeInfo( irParamType = builder->getRateQualifiedType(builder->getGroupSharedRate(), irParamType); } + // The 'payload' parameter is a read-only groupshared value + if(paramInfo.decl && paramInfo.decl->hasModifier<HLSLPayloadModifier>()) + { + irParamType = builder->getRateQualifiedType(builder->getGroupSharedRate(), irParamType); + } + paramTypes.add(irParamType); } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index e184585e3..3f5267577 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -7302,6 +7302,7 @@ namespace Slang _makeParseModifier("vertices", HLSLVerticesModifier::kReflectClassInfo), _makeParseModifier("indices", HLSLIndicesModifier::kReflectClassInfo), _makeParseModifier("primitives", HLSLPrimitivesModifier::kReflectClassInfo), + _makeParseModifier("payload", HLSLPayloadModifier::kReflectClassInfo), // Modifiers for unary operator declarations _makeParseModifier("__prefix", PrefixModifier::kReflectClassInfo), |
