diff options
| author | Yong He <yonghe@outlook.com> | 2024-05-10 09:41:31 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-10 09:41:31 -0700 |
| commit | 1dcd814f5038229703e52841b1b0304c22bffb73 (patch) | |
| tree | 817b95d66bb9ad665375d9b1fa09b5829ca4f38f /source/slang/slang-emit-metal.cpp | |
| parent | 926009a58315845b3a3a95e2724486a6c9e987ea (diff) | |
More Metal Intrinsics. (#4143)
Diffstat (limited to 'source/slang/slang-emit-metal.cpp')
| -rw-r--r-- | source/slang/slang-emit-metal.cpp | 118 |
1 files changed, 114 insertions, 4 deletions
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index 4d8a207c3..1eb4b9abe 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -11,6 +11,40 @@ namespace Slang { +static const char* kMetalBuiltinPreludeMatrixCompMult = R"( +template<typename T, int A, int B> +matrix<T,A,B> _slang_matrixCompMult(matrix<T,A,B> m1, matrix<T,A,B> m2) +{ + matrix<T,A,B> result; + for (int i = 0; i < A; i++) + result[i] = m1[i] * m2[i]; + return result; +} +)"; + +static const char* kMetalBuiltinPreludeMatrixReshape = R"( +template<int A, int B, typename T, int N, int M> +matrix<T,A,B> _slang_matrixReshape(matrix<T,N,M> m) +{ + matrix<T,A,B> result = T(0); + for (int i = 0; i < min(A,N); i++) + for (int j = 0; j < min(B,M); j++) + result[i] = m[i][j]; + return result; +} +)"; + +static const char* kMetalBuiltinPreludeVectorReshape = R"( +template<int A, typename T, int N> +vec<T,A> _slang_vectorReshape(vec<T,N> v) +{ + vec<T,A> result = T(0); + for (int i = 0; i < min(A,N); i++) + result[i] = v[i]; + return result; +} +)"; + void MetalSourceEmitter::_emitHLSLDecorationSingleString(const char* name, IRFunc* entryPoint, IRStringLit* val) { SLANG_UNUSED(entryPoint); @@ -163,7 +197,7 @@ void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoi switch (stage) { - case Stage::Pixel: + case Stage::Pixel: { if (irFunc->findDecoration<IREarlyDepthStencilDecoration>()) { @@ -176,12 +210,36 @@ void MetalSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoi } } +void MetalSourceEmitter::ensurePrelude(const char* preludeText) +{ + IRStringLit* stringLit; + if (!m_builtinPreludes.tryGetValue(preludeText, stringLit)) + { + IRBuilder builder(m_irModule); + stringLit = builder.getStringValue(UnownedStringSlice(preludeText)); + m_builtinPreludes[preludeText] = stringLit; + } + m_requiredPreludes.add(stringLit); +} + +bool MetalSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_discard: + m_writer->emit("discard_fragment();\n"); + return true; + } + return false; +} + bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { switch (inst->getOp()) { case kIROp_MakeVector: case kIROp_MakeMatrix: + case kIROp_MakeVectorFromScalar: { if (inst->getOperandCount() == 1) { @@ -190,19 +248,71 @@ bool MetalSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inO auto prec = getInfo(EmitOp::Prefix); needClose = maybeEmitParens(outerPrec, prec); - - // Need to emit as cast for HLSL emitType(inst->getDataType()); m_writer->emit("("); emitOperand(inst->getOperand(0), rightSide(outerPrec, prec)); m_writer->emit(") "); maybeCloseParens(needClose); - // Handled return true; } break; } + case kIROp_MatrixReshape: + { + ensurePrelude(kMetalBuiltinPreludeMatrixReshape); + m_writer->emit("_slang_matrixReshape<"); + auto matrixType = as<IRMatrixType>(inst->getDataType()); + emitOperand(matrixType->getRowCount(), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(matrixType->getColumnCount(), getInfo(EmitOp::General)); + m_writer->emit(">("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_VectorReshape: + { + ensurePrelude(kMetalBuiltinPreludeVectorReshape); + m_writer->emit("_slang_vectorReshape<"); + auto vectorType = as<IRVectorType>(inst->getDataType()); + emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General)); + m_writer->emit(">("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + case kIROp_Mul: + { + // Component-wise multiplication needs to be special cased, + // because Metal uses infix `*` to express inner product + // when working with matrices. + + // Are both operands matrices? + if (as<IRMatrixType>(inst->getOperand(0)->getDataType()) + && as<IRMatrixType>(inst->getOperand(1)->getDataType())) + { + ensurePrelude(kMetalBuiltinPreludeMatrixCompMult); + m_writer->emit("_slang_matrixCompMult("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + break; + } + case kIROp_Select: + { + m_writer->emit("select("); + emitOperand(inst->getOperand(2), leftSide(getInfo(EmitOp::General), getInfo(EmitOp::General))); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), leftSide(getInfo(EmitOp::General), getInfo(EmitOp::General))); + m_writer->emit(", "); + emitOperand(inst->getOperand(0), leftSide(getInfo(EmitOp::General), getInfo(EmitOp::General))); + m_writer->emit(")"); + return true; + } case kIROp_BitCast: { auto toType = inst->getDataType(); |
