summaryrefslogtreecommitdiff
path: root/source/slang
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2023-08-01 12:43:51 +0800
committerGitHub <noreply@github.com>2023-08-01 12:43:51 +0800
commitc34a7b6627d4c07531daf7d99dceaf7f89bd1c0a (patch)
tree36eef7ee055c3706bce32493f47fddb5c0af3a4f /source/slang
parent5349241098076bead63f638daf2e4b9a9cb3e496 (diff)
Generalize collectInductionValues (#3031)
* Generalize collectInductionValues * Support affine transformations of loop index as induction variables * Test for generalized induction value collection * Neaten inductive variable finding * Store the type of implication success when finding inductive variables * Test that loop induction finding does not alway succeed * Support chains of additions and branches of additions in induction variable finding * Use c++17 for downstream compilers
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp405
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h3
-rw-r--r--source/slang/slang-ir.h12
3 files changed, 299 insertions, 121 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 6947fd7c5..898a86049 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -1,6 +1,9 @@
#include "slang-ir-autodiff-primal-hoist.h"
+#include "slang-ast-support-types.h"
#include "slang-ir-autodiff-region.h"
#include "slang-ir-simplify-cfg.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
namespace Slang
{
@@ -493,6 +496,219 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
return hoistInfo;
}
+struct ImplicationParams
+{
+ IRInst *condition, *induction, *block;
+ HashCode getHashCode() const { return getHashCodeBytewise(*this); }
+ // C++20: friend auto operator<=>(const ImplicationParams&, const ImplicationParams&) = default;
+ friend bool operator==(const ImplicationParams& x, const ImplicationParams& y)
+ {
+ return x.condition == y.condition && x.induction == y.induction && x.block == y.block;
+ }
+};
+
+struct ImplicationResult
+{
+ enum
+ {
+ // The value was not a constant offset from the induction variable and
+ // the condition variable was true
+ Falsified,
+ // The condition variable is false, so the value's relationship to the
+ // induction variable doesn't matter
+ AntecedentHolds,
+ // The value is a constant offset from the induction variable, stored
+ // in 'factor'
+ ConsequentHolds
+ } e;
+ IRIntegerValue factor;
+};
+
+static ImplicationResult join(const ImplicationResult& a, const ImplicationResult& b)
+{
+ if(a.e == ImplicationResult::Falsified || b.e == ImplicationResult::Falsified)
+ return {ImplicationResult::Falsified, 0};
+ if(a.e == ImplicationResult::AntecedentHolds)
+ return b;
+ if(b.e == ImplicationResult::AntecedentHolds)
+ return a;
+ if(a.factor != b.factor)
+ return {ImplicationResult::Falsified, 0};
+ return a;
+}
+
+static ImplicationResult inductionImplicationHolds(
+ Dictionary<ImplicationParams, ImplicationResult>& memo,
+ IRInst* const prevVal,
+ IRInst* const conditionVal,
+ IRInst* const inductiveVal,
+ IRBlock* const block);
+
+static bool unpackConstantAddition(IRInst* addOrSub, IRInst*& operand, IRIntegerValue& constant)
+{
+ if(addOrSub->getOp() != kIROp_Add && addOrSub->getOp() != kIROp_Sub)
+ return false;
+ const bool negate = addOrSub->getOp() == kIROp_Sub;
+
+ auto o = addOrSub->getOperand(0);
+ auto c = addOrSub->getOperand(1);
+ if(!as<IRIntLit>(c))
+ std::swap(o, c);
+ const auto cLit = as<IRIntLit>(c);
+ if(!cLit)
+ return false;
+ operand = o;
+ constant = cLit->getValue();
+ // Check that we can actually represent this!
+ if(negate && constant == std::numeric_limits<IRIntegerValue>::min())
+ return false;
+ constant *= negate ? -1 : 1;
+ return true;
+}
+
+// isAdditionOf(m, i, p, c, f) returns true if it can prove `c => i = p + f`
+// for some constant factor f
+static bool isAdditionOf(
+ Dictionary<ImplicationParams, ImplicationResult>& memo,
+ IRInst* inductiveVal,
+ IRInst* prevVal,
+ IRInst* conditionVal,
+ IRIntegerValue& factor)
+{
+ IRInst* operand;
+ IRIntegerValue constant;
+ if(!unpackConstantAddition(inductiveVal, operand, constant))
+ return false;
+
+ const auto impRes = inductionImplicationHolds(
+ memo,
+ prevVal,
+ conditionVal,
+ operand,
+ as<IRBlock>(inductiveVal->getParent()));
+
+ if(impRes.e == ImplicationResult::ConsequentHolds)
+ {
+ // TODO: Check for overflow here (strictly speaking it shouldn't
+ // matter in the end numerically (except that this could be UB)).
+ factor = impRes.factor + constant;
+ return true;
+ }
+ return false;
+}
+
+// Returns true if we can prove that in this block this value is
+// always false
+static bool isAlwaysFalseInBlock(IRInst* inst, IRBlock* block)
+{
+ const auto b = as<IRBoolLit>(inst);
+ if(b)
+ return !b->getValue();
+
+ // At the moment we just check that the predecessors of this
+ // block have us on the false path of a conditional branch on
+ // the instruction under question.
+ bool isFalse = true;
+ for(const auto predecessor : block->getPredecessors())
+ {
+ const auto predConditionalBranch = as<IRConditionalBranch>(predecessor->getTerminator());
+ isFalse &=
+ predConditionalBranch &&
+ predConditionalBranch->getCondition() == inst &&
+ predConditionalBranch->getFalseBlock() == block;
+ if(!isFalse)
+ break;
+ }
+ return isFalse;
+}
+
+// This function takes:
+// - A block with an unconditional branch with at least one parameter 'i'
+// - The index of 'i'
+// - A condition variable 'c'
+// - The predecessor case in the induction 'p'
+//
+// It returns true if at the time of the branch: 'isTrue(c) => isInductiveValue(i, p)'
+// It return false if it can't prove this implication holds.
+static ImplicationResult inductionImplicationHolds(
+ Dictionary<ImplicationParams, ImplicationResult>& memo,
+ IRInst* const prevVal,
+ IRInst* const conditionVal,
+ IRInst* const inductiveVal,
+ IRBlock* const block)
+{
+ // If we have a result memoized we can safely return that.
+ const ImplicationParams i = {conditionVal, inductiveVal, block};
+ const auto memoized = memo.tryGetValue(i);
+ if(memoized)
+ return *memoized;
+
+ // While we are detemining if the implication holds at this position we set
+ // the result to Falsified so as to fail if we require a self-referential
+ // proof
+ memo.add(i, {ImplicationResult::Falsified, 0});
+ // A helper to record the solution as we're returning
+ const auto andRemember = [&memo, i](ImplicationResult r) {
+ memo.set(i, r);
+ return r;
+ };
+
+ // Our most general solution is if the left hand side of the implication is
+ // false, in which case we can return success without specifying a factor
+ if(isAlwaysFalseInBlock(conditionVal, block))
+ return andRemember({ImplicationResult::AntecedentHolds, 0});
+
+ // Otherwise, we handle the additive case
+ // One easy case is that this *is* the previous value, in which case it's a
+ // trivial solution with an addition of 0
+ if(prevVal == inductiveVal)
+ return andRemember({ImplicationResult::ConsequentHolds, 0});
+
+ // Otherwise is it a function over the inductive variable
+ IRIntegerValue factor;
+ if(isAdditionOf(memo, inductiveVal, prevVal, conditionVal, factor))
+ return andRemember({ImplicationResult::ConsequentHolds, factor});
+
+ // The last thing to try is to consider the case where the
+ // inductive value under consideration is a parameter, in that case we can
+ // recurse into the predecessors of this block, replacing the parameter and
+ // condition variables with their arguments where appropriate.
+ const auto inductiveParam = as<IRParam>(inductiveVal);
+
+ // If it's not a parameter then we don't know how to continue
+ // (in principle we could also hadle instructions such as loads here)
+ if(!inductiveParam)
+ return {ImplicationResult::Falsified, 0};
+
+ const auto conditionParam = as<IRParam>(conditionVal);
+ const auto inductiveParamIndex = block->getParamIndex(inductiveParam);
+ const auto conditionParamIndex = block->getParamIndex(conditionParam);
+
+ // If we have no predecessors, then all the possible values (none) of the
+ // condition variable are false, so our antecedent holds
+ ImplicationResult res = {ImplicationResult::AntecedentHolds, 0};
+
+ for(const auto predecessor : block->getPredecessors())
+ {
+ const auto predTerminator = as<IRUnconditionalBranch>(predecessor->getTerminator());
+ SLANG_ASSERT(inductiveParamIndex == -1 || predTerminator);
+ SLANG_ASSERT(conditionParamIndex == -1 || predTerminator);
+
+ const auto nextInductiveParam =
+ inductiveParamIndex == -1 ? inductiveParam : predTerminator->getArg(inductiveParamIndex);
+ const auto nextConditionParam =
+ conditionParamIndex == -1 ? conditionParam : predTerminator->getArg(conditionParamIndex);
+
+ const auto predResult = inductionImplicationHolds(memo, prevVal, nextConditionParam, nextInductiveParam, predecessor);
+ res = join(res, predResult);
+
+ if(res.e == ImplicationResult::Falsified)
+ break;
+ }
+
+ return andRemember(res);
+}
+
void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode* func)
{
// Collect loop induction values.
@@ -536,139 +752,80 @@ void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode*
if (conditionParamIndex == -1)
continue;
- // Next, we try to identify the original induction variables, if they exist.
- // These are trickier, and we have to hard code the complex pattern that
- // we can recognize.
- // This pattern matching logic is ugly and fragile against changes to cfg
- // normalization, but it is the easiest way to do it right now.
- // Basically, we are looking for this pattern:
- // loop(..., i=initVal)
- // {
- // targetBlock:
- // ...
- // param int i;
- // param bool condition;
- // ...
- // branch condtionBlock;
- // conditionBlock:
- // if (condition)
- // {
- // }
- // else
- // {
- // break;
- // }
- // // ...
- // someBodyBlock:
- // ...
- // if (condition)
- // {
- // ...
- // // Check condition 1: i is used by an `add`
- // // Check condition 2: parent of (i+1) is a branch target of if(condition)
- // // Check condition 3: branches to parentBlock with i1 = i + 1.
- // goto parentBlock(i + 1);
- // }
- // else
- // goto parentBlock(other);
- // parentBlock:
- // // Check condition 4: parentBlock branches to finalBlock.
- // param int i1;
- // goto finalBlock;
- // finalBlock:
- // // Check condition 5: finalBlock branches to targetBlock with new i = i1.
- // goto loopHeader(i1);
- // }
+ // Next we try to identify any induction variables.
//
+ // An inductive parameter must:
+ // - Be initialized as anything from a single predecessor, the base
+ // case
+ // - Be passed a function of itself only on any other entries to
+ // the loop, the inductive case
+ //
+ // In terms of matching here, we allow the base case to be
+ // anything, and the inductive case to be the successor function
+ //
+ // We also handle the case where something other than the base or
+ // inductive case is passed to the top of the loop when the "condition
+ // parameter" is false, in which case the "non-inductive" value isn't
+ // actually used.
+
paramIndex = -1;
for (auto param : targetBlock->getParams())
{
paramIndex++;
- if (!param->getDataType())
+
+ const auto t = param->getDataType();
+ if (!t || !isScalarIntegerType(t))
continue;
- if (isScalarIntegerType(param->getDataType()))
+
+ // This *is* the loop counter!
+ if(param->findDecoration<IRLoopCounterDecoration>())
+ continue;
+
+ auto predecessors = targetBlock->getPredecessors();
+ Dictionary<ImplicationParams, ImplicationResult> memo;
+ ImplicationResult impRes = {ImplicationResult::AntecedentHolds, 0};
+ for(const auto predecessor : predecessors)
{
- // If the param is always equal to the loop index, we don't need to store it.
- IRInst* addUse = nullptr;
- for (auto use = param->firstUse; use && !addUse; use = use->nextUse)
- {
- auto user = use->getUser();
- if (user->getOp() != kIROp_Add)
- continue;
- auto intLit = as<IRIntLit>(use->getUser()->getOperand(1));
- if (!intLit)
- continue;
- if (intLit->getValue() != 1)
- continue;
-
- // The add inst's parent block is behind a `ifelse(loopCondition)`.
- auto addInstBlock = as<IRBlock>(user->getParent());
- if (!addInstBlock)
- continue;
- auto predecessors = addInstBlock->getPredecessors();
- if (predecessors.getCount() != 1)
- continue;
- auto parentIfElse = as<IRIfElse>(predecessors.b->getUser());
- if (!parentIfElse)
- continue;
- auto parentCondition = parentIfElse->getCondition();
-
- auto branch = as<IRUnconditionalBranch>(addInstBlock->getTerminator());
- if (!branch)
- continue;
-
- // The add inst should be used as a branchArg.
- UInt argIndex = 0;
- for (UInt i = 0; i < branch->getArgCount(); i++)
- {
- if (branch->getArg(i) == user)
- {
- addUse = user;
- argIndex = i;
- break;
- }
- }
- if (!addUse)
- continue;
- auto branchTarget1 = branch->getTargetBlock();
- auto branchParam = branchTarget1->getFirstParam();
- for (UInt i = 0; i < argIndex; i++)
- if (branchParam)
- branchParam = branchParam->getNextParam();
- if (!branchParam)
- continue;
-
- // The branchParam is used as argument to branch back to loop header.
- auto branch2 = as<IRUnconditionalBranch>(branchTarget1->getTerminator());
- if (!branch2)
- continue;
- if (branch2->getTargetBlock() != targetBlock)
- continue;
- argIndex = 0;
- for (UInt i = 0; i < branch2->getArgCount(); i++)
- {
- if (branch2->getArg(i) == branchParam)
- {
- argIndex = i;
- break;
- }
- }
- if (argIndex != (UInt)paramIndex)
- continue;
+ // Since this is branching with a parameter, it can only be an
+ // unconditional branch.
+ const auto predTerminator = as<IRUnconditionalBranch>(predecessor->getTerminator());
+ SLANG_ASSERT(predTerminator);
+
+ // Is this the base case?
+ if(predTerminator == loopInst)
+ continue;
- // parentCondition is also used as the new condition in the back jump.
- if (conditionParamIndex < 0 || (UInt)conditionParamIndex >= branch2->getArgCount() ||
- branch2->getArg((UInt)conditionParamIndex) != parentCondition)
- continue;
+ const auto conditionArg = predTerminator->getArg(conditionParamIndex);
+ const auto inductiveArg = predTerminator->getArg(paramIndex);
+ // Check that the required implication holds for this block
+ const auto predRes = inductionImplicationHolds(memo, param, conditionArg, inductiveArg, predecessor);
+ impRes = join(impRes, predRes);
+ if(impRes.e == ImplicationResult::Falsified)
+ break;
+ }
+
+ switch(impRes.e)
+ {
+ // This wasn't an induction variable
+ case ImplicationResult::Falsified:
+ break;
+
+ // The loop doesn't loop (because in every case the break flag is
+ // true)
+ case ImplicationResult::AntecedentHolds:
+ break;
+
+ case ImplicationResult::ConsequentHolds:
+ {
// The use of the add inst matches all of our conditions as an induction value
- // that is equivalent to loop counter.
+ // that is a constant offset from a multiple of the loop counter.
LoopInductionValueInfo info;
- info.kind = LoopInductionValueInfo::Kind::EqualsToCounter;
+ info.kind = LoopInductionValueInfo::Kind::AffineFunctionOfCounter;
info.loopInst = loopInst;
info.counterOffset = loopInst->getArg(paramIndex);
+ info.counterFactor = impRes.factor;
inductionValueInsts[param] = info;
- break;
}
}
}
@@ -710,12 +867,20 @@ void applyToInst(
{
replacement = builder->getBoolValue(true);
}
- else if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::EqualsToCounter)
+ else if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::AffineFunctionOfCounter)
{
auto indexInfo = blockIndexInfo.tryGetValue(inductionValueInfo.loopInst->getTargetBlock());
SLANG_ASSERT(indexInfo);
SLANG_ASSERT(indexInfo->getCount() != 0);
replacement = indexInfo->getFirst().diffCountParam;
+ if (inductionValueInfo.counterFactor != 1)
+ {
+ setInsertAfterOrdinaryInst(builder, replacement);
+ replacement = builder->emitMul(
+ replacement->getDataType(),
+ replacement,
+ builder->getIntValue(replacement->getDataType(), inductionValueInfo.counterFactor));
+ }
if (inductionValueInfo.counterOffset)
{
setInsertAfterOrdinaryInst(builder, replacement);
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index 59c70e862..96e4ea99b 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -240,11 +240,12 @@ namespace Slang
enum Kind
{
AlwaysTrue,
- EqualsToCounter,
+ AffineFunctionOfCounter,
};
Kind kind;
IRLoop* loopInst = nullptr;
IRInst* counterOffset = nullptr;
+ IRIntegerValue counterFactor = 1;
};
// Information on which insts are to be stored, recomputed
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index e7cc0d6c8..b0d9bb109 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1208,6 +1208,18 @@ struct IRBlock : IRInst
getFirstParam(),
getLastParam());
}
+ // Linear in the parameter index, returns -1 if the param doesn't exist
+ Index getParamIndex(IRParam* const needle)
+ {
+ Index ret = 0;
+ for(const auto p : getParams())
+ {
+ if (p == needle)
+ return ret;
+ ret++;
+ }
+ return -1;
+ }
void addParam(IRParam* param);
void insertParamAtHead(IRParam* param);