diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-05-09 23:26:43 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-09 20:26:43 -0700 |
| commit | 48203ea02250ba517f749a222092f091d9bef15e (patch) | |
| tree | 75663fe17fcbae4a376d9cdbe31393de43171e7c | |
| parent | 5a6c2baadbc16fc2099a6951e389b9bd3cad08f6 (diff) | |
Fix SPIRV unsigned to signed widening casts (#7051)
* Fix unsigned to signed casts for SPIRV
* Add test
* Fix ICE crash
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 31 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 2 | ||||
| -rw-r--r-- | tests/spirv/int-cast-unsigned-to-signed.slang | 35 |
4 files changed, 92 insertions, 10 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index e4b5c2f4d..ad1f1bdf8 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -6599,11 +6599,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV)); const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV); const auto toType = getVectorOrCoopMatrixElementType(toTypeV); + IRBuilder builder(inst); if (as<IRBoolType>(fromType)) { // Cast from bool to int. - IRBuilder builder(inst); builder.setInsertBefore(inst); auto zero = builder.getIntValue(toType, 0); auto one = builder.getIntValue(toType, 1); @@ -6635,7 +6635,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex else if (as<IRBoolType>(toType)) { // Cast from int to bool. - IRBuilder builder(inst); builder.setInsertBefore(inst); auto zero = builder.getIntValue(fromType, 0); if (auto vecType = as<IRVectorType>(toTypeV)) @@ -6667,20 +6666,35 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex const auto toInfo = getIntTypeInfo(toType); if (fromInfo == toInfo) + { + // Same exact integer types, copy the object. return emitOpCopyObject(parent, inst, toTypeV, inst->getOperand(0)); + } else if (fromInfo.width == toInfo.width) + { + // Same bit width, perform bit cast. return emitOpBitcast(parent, inst, toTypeV, inst->getOperand(0)); + } else if (!fromInfo.isSigned && !toInfo.isSigned) - // unsigned to unsigned, don't sign extend + { + // Unsigned to unsigned, don't sign extend. return emitOpUConvert(parent, inst, toTypeV, inst->getOperand(0)); - else if (toInfo.isSigned) - // unsigned to signed, sign extend - return emitOpSConvert(parent, inst, toTypeV, inst->getOperand(0)); + } + else if (!fromInfo.isSigned && toInfo.isSigned) + { + // Unsigned to signed with different widths, don't sign extend. + // Perform unsigned conversion first to an unsigned integer of the same width as the + // result then perform bit cast to the signed result type. This is done because SPIRV's + // unsigned conversion (`OpUConvert`) requires result type to be unsigned. + auto unsignedV = emitOpUConvert( + parent, + nullptr, + builder.getType(getOppositeSignIntTypeOp(toType->getOp())), + inst->getOperand(0)); + return emitOpBitcast(parent, inst, toTypeV, unsignedV); + } else if (fromInfo.isSigned) - // signed to unsigned, sign extend - return emitOpSConvert(parent, inst, toTypeV, inst->getOperand(0)); - else if (fromInfo.isSigned && toInfo.isSigned) - // signed to signed, sign extend + // Signed to signed and signed to unsigned, sign extend. return emitOpSConvert(parent, inst, toTypeV, inst->getOperand(0)); SLANG_UNREACHABLE(__func__); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index a99eddebb..98c0fa471 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -7811,6 +7811,37 @@ IROp getIntTypeOpFromInfo(const IntInfo info) } } +IROp getOppositeSignIntTypeOp(IROp op) +{ + switch (op) + { + case kIROp_UInt8Type: + return kIROp_Int8Type; + case kIROp_UInt16Type: + return kIROp_Int16Type; + case kIROp_UIntType: + return kIROp_IntType; + case kIROp_UInt64Type: + return kIROp_Int64Type; + case kIROp_UIntPtrType: + return kIROp_IntPtrType; + + case kIROp_Int8Type: + return kIROp_UInt8Type; + case kIROp_Int16Type: + return kIROp_UInt16Type; + case kIROp_IntType: + return kIROp_UIntType; + case kIROp_Int64Type: + return kIROp_UInt64Type; + case kIROp_IntPtrType: + return kIROp_UIntPtrType; + + default: + SLANG_UNEXPECTED("Unhandled type passed to getOppositeSignIntTypeOp"); + } +} + FloatInfo getFloatingTypeInfo(const IRType* floatType) { switch (floatType->getOp()) diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 5a1ae94f7..91c2f018a 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1053,6 +1053,8 @@ IntInfo getIntTypeInfo(const IRType* intType); // left-inverse of getIntTypeInfo IROp getIntTypeOpFromInfo(const IntInfo info); +IROp getOppositeSignIntTypeOp(IROp op); + struct FloatInfo { Int width; diff --git a/tests/spirv/int-cast-unsigned-to-signed.slang b/tests/spirv/int-cast-unsigned-to-signed.slang new file mode 100644 index 000000000..0ec3203a4 --- /dev/null +++ b/tests/spirv/int-cast-unsigned-to-signed.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> outputBuffer; + +//TEST_INPUT:ubuffer(data=[200], stride=4):name input1 +RWStructuredBuffer<uint8_t> input1; + +//TEST_INPUT:ubuffer(data=[35000], stride=4):name input2 +RWStructuredBuffer<uint16_t> input2; + +//TEST_INPUT:ubuffer(data=[201], stride=4):name input3 +RWStructuredBuffer<uint8_t> input3; + +// +// Tests unsigned to signed casts to wider int. +// + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = dispatchThreadID.x; + + // BUF: 200 + outputBuffer[index] = input1[0]; + + // BUF: 35000 + outputBuffer[index + 1] = input2[0]; + + // BUF: 201 + int16_t a = input3[0]; + outputBuffer[index + 2] = a; +} |
