From 284cee1f246c072f190c87c8fb60c1d2181e458f Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 24 Apr 2023 19:44:23 -0700 Subject: Change AD checkpointing policy to recompute more. (#2836) Co-authored-by: Yong He --- source/slang/slang-ir-autodiff.cpp | 42 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) (limited to 'source/slang/slang-ir-autodiff.cpp') diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index a8af148d9..656b0e11b 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1868,6 +1868,48 @@ IRUse* findUniqueStoredVal(IRVar* var) } } +// Given a local var that is supposed to have a unique write, find the last inst +// that writes to it. Note: if var is intended for an inout argument, it will +// have exactly one store that sets its initial value and one call that writes +// the final value to it, this method will return the call inst for this case. +IRUse* findLatestUniqueWriteUse(IRVar* var) +{ + IRUse* storeUse = nullptr; + // If no unique store found, try to look for a call. + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto callInst = as(use->getUser())) + { + SLANG_RELEASE_ASSERT(!storeUse); + storeUse = use; + } + } + return findUniqueStoredVal(var); +} + +// Given a local var that is supposed to have a unique write, find the last inst +// that writes to it. Note: if var is intended for an inout argument, it will +// have exactly one store that sets its initial value and one call that writes +// the final value to it, this method will return the store inst for this case. +IRUse* findEarliestUniqueWriteUse(IRVar* var) +{ + IRUse* storeUse = findUniqueStoredVal(var); + if (storeUse) + return storeUse; + + // If no unique store found, try to look for a call. + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto callInst = as(use->getUser())) + { + SLANG_RELEASE_ASSERT(!storeUse); + storeUse = use; + } + } + return storeUse; +} + + bool isDerivativeContextVar(IRVar* var) { return var->findDecoration(); -- cgit v1.2.3