summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp66
-rw-r--r--tests/cuda/dispatch-thread-id-extraction.slang48
2 files changed, 109 insertions, 5 deletions
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index 6ece10457..b31a6d92f 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -1270,13 +1270,13 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz
switch (info.systemValueSemanticName)
{
case SystemValueSemanticName::GroupID:
- return LegalizedVaryingVal::makeValue(blockIdxGlobalParam);
+ return createLegalizedVal(info, blockIdxGlobalParam);
case SystemValueSemanticName::GroupThreadID:
- return LegalizedVaryingVal::makeValue(threadIdxGlobalParam);
+ return createLegalizedVal(info, threadIdxGlobalParam);
case SystemValueSemanticName::GroupIndex:
- return LegalizedVaryingVal::makeValue(groupThreadIndex);
+ return createLegalizedVal(info, groupThreadIndex);
case SystemValueSemanticName::DispatchThreadID:
- return LegalizedVaryingVal::makeValue(dispatchThreadID);
+ return createLegalizedVal(info, dispatchThreadID);
default:
return diagnoseUnsupportedSystemVal(info);
}
@@ -1331,6 +1331,62 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz
return diagnoseUnsupportedUserVal(info);
}
}
+
+ LegalizedVaryingVal createLegalizedVal(VaryingParamInfo const& info, IRInst* id)
+ {
+ // If the parameter type is not uint3, we need to extract components as needed
+ auto paramType = info.type->getOperand(0);
+ IRBuilder builder(m_module);
+ builder.setInsertBefore(m_firstOrdinaryInst);
+
+ if (as<IRBasicType>(paramType))
+ {
+ auto uintType = builder.getBasicType(BaseType::UInt);
+ UInt swizzleIndex = 0;
+ auto xComponent = builder.emitSwizzle(uintType, id, 1, &swizzleIndex);
+
+ if (auto basicType = as<IRBasicType>(paramType))
+ {
+ if (basicType->getBaseType() != BaseType::UInt)
+ {
+ xComponent = builder.emitBitCast(basicType, xComponent);
+ }
+ }
+ return LegalizedVaryingVal::makeValue(xComponent);
+ }
+ // For vector types, use a swizzle to extract the needed components
+ else if (auto vectorType = as<IRVectorType>(paramType))
+ {
+ auto elementCount = getIntVal(vectorType->getElementCount());
+
+ if (elementCount > 0 && elementCount <= 3)
+ {
+ // Setup indices for the swizzle (0 for x, 1 for y, 2 for z)
+ UInt swizzleIndices[3] = {0, 1, 2};
+ auto uintType = builder.getBasicType(BaseType::UInt);
+
+ // Use a swizzle to extract all needed components at once
+ auto extractedVector = builder.emitSwizzle(
+ builder.getVectorType(uintType, elementCount),
+ id,
+ elementCount,
+ swizzleIndices);
+
+ // Cast if the element type is not uint
+ auto elementType = vectorType->getElementType();
+ if (auto basicElementType = as<IRBasicType>(elementType))
+ {
+ if (basicElementType->getBaseType() != BaseType::UInt)
+ {
+ extractedVector = builder.emitBitCast(vectorType, extractedVector);
+ }
+ }
+ return LegalizedVaryingVal::makeValue(extractedVector);
+ }
+ }
+ // Default to the full uint3 if the parameter type doesn't match our expectations
+ return LegalizedVaryingVal::makeValue(id);
+ }
};
@@ -1763,7 +1819,7 @@ private:
void removeSemanticLayoutsFromLegalizedStructs()
{
// Metal and WGSL does not allow duplicate attributes to appear in the same shader.
- // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]`
+ // If we emit our own struct with `[[color(0)]]`, all existing uses of `[[color(0)]]`
// must be removed.
for (auto field : semanticInfoToRemove)
{
diff --git a/tests/cuda/dispatch-thread-id-extraction.slang b/tests/cuda/dispatch-thread-id-extraction.slang
new file mode 100644
index 000000000..5fc3c89a6
--- /dev/null
+++ b/tests/cuda/dispatch-thread-id-extraction.slang
@@ -0,0 +1,48 @@
+//TEST:SIMPLE(filecheck=CHECK): -target cuda -line-directive-mode none
+
+// This test verifies that DispatchThreadID parameter of different types
+// correctly extracts components from the underlying uint3 value in CUDA.
+
+//TEST_INPUT: ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name cudaOutputBuffer
+RWStructuredBuffer<float> cudaOutputBuffer;
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain(uint tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.x];
+}
+// CHECK: uint _S1 = (blockIdx * blockDim + threadIdx).x;
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain2(uint2 tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.y];
+}
+// CHECK: uint2 _S2 = uint2 {(blockIdx * blockDim + threadIdx).x, (blockIdx * blockDim + threadIdx).y};
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain3(int2 tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.x];
+}
+// CHECK: int _S3 = (slang_bit_cast<int2 >(uint2 {(blockIdx * blockDim + threadIdx).x, (blockIdx * blockDim + threadIdx).y})).x;
+
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain4(int tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.x];
+}
+// CHECK: int _S4 = (slang_bit_cast<int>((blockIdx * blockDim + threadIdx).x));
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain5(int tid: SV_GroupIndex, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.x];
+}
+// CHECK: int _S5 = (slang_bit_cast<int>((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x)); \ No newline at end of file