diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 6 |
2 files changed, 30 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index ef5161104..06e3f409d 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -384,7 +384,21 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( processedUses.add(use); - HoistResult result = this->classify(use); + HoistResult result = HoistResult::none(); + + // Sometimes, we already have a decision for this val. + // + // This is a workaround to some of the problems + // with the multi-pass approach where we can see an + // inst that was already classified, but through a + // different use. + // + if (checkpointInfo->recomputeSet.contains(use.usedVal)) + result = HoistResult::recompute(use.usedVal); + else if (checkpointInfo->storeSet.contains(use.usedVal)) + result = HoistResult::store(use.usedVal); + else + result = this->classify(use); if (result.mode == HoistResult::Mode::Store) { @@ -500,6 +514,15 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( callVarWorkList.add(callUser); } } + + // This is a bit of a hack.. ideally we need to add the var to the worklist for + // further processing rather than replicating those operations here. + // + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (isDifferentialInst(use->getUser())) + usesToReplace.add(use); + } } else if (auto call = as<IRCall>(inst)) { diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index 13e2ca078..92bc8197d 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -171,6 +171,10 @@ struct HoistResult case Mode::Invert: SLANG_UNEXPECTED("Wrong constructor for HoistResult::Mode::Invert"); break; + case Mode::None: + instToStore = nullptr; + instToRecompute = nullptr; + break; default: SLANG_UNEXPECTED("Unhandled hoist mode"); break; @@ -187,6 +191,8 @@ struct HoistResult static HoistResult recompute(IRInst* inst) { return HoistResult(Mode::Recompute, inst); } static HoistResult invert(InversionInfo inst) { return HoistResult(inst); } + + static HoistResult none() { return HoistResult(Mode::None, nullptr); } }; struct IndexTrackingInfo : public RefObject |
