diff options
| author | venkataram-nv <vedavamadath@nvidia.com> | 2025-07-30 09:27:38 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-30 16:27:38 +0000 |
| commit | 92ee2927d0012dd454dff7bb53b900f5240073d5 (patch) | |
| tree | d0a648fbb1e6b08c6eec90fadb23435731c1eefe | |
| parent | 42dc521f7817328a20e40b3352ae667dfd124edb (diff) | |
Lowering unsupported matrix types for GLSL/WGSL/Metal targets (#7936)
* Add emit cases for WGSL and GLSL
* Fix compilation warnings
Modify short cutting test to reflect change in emit logic
Lower matrix for metal as well
Add emit matrix logic for metal
Fix compiler warning
Brace initializer for lowered matrices
Fix compiler warnings
* Tests for metal
* Fix mult, any, and determinant
* Fix matrix-matrix multiplication
* Fix mat mul to be element-wise
* Fix compiler warning
* Move makeMatrix to legalization
* Move unary and binary arithmetic operator lowering to legalization
* Remove emit logic and move final comparison operators to legalization
* Handle vector/matrix negation for WGSL
* Restore older SPIR-V emit logic
* Address PR comments
* Revert to zero minus for negation
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
| -rw-r--r-- | source/slang/hlsl.meta.slang | 65 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 170 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.cpp | 435 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.h | 2 | ||||
| -rw-r--r-- | tests/compute/logic-no-short-circuit-evaluation.slang | 4 | ||||
| -rw-r--r-- | tests/glsl/matrix-bool-lowering.slang | 114 | ||||
| -rw-r--r-- | tests/glsl/matrix-integer-lowering.slang | 199 | ||||
| -rw-r--r-- | tests/metal/matrix-bool-lowering.slang | 119 | ||||
| -rw-r--r-- | tests/metal/matrix-integer-lowering.slang | 202 | ||||
| -rw-r--r-- | tests/spirv/matrix-bool-lowering.slang | 2 | ||||
| -rw-r--r-- | tests/spirv/matrix-integer-lowering.slang | 12 | ||||
| -rw-r--r-- | tests/wgsl/matrix-bool-lowering.slang | 114 | ||||
| -rw-r--r-- | tests/wgsl/matrix-integer-lowering.slang | 199 |
15 files changed, 1484 insertions, 172 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index aa494ec95..9fd5c8b6e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6481,14 +6481,10 @@ bool all(T x) { __target_switch { - default: - __intrinsic_asm "bool($0)"; case hlsl: __intrinsic_asm "all"; case metal: __intrinsic_asm "all"; - case wgsl: - __intrinsic_asm "all"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -6505,6 +6501,8 @@ bool all(T x) return __slang_noop_cast<bool>(x); else return false; + default: + __intrinsic_asm "bool($0)"; } } @@ -6550,9 +6548,17 @@ bool all(vector<T,N> x) }; } case wgsl: + // WGSL all() only works with boolean vectors if (__isBool<T>()) - __intrinsic_asm "all"; - __intrinsic_asm "all(vec$N0<bool>($0))"; + __intrinsic_asm "all($0)"; + else + { + // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion + bool result = true; + for(int i = 0; i < N; ++i) + result = result && all(x[i]); + return result; + } default: bool result = true; for(int i = 0; i < N; ++i) @@ -6563,7 +6569,7 @@ bool all(vector<T,N> x) __generic<T : __BuiltinType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool all(matrix<T,N,M> x) { __target_switch @@ -6655,7 +6661,8 @@ bool any(T x) case metal: __intrinsic_asm "any"; case wgsl: - __intrinsic_asm "any"; + // For scalars, any() doesn't exist in WGSL, just convert to bool + __intrinsic_asm "bool($0)"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -6686,7 +6693,17 @@ bool any(vector<T, N> x) case hlsl: __intrinsic_asm "any"; case metal: - __intrinsic_asm "any"; + if (__isBool<T>()) + __intrinsic_asm "any"; + else + { + // For non-bool types, convert to bool vector first + // Metal's any() only works with bool vectors + bool result = false; + for(int i = 0; i < N; ++i) + result = result || any(x[i]); + return result; + } case glsl: __intrinsic_asm "any(bvec$N0($0))"; case spirv: @@ -6714,7 +6731,17 @@ bool any(vector<T, N> x) }; } case wgsl: - __intrinsic_asm "any"; + // WGSL any() only works with boolean vectors + if (__isBool<T>()) + __intrinsic_asm "any($0)"; + else + { + // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion + bool result = false; + for(int i = 0; i < N; ++i) + result = result || any(x[i]); + return result; + } default: bool result = false; for(int i = 0; i < N; ++i) @@ -6725,7 +6752,7 @@ bool any(vector<T, N> x) __generic<T : __BuiltinType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool any(matrix<T, N, M> x) { __target_switch @@ -8626,11 +8653,8 @@ T determinant(matrix<T,N,N> m) { __target_switch { - case glsl: __intrinsic_asm "determinant"; case hlsl: __intrinsic_asm "determinant"; - case metal: __intrinsic_asm "determinant"; - case wgsl: __intrinsic_asm "determinant"; - // SPIR-V doesn't support integer determinants, so we need to implement it manually + // GLSL, WGSL, and SPIR-V don't support integer determinants for lowered matrices, so we need to implement it manually default: static_assert(N >= 1 && N <= 4, "determinant is only implemented up to 4x4 matrices"); if (N == 1) @@ -13804,16 +13828,14 @@ matrix<T, M, N> transpose(matrix<T, N, M> x) } __generic<T : __BuiltinIntegerType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] [PreferRecompute] matrix<T, M, N> transpose(matrix<T, N, M> x) { __target_switch { - case glsl: __intrinsic_asm "transpose"; case hlsl: __intrinsic_asm "transpose"; - case wgsl: __intrinsic_asm "transpose"; - // SPIRV-V doenst't support integer matrices, so transpose it manually + // GLSL, WGSL, SPIR-V, and Metal don't support integer matrices when lowered, so transpose it manually default: matrix<T, M, N> result; for (int r = 0; r < M; ++r) @@ -13824,19 +13846,18 @@ matrix<T, M, N> transpose(matrix<T, N, M> x) } __generic<T : __BuiltinLogicalType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] [PreferRecompute] [OverloadRank(-1)] matrix<T, M, N> transpose(matrix<T, N, M> x) { __target_switch { - case glsl: __intrinsic_asm "transpose"; case hlsl: __intrinsic_asm "transpose"; case spirv: return spirv_asm { OpTranspose $$matrix<T, M, N> result $x }; - case wgsl: __intrinsic_asm "transpose"; + // GLSL, WGSL, and Metal don't support bool matrices when lowered, so transpose it manually default: matrix<T, M, N> result; for (int r = 0; r < M; ++r) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b1b4c4570..da2620856 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -360,7 +360,7 @@ struct SpvLiteralBits // > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed // > four per word, following the little-endian convention (i.e., the // > first octet is in the lowest-order 8 bits of the word). - // > The final word contains the string's nul-termination character (0), and + // > The final word contains the string’s nul-termination character (0), and // > all contents past the end of the string in the final word are padded with 0. // First work out the amount of words we'll need @@ -2039,24 +2039,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); - auto elementType = matrixType->getElementType(); - - // SPIR-V only supports floating-point matrices - // bool/int matrices should be lowered to - // arrays of vectors before reaching here - SLANG_ASSERT(!as<IRBoolType>(elementType)); - SLANG_ASSERT(!as<IRIntType>(elementType)); - SLANG_ASSERT(!as<IRUIntType>(elementType)); - auto vectorSpvType = ensureVectorType( - static_cast<IRBasicType*>(elementType)->getBaseType(), + static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); - const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount)); - SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv); - return matrixSpvType; + auto matrixSPVType = emitOpTypeMatrix( + inst, + vectorSpvType, + SpvLiteralInteger::from32(int32_t(columnCount))); + return matrixSPVType; } case kIROp_ArrayType: case kIROp_UnsizedArrayType: @@ -2628,7 +2621,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvWord arrayed = inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed; - // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored." + // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored." SpvWord depth = ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled @@ -7780,40 +7773,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Otherwise, operands are raw elements, we need to construct row vectors first, // then construct matrix from row vectors. List<SpvInst*> rowVectors; - - IRIntegerValue rowCount; - IRIntegerValue colCount; - IRType* elementType; - - // Data type can be either matrix or vector depending on the - // legalization requirements - auto dataType = inst->getDataType(); - - if (auto matrixType = as<IRMatrixType>(dataType)) - { - elementType = matrixType->getElementType(); - rowCount = getIntVal(matrixType->getRowCount()); - colCount = getIntVal(matrixType->getColumnCount()); - } - else if (auto arrayType = as<IRArrayType>(dataType)) - { - auto vectorType = as<IRVectorType>(arrayType->getElementType()); - SLANG_ASSERT(vectorType); - - elementType = vectorType->getElementType(); - rowCount = getIntVal(arrayType->getElementCount()); - colCount = getIntVal(vectorType->getElementCount()); - } - else - { - SLANG_UNEXPECTED("data type for makeMatrix operation is " - "expected be either a matrix or array type"); - } - + auto matrixType = cast<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); IRBuilder builder(inst); builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); List<IRInst*> colElements; UInt index = 0; for (IRIntegerValue j = 0; j < rowCount; j++) @@ -7938,10 +7903,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex ArrayView<IRInst*> operands) { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); - SLANG_ASSERT(elementType); - IRBasicType* basicType = as<IRBasicType>(elementType); - SLANG_ASSERT(basicType); SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); if (opCode == SpvOpUndef) @@ -8002,52 +7964,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } - // Helper method to handle composite arithmetic operations for matrices and arrays - SpvInst* emitCompositeArithmetic( - SpvInstParent* parent, - IRInst* inst, - IRIntegerValue rowCount, - IRIntegerValue colCount, - IRType* elementType, - IRType* resultType, - bool isMatrixType) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - List<SpvInst*> rows; - - for (IRIntegerValue i = 0; i < rowCount; i++) - { - List<IRInst*> operands; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto originalOperand = inst->getOperand(j); - bool shouldExtract = - isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr - : as<IRArrayType>(originalOperand->getDataType()) != nullptr; - - if (shouldExtract) - { - auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); - operands.add(operand); - } - else - { - operands.add(originalOperand); - } - } - rows.add(emitVectorOrScalarArithmetic( - parent, - nullptr, - rowVectorType, - inst->getOp(), - inst->getOperandCount(), - operands.getArrayView())); - } - return emitCompositeConstruct(parent, inst, resultType, rows); - } SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { @@ -8055,38 +7971,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - matrixType->getElementType(), - inst->getDataType(), - true); - } - else if (const auto arrayType = as<IRArrayType>(inst->getDataType())) - { - // Only for legalization - auto arrayElementType = arrayType->getElementType(); - SLANG_ASSERT(as<IRVectorType>(arrayElementType)); - - auto vectorType = as<IRVectorType>(arrayElementType); - auto elementType = vectorType->getElementType(); - SLANG_ASSERT( - as<IRBoolType>(elementType) || as<IRUIntType>(elementType) || - as<IRIntType>(elementType)); - - auto rowCount = getIntVal(arrayType->getElementCount()); - auto colCount = getIntVal(vectorType->getElementCount()); - - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - elementType, - inst->getDataType(), - false); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<SpvInst*> rows; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + if (as<IRMatrixType>(originalOperand->getDataType())) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic( + parent, + nullptr, + rowVectorType, + inst->getOp(), + inst->getOperandCount(), + operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); } Array<IRInst*, 4> operands; diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index fbcb54d10..53c3aa487 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1624,6 +1624,22 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(")"); return true; } + case kIROp_Neg: + { + auto opType = inst->getOperand(0)->getDataType(); + if (as<IRMatrixType>(opType) || as<IRVectorType>(opType)) + { + // WGSL does not support negate operator on matrices and vectors, + // we should emit "(type(0) - op0)" instead. + m_writer->emit("("); + emitType(inst->getDataType()); + m_writer->emit("(0) - "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + break; + } } return false; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index b548ef632..405bca5a2 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1339,7 +1339,10 @@ Result linkAndOptimizeIR( } legalizeMatrixTypes(targetProgram, irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-MATRIX-LEGALIZATION"); + legalizeVectorTypes(irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-VECTOR-LEGALIZATION"); // Once specialization and type legalization have been performed, // we should perform some of our basic optimization steps again, diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp index 0b972b5bd..8c8cb0c84 100644 --- a/source/slang/slang-ir-legalize-matrix-types.cpp +++ b/source/slang/slang-ir-legalize-matrix-types.cpp @@ -1,6 +1,7 @@ #include "slang-ir-legalize-matrix-types.h" #include "slang-compiler.h" +#include "slang-ir-insts-enum.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -50,6 +51,9 @@ struct MatrixTypeLoweringContext case CodeGenTarget::WGSL: case CodeGenTarget::WGSLSPIRV: case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::Metal: + case CodeGenTarget::MetalLib: + case CodeGenTarget::MetalLibAssembly: return true; default: return false; @@ -66,33 +70,430 @@ struct MatrixTypeLoweringContext as<IRIntType>(elementType); } - IRInst* getReplacement(IRInst* inst) + IRInst* legalizeMatrixTypeDeclaration(IRInst* inst) { - if (auto replacement = replacements.tryGetValue(inst)) - return *replacement; + auto matrixType = as<IRMatrixType>(inst); + if (shouldLowerMatrixType(matrixType)) + { + // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C) + auto elementType = matrixType->getElementType(); + auto rowCount = matrixType->getRowCount(); + auto columnCount = matrixType->getColumnCount(); - IRInst* newInst = inst; + IRBuilder builder(matrixType); + builder.setInsertBefore(matrixType); + + // Create vector type for columns: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type for rows: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + return arrayType; + } + return inst; + } + + IRInst* legalizeMakeMatrix(IRInst* inst) + { + auto makeMatrix = as<IRMakeMatrix>(inst); + auto matrixType = as<IRMatrixType>(makeMatrix->getDataType()); + + SLANG_ASSERT(matrixType && "Matrix type is expected"); + SLANG_ASSERT( + shouldLowerMatrixType(matrixType) && "Matrix type is expected to need legalization"); + + // Lower makeMatrix to makeArray of makeVectors + auto elementType = matrixType->getElementType(); + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto columnCount = as<IRIntLit>(matrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); - if (auto matrixType = as<IRMatrixType>(inst)) + IRBuilder builder(makeMatrix); + builder.setInsertBefore(makeMatrix); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Group operands into rows and create vectors + List<IRInst*> rowVectors; + UInt operandIndex = 0; + + // Assert that we have the expected number of operands + SLANG_ASSERT( + makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()) && + "makeMatrix operand count must match matrix dimensions"); + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) { - if (shouldLowerMatrixType(matrixType)) + List<IRInst*> rowElements; + for (IRIntegerValue col = 0; col < columnCount->getValue(); col++) { - // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C) - auto elementType = matrixType->getElementType(); - auto rowCount = matrixType->getRowCount(); - auto columnCount = matrixType->getColumnCount(); + SLANG_ASSERT( + operandIndex < makeMatrix->getOperandCount() && "Operand index out of bounds"); + rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex))); + operandIndex++; + } + + SLANG_ASSERT( + rowElements.getCount() == columnCount->getValue() && + "Row elements count must match column count"); + auto rowVector = builder.emitMakeVector(vectorType, rowElements); + rowVectors.add(rowVector); + } + + SLANG_ASSERT( + rowVectors.getCount() == rowCount->getValue() && + "Row vectors count must match matrix row count"); + return builder.emitMakeArray(arrayType, rowVectors.getCount(), rowVectors.getBuffer()); + } + + IRInst* legalizeMatrixMatrixBinaryOperation( + IRBuilder& builder, + IRInst* legalizedA, + IRInst* legalizedB, + IRMatrixType* resultMatrixType, + IROp binaryOp) + { + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); - IRBuilder builder(matrixType); - builder.setInsertBefore(matrixType); + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); - // Create vector type for columns: vector<T, C> - auto vectorType = builder.getVectorType(elementType, columnCount); + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); - // Create array type for rows: vector<T, C>[R] - auto arrayType = builder.getArrayType(vectorType, rowCount); + // Extract vectors from both arrays and apply binary operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from each operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst); + auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst); - newInst = arrayType; + // Apply the binary operation to the vectors + IRInst* args[] = {vectorA, vectorB}; + auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + + template<bool matrixIsFirst> + IRInst* legalizeMatrixMixedBinaryOperation( + IRBuilder& builder, + IRInst* legalizedMatrix, + IRInst* legalizedOther, + IRMatrixType* resultMatrixType, + IROp binaryOp) + { + // Verify that the other operand is either a vector or scalar type + auto otherType = legalizedOther->getDataType(); + auto otherVectorType = as<IRVectorType>(otherType); + auto otherBasicType = as<IRBasicType>(otherType); + SLANG_ASSERT( + (otherVectorType || otherBasicType) && "Other operand must be vector or scalar type"); + + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Extract vectors from matrix array and apply binary operation with other operand + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from matrix array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto matrixRowVector = builder.emitElementExtract(legalizedMatrix, rowIndexInst); + + // Apply the binary operation between matrix row vector and other operand + IRInst* args[2]; + if constexpr (matrixIsFirst) + { + args[0] = matrixRowVector; + args[1] = legalizedOther; } + else + { + args[0] = legalizedOther; + args[1] = matrixRowVector; + } + auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeBinaryOperation(IRInst* inst, IROp binaryOp) + { + IRInst* opdA = inst->getOperand(0); + IRInst* opdB = inst->getOperand(1); + + // Check what types we're dealing with + auto typeA = opdA->getDataType(); + auto typeB = opdB->getDataType(); + + auto matrixTypeA = as<IRMatrixType>(typeA); + auto matrixTypeB = as<IRMatrixType>(typeB); + + bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA); + bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB); + + // Get the result matrix type to determine dimensions + auto resultMatrixType = as<IRMatrixType>(inst->getDataType()); + SLANG_ASSERT(resultMatrixType && "Binary operation should have matrix result type"); + SLANG_ASSERT( + shouldLowerMatrixType(resultMatrixType) && + "Result matrix type should need legalization"); + + // Create IRBuilder at the top level + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Get legalized operands once + IRInst* legalizedA = getReplacement(opdA); + IRInst* legalizedB = getReplacement(opdB); + + if (shouldLowerA && shouldLowerB) + { + return legalizeMatrixMatrixBinaryOperation( + builder, + legalizedA, + legalizedB, + resultMatrixType, + binaryOp); + } + else if (shouldLowerA && !shouldLowerB) + { + return legalizeMatrixMixedBinaryOperation<true>( + builder, + legalizedA, + legalizedB, + resultMatrixType, + binaryOp); + } + else if (!shouldLowerA && shouldLowerB) + { + return legalizeMatrixMixedBinaryOperation<false>( + builder, + legalizedB, + legalizedA, + resultMatrixType, + binaryOp); + } + + // Neither operand is a matrix that needs lowering, shouldn't reach here + SLANG_UNREACHABLE("legalizeBinaryOperation called but no matrix operand needs lowering"); + } + + IRInst* legalizeComparisonOperation(IRInst* inst, IROp comparisonOp) + { + IRInst* opdA = inst->getOperand(0); + IRInst* opdB = inst->getOperand(1); + + // Check what types we're dealing with + auto typeA = opdA->getDataType(); + auto typeB = opdB->getDataType(); + + auto matrixTypeA = as<IRMatrixType>(typeA); + auto matrixTypeB = as<IRMatrixType>(typeB); + + bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA); + bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB); + + // Only matrix-matrix comparisons are supported + SLANG_ASSERT( + shouldLowerA && shouldLowerB && + "Comparison operations only supported between matrices that need lowering"); + + // Create IRBuilder at the top level + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Get legalized operands + IRInst* legalizedA = getReplacement(opdA); + IRInst* legalizedB = getReplacement(opdB); + + auto rowCount = as<IRIntLit>(matrixTypeA->getRowCount()); + auto columnCount = as<IRIntLit>(matrixTypeA->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + // Create boolean vector type for rows: vector<bool, C> + auto boolType = builder.getBoolType(); + auto boolVectorType = builder.getVectorType(boolType, columnCount); + + // Create array type: vector<bool, C>[R] + auto boolArrayType = builder.getArrayType(boolVectorType, rowCount); + + // Extract vectors from both arrays and apply comparison operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from each operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst); + auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst); + + // Apply the comparison operation to the vectors + IRInst* args[] = {vectorA, vectorB}; + auto resultVector = builder.emitIntrinsicInst(boolVectorType, comparisonOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + boolArrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeUnaryOperation(IRInst* inst, IROp unaryOp) + { + IRInst* operand = inst->getOperand(0); + + // Get the legalized operand (should be an array of vectors) + IRInst* legalizedOperand = getReplacement(operand); + + // Get the result matrix type to determine dimensions + auto resultMatrixType = as<IRMatrixType>(inst->getDataType()); + SLANG_ASSERT(resultMatrixType && "Unary operation should have matrix result type"); + SLANG_ASSERT( + shouldLowerMatrixType(resultMatrixType) && + "Result matrix type should need legalization"); + + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Extract vectors from array and apply unary operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vector = builder.emitElementExtract(legalizedOperand, rowIndexInst); + + // Apply the unary operation to the vector + IRInst* args[] = {vector}; + auto resultVector = builder.emitIntrinsicInst(vectorType, unaryOp, 1, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeMatrixProducingInstruction(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_MakeMatrix: + return legalizeMakeMatrix(inst); + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + return legalizeBinaryOperation(inst, inst->getOp()); + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + return legalizeComparisonOperation(inst, inst->getOp()); + case kIROp_Not: + case kIROp_BitNot: + case kIROp_Neg: + return legalizeUnaryOperation(inst, inst->getOp()); + default: + break; + } + + return inst; + } + + IRInst* getReplacement(IRInst* inst) + { + if (auto replacement = replacements.tryGetValue(inst)) + return *replacement; + + IRInst* newInst = inst; + if (as<IRMatrixType>(inst)) + newInst = legalizeMatrixTypeDeclaration(inst); + + IRType* resultType = inst->getDataType(); + if (auto matrixType = as<IRMatrixType>(resultType)) + { + if (shouldLowerMatrixType(matrixType)) + newInst = legalizeMatrixProducingInstruction(inst); } replacements[inst] = newInst; diff --git a/source/slang/slang-ir-legalize-matrix-types.h b/source/slang/slang-ir-legalize-matrix-types.h index 418e80a83..a2e71a402 100644 --- a/source/slang/slang-ir-legalize-matrix-types.h +++ b/source/slang/slang-ir-legalize-matrix-types.h @@ -7,7 +7,7 @@ struct IRModule; class DiagnosticSink; class TargetProgram; -// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, and GLSL targets +// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, GLSL, and Metal targets void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink); } // namespace Slang
\ No newline at end of file diff --git a/tests/compute/logic-no-short-circuit-evaluation.slang b/tests/compute/logic-no-short-circuit-evaluation.slang index ea2b7a0c3..342a11f28 100644 --- a/tests/compute/logic-no-short-circuit-evaluation.slang +++ b/tests/compute/logic-no-short-circuit-evaluation.slang @@ -32,7 +32,7 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) //SM5:(all({{.*}}&& //HLSL2018:(all({{.*}}&& //SM6:(all(and( - //WGS:(all(select(vec2<bool>(false), + //WGS:(all((select(vec2<bool>(false), //MTL:(all({{.*}}&& if (all(bool2(index >= 1) && assignFunc(index))) { @@ -54,7 +54,7 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) //SM5:(all({{.*}}?{{.*}}: //HLSL2018:(all({{.*}}?{{.*}}: //SM6:(all(select( - //WGS:(all(select(vec2<bool>(false), + //WGS:(all((select(vec2<bool>(false), //MTL:(all(select(bool2(false) if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false))) { diff --git a/tests/glsl/matrix-bool-lowering.slang b/tests/glsl/matrix-bool-lowering.slang new file mode 100644 index 000000000..9f2ad913f --- /dev/null +++ b/tests/glsl/matrix-bool-lowering.slang @@ -0,0 +1,114 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -emit-spirv-via-glsl -shaderobj + +//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> inputBuffer; +RWStructuredBuffer<int> outputBuffer; + +// Global bool constants to avoid constant folding +static bool trueVal; +static bool falseVal; + +struct matrixWrapper { + bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal); +} + +bool elementAnd(bool2x2 matrix) +{ + return trueVal + && matrix[0][0] + && matrix[0][1] + && matrix[1][0] + && matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Load true/false values from input buffer to avoid constant folding + trueVal = inputBuffer[0] != 0; + falseVal = inputBuffer[1] != 0; + + // Test bool matrix construction + bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal); + bool3x3 mat2 = bool3x3( + trueVal, falseVal, trueVal, + falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal + ); + bool2x4 mat3 = bool2x4( + trueVal, falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal, falseVal + ); + + // Test bool matrix element access + bool val1 = mat1[0][0]; + bool val2 = mat2[2][1]; + + // Test bool matrix row access + bool2 row = mat1[1]; + bool3 row3 = mat2[0]; + + // Test logical operations + bool2x2 not_mat = !mat1; + bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal); + + // Test element assignment + mat1[0][1] = trueVal; + mat2[1][2] = falseVal; + + // Test passing bool matrices to functions + bool anded = elementAnd(mat1); + + // Test structs with bool matrix fields + matrixWrapper wrapper = {}; + + // Test any/all operations + bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal); + bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal); + + bool test_all_true = all(all_true); // all elements true -> true + bool test_all_false = all(all_false); // all elements false -> false + bool test_all_mixed = all(mixed); // some elements false -> false + bool test_any_true = any(all_true); // some elements true -> true + bool test_any_false = any(all_false); // no elements true -> false + bool test_any_mixed = any(mixed); // some elements true -> true + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 0 + outputBuffer[2] = row.x; + // CHECK-NEXT: 0 + outputBuffer[3] = row.y; + // CHECK-NEXT: 1 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 0 + outputBuffer[5] = not_mat[0][0]; + // CHECK-NEXT: 0 + outputBuffer[6] = and_mat[0][0]; + // CHECK-NEXT: 1 + outputBuffer[7] = mat1[0][1]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat3[0][1]; + // CHECK-NEXT: 0 + outputBuffer[9] = anded; + // CHECK-NEXT: 0 + outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[11] = test_all_true; + // CHECK-NEXT: 1 + outputBuffer[12] = test_all_false; + // CHECK-NEXT: 0 + outputBuffer[13] = test_all_mixed; + // CHECK-NEXT: 0 + outputBuffer[14] = test_any_true; + // CHECK-NEXT: 1 + outputBuffer[15] = test_any_false; + // CHECK-NEXT: 0 + outputBuffer[16] = test_any_mixed; + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/glsl/matrix-integer-lowering.slang b/tests/glsl/matrix-integer-lowering.slang new file mode 100644 index 000000000..4d6033d79 --- /dev/null +++ b/tests/glsl/matrix-integer-lowering.slang @@ -0,0 +1,199 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -vk -output-using-type -compute -emit-spirv-via-glsl -shaderobj -xslang -DTYPE=int +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -vk -output-using-type -compute -emit-spirv-via-glsl -shaderobj -xslang -DTYPE=uint + +#ifndef TYPE +#define TYPE int +#endif + +typealias m2x2 = matrix<TYPE, 2, 2>; +typealias m2x3 = matrix<TYPE, 2, 3>; +typealias m3x3 = matrix<TYPE, 3, 3>; +typealias m2x4 = matrix<TYPE, 2, 4>; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer +RWStructuredBuffer<TYPE> outputBuffer; +RWStructuredBuffer<TYPE> expectedBuffer; + +struct matrixWrapper { + m2x2 mat1 = m2x2(1, 2, 3, 4); + m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10); +}; + +TYPE elementAdd(m2x2 matrix) +{ + return matrix[0][0] + + matrix[0][1] + + matrix[1][0] + + matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Test matrix construction + m2x2 mat1 = m2x2(1, 2, 3, 4); + m3x3 mat2 = m3x3( + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ); + m2x4 mat3 = m2x4( + 10, 11, 12, 13, + 14, 15, 16, 17 + ); + + // Test matrix element access + TYPE val1 = mat1[0][0]; + TYPE val2 = mat2[2][1]; + + // Test matrix row access + vector<TYPE, 2> row = mat1[1]; + vector<TYPE, 3> row3 = mat2[0]; + + // Test arithmetic operations + m2x2 mat5 = m2x2(2, 4, 6, 7); + + m2x2 mat_scalar = 2 * mat1; + m2x2 mat_add = mat1 + mat5; + m2x2 mat_sub = mat5 - mat1; + m2x2 mat_mul = mat1 * mat5; + + // Test passing matrices to functions + TYPE added = elementAdd(mat1); + + // Test structs with matrix fields + matrixWrapper wrapper = {}; + + // Test matrix intrinsic operations + + // Test determinant for square matrices + m2x2 mat6 = m2x2(2, 1, 4, 3); + TYPE det2x2 = TYPE(determinant(mat6)); + TYPE det3x3 = TYPE(determinant(mat2)); + + // Test transpose + matrix<TYPE, 2, 2> trans2x2 = transpose(mat1); + matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2); + + // Test element-wise min/max + m2x2 mat_min = min(mat1, mat5); + m2x2 mat_max = max(mat1, mat5); + + // Test all/any operations (these return bool, but we'll cast to TYPE for output) + m2x2 zero_mat = m2x2(0, 0, 0, 0); + m2x2 mixed_mat = m2x2(1, 0, 2, 0); + + TYPE all_nonzero = TYPE(all(mat1)); + TYPE all_zero = TYPE(all(zero_mat)); + TYPE any_nonzero = TYPE(any(mixed_mat)); + TYPE any_zero = TYPE(any(zero_mat)); + + // Test bit shift operations + m2x2 shift_mat = m2x2(1, 2, 4, 8); + m2x2 left_shift = shift_mat << 1; + m2x2 right_shift = shift_mat >> 1; + + // Test comparison operations (these return bool matrices, cast to TYPE for output) + m2x2 comp_mat1 = m2x2(1, 3, 2, 4); + m2x2 comp_mat2 = m2x2(2, 2, 3, 3); + + matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2; + matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2; + matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2; + matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2; + matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2; + matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2; + + // Test matrix negation operations + m2x2 neg_mat = m2x2(1, -2, 3, -4); + m2x2 negated = -neg_mat; + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 8 + outputBuffer[2] = row.x; + // CHECK-NEXT: 3 + outputBuffer[3] = row.y; + // CHECK-NEXT: 4 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 2 + outputBuffer[5] = mat_scalar[0][0]; + // CHECK-NEXT: 2 + outputBuffer[6] = mat_add[0][0]; + // CHECK-NEXT: 3 + outputBuffer[7] = mat_sub[0][0]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat_mul[1][1]; + // CHECK-NEXT: 28 + outputBuffer[9] = added; + // CHECK-NEXT: 10 + outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0]; + // CHECK-NEXT: 5 + + // Matrix intrinsic operation results + outputBuffer[11] = det2x2; + // CHECK-NEXT: 2 + outputBuffer[12] = det3x3; + // CHECK-NEXT: 0 + outputBuffer[13] = mat_min[0][0]; + // CHECK-NEXT: 1 + outputBuffer[14] = mat_min[1][1]; + // CHECK-NEXT: 4 + outputBuffer[15] = mat_max[0][0]; + // CHECK-NEXT: 2 + outputBuffer[16] = mat_max[1][1]; + // CHECK-NEXT: 7 + outputBuffer[17] = all_nonzero; + // CHECK-NEXT: 1 + outputBuffer[18] = all_zero; + // CHECK-NEXT: 0 + outputBuffer[19] = any_nonzero; + // CHECK-NEXT: 1 + outputBuffer[20] = any_zero; + // CHECK-NEXT: 0 + outputBuffer[21] = trans2x2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[22] = trans2x2[1][0]; + // CHECK-NEXT: 2 + outputBuffer[23] = trans2x3[0][0]; + // CHECK-NEXT: 5 + + // Bit shift operation results + outputBuffer[24] = left_shift[0][0]; + // CHECK-NEXT: 2 + outputBuffer[25] = left_shift[0][1]; + // CHECK-NEXT: 4 + outputBuffer[26] = right_shift[1][0]; + // CHECK-NEXT: 2 + outputBuffer[27] = right_shift[1][1]; + // CHECK-NEXT: 4 + + // Comparison operation results (bool matrices cast to TYPE) + outputBuffer[28] = TYPE(less_than[0][0]); + // CHECK-NEXT: 1 + outputBuffer[29] = TYPE(less_than[0][1]); + // CHECK-NEXT: 0 + outputBuffer[30] = TYPE(greater_than[0][1]); + // CHECK-NEXT: 1 + outputBuffer[31] = TYPE(greater_than[1][1]); + // CHECK-NEXT: 1 + outputBuffer[32] = TYPE(less_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[33] = TYPE(less_equal[0][1]); + // CHECK-NEXT: 0 + outputBuffer[34] = TYPE(greater_equal[0][1]); + // CHECK-NEXT: 1 + outputBuffer[35] = TYPE(greater_equal[1][0]); + // CHECK-NEXT: 0 + outputBuffer[36] = TYPE(equal_to[0][0]); + // CHECK-NEXT: 0 + outputBuffer[37] = TYPE(not_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[38] = TYPE(negated[0][0] == expectedBuffer[0]); + // CHECK-NEXT: 1 + outputBuffer[39] = TYPE(negated[1][1] == expectedBuffer[1]); + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/metal/matrix-bool-lowering.slang b/tests/metal/matrix-bool-lowering.slang new file mode 100644 index 000000000..4248bb573 --- /dev/null +++ b/tests/metal/matrix-bool-lowering.slang @@ -0,0 +1,119 @@ +//TEST:SIMPLE(filecheck=METAL): -target metal -stage compute -entry computeMain +//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage compute -entry computeMain +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -mtl -shaderobj + +//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> inputBuffer; +RWStructuredBuffer<int> outputBuffer; + +// Global bool constants to avoid constant folding +static bool trueVal; +static bool falseVal; + +struct matrixWrapper { + bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal); +} + +bool elementAnd(bool2x2 matrix) +{ + return trueVal + && matrix[0][0] + && matrix[0][1] + && matrix[1][0] + && matrix[1][1]; +} + +// METAL: array<bool2, int(2)> +// METALLIB: @computeMain + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Load true/false values from input buffer to avoid constant folding + trueVal = inputBuffer[0] != 0; + falseVal = inputBuffer[1] != 0; + + // Test bool matrix construction + bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal); + bool3x3 mat2 = bool3x3( + trueVal, falseVal, trueVal, + falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal + ); + bool2x4 mat3 = bool2x4( + trueVal, falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal, falseVal + ); + + // Test bool matrix element access + bool val1 = mat1[0][0]; + bool val2 = mat2[2][1]; + + // Test bool matrix row access + bool2 row = mat1[1]; + bool3 row3 = mat2[0]; + + // Test logical operations + bool2x2 not_mat = !mat1; + bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal); + + // Test element assignment + mat1[0][1] = trueVal; + mat2[1][2] = falseVal; + + // Test passing bool matrices to functions + bool anded = elementAnd(mat1); + + // Test structs with bool matrix fields + matrixWrapper wrapper = {}; + + // Test any/all operations + bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal); + bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal); + + bool test_all_true = all(all_true); // all elements true -> true + bool test_all_false = all(all_false); // all elements false -> false + bool test_all_mixed = all(mixed); // some elements false -> false + bool test_any_true = any(all_true); // some elements true -> true + bool test_any_false = any(all_false); // no elements true -> false + bool test_any_mixed = any(mixed); // some elements true -> true + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 0 + outputBuffer[2] = row.x; + // CHECK-NEXT: 0 + outputBuffer[3] = row.y; + // CHECK-NEXT: 1 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 0 + outputBuffer[5] = not_mat[0][0]; + // CHECK-NEXT: 0 + outputBuffer[6] = and_mat[0][0]; + // CHECK-NEXT: 1 + outputBuffer[7] = mat1[0][1]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat3[0][1]; + // CHECK-NEXT: 0 + outputBuffer[9] = anded; + // CHECK-NEXT: 0 + outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[11] = test_all_true; + // CHECK-NEXT: 1 + outputBuffer[12] = test_all_false; + // CHECK-NEXT: 0 + outputBuffer[13] = test_all_mixed; + // CHECK-NEXT: 0 + outputBuffer[14] = test_any_true; + // CHECK-NEXT: 1 + outputBuffer[15] = test_any_false; + // CHECK-NEXT: 0 + outputBuffer[16] = test_any_mixed; + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/metal/matrix-integer-lowering.slang b/tests/metal/matrix-integer-lowering.slang new file mode 100644 index 000000000..04aec5a7c --- /dev/null +++ b/tests/metal/matrix-integer-lowering.slang @@ -0,0 +1,202 @@ +//TEST:SIMPLE(filecheck=METAL): -target metal -stage compute -entry computeMain -DTYPE=int +//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage compute -entry computeMain -DTYPE=int +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -mtl -shaderobj -xslang -DTYPE=int +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -mtl -shaderobj -xslang -DTYPE=uint + +#ifndef TYPE +#define TYPE int +#endif + +typealias m2x2 = matrix<TYPE, 2, 2>; +typealias m2x3 = matrix<TYPE, 2, 3>; +typealias m3x3 = matrix<TYPE, 3, 3>; +typealias m2x4 = matrix<TYPE, 2, 4>; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer +RWStructuredBuffer<TYPE> outputBuffer; +RWStructuredBuffer<TYPE> expectedBuffer; + +struct matrixWrapper { + m2x2 mat1 = m2x2(1, 2, 3, 4); + m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10); +}; + +TYPE elementAdd(m2x2 matrix) +{ + return matrix[0][0] + + matrix[0][1] + + matrix[1][0] + + matrix[1][1]; +} + +// METAL: array<{{(int|uint)}}2, int(2)> +// METALLIB: @computeMain + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Test matrix construction + m2x2 mat1 = m2x2(1, 2, 3, 4); + m3x3 mat2 = m3x3( + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ); + m2x4 mat3 = m2x4( + 10, 11, 12, 13, + 14, 15, 16, 17 + ); + + // Test matrix element access + TYPE val1 = mat1[0][0]; + TYPE val2 = mat2[2][1]; + + // Test matrix row access + vector<TYPE, 2> row = mat1[1]; + vector<TYPE, 3> row3 = mat2[0]; + + // Test arithmetic operations + m2x2 mat5 = m2x2(2, 4, 6, 7); + + m2x2 mat_scalar = 2 * mat1; + m2x2 mat_add = mat1 + mat5; + m2x2 mat_sub = mat5 - mat1; + m2x2 mat_mul = mat1 * mat5; + + // Test passing matrices to functions + TYPE added = elementAdd(mat1); + + // Test structs with matrix fields + matrixWrapper wrapper = {}; + + // Test matrix intrinsic operations + + // Test determinant for square matrices + m2x2 mat6 = m2x2(2, 1, 4, 3); + TYPE det2x2 = TYPE(determinant(mat6)); + TYPE det3x3 = TYPE(determinant(mat2)); + + // Test transpose + matrix<TYPE, 2, 2> trans2x2 = transpose(mat1); + matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2); + + // Test element-wise min/max + m2x2 mat_min = min(mat1, mat5); + m2x2 mat_max = max(mat1, mat5); + + // Test all/any operations (these return bool, but we'll cast to TYPE for output) + m2x2 zero_mat = m2x2(0, 0, 0, 0); + m2x2 mixed_mat = m2x2(1, 0, 2, 0); + + TYPE all_nonzero = TYPE(all(mat1)); + TYPE all_zero = TYPE(all(zero_mat)); + TYPE any_nonzero = TYPE(any(mixed_mat)); + TYPE any_zero = TYPE(any(zero_mat)); + + // Test bit shift operations + m2x2 shift_mat = m2x2(1, 2, 4, 8); + m2x2 left_shift = shift_mat << 1; + m2x2 right_shift = shift_mat >> 1; + + // Test comparison operations (these return bool matrices, cast to TYPE for output) + m2x2 comp_mat1 = m2x2(1, 3, 2, 4); + m2x2 comp_mat2 = m2x2(2, 2, 3, 3); + + matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2; + matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2; + matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2; + matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2; + matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2; + matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2; + + // Test matrix negation operations + m2x2 neg_mat = m2x2(1, -2, 3, -4); + m2x2 negated = -neg_mat; + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 8 + outputBuffer[2] = row.x; + // CHECK-NEXT: 3 + outputBuffer[3] = row.y; + // CHECK-NEXT: 4 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 2 + outputBuffer[5] = mat_scalar[0][0]; + // CHECK-NEXT: 2 + outputBuffer[6] = mat_add[0][0]; + // CHECK-NEXT: 3 + outputBuffer[7] = mat_sub[0][0]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat_mul[1][1]; + // CHECK-NEXT: 28 + outputBuffer[9] = added; + // CHECK-NEXT: 10 + outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0]; + // CHECK-NEXT: 5 + + // Matrix intrinsic operation results + outputBuffer[11] = det2x2; + // CHECK-NEXT: 2 + outputBuffer[12] = det3x3; + // CHECK-NEXT: 0 + outputBuffer[13] = mat_min[0][0]; + // CHECK-NEXT: 1 + outputBuffer[14] = mat_min[1][1]; + // CHECK-NEXT: 4 + outputBuffer[15] = mat_max[0][0]; + // CHECK-NEXT: 2 + outputBuffer[16] = mat_max[1][1]; + // CHECK-NEXT: 7 + outputBuffer[17] = all_nonzero; + // CHECK-NEXT: 1 + outputBuffer[18] = all_zero; + // CHECK-NEXT: 0 + outputBuffer[19] = any_nonzero; + // CHECK-NEXT: 1 + outputBuffer[20] = any_zero; + // CHECK-NEXT: 0 + outputBuffer[21] = trans2x2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[22] = trans2x2[1][0]; + // CHECK-NEXT: 2 + outputBuffer[23] = trans2x3[0][0]; + // CHECK-NEXT: 5 + + // Bit shift operation results + outputBuffer[24] = left_shift[0][0]; + // CHECK-NEXT: 2 + outputBuffer[25] = left_shift[0][1]; + // CHECK-NEXT: 4 + outputBuffer[26] = right_shift[1][0]; + // CHECK-NEXT: 2 + outputBuffer[27] = right_shift[1][1]; + // CHECK-NEXT: 4 + + // Comparison operation results (bool matrices cast to TYPE) + outputBuffer[28] = TYPE(less_than[0][0]); + // CHECK-NEXT: 1 + outputBuffer[29] = TYPE(less_than[0][1]); + // CHECK-NEXT: 0 + outputBuffer[30] = TYPE(greater_than[0][1]); + // CHECK-NEXT: 1 + outputBuffer[31] = TYPE(greater_than[1][1]); + // CHECK-NEXT: 1 + outputBuffer[32] = TYPE(less_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[33] = TYPE(less_equal[0][1]); + // CHECK-NEXT: 0 + outputBuffer[34] = TYPE(greater_equal[0][1]); + // CHECK-NEXT: 1 + outputBuffer[35] = TYPE(greater_equal[1][0]); + // CHECK-NEXT: 0 + outputBuffer[36] = TYPE(equal_to[0][0]); + // CHECK-NEXT: 0 + outputBuffer[37] = TYPE(negated[0][0] == expectedBuffer[0]); + // CHECK-NEXT: 1 + outputBuffer[38] = TYPE(negated[1][1] == expectedBuffer[1]); + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/spirv/matrix-bool-lowering.slang b/tests/spirv/matrix-bool-lowering.slang index 63b7caacf..f903fbf17 100644 --- a/tests/spirv/matrix-bool-lowering.slang +++ b/tests/spirv/matrix-bool-lowering.slang @@ -1,6 +1,6 @@ //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -vk -shaderobj -xslang -emit-spirv-directly -//TEST_INPUT:ubuffer(data=[1 0], stride=4):in,name inputBuffer +//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer //TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer RWStructuredBuffer<int> inputBuffer; RWStructuredBuffer<int> outputBuffer; diff --git a/tests/spirv/matrix-integer-lowering.slang b/tests/spirv/matrix-integer-lowering.slang index 518d0f78b..fded652a4 100644 --- a/tests/spirv/matrix-integer-lowering.slang +++ b/tests/spirv/matrix-integer-lowering.slang @@ -10,8 +10,10 @@ typealias m2x3 = matrix<TYPE, 2, 3>; typealias m3x3 = matrix<TYPE, 3, 3>; typealias m2x4 = matrix<TYPE, 2, 4>; -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer RWStructuredBuffer<TYPE> outputBuffer; +RWStructuredBuffer<TYPE> expectedBuffer; struct matrixWrapper { m2x2 mat1 = m2x2(1, 2, 3, 4); @@ -103,6 +105,10 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2; matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2; + // Test matrix negation operations + m2x2 neg_mat = m2x2(1, -2, 3, -4); + m2x2 negated = -neg_mat; + // Store results outputBuffer[0] = val1; // CHECK: 1 @@ -186,4 +192,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) // CHECK-NEXT: 0 outputBuffer[37] = TYPE(not_equal[0][0]); // CHECK-NEXT: 1 + outputBuffer[38] = TYPE(negated[0][0] == expectedBuffer[0]); + // CHECK-NEXT: 1 + outputBuffer[39] = TYPE(negated[1][1] == expectedBuffer[1]); + // CHECK-NEXT: 1 }
\ No newline at end of file diff --git a/tests/wgsl/matrix-bool-lowering.slang b/tests/wgsl/matrix-bool-lowering.slang new file mode 100644 index 000000000..4803fa73a --- /dev/null +++ b/tests/wgsl/matrix-bool-lowering.slang @@ -0,0 +1,114 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -wgsl -shaderobj + +//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> inputBuffer; +RWStructuredBuffer<int> outputBuffer; + +// Global bool constants to avoid constant folding +static bool trueVal; +static bool falseVal; + +struct matrixWrapper { + bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal); +} + +bool elementAnd(bool2x2 matrix) +{ + return trueVal + && matrix[0][0] + && matrix[0][1] + && matrix[1][0] + && matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Load true/false values from input buffer to avoid constant folding + trueVal = inputBuffer[0] != 0; + falseVal = inputBuffer[1] != 0; + + // Test bool matrix construction + bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal); + bool3x3 mat2 = bool3x3( + trueVal, falseVal, trueVal, + falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal + ); + bool2x4 mat3 = bool2x4( + trueVal, falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal, falseVal + ); + + // Test bool matrix element access + bool val1 = mat1[0][0]; + bool val2 = mat2[2][1]; + + // Test bool matrix row access + bool2 row = mat1[1]; + bool3 row3 = mat2[0]; + + // Test logical operations + bool2x2 not_mat = !mat1; + bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal); + + // Test element assignment + mat1[0][1] = trueVal; + mat2[1][2] = falseVal; + + // Test passing bool matrices to functions + bool anded = elementAnd(mat1); + + // Test structs with bool matrix fields + matrixWrapper wrapper = {}; + + // Test any/all operations + bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal); + bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal); + + bool test_all_true = all(all_true); // all elements true -> true + bool test_all_false = all(all_false); // all elements false -> false + bool test_all_mixed = all(mixed); // some elements false -> false + bool test_any_true = any(all_true); // some elements true -> true + bool test_any_false = any(all_false); // no elements true -> false + bool test_any_mixed = any(mixed); // some elements true -> true + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 0 + outputBuffer[2] = row.x; + // CHECK-NEXT: 0 + outputBuffer[3] = row.y; + // CHECK-NEXT: 1 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 0 + outputBuffer[5] = not_mat[0][0]; + // CHECK-NEXT: 0 + outputBuffer[6] = and_mat[0][0]; + // CHECK-NEXT: 1 + outputBuffer[7] = mat1[0][1]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat3[0][1]; + // CHECK-NEXT: 0 + outputBuffer[9] = anded; + // CHECK-NEXT: 0 + outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[11] = test_all_true; + // CHECK-NEXT: 1 + outputBuffer[12] = test_all_false; + // CHECK-NEXT: 0 + outputBuffer[13] = test_all_mixed; + // CHECK-NEXT: 0 + outputBuffer[14] = test_any_true; + // CHECK-NEXT: 1 + outputBuffer[15] = test_any_false; + // CHECK-NEXT: 0 + outputBuffer[16] = test_any_mixed; + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/wgsl/matrix-integer-lowering.slang b/tests/wgsl/matrix-integer-lowering.slang new file mode 100644 index 000000000..fc2a64382 --- /dev/null +++ b/tests/wgsl/matrix-integer-lowering.slang @@ -0,0 +1,199 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -wgsl -shaderobj -xslang -DTYPE=int +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -wgsl -shaderobj -xslang -DTYPE=uint + +#ifndef TYPE +#define TYPE int +#endif + +typealias m2x2 = matrix<TYPE, 2, 2>; +typealias m2x3 = matrix<TYPE, 2, 3>; +typealias m3x3 = matrix<TYPE, 3, 3>; +typealias m2x4 = matrix<TYPE, 2, 4>; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer +RWStructuredBuffer<TYPE> outputBuffer; +RWStructuredBuffer<TYPE> expectedBuffer; + +struct matrixWrapper { + m2x2 mat1 = m2x2(1, 2, 3, 4); + m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10); +}; + +TYPE elementAdd(m2x2 matrix) +{ + return matrix[0][0] + + matrix[0][1] + + matrix[1][0] + + matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Test matrix construction + m2x2 mat1 = m2x2(1, 2, 3, 4); + m3x3 mat2 = m3x3( + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ); + m2x4 mat3 = m2x4( + 10, 11, 12, 13, + 14, 15, 16, 17 + ); + + // Test matrix element access + TYPE val1 = mat1[0][0]; + TYPE val2 = mat2[2][1]; + + // Test matrix row access + vector<TYPE, 2> row = mat1[1]; + vector<TYPE, 3> row3 = mat2[0]; + + // Test arithmetic operations + m2x2 mat5 = m2x2(2, 4, 6, 7); + + m2x2 mat_scalar = 2 * mat1; + m2x2 mat_add = mat1 + mat5; + m2x2 mat_sub = mat5 - mat1; + m2x2 mat_mul = mat1 * mat5; + + // Test passing matrices to functions + TYPE added = elementAdd(mat1); + + // Test structs with matrix fields + matrixWrapper wrapper = {}; + + // Test matrix intrinsic operations + + // Test determinant for square matrices + m2x2 mat6 = m2x2(2, 1, 4, 3); + TYPE det2x2 = TYPE(determinant(mat6)); + TYPE det3x3 = TYPE(determinant(mat2)); + + // Test transpose + matrix<TYPE, 2, 2> trans2x2 = transpose(mat1); + matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2); + + // Test element-wise min/max + m2x2 mat_min = min(mat1, mat5); + m2x2 mat_max = max(mat1, mat5); + + // Test all/any operations (these return bool, but we'll cast to TYPE for output) + m2x2 zero_mat = m2x2(0, 0, 0, 0); + m2x2 mixed_mat = m2x2(1, 0, 2, 0); + + TYPE all_nonzero = TYPE(all(mat1)); + TYPE all_zero = TYPE(all(zero_mat)); + TYPE any_nonzero = TYPE(any(mixed_mat)); + TYPE any_zero = TYPE(any(zero_mat)); + + // Test bit shift operations + m2x2 shift_mat = m2x2(1, 2, 4, 8); + m2x2 left_shift = shift_mat << 1; + m2x2 right_shift = shift_mat >> 1; + + // Test comparison operations (these return bool matrices, cast to TYPE for output) + m2x2 comp_mat1 = m2x2(1, 3, 2, 4); + m2x2 comp_mat2 = m2x2(2, 2, 3, 3); + + matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2; + matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2; + matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2; + matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2; + matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2; + matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2; + + // Test matrix negation operations + m2x2 neg_mat = m2x2(1, -2, 3, -4); + m2x2 negated = -neg_mat; + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 8 + outputBuffer[2] = row.x; + // CHECK-NEXT: 3 + outputBuffer[3] = row.y; + // CHECK-NEXT: 4 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 2 + outputBuffer[5] = mat_scalar[0][0]; + // CHECK-NEXT: 2 + outputBuffer[6] = mat_add[0][0]; + // CHECK-NEXT: 3 + outputBuffer[7] = mat_sub[0][0]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat_mul[1][1]; + // CHECK-NEXT: 28 + outputBuffer[9] = added; + // CHECK-NEXT: 10 + outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0]; + // CHECK-NEXT: 5 + + // Matrix intrinsic operation results + outputBuffer[11] = det2x2; + // CHECK-NEXT: 2 + outputBuffer[12] = det3x3; + // CHECK-NEXT: 0 + outputBuffer[13] = mat_min[0][0]; + // CHECK-NEXT: 1 + outputBuffer[14] = mat_min[1][1]; + // CHECK-NEXT: 4 + outputBuffer[15] = mat_max[0][0]; + // CHECK-NEXT: 2 + outputBuffer[16] = mat_max[1][1]; + // CHECK-NEXT: 7 + outputBuffer[17] = all_nonzero; + // CHECK-NEXT: 1 + outputBuffer[18] = all_zero; + // CHECK-NEXT: 0 + outputBuffer[19] = any_nonzero; + // CHECK-NEXT: 1 + outputBuffer[20] = any_zero; + // CHECK-NEXT: 0 + outputBuffer[21] = trans2x2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[22] = trans2x2[1][0]; + // CHECK-NEXT: 2 + outputBuffer[23] = trans2x3[0][0]; + // CHECK-NEXT: 5 + + // Bit shift operation results + outputBuffer[24] = left_shift[0][0]; + // CHECK-NEXT: 2 + outputBuffer[25] = left_shift[0][1]; + // CHECK-NEXT: 4 + outputBuffer[26] = right_shift[1][0]; + // CHECK-NEXT: 2 + outputBuffer[27] = right_shift[1][1]; + // CHECK-NEXT: 4 + + // Comparison operation results (bool matrices cast to TYPE) + outputBuffer[28] = TYPE(less_than[0][0]); + // CHECK-NEXT: 1 + outputBuffer[29] = TYPE(less_than[0][1]); + // CHECK-NEXT: 0 + outputBuffer[30] = TYPE(greater_than[0][1]); + // CHECK-NEXT: 1 + outputBuffer[31] = TYPE(greater_than[1][1]); + // CHECK-NEXT: 1 + outputBuffer[32] = TYPE(less_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[33] = TYPE(less_equal[0][1]); + // CHECK-NEXT: 0 + outputBuffer[34] = TYPE(greater_equal[0][1]); + // CHECK-NEXT: 1 + outputBuffer[35] = TYPE(greater_equal[1][0]); + // CHECK-NEXT: 0 + outputBuffer[36] = TYPE(equal_to[0][0]); + // CHECK-NEXT: 0 + outputBuffer[37] = TYPE(not_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[38] = TYPE(negated[0][0] == expectedBuffer[0]); + // CHECK-NEXT: 1 + outputBuffer[39] = TYPE(negated[1][1] == expectedBuffer[1]); + // CHECK-NEXT: 1 +}
\ No newline at end of file |
