diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-25 17:27:40 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-25 17:27:40 -0800 |
| commit | 1f4c7cab13341c2e9d48df2b01ed2c048c17c152 (patch) | |
| tree | ed85dda63e1c939cf474961b965b7cc1883940bb /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | aa6814be1f7dea20597ae34d477e79e53d4a543f (diff) | |
Unify UpdateField and UpdateElement with access chain. (#2611)
* Unify UpdateField and UpdateElement with access chain.
* Fix warnings.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 42 |
1 files changed, 40 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 000921c7e..fce2043eb 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -7,6 +7,9 @@ #include "slang-ir-inst-pass-base.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-autodiff-fwd.h" +#include "slang-ir-single-return.h" +#include "slang-ir-addr-inst-elimination.h" +#include "slang-ir-eliminate-multilevel-break.h" namespace Slang { @@ -483,6 +486,39 @@ namespace Slang builder.emitBranch(firstBlock); } + struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy + { + DifferentiableTypeConformanceContext* diffTypeContext; + + virtual bool shouldConvertAddrInst(IRInst* addrInst) override + { + if (isDifferentiableType(*diffTypeContext, addrInst->getDataType())) + return true; + return false; + } + }; + + SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func) + { + DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext); + diffTypeContext.setFunc(func); + + if (!isSingleReturnFunc(func)) + { + convertFuncToSingleReturnForm(func->getModule(), func); + } + eliminateMultiLevelBreakForFunc(func->getModule(), func); + + AutoDiffAddressConversionPolicy cvtPolicty; + cvtPolicty.diffTypeContext = &diffTypeContext; + auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink); + if (SLANG_SUCCEEDED(result)) + { + simplifyFunc(func); + } + return result; + } + // Create a copy of originalFunc's forward derivative in the same generic context (if any) of // `diffPropagateFunc`. IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc( @@ -501,8 +537,10 @@ namespace Slang stripDerivativeDecorations(primalFunc); eliminateDeadCode(primalOuterParent); - // Perform simplification. - simplifyFunc(primalFunc); + // Perform required transformations and simplifications on the original func to make it + // reversible. + if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc))) + return diffPropagateFunc; // Forward transcribe the clone of the original func. ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>( |
