summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-02-17 17:56:07 -0500
committerGitHub <noreply@github.com>2023-02-17 14:56:07 -0800
commit5cd39d1527f87ebab966cbd9c136b93058a709bc (patch)
tree0738ed293e79ab22edd200221e93cd09b6158ebe /source
parent051607368e8d3dd55d2ad2b2200ef656244ec70d (diff)
AD: Remove the original loop condition upon inversion (#2661)
* Remove the original condition upon loop inversion (it's redundant, and causes out-of-bounds accesses) * minor fix (also removed the first loop check skip) * Cleanup unused insts * minor comment fix
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h32
1 files changed, 2 insertions, 30 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index b74416b76..f998ae13f 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -1051,27 +1051,6 @@ struct DiffTransposePass
IRBlock* revLoopCondBlock = revBlockMap[firstLoopBlock];
builder->setInsertBefore(revLoopCondBlock->getTerminator());
- auto loopBaseCondition = as<IRIfElse>(revLoopCondBlock->getTerminator())->getCondition();
-
- // Convert the loop from a 'for' into a 'do-while' by skipping the first check
-
- IRBlock* revLoopStartBlock = revBlockMap[as<IRBlock>(loopInst->getBreakBlock())];
- builder->setInsertBefore(revLoopStartBlock->getTerminator());
-
- auto firstLoopCheckSkipVar = builder->emitVar(builder->getBoolType());
- builder->emitStore(firstLoopCheckSkipVar, builder->getBoolValue(true));
-
- builder->setInsertBefore(revLoopCondBlock->getTerminator());
- auto firstLoopCheckSkipVal = builder->emitLoad(firstLoopCheckSkipVar);
-
- builder->emitStore(firstLoopCheckSkipVar, builder->getBoolValue(false));
-
- loopBaseCondition = builder->emitIntrinsicInst(
- builder->getBoolType(),
- kIROp_Or,
- 2,
- List<IRInst*>(firstLoopCheckSkipVal, loopBaseCondition).getBuffer());
-
// Add a terminating condition based on the loop counter's initial primal value
IRParam* loopCounterParam = nullptr;
@@ -1080,7 +1059,7 @@ struct DiffTransposePass
{
if (param->findDecoration<IRLoopCounterDecoration>())
{
- // There really should be two (or more) loop counter params.
+ // There really not should be two (or more) loop counter params.
SLANG_RELEASE_ASSERT(loopCounterParam == nullptr);
loopCounterParam = param;
}
@@ -1102,15 +1081,8 @@ struct DiffTransposePass
List<IRInst*>(
hoistPrimalInst(builder, loopCounterParam),
hoistPrimalInst(builder, loopCounterInitVal)).getBuffer());
-
- loopBaseCondition = builder->emitIntrinsicInst(
- builder->getBoolType(),
- kIROp_And,
- 2,
- List<IRInst*>(paramBoundsCheck, loopBaseCondition).getBuffer());
-
- as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(loopBaseCondition);
+ as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(paramBoundsCheck);
}
List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput)