summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMukund Keshava <mkeshava@nvidia.com>2025-04-30 17:50:50 +0530
committerGitHub <noreply@github.com>2025-04-30 12:20:50 +0000
commit54acb11b1c8b6af2504ff3a3e0f56ca8baba4753 (patch)
tree6e4fd3548b2fcbd3a127ea6cef32fb037fa1d902
parentb0e150511a6a536c8ad9e74910b30ae179a10ec9 (diff)
cuda: Improve entry handling for SV_DispatchThreadID (#6925)
* cuda: Improve entry handling for SV_DispatchThreadID Fixes #6780 This commit improves CUDA entry point handling by extracting appropriate components from DispatchThreadID based on parameter type. It now properly handles uint scalar (x component only) and uint2 vector (x,y components) instead of always using the full uint3 value. Add a new test case to check for this. * format code * fix CI failure * Handle review comments --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp66
-rw-r--r--tests/cuda/dispatch-thread-id-extraction.slang48
2 files changed, 109 insertions, 5 deletions
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index 6ece10457..b31a6d92f 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -1270,13 +1270,13 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz
switch (info.systemValueSemanticName)
{
case SystemValueSemanticName::GroupID:
- return LegalizedVaryingVal::makeValue(blockIdxGlobalParam);
+ return createLegalizedVal(info, blockIdxGlobalParam);
case SystemValueSemanticName::GroupThreadID:
- return LegalizedVaryingVal::makeValue(threadIdxGlobalParam);
+ return createLegalizedVal(info, threadIdxGlobalParam);
case SystemValueSemanticName::GroupIndex:
- return LegalizedVaryingVal::makeValue(groupThreadIndex);
+ return createLegalizedVal(info, groupThreadIndex);
case SystemValueSemanticName::DispatchThreadID:
- return LegalizedVaryingVal::makeValue(dispatchThreadID);
+ return createLegalizedVal(info, dispatchThreadID);
default:
return diagnoseUnsupportedSystemVal(info);
}
@@ -1331,6 +1331,62 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz
return diagnoseUnsupportedUserVal(info);
}
}
+
+ LegalizedVaryingVal createLegalizedVal(VaryingParamInfo const& info, IRInst* id)
+ {
+ // If the parameter type is not uint3, we need to extract components as needed
+ auto paramType = info.type->getOperand(0);
+ IRBuilder builder(m_module);
+ builder.setInsertBefore(m_firstOrdinaryInst);
+
+ if (as<IRBasicType>(paramType))
+ {
+ auto uintType = builder.getBasicType(BaseType::UInt);
+ UInt swizzleIndex = 0;
+ auto xComponent = builder.emitSwizzle(uintType, id, 1, &swizzleIndex);
+
+ if (auto basicType = as<IRBasicType>(paramType))
+ {
+ if (basicType->getBaseType() != BaseType::UInt)
+ {
+ xComponent = builder.emitBitCast(basicType, xComponent);
+ }
+ }
+ return LegalizedVaryingVal::makeValue(xComponent);
+ }
+ // For vector types, use a swizzle to extract the needed components
+ else if (auto vectorType = as<IRVectorType>(paramType))
+ {
+ auto elementCount = getIntVal(vectorType->getElementCount());
+
+ if (elementCount > 0 && elementCount <= 3)
+ {
+ // Setup indices for the swizzle (0 for x, 1 for y, 2 for z)
+ UInt swizzleIndices[3] = {0, 1, 2};
+ auto uintType = builder.getBasicType(BaseType::UInt);
+
+ // Use a swizzle to extract all needed components at once
+ auto extractedVector = builder.emitSwizzle(
+ builder.getVectorType(uintType, elementCount),
+ id,
+ elementCount,
+ swizzleIndices);
+
+ // Cast if the element type is not uint
+ auto elementType = vectorType->getElementType();
+ if (auto basicElementType = as<IRBasicType>(elementType))
+ {
+ if (basicElementType->getBaseType() != BaseType::UInt)
+ {
+ extractedVector = builder.emitBitCast(vectorType, extractedVector);
+ }
+ }
+ return LegalizedVaryingVal::makeValue(extractedVector);
+ }
+ }
+ // Default to the full uint3 if the parameter type doesn't match our expectations
+ return LegalizedVaryingVal::makeValue(id);
+ }
};
@@ -1763,7 +1819,7 @@ private:
void removeSemanticLayoutsFromLegalizedStructs()
{
// Metal and WGSL does not allow duplicate attributes to appear in the same shader.
- // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]`
+ // If we emit our own struct with `[[color(0)]]`, all existing uses of `[[color(0)]]`
// must be removed.
for (auto field : semanticInfoToRemove)
{
diff --git a/tests/cuda/dispatch-thread-id-extraction.slang b/tests/cuda/dispatch-thread-id-extraction.slang
new file mode 100644
index 000000000..5fc3c89a6
--- /dev/null
+++ b/tests/cuda/dispatch-thread-id-extraction.slang
@@ -0,0 +1,48 @@
+//TEST:SIMPLE(filecheck=CHECK): -target cuda -line-directive-mode none
+
+// This test verifies that DispatchThreadID parameter of different types
+// correctly extracts components from the underlying uint3 value in CUDA.
+
+//TEST_INPUT: ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name cudaOutputBuffer
+RWStructuredBuffer<float> cudaOutputBuffer;
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain(uint tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.x];
+}
+// CHECK: uint _S1 = (blockIdx * blockDim + threadIdx).x;
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain2(uint2 tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.y];
+}
+// CHECK: uint2 _S2 = uint2 {(blockIdx * blockDim + threadIdx).x, (blockIdx * blockDim + threadIdx).y};
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain3(int2 tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.x];
+}
+// CHECK: int _S3 = (slang_bit_cast<int2 >(uint2 {(blockIdx * blockDim + threadIdx).x, (blockIdx * blockDim + threadIdx).y})).x;
+
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain4(int tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.x];
+}
+// CHECK: int _S4 = (slang_bit_cast<int>((blockIdx * blockDim + threadIdx).x));
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain5(int tid: SV_GroupIndex, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst)
+{
+ dst[tid.x] = src[tid.x];
+}
+// CHECK: int _S5 = (slang_bit_cast<int>((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x)); \ No newline at end of file