diff options
| -rw-r--r-- | source/slang/hlsl.meta.slang | 53 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 172 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-binary-operator.cpp | 157 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.cpp | 141 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-matrix-types.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.cpp | 18 | ||||
| -rw-r--r-- | tests/compute/integer-matrix-diagnostic.slang | 22 | ||||
| -rw-r--r-- | tests/spirv/matrix-bool-lowering.slang | 114 | ||||
| -rw-r--r-- | tests/spirv/matrix-integer-lowering.slang | 189 |
12 files changed, 738 insertions, 152 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 264098bec..2ac886f61 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -8616,6 +8616,55 @@ T determinant(matrix<T,N,N> m) } } +/// @param m The matrix. +/// @return The determinant of the matrix. +/// @category math +__generic<T : __BuiltinIntegerType, let N : int> +[__readNone] +[require(glsl_hlsl_metal_spirv_wgsl)] +T determinant(matrix<T,N,N> m) +{ + __target_switch + { + case glsl: __intrinsic_asm "determinant"; + case hlsl: __intrinsic_asm "determinant"; + case metal: __intrinsic_asm "determinant"; + case wgsl: __intrinsic_asm "determinant"; + // SPIR-V doesn't support integer determinants, so we need to implement it manually + default: + static_assert(N >= 1 && N <= 4, "determinant is only implemented up to 4x4 matrices"); + if (N == 1) + { + return m[0][0]; + } + else if (N == 2) + { + return m[0][0] * m[1][1] - m[0][1] * m[1][0]; + } + else if (N == 3) + { + return + m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1]) + - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0]) + + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]); + } + else// if (N == 4) + { + T a = m[2][2] * m[3][3] - m[2][3] * m[3][2]; + T b = m[2][1] * m[3][3] - m[2][3] * m[3][1]; + T c = m[2][1] * m[3][2] - m[2][2] * m[3][1]; + T d = m[2][0] * m[3][3] - m[2][3] * m[3][0]; + T e = m[2][0] * m[3][2] - m[2][2] * m[3][0]; + T f = m[2][0] * m[3][1] - m[2][1] * m[3][0]; + return + m[0][0] * (m[1][1] * a - m[1][2] * b + m[1][3] * c) + - m[0][1] * (m[1][0] * a - m[1][2] * d + m[1][3] * e) + + m[0][2] * (m[1][0] * b - m[1][1] * d + m[1][3] * f) + - m[0][3] * (m[1][0] * c - m[1][1] * e + m[1][2] * f); + } + } +} + /// Barrier for device memory. /// @category barrier __glsl_extension(GL_KHR_memory_scope_semantics) @@ -13720,10 +13769,8 @@ matrix<T, M, N> transpose(matrix<T, N, M> x) { case glsl: __intrinsic_asm "transpose"; case hlsl: __intrinsic_asm "transpose"; - case spirv: return spirv_asm { - OpTranspose $$matrix<T, M, N> result $x - }; case wgsl: __intrinsic_asm "transpose"; + // SPIRV-V doenst't support integer matrices, so transpose it manually default: matrix<T, M, N> result; for (int r = 0; r < M; ++r) diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 0ce5d9f47..3dafda3be 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -2201,12 +2201,6 @@ DIAGNOSTIC( DIAGNOSTIC(39999, Fatal, complationCeased, "compilation ceased") DIAGNOSTIC( - 38202, - Error, - matrixWithDisallowedElementTypeEncountered, - "matrix with disallowed element type '$0' encountered") - -DIAGNOSTIC( 38203, Error, vectorWithDisallowedElementTypeEncountered, diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 2b6f1c821..bbed44c51 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -216,7 +216,7 @@ struct SpvInst : SpvInstParent // // > Word Count: The complete number of words taken by an instruction, // > including the word holding the word count and opcode, and any optional - // > operands. An instruction’s word count is the total space taken by the instruction. + // > operands. An instruction's word count is the total space taken by the instruction. // SpvWord wordCount = 1 + SpvWord(operandWordsCount); @@ -360,7 +360,7 @@ struct SpvLiteralBits // > UTF-8 encoding scheme. The UTF-8 octets (8-bit bytes) are packed // > four per word, following the little-endian convention (i.e., the // > first octet is in the lowest-order 8 bits of the word). - // > The final word contains the string’s nul-termination character (0), and + // > The final word contains the string's nul-termination character (0), and // > all contents past the end of the string in the final word are padded with 0. // First work out the amount of words we'll need @@ -2039,17 +2039,24 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex case kIROp_MatrixType: { auto matrixType = static_cast<IRMatrixType*>(inst); + auto elementType = matrixType->getElementType(); + + // SPIR-V only supports floating-point matrices + // bool/int matrices should be lowered to + // arrays of vectors before reaching here + SLANG_ASSERT(!as<IRBoolType>(elementType)); + SLANG_ASSERT(!as<IRIntType>(elementType)); + SLANG_ASSERT(!as<IRUIntType>(elementType)); + auto vectorSpvType = ensureVectorType( - static_cast<IRBasicType*>(matrixType->getElementType())->getBaseType(), + static_cast<IRBasicType*>(elementType)->getBaseType(), static_cast<IRIntLit*>(matrixType->getColumnCount())->getValue(), nullptr); const auto columnCount = static_cast<IRIntLit*>(matrixType->getRowCount())->getValue(); - auto matrixSPVType = emitOpTypeMatrix( - inst, - vectorSpvType, - SpvLiteralInteger::from32(int32_t(columnCount))); - return matrixSPVType; + const auto columnCountSpv = SpvLiteralInteger::from32(int32_t(columnCount)); + SpvInst* matrixSpvType = emitOpTypeMatrix(inst, vectorSpvType, columnCountSpv); + return matrixSpvType; } case kIROp_ArrayType: case kIROp_UnsizedArrayType: @@ -2621,7 +2628,7 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SpvWord arrayed = inst->isArray() ? ImageOpConstants::isArrayed : ImageOpConstants::notArrayed; - // Vulkan spec 16.1: "The “Depth” operand of OpTypeImage is ignored." + // Vulkan spec 16.1: "The "Depth" operand of OpTypeImage is ignored." SpvWord depth = ImageOpConstants::unknownDepthImage; // No knowledge of if this is a depth image SpvWord ms = inst->isMultisample() ? ImageOpConstants::isMultisampled @@ -7767,12 +7774,40 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex // Otherwise, operands are raw elements, we need to construct row vectors first, // then construct matrix from row vectors. List<SpvInst*> rowVectors; - auto matrixType = as<IRMatrixType>(inst->getDataType()); - auto rowCount = getIntVal(matrixType->getRowCount()); - auto colCount = getIntVal(matrixType->getColumnCount()); + + IRIntegerValue rowCount; + IRIntegerValue colCount; + IRType* elementType; + + // Data type can be either matrix or vector depending on the + // legalization requirements + auto dataType = inst->getDataType(); + + if (auto matrixType = as<IRMatrixType>(dataType)) + { + elementType = matrixType->getElementType(); + rowCount = getIntVal(matrixType->getRowCount()); + colCount = getIntVal(matrixType->getColumnCount()); + } + else if (auto arrayType = as<IRArrayType>(dataType)) + { + auto vectorType = as<IRVectorType>(arrayType->getElementType()); + SLANG_ASSERT(vectorType); + + elementType = vectorType->getElementType(); + rowCount = getIntVal(arrayType->getElementCount()); + colCount = getIntVal(vectorType->getElementCount()); + } + else + { + SLANG_UNEXPECTED("data type for makeMatrix operation is " + "expected be either a matrix or array type"); + } + IRBuilder builder(inst); builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); + auto rowVectorType = builder.getVectorType(elementType, colCount); + List<IRInst*> colElements; UInt index = 0; for (IRIntegerValue j = 0; j < rowCount; j++) @@ -7897,7 +7932,10 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex ArrayView<IRInst*> operands) { IRType* elementType = getVectorOrCoopMatrixElementType(operands[0]->getDataType()); + SLANG_ASSERT(elementType); + IRBasicType* basicType = as<IRBasicType>(elementType); + SLANG_ASSERT(basicType); SpvOp opCode = _arithmeticOpCodeConvert(op, basicType); if (opCode == SpvOpUndef) @@ -7958,6 +7996,52 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex SLANG_UNREACHABLE("Arithmetic op with 0 or more than 2 operands"); } + // Helper method to handle composite arithmetic operations for matrices and arrays + SpvInst* emitCompositeArithmetic( + SpvInstParent* parent, + IRInst* inst, + IRIntegerValue rowCount, + IRIntegerValue colCount, + IRType* elementType, + IRType* resultType, + bool isMatrixType) + { + IRBuilder builder(inst); + builder.setInsertBefore(inst); + auto rowVectorType = builder.getVectorType(elementType, colCount); + List<SpvInst*> rows; + + for (IRIntegerValue i = 0; i < rowCount; i++) + { + List<IRInst*> operands; + for (UInt j = 0; j < inst->getOperandCount(); j++) + { + auto originalOperand = inst->getOperand(j); + bool shouldExtract = + isMatrixType ? as<IRMatrixType>(originalOperand->getDataType()) != nullptr + : as<IRArrayType>(originalOperand->getDataType()) != nullptr; + + if (shouldExtract) + { + auto operand = builder.emitElementExtract(originalOperand, i); + emitLocalInst(parent, operand); + operands.add(operand); + } + else + { + operands.add(originalOperand); + } + } + rows.add(emitVectorOrScalarArithmetic( + parent, + nullptr, + rowVectorType, + inst->getOp(), + inst->getOperandCount(), + operands.getArrayView())); + } + return emitCompositeConstruct(parent, inst, resultType, rows); + } SpvInst* emitArithmetic(SpvInstParent* parent, IRInst* inst) { @@ -7965,36 +8049,38 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex { auto rowCount = getIntVal(matrixType->getRowCount()); auto colCount = getIntVal(matrixType->getColumnCount()); - IRBuilder builder(inst); - builder.setInsertBefore(inst); - auto rowVectorType = builder.getVectorType(matrixType->getElementType(), colCount); - List<SpvInst*> rows; - for (IRIntegerValue i = 0; i < rowCount; i++) - { - List<IRInst*> operands; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto originalOperand = inst->getOperand(j); - if (as<IRMatrixType>(originalOperand->getDataType())) - { - auto operand = builder.emitElementExtract(originalOperand, i); - emitLocalInst(parent, operand); - operands.add(operand); - } - else - { - operands.add(originalOperand); - } - } - rows.add(emitVectorOrScalarArithmetic( - parent, - nullptr, - rowVectorType, - inst->getOp(), - inst->getOperandCount(), - operands.getArrayView())); - } - return emitCompositeConstruct(parent, inst, inst->getDataType(), rows); + return emitCompositeArithmetic( + parent, + inst, + rowCount, + colCount, + matrixType->getElementType(), + inst->getDataType(), + true); + } + else if (const auto arrayType = as<IRArrayType>(inst->getDataType())) + { + // Only for legalization + auto arrayElementType = arrayType->getElementType(); + SLANG_ASSERT(as<IRVectorType>(arrayElementType)); + + auto vectorType = as<IRVectorType>(arrayElementType); + auto elementType = vectorType->getElementType(); + SLANG_ASSERT( + as<IRBoolType>(elementType) || as<IRUIntType>(elementType) || + as<IRIntType>(elementType)); + + auto rowCount = getIntVal(arrayType->getElementCount()); + auto colCount = getIntVal(vectorType->getElementCount()); + + return emitCompositeArithmetic( + parent, + inst, + rowCount, + colCount, + elementType, + inst->getDataType(), + false); } Array<IRInst*, 4> operands; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index f40679bd9..067b5a551 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -61,6 +61,7 @@ #include "slang-ir-legalize-empty-array.h" #include "slang-ir-legalize-global-values.h" #include "slang-ir-legalize-image-subscript.h" +#include "slang-ir-legalize-matrix-types.h" #include "slang-ir-legalize-mesh-outputs.h" #include "slang-ir-legalize-uniform-buffer-load.h" #include "slang-ir-legalize-varying-params.h" @@ -1334,6 +1335,7 @@ Result linkAndOptimizeIR( legalizeEmptyTypes(targetProgram, irModule, sink); } + legalizeMatrixTypes(targetProgram, irModule, sink); legalizeVectorTypes(irModule, sink); // Once specialization and type legalization have been performed, diff --git a/source/slang/slang-ir-legalize-binary-operator.cpp b/source/slang/slang-ir-legalize-binary-operator.cpp index f2f7cdef2..24ba61fc6 100644 --- a/source/slang/slang-ir-legalize-binary-operator.cpp +++ b/source/slang/slang-ir-legalize-binary-operator.cpp @@ -176,93 +176,124 @@ void legalizeBinaryOp(IRInst* inst, DiagnosticSink* sink, CodeGenTarget target) void legalizeLogicalAndOr(IRInst* inst) { - switch (inst->getOp()) + auto op = inst->getOp(); + if (op == kIROp_And || op == kIROp_Or) { - case kIROp_And: - case kIROp_Or: + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + // Logical-AND and logical-OR takes boolean types as its operands. + // If they are not, legalize them by casting to boolean type. + // + SLANG_ASSERT(inst->getOperandCount() == 2); + for (UInt i = 0; i < 2; i++) { - IRBuilder builder(inst); - builder.setInsertBefore(inst); - - // Logical-AND and logical-OR takes boolean types as its operands. - // If they are not, legalize them by casting to boolean type. - // - SLANG_ASSERT(inst->getOperandCount() == 2); - for (UInt i = 0; i < 2; i++) - { - auto operand = inst->getOperand(i); - auto operandDataType = operand->getDataType(); + auto operand = inst->getOperand(i); + auto operandDataType = operand->getDataType(); - if (auto vecType = as<IRVectorType>(operandDataType)) - { - if (!as<IRBoolType>(vecType->getElementType())) - { - // Cast operand to vector<bool,N> - auto elemCount = vecType->getElementCount(); - auto vb = builder.getVectorType(builder.getBoolType(), elemCount); - auto v = builder.emitCast(vb, operand); - builder.replaceOperand(inst->getOperands() + i, v); - } - } - else if (!as<IRBoolType>(operandDataType)) - { - // Cast operand to bool - auto s = builder.emitCast(builder.getBoolType(), operand); - builder.replaceOperand(inst->getOperands() + i, s); - } - } + SLANG_ASSERT( + as<IRMatrixType>(operandDataType) || as<IRVectorType>(operandDataType) || + as<IRArrayType>(operandDataType) || as<IRBoolType>(operandDataType)); - // Legalize the return type; mostly for SPIRV. - // The return type of OpLogicalOr must be boolean type. - // If not, we need to recreate the instruction with boolean return type. - // Then, we have to cast it back to the original type so that other instrucitons that - // use have the matching types. - // - auto dataType = inst->getDataType(); - auto lhs = inst->getOperand(0); - auto rhs = inst->getOperand(1); - IRInst* newInst = nullptr; - - if (auto vecType = as<IRVectorType>(dataType)) + if (auto vecType = as<IRVectorType>(operandDataType)) { if (!as<IRBoolType>(vecType->getElementType())) { - // Return type should be vector<bool,N> + // Cast operand to vector<bool,N> auto elemCount = vecType->getElementCount(); auto vb = builder.getVectorType(builder.getBoolType(), elemCount); - - if (inst->getOp() == kIROp_And) - { - newInst = builder.emitAnd(vb, lhs, rhs); - } - else - { - newInst = builder.emitOr(vb, lhs, rhs); - } - newInst = builder.emitCast(dataType, newInst); + auto v = builder.emitCast(vb, operand); + builder.replaceOperand(inst->getOperands() + i, v); } } - else if (!as<IRBoolType>(dataType)) + } + + // Legalize the return type; mostly for SPIRV. + // The return type of OpLogicalOr must be boolean type. + // If not, we need to recreate the instruction with boolean return type. + // Then, we have to cast it back to the original type so that other instrucitons that + // use have the matching types. + // + auto dataType = inst->getDataType(); + auto lhs = inst->getOperand(0); + auto rhs = inst->getOperand(1); + IRInst* newInst = nullptr; + + SLANG_ASSERT( + as<IRMatrixType>(dataType) || as<IRVectorType>(dataType) || as<IRBoolType>(dataType) || + as<IRArrayType>(dataType)); + if (auto vecType = as<IRVectorType>(dataType)) + { + if (!as<IRBoolType>(vecType->getElementType())) { - // Return type should be bool + // Return type should be vector<bool,N> + auto elemCount = vecType->getElementCount(); + auto vb = builder.getVectorType(builder.getBoolType(), elemCount); + if (inst->getOp() == kIROp_And) { - newInst = builder.emitAnd(builder.getBoolType(), lhs, rhs); + newInst = builder.emitAnd(vb, lhs, rhs); } else { - newInst = builder.emitOr(builder.getBoolType(), lhs, rhs); + newInst = builder.emitOr(vb, lhs, rhs); } newInst = builder.emitCast(dataType, newInst); } + } + else if (auto arrayType = as<IRArrayType>(dataType)) + { + // Handle lowered matrices (arrays of vectors) + auto arrayVecType = as<IRVectorType>(arrayType->getElementType()); + SLANG_ASSERT(arrayVecType); + + // At this point, lhs and rhs should already be converted to bool arrays + auto lhsArrayType = as<IRArrayType>(lhs->getDataType()); + auto rhsArrayType = as<IRArrayType>(rhs->getDataType()); + SLANG_ASSERT(lhsArrayType && rhsArrayType); + + auto lhsVecType = as<IRVectorType>(lhsArrayType->getElementType()); + auto rhsVecType = as<IRVectorType>(rhsArrayType->getElementType()); + SLANG_ASSERT(lhsVecType && rhsVecType); + + SLANG_ASSERT( + as<IRBoolType>(lhsVecType->getElementType()) && + as<IRBoolType>(rhsVecType->getElementType())); - if (newInst && inst != newInst) + auto arraySize = arrayType->getElementCount(); + List<IRInst*> resultElements; + + // Extract each vector from both arrays, perform AND/OR, collect results + for (IRIntegerValue i = 0; i < getIntVal(arraySize); i++) { - inst->replaceUsesWith(newInst); - inst->removeAndDeallocate(); + auto indexVal = builder.getIntValue(builder.getIntType(), i); + auto lhsElement = builder.emitElementExtract(lhs, indexVal); + auto rhsElement = builder.emitElementExtract(rhs, indexVal); + + IRInst* resultElement; + if (inst->getOp() == kIROp_And) + { + resultElement = + builder.emitAnd(lhsElement->getDataType(), lhsElement, rhsElement); + } + else + { + resultElement = + builder.emitOr(lhsElement->getDataType(), lhsElement, rhsElement); + } + resultElements.add(resultElement); } + + // Construct the result array from the individual vector results + newInst = + builder.emitMakeArray(dataType, getIntVal(arraySize), resultElements.getBuffer()); + } + + if (newInst && inst != newInst) + { + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); } - break; } for (auto child : inst->getModifiableChildren()) diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp new file mode 100644 index 000000000..0b972b5bd --- /dev/null +++ b/source/slang/slang-ir-legalize-matrix-types.cpp @@ -0,0 +1,141 @@ +#include "slang-ir-legalize-matrix-types.h" + +#include "slang-compiler.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir.h" + +namespace Slang +{ + +struct MatrixTypeLoweringContext +{ + TargetProgram* targetProgram; + IRModule* module; + DiagnosticSink* sink; + + InstWorkList workList; + InstHashSet workListSet; + + Dictionary<IRInst*, IRInst*> replacements; + + MatrixTypeLoweringContext(TargetProgram* targetProgram, IRModule* module) + : targetProgram(targetProgram), module(module), workList(module), workListSet(module) + { + } + + void addToWorkList(IRInst* inst) + { + for (auto ii = inst->getParent(); ii; ii = ii->getParent()) + { + if (as<IRGeneric>(ii)) + return; + } + + if (workListSet.contains(inst)) + return; + + workList.add(inst); + workListSet.add(inst); + } + + bool shouldLowerTarget() + { + auto target = targetProgram->getTargetReq()->getTarget(); + switch (target) + { + case CodeGenTarget::SPIRV: + case CodeGenTarget::SPIRVAssembly: + case CodeGenTarget::GLSL: + case CodeGenTarget::WGSL: + case CodeGenTarget::WGSLSPIRV: + case CodeGenTarget::WGSLSPIRVAssembly: + return true; + default: + return false; + } + } + + bool shouldLowerMatrixType(IRMatrixType* matrixType) + { + if (!shouldLowerTarget()) + return false; + + auto elementType = matrixType->getElementType(); + return as<IRBoolType>(elementType) || as<IRUIntType>(elementType) || + as<IRIntType>(elementType); + } + + IRInst* getReplacement(IRInst* inst) + { + if (auto replacement = replacements.tryGetValue(inst)) + return *replacement; + + IRInst* newInst = inst; + + if (auto matrixType = as<IRMatrixType>(inst)) + { + if (shouldLowerMatrixType(matrixType)) + { + // Lower matrix<T, R, C> to T[R][C] (array of R vectors of length C) + auto elementType = matrixType->getElementType(); + auto rowCount = matrixType->getRowCount(); + auto columnCount = matrixType->getColumnCount(); + + IRBuilder builder(matrixType); + builder.setInsertBefore(matrixType); + + // Create vector type for columns: vector<T, C> + auto vectorType = builder.getVectorType(elementType, columnCount); + + // Create array type for rows: vector<T, C>[R] + auto arrayType = builder.getArrayType(vectorType, rowCount); + + newInst = arrayType; + } + } + + replacements[inst] = newInst; + return newInst; + } + + void processModule() + { + addToWorkList(module->getModuleInst()); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.remove(inst); + + // Run this inst through the replacer + getReplacement(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + + // Apply all replacements + for (const auto& [old, replacement] : replacements) + { + if (old != replacement) + { + old->replaceUsesWith(replacement); + old->removeAndDeallocate(); + } + } + } +}; + +void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink) +{ + MatrixTypeLoweringContext context(targetProgram, module); + context.sink = sink; + context.processModule(); +} + +} // namespace Slang
\ No newline at end of file diff --git a/source/slang/slang-ir-legalize-matrix-types.h b/source/slang/slang-ir-legalize-matrix-types.h new file mode 100644 index 000000000..418e80a83 --- /dev/null +++ b/source/slang/slang-ir-legalize-matrix-types.h @@ -0,0 +1,13 @@ +#pragma once + +namespace Slang +{ + +struct IRModule; +class DiagnosticSink; +class TargetProgram; + +// Lower int/uint/bool matrix types to arrays for SPIRV, WGSL, and GLSL targets +void legalizeMatrixTypes(TargetProgram* targetProgram, IRModule* module, DiagnosticSink* sink); + +} // namespace Slang
\ No newline at end of file diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 1b11f8165..a1d043dff 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -453,6 +453,9 @@ void getTypeNameHint(StringBuilder& sb, IRInst* type) switch (type->getOp()) { + case kIROp_BoolType: + sb << "bool"; + break; case kIROp_FloatType: sb << "float"; break; diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index bf5d8ed5d..156fe249f 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -598,21 +598,9 @@ void validateVectorsAndMatrices( } } - // Verify that the element type is a floating point type, or an allowed integral type - auto elementType = matrixType->getElementType(); - uint32_t allowedWidths = 0U; - if (isCPUTarget(targetRequest)) - allowedWidths = 8U | 16U | 32U | 64U; - else if (isCUDATarget(targetRequest)) - allowedWidths = 32U | 64U; - else if (isD3DTarget(targetRequest)) - allowedWidths = 16U | 32U; - validateVectorOrMatrixElementType( - sink, - matrixType->sourceLoc, - elementType, - allowedWidths, - Diagnostics::matrixWithDisallowedElementTypeEncountered); + // Matrix element type validation removed to allow integer/bool matrices + // which will be lowered to arrays of vectors on targets that don't support them + // natively } else if (auto vectorType = as<IRVectorType>(globalInst)) { diff --git a/tests/compute/integer-matrix-diagnostic.slang b/tests/compute/integer-matrix-diagnostic.slang deleted file mode 100644 index bd69c28e4..000000000 --- a/tests/compute/integer-matrix-diagnostic.slang +++ /dev/null @@ -1,22 +0,0 @@ -// Check that using matrices with integer floating point type yields the correct diagnostic - -//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target glsl -entry computeMain -stage compute -//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target metal -entry computeMain -stage compute -//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute -//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target wgsl -entry computeMain -stage compute - -cbuffer MatrixBuffer -{ - // CHECK: error 38202 - int4x4 iMatrix; -} - -//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out, name=outputBuffer -RWStructuredBuffer<int4> outputBuffer; - -[numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) -{ - uint index = dispatchThreadID.x; - outputBuffer[index] = iMatrix[0][0]; -}
\ No newline at end of file diff --git a/tests/spirv/matrix-bool-lowering.slang b/tests/spirv/matrix-bool-lowering.slang new file mode 100644 index 000000000..63b7caacf --- /dev/null +++ b/tests/spirv/matrix-bool-lowering.slang @@ -0,0 +1,114 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -vk -shaderobj -xslang -emit-spirv-directly + +//TEST_INPUT:ubuffer(data=[1 0], stride=4):in,name inputBuffer +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> inputBuffer; +RWStructuredBuffer<int> outputBuffer; + +// Global bool constants to avoid constant folding +static bool trueVal; +static bool falseVal; + +struct matrixWrapper { + bool2x2 mat1 = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x3 mat2 = bool2x3(trueVal, trueVal, falseVal, falseVal, falseVal, trueVal); +} + +bool elementAnd(bool2x2 matrix) +{ + return trueVal + && matrix[0][0] + && matrix[0][1] + && matrix[1][0] + && matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Load true/false values from input buffer to avoid constant folding + trueVal = inputBuffer[0] != 0; + falseVal = inputBuffer[1] != 0; + + // Test bool matrix construction + bool2x2 mat1 = bool2x2(trueVal, falseVal, falseVal, trueVal); + bool3x3 mat2 = bool3x3( + trueVal, falseVal, trueVal, + falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal + ); + bool2x4 mat3 = bool2x4( + trueVal, falseVal, trueVal, falseVal, + trueVal, falseVal, trueVal, falseVal + ); + + // Test bool matrix element access + bool val1 = mat1[0][0]; + bool val2 = mat2[2][1]; + + // Test bool matrix row access + bool2 row = mat1[1]; + bool3 row3 = mat2[0]; + + // Test logical operations + bool2x2 not_mat = !mat1; + bool2x2 and_mat = mat1 && bool2x2(trueVal, trueVal, falseVal, falseVal); + + // Test element assignment + mat1[0][1] = trueVal; + mat2[1][2] = falseVal; + + // Test passing bool matrices to functions + bool anded = elementAnd(mat1); + + // Test structs with bool matrix fields + matrixWrapper wrapper = {}; + + // Test any/all operations + bool2x2 all_true = bool2x2(trueVal, trueVal, trueVal, trueVal); + bool2x2 all_false = bool2x2(falseVal, falseVal, falseVal, falseVal); + bool2x2 mixed = bool2x2(trueVal, falseVal, trueVal, falseVal); + + bool test_all_true = all(all_true); // all elements true -> true + bool test_all_false = all(all_false); // all elements false -> false + bool test_all_mixed = all(mixed); // some elements false -> false + bool test_any_true = any(all_true); // some elements true -> true + bool test_any_false = any(all_false); // no elements true -> false + bool test_any_mixed = any(mixed); // some elements true -> true + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 0 + outputBuffer[2] = row.x; + // CHECK-NEXT: 0 + outputBuffer[3] = row.y; + // CHECK-NEXT: 1 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 0 + outputBuffer[5] = not_mat[0][0]; + // CHECK-NEXT: 0 + outputBuffer[6] = and_mat[0][0]; + // CHECK-NEXT: 1 + outputBuffer[7] = mat1[0][1]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat3[0][1]; + // CHECK-NEXT: 0 + outputBuffer[9] = anded; + // CHECK-NEXT: 0 + outputBuffer[10] = wrapper.mat1[0][0] || wrapper.mat2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[11] = test_all_true; + // CHECK-NEXT: 1 + outputBuffer[12] = test_all_false; + // CHECK-NEXT: 0 + outputBuffer[13] = test_all_mixed; + // CHECK-NEXT: 0 + outputBuffer[14] = test_any_true; + // CHECK-NEXT: 1 + outputBuffer[15] = test_any_false; + // CHECK-NEXT: 0 + outputBuffer[16] = test_any_mixed; + // CHECK-NEXT: 1 +} diff --git a/tests/spirv/matrix-integer-lowering.slang b/tests/spirv/matrix-integer-lowering.slang new file mode 100644 index 000000000..518d0f78b --- /dev/null +++ b/tests/spirv/matrix-integer-lowering.slang @@ -0,0 +1,189 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -vk -shaderobj -xslang -emit-spirv-directly -xslang -DTYPE=int +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -output-using-type -compute -vk -shaderobj -xslang -emit-spirv-directly -xslang -DTYPE=uint + +#ifndef TYPE +#define TYPE int +#endif + +typealias m2x2 = matrix<TYPE, 2, 2>; +typealias m2x3 = matrix<TYPE, 2, 3>; +typealias m3x3 = matrix<TYPE, 3, 3>; +typealias m2x4 = matrix<TYPE, 2, 4>; + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<TYPE> outputBuffer; + +struct matrixWrapper { + m2x2 mat1 = m2x2(1, 2, 3, 4); + m2x3 mat2 = m2x3(5, 6, 7, 8, 9, 10); +}; + +TYPE elementAdd(m2x2 matrix) +{ + return matrix[0][0] + + matrix[0][1] + + matrix[1][0] + + matrix[1][1]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Test matrix construction + m2x2 mat1 = m2x2(1, 2, 3, 4); + m3x3 mat2 = m3x3( + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ); + m2x4 mat3 = m2x4( + 10, 11, 12, 13, + 14, 15, 16, 17 + ); + + // Test matrix element access + TYPE val1 = mat1[0][0]; + TYPE val2 = mat2[2][1]; + + // Test matrix row access + vector<TYPE, 2> row = mat1[1]; + vector<TYPE, 3> row3 = mat2[0]; + + // Test arithmetic operations + m2x2 mat5 = m2x2(2, 4, 6, 7); + + m2x2 mat_scalar = 2 * mat1; + m2x2 mat_add = mat1 + mat5; + m2x2 mat_sub = mat5 - mat1; + m2x2 mat_mul = mat1 * mat5; + + // Test passing matrices to functions + TYPE added = elementAdd(mat1); + + // Test structs with matrix fields + matrixWrapper wrapper = {}; + + // Test matrix intrinsic operations + + // Test determinant for square matrices + m2x2 mat6 = m2x2(2, 1, 4, 3); + TYPE det2x2 = TYPE(determinant(mat6)); + TYPE det3x3 = TYPE(determinant(mat2)); + + // Test transpose + matrix<TYPE, 2, 2> trans2x2 = transpose(mat1); + matrix<TYPE, 3, 2> trans2x3 = transpose(wrapper.mat2); + + // Test element-wise min/max + m2x2 mat_min = min(mat1, mat5); + m2x2 mat_max = max(mat1, mat5); + + // Test all/any operations (these return bool, but we'll cast to TYPE for output) + m2x2 zero_mat = m2x2(0, 0, 0, 0); + m2x2 mixed_mat = m2x2(1, 0, 2, 0); + + TYPE all_nonzero = TYPE(all(mat1)); + TYPE all_zero = TYPE(all(zero_mat)); + TYPE any_nonzero = TYPE(any(mixed_mat)); + TYPE any_zero = TYPE(any(zero_mat)); + + // Test bit shift operations + m2x2 shift_mat = m2x2(1, 2, 4, 8); + m2x2 left_shift = shift_mat << 1; + m2x2 right_shift = shift_mat >> 1; + + // Test comparison operations (these return bool matrices, cast to TYPE for output) + m2x2 comp_mat1 = m2x2(1, 3, 2, 4); + m2x2 comp_mat2 = m2x2(2, 2, 3, 3); + + matrix<bool, 2, 2> less_than = comp_mat1 < comp_mat2; + matrix<bool, 2, 2> greater_than = comp_mat1 > comp_mat2; + matrix<bool, 2, 2> less_equal = comp_mat1 <= comp_mat2; + matrix<bool, 2, 2> greater_equal = comp_mat1 >= comp_mat2; + matrix<bool, 2, 2> equal_to = comp_mat1 == comp_mat2; + matrix<bool, 2, 2> not_equal = comp_mat1 != comp_mat2; + + // Store results + outputBuffer[0] = val1; + // CHECK: 1 + outputBuffer[1] = val2; + // CHECK-NEXT: 8 + outputBuffer[2] = row.x; + // CHECK-NEXT: 3 + outputBuffer[3] = row.y; + // CHECK-NEXT: 4 + outputBuffer[4] = row3.y; + // CHECK-NEXT: 2 + outputBuffer[5] = mat_scalar[0][0]; + // CHECK-NEXT: 2 + outputBuffer[6] = mat_add[0][0]; + // CHECK-NEXT: 3 + outputBuffer[7] = mat_sub[0][0]; + // CHECK-NEXT: 1 + outputBuffer[8] = mat_mul[1][1]; + // CHECK-NEXT: 28 + outputBuffer[9] = added; + // CHECK-NEXT: 10 + outputBuffer[10] = wrapper.mat1[0][0] * wrapper.mat2[0][0]; + // CHECK-NEXT: 5 + + // Matrix intrinsic operation results + outputBuffer[11] = det2x2; + // CHECK-NEXT: 2 + outputBuffer[12] = det3x3; + // CHECK-NEXT: 0 + outputBuffer[13] = mat_min[0][0]; + // CHECK-NEXT: 1 + outputBuffer[14] = mat_min[1][1]; + // CHECK-NEXT: 4 + outputBuffer[15] = mat_max[0][0]; + // CHECK-NEXT: 2 + outputBuffer[16] = mat_max[1][1]; + // CHECK-NEXT: 7 + outputBuffer[17] = all_nonzero; + // CHECK-NEXT: 1 + outputBuffer[18] = all_zero; + // CHECK-NEXT: 0 + outputBuffer[19] = any_nonzero; + // CHECK-NEXT: 1 + outputBuffer[20] = any_zero; + // CHECK-NEXT: 0 + outputBuffer[21] = trans2x2[0][0]; + // CHECK-NEXT: 1 + outputBuffer[22] = trans2x2[1][0]; + // CHECK-NEXT: 2 + outputBuffer[23] = trans2x3[0][0]; + // CHECK-NEXT: 5 + + // Bit shift operation results + outputBuffer[24] = left_shift[0][0]; + // CHECK-NEXT: 2 + outputBuffer[25] = left_shift[0][1]; + // CHECK-NEXT: 4 + outputBuffer[26] = right_shift[1][0]; + // CHECK-NEXT: 2 + outputBuffer[27] = right_shift[1][1]; + // CHECK-NEXT: 4 + + // Comparison operation results (bool matrices cast to TYPE) + outputBuffer[28] = TYPE(less_than[0][0]); + // CHECK-NEXT: 1 + outputBuffer[29] = TYPE(less_than[0][1]); + // CHECK-NEXT: 0 + outputBuffer[30] = TYPE(greater_than[0][1]); + // CHECK-NEXT: 1 + outputBuffer[31] = TYPE(greater_than[1][1]); + // CHECK-NEXT: 1 + outputBuffer[32] = TYPE(less_equal[0][0]); + // CHECK-NEXT: 1 + outputBuffer[33] = TYPE(less_equal[0][1]); + // CHECK-NEXT: 0 + outputBuffer[34] = TYPE(greater_equal[0][1]); + // CHECK-NEXT: 1 + outputBuffer[35] = TYPE(greater_equal[1][0]); + // CHECK-NEXT: 0 + outputBuffer[36] = TYPE(equal_to[0][0]); + // CHECK-NEXT: 0 + outputBuffer[37] = TYPE(not_equal[0][0]); + // CHECK-NEXT: 1 +}
\ No newline at end of file |
