summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-simplify-for-emit.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-16 13:55:32 -0800
committerGitHub <noreply@github.com>2023-02-16 13:55:32 -0800
commit4c4826d47eeef4675daae4ae53ff76f4d5ebd84a (patch)
treeed4af0ded878e4f06e9641ce61d26ffd7c89ccbc /source/slang/slang-ir-simplify-for-emit.cpp
parenteda88e513e8b1e2abc05e9dc8555f237d96472df (diff)
Overhaul global inst deduplication and cpp/cuda backend. (#2654)
* Overhaul global inst deduplication and cpp/cuda backend. * Update IR documentation. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-simplify-for-emit.cpp')
-rw-r--r--source/slang/slang-ir-simplify-for-emit.cpp121
1 files changed, 116 insertions, 5 deletions
diff --git a/source/slang/slang-ir-simplify-for-emit.cpp b/source/slang/slang-ir-simplify-for-emit.cpp
index 5e5f61a4a..67d95c59f 100644
--- a/source/slang/slang-ir-simplify-for-emit.cpp
+++ b/source/slang/slang-ir-simplify-for-emit.cpp
@@ -5,12 +5,16 @@
namespace Slang
{
+bool isCPUTarget(TargetRequest* targetReq);
+bool isCUDATarget(TargetRequest* targetReq);
+
struct SimplifyForEmitContext : public InstPassBase
{
- SimplifyForEmitContext(IRModule* inModule)
- : InstPassBase(inModule)
+ SimplifyForEmitContext(IRModule* inModule, TargetRequest* inTargetReq)
+ : InstPassBase(inModule), targetReq(inTargetReq)
{}
+ TargetRequest* targetReq;
List<IRInst*> followUpWorkList;
HashSet<IRInst*> followUpWorkListSet;
@@ -134,7 +138,7 @@ struct SimplifyForEmitContext : public InstPassBase
IRBuilder builder(sharedBuilderStorage);
builder.setInsertBefore(user);
auto newLoad = builder.emitLoad(load->getPtr());
- use->set(newLoad);
+ builder.replaceOperand(use, newLoad);
}
void processLoad(IRLoad* inst)
@@ -330,8 +334,115 @@ struct SimplifyForEmitContext : public InstPassBase
processInst(followUpWorkList[i]);
}
+ void unifyBinaryExprOperands(IRGlobalValueWithCode* func)
+ {
+ IRBuilder builder(func->getModule());
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_IRem:
+ case kIROp_FRem:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Leq:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Greater:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ builder.setInsertBefore(inst);
+ SLANG_ASSERT(inst->getOperandCount() == 2);
+ if (as<IRVectorType>(inst->getDataType()))
+ {
+ for (UInt a = 0; a < 2; a++)
+ {
+ if (as<IRBasicType>(inst->getOperand(a)->getDataType()))
+ {
+ auto v = builder.emitMakeVectorFromScalar(
+ inst->getOperand(1 - a)->getDataType(), inst->getOperand(a));
+ inst->setOperand(a, v);
+ }
+ }
+ }
+ else if (as<IRMatrixType>(inst->getDataType()))
+ {
+ for (UInt a = 0; a < 2; a++)
+ {
+ if (as<IRBasicType>(inst->getOperand(a)->getDataType()))
+ {
+ auto v = builder.emitMakeMatrixFromScalar(
+ inst->getOperand(1 - a)->getDataType(), inst->getOperand(a));
+ inst->setOperand(a, v);
+ }
+ }
+ }
+
+ break;
+ }
+ }
+ }
+ }
+
+ // Turn single element vector values into scalars before using it to call an intrinsic func.
+ void lowerTrivialVector(IRGlobalValueWithCode* func)
+ {
+ IRBuilder builder(func->getModule());
+ List<IRInst*> instsToProcess;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ {
+ // If we are calling an intrinsic with any vector<T,1> argument, replace it with T.
+ auto callInst = as<IRCall>(inst);
+ if (getResolvedInstForDecorations(callInst->getCallee())->findDecoration<IRTargetIntrinsicDecoration>())
+ {
+ for (UInt a = 0; a < callInst->getArgCount(); a++)
+ {
+ auto arg = callInst->getArg(a);
+ if (auto argVectorType = as<IRVectorType>(arg->getDataType()))
+ {
+ if (cast<IRIntLit>(argVectorType->getElementCount())->getValue() == 1)
+ {
+ builder.setInsertBefore(callInst);
+ UInt idx = 0;
+ auto newArg = builder.emitSwizzle(argVectorType->getElementType(), arg, 1, &idx);
+ callInst->setOperand(a + 1, newArg);
+ }
+ }
+ }
+ }
+ }
+ break;
+ }
+ }
+ }
+ }
+
+
void processFunc(IRGlobalValueWithCode* func)
{
+ if (isCPUTarget(targetReq) || isCUDATarget(targetReq))
+ {
+ unifyBinaryExprOperands(func);
+ lowerTrivialVector(func);
+ }
eliminateCompositeConstruct(func);
deferAndDuplicateElementExtract(func);
deferAndDuplicateLoad(func);
@@ -345,9 +456,9 @@ struct SimplifyForEmitContext : public InstPassBase
}
};
-void simplifyForEmit(IRModule* module)
+void simplifyForEmit(IRModule* module, TargetRequest* targetRequest)
{
- SimplifyForEmitContext context(module);
+ SimplifyForEmitContext context(module, targetRequest);
context.processModule();
}