summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/core.meta.slang4
-rw-r--r--source/slang/hlsl.meta.slang28
-rw-r--r--source/slang/slang-emit-spirv-ops.h14
-rw-r--r--source/slang/slang-emit-spirv.cpp40
-rw-r--r--source/slang/slang-type-system-shared.h2
-rw-r--r--tests/workgraphs/consumer.slang32
6 files changed, 108 insertions, 12 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index f0324ba1a..ba8decc12 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -3965,3 +3965,7 @@ attribute_syntax [DerivativeGroupQuad] : DerivativeGroupQuadAttribute;
/// effect on other targets.
__attributeTarget(FuncDecl)
attribute_syntax [DerivativeGroupLinear] : DerivativeGroupLinearAttribute;
+
+__generic<T>
+typealias NodePayloadPtr = Ptr<T, $( (uint64_t)AddressSpace::NodePayloadAMDX)>;
+
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index a27cafbd4..07bf2ffbd 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -21365,3 +21365,31 @@ int8_t4_packed pack_clamp_s8(int16_t4 unpackedValue)
}
}
+// Work-graphs
+
+//@public:
+/// read-only input to Broadcasting launch node.
+__generic<T>
+//TODO: DispatchNodeInputRecord should be available only for broadcasting node shader.
+//[require(broadcasting_node)]
+[require(spirv)]
+struct DispatchNodeInputRecord
+{
+ /// Provide an access to a record object that only holds a single record.
+ NodePayloadPtr<T> Get()
+ {
+ int index = 0;
+ __target_switch
+ {
+ case spirv:
+ return spirv_asm
+ {
+ %in_payload_t = OpTypeNodePayloadArrayAMDX $$T;
+ %in_payload_ptr_t = OpTypePointer NodePayloadAMDX %in_payload_t;
+ %var = OpVariable %in_payload_ptr_t NodePayloadAMDX;
+ result : $$NodePayloadPtr<T> = OpAccessChain %var $index;
+ };
+ }
+ }
+};
+
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index d385d54b2..1f2996646 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -2552,4 +2552,18 @@ SpvInst* emitOpAtomicIDecrement(
memory,
semantics);
}
+
+// https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/AMD/SPV_AMDX_shader_enqueue.html#OpTypeNodePayloadArrayAMDX
+template<typename T>
+SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type)
+{
+ static_assert(isSingular<T>);
+ return emitInstMemoized(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ inst,
+ SpvOpTypeNodePayloadArrayAMDX,
+ kResultID,
+ type);
+}
+
#endif // SLANG_IN_SPIRV_EMIT_CONTEXT
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index d8c479cd1..1407404ad 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1311,6 +1311,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
return SpvStorageClassImage;
case AddressSpace::UserPointer:
return SpvStorageClassPhysicalStorageBuffer;
+ case AddressSpace::NodePayloadAMDX:
+ return SpvStorageClassNodePayloadAMDX;
case AddressSpace::Global:
case AddressSpace::MetalObjectData:
case AddressSpace::SpecializationConstant:
@@ -1504,13 +1506,22 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SLANG_ASSERT(ptrType);
if (ptrType->hasAddressSpace())
storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
- if (storageClass == SpvStorageClassStorageBuffer)
+
+ switch (storageClass)
+ {
+ case SpvStorageClassStorageBuffer:
ensureExtensionDeclaration(
UnownedStringSlice("SPV_KHR_storage_buffer_storage_class"));
- if (storageClass == SpvStorageClassPhysicalStorageBuffer)
- {
+ break;
+ case SpvStorageClassPhysicalStorageBuffer:
requirePhysicalStorageAddressing();
+ break;
+ case SpvStorageClassNodePayloadAMDX:
+ requireSPIRVCapability(SpvCapabilityShaderEnqueueAMDX);
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_AMDX_shader_enqueue"));
+ break;
}
+
auto valueType = ptrType->getValueType();
// If we haven't emitted the inner type yet, we need to emit a forward declaration.
bool useForwardDeclaration =
@@ -1524,17 +1535,20 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
builder.setInsertBefore(valueType);
valueTypeId = getID(ensureInst(builder.getUIntType()));
}
+ else if (useForwardDeclaration)
+ {
+ valueTypeId = getIRInstSpvID(valueType);
+ }
+ else if (storageClass == SpvStorageClassNodePayloadAMDX)
+ {
+ auto spvValueType = ensureInst(valueType);
+ auto spvNodePayloadType = emitOpTypeNodePayloadArray(inst, spvValueType);
+ valueTypeId = getID(spvNodePayloadType);
+ }
else
{
- if (useForwardDeclaration)
- {
- valueTypeId = getIRInstSpvID(valueType);
- }
- else
- {
- auto spvValueType = ensureInst(valueType);
- valueTypeId = getID(spvValueType);
- }
+ auto spvValueType = ensureInst(valueType);
+ valueTypeId = getID(spvValueType);
}
auto resultSpvType = emitOpTypePointer(inst, storageClass, valueTypeId);
@@ -7564,6 +7578,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case SpvOpMemberDecorate:
case SpvOpMemberDecorateString:
return getSection(SpvLogicalSectionID::Annotations);
+ case SpvOpTypeNodePayloadArrayAMDX:
+ return getSection(SpvLogicalSectionID::ConstantsAndTypes);
default:
return defaultParent;
}
diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h
index d11bc8cd3..583eb2216 100644
--- a/source/slang/slang-type-system-shared.h
+++ b/source/slang/slang-type-system-shared.h
@@ -110,6 +110,8 @@ enum class AddressSpace : uint64_t
Image,
// Represents a SPIR-V specialization constant
SpecializationConstant,
+ // Corresponds to SPIR-V's SpvStorageClassNodePayloadAMDX,
+ NodePayloadAMDX,
// Default address space for a user-defined pointer
UserPointer = 0x100000001ULL,
diff --git a/tests/workgraphs/consumer.slang b/tests/workgraphs/consumer.slang
new file mode 100644
index 000000000..5e211a2a1
--- /dev/null
+++ b/tests/workgraphs/consumer.slang
@@ -0,0 +1,32 @@
+//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry main -skip-spirv-validation
+struct RecordData
+{
+ int myData;
+};
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void main(uint3 dispatchThreadId : SV_GroupThreadID)
+{
+ spirv_asm
+ {
+ OpExecutionMode $main ShaderIndexAMDX $(0);
+ OpExecutionMode $main StaticNumWorkgroupsAMDX $(1) $(1) $(1);
+ };
+
+ DispatchNodeInputRecord<RecordData> inputData;
+
+ let recordData = inputData.Get();
+ int myData = recordData.myData;
+}
+
+//CHK: ; Types, variables and constants
+//CHK: [[MemberType:%[a-zA-Z_0-9]+]] = OpTypeInt 32 1
+//CHK: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]]
+//CHK: [[PayloadType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]]
+//CHK: [[PtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType]]
+
+//CHK: ; Function
+//CHK: [[VarName:%[a-zA-Z_0-9]+]] = OpVariable [[PtrType]] NodePayloadAMDX
+//CHK: = OpAccessChain [[PtrType]] [[VarName]]
+