From 6bca0ec355aae2955c7de1cd16c2dc0dfe46f19c Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 21 Feb 2023 11:37:15 -0500 Subject: Added support for simple while loops (#2667) * Added support for simple while loops * Fix support for while loops by changing logic to grab the loop update block --- source/slang/slang-ir-autodiff-cfg-norm.cpp | 16 ++++++++++------ source/slang/slang-ir-autodiff-unzip.h | 24 ++++++++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 4014473ea..d64c6d1f6 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -473,14 +473,18 @@ struct CFGNormalizationPass loopEndPoint = falseEndPoint; isLoopOnTrueSide = false; } + + // Right now, we only support loops where the loop is on the true side of + // the condition. If we ever encounter the other case, fill in logic to + // flip the condition. + // + SLANG_RELEASE_ASSERT(isLoopOnTrueSide); + // Expect atleast one basic block (other than the condition block), in + // the loop. + // SLANG_RELEASE_ASSERT(loopEndPoint.exitBlock); - - // Special case.. the if-else of a loop needs it's - // after block to be pointing at the last block before - // it loops back to the if-else. - // - // ifElse->afterBlock.set(loopEndPoint.exitBlock); + SLANG_RELEASE_ASSERT(!loopEndPoint.isRegionEmpty); // Does the loop endpoint have both 'break' and 'base' // control flows? diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 944df2c81..2ccb8d8e2 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -277,14 +277,26 @@ struct DiffUnzipPass IRBlock* getUpdateBlock(IndexedRegion* region) { - // TODO: What if the 'continue' region has multiple - // blocks? - // We ideally want the _last_ block before control loops back. + auto initBlock = getInitializerBlock(region); + + auto condBlock = region->firstBlock; + + IRBlock* lastLoopBlock = nullptr; + + for (auto predecessor : condBlock->getPredecessors()) + { + if (predecessor != initBlock) + lastLoopBlock = predecessor; + } + + // Should find atleast one predecessor that is _not_ the + // init block (that contains the loop info). This + // predecessor would be the last block in the loop + // before looping back to the condition. // - SLANG_RELEASE_ASSERT(as( - region->continueBlock->getTerminator())->getTargetBlock() == region->firstBlock); + SLANG_RELEASE_ASSERT(lastLoopBlock); - return region->continueBlock; + return lastLoopBlock; } IRBlock* getFirstLoopBodyBlock(IndexedRegion* region) -- cgit v1.2.3