summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-autodiff-unzip.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-01-10 03:16:24 +0530
committerGitHub <noreply@github.com>2025-01-09 13:46:24 -0800
commit87f00a36a123e36b415eeea82e02a8366cc5b881 (patch)
tree719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-unzip.cpp
parent6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff)
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-unzip.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp16
1 files changed, 16 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 9ee2cb4d2..49c1d9ff7 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -93,6 +93,22 @@ struct ExtractPrimalFuncContext
as<IRGeneric>(findOuterGeneric(destFunc)),
destFunc);
+ if (auto origGeneric = as<IRGeneric>(findOuterGeneric(originalFunc)))
+ {
+ // Clone in everything else except the return value.
+ IRBuilder subBuilder(destFunc);
+ builder.setInsertAfter(findOuterGeneric(destFunc)->getFirstBlock()->getLastParam());
+
+ // Clone in any hoistable insts.
+ for (auto child = origGeneric->getFirstBlock()->getFirstOrdinaryInst(); child;
+ child = child->getNextInst())
+ {
+ if ((child != originalFunc) && !as<IRReturn>(child) &&
+ !as<IRGlobalValueWithCode>(child))
+ migrationContext.cloneInst(&subBuilder, child);
+ }
+ }
+
originalFuncType = as<IRFuncType>(originalFunc->getDataType());
SLANG_RELEASE_ASSERT(originalFuncType);