diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-23 06:59:25 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-23 06:59:25 -0800 |
| commit | 46a4d98baa1d43b33717b4377aefeeaf46b9c2ff (patch) | |
| tree | c89f3a1c416330f859887d00f896b18bcc7488a5 /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | 263ca18ea516cfce43fda703c0a411aaf1938e42 (diff) | |
Full address insts elimination for backward autodiff. (#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index d3a6137c1..779a4f1a3 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -5,10 +5,9 @@ #include "slang-ir-eliminate-phis.h" #include "slang-ir-util.h" #include "slang-ir-inst-pass-base.h" - +#include "slang-ir-ssa-simplification.h" #include "slang-ir-autodiff-fwd.h" - namespace Slang { IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType) @@ -502,6 +501,17 @@ namespace Slang stripDerivativeDecorations(primalFunc); eliminateDeadCode(primalOuterParent); + // Perform preparation and simplification. + differentiableTypeConformanceContext.setFunc(primalFunc); + if (SLANG_FAILED(eliminateAddressInsts( + builder->getSharedBuilder(), + differentiableTypeConformanceContext, + primalFunc, + sink))) + return nullptr; + + simplifyFunc(primalFunc); + // Forward transcribe the clone of the original func. ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>( autoDiffSharedContext->transcriberSet.forwardTranscriber); @@ -567,7 +577,9 @@ namespace Slang } auto fwdDiffFunc = generateNewForwardDerivativeForFunc(&tempBuilder, primalFunc, diffPropagateFunc); - + if (!fwdDiffFunc) + return; + // Split first block into a paramter block. this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc)); |
