diff options
| author | venkataram-nv <vedavamadath@nvidia.com> | 2025-07-30 09:27:38 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-30 16:27:38 +0000 |
| commit | 92ee2927d0012dd454dff7bb53b900f5240073d5 (patch) | |
| tree | d0a648fbb1e6b08c6eec90fadb23435731c1eefe /source/slang | |
| parent | 42dc521f7817328a20e40b3352ae667dfd124edb (diff) | |
Lowering unsupported matrix types for GLSL/WGSL/Metal targets (#7936)
* Add emit cases for WGSL and GLSL
* Fix compilation warnings
Modify short cutting test to reflect change in emit logic
Lower matrix for metal as well
Add emit matrix logic for metal
Fix compiler warning
Brace initializer for lowered matrices
Fix compiler warnings
* Tests for metal
* Fix mult, any, and determinant
* Fix matrix-matrix multiplication
* Fix mat mul to be element-wise
* Fix compiler warning
* Move makeMatrix to legalization
* Move unary and binary arithmetic operator lowering to legalization
* Remove emit logic and move final comparison operators to legalization
* Handle vector/matrix negation for WGSL
* Restore older SPIR-V emit logic
* Address PR comments
* Revert to zero minus for negation
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 65 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 170 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.cpp | 435 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.h | 2 |
6 files changed, 523 insertions, 168 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index aa494ec95..9fd5c8b6e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6481,14 +6481,10 @@ bool all(T x) { __target_switch { - default: - __intrinsic_asm "bool($0)"; case hlsl: __intrinsic_asm "all"; case metal: __intrinsic_asm "all"; - case wgsl: - __intrinsic_asm "all"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -6505,6 +6501,8 @@ bool all(T x) return __slang_noop_cast<bool>(x); else return false; + default: + __intrinsic_asm "bool($0)"; } } @@ -6550,9 +6548,17 @@ bool all(vector<T,N> x) }; } case wgsl: + // WGSL all() only works with boolean vectors if (__isBool<T>()) - __intrinsic_asm "all"; - __intrinsic_asm "all(vec$N0<bool>($0))"; + __intrinsic_asm "all($0)"; + else + { + // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion + bool result = true; + for(int i = 0; i < N; ++i) + result = result && all(x[i]); + return result; + } default: bool result = true; for(int i = 0; i < N; ++i) @@ -6563,7 +6569,7 @@ bool all(vector<T,N> x) __generic<T : __BuiltinType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool all(matrix<T,N,M> x) { __target_switch @@ -6655,7 +6661,8 @@ bool any(T x) case metal: __intrinsic_asm "any"; case wgsl: - __intrinsic_asm "any"; + // For scalars, any() doesn't exist in WGSL, just convert to bool + __intrinsic_asm "bool($0)"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -6686,7 +6693,17 @@ bool any(vector<T, N> x) case hlsl: __intrinsic_asm "any"; case metal: - __intrinsic_asm "any"; + if (__isBool<T>()) + __intrinsic_asm "any"; + else + { + // For non-bool types, convert to bool vector first + // Metal's any() only works with bool vectors + bool result = false; + for(int i = 0; i < N; ++i) + result = result || any(x[i]); + return result; + } case glsl: __intrinsic_asm "any(bvec$N0($0))"; case spirv: @@ -6714,7 +6731,17 @@ bool any(vector<T, N> x) }; } case wgsl: - __intrinsic_asm "any"; + // WGSL any() only works with boolean vectors + if (__isBool<T>()) + __intrinsic_asm "any($0)"; + else + { + // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion + bool result = false; + for(int i = 0; i < N; ++i) + result = result || any(x[i]); + return result; + } default: bool result = false; for(int i = 0; i < N; ++i) @@ -6725,7 +6752,7 @@ bool any(vector<T, N> x) __generic<T : __BuiltinType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool any(matrix<T, N, M> x) { __target_switch @@ -8626,11 +8653,8 @@ T determinant(matrix<T,N,N> m) { __target_switch { - case glsl: __intrinsic_asm "determinant"; case hlsl: __intrinsic_asm "determinant"; - case metal: __intrinsic_asm "determinant"; - case wgsl: __intrinsic_asm "determinant"; - // SPIR-V doesn't support integer determinants, so we need to implement it manually + // GLSL, WGSL, and SPIR-V don't support integer determinants for lowered matrices, so we need to implement it manually default: static_assert(N >= 1 && N <= 4, "determinant is only implemented up to 4x4 matrices"); if (N == 1) @@ -13804,16 +13828,14 @@ matrix<T, M, N> transpose(matrix<T, N, M> x) } __generic<T : __BuiltinIntegerType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] [PreferRecompute] matrix<T, M, N> transpose(matrix<T, N, M> x) { __target_switch { - case glsl: __intrinsic_asm "transpose"; case hlsl: __intrinsic_asm "transpose"; - case wgsl: __intrinsic_asm "transpose"; - // SPIRV-V doenst't support integer matrices, so transpose it manually + // GLSL, WGSL, SPIR-V, and Metal don't support integer matrices when lowered, so transpose it manually default: matrix<T, M, N> result; for (int r = 0; r < M; ++r) @@ -13824,19 +13846,18 @@ matrix<T, M, N> transpose(matrix<T, N, M> x) } __generic<T : __BuiltinLogicalType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] [PreferRecompute] [OverloadRank(-1)] matrix<T, M, N> transpose(matrix<T, N, M> x) { __target_switch { - case glsl: __intrinsic_asm "transpose"; case hlsl: __intrinsic_asm "transpose"; case spirv: return spirv_asm { OpTranspose $$matrix<T, M, N> result $x }; - case wgsl: __intrinsic_asm "transpose"; + // GLSL, WGSL, and Metal don't support bool matrices when lowered, so transpose it manually default: matrix<T, M, N> result; for (int r = 0; r < M; ++r) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b1b4c4570..da2620856 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -360,7 +360,7 @@ struct SpvLiteralBits // > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed // > four per word, following the little-endian convention (i.e., the // > first octet is in the lowest-order 8 bits of the word). - // > The final word contains the string's nul-termination character (0), and + // > The final word contains the string’s nul-termination character (0), and // > all contents past the end of the string in the final word are padded with 0. // First work out the amount of words we'll need @@ -2039,24 +2039,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); - auto elementType = matrixType->getElementType(); - - // SPIR-V only supports floating-point matrices - // bool/int matrices should be lowered to - // arrays of vectors before reaching here - SLANG_ASSERT(!as<IRBoolType>(elementType)); - SLANG_ASSERT(!as<IRIntType>(elementType)); - SLANG_ASSERT(!as<IRUIntType>(elementType)); - auto vectorSpvType = ensureVectorType( - static_cast<IRBasicType*>(elementType)->getBaseType(), + static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); - const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount)); - SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv); - return matrixSpvType; + auto matrixSPVType = emitOpTypeMatrix( + inst, + vectorSpvType, + SpvLiteralInteger::from32(int32_t(columnCount))); + return matrixSPVType; } case kIROp_ArrayType: case kIROp_UnsizedArrayType: @@ -2628,7 +2621,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvWord arrayed = inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed; - // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored." + // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored." SpvWord depth = ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled @@ -7780,40 +7773,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Otherwise, operands are raw elements, we need to construct row vectors first, // then construct matrix from row vectors. List<SpvInst*> rowVectors; - - IRIntegerValue rowCount; - IRIntegerValue colCount; - IRType* elementType; - - // Data type can be either matrix or vector depending on the - // legalization requirements - auto dataType = inst->getDataType(); - - if (auto matrixType = as<IRMatrixType>(dataType)) - { - elementType = matrixType->getElementType(); - rowCount = getIntVal(matrixType->getRowCount()); - colCount = getIntVal(matrixType->getColumnCount()); - } - else if (auto arrayType = as<IRArrayType>(dataType)) - { - auto vectorType = as<IRVectorType>(arrayType->getElementType()); - SLANG_ASSERT(vectorType); - - elementType = vectorType->getElementType(); - rowCount = getIntVal(arrayType->getElementCount()); - colCount = getIntVal(vectorType->getElementCount()); - } - else - { - SLANG_UNEXPECTED("data type for makeMatrix operation is " - "expected be either a matrix or array type"); - } - + auto matrixType = cast<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); IRBuilder builder(inst); builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); List<IRInst*> colElements; UInt index = 0; for (IRIntegerValue j = 0; j < rowCount; j++) @@ -7938,10 +7903,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex ArrayView<IRInst*> operands) { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); - SLANG_ASSERT(elementType); - IRBasicType* basicType = as<IRBasicType>(elementType); - SLANG_ASSERT(basicType); SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); if (opCode == SpvOpUndef) @@ -8002,52 +7964,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } - // Helper method to handle composite arithmetic operations for matrices and arrays - SpvInst* emitCompositeArithmetic( - SpvInstParent* parent, - IRInst* inst, - IRIntegerValue rowCount, - IRIntegerValue colCount, - IRType* elementType, - IRType* resultType, - bool isMatrixType) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - List<SpvInst*> rows; - - for (IRIntegerValue i = 0; i < rowCount; i++) - { - List<IRInst*> operands; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto originalOperand = inst->getOperand(j); - bool shouldExtract = - isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr - : as<IRArrayType>(originalOperand->getDataType()) != nullptr; - - if (shouldExtract) - { - auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); - operands.add(operand); - } - else - { - operands.add(originalOperand); - } - } - rows.add(emitVectorOrScalarArithmetic( - parent, - nullptr, - rowVectorType, - inst->getOp(), - inst->getOperandCount(), - operands.getArrayView())); - } - return emitCompositeConstruct(parent, inst, resultType, rows); - } SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { @@ -8055,38 +7971,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - matrixType->getElementType(), - inst->getDataType(), - true); - } - else if (const auto arrayType = as<IRArrayType>(inst->getDataType())) - { - // Only for legalization - auto arrayElementType = arrayType->getElementType(); - SLANG_ASSERT(as<IRVectorType>(arrayElementType)); - - auto vectorType = as<IRVectorType>(arrayElementType); - auto elementType = vectorType->getElementType(); - SLANG_ASSERT( - as<IRBoolType>(elementType) || as<IRUIntType>(elementType) || - as<IRIntType>(elementType)); - - auto rowCount = getIntVal(arrayType->getElementCount()); - auto colCount = getIntVal(vectorType->getElementCount()); - - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - elementType, - inst->getDataType(), - false); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<SpvInst*> rows; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + if (as<IRMatrixType>(originalOperand->getDataType())) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic( + parent, + nullptr, + rowVectorType, + inst->getOp(), + inst->getOperandCount(), + operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); } Array<IRInst*, 4> operands; diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index fbcb54d10..53c3aa487 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1624,6 +1624,22 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(")"); return true; } + case kIROp_Neg: + { + auto opType = inst->getOperand(0)->getDataType(); + if (as<IRMatrixType>(opType) || as<IRVectorType>(opType)) + { + // WGSL does not support negate operator on matrices and vectors, + // we should emit "(type(0) - op0)" instead. + m_writer->emit("("); + emitType(inst->getDataType()); + m_writer->emit("(0) - "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + break; + } } return false; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index b548ef632..405bca5a2 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1339,7 +1339,10 @@ Result linkAndOptimizeIR( } legalizeMatrixTypes(targetProgram, irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-MATRIX-LEGALIZATION"); + legalizeVectorTypes(irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-VECTOR-LEGALIZATION"); // Once specialization and type legalization have been performed, // we should perform some of our basic optimization steps again, diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp index 0b972b5bd..8c8cb0c84 100644 --- a/source/slang/slang-ir-legalize-matrix-types.cpp +++ b/source/slang/slang-ir-legalize-matrix-types.cpp @@ -1,6 +1,7 @@ #include "slang-ir-legalize-matrix-types.h" #include "slang-compiler.h" +#include "slang-ir-insts-enum.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -50,6 +51,9 @@ struct MatrixTypeLoweringContext case CodeGenTarget::WGSL: case CodeGenTarget::WGSLSPIRV: case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::Metal: + case CodeGenTarget::MetalLib: + case CodeGenTarget::MetalLibAssembly: return true; default: return false; @@ -66,33 +70,430 @@ struct MatrixTypeLoweringContext as<IRIntType>(elementType); } - IRInst* getReplacement(IRInst* inst) + IRInst* legalizeMatrixTypeDeclaration(IRInst* inst) { - if (auto replacement = replacements.tryGetValue(inst)) - return *replacement; + auto matrixType = as<IRMatrixType>(inst); + if (shouldLowerMatrixType(matrixType)) + { + // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C) + auto elementType = matrixType->getElementType(); + auto rowCount = matrixType->getRowCount(); + auto columnCount = matrixType->getColumnCount(); - IRInst* newInst = inst; + IRBuilder builder(matrixType); + builder.setInsertBefore(matrixType); + + // Create vector type for columns: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type for rows: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + return arrayType; + } + return inst; + } + + IRInst* legalizeMakeMatrix(IRInst* inst) + { + auto makeMatrix = as<IRMakeMatrix>(inst); + auto matrixType = as<IRMatrixType>(makeMatrix->getDataType()); + + SLANG_ASSERT(matrixType && "Matrix type is expected"); + SLANG_ASSERT( + shouldLowerMatrixType(matrixType) && "Matrix type is expected to need legalization"); + + // Lower makeMatrix to makeArray of makeVectors + auto elementType = matrixType->getElementType(); + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto columnCount = as<IRIntLit>(matrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); - if (auto matrixType = as<IRMatrixType>(inst)) + IRBuilder builder(makeMatrix); + builder.setInsertBefore(makeMatrix); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Group operands into rows and create vectors + List<IRInst*> rowVectors; + UInt operandIndex = 0; + + // Assert that we have the expected number of operands + SLANG_ASSERT( + makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()) && + "makeMatrix operand count must match matrix dimensions"); + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) { - if (shouldLowerMatrixType(matrixType)) + List<IRInst*> rowElements; + for (IRIntegerValue col = 0; col < columnCount->getValue(); col++) { - // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C) - auto elementType = matrixType->getElementType(); - auto rowCount = matrixType->getRowCount(); - auto columnCount = matrixType->getColumnCount(); + SLANG_ASSERT( + operandIndex < makeMatrix->getOperandCount() && "Operand index out of bounds"); + rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex))); + operandIndex++; + } + + SLANG_ASSERT( + rowElements.getCount() == columnCount->getValue() && + "Row elements count must match column count"); + auto rowVector = builder.emitMakeVector(vectorType, rowElements); + rowVectors.add(rowVector); + } + + SLANG_ASSERT( + rowVectors.getCount() == rowCount->getValue() && + "Row vectors count must match matrix row count"); + return builder.emitMakeArray(arrayType, rowVectors.getCount(), rowVectors.getBuffer()); + } + + IRInst* legalizeMatrixMatrixBinaryOperation( + IRBuilder& builder, + IRInst* legalizedA, + IRInst* legalizedB, + IRMatrixType* resultMatrixType, + IROp binaryOp) + { + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); - IRBuilder builder(matrixType); - builder.setInsertBefore(matrixType); + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); - // Create vector type for columns: vector<T, C> - auto vectorType = builder.getVectorType(elementType, columnCount); + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); - // Create array type for rows: vector<T, C>[R] - auto arrayType = builder.getArrayType(vectorType, rowCount); + // Extract vectors from both arrays and apply binary operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from each operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst); + auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst); - newInst = arrayType; + // Apply the binary operation to the vectors + IRInst* args[] = {vectorA, vectorB}; + auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + + template<bool matrixIsFirst> + IRInst* legalizeMatrixMixedBinaryOperation( + IRBuilder& builder, + IRInst* legalizedMatrix, + IRInst* legalizedOther, + IRMatrixType* resultMatrixType, + IROp binaryOp) + { + // Verify that the other operand is either a vector or scalar type + auto otherType = legalizedOther->getDataType(); + auto otherVectorType = as<IRVectorType>(otherType); + auto otherBasicType = as<IRBasicType>(otherType); + SLANG_ASSERT( + (otherVectorType || otherBasicType) && "Other operand must be vector or scalar type"); + + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Extract vectors from matrix array and apply binary operation with other operand + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from matrix array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto matrixRowVector = builder.emitElementExtract(legalizedMatrix, rowIndexInst); + + // Apply the binary operation between matrix row vector and other operand + IRInst* args[2]; + if constexpr (matrixIsFirst) + { + args[0] = matrixRowVector; + args[1] = legalizedOther; } + else + { + args[0] = legalizedOther; + args[1] = matrixRowVector; + } + auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeBinaryOperation(IRInst* inst, IROp binaryOp) + { + IRInst* opdA = inst->getOperand(0); + IRInst* opdB = inst->getOperand(1); + + // Check what types we're dealing with + auto typeA = opdA->getDataType(); + auto typeB = opdB->getDataType(); + + auto matrixTypeA = as<IRMatrixType>(typeA); + auto matrixTypeB = as<IRMatrixType>(typeB); + + bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA); + bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB); + + // Get the result matrix type to determine dimensions + auto resultMatrixType = as<IRMatrixType>(inst->getDataType()); + SLANG_ASSERT(resultMatrixType && "Binary operation should have matrix result type"); + SLANG_ASSERT( + shouldLowerMatrixType(resultMatrixType) && + "Result matrix type should need legalization"); + + // Create IRBuilder at the top level + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Get legalized operands once + IRInst* legalizedA = getReplacement(opdA); + IRInst* legalizedB = getReplacement(opdB); + + if (shouldLowerA && shouldLowerB) + { + return legalizeMatrixMatrixBinaryOperation( + builder, + legalizedA, + legalizedB, + resultMatrixType, + binaryOp); + } + else if (shouldLowerA && !shouldLowerB) + { + return legalizeMatrixMixedBinaryOperation<true>( + builder, + legalizedA, + legalizedB, + resultMatrixType, + binaryOp); + } + else if (!shouldLowerA && shouldLowerB) + { + return legalizeMatrixMixedBinaryOperation<false>( + builder, + legalizedB, + legalizedA, + resultMatrixType, + binaryOp); + } + + // Neither operand is a matrix that needs lowering, shouldn't reach here + SLANG_UNREACHABLE("legalizeBinaryOperation called but no matrix operand needs lowering"); + } + + IRInst* legalizeComparisonOperation(IRInst* inst, IROp comparisonOp) + { + IRInst* opdA = inst->getOperand(0); + IRInst* opdB = inst->getOperand(1); + + // Check what types we're dealing with + auto typeA = opdA->getDataType(); + auto typeB = opdB->getDataType(); + + auto matrixTypeA = as<IRMatrixType>(typeA); + auto matrixTypeB = as<IRMatrixType>(typeB); + + bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA); + bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB); + + // Only matrix-matrix comparisons are supported + SLANG_ASSERT( + shouldLowerA && shouldLowerB && + "Comparison operations only supported between matrices that need lowering"); + + // Create IRBuilder at the top level + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Get legalized operands + IRInst* legalizedA = getReplacement(opdA); + IRInst* legalizedB = getReplacement(opdB); + + auto rowCount = as<IRIntLit>(matrixTypeA->getRowCount()); + auto columnCount = as<IRIntLit>(matrixTypeA->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + // Create boolean vector type for rows: vector<bool, C> + auto boolType = builder.getBoolType(); + auto boolVectorType = builder.getVectorType(boolType, columnCount); + + // Create array type: vector<bool, C>[R] + auto boolArrayType = builder.getArrayType(boolVectorType, rowCount); + + // Extract vectors from both arrays and apply comparison operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from each operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst); + auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst); + + // Apply the comparison operation to the vectors + IRInst* args[] = {vectorA, vectorB}; + auto resultVector = builder.emitIntrinsicInst(boolVectorType, comparisonOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + boolArrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeUnaryOperation(IRInst* inst, IROp unaryOp) + { + IRInst* operand = inst->getOperand(0); + + // Get the legalized operand (should be an array of vectors) + IRInst* legalizedOperand = getReplacement(operand); + + // Get the result matrix type to determine dimensions + auto resultMatrixType = as<IRMatrixType>(inst->getDataType()); + SLANG_ASSERT(resultMatrixType && "Unary operation should have matrix result type"); + SLANG_ASSERT( + shouldLowerMatrixType(resultMatrixType) && + "Result matrix type should need legalization"); + + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Extract vectors from array and apply unary operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vector = builder.emitElementExtract(legalizedOperand, rowIndexInst); + + // Apply the unary operation to the vector + IRInst* args[] = {vector}; + auto resultVector = builder.emitIntrinsicInst(vectorType, unaryOp, 1, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeMatrixProducingInstruction(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_MakeMatrix: + return legalizeMakeMatrix(inst); + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + return legalizeBinaryOperation(inst, inst->getOp()); + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + return legalizeComparisonOperation(inst, inst->getOp()); + case kIROp_Not: + case kIROp_BitNot: + case kIROp_Neg: + return legalizeUnaryOperation(inst, inst->getOp()); + default: + break; + } + + return inst; + } + + IRInst* getReplacement(IRInst* inst) + { + if (auto replacement = replacements.tryGetValue(inst)) + return *replacement; + + IRInst* newInst = inst; + if (as<IRMatrixType>(inst)) + newInst = legalizeMatrixTypeDeclaration(inst); + + IRType* resultType = inst->getDataType(); + if (auto matrixType = as<IRMatrixType>(resultType)) + { + if (shouldLowerMatrixType(matrixType)) + newInst = legalizeMatrixProducingInstruction(inst); } replacements[inst] = newInst; diff --git a/source/slang/slang-ir-legalize-matrix-types.h b/source/slang/slang-ir-legalize-matrix-types.h index 418e80a83..a2e71a402 100644 --- a/source/slang/slang-ir-legalize-matrix-types.h +++ b/source/slang/slang-ir-legalize-matrix-types.h @@ -7,7 +7,7 @@ struct IRModule; class DiagnosticSink; class TargetProgram; -// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, and GLSL targets +// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, GLSL, and Metal targets void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink); } // namespace Slang
\ No newline at end of file |
