From 082c48d96c5f8f6b4f560d705fe731da14409cb4 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 29 Mar 2023 17:05:07 -0700 Subject: Update checkpoint policy to make obvious recompute decisions. (#2753) * Update checkpoint policy to make obvious recompute decisions. Also adds an optimization to fold updateElement chains on the same array or struct into a single makeArray or makeStruct. * Bug fixes around array types with different int typed count. * change test. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-ir-autodiff.cpp | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) (limited to 'source/slang/slang-ir-autodiff.cpp') diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 10c751d52..024d31fd8 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -282,15 +282,29 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( IRBuilder* builder, IRType* originalPairType) { IRInst* result = nullptr; - if (pairTypeCache.TryGetValue(originalPairType, result)) - return result; auto pairType = as(originalPairType); + if (!pairType) + return originalPairType; + + // We make our type cache keyed on the primal type, not the pair type. + // This is because there may be duplicate pair types for the same + // primal type but different witness tables, and we don't want to treat + // them as distinct. + // We might want to consider making witness tables part of IR + // deduplication (make them HOISTABLE insts), but that is a bigger + // change. Another alternative is to make the witness operand of + // `IRDifferentialPairTypeBase` be child instead of an operand + // so that it is not considered part of the type for deduplication + // purposes. + + auto primalType = pairType->getValueType(); + if (pairTypeCache.TryGetValue(primalType, result)) + return result; if (!pairType) { result = originalPairType; return result; } - auto primalType = pairType->getValueType(); if (as(primalType)) { result = nullptr; @@ -301,7 +315,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( if (!diffType) return result; result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); - pairTypeCache.Add(originalPairType, result); + pairTypeCache.Add(primalType, result); return result; } @@ -1820,4 +1834,16 @@ bool isDerivativeContextVar(IRVar* var) return var->findDecoration(); } +bool isDiffInst(IRInst* inst) +{ + if (inst->findDecoration() || + inst->findDecoration()) + return true; + + if (auto block = as(inst->getParent())) + return isDiffInst(block); + + return false; +} + } -- cgit v1.2.3