diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-04-14 14:48:17 -0600 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-14 14:48:17 -0600 |
| commit | 705d00ab8528e0d7c14f68b7d0e15fb57280c16e (patch) | |
| tree | acf6e024ef803c5a49e2c6c0075ab0d9a49a11d3 | |
| parent | d6f4780e8a608fa37597116d5b0ac5c80034c2aa (diff) | |
Fix matrix division by scalar for Metal and WGSL targets (#6752)
* Fix matrix division by scalar for Metal and WGSL targets
* Add tests
* Minor fix
* Fix compilation error
* Convert to multiplication for WGSL
* Minor cleanup
---------
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-binary-operator.cpp | 156 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-binary-operator.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-metal-legalize.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-wgsl-legalize.cpp | 8 | ||||
| -rw-r--r-- | tests/bugs/matrix-divided-by-scalar.slang | 21 | ||||
| -rw-r--r-- | tests/diagnostics/division-by-matrix.slang | 15 |
7 files changed, 161 insertions, 59 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 6d84792fb..dfea9fede 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2622,6 +2622,11 @@ DIAGNOSTIC( resourceTypesInConstantBufferInParameterBlockNotAllowedOnMetal, "nesting a 'ConstantBuffer' containing resource types inside a 'ParameterBlock' is not " "supported on Metal, please use 'ParameterBlock' instead.") +DIAGNOSTIC( + 56102, + Error, + divisionByMatrixNotSupported, + "division by matrix is not supported for Metal and WGSL targets.") DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0") DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.") diff --git a/source/slang/slang-ir-legalize-binary-operator.cpp b/source/slang/slang-ir-legalize-binary-operator.cpp index 1595aa130..f2f7cdef2 100644 --- a/source/slang/slang-ir-legalize-binary-operator.cpp +++ b/source/slang/slang-ir-legalize-binary-operator.cpp @@ -1,12 +1,105 @@ #include "slang-ir-legalize-binary-operator.h" +#include "compiler-core/slang-diagnostic-sink.h" #include "slang-ir-insts.h" namespace Slang { -void legalizeBinaryOp(IRInst* inst) +static bool isVectorOrMatrix(IRType* type) { + switch (type->getOp()) + { + case kIROp_VectorType: + case kIROp_MatrixType: + return true; + default: + return false; + } +}; + +static bool isDivisionByMatrix(IRInst* inst) +{ + return (inst->getOp() == kIROp_Div) && (as<IRMatrixType>(inst->getOperand(1)->getDataType())); +} + +static bool isMatrixDividedByScalar(IRInst* inst) +{ + return (inst->getOp() == kIROp_Div) && (as<IRMatrixType>(inst->getOperand(0)->getDataType())) && + (as<IRBasicType>(inst->getOperand(1)->getDataType())); +} + +// If one operand is a composite type (vector or matrix), and the other one is a scalar +// type, then the scalar is converted to a composite type. +static void legalizeScalarOperandsToMatchComposite(IRInst* inst) +{ + if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) && + as<IRBasicType>(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + IRType* compositeType = inst->getOperand(0)->getDataType(); + IRInst* scalarValue = inst->getOperand(1); + // Retain the scalar type for shifts + if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) + { + auto vectorType = as<IRVectorType>(compositeType); + compositeType = + builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount()); + } + auto newRhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue); + builder.replaceOperand(inst->getOperands() + 1, newRhs); + } + else if ( + as<IRBasicType>(inst->getOperand(0)->getDataType()) && + isVectorOrMatrix(inst->getOperand(1)->getDataType())) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + IRType* compositeType = inst->getOperand(1)->getDataType(); + IRInst* scalarValue = inst->getOperand(0); + // Retain the scalar type for shifts + if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) + { + auto vectorType = as<IRVectorType>(compositeType); + compositeType = + builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount()); + } + auto newLhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue); + builder.replaceOperand(inst->getOperands(), newLhs); + } +} + +// Replaces a division by scalar operation by a multiplication. +// This is done for WGSL where matrix divided by scalar operations are not supported. +static void replaceMatrixDividedByScalarWithMul(IRInst* inst) +{ + SLANG_ASSERT(isMatrixDividedByScalar(inst)); + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + auto scalarType = inst->getOperand(1)->getDataType(); + auto newRhs = + builder.emitDiv(scalarType, builder.getFloatValue(scalarType, 1.0), inst->getOperand(1)); + auto newOp = builder.emitMul(inst->getDataType(), inst->getOperand(0), newRhs); + + inst->replaceUsesWith(newOp); + inst->transferDecorationsTo(newOp); +} + +void legalizeBinaryOp(IRInst* inst, DiagnosticSink* sink, CodeGenTarget target) +{ + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Division by matrix is not supported on Metal and WGSL. + if (isDivisionByMatrix(inst)) + { + sink->diagnose(inst, Diagnostics::divisionByMatrixNotSupported); + return; + } + // For shifts, ensure that the shift amount is unsigned, as required by // https://www.w3.org/TR/WGSL/#bit-expr. if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) @@ -19,8 +112,6 @@ void legalizeBinaryOp(IRInst* inst) IntInfo opIntInfo = getIntTypeInfo(shiftAmountElementType); if (opIntInfo.isSigned) { - IRBuilder builder(inst); - builder.setInsertBefore(inst); opIntInfo.isSigned = false; shiftAmountElementType = builder.getType(getIntTypeOpFromInfo(opIntInfo)); shiftAmountVectorType = builder.getVectorType( @@ -35,8 +126,6 @@ void legalizeBinaryOp(IRInst* inst) IntInfo opIntInfo = getIntTypeInfo(shiftAmountType); if (opIntInfo.isSigned) { - IRBuilder builder(inst); - builder.setInsertBefore(inst); opIntInfo.isSigned = false; shiftAmountType = builder.getType(getIntTypeOpFromInfo(opIntInfo)); IRInst* newShiftAmount = builder.emitCast(shiftAmountType, shiftAmount); @@ -45,54 +134,23 @@ void legalizeBinaryOp(IRInst* inst) } } - auto isVectorOrMatrix = [](IRType* type) + // For matrix divided by scalar operations, do not convert scalar divisor to dividend's matrix + // type. Division by matrix is not supported on Metal and WGSL. + if (!isMatrixDividedByScalar(inst)) { - switch (type->getOp()) - { - case kIROp_VectorType: - case kIROp_MatrixType: - return true; - default: - return false; - } - }; - if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) && - as<IRBasicType>(inst->getOperand(1)->getDataType())) + legalizeScalarOperandsToMatchComposite(inst); + } + else if (isWGPUTarget(target)) { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - IRType* compositeType = inst->getOperand(0)->getDataType(); - IRInst* scalarValue = inst->getOperand(1); - // Retain the scalar type for shifts - if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) - { - auto vectorType = as<IRVectorType>(compositeType); - compositeType = - builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount()); - } - auto newRhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue); - builder.replaceOperand(inst->getOperands() + 1, newRhs); + // WGSL does not support matrix division by scalar, convert it to multiplication. + replaceMatrixDividedByScalarWithMul(inst); } - else if ( - as<IRBasicType>(inst->getOperand(0)->getDataType()) && - isVectorOrMatrix(inst->getOperand(1)->getDataType())) + else { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - IRType* compositeType = inst->getOperand(1)->getDataType(); - IRInst* scalarValue = inst->getOperand(0); - // Retain the scalar type for shifts - if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh) - { - auto vectorType = as<IRVectorType>(compositeType); - compositeType = - builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount()); - } - auto newLhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue); - builder.replaceOperand(inst->getOperands(), newLhs); + // Matrix divided by scalar is natively supported on Metal - leave it as is. } - else if ( - isIntegralType(inst->getOperand(0)->getDataType()) && + + if (isIntegralType(inst->getOperand(0)->getDataType()) && isIntegralType(inst->getOperand(1)->getDataType())) { // Unless the operator is a shift, and if the integer operands differ in signedness, @@ -108,8 +166,6 @@ void legalizeBinaryOp(IRInst* inst) { int signedOpIndex = (int)opIntInfo[1].isSigned; opIntInfo[signedOpIndex].isSigned = false; - IRBuilder builder(inst); - builder.setInsertBefore(inst); auto newOp = builder.emitCast( builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])), inst->getOperand(signedOpIndex)); diff --git a/source/slang/slang-ir-legalize-binary-operator.h b/source/slang/slang-ir-legalize-binary-operator.h index f9ebf90d8..cffa3efb6 100644 --- a/source/slang/slang-ir-legalize-binary-operator.h +++ b/source/slang/slang-ir-legalize-binary-operator.h @@ -1,17 +1,22 @@ #pragma once +#include "slang-compiler.h" + namespace Slang { struct IRInst; +class DiagnosticSink; +// Legalize binary operations for Metal and WGSL targets. +// // Ensures: // - Shift amounts are over unsigned scalar types. // - If one operand is a composite type (vector or matrix), and the other one is a scalar // type, then the scalar is converted to a composite type. // - If 'inst' is not a shift, and if operands are integers of mixed signedness, then the // signed operand is converted to unsigned. -void legalizeBinaryOp(IRInst* inst); +void legalizeBinaryOp(IRInst* inst, DiagnosticSink* sink, CodeGenTarget target); // The logical binary operators such as AND and OR takes boolean types are its input. // If they are in integer type, as an example, we need to explicitly cast to bool type. diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp index e9f693622..589413cbf 100644 --- a/source/slang/slang-ir-metal-legalize.cpp +++ b/source/slang/slang-ir-metal-legalize.cpp @@ -181,7 +181,7 @@ struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner } }; -static void processInst(IRInst* inst) +static void processInst(IRInst* inst, DiagnosticSink* sink) { switch (inst->getOp()) { @@ -204,7 +204,7 @@ static void processInst(IRInst* inst) case kIROp_Less: case kIROp_Geq: case kIROp_Leq: - legalizeBinaryOp(inst); + legalizeBinaryOp(inst, sink, CodeGenTarget::Metal); break; case kIROp_MetalCastToDepthTexture: { @@ -220,7 +220,7 @@ static void processInst(IRInst* inst) default: for (auto child : inst->getModifiableChildren()) { - processInst(child); + processInst(child, sink); } } } @@ -248,7 +248,7 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink) MetalAddressSpaceAssigner metalAddressSpaceAssigner; specializeAddressSpace(module, &metalAddressSpaceAssigner); - processInst(module->getModuleInst()); + processInst(module->getModuleInst(), sink); } } // namespace Slang diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index efa028703..51f16e603 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -121,7 +121,7 @@ static void legalizeSwitch(IRSwitch* switchInst) switchInst->removeAndDeallocate(); } -static void processInst(IRInst* inst) +static void processInst(IRInst* inst, DiagnosticSink* sink) { switch (inst->getOp()) { @@ -154,7 +154,7 @@ static void processInst(IRInst* inst) case kIROp_Less: case kIROp_Geq: case kIROp_Leq: - legalizeBinaryOp(inst); + legalizeBinaryOp(inst, sink, CodeGenTarget::WGSL); break; case kIROp_Func: @@ -163,7 +163,7 @@ static void processInst(IRInst* inst) default: for (auto child : inst->getModifiableChildren()) { - processInst(child); + processInst(child, sink); } } } @@ -218,7 +218,7 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) legalizeEntryPointVaryingParamsForWGSL(module, sink, entryPoints); // Go through every instruction in the module and legalize them as needed. - processInst(module->getModuleInst()); + processInst(module->getModuleInst(), sink); // Some global insts are illegal, e.g. function calls. // We need to inline and remove those. diff --git a/tests/bugs/matrix-divided-by-scalar.slang b/tests/bugs/matrix-divided-by-scalar.slang new file mode 100644 index 000000000..27c23e501 --- /dev/null +++ b/tests/bugs/matrix-divided-by-scalar.slang @@ -0,0 +1,21 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -compute -entry computeMain -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-metal -compute -entry computeMain -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-wgpu -compute -entry computeMain -output-using-type + +//TEST_INPUT:ubuffer(data=[3.5], stride=4):name inputBuffer +StructuredBuffer<float> inputBuffer; + +//TEST_INPUT:ubuffer(data=[0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + // CHECK: 6.0 + outputBuffer[0] = (float3x3(15.0) / 2.5)[0][0]; + + // CHECK: 4.0 + outputBuffer[1] = (float4x4(14.0) / inputBuffer[0])[0][0]; +} + diff --git a/tests/diagnostics/division-by-matrix.slang b/tests/diagnostics/division-by-matrix.slang new file mode 100644 index 000000000..6ed78d353 --- /dev/null +++ b/tests/diagnostics/division-by-matrix.slang @@ -0,0 +1,15 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target metal +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target wgsl + +RWStructuredBuffer<float> outputBuffer; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + // CHECK: error 56102: division by matrix is not supported + float3x3 divisor = float3x3(2.5); + divisor[1][1] = 1.5; + outputBuffer[0] = (float3x3(15) / divisor)[0][0]; +} + |
