summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAnders Leino <aleino@nvidia.com>2025-01-10 21:05:05 +0200
committerGitHub <noreply@github.com>2025-01-10 11:05:05 -0800
commit803e0c9f9a9dc4b01e29ebbf3b37a5bba782ac83 (patch)
tree4996c9f415c64692e8381ae8c9ab1ab914ee86ea
parent6437f2d37b08972db5e4515bd124639c2903dda1 (diff)
WGSL: Convert signed vector shift amounts to unsigned (#6023)
* WGSL: Fixes for signed shift amounts - Handle the case of vector shift amounts - Closes #5985 - Move handling of scalar case from emit to legalization - Add tests for bitshifts. * Move the binary operator legalization function to a common place * Metal: Legalize binary operations Closes #6029. * Fix Metal filecheck test The int shift amounts are now converted to unsigned. * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-emit-wgsl.cpp21
-rw-r--r--source/slang/slang-ir-legalize-binary-operator.cpp121
-rw-r--r--source/slang/slang-ir-legalize-binary-operator.h16
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp37
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp59
-rw-r--r--tests/metal/byte-address-buffer.slang8
-rw-r--r--tests/wgsl/bitshifts.slang92
-rw-r--r--tests/wgsl/bitshifts.slang.expected.txt47
8 files changed, 324 insertions, 77 deletions
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index 3b2cf12d0..30a7af938 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -1372,10 +1372,10 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
case kIROp_Rsh:
case kIROp_Lsh:
{
- // Shift amounts must be an unsigned type in WGSL
+ // Shift amounts must be an unsigned type in WGSL.
+ // We ensure this during legalization.
// https://www.w3.org/TR/WGSL/#bit-expr
- IRInst* const shiftAmount = inst->getOperand(1);
- IRType* const shiftAmountType = shiftAmount->getDataType();
+ SLANG_ASSERT(inst->getOperand(1)->getDataType()->getOp() != kIROp_IntType);
// Dawn complains about mixing '<<' and '|', '^' and a bunch of other bit operators
// without a paranthesis, so we'll always emit paranthesis around the shift amount.
@@ -1392,18 +1392,9 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
m_writer->emit(info.op);
m_writer->emit(" ");
- if (shiftAmountType->getOp() == kIROp_IntType)
- {
- m_writer->emit("bitcast<u32>(");
- emitOperand(inst->getOperand(1), rightSide(outerPrec, info));
- m_writer->emit(")");
- }
- else
- {
- m_writer->emit("(");
- emitOperand(inst->getOperand(1), rightSide(outerPrec, info));
- m_writer->emit(")");
- }
+ m_writer->emit("(");
+ emitOperand(inst->getOperand(1), rightSide(outerPrec, info));
+ m_writer->emit(")");
maybeCloseParens(needClose);
diff --git a/source/slang/slang-ir-legalize-binary-operator.cpp b/source/slang/slang-ir-legalize-binary-operator.cpp
new file mode 100644
index 000000000..a1affb7e9
--- /dev/null
+++ b/source/slang/slang-ir-legalize-binary-operator.cpp
@@ -0,0 +1,121 @@
+#include "slang-ir-legalize-binary-operator.h"
+
+#include "slang-ir-insts.h"
+
+namespace Slang
+{
+
+void legalizeBinaryOp(IRInst* inst)
+{
+ // For shifts, ensure that the shift amount is unsigned, as required by
+ // https://www.w3.org/TR/WGSL/#bit-expr.
+ if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
+ {
+ IRInst* shiftAmount = inst->getOperand(1);
+ IRType* shiftAmountType = shiftAmount->getDataType();
+ if (auto shiftAmountVectorType = as<IRVectorType>(shiftAmountType))
+ {
+ IRType* shiftAmountElementType = shiftAmountVectorType->getElementType();
+ IntInfo opIntInfo = getIntTypeInfo(shiftAmountElementType);
+ if (opIntInfo.isSigned)
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ opIntInfo.isSigned = false;
+ shiftAmountElementType = builder.getType(getIntTypeOpFromInfo(opIntInfo));
+ shiftAmountVectorType = builder.getVectorType(
+ shiftAmountElementType,
+ shiftAmountVectorType->getElementCount());
+ IRInst* newShiftAmount = builder.emitCast(shiftAmountVectorType, shiftAmount);
+ builder.replaceOperand(inst->getOperands() + 1, newShiftAmount);
+ }
+ }
+ else if (isIntegralType(shiftAmountType))
+ {
+ IntInfo opIntInfo = getIntTypeInfo(shiftAmountType);
+ if (opIntInfo.isSigned)
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ opIntInfo.isSigned = false;
+ shiftAmountType = builder.getType(getIntTypeOpFromInfo(opIntInfo));
+ IRInst* newShiftAmount = builder.emitCast(shiftAmountType, shiftAmount);
+ builder.replaceOperand(inst->getOperands() + 1, newShiftAmount);
+ }
+ }
+ }
+
+ auto isVectorOrMatrix = [](IRType* type)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ return true;
+ default:
+ return false;
+ }
+ };
+ if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
+ as<IRBasicType>(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ IRType* compositeType = inst->getOperand(0)->getDataType();
+ IRInst* scalarValue = inst->getOperand(1);
+ // Retain the scalar type for shifts
+ if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
+ {
+ auto vectorType = as<IRVectorType>(compositeType);
+ compositeType =
+ builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
+ }
+ auto newRhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
+ builder.replaceOperand(inst->getOperands() + 1, newRhs);
+ }
+ else if (
+ as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
+ isVectorOrMatrix(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ IRType* compositeType = inst->getOperand(1)->getDataType();
+ IRInst* scalarValue = inst->getOperand(0);
+ // Retain the scalar type for shifts
+ if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
+ {
+ auto vectorType = as<IRVectorType>(compositeType);
+ compositeType =
+ builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
+ }
+ auto newLhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
+ builder.replaceOperand(inst->getOperands(), newLhs);
+ }
+ else if (
+ isIntegralType(inst->getOperand(0)->getDataType()) &&
+ isIntegralType(inst->getOperand(1)->getDataType()))
+ {
+ // Unless the operator is a shift, and if the integer operands differ in signedness,
+ // then convert the signed one to unsigned.
+ // We're assuming that the cases where this is bad have already been caught by
+ // common validation checks.
+ IntInfo opIntInfo[2] = {
+ getIntTypeInfo(inst->getOperand(0)->getDataType()),
+ getIntTypeInfo(inst->getOperand(1)->getDataType())};
+ bool isShift = inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh;
+ bool signednessDiffers = opIntInfo[0].isSigned != opIntInfo[1].isSigned;
+ if (!isShift && signednessDiffers)
+ {
+ int signedOpIndex = (int)opIntInfo[1].isSigned;
+ opIntInfo[signedOpIndex].isSigned = false;
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto newOp = builder.emitCast(
+ builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])),
+ inst->getOperand(signedOpIndex));
+ builder.replaceOperand(inst->getOperands() + signedOpIndex, newOp);
+ }
+ }
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-legalize-binary-operator.h b/source/slang/slang-ir-legalize-binary-operator.h
new file mode 100644
index 000000000..71c319718
--- /dev/null
+++ b/source/slang/slang-ir-legalize-binary-operator.h
@@ -0,0 +1,16 @@
+#pragma once
+
+namespace Slang
+{
+
+struct IRInst;
+
+// Ensures:
+// - Shift amounts are over unsigned scalar types.
+// - If one operand is a composite type (vector or matrix), and the other one is a scalar
+// type, then the scalar is converted to a composite type.
+// - If 'inst' is not a shift, and if operands are integers of mixed signedness, then the
+// signed operand is converted to unsigned.
+void legalizeBinaryOp(IRInst* inst);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index ce5b34c3e..5bfa62e4a 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -2,6 +2,7 @@
#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
+#include "slang-ir-legalize-binary-operator.h"
#include "slang-ir-legalize-varying-params.h"
#include "slang-ir-specialize-address-space.h"
#include "slang-ir-util.h"
@@ -2120,6 +2121,40 @@ struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner
}
};
+static void processInst(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_FRem:
+ case kIROp_IRem:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ legalizeBinaryOp(inst);
+ break;
+
+ default:
+ for (auto child : inst->getModifiableChildren())
+ {
+ processInst(child);
+ }
+ }
+}
+
void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
{
List<EntryPointInfo> entryPoints;
@@ -2145,6 +2180,8 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
MetalAddressSpaceAssigner metalAddressSpaceAssigner;
specializeAddressSpace(module, &metalAddressSpaceAssigner);
+
+ processInst(module->getModuleInst());
}
} // namespace Slang
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
index f76a0541c..effc06f3e 100644
--- a/source/slang/slang-ir-wgsl-legalize.cpp
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -1,6 +1,7 @@
#include "slang-ir-wgsl-legalize.h"
#include "slang-ir-insts.h"
+#include "slang-ir-legalize-binary-operator.h"
#include "slang-ir-legalize-global-values.h"
#include "slang-ir-legalize-varying-params.h"
#include "slang-ir-util.h"
@@ -1487,64 +1488,6 @@ struct LegalizeWGSLEntryPointContext
switchInst->removeAndDeallocate();
}
- void legalizeBinaryOp(IRInst* inst)
- {
- auto isVectorOrMatrix = [](IRType* type)
- {
- switch (type->getOp())
- {
- case kIROp_VectorType:
- case kIROp_MatrixType:
- return true;
- default:
- return false;
- }
- };
- if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
- as<IRBasicType>(inst->getOperand(1)->getDataType()))
- {
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- auto newRhs = builder.emitMakeCompositeFromScalar(
- inst->getOperand(0)->getDataType(),
- inst->getOperand(1));
- builder.replaceOperand(inst->getOperands() + 1, newRhs);
- }
- else if (
- as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
- isVectorOrMatrix(inst->getOperand(1)->getDataType()))
- {
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- auto newLhs = builder.emitMakeCompositeFromScalar(
- inst->getOperand(1)->getDataType(),
- inst->getOperand(0));
- builder.replaceOperand(inst->getOperands(), newLhs);
- }
- else if (
- isIntegralType(inst->getOperand(0)->getDataType()) &&
- isIntegralType(inst->getOperand(1)->getDataType()))
- {
- // If integer operands differ in signedness, convert the signed one to unsigned.
- // We're assuming that the cases where this is bad have already been caught by
- // common validation checks.
- IntInfo opIntInfo[2] = {
- getIntTypeInfo(inst->getOperand(0)->getDataType()),
- getIntTypeInfo(inst->getOperand(1)->getDataType())};
- if (opIntInfo[0].isSigned != opIntInfo[1].isSigned)
- {
- int signedOpIndex = (int)opIntInfo[1].isSigned;
- opIntInfo[signedOpIndex].isSigned = false;
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- auto newOp = builder.emitCast(
- builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])),
- inst->getOperand(signedOpIndex));
- builder.replaceOperand(inst->getOperands() + signedOpIndex, newOp);
- }
- }
- }
-
void processInst(IRInst* inst)
{
switch (inst->getOp())
diff --git a/tests/metal/byte-address-buffer.slang b/tests/metal/byte-address-buffer.slang
index 24802815e..d4b58061f 100644
--- a/tests/metal/byte-address-buffer.slang
+++ b/tests/metal/byte-address-buffer.slang
@@ -20,11 +20,11 @@ struct TestStruct
void main_kernel(uint3 tid: SV_DispatchThreadID)
{
// CHECK: uint [[WORD0:[a-zA-Z0-9_]+]] = as_type<uint>({{.*}}[(int(0))>>2]);
- // CHECK: uint8_t [[A:[a-zA-Z0-9_]+]] = uint8_t([[WORD0]] >> int(0) & 255U);
+ // CHECK: uint8_t [[A:[a-zA-Z0-9_]+]] = uint8_t([[WORD0]] >> 0U & 255U);
// CHECK: uint [[WORD1:[a-zA-Z0-9_]+]] = as_type<uint>({{.*}}[(int(0))>>2]);
- // CHECK: half [[H:[a-zA-Z0-9_]+]] = as_type<half>(ushort([[WORD1]] >> int(16) & 65535U));
+ // CHECK: half [[H:[a-zA-Z0-9_]+]] = as_type<half>(ushort([[WORD1]] >> 16U & 65535U));
- // CHECK: {{.*}}[(int(128))>>2] = as_type<uint32_t>(({{.*}} & 4294967040U) | uint([[A]]) << int(0));
- // CHECK: {{.*}}[(int(128))>>2] = as_type<uint32_t>(({{.*}} & 65535U) | uint(as_type<ushort>([[H]])) << int(16));
+ // CHECK: {{.*}}[(int(128))>>2] = as_type<uint32_t>(({{.*}} & 4294967040U) | uint([[A]]) << 0U);
+ // CHECK: {{.*}}[(int(128))>>2] = as_type<uint32_t>(({{.*}} & 65535U) | uint(as_type<ushort>([[H]])) << 16U);
buffer.Store(128, buffer.Load<TestStruct>(0));
}
diff --git a/tests/wgsl/bitshifts.slang b/tests/wgsl/bitshifts.slang
new file mode 100644
index 000000000..50d2fc43d
--- /dev/null
+++ b/tests/wgsl/bitshifts.slang
@@ -0,0 +1,92 @@
+//TEST(compute):COMPARE_COMPUTE:-shaderobj
+
+//TEST_INPUT:ubuffer(data=[3 7 8 10], stride=4):name=inputBuffer
+RWStructuredBuffer<int> inputBuffer;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ int amount = 1;
+
+ outputBuffer[0] = inputBuffer[0] >> amount;
+
+ int2 v2 = int2(inputBuffer[0], inputBuffer[1]) >> amount;
+ outputBuffer[1] = v2[0];
+ outputBuffer[2] = v2[1];
+
+ int3 v3 = int3(inputBuffer[0], inputBuffer[1], inputBuffer[2]) >> amount;
+ outputBuffer[3] = v3[0];
+ outputBuffer[4] = v3[1];
+ outputBuffer[5] = v3[2];
+
+ int4 v4 = int4(inputBuffer[0], inputBuffer[1], inputBuffer[2], inputBuffer[3]) >> amount;
+ outputBuffer[6] = v4[0];
+ outputBuffer[7] = v4[1];
+ outputBuffer[8] = v4[2];
+ outputBuffer[9] = v4[3];
+
+ outputBuffer[10] = inputBuffer[0] << amount;
+
+ v2 = int2(inputBuffer[0], inputBuffer[1]) << amount;
+ outputBuffer[11] = v2[0];
+ outputBuffer[12] = v2[1];
+
+ v3 = int3(inputBuffer[0], inputBuffer[1], inputBuffer[2]) << amount;
+ outputBuffer[13] = v3[0];
+ outputBuffer[14] = v3[1];
+ outputBuffer[15] = v3[2];
+
+ v4 = int4(inputBuffer[0], inputBuffer[1], inputBuffer[2], inputBuffer[3]) << amount;
+ outputBuffer[16] = v4[0];
+ outputBuffer[17] = v4[1];
+ outputBuffer[18] = v4[2];
+ outputBuffer[19] = v4[3];
+
+ v2 = inputBuffer[0] >> int2(amount);
+ outputBuffer[20] = v2[0];
+ outputBuffer[21] = v2[1];
+
+ v3 = inputBuffer[1] >> int3(amount);
+ outputBuffer[22] = v3[0];
+ outputBuffer[23] = v3[1];
+ outputBuffer[24] = v3[2];
+
+ v4 = inputBuffer[2] >> int4(amount);
+ outputBuffer[25] = v4[0];
+ outputBuffer[26] = v4[1];
+ outputBuffer[27] = v4[2];
+ outputBuffer[28] = v4[3];
+
+ v2 = inputBuffer[0] << int2(amount);
+ outputBuffer[29] = v2[0];
+ outputBuffer[30] = v2[1];
+
+ v3 = inputBuffer[1] << int3(amount);
+ outputBuffer[31] = v3[0];
+ outputBuffer[32] = v3[1];
+ outputBuffer[33] = v3[2];
+
+ v4 = inputBuffer[2] << int4(amount);
+ outputBuffer[34] = v4[0];
+ outputBuffer[35] = v4[1];
+ outputBuffer[36] = v4[2];
+ outputBuffer[37] = v4[3];
+
+ v2 = int2(inputBuffer[0], inputBuffer[1]) >> int2(1, 2);
+ outputBuffer[38] = v2[0];
+ outputBuffer[39] = v2[1];
+
+ v3 = int3(inputBuffer[0], inputBuffer[1], inputBuffer[2]) >> int3(1, 2, 3);
+ outputBuffer[40] = v3[0];
+ outputBuffer[41] = v3[1];
+ outputBuffer[42] = v3[2];
+
+ v4 = int4(inputBuffer[0], inputBuffer[1], inputBuffer[2], inputBuffer[4]) >> int4(1, 2, 3, 4);
+ outputBuffer[43] = v4[0];
+ outputBuffer[44] = v4[1];
+ outputBuffer[45] = v4[2];
+ outputBuffer[46] = v4[3];
+}
diff --git a/tests/wgsl/bitshifts.slang.expected.txt b/tests/wgsl/bitshifts.slang.expected.txt
new file mode 100644
index 000000000..ff1ae84e4
--- /dev/null
+++ b/tests/wgsl/bitshifts.slang.expected.txt
@@ -0,0 +1,47 @@
+1
+1
+3
+1
+3
+4
+1
+3
+4
+5
+6
+6
+E
+6
+E
+10
+6
+E
+10
+14
+1
+1
+3
+3
+3
+4
+4
+4
+4
+6
+6
+E
+E
+E
+10
+10
+10
+10
+1
+1
+1
+1
+1
+1
+1
+1
+0