summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff.cpp28
1 files changed, 17 insertions, 11 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 1909f860c..ddffd0e21 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -693,16 +693,19 @@ struct AutoDiffPass : public InstPassBase
bool lowerIntermediateContextType(IRBuilder* builder)
{
- bool changed = false;
+ bool result = false;
OrderedHashSet<IRInst*> loweredIntermediateTypes;
// Replace all `BackwardDiffIntermediateContextType` insts with the struct type
// that we generated during backward diff pass.
- processAllInsts([&](IRInst* inst)
- {
- switch (inst->getOp())
+ for (;;)
+ {
+ bool changed = false;
+ processAllInsts([&](IRInst* inst)
{
- case kIROp_BackwardDiffIntermediateContextType:
+ switch (inst->getOp())
+ {
+ case kIROp_BackwardDiffIntermediateContextType:
{
auto differentiateInst = as<IRBackwardDiffIntermediateContextType>(inst);
@@ -719,15 +722,18 @@ struct AutoDiffPass : public InstPassBase
}
}
break;
- default:
- break;
- }
- });
-
+ default:
+ break;
+ }
+ });
+ result |= changed;
+ if (!changed)
+ break;
+ }
// Now we generate the differential type for the intermediate context type
// to allow higher order differentiation.
generateDifferentialImplementationForContextType(loweredIntermediateTypes);
- return changed;
+ return result;
}
// Utility function for topology sorting the intermediate context types.