summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2023-08-01 12:43:51 +0800
committerGitHub <noreply@github.com>2023-08-01 12:43:51 +0800
commitc34a7b6627d4c07531daf7d99dceaf7f89bd1c0a (patch)
tree36eef7ee055c3706bce32493f47fddb5c0af3a4f
parent5349241098076bead63f638daf2e4b9a9cb3e496 (diff)
Generalize collectInductionValues (#3031)
* Generalize collectInductionValues * Support affine transformations of loop index as induction variables * Test for generalized induction value collection * Neaten inductive variable finding * Store the type of implication success when finding inductive variables * Test that loop induction finding does not alway succeed * Support chains of additions and branches of additions in induction variable finding * Use c++17 for downstream compilers
-rw-r--r--source/compiler-core/slang-gcc-compiler-util.cpp15
-rw-r--r--source/compiler-core/slang-nvrtc-compiler.cpp5
-rw-r--r--source/core/slang-hash.h7
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp405
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h3
-rw-r--r--source/slang/slang-ir.h12
-rw-r--r--tests/autodiff/long-loop-branching-addition.slang57
-rw-r--r--tests/autodiff/long-loop-branching-addition.slang.expected.txt2
-rw-r--r--tests/autodiff/long-loop-chained-addition.slang42
-rw-r--r--tests/autodiff/long-loop-chained-addition.slang.expected.txt2
-rw-r--r--tests/autodiff/long-loop-multiple.slang41
-rw-r--r--tests/autodiff/long-loop-multiple.slang.expected.txt2
-rw-r--r--tests/autodiff/long-loop-noninductive.slang40
-rw-r--r--tests/autodiff/long-loop-noninductive.slang.expected.txt2
-rw-r--r--tests/autodiff/long-while-loop.slang43
-rw-r--r--tests/autodiff/long-while-loop.slang.expected.txt2
-rw-r--r--tests/cpu-program/cpu-hello-world.slang2
17 files changed, 549 insertions, 133 deletions
diff --git a/source/compiler-core/slang-gcc-compiler-util.cpp b/source/compiler-core/slang-gcc-compiler-util.cpp
index f003ffb0f..c4354f40e 100644
--- a/source/compiler-core/slang-gcc-compiler-util.cpp
+++ b/source/compiler-core/slang-gcc-compiler-util.cpp
@@ -466,8 +466,8 @@ static SlangResult _parseGCCFamilyLine(SliceAllocator& allocator, const UnownedS
{
cmdLine.addArg("-fvisibility=hidden");
- // Need C++14 for partial specialization
- cmdLine.addArg("-std=c++14");
+ // C++17 since we share headers with slang itself (which uses c++17)
+ cmdLine.addArg("-std=c++17");
}
// TODO(JS): Here we always set -m32 on x86. It could be argued it is only necessary when creating a shared library
@@ -696,19 +696,16 @@ static SlangResult _parseGCCFamilyLine(SliceAllocator& allocator, const UnownedS
ComPtr<IDownstreamCompiler> compiler;
if (SLANG_SUCCEEDED(createCompiler(ExecutableLocation(path, "g++"), compiler)))
{
- // A downstream compiler for Slang must currently support C++14 - such that
+ // A downstream compiler for Slang must currently support C++17 - such that
// the prelude and generated code works.
//
- // The first version of gcc that supports `-std=c++14` is 5.0
+ // The first version of gcc that supports stable `-std=c++17` is 9.0
// https://gcc.gnu.org/projects/cxx-status.html
- //
- // If could be argued to allow C/C++ compilations via older versions through an older version
- // but that requires some more complex behavior, so we don't allow for now.
auto desc = compiler->getDesc();
- if (desc.version.m_major < 5)
+ if (desc.version.m_major < 9)
{
- // If the version isn't 5 or higher, we don't add this version of the compiler.
+ // If the version isn't 9 or higher, we don't add this version of the compiler.
return SLANG_OK;
}
diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp
index e2f3e678c..c756955ec 100644
--- a/source/compiler-core/slang-nvrtc-compiler.cpp
+++ b/source/compiler-core/slang-nvrtc-compiler.cpp
@@ -784,8 +784,9 @@ SlangResult NVRTCDownstreamCompiler::compile(const DownstreamCompileOptions& inO
// Neither of these options are strictly required, for general use of nvrtc,
// but are enabled to make use withing Slang work more smoothly
{
- // Require c++14, as makes initialization construction with {} available and so simplifies code generation
- cmdLine.addArg("-std=c++14");
+ // Require c++17, the default at the time of writing, since we share
+ // some functionality between slang itself and the compiled code
+ cmdLine.addArg("-std=c++17");
// Disable all warnings
// This is arguably too much - but nvrtc does not appear to have a mechanism to switch off individual warnings.
diff --git a/source/core/slang-hash.h b/source/core/slang-hash.h
index 8c4fbac42..bc4b30ccc 100644
--- a/source/core/slang-hash.h
+++ b/source/core/slang-hash.h
@@ -164,6 +164,13 @@ namespace Slang
return PointerHash<std::is_pointer<TKey>::value>::getHashCode(key);
}
+ template<typename TKey>
+ HashCode getHashCodeBytewise(const TKey& t)
+ {
+ static_assert(std::has_unique_object_representations_v<TKey>);
+ return getHashCode(reinterpret_cast<const char*>(&t), sizeof(TKey));
+ }
+
inline HashCode combineHash(HashCode left, HashCode right)
{
return (left * 16777619) ^ right;
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 6947fd7c5..898a86049 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -1,6 +1,9 @@
#include "slang-ir-autodiff-primal-hoist.h"
+#include "slang-ast-support-types.h"
#include "slang-ir-autodiff-region.h"
#include "slang-ir-simplify-cfg.h"
+#include "slang-ir-util.h"
+#include "slang-ir.h"
namespace Slang
{
@@ -493,6 +496,219 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
return hoistInfo;
}
+struct ImplicationParams
+{
+ IRInst *condition, *induction, *block;
+ HashCode getHashCode() const { return getHashCodeBytewise(*this); }
+ // C++20: friend auto operator<=>(const ImplicationParams&, const ImplicationParams&) = default;
+ friend bool operator==(const ImplicationParams& x, const ImplicationParams& y)
+ {
+ return x.condition == y.condition && x.induction == y.induction && x.block == y.block;
+ }
+};
+
+struct ImplicationResult
+{
+ enum
+ {
+ // The value was not a constant offset from the induction variable and
+ // the condition variable was true
+ Falsified,
+ // The condition variable is false, so the value's relationship to the
+ // induction variable doesn't matter
+ AntecedentHolds,
+ // The value is a constant offset from the induction variable, stored
+ // in 'factor'
+ ConsequentHolds
+ } e;
+ IRIntegerValue factor;
+};
+
+static ImplicationResult join(const ImplicationResult& a, const ImplicationResult& b)
+{
+ if(a.e == ImplicationResult::Falsified || b.e == ImplicationResult::Falsified)
+ return {ImplicationResult::Falsified, 0};
+ if(a.e == ImplicationResult::AntecedentHolds)
+ return b;
+ if(b.e == ImplicationResult::AntecedentHolds)
+ return a;
+ if(a.factor != b.factor)
+ return {ImplicationResult::Falsified, 0};
+ return a;
+}
+
+static ImplicationResult inductionImplicationHolds(
+ Dictionary<ImplicationParams, ImplicationResult>& memo,
+ IRInst* const prevVal,
+ IRInst* const conditionVal,
+ IRInst* const inductiveVal,
+ IRBlock* const block);
+
+static bool unpackConstantAddition(IRInst* addOrSub, IRInst*& operand, IRIntegerValue& constant)
+{
+ if(addOrSub->getOp() != kIROp_Add && addOrSub->getOp() != kIROp_Sub)
+ return false;
+ const bool negate = addOrSub->getOp() == kIROp_Sub;
+
+ auto o = addOrSub->getOperand(0);
+ auto c = addOrSub->getOperand(1);
+ if(!as<IRIntLit>(c))
+ std::swap(o, c);
+ const auto cLit = as<IRIntLit>(c);
+ if(!cLit)
+ return false;
+ operand = o;
+ constant = cLit->getValue();
+ // Check that we can actually represent this!
+ if(negate && constant == std::numeric_limits<IRIntegerValue>::min())
+ return false;
+ constant *= negate ? -1 : 1;
+ return true;
+}
+
+// isAdditionOf(m, i, p, c, f) returns true if it can prove `c => i = p + f`
+// for some constant factor f
+static bool isAdditionOf(
+ Dictionary<ImplicationParams, ImplicationResult>& memo,
+ IRInst* inductiveVal,
+ IRInst* prevVal,
+ IRInst* conditionVal,
+ IRIntegerValue& factor)
+{
+ IRInst* operand;
+ IRIntegerValue constant;
+ if(!unpackConstantAddition(inductiveVal, operand, constant))
+ return false;
+
+ const auto impRes = inductionImplicationHolds(
+ memo,
+ prevVal,
+ conditionVal,
+ operand,
+ as<IRBlock>(inductiveVal->getParent()));
+
+ if(impRes.e == ImplicationResult::ConsequentHolds)
+ {
+ // TODO: Check for overflow here (strictly speaking it shouldn't
+ // matter in the end numerically (except that this could be UB)).
+ factor = impRes.factor + constant;
+ return true;
+ }
+ return false;
+}
+
+// Returns true if we can prove that in this block this value is
+// always false
+static bool isAlwaysFalseInBlock(IRInst* inst, IRBlock* block)
+{
+ const auto b = as<IRBoolLit>(inst);
+ if(b)
+ return !b->getValue();
+
+ // At the moment we just check that the predecessors of this
+ // block have us on the false path of a conditional branch on
+ // the instruction under question.
+ bool isFalse = true;
+ for(const auto predecessor : block->getPredecessors())
+ {
+ const auto predConditionalBranch = as<IRConditionalBranch>(predecessor->getTerminator());
+ isFalse &=
+ predConditionalBranch &&
+ predConditionalBranch->getCondition() == inst &&
+ predConditionalBranch->getFalseBlock() == block;
+ if(!isFalse)
+ break;
+ }
+ return isFalse;
+}
+
+// This function takes:
+// - A block with an unconditional branch with at least one parameter 'i'
+// - The index of 'i'
+// - A condition variable 'c'
+// - The predecessor case in the induction 'p'
+//
+// It returns true if at the time of the branch: 'isTrue(c) => isInductiveValue(i, p)'
+// It return false if it can't prove this implication holds.
+static ImplicationResult inductionImplicationHolds(
+ Dictionary<ImplicationParams, ImplicationResult>& memo,
+ IRInst* const prevVal,
+ IRInst* const conditionVal,
+ IRInst* const inductiveVal,
+ IRBlock* const block)
+{
+ // If we have a result memoized we can safely return that.
+ const ImplicationParams i = {conditionVal, inductiveVal, block};
+ const auto memoized = memo.tryGetValue(i);
+ if(memoized)
+ return *memoized;
+
+ // While we are detemining if the implication holds at this position we set
+ // the result to Falsified so as to fail if we require a self-referential
+ // proof
+ memo.add(i, {ImplicationResult::Falsified, 0});
+ // A helper to record the solution as we're returning
+ const auto andRemember = [&memo, i](ImplicationResult r) {
+ memo.set(i, r);
+ return r;
+ };
+
+ // Our most general solution is if the left hand side of the implication is
+ // false, in which case we can return success without specifying a factor
+ if(isAlwaysFalseInBlock(conditionVal, block))
+ return andRemember({ImplicationResult::AntecedentHolds, 0});
+
+ // Otherwise, we handle the additive case
+ // One easy case is that this *is* the previous value, in which case it's a
+ // trivial solution with an addition of 0
+ if(prevVal == inductiveVal)
+ return andRemember({ImplicationResult::ConsequentHolds, 0});
+
+ // Otherwise is it a function over the inductive variable
+ IRIntegerValue factor;
+ if(isAdditionOf(memo, inductiveVal, prevVal, conditionVal, factor))
+ return andRemember({ImplicationResult::ConsequentHolds, factor});
+
+ // The last thing to try is to consider the case where the
+ // inductive value under consideration is a parameter, in that case we can
+ // recurse into the predecessors of this block, replacing the parameter and
+ // condition variables with their arguments where appropriate.
+ const auto inductiveParam = as<IRParam>(inductiveVal);
+
+ // If it's not a parameter then we don't know how to continue
+ // (in principle we could also hadle instructions such as loads here)
+ if(!inductiveParam)
+ return {ImplicationResult::Falsified, 0};
+
+ const auto conditionParam = as<IRParam>(conditionVal);
+ const auto inductiveParamIndex = block->getParamIndex(inductiveParam);
+ const auto conditionParamIndex = block->getParamIndex(conditionParam);
+
+ // If we have no predecessors, then all the possible values (none) of the
+ // condition variable are false, so our antecedent holds
+ ImplicationResult res = {ImplicationResult::AntecedentHolds, 0};
+
+ for(const auto predecessor : block->getPredecessors())
+ {
+ const auto predTerminator = as<IRUnconditionalBranch>(predecessor->getTerminator());
+ SLANG_ASSERT(inductiveParamIndex == -1 || predTerminator);
+ SLANG_ASSERT(conditionParamIndex == -1 || predTerminator);
+
+ const auto nextInductiveParam =
+ inductiveParamIndex == -1 ? inductiveParam : predTerminator->getArg(inductiveParamIndex);
+ const auto nextConditionParam =
+ conditionParamIndex == -1 ? conditionParam : predTerminator->getArg(conditionParamIndex);
+
+ const auto predResult = inductionImplicationHolds(memo, prevVal, nextConditionParam, nextInductiveParam, predecessor);
+ res = join(res, predResult);
+
+ if(res.e == ImplicationResult::Falsified)
+ break;
+ }
+
+ return andRemember(res);
+}
+
void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode* func)
{
// Collect loop induction values.
@@ -536,139 +752,80 @@ void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode*
if (conditionParamIndex == -1)
continue;
- // Next, we try to identify the original induction variables, if they exist.
- // These are trickier, and we have to hard code the complex pattern that
- // we can recognize.
- // This pattern matching logic is ugly and fragile against changes to cfg
- // normalization, but it is the easiest way to do it right now.
- // Basically, we are looking for this pattern:
- // loop(..., i=initVal)
- // {
- // targetBlock:
- // ...
- // param int i;
- // param bool condition;
- // ...
- // branch condtionBlock;
- // conditionBlock:
- // if (condition)
- // {
- // }
- // else
- // {
- // break;
- // }
- // // ...
- // someBodyBlock:
- // ...
- // if (condition)
- // {
- // ...
- // // Check condition 1: i is used by an `add`
- // // Check condition 2: parent of (i+1) is a branch target of if(condition)
- // // Check condition 3: branches to parentBlock with i1 = i + 1.
- // goto parentBlock(i + 1);
- // }
- // else
- // goto parentBlock(other);
- // parentBlock:
- // // Check condition 4: parentBlock branches to finalBlock.
- // param int i1;
- // goto finalBlock;
- // finalBlock:
- // // Check condition 5: finalBlock branches to targetBlock with new i = i1.
- // goto loopHeader(i1);
- // }
+ // Next we try to identify any induction variables.
//
+ // An inductive parameter must:
+ // - Be initialized as anything from a single predecessor, the base
+ // case
+ // - Be passed a function of itself only on any other entries to
+ // the loop, the inductive case
+ //
+ // In terms of matching here, we allow the base case to be
+ // anything, and the inductive case to be the successor function
+ //
+ // We also handle the case where something other than the base or
+ // inductive case is passed to the top of the loop when the "condition
+ // parameter" is false, in which case the "non-inductive" value isn't
+ // actually used.
+
paramIndex = -1;
for (auto param : targetBlock->getParams())
{
paramIndex++;
- if (!param->getDataType())
+
+ const auto t = param->getDataType();
+ if (!t || !isScalarIntegerType(t))
continue;
- if (isScalarIntegerType(param->getDataType()))
+
+ // This *is* the loop counter!
+ if(param->findDecoration<IRLoopCounterDecoration>())
+ continue;
+
+ auto predecessors = targetBlock->getPredecessors();
+ Dictionary<ImplicationParams, ImplicationResult> memo;
+ ImplicationResult impRes = {ImplicationResult::AntecedentHolds, 0};
+ for(const auto predecessor : predecessors)
{
- // If the param is always equal to the loop index, we don't need to store it.
- IRInst* addUse = nullptr;
- for (auto use = param->firstUse; use && !addUse; use = use->nextUse)
- {
- auto user = use->getUser();
- if (user->getOp() != kIROp_Add)
- continue;
- auto intLit = as<IRIntLit>(use->getUser()->getOperand(1));
- if (!intLit)
- continue;
- if (intLit->getValue() != 1)
- continue;
-
- // The add inst's parent block is behind a `ifelse(loopCondition)`.
- auto addInstBlock = as<IRBlock>(user->getParent());
- if (!addInstBlock)
- continue;
- auto predecessors = addInstBlock->getPredecessors();
- if (predecessors.getCount() != 1)
- continue;
- auto parentIfElse = as<IRIfElse>(predecessors.b->getUser());
- if (!parentIfElse)
- continue;
- auto parentCondition = parentIfElse->getCondition();
-
- auto branch = as<IRUnconditionalBranch>(addInstBlock->getTerminator());
- if (!branch)
- continue;
-
- // The add inst should be used as a branchArg.
- UInt argIndex = 0;
- for (UInt i = 0; i < branch->getArgCount(); i++)
- {
- if (branch->getArg(i) == user)
- {
- addUse = user;
- argIndex = i;
- break;
- }
- }
- if (!addUse)
- continue;
- auto branchTarget1 = branch->getTargetBlock();
- auto branchParam = branchTarget1->getFirstParam();
- for (UInt i = 0; i < argIndex; i++)
- if (branchParam)
- branchParam = branchParam->getNextParam();
- if (!branchParam)
- continue;
-
- // The branchParam is used as argument to branch back to loop header.
- auto branch2 = as<IRUnconditionalBranch>(branchTarget1->getTerminator());
- if (!branch2)
- continue;
- if (branch2->getTargetBlock() != targetBlock)
- continue;
- argIndex = 0;
- for (UInt i = 0; i < branch2->getArgCount(); i++)
- {
- if (branch2->getArg(i) == branchParam)
- {
- argIndex = i;
- break;
- }
- }
- if (argIndex != (UInt)paramIndex)
- continue;
+ // Since this is branching with a parameter, it can only be an
+ // unconditional branch.
+ const auto predTerminator = as<IRUnconditionalBranch>(predecessor->getTerminator());
+ SLANG_ASSERT(predTerminator);
+
+ // Is this the base case?
+ if(predTerminator == loopInst)
+ continue;
- // parentCondition is also used as the new condition in the back jump.
- if (conditionParamIndex < 0 || (UInt)conditionParamIndex >= branch2->getArgCount() ||
- branch2->getArg((UInt)conditionParamIndex) != parentCondition)
- continue;
+ const auto conditionArg = predTerminator->getArg(conditionParamIndex);
+ const auto inductiveArg = predTerminator->getArg(paramIndex);
+ // Check that the required implication holds for this block
+ const auto predRes = inductionImplicationHolds(memo, param, conditionArg, inductiveArg, predecessor);
+ impRes = join(impRes, predRes);
+ if(impRes.e == ImplicationResult::Falsified)
+ break;
+ }
+
+ switch(impRes.e)
+ {
+ // This wasn't an induction variable
+ case ImplicationResult::Falsified:
+ break;
+
+ // The loop doesn't loop (because in every case the break flag is
+ // true)
+ case ImplicationResult::AntecedentHolds:
+ break;
+
+ case ImplicationResult::ConsequentHolds:
+ {
// The use of the add inst matches all of our conditions as an induction value
- // that is equivalent to loop counter.
+ // that is a constant offset from a multiple of the loop counter.
LoopInductionValueInfo info;
- info.kind = LoopInductionValueInfo::Kind::EqualsToCounter;
+ info.kind = LoopInductionValueInfo::Kind::AffineFunctionOfCounter;
info.loopInst = loopInst;
info.counterOffset = loopInst->getArg(paramIndex);
+ info.counterFactor = impRes.factor;
inductionValueInsts[param] = info;
- break;
}
}
}
@@ -710,12 +867,20 @@ void applyToInst(
{
replacement = builder->getBoolValue(true);
}
- else if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::EqualsToCounter)
+ else if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::AffineFunctionOfCounter)
{
auto indexInfo = blockIndexInfo.tryGetValue(inductionValueInfo.loopInst->getTargetBlock());
SLANG_ASSERT(indexInfo);
SLANG_ASSERT(indexInfo->getCount() != 0);
replacement = indexInfo->getFirst().diffCountParam;
+ if (inductionValueInfo.counterFactor != 1)
+ {
+ setInsertAfterOrdinaryInst(builder, replacement);
+ replacement = builder->emitMul(
+ replacement->getDataType(),
+ replacement,
+ builder->getIntValue(replacement->getDataType(), inductionValueInfo.counterFactor));
+ }
if (inductionValueInfo.counterOffset)
{
setInsertAfterOrdinaryInst(builder, replacement);
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index 59c70e862..96e4ea99b 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -240,11 +240,12 @@ namespace Slang
enum Kind
{
AlwaysTrue,
- EqualsToCounter,
+ AffineFunctionOfCounter,
};
Kind kind;
IRLoop* loopInst = nullptr;
IRInst* counterOffset = nullptr;
+ IRIntegerValue counterFactor = 1;
};
// Information on which insts are to be stored, recomputed
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index e7cc0d6c8..b0d9bb109 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1208,6 +1208,18 @@ struct IRBlock : IRInst
getFirstParam(),
getLastParam());
}
+ // Linear in the parameter index, returns -1 if the param doesn't exist
+ Index getParamIndex(IRParam* const needle)
+ {
+ Index ret = 0;
+ for(const auto p : getParams())
+ {
+ if (p == needle)
+ return ret;
+ ret++;
+ }
+ return -1;
+ }
void addParam(IRParam* param);
void insertParamAtHead(IRParam* param);
diff --git a/tests/autodiff/long-loop-branching-addition.slang b/tests/autodiff/long-loop-branching-addition.slang
new file mode 100644
index 000000000..beb371bd0
--- /dev/null
+++ b/tests/autodiff/long-loop-branching-addition.slang
@@ -0,0 +1,57 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[BackwardDifferentiable]
+float sin_series(float x, int iterations)
+{
+ float result = x;
+ float term = x;
+ int i = 1;
+ [MaxIters(30)]
+ do
+ {
+ term *= -1.0f * x * x / ((2 * i) * (2 * i + 1));
+ result += term;
+ if(result > 1000000)
+ {
+ i += 1;
+ if(result > 2000000)
+ {
+ term += 1;
+ i += 1;
+ }
+ else
+ {
+ i += 1;
+ }
+ }
+ else
+ {
+ i += 2;
+ }
+ i += -1;
+ } while (i < iterations);
+ return result;
+}
+
+// Check that the intermediate context of sin_series does not have an array for `i`.
+// This test inparticular checks that can identify induction variables through
+// branching control flow
+
+// CHECK: struct s_bwd_sin_series_Intermediates
+// CHECK-NOT: int {{[A-Za-z0-9_]+}}[{{.*}}]
+// CHECK: }
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ var x = diffPair(float.getPi(), 1.0);
+
+ __bwd_diff(sin_series)(x, 30, 1.0f);
+
+ outputBuffer[0] = x.d; // -1.0
+}
diff --git a/tests/autodiff/long-loop-branching-addition.slang.expected.txt b/tests/autodiff/long-loop-branching-addition.slang.expected.txt
new file mode 100644
index 000000000..77bdc5ea4
--- /dev/null
+++ b/tests/autodiff/long-loop-branching-addition.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+-1.0
diff --git a/tests/autodiff/long-loop-chained-addition.slang b/tests/autodiff/long-loop-chained-addition.slang
new file mode 100644
index 000000000..8f75744a9
--- /dev/null
+++ b/tests/autodiff/long-loop-chained-addition.slang
@@ -0,0 +1,42 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[BackwardDifferentiable]
+float sin_series(float x, int iterations)
+{
+ float result = x;
+ float term = x;
+ int i = 1;
+ [MaxIters(30)]
+ do
+ {
+ term *= -1.0f * x * x / ((2 * i) * (2 * i + 1));
+ i += 2;
+ i++;
+ result += term;
+ i -= 2;
+ } while (i < iterations);
+ return result;
+}
+
+// Check that the intermediate context of sin_series does not have an array for `i`.
+// This test inparticular checks that can identify induction variables with
+// more than one operation applied to them during the loop
+
+// CHECK: struct s_bwd_sin_series_Intermediates
+// CHECK-NOT: int {{[A-Za-z0-9_]+}}[{{.*}}]
+// CHECK: }
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ var x = diffPair(float.getPi(), 1.0);
+
+ __bwd_diff(sin_series)(x, 30, 1.0f);
+
+ outputBuffer[0] = x.d; // -1.0
+}
diff --git a/tests/autodiff/long-loop-chained-addition.slang.expected.txt b/tests/autodiff/long-loop-chained-addition.slang.expected.txt
new file mode 100644
index 000000000..77bdc5ea4
--- /dev/null
+++ b/tests/autodiff/long-loop-chained-addition.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+-1.0
diff --git a/tests/autodiff/long-loop-multiple.slang b/tests/autodiff/long-loop-multiple.slang
new file mode 100644
index 000000000..a696beccf
--- /dev/null
+++ b/tests/autodiff/long-loop-multiple.slang
@@ -0,0 +1,41 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[BackwardDifferentiable]
+float sin_series(float x, int iterations)
+{
+ float result = x;
+ float term = x;
+ [MaxIters(30)]
+ for (int i = 1; i < iterations * 10; i += 10)
+ {
+ term *= -1.0f * x * x / ((2 * i / 10 + 1) * (2 * i / 10 + 2));
+ result += term;
+ }
+ return result;
+}
+
+// Check that the intermediate context of sin_series does not have an array for `i`.
+// This test differs from ./long-loop.slang in that the loop counter is
+// relative to a multiple of the loop iteration
+
+// CHECK: struct s_bwd_sin_series_Intermediates
+// CHECK-NOT: int {{[A-Za-z0-9_]+}}[{{.*}}]
+// CHECK: }
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ var x = diffPair(float.getPi(), 1.0);
+
+ __bwd_diff(sin_series)(x, 30, 1.0f);
+
+ outputBuffer[0] = x.d; // -1.0
+}
+
+
+
diff --git a/tests/autodiff/long-loop-multiple.slang.expected.txt b/tests/autodiff/long-loop-multiple.slang.expected.txt
new file mode 100644
index 000000000..77bdc5ea4
--- /dev/null
+++ b/tests/autodiff/long-loop-multiple.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+-1.0
diff --git a/tests/autodiff/long-loop-noninductive.slang b/tests/autodiff/long-loop-noninductive.slang
new file mode 100644
index 000000000..bfd37c4f2
--- /dev/null
+++ b/tests/autodiff/long-loop-noninductive.slang
@@ -0,0 +1,40 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[BackwardDifferentiable]
+float sin_series(float x, int iterations)
+{
+ float result = x;
+ float term = x;
+ [MaxIters(35)]
+ for (int i = 1; i < iterations; i++)
+ {
+ if(i == 32)
+ i += 1;
+ term *= -1.0f * x * x / ((2 * i) * (2 * i + 1));
+ result += term;
+ }
+ return result;
+}
+
+// Check that the intermediate context of sin_series still has an array for
+// `i`. This test checks that the induction variable finder doesn't
+// accidentally succeed all the time
+
+// CHECK: struct s_bwd_sin_series_Intermediates
+// CHECK: int {{[A-Za-z0-9_]+}}[{{.*}}]
+// CHECK: }
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ var x = diffPair(float.getPi(), 1.0);
+
+ __bwd_diff(sin_series)(x, 30, 1.0f);
+
+ outputBuffer[0] = x.d; // -1.0
+}
diff --git a/tests/autodiff/long-loop-noninductive.slang.expected.txt b/tests/autodiff/long-loop-noninductive.slang.expected.txt
new file mode 100644
index 000000000..77bdc5ea4
--- /dev/null
+++ b/tests/autodiff/long-loop-noninductive.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+-1.0
diff --git a/tests/autodiff/long-while-loop.slang b/tests/autodiff/long-while-loop.slang
new file mode 100644
index 000000000..20d802e2a
--- /dev/null
+++ b/tests/autodiff/long-while-loop.slang
@@ -0,0 +1,43 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[BackwardDifferentiable]
+float sin_series(float x, int iterations)
+{
+ float result = x;
+ float term = x;
+ int i = 1;
+ [MaxIters(30)]
+ do
+ {
+ term *= -1.0f * x * x / ((2 * i) * (2 * i + 1));
+ result += term;
+ i++;
+ } while (i < iterations);
+ return result;
+}
+
+// Check that the intermediate context of sin_series does not have an array for `i`.
+// This differs from ./long-loop.slang in that it uses an equivalent do/while
+// loop, this tests checks that induction variables are still correctly identified.
+
+// CHECK: struct s_bwd_sin_series_Intermediates
+// CHECK-NOT: int {{[A-Za-z0-9_]+}}[{{.*}}]
+// CHECK: }
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ var x = diffPair(float.getPi(), 1.0);
+
+ __bwd_diff(sin_series)(x, 30, 1.0f);
+
+ outputBuffer[0] = x.d; // -1.0
+}
+
+
+
diff --git a/tests/autodiff/long-while-loop.slang.expected.txt b/tests/autodiff/long-while-loop.slang.expected.txt
new file mode 100644
index 000000000..77bdc5ea4
--- /dev/null
+++ b/tests/autodiff/long-while-loop.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+-1.0
diff --git a/tests/cpu-program/cpu-hello-world.slang b/tests/cpu-program/cpu-hello-world.slang
index f91e354bc..f1285f889 100644
--- a/tests/cpu-program/cpu-hello-world.slang
+++ b/tests/cpu-program/cpu-hello-world.slang
@@ -4,4 +4,4 @@ public __extern_cpp int main()
{
printf("Hello World.\n");
return 0;
-} \ No newline at end of file
+}