diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 65 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 170 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.cpp | 435 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.h | 2 |
6 files changed, 523 insertions, 168 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index aa494ec95..9fd5c8b6e 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6481,14 +6481,10 @@ bool all(T x) { __target_switch { - default: - __intrinsic_asm "bool($0)"; case hlsl: __intrinsic_asm "all"; case metal: __intrinsic_asm "all"; - case wgsl: - __intrinsic_asm "all"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -6505,6 +6501,8 @@ bool all(T x) return __slang_noop_cast<bool>(x); else return false; + default: + __intrinsic_asm "bool($0)"; } } @@ -6550,9 +6548,17 @@ bool all(vector<T,N> x) }; } case wgsl: + // WGSL all() only works with boolean vectors if (__isBool<T>()) - __intrinsic_asm "all"; - __intrinsic_asm "all(vec$N0<bool>($0))"; + __intrinsic_asm "all($0)"; + else + { + // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion + bool result = true; + for(int i = 0; i < N; ++i) + result = result && all(x[i]); + return result; + } default: bool result = true; for(int i = 0; i < N; ++i) @@ -6563,7 +6569,7 @@ bool all(vector<T,N> x) __generic<T : __BuiltinType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_metal_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool all(matrix<T,N,M> x) { __target_switch @@ -6655,7 +6661,8 @@ bool any(T x) case metal: __intrinsic_asm "any"; case wgsl: - __intrinsic_asm "any"; + // For scalars, any() doesn't exist in WGSL, just convert to bool + __intrinsic_asm "bool($0)"; case spirv: let zero = __default<T>(); if (__isInt<T>()) @@ -6686,7 +6693,17 @@ bool any(vector<T, N> x) case hlsl: __intrinsic_asm "any"; case metal: - __intrinsic_asm "any"; + if (__isBool<T>()) + __intrinsic_asm "any"; + else + { + // For non-bool types, convert to bool vector first + // Metal's any() only works with bool vectors + bool result = false; + for(int i = 0; i < N; ++i) + result = result || any(x[i]); + return result; + } case glsl: __intrinsic_asm "any(bvec$N0($0))"; case spirv: @@ -6714,7 +6731,17 @@ bool any(vector<T, N> x) }; } case wgsl: - __intrinsic_asm "any"; + // WGSL any() only works with boolean vectors + if (__isBool<T>()) + __intrinsic_asm "any($0)"; + else + { + // Fall back to loop for non-boolean types since WGSL doesn't support direct conversion + bool result = false; + for(int i = 0; i < N; ++i) + result = result || any(x[i]); + return result; + } default: bool result = false; for(int i = 0; i < N; ++i) @@ -6725,7 +6752,7 @@ bool any(vector<T, N> x) __generic<T : __BuiltinType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl)] bool any(matrix<T, N, M> x) { __target_switch @@ -8626,11 +8653,8 @@ T determinant(matrix<T,N,N> m) { __target_switch { - case glsl: __intrinsic_asm "determinant"; case hlsl: __intrinsic_asm "determinant"; - case metal: __intrinsic_asm "determinant"; - case wgsl: __intrinsic_asm "determinant"; - // SPIR-V doesn't support integer determinants, so we need to implement it manually + // GLSL, WGSL, and SPIR-V don't support integer determinants for lowered matrices, so we need to implement it manually default: static_assert(N >= 1 && N <= 4, "determinant is only implemented up to 4x4 matrices"); if (N == 1) @@ -13804,16 +13828,14 @@ matrix<T, M, N> transpose(matrix<T, N, M> x) } __generic<T : __BuiltinIntegerType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] [PreferRecompute] matrix<T, M, N> transpose(matrix<T, N, M> x) { __target_switch { - case glsl: __intrinsic_asm "transpose"; case hlsl: __intrinsic_asm "transpose"; - case wgsl: __intrinsic_asm "transpose"; - // SPIRV-V doenst't support integer matrices, so transpose it manually + // GLSL, WGSL, SPIR-V, and Metal don't support integer matrices when lowered, so transpose it manually default: matrix<T, M, N> result; for (int r = 0; r < M; ++r) @@ -13824,19 +13846,18 @@ matrix<T, M, N> transpose(matrix<T, N, M> x) } __generic<T : __BuiltinLogicalType, let N : int, let M : int> [__readNone] -[require(cpp_cuda_glsl_hlsl_spirv_wgsl, sm_4_0_version)] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_4_0_version)] [PreferRecompute] [OverloadRank(-1)] matrix<T, M, N> transpose(matrix<T, N, M> x) { __target_switch { - case glsl: __intrinsic_asm "transpose"; case hlsl: __intrinsic_asm "transpose"; case spirv: return spirv_asm { OpTranspose $$matrix<T, M, N> result $x }; - case wgsl: __intrinsic_asm "transpose"; + // GLSL, WGSL, and Metal don't support bool matrices when lowered, so transpose it manually default: matrix<T, M, N> result; for (int r = 0; r < M; ++r) diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b1b4c4570..da2620856 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -360,7 +360,7 @@ struct SpvLiteralBits // > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed // > four per word, following the little-endian convention (i.e., the // > first octet is in the lowest-order 8 bits of the word). - // > The final word contains the string's nul-termination character (0), and + // > The final word contains the string’s nul-termination character (0), and // > all contents past the end of the string in the final word are padded with 0. // First work out the amount of words we'll need @@ -2039,24 +2039,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); - auto elementType = matrixType->getElementType(); - - // SPIR-V only supports floating-point matrices - // bool/int matrices should be lowered to - // arrays of vectors before reaching here - SLANG_ASSERT(!as<IRBoolType>(elementType)); - SLANG_ASSERT(!as<IRIntType>(elementType)); - SLANG_ASSERT(!as<IRUIntType>(elementType)); - auto vectorSpvType = ensureVectorType( - static_cast<IRBasicType*>(elementType)->getBaseType(), + static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); - const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount)); - SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv); - return matrixSpvType; + auto matrixSPVType = emitOpTypeMatrix( + inst, + vectorSpvType, + SpvLiteralInteger::from32(int32_t(columnCount))); + return matrixSPVType; } case kIROp_ArrayType: case kIROp_UnsizedArrayType: @@ -2628,7 +2621,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvWord arrayed = inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed; - // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored." + // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored." SpvWord depth = ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled @@ -7780,40 +7773,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Otherwise, operands are raw elements, we need to construct row vectors first, // then construct matrix from row vectors. List<SpvInst*> rowVectors; - - IRIntegerValue rowCount; - IRIntegerValue colCount; - IRType* elementType; - - // Data type can be either matrix or vector depending on the - // legalization requirements - auto dataType = inst->getDataType(); - - if (auto matrixType = as<IRMatrixType>(dataType)) - { - elementType = matrixType->getElementType(); - rowCount = getIntVal(matrixType->getRowCount()); - colCount = getIntVal(matrixType->getColumnCount()); - } - else if (auto arrayType = as<IRArrayType>(dataType)) - { - auto vectorType = as<IRVectorType>(arrayType->getElementType()); - SLANG_ASSERT(vectorType); - - elementType = vectorType->getElementType(); - rowCount = getIntVal(arrayType->getElementCount()); - colCount = getIntVal(vectorType->getElementCount()); - } - else - { - SLANG_UNEXPECTED("data type for makeMatrix operation is " - "expected be either a matrix or array type"); - } - + auto matrixType = cast<IRMatrixType>(inst->getDataType()); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); IRBuilder builder(inst); builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); List<IRInst*> colElements; UInt index = 0; for (IRIntegerValue j = 0; j < rowCount; j++) @@ -7938,10 +7903,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex ArrayView<IRInst*> operands) { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); - SLANG_ASSERT(elementType); - IRBasicType* basicType = as<IRBasicType>(elementType); - SLANG_ASSERT(basicType); SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); if (opCode == SpvOpUndef) @@ -8002,52 +7964,6 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } - // Helper method to handle composite arithmetic operations for matrices and arrays - SpvInst* emitCompositeArithmetic( - SpvInstParent* parent, - IRInst* inst, - IRIntegerValue rowCount, - IRIntegerValue colCount, - IRType* elementType, - IRType* resultType, - bool isMatrixType) - { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(elementType, colCount); - List<SpvInst*> rows; - - for (IRIntegerValue i = 0; i < rowCount; i++) - { - List<IRInst*> operands; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto originalOperand = inst->getOperand(j); - bool shouldExtract = - isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr - : as<IRArrayType>(originalOperand->getDataType()) != nullptr; - - if (shouldExtract) - { - auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); - operands.add(operand); - } - else - { - operands.add(originalOperand); - } - } - rows.add(emitVectorOrScalarArithmetic( - parent, - nullptr, - rowVectorType, - inst->getOp(), - inst->getOperandCount(), - operands.getArrayView())); - } - return emitCompositeConstruct(parent, inst, resultType, rows); - } SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { @@ -8055,38 +7971,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - matrixType->getElementType(), - inst->getDataType(), - true); - } - else if (const auto arrayType = as<IRArrayType>(inst->getDataType())) - { - // Only for legalization - auto arrayElementType = arrayType->getElementType(); - SLANG_ASSERT(as<IRVectorType>(arrayElementType)); - - auto vectorType = as<IRVectorType>(arrayElementType); - auto elementType = vectorType->getElementType(); - SLANG_ASSERT( - as<IRBoolType>(elementType) || as<IRUIntType>(elementType) || - as<IRIntType>(elementType)); - - auto rowCount = getIntVal(arrayType->getElementCount()); - auto colCount = getIntVal(vectorType->getElementCount()); - - return emitCompositeArithmetic( - parent, - inst, - rowCount, - colCount, - elementType, - inst->getDataType(), - false); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + List<SpvInst*> rows; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + if (as<IRMatrixType>(originalOperand->getDataType())) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic( + parent, + nullptr, + rowVectorType, + inst->getOp(), + inst->getOperandCount(), + operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); } Array<IRInst*, 4> operands; diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index fbcb54d10..53c3aa487 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -1624,6 +1624,22 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu m_writer->emit(")"); return true; } + case kIROp_Neg: + { + auto opType = inst->getOperand(0)->getDataType(); + if (as<IRMatrixType>(opType) || as<IRVectorType>(opType)) + { + // WGSL does not support negate operator on matrices and vectors, + // we should emit "(type(0) - op0)" instead. + m_writer->emit("("); + emitType(inst->getDataType()); + m_writer->emit("(0) - "); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + break; + } } return false; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index b548ef632..405bca5a2 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1339,7 +1339,10 @@ Result linkAndOptimizeIR( } legalizeMatrixTypes(targetProgram, irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-MATRIX-LEGALIZATION"); + legalizeVectorTypes(irModule, sink); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-VECTOR-LEGALIZATION"); // Once specialization and type legalization have been performed, // we should perform some of our basic optimization steps again, diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp index 0b972b5bd..8c8cb0c84 100644 --- a/source/slang/slang-ir-legalize-matrix-types.cpp +++ b/source/slang/slang-ir-legalize-matrix-types.cpp @@ -1,6 +1,7 @@ #include "slang-ir-legalize-matrix-types.h" #include "slang-compiler.h" +#include "slang-ir-insts-enum.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -50,6 +51,9 @@ struct MatrixTypeLoweringContext case CodeGenTarget::WGSL: case CodeGenTarget::WGSLSPIRV: case CodeGenTarget::WGSLSPIRVAssembly: + case CodeGenTarget::Metal: + case CodeGenTarget::MetalLib: + case CodeGenTarget::MetalLibAssembly: return true; default: return false; @@ -66,33 +70,430 @@ struct MatrixTypeLoweringContext as<IRIntType>(elementType); } - IRInst* getReplacement(IRInst* inst) + IRInst* legalizeMatrixTypeDeclaration(IRInst* inst) { - if (auto replacement = replacements.tryGetValue(inst)) - return *replacement; + auto matrixType = as<IRMatrixType>(inst); + if (shouldLowerMatrixType(matrixType)) + { + // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C) + auto elementType = matrixType->getElementType(); + auto rowCount = matrixType->getRowCount(); + auto columnCount = matrixType->getColumnCount(); - IRInst* newInst = inst; + IRBuilder builder(matrixType); + builder.setInsertBefore(matrixType); + + // Create vector type for columns: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type for rows: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + return arrayType; + } + return inst; + } + + IRInst* legalizeMakeMatrix(IRInst* inst) + { + auto makeMatrix = as<IRMakeMatrix>(inst); + auto matrixType = as<IRMatrixType>(makeMatrix->getDataType()); + + SLANG_ASSERT(matrixType && "Matrix type is expected"); + SLANG_ASSERT( + shouldLowerMatrixType(matrixType) && "Matrix type is expected to need legalization"); + + // Lower makeMatrix to makeArray of makeVectors + auto elementType = matrixType->getElementType(); + auto rowCount = as<IRIntLit>(matrixType->getRowCount()); + auto columnCount = as<IRIntLit>(matrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); - if (auto matrixType = as<IRMatrixType>(inst)) + IRBuilder builder(makeMatrix); + builder.setInsertBefore(makeMatrix); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Group operands into rows and create vectors + List<IRInst*> rowVectors; + UInt operandIndex = 0; + + // Assert that we have the expected number of operands + SLANG_ASSERT( + makeMatrix->getOperandCount() == UInt(rowCount->getValue() * columnCount->getValue()) && + "makeMatrix operand count must match matrix dimensions"); + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) { - if (shouldLowerMatrixType(matrixType)) + List<IRInst*> rowElements; + for (IRIntegerValue col = 0; col < columnCount->getValue(); col++) { - // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C) - auto elementType = matrixType->getElementType(); - auto rowCount = matrixType->getRowCount(); - auto columnCount = matrixType->getColumnCount(); + SLANG_ASSERT( + operandIndex < makeMatrix->getOperandCount() && "Operand index out of bounds"); + rowElements.add(getReplacement(makeMatrix->getOperand(operandIndex))); + operandIndex++; + } + + SLANG_ASSERT( + rowElements.getCount() == columnCount->getValue() && + "Row elements count must match column count"); + auto rowVector = builder.emitMakeVector(vectorType, rowElements); + rowVectors.add(rowVector); + } + + SLANG_ASSERT( + rowVectors.getCount() == rowCount->getValue() && + "Row vectors count must match matrix row count"); + return builder.emitMakeArray(arrayType, rowVectors.getCount(), rowVectors.getBuffer()); + } + + IRInst* legalizeMatrixMatrixBinaryOperation( + IRBuilder& builder, + IRInst* legalizedA, + IRInst* legalizedB, + IRMatrixType* resultMatrixType, + IROp binaryOp) + { + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); - IRBuilder builder(matrixType); - builder.setInsertBefore(matrixType); + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); - // Create vector type for columns: vector<T, C> - auto vectorType = builder.getVectorType(elementType, columnCount); + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); - // Create array type for rows: vector<T, C>[R] - auto arrayType = builder.getArrayType(vectorType, rowCount); + // Extract vectors from both arrays and apply binary operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from each operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst); + auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst); - newInst = arrayType; + // Apply the binary operation to the vectors + IRInst* args[] = {vectorA, vectorB}; + auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + + template<bool matrixIsFirst> + IRInst* legalizeMatrixMixedBinaryOperation( + IRBuilder& builder, + IRInst* legalizedMatrix, + IRInst* legalizedOther, + IRMatrixType* resultMatrixType, + IROp binaryOp) + { + // Verify that the other operand is either a vector or scalar type + auto otherType = legalizedOther->getDataType(); + auto otherVectorType = as<IRVectorType>(otherType); + auto otherBasicType = as<IRBasicType>(otherType); + SLANG_ASSERT( + (otherVectorType || otherBasicType) && "Other operand must be vector or scalar type"); + + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Extract vectors from matrix array and apply binary operation with other operand + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from matrix array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto matrixRowVector = builder.emitElementExtract(legalizedMatrix, rowIndexInst); + + // Apply the binary operation between matrix row vector and other operand + IRInst* args[2]; + if constexpr (matrixIsFirst) + { + args[0] = matrixRowVector; + args[1] = legalizedOther; } + else + { + args[0] = legalizedOther; + args[1] = matrixRowVector; + } + auto resultVector = builder.emitIntrinsicInst(vectorType, binaryOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeBinaryOperation(IRInst* inst, IROp binaryOp) + { + IRInst* opdA = inst->getOperand(0); + IRInst* opdB = inst->getOperand(1); + + // Check what types we're dealing with + auto typeA = opdA->getDataType(); + auto typeB = opdB->getDataType(); + + auto matrixTypeA = as<IRMatrixType>(typeA); + auto matrixTypeB = as<IRMatrixType>(typeB); + + bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA); + bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB); + + // Get the result matrix type to determine dimensions + auto resultMatrixType = as<IRMatrixType>(inst->getDataType()); + SLANG_ASSERT(resultMatrixType && "Binary operation should have matrix result type"); + SLANG_ASSERT( + shouldLowerMatrixType(resultMatrixType) && + "Result matrix type should need legalization"); + + // Create IRBuilder at the top level + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Get legalized operands once + IRInst* legalizedA = getReplacement(opdA); + IRInst* legalizedB = getReplacement(opdB); + + if (shouldLowerA && shouldLowerB) + { + return legalizeMatrixMatrixBinaryOperation( + builder, + legalizedA, + legalizedB, + resultMatrixType, + binaryOp); + } + else if (shouldLowerA && !shouldLowerB) + { + return legalizeMatrixMixedBinaryOperation<true>( + builder, + legalizedA, + legalizedB, + resultMatrixType, + binaryOp); + } + else if (!shouldLowerA && shouldLowerB) + { + return legalizeMatrixMixedBinaryOperation<false>( + builder, + legalizedB, + legalizedA, + resultMatrixType, + binaryOp); + } + + // Neither operand is a matrix that needs lowering, shouldn't reach here + SLANG_UNREACHABLE("legalizeBinaryOperation called but no matrix operand needs lowering"); + } + + IRInst* legalizeComparisonOperation(IRInst* inst, IROp comparisonOp) + { + IRInst* opdA = inst->getOperand(0); + IRInst* opdB = inst->getOperand(1); + + // Check what types we're dealing with + auto typeA = opdA->getDataType(); + auto typeB = opdB->getDataType(); + + auto matrixTypeA = as<IRMatrixType>(typeA); + auto matrixTypeB = as<IRMatrixType>(typeB); + + bool shouldLowerA = matrixTypeA && shouldLowerMatrixType(matrixTypeA); + bool shouldLowerB = matrixTypeB && shouldLowerMatrixType(matrixTypeB); + + // Only matrix-matrix comparisons are supported + SLANG_ASSERT( + shouldLowerA && shouldLowerB && + "Comparison operations only supported between matrices that need lowering"); + + // Create IRBuilder at the top level + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Get legalized operands + IRInst* legalizedA = getReplacement(opdA); + IRInst* legalizedB = getReplacement(opdB); + + auto rowCount = as<IRIntLit>(matrixTypeA->getRowCount()); + auto columnCount = as<IRIntLit>(matrixTypeA->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + // Create boolean vector type for rows: vector<bool, C> + auto boolType = builder.getBoolType(); + auto boolVectorType = builder.getVectorType(boolType, columnCount); + + // Create array type: vector<bool, C>[R] + auto boolArrayType = builder.getArrayType(boolVectorType, rowCount); + + // Extract vectors from both arrays and apply comparison operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from each operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vectorA = builder.emitElementExtract(legalizedA, rowIndexInst); + auto vectorB = builder.emitElementExtract(legalizedB, rowIndexInst); + + // Apply the comparison operation to the vectors + IRInst* args[] = {vectorA, vectorB}; + auto resultVector = builder.emitIntrinsicInst(boolVectorType, comparisonOp, 2, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + boolArrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeUnaryOperation(IRInst* inst, IROp unaryOp) + { + IRInst* operand = inst->getOperand(0); + + // Get the legalized operand (should be an array of vectors) + IRInst* legalizedOperand = getReplacement(operand); + + // Get the result matrix type to determine dimensions + auto resultMatrixType = as<IRMatrixType>(inst->getDataType()); + SLANG_ASSERT(resultMatrixType && "Unary operation should have matrix result type"); + SLANG_ASSERT( + shouldLowerMatrixType(resultMatrixType) && + "Result matrix type should need legalization"); + + auto elementType = resultMatrixType->getElementType(); + auto rowCount = as<IRIntLit>(resultMatrixType->getRowCount()); + auto columnCount = as<IRIntLit>(resultMatrixType->getColumnCount()); + + SLANG_ASSERT( + rowCount && columnCount && + "Matrix dimensions must be compile-time constants for lowering"); + + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Create vector type for rows: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + // Extract vectors from array and apply unary operation + List<IRInst*> resultVectors; + + for (IRIntegerValue row = 0; row < rowCount->getValue(); row++) + { + // Extract the row vector from operand array + auto rowIndexInst = builder.getIntValue(builder.getIntType(), row); + auto vector = builder.emitElementExtract(legalizedOperand, rowIndexInst); + + // Apply the unary operation to the vector + IRInst* args[] = {vector}; + auto resultVector = builder.emitIntrinsicInst(vectorType, unaryOp, 1, args); + + resultVectors.add(resultVector); + } + + // Create the result array from the vectors + return builder.emitMakeArray( + arrayType, + resultVectors.getCount(), + resultVectors.getBuffer()); + } + + IRInst* legalizeMatrixProducingInstruction(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_MakeMatrix: + return legalizeMakeMatrix(inst); + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + return legalizeBinaryOperation(inst, inst->getOp()); + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Leq: + return legalizeComparisonOperation(inst, inst->getOp()); + case kIROp_Not: + case kIROp_BitNot: + case kIROp_Neg: + return legalizeUnaryOperation(inst, inst->getOp()); + default: + break; + } + + return inst; + } + + IRInst* getReplacement(IRInst* inst) + { + if (auto replacement = replacements.tryGetValue(inst)) + return *replacement; + + IRInst* newInst = inst; + if (as<IRMatrixType>(inst)) + newInst = legalizeMatrixTypeDeclaration(inst); + + IRType* resultType = inst->getDataType(); + if (auto matrixType = as<IRMatrixType>(resultType)) + { + if (shouldLowerMatrixType(matrixType)) + newInst = legalizeMatrixProducingInstruction(inst); } replacements[inst] = newInst; diff --git a/source/slang/slang-ir-legalize-matrix-types.h b/source/slang/slang-ir-legalize-matrix-types.h index 418e80a83..a2e71a402 100644 --- a/source/slang/slang-ir-legalize-matrix-types.h +++ b/source/slang/slang-ir-legalize-matrix-types.h @@ -7,7 +7,7 @@ struct IRModule; class DiagnosticSink; class TargetProgram; -// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, and GLSL targets +// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, GLSL, and Metal targets void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink); } // namespace Slang
\ No newline at end of file |
