diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-21 11:37:15 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-21 08:37:15 -0800 |
| commit | 6bca0ec355aae2955c7de1cd16c2dc0dfe46f19c (patch) | |
| tree | ab9bf433bd575375f6c0871bc7589a72b39b6615 | |
| parent | e5bec2fcb86da56775a3f1a0bc0af5039b722e86 (diff) | |
Added support for simple while loops (#2667)
* Added support for simple while loops
* Fix support for while loops by changing logic to grab the loop update block
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 24 | ||||
| -rw-r--r-- | tests/autodiff/reverse-while-loop.slang | 47 | ||||
| -rw-r--r-- | tests/autodiff/reverse-while-loop.slang.expected.txt | 6 |
4 files changed, 81 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 4014473ea..d64c6d1f6 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -473,14 +473,18 @@ struct CFGNormalizationPass loopEndPoint = falseEndPoint; isLoopOnTrueSide = false; } + + // Right now, we only support loops where the loop is on the true side of + // the condition. If we ever encounter the other case, fill in logic to + // flip the condition. + // + SLANG_RELEASE_ASSERT(isLoopOnTrueSide); + // Expect atleast one basic block (other than the condition block), in + // the loop. + // SLANG_RELEASE_ASSERT(loopEndPoint.exitBlock); - - // Special case.. the if-else of a loop needs it's - // after block to be pointing at the last block before - // it loops back to the if-else. - // - // ifElse->afterBlock.set(loopEndPoint.exitBlock); + SLANG_RELEASE_ASSERT(!loopEndPoint.isRegionEmpty); // Does the loop endpoint have both 'break' and 'base' // control flows? diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 944df2c81..2ccb8d8e2 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -277,14 +277,26 @@ struct DiffUnzipPass IRBlock* getUpdateBlock(IndexedRegion* region) { - // TODO: What if the 'continue' region has multiple - // blocks? - // We ideally want the _last_ block before control loops back. + auto initBlock = getInitializerBlock(region); + + auto condBlock = region->firstBlock; + + IRBlock* lastLoopBlock = nullptr; + + for (auto predecessor : condBlock->getPredecessors()) + { + if (predecessor != initBlock) + lastLoopBlock = predecessor; + } + + // Should find atleast one predecessor that is _not_ the + // init block (that contains the loop info). This + // predecessor would be the last block in the loop + // before looping back to the condition. // - SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>( - region->continueBlock->getTerminator())->getTargetBlock() == region->firstBlock); + SLANG_RELEASE_ASSERT(lastLoopBlock); - return region->continueBlock; + return lastLoopBlock; } IRBlock* getFirstLoopBodyBlock(IndexedRegion* region) diff --git a/tests/autodiff/reverse-while-loop.slang b/tests/autodiff/reverse-while-loop.slang new file mode 100644 index 000000000..c8d2c542a --- /dev/null +++ b/tests/autodiff/reverse-while-loop.slang @@ -0,0 +1,47 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test_simple_while(float y) +{ + float t = y; + + bool keepGoing = true; + int i = 2; + + [MaxIters(3)] + while (keepGoing) + { + i++; + t = t * t; + + keepGoing = (i < 5); + } + + return t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(test_simple_while)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 8.0 + } + + { + dpfloat dpa = dpfloat(0.4, 0.0); + + __bwd_diff(test_simple_while)(dpa, 1.0f); + outputBuffer[1] = dpa.d; // Expect: 0.0131072 + } +} diff --git a/tests/autodiff/reverse-while-loop.slang.expected.txt b/tests/autodiff/reverse-while-loop.slang.expected.txt new file mode 100644 index 000000000..76b7cf779 --- /dev/null +++ b/tests/autodiff/reverse-while-loop.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +8.000000 +0.013107 +0.000000 +0.000000 +0.000000 |
