diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2023-08-01 12:43:51 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-01 12:43:51 +0800 |
| commit | c34a7b6627d4c07531daf7d99dceaf7f89bd1c0a (patch) | |
| tree | 36eef7ee055c3706bce32493f47fddb5c0af3a4f | |
| parent | 5349241098076bead63f638daf2e4b9a9cb3e496 (diff) | |
Generalize collectInductionValues (#3031)
* Generalize collectInductionValues
* Support affine transformations of loop index as induction variables
* Test for generalized induction value collection
* Neaten inductive variable finding
* Store the type of implication success when finding inductive variables
* Test that loop induction finding does not alway succeed
* Support chains of additions and branches of additions in induction variable finding
* Use c++17 for downstream compilers
| -rw-r--r-- | source/compiler-core/slang-gcc-compiler-util.cpp | 15 | ||||
| -rw-r--r-- | source/compiler-core/slang-nvrtc-compiler.cpp | 5 | ||||
| -rw-r--r-- | source/core/slang-hash.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 405 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 12 | ||||
| -rw-r--r-- | tests/autodiff/long-loop-branching-addition.slang | 57 | ||||
| -rw-r--r-- | tests/autodiff/long-loop-branching-addition.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/autodiff/long-loop-chained-addition.slang | 42 | ||||
| -rw-r--r-- | tests/autodiff/long-loop-chained-addition.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/autodiff/long-loop-multiple.slang | 41 | ||||
| -rw-r--r-- | tests/autodiff/long-loop-multiple.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/autodiff/long-loop-noninductive.slang | 40 | ||||
| -rw-r--r-- | tests/autodiff/long-loop-noninductive.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/autodiff/long-while-loop.slang | 43 | ||||
| -rw-r--r-- | tests/autodiff/long-while-loop.slang.expected.txt | 2 | ||||
| -rw-r--r-- | tests/cpu-program/cpu-hello-world.slang | 2 |
17 files changed, 549 insertions, 133 deletions
diff --git a/source/compiler-core/slang-gcc-compiler-util.cpp b/source/compiler-core/slang-gcc-compiler-util.cpp index f003ffb0f..c4354f40e 100644 --- a/source/compiler-core/slang-gcc-compiler-util.cpp +++ b/source/compiler-core/slang-gcc-compiler-util.cpp @@ -466,8 +466,8 @@ static SlangResult _parseGCCFamilyLine(SliceAllocator& allocator, const UnownedS { cmdLine.addArg("-fvisibility=hidden"); - // Need C++14 for partial specialization - cmdLine.addArg("-std=c++14"); + // C++17 since we share headers with slang itself (which uses c++17) + cmdLine.addArg("-std=c++17"); } // TODO(JS): Here we always set -m32 on x86. It could be argued it is only necessary when creating a shared library @@ -696,19 +696,16 @@ static SlangResult _parseGCCFamilyLine(SliceAllocator& allocator, const UnownedS ComPtr<IDownstreamCompiler> compiler; if (SLANG_SUCCEEDED(createCompiler(ExecutableLocation(path, "g++"), compiler))) { - // A downstream compiler for Slang must currently support C++14 - such that + // A downstream compiler for Slang must currently support C++17 - such that // the prelude and generated code works. // - // The first version of gcc that supports `-std=c++14` is 5.0 + // The first version of gcc that supports stable `-std=c++17` is 9.0 // https://gcc.gnu.org/projects/cxx-status.html - // - // If could be argued to allow C/C++ compilations via older versions through an older version - // but that requires some more complex behavior, so we don't allow for now. auto desc = compiler->getDesc(); - if (desc.version.m_major < 5) + if (desc.version.m_major < 9) { - // If the version isn't 5 or higher, we don't add this version of the compiler. + // If the version isn't 9 or higher, we don't add this version of the compiler. return SLANG_OK; } diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp index e2f3e678c..c756955ec 100644 --- a/source/compiler-core/slang-nvrtc-compiler.cpp +++ b/source/compiler-core/slang-nvrtc-compiler.cpp @@ -784,8 +784,9 @@ SlangResult NVRTCDownstreamCompiler::compile(const DownstreamCompileOptions& inO // Neither of these options are strictly required, for general use of nvrtc, // but are enabled to make use withing Slang work more smoothly { - // Require c++14, as makes initialization construction with {} available and so simplifies code generation - cmdLine.addArg("-std=c++14"); + // Require c++17, the default at the time of writing, since we share + // some functionality between slang itself and the compiled code + cmdLine.addArg("-std=c++17"); // Disable all warnings // This is arguably too much - but nvrtc does not appear to have a mechanism to switch off individual warnings. diff --git a/source/core/slang-hash.h b/source/core/slang-hash.h index 8c4fbac42..bc4b30ccc 100644 --- a/source/core/slang-hash.h +++ b/source/core/slang-hash.h @@ -164,6 +164,13 @@ namespace Slang return PointerHash<std::is_pointer<TKey>::value>::getHashCode(key); } + template<typename TKey> + HashCode getHashCodeBytewise(const TKey& t) + { + static_assert(std::has_unique_object_representations_v<TKey>); + return getHashCode(reinterpret_cast<const char*>(&t), sizeof(TKey)); + } + inline HashCode combineHash(HashCode left, HashCode right) { return (left * 16777619) ^ right; diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 6947fd7c5..898a86049 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1,6 +1,9 @@ #include "slang-ir-autodiff-primal-hoist.h" +#include "slang-ast-support-types.h" #include "slang-ir-autodiff-region.h" #include "slang-ir-simplify-cfg.h" +#include "slang-ir-util.h" +#include "slang-ir.h" namespace Slang { @@ -493,6 +496,219 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( return hoistInfo; } +struct ImplicationParams +{ + IRInst *condition, *induction, *block; + HashCode getHashCode() const { return getHashCodeBytewise(*this); } + // C++20: friend auto operator<=>(const ImplicationParams&, const ImplicationParams&) = default; + friend bool operator==(const ImplicationParams& x, const ImplicationParams& y) + { + return x.condition == y.condition && x.induction == y.induction && x.block == y.block; + } +}; + +struct ImplicationResult +{ + enum + { + // The value was not a constant offset from the induction variable and + // the condition variable was true + Falsified, + // The condition variable is false, so the value's relationship to the + // induction variable doesn't matter + AntecedentHolds, + // The value is a constant offset from the induction variable, stored + // in 'factor' + ConsequentHolds + } e; + IRIntegerValue factor; +}; + +static ImplicationResult join(const ImplicationResult& a, const ImplicationResult& b) +{ + if(a.e == ImplicationResult::Falsified || b.e == ImplicationResult::Falsified) + return {ImplicationResult::Falsified, 0}; + if(a.e == ImplicationResult::AntecedentHolds) + return b; + if(b.e == ImplicationResult::AntecedentHolds) + return a; + if(a.factor != b.factor) + return {ImplicationResult::Falsified, 0}; + return a; +} + +static ImplicationResult inductionImplicationHolds( + Dictionary<ImplicationParams, ImplicationResult>& memo, + IRInst* const prevVal, + IRInst* const conditionVal, + IRInst* const inductiveVal, + IRBlock* const block); + +static bool unpackConstantAddition(IRInst* addOrSub, IRInst*& operand, IRIntegerValue& constant) +{ + if(addOrSub->getOp() != kIROp_Add && addOrSub->getOp() != kIROp_Sub) + return false; + const bool negate = addOrSub->getOp() == kIROp_Sub; + + auto o = addOrSub->getOperand(0); + auto c = addOrSub->getOperand(1); + if(!as<IRIntLit>(c)) + std::swap(o, c); + const auto cLit = as<IRIntLit>(c); + if(!cLit) + return false; + operand = o; + constant = cLit->getValue(); + // Check that we can actually represent this! + if(negate && constant == std::numeric_limits<IRIntegerValue>::min()) + return false; + constant *= negate ? -1 : 1; + return true; +} + +// isAdditionOf(m, i, p, c, f) returns true if it can prove `c => i = p + f` +// for some constant factor f +static bool isAdditionOf( + Dictionary<ImplicationParams, ImplicationResult>& memo, + IRInst* inductiveVal, + IRInst* prevVal, + IRInst* conditionVal, + IRIntegerValue& factor) +{ + IRInst* operand; + IRIntegerValue constant; + if(!unpackConstantAddition(inductiveVal, operand, constant)) + return false; + + const auto impRes = inductionImplicationHolds( + memo, + prevVal, + conditionVal, + operand, + as<IRBlock>(inductiveVal->getParent())); + + if(impRes.e == ImplicationResult::ConsequentHolds) + { + // TODO: Check for overflow here (strictly speaking it shouldn't + // matter in the end numerically (except that this could be UB)). + factor = impRes.factor + constant; + return true; + } + return false; +} + +// Returns true if we can prove that in this block this value is +// always false +static bool isAlwaysFalseInBlock(IRInst* inst, IRBlock* block) +{ + const auto b = as<IRBoolLit>(inst); + if(b) + return !b->getValue(); + + // At the moment we just check that the predecessors of this + // block have us on the false path of a conditional branch on + // the instruction under question. + bool isFalse = true; + for(const auto predecessor : block->getPredecessors()) + { + const auto predConditionalBranch = as<IRConditionalBranch>(predecessor->getTerminator()); + isFalse &= + predConditionalBranch && + predConditionalBranch->getCondition() == inst && + predConditionalBranch->getFalseBlock() == block; + if(!isFalse) + break; + } + return isFalse; +} + +// This function takes: +// - A block with an unconditional branch with at least one parameter 'i' +// - The index of 'i' +// - A condition variable 'c' +// - The predecessor case in the induction 'p' +// +// It returns true if at the time of the branch: 'isTrue(c) => isInductiveValue(i, p)' +// It return false if it can't prove this implication holds. +static ImplicationResult inductionImplicationHolds( + Dictionary<ImplicationParams, ImplicationResult>& memo, + IRInst* const prevVal, + IRInst* const conditionVal, + IRInst* const inductiveVal, + IRBlock* const block) +{ + // If we have a result memoized we can safely return that. + const ImplicationParams i = {conditionVal, inductiveVal, block}; + const auto memoized = memo.tryGetValue(i); + if(memoized) + return *memoized; + + // While we are detemining if the implication holds at this position we set + // the result to Falsified so as to fail if we require a self-referential + // proof + memo.add(i, {ImplicationResult::Falsified, 0}); + // A helper to record the solution as we're returning + const auto andRemember = [&memo, i](ImplicationResult r) { + memo.set(i, r); + return r; + }; + + // Our most general solution is if the left hand side of the implication is + // false, in which case we can return success without specifying a factor + if(isAlwaysFalseInBlock(conditionVal, block)) + return andRemember({ImplicationResult::AntecedentHolds, 0}); + + // Otherwise, we handle the additive case + // One easy case is that this *is* the previous value, in which case it's a + // trivial solution with an addition of 0 + if(prevVal == inductiveVal) + return andRemember({ImplicationResult::ConsequentHolds, 0}); + + // Otherwise is it a function over the inductive variable + IRIntegerValue factor; + if(isAdditionOf(memo, inductiveVal, prevVal, conditionVal, factor)) + return andRemember({ImplicationResult::ConsequentHolds, factor}); + + // The last thing to try is to consider the case where the + // inductive value under consideration is a parameter, in that case we can + // recurse into the predecessors of this block, replacing the parameter and + // condition variables with their arguments where appropriate. + const auto inductiveParam = as<IRParam>(inductiveVal); + + // If it's not a parameter then we don't know how to continue + // (in principle we could also hadle instructions such as loads here) + if(!inductiveParam) + return {ImplicationResult::Falsified, 0}; + + const auto conditionParam = as<IRParam>(conditionVal); + const auto inductiveParamIndex = block->getParamIndex(inductiveParam); + const auto conditionParamIndex = block->getParamIndex(conditionParam); + + // If we have no predecessors, then all the possible values (none) of the + // condition variable are false, so our antecedent holds + ImplicationResult res = {ImplicationResult::AntecedentHolds, 0}; + + for(const auto predecessor : block->getPredecessors()) + { + const auto predTerminator = as<IRUnconditionalBranch>(predecessor->getTerminator()); + SLANG_ASSERT(inductiveParamIndex == -1 || predTerminator); + SLANG_ASSERT(conditionParamIndex == -1 || predTerminator); + + const auto nextInductiveParam = + inductiveParamIndex == -1 ? inductiveParam : predTerminator->getArg(inductiveParamIndex); + const auto nextConditionParam = + conditionParamIndex == -1 ? conditionParam : predTerminator->getArg(conditionParamIndex); + + const auto predResult = inductionImplicationHolds(memo, prevVal, nextConditionParam, nextInductiveParam, predecessor); + res = join(res, predResult); + + if(res.e == ImplicationResult::Falsified) + break; + } + + return andRemember(res); +} + void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode* func) { // Collect loop induction values. @@ -536,139 +752,80 @@ void AutodiffCheckpointPolicyBase::collectInductionValues(IRGlobalValueWithCode* if (conditionParamIndex == -1) continue; - // Next, we try to identify the original induction variables, if they exist. - // These are trickier, and we have to hard code the complex pattern that - // we can recognize. - // This pattern matching logic is ugly and fragile against changes to cfg - // normalization, but it is the easiest way to do it right now. - // Basically, we are looking for this pattern: - // loop(..., i=initVal) - // { - // targetBlock: - // ... - // param int i; - // param bool condition; - // ... - // branch condtionBlock; - // conditionBlock: - // if (condition) - // { - // } - // else - // { - // break; - // } - // // ... - // someBodyBlock: - // ... - // if (condition) - // { - // ... - // // Check condition 1: i is used by an `add` - // // Check condition 2: parent of (i+1) is a branch target of if(condition) - // // Check condition 3: branches to parentBlock with i1 = i + 1. - // goto parentBlock(i + 1); - // } - // else - // goto parentBlock(other); - // parentBlock: - // // Check condition 4: parentBlock branches to finalBlock. - // param int i1; - // goto finalBlock; - // finalBlock: - // // Check condition 5: finalBlock branches to targetBlock with new i = i1. - // goto loopHeader(i1); - // } + // Next we try to identify any induction variables. // + // An inductive parameter must: + // - Be initialized as anything from a single predecessor, the base + // case + // - Be passed a function of itself only on any other entries to + // the loop, the inductive case + // + // In terms of matching here, we allow the base case to be + // anything, and the inductive case to be the successor function + // + // We also handle the case where something other than the base or + // inductive case is passed to the top of the loop when the "condition + // parameter" is false, in which case the "non-inductive" value isn't + // actually used. + paramIndex = -1; for (auto param : targetBlock->getParams()) { paramIndex++; - if (!param->getDataType()) + + const auto t = param->getDataType(); + if (!t || !isScalarIntegerType(t)) continue; - if (isScalarIntegerType(param->getDataType())) + + // This *is* the loop counter! + if(param->findDecoration<IRLoopCounterDecoration>()) + continue; + + auto predecessors = targetBlock->getPredecessors(); + Dictionary<ImplicationParams, ImplicationResult> memo; + ImplicationResult impRes = {ImplicationResult::AntecedentHolds, 0}; + for(const auto predecessor : predecessors) { - // If the param is always equal to the loop index, we don't need to store it. - IRInst* addUse = nullptr; - for (auto use = param->firstUse; use && !addUse; use = use->nextUse) - { - auto user = use->getUser(); - if (user->getOp() != kIROp_Add) - continue; - auto intLit = as<IRIntLit>(use->getUser()->getOperand(1)); - if (!intLit) - continue; - if (intLit->getValue() != 1) - continue; - - // The add inst's parent block is behind a `ifelse(loopCondition)`. - auto addInstBlock = as<IRBlock>(user->getParent()); - if (!addInstBlock) - continue; - auto predecessors = addInstBlock->getPredecessors(); - if (predecessors.getCount() != 1) - continue; - auto parentIfElse = as<IRIfElse>(predecessors.b->getUser()); - if (!parentIfElse) - continue; - auto parentCondition = parentIfElse->getCondition(); - - auto branch = as<IRUnconditionalBranch>(addInstBlock->getTerminator()); - if (!branch) - continue; - - // The add inst should be used as a branchArg. - UInt argIndex = 0; - for (UInt i = 0; i < branch->getArgCount(); i++) - { - if (branch->getArg(i) == user) - { - addUse = user; - argIndex = i; - break; - } - } - if (!addUse) - continue; - auto branchTarget1 = branch->getTargetBlock(); - auto branchParam = branchTarget1->getFirstParam(); - for (UInt i = 0; i < argIndex; i++) - if (branchParam) - branchParam = branchParam->getNextParam(); - if (!branchParam) - continue; - - // The branchParam is used as argument to branch back to loop header. - auto branch2 = as<IRUnconditionalBranch>(branchTarget1->getTerminator()); - if (!branch2) - continue; - if (branch2->getTargetBlock() != targetBlock) - continue; - argIndex = 0; - for (UInt i = 0; i < branch2->getArgCount(); i++) - { - if (branch2->getArg(i) == branchParam) - { - argIndex = i; - break; - } - } - if (argIndex != (UInt)paramIndex) - continue; + // Since this is branching with a parameter, it can only be an + // unconditional branch. + const auto predTerminator = as<IRUnconditionalBranch>(predecessor->getTerminator()); + SLANG_ASSERT(predTerminator); + + // Is this the base case? + if(predTerminator == loopInst) + continue; - // parentCondition is also used as the new condition in the back jump. - if (conditionParamIndex < 0 || (UInt)conditionParamIndex >= branch2->getArgCount() || - branch2->getArg((UInt)conditionParamIndex) != parentCondition) - continue; + const auto conditionArg = predTerminator->getArg(conditionParamIndex); + const auto inductiveArg = predTerminator->getArg(paramIndex); + // Check that the required implication holds for this block + const auto predRes = inductionImplicationHolds(memo, param, conditionArg, inductiveArg, predecessor); + impRes = join(impRes, predRes); + if(impRes.e == ImplicationResult::Falsified) + break; + } + + switch(impRes.e) + { + // This wasn't an induction variable + case ImplicationResult::Falsified: + break; + + // The loop doesn't loop (because in every case the break flag is + // true) + case ImplicationResult::AntecedentHolds: + break; + + case ImplicationResult::ConsequentHolds: + { // The use of the add inst matches all of our conditions as an induction value - // that is equivalent to loop counter. + // that is a constant offset from a multiple of the loop counter. LoopInductionValueInfo info; - info.kind = LoopInductionValueInfo::Kind::EqualsToCounter; + info.kind = LoopInductionValueInfo::Kind::AffineFunctionOfCounter; info.loopInst = loopInst; info.counterOffset = loopInst->getArg(paramIndex); + info.counterFactor = impRes.factor; inductionValueInsts[param] = info; - break; } } } @@ -710,12 +867,20 @@ void applyToInst( { replacement = builder->getBoolValue(true); } - else if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::EqualsToCounter) + else if (inductionValueInfo.kind == LoopInductionValueInfo::Kind::AffineFunctionOfCounter) { auto indexInfo = blockIndexInfo.tryGetValue(inductionValueInfo.loopInst->getTargetBlock()); SLANG_ASSERT(indexInfo); SLANG_ASSERT(indexInfo->getCount() != 0); replacement = indexInfo->getFirst().diffCountParam; + if (inductionValueInfo.counterFactor != 1) + { + setInsertAfterOrdinaryInst(builder, replacement); + replacement = builder->emitMul( + replacement->getDataType(), + replacement, + builder->getIntValue(replacement->getDataType(), inductionValueInfo.counterFactor)); + } if (inductionValueInfo.counterOffset) { setInsertAfterOrdinaryInst(builder, replacement); diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index 59c70e862..96e4ea99b 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -240,11 +240,12 @@ namespace Slang enum Kind { AlwaysTrue, - EqualsToCounter, + AffineFunctionOfCounter, }; Kind kind; IRLoop* loopInst = nullptr; IRInst* counterOffset = nullptr; + IRIntegerValue counterFactor = 1; }; // Information on which insts are to be stored, recomputed diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index e7cc0d6c8..b0d9bb109 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1208,6 +1208,18 @@ struct IRBlock : IRInst getFirstParam(), getLastParam()); } + // Linear in the parameter index, returns -1 if the param doesn't exist + Index getParamIndex(IRParam* const needle) + { + Index ret = 0; + for(const auto p : getParams()) + { + if (p == needle) + return ret; + ret++; + } + return -1; + } void addParam(IRParam* param); void insertParamAtHead(IRParam* param); diff --git a/tests/autodiff/long-loop-branching-addition.slang b/tests/autodiff/long-loop-branching-addition.slang new file mode 100644 index 000000000..beb371bd0 --- /dev/null +++ b/tests/autodiff/long-loop-branching-addition.slang @@ -0,0 +1,57 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +float sin_series(float x, int iterations) +{ + float result = x; + float term = x; + int i = 1; + [MaxIters(30)] + do + { + term *= -1.0f * x * x / ((2 * i) * (2 * i + 1)); + result += term; + if(result > 1000000) + { + i += 1; + if(result > 2000000) + { + term += 1; + i += 1; + } + else + { + i += 1; + } + } + else + { + i += 2; + } + i += -1; + } while (i < iterations); + return result; +} + +// Check that the intermediate context of sin_series does not have an array for `i`. +// This test inparticular checks that can identify induction variables through +// branching control flow + +// CHECK: struct s_bwd_sin_series_Intermediates +// CHECK-NOT: int {{[A-Za-z0-9_]+}}[{{.*}}] +// CHECK: } + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var x = diffPair(float.getPi(), 1.0); + + __bwd_diff(sin_series)(x, 30, 1.0f); + + outputBuffer[0] = x.d; // -1.0 +} diff --git a/tests/autodiff/long-loop-branching-addition.slang.expected.txt b/tests/autodiff/long-loop-branching-addition.slang.expected.txt new file mode 100644 index 000000000..77bdc5ea4 --- /dev/null +++ b/tests/autodiff/long-loop-branching-addition.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +-1.0 diff --git a/tests/autodiff/long-loop-chained-addition.slang b/tests/autodiff/long-loop-chained-addition.slang new file mode 100644 index 000000000..8f75744a9 --- /dev/null +++ b/tests/autodiff/long-loop-chained-addition.slang @@ -0,0 +1,42 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +float sin_series(float x, int iterations) +{ + float result = x; + float term = x; + int i = 1; + [MaxIters(30)] + do + { + term *= -1.0f * x * x / ((2 * i) * (2 * i + 1)); + i += 2; + i++; + result += term; + i -= 2; + } while (i < iterations); + return result; +} + +// Check that the intermediate context of sin_series does not have an array for `i`. +// This test inparticular checks that can identify induction variables with +// more than one operation applied to them during the loop + +// CHECK: struct s_bwd_sin_series_Intermediates +// CHECK-NOT: int {{[A-Za-z0-9_]+}}[{{.*}}] +// CHECK: } + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var x = diffPair(float.getPi(), 1.0); + + __bwd_diff(sin_series)(x, 30, 1.0f); + + outputBuffer[0] = x.d; // -1.0 +} diff --git a/tests/autodiff/long-loop-chained-addition.slang.expected.txt b/tests/autodiff/long-loop-chained-addition.slang.expected.txt new file mode 100644 index 000000000..77bdc5ea4 --- /dev/null +++ b/tests/autodiff/long-loop-chained-addition.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +-1.0 diff --git a/tests/autodiff/long-loop-multiple.slang b/tests/autodiff/long-loop-multiple.slang new file mode 100644 index 000000000..a696beccf --- /dev/null +++ b/tests/autodiff/long-loop-multiple.slang @@ -0,0 +1,41 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +float sin_series(float x, int iterations) +{ + float result = x; + float term = x; + [MaxIters(30)] + for (int i = 1; i < iterations * 10; i += 10) + { + term *= -1.0f * x * x / ((2 * i / 10 + 1) * (2 * i / 10 + 2)); + result += term; + } + return result; +} + +// Check that the intermediate context of sin_series does not have an array for `i`. +// This test differs from ./long-loop.slang in that the loop counter is +// relative to a multiple of the loop iteration + +// CHECK: struct s_bwd_sin_series_Intermediates +// CHECK-NOT: int {{[A-Za-z0-9_]+}}[{{.*}}] +// CHECK: } + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var x = diffPair(float.getPi(), 1.0); + + __bwd_diff(sin_series)(x, 30, 1.0f); + + outputBuffer[0] = x.d; // -1.0 +} + + + diff --git a/tests/autodiff/long-loop-multiple.slang.expected.txt b/tests/autodiff/long-loop-multiple.slang.expected.txt new file mode 100644 index 000000000..77bdc5ea4 --- /dev/null +++ b/tests/autodiff/long-loop-multiple.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +-1.0 diff --git a/tests/autodiff/long-loop-noninductive.slang b/tests/autodiff/long-loop-noninductive.slang new file mode 100644 index 000000000..bfd37c4f2 --- /dev/null +++ b/tests/autodiff/long-loop-noninductive.slang @@ -0,0 +1,40 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +float sin_series(float x, int iterations) +{ + float result = x; + float term = x; + [MaxIters(35)] + for (int i = 1; i < iterations; i++) + { + if(i == 32) + i += 1; + term *= -1.0f * x * x / ((2 * i) * (2 * i + 1)); + result += term; + } + return result; +} + +// Check that the intermediate context of sin_series still has an array for +// `i`. This test checks that the induction variable finder doesn't +// accidentally succeed all the time + +// CHECK: struct s_bwd_sin_series_Intermediates +// CHECK: int {{[A-Za-z0-9_]+}}[{{.*}}] +// CHECK: } + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var x = diffPair(float.getPi(), 1.0); + + __bwd_diff(sin_series)(x, 30, 1.0f); + + outputBuffer[0] = x.d; // -1.0 +} diff --git a/tests/autodiff/long-loop-noninductive.slang.expected.txt b/tests/autodiff/long-loop-noninductive.slang.expected.txt new file mode 100644 index 000000000..77bdc5ea4 --- /dev/null +++ b/tests/autodiff/long-loop-noninductive.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +-1.0 diff --git a/tests/autodiff/long-while-loop.slang b/tests/autodiff/long-while-loop.slang new file mode 100644 index 000000000..20d802e2a --- /dev/null +++ b/tests/autodiff/long-while-loop.slang @@ -0,0 +1,43 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[BackwardDifferentiable] +float sin_series(float x, int iterations) +{ + float result = x; + float term = x; + int i = 1; + [MaxIters(30)] + do + { + term *= -1.0f * x * x / ((2 * i) * (2 * i + 1)); + result += term; + i++; + } while (i < iterations); + return result; +} + +// Check that the intermediate context of sin_series does not have an array for `i`. +// This differs from ./long-loop.slang in that it uses an equivalent do/while +// loop, this tests checks that induction variables are still correctly identified. + +// CHECK: struct s_bwd_sin_series_Intermediates +// CHECK-NOT: int {{[A-Za-z0-9_]+}}[{{.*}}] +// CHECK: } + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + var x = diffPair(float.getPi(), 1.0); + + __bwd_diff(sin_series)(x, 30, 1.0f); + + outputBuffer[0] = x.d; // -1.0 +} + + + diff --git a/tests/autodiff/long-while-loop.slang.expected.txt b/tests/autodiff/long-while-loop.slang.expected.txt new file mode 100644 index 000000000..77bdc5ea4 --- /dev/null +++ b/tests/autodiff/long-while-loop.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +-1.0 diff --git a/tests/cpu-program/cpu-hello-world.slang b/tests/cpu-program/cpu-hello-world.slang index f91e354bc..f1285f889 100644 --- a/tests/cpu-program/cpu-hello-world.slang +++ b/tests/cpu-program/cpu-hello-world.slang @@ -4,4 +4,4 @@ public __extern_cpp int main() { printf("Hello World.\n"); return 0; -}
\ No newline at end of file +} |
