diff options
| -rw-r--r-- | source/slang/hlsl.meta.slang | 65 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 170 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.cpp | 435 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.h | 2 | ||||
| -rw-r--r-- | tests/compute/logic-no-short-circuit-evaluation.slang | 4 | ||||
| -rw-r--r-- | tests/glsl/matrix-bool-lowering.slang | 114 | ||||
| -rw-r--r-- | tests/glsl/matrix-integer-lowering.slang | 199 | ||||
| -rw-r--r-- | tests/metal/matrix-bool-lowering.slang | 119 | ||||
| -rw-r--r-- | tests/metal/matrix-integer-lowering.slang | 202 | ||||
| -rw-r--r-- | tests/spirv/matrix-bool-lowering.slang | 2 | ||||
| -rw-r--r-- | tests/spirv/matrix-integer-lowering.slang | 12 | ||||
| -rw-r--r-- | tests/wgsl/matrix-bool-lowering.slang | 114 | ||||
| -rw-r--r-- | tests/wgsl/matrix-integer-lowering.slang | 199 |
15 files changed, 1484 insertions, 172 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index aa494ec95..9fd5c8b6e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6481,14 +6481,10 @@ bool all(T x) { __target_switch { - default: - __intrinsic_asm "bool($0)"; case hlsl: __intrinsic_asm "all"; case metal: __intrinsic_asm "all"; - case wgsl: - __intrinsic_asm "all"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -6505,6 +6501,8 @@ bool all(T x) return __slang_noop_cast<bool>(x); else return false; + default: + __intrinsic_asm "bool($0)"; } } @@ -6550,9 +6548,17 @@ bool all(vector<T,N> x) }; } case wgsl: + // WGSL all() only works with boolean vectors if (__isBool<T>()) - __intrinsic_asm "all"; - __intrinsic_asm "all(vec$N0<bool>($0))"; + __intrinsic_asm "all($0)"; + else + { + // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion + bool result = true; + for(int i = 0; i < N; ++i) + result = result && all(x[i]); + return result; + } default: bool result = true; for(int i = 0; i < N; ++i) @@ -6563,7 +6569,7 @@ bool all(vector<T,N> x) __generic<T : __BuiltinType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool all(matrix<T,N,M> x) { __target_switch @@ -6655,7 +6661,8 @@ bool any(T x) case metal: __intrinsic_asm "any"; case wgsl: - __intrinsic_asm "any"; + // For scalars, any() doesn't exist in WGSL, just convert to bool + __intrinsic_asm "bool($0)"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -6686,7 +6693,17 @@ bool any(vector<T, N> x) case hlsl: __intrinsic_asm "any"; case metal: - __intrinsic_asm "any"; + if (__isBool<T>()) + __intrinsic_asm "any"; + else + { + // For non-bool types, convert to bool vector first + // Metal's any() only works with bool vectors + bool result = false; + for(int i = 0; i < N; ++i) + result = result || any(x[i]); + return result; + } case glsl: __intrinsic_asm "any(bvec$N0($0))"; case spirv: @@ -6714,7 +6731,17 @@ bool any(vector<T, N> x) }; } case wgsl: - __intrinsic_asm "any"; + // WGSL any() only works with boolean vectors + if (__isBool<T>()) + __intrinsic_asm "any($0)"; + else + { + // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion + bool result = false; + for(int i = 0; i < N; ++i) + result = result || any(x[i]); + return result; + } default: bool result = false; for(int i = 0; i < N; ++i) @@ -6725,7 +6752,7 @@ bool any(vector<T, N> x) __generic<T : __BuiltinType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool any(matrix<T, N, M> x) { __target_switch @@ -8626,11 +8653,8 @@ T determinant(matrix<T,N,N> m) { __target_switch { - case glsl: __intrinsic_asm "determinant"; case hlsl: __intrinsic_asm "determinant"; - case metal: __intrinsic_asm "determinant"; - case wgsl: __intrinsic_asm "determinant"; - // SPIR-V doesn't support integer determinants, so we need to implement it manually + // GLSL, WGSL, and SPIR-V don't support integer determinants for lowered matrices, so we need to implement it manually default: static_assert(N >= 1 && N <= 4, "determinant is only implemented up to 4x4 matrices"); if (N == 1) @@ -13804,16 +13828,14 @@ matrix<T, M, N> transpose(matrix<T, N, M> x) } __generic<T : __BuiltinIntegerType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] [PreferRecompute] matrix<T, M, N> transpose(matrix<T, N, M> x) { __target_switch { - case glsl: __intrinsic_asm "transpose"; case hlsl: __intrinsic_asm "transpose"; - case wgsl: __intrinsic_asm "transpose"; - // SPIRV-V doenst't support integer matrices, so transpose it manually + // GLSL, WGSL, SPIR-V, and Metal don't support integer matrices when lowered, so transpose it manually default: matrix<T, M, N> result; for (int r = 0; r < M; ++r) @@ -13824,19 +13846,18 @@ matrix<T, M, N> transpose(matrix<T, N, M> x) } __generic<T : __BuiltinLogicalType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] [PreferRecompute] [OverloadRank(-1)] matrix<T, M, N> transpose(matrix<T, N, M> x) { __target_switch { - case glsl: __intrinsic_asm "transpose"; case hlsl: __intrinsic_asm "transpose"; case spirv: return spirv_asm { OpTranspose $$matrix<T, M, N> result $x }; - case wgsl: __intrinsic_asm "transpose"; + // GLSL, WGSL, and Metal don't support bool matrices when lowered, so transpose it manually default: matrix<T, M, N> result; for (int r = 0; r < M; ++r) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b1b4c4570..da2620856 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -360,7 +360,7 @@ struct SpvLiteralBits // > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed // > four per word, following the little-endian convention (i.e., the // > first octet is in the lowest-order 8 bits of the word). - // > The final word contains the string's nul-termination character (0), and + // > The final word contains the string’s nul-termination character (0), and // > all contents past the end of the string in the final word are padded with 0. // First work out the amount of words we'll need @@ -2039,24 +2039,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); - auto elementType = matrixType->getElementType(); - - // SPIR-V only supports floating-point matrices - // bool/int matrices should be lowered to - // arrays of vectors before reaching here - SLANG_ASSERT(!as<IRBoolType>(elementType)); - SLANG_ASSERT(!as<IRIntType>(elementType)); - SLANG_ASSERT(!as<IRUIntType>(elementType)); - auto vectorSpvType = ensureVectorType( - static_cast<IRBasicType*>(elementType)->getBaseType(), + static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); - const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount)); - SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv); - return matrixSpvType; + auto matrixSPVType = emitOpTypeMatrix( + inst, + vectorSpvType, + SpvLiteralInteger::from32(int32_t(columnCount))); + return matrixSPVType; } case kIROp_ArrayType: case kIROp_UnsizedArrayType: @@ -2628,7 +2621,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvWord arrayed = inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed; - // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored." + // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored." SpvWord depth = ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled @@ -7780,40 +7773,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Otherwise, operands are raw elements, we need to construct row vectors first, // then construct matrix from row vectors. List<SpvInst*> rowVectors; - - IRIntegerValue rowCount; - IRIntegerValue colCount; - IRType* elementType; - - // Data type can be either matrix or vector depending on the - // legalization requirements - auto dataType = inst->getDataType(); - - if (auto matrixType = as<IRMatrixType>(dataType)) - { - elementType = matrixType->getElementType(); - rowCount = getIntVal(matrixType->getRowCount()); - colCount = getIntVal(matrixType->getColumnCount()); - } - else if (auto arrayType = as<IRArrayType>(dataType)) - { - auto vectorType = as<IRVectorType>(arrayType->getElementType()); - SLANG_ASSERT(vectorType); - - elementType = vectorType->getElementType(); - rowCount = getIntVal(arrayType->getElementCount()); - colCount = getIntVal(vectorType->getElementCount()); - } - else - { - SLANG_UNEXPECTED("data type for makeMatrix operation is " - "expected be either a matrix or array type"); - } - + auto matrixType = cast<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); IRBuilder builder(inst); builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); List<IRInst*> colElements; UInt index = 0; for (IRIntegerValue j = 0; j < rowCount; j++) @@ -7938,10 +7903,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex ArrayView<IRInst*> operands) { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); - SLANG_ASSERT(elementType); - IRBasicType* basicType = as<IRBasicType>(elementType); - SLANG_ASSERT(basicType); SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); if (opCode == SpvOpUndef) @@ -8002,52 +7964,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } - // Helper method to handle composite arithmetic operations for matrices and arrays - SpvInst* emitCompositeArithmetic( - SpvInstParent* parent, - IRInst* inst, - IRIntegerValue rowCount, - IRIntegerValue colCount, - IRType* elementType, - IRType* resultType, - bool isMatrixType) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - List<SpvInst*> rows; - - for (IRIntegerValue i = 0; i < rowCount; i++) - { - List<IRInst*> operands; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto originalOperand = inst->getOperand(j); - bool shouldExtract = - isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr - : as<IRArrayType>(originalOperand->getDataType()) != nullptr; - - if (shouldExtract) - { - auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); - operands.add(operand); - } - else - { - operands.add(originalOperand); - } - } - rows.add(emitVectorOrScalarArithmetic( - parent, - nullptr, - rowVectorType, - inst->getOp(), - inst->getOperandCount(), - operands.getArrayView())); - } - return emitCompositeConstruct(parent, inst, resultType, rows); - } SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { @@ -8055,38 +7971,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - matrixType->getElementType(), - inst->getDataType(), - true); - } - else if (const auto arrayType = as<IRArrayType>(inst->getDataType())) - { - // Only for legalization - auto arrayElementType = arrayType->getElementType(); - SLANG_ASSERT(as<IRVectorType>(arrayElementType)); - - auto vectorType = as<IRVectorType>(arrayElementType); - auto elementType = vectorType->getElementType(); - SLANG_ASSERT( - as<IRBoolType>(elementType) || as<IRUIntType>(elementType) || - as<IRIntType>(elementType)); - - auto rowCount = getIntVal(arrayType->getElementCount()); - auto colCount = getIntVal(vectorType->getElementCount()); - - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - elementType, - inst->getDataType(), - false); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<SpvInst*> rows; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + if (as<IRMatrixType>(originalOperand->getDataType())) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic( + parent, + nullptr, + rowVectorType, + inst->getOp(), + inst->getOperandCount(), + operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); } Array<IRInst*, 4> operands; diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index fbcb54d10..53c3aa487 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1624,6 +1624,22 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(")"); return true; } + case kIROp_Neg: + { + auto opType = inst->getOperand(0)->getDataType(); + if (as<IRMatrixType>(opType) || as<IRVectorType>(opType)) + { + // WGSL does not support negate operator on matrices and vectors, + // we should emit "(type(0) - op0)" instead. + m_writer->emit("("); + emitType(inst->getDataType()); + m_writer->emit("(0) - "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + break; + } } return false; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index b548ef632..405bca5a2 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1339,7 +1339,10 @@ Result linkAndOptimizeIR( } legalizeMatrixTypes(targetProgram, irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-MATRIX-LEGALIZATION"); + legalizeVectorTypes(irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-VECTOR-LEGALIZATION"); // Once specialization and type legalization have been performed, // we should perform some of our basic optimization steps again, diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp index 0b972b5bd..8c8cb0c84 100644 --- a/source/slang/slang-ir-legalize-matrix-types.cpp +++ b/source/slang/slang-ir-legalize-matrix-types.cpp @@ -1,6 +1,7 @@ #include "slang-ir-legalize-matrix-types.h" #include "slang-compiler.h" +#include "slang-ir-insts-enum.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -50,6 +51,9 @@ struct MatrixTypeLoweringContext case CodeGenTarget::WGSL: case CodeGenTarget::WGSLSPIRV: case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::Metal: + case CodeGenTarget::MetalLib: + case CodeGenTarget::MetalLibAssembly: return true; default: return false; @@ -66,33 +70,430 @@ struct MatrixTypeLoweringContext as<IRIntType>(elementType); } - IRInst* getReplacement(IRInst* inst) + IRInst* legalizeMatrixTypeDeclaration(IRInst* inst) { - if (auto replacement = replacements.tryGetValue(inst)) - return *replacement; + auto matrixType = as<IRMatrixType>(inst); + if (shouldLowerMatrixType(matrixType)) + { + // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C) + auto elementType = matrixType->getElementType(); + auto rowCount = matrixType->getRowCount(); + auto columnCount = matrixType->getColumnCount(); - IRInst* newInst = inst; + IRBuilder builder(matrixType); + builder.setInsertBefore(matrixType); + + // Create vector type for columns: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type for rows: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + return arrayType; + } + return inst; + } + + IRInst* legalizeMakeMatrix(IRInst* inst) + { + auto makeMatrix = as<IRMakeMatrix>(inst); + auto matrixType = as<IRMatrixType>(makeMatrix->getDataType()); + + SLANG_ASSERT(matrixType && "Matrix type is expected"); + SLANG_ASSERT( + shouldLowerMatrixType(matrixType) && "Matrix type is expected to need legalization"); + + // Lower makeMatrix to makeArray of makeVectors + auto elementType = matrixType->getElementType(); + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto columnCount = as<IRIntLit>(matrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); - if (auto matrixType = as<IRMatrixType>(inst)) + IRBuilder builder(makeMatrix); + builder.setInsertBefore(makeMatrix); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Group operands into rows and create vectors + List<IRInst*> rowVectors; + UInt operandIndex = 0; + + // Assert that we have the expected number of operands + SLANG_ASSERT( + makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()) && + "makeMatrix operand count must match matrix dimensions"); + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) { - if (shouldLowerMatrixType(matrixType)) + List<IRInst*> rowElements; + for (IRIntegerValue col = 0; col < columnCount->getValue(); col++) { - // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C) - auto elementType = matrixType->getElementType(); - auto rowCount = matrixType->getRowCount(); - auto columnCount = matrixType->getColumnCount(); + SLANG_ASSERT( + operandIndex < makeMatrix->getOperandCount() && "Operand index out of bounds"); + rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex))); + operandIndex++; + } + + SLANG_ASSERT( + rowElements.getCount() == columnCount->getValue() && + "Row elements count must match column count"); + auto rowVector = builder.emitMakeVector(vectorType, rowElements); + rowVectors.add(rowVector); + } + + SLANG_ASSERT( + rowVectors.getCount() == rowCount->getValue() && + "Row vectors count must match matrix row count"); + return builder.emitMakeArray(arrayType, rowVectors.getCount(), rowVectors.getBuffer()); + } + + IRInst* legalizeMatrixMatrixBinaryOperation( + IRBuilder& builder, + IRInst* legalizedA, + IRInst* legalizedB, + IRMatrixType* resultMatrixType, + IROp binaryOp) + { + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); - IRBuilder builder(matrixType); - builder.setInsertBefore(matrixType); + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); - // Create vector type for columns: vector<T, C> - auto vectorType = builder.getVectorType(elementType, columnCount); + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); - // Create array type for rows: vector<T, C>[R] - auto arrayType = builder.getArrayType(vectorType, rowCount); + // Extract vectors from both arrays and apply binary operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from each operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst); + auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst); - newInst = arrayType; + // Apply the binary operation to the vectors + IRInst* args[] = {vectorA, vectorB}; + auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + + template<bool matrixIsFirst> + IRInst* legalizeMatrixMixedBinaryOperation( + IRBuilder& builder, + IRInst* legalizedMatrix, + IRInst* legalizedOther, + IRMatrixType* resultMatrixType, + IROp binaryOp) + { + // Verify that the other operand is either a vector or scalar type + auto otherType = legalizedOther->getDataType(); + auto otherVectorType = as<IRVectorType>(otherType); + auto otherBasicType = as<IRBasicType>(otherType); + SLANG_ASSERT( + (otherVectorType || otherBasicType) && "Other operand must be vector or scalar type"); + + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Extract vectors from matrix array and apply binary operation with other operand + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from matrix array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto matrixRowVector = builder.emitElementExtract(legalizedMatrix, rowIndexInst); + + // Apply the binary operation between matrix row vector and other operand + IRInst* args[2]; + if constexpr (matrixIsFirst) + { + args[0] = matrixRowVector; + args[1] = legalizedOther; } + else + { + args[0] = legalizedOther; + args[1] = matrixRowVector; + } + auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeBinaryOperation(IRInst* inst, IROp binaryOp) + { + IRInst* opdA = inst->getOperand(0); + IRInst* opdB = inst->getOperand(1); + + // Check what types we're dealing with + auto typeA = opdA->getDataType(); + auto typeB = opdB->getDataType(); + + auto matrixTypeA = as<IRMatrixType>(typeA); + auto matrixTypeB = as<IRMatrixType>(typeB); + + bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA); + bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB); + + // Get the result matrix type to determine dimensions + auto resultMatrixType = as<IRMatrixType>(inst->getDataType()); + SLANG_ASSERT(resultMatrixType && "Binary operation should have matrix result type"); + SLANG_ASSERT( + shouldLowerMatrixType(resultMatrixType) && + "Result matrix type should need legalization"); + + // Create IRBuilder at the top level + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Get legalized operands once + IRInst* legalizedA = getReplacement(opdA); + IRInst* legalizedB = getReplacement(opdB); + + if (shouldLowerA && shouldLowerB) + { + return legalizeMatrixMatrixBinaryOperation( + builder, + legalizedA, + legalizedB, + resultMatrixType, + binaryOp); + } + else if (shouldLowerA && !shouldLowerB) + { + return legalizeMatrixMixedBinaryOperation<true>( + builder, + legalizedA, + legalizedB, + resultMatrixType, + binaryOp); + } + else if (!shouldLowerA && shouldLowerB) + { + return legalizeMatrixMixedBinaryOperation<false>( + builder, + legalizedB, + legalizedA, + resultMatrixType, + binaryOp); + } + + // Neither operand is a matrix that needs lowering, shouldn't reach here + SLANG_UNREACHABLE("legalizeBinaryOperation called but no matrix operand needs lowering"); + } + + IRInst* legalizeComparisonOperation(IRInst* inst, IROp comparisonOp) + { + IRInst* opdA = inst->getOperand(0); + IRInst* opdB = inst->getOperand(1); + + // Check what types we're dealing with + auto typeA = opdA->getDataType(); + auto typeB = opdB->getDataType(); + + auto matrixTypeA = as<IRMatrixType>(typeA); + auto matrixTypeB = as<IRMatrixType>(typeB); + + bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA); + bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB); + + // Only matrix-matrix comparisons are supported + SLANG_ASSERT( + shouldLowerA && shouldLowerB && + "Comparison operations only supported between matrices that need lowering"); + + // Create IRBuilder at the top level + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Get legalized operands + IRInst* legalizedA = getReplacement(opdA); + IRInst* legalizedB = getReplacement(opdB); + + auto rowCount = as<IRIntLit>(matrixTypeA->getRowCount()); + auto columnCount = as<IRIntLit>(matrixTypeA->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + // Create boolean vector type for rows: vector<bool, C> + auto boolType = builder.getBoolType(); + auto boolVectorType = builder.getVectorType(boolType, columnCount); + + // Create array type: vector<bool, C>[R] + auto boolArrayType = builder.getArrayType(boolVectorType, rowCount); + + // Extract vectors from both arrays and apply comparison operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from each operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst); + auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst); + + // Apply the comparison operation to the vectors + IRInst* args[] = {vectorA, vectorB}; + auto resultVector = builder.emitIntrinsicInst(boolVectorType, comparisonOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + boolArrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeUnaryOperation(IRInst* inst, IROp unaryOp) + { + IRInst* operand = inst->getOperand(0); + + // Get the legalized operand (should be an array of vectors) + IRInst* legalizedOperand = getReplacement(operand); + + // Get the result matrix type to determine dimensions + auto resultMatrixType = as<IRMatrixType>(inst->getDataType()); + SLANG_ASSERT(resultMatrixType && "Unary operation should have matrix result type"); + SLANG_ASSERT( + shouldLowerMatrixType(resultMatrixType) && + "Result matrix type should need legalization"); + + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Extract vectors from array and apply unary operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vector = builder.emitElementExtract(legalizedOperand, rowIndexInst); + + // Apply the unary operation to the vector + IRInst* args[] = {vector}; + auto resultVector = builder.emitIntrinsicInst(vectorType, unaryOp, 1, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeMatrixProducingInstruction(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_MakeMatrix: + return legalizeMakeMatrix(inst); + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + return legalizeBinaryOperation(inst, inst->getOp()); + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + return legalizeComparisonOperation(inst, inst->getOp()); + case kIROp_Not: + case kIROp_BitNot: + case kIROp_Neg: + return legalizeUnaryOperation(inst, inst->getOp()); + default: + break; + } + + return inst; + } + + IRInst* getReplacement(IRInst* inst) + { + if (auto replacement = replacements.tryGetValue(inst)) + return *replacement; + + IRInst* newInst = inst; + if (as<IRMatrixType>(inst)) + newInst = legalizeMatrixTypeDeclaration(inst); + + IRType* resultType = inst->getDataType(); + if (auto matrixType = as<IRMatrixType>(resultType)) + { + if (shouldLowerMatrixType(matrixType)) + newInst = legalizeMatrixProducingInstruction(inst); } replacements[inst] = newInst; diff --git a/source/slang/slang-ir-legalize-matrix-types.h b/source/slang/slang-ir-legalize-matrix-types.h index 418e80a83..a2e71a402 100644 --- a/source/slang/slang-ir-legalize-matrix-types.h +++ b/source/slang/slang-ir-legalize-matrix-types.h @@ -7,7 +7,7 @@ struct IRModule; class DiagnosticSink; class TargetProgram; -// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, and GLSL targets +// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, GLSL, and Metal targets void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink); } // namespace Slang
\ No newline at end of file diff --git a/tests/compute/logic-no-short-circuit-evaluation.slang b/tests/compute/logic-no-short-circuit-evaluation.slang index ea2b7a0c3..342a11f28 100644 --- a/tests/compute/logic-no-short-circuit-evaluation.slang +++ b/tests/compute/logic-no-short-circuit-evaluation.slang @@ -32,7 +32,7 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) //SM5:(all({{.*}}&& //HLSL2018:(all({{.*}}&& //SM6:(all(and( - //WGS:(all(select(vec2<bool>(false), + //WGS:(all((select(vec2<bool>(false), //MTL:(all({{.*}}&& if (all(bool2(index >= 1) && assignFunc(index))) { @@ -54,7 +54,7 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) //SM5:(all({{.*}}?{{.*}}: //HLSL2018:(all({{.*}}?{{.*}}: //SM6:(all(select( - //WGS:(all(select(vec2<bool>(false), + //WGS:(all((select(vec2<bool>(false), //MTL:(all(select(bool2(false) if (all(bool2(index >= 3) ? assignFunc(index) : bool2(false))) { diff --git a/tests/glsl/matrix-bool-lowering.slang b/tests/glsl/matrix-bool-lowering.slang new file mode 100644 index 000000000..9f2ad913f --- /dev/null +++ b/tests/glsl/matrix-bool-lowering.slang @@ -0,0 +1,114 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -emit-spirv-via-glsl -shaderobj + +//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> inputBuffer; +RWStructuredBuffer<int> outputBuffer; + +// Global bool constants to avoid constant folding +static bool trueVal; +static bool falseVal; + +struct matrixWrapper { + bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal); +} + +bool elementAnd(bool2x2 matrix) +{ + return trueVal + && matrix[0][0] + && matrix[0][1] + && matrix[1][0] + && matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Load true/false values from input buffer to avoid constant folding + trueVal = inputBuffer[0] != 0; + falseVal = inputBuffer[1] != 0; + + // Test bool matrix construction + bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal); + bool3x3 mat2 = bool3x3( + trueVal, falseVal, trueVal, + falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal + ); + bool2x4 mat3 = bool2x4( + trueVal, falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal, falseVal + ); + + // Test bool matrix element access + bool val1 = mat1[0][0]; + bool val2 = mat2[2][1]; + + // Test bool matrix row access + bool2 row = mat1[1]; + bool3 row3 = mat2[0]; + + // Test logical operations + bool2x2 not_mat = !mat1; + bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal); + + // Test element assignment + mat1[0][1] = trueVal; + mat2[1][2] = falseVal; + + // Test passing bool matrices to functions + bool anded = elementAnd(mat1); + + // Test structs with bool matrix fields + matrixWrapper wrapper = {}; + + // Test any/all operations + bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal); + bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal); + + bool test_all_true = all(all_true); // all elements true -> true + bool test_all_false = all(all_false); // all elements false -> false + bool test_all_mixed = all(mixed); // some elements false -> false + bool test_any_true = any(all_true); // some elements true -> true + bool test_any_false = any(all_false); // no elements true -> false + bool test_any_mixed = any(mixed); // some elements true -> true + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 0 + outputBuffer[2] = row.x; + // CHECK-NEXT: 0 + outputBuffer[3] = row.y; + // CHECK-NEXT: 1 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 0 + outputBuffer[5] = not_mat[0][0]; + // CHECK-NEXT: 0 + outputBuffer[6] = and_mat[0][0]; + // CHECK-NEXT: 1 + outputBuffer[7] = mat1[0][1]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat3[0][1]; + // CHECK-NEXT: 0 + outputBuffer[9] = anded; + // CHECK-NEXT: 0 + outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[11] = test_all_true; + // CHECK-NEXT: 1 + outputBuffer[12] = test_all_false; + // CHECK-NEXT: 0 + outputBuffer[13] = test_all_mixed; + // CHECK-NEXT: 0 + outputBuffer[14] = test_any_true; + // CHECK-NEXT: 1 + outputBuffer[15] = test_any_false; + // CHECK-NEXT: 0 + outputBuffer[16] = test_any_mixed; + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/glsl/matrix-integer-lowering.slang b/tests/glsl/matrix-integer-lowering.slang new file mode 100644 index 000000000..4d6033d79 --- /dev/null +++ b/tests/glsl/matrix-integer-lowering.slang @@ -0,0 +1,199 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -vk -output-using-type -compute -emit-spirv-via-glsl -shaderobj -xslang -DTYPE=int +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -vk -output-using-type -compute -emit-spirv-via-glsl -shaderobj -xslang -DTYPE=uint + +#ifndef TYPE +#define TYPE int +#endif + +typealias m2x2 = matrix<TYPE, 2, 2>; +typealias m2x3 = matrix<TYPE, 2, 3>; +typealias m3x3 = matrix<TYPE, 3, 3>; +typealias m2x4 = matrix<TYPE, 2, 4>; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer +RWStructuredBuffer<TYPE> outputBuffer; +RWStructuredBuffer<TYPE> expectedBuffer; + +struct matrixWrapper { + m2x2 mat1 = m2x2(1, 2, 3, 4); + m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10); +}; + +TYPE elementAdd(m2x2 matrix) +{ + return matrix[0][0] + + matrix[0][1] + + matrix[1][0] + + matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Test matrix construction + m2x2 mat1 = m2x2(1, 2, 3, 4); + m3x3 mat2 = m3x3( + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ); + m2x4 mat3 = m2x4( + 10, 11, 12, 13, + 14, 15, 16, 17 + ); + + // Test matrix element access + TYPE val1 = mat1[0][0]; + TYPE val2 = mat2[2][1]; + + // Test matrix row access + vector<TYPE, 2> row = mat1[1]; + vector<TYPE, 3> row3 = mat2[0]; + + // Test arithmetic operations + m2x2 mat5 = m2x2(2, 4, 6, 7); + + m2x2 mat_scalar = 2 * mat1; + m2x2 mat_add = mat1 + mat5; + m2x2 mat_sub = mat5 - mat1; + m2x2 mat_mul = mat1 * mat5; + + // Test passing matrices to functions + TYPE added = elementAdd(mat1); + + // Test structs with matrix fields + matrixWrapper wrapper = {}; + + // Test matrix intrinsic operations + + // Test determinant for square matrices + m2x2 mat6 = m2x2(2, 1, 4, 3); + TYPE det2x2 = TYPE(determinant(mat6)); + TYPE det3x3 = TYPE(determinant(mat2)); + + // Test transpose + matrix<TYPE, 2, 2> trans2x2 = transpose(mat1); + matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2); + + // Test element-wise min/max + m2x2 mat_min = min(mat1, mat5); + m2x2 mat_max = max(mat1, mat5); + + // Test all/any operations (these return bool, but we'll cast to TYPE for output) + m2x2 zero_mat = m2x2(0, 0, 0, 0); + m2x2 mixed_mat = m2x2(1, 0, 2, 0); + + TYPE all_nonzero = TYPE(all(mat1)); + TYPE all_zero = TYPE(all(zero_mat)); + TYPE any_nonzero = TYPE(any(mixed_mat)); + TYPE any_zero = TYPE(any(zero_mat)); + + // Test bit shift operations + m2x2 shift_mat = m2x2(1, 2, 4, 8); + m2x2 left_shift = shift_mat << 1; + m2x2 right_shift = shift_mat >> 1; + + // Test comparison operations (these return bool matrices, cast to TYPE for output) + m2x2 comp_mat1 = m2x2(1, 3, 2, 4); + m2x2 comp_mat2 = m2x2(2, 2, 3, 3); + + matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2; + matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2; + matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2; + matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2; + matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2; + matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2; + + // Test matrix negation operations + m2x2 neg_mat = m2x2(1, -2, 3, -4); + m2x2 negated = -neg_mat; + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 8 + outputBuffer[2] = row.x; + // CHECK-NEXT: 3 + outputBuffer[3] = row.y; + // CHECK-NEXT: 4 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 2 + outputBuffer[5] = mat_scalar[0][0]; + // CHECK-NEXT: 2 + outputBuffer[6] = mat_add[0][0]; + // CHECK-NEXT: 3 + outputBuffer[7] = mat_sub[0][0]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat_mul[1][1]; + // CHECK-NEXT: 28 + outputBuffer[9] = added; + // CHECK-NEXT: 10 + outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0]; + // CHECK-NEXT: 5 + + // Matrix intrinsic operation results + outputBuffer[11] = det2x2; + // CHECK-NEXT: 2 + outputBuffer[12] = det3x3; + // CHECK-NEXT: 0 + outputBuffer[13] = mat_min[0][0]; + // CHECK-NEXT: 1 + outputBuffer[14] = mat_min[1][1]; + // CHECK-NEXT: 4 + outputBuffer[15] = mat_max[0][0]; + // CHECK-NEXT: 2 + outputBuffer[16] = mat_max[1][1]; + // CHECK-NEXT: 7 + outputBuffer[17] = all_nonzero; + // CHECK-NEXT: 1 + outputBuffer[18] = all_zero; + // CHECK-NEXT: 0 + outputBuffer[19] = any_nonzero; + // CHECK-NEXT: 1 + outputBuffer[20] = any_zero; + // CHECK-NEXT: 0 + outputBuffer[21] = trans2x2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[22] = trans2x2[1][0]; + // CHECK-NEXT: 2 + outputBuffer[23] = trans2x3[0][0]; + // CHECK-NEXT: 5 + + // Bit shift operation results + outputBuffer[24] = left_shift[0][0]; + // CHECK-NEXT: 2 + outputBuffer[25] = left_shift[0][1]; + // CHECK-NEXT: 4 + outputBuffer[26] = right_shift[1][0]; + // CHECK-NEXT: 2 + outputBuffer[27] = right_shift[1][1]; + // CHECK-NEXT: 4 + + // Comparison operation results (bool matrices cast to TYPE) + outputBuffer[28] = TYPE(less_than[0][0]); + // CHECK-NEXT: 1 + outputBuffer[29] = TYPE(less_than[0][1]); + // CHECK-NEXT: 0 + outputBuffer[30] = TYPE(greater_than[0][1]); + // CHECK-NEXT: 1 + outputBuffer[31] = TYPE(greater_than[1][1]); + // CHECK-NEXT: 1 + outputBuffer[32] = TYPE(less_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[33] = TYPE(less_equal[0][1]); + // CHECK-NEXT: 0 + outputBuffer[34] = TYPE(greater_equal[0][1]); + // CHECK-NEXT: 1 + outputBuffer[35] = TYPE(greater_equal[1][0]); + // CHECK-NEXT: 0 + outputBuffer[36] = TYPE(equal_to[0][0]); + // CHECK-NEXT: 0 + outputBuffer[37] = TYPE(not_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[38] = TYPE(negated[0][0] == expectedBuffer[0]); + // CHECK-NEXT: 1 + outputBuffer[39] = TYPE(negated[1][1] == expectedBuffer[1]); + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/metal/matrix-bool-lowering.slang b/tests/metal/matrix-bool-lowering.slang new file mode 100644 index 000000000..4248bb573 --- /dev/null +++ b/tests/metal/matrix-bool-lowering.slang @@ -0,0 +1,119 @@ +//TEST:SIMPLE(filecheck=METAL): -target metal -stage compute -entry computeMain +//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage compute -entry computeMain +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -mtl -shaderobj + +//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> inputBuffer; +RWStructuredBuffer<int> outputBuffer; + +// Global bool constants to avoid constant folding +static bool trueVal; +static bool falseVal; + +struct matrixWrapper { + bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal); +} + +bool elementAnd(bool2x2 matrix) +{ + return trueVal + && matrix[0][0] + && matrix[0][1] + && matrix[1][0] + && matrix[1][1]; +} + +// METAL: array<bool2, int(2)> +// METALLIB: @computeMain + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Load true/false values from input buffer to avoid constant folding + trueVal = inputBuffer[0] != 0; + falseVal = inputBuffer[1] != 0; + + // Test bool matrix construction + bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal); + bool3x3 mat2 = bool3x3( + trueVal, falseVal, trueVal, + falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal + ); + bool2x4 mat3 = bool2x4( + trueVal, falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal, falseVal + ); + + // Test bool matrix element access + bool val1 = mat1[0][0]; + bool val2 = mat2[2][1]; + + // Test bool matrix row access + bool2 row = mat1[1]; + bool3 row3 = mat2[0]; + + // Test logical operations + bool2x2 not_mat = !mat1; + bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal); + + // Test element assignment + mat1[0][1] = trueVal; + mat2[1][2] = falseVal; + + // Test passing bool matrices to functions + bool anded = elementAnd(mat1); + + // Test structs with bool matrix fields + matrixWrapper wrapper = {}; + + // Test any/all operations + bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal); + bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal); + + bool test_all_true = all(all_true); // all elements true -> true + bool test_all_false = all(all_false); // all elements false -> false + bool test_all_mixed = all(mixed); // some elements false -> false + bool test_any_true = any(all_true); // some elements true -> true + bool test_any_false = any(all_false); // no elements true -> false + bool test_any_mixed = any(mixed); // some elements true -> true + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 0 + outputBuffer[2] = row.x; + // CHECK-NEXT: 0 + outputBuffer[3] = row.y; + // CHECK-NEXT: 1 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 0 + outputBuffer[5] = not_mat[0][0]; + // CHECK-NEXT: 0 + outputBuffer[6] = and_mat[0][0]; + // CHECK-NEXT: 1 + outputBuffer[7] = mat1[0][1]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat3[0][1]; + // CHECK-NEXT: 0 + outputBuffer[9] = anded; + // CHECK-NEXT: 0 + outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[11] = test_all_true; + // CHECK-NEXT: 1 + outputBuffer[12] = test_all_false; + // CHECK-NEXT: 0 + outputBuffer[13] = test_all_mixed; + // CHECK-NEXT: 0 + outputBuffer[14] = test_any_true; + // CHECK-NEXT: 1 + outputBuffer[15] = test_any_false; + // CHECK-NEXT: 0 + outputBuffer[16] = test_any_mixed; + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/metal/matrix-integer-lowering.slang b/tests/metal/matrix-integer-lowering.slang new file mode 100644 index 000000000..04aec5a7c --- /dev/null +++ b/tests/metal/matrix-integer-lowering.slang @@ -0,0 +1,202 @@ +//TEST:SIMPLE(filecheck=METAL): -target metal -stage compute -entry computeMain -DTYPE=int +//TEST:SIMPLE(filecheck=METALLIB): -target metallib -stage compute -entry computeMain -DTYPE=int +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -mtl -shaderobj -xslang -DTYPE=int +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -mtl -shaderobj -xslang -DTYPE=uint + +#ifndef TYPE +#define TYPE int +#endif + +typealias m2x2 = matrix<TYPE, 2, 2>; +typealias m2x3 = matrix<TYPE, 2, 3>; +typealias m3x3 = matrix<TYPE, 3, 3>; +typealias m2x4 = matrix<TYPE, 2, 4>; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer +RWStructuredBuffer<TYPE> outputBuffer; +RWStructuredBuffer<TYPE> expectedBuffer; + +struct matrixWrapper { + m2x2 mat1 = m2x2(1, 2, 3, 4); + m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10); +}; + +TYPE elementAdd(m2x2 matrix) +{ + return matrix[0][0] + + matrix[0][1] + + matrix[1][0] + + matrix[1][1]; +} + +// METAL: array<{{(int|uint)}}2, int(2)> +// METALLIB: @computeMain + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Test matrix construction + m2x2 mat1 = m2x2(1, 2, 3, 4); + m3x3 mat2 = m3x3( + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ); + m2x4 mat3 = m2x4( + 10, 11, 12, 13, + 14, 15, 16, 17 + ); + + // Test matrix element access + TYPE val1 = mat1[0][0]; + TYPE val2 = mat2[2][1]; + + // Test matrix row access + vector<TYPE, 2> row = mat1[1]; + vector<TYPE, 3> row3 = mat2[0]; + + // Test arithmetic operations + m2x2 mat5 = m2x2(2, 4, 6, 7); + + m2x2 mat_scalar = 2 * mat1; + m2x2 mat_add = mat1 + mat5; + m2x2 mat_sub = mat5 - mat1; + m2x2 mat_mul = mat1 * mat5; + + // Test passing matrices to functions + TYPE added = elementAdd(mat1); + + // Test structs with matrix fields + matrixWrapper wrapper = {}; + + // Test matrix intrinsic operations + + // Test determinant for square matrices + m2x2 mat6 = m2x2(2, 1, 4, 3); + TYPE det2x2 = TYPE(determinant(mat6)); + TYPE det3x3 = TYPE(determinant(mat2)); + + // Test transpose + matrix<TYPE, 2, 2> trans2x2 = transpose(mat1); + matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2); + + // Test element-wise min/max + m2x2 mat_min = min(mat1, mat5); + m2x2 mat_max = max(mat1, mat5); + + // Test all/any operations (these return bool, but we'll cast to TYPE for output) + m2x2 zero_mat = m2x2(0, 0, 0, 0); + m2x2 mixed_mat = m2x2(1, 0, 2, 0); + + TYPE all_nonzero = TYPE(all(mat1)); + TYPE all_zero = TYPE(all(zero_mat)); + TYPE any_nonzero = TYPE(any(mixed_mat)); + TYPE any_zero = TYPE(any(zero_mat)); + + // Test bit shift operations + m2x2 shift_mat = m2x2(1, 2, 4, 8); + m2x2 left_shift = shift_mat << 1; + m2x2 right_shift = shift_mat >> 1; + + // Test comparison operations (these return bool matrices, cast to TYPE for output) + m2x2 comp_mat1 = m2x2(1, 3, 2, 4); + m2x2 comp_mat2 = m2x2(2, 2, 3, 3); + + matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2; + matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2; + matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2; + matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2; + matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2; + matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2; + + // Test matrix negation operations + m2x2 neg_mat = m2x2(1, -2, 3, -4); + m2x2 negated = -neg_mat; + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 8 + outputBuffer[2] = row.x; + // CHECK-NEXT: 3 + outputBuffer[3] = row.y; + // CHECK-NEXT: 4 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 2 + outputBuffer[5] = mat_scalar[0][0]; + // CHECK-NEXT: 2 + outputBuffer[6] = mat_add[0][0]; + // CHECK-NEXT: 3 + outputBuffer[7] = mat_sub[0][0]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat_mul[1][1]; + // CHECK-NEXT: 28 + outputBuffer[9] = added; + // CHECK-NEXT: 10 + outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0]; + // CHECK-NEXT: 5 + + // Matrix intrinsic operation results + outputBuffer[11] = det2x2; + // CHECK-NEXT: 2 + outputBuffer[12] = det3x3; + // CHECK-NEXT: 0 + outputBuffer[13] = mat_min[0][0]; + // CHECK-NEXT: 1 + outputBuffer[14] = mat_min[1][1]; + // CHECK-NEXT: 4 + outputBuffer[15] = mat_max[0][0]; + // CHECK-NEXT: 2 + outputBuffer[16] = mat_max[1][1]; + // CHECK-NEXT: 7 + outputBuffer[17] = all_nonzero; + // CHECK-NEXT: 1 + outputBuffer[18] = all_zero; + // CHECK-NEXT: 0 + outputBuffer[19] = any_nonzero; + // CHECK-NEXT: 1 + outputBuffer[20] = any_zero; + // CHECK-NEXT: 0 + outputBuffer[21] = trans2x2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[22] = trans2x2[1][0]; + // CHECK-NEXT: 2 + outputBuffer[23] = trans2x3[0][0]; + // CHECK-NEXT: 5 + + // Bit shift operation results + outputBuffer[24] = left_shift[0][0]; + // CHECK-NEXT: 2 + outputBuffer[25] = left_shift[0][1]; + // CHECK-NEXT: 4 + outputBuffer[26] = right_shift[1][0]; + // CHECK-NEXT: 2 + outputBuffer[27] = right_shift[1][1]; + // CHECK-NEXT: 4 + + // Comparison operation results (bool matrices cast to TYPE) + outputBuffer[28] = TYPE(less_than[0][0]); + // CHECK-NEXT: 1 + outputBuffer[29] = TYPE(less_than[0][1]); + // CHECK-NEXT: 0 + outputBuffer[30] = TYPE(greater_than[0][1]); + // CHECK-NEXT: 1 + outputBuffer[31] = TYPE(greater_than[1][1]); + // CHECK-NEXT: 1 + outputBuffer[32] = TYPE(less_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[33] = TYPE(less_equal[0][1]); + // CHECK-NEXT: 0 + outputBuffer[34] = TYPE(greater_equal[0][1]); + // CHECK-NEXT: 1 + outputBuffer[35] = TYPE(greater_equal[1][0]); + // CHECK-NEXT: 0 + outputBuffer[36] = TYPE(equal_to[0][0]); + // CHECK-NEXT: 0 + outputBuffer[37] = TYPE(negated[0][0] == expectedBuffer[0]); + // CHECK-NEXT: 1 + outputBuffer[38] = TYPE(negated[1][1] == expectedBuffer[1]); + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/spirv/matrix-bool-lowering.slang b/tests/spirv/matrix-bool-lowering.slang index 63b7caacf..f903fbf17 100644 --- a/tests/spirv/matrix-bool-lowering.slang +++ b/tests/spirv/matrix-bool-lowering.slang @@ -1,6 +1,6 @@ //TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -vk -shaderobj -xslang -emit-spirv-directly -//TEST_INPUT:ubuffer(data=[1 0], stride=4):in,name inputBuffer +//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer //TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer RWStructuredBuffer<int> inputBuffer; RWStructuredBuffer<int> outputBuffer; diff --git a/tests/spirv/matrix-integer-lowering.slang b/tests/spirv/matrix-integer-lowering.slang index 518d0f78b..fded652a4 100644 --- a/tests/spirv/matrix-integer-lowering.slang +++ b/tests/spirv/matrix-integer-lowering.slang @@ -10,8 +10,10 @@ typealias m2x3 = matrix<TYPE, 2, 3>; typealias m3x3 = matrix<TYPE, 3, 3>; typealias m2x4 = matrix<TYPE, 2, 4>; -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer RWStructuredBuffer<TYPE> outputBuffer; +RWStructuredBuffer<TYPE> expectedBuffer; struct matrixWrapper { m2x2 mat1 = m2x2(1, 2, 3, 4); @@ -103,6 +105,10 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2; matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2; + // Test matrix negation operations + m2x2 neg_mat = m2x2(1, -2, 3, -4); + m2x2 negated = -neg_mat; + // Store results outputBuffer[0] = val1; // CHECK: 1 @@ -186,4 +192,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) // CHECK-NEXT: 0 outputBuffer[37] = TYPE(not_equal[0][0]); // CHECK-NEXT: 1 + outputBuffer[38] = TYPE(negated[0][0] == expectedBuffer[0]); + // CHECK-NEXT: 1 + outputBuffer[39] = TYPE(negated[1][1] == expectedBuffer[1]); + // CHECK-NEXT: 1 }
\ No newline at end of file diff --git a/tests/wgsl/matrix-bool-lowering.slang b/tests/wgsl/matrix-bool-lowering.slang new file mode 100644 index 000000000..4803fa73a --- /dev/null +++ b/tests/wgsl/matrix-bool-lowering.slang @@ -0,0 +1,114 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -wgsl -shaderobj + +//TEST_INPUT:ubuffer(data=[1 0], stride=4):name inputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> inputBuffer; +RWStructuredBuffer<int> outputBuffer; + +// Global bool constants to avoid constant folding +static bool trueVal; +static bool falseVal; + +struct matrixWrapper { + bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal); +} + +bool elementAnd(bool2x2 matrix) +{ + return trueVal + && matrix[0][0] + && matrix[0][1] + && matrix[1][0] + && matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Load true/false values from input buffer to avoid constant folding + trueVal = inputBuffer[0] != 0; + falseVal = inputBuffer[1] != 0; + + // Test bool matrix construction + bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal); + bool3x3 mat2 = bool3x3( + trueVal, falseVal, trueVal, + falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal + ); + bool2x4 mat3 = bool2x4( + trueVal, falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal, falseVal + ); + + // Test bool matrix element access + bool val1 = mat1[0][0]; + bool val2 = mat2[2][1]; + + // Test bool matrix row access + bool2 row = mat1[1]; + bool3 row3 = mat2[0]; + + // Test logical operations + bool2x2 not_mat = !mat1; + bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal); + + // Test element assignment + mat1[0][1] = trueVal; + mat2[1][2] = falseVal; + + // Test passing bool matrices to functions + bool anded = elementAnd(mat1); + + // Test structs with bool matrix fields + matrixWrapper wrapper = {}; + + // Test any/all operations + bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal); + bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal); + + bool test_all_true = all(all_true); // all elements true -> true + bool test_all_false = all(all_false); // all elements false -> false + bool test_all_mixed = all(mixed); // some elements false -> false + bool test_any_true = any(all_true); // some elements true -> true + bool test_any_false = any(all_false); // no elements true -> false + bool test_any_mixed = any(mixed); // some elements true -> true + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 0 + outputBuffer[2] = row.x; + // CHECK-NEXT: 0 + outputBuffer[3] = row.y; + // CHECK-NEXT: 1 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 0 + outputBuffer[5] = not_mat[0][0]; + // CHECK-NEXT: 0 + outputBuffer[6] = and_mat[0][0]; + // CHECK-NEXT: 1 + outputBuffer[7] = mat1[0][1]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat3[0][1]; + // CHECK-NEXT: 0 + outputBuffer[9] = anded; + // CHECK-NEXT: 0 + outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[11] = test_all_true; + // CHECK-NEXT: 1 + outputBuffer[12] = test_all_false; + // CHECK-NEXT: 0 + outputBuffer[13] = test_all_mixed; + // CHECK-NEXT: 0 + outputBuffer[14] = test_any_true; + // CHECK-NEXT: 1 + outputBuffer[15] = test_any_false; + // CHECK-NEXT: 0 + outputBuffer[16] = test_any_mixed; + // CHECK-NEXT: 1 +}
\ No newline at end of file diff --git a/tests/wgsl/matrix-integer-lowering.slang b/tests/wgsl/matrix-integer-lowering.slang new file mode 100644 index 000000000..fc2a64382 --- /dev/null +++ b/tests/wgsl/matrix-integer-lowering.slang @@ -0,0 +1,199 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -wgsl -shaderobj -xslang -DTYPE=int +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -wgsl -shaderobj -xslang -DTYPE=uint + +#ifndef TYPE +#define TYPE int +#endif + +typealias m2x2 = matrix<TYPE, 2, 2>; +typealias m2x3 = matrix<TYPE, 2, 3>; +typealias m3x3 = matrix<TYPE, 3, 3>; +typealias m2x4 = matrix<TYPE, 2, 4>; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +//TEST_INPUT:ubuffer(data=[-1 4], stride=4):name expectedBuffer +RWStructuredBuffer<TYPE> outputBuffer; +RWStructuredBuffer<TYPE> expectedBuffer; + +struct matrixWrapper { + m2x2 mat1 = m2x2(1, 2, 3, 4); + m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10); +}; + +TYPE elementAdd(m2x2 matrix) +{ + return matrix[0][0] + + matrix[0][1] + + matrix[1][0] + + matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Test matrix construction + m2x2 mat1 = m2x2(1, 2, 3, 4); + m3x3 mat2 = m3x3( + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ); + m2x4 mat3 = m2x4( + 10, 11, 12, 13, + 14, 15, 16, 17 + ); + + // Test matrix element access + TYPE val1 = mat1[0][0]; + TYPE val2 = mat2[2][1]; + + // Test matrix row access + vector<TYPE, 2> row = mat1[1]; + vector<TYPE, 3> row3 = mat2[0]; + + // Test arithmetic operations + m2x2 mat5 = m2x2(2, 4, 6, 7); + + m2x2 mat_scalar = 2 * mat1; + m2x2 mat_add = mat1 + mat5; + m2x2 mat_sub = mat5 - mat1; + m2x2 mat_mul = mat1 * mat5; + + // Test passing matrices to functions + TYPE added = elementAdd(mat1); + + // Test structs with matrix fields + matrixWrapper wrapper = {}; + + // Test matrix intrinsic operations + + // Test determinant for square matrices + m2x2 mat6 = m2x2(2, 1, 4, 3); + TYPE det2x2 = TYPE(determinant(mat6)); + TYPE det3x3 = TYPE(determinant(mat2)); + + // Test transpose + matrix<TYPE, 2, 2> trans2x2 = transpose(mat1); + matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2); + + // Test element-wise min/max + m2x2 mat_min = min(mat1, mat5); + m2x2 mat_max = max(mat1, mat5); + + // Test all/any operations (these return bool, but we'll cast to TYPE for output) + m2x2 zero_mat = m2x2(0, 0, 0, 0); + m2x2 mixed_mat = m2x2(1, 0, 2, 0); + + TYPE all_nonzero = TYPE(all(mat1)); + TYPE all_zero = TYPE(all(zero_mat)); + TYPE any_nonzero = TYPE(any(mixed_mat)); + TYPE any_zero = TYPE(any(zero_mat)); + + // Test bit shift operations + m2x2 shift_mat = m2x2(1, 2, 4, 8); + m2x2 left_shift = shift_mat << 1; + m2x2 right_shift = shift_mat >> 1; + + // Test comparison operations (these return bool matrices, cast to TYPE for output) + m2x2 comp_mat1 = m2x2(1, 3, 2, 4); + m2x2 comp_mat2 = m2x2(2, 2, 3, 3); + + matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2; + matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2; + matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2; + matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2; + matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2; + matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2; + + // Test matrix negation operations + m2x2 neg_mat = m2x2(1, -2, 3, -4); + m2x2 negated = -neg_mat; + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 8 + outputBuffer[2] = row.x; + // CHECK-NEXT: 3 + outputBuffer[3] = row.y; + // CHECK-NEXT: 4 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 2 + outputBuffer[5] = mat_scalar[0][0]; + // CHECK-NEXT: 2 + outputBuffer[6] = mat_add[0][0]; + // CHECK-NEXT: 3 + outputBuffer[7] = mat_sub[0][0]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat_mul[1][1]; + // CHECK-NEXT: 28 + outputBuffer[9] = added; + // CHECK-NEXT: 10 + outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0]; + // CHECK-NEXT: 5 + + // Matrix intrinsic operation results + outputBuffer[11] = det2x2; + // CHECK-NEXT: 2 + outputBuffer[12] = det3x3; + // CHECK-NEXT: 0 + outputBuffer[13] = mat_min[0][0]; + // CHECK-NEXT: 1 + outputBuffer[14] = mat_min[1][1]; + // CHECK-NEXT: 4 + outputBuffer[15] = mat_max[0][0]; + // CHECK-NEXT: 2 + outputBuffer[16] = mat_max[1][1]; + // CHECK-NEXT: 7 + outputBuffer[17] = all_nonzero; + // CHECK-NEXT: 1 + outputBuffer[18] = all_zero; + // CHECK-NEXT: 0 + outputBuffer[19] = any_nonzero; + // CHECK-NEXT: 1 + outputBuffer[20] = any_zero; + // CHECK-NEXT: 0 + outputBuffer[21] = trans2x2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[22] = trans2x2[1][0]; + // CHECK-NEXT: 2 + outputBuffer[23] = trans2x3[0][0]; + // CHECK-NEXT: 5 + + // Bit shift operation results + outputBuffer[24] = left_shift[0][0]; + // CHECK-NEXT: 2 + outputBuffer[25] = left_shift[0][1]; + // CHECK-NEXT: 4 + outputBuffer[26] = right_shift[1][0]; + // CHECK-NEXT: 2 + outputBuffer[27] = right_shift[1][1]; + // CHECK-NEXT: 4 + + // Comparison operation results (bool matrices cast to TYPE) + outputBuffer[28] = TYPE(less_than[0][0]); + // CHECK-NEXT: 1 + outputBuffer[29] = TYPE(less_than[0][1]); + // CHECK-NEXT: 0 + outputBuffer[30] = TYPE(greater_than[0][1]); + // CHECK-NEXT: 1 + outputBuffer[31] = TYPE(greater_than[1][1]); + // CHECK-NEXT: 1 + outputBuffer[32] = TYPE(less_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[33] = TYPE(less_equal[0][1]); + // CHECK-NEXT: 0 + outputBuffer[34] = TYPE(greater_equal[0][1]); + // CHECK-NEXT: 1 + outputBuffer[35] = TYPE(greater_equal[1][0]); + // CHECK-NEXT: 0 + outputBuffer[36] = TYPE(equal_to[0][0]); + // CHECK-NEXT: 0 + outputBuffer[37] = TYPE(not_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[38] = TYPE(negated[0][0] == expectedBuffer[0]); + // CHECK-NEXT: 1 + outputBuffer[39] = TYPE(negated[1][1] == expectedBuffer[1]); + // CHECK-NEXT: 1 +}
\ No newline at end of file |
