summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-05-09 23:26:43 -0400
committerGitHub <noreply@github.com>2025-05-09 20:26:43 -0700
commit48203ea02250ba517f749a222092f091d9bef15e (patch)
tree75663fe17fcbae4a376d9cdbe31393de43171e7c
parent5a6c2baadbc16fc2099a6951e389b9bd3cad08f6 (diff)
Fix SPIRV unsigned to signed widening casts (#7051)
* Fix unsigned to signed casts for SPIRV * Add test * Fix ICE crash
-rw-r--r--source/slang/slang-emit-spirv.cpp34
-rw-r--r--source/slang/slang-ir.cpp31
-rw-r--r--source/slang/slang-ir.h2
-rw-r--r--tests/spirv/int-cast-unsigned-to-signed.slang35
4 files changed, 92 insertions, 10 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index e4b5c2f4d..ad1f1bdf8 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -6599,11 +6599,11 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SLANG_ASSERT(!as<IRVectorType>(fromTypeV) == !as<IRVectorType>(toTypeV));
const auto fromType = getVectorOrCoopMatrixElementType(fromTypeV);
const auto toType = getVectorOrCoopMatrixElementType(toTypeV);
+ IRBuilder builder(inst);
if (as<IRBoolType>(fromType))
{
// Cast from bool to int.
- IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto zero = builder.getIntValue(toType, 0);
auto one = builder.getIntValue(toType, 1);
@@ -6635,7 +6635,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
else if (as<IRBoolType>(toType))
{
// Cast from int to bool.
- IRBuilder builder(inst);
builder.setInsertBefore(inst);
auto zero = builder.getIntValue(fromType, 0);
if (auto vecType = as<IRVectorType>(toTypeV))
@@ -6667,20 +6666,35 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
const auto toInfo = getIntTypeInfo(toType);
if (fromInfo == toInfo)
+ {
+ // Same exact integer types, copy the object.
return emitOpCopyObject(parent, inst, toTypeV, inst->getOperand(0));
+ }
else if (fromInfo.width == toInfo.width)
+ {
+ // Same bit width, perform bit cast.
return emitOpBitcast(parent, inst, toTypeV, inst->getOperand(0));
+ }
else if (!fromInfo.isSigned && !toInfo.isSigned)
- // unsigned to unsigned, don't sign extend
+ {
+ // Unsigned to unsigned, don't sign extend.
return emitOpUConvert(parent, inst, toTypeV, inst->getOperand(0));
- else if (toInfo.isSigned)
- // unsigned to signed, sign extend
- return emitOpSConvert(parent, inst, toTypeV, inst->getOperand(0));
+ }
+ else if (!fromInfo.isSigned && toInfo.isSigned)
+ {
+ // Unsigned to signed with different widths, don't sign extend.
+ // Perform unsigned conversion first to an unsigned integer of the same width as the
+ // result then perform bit cast to the signed result type. This is done because SPIRV's
+ // unsigned conversion (`OpUConvert`) requires result type to be unsigned.
+ auto unsignedV = emitOpUConvert(
+ parent,
+ nullptr,
+ builder.getType(getOppositeSignIntTypeOp(toType->getOp())),
+ inst->getOperand(0));
+ return emitOpBitcast(parent, inst, toTypeV, unsignedV);
+ }
else if (fromInfo.isSigned)
- // signed to unsigned, sign extend
- return emitOpSConvert(parent, inst, toTypeV, inst->getOperand(0));
- else if (fromInfo.isSigned && toInfo.isSigned)
- // signed to signed, sign extend
+ // Signed to signed and signed to unsigned, sign extend.
return emitOpSConvert(parent, inst, toTypeV, inst->getOperand(0));
SLANG_UNREACHABLE(__func__);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index a99eddebb..98c0fa471 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -7811,6 +7811,37 @@ IROp getIntTypeOpFromInfo(const IntInfo info)
}
}
+IROp getOppositeSignIntTypeOp(IROp op)
+{
+ switch (op)
+ {
+ case kIROp_UInt8Type:
+ return kIROp_Int8Type;
+ case kIROp_UInt16Type:
+ return kIROp_Int16Type;
+ case kIROp_UIntType:
+ return kIROp_IntType;
+ case kIROp_UInt64Type:
+ return kIROp_Int64Type;
+ case kIROp_UIntPtrType:
+ return kIROp_IntPtrType;
+
+ case kIROp_Int8Type:
+ return kIROp_UInt8Type;
+ case kIROp_Int16Type:
+ return kIROp_UInt16Type;
+ case kIROp_IntType:
+ return kIROp_UIntType;
+ case kIROp_Int64Type:
+ return kIROp_UInt64Type;
+ case kIROp_IntPtrType:
+ return kIROp_UIntPtrType;
+
+ default:
+ SLANG_UNEXPECTED("Unhandled type passed to getOppositeSignIntTypeOp");
+ }
+}
+
FloatInfo getFloatingTypeInfo(const IRType* floatType)
{
switch (floatType->getOp())
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 5a1ae94f7..91c2f018a 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1053,6 +1053,8 @@ IntInfo getIntTypeInfo(const IRType* intType);
// left-inverse of getIntTypeInfo
IROp getIntTypeOpFromInfo(const IntInfo info);
+IROp getOppositeSignIntTypeOp(IROp op);
+
struct FloatInfo
{
Int width;
diff --git a/tests/spirv/int-cast-unsigned-to-signed.slang b/tests/spirv/int-cast-unsigned-to-signed.slang
new file mode 100644
index 000000000..0ec3203a4
--- /dev/null
+++ b/tests/spirv/int-cast-unsigned-to-signed.slang
@@ -0,0 +1,35 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cpu -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUF):-cuda -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+//TEST_INPUT:ubuffer(data=[200], stride=4):name input1
+RWStructuredBuffer<uint8_t> input1;
+
+//TEST_INPUT:ubuffer(data=[35000], stride=4):name input2
+RWStructuredBuffer<uint16_t> input2;
+
+//TEST_INPUT:ubuffer(data=[201], stride=4):name input3
+RWStructuredBuffer<uint8_t> input3;
+
+//
+// Tests unsigned to signed casts to wider int.
+//
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint index = dispatchThreadID.x;
+
+ // BUF: 200
+ outputBuffer[index] = input1[0];
+
+ // BUF: 35000
+ outputBuffer[index + 1] = input2[0];
+
+ // BUF: 201
+ int16_t a = input3[0];
+ outputBuffer[index + 2] = a;
+}