summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/hlsl.meta.slang65
-rw-r--r--source/slang/slang-emit-spirv.cpp170
-rw-r--r--source/slang/slang-emit-wgsl.cpp16
-rw-r--r--source/slang/slang-emit.cpp3
-rw-r--r--source/slang/slang-ir-legalize-matrix-types.cpp435
-rw-r--r--source/slang/slang-ir-legalize-matrix-types.h2
-rw-r--r--tests/compute/logic-no-short-circuit-evaluation.slang4
-rw-r--r--tests/glsl/matrix-bool-lowering.slang114
-rw-r--r--tests/glsl/matrix-integer-lowering.slang199
-rw-r--r--tests/metal/matrix-bool-lowering.slang119
-rw-r--r--tests/metal/matrix-integer-lowering.slang202
-rw-r--r--tests/spirv/matrix-bool-lowering.slang2
-rw-r--r--tests/spirv/matrix-integer-lowering.slang12
-rw-r--r--tests/wgsl/matrix-bool-lowering.slang114
-rw-r--r--tests/wgsl/matrix-integer-lowering.slang199
15 files changed, 1484 insertions, 172 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index aa494ec95..9fd5c8b6e 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -6481,14 +6481,10 @@ bool all(T x)
{
__target_switch
{
- default:
- __intrinsic_asm "bool($0)";
case hlsl:
__intrinsic_asm "all";
case metal:
__intrinsic_asm "all";
- case wgsl:
- __intrinsic_asm "all";
case spirv:
let zero = __default<T>();
if (__isInt<T>())
@@ -6505,6 +6501,8 @@ bool all(T x)
return __slang_noop_cast<bool>(x);
else
return false;
+ default:
+ __intrinsic_asm "bool($0)";
}
}
@@ -6550,9 +6548,17 @@ bool all(vector<T,N> x)
};
}
case wgsl:
+ // WGSL all() only works with boolean vectors
if (__isBool<T>())
- __intrinsic_asm "all";
- __intrinsic_asm "all(vec$N0<bool>($0))";
+ __intrinsic_asm "all($0)";
+ else
+ {
+ // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion
+ bool result = true;
+ for(int i = 0; i < N; ++i)
+ result = result && all(x[i]);
+ return result;
+ }
default:
bool result = true;
for(int i = 0; i < N; ++i)
@@ -6563,7 +6569,7 @@ bool all(vector<T,N> x)
__generic<T : __BuiltinType, let N : int, let M : int>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_metal_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
bool all(matrix<T,N,M> x)
{
__target_switch
@@ -6655,7 +6661,8 @@ bool any(T x)
case metal:
__intrinsic_asm "any";
case wgsl:
- __intrinsic_asm "any";
+ // For scalars, any() doesn't exist in WGSL, just convert to bool
+ __intrinsic_asm "bool($0)";
case spirv:
let zero = __default<T>();
if (__isInt<T>())
@@ -6686,7 +6693,17 @@ bool any(vector<T, N> x)
case hlsl:
__intrinsic_asm "any";
case metal:
- __intrinsic_asm "any";
+ if (__isBool<T>())
+ __intrinsic_asm "any";
+ else
+ {
+ // For non-bool types, convert to bool vector first
+ // Metal's any() only works with bool vectors
+ bool result = false;
+ for(int i = 0; i < N; ++i)
+ result = result || any(x[i]);
+ return result;
+ }
case glsl:
__intrinsic_asm "any(bvec$N0($0))";
case spirv:
@@ -6714,7 +6731,17 @@ bool any(vector<T, N> x)
};
}
case wgsl:
- __intrinsic_asm "any";
+ // WGSL any() only works with boolean vectors
+ if (__isBool<T>())
+ __intrinsic_asm "any($0)";
+ else
+ {
+ // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion
+ bool result = false;
+ for(int i = 0; i < N; ++i)
+ result = result || any(x[i]);
+ return result;
+ }
default:
bool result = false;
for(int i = 0; i < N; ++i)
@@ -6725,7 +6752,7 @@ bool any(vector<T, N> x)
__generic<T : __BuiltinType, let N : int, let M : int>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_spirv)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)]
bool any(matrix<T, N, M> x)
{
__target_switch
@@ -8626,11 +8653,8 @@ T determinant(matrix<T,N,N> m)
{
__target_switch
{
- case glsl: __intrinsic_asm "determinant";
case hlsl: __intrinsic_asm "determinant";
- case metal: __intrinsic_asm "determinant";
- case wgsl: __intrinsic_asm "determinant";
- // SPIR-V doesn't support integer determinants, so we need to implement it manually
+ // GLSL, WGSL, and SPIR-V don't support integer determinants for lowered matrices, so we need to implement it manually
default:
static_assert(N >= 1 && N <= 4, "determinant is only implemented up to 4x4 matrices");
if (N == 1)
@@ -13804,16 +13828,14 @@ matrix<T, M, N> transpose(matrix<T, N, M> x)
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
[PreferRecompute]
matrix<T, M, N> transpose(matrix<T, N, M> x)
{
__target_switch
{
- case glsl: __intrinsic_asm "transpose";
case hlsl: __intrinsic_asm "transpose";
- case wgsl: __intrinsic_asm "transpose";
- // SPIRV-V doenst't support integer matrices, so transpose it manually
+ // GLSL, WGSL, SPIR-V, and Metal don't support integer matrices when lowered, so transpose it manually
default:
matrix<T, M, N> result;
for (int r = 0; r < M; ++r)
@@ -13824,19 +13846,18 @@ matrix<T, M, N> transpose(matrix<T, N, M> x)
}
__generic<T : __BuiltinLogicalType, let N : int, let M : int>
[__readNone]
-[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)]
[PreferRecompute]
[OverloadRank(-1)]
matrix<T, M, N> transpose(matrix<T, N, M> x)
{
__target_switch
{
- case glsl: __intrinsic_asm "transpose";
case hlsl: __intrinsic_asm "transpose";
case spirv: return spirv_asm {
OpTranspose $$matrix<T, M, N> result $x
};
- case wgsl: __intrinsic_asm "transpose";
+ // GLSL, WGSL, and Metal don't support bool matrices when lowered, so transpose it manually
default:
matrix<T, M, N> result;
for (int r = 0; r < M; ++r)
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index b1b4c4570..da2620856 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -360,7 +360,7 @@ struct SpvLiteralBits
// > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed
// > four per word, following the little-endian convention (i.e., the
// > first octet is in the lowest-order 8 bits of the word).
- // > The final word contains the string's nul-termination character (0), and
+ // > The final word contains the string’s nul-termination character (0), and
// > all contents past the end of the string in the final word are padded with 0.
// First work out the amount of words we'll need
@@ -2039,24 +2039,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_MatrixType:
{
auto matrixType = static_cast<IRMatrixType*>(inst);
- auto elementType = matrixType->getElementType();
-
- // SPIR-V only supports floating-point matrices
- // bool/int matrices should be lowered to
- // arrays of vectors before reaching here
- SLANG_ASSERT(!as<IRBoolType>(elementType));
- SLANG_ASSERT(!as<IRIntType>(elementType));
- SLANG_ASSERT(!as<IRUIntType>(elementType));
-
auto vectorSpvType = ensureVectorType(
- static_cast<IRBasicType*>(elementType)->getBaseType(),
+ static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(),
static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(),
nullptr);
const auto columnCount =
static_cast<IRIntLit*>(matrixType->getRowCount())->getValue();
- const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount));
- SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv);
- return matrixSpvType;
+ auto matrixSPVType = emitOpTypeMatrix(
+ inst,
+ vectorSpvType,
+ SpvLiteralInteger::from32(int32_t(columnCount)));
+ return matrixSPVType;
}
case kIROp_ArrayType:
case kIROp_UnsizedArrayType:
@@ -2628,7 +2621,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SpvWord arrayed =
inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed;
- // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored."
+ // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored."
SpvWord depth =
ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image
SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled
@@ -7780,40 +7773,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
// Otherwise, operands are raw elements, we need to construct row vectors first,
// then construct matrix from row vectors.
List<SpvInst*> rowVectors;
-
- IRIntegerValue rowCount;
- IRIntegerValue colCount;
- IRType* elementType;
-
- // Data type can be either matrix or vector depending on the
- // legalization requirements
- auto dataType = inst->getDataType();
-
- if (auto matrixType = as<IRMatrixType>(dataType))
- {
- elementType = matrixType->getElementType();
- rowCount = getIntVal(matrixType->getRowCount());
- colCount = getIntVal(matrixType->getColumnCount());
- }
- else if (auto arrayType = as<IRArrayType>(dataType))
- {
- auto vectorType = as<IRVectorType>(arrayType->getElementType());
- SLANG_ASSERT(vectorType);
-
- elementType = vectorType->getElementType();
- rowCount = getIntVal(arrayType->getElementCount());
- colCount = getIntVal(vectorType->getElementCount());
- }
- else
- {
- SLANG_UNEXPECTED("data type for makeMatrix operation is "
- "expected be either a matrix or array type");
- }
-
+ auto matrixType = cast<IRMatrixType>(inst->getDataType());
+ auto rowCount = getIntVal(matrixType->getRowCount());
+ auto colCount = getIntVal(matrixType->getColumnCount());
IRBuilder builder(inst);
builder.setInsertBefore(inst);
- auto rowVectorType = builder.getVectorType(elementType, colCount);
-
+ auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
List<IRInst*> colElements;
UInt index = 0;
for (IRIntegerValue j = 0; j < rowCount; j++)
@@ -7938,10 +7903,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
ArrayView<IRInst*> operands)
{
IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType());
- SLANG_ASSERT(elementType);
-
IRBasicType* basicType = as<IRBasicType>(elementType);
- SLANG_ASSERT(basicType);
SpvOp opCode = _arithmeticOpCodeConvert(op, basicType);
if (opCode == SpvOpUndef)
@@ -8002,52 +7964,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands");
}
- // Helper method to handle composite arithmetic operations for matrices and arrays
- SpvInst* emitCompositeArithmetic(
- SpvInstParent* parent,
- IRInst* inst,
- IRIntegerValue rowCount,
- IRIntegerValue colCount,
- IRType* elementType,
- IRType* resultType,
- bool isMatrixType)
- {
- IRBuilder builder(inst);
- builder.setInsertBefore(inst);
- auto rowVectorType = builder.getVectorType(elementType, colCount);
- List<SpvInst*> rows;
-
- for (IRIntegerValue i = 0; i < rowCount; i++)
- {
- List<IRInst*> operands;
- for (UInt j = 0; j < inst->getOperandCount(); j++)
- {
- auto originalOperand = inst->getOperand(j);
- bool shouldExtract =
- isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr
- : as<IRArrayType>(originalOperand->getDataType()) != nullptr;
-
- if (shouldExtract)
- {
- auto operand = builder.emitElementExtract(originalOperand, i);
- emitLocalInst(parent, operand);
- operands.add(operand);
- }
- else
- {
- operands.add(originalOperand);
- }
- }
- rows.add(emitVectorOrScalarArithmetic(
- parent,
- nullptr,
- rowVectorType,
- inst->getOp(),
- inst->getOperandCount(),
- operands.getArrayView()));
- }
- return emitCompositeConstruct(parent, inst, resultType, rows);
- }
SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst)
{
@@ -8055,38 +7971,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
auto rowCount = getIntVal(matrixType->getRowCount());
auto colCount = getIntVal(matrixType->getColumnCount());
- return emitCompositeArithmetic(
- parent,
- inst,
- rowCount,
- colCount,
- matrixType->getElementType(),
- inst->getDataType(),
- true);
- }
- else if (const auto arrayType = as<IRArrayType>(inst->getDataType()))
- {
- // Only for legalization
- auto arrayElementType = arrayType->getElementType();
- SLANG_ASSERT(as<IRVectorType>(arrayElementType));
-
- auto vectorType = as<IRVectorType>(arrayElementType);
- auto elementType = vectorType->getElementType();
- SLANG_ASSERT(
- as<IRBoolType>(elementType) || as<IRUIntType>(elementType) ||
- as<IRIntType>(elementType));
-
- auto rowCount = getIntVal(arrayType->getElementCount());
- auto colCount = getIntVal(vectorType->getElementCount());
-
- return emitCompositeArithmetic(
- parent,
- inst,
- rowCount,
- colCount,
- elementType,
- inst->getDataType(),
- false);
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+ auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount);
+ List<SpvInst*> rows;
+ for (IRIntegerValue i = 0; i < rowCount; i++)
+ {
+ List<IRInst*> operands;
+ for (UInt j = 0; j < inst->getOperandCount(); j++)
+ {
+ auto originalOperand = inst->getOperand(j);
+ if (as<IRMatrixType>(originalOperand->getDataType()))
+ {
+ auto operand = builder.emitElementExtract(originalOperand, i);
+ emitLocalInst(parent, operand);
+ operands.add(operand);
+ }
+ else
+ {
+ operands.add(originalOperand);
+ }
+ }
+ rows.add(emitVectorOrScalarArithmetic(
+ parent,
+ nullptr,
+ rowVectorType,
+ inst->getOp(),
+ inst->getOperandCount(),
+ operands.getArrayView()));
+ }
+ return emitCompositeConstruct(parent, inst, inst->getDataType(), rows);
}
Array<IRInst*, 4> operands;
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index fbcb54d10..53c3aa487 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -1624,6 +1624,22 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
m_writer->emit(")");
return true;
}
+ case kIROp_Neg:
+ {
+ auto opType = inst->getOperand(0)->getDataType();
+ if (as<IRMatrixType>(opType) || as<IRVectorType>(opType))
+ {
+ // WGSL does not support negate operator on matrices and vectors,
+ // we should emit "(type(0) - op0)" instead.
+ m_writer->emit("(");
+ emitType(inst->getDataType());
+ m_writer->emit("(0) - ");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ break;
+ }
}
return false;
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index b548ef632..405bca5a2 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1339,7 +1339,10 @@ Result linkAndOptimizeIR(
}
legalizeMatrixTypes(targetProgram, irModule, sink);
+ dumpIRIfEnabled(codeGenContext, irModule, "AFTER-MATRIX-LEGALIZATION");
+
legalizeVectorTypes(irModule, sink);
+ dumpIRIfEnabled(codeGenContext, irModule, "AFTER-VECTOR-LEGALIZATION");
// Once specialization and type legalization have been performed,
// we should perform some of our basic optimization steps again,
diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp
index 0b972b5bd..8c8cb0c84 100644
--- a/source/slang/slang-ir-legalize-matrix-types.cpp
+++ b/source/slang/slang-ir-legalize-matrix-types.cpp
@@ -1,6 +1,7 @@
#include "slang-ir-legalize-matrix-types.h"
#include "slang-compiler.h"
+#include "slang-ir-insts-enum.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"
#include "slang-ir.h"
@@ -50,6 +51,9 @@ struct MatrixTypeLoweringContext
case CodeGenTarget::WGSL:
case CodeGenTarget::WGSLSPIRV:
case CodeGenTarget::WGSLSPIRVAssembly:
+ case CodeGenTarget::Metal:
+ case CodeGenTarget::MetalLib:
+ case CodeGenTarget::MetalLibAssembly:
return true;
default:
return false;
@@ -66,33 +70,430 @@ struct MatrixTypeLoweringContext
as<IRIntType>(elementType);
}
- IRInst* getReplacement(IRInst* inst)
+ IRInst* legalizeMatrixTypeDeclaration(IRInst* inst)
{
- if (auto replacement = replacements.tryGetValue(inst))
- return *replacement;
+ auto matrixType = as<IRMatrixType>(inst);
+ if (shouldLowerMatrixType(matrixType))
+ {
+ // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C)
+ auto elementType = matrixType->getElementType();
+ auto rowCount = matrixType->getRowCount();
+ auto columnCount = matrixType->getColumnCount();
- IRInst* newInst = inst;
+ IRBuilder builder(matrixType);
+ builder.setInsertBefore(matrixType);
+
+ // Create vector type for columns: vector<T, C>
+ auto vectorType = builder.getVectorType(elementType, columnCount);
+
+ // Create array type for rows: vector<T, C>[R]
+ auto arrayType = builder.getArrayType(vectorType, rowCount);
+
+ return arrayType;
+ }
+ return inst;
+ }
+
+ IRInst* legalizeMakeMatrix(IRInst* inst)
+ {
+ auto makeMatrix = as<IRMakeMatrix>(inst);
+ auto matrixType = as<IRMatrixType>(makeMatrix->getDataType());
+
+ SLANG_ASSERT(matrixType && "Matrix type is expected");
+ SLANG_ASSERT(
+ shouldLowerMatrixType(matrixType) && "Matrix type is expected to need legalization");
+
+ // Lower makeMatrix to makeArray of makeVectors
+ auto elementType = matrixType->getElementType();
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ auto columnCount = as<IRIntLit>(matrixType->getColumnCount());
+
+ SLANG_ASSERT(
+ rowCount && columnCount &&
+ "Matrix dimensions must be compile-time constants for lowering");
- if (auto matrixType = as<IRMatrixType>(inst))
+ IRBuilder builder(makeMatrix);
+ builder.setInsertBefore(makeMatrix);
+
+ // Create vector type for rows: vector<T, C>
+ auto vectorType = builder.getVectorType(elementType, columnCount);
+
+ // Create array type: vector<T, C>[R]
+ auto arrayType = builder.getArrayType(vectorType, rowCount);
+
+ // Group operands into rows and create vectors
+ List<IRInst*> rowVectors;
+ UInt operandIndex = 0;
+
+ // Assert that we have the expected number of operands
+ SLANG_ASSERT(
+ makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()) &&
+ "makeMatrix operand count must match matrix dimensions");
+
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
{
- if (shouldLowerMatrixType(matrixType))
+ List<IRInst*> rowElements;
+ for (IRIntegerValue col = 0; col < columnCount->getValue(); col++)
{
- // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C)
- auto elementType = matrixType->getElementType();
- auto rowCount = matrixType->getRowCount();
- auto columnCount = matrixType->getColumnCount();
+ SLANG_ASSERT(
+ operandIndex < makeMatrix->getOperandCount() && "Operand index out of bounds");
+ rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex)));
+ operandIndex++;
+ }
+
+ SLANG_ASSERT(
+ rowElements.getCount() == columnCount->getValue() &&
+ "Row elements count must match column count");
+ auto rowVector = builder.emitMakeVector(vectorType, rowElements);
+ rowVectors.add(rowVector);
+ }
+
+ SLANG_ASSERT(
+ rowVectors.getCount() == rowCount->getValue() &&
+ "Row vectors count must match matrix row count");
+ return builder.emitMakeArray(arrayType, rowVectors.getCount(), rowVectors.getBuffer());
+ }
+
+ IRInst* legalizeMatrixMatrixBinaryOperation(
+ IRBuilder& builder,
+ IRInst* legalizedA,
+ IRInst* legalizedB,
+ IRMatrixType* resultMatrixType,
+ IROp binaryOp)
+ {
+ auto elementType = resultMatrixType->getElementType();
+ auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount());
+ auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount());
+
+ SLANG_ASSERT(
+ rowCount && columnCount &&
+ "Matrix dimensions must be compile-time constants for lowering");
- IRBuilder builder(matrixType);
- builder.setInsertBefore(matrixType);
+ // Create vector type for rows: vector<T, C>
+ auto vectorType = builder.getVectorType(elementType, columnCount);
- // Create vector type for columns: vector<T, C>
- auto vectorType = builder.getVectorType(elementType, columnCount);
+ // Create array type: vector<T, C>[R]
+ auto arrayType = builder.getArrayType(vectorType, rowCount);
- // Create array type for rows: vector<T, C>[R]
- auto arrayType = builder.getArrayType(vectorType, rowCount);
+ // Extract vectors from both arrays and apply binary operation
+ List<IRInst*> resultVectors;
+
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ {
+ // Extract the row vector from each operand array
+ auto rowIndexInst = builder.getIntValue(builder.getIntType(), row);
+ auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst);
+ auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst);
- newInst = arrayType;
+ // Apply the binary operation to the vectors
+ IRInst* args[] = {vectorA, vectorB};
+ auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args);
+
+ resultVectors.add(resultVector);
+ }
+
+ // Create the result array from the vectors
+ return builder.emitMakeArray(
+ arrayType,
+ resultVectors.getCount(),
+ resultVectors.getBuffer());
+ }
+
+
+ template<bool matrixIsFirst>
+ IRInst* legalizeMatrixMixedBinaryOperation(
+ IRBuilder& builder,
+ IRInst* legalizedMatrix,
+ IRInst* legalizedOther,
+ IRMatrixType* resultMatrixType,
+ IROp binaryOp)
+ {
+ // Verify that the other operand is either a vector or scalar type
+ auto otherType = legalizedOther->getDataType();
+ auto otherVectorType = as<IRVectorType>(otherType);
+ auto otherBasicType = as<IRBasicType>(otherType);
+ SLANG_ASSERT(
+ (otherVectorType || otherBasicType) && "Other operand must be vector or scalar type");
+
+ auto elementType = resultMatrixType->getElementType();
+ auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount());
+ auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount());
+
+ SLANG_ASSERT(
+ rowCount && columnCount &&
+ "Matrix dimensions must be compile-time constants for lowering");
+
+ // Create vector type for rows: vector<T, C>
+ auto vectorType = builder.getVectorType(elementType, columnCount);
+
+ // Create array type: vector<T, C>[R]
+ auto arrayType = builder.getArrayType(vectorType, rowCount);
+
+ // Extract vectors from matrix array and apply binary operation with other operand
+ List<IRInst*> resultVectors;
+
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ {
+ // Extract the row vector from matrix array
+ auto rowIndexInst = builder.getIntValue(builder.getIntType(), row);
+ auto matrixRowVector = builder.emitElementExtract(legalizedMatrix, rowIndexInst);
+
+ // Apply the binary operation between matrix row vector and other operand
+ IRInst* args[2];
+ if constexpr (matrixIsFirst)
+ {
+ args[0] = matrixRowVector;
+ args[1] = legalizedOther;
}
+ else
+ {
+ args[0] = legalizedOther;
+ args[1] = matrixRowVector;
+ }
+ auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args);
+
+ resultVectors.add(resultVector);
+ }
+
+ // Create the result array from the vectors
+ return builder.emitMakeArray(
+ arrayType,
+ resultVectors.getCount(),
+ resultVectors.getBuffer());
+ }
+
+ IRInst* legalizeBinaryOperation(IRInst* inst, IROp binaryOp)
+ {
+ IRInst* opdA = inst->getOperand(0);
+ IRInst* opdB = inst->getOperand(1);
+
+ // Check what types we're dealing with
+ auto typeA = opdA->getDataType();
+ auto typeB = opdB->getDataType();
+
+ auto matrixTypeA = as<IRMatrixType>(typeA);
+ auto matrixTypeB = as<IRMatrixType>(typeB);
+
+ bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA);
+ bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB);
+
+ // Get the result matrix type to determine dimensions
+ auto resultMatrixType = as<IRMatrixType>(inst->getDataType());
+ SLANG_ASSERT(resultMatrixType && "Binary operation should have matrix result type");
+ SLANG_ASSERT(
+ shouldLowerMatrixType(resultMatrixType) &&
+ "Result matrix type should need legalization");
+
+ // Create IRBuilder at the top level
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+
+ // Get legalized operands once
+ IRInst* legalizedA = getReplacement(opdA);
+ IRInst* legalizedB = getReplacement(opdB);
+
+ if (shouldLowerA && shouldLowerB)
+ {
+ return legalizeMatrixMatrixBinaryOperation(
+ builder,
+ legalizedA,
+ legalizedB,
+ resultMatrixType,
+ binaryOp);
+ }
+ else if (shouldLowerA && !shouldLowerB)
+ {
+ return legalizeMatrixMixedBinaryOperation<true>(
+ builder,
+ legalizedA,
+ legalizedB,
+ resultMatrixType,
+ binaryOp);
+ }
+ else if (!shouldLowerA && shouldLowerB)
+ {
+ return legalizeMatrixMixedBinaryOperation<false>(
+ builder,
+ legalizedB,
+ legalizedA,
+ resultMatrixType,
+ binaryOp);
+ }
+
+ // Neither operand is a matrix that needs lowering, shouldn't reach here
+ SLANG_UNREACHABLE("legalizeBinaryOperation called but no matrix operand needs lowering");
+ }
+
+ IRInst* legalizeComparisonOperation(IRInst* inst, IROp comparisonOp)
+ {
+ IRInst* opdA = inst->getOperand(0);
+ IRInst* opdB = inst->getOperand(1);
+
+ // Check what types we're dealing with
+ auto typeA = opdA->getDataType();
+ auto typeB = opdB->getDataType();
+
+ auto matrixTypeA = as<IRMatrixType>(typeA);
+ auto matrixTypeB = as<IRMatrixType>(typeB);
+
+ bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA);
+ bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB);
+
+ // Only matrix-matrix comparisons are supported
+ SLANG_ASSERT(
+ shouldLowerA && shouldLowerB &&
+ "Comparison operations only supported between matrices that need lowering");
+
+ // Create IRBuilder at the top level
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+
+ // Get legalized operands
+ IRInst* legalizedA = getReplacement(opdA);
+ IRInst* legalizedB = getReplacement(opdB);
+
+ auto rowCount = as<IRIntLit>(matrixTypeA->getRowCount());
+ auto columnCount = as<IRIntLit>(matrixTypeA->getColumnCount());
+
+ SLANG_ASSERT(
+ rowCount && columnCount &&
+ "Matrix dimensions must be compile-time constants for lowering");
+
+ // Create boolean vector type for rows: vector<bool, C>
+ auto boolType = builder.getBoolType();
+ auto boolVectorType = builder.getVectorType(boolType, columnCount);
+
+ // Create array type: vector<bool, C>[R]
+ auto boolArrayType = builder.getArrayType(boolVectorType, rowCount);
+
+ // Extract vectors from both arrays and apply comparison operation
+ List<IRInst*> resultVectors;
+
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ {
+ // Extract the row vector from each operand array
+ auto rowIndexInst = builder.getIntValue(builder.getIntType(), row);
+ auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst);
+ auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst);
+
+ // Apply the comparison operation to the vectors
+ IRInst* args[] = {vectorA, vectorB};
+ auto resultVector = builder.emitIntrinsicInst(boolVectorType, comparisonOp, 2, args);
+
+ resultVectors.add(resultVector);
+ }
+
+ // Create the result array from the vectors
+ return builder.emitMakeArray(
+ boolArrayType,
+ resultVectors.getCount(),
+ resultVectors.getBuffer());
+ }
+
+ IRInst* legalizeUnaryOperation(IRInst* inst, IROp unaryOp)
+ {
+ IRInst* operand = inst->getOperand(0);
+
+ // Get the legalized operand (should be an array of vectors)
+ IRInst* legalizedOperand = getReplacement(operand);
+
+ // Get the result matrix type to determine dimensions
+ auto resultMatrixType = as<IRMatrixType>(inst->getDataType());
+ SLANG_ASSERT(resultMatrixType && "Unary operation should have matrix result type");
+ SLANG_ASSERT(
+ shouldLowerMatrixType(resultMatrixType) &&
+ "Result matrix type should need legalization");
+
+ auto elementType = resultMatrixType->getElementType();
+ auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount());
+ auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount());
+
+ SLANG_ASSERT(
+ rowCount && columnCount &&
+ "Matrix dimensions must be compile-time constants for lowering");
+
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+
+ // Create vector type for rows: vector<T, C>
+ auto vectorType = builder.getVectorType(elementType, columnCount);
+
+ // Create array type: vector<T, C>[R]
+ auto arrayType = builder.getArrayType(vectorType, rowCount);
+
+ // Extract vectors from array and apply unary operation
+ List<IRInst*> resultVectors;
+
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ {
+ // Extract the row vector from operand array
+ auto rowIndexInst = builder.getIntValue(builder.getIntType(), row);
+ auto vector = builder.emitElementExtract(legalizedOperand, rowIndexInst);
+
+ // Apply the unary operation to the vector
+ IRInst* args[] = {vector};
+ auto resultVector = builder.emitIntrinsicInst(vectorType, unaryOp, 1, args);
+
+ resultVectors.add(resultVector);
+ }
+
+ // Create the result array from the vectors
+ return builder.emitMakeArray(
+ arrayType,
+ resultVectors.getCount(),
+ resultVectors.getBuffer());
+ }
+
+ IRInst* legalizeMatrixProducingInstruction(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_MakeMatrix:
+ return legalizeMakeMatrix(inst);
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ return legalizeBinaryOperation(inst, inst->getOp());
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ return legalizeComparisonOperation(inst, inst->getOp());
+ case kIROp_Not:
+ case kIROp_BitNot:
+ case kIROp_Neg:
+ return legalizeUnaryOperation(inst, inst->getOp());
+ default:
+ break;
+ }
+
+ return inst;
+ }
+
+ IRInst* getReplacement(IRInst* inst)
+ {
+ if (auto replacement = replacements.tryGetValue(inst))
+ return *replacement;
+
+ IRInst* newInst = inst;
+ if (as<IRMatrixType>(inst))
+ newInst = legalizeMatrixTypeDeclaration(inst);
+
+ IRType* resultType = inst->getDataType();
+ if (auto matrixType = as<IRMatrixType>(resultType))
+ {
+ if (shouldLowerMatrixType(matrixType))
+ newInst = legalizeMatrixProducingInstruction(inst);
}
replacements[inst] = newInst;
diff --git a/source/slang/slang-ir-legalize-matrix-types.h b/source/slang/slang-ir-legalize-matrix-types.h
index 418e80a83..a2e71a402 100644
--- a/source/slang/slang-ir-legalize-matrix-types.h
+++ b/source/slang/slang-ir-legalize-matrix-types.h
@@ -7,7 +7,7 @@ struct IRModule;
class DiagnosticSink;
class TargetProgram;
-// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, and GLSL targets
+// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, GLSL, and Metal targets
void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink);
} // namespace Slang \ No newline at end of file
diff --git a/tests/compute/logic-no-short-circuit-evaluation.slang b/tests/compute/logic-no-short-circuit-evaluation.slang
index ea2b7a0c3..342a11f28 100644
--- a/tests/compute/logic-no-short-circuit-evaluation.slang
+++ b/tests/compute/logic-no-short-circuit-evaluation.slang
@@ -32,7 +32,7 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
//SM5:(all({{.*}}&&
//HLSL2018:(all({{.*}}&&
//SM6:(all(and(
- //WGS:(all(select(vec2<bool>(false),
+ //WGS:(all((select(vec2<bool>(false),
//MTL:(all({{.*}}&&
if (all(bool2(index >= 1) && assignFunc(index)))
{
@@ -54,7 +54,7 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
//SM5:(all({{.*}}?{{.*}}:
//HLSL2018:(all({{.*}}?{{.*}}:
//SM6:(all(select(
- //WGS:(all(select(vec2<bool>(false),
+ //WGS:(all((select(vec2<bool>(false),
//MTL:(all(select(bool2(false)
if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false)))
{
diff --git a/tests/glsl/matrix-bool-lowering.slang b/tests/glsl/matrix-bool-lowering.slang
new file mode 100644
index 000000000..9f2ad913f
--- /dev/null
+++ b/tests/glsl/matrix-bool-lowering.slang
@@ -0,0 +1,114 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -emit-spirv-via-glsl -shaderobj
+
+//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> inputBuffer;
+RWStructuredBuffer<int> outputBuffer;
+
+// Global bool constants to avoid constant folding
+static bool trueVal;
+static bool falseVal;
+
+struct matrixWrapper {
+ bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal);
+ bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal);
+}
+
+bool elementAnd(bool2x2 matrix)
+{
+ return trueVal
+ && matrix[0][0]
+ && matrix[0][1]
+ && matrix[1][0]
+ && matrix[1][1];
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Load true/false values from input buffer to avoid constant folding
+ trueVal = inputBuffer[0] != 0;
+ falseVal = inputBuffer[1] != 0;
+
+ // Test bool matrix construction
+ bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal);
+ bool3x3 mat2 = bool3x3(
+ trueVal, falseVal, trueVal,
+ falseVal, trueVal, falseVal,
+ trueVal, falseVal, trueVal
+ );
+ bool2x4 mat3 = bool2x4(
+ trueVal, falseVal, trueVal, falseVal,
+ trueVal, falseVal, trueVal, falseVal
+ );
+
+ // Test bool matrix element access
+ bool val1 = mat1[0][0];
+ bool val2 = mat2[2][1];
+
+ // Test bool matrix row access
+ bool2 row = mat1[1];
+ bool3 row3 = mat2[0];
+
+ // Test logical operations
+ bool2x2 not_mat = !mat1;
+ bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal);
+
+ // Test element assignment
+ mat1[0][1] = trueVal;
+ mat2[1][2] = falseVal;
+
+ // Test passing bool matrices to functions
+ bool anded = elementAnd(mat1);
+
+ // Test structs with bool matrix fields
+ matrixWrapper wrapper = {};
+
+ // Test any/all operations
+ bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal);
+ bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal);
+ bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal);
+
+ bool test_all_true = all(all_true); // all elements true -> true
+ bool test_all_false = all(all_false); // all elements false -> false
+ bool test_all_mixed = all(mixed); // some elements false -> false
+ bool test_any_true = any(all_true); // some elements true -> true
+ bool test_any_false = any(all_false); // no elements true -> false
+ bool test_any_mixed = any(mixed); // some elements true -> true
+
+ // Store results
+ outputBuffer[0] = val1;
+ // CHECK: 1
+ outputBuffer[1] = val2;
+ // CHECK-NEXT: 0
+ outputBuffer[2] = row.x;
+ // CHECK-NEXT: 0
+ outputBuffer[3] = row.y;
+ // CHECK-NEXT: 1
+ outputBuffer[4] = row3.y;
+ // CHECK-NEXT: 0
+ outputBuffer[5] = not_mat[0][0];
+ // CHECK-NEXT: 0
+ outputBuffer[6] = and_mat[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[7] = mat1[0][1];
+ // CHECK-NEXT: 1
+ outputBuffer[8] = mat3[0][1];
+ // CHECK-NEXT: 0
+ outputBuffer[9] = anded;
+ // CHECK-NEXT: 0
+ outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[11] = test_all_true;
+ // CHECK-NEXT: 1
+ outputBuffer[12] = test_all_false;
+ // CHECK-NEXT: 0
+ outputBuffer[13] = test_all_mixed;
+ // CHECK-NEXT: 0
+ outputBuffer[14] = test_any_true;
+ // CHECK-NEXT: 1
+ outputBuffer[15] = test_any_false;
+ // CHECK-NEXT: 0
+ outputBuffer[16] = test_any_mixed;
+ // CHECK-NEXT: 1
+} \ No newline at end of file
diff --git a/tests/glsl/matrix-integer-lowering.slang b/tests/glsl/matrix-integer-lowering.slang
new file mode 100644
index 000000000..4d6033d79
--- /dev/null
+++ b/tests/glsl/matrix-integer-lowering.slang
@@ -0,0 +1,199 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -vk -output-using-type -compute -emit-spirv-via-glsl -shaderobj -xslang -DTYPE=int
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -vk -output-using-type -compute -emit-spirv-via-glsl -shaderobj -xslang -DTYPE=uint
+
+#ifndef TYPE
+#define TYPE int
+#endif
+
+typealias m2x2 = matrix<TYPE, 2, 2>;
+typealias m2x3 = matrix<TYPE, 2, 3>;
+typealias m3x3 = matrix<TYPE, 3, 3>;
+typealias m2x4 = matrix<TYPE, 2, 4>;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer
+RWStructuredBuffer<TYPE> outputBuffer;
+RWStructuredBuffer<TYPE> expectedBuffer;
+
+struct matrixWrapper {
+ m2x2 mat1 = m2x2(1, 2, 3, 4);
+ m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10);
+};
+
+TYPE elementAdd(m2x2 matrix)
+{
+ return matrix[0][0]
+ + matrix[0][1]
+ + matrix[1][0]
+ + matrix[1][1];
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Test matrix construction
+ m2x2 mat1 = m2x2(1, 2, 3, 4);
+ m3x3 mat2 = m3x3(
+ 1, 2, 3,
+ 4, 5, 6,
+ 7, 8, 9
+ );
+ m2x4 mat3 = m2x4(
+ 10, 11, 12, 13,
+ 14, 15, 16, 17
+ );
+
+ // Test matrix element access
+ TYPE val1 = mat1[0][0];
+ TYPE val2 = mat2[2][1];
+
+ // Test matrix row access
+ vector<TYPE, 2> row = mat1[1];
+ vector<TYPE, 3> row3 = mat2[0];
+
+ // Test arithmetic operations
+ m2x2 mat5 = m2x2(2, 4, 6, 7);
+
+ m2x2 mat_scalar = 2 * mat1;
+ m2x2 mat_add = mat1 + mat5;
+ m2x2 mat_sub = mat5 - mat1;
+ m2x2 mat_mul = mat1 * mat5;
+
+ // Test passing matrices to functions
+ TYPE added = elementAdd(mat1);
+
+ // Test structs with matrix fields
+ matrixWrapper wrapper = {};
+
+ // Test matrix intrinsic operations
+
+ // Test determinant for square matrices
+ m2x2 mat6 = m2x2(2, 1, 4, 3);
+ TYPE det2x2 = TYPE(determinant(mat6));
+ TYPE det3x3 = TYPE(determinant(mat2));
+
+ // Test transpose
+ matrix<TYPE, 2, 2> trans2x2 = transpose(mat1);
+ matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2);
+
+ // Test element-wise min/max
+ m2x2 mat_min = min(mat1, mat5);
+ m2x2 mat_max = max(mat1, mat5);
+
+ // Test all/any operations (these return bool, but we'll cast to TYPE for output)
+ m2x2 zero_mat = m2x2(0, 0, 0, 0);
+ m2x2 mixed_mat = m2x2(1, 0, 2, 0);
+
+ TYPE all_nonzero = TYPE(all(mat1));
+ TYPE all_zero = TYPE(all(zero_mat));
+ TYPE any_nonzero = TYPE(any(mixed_mat));
+ TYPE any_zero = TYPE(any(zero_mat));
+
+ // Test bit shift operations
+ m2x2 shift_mat = m2x2(1, 2, 4, 8);
+ m2x2 left_shift = shift_mat << 1;
+ m2x2 right_shift = shift_mat >> 1;
+
+ // Test comparison operations (these return bool matrices, cast to TYPE for output)
+ m2x2 comp_mat1 = m2x2(1, 3, 2, 4);
+ m2x2 comp_mat2 = m2x2(2, 2, 3, 3);
+
+ matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2;
+ matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2;
+ matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2;
+ matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2;
+ matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2;
+ matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2;
+
+ // Test matrix negation operations
+ m2x2 neg_mat = m2x2(1, -2, 3, -4);
+ m2x2 negated = -neg_mat;
+
+ // Store results
+ outputBuffer[0] = val1;
+ // CHECK: 1
+ outputBuffer[1] = val2;
+ // CHECK-NEXT: 8
+ outputBuffer[2] = row.x;
+ // CHECK-NEXT: 3
+ outputBuffer[3] = row.y;
+ // CHECK-NEXT: 4
+ outputBuffer[4] = row3.y;
+ // CHECK-NEXT: 2
+ outputBuffer[5] = mat_scalar[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[6] = mat_add[0][0];
+ // CHECK-NEXT: 3
+ outputBuffer[7] = mat_sub[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[8] = mat_mul[1][1];
+ // CHECK-NEXT: 28
+ outputBuffer[9] = added;
+ // CHECK-NEXT: 10
+ outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0];
+ // CHECK-NEXT: 5
+
+ // Matrix intrinsic operation results
+ outputBuffer[11] = det2x2;
+ // CHECK-NEXT: 2
+ outputBuffer[12] = det3x3;
+ // CHECK-NEXT: 0
+ outputBuffer[13] = mat_min[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[14] = mat_min[1][1];
+ // CHECK-NEXT: 4
+ outputBuffer[15] = mat_max[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[16] = mat_max[1][1];
+ // CHECK-NEXT: 7
+ outputBuffer[17] = all_nonzero;
+ // CHECK-NEXT: 1
+ outputBuffer[18] = all_zero;
+ // CHECK-NEXT: 0
+ outputBuffer[19] = any_nonzero;
+ // CHECK-NEXT: 1
+ outputBuffer[20] = any_zero;
+ // CHECK-NEXT: 0
+ outputBuffer[21] = trans2x2[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[22] = trans2x2[1][0];
+ // CHECK-NEXT: 2
+ outputBuffer[23] = trans2x3[0][0];
+ // CHECK-NEXT: 5
+
+ // Bit shift operation results
+ outputBuffer[24] = left_shift[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[25] = left_shift[0][1];
+ // CHECK-NEXT: 4
+ outputBuffer[26] = right_shift[1][0];
+ // CHECK-NEXT: 2
+ outputBuffer[27] = right_shift[1][1];
+ // CHECK-NEXT: 4
+
+ // Comparison operation results (bool matrices cast to TYPE)
+ outputBuffer[28] = TYPE(less_than[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[29] = TYPE(less_than[0][1]);
+ // CHECK-NEXT: 0
+ outputBuffer[30] = TYPE(greater_than[0][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[31] = TYPE(greater_than[1][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[32] = TYPE(less_equal[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[33] = TYPE(less_equal[0][1]);
+ // CHECK-NEXT: 0
+ outputBuffer[34] = TYPE(greater_equal[0][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[35] = TYPE(greater_equal[1][0]);
+ // CHECK-NEXT: 0
+ outputBuffer[36] = TYPE(equal_to[0][0]);
+ // CHECK-NEXT: 0
+ outputBuffer[37] = TYPE(not_equal[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[38] = TYPE(negated[0][0] == expectedBuffer[0]);
+ // CHECK-NEXT: 1
+ outputBuffer[39] = TYPE(negated[1][1] == expectedBuffer[1]);
+ // CHECK-NEXT: 1
+} \ No newline at end of file
diff --git a/tests/metal/matrix-bool-lowering.slang b/tests/metal/matrix-bool-lowering.slang
new file mode 100644
index 000000000..4248bb573
--- /dev/null
+++ b/tests/metal/matrix-bool-lowering.slang
@@ -0,0 +1,119 @@
+//TEST:SIMPLE(filecheck=METAL): -target metal -stage compute -entry computeMain
+//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage compute -entry computeMain
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -mtl -shaderobj
+
+//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> inputBuffer;
+RWStructuredBuffer<int> outputBuffer;
+
+// Global bool constants to avoid constant folding
+static bool trueVal;
+static bool falseVal;
+
+struct matrixWrapper {
+ bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal);
+ bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal);
+}
+
+bool elementAnd(bool2x2 matrix)
+{
+ return trueVal
+ && matrix[0][0]
+ && matrix[0][1]
+ && matrix[1][0]
+ && matrix[1][1];
+}
+
+// METAL: array<bool2, int(2)>
+// METALLIB: @computeMain
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Load true/false values from input buffer to avoid constant folding
+ trueVal = inputBuffer[0] != 0;
+ falseVal = inputBuffer[1] != 0;
+
+ // Test bool matrix construction
+ bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal);
+ bool3x3 mat2 = bool3x3(
+ trueVal, falseVal, trueVal,
+ falseVal, trueVal, falseVal,
+ trueVal, falseVal, trueVal
+ );
+ bool2x4 mat3 = bool2x4(
+ trueVal, falseVal, trueVal, falseVal,
+ trueVal, falseVal, trueVal, falseVal
+ );
+
+ // Test bool matrix element access
+ bool val1 = mat1[0][0];
+ bool val2 = mat2[2][1];
+
+ // Test bool matrix row access
+ bool2 row = mat1[1];
+ bool3 row3 = mat2[0];
+
+ // Test logical operations
+ bool2x2 not_mat = !mat1;
+ bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal);
+
+ // Test element assignment
+ mat1[0][1] = trueVal;
+ mat2[1][2] = falseVal;
+
+ // Test passing bool matrices to functions
+ bool anded = elementAnd(mat1);
+
+ // Test structs with bool matrix fields
+ matrixWrapper wrapper = {};
+
+ // Test any/all operations
+ bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal);
+ bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal);
+ bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal);
+
+ bool test_all_true = all(all_true); // all elements true -> true
+ bool test_all_false = all(all_false); // all elements false -> false
+ bool test_all_mixed = all(mixed); // some elements false -> false
+ bool test_any_true = any(all_true); // some elements true -> true
+ bool test_any_false = any(all_false); // no elements true -> false
+ bool test_any_mixed = any(mixed); // some elements true -> true
+
+ // Store results
+ outputBuffer[0] = val1;
+ // CHECK: 1
+ outputBuffer[1] = val2;
+ // CHECK-NEXT: 0
+ outputBuffer[2] = row.x;
+ // CHECK-NEXT: 0
+ outputBuffer[3] = row.y;
+ // CHECK-NEXT: 1
+ outputBuffer[4] = row3.y;
+ // CHECK-NEXT: 0
+ outputBuffer[5] = not_mat[0][0];
+ // CHECK-NEXT: 0
+ outputBuffer[6] = and_mat[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[7] = mat1[0][1];
+ // CHECK-NEXT: 1
+ outputBuffer[8] = mat3[0][1];
+ // CHECK-NEXT: 0
+ outputBuffer[9] = anded;
+ // CHECK-NEXT: 0
+ outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[11] = test_all_true;
+ // CHECK-NEXT: 1
+ outputBuffer[12] = test_all_false;
+ // CHECK-NEXT: 0
+ outputBuffer[13] = test_all_mixed;
+ // CHECK-NEXT: 0
+ outputBuffer[14] = test_any_true;
+ // CHECK-NEXT: 1
+ outputBuffer[15] = test_any_false;
+ // CHECK-NEXT: 0
+ outputBuffer[16] = test_any_mixed;
+ // CHECK-NEXT: 1
+} \ No newline at end of file
diff --git a/tests/metal/matrix-integer-lowering.slang b/tests/metal/matrix-integer-lowering.slang
new file mode 100644
index 000000000..04aec5a7c
--- /dev/null
+++ b/tests/metal/matrix-integer-lowering.slang
@@ -0,0 +1,202 @@
+//TEST:SIMPLE(filecheck=METAL): -target metal -stage compute -entry computeMain -DTYPE=int
+//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage compute -entry computeMain -DTYPE=int
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -mtl -shaderobj -xslang -DTYPE=int
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -mtl -shaderobj -xslang -DTYPE=uint
+
+#ifndef TYPE
+#define TYPE int
+#endif
+
+typealias m2x2 = matrix<TYPE, 2, 2>;
+typealias m2x3 = matrix<TYPE, 2, 3>;
+typealias m3x3 = matrix<TYPE, 3, 3>;
+typealias m2x4 = matrix<TYPE, 2, 4>;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer
+RWStructuredBuffer<TYPE> outputBuffer;
+RWStructuredBuffer<TYPE> expectedBuffer;
+
+struct matrixWrapper {
+ m2x2 mat1 = m2x2(1, 2, 3, 4);
+ m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10);
+};
+
+TYPE elementAdd(m2x2 matrix)
+{
+ return matrix[0][0]
+ + matrix[0][1]
+ + matrix[1][0]
+ + matrix[1][1];
+}
+
+// METAL: array<{{(int|uint)}}2, int(2)>
+// METALLIB: @computeMain
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Test matrix construction
+ m2x2 mat1 = m2x2(1, 2, 3, 4);
+ m3x3 mat2 = m3x3(
+ 1, 2, 3,
+ 4, 5, 6,
+ 7, 8, 9
+ );
+ m2x4 mat3 = m2x4(
+ 10, 11, 12, 13,
+ 14, 15, 16, 17
+ );
+
+ // Test matrix element access
+ TYPE val1 = mat1[0][0];
+ TYPE val2 = mat2[2][1];
+
+ // Test matrix row access
+ vector<TYPE, 2> row = mat1[1];
+ vector<TYPE, 3> row3 = mat2[0];
+
+ // Test arithmetic operations
+ m2x2 mat5 = m2x2(2, 4, 6, 7);
+
+ m2x2 mat_scalar = 2 * mat1;
+ m2x2 mat_add = mat1 + mat5;
+ m2x2 mat_sub = mat5 - mat1;
+ m2x2 mat_mul = mat1 * mat5;
+
+ // Test passing matrices to functions
+ TYPE added = elementAdd(mat1);
+
+ // Test structs with matrix fields
+ matrixWrapper wrapper = {};
+
+ // Test matrix intrinsic operations
+
+ // Test determinant for square matrices
+ m2x2 mat6 = m2x2(2, 1, 4, 3);
+ TYPE det2x2 = TYPE(determinant(mat6));
+ TYPE det3x3 = TYPE(determinant(mat2));
+
+ // Test transpose
+ matrix<TYPE, 2, 2> trans2x2 = transpose(mat1);
+ matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2);
+
+ // Test element-wise min/max
+ m2x2 mat_min = min(mat1, mat5);
+ m2x2 mat_max = max(mat1, mat5);
+
+ // Test all/any operations (these return bool, but we'll cast to TYPE for output)
+ m2x2 zero_mat = m2x2(0, 0, 0, 0);
+ m2x2 mixed_mat = m2x2(1, 0, 2, 0);
+
+ TYPE all_nonzero = TYPE(all(mat1));
+ TYPE all_zero = TYPE(all(zero_mat));
+ TYPE any_nonzero = TYPE(any(mixed_mat));
+ TYPE any_zero = TYPE(any(zero_mat));
+
+ // Test bit shift operations
+ m2x2 shift_mat = m2x2(1, 2, 4, 8);
+ m2x2 left_shift = shift_mat << 1;
+ m2x2 right_shift = shift_mat >> 1;
+
+ // Test comparison operations (these return bool matrices, cast to TYPE for output)
+ m2x2 comp_mat1 = m2x2(1, 3, 2, 4);
+ m2x2 comp_mat2 = m2x2(2, 2, 3, 3);
+
+ matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2;
+ matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2;
+ matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2;
+ matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2;
+ matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2;
+ matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2;
+
+ // Test matrix negation operations
+ m2x2 neg_mat = m2x2(1, -2, 3, -4);
+ m2x2 negated = -neg_mat;
+
+ // Store results
+ outputBuffer[0] = val1;
+ // CHECK: 1
+ outputBuffer[1] = val2;
+ // CHECK-NEXT: 8
+ outputBuffer[2] = row.x;
+ // CHECK-NEXT: 3
+ outputBuffer[3] = row.y;
+ // CHECK-NEXT: 4
+ outputBuffer[4] = row3.y;
+ // CHECK-NEXT: 2
+ outputBuffer[5] = mat_scalar[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[6] = mat_add[0][0];
+ // CHECK-NEXT: 3
+ outputBuffer[7] = mat_sub[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[8] = mat_mul[1][1];
+ // CHECK-NEXT: 28
+ outputBuffer[9] = added;
+ // CHECK-NEXT: 10
+ outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0];
+ // CHECK-NEXT: 5
+
+ // Matrix intrinsic operation results
+ outputBuffer[11] = det2x2;
+ // CHECK-NEXT: 2
+ outputBuffer[12] = det3x3;
+ // CHECK-NEXT: 0
+ outputBuffer[13] = mat_min[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[14] = mat_min[1][1];
+ // CHECK-NEXT: 4
+ outputBuffer[15] = mat_max[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[16] = mat_max[1][1];
+ // CHECK-NEXT: 7
+ outputBuffer[17] = all_nonzero;
+ // CHECK-NEXT: 1
+ outputBuffer[18] = all_zero;
+ // CHECK-NEXT: 0
+ outputBuffer[19] = any_nonzero;
+ // CHECK-NEXT: 1
+ outputBuffer[20] = any_zero;
+ // CHECK-NEXT: 0
+ outputBuffer[21] = trans2x2[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[22] = trans2x2[1][0];
+ // CHECK-NEXT: 2
+ outputBuffer[23] = trans2x3[0][0];
+ // CHECK-NEXT: 5
+
+ // Bit shift operation results
+ outputBuffer[24] = left_shift[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[25] = left_shift[0][1];
+ // CHECK-NEXT: 4
+ outputBuffer[26] = right_shift[1][0];
+ // CHECK-NEXT: 2
+ outputBuffer[27] = right_shift[1][1];
+ // CHECK-NEXT: 4
+
+ // Comparison operation results (bool matrices cast to TYPE)
+ outputBuffer[28] = TYPE(less_than[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[29] = TYPE(less_than[0][1]);
+ // CHECK-NEXT: 0
+ outputBuffer[30] = TYPE(greater_than[0][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[31] = TYPE(greater_than[1][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[32] = TYPE(less_equal[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[33] = TYPE(less_equal[0][1]);
+ // CHECK-NEXT: 0
+ outputBuffer[34] = TYPE(greater_equal[0][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[35] = TYPE(greater_equal[1][0]);
+ // CHECK-NEXT: 0
+ outputBuffer[36] = TYPE(equal_to[0][0]);
+ // CHECK-NEXT: 0
+ outputBuffer[37] = TYPE(negated[0][0] == expectedBuffer[0]);
+ // CHECK-NEXT: 1
+ outputBuffer[38] = TYPE(negated[1][1] == expectedBuffer[1]);
+ // CHECK-NEXT: 1
+} \ No newline at end of file
diff --git a/tests/spirv/matrix-bool-lowering.slang b/tests/spirv/matrix-bool-lowering.slang
index 63b7caacf..f903fbf17 100644
--- a/tests/spirv/matrix-bool-lowering.slang
+++ b/tests/spirv/matrix-bool-lowering.slang
@@ -1,6 +1,6 @@
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -vk -shaderobj -xslang -emit-spirv-directly
-//TEST_INPUT:ubuffer(data=[1 0], stride=4):in,name inputBuffer
+//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> inputBuffer;
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/spirv/matrix-integer-lowering.slang b/tests/spirv/matrix-integer-lowering.slang
index 518d0f78b..fded652a4 100644
--- a/tests/spirv/matrix-integer-lowering.slang
+++ b/tests/spirv/matrix-integer-lowering.slang
@@ -10,8 +10,10 @@ typealias m2x3 = matrix<TYPE, 2, 3>;
typealias m3x3 = matrix<TYPE, 3, 3>;
typealias m2x4 = matrix<TYPE, 2, 4>;
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer
RWStructuredBuffer<TYPE> outputBuffer;
+RWStructuredBuffer<TYPE> expectedBuffer;
struct matrixWrapper {
m2x2 mat1 = m2x2(1, 2, 3, 4);
@@ -103,6 +105,10 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2;
matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2;
+ // Test matrix negation operations
+ m2x2 neg_mat = m2x2(1, -2, 3, -4);
+ m2x2 negated = -neg_mat;
+
// Store results
outputBuffer[0] = val1;
// CHECK: 1
@@ -186,4 +192,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
// CHECK-NEXT: 0
outputBuffer[37] = TYPE(not_equal[0][0]);
// CHECK-NEXT: 1
+ outputBuffer[38] = TYPE(negated[0][0] == expectedBuffer[0]);
+ // CHECK-NEXT: 1
+ outputBuffer[39] = TYPE(negated[1][1] == expectedBuffer[1]);
+ // CHECK-NEXT: 1
} \ No newline at end of file
diff --git a/tests/wgsl/matrix-bool-lowering.slang b/tests/wgsl/matrix-bool-lowering.slang
new file mode 100644
index 000000000..4803fa73a
--- /dev/null
+++ b/tests/wgsl/matrix-bool-lowering.slang
@@ -0,0 +1,114 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -wgsl -shaderobj
+
+//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> inputBuffer;
+RWStructuredBuffer<int> outputBuffer;
+
+// Global bool constants to avoid constant folding
+static bool trueVal;
+static bool falseVal;
+
+struct matrixWrapper {
+ bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal);
+ bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal);
+}
+
+bool elementAnd(bool2x2 matrix)
+{
+ return trueVal
+ && matrix[0][0]
+ && matrix[0][1]
+ && matrix[1][0]
+ && matrix[1][1];
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Load true/false values from input buffer to avoid constant folding
+ trueVal = inputBuffer[0] != 0;
+ falseVal = inputBuffer[1] != 0;
+
+ // Test bool matrix construction
+ bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal);
+ bool3x3 mat2 = bool3x3(
+ trueVal, falseVal, trueVal,
+ falseVal, trueVal, falseVal,
+ trueVal, falseVal, trueVal
+ );
+ bool2x4 mat3 = bool2x4(
+ trueVal, falseVal, trueVal, falseVal,
+ trueVal, falseVal, trueVal, falseVal
+ );
+
+ // Test bool matrix element access
+ bool val1 = mat1[0][0];
+ bool val2 = mat2[2][1];
+
+ // Test bool matrix row access
+ bool2 row = mat1[1];
+ bool3 row3 = mat2[0];
+
+ // Test logical operations
+ bool2x2 not_mat = !mat1;
+ bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal);
+
+ // Test element assignment
+ mat1[0][1] = trueVal;
+ mat2[1][2] = falseVal;
+
+ // Test passing bool matrices to functions
+ bool anded = elementAnd(mat1);
+
+ // Test structs with bool matrix fields
+ matrixWrapper wrapper = {};
+
+ // Test any/all operations
+ bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal);
+ bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal);
+ bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal);
+
+ bool test_all_true = all(all_true); // all elements true -> true
+ bool test_all_false = all(all_false); // all elements false -> false
+ bool test_all_mixed = all(mixed); // some elements false -> false
+ bool test_any_true = any(all_true); // some elements true -> true
+ bool test_any_false = any(all_false); // no elements true -> false
+ bool test_any_mixed = any(mixed); // some elements true -> true
+
+ // Store results
+ outputBuffer[0] = val1;
+ // CHECK: 1
+ outputBuffer[1] = val2;
+ // CHECK-NEXT: 0
+ outputBuffer[2] = row.x;
+ // CHECK-NEXT: 0
+ outputBuffer[3] = row.y;
+ // CHECK-NEXT: 1
+ outputBuffer[4] = row3.y;
+ // CHECK-NEXT: 0
+ outputBuffer[5] = not_mat[0][0];
+ // CHECK-NEXT: 0
+ outputBuffer[6] = and_mat[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[7] = mat1[0][1];
+ // CHECK-NEXT: 1
+ outputBuffer[8] = mat3[0][1];
+ // CHECK-NEXT: 0
+ outputBuffer[9] = anded;
+ // CHECK-NEXT: 0
+ outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[11] = test_all_true;
+ // CHECK-NEXT: 1
+ outputBuffer[12] = test_all_false;
+ // CHECK-NEXT: 0
+ outputBuffer[13] = test_all_mixed;
+ // CHECK-NEXT: 0
+ outputBuffer[14] = test_any_true;
+ // CHECK-NEXT: 1
+ outputBuffer[15] = test_any_false;
+ // CHECK-NEXT: 0
+ outputBuffer[16] = test_any_mixed;
+ // CHECK-NEXT: 1
+} \ No newline at end of file
diff --git a/tests/wgsl/matrix-integer-lowering.slang b/tests/wgsl/matrix-integer-lowering.slang
new file mode 100644
index 000000000..fc2a64382
--- /dev/null
+++ b/tests/wgsl/matrix-integer-lowering.slang
@@ -0,0 +1,199 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -wgsl -shaderobj -xslang -DTYPE=int
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -wgsl -shaderobj -xslang -DTYPE=uint
+
+#ifndef TYPE
+#define TYPE int
+#endif
+
+typealias m2x2 = matrix<TYPE, 2, 2>;
+typealias m2x3 = matrix<TYPE, 2, 3>;
+typealias m3x3 = matrix<TYPE, 3, 3>;
+typealias m2x4 = matrix<TYPE, 2, 4>;
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer
+RWStructuredBuffer<TYPE> outputBuffer;
+RWStructuredBuffer<TYPE> expectedBuffer;
+
+struct matrixWrapper {
+ m2x2 mat1 = m2x2(1, 2, 3, 4);
+ m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10);
+};
+
+TYPE elementAdd(m2x2 matrix)
+{
+ return matrix[0][0]
+ + matrix[0][1]
+ + matrix[1][0]
+ + matrix[1][1];
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // Test matrix construction
+ m2x2 mat1 = m2x2(1, 2, 3, 4);
+ m3x3 mat2 = m3x3(
+ 1, 2, 3,
+ 4, 5, 6,
+ 7, 8, 9
+ );
+ m2x4 mat3 = m2x4(
+ 10, 11, 12, 13,
+ 14, 15, 16, 17
+ );
+
+ // Test matrix element access
+ TYPE val1 = mat1[0][0];
+ TYPE val2 = mat2[2][1];
+
+ // Test matrix row access
+ vector<TYPE, 2> row = mat1[1];
+ vector<TYPE, 3> row3 = mat2[0];
+
+ // Test arithmetic operations
+ m2x2 mat5 = m2x2(2, 4, 6, 7);
+
+ m2x2 mat_scalar = 2 * mat1;
+ m2x2 mat_add = mat1 + mat5;
+ m2x2 mat_sub = mat5 - mat1;
+ m2x2 mat_mul = mat1 * mat5;
+
+ // Test passing matrices to functions
+ TYPE added = elementAdd(mat1);
+
+ // Test structs with matrix fields
+ matrixWrapper wrapper = {};
+
+ // Test matrix intrinsic operations
+
+ // Test determinant for square matrices
+ m2x2 mat6 = m2x2(2, 1, 4, 3);
+ TYPE det2x2 = TYPE(determinant(mat6));
+ TYPE det3x3 = TYPE(determinant(mat2));
+
+ // Test transpose
+ matrix<TYPE, 2, 2> trans2x2 = transpose(mat1);
+ matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2);
+
+ // Test element-wise min/max
+ m2x2 mat_min = min(mat1, mat5);
+ m2x2 mat_max = max(mat1, mat5);
+
+ // Test all/any operations (these return bool, but we'll cast to TYPE for output)
+ m2x2 zero_mat = m2x2(0, 0, 0, 0);
+ m2x2 mixed_mat = m2x2(1, 0, 2, 0);
+
+ TYPE all_nonzero = TYPE(all(mat1));
+ TYPE all_zero = TYPE(all(zero_mat));
+ TYPE any_nonzero = TYPE(any(mixed_mat));
+ TYPE any_zero = TYPE(any(zero_mat));
+
+ // Test bit shift operations
+ m2x2 shift_mat = m2x2(1, 2, 4, 8);
+ m2x2 left_shift = shift_mat << 1;
+ m2x2 right_shift = shift_mat >> 1;
+
+ // Test comparison operations (these return bool matrices, cast to TYPE for output)
+ m2x2 comp_mat1 = m2x2(1, 3, 2, 4);
+ m2x2 comp_mat2 = m2x2(2, 2, 3, 3);
+
+ matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2;
+ matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2;
+ matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2;
+ matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2;
+ matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2;
+ matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2;
+
+ // Test matrix negation operations
+ m2x2 neg_mat = m2x2(1, -2, 3, -4);
+ m2x2 negated = -neg_mat;
+
+ // Store results
+ outputBuffer[0] = val1;
+ // CHECK: 1
+ outputBuffer[1] = val2;
+ // CHECK-NEXT: 8
+ outputBuffer[2] = row.x;
+ // CHECK-NEXT: 3
+ outputBuffer[3] = row.y;
+ // CHECK-NEXT: 4
+ outputBuffer[4] = row3.y;
+ // CHECK-NEXT: 2
+ outputBuffer[5] = mat_scalar[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[6] = mat_add[0][0];
+ // CHECK-NEXT: 3
+ outputBuffer[7] = mat_sub[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[8] = mat_mul[1][1];
+ // CHECK-NEXT: 28
+ outputBuffer[9] = added;
+ // CHECK-NEXT: 10
+ outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0];
+ // CHECK-NEXT: 5
+
+ // Matrix intrinsic operation results
+ outputBuffer[11] = det2x2;
+ // CHECK-NEXT: 2
+ outputBuffer[12] = det3x3;
+ // CHECK-NEXT: 0
+ outputBuffer[13] = mat_min[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[14] = mat_min[1][1];
+ // CHECK-NEXT: 4
+ outputBuffer[15] = mat_max[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[16] = mat_max[1][1];
+ // CHECK-NEXT: 7
+ outputBuffer[17] = all_nonzero;
+ // CHECK-NEXT: 1
+ outputBuffer[18] = all_zero;
+ // CHECK-NEXT: 0
+ outputBuffer[19] = any_nonzero;
+ // CHECK-NEXT: 1
+ outputBuffer[20] = any_zero;
+ // CHECK-NEXT: 0
+ outputBuffer[21] = trans2x2[0][0];
+ // CHECK-NEXT: 1
+ outputBuffer[22] = trans2x2[1][0];
+ // CHECK-NEXT: 2
+ outputBuffer[23] = trans2x3[0][0];
+ // CHECK-NEXT: 5
+
+ // Bit shift operation results
+ outputBuffer[24] = left_shift[0][0];
+ // CHECK-NEXT: 2
+ outputBuffer[25] = left_shift[0][1];
+ // CHECK-NEXT: 4
+ outputBuffer[26] = right_shift[1][0];
+ // CHECK-NEXT: 2
+ outputBuffer[27] = right_shift[1][1];
+ // CHECK-NEXT: 4
+
+ // Comparison operation results (bool matrices cast to TYPE)
+ outputBuffer[28] = TYPE(less_than[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[29] = TYPE(less_than[0][1]);
+ // CHECK-NEXT: 0
+ outputBuffer[30] = TYPE(greater_than[0][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[31] = TYPE(greater_than[1][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[32] = TYPE(less_equal[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[33] = TYPE(less_equal[0][1]);
+ // CHECK-NEXT: 0
+ outputBuffer[34] = TYPE(greater_equal[0][1]);
+ // CHECK-NEXT: 1
+ outputBuffer[35] = TYPE(greater_equal[1][0]);
+ // CHECK-NEXT: 0
+ outputBuffer[36] = TYPE(equal_to[0][0]);
+ // CHECK-NEXT: 0
+ outputBuffer[37] = TYPE(not_equal[0][0]);
+ // CHECK-NEXT: 1
+ outputBuffer[38] = TYPE(negated[0][0] == expectedBuffer[0]);
+ // CHECK-NEXT: 1
+ outputBuffer[39] = TYPE(negated[1][1] == expectedBuffer[1]);
+ // CHECK-NEXT: 1
+} \ No newline at end of file