summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-25 17:27:40 -0800
committerGitHub <noreply@github.com>2023-01-25 17:27:40 -0800
commit1f4c7cab13341c2e9d48df2b01ed2c048c17c152 (patch)
treeed85dda63e1c939cf474961b965b7cc1883940bb /source/slang/slang-ir-autodiff-rev.cpp
parentaa6814be1f7dea20597ae34d477e79e53d4a543f (diff)
Unify UpdateField and UpdateElement with access chain. (#2611)
* Unify UpdateField and UpdateElement with access chain. * Fix warnings. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp42
1 files changed, 40 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 000921c7e..fce2043eb 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -7,6 +7,9 @@
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-autodiff-fwd.h"
+#include "slang-ir-single-return.h"
+#include "slang-ir-addr-inst-elimination.h"
+#include "slang-ir-eliminate-multilevel-break.h"
namespace Slang
{
@@ -483,6 +486,39 @@ namespace Slang
builder.emitBranch(firstBlock);
}
+ struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
+ {
+ DifferentiableTypeConformanceContext* diffTypeContext;
+
+ virtual bool shouldConvertAddrInst(IRInst* addrInst) override
+ {
+ if (isDifferentiableType(*diffTypeContext, addrInst->getDataType()))
+ return true;
+ return false;
+ }
+ };
+
+ SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func)
+ {
+ DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
+ diffTypeContext.setFunc(func);
+
+ if (!isSingleReturnFunc(func))
+ {
+ convertFuncToSingleReturnForm(func->getModule(), func);
+ }
+ eliminateMultiLevelBreakForFunc(func->getModule(), func);
+
+ AutoDiffAddressConversionPolicy cvtPolicty;
+ cvtPolicty.diffTypeContext = &diffTypeContext;
+ auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink);
+ if (SLANG_SUCCEEDED(result))
+ {
+ simplifyFunc(func);
+ }
+ return result;
+ }
+
// Create a copy of originalFunc's forward derivative in the same generic context (if any) of
// `diffPropagateFunc`.
IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc(
@@ -501,8 +537,10 @@ namespace Slang
stripDerivativeDecorations(primalFunc);
eliminateDeadCode(primalOuterParent);
- // Perform simplification.
- simplifyFunc(primalFunc);
+ // Perform required transformations and simplifications on the original func to make it
+ // reversible.
+ if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc)))
+ return diffPropagateFunc;
// Forward transcribe the clone of the original func.
ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(