summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-emit.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-loop-analysis.cpp1047
-rw-r--r--source/slang/slang-ir-autodiff-loop-analysis.h181
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp307
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h3
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h8
-rw-r--r--source/slang/slang-ir.cpp7
-rw-r--r--tests/autodiff/reverse-continue-loop.slang3
-rw-r--r--tests/autodiff/reverse-loop-diff-only-2.slang15
-rw-r--r--tests/autodiff/reverse-loop-exit-value-inference-1.slang240
-rw-r--r--tests/autodiff/reverse-loop-simple.slang (renamed from tests/autodiff/reverse-loop.slang)0
-rw-r--r--tests/autodiff/reverse-loop-simple.slang.expected.txt (renamed from tests/autodiff/reverse-loop.slang.expected.txt)0
13 files changed, 1804 insertions, 11 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 8d7577b52..4eb4719f0 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1037,7 +1037,10 @@ Result linkAndOptimizeIR(
// Report checkpointing information
if (codeGenContext->shouldReportCheckpointIntermediates())
+ {
+ simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink);
reportCheckpointIntermediates(codeGenContext, sink, irModule);
+ }
// Finalization is always run so AD-related instructions can be removed,
// even if the AD pass itself is not run.
diff --git a/source/slang/slang-ir-autodiff-loop-analysis.cpp b/source/slang/slang-ir-autodiff-loop-analysis.cpp
new file mode 100644
index 000000000..d4ff631a6
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-loop-analysis.cpp
@@ -0,0 +1,1047 @@
+// slang-ir-autodiff-loop-analysis.cpp
+
+#include "slang-ir-autodiff-loop-analysis.h"
+
+namespace Slang
+{
+
+static bool isCompareCmpInst(IRInst* inst)
+{
+ // Switch on the opcode of the instruction
+ switch (inst->getOp())
+ {
+ case kIROp_Less:
+ case kIROp_Greater:
+ case kIROp_Leq:
+ case kIROp_Geq:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ return true;
+ default:
+ return false;
+ }
+}
+
+SimpleRelation mergeEqualityWithIntegerRelation(SimpleRelation equality, SimpleRelation relation)
+{
+ SLANG_ASSERT(
+ equality.type == SimpleRelation::IntegerRelation &&
+ relation.type == SimpleRelation::IntegerRelation);
+ SLANG_ASSERT(equality.comparator == SimpleRelation::Equal);
+
+ switch (relation.comparator)
+ {
+ case SimpleRelation::Equal:
+ if (relation.integerValue == equality.integerValue)
+ return relation;
+ break; // Technically we'd want to return a "set" here, but we don't have a representation
+ // for that.
+ case SimpleRelation::LessThanEqual:
+ if (equality.integerValue <= relation.integerValue)
+ return relation;
+ break;
+ case SimpleRelation::GreaterThanEqual:
+ if (equality.integerValue >= relation.integerValue)
+ return relation;
+ break;
+ default:
+ break;
+ }
+
+ return SimpleRelation::anyRelation();
+}
+
+SimpleRelation mergeIntervals(SimpleRelation a, SimpleRelation b)
+{
+ SLANG_ASSERT(
+ a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation);
+
+ if (a.comparator == SimpleRelation::Equal)
+ {
+ return mergeEqualityWithIntegerRelation(a, b);
+ }
+ else if (b.comparator == SimpleRelation::Equal)
+ {
+ return mergeEqualityWithIntegerRelation(b, a);
+ }
+
+ // TODO: Handle other cases...
+ return SimpleRelation::anyRelation();
+}
+
+// Returns the tighest "simple" relation such that (a v b -> result)
+//
+// Note: "simple" means that the relation is not a disjunction or conjunction of other relations.
+//
+SimpleRelation relationUnion(SimpleRelation a, SimpleRelation b)
+{
+ // Base case. The disjunction operator is idempotent.
+ if (a == b)
+ return a;
+
+ // If either side is trivially true, the result is trivially true.
+ if (a.type == SimpleRelation::Any || b.type == SimpleRelation::Any)
+ return SimpleRelation::anyRelation();
+
+ // If either side is trivially false, then the result is the other relation.
+ if (a.type == SimpleRelation::Impossible)
+ return b;
+
+ if (b.type == SimpleRelation::Impossible)
+ return a;
+
+ // If one is the negated form of the other, there's really nothing we can prove, since
+ // A OR ~A is always true.
+ //
+ if (a.negated() == b)
+ return SimpleRelation::anyRelation();
+
+ // Handle the case of where one is an inequality and the other is an equality.
+ if (a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation)
+ return mergeIntervals(a, b);
+
+ // TODO: Here's where we can handle subset cases like (a < 10) and (a < 20) => (a < 20), etc..
+ // But we don't _have_ to. The more we can prove, the more cases we can handle, but the result
+ // is still correct without it.
+ //
+
+ // Default to not being able to say anything.
+ return SimpleRelation::anyRelation();
+}
+
+SimpleRelation intersectEqualityWithIntegerRelation(
+ SimpleRelation equality,
+ SimpleRelation relation)
+{
+ SLANG_ASSERT(equality.type == SimpleRelation::IntegerRelation);
+ SLANG_ASSERT(relation.type == SimpleRelation::IntegerRelation);
+ SLANG_ASSERT(equality.comparator == SimpleRelation::Equal);
+
+ if (relation.comparator == SimpleRelation::Equal)
+ {
+ if (relation.integerValue == equality.integerValue)
+ return SimpleRelation::integerRelation(SimpleRelation::Equal, equality.integerValue);
+ else
+ return SimpleRelation::impossibleRelation();
+ }
+ else if (relation.comparator == SimpleRelation::LessThanEqual)
+ {
+ if (equality.integerValue <= relation.integerValue)
+ return SimpleRelation::integerRelation(
+ SimpleRelation::LessThanEqual,
+ relation.integerValue);
+ else
+ return SimpleRelation::impossibleRelation();
+ }
+ else if (relation.comparator == SimpleRelation::GreaterThanEqual)
+ {
+ if (equality.integerValue >= relation.integerValue)
+ return SimpleRelation::integerRelation(
+ SimpleRelation::GreaterThanEqual,
+ relation.integerValue);
+ else
+ return SimpleRelation::impossibleRelation();
+ }
+ else if (relation.comparator == SimpleRelation::NotEqual)
+ {
+ if (equality.integerValue != relation.integerValue)
+ return SimpleRelation::integerRelation(SimpleRelation::NotEqual, relation.integerValue);
+ else
+ return SimpleRelation::impossibleRelation();
+ }
+
+ return SimpleRelation::anyRelation();
+}
+
+// Intersect intervals.
+SimpleRelation intersectIntervals(SimpleRelation a, SimpleRelation b)
+{
+ SLANG_ASSERT(
+ a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation);
+
+ if (a.comparator == SimpleRelation::Equal)
+ {
+ return intersectEqualityWithIntegerRelation(a, b);
+ }
+ else if (b.comparator == SimpleRelation::Equal)
+ {
+ return intersectEqualityWithIntegerRelation(b, a);
+ }
+
+ // TODO: Handle other cases...
+
+ // We'll just default to picking the first one, since (a ^ b) -> a is always true.
+ return a;
+}
+
+// Returns the best "simple" relation such that (a ^ b -> result)
+//
+SimpleRelation relationIntersection(SimpleRelation a, SimpleRelation b)
+{
+ // Base case. The conjunction operator is idempotent.
+ if (a == b)
+ return a;
+
+ // If one is the negated form of the other, then we can prove that the result is impossible.
+ // Doesn't necessarily mean that we have an error on our hands, but it does mean that whatever
+ // case we're considering can't happen, so can be ignored (unreachable)
+ //
+ if (a.negated() == b)
+ return SimpleRelation::impossibleRelation();
+
+ // If any of the relations is impossible, then the result is impossible.
+ if (a.type == SimpleRelation::Impossible || b.type == SimpleRelation::Impossible)
+ return SimpleRelation::impossibleRelation();
+
+ // If any one of the relations is trivially true, then the result is the other relation.
+ if (a.type == SimpleRelation::Any)
+ return b;
+
+ if (b.type == SimpleRelation::Any)
+ return a;
+
+ //
+ // We'll handle the cases where one is an equality and the other is an inequality or equality.
+ //
+ // i.e. For a conjunction (a == 10) ^ (a < 20), we can use the narrower relation (a == 10).
+ //
+ if (a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation)
+ return intersectIntervals(a, b);
+
+ // TODO: Handle other cases...
+ return SimpleRelation::anyRelation();
+}
+
+void StatementSet::disjunct(StatementSet other)
+{
+ // false v (a1 v a2 v a3 ...) = (a1 v a2 v a3 ...)
+ if (isTriviallyFalse())
+ {
+ statements = other.statements;
+ return;
+ }
+
+ // (a1 v a2 v a3 ...) v false = (a1 v a2 v a3 ...)
+ if (other.isTriviallyFalse())
+ return;
+
+ // true v (a1 v a2 v a3 ...) = true
+ if (other.isTriviallyTrue())
+ {
+ statements.clear();
+ return;
+ }
+
+ // (a1 v a2 v a3 ...) v true = true
+ if (isTriviallyTrue())
+ return;
+
+ for (auto& statement : other.statements)
+ {
+ // Since we hold only one statement per inst, we can perform disjunction
+ // on a per-inst basis.
+ // If an inst does not exist in the current set, then it's an empty statement.
+ //
+ if (statements.containsKey(statement.first))
+ {
+ auto newRelation = relationUnion(statement.second, statements[statement.first]);
+ set(statement.first, newRelation);
+ }
+ }
+
+ // Remove any insts that don't have a corresponding statement in the other set,
+ // since this effectively means "any".
+ //
+ for (auto& statement : statements)
+ {
+ if (!other.statements.containsKey(statement.first))
+ statements.remove(statement.first);
+ }
+}
+
+void StatementSet::conjunct(StatementSet other)
+{
+ // true ^ (a1 ^ a2 ^ a3 ...) = (a1 ^ a2 ^ a3 ...)
+ if (other.isTriviallyTrue())
+ return;
+
+ // (a1 ^ a2 ^ a3 ...) ^ true = (a1 ^ a2 ^ a3 ...)
+ if (isTriviallyTrue())
+ {
+ statements = other.statements;
+ return;
+ }
+
+ // false ^ (a1 ^ a2 ^ a3 ...) = false
+ if (isTriviallyFalse())
+ return;
+
+ // (a1 ^ a2 ^ a3 ...) ^ false = false
+ if (other.isTriviallyFalse())
+ {
+ statements = other.statements;
+ return;
+ }
+
+ // Otherwise do an element-wise conjunction.
+ for (auto& statement : other.statements)
+ {
+ if (statements.containsKey(statement.first))
+ {
+ set(statement.first,
+ relationIntersection(statement.second, statements[statement.first]));
+ }
+ else
+ {
+ set(statement.first, statement.second);
+ }
+ }
+}
+
+void StatementSet::conjunct(IRInst* inst, SimpleRelation relation)
+{
+ if (isTriviallyFalse())
+ return;
+
+ if (statements.containsKey(inst))
+ {
+ set(inst, relationIntersection(relation, statements[inst]));
+ }
+ else
+ {
+ set(inst, relation);
+ }
+}
+
+// This function answers the question: "Can we prove that relationB is true if relationA is true?"
+//
+// Note that this is not the same as "Does relationA imply relationB", since there can be cases
+// where this is indeed true, but we just don't have the logic to prove it.
+//
+bool doesRelationImply(SimpleRelation relationA, SimpleRelation relationB)
+{
+ // Equal relations imply each other
+ if (relationA == relationB)
+ return true;
+
+ // If B is trivially true, then A implies B
+ if (relationB.type == SimpleRelation::Any)
+ return true;
+
+ // If A is trivially true, then A implies B only if B is also trivially true
+ if (relationA.type == SimpleRelation::Any)
+ return (relationB.type == SimpleRelation::Any);
+
+ // If A is impossible, then technically what we return doesn't matter...
+ if (relationA.type == SimpleRelation::Impossible ||
+ relationB.type == SimpleRelation::Impossible)
+ return false;
+
+ // If A is a boolean relation, then A implies B if B is also a boolean relation and the values
+ // are the same.
+ //
+ if (relationA.type == SimpleRelation::BoolRelation)
+ return (relationB.type == SimpleRelation::BoolRelation) &&
+ (relationA.boolValue == relationB.boolValue);
+
+ if (relationA.type == SimpleRelation::IntegerRelation)
+ {
+ if (relationB.type != SimpleRelation::IntegerRelation)
+ return false;
+
+ // Technically, the equality case is already handled above, so we'll only consider
+ // cases where A and B are not the same relation, but where A -> B
+
+ // If A is an equality, and B is an inequality, we can test
+ if (relationA.comparator == SimpleRelation::Equal)
+ {
+ if (relationB.comparator == SimpleRelation::LessThanEqual)
+ return relationA.integerValue <= relationB.integerValue;
+ else if (relationB.comparator == SimpleRelation::GreaterThanEqual)
+ return relationA.integerValue >= relationB.integerValue;
+ }
+
+ // If A is an equality, and B is an inequality with different values, then
+ // A -> B
+ //
+ if (relationA.comparator == SimpleRelation::Equal &&
+ relationB.comparator == SimpleRelation::NotEqual)
+ {
+ return relationA.integerValue != relationB.integerValue;
+ }
+
+ if (relationA.comparator == SimpleRelation::GreaterThanEqual &&
+ relationB.comparator == SimpleRelation::GreaterThanEqual)
+ {
+ return relationA.integerValue >= relationB.integerValue;
+ }
+
+ if (relationA.comparator == SimpleRelation::LessThanEqual &&
+ relationB.comparator == SimpleRelation::LessThanEqual)
+ {
+ return relationA.integerValue <= relationB.integerValue;
+ }
+
+ // TODO: Handle other cases.. these come up rarely, so we can
+ }
+
+ return false;
+}
+
+bool isIntegerConstantValue(IRInst* inst)
+{
+ return inst->getOp() == kIROp_IntLit;
+}
+
+bool isBoolConstantValue(IRInst* inst)
+{
+ return inst->getOp() == kIROp_BoolLit;
+}
+
+IRIntegerValue getConstantIntegerValue(IRInst* inst)
+{
+ SLANG_ASSERT(isIntegerConstantValue(inst));
+ return as<IRIntLit>(inst)->getValue();
+}
+
+bool getConstantBoolValue(IRInst* inst)
+{
+ SLANG_ASSERT(isBoolConstantValue(inst));
+ return as<IRBoolLit>(inst)->getValue();
+}
+
+StatementSet tryExtractStatements(IRTerminatorInst* inst, IRBlock* block)
+{
+ StatementSet statements;
+
+ // From condInst, extract a statement about any inst such that we have an equality
+ // statement (integer or boolean) on the inst.
+ //
+ if (auto ifElse = as<IRIfElse>(inst))
+ {
+ // Check that the block is the true or false block of the if-else
+ bool isTrueBlock = ifElse->getTrueBlock() == block;
+ bool isFalseBlock = ifElse->getFalseBlock() == block;
+ if (!isTrueBlock && !isFalseBlock)
+ goto done;
+
+ auto condInst = inst->getOperand(0);
+ statements.conjunct(condInst, SimpleRelation::boolRelation(isTrueBlock));
+
+ if (condInst->getOp() == kIROp_Eql)
+ {
+ auto leftOperand = condInst->getOperand(0);
+ auto rightOperand = condInst->getOperand(1);
+
+ if (isIntegerConstantValue(leftOperand))
+ {
+ statements.conjunct(
+ rightOperand,
+ SimpleRelation::integerRelation(
+ (isTrueBlock ? SimpleRelation::Equal : SimpleRelation::NotEqual),
+ getConstantIntegerValue(leftOperand)));
+ }
+ else if (isIntegerConstantValue(rightOperand))
+ {
+ statements.conjunct(
+ leftOperand,
+ SimpleRelation::integerRelation(
+ (isTrueBlock ? SimpleRelation::Equal : SimpleRelation::NotEqual),
+ getConstantIntegerValue(rightOperand)));
+ }
+ }
+ else if (isCompareCmpInst(condInst))
+ {
+ auto leftOperand = condInst->getOperand(0);
+ auto rightOperand = condInst->getOperand(1);
+
+ bool isParamLeft = !isIntegerConstantValue(leftOperand);
+ bool isParamRight = !isIntegerConstantValue(rightOperand);
+
+ // If neither operand is an inst, we can't say anything.
+ if (!isParamLeft && !isParamRight)
+ goto done;
+
+ auto paramOperand = isParamLeft ? leftOperand : rightOperand;
+ auto otherOperand = isParamLeft ? rightOperand : leftOperand;
+
+ // Check if the "other" operand is a constant
+ if (!isIntegerConstantValue(otherOperand))
+ goto done;
+
+ auto constantVal = getConstantIntegerValue(otherOperand);
+
+ SimpleRelation::Comparator comparator;
+ switch (condInst->getOp())
+ {
+ case kIROp_Less:
+ comparator = SimpleRelation::LessThanEqual;
+ constantVal = constantVal - 1;
+ break;
+ case kIROp_Greater:
+ comparator = SimpleRelation::GreaterThanEqual;
+ constantVal = constantVal + 1;
+ break;
+ case kIROp_Leq:
+ comparator = SimpleRelation::LessThanEqual;
+ break;
+ case kIROp_Geq:
+ comparator = SimpleRelation::GreaterThanEqual;
+ break;
+ case kIROp_Eql:
+ comparator = SimpleRelation::Equal;
+ break;
+ case kIROp_Neq:
+ comparator = SimpleRelation::NotEqual;
+ break;
+ default:
+ SLANG_UNREACHABLE("unexpected op code");
+ }
+ auto relation = SimpleRelation::integerRelation(comparator, constantVal);
+ statements.conjunct(
+ paramOperand,
+ ((isParamLeft ^ !isTrueBlock) ? relation : relation.negated()));
+ }
+ }
+ else if (auto switchInst = as<IRSwitch>(inst))
+ {
+ // Check that the block is the default case of the switch
+ if (switchInst->getDefaultLabel() == block)
+ goto done;
+
+ // Check each case block
+ UInt caseCount = switchInst->getCaseCount();
+ for (UInt i = 0; i < caseCount; i++)
+ {
+ auto caseValue = switchInst->getCaseValue(i);
+ auto caseBlock = switchInst->getCaseLabel(i);
+
+ if (caseBlock == block && isIntegerConstantValue(caseValue))
+ {
+ auto constantVal = getConstantIntegerValue(caseValue);
+ statements.conjunct(
+ switchInst->getCondition(),
+ SimpleRelation::integerRelation(SimpleRelation::Equal, constantVal));
+ }
+ }
+ }
+
+done:
+ return statements;
+}
+
+enum class BlockStateFlags
+{
+ UpwardPropCompleted = 1 << 0,
+ DownwardPropCompleted = 1 << 1
+};
+
+void markUpwardPropCompleted(IRBlock* block)
+{
+ block->scratchData |= (UInt64)BlockStateFlags::UpwardPropCompleted;
+}
+
+void markDownwardPropCompleted(IRBlock* block)
+{
+ block->scratchData |= (UInt64)BlockStateFlags::DownwardPropCompleted;
+}
+
+bool isUpwardPropCompleted(IRBlock* block)
+{
+ return block->scratchData & (UInt64)BlockStateFlags::UpwardPropCompleted;
+}
+
+bool isDownwardPropCompleted(IRBlock* block)
+{
+ return block->scratchData & (UInt64)BlockStateFlags::DownwardPropCompleted;
+}
+
+void clearBlockState(IRBlock* block)
+{
+ block->scratchData = 0;
+}
+
+bool isLoopConditionBlock(IRBlock* block)
+{
+ for (auto use = block->firstUse; use; use = use->nextUse)
+ {
+ if (auto loop = as<IRLoop>(use->getUser()))
+ {
+ if (loop->getTargetBlock() == block)
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool isBlockReadyForUpwardProp(IRBlock* block)
+{
+ if (isLoopConditionBlock(block))
+ {
+ auto falseBlock = cast<IRIfElse>(block->getTerminator())->getFalseBlock();
+ return isUpwardPropCompleted(falseBlock);
+ }
+
+ // Check that successors have completed upward propagation.
+ for (auto successor : block->getSuccessors())
+ {
+ if (!isUpwardPropCompleted(successor))
+ return false;
+ }
+ return true;
+}
+
+bool isBlockReadyForDownwardProp(IRBlock* block)
+{
+ // Check that predecessors have completed downward propagation.
+ for (auto predecessor : block->getPredecessors())
+ {
+ if (!isDownwardPropCompleted(predecessor))
+ return false;
+ }
+ return true;
+}
+
+StatementSet propagateStatementUpwards(IRInst* inst, SimpleRelation relation)
+{
+ // Lambda to make a single-statement set.
+ auto makeStatementSet = [&](IRInst* inst, SimpleRelation relation)
+ {
+ StatementSet set;
+ set.conjunct(inst, relation);
+ return set;
+ };
+
+ if (as<IRParam>(inst))
+ return makeStatementSet(inst, relation);
+
+ if (isIntegerConstantValue(inst))
+ {
+ auto relationFromInst =
+ SimpleRelation::integerRelation(SimpleRelation::Equal, getConstantIntegerValue(inst));
+ if (doesRelationImply(relation, relationFromInst))
+ return makeStatementSet(inst, SimpleRelation::anyRelation()); // Trivially true
+ else if (doesRelationImply(relation, relationFromInst.negated()))
+ return makeStatementSet(inst, SimpleRelation::impossibleRelation());
+ else
+ return makeStatementSet(inst, SimpleRelation::anyRelation());
+ }
+ else if (isBoolConstantValue(inst))
+ {
+ auto relationFromInst = SimpleRelation::boolRelation(getConstantBoolValue(inst));
+ if (doesRelationImply(relation, relationFromInst))
+ return makeStatementSet(inst, SimpleRelation::anyRelation()); // Trivially true
+ else if (doesRelationImply(relation, relationFromInst.negated()))
+ return makeStatementSet(inst, SimpleRelation::impossibleRelation());
+ else
+ return makeStatementSet(inst, SimpleRelation::anyRelation());
+ }
+ else if (inst->getOp() == kIROp_Add || inst->getOp() == kIROp_Sub)
+ {
+ // TODO: Translate equality/inequality.
+ }
+
+ return makeStatementSet(inst, SimpleRelation::anyRelation());
+}
+
+StatementSet propagateUpwards(
+ RefPtr<IRDominatorTree> domTree,
+ IRBlock* current,
+ IRBlock* predecessor,
+ StatementSet predicateSet)
+{
+ // Translate the set of predicates from the current block to the predecessor block.
+ //
+ // The key idea is that we need to find a set of predicate statements (A') for the predecessor
+ // block, such that A => A'.
+ //
+ // During the downward phase, the predecessor will then return a set of
+ // statements (B') such that A' => B'. This B' can be propagated "downwards" into a set
+ // of statements B such that B' => B.
+ //
+ // We can then combine these three rules A => A', A' => B' and B' => B to get A => B
+ // which is the statement set that we want for our current block.
+ //
+
+ StatementSet newPredicateSet;
+ for (auto& statementInstPair : predicateSet.statements)
+ {
+ auto predicateRelation = statementInstPair.second;
+ auto predicateInst = statementInstPair.first;
+ if (as<IRParam>(predicateInst) && predicateInst->getParent() == current)
+ {
+ auto paramIndex = getParamIndexInBlock(cast<IRParam>(predicateInst));
+ auto translatedInst =
+ as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(paramIndex);
+
+ // If the translate inst is outside the block, add it in as-is, otherwise,
+ // we'll need to propagate it to the operands of the inst
+ //
+ auto statementSet = propagateStatementUpwards(translatedInst, predicateRelation);
+ newPredicateSet.conjunct(statementSet);
+ }
+ else
+ {
+ newPredicateSet.conjunct(predicateInst, predicateRelation);
+ }
+ }
+
+ // If our current block is a merge block for a conditional branch, we should add the condition
+ // to the predicate set.
+ //
+ for (auto blockUse = current->firstUse; blockUse; blockUse = blockUse->nextUse)
+ {
+ if (auto ifElse = as<IRIfElse>(blockUse->getUser()))
+ {
+ if (ifElse->getAfterBlock() == current)
+ {
+ // We're looking at the merge block for a conditional branch.
+
+ if (domTree->dominates(ifElse->getTrueBlock(), predecessor))
+ {
+ // True branch
+ newPredicateSet.conjunct(tryExtractStatements(ifElse, ifElse->getTrueBlock()));
+ }
+ else if (domTree->dominates(ifElse->getFalseBlock(), predecessor))
+ {
+ // False branch
+ newPredicateSet.conjunct(tryExtractStatements(ifElse, ifElse->getFalseBlock()));
+ }
+ else
+ {
+ // Panic
+ SLANG_UNREACHABLE("Unreachable block in conditional branch");
+ }
+ }
+ }
+
+ // We'll ignore switch statements for now, but they're trivial to add.
+ // TODO: Add switch statements.
+ }
+
+ // We have one more edge-case. The condition block of a loop inst.
+ if (auto ifElse = as<IRIfElse>(current->getTerminator()))
+ {
+ if (domTree->dominates(ifElse->getTrueBlock(), predecessor) &&
+ !domTree->dominates(ifElse->getFalseBlock(), predecessor))
+ {
+ // True branch
+ newPredicateSet.conjunct(tryExtractStatements(ifElse, ifElse->getTrueBlock()));
+ }
+ }
+ return newPredicateSet;
+}
+
+StatementSet propagateStatementDownwards(
+ IRInst* srcInst,
+ IRInst* dstInst,
+ StatementSet srcStatements)
+{
+ // We'll keep translating through the inst, until we either hit a parameter
+ // until we either hit a parameter, or we leave the current block.
+ //
+
+ // Lambda to make a single-statement set.
+ auto singleStatement = [&](IRInst* inst, SimpleRelation relation)
+ {
+ StatementSet set;
+ set.conjunct(inst, relation);
+ return set;
+ };
+
+ if (srcStatements.statements.containsKey(srcInst))
+ return singleStatement(dstInst, srcStatements.statements[srcInst]);
+
+ if (isIntegerConstantValue(srcInst))
+ {
+ return singleStatement(
+ dstInst,
+ SimpleRelation::integerRelation(
+ SimpleRelation::Equal,
+ getConstantIntegerValue(srcInst)));
+ }
+ else if (isBoolConstantValue(srcInst))
+ {
+ return singleStatement(
+ dstInst,
+ SimpleRelation::boolRelation(getConstantBoolValue(srcInst)));
+ }
+
+ if (srcInst->getOp() == kIROp_Add || srcInst->getOp() == kIROp_Sub)
+ {
+ auto left = srcInst->getOperand(0);
+ auto right = srcInst->getOperand(1);
+
+ auto isLeftConstant = isIntegerConstantValue(left);
+ auto isRightConstant = isIntegerConstantValue(right);
+
+ if (!isLeftConstant && !isRightConstant)
+ return singleStatement(dstInst,
+ SimpleRelation::anyRelation()); // Can't say anything
+
+ if (srcInst->getOp() == kIROp_Add || (srcInst->getOp() == kIROp_Sub && isRightConstant))
+ {
+ auto constant =
+ isLeftConstant ? getConstantIntegerValue(left) : getConstantIntegerValue(right);
+ auto operand = isLeftConstant ? right : left;
+
+ constant = srcInst->getOp() == kIROp_Add ? constant : -constant;
+
+ auto operandStatement = propagateStatementDownwards(operand, operand, srcStatements);
+ auto relation = operandStatement.statements.containsKey(operand)
+ ? operandStatement.statements[operand]
+ : SimpleRelation::anyRelation();
+
+ if (relation.type == SimpleRelation::IntegerRelation)
+ {
+ switch (relation.comparator)
+ {
+ case SimpleRelation::Equal:
+ case SimpleRelation::NotEqual:
+ case SimpleRelation::LessThanEqual:
+ case SimpleRelation::GreaterThanEqual:
+ return singleStatement(
+ dstInst,
+ SimpleRelation::integerRelation(
+ relation.comparator,
+ constant + relation.integerValue));
+ }
+ }
+ }
+ }
+
+ // Default
+ return singleStatement(dstInst, SimpleRelation::anyRelation());
+}
+
+StatementSet propagateDownwards(
+ RefPtr<IRDominatorTree> domTree,
+ IRBlock* successor,
+ IRBlock* predecessor,
+ StatementSet statementSet)
+{
+ // Translate a set of statements from the current block to the successor block.
+ //
+ // That is, find a set of statements (B') for the successor block such that B => B'
+ //
+ StatementSet newStatementSet;
+
+ if (statementSet.isTriviallyFalse())
+ {
+ return statementSet;
+ }
+
+ // Go over all the parameters of the successor block, find corresponding arguments, and
+ // convert any statements to the new set.
+ //
+ UInt paramIndex = 0;
+ for (auto param : successor->getParams())
+ {
+ auto arg = as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(paramIndex);
+ auto statement = propagateStatementDownwards(arg, param, statementSet);
+ newStatementSet.conjunct(statement);
+ paramIndex++;
+ }
+
+ newStatementSet.conjunct(tryExtractStatements(predecessor->getTerminator(), successor));
+
+ // For all other statements in the statementSet, we'll add them in, but only
+ // if the predecessor dominates the successor.
+ // An exception is parameters defined in the successor (since these are getting
+ // redefined, we should not be considering existing statements)
+ //
+ for (auto& statement : statementSet.statements)
+ {
+ if (domTree->dominates(statement.first->getParent(), successor) &&
+ !(as<IRParam>(statement.first) && statement.first->getParent() == successor))
+ newStatementSet.conjunct(statement.first, statement.second);
+ }
+
+ return newStatementSet;
+}
+
+struct Edge
+{
+ IRBlock* predecessor;
+ IRBlock* successor;
+
+ bool operator==(const Edge& other) const
+ {
+ return predecessor == other.predecessor && successor == other.successor;
+ }
+
+ UInt64 getHashCode() const
+ {
+ UInt64 predHash = Slang::getHashCode(predecessor);
+ UInt64 succHash = Slang::getHashCode(successor);
+ return Slang::combineHash(predHash, succHash);
+ }
+};
+
+
+// This routine returns a set of implications for any insts visible in a block.
+//
+// The process uses a modified version of abstract interpretation, by first propagating a set
+// of predicates "backwards" repeatedly through the predecessors, then calculating the set of
+// implications "forwards" repeatedly through the successors.
+//
+// Note that the resulting implications don't contain all possible statements that could be inferred
+// statically (this is an undeciable problem), but rather whatever can be inferred in just two steps
+// through the blocks. This suffices for the vast majority of common loop structures.
+//
+StatementSet collectImplications(
+ RefPtr<IRDominatorTree> domTree,
+ IRBlock* block,
+ StatementSet Predicates)
+{
+ List<Edge> orderedEdgeList; // Edges in the order that they're processed.
+ HashSet<Edge> falseEdges; // Edges between blocks where the successor's predicate does not imply
+ // the predecessor's predicate.
+
+ // Initialize a work list.
+ List<IRBlock*> workList;
+ workList.add(block);
+
+ // Clear scratch bits.
+ IRFunc* func = cast<IRFunc>(domTree->code);
+ for (auto _block : func->getBlocks())
+ {
+ clearBlockState(_block);
+ }
+
+ //
+ // Upward pass: Propagate predicates through predecessors until
+ // there're no more blocks left to process.
+ //
+
+ // We'll keep track of the predicates for each block.
+ Dictionary<IRBlock*, StatementSet> blockPredicates;
+
+ blockPredicates[block] = Predicates;
+
+ while (workList.getCount() > 0)
+ {
+ auto current = workList.getLast();
+ workList.removeLast();
+
+ // If the block has already been processed, skip it.
+ if (isUpwardPropCompleted(current))
+ continue;
+
+ // If the block is not ready for upward propagation, add it to the work list.
+ if (current != block && !isBlockReadyForUpwardProp(current))
+ {
+ workList.add(current);
+ // Then add all the successors to the work list.
+ for (auto successor : current->getSuccessors())
+ workList.add(successor);
+
+ continue;
+ }
+
+ // Otherwise, we'll process the block.
+ //
+ // Get our predicate set, then propagate it to all predecessors.
+ //
+ auto predicates = blockPredicates[current];
+
+ HashSet<IRBlock*> uniquePredecessors;
+ for (auto predecessor : current->getPredecessors())
+ uniquePredecessors.add(predecessor);
+
+ for (auto predecessor : uniquePredecessors)
+ {
+ // We also need to handle the recursive case, where the predecessor
+ // is already "sealed".
+ //
+ if (isUpwardPropCompleted(predecessor))
+ {
+ orderedEdgeList.add({predecessor, current});
+
+ // Verify that "current predicate" => "predecessor predicate".
+
+ // TODO: Implement later.
+ // For now, we can default to assuming that this edge is not
+ // valid. This works fine since we're not trying to prove anything recursive (like
+ // inductivity), but we should revisit this if we do want to unify the induction
+ // value inference pass with this loop analysis system.
+ //
+
+ // We'll add this to the set of false edges so that the downward prop pass
+ // doesn't propagate any implications through this edge.
+ //
+ falseEdges.add({predecessor, current});
+ continue;
+ }
+
+ auto newPredicates = propagateUpwards(domTree, current, predecessor, predicates);
+
+ if (!blockPredicates.containsKey(predecessor))
+ blockPredicates[predecessor] = newPredicates;
+ else
+ blockPredicates[predecessor].disjunct(newPredicates);
+
+ orderedEdgeList.add({predecessor, current});
+
+ // Add predecessors to work list.
+ workList.add(predecessor);
+ }
+
+ markUpwardPropCompleted(current);
+ }
+
+ //
+ // Downward pass: Propagate implications through successors until
+ // there're no more blocks left to process.
+ //
+
+ Dictionary<IRBlock*, StatementSet> blockImplications;
+
+ // Set 'block' to something trivial base case.
+ // blockImplications[block] = blockPredicates[block]; // statement => statement
+
+ while (orderedEdgeList.getCount() > 0)
+ {
+ auto edge = orderedEdgeList.getLast();
+ orderedEdgeList.removeLast();
+
+ // Get the predicate set for the predecessor.
+ auto predecessorPredicates = blockPredicates[edge.predecessor];
+
+ // Get the implication set for the predecessor.
+ auto predecessorImplications = StatementSet();
+
+ if (falseEdges.contains(edge))
+ {
+ // Since A' => B' is not true, effectively, we can't say anything..
+ predecessorImplications = StatementSet();
+ }
+ else
+ {
+ // (A' => B') => (A' => A' ^ B')
+ predecessorImplications = blockImplications[edge.predecessor];
+ predecessorImplications.conjunct(predecessorPredicates);
+ }
+
+ // Propagate the implication set to the successor.
+ auto successorImplications =
+ propagateDownwards(domTree, edge.successor, edge.predecessor, predecessorImplications);
+
+ if (!blockImplications.containsKey(edge.successor))
+ blockImplications[edge.successor] = successorImplications;
+ else
+ blockImplications[edge.successor].disjunct(successorImplications);
+ }
+
+ // Clear scratch bits.
+ for (auto _block : func->getBlocks())
+ {
+ clearBlockState(_block);
+ }
+
+ // We should have a final set of implications for our block.
+ return blockImplications[block];
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-autodiff-loop-analysis.h b/source/slang/slang-ir-autodiff-loop-analysis.h
new file mode 100644
index 000000000..f13ccd7b9
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-loop-analysis.h
@@ -0,0 +1,181 @@
+// slang-ir-autodiff-loop-analysis.h
+#pragma once
+
+#include "slang-ir-autodiff-region.h"
+#include "slang-ir-autodiff.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-insts.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct SimpleRelation
+{
+ enum Type
+ {
+ Any, // Target can be anything (all values are possible)
+ IntegerRelation, // Target satisfies a simple integer equality/inequality
+ BoolRelation, // Target satisfies boolean equality
+ Impossible // Target is impossible (has no possible values)
+ } type;
+
+ enum Comparator
+ {
+ LessThanEqual,
+ GreaterThanEqual,
+ Equal,
+ NotEqual
+ } comparator;
+ IRIntegerValue integerValue;
+ bool boolValue;
+
+ static SimpleRelation integerRelation(Comparator comparator, IRIntegerValue integerValue)
+ {
+ return SimpleRelation{IntegerRelation, comparator, integerValue, false};
+ }
+
+ static SimpleRelation boolRelation(bool boolValue)
+ {
+ return SimpleRelation{BoolRelation, Equal, 0, boolValue};
+ }
+
+ static SimpleRelation impossibleRelation()
+ {
+ return SimpleRelation{Impossible, Equal, 0, false};
+ }
+
+ static SimpleRelation anyRelation() { return SimpleRelation{Any, Equal, 0, false}; }
+
+ bool operator==(const SimpleRelation& other) const
+ {
+ switch (type)
+ {
+ case Any:
+ return other.type == Any;
+ case IntegerRelation:
+ return other.type == IntegerRelation && comparator == other.comparator &&
+ integerValue == other.integerValue;
+ case BoolRelation:
+ return other.type == BoolRelation && boolValue == other.boolValue;
+ case Impossible:
+ return other.type == Impossible;
+ default:
+ SLANG_UNREACHABLE("Unhandled relation type");
+ }
+ }
+
+ bool operator!=(const SimpleRelation& other) const { return !(*this == other); }
+
+ SimpleRelation negated() const
+ {
+ switch (type)
+ {
+ case Any:
+ return SimpleRelation{Impossible, Equal, 0, false};
+ case Impossible:
+ return SimpleRelation{Any, Equal, 0, false};
+ case BoolRelation:
+ return SimpleRelation{BoolRelation, Equal, 0, !boolValue};
+ case IntegerRelation:
+ switch (comparator)
+ {
+ case LessThanEqual:
+ return SimpleRelation{IntegerRelation, GreaterThanEqual, integerValue + 1, false};
+ case GreaterThanEqual:
+ return SimpleRelation{IntegerRelation, LessThanEqual, integerValue - 1, false};
+ case Equal:
+ return SimpleRelation{IntegerRelation, NotEqual, integerValue, false};
+ case NotEqual:
+ return SimpleRelation{IntegerRelation, Equal, integerValue, false};
+ default:
+ SLANG_UNREACHABLE("Unhandled comparator");
+ }
+ default:
+ SLANG_UNREACHABLE("Unhandled relation type");
+ }
+ }
+
+ HashCode64 getHashCode() const
+ {
+ HashCode64 code = Slang::getHashCode(int(type));
+ switch (type)
+ {
+ case IntegerRelation:
+ code = combineHash(code, Slang::getHashCode(comparator));
+ code = combineHash(code, Slang::getHashCode(integerValue));
+ break;
+ case BoolRelation:
+ code = combineHash(code, Slang::getHashCode(boolValue));
+ break;
+ case Impossible:
+ case Any:
+ break;
+ default:
+ SLANG_UNREACHABLE("Unhandled relation type");
+ }
+ return code;
+ }
+};
+
+struct StatementSet
+{
+ // A conjunction of independent statements (a1 ^ a2 ^ a3 ...)
+ // - One simple relation per inst.
+ // - The absence of an entry implies that the inst is unconstrained.
+ // - The presence of any "Impossible" relation indicates that the entire conjunction is always
+ // false.
+ //
+ Dictionary<IRInst*, SimpleRelation> statements;
+
+ // Disjunction of a conjunction of statements (a1 ^ a2 ^ a3 ...) with the current conjunction.
+ void disjunct(StatementSet other);
+
+ // Conjunction of a conjunction of statements (a1 ^ a2 ^ a3 ...) with the current conjunction.
+ void conjunct(StatementSet other);
+
+ // Conjunction of a single statement with the current conjunction.
+ void conjunct(IRInst* inst, SimpleRelation relation);
+
+ void set(IRInst* inst, SimpleRelation relation)
+ {
+ if (relation.type == SimpleRelation::Any)
+ {
+ if (statements.containsKey(inst))
+ statements.remove(inst);
+ return;
+ }
+
+ statements[inst] = relation;
+ }
+
+ bool isTriviallyFalse()
+ {
+ for (auto& statement : statements)
+ {
+ if (statement.second.type == SimpleRelation::Impossible)
+ return true;
+ }
+ return false;
+ }
+
+ bool isTriviallyTrue() { return statements.getCount() == 0; }
+};
+
+
+// Utility functions.
+bool isIntegerConstantValue(IRInst* inst);
+bool isBoolConstantValue(IRInst* inst);
+IRIntegerValue getConstantIntegerValue(IRInst* inst);
+bool getConstantBoolValue(IRInst* inst);
+
+bool doesRelationImply(SimpleRelation relationA, SimpleRelation relationB);
+
+// Try to collect a set of implications for any insts visible in a block,
+// subject to the set of predicates.
+//
+StatementSet collectImplications(
+ RefPtr<IRDominatorTree> domTree,
+ IRBlock* block,
+ StatementSet Predicates);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index c2403e53b..231221156 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -2,6 +2,7 @@
#include "../core/slang-func-ptr.h"
#include "slang-ast-support-types.h"
+#include "slang-ir-autodiff-loop-analysis.h"
#include "slang-ir-autodiff-region.h"
#include "slang-ir-insts.h"
#include "slang-ir-simplify-cfg.h"
@@ -295,9 +296,10 @@ bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInf
if (indicesA.getCount() > indicesB.getCount())
return false;
+ auto offset = (indicesB.getCount() - indicesA.getCount());
for (Index ii = 0; ii < indicesA.getCount(); ii++)
{
- if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam)
+ if (indicesA[ii].primalCountParam != indicesB[ii + offset].primalCountParam)
return false;
}
@@ -402,7 +404,7 @@ void splitLoopConditionBlockInsts(
if (loopUses.getCount() > 0 && afterLoopUses.getCount() > 0)
{
setInsertAfterOrdinaryInst(&builder, inst);
- auto copy = builder.emitCheckpointObject(inst);
+ auto copy = builder.emitLoopExitValue(inst);
// Copy source location so that checkpoint reporting is accurate
copy->sourceLoc = inst->sourceLoc;
@@ -425,6 +427,8 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
{
collectInductionValues(func);
+ collectLoopExitConditions(func);
+
RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo();
RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
@@ -581,6 +585,19 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
workList.add(&branchInst->getArgs()[paramIndex]);
}
}
+ else if (auto exitValue = as<IRLoopExitValue>(result.instToRecompute))
+ {
+ // If we also have an exit value (a stronger condition on the param), record
+ // it.
+ //
+ if (auto loopExitValueInst =
+ loopExitValueInsts.tryGetValue(exitValue->getVal()))
+ {
+ checkpointInfo->loopExitValueInsts.addIfNotExists(
+ exitValue->getVal(),
+ *loopExitValueInst);
+ }
+ }
else
{
if (auto var = as<IRVar>(result.instToRecompute))
@@ -1046,6 +1063,248 @@ void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode*
}
}
+static bool isValueInRange(IRIntegerValue value, IRType* type)
+{
+ IRInst* innerType = unwrapAttributedType(type);
+ IRIntegerValue nBits;
+ bool isSigned;
+
+ switch (innerType->getOp())
+ {
+ case kIROp_IntType:
+ case kIROp_UIntType:
+ nBits = 32;
+ break;
+ case kIROp_Int16Type:
+ case kIROp_UInt16Type:
+ nBits = 16;
+ break;
+ case kIROp_Int8Type:
+ case kIROp_UInt8Type:
+ nBits = 8;
+ break;
+ case kIROp_Int64Type:
+ case kIROp_UInt64Type:
+ nBits = 64;
+ break;
+ default:
+ return false;
+ }
+
+ switch (innerType->getOp())
+ {
+ case kIROp_IntType:
+ case kIROp_Int16Type:
+ case kIROp_Int8Type:
+ case kIROp_Int64Type:
+ isSigned = true;
+ break;
+ case kIROp_UIntType:
+ case kIROp_UInt16Type:
+ case kIROp_UInt8Type:
+ case kIROp_UInt64Type:
+ isSigned = false;
+ break;
+ default:
+ return false;
+ }
+
+ if (nBits >= 64)
+ {
+ // IRIntegerValue is 64-bit, so we assume we can always represent the value.
+ // TODO: Corner cases like loops that _rely_ on 64-bit integer overflow might not be handled
+ // correctly.
+ //
+ return true;
+ }
+
+ if (isSigned)
+ {
+ IRIntegerValue maxValue = (1ULL << (nBits - 1)) - 1;
+ return value >= -maxValue && value <= maxValue;
+ }
+ else
+ {
+ IRIntegerValue maxValue = (1ULL << nBits) - 1;
+ return value >= 0 && value <= maxValue;
+ }
+}
+
+void AutodiffCheckpointPolicyBase::collectLoopExitConditions(IRGlobalValueWithCode* func)
+{
+ // Assume that the InductionValueInfo is already collected.
+ IRBuilder builder(func->getModule());
+ RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
+ for (auto block : func->getBlocks())
+ {
+ auto loopInst = as<IRLoop>(block->getTerminator());
+ if (!loopInst)
+ continue;
+ auto targetBlock = loopInst->getTargetBlock();
+ auto ifElse = as<IRIfElse>(targetBlock->getTerminator());
+ if (!ifElse)
+ continue;
+
+ auto condParam = as<IRParam>(ifElse->getCondition());
+ if (!condParam || condParam->getParent() != targetBlock)
+ continue;
+
+ // Locate the loop counter.
+ IRInst* loopCounter = nullptr;
+ for (auto param : targetBlock->getParams())
+ {
+ if (param->findDecoration<IRLoopCounterDecoration>())
+ {
+ loopCounter = param;
+ break;
+ }
+ }
+
+ if (!loopCounter)
+ continue;
+
+ // Go over all loop phi parameters for which we have induction value info,
+ // and try to determine a relation on the exit value.
+ //
+ for (auto param : targetBlock->getParams())
+ {
+ auto inductionValueInfo = inductionValueInsts.tryGetValue(param);
+ if (!inductionValueInfo ||
+ inductionValueInfo->kind != LoopInductionValueInfo::AffineFunctionOfCounter)
+ continue;
+
+ // We need to have a known constant offset to be able to compute the loop exit value.
+ if (!isIntegerConstantValue(inductionValueInfo->counterOffset))
+ continue;
+
+ StatementSet conditionIsFalse;
+ conditionIsFalse.conjunct(condParam, SimpleRelation::boolRelation(false));
+
+ // Collect a statement that holds when the loop condition is false.
+ const auto implicationsForFalseCondition =
+ collectImplications(domTree, targetBlock, conditionIsFalse);
+
+ if (!implicationsForFalseCondition.statements.containsKey(param))
+ {
+ // The statement we collected says nothing about the parameter. No point continuing.
+ continue;
+ }
+
+ // Collect statements for the inverse.. i.e. some relation that holds if the condition
+ // is true.
+ StatementSet conditionIsTrue;
+ conditionIsTrue.conjunct(condParam, SimpleRelation::boolRelation(true));
+ const auto implicationsForTrueCondition =
+ collectImplications(domTree, targetBlock, conditionIsTrue);
+
+ if (!implicationsForTrueCondition.statements.containsKey(param))
+ {
+ // The statement we collected says nothing about the parameter. No point continuing.
+ continue;
+ }
+
+ // Extract A s.t. ~breakFlag => A.
+ //
+ // (Note that breakFlag == false is the case where the
+ // loop exits)
+ //
+ SimpleRelation statement = implicationsForFalseCondition.statements.getValue(param);
+
+ // Extract B s.t. breakFlag => B
+ SimpleRelation inverseStatement =
+ implicationsForTrueCondition.statements.getValue(param);
+
+ // If A => ~B, then by using the contrapositive, we get A <=> ~breakFlag
+ if (!doesRelationImply(statement, inverseStatement.negated()))
+ {
+ // If the above doesn't work, we can try using ~B instead.
+ if (!doesRelationImply(inverseStatement.negated(), statement))
+ continue; // Neither works.. we can't infer anything about param.
+ else
+ statement = inverseStatement.negated(); // Use ~B <=> ~breakFlag
+ }
+
+ // We found a relation on the parameter at the loop exit, and we also proved that
+ // if the relation holds, the loop must exit.
+ //
+ // If we have an inequality + information that a value is an inductive (i.e. follows a
+ // sequence of the form `start + i * step`), then we can use that to compute the exact
+ // value at the loop exit.
+ //
+ // We can do this by solving the inequality for the parameter, using the inductive value
+ // as the counter variable.
+ //
+ if (inductionValueInfo->kind == LoopInductionValueInfo::Kind::AffineFunctionOfCounter)
+ {
+ auto counterOffset = getConstantIntegerValue(inductionValueInfo->counterOffset);
+ auto counterFactor = inductionValueInfo->counterFactor;
+
+ SLANG_ASSERT(statement.type == SimpleRelation::Type::IntegerRelation);
+ auto relationValue = statement.integerValue;
+
+ auto recordExitValue = [&](IRIntegerValue exitIValue, IRIntegerValue exitParamValue)
+ {
+ // TODO: Maybe we should warn if the inferred exit value is out of range?
+ if (isValueInRange(exitParamValue, param->getDataType()))
+ {
+ this->loopExitValueInsts[param] =
+ builder.getIntValue(param->getDataType(), exitParamValue);
+ }
+
+ // The interesting part is that since we know that this variable is an bijective
+ // function of the loop counter, we can also compute the loop counter's exit
+ // value.
+ //
+ // Since this can come from multiple parameters, we'll verify to make sure that
+ // there are no contradictions.
+ //
+ IRInst* loopCounterExitValue;
+ if (this->loopExitValueInsts.tryGetValue(loopCounter, loopCounterExitValue))
+ {
+ auto loopCounterExitIValue = getConstantIntegerValue(loopCounterExitValue);
+ if (loopCounterExitIValue != exitIValue)
+ {
+ SLANG_ASSERT(!"contradictory loop exit values");
+ }
+ }
+ else
+ {
+ // TODO: Maybe we should warn if the inferred exit value is out of range?
+ if (isValueInRange(exitIValue, loopCounter->getDataType()))
+ {
+ this->loopExitValueInsts[loopCounter] =
+ builder.getIntValue(loopCounter->getDataType(), exitIValue);
+ }
+ }
+ };
+
+ if (counterFactor > 0 && statement.comparator == SimpleRelation::GreaterThanEqual)
+ {
+ // Find the smallest value that satisfies counterFactor * i + counterOffset >=
+ // relationValue
+ //
+ IRIntegerValue exitIValue =
+ (((relationValue - counterOffset) + counterFactor - 1) / counterFactor);
+ IRIntegerValue exitParamValue =
+ counterOffset + counterFactor * (exitIValue - 1);
+ recordExitValue(exitIValue, exitParamValue);
+ }
+ else if (counterFactor < 0 && statement.comparator == SimpleRelation::LessThanEqual)
+ {
+ // Find the largest value that satisfies counterFactor * i + counterOffset <=
+ // relationValue
+ //
+ IRIntegerValue exitIValue =
+ ((relationValue - counterOffset) + (counterFactor + 1)) / counterFactor;
+ IRIntegerValue exitParamValue = counterOffset + counterFactor * exitIValue;
+ recordExitValue(exitIValue, exitParamValue);
+ }
+ // TODO: handle other cases
+ }
+ }
+ }
+}
+
void applyToInst(
IRBuilder* builder,
CheckpointSetInfo* checkpointInfo,
@@ -1065,6 +1324,22 @@ void applyToInst(
bool isInstRecomputed = checkpointInfo->recomputeSet.contains(inst);
if (isInstRecomputed)
{
+ if (auto loopExitValueInst = as<IRLoopExitValue>(inst))
+ {
+ if (auto loopExitValue =
+ checkpointInfo->loopExitValueInsts.tryGetValue(loopExitValueInst->getVal()))
+ {
+ cloneCtx->cloneEnv.mapOldValToNew[inst] = *loopExitValue;
+ cloneCtx->registerClonedInst(builder, inst, *loopExitValue);
+ return;
+ }
+
+ // Should never happen. (Can't mark a LoopExitValue inst as recomputed without having an
+ // entry in loopExitValueInsts dict)
+ //
+ SLANG_ASSERT(!"no loop exit value found for inst");
+ }
+
if (as<IRParam>(inst))
{
// Can completely ignore first block parameters
@@ -2239,6 +2514,13 @@ void lowerCheckpointObjectInsts(IRGlobalValueWithCode* func)
inst->removeAndDeallocate();
}
+ if (auto loopExitValueInst = as<IRLoopExitValue>(inst))
+ {
+ auto originalVal = loopExitValueInst->getVal();
+ loopExitValueInst->replaceUsesWith(originalVal);
+ loopExitValueInst->removeAndDeallocate();
+ }
+
inst = nextInst;
}
}
@@ -2268,9 +2550,17 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
sortBlocksInFunc(func);
+ // Dump IR.
+ /*IRDumpOptions options;
+ options.flags = IRDumpOptions::Flag::DumpDebugIds;
+ options.mode = IRDumpOptions::Mode::Detailed;
+ DiagnosticSinkWriter writer(sink);
+ writer.write("### BEFORE-PROCESS-FUNC\n", strlen("### BEFORE-PROCESS-FUNC\n"));
+ dumpIR(func, options, sink->getSourceManager(), &writer);*/
+
// Determine the strategy we should use to make a primal inst available.
- // If we decide to recompute the inst, emit the recompute inst in the corresponding recompute
- // block.
+ // If we decide to recompute the inst, emit the recompute inst in the corresponding
+ // recompute block.
//
RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule());
chkPolicy->preparePolicy(func);
@@ -2421,6 +2711,7 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_MatrixReshape:
case kIROp_VectorReshape:
case kIROp_GetTupleElement:
+ case kIROp_LoopExitValue:
return false;
case kIROp_Load:
@@ -2545,6 +2836,13 @@ bool DefaultCheckpointPolicy::canRecompute(UseOrPseudoUse use)
}
}
}
+ else if (auto exitValue = as<IRLoopExitValue>(use.usedVal))
+ {
+ if (loopExitValueInsts.containsKey(exitValue->getVal()))
+ return true;
+ else
+ return false;
+ }
return true;
}
@@ -2578,5 +2876,4 @@ HoistResult DefaultCheckpointPolicy::classify(UseOrPseudoUse use)
}
}
}
-
}; // namespace Slang
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index 92bc8197d..f8c31940a 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -252,6 +252,7 @@ struct CheckpointSetInfo : public RefObject
HashSet<IRInst*> invertSet;
Dictionary<IRInst*, LoopInductionValueInfo> loopInductionInfo;
Dictionary<IRInst*, InversionInfo> invInfoMap;
+ Dictionary<IRInst*, IRInst*> loopExitValueInsts;
};
struct UseOrPseudoUse
@@ -323,7 +324,9 @@ public:
protected:
IRModule* module;
Dictionary<IRInst*, LoopInductionValueInfo> inductionValueInsts;
+ Dictionary<IRInst*, IRInst*> loopExitValueInsts;
void collectInductionValues(IRGlobalValueWithCode* func);
+ void collectLoopExitConditions(IRGlobalValueWithCode* func);
};
class DefaultCheckpointPolicy : public AutodiffCheckpointPolicyBase
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 8019fdd08..bd590f08f 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -717,6 +717,7 @@ INST(BitNot, bitnot, 1, 0)
INST(Select, select, 3, 0)
INST(CheckpointObject, checkpointObj, 1, 0)
+INST(LoopExitValue, loopExitValue, 1, 0)
INST(GetStringHash, getStringHash, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 86596f316..51608ebc9 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2710,6 +2710,13 @@ struct IRCheckpointObject : IRInst
IRInst* getVal() { return getOperand(0); }
};
+struct IRLoopExitValue : IRInst
+{
+ IR_LEAF_ISA(LoopExitValue);
+
+ IRInst* getVal() { return getOperand(0); }
+};
+
// Signals that this point in the code should be unreachable.
// We can/should emit a dataflow error if we can ever determine
// that a block ending in one of these can actually be
@@ -4458,6 +4465,7 @@ public:
IRInst* emitDiscard();
IRInst* emitCheckpointObject(IRInst* value);
+ IRInst* emitLoopExitValue(IRInst* value);
IRInst* emitUnreachable();
IRInst* emitMissingReturn();
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index a74ac58a4..1287d1598 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5734,6 +5734,13 @@ IRInst* IRBuilder::emitCheckpointObject(IRInst* value)
return inst;
}
+IRInst* IRBuilder::emitLoopExitValue(IRInst* value)
+{
+ auto inst = createInst<IRLoopExitValue>(this, kIROp_LoopExitValue, value->getFullType(), value);
+ addInst(inst);
+ return inst;
+}
+
IRInst* IRBuilder::emitBranch(IRBlock* pBlock)
{
auto inst = createInst<IRUnconditionalBranch>(this, kIROp_unconditionalBranch, nullptr, pBlock);
diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang
index 51f17b611..77bfb358c 100644
--- a/tests/autodiff/reverse-continue-loop.slang
+++ b/tests/autodiff/reverse-continue-loop.slang
@@ -9,14 +9,13 @@ RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef float.Differential dfloat;
-//CHK-DAG: note: checkpointing context of 24 bytes associated with function: 'test_loop_with_continue'
+//CHK-DAG: note: checkpointing context of 20 bytes associated with function: 'test_loop_with_continue'
[BackwardDifferentiable]
float test_loop_with_continue(float y)
{
//CHK-DAG: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item:
float t = y;
- //CHK-DAG: note: 4 bytes (int32_t) used to checkpoint the following item:
for (int i = 0; i < 3; i++)
{
if (t > 4.0)
diff --git a/tests/autodiff/reverse-loop-diff-only-2.slang b/tests/autodiff/reverse-loop-diff-only-2.slang
index 2cc33ecca..cc9e14736 100644
--- a/tests/autodiff/reverse-loop-diff-only-2.slang
+++ b/tests/autodiff/reverse-loop-diff-only-2.slang
@@ -32,10 +32,17 @@ float infinitesimal(float x)
// Test that computeLoop's intermediates have no float sitting
// around (must not cache the outvar from 'compute()')
-// CHECK: struct s_bwd_prop_computeLoop_Intermediates
-// CHECK-NEXT: {
-// CHECK-NOT: {{[A-Za-z0-9_]+}} {{[A-Za-z0-9_]+}}[{{.*}}]
-// CHECK: }
+//
+// Further, if loop exit value inference is working correctly,
+// then there should be no context type at all.
+//
+// CHECK-NOT: struct s_bwd_prop_computeLoop_Intermediates
+//
+// Check that the signature of the s_bwd_prop_computeLoop function only
+// contains an inout DiffPair_float_0 and a float.
+//
+// CHECK: void s_bwd_prop_computeLoop{{[_0-9]*}}(inout DiffPair_float{{[_0-9]*}} dpy{{[_0-9]*}}, float {{[_a-zA-Z0-9]*}})
+//
[BackwardDifferentiable]
[PreferRecompute]
diff --git a/tests/autodiff/reverse-loop-exit-value-inference-1.slang b/tests/autodiff/reverse-loop-exit-value-inference-1.slang
new file mode 100644
index 000000000..5a8d8c2c3
--- /dev/null
+++ b/tests/autodiff/reverse-loop-exit-value-inference-1.slang
@@ -0,0 +1,240 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cpu -compute -output-using-type -shaderobj
+//TEST:SIMPLE(filecheck=CHK_REPORT):-target hlsl -stage compute -entry computeMain -report-checkpoint-intermediates
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+// A variety of tests to check for loop exit value inference.
+// For all of these loops, we expect our inference pass to be able to
+// infer the loop exit value correctly.
+//
+// Further, if the optimization pass runs successfully, then there should
+// be absolutely no context stored for any of these tests.
+//
+
+// CHK_REPORT: (0): note: no checkpoint contexts to report
+
+[Differentiable]
+float test_simple(float y)
+{
+ float t = y;
+
+ for (int i = 0; i < 3; i++)
+ {
+ t = t * (i + 1);
+ }
+
+ return t;
+}
+
+[Differentiable]
+float test_strided(float y)
+{
+ float t = y;
+
+ for (int i = 0; i < 5; i+=2)
+ {
+ t = t * (i + 1);
+ }
+
+ return t;
+}
+
+[Differentiable]
+float test_offset(float y)
+{
+ float t = y;
+
+ for (int i = 2; i < 5; i+=2)
+ {
+ t = t * (i + 1);
+ }
+
+ return t;
+}
+
+[Differentiable]
+float test_negative_stride(float y)
+{
+ float t = y;
+
+ for (int i = 7; i >= 1; i-=2)
+ {
+ t = t * (i + 1);
+ }
+
+ return t;
+}
+
+[Differentiable]
+float test_nested(float y)
+{
+ float t = y;
+
+ for (int i = 0; i < 3; i++)
+ {
+ for (int j = 0; j < 3; j++)
+ {
+ t = t * (i + 4 * j + 1);
+ }
+ }
+
+ return t;
+}
+
+[Differentiable]
+float test_nested_with_offset(float y)
+{
+ float t = y;
+
+ for (int i = -3; i < 3; i++)
+ {
+ for (int j = -3; j < 3; j++)
+ {
+ t = t * ((abs(i) % 2) + (abs(j) % 2) + 1);
+ }
+ }
+
+ return t;
+}
+
+[Differentiable]
+float test_nested_with_conditions(float y)
+{
+ float t = y;
+
+ for (int i = 0; i < 3; i++)
+ {
+ if (i % 2 == 0)
+ {
+ for (int j = 0; j < 3; j++)
+ {
+ if (j % 2 == 0)
+ {
+ t = t * (i + 4 * j + 1);
+ }
+ }
+ }
+ }
+
+ return t;
+}
+
+[Differentiable]
+float test_with_continue(float y)
+{
+ float t = y;
+
+ for (int i = 0; i < 5; i++)
+ {
+ if (i % 2 == 0)
+ {
+ continue;
+ }
+
+ t = t * (i + 1);
+ }
+
+ return t;
+}
+
+[Differentiable]
+float test_nested_with_continue(float y)
+{
+ float t = y;
+
+ for (int i = 0; i < 3; i++)
+ {
+ if (i % 2 == 0)
+ continue;
+
+ for (int j = 0; j < 3; j++)
+ {
+ if (j % 2 == 0)
+ continue;
+
+ if (j == 0)
+ continue;
+
+ t = t * (i + 4 * j + 1);
+ }
+ }
+
+ return t;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ outputBuffer[0] = 0.0f; // CHECK: 0.000000
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_simple)(dpa, 1.0f);
+ outputBuffer[1] = dpa.d; // CHECK-NEXT: 6.000000
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_strided)(dpa, 1.0f);
+ outputBuffer[2] = dpa.d; // CHECK-NEXT: 15.000000
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_offset)(dpa, 1.0f);
+ outputBuffer[3] = dpa.d; // CHECK-NEXT: 15.000000
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_negative_stride)(dpa, 1.0f);
+ outputBuffer[4] = dpa.d; // CHECK-NEXT: 384.000000
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_nested)(dpa, 1.0f);
+ outputBuffer[5] = dpa.d; // CHECK-NEXT: 1247400.000000
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_nested_with_offset)(dpa, 1.0f);
+ outputBuffer[6] = dpa.d; // CHECK-NEXT: 5159780352.000000
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_nested_with_conditions)(dpa, 1.0f);
+ outputBuffer[7] = dpa.d; // CHECK-NEXT: 297.000000
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_with_continue)(dpa, 1.0f);
+ outputBuffer[8] = dpa.d; // CHECK-NEXT: 8.000000
+ }
+
+ {
+ dpfloat dpa = dpfloat(0.4, 0.0);
+
+ __bwd_diff(test_nested_with_continue)(dpa, 1.0f);
+ outputBuffer[9] = dpa.d; // CHECK-NEXT: 6.000000
+ }
+}
+
+//CHK-NOT: note \ No newline at end of file
diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop-simple.slang
index 18b672860..18b672860 100644
--- a/tests/autodiff/reverse-loop.slang
+++ b/tests/autodiff/reverse-loop-simple.slang
diff --git a/tests/autodiff/reverse-loop.slang.expected.txt b/tests/autodiff/reverse-loop-simple.slang.expected.txt
index 76b7cf779..76b7cf779 100644
--- a/tests/autodiff/reverse-loop.slang.expected.txt
+++ b/tests/autodiff/reverse-loop-simple.slang.expected.txt