summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-13 10:57:28 -0700
committerGitHub <noreply@github.com>2023-03-13 10:57:28 -0700
commita911ca6e06ce41e403b80fe6054162393491c8ac (patch)
tree6c8d56a3060b1887e7fd3126fe54a1241160eddd /source/slang/slang-ir-autodiff-rev.cpp
parent3fea56ef77a33273bf5af6f432163b30c0a0e1dc (diff)
Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695)
* Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp75
1 files changed, 1 insertions, 74 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index d7cce7c53..328af4867 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -2,14 +2,12 @@
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
-#include "slang-ir-eliminate-phis.h"
#include "slang-ir-autodiff-cfg-norm.h"
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-single-return.h"
-#include "slang-ir-addr-inst-elimination.h"
#include "slang-ir-eliminate-multilevel-break.h"
#include "slang-ir-init-local-var.h"
#include "slang-ir-redundancy-removal.h"
@@ -516,65 +514,6 @@ namespace Slang
builder.emitBranch(firstBlock);
}
- void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
- {
- IRBuilder builder(module);
- auto firstBlock = func->getFirstBlock();
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
-
- OrderedDictionary<IRParam*, IRVar*> mapParamToTempVar;
- List<IRParam*> params;
- for (auto param : firstBlock->getParams())
- {
- if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
- {
- params.add(param);
- }
- }
-
- for (auto param : params)
- {
- auto ptrType = as<IRPtrTypeBase>(param->getDataType());
- auto tempVar = builder.emitVar(ptrType->getValueType());
- param->replaceUsesWith(tempVar);
- mapParamToTempVar[param] = tempVar;
- if (ptrType->getOp() != kIROp_OutType)
- {
- builder.emitStore(tempVar, builder.emitLoad(param));
- }
- else
- {
- builder.emitStore(tempVar, builder.emitDefaultConstruct(ptrType->getValueType()));
- }
- }
-
- for (auto block : func->getBlocks())
- {
- for (auto inst : block->getChildren())
- {
- if (inst->getOp() == kIROp_Return)
- {
- builder.setInsertBefore(inst);
- for (auto& kv : mapParamToTempVar)
- {
- builder.emitStore(kv.Key, builder.emitLoad(kv.Value));
- }
- }
- }
- }
- }
-
-
- struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy
- {
- DifferentiableTypeConformanceContext* diffTypeContext;
-
- virtual bool shouldConvertAddrInst(IRInst*) override
- {
- return true;
- }
- };
-
SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func)
{
DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
@@ -592,19 +531,7 @@ namespace Slang
IRCFGNormalizationPass cfgPass = {this->getSink()};
normalizeCFG(autoDiffSharedContext->moduleInst->getModule(), func);
- insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
-
- AutoDiffAddressConversionPolicy cvtPolicty;
- cvtPolicty.diffTypeContext = &diffTypeContext;
- auto result = eliminateAddressInsts(&cvtPolicty, func, sink);
-
- if (SLANG_SUCCEEDED(result))
- {
- disableIRValidationAtInsert();
- simplifyFunc(func);
- enableIRValidationAtInsert();
- }
- return result;
+ return SLANG_OK;
}
// Create a copy of originalFunc's forward derivative in the same generic context (if any) of