From d40931cc8bde13520ea45769cf94e7cc6cc9065f Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Fri, 15 Mar 2024 08:48:41 +0800 Subject: Mesh shader refactoring and bugfixes (#3702) --- source/slang/core.meta.slang | 25 ++-- source/slang/slang-check-decl.cpp | 4 +- source/slang/slang-diagnostic-defs.h | 2 +- source/slang/slang-emit-c-like.cpp | 15 +++ source/slang/slang-emit-hlsl.cpp | 7 +- source/slang/slang-ir-glsl-legalize.cpp | 209 +++++++++++++++++-------------- source/slang/slang-ir-inst-defs.h | 1 + source/slang/slang-ir-insts.h | 10 +- source/slang/slang-ir-spirv-legalize.cpp | 24 +++- source/slang/slang-parameter-binding.cpp | 7 ++ 10 files changed, 191 insertions(+), 113 deletions(-) (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 237daac56..e55206ca9 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1255,12 +1255,15 @@ struct ParameterBlock {} __generic __magic_type(VerticesType) __intrinsic_type($(kIROp_VerticesType)) -struct Vertices +[__NonCopyableType] +struct OutputVertices { __subscript(uint index) -> T { - // TODO: Ellie make sure these remains write only - __intrinsic_op($(kIROp_GetElementPtr)) + // TODO: Make sure this remains write only, we can't do this with just + // a 'set' operation as it's legal to only write to part of the output + // buffer, or part of the output buffer at a time. + __intrinsic_op($(kIROp_MeshOutputRef)) ref; } }; @@ -1268,24 +1271,28 @@ struct Vertices __generic __magic_type(IndicesType) __intrinsic_type($(kIROp_IndicesType)) -struct Indices +[__NonCopyableType] +struct OutputIndices { __subscript(uint index) -> T { - // TODO: Ellie: It's illegal to not write out the whole primitive at once, should we use set over ref? - __intrinsic_op($(kIROp_GetElementPtr)) - ref; + // It's illegal to not write out the entire primitive at once, so limit + // this to set + [mutating] + __intrinsic_op($(kIROp_MeshOutputSet)) + set; } }; __generic __magic_type(PrimitivesType) __intrinsic_type($(kIROp_PrimitivesType)) -struct Primitives +[__NonCopyableType] +struct OutputPrimitives { __subscript(uint index) -> T { - __intrinsic_op($(kIROp_GetElementPtr)) + __intrinsic_op($(kIROp_MeshOutputRef)) ref; } }; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 88a707ed8..093e2599f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7032,7 +7032,9 @@ namespace Slang { return; } - if(!varDecl->findModifier()) + // HLSL requires an 'out' modifier here, but since we don't operate + // under such strict compatability we can just not warn here. + if(!varDecl->findModifier() && modifier) { getSink()->diagnose(varDecl, Diagnostics::meshOutputMustBeOut); } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 5b66a0f91..47c75ab0c 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -782,7 +782,7 @@ DIAGNOSTIC(52007, Error, typeCannotBeUsedInDynamicDispatch, "failed to generate DIAGNOSTIC(52008, Error, dynamicDispatchOnSpecializeOnlyInterface, "type '$0' is marked for specialization only, but dynamic dispatch is needed for the call.") DIAGNOSTIC(53001, Error, invalidTypeMarshallingForImportedDLLSymbol, "invalid type marshalling in imported func $0.") -DIAGNOSTIC(54001, Error, meshOutputMustBeOut, "Mesh shader outputs must be declared with 'out'.") +DIAGNOSTIC(54001, Warning, meshOutputMustBeOut, "Mesh shader outputs must be declared with 'out'.") DIAGNOSTIC(54002, Error, meshOutputMustBeArray, "HLSL style mesh shader outputs must be arrays") DIAGNOSTIC(54003, Error, meshOutputArrayMustHaveSize, "HLSL style mesh shader output arrays must have a length specified") DIAGNOSTIC(54004, Warning, unnecessaryHLSLMeshOutputModifier, "Unnecessary HLSL style mesh shader output modifier") diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 44d74f219..2019bfa8c 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2409,6 +2409,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO break; } case kIROp_GetElement: + case kIROp_MeshOutputRef: case kIROp_GetElementPtr: case kIROp_ImageSubscript: // HACK: deal with translation of GLSL geometry shader input arrays. @@ -2904,6 +2905,20 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst) m_writer->emit(";\n"); } break; + case kIROp_MeshOutputSet: + { + auto ii = (IRMeshOutputSet*)inst; + auto subscriptOuter = getInfo(EmitOp::General); + auto subscriptPrec = getInfo(EmitOp::Postfix); + emitOperand(ii->getBase(), leftSide(subscriptOuter, subscriptPrec)); + m_writer->emit("["); + emitOperand(ii->getIndex(), getInfo(EmitOp::General)); + m_writer->emit("]"); + m_writer->emit(" = "); + emitOperand(ii->getElementValue(), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + } + break; } } diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index fcab737e1..cf1ca794b 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -1206,10 +1206,11 @@ void HLSLSourceEmitter::emitMeshShaderModifiersImpl(IRInst* varInst) { if(auto modifier = varInst->findDecoration()) { + // DXC requires that mesh payload parameters have "out" specified const char* s = - as(modifier) ? "vertices " - : as(modifier) ? "indices " - : as(modifier) ? "primitives " + as(modifier) ? "out vertices " + : as(modifier) ? "out indices " + : as(modifier) ? "out primitives " : nullptr; SLANG_ASSERT(s && "Unhandled type of mesh output decoration"); m_writer->emit(s); diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index e40dc33ff..c9bc1339b 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -2106,106 +2106,118 @@ static void legalizeMeshOutputParam( // the writes may only be writing to parts of the output struct, or may not // be writes at all (i.e. being passed as an out paramter). // - traverseUsers(g, [&](IRInst* u) - { - auto l = as(u); - SLANG_ASSERT(l && "Mesh Output sentinel parameter wasn't used in a load"); - - std::function assignUses = - [&](ScalarizedVal& d, IRInst* a) + std::function assignUses = + [&](ScalarizedVal& d, IRInst* a) + { + // If we're just writing to an address, we can seamlessly + // replace it with the address to the SOA representation. + // GLSL's `out` function parameters have copy-out semantics, so + // this is all above board. + if(d.flavor == ScalarizedVal::Flavor::address) + { + IRBuilderInsertLocScope locScope{builder}; + builder->setInsertBefore(a); + a->replaceUsesWith(d.irValue); + a->removeAndDeallocate(); + return; + } + // Otherwise, go through the uses one by one and see what we can do + traverseUsers(a, [&](IRInst* s) { - // If we're just writing to an address, we can seamlessly - // replace it with the address to the SOA representation. - // GLSL's `out` function parameters have copy-out semantics, so - // this is all above board. - if(d.flavor == ScalarizedVal::Flavor::address) + IRBuilderInsertLocScope locScope{builder}; + builder->setInsertBefore(s); + if(auto m = as(s)) { - IRBuilderInsertLocScope locScope{builder}; - builder->setInsertBefore(a); - a->replaceUsesWith(d.irValue); - a->removeAndDeallocate(); - return; + auto key = as(m->getField()); + SLANG_ASSERT(key && "Result of getField wasn't a struct key"); + + auto d_ = extractField(builder, d, kMaxUInt, key); + assignUses(d_, m); } - // Otherwise, go through the uses one by one and see what we can do - traverseUsers(a, [&](IRInst* s) + else if(auto ref = as(s)) { - IRBuilderInsertLocScope locScope{builder}; - builder->setInsertBefore(s); - if(auto m = as(s)) - { - auto key = as(m->getField()); - SLANG_ASSERT(key && "Result of getField wasn't a struct key"); - - auto d_ = extractField(builder, d, kMaxUInt, key); - assignUses(d_, m); - } - else if(auto g = as(s)) - { - // Writing to something like `struct Vertex{ Foo foo[10]; }` - // This case is also what's taken in the initial - // traversal, as every mesh output is an array. - auto elemType = composeGetters( - g, - &IRInst::getFullType, - &IRPtrTypeBase::getValueType); - auto d_ = getSubscriptVal(builder, elemType, d, g->getIndex()); - assignUses(d_, g); - } - else if(auto store = as(s)) - { - // Store using the SOA representation + auto elemType = composeGetters( + ref, + &IRInst::getFullType, + &IRPtrTypeBase::getValueType); + auto d_ = getSubscriptVal(builder, elemType, d, ref->getIndex()); + assignUses(d_, ref); + } + else if(auto set = as(s)) + { + auto elemType = composeGetters( + set, + &IRInst::getFullType, + &IRPtrTypeBase::getValueType); + auto d_ = getSubscriptVal(builder, elemType, d, set->getIndex()); + assign(builder, d_, ScalarizedVal::value(set->getElementValue())); + set->removeAndDeallocate(); + } + else if(auto g = as(s)) + { + // Writing to something like `struct Vertex{ Foo foo[10]; }` + // This case is also what's taken in the initial + // traversal, as every mesh output is an array. + auto elemType = composeGetters( + g, + &IRInst::getFullType, + &IRPtrTypeBase::getValueType); + auto d_ = getSubscriptVal(builder, elemType, d, g->getIndex()); + assignUses(d_, g); + } + else if(auto store = as(s)) + { + // Store using the SOA representation - assign( - builder, - d, - ScalarizedVal::value(store->getVal())); + assign( + builder, + d, + ScalarizedVal::value(store->getVal())); - // Stores aren't used, safe to remove here without checking - store->removeAndDeallocate(); - } - else if(auto c = as(s)) + // Stores aren't used, safe to remove here without checking + store->removeAndDeallocate(); + } + else if(auto c = as(s)) + { + // Translate + // foo(vertices[n]) + // to + // tmp + // foo(tmp) + // vertices[n] = tmp; + // + // This has copy-out semantics, which is really the + // best we can hope for without going and + // specializing foo. + auto ptr = as(a->getFullType()); + SLANG_ASSERT(ptr && "Mesh output parameter was passed by value"); + auto t = ptr->getValueType(); + auto tmp = builder->emitVar(t); + for(UInt i = 0; i < c->getOperandCount(); i++) { - // Translate - // foo(vertices[n]) - // to - // tmp - // foo(tmp) - // vertices[n] = tmp; - // - // This has copy-out semantics, which is really the - // best we can hope for without going and - // specializing foo. - auto ptr = as(a->getFullType()); - SLANG_ASSERT(ptr && "Mesh output parameter was passed by value"); - auto t = ptr->getValueType(); - auto tmp = builder->emitVar(t); - for(UInt i = 0; i < c->getOperandCount(); i++) + if(c->getOperand(i) == a) { - if(c->getOperand(i) == a) - { - c->setOperand(i, tmp); - } + c->setOperand(i, tmp); } - builder->setInsertAfter(c); - assign(builder, d, - ScalarizedVal::value(builder->emitLoad(tmp))); - } - else if(const auto swiz = as(s)) - { - SLANG_UNEXPECTED("Swizzled store to a non-address ScalarizedVal"); } - else - { - SLANG_UNEXPECTED("Unhandled use of mesh output parameter during GLSL legalization"); - } - }); - - SLANG_ASSERT(!a->hasUses()); - a->removeAndDeallocate(); - }; + builder->setInsertAfter(c); + assign(builder, d, + ScalarizedVal::value(builder->emitLoad(tmp))); + } + else if(const auto swiz = as(s)) + { + SLANG_UNEXPECTED("Swizzled store to a non-address ScalarizedVal"); + } + else + { + SLANG_UNEXPECTED("Unhandled use of mesh output parameter during GLSL legalization"); + } + }); - assignUses(globalOutputVal, l); - }); + SLANG_ASSERT(!a->hasUses()); + a->removeAndDeallocate(); + }; + assignUses(globalOutputVal, g); // // GLSL requires that builtins are written to a block named @@ -2348,15 +2360,18 @@ static void legalizeMeshOutputParam( { traverseUsers(builtin.param, [&](IRInst* u) { - auto p = as(u); - SLANG_ASSERT(p && "Mesh Output sentinel parameter wasn't used as an array"); - IRBuilderInsertLocScope locScope{builder}; - builder->setInsertBefore(p); - auto e = builder->emitElementAddress(builder->getPtrType(meshOutputBlockType), blockParam, p->getIndex()); + builder->setInsertBefore(u); + IRInst* index; + if(const auto p = as(u)) + index = p->getIndex(); + else if(const auto m = as(u)) + index = m->getIndex(); + else + SLANG_UNEXPECTED("Illegal use of mesh output parameter"); + auto e = builder->emitElementAddress(builder->getPtrType(meshOutputBlockType), blockParam, index); auto a = builder->emitFieldAddress(builder->getPtrType(builtin.type), e, builtin.key); - - p->replaceUsesWith(a); + u->replaceUsesWith(a); }); } } @@ -2461,7 +2476,7 @@ void legalizeEntryPointParameterForGLSL( // - Geometry shader output streams // - Mesh shader outputs // - Mesh shader payload input - if (auto paramPtrType = as(paramType)) + if (auto paramPtrType = as(paramType)) { valueType = paramPtrType->getValueType(); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 977fca904..aa0929a57 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -478,6 +478,7 @@ INST(AtomicCounterDecrement, AtomicCounterDecrement, 1, 0) INST(GetNaturalStride, getNaturalStride, 1, 0) INST(MeshOutputRef, meshOutputRef, 2, 0) +INST(MeshOutputSet, meshOutputSet, 3, 0) // Construct a vector from a scalar // diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a85b279b8..3ae9f04d7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1379,12 +1379,20 @@ struct IRGLPositionInputDecoration : public IRDecoration struct IRMeshOutputRef : public IRInst { - enum { kOp = kIROp_MeshOutputRef }; IR_LEAF_ISA(MeshOutputRef) + IRInst* getBase() { return getOperand(0); } IRInst* getIndex() { return getOperand(1); } IRInst* getOutputType() { return cast(getFullType())->getValueType(); } }; +struct IRMeshOutputSet : public IRInst +{ + IR_LEAF_ISA(MeshOutputSet) + IRInst* getBase() { return getOperand(0); } + IRInst* getIndex() { return getOperand(1); } + IRInst* getElementValue() { return getOperand(2); } +}; + /// An attribute that can be attached to another instruction as an operand. /// /// Attributes serve a similar role to decorations, in that both are ways diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 6dbbbcd1f..4dde2a035 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1064,6 +1064,22 @@ struct SPIRVLegalizationContext : public SourceEmitterBase processGetElementPtrImpl(gepInst, gepInst->getBase(), gepInst->getIndex()); } + void processMeshOutputGetElementPtr(IRMeshOutputRef* gepInst) + { + processGetElementPtrImpl(gepInst, gepInst->getBase(), gepInst->getIndex()); + } + + void processMeshOutputSet(IRMeshOutputSet* setInst) + { + IRBuilder builder(m_sharedContext->m_irModule); + builder.setInsertBefore(setInst); + const auto p = builder.emitElementAddress(setInst->getBase(), setInst->getIndex()); + const auto s = builder.emitStore(p, setInst->getElementValue()); + setInst->removeAndDeallocate(); + addToWorkList(p); + addToWorkList(s); + } + void processGetOffsetPtr(IRInst* offsetPtrInst) { auto ptrOperandType = as(offsetPtrInst->getOperand(0)->getDataType()); @@ -1752,7 +1768,13 @@ struct SPIRVLegalizationContext : public SourceEmitterBase processImageSubscript(as(inst)); break; case kIROp_RWStructuredBufferGetElementPtr: - processRWStructuredBufferGetElementPtr(as(inst)); + processRWStructuredBufferGetElementPtr(cast(inst)); + break; + case kIROp_MeshOutputRef: + processMeshOutputGetElementPtr(cast(inst)); + break; + case kIROp_MeshOutputSet: + processMeshOutputSet(cast(inst)); break; case kIROp_RWStructuredBufferLoad: case kIROp_StructuredBufferLoad: diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index ebaa58adb..5d7aaa651 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2278,6 +2278,13 @@ static RefPtr computeEntryPointParameterTypeLayout( state.directionMask |= kEntryPointParameterDirection_Output; } + // For the purposes of type layout, mesh shader outputs are always + // treated as output only, despite missing an 'out' modifier + if(as(paramDeclRef.getDecl()->getType())) + { + state.directionMask = kEntryPointParameterDirection_Output; + } + return processEntryPointVaryingParameterDecl( context, paramDeclRef.getDecl(), -- cgit v1.2.3