diff options
| author | Anders Leino <aleino@nvidia.com> | 2025-01-10 21:05:05 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-10 11:05:05 -0800 |
| commit | 803e0c9f9a9dc4b01e29ebbf3b37a5bba782ac83 (patch) | |
| tree | 4996c9f415c64692e8381ae8c9ab1ab914ee86ea /source | |
| parent | 6437f2d37b08972db5e4515bd124639c2903dda1 (diff) | |
WGSL: Convert signed vector shift amounts to unsigned (#6023)
* WGSL: Fixes for signed shift amounts
- Handle the case of vector shift amounts
- Closes #5985
- Move handling of scalar case from emit to legalization
- Add tests for bitshifts.
* Move the binary operator legalization function to a common place
* Metal: Legalize binary operations
Closes #6029.
* Fix Metal filecheck test
The int shift amounts are now converted to unsigned.
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-wgsl.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-binary-operator.cpp | 121 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-binary-operator.h | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 37 | ||||
| -rw-r--r-- | source/slang/slang-ir-wgsl-legalize.cpp | 59 |
5 files changed, 181 insertions, 73 deletions
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 3b2cf12d0..30a7af938 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1372,10 +1372,10 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu case kIROp_Rsh: case kIROp_Lsh: { - // Shift amounts must be an unsigned type in WGSL + // Shift amounts must be an unsigned type in WGSL. + // We ensure this during legalization. // https://www.w3.org/TR/WGSL/#bit-expr - IRInst* const shiftAmount = inst->getOperand(1); - IRType* const shiftAmountType = shiftAmount->getDataType(); + SLANG_ASSERT(inst->getOperand(1)->getDataType()->getOp() != kIROp_IntType); // Dawn complains about mixing '<<' and '|', '^' and a bunch of other bit operators // without a paranthesis, so we'll always emit paranthesis around the shift amount. @@ -1392,18 +1392,9 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(info.op); m_writer->emit(" "); - if (shiftAmountType->getOp() == kIROp_IntType) - { - m_writer->emit("bitcast<u32>("); - emitOperand(inst->getOperand(1), rightSide(outerPrec, info)); - m_writer->emit(")"); - } - else - { - m_writer->emit("("); - emitOperand(inst->getOperand(1), rightSide(outerPrec, info)); - m_writer->emit(")"); - } + m_writer->emit("("); + emitOperand(inst->getOperand(1), rightSide(outerPrec, info)); + m_writer->emit(")"); maybeCloseParens(needClose); diff --git a/source/slang/slang-ir-legalize-binary-operator.cpp b/source/slang/slang-ir-legalize-binary-operator.cpp new file mode 100644 index 000000000..a1affb7e9 --- /dev/null +++ b/source/slang/slang-ir-legalize-binary-operator.cpp @@ -0,0 +1,121 @@ +#include "slang-ir-legalize-binary-operator.h" + +#include "slang-ir-insts.h" + +namespace Slang +{ + +void legalizeBinaryOp(IRInst* inst) +{ + // For shifts, ensure that the shift amount is unsigned, as required by + // https://www.w3.org/TR/WGSL/#bit-expr. + if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) + { + IRInst* shiftAmount = inst->getOperand(1); + IRType* shiftAmountType = shiftAmount->getDataType(); + if (auto shiftAmountVectorType = as<IRVectorType>(shiftAmountType)) + { + IRType* shiftAmountElementType = shiftAmountVectorType->getElementType(); + IntInfo opIntInfo = getIntTypeInfo(shiftAmountElementType); + if (opIntInfo.isSigned) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + opIntInfo.isSigned = false; + shiftAmountElementType = builder.getType(getIntTypeOpFromInfo(opIntInfo)); + shiftAmountVectorType = builder.getVectorType( + shiftAmountElementType, + shiftAmountVectorType->getElementCount()); + IRInst* newShiftAmount = builder.emitCast(shiftAmountVectorType, shiftAmount); + builder.replaceOperand(inst->getOperands() + 1, newShiftAmount); + } + } + else if (isIntegralType(shiftAmountType)) + { + IntInfo opIntInfo = getIntTypeInfo(shiftAmountType); + if (opIntInfo.isSigned) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + opIntInfo.isSigned = false; + shiftAmountType = builder.getType(getIntTypeOpFromInfo(opIntInfo)); + IRInst* newShiftAmount = builder.emitCast(shiftAmountType, shiftAmount); + builder.replaceOperand(inst->getOperands() + 1, newShiftAmount); + } + } + } + + auto isVectorOrMatrix = [](IRType* type) + { + switch (type->getOp()) + { + case kIROp_VectorType: + case kIROp_MatrixType: + return true; + default: + return false; + } + }; + if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) && + as<IRBasicType>(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + IRType* compositeType = inst->getOperand(0)->getDataType(); + IRInst* scalarValue = inst->getOperand(1); + // Retain the scalar type for shifts + if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) + { + auto vectorType = as<IRVectorType>(compositeType); + compositeType = + builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount()); + } + auto newRhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue); + builder.replaceOperand(inst->getOperands() + 1, newRhs); + } + else if ( + as<IRBasicType>(inst->getOperand(0)->getDataType()) && + isVectorOrMatrix(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + IRType* compositeType = inst->getOperand(1)->getDataType(); + IRInst* scalarValue = inst->getOperand(0); + // Retain the scalar type for shifts + if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) + { + auto vectorType = as<IRVectorType>(compositeType); + compositeType = + builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount()); + } + auto newLhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue); + builder.replaceOperand(inst->getOperands(), newLhs); + } + else if ( + isIntegralType(inst->getOperand(0)->getDataType()) && + isIntegralType(inst->getOperand(1)->getDataType())) + { + // Unless the operator is a shift, and if the integer operands differ in signedness, + // then convert the signed one to unsigned. + // We're assuming that the cases where this is bad have already been caught by + // common validation checks. + IntInfo opIntInfo[2] = { + getIntTypeInfo(inst->getOperand(0)->getDataType()), + getIntTypeInfo(inst->getOperand(1)->getDataType())}; + bool isShift = inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh; + bool signednessDiffers = opIntInfo[0].isSigned != opIntInfo[1].isSigned; + if (!isShift && signednessDiffers) + { + int signedOpIndex = (int)opIntInfo[1].isSigned; + opIntInfo[signedOpIndex].isSigned = false; + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto newOp = builder.emitCast( + builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])), + inst->getOperand(signedOpIndex)); + builder.replaceOperand(inst->getOperands() + signedOpIndex, newOp); + } + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-legalize-binary-operator.h b/source/slang/slang-ir-legalize-binary-operator.h new file mode 100644 index 000000000..71c319718 --- /dev/null +++ b/source/slang/slang-ir-legalize-binary-operator.h @@ -0,0 +1,16 @@ +#pragma once + +namespace Slang +{ + +struct IRInst; + +// Ensures: +// - Shift amounts are over unsigned scalar types. +// - If one operand is a composite type (vector or matrix), and the other one is a scalar +// type, then the scalar is converted to a composite type. +// - If 'inst' is not a shift, and if operands are integers of mixed signedness, then the +// signed operand is converted to unsigned. +void legalizeBinaryOp(IRInst* inst); + +} // namespace Slang diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index ce5b34c3e..5bfa62e4a 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -2,6 +2,7 @@ #include "slang-ir-clone.h" #include "slang-ir-insts.h" +#include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-varying-params.h" #include "slang-ir-specialize-address-space.h" #include "slang-ir-util.h" @@ -2120,6 +2121,40 @@ struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner } }; +static void processInst(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_FRem: + case kIROp_IRem: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + legalizeBinaryOp(inst); + break; + + default: + for (auto child : inst->getModifiableChildren()) + { + processInst(child); + } + } +} + void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) { List<EntryPointInfo> entryPoints; @@ -2145,6 +2180,8 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) MetalAddressSpaceAssigner metalAddressSpaceAssigner; specializeAddressSpace(module, &metalAddressSpaceAssigner); + + processInst(module->getModuleInst()); } } // namespace Slang diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index f76a0541c..effc06f3e 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -1,6 +1,7 @@ #include "slang-ir-wgsl-legalize.h" #include "slang-ir-insts.h" +#include "slang-ir-legalize-binary-operator.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-varying-params.h" #include "slang-ir-util.h" @@ -1487,64 +1488,6 @@ struct LegalizeWGSLEntryPointContext switchInst->removeAndDeallocate(); } - void legalizeBinaryOp(IRInst* inst) - { - auto isVectorOrMatrix = [](IRType* type) - { - switch (type->getOp()) - { - case kIROp_VectorType: - case kIROp_MatrixType: - return true; - default: - return false; - } - }; - if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) && - as<IRBasicType>(inst->getOperand(1)->getDataType())) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto newRhs = builder.emitMakeCompositeFromScalar( - inst->getOperand(0)->getDataType(), - inst->getOperand(1)); - builder.replaceOperand(inst->getOperands() + 1, newRhs); - } - else if ( - as<IRBasicType>(inst->getOperand(0)->getDataType()) && - isVectorOrMatrix(inst->getOperand(1)->getDataType())) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto newLhs = builder.emitMakeCompositeFromScalar( - inst->getOperand(1)->getDataType(), - inst->getOperand(0)); - builder.replaceOperand(inst->getOperands(), newLhs); - } - else if ( - isIntegralType(inst->getOperand(0)->getDataType()) && - isIntegralType(inst->getOperand(1)->getDataType())) - { - // If integer operands differ in signedness, convert the signed one to unsigned. - // We're assuming that the cases where this is bad have already been caught by - // common validation checks. - IntInfo opIntInfo[2] = { - getIntTypeInfo(inst->getOperand(0)->getDataType()), - getIntTypeInfo(inst->getOperand(1)->getDataType())}; - if (opIntInfo[0].isSigned != opIntInfo[1].isSigned) - { - int signedOpIndex = (int)opIntInfo[1].isSigned; - opIntInfo[signedOpIndex].isSigned = false; - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto newOp = builder.emitCast( - builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])), - inst->getOperand(signedOpIndex)); - builder.replaceOperand(inst->getOperands() + signedOpIndex, newOp); - } - } - } - void processInst(IRInst* inst) { switch (inst->getOp()) |
