diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-16 13:55:32 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-16 13:55:32 -0800 |
| commit | 4c4826d47eeef4675daae4ae53ff76f4d5ebd84a (patch) | |
| tree | ed4af0ded878e4f06e9641ce61d26ffd7c89ccbc /source/slang/slang-ir-simplify-for-emit.cpp | |
| parent | eda88e513e8b1e2abc05e9dc8555f237d96472df (diff) | |
Overhaul global inst deduplication and cpp/cuda backend. (#2654)
* Overhaul global inst deduplication and cpp/cuda backend.
* Update IR documentation.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-simplify-for-emit.cpp')
| -rw-r--r-- | source/slang/slang-ir-simplify-for-emit.cpp | 121 |
1 files changed, 116 insertions, 5 deletions
diff --git a/source/slang/slang-ir-simplify-for-emit.cpp b/source/slang/slang-ir-simplify-for-emit.cpp index 5e5f61a4a..67d95c59f 100644 --- a/source/slang/slang-ir-simplify-for-emit.cpp +++ b/source/slang/slang-ir-simplify-for-emit.cpp @@ -5,12 +5,16 @@ namespace Slang { +bool isCPUTarget(TargetRequest* targetReq); +bool isCUDATarget(TargetRequest* targetReq); + struct SimplifyForEmitContext : public InstPassBase { - SimplifyForEmitContext(IRModule* inModule) - : InstPassBase(inModule) + SimplifyForEmitContext(IRModule* inModule, TargetRequest* inTargetReq) + : InstPassBase(inModule), targetReq(inTargetReq) {} + TargetRequest* targetReq; List<IRInst*> followUpWorkList; HashSet<IRInst*> followUpWorkListSet; @@ -134,7 +138,7 @@ struct SimplifyForEmitContext : public InstPassBase IRBuilder builder(sharedBuilderStorage); builder.setInsertBefore(user); auto newLoad = builder.emitLoad(load->getPtr()); - use->set(newLoad); + builder.replaceOperand(use, newLoad); } void processLoad(IRLoad* inst) @@ -330,8 +334,115 @@ struct SimplifyForEmitContext : public InstPassBase processInst(followUpWorkList[i]); } + void unifyBinaryExprOperands(IRGlobalValueWithCode* func) + { + IRBuilder builder(func->getModule()); + + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst()) + { + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_IRem: + case kIROp_FRem: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Leq: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Greater: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Lsh: + case kIROp_Rsh: + builder.setInsertBefore(inst); + SLANG_ASSERT(inst->getOperandCount() == 2); + if (as<IRVectorType>(inst->getDataType())) + { + for (UInt a = 0; a < 2; a++) + { + if (as<IRBasicType>(inst->getOperand(a)->getDataType())) + { + auto v = builder.emitMakeVectorFromScalar( + inst->getOperand(1 - a)->getDataType(), inst->getOperand(a)); + inst->setOperand(a, v); + } + } + } + else if (as<IRMatrixType>(inst->getDataType())) + { + for (UInt a = 0; a < 2; a++) + { + if (as<IRBasicType>(inst->getOperand(a)->getDataType())) + { + auto v = builder.emitMakeMatrixFromScalar( + inst->getOperand(1 - a)->getDataType(), inst->getOperand(a)); + inst->setOperand(a, v); + } + } + } + + break; + } + } + } + } + + // Turn single element vector values into scalars before using it to call an intrinsic func. + void lowerTrivialVector(IRGlobalValueWithCode* func) + { + IRBuilder builder(func->getModule()); + List<IRInst*> instsToProcess; + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst()) + { + switch (inst->getOp()) + { + case kIROp_Call: + { + // If we are calling an intrinsic with any vector<T,1> argument, replace it with T. + auto callInst = as<IRCall>(inst); + if (getResolvedInstForDecorations(callInst->getCallee())->findDecoration<IRTargetIntrinsicDecoration>()) + { + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + if (auto argVectorType = as<IRVectorType>(arg->getDataType())) + { + if (cast<IRIntLit>(argVectorType->getElementCount())->getValue() == 1) + { + builder.setInsertBefore(callInst); + UInt idx = 0; + auto newArg = builder.emitSwizzle(argVectorType->getElementType(), arg, 1, &idx); + callInst->setOperand(a + 1, newArg); + } + } + } + } + } + break; + } + } + } + } + + void processFunc(IRGlobalValueWithCode* func) { + if (isCPUTarget(targetReq) || isCUDATarget(targetReq)) + { + unifyBinaryExprOperands(func); + lowerTrivialVector(func); + } eliminateCompositeConstruct(func); deferAndDuplicateElementExtract(func); deferAndDuplicateLoad(func); @@ -345,9 +456,9 @@ struct SimplifyForEmitContext : public InstPassBase } }; -void simplifyForEmit(IRModule* module) +void simplifyForEmit(IRModule* module, TargetRequest* targetRequest) { - SimplifyForEmitContext context(module); + SimplifyForEmitContext context(module, targetRequest); context.processModule(); } |
