summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slang.h4
-rw-r--r--source/slang/hlsl.meta.slang2
-rw-r--r--source/slang/slang-emit-metal.cpp27
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h5
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp100
-rw-r--r--source/slang/slang-ir.h4
-rw-r--r--tests/metal/simple-task.slang102
8 files changed, 246 insertions, 1 deletions
diff --git a/slang.h b/slang.h
index 337ce9bef..a758df2be 100644
--- a/slang.h
+++ b/slang.h
@@ -2281,6 +2281,9 @@ extern "C"
// Metal [[attribute]] inputs.
SLANG_PARAMETER_CATEGORY_METAL_ATTRIBUTE,
+ // Metal [[payload]] inputs
+ SLANG_PARAMETER_CATEGORY_METAL_PAYLOAD,
+
//
SLANG_PARAMETER_CATEGORY_COUNT,
@@ -2855,6 +2858,7 @@ namespace slang
MetalTexture = SLANG_PARAMETER_CATEGORY_METAL_TEXTURE,
MetalArgumentBufferElement = SLANG_PARAMETER_CATEGORY_METAL_ARGUMENT_BUFFER_ELEMENT,
MetalAttribute = SLANG_PARAMETER_CATEGORY_METAL_ATTRIBUTE,
+ MetalPayload = SLANG_PARAMETER_CATEGORY_METAL_PAYLOAD,
// DEPRECATED:
VertexInput = SLANG_PARAMETER_CATEGORY_VERTEX_INPUT,
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 13b1f0131..72c8515b3 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -15737,6 +15737,8 @@ void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint thread
// This intrinsic doesn't take into account writing meshPayload. That
// is dealt with separately by 'legalizeDispatchMeshPayloadForGLSL'.
__intrinsic_asm "EmitMeshTasksEXT($0, $1, $2)";
+ case metal:
+ __intrinsic_asm "_slang_mesh_payload = *$3; _slang_mgp.set_threadgroups_per_grid(uint3($0, $1, $2)); return;";
case spirv:
return spirv_asm
{
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp
index c6ffee953..96843e286 100644
--- a/source/slang/slang-emit-metal.cpp
+++ b/source/slang/slang-emit-metal.cpp
@@ -165,6 +165,9 @@ void MetalSourceEmitter::emitFuncParamLayoutImpl(IRInst* param)
case LayoutResourceKind::VaryingInput:
m_writer->emit(" [[stage_in]]");
break;
+ case LayoutResourceKind::MetalPayload:
+ m_writer->emit(" [[payload]]");
+ break;
}
}
if (auto sysSemanticAttr = layout->findSystemValueSemanticAttr())
@@ -191,6 +194,12 @@ void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoi
case Stage::Compute:
m_writer->emit("[[kernel]] ");
break;
+ case Stage::Mesh:
+ m_writer->emit("[[mesh]] ");
+ break;
+ case Stage::Amplification:
+ m_writer->emit("[[object]] ");
+ break;
default:
SLANG_ABORT_COMPILATION("unsupported stage.");
}
@@ -608,18 +617,26 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
{
case AddressSpace::Global:
m_writer->emit(" device");
+ m_writer->emit("*");
break;
case AddressSpace::Uniform:
m_writer->emit(" constant");
+ m_writer->emit("*");
break;
case AddressSpace::ThreadLocal:
m_writer->emit(" thread");
+ m_writer->emit("*");
break;
case AddressSpace::GroupShared:
m_writer->emit(" threadgroup");
+ m_writer->emit("*");
+ break;
+ case AddressSpace::MetalObjectData:
+ m_writer->emit(" object_data");
+ // object data is passed by reference
+ m_writer->emit("&");
break;
}
- m_writer->emit("*");
return;
}
case kIROp_ArrayType:
@@ -631,6 +648,11 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit(">");
return;
}
+ case kIROp_MetalMeshGridPropertiesType:
+ {
+ m_writer->emit("mesh_grid_properties ");
+ return;
+ }
default:
break;
}
@@ -939,6 +961,9 @@ void MetalSourceEmitter::emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, IRI
case AddressSpace::ThreadLocal:
m_writer->emit("thread ");
break;
+ case AddressSpace::MetalObjectData:
+ m_writer->emit("object_data ");
+ break;
default:
break;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 9ab8e9a02..4d39eb978 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -192,6 +192,9 @@ INST(Nop, nop, 0, 0)
INST(PrimitivesType, Primitives, 2, HOISTABLE)
INST_RANGE(MeshOutputType, VerticesType, PrimitivesType)
+ /* Metal Mesh Grid Properties */
+ INST(MetalMeshGridPropertiesType, mesh_grid_properties, 0, HOISTABLE)
+
/* HLSLStructuredBufferTypeBase */
INST(HLSLStructuredBufferType, StructuredBuffer, 0, HOISTABLE)
INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, HOISTABLE)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 93f3e1227..5670cad47 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3591,6 +3591,11 @@ public:
return getAttributedType(baseType, attributes.getCount(), attributes.getBuffer());
}
+ IRMetalMeshGridPropertiesType* getMetalMeshGridPropertiesType()
+ {
+ return (IRMetalMeshGridPropertiesType*)getType(kIROp_MetalMeshGridPropertiesType);
+ }
+
IRInst* emitDebugSource(UnownedStringSlice fileName, UnownedStringSlice source);
IRInst* emitDebugLine(IRInst* source, IRIntegerValue lineStart, IRIntegerValue lineEnd, IRIntegerValue colStart, IRIntegerValue colEnd);
IRInst* emitDebugVar(IRType* type, IRInst* source, IRInst* line, IRInst* col, IRInst* argIndex = nullptr);
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index ae91fd069..d4a234515 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -1,5 +1,6 @@
#include "slang-ir-metal-legalize.h"
+#include "slang-ir.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
#include "slang-ir-clone.h"
@@ -138,6 +139,8 @@ namespace Slang
continue;
if (!layout->findOffsetAttr(LayoutResourceKind::VaryingInput))
continue;
+ if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ continue;
paramsToPack.add(param);
}
@@ -306,6 +309,98 @@ namespace Slang
fixUpFuncType(func, structType);
}
+ void legalizeMeshEntryPoint(EntryPointInfo entryPoint)
+ {
+ auto func = entryPoint.entryPointFunc;
+
+ if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh)
+ {
+ return;
+ }
+
+ IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
+ for (auto param : func->getParams())
+ {
+ if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ {
+ IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
+
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(param, paramVarLayout);
+ }
+ }
+
+ }
+
+ void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint)
+ {
+ if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification)
+ {
+ return;
+ }
+ // Find out DispatchMesh function
+ IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
+ for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts())
+ {
+ if (const auto func = as<IRGlobalValueWithCode>(globalInst))
+ {
+ if (const auto dec = func->findDecoration<IRKnownBuiltinDecoration>())
+ {
+ if (dec->getName() == "DispatchMesh")
+ {
+ SLANG_ASSERT(!dispatchMeshFunc && "Multiple DispatchMesh functions found");
+ dispatchMeshFunc = func;
+ }
+ }
+ }
+ }
+
+ if (!dispatchMeshFunc)
+ return;
+
+ IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
+
+ // We'll rewrite the call to use mesh_grid_properties.set_threadgroups_per_grid
+ traverseUses(dispatchMeshFunc, [&](const IRUse* use) {
+ if (const auto call = as<IRCall>(use->getUser()))
+ {
+ SLANG_ASSERT(call->getArgCount() == 4);
+ const auto payload = call->getArg(3);
+
+ const auto payloadPtrType = composeGetters<IRPtrType>(
+ payload,
+ &IRInst::getDataType
+ );
+ SLANG_ASSERT(payloadPtrType);
+ const auto payloadType = payloadPtrType->getValueType();
+ SLANG_ASSERT(payloadType);
+
+ builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+ const auto annotatedPayloadType =
+ builder.getPtrType(
+ kIROp_RefType,
+ payloadPtrType->getValueType(),
+ AddressSpace::MetalObjectData
+ );
+ auto packedParam = builder.emitParam(annotatedPayloadType);
+ builder.addExternCppDecoration(packedParam, toSlice("_slang_mesh_payload"));
+ IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
+
+ // Add the MetalPayload resource info, so we can emit [[payload]]
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(packedParam, paramVarLayout);
+
+ // Now we replace the call to DispatchMesh with a call to the mesh grid properties
+ // But first we need to create the parameter
+ const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType();
+ auto mgp = builder.emitParam(meshGridPropertiesType);
+ builder.addExternCppDecoration(mgp, toSlice("_slang_mgp"));
+ }
+ });
+ }
+
void legalizeEntryPointForMetal(EntryPointInfo entryPoint, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
@@ -313,8 +408,11 @@ namespace Slang
hoistEntryPointParameterFromStruct(entryPoint);
packStageInParameters(entryPoint);
wrapReturnValueInStruct(entryPoint);
+ legalizeMeshEntryPoint(entryPoint);
+ legalizeDispatchMeshPayloadForMetal(entryPoint);
}
+
void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
{
List<EntryPointInfo> entryPoints;
@@ -337,4 +435,6 @@ namespace Slang
specializeAddressSpace(module);
}
+
}
+
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 9e8be7fc5..7c04729b0 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -46,6 +46,8 @@ enum class AddressSpace
Global = 2,
GroupShared = 3,
Uniform = 4,
+ // specific address space for payload data in metal
+ MetalObjectData = 5,
};
typedef unsigned int IROpFlags;
@@ -1549,6 +1551,8 @@ SIMPLE_IR_TYPE(VerticesType, MeshOutputType)
SIMPLE_IR_TYPE(IndicesType, MeshOutputType)
SIMPLE_IR_TYPE(PrimitivesType, MeshOutputType)
+SIMPLE_IR_TYPE(MetalMeshGridPropertiesType, Type)
+
SIMPLE_IR_TYPE(GLSLInputAttachmentType, Type)
SIMPLE_IR_PARENT_TYPE(ParameterGroupType, PointerLikeType)
SIMPLE_IR_PARENT_TYPE(UniformParameterGroupType, ParameterGroupType)
diff --git a/tests/metal/simple-task.slang b/tests/metal/simple-task.slang
new file mode 100644
index 000000000..fa38f6043
--- /dev/null
+++ b/tests/metal/simple-task.slang
@@ -0,0 +1,102 @@
+//TEST:SIMPLE(filecheck=CHECK): -entry taskMain -stage amplification -target metal
+
+//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
+
+uniform RWStructuredBuffer<float> outputBuffer;
+
+cbuffer Uniforms
+{
+ float4x4 modelViewProjection;
+}
+
+//
+// Task shader
+//
+
+struct MeshPayload
+{
+ int exponent;
+};
+
+// CHECK: MeshPayload_0 object_data& _slang_mesh_payload
+// CHECK: mesh_grid_properties _slang_mgp
+[numthreads(1,1,1)]
+void taskMain()
+{
+ // CHECK: _slang_mesh_payload
+ // CHECK: _slang_mgp.set_threadgroups_per_grid
+ MeshPayload p;
+ p.exponent = 3;
+ DispatchMesh(1, 1, 1, p);
+}
+
+//
+// Mesh shader
+//
+
+const static float2 positions[3] = {
+ float2(0.0, -0.5),
+ float2(0.5, 0.5),
+ float2(-0.5, 0.5)
+};
+
+const static float3 colors[3] = {
+ float3(1.0, 1.0, 0.0),
+ float3(0.0, 1.0, 1.0),
+ float3(1.0, 0.0, 1.0)
+};
+
+struct Vertex
+{
+ float4 pos : SV_Position;
+ float3 color : Color;
+ int index : Index;
+ int value : Value;
+};
+
+const static uint MAX_VERTS = 12;
+const static uint MAX_PRIMS = 4;
+
+[outputtopology("triangle")]
+[numthreads(12, 1, 1)]
+void meshMain(
+ in uint tig: SV_GroupIndex,
+ in payload MeshPayload meshPayload,
+ // Check that we correctly generate the specific 'in payload' that HLSL
+ // requires:
+ // HLSL: , in payload MeshPayload
+ OutputVertices<Vertex, MAX_VERTS> verts,
+ OutputIndices<uint3, MAX_PRIMS> triangles)
+{
+ const uint numVertices = 12;
+ const uint numPrimitives = 4;
+ SetMeshOutputCounts(numVertices, numPrimitives);
+
+ if (tig < numVertices)
+ {
+ const int tri = tig / 3;
+ verts[tig] = { float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent)) };
+ }
+
+ if (tig < numPrimitives)
+ triangles[tig] = tig * 3 + uint3(0, 1, 2);
+}
+
+//
+// Fragment Shader
+//
+
+struct Fragment
+{
+ float4 color : SV_Target;
+};
+
+Fragment fragmentMain(Vertex input)
+{
+ outputBuffer[input.index] = input.value;
+
+ Fragment output;
+ output.color = float4(input.color, 1.0);
+ return output;
+}
+