diff options
| author | Dynamitos <dynamitos15@gmail.com> | 2024-06-02 22:18:37 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-06-02 13:18:37 -0700 |
| commit | 753a524be885cf463fa6e60734aa739fcce1396f (patch) | |
| tree | d20e2bd2baf12a621f314727aabb395f5a8b5d5e /source/slang | |
| parent | 0bc89bc13251fedc9ed90cf473d2e6eb7fda3abf (diff) | |
Metal Task Shader payload (#4238)
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 27 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 100 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 4 |
6 files changed, 140 insertions, 1 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 13b1f0131..72c8515b3 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -15737,6 +15737,8 @@ void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint thread // This intrinsic doesn't take into account writing meshPayload. That // is dealt with separately by 'legalizeDispatchMeshPayloadForGLSL'. __intrinsic_asm "EmitMeshTasksEXT($0, $1, $2)"; + case metal: + __intrinsic_asm "_slang_mesh_payload = *$3; _slang_mgp.set_threadgroups_per_grid(uint3($0, $1, $2)); return;"; case spirv: return spirv_asm { diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index c6ffee953..96843e286 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -165,6 +165,9 @@ void MetalSourceEmitter::emitFuncParamLayoutImpl(IRInst* param) case LayoutResourceKind::VaryingInput: m_writer->emit(" [[stage_in]]"); break; + case LayoutResourceKind::MetalPayload: + m_writer->emit(" [[payload]]"); + break; } } if (auto sysSemanticAttr = layout->findSystemValueSemanticAttr()) @@ -191,6 +194,12 @@ void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoi case Stage::Compute: m_writer->emit("[[kernel]] "); break; + case Stage::Mesh: + m_writer->emit("[[mesh]] "); + break; + case Stage::Amplification: + m_writer->emit("[[object]] "); + break; default: SLANG_ABORT_COMPILATION("unsupported stage."); } @@ -608,18 +617,26 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) { case AddressSpace::Global: m_writer->emit(" device"); + m_writer->emit("*"); break; case AddressSpace::Uniform: m_writer->emit(" constant"); + m_writer->emit("*"); break; case AddressSpace::ThreadLocal: m_writer->emit(" thread"); + m_writer->emit("*"); break; case AddressSpace::GroupShared: m_writer->emit(" threadgroup"); + m_writer->emit("*"); + break; + case AddressSpace::MetalObjectData: + m_writer->emit(" object_data"); + // object data is passed by reference + m_writer->emit("&"); break; } - m_writer->emit("*"); return; } case kIROp_ArrayType: @@ -631,6 +648,11 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit(">"); return; } + case kIROp_MetalMeshGridPropertiesType: + { + m_writer->emit("mesh_grid_properties "); + return; + } default: break; } @@ -939,6 +961,9 @@ void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRI case AddressSpace::ThreadLocal: m_writer->emit("thread "); break; + case AddressSpace::MetalObjectData: + m_writer->emit("object_data "); + break; default: break; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 9ab8e9a02..4d39eb978 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -192,6 +192,9 @@ INST(Nop, nop, 0, 0) INST(PrimitivesType, Primitives, 2, HOISTABLE) INST_RANGE(MeshOutputType, VerticesType, PrimitivesType) + /* Metal Mesh Grid Properties */ + INST(MetalMeshGridPropertiesType, mesh_grid_properties, 0, HOISTABLE) + /* HLSLStructuredBufferTypeBase */ INST(HLSLStructuredBufferType, StructuredBuffer, 0, HOISTABLE) INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, HOISTABLE) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 93f3e1227..5670cad47 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3591,6 +3591,11 @@ public: return getAttributedType(baseType, attributes.getCount(), attributes.getBuffer()); } + IRMetalMeshGridPropertiesType* getMetalMeshGridPropertiesType() + { + return (IRMetalMeshGridPropertiesType*)getType(kIROp_MetalMeshGridPropertiesType); + } + IRInst* emitDebugSource(UnownedStringSlice fileName, UnownedStringSlice source); IRInst* emitDebugLine(IRInst* source, IRIntegerValue lineStart, IRIntegerValue lineEnd, IRIntegerValue colStart, IRIntegerValue colEnd); IRInst* emitDebugVar(IRType* type, IRInst* source, IRInst* line, IRInst* col, IRInst* argIndex = nullptr); diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index ae91fd069..d4a234515 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1,5 +1,6 @@ #include "slang-ir-metal-legalize.h" +#include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir-clone.h" @@ -138,6 +139,8 @@ namespace Slang continue; if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) continue; + if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + continue; paramsToPack.add(param); } @@ -306,6 +309,98 @@ namespace Slang fixUpFuncType(func, structType); } + void legalizeMeshEntryPoint(EntryPointInfo entryPoint) + { + auto func = entryPoint.entryPointFunc; + + if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh) + { + return; + } + + IRBuilder builder{ entryPoint.entryPointFunc->getModule() }; + for (auto param : func->getParams()) + { + if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + { + IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build()); + + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(param, paramVarLayout); + } + } + + } + + void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint) + { + if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification) + { + return; + } + // Find out DispatchMesh function + IRGlobalValueWithCode* dispatchMeshFunc = nullptr; + for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts()) + { + if (const auto func = as<IRGlobalValueWithCode>(globalInst)) + { + if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>()) + { + if (dec->getName() == "DispatchMesh") + { + SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found"); + dispatchMeshFunc = func; + } + } + } + } + + if (!dispatchMeshFunc) + return; + + IRBuilder builder{ entryPoint.entryPointFunc->getModule() }; + + // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid + traverseUses(dispatchMeshFunc, [&](const IRUse* use) { + if (const auto call = as<IRCall>(use->getUser())) + { + SLANG_ASSERT(call->getArgCount() == 4); + const auto payload = call->getArg(3); + + const auto payloadPtrType = composeGetters<IRPtrType>( + payload, + &IRInst::getDataType + ); + SLANG_ASSERT(payloadPtrType); + const auto payloadType = payloadPtrType->getValueType(); + SLANG_ASSERT(payloadType); + + builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + const auto annotatedPayloadType = + builder.getPtrType( + kIROp_RefType, + payloadPtrType->getValueType(), + AddressSpace::MetalObjectData + ); + auto packedParam = builder.emitParam(annotatedPayloadType); + builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload")); + IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build()); + + // Add the MetalPayload resource info, so we can emit [[payload]] + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(packedParam, paramVarLayout); + + // Now we replace the call to DispatchMesh with a call to the mesh grid properties + // But first we need to create the parameter + const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType(); + auto mgp = builder.emitParam(meshGridPropertiesType); + builder.addExternCppDecoration(mgp, toSlice("_slang_mgp")); + } + }); + } + void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -313,8 +408,11 @@ namespace Slang hoistEntryPointParameterFromStruct(entryPoint); packStageInParameters(entryPoint); wrapReturnValueInStruct(entryPoint); + legalizeMeshEntryPoint(entryPoint); + legalizeDispatchMeshPayloadForMetal(entryPoint); } + void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) { List<EntryPointInfo> entryPoints; @@ -337,4 +435,6 @@ namespace Slang specializeAddressSpace(module); } + } + diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 9e8be7fc5..7c04729b0 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -46,6 +46,8 @@ enum class AddressSpace Global = 2, GroupShared = 3, Uniform = 4, + // specific address space for payload data in metal + MetalObjectData = 5, }; typedef unsigned int IROpFlags; @@ -1549,6 +1551,8 @@ SIMPLE_IR_TYPE(VerticesType, MeshOutputType) SIMPLE_IR_TYPE(IndicesType, MeshOutputType) SIMPLE_IR_TYPE(PrimitivesType, MeshOutputType) +SIMPLE_IR_TYPE(MetalMeshGridPropertiesType, Type) + SIMPLE_IR_TYPE(GLSLInputAttachmentType, Type) SIMPLE_IR_PARENT_TYPE(ParameterGroupType, PointerLikeType) SIMPLE_IR_PARENT_TYPE(UniformParameterGroupType, ParameterGroupType) |
