From c18c4365af77fde279abed33876388961a180b3d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 28 Jan 2025 21:46:40 -0800 Subject: Fix loophole in hoisting where an `OpVar`'s uses might not be properly registered for replacement (#6212) * ix loophole in hoisting where an IRVar's uses might not be properly registered for replacement * fix formatting --- source/slang/slang-ir-autodiff-primal-hoist.cpp | 25 ++++++++++++++++++++++++- source/slang/slang-ir-autodiff-primal-hoist.h | 6 ++++++ 2 files changed, 30 insertions(+), 1 deletion(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index ef5161104..06e3f409d 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -384,7 +384,21 @@ RefPtr AutodiffCheckpointPolicyBase::processFunc( processedUses.add(use); - HoistResult result = this->classify(use); + HoistResult result = HoistResult::none(); + + // Sometimes, we already have a decision for this val. + // + // This is a workaround to some of the problems + // with the multi-pass approach where we can see an + // inst that was already classified, but through a + // different use. + // + if (checkpointInfo->recomputeSet.contains(use.usedVal)) + result = HoistResult::recompute(use.usedVal); + else if (checkpointInfo->storeSet.contains(use.usedVal)) + result = HoistResult::store(use.usedVal); + else + result = this->classify(use); if (result.mode == HoistResult::Mode::Store) { @@ -500,6 +514,15 @@ RefPtr AutodiffCheckpointPolicyBase::processFunc( callVarWorkList.add(callUser); } } + + // This is a bit of a hack.. ideally we need to add the var to the worklist for + // further processing rather than replicating those operations here. + // + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (isDifferentialInst(use->getUser())) + usesToReplace.add(use); + } } else if (auto call = as(inst)) { diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index 13e2ca078..92bc8197d 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -171,6 +171,10 @@ struct HoistResult case Mode::Invert: SLANG_UNEXPECTED("Wrong constructor for HoistResult::Mode::Invert"); break; + case Mode::None: + instToStore = nullptr; + instToRecompute = nullptr; + break; default: SLANG_UNEXPECTED("Unhandled hoist mode"); break; @@ -187,6 +191,8 @@ struct HoistResult static HoistResult recompute(IRInst* inst) { return HoistResult(Mode::Recompute, inst); } static HoistResult invert(InversionInfo inst) { return HoistResult(inst); } + + static HoistResult none() { return HoistResult(Mode::None, nullptr); } }; struct IndexTrackingInfo : public RefObject -- cgit v1.2.3