diff options
| -rw-r--r-- | source/slang/slang-emit.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-loop-analysis.cpp | 1047 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-loop-analysis.h | 181 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 307 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 7 | ||||
| -rw-r--r-- | tests/autodiff/reverse-continue-loop.slang | 3 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop-diff-only-2.slang | 15 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop-exit-value-inference-1.slang | 240 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop-simple.slang (renamed from tests/autodiff/reverse-loop.slang) | 0 | ||||
| -rw-r--r-- | tests/autodiff/reverse-loop-simple.slang.expected.txt (renamed from tests/autodiff/reverse-loop.slang.expected.txt) | 0 |
13 files changed, 1804 insertions, 11 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 8d7577b52..4eb4719f0 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1037,7 +1037,10 @@ Result linkAndOptimizeIR( // Report checkpointing information if (codeGenContext->shouldReportCheckpointIntermediates()) + { + simplifyIR(targetProgram, irModule, fastIRSimplificationOptions, sink); reportCheckpointIntermediates(codeGenContext, sink, irModule); + } // Finalization is always run so AD-related instructions can be removed, // even if the AD pass itself is not run. diff --git a/source/slang/slang-ir-autodiff-loop-analysis.cpp b/source/slang/slang-ir-autodiff-loop-analysis.cpp new file mode 100644 index 000000000..d4ff631a6 --- /dev/null +++ b/source/slang/slang-ir-autodiff-loop-analysis.cpp @@ -0,0 +1,1047 @@ +// slang-ir-autodiff-loop-analysis.cpp + +#include "slang-ir-autodiff-loop-analysis.h" + +namespace Slang +{ + +static bool isCompareCmpInst(IRInst* inst) +{ + // Switch on the opcode of the instruction + switch (inst->getOp()) + { + case kIROp_Less: + case kIROp_Greater: + case kIROp_Leq: + case kIROp_Geq: + case kIROp_Eql: + case kIROp_Neq: + return true; + default: + return false; + } +} + +SimpleRelation mergeEqualityWithIntegerRelation(SimpleRelation equality, SimpleRelation relation) +{ + SLANG_ASSERT( + equality.type == SimpleRelation::IntegerRelation && + relation.type == SimpleRelation::IntegerRelation); + SLANG_ASSERT(equality.comparator == SimpleRelation::Equal); + + switch (relation.comparator) + { + case SimpleRelation::Equal: + if (relation.integerValue == equality.integerValue) + return relation; + break; // Technically we'd want to return a "set" here, but we don't have a representation + // for that. + case SimpleRelation::LessThanEqual: + if (equality.integerValue <= relation.integerValue) + return relation; + break; + case SimpleRelation::GreaterThanEqual: + if (equality.integerValue >= relation.integerValue) + return relation; + break; + default: + break; + } + + return SimpleRelation::anyRelation(); +} + +SimpleRelation mergeIntervals(SimpleRelation a, SimpleRelation b) +{ + SLANG_ASSERT( + a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation); + + if (a.comparator == SimpleRelation::Equal) + { + return mergeEqualityWithIntegerRelation(a, b); + } + else if (b.comparator == SimpleRelation::Equal) + { + return mergeEqualityWithIntegerRelation(b, a); + } + + // TODO: Handle other cases... + return SimpleRelation::anyRelation(); +} + +// Returns the tighest "simple" relation such that (a v b -> result) +// +// Note: "simple" means that the relation is not a disjunction or conjunction of other relations. +// +SimpleRelation relationUnion(SimpleRelation a, SimpleRelation b) +{ + // Base case. The disjunction operator is idempotent. + if (a == b) + return a; + + // If either side is trivially true, the result is trivially true. + if (a.type == SimpleRelation::Any || b.type == SimpleRelation::Any) + return SimpleRelation::anyRelation(); + + // If either side is trivially false, then the result is the other relation. + if (a.type == SimpleRelation::Impossible) + return b; + + if (b.type == SimpleRelation::Impossible) + return a; + + // If one is the negated form of the other, there's really nothing we can prove, since + // A OR ~A is always true. + // + if (a.negated() == b) + return SimpleRelation::anyRelation(); + + // Handle the case of where one is an inequality and the other is an equality. + if (a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation) + return mergeIntervals(a, b); + + // TODO: Here's where we can handle subset cases like (a < 10) and (a < 20) => (a < 20), etc.. + // But we don't _have_ to. The more we can prove, the more cases we can handle, but the result + // is still correct without it. + // + + // Default to not being able to say anything. + return SimpleRelation::anyRelation(); +} + +SimpleRelation intersectEqualityWithIntegerRelation( + SimpleRelation equality, + SimpleRelation relation) +{ + SLANG_ASSERT(equality.type == SimpleRelation::IntegerRelation); + SLANG_ASSERT(relation.type == SimpleRelation::IntegerRelation); + SLANG_ASSERT(equality.comparator == SimpleRelation::Equal); + + if (relation.comparator == SimpleRelation::Equal) + { + if (relation.integerValue == equality.integerValue) + return SimpleRelation::integerRelation(SimpleRelation::Equal, equality.integerValue); + else + return SimpleRelation::impossibleRelation(); + } + else if (relation.comparator == SimpleRelation::LessThanEqual) + { + if (equality.integerValue <= relation.integerValue) + return SimpleRelation::integerRelation( + SimpleRelation::LessThanEqual, + relation.integerValue); + else + return SimpleRelation::impossibleRelation(); + } + else if (relation.comparator == SimpleRelation::GreaterThanEqual) + { + if (equality.integerValue >= relation.integerValue) + return SimpleRelation::integerRelation( + SimpleRelation::GreaterThanEqual, + relation.integerValue); + else + return SimpleRelation::impossibleRelation(); + } + else if (relation.comparator == SimpleRelation::NotEqual) + { + if (equality.integerValue != relation.integerValue) + return SimpleRelation::integerRelation(SimpleRelation::NotEqual, relation.integerValue); + else + return SimpleRelation::impossibleRelation(); + } + + return SimpleRelation::anyRelation(); +} + +// Intersect intervals. +SimpleRelation intersectIntervals(SimpleRelation a, SimpleRelation b) +{ + SLANG_ASSERT( + a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation); + + if (a.comparator == SimpleRelation::Equal) + { + return intersectEqualityWithIntegerRelation(a, b); + } + else if (b.comparator == SimpleRelation::Equal) + { + return intersectEqualityWithIntegerRelation(b, a); + } + + // TODO: Handle other cases... + + // We'll just default to picking the first one, since (a ^ b) -> a is always true. + return a; +} + +// Returns the best "simple" relation such that (a ^ b -> result) +// +SimpleRelation relationIntersection(SimpleRelation a, SimpleRelation b) +{ + // Base case. The conjunction operator is idempotent. + if (a == b) + return a; + + // If one is the negated form of the other, then we can prove that the result is impossible. + // Doesn't necessarily mean that we have an error on our hands, but it does mean that whatever + // case we're considering can't happen, so can be ignored (unreachable) + // + if (a.negated() == b) + return SimpleRelation::impossibleRelation(); + + // If any of the relations is impossible, then the result is impossible. + if (a.type == SimpleRelation::Impossible || b.type == SimpleRelation::Impossible) + return SimpleRelation::impossibleRelation(); + + // If any one of the relations is trivially true, then the result is the other relation. + if (a.type == SimpleRelation::Any) + return b; + + if (b.type == SimpleRelation::Any) + return a; + + // + // We'll handle the cases where one is an equality and the other is an inequality or equality. + // + // i.e. For a conjunction (a == 10) ^ (a < 20), we can use the narrower relation (a == 10). + // + if (a.type == SimpleRelation::IntegerRelation && b.type == SimpleRelation::IntegerRelation) + return intersectIntervals(a, b); + + // TODO: Handle other cases... + return SimpleRelation::anyRelation(); +} + +void StatementSet::disjunct(StatementSet other) +{ + // false v (a1 v a2 v a3 ...) = (a1 v a2 v a3 ...) + if (isTriviallyFalse()) + { + statements = other.statements; + return; + } + + // (a1 v a2 v a3 ...) v false = (a1 v a2 v a3 ...) + if (other.isTriviallyFalse()) + return; + + // true v (a1 v a2 v a3 ...) = true + if (other.isTriviallyTrue()) + { + statements.clear(); + return; + } + + // (a1 v a2 v a3 ...) v true = true + if (isTriviallyTrue()) + return; + + for (auto& statement : other.statements) + { + // Since we hold only one statement per inst, we can perform disjunction + // on a per-inst basis. + // If an inst does not exist in the current set, then it's an empty statement. + // + if (statements.containsKey(statement.first)) + { + auto newRelation = relationUnion(statement.second, statements[statement.first]); + set(statement.first, newRelation); + } + } + + // Remove any insts that don't have a corresponding statement in the other set, + // since this effectively means "any". + // + for (auto& statement : statements) + { + if (!other.statements.containsKey(statement.first)) + statements.remove(statement.first); + } +} + +void StatementSet::conjunct(StatementSet other) +{ + // true ^ (a1 ^ a2 ^ a3 ...) = (a1 ^ a2 ^ a3 ...) + if (other.isTriviallyTrue()) + return; + + // (a1 ^ a2 ^ a3 ...) ^ true = (a1 ^ a2 ^ a3 ...) + if (isTriviallyTrue()) + { + statements = other.statements; + return; + } + + // false ^ (a1 ^ a2 ^ a3 ...) = false + if (isTriviallyFalse()) + return; + + // (a1 ^ a2 ^ a3 ...) ^ false = false + if (other.isTriviallyFalse()) + { + statements = other.statements; + return; + } + + // Otherwise do an element-wise conjunction. + for (auto& statement : other.statements) + { + if (statements.containsKey(statement.first)) + { + set(statement.first, + relationIntersection(statement.second, statements[statement.first])); + } + else + { + set(statement.first, statement.second); + } + } +} + +void StatementSet::conjunct(IRInst* inst, SimpleRelation relation) +{ + if (isTriviallyFalse()) + return; + + if (statements.containsKey(inst)) + { + set(inst, relationIntersection(relation, statements[inst])); + } + else + { + set(inst, relation); + } +} + +// This function answers the question: "Can we prove that relationB is true if relationA is true?" +// +// Note that this is not the same as "Does relationA imply relationB", since there can be cases +// where this is indeed true, but we just don't have the logic to prove it. +// +bool doesRelationImply(SimpleRelation relationA, SimpleRelation relationB) +{ + // Equal relations imply each other + if (relationA == relationB) + return true; + + // If B is trivially true, then A implies B + if (relationB.type == SimpleRelation::Any) + return true; + + // If A is trivially true, then A implies B only if B is also trivially true + if (relationA.type == SimpleRelation::Any) + return (relationB.type == SimpleRelation::Any); + + // If A is impossible, then technically what we return doesn't matter... + if (relationA.type == SimpleRelation::Impossible || + relationB.type == SimpleRelation::Impossible) + return false; + + // If A is a boolean relation, then A implies B if B is also a boolean relation and the values + // are the same. + // + if (relationA.type == SimpleRelation::BoolRelation) + return (relationB.type == SimpleRelation::BoolRelation) && + (relationA.boolValue == relationB.boolValue); + + if (relationA.type == SimpleRelation::IntegerRelation) + { + if (relationB.type != SimpleRelation::IntegerRelation) + return false; + + // Technically, the equality case is already handled above, so we'll only consider + // cases where A and B are not the same relation, but where A -> B + + // If A is an equality, and B is an inequality, we can test + if (relationA.comparator == SimpleRelation::Equal) + { + if (relationB.comparator == SimpleRelation::LessThanEqual) + return relationA.integerValue <= relationB.integerValue; + else if (relationB.comparator == SimpleRelation::GreaterThanEqual) + return relationA.integerValue >= relationB.integerValue; + } + + // If A is an equality, and B is an inequality with different values, then + // A -> B + // + if (relationA.comparator == SimpleRelation::Equal && + relationB.comparator == SimpleRelation::NotEqual) + { + return relationA.integerValue != relationB.integerValue; + } + + if (relationA.comparator == SimpleRelation::GreaterThanEqual && + relationB.comparator == SimpleRelation::GreaterThanEqual) + { + return relationA.integerValue >= relationB.integerValue; + } + + if (relationA.comparator == SimpleRelation::LessThanEqual && + relationB.comparator == SimpleRelation::LessThanEqual) + { + return relationA.integerValue <= relationB.integerValue; + } + + // TODO: Handle other cases.. these come up rarely, so we can + } + + return false; +} + +bool isIntegerConstantValue(IRInst* inst) +{ + return inst->getOp() == kIROp_IntLit; +} + +bool isBoolConstantValue(IRInst* inst) +{ + return inst->getOp() == kIROp_BoolLit; +} + +IRIntegerValue getConstantIntegerValue(IRInst* inst) +{ + SLANG_ASSERT(isIntegerConstantValue(inst)); + return as<IRIntLit>(inst)->getValue(); +} + +bool getConstantBoolValue(IRInst* inst) +{ + SLANG_ASSERT(isBoolConstantValue(inst)); + return as<IRBoolLit>(inst)->getValue(); +} + +StatementSet tryExtractStatements(IRTerminatorInst* inst, IRBlock* block) +{ + StatementSet statements; + + // From condInst, extract a statement about any inst such that we have an equality + // statement (integer or boolean) on the inst. + // + if (auto ifElse = as<IRIfElse>(inst)) + { + // Check that the block is the true or false block of the if-else + bool isTrueBlock = ifElse->getTrueBlock() == block; + bool isFalseBlock = ifElse->getFalseBlock() == block; + if (!isTrueBlock && !isFalseBlock) + goto done; + + auto condInst = inst->getOperand(0); + statements.conjunct(condInst, SimpleRelation::boolRelation(isTrueBlock)); + + if (condInst->getOp() == kIROp_Eql) + { + auto leftOperand = condInst->getOperand(0); + auto rightOperand = condInst->getOperand(1); + + if (isIntegerConstantValue(leftOperand)) + { + statements.conjunct( + rightOperand, + SimpleRelation::integerRelation( + (isTrueBlock ? SimpleRelation::Equal : SimpleRelation::NotEqual), + getConstantIntegerValue(leftOperand))); + } + else if (isIntegerConstantValue(rightOperand)) + { + statements.conjunct( + leftOperand, + SimpleRelation::integerRelation( + (isTrueBlock ? SimpleRelation::Equal : SimpleRelation::NotEqual), + getConstantIntegerValue(rightOperand))); + } + } + else if (isCompareCmpInst(condInst)) + { + auto leftOperand = condInst->getOperand(0); + auto rightOperand = condInst->getOperand(1); + + bool isParamLeft = !isIntegerConstantValue(leftOperand); + bool isParamRight = !isIntegerConstantValue(rightOperand); + + // If neither operand is an inst, we can't say anything. + if (!isParamLeft && !isParamRight) + goto done; + + auto paramOperand = isParamLeft ? leftOperand : rightOperand; + auto otherOperand = isParamLeft ? rightOperand : leftOperand; + + // Check if the "other" operand is a constant + if (!isIntegerConstantValue(otherOperand)) + goto done; + + auto constantVal = getConstantIntegerValue(otherOperand); + + SimpleRelation::Comparator comparator; + switch (condInst->getOp()) + { + case kIROp_Less: + comparator = SimpleRelation::LessThanEqual; + constantVal = constantVal - 1; + break; + case kIROp_Greater: + comparator = SimpleRelation::GreaterThanEqual; + constantVal = constantVal + 1; + break; + case kIROp_Leq: + comparator = SimpleRelation::LessThanEqual; + break; + case kIROp_Geq: + comparator = SimpleRelation::GreaterThanEqual; + break; + case kIROp_Eql: + comparator = SimpleRelation::Equal; + break; + case kIROp_Neq: + comparator = SimpleRelation::NotEqual; + break; + default: + SLANG_UNREACHABLE("unexpected op code"); + } + auto relation = SimpleRelation::integerRelation(comparator, constantVal); + statements.conjunct( + paramOperand, + ((isParamLeft ^ !isTrueBlock) ? relation : relation.negated())); + } + } + else if (auto switchInst = as<IRSwitch>(inst)) + { + // Check that the block is the default case of the switch + if (switchInst->getDefaultLabel() == block) + goto done; + + // Check each case block + UInt caseCount = switchInst->getCaseCount(); + for (UInt i = 0; i < caseCount; i++) + { + auto caseValue = switchInst->getCaseValue(i); + auto caseBlock = switchInst->getCaseLabel(i); + + if (caseBlock == block && isIntegerConstantValue(caseValue)) + { + auto constantVal = getConstantIntegerValue(caseValue); + statements.conjunct( + switchInst->getCondition(), + SimpleRelation::integerRelation(SimpleRelation::Equal, constantVal)); + } + } + } + +done: + return statements; +} + +enum class BlockStateFlags +{ + UpwardPropCompleted = 1 << 0, + DownwardPropCompleted = 1 << 1 +}; + +void markUpwardPropCompleted(IRBlock* block) +{ + block->scratchData |= (UInt64)BlockStateFlags::UpwardPropCompleted; +} + +void markDownwardPropCompleted(IRBlock* block) +{ + block->scratchData |= (UInt64)BlockStateFlags::DownwardPropCompleted; +} + +bool isUpwardPropCompleted(IRBlock* block) +{ + return block->scratchData & (UInt64)BlockStateFlags::UpwardPropCompleted; +} + +bool isDownwardPropCompleted(IRBlock* block) +{ + return block->scratchData & (UInt64)BlockStateFlags::DownwardPropCompleted; +} + +void clearBlockState(IRBlock* block) +{ + block->scratchData = 0; +} + +bool isLoopConditionBlock(IRBlock* block) +{ + for (auto use = block->firstUse; use; use = use->nextUse) + { + if (auto loop = as<IRLoop>(use->getUser())) + { + if (loop->getTargetBlock() == block) + return true; + } + } + + return false; +} + +bool isBlockReadyForUpwardProp(IRBlock* block) +{ + if (isLoopConditionBlock(block)) + { + auto falseBlock = cast<IRIfElse>(block->getTerminator())->getFalseBlock(); + return isUpwardPropCompleted(falseBlock); + } + + // Check that successors have completed upward propagation. + for (auto successor : block->getSuccessors()) + { + if (!isUpwardPropCompleted(successor)) + return false; + } + return true; +} + +bool isBlockReadyForDownwardProp(IRBlock* block) +{ + // Check that predecessors have completed downward propagation. + for (auto predecessor : block->getPredecessors()) + { + if (!isDownwardPropCompleted(predecessor)) + return false; + } + return true; +} + +StatementSet propagateStatementUpwards(IRInst* inst, SimpleRelation relation) +{ + // Lambda to make a single-statement set. + auto makeStatementSet = [&](IRInst* inst, SimpleRelation relation) + { + StatementSet set; + set.conjunct(inst, relation); + return set; + }; + + if (as<IRParam>(inst)) + return makeStatementSet(inst, relation); + + if (isIntegerConstantValue(inst)) + { + auto relationFromInst = + SimpleRelation::integerRelation(SimpleRelation::Equal, getConstantIntegerValue(inst)); + if (doesRelationImply(relation, relationFromInst)) + return makeStatementSet(inst, SimpleRelation::anyRelation()); // Trivially true + else if (doesRelationImply(relation, relationFromInst.negated())) + return makeStatementSet(inst, SimpleRelation::impossibleRelation()); + else + return makeStatementSet(inst, SimpleRelation::anyRelation()); + } + else if (isBoolConstantValue(inst)) + { + auto relationFromInst = SimpleRelation::boolRelation(getConstantBoolValue(inst)); + if (doesRelationImply(relation, relationFromInst)) + return makeStatementSet(inst, SimpleRelation::anyRelation()); // Trivially true + else if (doesRelationImply(relation, relationFromInst.negated())) + return makeStatementSet(inst, SimpleRelation::impossibleRelation()); + else + return makeStatementSet(inst, SimpleRelation::anyRelation()); + } + else if (inst->getOp() == kIROp_Add || inst->getOp() == kIROp_Sub) + { + // TODO: Translate equality/inequality. + } + + return makeStatementSet(inst, SimpleRelation::anyRelation()); +} + +StatementSet propagateUpwards( + RefPtr<IRDominatorTree> domTree, + IRBlock* current, + IRBlock* predecessor, + StatementSet predicateSet) +{ + // Translate the set of predicates from the current block to the predecessor block. + // + // The key idea is that we need to find a set of predicate statements (A') for the predecessor + // block, such that A => A'. + // + // During the downward phase, the predecessor will then return a set of + // statements (B') such that A' => B'. This B' can be propagated "downwards" into a set + // of statements B such that B' => B. + // + // We can then combine these three rules A => A', A' => B' and B' => B to get A => B + // which is the statement set that we want for our current block. + // + + StatementSet newPredicateSet; + for (auto& statementInstPair : predicateSet.statements) + { + auto predicateRelation = statementInstPair.second; + auto predicateInst = statementInstPair.first; + if (as<IRParam>(predicateInst) && predicateInst->getParent() == current) + { + auto paramIndex = getParamIndexInBlock(cast<IRParam>(predicateInst)); + auto translatedInst = + as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(paramIndex); + + // If the translate inst is outside the block, add it in as-is, otherwise, + // we'll need to propagate it to the operands of the inst + // + auto statementSet = propagateStatementUpwards(translatedInst, predicateRelation); + newPredicateSet.conjunct(statementSet); + } + else + { + newPredicateSet.conjunct(predicateInst, predicateRelation); + } + } + + // If our current block is a merge block for a conditional branch, we should add the condition + // to the predicate set. + // + for (auto blockUse = current->firstUse; blockUse; blockUse = blockUse->nextUse) + { + if (auto ifElse = as<IRIfElse>(blockUse->getUser())) + { + if (ifElse->getAfterBlock() == current) + { + // We're looking at the merge block for a conditional branch. + + if (domTree->dominates(ifElse->getTrueBlock(), predecessor)) + { + // True branch + newPredicateSet.conjunct(tryExtractStatements(ifElse, ifElse->getTrueBlock())); + } + else if (domTree->dominates(ifElse->getFalseBlock(), predecessor)) + { + // False branch + newPredicateSet.conjunct(tryExtractStatements(ifElse, ifElse->getFalseBlock())); + } + else + { + // Panic + SLANG_UNREACHABLE("Unreachable block in conditional branch"); + } + } + } + + // We'll ignore switch statements for now, but they're trivial to add. + // TODO: Add switch statements. + } + + // We have one more edge-case. The condition block of a loop inst. + if (auto ifElse = as<IRIfElse>(current->getTerminator())) + { + if (domTree->dominates(ifElse->getTrueBlock(), predecessor) && + !domTree->dominates(ifElse->getFalseBlock(), predecessor)) + { + // True branch + newPredicateSet.conjunct(tryExtractStatements(ifElse, ifElse->getTrueBlock())); + } + } + return newPredicateSet; +} + +StatementSet propagateStatementDownwards( + IRInst* srcInst, + IRInst* dstInst, + StatementSet srcStatements) +{ + // We'll keep translating through the inst, until we either hit a parameter + // until we either hit a parameter, or we leave the current block. + // + + // Lambda to make a single-statement set. + auto singleStatement = [&](IRInst* inst, SimpleRelation relation) + { + StatementSet set; + set.conjunct(inst, relation); + return set; + }; + + if (srcStatements.statements.containsKey(srcInst)) + return singleStatement(dstInst, srcStatements.statements[srcInst]); + + if (isIntegerConstantValue(srcInst)) + { + return singleStatement( + dstInst, + SimpleRelation::integerRelation( + SimpleRelation::Equal, + getConstantIntegerValue(srcInst))); + } + else if (isBoolConstantValue(srcInst)) + { + return singleStatement( + dstInst, + SimpleRelation::boolRelation(getConstantBoolValue(srcInst))); + } + + if (srcInst->getOp() == kIROp_Add || srcInst->getOp() == kIROp_Sub) + { + auto left = srcInst->getOperand(0); + auto right = srcInst->getOperand(1); + + auto isLeftConstant = isIntegerConstantValue(left); + auto isRightConstant = isIntegerConstantValue(right); + + if (!isLeftConstant && !isRightConstant) + return singleStatement(dstInst, + SimpleRelation::anyRelation()); // Can't say anything + + if (srcInst->getOp() == kIROp_Add || (srcInst->getOp() == kIROp_Sub && isRightConstant)) + { + auto constant = + isLeftConstant ? getConstantIntegerValue(left) : getConstantIntegerValue(right); + auto operand = isLeftConstant ? right : left; + + constant = srcInst->getOp() == kIROp_Add ? constant : -constant; + + auto operandStatement = propagateStatementDownwards(operand, operand, srcStatements); + auto relation = operandStatement.statements.containsKey(operand) + ? operandStatement.statements[operand] + : SimpleRelation::anyRelation(); + + if (relation.type == SimpleRelation::IntegerRelation) + { + switch (relation.comparator) + { + case SimpleRelation::Equal: + case SimpleRelation::NotEqual: + case SimpleRelation::LessThanEqual: + case SimpleRelation::GreaterThanEqual: + return singleStatement( + dstInst, + SimpleRelation::integerRelation( + relation.comparator, + constant + relation.integerValue)); + } + } + } + } + + // Default + return singleStatement(dstInst, SimpleRelation::anyRelation()); +} + +StatementSet propagateDownwards( + RefPtr<IRDominatorTree> domTree, + IRBlock* successor, + IRBlock* predecessor, + StatementSet statementSet) +{ + // Translate a set of statements from the current block to the successor block. + // + // That is, find a set of statements (B') for the successor block such that B => B' + // + StatementSet newStatementSet; + + if (statementSet.isTriviallyFalse()) + { + return statementSet; + } + + // Go over all the parameters of the successor block, find corresponding arguments, and + // convert any statements to the new set. + // + UInt paramIndex = 0; + for (auto param : successor->getParams()) + { + auto arg = as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(paramIndex); + auto statement = propagateStatementDownwards(arg, param, statementSet); + newStatementSet.conjunct(statement); + paramIndex++; + } + + newStatementSet.conjunct(tryExtractStatements(predecessor->getTerminator(), successor)); + + // For all other statements in the statementSet, we'll add them in, but only + // if the predecessor dominates the successor. + // An exception is parameters defined in the successor (since these are getting + // redefined, we should not be considering existing statements) + // + for (auto& statement : statementSet.statements) + { + if (domTree->dominates(statement.first->getParent(), successor) && + !(as<IRParam>(statement.first) && statement.first->getParent() == successor)) + newStatementSet.conjunct(statement.first, statement.second); + } + + return newStatementSet; +} + +struct Edge +{ + IRBlock* predecessor; + IRBlock* successor; + + bool operator==(const Edge& other) const + { + return predecessor == other.predecessor && successor == other.successor; + } + + UInt64 getHashCode() const + { + UInt64 predHash = Slang::getHashCode(predecessor); + UInt64 succHash = Slang::getHashCode(successor); + return Slang::combineHash(predHash, succHash); + } +}; + + +// This routine returns a set of implications for any insts visible in a block. +// +// The process uses a modified version of abstract interpretation, by first propagating a set +// of predicates "backwards" repeatedly through the predecessors, then calculating the set of +// implications "forwards" repeatedly through the successors. +// +// Note that the resulting implications don't contain all possible statements that could be inferred +// statically (this is an undeciable problem), but rather whatever can be inferred in just two steps +// through the blocks. This suffices for the vast majority of common loop structures. +// +StatementSet collectImplications( + RefPtr<IRDominatorTree> domTree, + IRBlock* block, + StatementSet Predicates) +{ + List<Edge> orderedEdgeList; // Edges in the order that they're processed. + HashSet<Edge> falseEdges; // Edges between blocks where the successor's predicate does not imply + // the predecessor's predicate. + + // Initialize a work list. + List<IRBlock*> workList; + workList.add(block); + + // Clear scratch bits. + IRFunc* func = cast<IRFunc>(domTree->code); + for (auto _block : func->getBlocks()) + { + clearBlockState(_block); + } + + // + // Upward pass: Propagate predicates through predecessors until + // there're no more blocks left to process. + // + + // We'll keep track of the predicates for each block. + Dictionary<IRBlock*, StatementSet> blockPredicates; + + blockPredicates[block] = Predicates; + + while (workList.getCount() > 0) + { + auto current = workList.getLast(); + workList.removeLast(); + + // If the block has already been processed, skip it. + if (isUpwardPropCompleted(current)) + continue; + + // If the block is not ready for upward propagation, add it to the work list. + if (current != block && !isBlockReadyForUpwardProp(current)) + { + workList.add(current); + // Then add all the successors to the work list. + for (auto successor : current->getSuccessors()) + workList.add(successor); + + continue; + } + + // Otherwise, we'll process the block. + // + // Get our predicate set, then propagate it to all predecessors. + // + auto predicates = blockPredicates[current]; + + HashSet<IRBlock*> uniquePredecessors; + for (auto predecessor : current->getPredecessors()) + uniquePredecessors.add(predecessor); + + for (auto predecessor : uniquePredecessors) + { + // We also need to handle the recursive case, where the predecessor + // is already "sealed". + // + if (isUpwardPropCompleted(predecessor)) + { + orderedEdgeList.add({predecessor, current}); + + // Verify that "current predicate" => "predecessor predicate". + + // TODO: Implement later. + // For now, we can default to assuming that this edge is not + // valid. This works fine since we're not trying to prove anything recursive (like + // inductivity), but we should revisit this if we do want to unify the induction + // value inference pass with this loop analysis system. + // + + // We'll add this to the set of false edges so that the downward prop pass + // doesn't propagate any implications through this edge. + // + falseEdges.add({predecessor, current}); + continue; + } + + auto newPredicates = propagateUpwards(domTree, current, predecessor, predicates); + + if (!blockPredicates.containsKey(predecessor)) + blockPredicates[predecessor] = newPredicates; + else + blockPredicates[predecessor].disjunct(newPredicates); + + orderedEdgeList.add({predecessor, current}); + + // Add predecessors to work list. + workList.add(predecessor); + } + + markUpwardPropCompleted(current); + } + + // + // Downward pass: Propagate implications through successors until + // there're no more blocks left to process. + // + + Dictionary<IRBlock*, StatementSet> blockImplications; + + // Set 'block' to something trivial base case. + // blockImplications[block] = blockPredicates[block]; // statement => statement + + while (orderedEdgeList.getCount() > 0) + { + auto edge = orderedEdgeList.getLast(); + orderedEdgeList.removeLast(); + + // Get the predicate set for the predecessor. + auto predecessorPredicates = blockPredicates[edge.predecessor]; + + // Get the implication set for the predecessor. + auto predecessorImplications = StatementSet(); + + if (falseEdges.contains(edge)) + { + // Since A' => B' is not true, effectively, we can't say anything.. + predecessorImplications = StatementSet(); + } + else + { + // (A' => B') => (A' => A' ^ B') + predecessorImplications = blockImplications[edge.predecessor]; + predecessorImplications.conjunct(predecessorPredicates); + } + + // Propagate the implication set to the successor. + auto successorImplications = + propagateDownwards(domTree, edge.successor, edge.predecessor, predecessorImplications); + + if (!blockImplications.containsKey(edge.successor)) + blockImplications[edge.successor] = successorImplications; + else + blockImplications[edge.successor].disjunct(successorImplications); + } + + // Clear scratch bits. + for (auto _block : func->getBlocks()) + { + clearBlockState(_block); + } + + // We should have a final set of implications for our block. + return blockImplications[block]; +} + +} // namespace Slang diff --git a/source/slang/slang-ir-autodiff-loop-analysis.h b/source/slang/slang-ir-autodiff-loop-analysis.h new file mode 100644 index 000000000..f13ccd7b9 --- /dev/null +++ b/source/slang/slang-ir-autodiff-loop-analysis.h @@ -0,0 +1,181 @@ +// slang-ir-autodiff-loop-analysis.h +#pragma once + +#include "slang-ir-autodiff-region.h" +#include "slang-ir-autodiff.h" +#include "slang-ir-dominators.h" +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ +struct SimpleRelation +{ + enum Type + { + Any, // Target can be anything (all values are possible) + IntegerRelation, // Target satisfies a simple integer equality/inequality + BoolRelation, // Target satisfies boolean equality + Impossible // Target is impossible (has no possible values) + } type; + + enum Comparator + { + LessThanEqual, + GreaterThanEqual, + Equal, + NotEqual + } comparator; + IRIntegerValue integerValue; + bool boolValue; + + static SimpleRelation integerRelation(Comparator comparator, IRIntegerValue integerValue) + { + return SimpleRelation{IntegerRelation, comparator, integerValue, false}; + } + + static SimpleRelation boolRelation(bool boolValue) + { + return SimpleRelation{BoolRelation, Equal, 0, boolValue}; + } + + static SimpleRelation impossibleRelation() + { + return SimpleRelation{Impossible, Equal, 0, false}; + } + + static SimpleRelation anyRelation() { return SimpleRelation{Any, Equal, 0, false}; } + + bool operator==(const SimpleRelation& other) const + { + switch (type) + { + case Any: + return other.type == Any; + case IntegerRelation: + return other.type == IntegerRelation && comparator == other.comparator && + integerValue == other.integerValue; + case BoolRelation: + return other.type == BoolRelation && boolValue == other.boolValue; + case Impossible: + return other.type == Impossible; + default: + SLANG_UNREACHABLE("Unhandled relation type"); + } + } + + bool operator!=(const SimpleRelation& other) const { return !(*this == other); } + + SimpleRelation negated() const + { + switch (type) + { + case Any: + return SimpleRelation{Impossible, Equal, 0, false}; + case Impossible: + return SimpleRelation{Any, Equal, 0, false}; + case BoolRelation: + return SimpleRelation{BoolRelation, Equal, 0, !boolValue}; + case IntegerRelation: + switch (comparator) + { + case LessThanEqual: + return SimpleRelation{IntegerRelation, GreaterThanEqual, integerValue + 1, false}; + case GreaterThanEqual: + return SimpleRelation{IntegerRelation, LessThanEqual, integerValue - 1, false}; + case Equal: + return SimpleRelation{IntegerRelation, NotEqual, integerValue, false}; + case NotEqual: + return SimpleRelation{IntegerRelation, Equal, integerValue, false}; + default: + SLANG_UNREACHABLE("Unhandled comparator"); + } + default: + SLANG_UNREACHABLE("Unhandled relation type"); + } + } + + HashCode64 getHashCode() const + { + HashCode64 code = Slang::getHashCode(int(type)); + switch (type) + { + case IntegerRelation: + code = combineHash(code, Slang::getHashCode(comparator)); + code = combineHash(code, Slang::getHashCode(integerValue)); + break; + case BoolRelation: + code = combineHash(code, Slang::getHashCode(boolValue)); + break; + case Impossible: + case Any: + break; + default: + SLANG_UNREACHABLE("Unhandled relation type"); + } + return code; + } +}; + +struct StatementSet +{ + // A conjunction of independent statements (a1 ^ a2 ^ a3 ...) + // - One simple relation per inst. + // - The absence of an entry implies that the inst is unconstrained. + // - The presence of any "Impossible" relation indicates that the entire conjunction is always + // false. + // + Dictionary<IRInst*, SimpleRelation> statements; + + // Disjunction of a conjunction of statements (a1 ^ a2 ^ a3 ...) with the current conjunction. + void disjunct(StatementSet other); + + // Conjunction of a conjunction of statements (a1 ^ a2 ^ a3 ...) with the current conjunction. + void conjunct(StatementSet other); + + // Conjunction of a single statement with the current conjunction. + void conjunct(IRInst* inst, SimpleRelation relation); + + void set(IRInst* inst, SimpleRelation relation) + { + if (relation.type == SimpleRelation::Any) + { + if (statements.containsKey(inst)) + statements.remove(inst); + return; + } + + statements[inst] = relation; + } + + bool isTriviallyFalse() + { + for (auto& statement : statements) + { + if (statement.second.type == SimpleRelation::Impossible) + return true; + } + return false; + } + + bool isTriviallyTrue() { return statements.getCount() == 0; } +}; + + +// Utility functions. +bool isIntegerConstantValue(IRInst* inst); +bool isBoolConstantValue(IRInst* inst); +IRIntegerValue getConstantIntegerValue(IRInst* inst); +bool getConstantBoolValue(IRInst* inst); + +bool doesRelationImply(SimpleRelation relationA, SimpleRelation relationB); + +// Try to collect a set of implications for any insts visible in a block, +// subject to the set of predicates. +// +StatementSet collectImplications( + RefPtr<IRDominatorTree> domTree, + IRBlock* block, + StatementSet Predicates); + +} // namespace Slang diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index c2403e53b..231221156 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -2,6 +2,7 @@ #include "../core/slang-func-ptr.h" #include "slang-ast-support-types.h" +#include "slang-ir-autodiff-loop-analysis.h" #include "slang-ir-autodiff-region.h" #include "slang-ir-insts.h" #include "slang-ir-simplify-cfg.h" @@ -295,9 +296,10 @@ bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInf if (indicesA.getCount() > indicesB.getCount()) return false; + auto offset = (indicesB.getCount() - indicesA.getCount()); for (Index ii = 0; ii < indicesA.getCount(); ii++) { - if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam) + if (indicesA[ii].primalCountParam != indicesB[ii + offset].primalCountParam) return false; } @@ -402,7 +404,7 @@ void splitLoopConditionBlockInsts( if (loopUses.getCount() > 0 && afterLoopUses.getCount() > 0) { setInsertAfterOrdinaryInst(&builder, inst); - auto copy = builder.emitCheckpointObject(inst); + auto copy = builder.emitLoopExitValue(inst); // Copy source location so that checkpoint reporting is accurate copy->sourceLoc = inst->sourceLoc; @@ -425,6 +427,8 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( { collectInductionValues(func); + collectLoopExitConditions(func); + RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo(); RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); @@ -581,6 +585,19 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( workList.add(&branchInst->getArgs()[paramIndex]); } } + else if (auto exitValue = as<IRLoopExitValue>(result.instToRecompute)) + { + // If we also have an exit value (a stronger condition on the param), record + // it. + // + if (auto loopExitValueInst = + loopExitValueInsts.tryGetValue(exitValue->getVal())) + { + checkpointInfo->loopExitValueInsts.addIfNotExists( + exitValue->getVal(), + *loopExitValueInst); + } + } else { if (auto var = as<IRVar>(result.instToRecompute)) @@ -1046,6 +1063,248 @@ void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode* } } +static bool isValueInRange(IRIntegerValue value, IRType* type) +{ + IRInst* innerType = unwrapAttributedType(type); + IRIntegerValue nBits; + bool isSigned; + + switch (innerType->getOp()) + { + case kIROp_IntType: + case kIROp_UIntType: + nBits = 32; + break; + case kIROp_Int16Type: + case kIROp_UInt16Type: + nBits = 16; + break; + case kIROp_Int8Type: + case kIROp_UInt8Type: + nBits = 8; + break; + case kIROp_Int64Type: + case kIROp_UInt64Type: + nBits = 64; + break; + default: + return false; + } + + switch (innerType->getOp()) + { + case kIROp_IntType: + case kIROp_Int16Type: + case kIROp_Int8Type: + case kIROp_Int64Type: + isSigned = true; + break; + case kIROp_UIntType: + case kIROp_UInt16Type: + case kIROp_UInt8Type: + case kIROp_UInt64Type: + isSigned = false; + break; + default: + return false; + } + + if (nBits >= 64) + { + // IRIntegerValue is 64-bit, so we assume we can always represent the value. + // TODO: Corner cases like loops that _rely_ on 64-bit integer overflow might not be handled + // correctly. + // + return true; + } + + if (isSigned) + { + IRIntegerValue maxValue = (1ULL << (nBits - 1)) - 1; + return value >= -maxValue && value <= maxValue; + } + else + { + IRIntegerValue maxValue = (1ULL << nBits) - 1; + return value >= 0 && value <= maxValue; + } +} + +void AutodiffCheckpointPolicyBase::collectLoopExitConditions(IRGlobalValueWithCode* func) +{ + // Assume that the InductionValueInfo is already collected. + IRBuilder builder(func->getModule()); + RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); + for (auto block : func->getBlocks()) + { + auto loopInst = as<IRLoop>(block->getTerminator()); + if (!loopInst) + continue; + auto targetBlock = loopInst->getTargetBlock(); + auto ifElse = as<IRIfElse>(targetBlock->getTerminator()); + if (!ifElse) + continue; + + auto condParam = as<IRParam>(ifElse->getCondition()); + if (!condParam || condParam->getParent() != targetBlock) + continue; + + // Locate the loop counter. + IRInst* loopCounter = nullptr; + for (auto param : targetBlock->getParams()) + { + if (param->findDecoration<IRLoopCounterDecoration>()) + { + loopCounter = param; + break; + } + } + + if (!loopCounter) + continue; + + // Go over all loop phi parameters for which we have induction value info, + // and try to determine a relation on the exit value. + // + for (auto param : targetBlock->getParams()) + { + auto inductionValueInfo = inductionValueInsts.tryGetValue(param); + if (!inductionValueInfo || + inductionValueInfo->kind != LoopInductionValueInfo::AffineFunctionOfCounter) + continue; + + // We need to have a known constant offset to be able to compute the loop exit value. + if (!isIntegerConstantValue(inductionValueInfo->counterOffset)) + continue; + + StatementSet conditionIsFalse; + conditionIsFalse.conjunct(condParam, SimpleRelation::boolRelation(false)); + + // Collect a statement that holds when the loop condition is false. + const auto implicationsForFalseCondition = + collectImplications(domTree, targetBlock, conditionIsFalse); + + if (!implicationsForFalseCondition.statements.containsKey(param)) + { + // The statement we collected says nothing about the parameter. No point continuing. + continue; + } + + // Collect statements for the inverse.. i.e. some relation that holds if the condition + // is true. + StatementSet conditionIsTrue; + conditionIsTrue.conjunct(condParam, SimpleRelation::boolRelation(true)); + const auto implicationsForTrueCondition = + collectImplications(domTree, targetBlock, conditionIsTrue); + + if (!implicationsForTrueCondition.statements.containsKey(param)) + { + // The statement we collected says nothing about the parameter. No point continuing. + continue; + } + + // Extract A s.t. ~breakFlag => A. + // + // (Note that breakFlag == false is the case where the + // loop exits) + // + SimpleRelation statement = implicationsForFalseCondition.statements.getValue(param); + + // Extract B s.t. breakFlag => B + SimpleRelation inverseStatement = + implicationsForTrueCondition.statements.getValue(param); + + // If A => ~B, then by using the contrapositive, we get A <=> ~breakFlag + if (!doesRelationImply(statement, inverseStatement.negated())) + { + // If the above doesn't work, we can try using ~B instead. + if (!doesRelationImply(inverseStatement.negated(), statement)) + continue; // Neither works.. we can't infer anything about param. + else + statement = inverseStatement.negated(); // Use ~B <=> ~breakFlag + } + + // We found a relation on the parameter at the loop exit, and we also proved that + // if the relation holds, the loop must exit. + // + // If we have an inequality + information that a value is an inductive (i.e. follows a + // sequence of the form `start + i * step`), then we can use that to compute the exact + // value at the loop exit. + // + // We can do this by solving the inequality for the parameter, using the inductive value + // as the counter variable. + // + if (inductionValueInfo->kind == LoopInductionValueInfo::Kind::AffineFunctionOfCounter) + { + auto counterOffset = getConstantIntegerValue(inductionValueInfo->counterOffset); + auto counterFactor = inductionValueInfo->counterFactor; + + SLANG_ASSERT(statement.type == SimpleRelation::Type::IntegerRelation); + auto relationValue = statement.integerValue; + + auto recordExitValue = [&](IRIntegerValue exitIValue, IRIntegerValue exitParamValue) + { + // TODO: Maybe we should warn if the inferred exit value is out of range? + if (isValueInRange(exitParamValue, param->getDataType())) + { + this->loopExitValueInsts[param] = + builder.getIntValue(param->getDataType(), exitParamValue); + } + + // The interesting part is that since we know that this variable is an bijective + // function of the loop counter, we can also compute the loop counter's exit + // value. + // + // Since this can come from multiple parameters, we'll verify to make sure that + // there are no contradictions. + // + IRInst* loopCounterExitValue; + if (this->loopExitValueInsts.tryGetValue(loopCounter, loopCounterExitValue)) + { + auto loopCounterExitIValue = getConstantIntegerValue(loopCounterExitValue); + if (loopCounterExitIValue != exitIValue) + { + SLANG_ASSERT(!"contradictory loop exit values"); + } + } + else + { + // TODO: Maybe we should warn if the inferred exit value is out of range? + if (isValueInRange(exitIValue, loopCounter->getDataType())) + { + this->loopExitValueInsts[loopCounter] = + builder.getIntValue(loopCounter->getDataType(), exitIValue); + } + } + }; + + if (counterFactor > 0 && statement.comparator == SimpleRelation::GreaterThanEqual) + { + // Find the smallest value that satisfies counterFactor * i + counterOffset >= + // relationValue + // + IRIntegerValue exitIValue = + (((relationValue - counterOffset) + counterFactor - 1) / counterFactor); + IRIntegerValue exitParamValue = + counterOffset + counterFactor * (exitIValue - 1); + recordExitValue(exitIValue, exitParamValue); + } + else if (counterFactor < 0 && statement.comparator == SimpleRelation::LessThanEqual) + { + // Find the largest value that satisfies counterFactor * i + counterOffset <= + // relationValue + // + IRIntegerValue exitIValue = + ((relationValue - counterOffset) + (counterFactor + 1)) / counterFactor; + IRIntegerValue exitParamValue = counterOffset + counterFactor * exitIValue; + recordExitValue(exitIValue, exitParamValue); + } + // TODO: handle other cases + } + } + } +} + void applyToInst( IRBuilder* builder, CheckpointSetInfo* checkpointInfo, @@ -1065,6 +1324,22 @@ void applyToInst( bool isInstRecomputed = checkpointInfo->recomputeSet.contains(inst); if (isInstRecomputed) { + if (auto loopExitValueInst = as<IRLoopExitValue>(inst)) + { + if (auto loopExitValue = + checkpointInfo->loopExitValueInsts.tryGetValue(loopExitValueInst->getVal())) + { + cloneCtx->cloneEnv.mapOldValToNew[inst] = *loopExitValue; + cloneCtx->registerClonedInst(builder, inst, *loopExitValue); + return; + } + + // Should never happen. (Can't mark a LoopExitValue inst as recomputed without having an + // entry in loopExitValueInsts dict) + // + SLANG_ASSERT(!"no loop exit value found for inst"); + } + if (as<IRParam>(inst)) { // Can completely ignore first block parameters @@ -2239,6 +2514,13 @@ void lowerCheckpointObjectInsts(IRGlobalValueWithCode* func) inst->removeAndDeallocate(); } + if (auto loopExitValueInst = as<IRLoopExitValue>(inst)) + { + auto originalVal = loopExitValueInst->getVal(); + loopExitValueInst->replaceUsesWith(originalVal); + loopExitValueInst->removeAndDeallocate(); + } + inst = nextInst; } } @@ -2268,9 +2550,17 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func) sortBlocksInFunc(func); + // Dump IR. + /*IRDumpOptions options; + options.flags = IRDumpOptions::Flag::DumpDebugIds; + options.mode = IRDumpOptions::Mode::Detailed; + DiagnosticSinkWriter writer(sink); + writer.write("### BEFORE-PROCESS-FUNC\n", strlen("### BEFORE-PROCESS-FUNC\n")); + dumpIR(func, options, sink->getSourceManager(), &writer);*/ + // Determine the strategy we should use to make a primal inst available. - // If we decide to recompute the inst, emit the recompute inst in the corresponding recompute - // block. + // If we decide to recompute the inst, emit the recompute inst in the corresponding + // recompute block. // RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule()); chkPolicy->preparePolicy(func); @@ -2421,6 +2711,7 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_MatrixReshape: case kIROp_VectorReshape: case kIROp_GetTupleElement: + case kIROp_LoopExitValue: return false; case kIROp_Load: @@ -2545,6 +2836,13 @@ bool DefaultCheckpointPolicy::canRecompute(UseOrPseudoUse use) } } } + else if (auto exitValue = as<IRLoopExitValue>(use.usedVal)) + { + if (loopExitValueInsts.containsKey(exitValue->getVal())) + return true; + else + return false; + } return true; } @@ -2578,5 +2876,4 @@ HoistResult DefaultCheckpointPolicy::classify(UseOrPseudoUse use) } } } - }; // namespace Slang diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index 92bc8197d..f8c31940a 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -252,6 +252,7 @@ struct CheckpointSetInfo : public RefObject HashSet<IRInst*> invertSet; Dictionary<IRInst*, LoopInductionValueInfo> loopInductionInfo; Dictionary<IRInst*, InversionInfo> invInfoMap; + Dictionary<IRInst*, IRInst*> loopExitValueInsts; }; struct UseOrPseudoUse @@ -323,7 +324,9 @@ public: protected: IRModule* module; Dictionary<IRInst*, LoopInductionValueInfo> inductionValueInsts; + Dictionary<IRInst*, IRInst*> loopExitValueInsts; void collectInductionValues(IRGlobalValueWithCode* func); + void collectLoopExitConditions(IRGlobalValueWithCode* func); }; class DefaultCheckpointPolicy : public AutodiffCheckpointPolicyBase diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 8019fdd08..bd590f08f 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -717,6 +717,7 @@ INST(BitNot, bitnot, 1, 0) INST(Select, select, 3, 0) INST(CheckpointObject, checkpointObj, 1, 0) +INST(LoopExitValue, loopExitValue, 1, 0) INST(GetStringHash, getStringHash, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 86596f316..51608ebc9 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2710,6 +2710,13 @@ struct IRCheckpointObject : IRInst IRInst* getVal() { return getOperand(0); } }; +struct IRLoopExitValue : IRInst +{ + IR_LEAF_ISA(LoopExitValue); + + IRInst* getVal() { return getOperand(0); } +}; + // Signals that this point in the code should be unreachable. // We can/should emit a dataflow error if we can ever determine // that a block ending in one of these can actually be @@ -4458,6 +4465,7 @@ public: IRInst* emitDiscard(); IRInst* emitCheckpointObject(IRInst* value); + IRInst* emitLoopExitValue(IRInst* value); IRInst* emitUnreachable(); IRInst* emitMissingReturn(); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index a74ac58a4..1287d1598 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5734,6 +5734,13 @@ IRInst* IRBuilder::emitCheckpointObject(IRInst* value) return inst; } +IRInst* IRBuilder::emitLoopExitValue(IRInst* value) +{ + auto inst = createInst<IRLoopExitValue>(this, kIROp_LoopExitValue, value->getFullType(), value); + addInst(inst); + return inst; +} + IRInst* IRBuilder::emitBranch(IRBlock* pBlock) { auto inst = createInst<IRUnconditionalBranch>(this, kIROp_unconditionalBranch, nullptr, pBlock); diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang index 51f17b611..77bfb358c 100644 --- a/tests/autodiff/reverse-continue-loop.slang +++ b/tests/autodiff/reverse-continue-loop.slang @@ -9,14 +9,13 @@ RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; typedef float.Differential dfloat; -//CHK-DAG: note: checkpointing context of 24 bytes associated with function: 'test_loop_with_continue' +//CHK-DAG: note: checkpointing context of 20 bytes associated with function: 'test_loop_with_continue' [BackwardDifferentiable] float test_loop_with_continue(float y) { //CHK-DAG: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item: float t = y; - //CHK-DAG: note: 4 bytes (int32_t) used to checkpoint the following item: for (int i = 0; i < 3; i++) { if (t > 4.0) diff --git a/tests/autodiff/reverse-loop-diff-only-2.slang b/tests/autodiff/reverse-loop-diff-only-2.slang index 2cc33ecca..cc9e14736 100644 --- a/tests/autodiff/reverse-loop-diff-only-2.slang +++ b/tests/autodiff/reverse-loop-diff-only-2.slang @@ -32,10 +32,17 @@ float infinitesimal(float x) // Test that computeLoop's intermediates have no float sitting // around (must not cache the outvar from 'compute()') -// CHECK: struct s_bwd_prop_computeLoop_Intermediates -// CHECK-NEXT: { -// CHECK-NOT: {{[A-Za-z0-9_]+}} {{[A-Za-z0-9_]+}}[{{.*}}] -// CHECK: } +// +// Further, if loop exit value inference is working correctly, +// then there should be no context type at all. +// +// CHECK-NOT: struct s_bwd_prop_computeLoop_Intermediates +// +// Check that the signature of the s_bwd_prop_computeLoop function only +// contains an inout DiffPair_float_0 and a float. +// +// CHECK: void s_bwd_prop_computeLoop{{[_0-9]*}}(inout DiffPair_float{{[_0-9]*}} dpy{{[_0-9]*}}, float {{[_a-zA-Z0-9]*}}) +// [BackwardDifferentiable] [PreferRecompute] diff --git a/tests/autodiff/reverse-loop-exit-value-inference-1.slang b/tests/autodiff/reverse-loop-exit-value-inference-1.slang new file mode 100644 index 000000000..5a8d8c2c3 --- /dev/null +++ b/tests/autodiff/reverse-loop-exit-value-inference-1.slang @@ -0,0 +1,240 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cpu -compute -output-using-type -shaderobj +//TEST:SIMPLE(filecheck=CHK_REPORT):-target hlsl -stage compute -entry computeMain -report-checkpoint-intermediates + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +// A variety of tests to check for loop exit value inference. +// For all of these loops, we expect our inference pass to be able to +// infer the loop exit value correctly. +// +// Further, if the optimization pass runs successfully, then there should +// be absolutely no context stored for any of these tests. +// + +// CHK_REPORT: (0): note: no checkpoint contexts to report + +[Differentiable] +float test_simple(float y) +{ + float t = y; + + for (int i = 0; i < 3; i++) + { + t = t * (i + 1); + } + + return t; +} + +[Differentiable] +float test_strided(float y) +{ + float t = y; + + for (int i = 0; i < 5; i+=2) + { + t = t * (i + 1); + } + + return t; +} + +[Differentiable] +float test_offset(float y) +{ + float t = y; + + for (int i = 2; i < 5; i+=2) + { + t = t * (i + 1); + } + + return t; +} + +[Differentiable] +float test_negative_stride(float y) +{ + float t = y; + + for (int i = 7; i >= 1; i-=2) + { + t = t * (i + 1); + } + + return t; +} + +[Differentiable] +float test_nested(float y) +{ + float t = y; + + for (int i = 0; i < 3; i++) + { + for (int j = 0; j < 3; j++) + { + t = t * (i + 4 * j + 1); + } + } + + return t; +} + +[Differentiable] +float test_nested_with_offset(float y) +{ + float t = y; + + for (int i = -3; i < 3; i++) + { + for (int j = -3; j < 3; j++) + { + t = t * ((abs(i) % 2) + (abs(j) % 2) + 1); + } + } + + return t; +} + +[Differentiable] +float test_nested_with_conditions(float y) +{ + float t = y; + + for (int i = 0; i < 3; i++) + { + if (i % 2 == 0) + { + for (int j = 0; j < 3; j++) + { + if (j % 2 == 0) + { + t = t * (i + 4 * j + 1); + } + } + } + } + + return t; +} + +[Differentiable] +float test_with_continue(float y) +{ + float t = y; + + for (int i = 0; i < 5; i++) + { + if (i % 2 == 0) + { + continue; + } + + t = t * (i + 1); + } + + return t; +} + +[Differentiable] +float test_nested_with_continue(float y) +{ + float t = y; + + for (int i = 0; i < 3; i++) + { + if (i % 2 == 0) + continue; + + for (int j = 0; j < 3; j++) + { + if (j % 2 == 0) + continue; + + if (j == 0) + continue; + + t = t * (i + 4 * j + 1); + } + } + + return t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = 0.0f; // CHECK: 0.000000 + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_simple)(dpa, 1.0f); + outputBuffer[1] = dpa.d; // CHECK-NEXT: 6.000000 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_strided)(dpa, 1.0f); + outputBuffer[2] = dpa.d; // CHECK-NEXT: 15.000000 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_offset)(dpa, 1.0f); + outputBuffer[3] = dpa.d; // CHECK-NEXT: 15.000000 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_negative_stride)(dpa, 1.0f); + outputBuffer[4] = dpa.d; // CHECK-NEXT: 384.000000 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_nested)(dpa, 1.0f); + outputBuffer[5] = dpa.d; // CHECK-NEXT: 1247400.000000 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_nested_with_offset)(dpa, 1.0f); + outputBuffer[6] = dpa.d; // CHECK-NEXT: 5159780352.000000 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_nested_with_conditions)(dpa, 1.0f); + outputBuffer[7] = dpa.d; // CHECK-NEXT: 297.000000 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_with_continue)(dpa, 1.0f); + outputBuffer[8] = dpa.d; // CHECK-NEXT: 8.000000 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_nested_with_continue)(dpa, 1.0f); + outputBuffer[9] = dpa.d; // CHECK-NEXT: 6.000000 + } +} + +//CHK-NOT: note
\ No newline at end of file diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop-simple.slang index 18b672860..18b672860 100644 --- a/tests/autodiff/reverse-loop.slang +++ b/tests/autodiff/reverse-loop-simple.slang diff --git a/tests/autodiff/reverse-loop.slang.expected.txt b/tests/autodiff/reverse-loop-simple.slang.expected.txt index 76b7cf779..76b7cf779 100644 --- a/tests/autodiff/reverse-loop.slang.expected.txt +++ b/tests/autodiff/reverse-loop-simple.slang.expected.txt |
