diff options
| author | Jay Kwak <82421531+jkwak-work@users.noreply.github.com> | 2025-01-09 20:49:25 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-09 20:49:25 -0800 |
| commit | 55ff4686e5685c414d82f16b9c1a4a331bd4f853 (patch) | |
| tree | 3ac28ec249ad345417b04247ea2d17d03adbd9b1 | |
| parent | fce63c2c550b8715e347a44b1d874f48157543d3 (diff) | |
Support a storage class, NodePayloadAMDX, for SPIRV work-graphs (#6052)
In order to unblock experiments with SPIRV work-graphs, Slang
needs to support the storage class, `NodePayloadAMDX`.
Note that this commit is only to support a storage class,
`NodePayloadAMDX`. There are many parts required for work-graphs
hasn't been implemented yet.
The implementation of `DispatchNodeInputRecord` is not required, but it
is implemented mostly for a testing purpose.
Closes #6049
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/core.meta.slang | 4 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 28 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv-ops.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 40 | ||||
| -rw-r--r-- | source/slang/slang-type-system-shared.h | 2 | ||||
| -rw-r--r-- | tests/workgraphs/consumer.slang | 32 |
6 files changed, 108 insertions, 12 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index f0324ba1a..ba8decc12 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -3965,3 +3965,7 @@ attribute_syntax [DerivativeGroupQuad] : DerivativeGroupQuadAttribute; /// effect on other targets. __attributeTarget(FuncDecl) attribute_syntax [DerivativeGroupLinear] : DerivativeGroupLinearAttribute; + +__generic<T> +typealias NodePayloadPtr = Ptr<T, $( (uint64_t)AddressSpace::NodePayloadAMDX)>; + diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index a27cafbd4..07bf2ffbd 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -21365,3 +21365,31 @@ int8_t4_packed pack_clamp_s8(int16_t4 unpackedValue) } } +// Work-graphs + +//@public: +/// read-only input to Broadcasting launch node. +__generic<T> +//TODO: DispatchNodeInputRecord should be available only for broadcasting node shader. +//[require(broadcasting_node)] +[require(spirv)] +struct DispatchNodeInputRecord +{ + /// Provide an access to a record object that only holds a single record. + NodePayloadPtr<T> Get() + { + int index = 0; + __target_switch + { + case spirv: + return spirv_asm + { + %in_payload_t = OpTypeNodePayloadArrayAMDX $$T; + %in_payload_ptr_t = OpTypePointer NodePayloadAMDX %in_payload_t; + %var = OpVariable %in_payload_ptr_t NodePayloadAMDX; + result : $$NodePayloadPtr<T> = OpAccessChain %var $index; + }; + } + } +}; + diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index d385d54b2..1f2996646 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -2552,4 +2552,18 @@ SpvInst* emitOpAtomicIDecrement( memory, semantics); } + +// https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/AMD/SPV_AMDX_shader_enqueue.html#OpTypeNodePayloadArrayAMDX +template<typename T> +SpvInst* emitOpTypeNodePayloadArray(IRInst* inst, const T& type) +{ + static_assert(isSingular<T>); + return emitInstMemoized( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + inst, + SpvOpTypeNodePayloadArrayAMDX, + kResultID, + type); +} + #endif // SLANG_IN_SPIRV_EMIT_CONTEXT diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index d8c479cd1..1407404ad 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1311,6 +1311,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex return SpvStorageClassImage; case AddressSpace::UserPointer: return SpvStorageClassPhysicalStorageBuffer; + case AddressSpace::NodePayloadAMDX: + return SpvStorageClassNodePayloadAMDX; case AddressSpace::Global: case AddressSpace::MetalObjectData: case AddressSpace::SpecializationConstant: @@ -1504,13 +1506,22 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_ASSERT(ptrType); if (ptrType->hasAddressSpace()) storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace()); - if (storageClass == SpvStorageClassStorageBuffer) + + switch (storageClass) + { + case SpvStorageClassStorageBuffer: ensureExtensionDeclaration( UnownedStringSlice("SPV_KHR_storage_buffer_storage_class")); - if (storageClass == SpvStorageClassPhysicalStorageBuffer) - { + break; + case SpvStorageClassPhysicalStorageBuffer: requirePhysicalStorageAddressing(); + break; + case SpvStorageClassNodePayloadAMDX: + requireSPIRVCapability(SpvCapabilityShaderEnqueueAMDX); + ensureExtensionDeclaration(UnownedStringSlice("SPV_AMDX_shader_enqueue")); + break; } + auto valueType = ptrType->getValueType(); // If we haven't emitted the inner type yet, we need to emit a forward declaration. bool useForwardDeclaration = @@ -1524,17 +1535,20 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex builder.setInsertBefore(valueType); valueTypeId = getID(ensureInst(builder.getUIntType())); } + else if (useForwardDeclaration) + { + valueTypeId = getIRInstSpvID(valueType); + } + else if (storageClass == SpvStorageClassNodePayloadAMDX) + { + auto spvValueType = ensureInst(valueType); + auto spvNodePayloadType = emitOpTypeNodePayloadArray(inst, spvValueType); + valueTypeId = getID(spvNodePayloadType); + } else { - if (useForwardDeclaration) - { - valueTypeId = getIRInstSpvID(valueType); - } - else - { - auto spvValueType = ensureInst(valueType); - valueTypeId = getID(spvValueType); - } + auto spvValueType = ensureInst(valueType); + valueTypeId = getID(spvValueType); } auto resultSpvType = emitOpTypePointer(inst, storageClass, valueTypeId); @@ -7564,6 +7578,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case SpvOpMemberDecorate: case SpvOpMemberDecorateString: return getSection(SpvLogicalSectionID::Annotations); + case SpvOpTypeNodePayloadArrayAMDX: + return getSection(SpvLogicalSectionID::ConstantsAndTypes); default: return defaultParent; } diff --git a/source/slang/slang-type-system-shared.h b/source/slang/slang-type-system-shared.h index d11bc8cd3..583eb2216 100644 --- a/source/slang/slang-type-system-shared.h +++ b/source/slang/slang-type-system-shared.h @@ -110,6 +110,8 @@ enum class AddressSpace : uint64_t Image, // Represents a SPIR-V specialization constant SpecializationConstant, + // Corresponds to SPIR-V's SpvStorageClassNodePayloadAMDX, + NodePayloadAMDX, // Default address space for a user-defined pointer UserPointer = 0x100000001ULL, diff --git a/tests/workgraphs/consumer.slang b/tests/workgraphs/consumer.slang new file mode 100644 index 000000000..5e211a2a1 --- /dev/null +++ b/tests/workgraphs/consumer.slang @@ -0,0 +1,32 @@ +//TEST:SIMPLE(filecheck=CHK): -target spirv-asm -stage compute -entry main -skip-spirv-validation +struct RecordData +{ + int myData; +}; + +[shader("compute")] +[numthreads(1, 1, 1)] +void main(uint3 dispatchThreadId : SV_GroupThreadID) +{ + spirv_asm + { + OpExecutionMode $main ShaderIndexAMDX $(0); + OpExecutionMode $main StaticNumWorkgroupsAMDX $(1) $(1) $(1); + }; + + DispatchNodeInputRecord<RecordData> inputData; + + let recordData = inputData.Get(); + int myData = recordData.myData; +} + +//CHK: ; Types, variables and constants +//CHK: [[MemberType:%[a-zA-Z_0-9]+]] = OpTypeInt 32 1 +//CHK: [[StructType:%[a-zA-Z_0-9]+]] = OpTypeStruct [[MemberType]] +//CHK: [[PayloadType:%[a-zA-Z_0-9]+]] = OpTypeNodePayloadArrayAMDX [[StructType]] +//CHK: [[PtrType:%[a-zA-Z_0-9]+]] = OpTypePointer NodePayloadAMDX [[PayloadType]] + +//CHK: ; Function +//CHK: [[VarName:%[a-zA-Z_0-9]+]] = OpVariable [[PtrType]] NodePayloadAMDX +//CHK: = OpAccessChain [[PtrType]] [[VarName]] + |
