summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-redundancy-removal.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-23 06:59:25 -0800
committerGitHub <noreply@github.com>2023-01-23 06:59:25 -0800
commit46a4d98baa1d43b33717b4377aefeeaf46b9c2ff (patch)
treec89f3a1c416330f859887d00f896b18bcc7488a5 /source/slang/slang-ir-redundancy-removal.cpp
parent263ca18ea516cfce43fda703c0a411aaf1938e42 (diff)
Full address insts elimination for backward autodiff. (#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-redundancy-removal.cpp')
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp125
1 files changed, 125 insertions, 0 deletions
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
new file mode 100644
index 000000000..a57bfce3e
--- /dev/null
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -0,0 +1,125 @@
+#include "slang-ir-redundancy-removal.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-util.h"
+
+namespace Slang
+{
+
+struct RedundancyRemovalContext
+{
+ RefPtr<IRDominatorTree> dom;
+ bool removeRedundancyInBlock(DeduplicateContext& deduplicateContext, IRBlock* block)
+ {
+ bool result = false;
+ for (auto instP : block->getChildren())
+ {
+ auto resultInst = deduplicateContext.deduplicate(instP, [&](IRInst* inst)
+ {
+ auto parentBlock = as<IRBlock>(inst->getParent());
+ if (!parentBlock)
+ return false;
+ if (dom->isUnreachable(parentBlock))
+ return false;
+
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_Module:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Not:
+ case kIROp_FieldExtract:
+ case kIROp_FieldAddress:
+ case kIROp_GetElement:
+ case kIROp_GetElementPtr:
+ case kIROp_LookupWitness:
+ case kIROp_Specialize:
+ case kIROp_OptionalHasValue:
+ case kIROp_GetOptionalValue:
+ case kIROp_MakeOptionalValue:
+ case kIROp_MakeTuple:
+ case kIROp_GetTupleElement:
+ case kIROp_MakeStruct:
+ case kIROp_MakeArray:
+ case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_swizzle:
+ case kIROp_MatrixReshape:
+ case kIROp_MakeString:
+ case kIROp_MakeResultError:
+ case kIROp_MakeResultValue:
+ case kIROp_GetResultError:
+ case kIROp_GetResultValue:
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_CastIntToPtr:
+ case kIROp_CastPtrToBool:
+ case kIROp_CastPtrToInt:
+ case kIROp_BitAnd:
+ case kIROp_BitNot:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_BitCast:
+ case kIROp_Reinterpret:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ case kIROp_Neq:
+ case kIROp_Eql:
+ return true;
+ default:
+ return false;
+ }
+ });
+ if (resultInst != instP)
+ result = true;
+ }
+ for (auto child : dom->getImmediatelyDominatedBlocks(block))
+ {
+ DeduplicateContext subContext;
+ subContext.deduplicateMap = deduplicateContext.deduplicateMap;
+ result |= removeRedundancyInBlock(subContext, child);
+ }
+ return result;
+ }
+};
+
+bool removeRedundancy(IRModule* module)
+{
+ bool changed = false;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto genericInst = as<IRGeneric>(inst))
+ {
+ removeRedundancyInFunc(genericInst);
+ inst = findGenericReturnVal(genericInst);
+ }
+ if (auto func = as<IRFunc>(inst))
+ {
+ changed |= removeRedundancyInFunc(func);
+ }
+ }
+ return changed;
+}
+
+bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
+{
+ auto root = func->getFirstBlock();
+ if (!root)
+ return false;
+
+ RedundancyRemovalContext context;
+ context.dom = computeDominatorTree(func);
+ DeduplicateContext deduplicateCtx;
+ return context.removeRedundancyInBlock(deduplicateCtx, root);
+}
+
+}