diff options
| -rw-r--r-- | slang.h | 4 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 27 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 100 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 4 | ||||
| -rw-r--r-- | tests/metal/simple-task.slang | 102 |
8 files changed, 246 insertions, 1 deletions
@@ -2281,6 +2281,9 @@ extern "C" // Metal [[attribute]] inputs. SLANG_PARAMETER_CATEGORY_METAL_ATTRIBUTE, + // Metal [[payload]] inputs + SLANG_PARAMETER_CATEGORY_METAL_PAYLOAD, + // SLANG_PARAMETER_CATEGORY_COUNT, @@ -2855,6 +2858,7 @@ namespace slang MetalTexture = SLANG_PARAMETER_CATEGORY_METAL_TEXTURE, MetalArgumentBufferElement = SLANG_PARAMETER_CATEGORY_METAL_ARGUMENT_BUFFER_ELEMENT, MetalAttribute = SLANG_PARAMETER_CATEGORY_METAL_ATTRIBUTE, + MetalPayload = SLANG_PARAMETER_CATEGORY_METAL_PAYLOAD, // DEPRECATED: VertexInput = SLANG_PARAMETER_CATEGORY_VERTEX_INPUT, diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 13b1f0131..72c8515b3 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -15737,6 +15737,8 @@ void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint thread // This intrinsic doesn't take into account writing meshPayload. That // is dealt with separately by 'legalizeDispatchMeshPayloadForGLSL'. __intrinsic_asm "EmitMeshTasksEXT($0, $1, $2)"; + case metal: + __intrinsic_asm "_slang_mesh_payload = *$3; _slang_mgp.set_threadgroups_per_grid(uint3($0, $1, $2)); return;"; case spirv: return spirv_asm { diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index c6ffee953..96843e286 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -165,6 +165,9 @@ void MetalSourceEmitter::emitFuncParamLayoutImpl(IRInst* param) case LayoutResourceKind::VaryingInput: m_writer->emit(" [[stage_in]]"); break; + case LayoutResourceKind::MetalPayload: + m_writer->emit(" [[payload]]"); + break; } } if (auto sysSemanticAttr = layout->findSystemValueSemanticAttr()) @@ -191,6 +194,12 @@ void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoi case Stage::Compute: m_writer->emit("[[kernel]] "); break; + case Stage::Mesh: + m_writer->emit("[[mesh]] "); + break; + case Stage::Amplification: + m_writer->emit("[[object]] "); + break; default: SLANG_ABORT_COMPILATION("unsupported stage."); } @@ -608,18 +617,26 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) { case AddressSpace::Global: m_writer->emit(" device"); + m_writer->emit("*"); break; case AddressSpace::Uniform: m_writer->emit(" constant"); + m_writer->emit("*"); break; case AddressSpace::ThreadLocal: m_writer->emit(" thread"); + m_writer->emit("*"); break; case AddressSpace::GroupShared: m_writer->emit(" threadgroup"); + m_writer->emit("*"); + break; + case AddressSpace::MetalObjectData: + m_writer->emit(" object_data"); + // object data is passed by reference + m_writer->emit("&"); break; } - m_writer->emit("*"); return; } case kIROp_ArrayType: @@ -631,6 +648,11 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit(">"); return; } + case kIROp_MetalMeshGridPropertiesType: + { + m_writer->emit("mesh_grid_properties "); + return; + } default: break; } @@ -939,6 +961,9 @@ void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRI case AddressSpace::ThreadLocal: m_writer->emit("thread "); break; + case AddressSpace::MetalObjectData: + m_writer->emit("object_data "); + break; default: break; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 9ab8e9a02..4d39eb978 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -192,6 +192,9 @@ INST(Nop, nop, 0, 0) INST(PrimitivesType, Primitives, 2, HOISTABLE) INST_RANGE(MeshOutputType, VerticesType, PrimitivesType) + /* Metal Mesh Grid Properties */ + INST(MetalMeshGridPropertiesType, mesh_grid_properties, 0, HOISTABLE) + /* HLSLStructuredBufferTypeBase */ INST(HLSLStructuredBufferType, StructuredBuffer, 0, HOISTABLE) INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, HOISTABLE) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 93f3e1227..5670cad47 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3591,6 +3591,11 @@ public: return getAttributedType(baseType, attributes.getCount(), attributes.getBuffer()); } + IRMetalMeshGridPropertiesType* getMetalMeshGridPropertiesType() + { + return (IRMetalMeshGridPropertiesType*)getType(kIROp_MetalMeshGridPropertiesType); + } + IRInst* emitDebugSource(UnownedStringSlice fileName, UnownedStringSlice source); IRInst* emitDebugLine(IRInst* source, IRIntegerValue lineStart, IRIntegerValue lineEnd, IRIntegerValue colStart, IRIntegerValue colEnd); IRInst* emitDebugVar(IRType* type, IRInst* source, IRInst* line, IRInst* col, IRInst* argIndex = nullptr); diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index ae91fd069..d4a234515 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -1,5 +1,6 @@ #include "slang-ir-metal-legalize.h" +#include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir-clone.h" @@ -138,6 +139,8 @@ namespace Slang continue; if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput)) continue; + if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + continue; paramsToPack.add(param); } @@ -306,6 +309,98 @@ namespace Slang fixUpFuncType(func, structType); } + void legalizeMeshEntryPoint(EntryPointInfo entryPoint) + { + auto func = entryPoint.entryPointFunc; + + if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh) + { + return; + } + + IRBuilder builder{ entryPoint.entryPointFunc->getModule() }; + for (auto param : func->getParams()) + { + if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration)) + { + IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build()); + + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(param, paramVarLayout); + } + } + + } + + void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint) + { + if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification) + { + return; + } + // Find out DispatchMesh function + IRGlobalValueWithCode* dispatchMeshFunc = nullptr; + for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts()) + { + if (const auto func = as<IRGlobalValueWithCode>(globalInst)) + { + if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>()) + { + if (dec->getName() == "DispatchMesh") + { + SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found"); + dispatchMeshFunc = func; + } + } + } + } + + if (!dispatchMeshFunc) + return; + + IRBuilder builder{ entryPoint.entryPointFunc->getModule() }; + + // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid + traverseUses(dispatchMeshFunc, [&](const IRUse* use) { + if (const auto call = as<IRCall>(use->getUser())) + { + SLANG_ASSERT(call->getArgCount() == 4); + const auto payload = call->getArg(3); + + const auto payloadPtrType = composeGetters<IRPtrType>( + payload, + &IRInst::getDataType + ); + SLANG_ASSERT(payloadPtrType); + const auto payloadType = payloadPtrType->getValueType(); + SLANG_ASSERT(payloadType); + + builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + const auto annotatedPayloadType = + builder.getPtrType( + kIROp_RefType, + payloadPtrType->getValueType(), + AddressSpace::MetalObjectData + ); + auto packedParam = builder.emitParam(annotatedPayloadType); + builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload")); + IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build()); + + // Add the MetalPayload resource info, so we can emit [[payload]] + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload); + auto paramVarLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(packedParam, paramVarLayout); + + // Now we replace the call to DispatchMesh with a call to the mesh grid properties + // But first we need to create the parameter + const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType(); + auto mgp = builder.emitParam(meshGridPropertiesType); + builder.addExternCppDecoration(mgp, toSlice("_slang_mgp")); + } + }); + } + void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink) { SLANG_UNUSED(sink); @@ -313,8 +408,11 @@ namespace Slang hoistEntryPointParameterFromStruct(entryPoint); packStageInParameters(entryPoint); wrapReturnValueInStruct(entryPoint); + legalizeMeshEntryPoint(entryPoint); + legalizeDispatchMeshPayloadForMetal(entryPoint); } + void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) { List<EntryPointInfo> entryPoints; @@ -337,4 +435,6 @@ namespace Slang specializeAddressSpace(module); } + } + diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 9e8be7fc5..7c04729b0 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -46,6 +46,8 @@ enum class AddressSpace Global = 2, GroupShared = 3, Uniform = 4, + // specific address space for payload data in metal + MetalObjectData = 5, }; typedef unsigned int IROpFlags; @@ -1549,6 +1551,8 @@ SIMPLE_IR_TYPE(VerticesType, MeshOutputType) SIMPLE_IR_TYPE(IndicesType, MeshOutputType) SIMPLE_IR_TYPE(PrimitivesType, MeshOutputType) +SIMPLE_IR_TYPE(MetalMeshGridPropertiesType, Type) + SIMPLE_IR_TYPE(GLSLInputAttachmentType, Type) SIMPLE_IR_PARENT_TYPE(ParameterGroupType, PointerLikeType) SIMPLE_IR_PARENT_TYPE(UniformParameterGroupType, ParameterGroupType) diff --git a/tests/metal/simple-task.slang b/tests/metal/simple-task.slang new file mode 100644 index 000000000..fa38f6043 --- /dev/null +++ b/tests/metal/simple-task.slang @@ -0,0 +1,102 @@ +//TEST:SIMPLE(filecheck=CHECK): -entry taskMain -stage amplification -target metal + +//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer + +uniform RWStructuredBuffer<float> outputBuffer; + +cbuffer Uniforms +{ + float4x4 modelViewProjection; +} + +// +// Task shader +// + +struct MeshPayload +{ + int exponent; +}; + +// CHECK: MeshPayload_0 object_data& _slang_mesh_payload +// CHECK: mesh_grid_properties _slang_mgp +[numthreads(1,1,1)] +void taskMain() +{ + // CHECK: _slang_mesh_payload + // CHECK: _slang_mgp.set_threadgroups_per_grid + MeshPayload p; + p.exponent = 3; + DispatchMesh(1, 1, 1, p); +} + +// +// Mesh shader +// + +const static float2 positions[3] = { + float2(0.0, -0.5), + float2(0.5, 0.5), + float2(-0.5, 0.5) +}; + +const static float3 colors[3] = { + float3(1.0, 1.0, 0.0), + float3(0.0, 1.0, 1.0), + float3(1.0, 0.0, 1.0) +}; + +struct Vertex +{ + float4 pos : SV_Position; + float3 color : Color; + int index : Index; + int value : Value; +}; + +const static uint MAX_VERTS = 12; +const static uint MAX_PRIMS = 4; + +[outputtopology("triangle")] +[numthreads(12, 1, 1)] +void meshMain( + in uint tig: SV_GroupIndex, + in payload MeshPayload meshPayload, + // Check that we correctly generate the specific 'in payload' that HLSL + // requires: + // HLSL: , in payload MeshPayload + OutputVertices<Vertex, MAX_VERTS> verts, + OutputIndices<uint3, MAX_PRIMS> triangles) +{ + const uint numVertices = 12; + const uint numPrimitives = 4; + SetMeshOutputCounts(numVertices, numPrimitives); + + if (tig < numVertices) + { + const int tri = tig / 3; + verts[tig] = { float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent)) }; + } + + if (tig < numPrimitives) + triangles[tig] = tig * 3 + uint3(0, 1, 2); +} + +// +// Fragment Shader +// + +struct Fragment +{ + float4 color : SV_Target; +}; + +Fragment fragmentMain(Vertex input) +{ + outputBuffer[input.index] = input.value; + + Fragment output; + output.color = float4(input.color, 1.0); + return output; +} + |
