diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-24 04:33:51 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-24 04:33:51 -0700 |
| commit | 03c10833beb331e234554808c2a80d3cadecc7c0 (patch) | |
| tree | b135201bfb1128409739405ca508a01922a97333 /source | |
| parent | 56a84a06488afb817f79fbd99e8b470bd587ccd1 (diff) | |
Fix nested bwdContextType lowering. (#2731)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 1909f860c..ddffd0e21 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -693,16 +693,19 @@ struct AutoDiffPass : public InstPassBase bool lowerIntermediateContextType(IRBuilder* builder) { - bool changed = false; + bool result = false; OrderedHashSet<IRInst*> loweredIntermediateTypes; // Replace all `BackwardDiffIntermediateContextType` insts with the struct type // that we generated during backward diff pass. - processAllInsts([&](IRInst* inst) - { - switch (inst->getOp()) + for (;;) + { + bool changed = false; + processAllInsts([&](IRInst* inst) { - case kIROp_BackwardDiffIntermediateContextType: + switch (inst->getOp()) + { + case kIROp_BackwardDiffIntermediateContextType: { auto differentiateInst = as<IRBackwardDiffIntermediateContextType>(inst); @@ -719,15 +722,18 @@ struct AutoDiffPass : public InstPassBase } } break; - default: - break; - } - }); - + default: + break; + } + }); + result |= changed; + if (!changed) + break; + } // Now we generate the differential type for the intermediate context type // to allow higher order differentiation. generateDifferentialImplementationForContextType(loweredIntermediateTypes); - return changed; + return result; } // Utility function for topology sorting the intermediate context types. |
