summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvenkataram-nv <vedavamadath@nvidia.com>2025-07-31 15:12:21 -0700
committerGitHub <noreply@github.com>2025-07-31 22:12:21 +0000
commit30fd3c63fb4af9ea8d482c75921710df1b40e59e (patch)
treecd1001e90f5328f20fa7bc6d030bcfcc4e01979f
parentaefd1e3e0dbe4e77f8d7dbbfa04e15c2db615394 (diff)
Add matrix select intrinsic (#7566)
* Add matrix select intrinsic * Fix hlsl test * Restrict matrix select to HLSL * Better test for HLSL side * Select route for GLSL/SPIRV * Exclude matrices from select legalization * Exclude CUDA from select test * Inline and move * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
-rw-r--r--source/slang/core.meta.slang28
-rw-r--r--source/slang/slang-ir-legalize-composite-select.cpp23
-rw-r--r--source/slang/slang-ir-legalize-matrix-types.cpp55
-rw-r--r--source/slang/slang-ir.cpp14
-rw-r--r--tests/hlsl-intrinsic/matrix-cast-to-vector.slang9
-rw-r--r--tests/language-feature/matrix-select.slang45
6 files changed, 151 insertions, 23 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index f42cd25cf..2c6bf04d1 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -958,6 +958,31 @@ __generic<T, let N : int> __intrinsic_op(select) vector<T,N> operator?:(vector<b
__generic<T> __intrinsic_op(select) T select(bool condition, T ifTrue, T ifFalse);
__generic<T, let N : int> __intrinsic_op(select) vector<T,N> select(vector<bool,N> condition, vector<T,N> ifTrue, vector<T,N> ifFalse);
+[require(hlsl)]
+__generic<T, let N : int, let M : int> __intrinsic_op(select) matrix<T,N,M> __hlsl_select(matrix<bool,N,M> condition, matrix<T,N,M> ifTrue, matrix<T,N,M> ifFalse);
+
+__generic<T, let N : int, let M : int>
+matrix<T,N,M> select(matrix<bool,N,M> condition, matrix<T,N,M> ifTrue, matrix<T,N,M> ifFalse)
+{
+ __target_switch
+ {
+ case hlsl:
+ return __hlsl_select(condition, ifTrue, ifFalse);
+ default:
+ matrix<T,N,M> result;
+ [[unroll]]
+ for (uint32_t i = 0; i < N; i++)
+ result[i] = select(condition[i], ifTrue[i], ifFalse[i]);
+ return result;
+ }
+}
+
+__generic<T, let N : int, let M : int>
+matrix<T,N,M> operator?:(matrix<bool,N,M> condition, matrix<T,N,M> ifTrue, matrix<T,N,M> ifFalse)
+{
+ return select(condition, ifTrue, ifFalse);
+}
+
[ForceInline]
__generic<T> Optional<T> select(bool condition, __none_t ifTrue, T ifFalse)
{
@@ -973,7 +998,8 @@ __generic<T> Optional<T> select(bool condition, T ifTrue, __none_t ifFalse)
// Allow real-number types to be cast into each other
//@hidden:
__intrinsic_op($(kIROp_FloatCast))
- T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val);
+T __realCast<T : __BuiltinRealType, U : __BuiltinRealType>(U val);
+
//@hidden:
__intrinsic_op($(kIROp_CastIntToFloat))
T __realCast<T : __BuiltinRealType, U : __BuiltinIntegerType>(U val);
diff --git a/source/slang/slang-ir-legalize-composite-select.cpp b/source/slang/slang-ir-legalize-composite-select.cpp
index 1b2ba0670..cdd37efa3 100644
--- a/source/slang/slang-ir-legalize-composite-select.cpp
+++ b/source/slang/slang-ir-legalize-composite-select.cpp
@@ -9,7 +9,8 @@
namespace Slang
{
-void legalizeASingleNonVectorCompositeSelect(IRBuilder& builder, IRSelect* selectInst)
+
+void legalizeCompositeSelect(IRBuilder& builder, IRSelect* selectInst)
{
SLANG_ASSERT(selectInst);
@@ -49,6 +50,7 @@ void legalizeASingleNonVectorCompositeSelect(IRBuilder& builder, IRSelect* selec
// Clean up
selectInst->removeAndDeallocate();
}
+
void legalizeNonVectorCompositeSelect(IRModule* module)
{
IRBuilder builder(module);
@@ -57,23 +59,24 @@ void legalizeNonVectorCompositeSelect(IRModule* module)
auto func = as<IRFunc>(globalInst);
if (!func)
continue;
+
for (auto block : func->getBlocks())
{
- auto inst = block->getFirstInst();
- IRInst* next;
- for (; inst; inst = next)
+ for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
{
- next = inst->getNextInst();
- switch (inst->getOp())
+ if (auto select = as<IRSelect>(inst))
{
- case kIROp_Select:
// Replace OpSelect with if/else branch (same process as glslang)
- if (!isScalarOrVectorType(inst->getFullType()))
- legalizeASingleNonVectorCompositeSelect(builder, as<IRSelect>(inst));
- continue;
+ bool requiresLegalization = !as<IRBasicType>(select->getFullType()) &&
+ !as<IRVectorType>(select->getFullType()) &&
+ !as<IRMatrixType>(select->getFullType());
+
+ if (requiresLegalization)
+ legalizeCompositeSelect(builder, select);
}
}
}
}
}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-legalize-matrix-types.cpp b/source/slang/slang-ir-legalize-matrix-types.cpp
index 8c8cb0c84..327fa7ead 100644
--- a/source/slang/slang-ir-legalize-matrix-types.cpp
+++ b/source/slang/slang-ir-legalize-matrix-types.cpp
@@ -154,6 +154,59 @@ struct MatrixTypeLoweringContext
return builder.emitMakeArray(arrayType, rowVectors.getCount(), rowVectors.getBuffer());
}
+ IRInst* legalizeMakeMatrixFromScalar(IRInst* inst)
+ {
+ auto matrixType = as<IRMatrixType>(inst->getDataType());
+
+ SLANG_ASSERT(matrixType && "Matrix type is expected");
+ SLANG_ASSERT(
+ shouldLowerMatrixType(matrixType) && "Matrix type is expected to need legalization");
+
+ // Lower makeMatrixFromScalar to makeArray of makeVectors from scalar
+ auto elementType = matrixType->getElementType();
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount());
+ auto columnCount = as<IRIntLit>(matrixType->getColumnCount());
+
+ SLANG_ASSERT(
+ rowCount && columnCount &&
+ "Matrix dimensions must be compile-time constants for lowering");
+
+ SLANG_ASSERT(
+ inst->getOperandCount() == 1 && "makeMatrixFromScalar should have exactly one operand");
+
+ IRBuilder builder(inst);
+ builder.setInsertBefore(inst);
+
+ // Get the scalar operand
+ auto scalarOperand = getReplacement(inst->getOperand(0));
+
+ // Create vector type for rows: vector<T, C>
+ auto vectorType = builder.getVectorType(elementType, columnCount);
+
+ // Create array type: vector<T, C>[R]
+ auto arrayType = builder.getArrayType(vectorType, rowCount);
+
+ // Create a vector from the scalar (replicated C times)
+ List<IRInst*> vectorElements;
+ for (IRIntegerValue col = 0; col < columnCount->getValue(); col++)
+ {
+ vectorElements.add(scalarOperand);
+ }
+ auto rowVector = builder.emitMakeVector(vectorType, vectorElements);
+
+ // Create array with R copies of the same vector
+ List<IRInst*> rowVectors;
+ for (IRIntegerValue row = 0; row < rowCount->getValue(); row++)
+ {
+ rowVectors.add(rowVector);
+ }
+
+ SLANG_ASSERT(
+ rowVectors.getCount() == rowCount->getValue() &&
+ "Row vectors count must match matrix row count");
+ return builder.emitMakeArray(arrayType, rowVectors.getCount(), rowVectors.getBuffer());
+ }
+
IRInst* legalizeMatrixMatrixBinaryOperation(
IRBuilder& builder,
IRInst* legalizedA,
@@ -450,6 +503,8 @@ struct MatrixTypeLoweringContext
{
case kIROp_MakeMatrix:
return legalizeMakeMatrix(inst);
+ case kIROp_MakeMatrixFromScalar:
+ return legalizeMakeMatrixFromScalar(inst);
case kIROp_Add:
case kIROp_Sub:
case kIROp_Mul:
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 63d3766ab..b48dcc7e6 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5995,14 +5995,14 @@ IRInst* IRBuilder::emitIfElseWithBlocks(
outTrueBlock = createBlock();
outAfterBlock = createBlock();
outFalseBlock = createBlock();
+
auto f = getFunc();
- SLANG_ASSERT(f);
- if (f)
- {
- f->addBlock(outTrueBlock);
- f->addBlock(outAfterBlock);
- f->addBlock(outFalseBlock);
- }
+
+ SLANG_ASSERT(f && "Expected function");
+ f->addBlock(outTrueBlock);
+ f->addBlock(outAfterBlock);
+ f->addBlock(outFalseBlock);
+
auto result = emitIfElse(val, outTrueBlock, outFalseBlock, outAfterBlock);
setInsertInto(outTrueBlock);
return result;
diff --git a/tests/hlsl-intrinsic/matrix-cast-to-vector.slang b/tests/hlsl-intrinsic/matrix-cast-to-vector.slang
index 522f3ce11..74df140d3 100644
--- a/tests/hlsl-intrinsic/matrix-cast-to-vector.slang
+++ b/tests/hlsl-intrinsic/matrix-cast-to-vector.slang
@@ -17,10 +17,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
float2x2 matrix2x2_2 = (float2x2)vector4_2;
outputBuffer[0] = uint(true
- && all(vector4_1 == float4(1, 2, 3, 4))
-
- && all(matrix2x2_2[0] == float2(1,2))
- && all(matrix2x2_2[1] == float2(3,4))
- );
+ && all(vector4_1 == float4(1, 2, 3, 4))
+ && all(matrix2x2_2[0] == float2(1,2))
+ && all(matrix2x2_2[1] == float2(3,4))
+ );
//BUF: 1
}
diff --git a/tests/language-feature/matrix-select.slang b/tests/language-feature/matrix-select.slang
new file mode 100644
index 000000000..a3cab6906
--- /dev/null
+++ b/tests/language-feature/matrix-select.slang
@@ -0,0 +1,45 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-dx12 -use-dxil -compute -shaderobj -output-using-type -xslang -matrix-layout-column-major
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-dx12 -use-dxil -compute -shaderobj -output-using-type -xslang -matrix-layout-row-major
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type -xslang -matrix-layout-column-major
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-vk -compute -shaderobj -output-using-type -xslang -matrix-layout-row-major
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-mtl -compute -output-using-type -xslang -matrix-layout-column-major
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-mtl -compute -output-using-type -xslang -matrix-layout-row-major
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-wgpu -compute -output-using-type -xslang -matrix-layout-column-major
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHK):-wgpu -compute -output-using-type -xslang -matrix-layout-row-major
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+int selectDims<int N, int M>(bool cond)
+{
+ return select(
+ matrix<bool, N, M>(cond),
+ matrix<int, N, M>(1),
+ matrix<int, N, M>(0)
+ )[0][0];
+}
+
+int selectDimsDigit<int N, int M, int D>(int x)
+{
+ return selectDims<N, M>(((x >> D) & 0b1) == 0b1) << D;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int x = 324;
+
+ int s = 0;
+ s += selectDimsDigit<2, 2, 0>(x);
+ s += selectDimsDigit<2, 3, 1>(x);
+ s += selectDimsDigit<2, 4, 2>(x);
+ s += selectDimsDigit<3, 2, 3>(x);
+ s += selectDimsDigit<3, 3, 4>(x);
+ s += selectDimsDigit<3, 4, 5>(x);
+ s += selectDimsDigit<4, 2, 6>(x);
+ s += selectDimsDigit<4, 3, 7>(x);
+ s += selectDimsDigit<4, 4, 8>(x);
+
+ // CHK: 324
+ outputBuffer[0] = s;
+} \ No newline at end of file