diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-29 17:05:07 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-29 17:05:07 -0700 |
| commit | 082c48d96c5f8f6b4f560d705fe731da14409cb4 (patch) | |
| tree | fe9860aea3326cd321365bc5530a917fcef94718 /source/slang/slang-ir-autodiff.cpp | |
| parent | a862f5b7007ef50b5def30506f0cea138b73c710 (diff) | |
Update checkpoint policy to make obvious recompute decisions. (#2753)
* Update checkpoint policy to make obvious recompute decisions.
Also adds an optimization to fold updateElement chains on the same array or struct into a single makeArray or makeStruct.
* Bug fixes around array types with different int typed count.
* change test.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 34 |
1 files changed, 30 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 10c751d52..024d31fd8 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -282,15 +282,29 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( IRBuilder* builder, IRType* originalPairType) { IRInst* result = nullptr; - if (pairTypeCache.TryGetValue(originalPairType, result)) - return result; auto pairType = as<IRDifferentialPairTypeBase>(originalPairType); if (!pairType) + return originalPairType; + + // We make our type cache keyed on the primal type, not the pair type. + // This is because there may be duplicate pair types for the same + // primal type but different witness tables, and we don't want to treat + // them as distinct. + // We might want to consider making witness tables part of IR + // deduplication (make them HOISTABLE insts), but that is a bigger + // change. Another alternative is to make the witness operand of + // `IRDifferentialPairTypeBase` be child instead of an operand + // so that it is not considered part of the type for deduplication + // purposes. + + auto primalType = pairType->getValueType(); + if (pairTypeCache.TryGetValue(primalType, result)) + return result; + if (!pairType) { result = originalPairType; return result; } - auto primalType = pairType->getValueType(); if (as<IRParam>(primalType)) { result = nullptr; @@ -301,7 +315,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( if (!diffType) return result; result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); - pairTypeCache.Add(originalPairType, result); + pairTypeCache.Add(primalType, result); return result; } @@ -1820,4 +1834,16 @@ bool isDerivativeContextVar(IRVar* var) return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>(); } +bool isDiffInst(IRInst* inst) +{ + if (inst->findDecoration<IRDifferentialInstDecoration>() || + inst->findDecoration<IRMixedDifferentialInstDecoration>()) + return true; + + if (auto block = as<IRBlock>(inst->getParent())) + return isDiffInst(block); + + return false; +} + } |
