summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2023-09-12 11:13:11 +0800
committerGitHub <noreply@github.com>2023-09-11 20:13:11 -0700
commit09854a4596019ddb3bb315b8836b5c88e718cdc7 (patch)
tree1556ae3e00da0fac91343f159b52cee1231a7fab /source
parent87bb0b503544f1b8c6ec818e25c695b31cda24b7 (diff)
Add Mesh and Task shader support to GFX (#3190)
* Bump vulkan headers Also just use vulkan-headers as a submodule * Add drawMeshTasks to gfx graphics pipelines * Add DispatchMesh overload with no payload, with GLSL intrinsic * Require spirv 1.4 for mesh shaders * Add vulkan mesh shader feature discovery * Add mesh shader stage bits to vk-util * Add mesh and task shader support to render-test * Add mesh and task tests * Preserve "payload" specifier in task shaders * Add mesh shader pipeline support to gfx * Add TODO * Add numThreads attribute for amplification stage * Add payload to task shader test * Drop dependency on d3dx12 * Allow passing payloads from task to mesh shaders * regenerate vs projects * check DispatchMesh name correctly * Add mesh shader tests to failing tests * Detect wave-ops feature on vulkan * Add fuse-product to expected failures This fails because the global varaible `count` is not initialized * Add required extension to WaveMaskMatch SPIR-V impl * Remove meshShader member from pipeline desc * Identify mesh shader support on d3d12
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang36
-rw-r--r--source/slang/slang-ast-modifier.h5
-rw-r--r--source/slang/slang-emit-c-like.cpp25
-rw-r--r--source/slang/slang-emit-c-like.h8
-rw-r--r--source/slang/slang-emit-cuda.cpp2
-rw-r--r--source/slang/slang-emit-cuda.h2
-rw-r--r--source/slang/slang-emit-glsl.cpp20
-rw-r--r--source/slang/slang-emit-glsl.h4
-rw-r--r--source/slang/slang-emit-hlsl.cpp21
-rw-r--r--source/slang/slang-emit-hlsl.h4
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-ir-fuse-satcoop.cpp6
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp116
-rw-r--r--source/slang/slang-ir-glsl-legalize.h2
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp15
-rw-r--r--source/slang/slang-parser.cpp1
18 files changed, 234 insertions, 41 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index c670f234e..752c99a5b 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -5948,9 +5948,8 @@ matrix<T,N,M> WaveMaskReadLaneFirst(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinType>
__glsl_extension(GL_NV_shader_subgroup_partitioned)
-__spirv_version(1.3)
+__spirv_version(1.1)
__cuda_sm_version(7.0)
-__spirv_capability(GroupNonUniformPartitionedNV)
WaveMask WaveMaskMatch(WaveMask mask, T value)
{
__target_switch
@@ -5959,14 +5958,18 @@ WaveMask WaveMaskMatch(WaveMask mask, T value)
case cuda: __intrinsic_asm "_waveMatchScalar($0, $1).x";
case hlsl: __intrinsic_asm "WaveMatch($1).x";
case spirv:
- return (spirv_asm {OpGroupNonUniformPartitionNV $$uint4 result $value}).x;
+ return (spirv_asm
+ {
+ OpCapability GroupNonUniformPartitionedNV;
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpGroupNonUniformPartitionNV $$uint4 result $value
+ }).x;
}
}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_NV_shader_subgroup_partitioned)
-__spirv_version(1.3)
+__spirv_version(1.1)
__cuda_sm_version(7.0)
-__spirv_capability(GroupNonUniformPartitionedNV)
WaveMask WaveMaskMatch(WaveMask mask, vector<T,N> value)
{
__target_switch
@@ -5975,7 +5978,12 @@ WaveMask WaveMaskMatch(WaveMask mask, vector<T,N> value)
case cuda: __intrinsic_asm "_waveMatchMultiple($0, $1).x";
case hlsl: __intrinsic_asm "WaveMatch($1).x";
case spirv:
- return (spirv_asm {OpGroupNonUniformPartitionNV $$uint4 result $value}).x;
+ return (spirv_asm
+ {
+ OpCapability GroupNonUniformPartitionedNV;
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpGroupNonUniformPartitionNV $$uint4 result $value
+ }).x;
}
}
@@ -7524,7 +7532,21 @@ void SetMeshOutputCounts(uint vertexCount, uint primitiveCount);
// Specify the number of downstream mesh shader thread groups to invoke from an amplification shader,
// and provide the values for per-mesh payload parameters.
//
-void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint threadGroupCountZ, P meshPayload);
+// This function doesn't return.
+//
+[KnownBuiltin("DispatchMesh")]
+void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint threadGroupCountZ, __ref P meshPayload)
+{
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm "DispatchMesh";
+ case glsl:
+ // This intrinsic doesn't take into account writing meshPayload. That
+ // is dealt with separately by 'legalizeDispatchMeshPayloadForGLSL'.
+ __intrinsic_asm "EmitMeshTasksEXT($0, $1, $2)";
+ }
+}
//
// "Sampler feedback" types `FeedbackTexture2D` and `FeedbackTexture2DArray`.
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 3bd52245a..d70651636 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -985,6 +985,11 @@ class HLSLPrimitivesModifier : public HLSLMeshShaderOutputModifier
SLANG_AST_CLASS(HLSLPrimitivesModifier)
};
+class HLSLPayloadModifier : public Modifier
+{
+ SLANG_AST_CLASS(HLSLPayloadModifier)
+};
+
// A modifier to indicate that a constructor/initializer can be used
// to perform implicit type conversion, and to specify the cost of
// the conversion, if applied.
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index a8080851f..fa6fd2b43 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1638,11 +1638,14 @@ void CLikeSourceEmitter::emitArgs(IRInst* inst)
m_writer->emit(")");
}
-void CLikeSourceEmitter::emitRateQualifiers(IRInst* value)
+void CLikeSourceEmitter::emitRateQualifiersAndAddressSpace(IRInst* value)
{
- if (IRRate* rate = value->getRate())
+ const auto rate = value->getRate();
+ const auto ptrTy = composeGetters<IRPtrTypeBase>(value, &IRInst::getDataType);
+ const auto addressSpace = ptrTy ? ptrTy->getAddressSpace() : -1;
+ if (rate || addressSpace != -1)
{
- emitRateQualifiersImpl(rate);
+ emitRateQualifiersAndAddressSpaceImpl(rate, addressSpace);
}
}
@@ -1657,7 +1660,7 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst)
emitTempModifiers(inst);
- emitRateQualifiers(inst);
+ emitRateQualifiersAndAddressSpace(inst);
if(as<IRModuleInst>(inst->getParent()))
{
@@ -3138,7 +3141,7 @@ void CLikeSourceEmitter::emitSimpleFuncParamImpl(IRParam* param)
|| layout->usesResourceKind(LayoutResourceKind::VaryingOutput))
{
emitInterpolationModifiers(param, paramType, layout);
- emitMeshOutputModifiers(param);
+ emitMeshShaderModifiers(param);
}
}
@@ -3514,9 +3517,9 @@ void CLikeSourceEmitter::emitInterpolationModifiers(IRInst* varInst, IRType* val
emitInterpolationModifiersImpl(varInst, valueType, layout);
}
-void CLikeSourceEmitter::emitMeshOutputModifiers(IRInst* varInst)
+void CLikeSourceEmitter::emitMeshShaderModifiers(IRInst* varInst)
{
- emitMeshOutputModifiersImpl(varInst);
+ emitMeshShaderModifiersImpl(varInst);
}
/// Emit modifiers that should apply even for a declaration of an SSA temporary.
@@ -3547,7 +3550,7 @@ void CLikeSourceEmitter::emitVarModifiers(IRVarLayout* layout, IRInst* varDecl,
|| layout->usesResourceKind(LayoutResourceKind::VaryingOutput))
{
emitInterpolationModifiers(varDecl, varType, layout);
- emitMeshOutputModifiers(varDecl);
+ emitMeshShaderModifiers(varDecl);
}
// Output target specific qualifiers
@@ -3643,7 +3646,7 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl)
break;
}
#endif
- emitRateQualifiers(varDecl);
+ emitRateQualifiersAndAddressSpace(varDecl);
emitType(varType, getName(varDecl));
@@ -3725,7 +3728,7 @@ void CLikeSourceEmitter::emitGlobalVar(IRGlobalVar* varDecl)
emitVarModifiers(layout, varDecl, varType);
- emitRateQualifiers(varDecl);
+ emitRateQualifiersAndAddressSpace(varDecl);
emitType(varType, getName(varDecl));
// TODO: These shouldn't be needed for ordinary
@@ -3788,7 +3791,7 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl)
emitVarModifiers(layout, varDecl, varType);
- emitRateQualifiers(varDecl);
+ emitRateQualifiersAndAddressSpace(varDecl);
emitType(varType, getName(varDecl));
emitSemantics(varDecl);
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index 62f6d20a2..976ac8e19 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -349,7 +349,7 @@ public:
void emitArgs(IRInst* inst);
- void emitRateQualifiers(IRInst* value);
+ void emitRateQualifiersAndAddressSpace(IRInst* value);
void emitInstResultDecl(IRInst* inst);
@@ -430,7 +430,7 @@ public:
void emitPostKeywordTypeAttributes(IRInst* inst) { emitPostKeywordTypeAttributesImpl(inst); }
void emitInterpolationModifiers(IRInst* varInst, IRType* valueType, IRVarLayout* layout);
- void emitMeshOutputModifiers(IRInst* varInst);
+ void emitMeshShaderModifiers(IRInst* varInst);
virtual void emitPackOffsetModifier(IRInst* /*varInst*/, IRType* /*valueType*/, IRPackOffsetDecoration* /*decoration*/) {};
@@ -511,13 +511,13 @@ public:
/// the appropriate generated declarations occur.
virtual void emitPreModuleImpl() {}
- virtual void emitRateQualifiersImpl(IRRate* rate) { SLANG_UNUSED(rate); }
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) { SLANG_UNUSED(rate); SLANG_UNUSED(addressSpace); }
virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) { SLANG_UNUSED(inst); SLANG_UNUSED(allowOffsetLayout); }
virtual void emitSimpleFuncParamImpl(IRParam* param);
virtual void emitSimpleFuncParamsImpl(IRFunc* func);
virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) { SLANG_UNUSED(varInst); SLANG_UNUSED(valueType); SLANG_UNUSED(layout); }
- virtual void emitMeshOutputModifiersImpl(IRInst* varInst) { SLANG_UNUSED(varInst) }
+ virtual void emitMeshShaderModifiersImpl(IRInst* varInst) { SLANG_UNUSED(varInst) }
virtual void emitSimpleTypeImpl(IRType* type) = 0;
virtual void emitVarDecorationsImpl(IRInst* varDecl) { SLANG_UNUSED(varDecl); }
virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) { SLANG_UNUSED(layout); }
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index fa0e3c7aa..9364f4441 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -742,7 +742,7 @@ void CUDASourceEmitter::emitSimpleTypeImpl(IRType* type)
}
}
-void CUDASourceEmitter::emitRateQualifiersImpl(IRRate* rate)
+void CUDASourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace)
{
if (as<IRGroupSharedRate>(rate))
{
diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h
index 82d0240b3..a63d087bf 100644
--- a/source/slang/slang-emit-cuda.h
+++ b/source/slang/slang-emit-cuda.h
@@ -69,7 +69,7 @@ protected:
virtual void emitPreModuleImpl() SLANG_OVERRIDE;
- virtual void emitRateQualifiersImpl(IRRate* rate) SLANG_OVERRIDE;
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE;
virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsetLayout) SLANG_OVERRIDE;
virtual void emitSimpleFuncImpl(IRFunc* func) SLANG_OVERRIDE;
virtual void emitSimpleFuncParamsImpl(IRFunc* func) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index c097de5b9..f82a76d4a 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -47,6 +47,7 @@ SlangResult GLSLSourceEmitter::init()
case Stage::Amplification:
{
_requireGLSLExtension(UnownedStringSlice::fromLiteral("GL_EXT_mesh_shader"));
+ _requireSPIRVVersion(SemanticVersion(1, 4));
break;
}
default: break;
@@ -845,7 +846,7 @@ void GLSLSourceEmitter::_maybeEmitGLSLBuiltin(IRGlobalParam* var, UnownedStringS
// SLANG_ASSERT(layout && "Mesh shader builtin output has no layout");
// SLANG_ASSERT(layout->usesResourceKind(LayoutResourceKind::VaryingOutput));
// emitVarModifiers(layout, var, arrayType);
- emitMeshOutputModifiers(var);
+ emitMeshShaderModifiers(var);
m_writer->emit("out");
m_writer->emit(" ");
m_writer->emit(elementTypeName);
@@ -1186,6 +1187,11 @@ void GLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin
}
}
break;
+ case Stage::Amplification:
+ {
+ emitLocalSizeLayout();
+ }
+ break;
// TODO: There are other stages that will need this kind of handling.
default:
break;
@@ -1211,7 +1217,7 @@ void GLSLSourceEmitter::_emitGLSLPerVertexVaryingFragmentInput(IRGlobalParam* pa
emitVarModifiers(layout, param, type);
- emitRateQualifiers(param);
+ emitRateQualifiersAndAddressSpace(param);
auto name = getName(param);
StringSliceLoc nameAndLoc(name.getUnownedSlice());
@@ -2440,9 +2446,13 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
SLANG_DIAGNOSE_UNEXPECTED(getSink(), SourceLoc(), "unhandled type");
}
-void GLSLSourceEmitter::emitRateQualifiersImpl(IRRate* rate)
+void GLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace)
{
- if (as<IRConstExprRate>(rate))
+ if(addressSpace == SpvStorageClassTaskPayloadWorkgroupEXT)
+ {
+ m_writer->emit("taskPayloadSharedEXT ");
+ }
+ else if (as<IRConstExprRate>(rate))
{
m_writer->emit("const ");
@@ -2565,7 +2575,7 @@ void GLSLSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueTyp
m_writer->emit(")\n");
}
-void GLSLSourceEmitter::emitMeshOutputModifiersImpl(IRInst* varInst)
+void GLSLSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst)
{
if(varInst->findDecoration<IRGLSLPrimitivesRateDecoration>())
{
diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h
index 780b24453..f15ea03b4 100644
--- a/source/slang/slang-emit-glsl.h
+++ b/source/slang/slang-emit-glsl.h
@@ -31,11 +31,11 @@ protected:
virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE;
- virtual void emitRateQualifiersImpl(IRRate* rate) SLANG_OVERRIDE;
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE;
virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE;
virtual void emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* decoration) SLANG_OVERRIDE;
- virtual void emitMeshOutputModifiersImpl(IRInst* varInst) SLANG_OVERRIDE;
+ virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE;
virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE;
virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE;
virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index 66902a624..9defc9adc 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -442,6 +442,11 @@ void HLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin
}
break;
}
+ case Stage::Amplification:
+ {
+ emitNumThreadsAttribute();
+ break;
+ }
// TODO: There are other stages that will need this kind of handling.
default:
break;
@@ -1032,7 +1037,7 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
}
}
-void HLSLSourceEmitter::emitRateQualifiersImpl(IRRate* rate)
+void HLSLSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, [[maybe_unused]] IRIntegerValue addressSpace)
{
if (as<IRGroupSharedRate>(rate))
{
@@ -1138,8 +1143,11 @@ void HLSLSourceEmitter::_emitPrefixTypeAttr(IRAttr* attr)
void HLSLSourceEmitter::emitSimpleFuncParamImpl(IRParam* param)
{
- emitRateQualifiers(param);
- emitMeshOutputModifiers(param);
+ // A mesh shader input payload has it's own weird stuff going on, handled
+ // in emitMeshShaderModifiers, skip this bit which will introduce an
+ // invalid "groupshared" keyword.
+ if (!param->findDecoration<IRHLSLMeshPayloadDecoration>())
+ emitRateQualifiersAndAddressSpace(param);
if (auto decor = param->findDecoration<IRGeometryInputPrimitiveTypeDecoration>())
{
@@ -1200,7 +1208,7 @@ void HLSLSourceEmitter::emitPackOffsetModifier(IRInst* varInst, IRType* valueTyp
// We emit packoffset as a semantic in `emitSemantic`, so nothing to do here.
}
-void HLSLSourceEmitter::emitMeshOutputModifiersImpl(IRInst* varInst)
+void HLSLSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst)
{
if(auto modifier = varInst->findDecoration<IRMeshOutputDecoration>())
{
@@ -1212,6 +1220,11 @@ void HLSLSourceEmitter::emitMeshOutputModifiersImpl(IRInst* varInst)
SLANG_ASSERT(s && "Unhandled type of mesh output decoration");
m_writer->emit(s);
}
+ if(varInst->findDecoration<IRHLSLMeshPayloadDecoration>())
+ {
+ // DXC requires that mesh payload parameters have "in" specified
+ m_writer->emit("in payload ");
+ }
}
void HLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl)
diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h
index 08363bceb..707667e90 100644
--- a/source/slang/slang-emit-hlsl.h
+++ b/source/slang/slang-emit-hlsl.h
@@ -35,13 +35,13 @@ protected:
virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE;
- virtual void emitRateQualifiersImpl(IRRate* rate) SLANG_OVERRIDE;
+ virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRIntegerValue addressSpace) SLANG_OVERRIDE;
virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsets) SLANG_OVERRIDE;
virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE;
virtual void emitInterpolationModifiersImpl(IRInst* varInst, IRType* valueType, IRVarLayout* layout) SLANG_OVERRIDE;
virtual void emitPackOffsetModifier(IRInst* varInst, IRType* valueType, IRPackOffsetDecoration* decoration) SLANG_OVERRIDE;
- virtual void emitMeshOutputModifiersImpl(IRInst* varInst) SLANG_OVERRIDE;
+ virtual void emitMeshShaderModifiersImpl(IRInst* varInst) SLANG_OVERRIDE;
virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE;
virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE;
virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 4bee37746..543bb089d 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -806,9 +806,11 @@ Result linkAndOptimizeIR(
{
case CodeGenTarget::GLSL:
case CodeGenTarget::SPIRV:
+ case CodeGenTarget::SPIRVAssembly:
{
legalizeImageSubscriptForGLSL(irModule);
legalizeConstantBufferLoadForGLSL(irModule);
+ legalizeDispatchMeshPayloadForGLSL(irModule);
}
break;
default:
@@ -879,6 +881,10 @@ Result linkAndOptimizeIR(
//
// If any have survived this far, change them back to regular (decorated)
// arrays that the emitters can deal with.
+ //
+ // TODO: This is too early for the SPIR-V backend, which requires these
+ // types for when it calls legalizeEntryPointsForGLSL (later than GLSL does
+ // above)
legalizeMeshOutputTypes(irModule);
// We need to lower any types used in a buffer resource (e.g. ContantBuffer or StructuredBuffer) into
diff --git a/source/slang/slang-ir-fuse-satcoop.cpp b/source/slang/slang-ir-fuse-satcoop.cpp
index b672f3f7c..e6b3d7f10 100644
--- a/source/slang/slang-ir-fuse-satcoop.cpp
+++ b/source/slang/slang-ir-fuse-satcoop.cpp
@@ -432,7 +432,11 @@ IRCall* isKnownFunction(const char* n, IRInst* i)
if(!generic)
return nullptr;
- auto h = generic->findDecoration<IRKnownBuiltinDecoration>();
+ auto inner = findGenericReturnVal(generic);
+ if(!inner)
+ return nullptr;
+
+ auto h = inner->findDecoration<IRKnownBuiltinDecoration>();
if(!h || h->getName() != n)
return nullptr;
return call;
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 23ddfc50f..866fd3a4d 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -7,8 +7,8 @@
#include "slang-ir-insts.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-specialize-function-call.h"
-
#include "slang-glsl-extension-tracker.h"
+#include "../../external/spirv-headers/include/spirv/unified1/spirv.h"
namespace Slang
{
@@ -1929,7 +1929,35 @@ void legalizeRayTracingEntryPointParameterForGLSL(
builder->addDependsOnDecoration(func, globalParam);
}
-void legalizeMeshOutputParam(
+static void legalizeMeshPayloadInputParam(
+ GLSLLegalizationContext* context,
+ CodeGenContext* codeGenContext,
+ IRParam* pp)
+{
+ auto builder = context->getBuilder();
+ auto stage = context->getStage();
+ SLANG_ASSERT(stage == Stage::Mesh && "legalizing mesh payload input, but we're not a mesh shader");
+ IRBuilderInsertLocScope locScope{builder};
+ builder->setInsertInto(builder->getModule());
+
+ const auto g = builder->emitVar(pp->getDataType(), SpvStorageClassTaskPayloadWorkgroupEXT);
+ g->setFullType(builder->getRateQualifiedType(builder->getGroupSharedRate(), g->getFullType()));
+ // moveValueBefore(g, builder->getFunc());
+ builder->addNameHintDecoration(g, pp->findDecoration<IRNameHintDecoration>()->getName());
+ pp->replaceUsesWith(g);
+ struct MeshPayloadInputSpecializationCondition : FunctionCallSpecializeCondition
+ {
+ bool doesParamWantSpecialization(IRParam*, IRInst* arg)
+ {
+ return arg == g;
+ }
+ IRInst* g;
+ } condition;
+ condition.g = g;
+ specializeFunctionCalls(codeGenContext, builder->getModule(), &condition);
+}
+
+static void legalizeMeshOutputParam(
GLSLLegalizationContext* context,
CodeGenContext* codeGenContext,
IRFunc* func,
@@ -2354,6 +2382,7 @@ void legalizeEntryPointParameterForGLSL(
// don't fit into the standard varying model.
// - Geometry shader output streams
// - Mesh shader outputs
+ // - Mesh shader payload input
if( auto paramPtrType = as<IROutTypeBase>(paramType) )
{
auto valueType = paramPtrType->getValueType();
@@ -2476,6 +2505,10 @@ void legalizeEntryPointParameterForGLSL(
return legalizeMeshOutputParam(context, codeGenContext, func, pp, paramLayout, meshOutputType);
}
}
+ else if(pp->findDecoration<IRHLSLMeshPayloadDecoration>())
+ {
+ return legalizeMeshPayloadInputParam(context, codeGenContext, pp);
+ }
// When we have an HLSL ray tracing shader entry point,
// we don't want to translate the inputs/outputs for GLSL/SPIR-V
@@ -2878,4 +2911,83 @@ void legalizeConstantBufferLoadForGLSL(IRModule* module)
}
+void legalizeDispatchMeshPayloadForGLSL(IRModule* module)
+{
+ // Find out DispatchMesh function
+ IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
+ for(const auto globalInst : module->getGlobalInsts())
+ {
+ if(const auto func = as<IRGlobalValueWithCode>(globalInst))
+ {
+ if(const auto dec = func->findDecoration<IRKnownBuiltinDecoration>())
+ {
+ if(dec->getName() == "DispatchMesh")
+ {
+ SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found");
+ dispatchMeshFunc = func;
+ }
+ }
+ }
+ }
+
+ if(!dispatchMeshFunc)
+ return;
+
+ IRBuilder builder{module};
+ builder.setInsertBefore(dispatchMeshFunc);
+
+ // We'll rewrite the calls to call EmitMeshTasksEXT
+ traverseUses(dispatchMeshFunc, [&](const IRUse* use){
+ if(const auto call = as<IRCall>(use->getUser()))
+ {
+ SLANG_ASSERT(call->getArgCount() == 4);
+ const auto payload = call->getArg(3);
+
+ const auto payloadPtrType = composeGetters<IRPtrTypeBase>(
+ payload,
+ &IRInst::getDataType
+ );
+ SLANG_ASSERT(payloadPtrType);
+ const auto payloadType = payloadPtrType->getValueType();
+ SLANG_ASSERT(payloadType);
+
+ const bool isGroupsharedGlobal =
+ payload->getParent() == module->getModuleInst() &&
+ composeGetters<IRGroupSharedRate>(payload, &IRInst::getRate);
+ if(isGroupsharedGlobal)
+ {
+ // If it's a groupshared global, then we put it in the address
+ // space we know to emit as taskPayloadSharedEXT instead (or
+ // naturally fall through correctly for SPIR-V emit)
+ //
+ // Keep it as a groupshared rate qualified type so we don't
+ // miss out on any further legalization requirement or
+ // optimization opportunities.
+ const auto payloadSharedPtrType =
+ builder.getRateQualifiedType(
+ builder.getGroupSharedRate(),
+ builder.getPtrType(
+ payloadPtrType->getOp(),
+ payloadPtrType->getValueType(),
+ SpvStorageClassTaskPayloadWorkgroupEXT
+ )
+ );
+ payload->setFullType(payloadSharedPtrType);
+ }
+ else
+ {
+ // ...
+ // If it's not a groupshared global, then create such a
+ // parameter and store into the value being passed to this
+ // call.
+ builder.setInsertInto(module->getModuleInst());
+ const auto v = builder.emitVar(payloadType, SpvStorageClassTaskPayloadWorkgroupEXT);
+ v->setFullType(builder.getRateQualifiedType(builder.getGroupSharedRate(), v->getFullType()));
+ builder.setInsertBefore(call);
+ builder.emitStore(v, payload);
+ }
+ }
+ });
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-glsl-legalize.h b/source/slang/slang-ir-glsl-legalize.h
index 1816df1f2..6fd0b642e 100644
--- a/source/slang/slang-ir-glsl-legalize.h
+++ b/source/slang/slang-ir-glsl-legalize.h
@@ -25,4 +25,6 @@ void legalizeImageSubscriptForGLSL(IRModule* module);
void legalizeConstantBufferLoadForGLSL(IRModule* module);
+void legalizeDispatchMeshPayloadForGLSL(IRModule* module);
+
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f48801162..552a8af6d 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -813,6 +813,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(IndicesDecoration, indices, 1, 0)
INST(PrimitivesDecoration, primitives, 1, 0)
INST_RANGE(MeshOutputDecoration, VerticesDecoration, PrimitivesDecoration)
+ INST(HLSLMeshPayloadDecoration, payload, 0, 0)
INST(GLSLPrimitivesRateDecoration, perprimitive, 0, 0)
// Marks an inst that represents the gl_Position output.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index bfcca5b02..d8788aa06 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -360,6 +360,7 @@ IR_SIMPLE_DECORATION(NoInlineDecoration)
IR_SIMPLE_DECORATION(AlwaysFoldIntoUseSiteDecoration)
IR_SIMPLE_DECORATION(StaticRequirementDecoration)
IR_SIMPLE_DECORATION(NonCopyableTypeDecoration)
+IR_SIMPLE_DECORATION(HLSLMeshPayloadDecoration)
struct IRNVAPIMagicDecoration : IRDecoration
{
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 4266b46f9..106e5b5a3 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1355,7 +1355,10 @@ static void addLinkageDecoration(
}
else if (as<KnownBuiltinAttribute>(modifier))
{
- builder->addKnownBuiltinDecoration(inst, decl->getName()->text.getUnownedSlice());
+ // We add this to the internal instruction, like other name-like
+ // decorations, for instance "nameHint". This prevents it becoming
+ // lost during specialization.
+ builder->addKnownBuiltinDecoration(inInst, decl->getName()->text.getUnownedSlice());
}
}
if (as<InterfaceDecl>(decl->parentDecl) &&
@@ -2110,6 +2113,10 @@ void addVarDecorations(
{
builder->addFormatDecoration(inst, formatAttr->format);
}
+ else if(auto payloadMod = as<HLSLPayloadModifier>(mod))
+ {
+ builder->addSimpleDecoration<IRHLSLMeshPayloadDecoration>(inst);
+ }
// TODO: what are other modifiers we need to propagate through?
}
@@ -2977,6 +2984,12 @@ void _lowerFuncDeclBaseTypeInfo(
irParamType = builder->getRateQualifiedType(builder->getGroupSharedRate(), irParamType);
}
+ // The 'payload' parameter is a read-only groupshared value
+ if(paramInfo.decl && paramInfo.decl->hasModifier<HLSLPayloadModifier>())
+ {
+ irParamType = builder->getRateQualifiedType(builder->getGroupSharedRate(), irParamType);
+ }
+
paramTypes.add(irParamType);
}
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index e184585e3..3f5267577 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -7302,6 +7302,7 @@ namespace Slang
_makeParseModifier("vertices", HLSLVerticesModifier::kReflectClassInfo),
_makeParseModifier("indices", HLSLIndicesModifier::kReflectClassInfo),
_makeParseModifier("primitives", HLSLPrimitivesModifier::kReflectClassInfo),
+ _makeParseModifier("payload", HLSLPayloadModifier::kReflectClassInfo),
// Modifiers for unary operator declarations
_makeParseModifier("__prefix", PrefixModifier::kReflectClassInfo),