summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-01-28 21:46:40 -0800
committerGitHub <noreply@github.com>2025-01-29 05:46:40 +0000
commitc18c4365af77fde279abed33876388961a180b3d (patch)
tree5701907d22654b90ce44ee8f8223bbf8c77183a4 /source
parent1c282b80b9fbcfea9dc3dab7f5f546b069143e01 (diff)
Fix loophole in hoisting where an `OpVar`'s uses might not be properly registered for replacement (#6212)
* ix loophole in hoisting where an IRVar's uses might not be properly registered for replacement * fix formatting
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp25
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h6
2 files changed, 30 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index ef5161104..06e3f409d 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -384,7 +384,21 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
processedUses.add(use);
- HoistResult result = this->classify(use);
+ HoistResult result = HoistResult::none();
+
+ // Sometimes, we already have a decision for this val.
+ //
+ // This is a workaround to some of the problems
+ // with the multi-pass approach where we can see an
+ // inst that was already classified, but through a
+ // different use.
+ //
+ if (checkpointInfo->recomputeSet.contains(use.usedVal))
+ result = HoistResult::recompute(use.usedVal);
+ else if (checkpointInfo->storeSet.contains(use.usedVal))
+ result = HoistResult::store(use.usedVal);
+ else
+ result = this->classify(use);
if (result.mode == HoistResult::Mode::Store)
{
@@ -500,6 +514,15 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
callVarWorkList.add(callUser);
}
}
+
+ // This is a bit of a hack.. ideally we need to add the var to the worklist for
+ // further processing rather than replicating those operations here.
+ //
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ {
+ if (isDifferentialInst(use->getUser()))
+ usesToReplace.add(use);
+ }
}
else if (auto call = as<IRCall>(inst))
{
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index 13e2ca078..92bc8197d 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -171,6 +171,10 @@ struct HoistResult
case Mode::Invert:
SLANG_UNEXPECTED("Wrong constructor for HoistResult::Mode::Invert");
break;
+ case Mode::None:
+ instToStore = nullptr;
+ instToRecompute = nullptr;
+ break;
default:
SLANG_UNEXPECTED("Unhandled hoist mode");
break;
@@ -187,6 +191,8 @@ struct HoistResult
static HoistResult recompute(IRInst* inst) { return HoistResult(Mode::Recompute, inst); }
static HoistResult invert(InversionInfo inst) { return HoistResult(inst); }
+
+ static HoistResult none() { return HoistResult(Mode::None, nullptr); }
};
struct IndexTrackingInfo : public RefObject