summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-04-14 14:48:17 -0600
committerGitHub <noreply@github.com>2025-04-14 14:48:17 -0600
commit705d00ab8528e0d7c14f68b7d0e15fb57280c16e (patch)
treeacf6e024ef803c5a49e2c6c0075ab0d9a49a11d3 /source
parentd6f4780e8a608fa37597116d5b0ac5c80034c2aa (diff)
Fix matrix division by scalar for Metal and WGSL targets (#6752)
* Fix matrix division by scalar for Metal and WGSL targets * Add tests * Minor fix * Fix compilation error * Convert to multiplication for WGSL * Minor cleanup --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-diagnostic-defs.h5
-rw-r--r--source/slang/slang-ir-legalize-binary-operator.cpp156
-rw-r--r--source/slang/slang-ir-legalize-binary-operator.h7
-rw-r--r--source/slang/slang-ir-metal-legalize.cpp8
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp8
5 files changed, 125 insertions, 59 deletions
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 6d84792fb..dfea9fede 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -2622,6 +2622,11 @@ DIAGNOSTIC(
resourceTypesInConstantBufferInParameterBlockNotAllowedOnMetal,
"nesting a 'ConstantBuffer' containing resource types inside a 'ParameterBlock' is not "
"supported on Metal, please use 'ParameterBlock' instead.")
+DIAGNOSTIC(
+ 56102,
+ Error,
+ divisionByMatrixNotSupported,
+ "division by matrix is not supported for Metal and WGSL targets.")
DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.")
diff --git a/source/slang/slang-ir-legalize-binary-operator.cpp b/source/slang/slang-ir-legalize-binary-operator.cpp
index 1595aa130..f2f7cdef2 100644
--- a/source/slang/slang-ir-legalize-binary-operator.cpp
+++ b/source/slang/slang-ir-legalize-binary-operator.cpp
@@ -1,12 +1,105 @@
#include "slang-ir-legalize-binary-operator.h"
+#include "compiler-core/slang-diagnostic-sink.h"
#include "slang-ir-insts.h"
namespace Slang
{
-void legalizeBinaryOp(IRInst* inst)
+static bool isVectorOrMatrix(IRType* type)
{
+ switch (type->getOp())
+ {
+ case kIROp_VectorType:
+ case kIROp_MatrixType:
+ return true;
+ default:
+ return false;
+ }
+};
+
+static bool isDivisionByMatrix(IRInst* inst)
+{
+ return (inst->getOp() == kIROp_Div) && (as<IRMatrixType>(inst->getOperand(1)->getDataType()));
+}
+
+static bool isMatrixDividedByScalar(IRInst* inst)
+{
+ return (inst->getOp() == kIROp_Div) && (as<IRMatrixType>(inst->getOperand(0)->getDataType())) &&
+ (as<IRBasicType>(inst->getOperand(1)->getDataType()));
+}
+
+// If one operand is a composite type (vector or matrix), and the other one is a scalar
+// type, then the scalar is converted to a composite type.
+static void legalizeScalarOperandsToMatchComposite(IRInst* inst)
+{
+ if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
+ as<IRBasicType>(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ IRType* compositeType = inst->getOperand(0)->getDataType();
+ IRInst* scalarValue = inst->getOperand(1);
+ // Retain the scalar type for shifts
+ if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
+ {
+ auto vectorType = as<IRVectorType>(compositeType);
+ compositeType =
+ builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
+ }
+ auto newRhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
+ builder.replaceOperand(inst->getOperands() + 1, newRhs);
+ }
+ else if (
+ as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
+ isVectorOrMatrix(inst->getOperand(1)->getDataType()))
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ IRType* compositeType = inst->getOperand(1)->getDataType();
+ IRInst* scalarValue = inst->getOperand(0);
+ // Retain the scalar type for shifts
+ if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
+ {
+ auto vectorType = as<IRVectorType>(compositeType);
+ compositeType =
+ builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
+ }
+ auto newLhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
+ builder.replaceOperand(inst->getOperands(), newLhs);
+ }
+}
+
+// Replaces a division by scalar operation by a multiplication.
+// This is done for WGSL where matrix divided by scalar operations are not supported.
+static void replaceMatrixDividedByScalarWithMul(IRInst* inst)
+{
+ SLANG_ASSERT(isMatrixDividedByScalar(inst));
+
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+
+ auto scalarType = inst->getOperand(1)->getDataType();
+ auto newRhs =
+ builder.emitDiv(scalarType, builder.getFloatValue(scalarType, 1.0), inst->getOperand(1));
+ auto newOp = builder.emitMul(inst->getDataType(), inst->getOperand(0), newRhs);
+
+ inst->replaceUsesWith(newOp);
+ inst->transferDecorationsTo(newOp);
+}
+
+void legalizeBinaryOp(IRInst* inst, DiagnosticSink* sink, CodeGenTarget target)
+{
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+
+ // Division by matrix is not supported on Metal and WGSL.
+ if (isDivisionByMatrix(inst))
+ {
+ sink->diagnose(inst, Diagnostics::divisionByMatrixNotSupported);
+ return;
+ }
+
// For shifts, ensure that the shift amount is unsigned, as required by
// https://www.w3.org/TR/WGSL/#bit-expr.
if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
@@ -19,8 +112,6 @@ void legalizeBinaryOp(IRInst* inst)
IntInfo opIntInfo = getIntTypeInfo(shiftAmountElementType);
if (opIntInfo.isSigned)
{
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
opIntInfo.isSigned = false;
shiftAmountElementType = builder.getType(getIntTypeOpFromInfo(opIntInfo));
shiftAmountVectorType = builder.getVectorType(
@@ -35,8 +126,6 @@ void legalizeBinaryOp(IRInst* inst)
IntInfo opIntInfo = getIntTypeInfo(shiftAmountType);
if (opIntInfo.isSigned)
{
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
opIntInfo.isSigned = false;
shiftAmountType = builder.getType(getIntTypeOpFromInfo(opIntInfo));
IRInst* newShiftAmount = builder.emitCast(shiftAmountType, shiftAmount);
@@ -45,54 +134,23 @@ void legalizeBinaryOp(IRInst* inst)
}
}
- auto isVectorOrMatrix = [](IRType* type)
+ // For matrix divided by scalar operations, do not convert scalar divisor to dividend's matrix
+ // type. Division by matrix is not supported on Metal and WGSL.
+ if (!isMatrixDividedByScalar(inst))
{
- switch (type->getOp())
- {
- case kIROp_VectorType:
- case kIROp_MatrixType:
- return true;
- default:
- return false;
- }
- };
- if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
- as<IRBasicType>(inst->getOperand(1)->getDataType()))
+ legalizeScalarOperandsToMatchComposite(inst);
+ }
+ else if (isWGPUTarget(target))
{
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- IRType* compositeType = inst->getOperand(0)->getDataType();
- IRInst* scalarValue = inst->getOperand(1);
- // Retain the scalar type for shifts
- if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
- {
- auto vectorType = as<IRVectorType>(compositeType);
- compositeType =
- builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
- }
- auto newRhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
- builder.replaceOperand(inst->getOperands() + 1, newRhs);
+ // WGSL does not support matrix division by scalar, convert it to multiplication.
+ replaceMatrixDividedByScalarWithMul(inst);
}
- else if (
- as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
- isVectorOrMatrix(inst->getOperand(1)->getDataType()))
+ else
{
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- IRType* compositeType = inst->getOperand(1)->getDataType();
- IRInst* scalarValue = inst->getOperand(0);
- // Retain the scalar type for shifts
- if (inst->getOp() == kIROp_Lsh || inst->getOp() == kIROp_Rsh)
- {
- auto vectorType = as<IRVectorType>(compositeType);
- compositeType =
- builder.getVectorType(scalarValue->getDataType(), vectorType->getElementCount());
- }
- auto newLhs = builder.emitMakeCompositeFromScalar(compositeType, scalarValue);
- builder.replaceOperand(inst->getOperands(), newLhs);
+ // Matrix divided by scalar is natively supported on Metal - leave it as is.
}
- else if (
- isIntegralType(inst->getOperand(0)->getDataType()) &&
+
+ if (isIntegralType(inst->getOperand(0)->getDataType()) &&
isIntegralType(inst->getOperand(1)->getDataType()))
{
// Unless the operator is a shift, and if the integer operands differ in signedness,
@@ -108,8 +166,6 @@ void legalizeBinaryOp(IRInst* inst)
{
int signedOpIndex = (int)opIntInfo[1].isSigned;
opIntInfo[signedOpIndex].isSigned = false;
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
auto newOp = builder.emitCast(
builder.getType(getIntTypeOpFromInfo(opIntInfo[signedOpIndex])),
inst->getOperand(signedOpIndex));
diff --git a/source/slang/slang-ir-legalize-binary-operator.h b/source/slang/slang-ir-legalize-binary-operator.h
index f9ebf90d8..cffa3efb6 100644
--- a/source/slang/slang-ir-legalize-binary-operator.h
+++ b/source/slang/slang-ir-legalize-binary-operator.h
@@ -1,17 +1,22 @@
#pragma once
+#include "slang-compiler.h"
+
namespace Slang
{
struct IRInst;
+class DiagnosticSink;
+// Legalize binary operations for Metal and WGSL targets.
+//
// Ensures:
// - Shift amounts are over unsigned scalar types.
// - If one operand is a composite type (vector or matrix), and the other one is a scalar
// type, then the scalar is converted to a composite type.
// - If 'inst' is not a shift, and if operands are integers of mixed signedness, then the
// signed operand is converted to unsigned.
-void legalizeBinaryOp(IRInst* inst);
+void legalizeBinaryOp(IRInst* inst, DiagnosticSink* sink, CodeGenTarget target);
// The logical binary operators such as AND and OR takes boolean types are its input.
// If they are in integer type, as an example, we need to explicitly cast to bool type.
diff --git a/source/slang/slang-ir-metal-legalize.cpp b/source/slang/slang-ir-metal-legalize.cpp
index e9f693622..589413cbf 100644
--- a/source/slang/slang-ir-metal-legalize.cpp
+++ b/source/slang/slang-ir-metal-legalize.cpp
@@ -181,7 +181,7 @@ struct MetalAddressSpaceAssigner : InitialAddressSpaceAssigner
}
};
-static void processInst(IRInst* inst)
+static void processInst(IRInst* inst, DiagnosticSink* sink)
{
switch (inst->getOp())
{
@@ -204,7 +204,7 @@ static void processInst(IRInst* inst)
case kIROp_Less:
case kIROp_Geq:
case kIROp_Leq:
- legalizeBinaryOp(inst);
+ legalizeBinaryOp(inst, sink, CodeGenTarget::Metal);
break;
case kIROp_MetalCastToDepthTexture:
{
@@ -220,7 +220,7 @@ static void processInst(IRInst* inst)
default:
for (auto child : inst->getModifiableChildren())
{
- processInst(child);
+ processInst(child, sink);
}
}
}
@@ -248,7 +248,7 @@ void legalizeIRForMetal(IRModule* module, DiagnosticSink* sink)
MetalAddressSpaceAssigner metalAddressSpaceAssigner;
specializeAddressSpace(module, &metalAddressSpaceAssigner);
- processInst(module->getModuleInst());
+ processInst(module->getModuleInst(), sink);
}
} // namespace Slang
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
index efa028703..51f16e603 100644
--- a/source/slang/slang-ir-wgsl-legalize.cpp
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -121,7 +121,7 @@ static void legalizeSwitch(IRSwitch* switchInst)
switchInst->removeAndDeallocate();
}
-static void processInst(IRInst* inst)
+static void processInst(IRInst* inst, DiagnosticSink* sink)
{
switch (inst->getOp())
{
@@ -154,7 +154,7 @@ static void processInst(IRInst* inst)
case kIROp_Less:
case kIROp_Geq:
case kIROp_Leq:
- legalizeBinaryOp(inst);
+ legalizeBinaryOp(inst, sink, CodeGenTarget::WGSL);
break;
case kIROp_Func:
@@ -163,7 +163,7 @@ static void processInst(IRInst* inst)
default:
for (auto child : inst->getModifiableChildren())
{
- processInst(child);
+ processInst(child, sink);
}
}
}
@@ -218,7 +218,7 @@ void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink)
legalizeEntryPointVaryingParamsForWGSL(module, sink, entryPoints);
// Go through every instruction in the module and legalize them as needed.
- processInst(module->getModuleInst());
+ processInst(module->getModuleInst(), sink);
// Some global insts are illegal, e.g. function calls.
// We need to inline and remove those.