summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-24 04:33:51 -0700
committerGitHub <noreply@github.com>2023-03-24 04:33:51 -0700
commit03c10833beb331e234554808c2a80d3cadecc7c0 (patch)
treeb135201bfb1128409739405ca508a01922a97333 /source
parent56a84a06488afb817f79fbd99e8b470bd587ccd1 (diff)
Fix nested bwdContextType lowering. (#2731)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff.cpp28
1 files changed, 17 insertions, 11 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 1909f860c..ddffd0e21 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -693,16 +693,19 @@ struct AutoDiffPass : public InstPassBase
bool lowerIntermediateContextType(IRBuilder* builder)
{
- bool changed = false;
+ bool result = false;
OrderedHashSet<IRInst*> loweredIntermediateTypes;
// Replace all `BackwardDiffIntermediateContextType` insts with the struct type
// that we generated during backward diff pass.
- processAllInsts([&](IRInst* inst)
- {
- switch (inst->getOp())
+ for (;;)
+ {
+ bool changed = false;
+ processAllInsts([&](IRInst* inst)
{
- case kIROp_BackwardDiffIntermediateContextType:
+ switch (inst->getOp())
+ {
+ case kIROp_BackwardDiffIntermediateContextType:
{
auto differentiateInst = as<IRBackwardDiffIntermediateContextType>(inst);
@@ -719,15 +722,18 @@ struct AutoDiffPass : public InstPassBase
}
}
break;
- default:
- break;
- }
- });
-
+ default:
+ break;
+ }
+ });
+ result |= changed;
+ if (!changed)
+ break;
+ }
// Now we generate the differential type for the intermediate context type
// to allow higher order differentiation.
generateDifferentialImplementationForContextType(loweredIntermediateTypes);
- return changed;
+ return result;
}
// Utility function for topology sorting the intermediate context types.