summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-loop-analysis.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-04-04 14:12:27 -0400
committerGitHub <noreply@github.com>2025-04-04 18:12:27 +0000
commit41e7e565eb3dfa13562cbfa3e8641874c2c6d66c (patch)
treef3ba43df8fe2a4c8ca4a63c9e9f577d7e0f7a028 /source/slang/slang-ir-autodiff-loop-analysis.cpp
parent83a42cb76feb1f702ff730040f359cabc01c571a (diff)
Add a loop analysis step to infer the exit values of loop phi parameters. (#6696)
* Initial loop analysis pass * More changes for a single-pass implication propagation * Update slang-ir-autodiff-loop-analysis.cpp * Cleanup + new system for loop analysis * Fixup bugs in loop analysis * Remove some relation types to simplify the analysis. Add test * Remove unused * Address comments * Fix issue with continue loops * Update reverse-loop-exit-value-inference-1.slang * Update reverse-continue-loop.slang
Diffstat (limited to 'source/slang/slang-ir-autodiff-loop-analysis.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-loop-analysis.cpp1047
1 files changed, 1047 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-loop-analysis.cpp b/source/slang/slang-ir-autodiff-loop-analysis.cpp
new file mode 100644
index 000000000..d4ff631a6
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-loop-analysis.cpp
@@ -0,0 +1,1047 @@
+// slang-ir-autodiff-loop-analysis.cpp
+
+#include "slang-ir-autodiff-loop-analysis.h"
+
+namespace Slang
+{
+
+static bool isCompareCmpInst(IRInst* inst)
+{
+ // Switch on the opcode of the instruction
+ switch (inst->getOp())
+ {
+ case kIROp_Less:
+ case kIROp_Greater:
+ case kIROp_Leq:
+ case kIROp_Geq:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ return true;
+ default:
+ return false;
+ }
+}
+
+SimpleRelation mergeEqualityWithIntegerRelation(SimpleRelation equality, SimpleRelation relation)
+{
+ SLANG_ASSERT(
+ equality.type == SimpleRelation::IntegerRelation &&
+ relation.type == SimpleRelation::IntegerRelation);
+ SLANG_ASSERT(equality.comparator == SimpleRelation::Equal);
+
+ switch (relation.comparator)
+ {
+ case SimpleRelation::Equal:
+ if (relation.integerValue == equality.integerValue)
+ return relation;
+ break; // Technically we'd want to return a "set" here, but we don't have a representation
+ // for that.
+ case SimpleRelation::LessThanEqual:
+ if (equality.integerValue <= relation.integerValue)
+ return relation;
+ break;
+ case SimpleRelation::GreaterThanEqual:
+ if (equality.integerValue >= relation.integerValue)
+ return relation;
+ break;
+ default:
+ break;
+ }
+
+ return SimpleRelation::anyRelation();
+}
+
+SimpleRelation mergeIntervals(SimpleRelation a, SimpleRelation b)
+{
+ SLANG_ASSERT(
+ a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation);
+
+ if (a.comparator == SimpleRelation::Equal)
+ {
+ return mergeEqualityWithIntegerRelation(a, b);
+ }
+ else if (b.comparator == SimpleRelation::Equal)
+ {
+ return mergeEqualityWithIntegerRelation(b, a);
+ }
+
+ // TODO: Handle other cases...
+ return SimpleRelation::anyRelation();
+}
+
+// Returns the tighest "simple" relation such that (a v b -> result)
+//
+// Note: "simple" means that the relation is not a disjunction or conjunction of other relations.
+//
+SimpleRelation relationUnion(SimpleRelation a, SimpleRelation b)
+{
+ // Base case. The disjunction operator is idempotent.
+ if (a == b)
+ return a;
+
+ // If either side is trivially true, the result is trivially true.
+ if (a.type == SimpleRelation::Any || b.type == SimpleRelation::Any)
+ return SimpleRelation::anyRelation();
+
+ // If either side is trivially false, then the result is the other relation.
+ if (a.type == SimpleRelation::Impossible)
+ return b;
+
+ if (b.type == SimpleRelation::Impossible)
+ return a;
+
+ // If one is the negated form of the other, there's really nothing we can prove, since
+ // A OR ~A is always true.
+ //
+ if (a.negated() == b)
+ return SimpleRelation::anyRelation();
+
+ // Handle the case of where one is an inequality and the other is an equality.
+ if (a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation)
+ return mergeIntervals(a, b);
+
+ // TODO: Here's where we can handle subset cases like (a < 10) and (a < 20) => (a < 20), etc..
+ // But we don't _have_ to. The more we can prove, the more cases we can handle, but the result
+ // is still correct without it.
+ //
+
+ // Default to not being able to say anything.
+ return SimpleRelation::anyRelation();
+}
+
+SimpleRelation intersectEqualityWithIntegerRelation(
+ SimpleRelation equality,
+ SimpleRelation relation)
+{
+ SLANG_ASSERT(equality.type == SimpleRelation::IntegerRelation);
+ SLANG_ASSERT(relation.type == SimpleRelation::IntegerRelation);
+ SLANG_ASSERT(equality.comparator == SimpleRelation::Equal);
+
+ if (relation.comparator == SimpleRelation::Equal)
+ {
+ if (relation.integerValue == equality.integerValue)
+ return SimpleRelation::integerRelation(SimpleRelation::Equal, equality.integerValue);
+ else
+ return SimpleRelation::impossibleRelation();
+ }
+ else if (relation.comparator == SimpleRelation::LessThanEqual)
+ {
+ if (equality.integerValue <= relation.integerValue)
+ return SimpleRelation::integerRelation(
+ SimpleRelation::LessThanEqual,
+ relation.integerValue);
+ else
+ return SimpleRelation::impossibleRelation();
+ }
+ else if (relation.comparator == SimpleRelation::GreaterThanEqual)
+ {
+ if (equality.integerValue >= relation.integerValue)
+ return SimpleRelation::integerRelation(
+ SimpleRelation::GreaterThanEqual,
+ relation.integerValue);
+ else
+ return SimpleRelation::impossibleRelation();
+ }
+ else if (relation.comparator == SimpleRelation::NotEqual)
+ {
+ if (equality.integerValue != relation.integerValue)
+ return SimpleRelation::integerRelation(SimpleRelation::NotEqual, relation.integerValue);
+ else
+ return SimpleRelation::impossibleRelation();
+ }
+
+ return SimpleRelation::anyRelation();
+}
+
+// Intersect intervals.
+SimpleRelation intersectIntervals(SimpleRelation a, SimpleRelation b)
+{
+ SLANG_ASSERT(
+ a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation);
+
+ if (a.comparator == SimpleRelation::Equal)
+ {
+ return intersectEqualityWithIntegerRelation(a, b);
+ }
+ else if (b.comparator == SimpleRelation::Equal)
+ {
+ return intersectEqualityWithIntegerRelation(b, a);
+ }
+
+ // TODO: Handle other cases...
+
+ // We'll just default to picking the first one, since (a ^ b) -> a is always true.
+ return a;
+}
+
+// Returns the best "simple" relation such that (a ^ b -> result)
+//
+SimpleRelation relationIntersection(SimpleRelation a, SimpleRelation b)
+{
+ // Base case. The conjunction operator is idempotent.
+ if (a == b)
+ return a;
+
+ // If one is the negated form of the other, then we can prove that the result is impossible.
+ // Doesn't necessarily mean that we have an error on our hands, but it does mean that whatever
+ // case we're considering can't happen, so can be ignored (unreachable)
+ //
+ if (a.negated() == b)
+ return SimpleRelation::impossibleRelation();
+
+ // If any of the relations is impossible, then the result is impossible.
+ if (a.type == SimpleRelation::Impossible || b.type == SimpleRelation::Impossible)
+ return SimpleRelation::impossibleRelation();
+
+ // If any one of the relations is trivially true, then the result is the other relation.
+ if (a.type == SimpleRelation::Any)
+ return b;
+
+ if (b.type == SimpleRelation::Any)
+ return a;
+
+ //
+ // We'll handle the cases where one is an equality and the other is an inequality or equality.
+ //
+ // i.e. For a conjunction (a == 10) ^ (a < 20), we can use the narrower relation (a == 10).
+ //
+ if (a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation)
+ return intersectIntervals(a, b);
+
+ // TODO: Handle other cases...
+ return SimpleRelation::anyRelation();
+}
+
+void StatementSet::disjunct(StatementSet other)
+{
+ // false v (a1 v a2 v a3 ...) = (a1 v a2 v a3 ...)
+ if (isTriviallyFalse())
+ {
+ statements = other.statements;
+ return;
+ }
+
+ // (a1 v a2 v a3 ...) v false = (a1 v a2 v a3 ...)
+ if (other.isTriviallyFalse())
+ return;
+
+ // true v (a1 v a2 v a3 ...) = true
+ if (other.isTriviallyTrue())
+ {
+ statements.clear();
+ return;
+ }
+
+ // (a1 v a2 v a3 ...) v true = true
+ if (isTriviallyTrue())
+ return;
+
+ for (auto& statement : other.statements)
+ {
+ // Since we hold only one statement per inst, we can perform disjunction
+ // on a per-inst basis.
+ // If an inst does not exist in the current set, then it's an empty statement.
+ //
+ if (statements.containsKey(statement.first))
+ {
+ auto newRelation = relationUnion(statement.second, statements[statement.first]);
+ set(statement.first, newRelation);
+ }
+ }
+
+ // Remove any insts that don't have a corresponding statement in the other set,
+ // since this effectively means "any".
+ //
+ for (auto& statement : statements)
+ {
+ if (!other.statements.containsKey(statement.first))
+ statements.remove(statement.first);
+ }
+}
+
+void StatementSet::conjunct(StatementSet other)
+{
+ // true ^ (a1 ^ a2 ^ a3 ...) = (a1 ^ a2 ^ a3 ...)
+ if (other.isTriviallyTrue())
+ return;
+
+ // (a1 ^ a2 ^ a3 ...) ^ true = (a1 ^ a2 ^ a3 ...)
+ if (isTriviallyTrue())
+ {
+ statements = other.statements;
+ return;
+ }
+
+ // false ^ (a1 ^ a2 ^ a3 ...) = false
+ if (isTriviallyFalse())
+ return;
+
+ // (a1 ^ a2 ^ a3 ...) ^ false = false
+ if (other.isTriviallyFalse())
+ {
+ statements = other.statements;
+ return;
+ }
+
+ // Otherwise do an element-wise conjunction.
+ for (auto& statement : other.statements)
+ {
+ if (statements.containsKey(statement.first))
+ {
+ set(statement.first,
+ relationIntersection(statement.second, statements[statement.first]));
+ }
+ else
+ {
+ set(statement.first, statement.second);
+ }
+ }
+}
+
+void StatementSet::conjunct(IRInst* inst, SimpleRelation relation)
+{
+ if (isTriviallyFalse())
+ return;
+
+ if (statements.containsKey(inst))
+ {
+ set(inst, relationIntersection(relation, statements[inst]));
+ }
+ else
+ {
+ set(inst, relation);
+ }
+}
+
+// This function answers the question: "Can we prove that relationB is true if relationA is true?"
+//
+// Note that this is not the same as "Does relationA imply relationB", since there can be cases
+// where this is indeed true, but we just don't have the logic to prove it.
+//
+bool doesRelationImply(SimpleRelation relationA, SimpleRelation relationB)
+{
+ // Equal relations imply each other
+ if (relationA == relationB)
+ return true;
+
+ // If B is trivially true, then A implies B
+ if (relationB.type == SimpleRelation::Any)
+ return true;
+
+ // If A is trivially true, then A implies B only if B is also trivially true
+ if (relationA.type == SimpleRelation::Any)
+ return (relationB.type == SimpleRelation::Any);
+
+ // If A is impossible, then technically what we return doesn't matter...
+ if (relationA.type == SimpleRelation::Impossible ||
+ relationB.type == SimpleRelation::Impossible)
+ return false;
+
+ // If A is a boolean relation, then A implies B if B is also a boolean relation and the values
+ // are the same.
+ //
+ if (relationA.type == SimpleRelation::BoolRelation)
+ return (relationB.type == SimpleRelation::BoolRelation) &&
+ (relationA.boolValue == relationB.boolValue);
+
+ if (relationA.type == SimpleRelation::IntegerRelation)
+ {
+ if (relationB.type != SimpleRelation::IntegerRelation)
+ return false;
+
+ // Technically, the equality case is already handled above, so we'll only consider
+ // cases where A and B are not the same relation, but where A -> B
+
+ // If A is an equality, and B is an inequality, we can test
+ if (relationA.comparator == SimpleRelation::Equal)
+ {
+ if (relationB.comparator == SimpleRelation::LessThanEqual)
+ return relationA.integerValue <= relationB.integerValue;
+ else if (relationB.comparator == SimpleRelation::GreaterThanEqual)
+ return relationA.integerValue >= relationB.integerValue;
+ }
+
+ // If A is an equality, and B is an inequality with different values, then
+ // A -> B
+ //
+ if (relationA.comparator == SimpleRelation::Equal &&
+ relationB.comparator == SimpleRelation::NotEqual)
+ {
+ return relationA.integerValue != relationB.integerValue;
+ }
+
+ if (relationA.comparator == SimpleRelation::GreaterThanEqual &&
+ relationB.comparator == SimpleRelation::GreaterThanEqual)
+ {
+ return relationA.integerValue >= relationB.integerValue;
+ }
+
+ if (relationA.comparator == SimpleRelation::LessThanEqual &&
+ relationB.comparator == SimpleRelation::LessThanEqual)
+ {
+ return relationA.integerValue <= relationB.integerValue;
+ }
+
+ // TODO: Handle other cases.. these come up rarely, so we can
+ }
+
+ return false;
+}
+
+bool isIntegerConstantValue(IRInst* inst)
+{
+ return inst->getOp() == kIROp_IntLit;
+}
+
+bool isBoolConstantValue(IRInst* inst)
+{
+ return inst->getOp() == kIROp_BoolLit;
+}
+
+IRIntegerValue getConstantIntegerValue(IRInst* inst)
+{
+ SLANG_ASSERT(isIntegerConstantValue(inst));
+ return as<IRIntLit>(inst)->getValue();
+}
+
+bool getConstantBoolValue(IRInst* inst)
+{
+ SLANG_ASSERT(isBoolConstantValue(inst));
+ return as<IRBoolLit>(inst)->getValue();
+}
+
+StatementSet tryExtractStatements(IRTerminatorInst* inst, IRBlock* block)
+{
+ StatementSet statements;
+
+ // From condInst, extract a statement about any inst such that we have an equality
+ // statement (integer or boolean) on the inst.
+ //
+ if (auto ifElse = as<IRIfElse>(inst))
+ {
+ // Check that the block is the true or false block of the if-else
+ bool isTrueBlock = ifElse->getTrueBlock() == block;
+ bool isFalseBlock = ifElse->getFalseBlock() == block;
+ if (!isTrueBlock && !isFalseBlock)
+ goto done;
+
+ auto condInst = inst->getOperand(0);
+ statements.conjunct(condInst, SimpleRelation::boolRelation(isTrueBlock));
+
+ if (condInst->getOp() == kIROp_Eql)
+ {
+ auto leftOperand = condInst->getOperand(0);
+ auto rightOperand = condInst->getOperand(1);
+
+ if (isIntegerConstantValue(leftOperand))
+ {
+ statements.conjunct(
+ rightOperand,
+ SimpleRelation::integerRelation(
+ (isTrueBlock ? SimpleRelation::Equal : SimpleRelation::NotEqual),
+ getConstantIntegerValue(leftOperand)));
+ }
+ else if (isIntegerConstantValue(rightOperand))
+ {
+ statements.conjunct(
+ leftOperand,
+ SimpleRelation::integerRelation(
+ (isTrueBlock ? SimpleRelation::Equal : SimpleRelation::NotEqual),
+ getConstantIntegerValue(rightOperand)));
+ }
+ }
+ else if (isCompareCmpInst(condInst))
+ {
+ auto leftOperand = condInst->getOperand(0);
+ auto rightOperand = condInst->getOperand(1);
+
+ bool isParamLeft = !isIntegerConstantValue(leftOperand);
+ bool isParamRight = !isIntegerConstantValue(rightOperand);
+
+ // If neither operand is an inst, we can't say anything.
+ if (!isParamLeft && !isParamRight)
+ goto done;
+
+ auto paramOperand = isParamLeft ? leftOperand : rightOperand;
+ auto otherOperand = isParamLeft ? rightOperand : leftOperand;
+
+ // Check if the "other" operand is a constant
+ if (!isIntegerConstantValue(otherOperand))
+ goto done;
+
+ auto constantVal = getConstantIntegerValue(otherOperand);
+
+ SimpleRelation::Comparator comparator;
+ switch (condInst->getOp())
+ {
+ case kIROp_Less:
+ comparator = SimpleRelation::LessThanEqual;
+ constantVal = constantVal - 1;
+ break;
+ case kIROp_Greater:
+ comparator = SimpleRelation::GreaterThanEqual;
+ constantVal = constantVal + 1;
+ break;
+ case kIROp_Leq:
+ comparator = SimpleRelation::LessThanEqual;
+ break;
+ case kIROp_Geq:
+ comparator = SimpleRelation::GreaterThanEqual;
+ break;
+ case kIROp_Eql:
+ comparator = SimpleRelation::Equal;
+ break;
+ case kIROp_Neq:
+ comparator = SimpleRelation::NotEqual;
+ break;
+ default:
+ SLANG_UNREACHABLE("unexpected op code");
+ }
+ auto relation = SimpleRelation::integerRelation(comparator, constantVal);
+ statements.conjunct(
+ paramOperand,
+ ((isParamLeft ^ !isTrueBlock) ? relation : relation.negated()));
+ }
+ }
+ else if (auto switchInst = as<IRSwitch>(inst))
+ {
+ // Check that the block is the default case of the switch
+ if (switchInst->getDefaultLabel() == block)
+ goto done;
+
+ // Check each case block
+ UInt caseCount = switchInst->getCaseCount();
+ for (UInt i = 0; i < caseCount; i++)
+ {
+ auto caseValue = switchInst->getCaseValue(i);
+ auto caseBlock = switchInst->getCaseLabel(i);
+
+ if (caseBlock == block && isIntegerConstantValue(caseValue))
+ {
+ auto constantVal = getConstantIntegerValue(caseValue);
+ statements.conjunct(
+ switchInst->getCondition(),
+ SimpleRelation::integerRelation(SimpleRelation::Equal, constantVal));
+ }
+ }
+ }
+
+done:
+ return statements;
+}
+
+enum class BlockStateFlags
+{
+ UpwardPropCompleted = 1 << 0,
+ DownwardPropCompleted = 1 << 1
+};
+
+void markUpwardPropCompleted(IRBlock* block)
+{
+ block->scratchData |= (UInt64)BlockStateFlags::UpwardPropCompleted;
+}
+
+void markDownwardPropCompleted(IRBlock* block)
+{
+ block->scratchData |= (UInt64)BlockStateFlags::DownwardPropCompleted;
+}
+
+bool isUpwardPropCompleted(IRBlock* block)
+{
+ return block->scratchData & (UInt64)BlockStateFlags::UpwardPropCompleted;
+}
+
+bool isDownwardPropCompleted(IRBlock* block)
+{
+ return block->scratchData & (UInt64)BlockStateFlags::DownwardPropCompleted;
+}
+
+void clearBlockState(IRBlock* block)
+{
+ block->scratchData = 0;
+}
+
+bool isLoopConditionBlock(IRBlock* block)
+{
+ for (auto use = block->firstUse; use; use = use->nextUse)
+ {
+ if (auto loop = as<IRLoop>(use->getUser()))
+ {
+ if (loop->getTargetBlock() == block)
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool isBlockReadyForUpwardProp(IRBlock* block)
+{
+ if (isLoopConditionBlock(block))
+ {
+ auto falseBlock = cast<IRIfElse>(block->getTerminator())->getFalseBlock();
+ return isUpwardPropCompleted(falseBlock);
+ }
+
+ // Check that successors have completed upward propagation.
+ for (auto successor : block->getSuccessors())
+ {
+ if (!isUpwardPropCompleted(successor))
+ return false;
+ }
+ return true;
+}
+
+bool isBlockReadyForDownwardProp(IRBlock* block)
+{
+ // Check that predecessors have completed downward propagation.
+ for (auto predecessor : block->getPredecessors())
+ {
+ if (!isDownwardPropCompleted(predecessor))
+ return false;
+ }
+ return true;
+}
+
+StatementSet propagateStatementUpwards(IRInst* inst, SimpleRelation relation)
+{
+ // Lambda to make a single-statement set.
+ auto makeStatementSet = [&](IRInst* inst, SimpleRelation relation)
+ {
+ StatementSet set;
+ set.conjunct(inst, relation);
+ return set;
+ };
+
+ if (as<IRParam>(inst))
+ return makeStatementSet(inst, relation);
+
+ if (isIntegerConstantValue(inst))
+ {
+ auto relationFromInst =
+ SimpleRelation::integerRelation(SimpleRelation::Equal, getConstantIntegerValue(inst));
+ if (doesRelationImply(relation, relationFromInst))
+ return makeStatementSet(inst, SimpleRelation::anyRelation()); // Trivially true
+ else if (doesRelationImply(relation, relationFromInst.negated()))
+ return makeStatementSet(inst, SimpleRelation::impossibleRelation());
+ else
+ return makeStatementSet(inst, SimpleRelation::anyRelation());
+ }
+ else if (isBoolConstantValue(inst))
+ {
+ auto relationFromInst = SimpleRelation::boolRelation(getConstantBoolValue(inst));
+ if (doesRelationImply(relation, relationFromInst))
+ return makeStatementSet(inst, SimpleRelation::anyRelation()); // Trivially true
+ else if (doesRelationImply(relation, relationFromInst.negated()))
+ return makeStatementSet(inst, SimpleRelation::impossibleRelation());
+ else
+ return makeStatementSet(inst, SimpleRelation::anyRelation());
+ }
+ else if (inst->getOp() == kIROp_Add || inst->getOp() == kIROp_Sub)
+ {
+ // TODO: Translate equality/inequality.
+ }
+
+ return makeStatementSet(inst, SimpleRelation::anyRelation());
+}
+
+StatementSet propagateUpwards(
+ RefPtr<IRDominatorTree> domTree,
+ IRBlock* current,
+ IRBlock* predecessor,
+ StatementSet predicateSet)
+{
+ // Translate the set of predicates from the current block to the predecessor block.
+ //
+ // The key idea is that we need to find a set of predicate statements (A') for the predecessor
+ // block, such that A => A'.
+ //
+ // During the downward phase, the predecessor will then return a set of
+ // statements (B') such that A' => B'. This B' can be propagated "downwards" into a set
+ // of statements B such that B' => B.
+ //
+ // We can then combine these three rules A => A', A' => B' and B' => B to get A => B
+ // which is the statement set that we want for our current block.
+ //
+
+ StatementSet newPredicateSet;
+ for (auto& statementInstPair : predicateSet.statements)
+ {
+ auto predicateRelation = statementInstPair.second;
+ auto predicateInst = statementInstPair.first;
+ if (as<IRParam>(predicateInst) && predicateInst->getParent() == current)
+ {
+ auto paramIndex = getParamIndexInBlock(cast<IRParam>(predicateInst));
+ auto translatedInst =
+ as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(paramIndex);
+
+ // If the translate inst is outside the block, add it in as-is, otherwise,
+ // we'll need to propagate it to the operands of the inst
+ //
+ auto statementSet = propagateStatementUpwards(translatedInst, predicateRelation);
+ newPredicateSet.conjunct(statementSet);
+ }
+ else
+ {
+ newPredicateSet.conjunct(predicateInst, predicateRelation);
+ }
+ }
+
+ // If our current block is a merge block for a conditional branch, we should add the condition
+ // to the predicate set.
+ //
+ for (auto blockUse = current->firstUse; blockUse; blockUse = blockUse->nextUse)
+ {
+ if (auto ifElse = as<IRIfElse>(blockUse->getUser()))
+ {
+ if (ifElse->getAfterBlock() == current)
+ {
+ // We're looking at the merge block for a conditional branch.
+
+ if (domTree->dominates(ifElse->getTrueBlock(), predecessor))
+ {
+ // True branch
+ newPredicateSet.conjunct(tryExtractStatements(ifElse, ifElse->getTrueBlock()));
+ }
+ else if (domTree->dominates(ifElse->getFalseBlock(), predecessor))
+ {
+ // False branch
+ newPredicateSet.conjunct(tryExtractStatements(ifElse, ifElse->getFalseBlock()));
+ }
+ else
+ {
+ // Panic
+ SLANG_UNREACHABLE("Unreachable block in conditional branch");
+ }
+ }
+ }
+
+ // We'll ignore switch statements for now, but they're trivial to add.
+ // TODO: Add switch statements.
+ }
+
+ // We have one more edge-case. The condition block of a loop inst.
+ if (auto ifElse = as<IRIfElse>(current->getTerminator()))
+ {
+ if (domTree->dominates(ifElse->getTrueBlock(), predecessor) &&
+ !domTree->dominates(ifElse->getFalseBlock(), predecessor))
+ {
+ // True branch
+ newPredicateSet.conjunct(tryExtractStatements(ifElse, ifElse->getTrueBlock()));
+ }
+ }
+ return newPredicateSet;
+}
+
+StatementSet propagateStatementDownwards(
+ IRInst* srcInst,
+ IRInst* dstInst,
+ StatementSet srcStatements)
+{
+ // We'll keep translating through the inst, until we either hit a parameter
+ // until we either hit a parameter, or we leave the current block.
+ //
+
+ // Lambda to make a single-statement set.
+ auto singleStatement = [&](IRInst* inst, SimpleRelation relation)
+ {
+ StatementSet set;
+ set.conjunct(inst, relation);
+ return set;
+ };
+
+ if (srcStatements.statements.containsKey(srcInst))
+ return singleStatement(dstInst, srcStatements.statements[srcInst]);
+
+ if (isIntegerConstantValue(srcInst))
+ {
+ return singleStatement(
+ dstInst,
+ SimpleRelation::integerRelation(
+ SimpleRelation::Equal,
+ getConstantIntegerValue(srcInst)));
+ }
+ else if (isBoolConstantValue(srcInst))
+ {
+ return singleStatement(
+ dstInst,
+ SimpleRelation::boolRelation(getConstantBoolValue(srcInst)));
+ }
+
+ if (srcInst->getOp() == kIROp_Add || srcInst->getOp() == kIROp_Sub)
+ {
+ auto left = srcInst->getOperand(0);
+ auto right = srcInst->getOperand(1);
+
+ auto isLeftConstant = isIntegerConstantValue(left);
+ auto isRightConstant = isIntegerConstantValue(right);
+
+ if (!isLeftConstant && !isRightConstant)
+ return singleStatement(dstInst,
+ SimpleRelation::anyRelation()); // Can't say anything
+
+ if (srcInst->getOp() == kIROp_Add || (srcInst->getOp() == kIROp_Sub && isRightConstant))
+ {
+ auto constant =
+ isLeftConstant ? getConstantIntegerValue(left) : getConstantIntegerValue(right);
+ auto operand = isLeftConstant ? right : left;
+
+ constant = srcInst->getOp() == kIROp_Add ? constant : -constant;
+
+ auto operandStatement = propagateStatementDownwards(operand, operand, srcStatements);
+ auto relation = operandStatement.statements.containsKey(operand)
+ ? operandStatement.statements[operand]
+ : SimpleRelation::anyRelation();
+
+ if (relation.type == SimpleRelation::IntegerRelation)
+ {
+ switch (relation.comparator)
+ {
+ case SimpleRelation::Equal:
+ case SimpleRelation::NotEqual:
+ case SimpleRelation::LessThanEqual:
+ case SimpleRelation::GreaterThanEqual:
+ return singleStatement(
+ dstInst,
+ SimpleRelation::integerRelation(
+ relation.comparator,
+ constant + relation.integerValue));
+ }
+ }
+ }
+ }
+
+ // Default
+ return singleStatement(dstInst, SimpleRelation::anyRelation());
+}
+
+StatementSet propagateDownwards(
+ RefPtr<IRDominatorTree> domTree,
+ IRBlock* successor,
+ IRBlock* predecessor,
+ StatementSet statementSet)
+{
+ // Translate a set of statements from the current block to the successor block.
+ //
+ // That is, find a set of statements (B') for the successor block such that B => B'
+ //
+ StatementSet newStatementSet;
+
+ if (statementSet.isTriviallyFalse())
+ {
+ return statementSet;
+ }
+
+ // Go over all the parameters of the successor block, find corresponding arguments, and
+ // convert any statements to the new set.
+ //
+ UInt paramIndex = 0;
+ for (auto param : successor->getParams())
+ {
+ auto arg = as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(paramIndex);
+ auto statement = propagateStatementDownwards(arg, param, statementSet);
+ newStatementSet.conjunct(statement);
+ paramIndex++;
+ }
+
+ newStatementSet.conjunct(tryExtractStatements(predecessor->getTerminator(), successor));
+
+ // For all other statements in the statementSet, we'll add them in, but only
+ // if the predecessor dominates the successor.
+ // An exception is parameters defined in the successor (since these are getting
+ // redefined, we should not be considering existing statements)
+ //
+ for (auto& statement : statementSet.statements)
+ {
+ if (domTree->dominates(statement.first->getParent(), successor) &&
+ !(as<IRParam>(statement.first) && statement.first->getParent() == successor))
+ newStatementSet.conjunct(statement.first, statement.second);
+ }
+
+ return newStatementSet;
+}
+
+struct Edge
+{
+ IRBlock* predecessor;
+ IRBlock* successor;
+
+ bool operator==(const Edge& other) const
+ {
+ return predecessor == other.predecessor && successor == other.successor;
+ }
+
+ UInt64 getHashCode() const
+ {
+ UInt64 predHash = Slang::getHashCode(predecessor);
+ UInt64 succHash = Slang::getHashCode(successor);
+ return Slang::combineHash(predHash, succHash);
+ }
+};
+
+
+// This routine returns a set of implications for any insts visible in a block.
+//
+// The process uses a modified version of abstract interpretation, by first propagating a set
+// of predicates "backwards" repeatedly through the predecessors, then calculating the set of
+// implications "forwards" repeatedly through the successors.
+//
+// Note that the resulting implications don't contain all possible statements that could be inferred
+// statically (this is an undeciable problem), but rather whatever can be inferred in just two steps
+// through the blocks. This suffices for the vast majority of common loop structures.
+//
+StatementSet collectImplications(
+ RefPtr<IRDominatorTree> domTree,
+ IRBlock* block,
+ StatementSet Predicates)
+{
+ List<Edge> orderedEdgeList; // Edges in the order that they're processed.
+ HashSet<Edge> falseEdges; // Edges between blocks where the successor's predicate does not imply
+ // the predecessor's predicate.
+
+ // Initialize a work list.
+ List<IRBlock*> workList;
+ workList.add(block);
+
+ // Clear scratch bits.
+ IRFunc* func = cast<IRFunc>(domTree->code);
+ for (auto _block : func->getBlocks())
+ {
+ clearBlockState(_block);
+ }
+
+ //
+ // Upward pass: Propagate predicates through predecessors until
+ // there're no more blocks left to process.
+ //
+
+ // We'll keep track of the predicates for each block.
+ Dictionary<IRBlock*, StatementSet> blockPredicates;
+
+ blockPredicates[block] = Predicates;
+
+ while (workList.getCount() > 0)
+ {
+ auto current = workList.getLast();
+ workList.removeLast();
+
+ // If the block has already been processed, skip it.
+ if (isUpwardPropCompleted(current))
+ continue;
+
+ // If the block is not ready for upward propagation, add it to the work list.
+ if (current != block && !isBlockReadyForUpwardProp(current))
+ {
+ workList.add(current);
+ // Then add all the successors to the work list.
+ for (auto successor : current->getSuccessors())
+ workList.add(successor);
+
+ continue;
+ }
+
+ // Otherwise, we'll process the block.
+ //
+ // Get our predicate set, then propagate it to all predecessors.
+ //
+ auto predicates = blockPredicates[current];
+
+ HashSet<IRBlock*> uniquePredecessors;
+ for (auto predecessor : current->getPredecessors())
+ uniquePredecessors.add(predecessor);
+
+ for (auto predecessor : uniquePredecessors)
+ {
+ // We also need to handle the recursive case, where the predecessor
+ // is already "sealed".
+ //
+ if (isUpwardPropCompleted(predecessor))
+ {
+ orderedEdgeList.add({predecessor, current});
+
+ // Verify that "current predicate" => "predecessor predicate".
+
+ // TODO: Implement later.
+ // For now, we can default to assuming that this edge is not
+ // valid. This works fine since we're not trying to prove anything recursive (like
+ // inductivity), but we should revisit this if we do want to unify the induction
+ // value inference pass with this loop analysis system.
+ //
+
+ // We'll add this to the set of false edges so that the downward prop pass
+ // doesn't propagate any implications through this edge.
+ //
+ falseEdges.add({predecessor, current});
+ continue;
+ }
+
+ auto newPredicates = propagateUpwards(domTree, current, predecessor, predicates);
+
+ if (!blockPredicates.containsKey(predecessor))
+ blockPredicates[predecessor] = newPredicates;
+ else
+ blockPredicates[predecessor].disjunct(newPredicates);
+
+ orderedEdgeList.add({predecessor, current});
+
+ // Add predecessors to work list.
+ workList.add(predecessor);
+ }
+
+ markUpwardPropCompleted(current);
+ }
+
+ //
+ // Downward pass: Propagate implications through successors until
+ // there're no more blocks left to process.
+ //
+
+ Dictionary<IRBlock*, StatementSet> blockImplications;
+
+ // Set 'block' to something trivial base case.
+ // blockImplications[block] = blockPredicates[block]; // statement => statement
+
+ while (orderedEdgeList.getCount() > 0)
+ {
+ auto edge = orderedEdgeList.getLast();
+ orderedEdgeList.removeLast();
+
+ // Get the predicate set for the predecessor.
+ auto predecessorPredicates = blockPredicates[edge.predecessor];
+
+ // Get the implication set for the predecessor.
+ auto predecessorImplications = StatementSet();
+
+ if (falseEdges.contains(edge))
+ {
+ // Since A' => B' is not true, effectively, we can't say anything..
+ predecessorImplications = StatementSet();
+ }
+ else
+ {
+ // (A' => B') => (A' => A' ^ B')
+ predecessorImplications = blockImplications[edge.predecessor];
+ predecessorImplications.conjunct(predecessorPredicates);
+ }
+
+ // Propagate the implication set to the successor.
+ auto successorImplications =
+ propagateDownwards(domTree, edge.successor, edge.predecessor, predecessorImplications);
+
+ if (!blockImplications.containsKey(edge.successor))
+ blockImplications[edge.successor] = successorImplications;
+ else
+ blockImplications[edge.successor].disjunct(successorImplications);
+ }
+
+ // Clear scratch bits.
+ for (auto _block : func->getBlocks())
+ {
+ clearBlockState(_block);
+ }
+
+ // We should have a final set of implications for our block.
+ return blockImplications[block];
+}
+
+} // namespace Slang