summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-01-17 20:21:01 -0500
committerGitHub <noreply@github.com>2023-01-17 17:21:01 -0800
commit1a486813ef0bc7f7a2eb6eaeec2921fd71a2bd05 (patch)
tree5d64c73b7859a1657f71af95da7bc9e78fc58bf2 /source/slang/slang-ir-autodiff-rev.cpp
parent2c437498d3a09b58de17a8865242814d9ea92fde (diff)
Added switch-case support; fixed non-diff parameter transposition (#2596)
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp20
1 files changed, 11 insertions, 9 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 8d6419cf2..de4fbe182 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -666,14 +666,15 @@ namespace Slang
builder->setInsertInto(fwdDiffParameterBlock);
- // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<>
- for (auto child = fwdDiffParameterBlock->getFirstParam(); child;)
+ List<IRParam*> fwdParams;
+ for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam())
{
- IRParam* nextChild = child->getNextParam();
+ fwdParams.add(child);
+ }
- auto fwdParam = as<IRParam>(child);
- SLANG_ASSERT(fwdParam);
-
+ // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<>
+ for (auto fwdParam : fwdParams)
+ {
// TODO: Handle ptr<pair> types.
if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()))
{
@@ -690,10 +691,11 @@ namespace Slang
else
{
// Default case (parameter has nothing to do with differentiation)
- // Do nothing.
+ // Simply move the parameter to the end.
+ //
+ fwdParam->removeFromParent();
+ fwdDiffParameterBlock->addParam(fwdParam);
}
-
- child = nextChild;
}
auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();