summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/core.meta.slang67
-rw-r--r--source/slang/hlsl.meta.slang7
-rw-r--r--source/slang/slang-emit-metal.cpp70
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp5
-rw-r--r--source/slang/slang-ir-inst-defs.h8
-rw-r--r--source/slang/slang-ir-insts.h32
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp154
-rw-r--r--source/slang/slang-ir.cpp26
-rw-r--r--source/slang/slang-ir.h9
-rw-r--r--tests/metal/simple-task.slang14
10 files changed, 368 insertions, 24 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 629737d6c..919c21473 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1435,13 +1435,34 @@ __intrinsic_type($(kIROp_VerticesType))
[__NonCopyableType]
struct OutputVertices
{
+ __intrinsic_op($(kIROp_MetalSetVertex))
+ static void _metalSetVertex(uint index, T val);
+
+ __intrinsic_op($(kIROp_MeshOutputSet))
+ static void _setVertex(This v, uint index, T val);
+
__subscript(uint index) -> T
{
// TODO: Make sure this remains write only, we can't do this with just
// a 'set' operation as it's legal to only write to part of the output
// buffer, or part of the output buffer at a time.
+
+ [mutating]
+ [require(glsl_hlsl_metal_spirv, meshshading)]
+ set
+ {
+ __target_switch
+ {
+ case metal: _metalSetVertex(index, newValue);
+ case glsl: _setVertex(this, index, newValue);
+ case hlsl: _setVertex(this, index, newValue);
+ case spirv: _setVertex(this, index, newValue);
+ }
+ }
+
//
// If a 'OutputVertices[index]' is referred to by a '__ref', call 'kIROp_MeshOutputRef(index)'
+ [require(glsl_hlsl_spirv, meshshading)]
__intrinsic_op($(kIROp_MeshOutputRef))
ref;
}
@@ -1453,13 +1474,29 @@ __intrinsic_type($(kIROp_IndicesType))
[__NonCopyableType]
struct OutputIndices
{
+ __intrinsic_op($(kIROp_MetalSetIndices))
+ static void __metalSetIndices(uint index, T val);
+
+ // for some reason only here in the indices array it uses the return value as an actual
+ // operand, while the others use the value of the instruction (the return value) to access
+ // the type of the vertex, as when using the ref there is no third operand
+ __intrinsic_op($(kIROp_MeshOutputSet))
+ static void __setIndices(This v, uint index, T val);
+
__subscript(uint index) -> T
{
- // It's illegal to not write out the entire primitive at once, so limit
- // this to set
[mutating]
- __intrinsic_op($(kIROp_MeshOutputSet))
- set;
+ [require(glsl_hlsl_metal_spirv, meshshading)]
+ set
+ {
+ __target_switch
+ {
+ case metal: __metalSetIndices(index, newValue);
+ case glsl: __setIndices(this, index, newValue);
+ case hlsl: __setIndices(this, index, newValue);
+ case spirv: __setIndices(this, index, newValue);
+ }
+ }
}
};
@@ -1469,9 +1506,29 @@ __intrinsic_type($(kIROp_PrimitivesType))
[__NonCopyableType]
struct OutputPrimitives
{
- __subscript(uint index) -> T
+ __intrinsic_op($(kIROp_MetalSetPrimitive))
+ static void __metalSetPrimitive(uint index, T val);
+
+ __intrinsic_op($(kIROp_MeshOutputSet))
+ static void __setPrimitive(This v, uint index, T val);
+
+ __subscript(uint index)->T
{
+ [mutating]
+ [require(glsl_hlsl_metal_spirv, meshshading)]
+ set
+ {
+ __target_switch
+ {
+ case metal: __metalSetPrimitive(index, newValue);
+ case glsl: __setPrimitive(this, index, newValue);
+ case hlsl: __setPrimitive(this, index, newValue);
+ case spirv: __setPrimitive(this, index, newValue);
+ }
+ }
+
// If a 'OutputPrimitives[index]' is referred to by a '__ref', call 'kIROp_MeshOutputRef(index)'
+ [require(glsl_hlsl_spirv, meshshading)]
__intrinsic_op($(kIROp_MeshOutputRef))
ref;
}
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 37d3ad19b..21cb9ef18 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -16288,7 +16288,8 @@ float dot2add(float2 left, float2 right, float acc);
// Set the number of output vertices and primitives for a mesh shader invocation.
__glsl_extension(GL_EXT_mesh_shader)
__glsl_version(450)
-[require(glsl_hlsl_spirv, meshshading)]
+[require(glsl_hlsl_metal_spirv, meshshading)]
+[noRefInline]
void SetMeshOutputCounts(uint vertexCount, uint primitiveCount)
{
__target_switch
@@ -16297,6 +16298,8 @@ void SetMeshOutputCounts(uint vertexCount, uint primitiveCount)
__intrinsic_asm "SetMeshOutputCounts";
case glsl:
__intrinsic_asm "SetMeshOutputsEXT";
+ case metal:
+ __intrinsic_asm "_slang_mesh.set_primitive_count($1)";
case spirv:
return spirv_asm
{
@@ -16328,7 +16331,7 @@ void DispatchMesh<P>(uint threadGroupCountX, uint threadGroupCountY, uint thread
// is dealt with separately by 'legalizeDispatchMeshPayloadForGLSL'.
__intrinsic_asm "EmitMeshTasksEXT($0, $1, $2)";
case metal:
- __intrinsic_asm "_slang_mesh_payload = *$3; _slang_mgp.set_threadgroups_per_grid(uint3($0, $1, $2)); return;";
+ __intrinsic_asm "*_slang_mesh_payload = *$3; _slang_mgp.set_threadgroups_per_grid(uint3($0, $1, $2)); return;";
case spirv:
return spirv_asm
{
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp
index 0b282c4d9..61e310401 100644
--- a/source/slang/slang-emit-metal.cpp
+++ b/source/slang/slang-emit-metal.cpp
@@ -492,6 +492,44 @@ bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inO
m_writer->emit(")");
return true;
}
+ case kIROp_MetalSetVertex:
+ {
+ auto setVertex = as<IRMetalSetVertex>(inst);
+ m_writer->emit("_slang_mesh.set_vertex(");
+ emitOperand(setVertex->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit(",");
+ emitOperand(setVertex->getElementValue(), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ case kIROp_MetalSetPrimitive:
+ {
+ auto setPrimitive = as<IRMetalSetPrimitive>(inst);
+ m_writer->emit("_slang_mesh.set_primitive(");
+ emitOperand(setPrimitive->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit(",");
+ emitOperand(setPrimitive->getElementValue(), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ case kIROp_MetalSetIndices:
+ {
+ auto setIndices = as<IRMetalSetIndices>(inst);
+ const auto indices = as<IRVectorType>(setIndices->getElementValue()->getDataType());
+ UInt numIndices = as<IRIntLit>(indices->getElementCount())->getValue();
+ for(UInt i = 0; i < numIndices; ++i) {
+ m_writer->emit("_slang_mesh.set_index(");
+ emitOperand(setIndices->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit("*");
+ m_writer->emitUInt64(numIndices);
+ m_writer->emit(",(");
+ emitOperand(setIndices->getElementValue(), getInfo(EmitOp::General));
+ m_writer->emit(")[");
+ m_writer->emitUInt64(i);
+ m_writer->emit("]);\n");
+ }
+ return true;
+ }
default: break;
}
// Not handled
@@ -699,6 +737,10 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
case kIROp_ConstRefType:
{
auto ptrType = cast<IRPtrTypeBase>(type);
+ if(type->getOp() == kIROp_ConstRefType)
+ {
+ m_writer->emit("const ");
+ }
emitType((IRType*)ptrType->getValueType());
switch (ptrType->getAddressSpace())
{
@@ -720,8 +762,7 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
break;
case AddressSpace::MetalObjectData:
m_writer->emit(" object_data");
- // object data is passed by reference
- m_writer->emit("&");
+ m_writer->emit("*");
break;
}
return;
@@ -785,6 +826,31 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
}
return;
}
+ else if (const auto meshType = as<IRMetalMeshType>(type))
+ {
+ m_writer->emit("metal::mesh<");
+ emitType(meshType->getVerticesType());
+ m_writer->emit(", ");
+ emitType(meshType->getPrimitivesType());
+ m_writer->emit(", ");
+ emitOperand(meshType->getNumVertices(), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(meshType->getNumPrimitives(), getInfo(EmitOp::General));
+ m_writer->emit(", metal::topology::");
+ switch(meshType->getTopology()->getValue()) {
+ case 1:
+ m_writer->emit("point");
+ break;
+ case 2:
+ m_writer->emit("line");
+ break;
+ case 3:
+ m_writer->emit("triangle");
+ break;
+ }
+ m_writer->emit(">");
+ return;
+ }
else if(auto specializedType = as<IRSpecialize>(type))
{
// If a `specialize` instruction made it this far, then
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 0242076f5..dabf0294f 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -2372,9 +2372,8 @@ static void legalizeMeshOutputParam(
else if(auto set = as<IRMeshOutputSet>(s))
{
auto elemType = composeGetters<IRType>(
- set,
- &IRInst::getFullType,
- &IRPtrTypeBase::getValueType);
+ set->getElementValue(),
+ &IRInst::getFullType);
auto d_ = getSubscriptVal(builder, elemType, d, set->getIndex());
assign(builder, d_, ScalarizedVal::value(set->getElementValue()));
set->removeAndDeallocate();
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 179ed3065..9fc1ab22c 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -193,6 +193,8 @@ INST(Nop, nop, 0, 0)
INST(PrimitivesType, Primitives, 2, HOISTABLE)
INST_RANGE(MeshOutputType, VerticesType, PrimitivesType)
+ /* Metal Mesh Type */
+ INST(MetalMeshType, metal::mesh, 5, HOISTABLE)
/* Metal Mesh Grid Properties */
INST(MetalMeshGridPropertiesType, mesh_grid_properties, 0, HOISTABLE)
@@ -513,6 +515,12 @@ INST(GetNaturalStride, getNaturalStride, 1, 0)
INST(MeshOutputRef, meshOutputRef, 2, 0)
INST(MeshOutputSet, meshOutputSet, 3, 0)
+// only two parameters as they are effectively static
+// TODO: make them reference the _slang_mesh object directly
+INST(MetalSetVertex, metalSetVertex, 2, 0)
+INST(MetalSetPrimitive, metalSetPrimitive, 2, 0)
+INST(MetalSetIndices, metalSetIndices, 2, 0)
+
// Construct a vector from a scalar
//
// %dst = MakeVectorFromScalar %T %N %val
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 79362799b..db1571e50 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1477,6 +1477,27 @@ struct IRMeshOutputSet : public IRInst
IRInst* getElementValue() { return getOperand(2); }
};
+struct IRMetalSetVertex : public IRInst
+{
+ IR_LEAF_ISA(MetalSetVertex)
+ IRInst* getIndex() { return getOperand(0); }
+ IRInst* getElementValue() { return getOperand(1); }
+};
+
+struct IRMetalSetPrimitive : public IRInst
+{
+ IR_LEAF_ISA(MetalSetPrimitive)
+ IRInst* getIndex() { return getOperand(0); }
+ IRInst* getElementValue() { return getOperand(1); }
+};
+
+struct IRMetalSetIndices : public IRInst
+{
+ IR_LEAF_ISA(MetalSetIndices)
+ IRInst* getIndex() { return getOperand(0); }
+ IRInst* getElementValue() { return getOperand(1); }
+};
+
/// An attribute that can be attached to another instruction as an operand.
///
/// Attributes serve a similar role to decorations, in that both are ways
@@ -3689,6 +3710,12 @@ public:
return (IRMetalMeshGridPropertiesType*)getType(kIROp_MetalMeshGridPropertiesType);
}
+ IRMetalMeshType* getMetalMeshType(IRType* vertexType, IRType* primitiveType, IRInst* numVertices, IRInst* numPrimitives, IRInst* topology)
+ {
+ IRInst* ops[5] = {vertexType, primitiveType, numVertices, numPrimitives, topology};
+ return (IRMetalMeshType*)getType(kIROp_MetalMeshType, 5, ops);
+ }
+
IRInst* emitDebugSource(UnownedStringSlice fileName, UnownedStringSlice source);
IRInst* emitDebugLine(IRInst* source, IRIntegerValue lineStart, IRIntegerValue lineEnd, IRIntegerValue colStart, IRIntegerValue colEnd);
IRInst* emitDebugVar(IRType* type, IRInst* source, IRInst* line, IRInst* col, IRInst* argIndex = nullptr);
@@ -4441,6 +4468,11 @@ public:
IRInst* emitRWStructuredBufferGetElementPtr(IRInst* structuredBuffer, IRInst* index);
IRInst* emitNonUniformResourceIndexInst(IRInst* val);
+
+ IRMetalSetVertex* emitMetalSetVertex(IRInst* index, IRInst* vertex);
+ IRMetalSetPrimitive* emitMetalSetPrimitive(IRInst* index, IRInst* primitive);
+ IRMetalSetIndices* emitMetalSetIndices(IRInst* index, IRInst* indices);
+
//
// Decorations
//
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index 435abc369..942a40716 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -1269,12 +1269,68 @@ namespace Slang
{
auto func = entryPoint.entryPointFunc;
- if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Mesh)
+ IRBuilder builder{ func->getModule() };
+ for (auto param : func->getParams())
+ {
+ if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
+ {
+ IRVarLayout::Builder varLayoutBuilder(&builder, IRTypeLayout::Builder{&builder}.build());
+
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::MetalPayload);
+ auto paramVarLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(param, paramVarLayout);
+
+ IRPtrTypeBase* type = as<IRPtrTypeBase>(param->getDataType());
+
+ const auto annotatedPayloadType =
+ builder.getPtrType(
+ kIROp_ConstRefType,
+ type->getValueType(),
+ AddressSpace::MetalObjectData
+ );
+
+ param->setFullType(annotatedPayloadType);
+ }
+ }
+ IROutputTopologyDecoration* outputDeco = entryPoint.entryPointFunc->findDecoration<IROutputTopologyDecoration>();
+ if(outputDeco == nullptr)
{
+ SLANG_UNEXPECTED("Mesh shader output decoration missing");
+ return;
+ }
+ const auto topology = outputDeco->getTopology();
+ const auto topStr = topology->getStringSlice();
+ UInt topologyEnum = 0;
+ if(topStr.caseInsensitiveEquals(toSlice("point")))
+ {
+ topologyEnum = 1;
+ }
+ else if(topStr.caseInsensitiveEquals(toSlice("line")))
+ {
+ topologyEnum = 2;
+ }
+ else if(topStr.caseInsensitiveEquals(toSlice("triangle")))
+ {
+ topologyEnum = 3;
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unknown topology");
return;
}
- IRBuilder builder{ entryPoint.entryPointFunc->getModule() };
+ IRInst* topologyConst = builder.getIntValue(builder.getIntType(), topologyEnum);
+
+ IRType* vertexType = nullptr;
+ IRType* indicesType = nullptr;
+ IRType* primitiveType = nullptr;
+
+ IRInst* maxVertices = nullptr;
+ IRInst* maxPrimitives = nullptr;
+
+ IRInst* verticesParam = nullptr;
+ IRInst* indicesParam = nullptr;
+ IRInst* primitivesParam = nullptr;
for (auto param : func->getParams())
{
if(param->findDecorationImpl(kIROp_HLSLMeshPayloadDecoration))
@@ -1285,16 +1341,85 @@ namespace Slang
auto paramVarLayout = varLayoutBuilder.build();
builder.addLayoutDecoration(param, paramVarLayout);
}
+ if(param->findDecorationImpl(kIROp_VerticesDecoration))
+ {
+ auto vertexRefType = as<IRPtrTypeBase>(param->getDataType());
+ auto vertexOutputType = as<IRVerticesType>(vertexRefType->getValueType());
+ vertexType = vertexOutputType->getElementType();
+ maxVertices = vertexOutputType->getMaxElementCount();
+ SLANG_ASSERT(vertexType);
+
+ verticesParam = param;
+ auto vertStruct = as<IRStructType>(vertexType);
+ for(auto field : vertStruct->getFields())
+ {
+ auto key = field->getKey();
+ if(auto deco = key->findDecoration<IRSemanticDecoration>())
+ {
+ if(deco->getSemanticName().caseInsensitiveEquals(toSlice("sv_position")))
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("position"));
+ }
+ }
+ }
+ }
+ if(param->findDecorationImpl(kIROp_IndicesDecoration))
+ {
+ auto indicesRefType = (IRConstRefType*)param->getDataType();
+ auto indicesOutputType = (IRIndicesType*)indicesRefType->getValueType();
+ indicesType = indicesOutputType->getElementType();
+ maxPrimitives = indicesOutputType->getMaxElementCount();
+ SLANG_ASSERT(indicesType);
+
+ indicesParam = param;
+ }
+ if(param->findDecorationImpl(kIROp_PrimitivesDecoration))
+ {
+ auto primitivesRefType = (IRConstRefType*)param->getDataType();
+ auto primitivesOutputType = (IRPrimitivesType*)primitivesRefType->getValueType();
+ primitiveType = primitivesOutputType->getElementType();
+ SLANG_ASSERT(primitiveType);
+
+ primitivesParam = param;
+ auto primStruct = as<IRStructType>(primitiveType);
+ for(auto field : primStruct->getFields())
+ {
+ auto key = field->getKey();
+ if(auto deco = key->findDecoration<IRSemanticDecoration>())
+ {
+ if(deco->getSemanticName().caseInsensitiveEquals(toSlice("sv_primitiveid")))
+ {
+ builder.addTargetSystemValueDecoration(key, toSlice("primitive_id"));
+ }
+ }
+ }
+ }
}
+ if(primitiveType == nullptr)
+ {
+ primitiveType = builder.getVoidType();
+ }
+ builder.setInsertBefore(entryPoint.entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+
+ auto meshParam = builder.emitParam(builder.getMetalMeshType(vertexType, primitiveType, maxVertices, maxPrimitives, topologyConst));
+ builder.addExternCppDecoration(meshParam, toSlice("_slang_mesh"));
+
+ verticesParam->removeFromParent();
+ verticesParam->removeAndDeallocate();
+
+ indicesParam->removeFromParent();
+ indicesParam->removeAndDeallocate();
+
+ if(primitivesParam != nullptr)
+ {
+ primitivesParam->removeFromParent();
+ primitivesParam->removeAndDeallocate();
+ }
}
void legalizeDispatchMeshPayloadForMetal(EntryPointInfo entryPoint)
{
- if (entryPoint.entryPointDecor->getProfile().getStage() != Stage::Amplification)
- {
- return;
- }
// Find out DispatchMesh function
IRGlobalValueWithCode* dispatchMeshFunc = nullptr;
for (const auto globalInst : entryPoint.entryPointFunc->getModule()->getGlobalInsts())
@@ -1353,8 +1478,8 @@ namespace Slang
const auto meshGridPropertiesType = builder.getMetalMeshGridPropertiesType();
auto mgp = builder.emitParam(meshGridPropertiesType);
builder.addExternCppDecoration(mgp, toSlice("_slang_mgp"));
- }
- });
+ }
+ });
}
IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType)
@@ -1601,8 +1726,17 @@ namespace Slang
wrapReturnValueInStruct(entryPoint);
//Other Legalize
- legalizeMeshEntryPoint(entryPoint);
- legalizeDispatchMeshPayloadForMetal(entryPoint);
+ switch(entryPoint.entryPointDecor->getProfile().getStage())
+ {
+ case Stage::Amplification:
+ legalizeDispatchMeshPayloadForMetal(entryPoint);
+ break;
+ case Stage::Mesh:
+ legalizeMeshEntryPoint(entryPoint);
+ break;
+ default:
+ break;
+ }
}
};
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6a76ccce3..acc6abd57 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5380,6 +5380,32 @@ namespace Slang
return emitSwizzle(type, base, elementCount, irElementIndices);
}
+ IRMetalSetVertex* IRBuilder::emitMetalSetVertex(
+ IRInst* index,
+ IRInst* vertex)
+ {
+ auto inst = createInst<IRMetalSetVertex>(this, kIROp_MetalSetVertex, getVoidType(), index, vertex);
+ addInst(inst);
+ return inst;
+ }
+
+ IRMetalSetPrimitive* IRBuilder::emitMetalSetPrimitive(
+ IRInst* index,
+ IRInst* primitive)
+ {
+ auto inst = createInst<IRMetalSetPrimitive>(this, kIROp_MetalSetVertex, getVoidType(), index, primitive);
+ addInst(inst);
+ return inst;
+ }
+
+ IRMetalSetIndices* IRBuilder::emitMetalSetIndices(
+ IRInst* index,
+ IRInst* indices)
+ {
+ auto inst = createInst<IRMetalSetIndices>(this, kIROp_MetalSetVertex, getVoidType(), index, indices);
+ addInst(inst);
+ return inst;
+ }
IRInst* IRBuilder::emitSwizzleSet(
IRType* type,
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 036c7f3a7..99c62b214 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1550,6 +1550,15 @@ SIMPLE_IR_TYPE(VerticesType, MeshOutputType)
SIMPLE_IR_TYPE(IndicesType, MeshOutputType)
SIMPLE_IR_TYPE(PrimitivesType, MeshOutputType)
+struct IRMetalMeshType : IRType
+{
+ IRType* getVerticesType() { return (IRType*)getOperand(0); }
+ IRType* getPrimitivesType() { return (IRType*)getOperand(1); }
+ IRInst* getNumVertices() { return (IRInst*)getOperand(2); }
+ IRInst* getNumPrimitives() { return (IRInst*)getOperand(3); }
+ IRIntLit* getTopology() { return (IRIntLit*)getOperand(4); }
+};
+
SIMPLE_IR_TYPE(MetalMeshGridPropertiesType, Type)
SIMPLE_IR_TYPE(GLSLInputAttachmentType, Type)
diff --git a/tests/metal/simple-task.slang b/tests/metal/simple-task.slang
index fa38f6043..4a12719d0 100644
--- a/tests/metal/simple-task.slang
+++ b/tests/metal/simple-task.slang
@@ -18,7 +18,7 @@ struct MeshPayload
int exponent;
};
-// CHECK: MeshPayload_0 object_data& _slang_mesh_payload
+// CHECK: MeshPayload_0 object_data* _slang_mesh_payload
// CHECK: mesh_grid_properties _slang_mgp
[numthreads(1,1,1)]
void taskMain()
@@ -54,6 +54,11 @@ struct Vertex
int value : Value;
};
+struct Primitive
+{
+ uint prim : SV_PrimitiveID;
+};
+
const static uint MAX_VERTS = 12;
const static uint MAX_PRIMS = 4;
@@ -66,7 +71,9 @@ void meshMain(
// requires:
// HLSL: , in payload MeshPayload
OutputVertices<Vertex, MAX_VERTS> verts,
- OutputIndices<uint3, MAX_PRIMS> triangles)
+ OutputIndices<uint3, MAX_PRIMS> triangles,
+ OutputPrimitives<Primitive, MAX_PRIMS> primitives
+ )
{
const uint numVertices = 12;
const uint numPrimitives = 4;
@@ -79,7 +86,10 @@ void meshMain(
}
if (tig < numPrimitives)
+ {
triangles[tig] = tig * 3 + uint3(0, 1, 2);
+ primitives[tig] = { tig };
+ }
}
//