summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvenkataram-nv <vedavamadath@nvidia.com>2025-07-18 09:38:00 -0700
committerGitHub <noreply@github.com>2025-07-18 16:38:00 +0000
commit48b6e2432ea28c06d04931fccd633e31eed6d995 (patch)
treeb976380fd3464b231275e0ae2c1c6ac8af1bb6c3
parent85edfb178cd243134f4bb3d35ad71f154d76c81c (diff)
Lower int/uint/bool matrices to arrays for SPIRV (#7687)
* Add tests for expected behaviour * Allow matrix types in logical or/and * Legalize int/bool matrix types and construction with makeMatrix * Legalize uint matrices and operations * Limit testing to only SPIRV * Better tests for int and bool * Add test for uint * Remove GLSL tests * Remove old test for diagnosing int matrices * Emit SPIRV directly in tests * format code * Address PR comments * Improve testing * Address PR comments * format code * Add tests for matrix intrinsic operations * Move matrix lowering to dedicated legalization pass * Fix compiler warning * Remove signal again * Reorder matrix and vector legalization * Fix formatting * Add shift and comparison tests --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--source/slang/hlsl.meta.slang53
-rw-r--r--source/slang/slang-diagnostic-defs.h6
-rw-r--r--source/slang/slang-emit-spirv.cpp172
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-ir-legalize-binary-operator.cpp157
-rw-r--r--source/slang/slang-ir-legalize-matrix-types.cpp141
-rw-r--r--source/slang/slang-ir-legalize-matrix-types.h13
-rw-r--r--source/slang/slang-ir-util.cpp3
-rw-r--r--source/slang/slang-ir-validate.cpp18
-rw-r--r--tests/compute/integer-matrix-diagnostic.slang22
-rw-r--r--tests/spirv/matrix-bool-lowering.slang114
-rw-r--r--tests/spirv/matrix-integer-lowering.slang189
12 files changed, 738 insertions, 152 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 264098bec..2ac886f61 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -8616,6 +8616,55 @@ T determinant(matrix<T,N,N> m)
}
}
+/// @param m The matrix.
+/// @return The determinant of the matrix.
+/// @category math
+__generic<T : __BuiltinIntegerType, let N : int>
+[__readNone]
+[require(glsl_hlsl_metal_spirv_wgsl)]
+T determinant(matrix<T,N,N> m)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "determinant";
+ case hlsl: __intrinsic_asm "determinant";
+ case metal: __intrinsic_asm "determinant";
+ case wgsl: __intrinsic_asm "determinant";
+ // SPIR-V doesn't support integer determinants, so we need to implement it manually
+ default:
+ static_assert(N >= 1 && N <= 4, "determinant is only implemented up to 4x4 matrices");
+ if (N == 1)
+ {
+ return m[0][0];
+ }
+ else if (N == 2)
+ {
+ return m[0][0] * m[1][1] - m[0][1] * m[1][0];
+ }
+ else if (N == 3)
+ {
+ return
+ m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
+ - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
+ + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
+ }
+ else// if (N == 4)
+ {
+ T a = m[2][2] * m[3][3] - m[2][3] * m[3][2];
+ T b = m[2][1] * m[3][3] - m[2][3] * m[3][1];
+ T c = m[2][1] * m[3][2] - m[2][2] * m[3][1];
+ T d = m[2][0] * m[3][3] - m[2][3] * m[3][0];
+ T e = m[2][0] * m[3][2] - m[2][2] * m[3][0];
+ T f = m[2][0] * m[3][1] - m[2][1] * m[3][0];
+ return
+ m[0][0] * (m[1][1] * a - m[1][2] * b + m[1][3] * c)
+ - m[0][1] * (m[1][0] * a - m[1][2] * d + m[1][3] * e)
+ + m[0][2] * (m[1][0] * b - m[1][1] * d + m[1][3] * f)
+ - m[0][3] * (m[1][0] * c - m[1][1] * e + m[1][2] * f);
+ }
+ }
+}
+
/// Barrier for device memory.
/// @category barrier
__glsl_extension(GL_KHR_memory_scope_semantics)
@@ -13720,10 +13769,8 @@ matrix<T, M, N> transpose(matrix<T, N, M> x)
{
case glsl: __intrinsic_asm "transpose";
case hlsl: __intrinsic_asm "transpose";
- case spirv: return spirv_asm {
- OpTranspose $$matrix<T, M, N> result $x
- };
case wgsl: __intrinsic_asm "transpose";
+ // SPIRV-V doenst't support integer matrices, so transpose it manually
default:
matrix<T, M, N> result;
for (int r = 0; r < M; ++r)
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 0ce5d9f47..3dafda3be 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -2201,12 +2201,6 @@ DIAGNOSTIC(
DIAGNOSTIC(39999, Fatal, complationCeased, "compilation ceased")
DIAGNOSTIC(
- 38202,
- Error,
- matrixWithDisallowedElementTypeEncountered,
- "matrix with disallowed element type '$0' encountered")
-
-DIAGNOSTIC(
38203,
Error,
vectorWithDisallowedElementTypeEncountered,
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 2b6f1c821..bbed44c51 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -216,7 +216,7 @@ struct SpvInst : SpvInstParent
//
// > Word Count: The complete number of words taken by an instruction,
// > including the word holding the word count and opcode, and any optional
- // > operands. An instruction’s word count is the total space taken by the instruction.
+ // > operands. An instruction's word count is the total space taken by the instruction.
//
SpvWord wordCount = 1 + SpvWord(operandWordsCount);
@@ -360,7 +360,7 @@ struct SpvLiteralBits
// > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed
// > four per word, following the little-endian convention (i.e., the
// > first octet is in the lowest-order 8 bits of the word).
- // > The final word contains the string’s nul-termination character (0), and
+ // > The final word contains the string's nul-termination character (0), and
// > all contents past the end of the string in the final word are padded with 0.
// First work out the amount of words we'll need
@@ -2039,17 +2039,24 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
+ auto elementType = matrixType->getElementType();
+
+ // SPIR-V only supports floating-point matrices
+ // bool/int matrices should be lowered to
+ // arrays of vectors before reaching here
+ SLANG_ASSERT(!as<IRBoolType>(elementType));
+ SLANG_ASSERT(!as<IRIntType>(elementType));
+ SLANG_ASSERT(!as<IRUIntType>(elementType));
+
auto vectorSpvType = ensureVectorType(
- static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(),
+ static_cast<IRBasicType*>(elementType)->getBaseType(),
static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(),
nullptr);
const auto columnCount =
static_cast<IRIntLit*>(matrixType->getRowCount())->getValue();
- auto matrixSPVType = emitOpTypeMatrix(
- inst,
- vectorSpvType,
- SpvLiteralInteger::from32(int32_t(columnCount)));
- return matrixSPVType;
+ const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount));
+ SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv);
+ return matrixSpvType;
}
case kIROp_ArrayType:
case kIROp_UnsizedArrayType:
@@ -2621,7 +2628,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SpvWord arrayed =
inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed;
- // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored."
+ // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored."
SpvWord depth =
ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image
SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled
@@ -7767,12 +7774,40 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
// Otherwise, operands are raw elements, we need to construct row vectors first,
// then construct matrix from row vectors.
List<SpvInst*> rowVectors;
- auto matrixType = as<IRMatrixType>(inst->getDataType());
- auto rowCount = getIntVal(matrixType->getRowCount());
- auto colCount = getIntVal(matrixType->getColumnCount());
+
+ IRIntegerValue rowCount;
+ IRIntegerValue colCount;
+ IRType* elementType;
+
+ // Data type can be either matrix or vector depending on the
+ // legalization requirements
+ auto dataType = inst->getDataType();
+
+ if (auto matrixType = as<IRMatrixType>(dataType))
+ {
+ elementType = matrixType->getElementType();
+ rowCount = getIntVal(matrixType->getRowCount());
+ colCount = getIntVal(matrixType->getColumnCount());
+ }
+ else if (auto arrayType = as<IRArrayType>(dataType))
+ {
+ auto vectorType = as<IRVectorType>(arrayType->getElementType());
+ SLANG_ASSERT(vectorType);
+
+ elementType = vectorType->getElementType();
+ rowCount = getIntVal(arrayType->getElementCount());
+ colCount = getIntVal(vectorType->getElementCount());
+ }
+ else
+ {
+ SLANG_UNEXPECTED("data type for makeMatrix operation is "
+ "expected be either a matrix or array type");
+ }
+
IRBuilder builder(inst);
builder.setInsertBefore(inst);
- auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
+ auto rowVectorType = builder.getVectorType(elementType, colCount);
+
List<IRInst*> colElements;
UInt index = 0;
for (IRIntegerValue j = 0; j < rowCount; j++)
@@ -7897,7 +7932,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
ArrayView<IRInst*> operands)
{
IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
+ SLANG_ASSERT(elementType);
+
IRBasicType* basicType = as<IRBasicType>(elementType);
+ SLANG_ASSERT(basicType);
SpvOp opCode = _arithmeticOpCodeConvert(op, basicType);
if (opCode == SpvOpUndef)
@@ -7958,6 +7996,52 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands");
}
+ // Helper method to handle composite arithmetic operations for matrices and arrays
+ SpvInst* emitCompositeArithmetic(
+ SpvInstParent* parent,
+ IRInst* inst,
+ IRIntegerValue rowCount,
+ IRIntegerValue colCount,
+ IRType* elementType,
+ IRType* resultType,
+ bool isMatrixType)
+ {
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto rowVectorType = builder.getVectorType(elementType, colCount);
+ List<SpvInst*> rows;
+
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ List<IRInst*> operands;
+ for (UInt j = 0; j < inst->getOperandCount(); j++)
+ {
+ auto originalOperand = inst->getOperand(j);
+ bool shouldExtract =
+ isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr
+ : as<IRArrayType>(originalOperand->getDataType()) != nullptr;
+
+ if (shouldExtract)
+ {
+ auto operand = builder.emitElementExtract(originalOperand, i);
+ emitLocalInst(parent, operand);
+ operands.add(operand);
+ }
+ else
+ {
+ operands.add(originalOperand);
+ }
+ }
+ rows.add(emitVectorOrScalarArithmetic(
+ parent,
+ nullptr,
+ rowVectorType,
+ inst->getOp(),
+ inst->getOperandCount(),
+ operands.getArrayView()));
+ }
+ return emitCompositeConstruct(parent, inst, resultType, rows);
+ }
SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
{
@@ -7965,36 +8049,38 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto rowCount = getIntVal(matrixType->getRowCount());
auto colCount = getIntVal(matrixType->getColumnCount());
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
- List<SpvInst*> rows;
- for (IRIntegerValue i = 0; i < rowCount; i++)
- {
- List<IRInst*> operands;
- for (UInt j = 0; j < inst->getOperandCount(); j++)
- {
- auto originalOperand = inst->getOperand(j);
- if (as<IRMatrixType>(originalOperand->getDataType()))
- {
- auto operand = builder.emitElementExtract(originalOperand, i);
- emitLocalInst(parent, operand);
- operands.add(operand);
- }
- else
- {
- operands.add(originalOperand);
- }
- }
- rows.add(emitVectorOrScalarArithmetic(
- parent,
- nullptr,
- rowVectorType,
- inst->getOp(),
- inst->getOperandCount(),
- operands.getArrayView()));
- }
- return emitCompositeConstruct(parent, inst, inst->getDataType(), rows);
+ return emitCompositeArithmetic(
+ parent,
+ inst,
+ rowCount,
+ colCount,
+ matrixType->getElementType(),
+ inst->getDataType(),
+ true);
+ }
+ else if (const auto arrayType = as<IRArrayType>(inst->getDataType()))
+ {
+ // Only for legalization
+ auto arrayElementType = arrayType->getElementType();
+ SLANG_ASSERT(as<IRVectorType>(arrayElementType));
+
+ auto vectorType = as<IRVectorType>(arrayElementType);
+ auto elementType = vectorType->getElementType();
+ SLANG_ASSERT(
+ as<IRBoolType>(elementType) || as<IRUIntType>(elementType) ||
+ as<IRIntType>(elementType));
+
+ auto rowCount = getIntVal(arrayType->getElementCount());
+ auto colCount = getIntVal(vectorType->getElementCount());
+
+ return emitCompositeArithmetic(
+ parent,
+ inst,
+ rowCount,
+ colCount,
+ elementType,
+ inst->getDataType(),
+ false);
}
Array<IRInst*, 4> operands;
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index f40679bd9..067b5a551 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -61,6 +61,7 @@
#include "slang-ir-legalize-empty-array.h"
#include "slang-ir-legalize-global-values.h"
#include "slang-ir-legalize-image-subscript.h"
+#include "slang-ir-legalize-matrix-types.h"
#include "slang-ir-legalize-mesh-outputs.h"
#include "slang-ir-legalize-uniform-buffer-load.h"
#include "slang-ir-legalize-varying-params.h"
@@ -1334,6 +1335,7 @@ Result linkAndOptimizeIR(
legalizeEmptyTypes(targetProgram, irModule, sink);
}
+ legalizeMatrixTypes(targetProgram, irModule, sink);
legalizeVectorTypes(irModule, sink);
// Once specialization and type legalization have been performed,
diff --git a/source/slang/slang-ir-legalize-binary-operator.cpp b/source/slang/slang-ir-legalize-binary-operator.cpp
index f2f7cdef2..24ba61fc6 100644
--- a/source/slang/slang-ir-legalize-binary-operator.cpp
+++ b/source/slang/slang-ir-legalize-binary-operator.cpp
@@ -176,93 +176,124 @@ void legalizeBinaryOp(IRInst* inst, DiagnosticSink* sink, CodeGenTarget target)
void legalizeLogicalAndOr(IRInst* inst)
{
- switch (inst->getOp())
+ auto op = inst->getOp();
+ if (op == kIROp_And || op == kIROp_Or)
{
- case kIROp_And:
- case kIROp_Or:
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+
+ // Logical-AND and logical-OR takes boolean types as its operands.
+ // If they are not, legalize them by casting to boolean type.
+ //
+ SLANG_ASSERT(inst->getOperandCount() == 2);
+ for (UInt i = 0; i < 2; i++)
{
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
-
- // Logical-AND and logical-OR takes boolean types as its operands.
- // If they are not, legalize them by casting to boolean type.
- //
- SLANG_ASSERT(inst->getOperandCount() == 2);
- for (UInt i = 0; i < 2; i++)
- {
- auto operand = inst->getOperand(i);
- auto operandDataType = operand->getDataType();
+ auto operand = inst->getOperand(i);
+ auto operandDataType = operand->getDataType();
- if (auto vecType = as<IRVectorType>(operandDataType))
- {
- if (!as<IRBoolType>(vecType->getElementType()))
- {
- // Cast operand to vector<bool,N>
- auto elemCount = vecType->getElementCount();
- auto vb = builder.getVectorType(builder.getBoolType(), elemCount);
- auto v = builder.emitCast(vb, operand);
- builder.replaceOperand(inst->getOperands() + i, v);
- }
- }
- else if (!as<IRBoolType>(operandDataType))
- {
- // Cast operand to bool
- auto s = builder.emitCast(builder.getBoolType(), operand);
- builder.replaceOperand(inst->getOperands() + i, s);
- }
- }
+ SLANG_ASSERT(
+ as<IRMatrixType>(operandDataType) || as<IRVectorType>(operandDataType) ||
+ as<IRArrayType>(operandDataType) || as<IRBoolType>(operandDataType));
- // Legalize the return type; mostly for SPIRV.
- // The return type of OpLogicalOr must be boolean type.
- // If not, we need to recreate the instruction with boolean return type.
- // Then, we have to cast it back to the original type so that other instrucitons that
- // use have the matching types.
- //
- auto dataType = inst->getDataType();
- auto lhs = inst->getOperand(0);
- auto rhs = inst->getOperand(1);
- IRInst* newInst = nullptr;
-
- if (auto vecType = as<IRVectorType>(dataType))
+ if (auto vecType = as<IRVectorType>(operandDataType))
{
if (!as<IRBoolType>(vecType->getElementType()))
{
- // Return type should be vector<bool,N>
+ // Cast operand to vector<bool,N>
auto elemCount = vecType->getElementCount();
auto vb = builder.getVectorType(builder.getBoolType(), elemCount);
-
- if (inst->getOp() == kIROp_And)
- {
- newInst = builder.emitAnd(vb, lhs, rhs);
- }
- else
- {
- newInst = builder.emitOr(vb, lhs, rhs);
- }
- newInst = builder.emitCast(dataType, newInst);
+ auto v = builder.emitCast(vb, operand);
+ builder.replaceOperand(inst->getOperands() + i, v);
}
}
- else if (!as<IRBoolType>(dataType))
+ }
+
+ // Legalize the return type; mostly for SPIRV.
+ // The return type of OpLogicalOr must be boolean type.
+ // If not, we need to recreate the instruction with boolean return type.
+ // Then, we have to cast it back to the original type so that other instrucitons that
+ // use have the matching types.
+ //
+ auto dataType = inst->getDataType();
+ auto lhs = inst->getOperand(0);
+ auto rhs = inst->getOperand(1);
+ IRInst* newInst = nullptr;
+
+ SLANG_ASSERT(
+ as<IRMatrixType>(dataType) || as<IRVectorType>(dataType) || as<IRBoolType>(dataType) ||
+ as<IRArrayType>(dataType));
+ if (auto vecType = as<IRVectorType>(dataType))
+ {
+ if (!as<IRBoolType>(vecType->getElementType()))
{
- // Return type should be bool
+ // Return type should be vector<bool,N>
+ auto elemCount = vecType->getElementCount();
+ auto vb = builder.getVectorType(builder.getBoolType(), elemCount);
+
if (inst->getOp() == kIROp_And)
{
- newInst = builder.emitAnd(builder.getBoolType(), lhs, rhs);
+ newInst = builder.emitAnd(vb, lhs, rhs);
}
else
{
- newInst = builder.emitOr(builder.getBoolType(), lhs, rhs);
+ newInst = builder.emitOr(vb, lhs, rhs);
}
newInst = builder.emitCast(dataType, newInst);
}
+ }
+ else if (auto arrayType = as<IRArrayType>(dataType))
+ {
+ // Handle lowered matrices (arrays of vectors)
+ auto arrayVecType = as<IRVectorType>(arrayType->getElementType());
+ SLANG_ASSERT(arrayVecType);
+
+ // At this point, lhs and rhs should already be converted to bool arrays
+ auto lhsArrayType = as<IRArrayType>(lhs->getDataType());
+ auto rhsArrayType = as<IRArrayType>(rhs->getDataType());
+ SLANG_ASSERT(lhsArrayType && rhsArrayType);
+
+ auto lhsVecType = as<IRVectorType>(lhsArrayType->getElementType());
+ auto rhsVecType = as<IRVectorType>(rhsArrayType->getElementType());
+ SLANG_ASSERT(lhsVecType && rhsVecType);
+
+ SLANG_ASSERT(
+ as<IRBoolType>(lhsVecType->getElementType()) &&
+ as<IRBoolType>(rhsVecType->getElementType()));
- if (newInst && inst != newInst)
+ auto arraySize = arrayType->getElementCount();
+ List<IRInst*> resultElements;
+
+ // Extract each vector from both arrays, perform AND/OR, collect results
+ for (IRIntegerValue i = 0; i < getIntVal(arraySize); i++)
{
- inst->replaceUsesWith(newInst);
- inst->removeAndDeallocate();
+ auto indexVal = builder.getIntValue(builder.getIntType(), i);
+ auto lhsElement = builder.emitElementExtract(lhs, indexVal);
+ auto rhsElement = builder.emitElementExtract(rhs, indexVal);
+
+ IRInst* resultElement;
+ if (inst->getOp() == kIROp_And)
+ {
+ resultElement =
+ builder.emitAnd(lhsElement->getDataType(), lhsElement, rhsElement);
+ }
+ else
+ {
+ resultElement =
+ builder.emitOr(lhsElement->getDataType(), lhsElement, rhsElement);
+ }
+ resultElements.add(resultElement);
}
+
+ // Construct the result array from the individual vector results
+ newInst =
+ builder.emitMakeArray(dataType, getIntVal(arraySize), resultElements.getBuffer());
+ }
+
+ if (newInst && inst != newInst)
+ {
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
}
- break;
}
for (auto child : inst->getModifiableChildren())
diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp
new file mode 100644
index 000000000..0b972b5bd
--- /dev/null
+++ b/source/slang/slang-ir-legalize-matrix-types.cpp
@@ -0,0 +1,141 @@
+#include "slang-ir-legalize-matrix-types.h"
+
+#include "slang-compiler.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+
+struct MatrixTypeLoweringContext
+{
+ TargetProgram* targetProgram;
+ IRModule* module;
+ DiagnosticSink* sink;
+
+ InstWorkList workList;
+ InstHashSet workListSet;
+
+ Dictionary<IRInst*, IRInst*> replacements;
+
+ MatrixTypeLoweringContext(TargetProgram* targetProgram, IRModule* module)
+ : targetProgram(targetProgram), module(module), workList(module), workListSet(module)
+ {
+ }
+
+ void addToWorkList(IRInst* inst)
+ {
+ for (auto ii = inst->getParent(); ii; ii = ii->getParent())
+ {
+ if (as<IRGeneric>(ii))
+ return;
+ }
+
+ if (workListSet.contains(inst))
+ return;
+
+ workList.add(inst);
+ workListSet.add(inst);
+ }
+
+ bool shouldLowerTarget()
+ {
+ auto target = targetProgram->getTargetReq()->getTarget();
+ switch (target)
+ {
+ case CodeGenTarget::SPIRV:
+ case CodeGenTarget::SPIRVAssembly:
+ case CodeGenTarget::GLSL:
+ case CodeGenTarget::WGSL:
+ case CodeGenTarget::WGSLSPIRV:
+ case CodeGenTarget::WGSLSPIRVAssembly:
+ return true;
+ default:
+ return false;
+ }
+ }
+
+ bool shouldLowerMatrixType(IRMatrixType* matrixType)
+ {
+ if (!shouldLowerTarget())
+ return false;
+
+ auto elementType = matrixType->getElementType();
+ return as<IRBoolType>(elementType) || as<IRUIntType>(elementType) ||
+ as<IRIntType>(elementType);
+ }
+
+ IRInst* getReplacement(IRInst* inst)
+ {
+ if (auto replacement = replacements.tryGetValue(inst))
+ return *replacement;
+
+ IRInst* newInst = inst;
+
+ if (auto matrixType = as<IRMatrixType>(inst))
+ {
+ if (shouldLowerMatrixType(matrixType))
+ {
+ // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C)
+ auto elementType = matrixType->getElementType();
+ auto rowCount = matrixType->getRowCount();
+ auto columnCount = matrixType->getColumnCount();
+
+ IRBuilder builder(matrixType);
+ builder.setInsertBefore(matrixType);
+
+ // Create vector type for columns: vector<T, C>
+ auto vectorType = builder.getVectorType(elementType, columnCount);
+
+ // Create array type for rows: vector<T, C>[R]
+ auto arrayType = builder.getArrayType(vectorType, rowCount);
+
+ newInst = arrayType;
+ }
+ }
+
+ replacements[inst] = newInst;
+ return newInst;
+ }
+
+ void processModule()
+ {
+ addToWorkList(module->getModuleInst());
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = workList.getLast();
+
+ workList.removeLast();
+ workListSet.remove(inst);
+
+ // Run this inst through the replacer
+ getReplacement(inst);
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
+ }
+
+ // Apply all replacements
+ for (const auto& [old, replacement] : replacements)
+ {
+ if (old != replacement)
+ {
+ old->replaceUsesWith(replacement);
+ old->removeAndDeallocate();
+ }
+ }
+ }
+};
+
+void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink)
+{
+ MatrixTypeLoweringContext context(targetProgram, module);
+ context.sink = sink;
+ context.processModule();
+}
+
+} // namespace Slang \ No newline at end of file
diff --git a/source/slang/slang-ir-legalize-matrix-types.h b/source/slang/slang-ir-legalize-matrix-types.h
new file mode 100644
index 000000000..418e80a83
--- /dev/null
+++ b/source/slang/slang-ir-legalize-matrix-types.h
@@ -0,0 +1,13 @@
+#pragma once
+
+namespace Slang
+{
+
+struct IRModule;
+class DiagnosticSink;
+class TargetProgram;
+
+// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, and GLSL targets
+void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink);
+
+} // namespace Slang \ No newline at end of file
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 1b11f8165..a1d043dff 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -453,6 +453,9 @@ void getTypeNameHint(StringBuilder& sb, IRInst* type)
switch (type->getOp())
{
+ case kIROp_BoolType:
+ sb << "bool";
+ break;
case kIROp_FloatType:
sb << "float";
break;
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index bf5d8ed5d..156fe249f 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -598,21 +598,9 @@ void validateVectorsAndMatrices(
}
}
- // Verify that the element type is a floating point type, or an allowed integral type
- auto elementType = matrixType->getElementType();
- uint32_t allowedWidths = 0U;
- if (isCPUTarget(targetRequest))
- allowedWidths = 8U | 16U | 32U | 64U;
- else if (isCUDATarget(targetRequest))
- allowedWidths = 32U | 64U;
- else if (isD3DTarget(targetRequest))
- allowedWidths = 16U | 32U;
- validateVectorOrMatrixElementType(
- sink,
- matrixType->sourceLoc,
- elementType,
- allowedWidths,
- Diagnostics::matrixWithDisallowedElementTypeEncountered);
+ // Matrix element type validation removed to allow integer/bool matrices
+ // which will be lowered to arrays of vectors on targets that don't support them
+ // natively
}
else if (auto vectorType = as<IRVectorType>(globalInst))
{
diff --git a/tests/compute/integer-matrix-diagnostic.slang b/tests/compute/integer-matrix-diagnostic.slang
deleted file mode 100644
index bd69c28e4..000000000
--- a/tests/compute/integer-matrix-diagnostic.slang
+++ /dev/null
@@ -1,22 +0,0 @@
-// Check that using matrices with integer floating point type yields the correct diagnostic
-
-//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target glsl -entry computeMain -stage compute
-//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target metal -entry computeMain -stage compute
-//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute
-//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target wgsl -entry computeMain -stage compute
-
-cbuffer MatrixBuffer
-{
- // CHECK: error 38202
- int4x4 iMatrix;
-}
-
-//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out, name=outputBuffer
-RWStructuredBuffer<int4> outputBuffer;
-
-[numthreads(4, 1, 1)]
-void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
-{
- uint index = dispatchThreadID.x;
- outputBuffer[index] = iMatrix[0][0];
-} \ No newline at end of file
diff --git a/tests/spirv/matrix-bool-lowering.slang b/tests/spirv/matrix-bool-lowering.slang
new file mode 100644
index 000000000..63b7caacf
--- /dev/null
+++ b/tests/spirv/matrix-bool-lowering.slang
@@ -0,0 +1,114 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -vk -shaderobj -xslang -emit-spirv-directly
+
+//TEST_INPUT:ubuffer(data=[1 0], stride=4):in,name inputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> inputBuffer;
+RWStructuredBuffer<int> outputBuffer;
+
+// Global bool constants to avoid constant folding
+static bool trueVal;
+static bool falseVal;
+
+struct matrixWrapper {
+ bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal);
+ bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal);
+}
+
+bool elementAnd(bool2x2 matrix)
+{
+ return trueVal
+ && matrix[0][0]
+ && matrix[0][1]
+ && matrix[1][0]
+ && matrix[1][1];
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Load true/false values from input buffer to avoid constant folding
+ trueVal = inputBuffer[0] != 0;
+ falseVal = inputBuffer[1] != 0;
+
+ // Test bool matrix construction
+ bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal);
+ bool3x3 mat2 = bool3x3(
+ trueVal, falseVal, trueVal,
+ falseVal, trueVal, falseVal,
+ trueVal, falseVal, trueVal
+ );
+ bool2x4 mat3 = bool2x4(
+ trueVal, falseVal, trueVal, falseVal,
+ trueVal, falseVal, trueVal, falseVal
+ );
+
+ // Test bool matrix element access
+ bool val1 = mat1[0][0];
+ bool val2 = mat2[2][1];
+
+ // Test bool matrix row access
+ bool2 row = mat1[1];
+ bool3 row3 = mat2[0];
+
+ // Test logical operations
+ bool2x2 not_mat = !mat1;
+ bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal);
+
+ // Test element assignment
+ mat1[0][1] = trueVal;
+ mat2[1][2] = falseVal;
+
+ // Test passing bool matrices to functions
+ bool anded = elementAnd(mat1);
+
+ // Test structs with bool matrix fields
+ matrixWrapper wrapper = {};
+
+ // Test any/all operations
+ bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal);
+ bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal);
+ bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal);
+
+ bool test_all_true = all(all_true); // all elements true -> true
+ bool test_all_false = all(all_false); // all elements false -> false
+ bool test_all_mixed = all(mixed); // some elements false -> false
+ bool test_any_true = any(all_true); // some elements true -> true
+ bool test_any_false = any(all_false); // no elements true -> false
+ bool test_any_mixed = any(mixed); // some elements true -> true
+
+ // Store results
+ outputBuffer[0] = val1;
+ // CHECK: 1
+ outputBuffer[1] = val2;
+ // CHECK-NEXT: 0
+ outputBuffer[2] = row.x;
+ // CHECK-NEXT: 0
+ outputBuffer[3] = row.y;
+ // CHECK-NEXT: 1
+ outputBuffer[4] = row3.y;
+ // CHECK-NEXT: 0
+ outputBuffer[5] = not_mat[0][0];
+ // CHECK-NEXT: 0
+ outputBuffer[6] = and_mat[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[7] = mat1[0][1];
+ // CHECK-NEXT: 1
+ outputBuffer[8] = mat3[0][1];
+ // CHECK-NEXT: 0
+ outputBuffer[9] = anded;
+ // CHECK-NEXT: 0
+ outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[11] = test_all_true;
+ // CHECK-NEXT: 1
+ outputBuffer[12] = test_all_false;
+ // CHECK-NEXT: 0
+ outputBuffer[13] = test_all_mixed;
+ // CHECK-NEXT: 0
+ outputBuffer[14] = test_any_true;
+ // CHECK-NEXT: 1
+ outputBuffer[15] = test_any_false;
+ // CHECK-NEXT: 0
+ outputBuffer[16] = test_any_mixed;
+ // CHECK-NEXT: 1
+}
diff --git a/tests/spirv/matrix-integer-lowering.slang b/tests/spirv/matrix-integer-lowering.slang
new file mode 100644
index 000000000..518d0f78b
--- /dev/null
+++ b/tests/spirv/matrix-integer-lowering.slang
@@ -0,0 +1,189 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -vk -shaderobj -xslang -emit-spirv-directly -xslang -DTYPE=int
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -vk -shaderobj -xslang -emit-spirv-directly -xslang -DTYPE=uint
+
+#ifndef TYPE
+#define TYPE int
+#endif
+
+typealias m2x2 = matrix<TYPE, 2, 2>;
+typealias m2x3 = matrix<TYPE, 2, 3>;
+typealias m3x3 = matrix<TYPE, 3, 3>;
+typealias m2x4 = matrix<TYPE, 2, 4>;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<TYPE> outputBuffer;
+
+struct matrixWrapper {
+ m2x2 mat1 = m2x2(1, 2, 3, 4);
+ m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10);
+};
+
+TYPE elementAdd(m2x2 matrix)
+{
+ return matrix[0][0]
+ + matrix[0][1]
+ + matrix[1][0]
+ + matrix[1][1];
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Test matrix construction
+ m2x2 mat1 = m2x2(1, 2, 3, 4);
+ m3x3 mat2 = m3x3(
+ 1, 2, 3,
+ 4, 5, 6,
+ 7, 8, 9
+ );
+ m2x4 mat3 = m2x4(
+ 10, 11, 12, 13,
+ 14, 15, 16, 17
+ );
+
+ // Test matrix element access
+ TYPE val1 = mat1[0][0];
+ TYPE val2 = mat2[2][1];
+
+ // Test matrix row access
+ vector<TYPE, 2> row = mat1[1];
+ vector<TYPE, 3> row3 = mat2[0];
+
+ // Test arithmetic operations
+ m2x2 mat5 = m2x2(2, 4, 6, 7);
+
+ m2x2 mat_scalar = 2 * mat1;
+ m2x2 mat_add = mat1 + mat5;
+ m2x2 mat_sub = mat5 - mat1;
+ m2x2 mat_mul = mat1 * mat5;
+
+ // Test passing matrices to functions
+ TYPE added = elementAdd(mat1);
+
+ // Test structs with matrix fields
+ matrixWrapper wrapper = {};
+
+ // Test matrix intrinsic operations
+
+ // Test determinant for square matrices
+ m2x2 mat6 = m2x2(2, 1, 4, 3);
+ TYPE det2x2 = TYPE(determinant(mat6));
+ TYPE det3x3 = TYPE(determinant(mat2));
+
+ // Test transpose
+ matrix<TYPE, 2, 2> trans2x2 = transpose(mat1);
+ matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2);
+
+ // Test element-wise min/max
+ m2x2 mat_min = min(mat1, mat5);
+ m2x2 mat_max = max(mat1, mat5);
+
+ // Test all/any operations (these return bool, but we'll cast to TYPE for output)
+ m2x2 zero_mat = m2x2(0, 0, 0, 0);
+ m2x2 mixed_mat = m2x2(1, 0, 2, 0);
+
+ TYPE all_nonzero = TYPE(all(mat1));
+ TYPE all_zero = TYPE(all(zero_mat));
+ TYPE any_nonzero = TYPE(any(mixed_mat));
+ TYPE any_zero = TYPE(any(zero_mat));
+
+ // Test bit shift operations
+ m2x2 shift_mat = m2x2(1, 2, 4, 8);
+ m2x2 left_shift = shift_mat << 1;
+ m2x2 right_shift = shift_mat >> 1;
+
+ // Test comparison operations (these return bool matrices, cast to TYPE for output)
+ m2x2 comp_mat1 = m2x2(1, 3, 2, 4);
+ m2x2 comp_mat2 = m2x2(2, 2, 3, 3);
+
+ matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2;
+ matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2;
+ matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2;
+ matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2;
+ matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2;
+ matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2;
+
+ // Store results
+ outputBuffer[0] = val1;
+ // CHECK: 1
+ outputBuffer[1] = val2;
+ // CHECK-NEXT: 8
+ outputBuffer[2] = row.x;
+ // CHECK-NEXT: 3
+ outputBuffer[3] = row.y;
+ // CHECK-NEXT: 4
+ outputBuffer[4] = row3.y;
+ // CHECK-NEXT: 2
+ outputBuffer[5] = mat_scalar[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[6] = mat_add[0][0];
+ // CHECK-NEXT: 3
+ outputBuffer[7] = mat_sub[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[8] = mat_mul[1][1];
+ // CHECK-NEXT: 28
+ outputBuffer[9] = added;
+ // CHECK-NEXT: 10
+ outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0];
+ // CHECK-NEXT: 5
+
+ // Matrix intrinsic operation results
+ outputBuffer[11] = det2x2;
+ // CHECK-NEXT: 2
+ outputBuffer[12] = det3x3;
+ // CHECK-NEXT: 0
+ outputBuffer[13] = mat_min[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[14] = mat_min[1][1];
+ // CHECK-NEXT: 4
+ outputBuffer[15] = mat_max[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[16] = mat_max[1][1];
+ // CHECK-NEXT: 7
+ outputBuffer[17] = all_nonzero;
+ // CHECK-NEXT: 1
+ outputBuffer[18] = all_zero;
+ // CHECK-NEXT: 0
+ outputBuffer[19] = any_nonzero;
+ // CHECK-NEXT: 1
+ outputBuffer[20] = any_zero;
+ // CHECK-NEXT: 0
+ outputBuffer[21] = trans2x2[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[22] = trans2x2[1][0];
+ // CHECK-NEXT: 2
+ outputBuffer[23] = trans2x3[0][0];
+ // CHECK-NEXT: 5
+
+ // Bit shift operation results
+ outputBuffer[24] = left_shift[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[25] = left_shift[0][1];
+ // CHECK-NEXT: 4
+ outputBuffer[26] = right_shift[1][0];
+ // CHECK-NEXT: 2
+ outputBuffer[27] = right_shift[1][1];
+ // CHECK-NEXT: 4
+
+ // Comparison operation results (bool matrices cast to TYPE)
+ outputBuffer[28] = TYPE(less_than[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[29] = TYPE(less_than[0][1]);
+ // CHECK-NEXT: 0
+ outputBuffer[30] = TYPE(greater_than[0][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[31] = TYPE(greater_than[1][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[32] = TYPE(less_equal[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[33] = TYPE(less_equal[0][1]);
+ // CHECK-NEXT: 0
+ outputBuffer[34] = TYPE(greater_equal[0][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[35] = TYPE(greater_equal[1][0]);
+ // CHECK-NEXT: 0
+ outputBuffer[36] = TYPE(equal_to[0][0]);
+ // CHECK-NEXT: 0
+ outputBuffer[37] = TYPE(not_equal[0][0]);
+ // CHECK-NEXT: 1
+} \ No newline at end of file