summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorvenkataram-nv <vedavamadath@nvidia.com>2025-07-31 15:12:21 -0700
committerGitHub <noreply@github.com>2025-07-31 22:12:21 +0000
commit30fd3c63fb4af9ea8d482c75921710df1b40e59e (patch)
treecd1001e90f5328f20fa7bc6d030bcfcc4e01979f /source
parentaefd1e3e0dbe4e77f8d7dbbfa04e15c2db615394 (diff)
Add matrix select intrinsic (#7566)
* Add matrix select intrinsic * Fix hlsl test * Restrict matrix select to HLSL * Better test for HLSL side * Select route for GLSL/SPIRV * Exclude matrices from select legalization * Exclude CUDA from select test * Inline and move * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang28
-rw-r--r--source/slang/slang-ir-legalize-composite-select.cpp23
-rw-r--r--source/slang/slang-ir-legalize-matrix-types.cpp55
-rw-r--r--source/slang/slang-ir.cpp14
4 files changed, 102 insertions, 18 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index f42cd25cf..2c6bf04d1 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -958,6 +958,31 @@ __generic<T, let N : int> __intrinsic_op(select) vector<T,N> operator?:(vector<b
__generic<T> __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse);
__generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse);
+[require(hlsl)]
+__generic<T, let N : int, let M : int> __intrinsic_op(select) matrix<T,N,M> __hlsl_select(matrix<bool,N,M> condition, matrix<T,N,M> ifTrue, matrix<T,N,M> ifFalse);
+
+__generic<T, let N : int, let M : int>
+matrix<T,N,M> select(matrix<bool,N,M> condition, matrix<T,N,M> ifTrue, matrix<T,N,M> ifFalse)
+{
+ __target_switch
+ {
+ case hlsl:
+ return __hlsl_select(condition, ifTrue, ifFalse);
+ default:
+ matrix<T,N,M> result;
+ [[unroll]]
+ for (uint32_t i = 0; i < N; i++)
+ result[i] = select(condition[i], ifTrue[i], ifFalse[i]);
+ return result;
+ }
+}
+
+__generic<T, let N : int, let M : int>
+matrix<T,N,M> operator?:(matrix<bool,N,M> condition, matrix<T,N,M> ifTrue, matrix<T,N,M> ifFalse)
+{
+ return select(condition, ifTrue, ifFalse);
+}
+
[ForceInline]
__generic<T> Optional<T> select(bool condition, __none_t ifTrue, T ifFalse)
{
@@ -973,7 +998,8 @@ __generic<T> Optional<T> select(bool condition, T ifTrue, __none_t ifFalse)
// Allow real-number types to be cast into each other
//@hidden:
__intrinsic_op($(kIROp_FloatCast))
- T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val);
+T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val);
+
//@hidden:
__intrinsic_op($(kIROp_CastIntToFloat))
T __realCast<T : __BuiltinRealType, U : __BuiltinIntegerType>(U val);
diff --git a/source/slang/slang-ir-legalize-composite-select.cpp b/source/slang/slang-ir-legalize-composite-select.cpp
index 1b2ba0670..cdd37efa3 100644
--- a/source/slang/slang-ir-legalize-composite-select.cpp
+++ b/source/slang/slang-ir-legalize-composite-select.cpp
@@ -9,7 +9,8 @@
namespace Slang
{
-void legalizeASingleNonVectorCompositeSelect(IRBuilder& builder, IRSelect* selectInst)
+
+void legalizeCompositeSelect(IRBuilder& builder, IRSelect* selectInst)
{
SLANG_ASSERT(selectInst);
@@ -49,6 +50,7 @@ void legalizeASingleNonVectorCompositeSelect(IRBuilder& builder, IRSelect* selec
// Clean up
selectInst->removeAndDeallocate();
}
+
void legalizeNonVectorCompositeSelect(IRModule* module)
{
IRBuilder builder(module);
@@ -57,23 +59,24 @@ void legalizeNonVectorCompositeSelect(IRModule* module)
auto func = as<IRFunc>(globalInst);
if (!func)
continue;
+
for (auto block : func->getBlocks())
{
- auto inst = block->getFirstInst();
- IRInst* next;
- for (; inst; inst = next)
+ for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
{
- next = inst->getNextInst();
- switch (inst->getOp())
+ if (auto select = as<IRSelect>(inst))
{
- case kIROp_Select:
// Replace OpSelect with if/else branch (same process as glslang)
- if (!isScalarOrVectorType(inst->getFullType()))
- legalizeASingleNonVectorCompositeSelect(builder, as<IRSelect>(inst));
- continue;
+ bool requiresLegalization = !as<IRBasicType>(select->getFullType()) &&
+ !as<IRVectorType>(select->getFullType()) &&
+ !as<IRMatrixType>(select->getFullType());
+
+ if (requiresLegalization)
+ legalizeCompositeSelect(builder, select);
}
}
}
}
}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp
index 8c8cb0c84..327fa7ead 100644
--- a/source/slang/slang-ir-legalize-matrix-types.cpp
+++ b/source/slang/slang-ir-legalize-matrix-types.cpp
@@ -154,6 +154,59 @@ struct MatrixTypeLoweringContext
return builder.emitMakeArray(arrayType, rowVectors.getCount(), rowVectors.getBuffer());
}
+ IRInst* legalizeMakeMatrixFromScalar(IRInst* inst)
+ {
+ auto matrixType = as<IRMatrixType>(inst->getDataType());
+
+ SLANG_ASSERT(matrixType && "Matrix type is expected");
+ SLANG_ASSERT(
+ shouldLowerMatrixType(matrixType) && "Matrix type is expected to need legalization");
+
+ // Lower makeMatrixFromScalar to makeArray of makeVectors from scalar
+ auto elementType = matrixType->getElementType();
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ auto columnCount = as<IRIntLit>(matrixType->getColumnCount());
+
+ SLANG_ASSERT(
+ rowCount && columnCount &&
+ "Matrix dimensions must be compile-time constants for lowering");
+
+ SLANG_ASSERT(
+ inst->getOperandCount() == 1 && "makeMatrixFromScalar should have exactly one operand");
+
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+
+ // Get the scalar operand
+ auto scalarOperand = getReplacement(inst->getOperand(0));
+
+ // Create vector type for rows: vector<T, C>
+ auto vectorType = builder.getVectorType(elementType, columnCount);
+
+ // Create array type: vector<T, C>[R]
+ auto arrayType = builder.getArrayType(vectorType, rowCount);
+
+ // Create a vector from the scalar (replicated C times)
+ List<IRInst*> vectorElements;
+ for (IRIntegerValue col = 0; col < columnCount->getValue(); col++)
+ {
+ vectorElements.add(scalarOperand);
+ }
+ auto rowVector = builder.emitMakeVector(vectorType, vectorElements);
+
+ // Create array with R copies of the same vector
+ List<IRInst*> rowVectors;
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ {
+ rowVectors.add(rowVector);
+ }
+
+ SLANG_ASSERT(
+ rowVectors.getCount() == rowCount->getValue() &&
+ "Row vectors count must match matrix row count");
+ return builder.emitMakeArray(arrayType, rowVectors.getCount(), rowVectors.getBuffer());
+ }
+
IRInst* legalizeMatrixMatrixBinaryOperation(
IRBuilder& builder,
IRInst* legalizedA,
@@ -450,6 +503,8 @@ struct MatrixTypeLoweringContext
{
case kIROp_MakeMatrix:
return legalizeMakeMatrix(inst);
+ case kIROp_MakeMatrixFromScalar:
+ return legalizeMakeMatrixFromScalar(inst);
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 63d3766ab..b48dcc7e6 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5995,14 +5995,14 @@ IRInst* IRBuilder::emitIfElseWithBlocks(
outTrueBlock = createBlock();
outAfterBlock = createBlock();
outFalseBlock = createBlock();
+
auto f = getFunc();
- SLANG_ASSERT(f);
- if (f)
- {
- f->addBlock(outTrueBlock);
- f->addBlock(outAfterBlock);
- f->addBlock(outFalseBlock);
- }
+
+ SLANG_ASSERT(f && "Expected function");
+ f->addBlock(outTrueBlock);
+ f->addBlock(outAfterBlock);
+ f->addBlock(outFalseBlock);
+
auto result = emitIfElse(val, outTrueBlock, outFalseBlock, outAfterBlock);
setInsertInto(outTrueBlock);
return result;