summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-metal.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-05-10 09:41:31 -0700
committerGitHub <noreply@github.com>2024-05-10 09:41:31 -0700
commit1dcd814f5038229703e52841b1b0304c22bffb73 (patch)
tree817b95d66bb9ad665375d9b1fa09b5829ca4f38f /source/slang/slang-emit-metal.cpp
parent926009a58315845b3a3a95e2724486a6c9e987ea (diff)
More Metal Intrinsics. (#4143)
Diffstat (limited to 'source/slang/slang-emit-metal.cpp')
-rw-r--r--source/slang/slang-emit-metal.cpp118
1 files changed, 114 insertions, 4 deletions
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp
index 4d8a207c3..1eb4b9abe 100644
--- a/source/slang/slang-emit-metal.cpp
+++ b/source/slang/slang-emit-metal.cpp
@@ -11,6 +11,40 @@
namespace Slang {
+static const char* kMetalBuiltinPreludeMatrixCompMult = R"(
+template<typename T, int A, int B>
+matrix<T,A,B> _slang_matrixCompMult(matrix<T,A,B> m1, matrix<T,A,B> m2)
+{
+ matrix<T,A,B> result;
+ for (int i = 0; i < A; i++)
+ result[i] = m1[i] * m2[i];
+ return result;
+}
+)";
+
+static const char* kMetalBuiltinPreludeMatrixReshape = R"(
+template<int A, int B, typename T, int N, int M>
+matrix<T,A,B> _slang_matrixReshape(matrix<T,N,M> m)
+{
+ matrix<T,A,B> result = T(0);
+ for (int i = 0; i < min(A,N); i++)
+ for (int j = 0; j < min(B,M); j++)
+ result[i] = m[i][j];
+ return result;
+}
+)";
+
+static const char* kMetalBuiltinPreludeVectorReshape = R"(
+template<int A, typename T, int N>
+vec<T,A> _slang_vectorReshape(vec<T,N> v)
+{
+ vec<T,A> result = T(0);
+ for (int i = 0; i < min(A,N); i++)
+ result[i] = v[i];
+ return result;
+}
+)";
+
void MetalSourceEmitter::_emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val)
{
SLANG_UNUSED(entryPoint);
@@ -163,7 +197,7 @@ void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoi
switch (stage)
{
- case Stage::Pixel:
+ case Stage::Pixel:
{
if (irFunc->findDecoration<IREarlyDepthStencilDecoration>())
{
@@ -176,12 +210,36 @@ void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoi
}
}
+void MetalSourceEmitter::ensurePrelude(const char* preludeText)
+{
+ IRStringLit* stringLit;
+ if (!m_builtinPreludes.tryGetValue(preludeText, stringLit))
+ {
+ IRBuilder builder(m_irModule);
+ stringLit = builder.getStringValue(UnownedStringSlice(preludeText));
+ m_builtinPreludes[preludeText] = stringLit;
+ }
+ m_requiredPreludes.add(stringLit);
+}
+
+bool MetalSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_discard:
+ m_writer->emit("discard_fragment();\n");
+ return true;
+ }
+ return false;
+}
+
bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
switch (inst->getOp())
{
case kIROp_MakeVector:
case kIROp_MakeMatrix:
+ case kIROp_MakeVectorFromScalar:
{
if (inst->getOperandCount() == 1)
{
@@ -190,19 +248,71 @@ bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inO
auto prec = getInfo(EmitOp::Prefix);
needClose = maybeEmitParens(outerPrec, prec);
-
- // Need to emit as cast for HLSL
emitType(inst->getDataType());
m_writer->emit("(");
emitOperand(inst->getOperand(0), rightSide(outerPrec, prec));
m_writer->emit(") ");
maybeCloseParens(needClose);
- // Handled
return true;
}
break;
}
+ case kIROp_MatrixReshape:
+ {
+ ensurePrelude(kMetalBuiltinPreludeMatrixReshape);
+ m_writer->emit("_slang_matrixReshape<");
+ auto matrixType = as<IRMatrixType>(inst->getDataType());
+ emitOperand(matrixType->getRowCount(), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(matrixType->getColumnCount(), getInfo(EmitOp::General));
+ m_writer->emit(">(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ case kIROp_VectorReshape:
+ {
+ ensurePrelude(kMetalBuiltinPreludeVectorReshape);
+ m_writer->emit("_slang_vectorReshape<");
+ auto vectorType = as<IRVectorType>(inst->getDataType());
+ emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General));
+ m_writer->emit(">(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ case kIROp_Mul:
+ {
+ // Component-wise multiplication needs to be special cased,
+ // because Metal uses infix `*` to express inner product
+ // when working with matrices.
+
+ // Are both operands matrices?
+ if (as<IRMatrixType>(inst->getOperand(0)->getDataType())
+ && as<IRMatrixType>(inst->getOperand(1)->getDataType()))
+ {
+ ensurePrelude(kMetalBuiltinPreludeMatrixCompMult);
+ m_writer->emit("_slang_matrixCompMult(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ break;
+ }
+ case kIROp_Select:
+ {
+ m_writer->emit("select(");
+ emitOperand(inst->getOperand(2), leftSide(getInfo(EmitOp::General), getInfo(EmitOp::General)));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), leftSide(getInfo(EmitOp::General), getInfo(EmitOp::General)));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(0), leftSide(getInfo(EmitOp::General), getInfo(EmitOp::General)));
+ m_writer->emit(")");
+ return true;
+ }
case kIROp_BitCast:
{
auto toType = inst->getDataType();