diff options
| -rw-r--r-- | source/slang/slang-ir-legalize-varying-params.cpp | 66 | ||||
| -rw-r--r-- | tests/cuda/dispatch-thread-id-extraction.slang | 48 |
2 files changed, 109 insertions, 5 deletions
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index 6ece10457..b31a6d92f 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -1270,13 +1270,13 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz switch (info.systemValueSemanticName) { case SystemValueSemanticName::GroupID: - return LegalizedVaryingVal::makeValue(blockIdxGlobalParam); + return createLegalizedVal(info, blockIdxGlobalParam); case SystemValueSemanticName::GroupThreadID: - return LegalizedVaryingVal::makeValue(threadIdxGlobalParam); + return createLegalizedVal(info, threadIdxGlobalParam); case SystemValueSemanticName::GroupIndex: - return LegalizedVaryingVal::makeValue(groupThreadIndex); + return createLegalizedVal(info, groupThreadIndex); case SystemValueSemanticName::DispatchThreadID: - return LegalizedVaryingVal::makeValue(dispatchThreadID); + return createLegalizedVal(info, dispatchThreadID); default: return diagnoseUnsupportedSystemVal(info); } @@ -1331,6 +1331,62 @@ struct CUDAEntryPointVaryingParamLegalizeContext : EntryPointVaryingParamLegaliz return diagnoseUnsupportedUserVal(info); } } + + LegalizedVaryingVal createLegalizedVal(VaryingParamInfo const& info, IRInst* id) + { + // If the parameter type is not uint3, we need to extract components as needed + auto paramType = info.type->getOperand(0); + IRBuilder builder(m_module); + builder.setInsertBefore(m_firstOrdinaryInst); + + if (as<IRBasicType>(paramType)) + { + auto uintType = builder.getBasicType(BaseType::UInt); + UInt swizzleIndex = 0; + auto xComponent = builder.emitSwizzle(uintType, id, 1, &swizzleIndex); + + if (auto basicType = as<IRBasicType>(paramType)) + { + if (basicType->getBaseType() != BaseType::UInt) + { + xComponent = builder.emitBitCast(basicType, xComponent); + } + } + return LegalizedVaryingVal::makeValue(xComponent); + } + // For vector types, use a swizzle to extract the needed components + else if (auto vectorType = as<IRVectorType>(paramType)) + { + auto elementCount = getIntVal(vectorType->getElementCount()); + + if (elementCount > 0 && elementCount <= 3) + { + // Setup indices for the swizzle (0 for x, 1 for y, 2 for z) + UInt swizzleIndices[3] = {0, 1, 2}; + auto uintType = builder.getBasicType(BaseType::UInt); + + // Use a swizzle to extract all needed components at once + auto extractedVector = builder.emitSwizzle( + builder.getVectorType(uintType, elementCount), + id, + elementCount, + swizzleIndices); + + // Cast if the element type is not uint + auto elementType = vectorType->getElementType(); + if (auto basicElementType = as<IRBasicType>(elementType)) + { + if (basicElementType->getBaseType() != BaseType::UInt) + { + extractedVector = builder.emitBitCast(vectorType, extractedVector); + } + } + return LegalizedVaryingVal::makeValue(extractedVector); + } + } + // Default to the full uint3 if the parameter type doesn't match our expectations + return LegalizedVaryingVal::makeValue(id); + } }; @@ -1763,7 +1819,7 @@ private: void removeSemanticLayoutsFromLegalizedStructs() { // Metal and WGSL does not allow duplicate attributes to appear in the same shader. - // If we emit our own struct with `[[color(0)]`, all existing uses of `[[color(0)]]` + // If we emit our own struct with `[[color(0)]]`, all existing uses of `[[color(0)]]` // must be removed. for (auto field : semanticInfoToRemove) { diff --git a/tests/cuda/dispatch-thread-id-extraction.slang b/tests/cuda/dispatch-thread-id-extraction.slang new file mode 100644 index 000000000..5fc3c89a6 --- /dev/null +++ b/tests/cuda/dispatch-thread-id-extraction.slang @@ -0,0 +1,48 @@ +//TEST:SIMPLE(filecheck=CHECK): -target cuda -line-directive-mode none + +// This test verifies that DispatchThreadID parameter of different types +// correctly extracts components from the underlying uint3 value in CUDA. + +//TEST_INPUT: ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name cudaOutputBuffer +RWStructuredBuffer<float> cudaOutputBuffer; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst) +{ + dst[tid.x] = src[tid.x]; +} +// CHECK: uint _S1 = (blockIdx * blockDim + threadIdx).x; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain2(uint2 tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst) +{ + dst[tid.x] = src[tid.y]; +} +// CHECK: uint2 _S2 = uint2 {(blockIdx * blockDim + threadIdx).x, (blockIdx * blockDim + threadIdx).y}; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain3(int2 tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst) +{ + dst[tid.x] = src[tid.x]; +} +// CHECK: int _S3 = (slang_bit_cast<int2 >(uint2 {(blockIdx * blockDim + threadIdx).x, (blockIdx * blockDim + threadIdx).y})).x; + + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain4(int tid: SV_DispatchThreadID, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst) +{ + dst[tid.x] = src[tid.x]; +} +// CHECK: int _S4 = (slang_bit_cast<int>((blockIdx * blockDim + threadIdx).x)); + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain5(int tid: SV_GroupIndex, StructuredBuffer<uint> src, RWStructuredBuffer<uint> dst) +{ + dst[tid.x] = src[tid.x]; +} +// CHECK: int _S5 = (slang_bit_cast<int>((threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x));
\ No newline at end of file |
