summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-01-23 06:59:25 -0800
committerGitHub <noreply@github.com>2023-01-23 06:59:25 -0800
commit46a4d98baa1d43b33717b4377aefeeaf46b9c2ff (patch)
treec89f3a1c416330f859887d00f896b18bcc7488a5 /source/slang/slang-ir-autodiff-rev.cpp
parent263ca18ea516cfce43fda703c0a411aaf1938e42 (diff)
Full address insts elimination for backward autodiff. (#2604)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp18
1 files changed, 15 insertions, 3 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index d3a6137c1..779a4f1a3 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -5,10 +5,9 @@
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
-
+#include "slang-ir-ssa-simplification.h"
#include "slang-ir-autodiff-fwd.h"
-
namespace Slang
{
IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType)
@@ -502,6 +501,17 @@ namespace Slang
stripDerivativeDecorations(primalFunc);
eliminateDeadCode(primalOuterParent);
+ // Perform preparation and simplification.
+ differentiableTypeConformanceContext.setFunc(primalFunc);
+ if (SLANG_FAILED(eliminateAddressInsts(
+ builder->getSharedBuilder(),
+ differentiableTypeConformanceContext,
+ primalFunc,
+ sink)))
+ return nullptr;
+
+ simplifyFunc(primalFunc);
+
// Forward transcribe the clone of the original func.
ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>(
autoDiffSharedContext->transcriberSet.forwardTranscriber);
@@ -567,7 +577,9 @@ namespace Slang
}
auto fwdDiffFunc = generateNewForwardDerivativeForFunc(&tempBuilder, primalFunc, diffPropagateFunc);
-
+ if (!fwdDiffFunc)
+ return;
+
// Split first block into a paramter block.
this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));