diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-13 10:57:28 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-13 10:57:28 -0700 |
| commit | a911ca6e06ce41e403b80fe6054162393491c8ac (patch) | |
| tree | 6c8d56a3060b1887e7fd3126fe54a1241160eddd /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | 3fea56ef77a33273bf5af6f432163b30c0a0e1dc (diff) | |
Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695)
* Support high order diff pattern: `bwd_diff(fwd_diff(f))`.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 75 |
1 files changed, 1 insertions, 74 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index d7cce7c53..328af4867 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -2,14 +2,12 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" -#include "slang-ir-eliminate-phis.h" #include "slang-ir-autodiff-cfg-norm.h" #include "slang-ir-util.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-autodiff-fwd.h" #include "slang-ir-single-return.h" -#include "slang-ir-addr-inst-elimination.h" #include "slang-ir-eliminate-multilevel-break.h" #include "slang-ir-init-local-var.h" #include "slang-ir-redundancy-removal.h" @@ -516,65 +514,6 @@ namespace Slang builder.emitBranch(firstBlock); } - void insertTempVarForMutableParams(IRModule* module, IRFunc* func) - { - IRBuilder builder(module); - auto firstBlock = func->getFirstBlock(); - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - - OrderedDictionary<IRParam*, IRVar*> mapParamToTempVar; - List<IRParam*> params; - for (auto param : firstBlock->getParams()) - { - if (auto ptrType = as<IRPtrTypeBase>(param->getDataType())) - { - params.add(param); - } - } - - for (auto param : params) - { - auto ptrType = as<IRPtrTypeBase>(param->getDataType()); - auto tempVar = builder.emitVar(ptrType->getValueType()); - param->replaceUsesWith(tempVar); - mapParamToTempVar[param] = tempVar; - if (ptrType->getOp() != kIROp_OutType) - { - builder.emitStore(tempVar, builder.emitLoad(param)); - } - else - { - builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType())); - } - } - - for (auto block : func->getBlocks()) - { - for (auto inst : block->getChildren()) - { - if (inst->getOp() == kIROp_Return) - { - builder.setInsertBefore(inst); - for (auto& kv : mapParamToTempVar) - { - builder.emitStore(kv.Key, builder.emitLoad(kv.Value)); - } - } - } - } - } - - - struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy - { - DifferentiableTypeConformanceContext* diffTypeContext; - - virtual bool shouldConvertAddrInst(IRInst*) override - { - return true; - } - }; - SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func) { DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext); @@ -592,19 +531,7 @@ namespace Slang IRCFGNormalizationPass cfgPass = {this->getSink()}; normalizeCFG(autoDiffSharedContext->moduleInst->getModule(), func); - insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func); - - AutoDiffAddressConversionPolicy cvtPolicty; - cvtPolicty.diffTypeContext = &diffTypeContext; - auto result = eliminateAddressInsts(&cvtPolicty, func, sink); - - if (SLANG_SUCCEEDED(result)) - { - disableIRValidationAtInsert(); - simplifyFunc(func); - enableIRValidationAtInsert(); - } - return result; + return SLANG_OK; } // Create a copy of originalFunc's forward derivative in the same generic context (if any) of |
