summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-cpp.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-03 15:54:16 -0700
committerGitHub <noreply@github.com>2023-04-03 15:54:16 -0700
commitb68516e2c2e39af79dda2ec7871fe4d821ef67c4 (patch)
treeec61ca320368f8128cd531a9272e8e49d5353247 /source/slang/slang-emit-cpp.cpp
parent7a346b2982c69ef97ebc4b308c77a1f1c88c548f (diff)
Emit simpler vector element access code. (#2770)
* Emit simpler vector element access code * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-emit-cpp.cpp')
-rw-r--r--source/slang/slang-emit-cpp.cpp56
1 files changed, 41 insertions, 15 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index a178dfe67..d6764ae06 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1270,13 +1270,26 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
IRInst* baseInst = getElementInst->getBase();
IRType* baseType = baseInst->getDataType();
- if (as<IRVectorType>(baseType))
+ if (auto vectorBaseType = as<IRVectorType>(baseType))
{
- m_writer->emit("_slang_vector_get_element(");
- emitOperand(baseInst, getInfo(EmitOp::General));
- m_writer->emit(", ");
- emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
- m_writer->emit(")");
+ if (auto intLitIndex = as<IRIntLit>(getElementInst->getIndex()))
+ {
+ // For static index, we can emit simpler code using the `.x`, `.y` members.
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(baseInst, leftSide(outerPrec, prec));
+ m_writer->emit(".");
+ m_writer->emit(getVectorElementNames(vectorBaseType)[intLitIndex->getValue()]);
+ }
+ else
+ {
+ // For dynamic index, we emit using `_slang_vector_get_element` intrinsics.
+ m_writer->emit("_slang_vector_get_element(");
+ emitOperand(baseInst, getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ }
return true;
}
else if (as<IRMatrixType>(baseType))
@@ -1297,24 +1310,37 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
IRInst* baseInst = getElementInst->getBase();
IRType* baseType = as<IRPtrTypeBase>(baseInst->getDataType())->getValueType();
- if (as<IRVectorType>(baseType))
+ if (auto vectorBaseType = as<IRVectorType>(baseType))
{
- m_writer->emit("_slang_vector_get_element_ptr(");
- emitOperand(baseInst, getInfo(EmitOp::General));
- m_writer->emit(", ");
- emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
- m_writer->emit(")");
+ if (auto intLitIndex = as<IRIntLit>(getElementInst->getIndex()))
+ {
+ // For static index, we can emit simpler code using the `.x`, `.y` members.
+ m_writer->emit("&(");
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(baseInst, leftSide(outerPrec, prec));
+ m_writer->emit("->");
+ m_writer->emit(getVectorElementNames(vectorBaseType)[intLitIndex->getValue()]);
+ m_writer->emit(")");
+ }
+ else
+ {
+ m_writer->emit("_slang_vector_get_element_ptr(");
+ emitOperand(baseInst, getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ }
return true;
}
else if (as<IRMatrixType>(baseType))
{
- m_writer->emit("&(");
+ m_writer->emit("(");
auto outerPrec = getInfo(EmitOp::General);
auto prec = getInfo(EmitOp::Postfix);
emitOperand(baseInst, leftSide(outerPrec, prec));
- m_writer->emit("->rows[");
+ m_writer->emit("->rows + ");
emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
- m_writer->emit("]");
m_writer->emit(")");
return true;
}