diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-01-28 21:46:40 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-29 05:46:40 +0000 |
| commit | c18c4365af77fde279abed33876388961a180b3d (patch) | |
| tree | 5701907d22654b90ce44ee8f8223bbf8c77183a4 /source | |
| parent | 1c282b80b9fbcfea9dc3dab7f5f546b069143e01 (diff) | |
Fix loophole in hoisting where an `OpVar`'s uses might not be properly registered for replacement (#6212)
* ix loophole in hoisting where an IRVar's uses might not be properly registered for replacement
* fix formatting
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 6 |
2 files changed, 30 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index ef5161104..06e3f409d 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -384,7 +384,21 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( processedUses.add(use); - HoistResult result = this->classify(use); + HoistResult result = HoistResult::none(); + + // Sometimes, we already have a decision for this val. + // + // This is a workaround to some of the problems + // with the multi-pass approach where we can see an + // inst that was already classified, but through a + // different use. + // + if (checkpointInfo->recomputeSet.contains(use.usedVal)) + result = HoistResult::recompute(use.usedVal); + else if (checkpointInfo->storeSet.contains(use.usedVal)) + result = HoistResult::store(use.usedVal); + else + result = this->classify(use); if (result.mode == HoistResult::Mode::Store) { @@ -500,6 +514,15 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( callVarWorkList.add(callUser); } } + + // This is a bit of a hack.. ideally we need to add the var to the worklist for + // further processing rather than replicating those operations here. + // + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (isDifferentialInst(use->getUser())) + usesToReplace.add(use); + } } else if (auto call = as<IRCall>(inst)) { diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index 13e2ca078..92bc8197d 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -171,6 +171,10 @@ struct HoistResult case Mode::Invert: SLANG_UNEXPECTED("Wrong constructor for HoistResult::Mode::Invert"); break; + case Mode::None: + instToStore = nullptr; + instToRecompute = nullptr; + break; default: SLANG_UNEXPECTED("Unhandled hoist mode"); break; @@ -187,6 +191,8 @@ struct HoistResult static HoistResult recompute(IRInst* inst) { return HoistResult(Mode::Recompute, inst); } static HoistResult invert(InversionInfo inst) { return HoistResult(inst); } + + static HoistResult none() { return HoistResult(Mode::None, nullptr); } }; struct IndexTrackingInfo : public RefObject |
