diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-01-17 20:21:01 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-17 17:21:01 -0800 |
| commit | 1a486813ef0bc7f7a2eb6eaeec2921fd71a2bd05 (patch) | |
| tree | 5d64c73b7859a1657f71af95da7bc9e78fc58bf2 /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | 2c437498d3a09b58de17a8865242814d9ea92fde (diff) | |
Added switch-case support; fixed non-diff parameter transposition (#2596)
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 8d6419cf2..de4fbe182 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -666,14 +666,15 @@ namespace Slang builder->setInsertInto(fwdDiffParameterBlock); - // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> - for (auto child = fwdDiffParameterBlock->getFirstParam(); child;) + List<IRParam*> fwdParams; + for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam()) { - IRParam* nextChild = child->getNextParam(); + fwdParams.add(child); + } - auto fwdParam = as<IRParam>(child); - SLANG_ASSERT(fwdParam); - + // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> + for (auto fwdParam : fwdParams) + { // TODO: Handle ptr<pair> types. if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType())) { @@ -690,10 +691,11 @@ namespace Slang else { // Default case (parameter has nothing to do with differentiation) - // Do nothing. + // Simply move the parameter to the end. + // + fwdParam->removeFromParent(); + fwdDiffParameterBlock->addParam(fwdParam); } - - child = nextChild; } auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); |
