summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-29 17:05:07 -0700
committerGitHub <noreply@github.com>2023-03-29 17:05:07 -0700
commit082c48d96c5f8f6b4f560d705fe731da14409cb4 (patch)
treefe9860aea3326cd321365bc5530a917fcef94718 /source/slang/slang-ir-autodiff.cpp
parenta862f5b7007ef50b5def30506f0cea138b73c710 (diff)
Update checkpoint policy to make obvious recompute decisions. (#2753)
* Update checkpoint policy to make obvious recompute decisions. Also adds an optimization to fold updateElement chains on the same array or struct into a single makeArray or makeStruct. * Bug fixes around array types with different int typed count. * change test. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp34
1 files changed, 30 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 10c751d52..024d31fd8 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -282,15 +282,29 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
IRBuilder* builder, IRType* originalPairType)
{
IRInst* result = nullptr;
- if (pairTypeCache.TryGetValue(originalPairType, result))
- return result;
auto pairType = as<IRDifferentialPairTypeBase>(originalPairType);
if (!pairType)
+ return originalPairType;
+
+ // We make our type cache keyed on the primal type, not the pair type.
+ // This is because there may be duplicate pair types for the same
+ // primal type but different witness tables, and we don't want to treat
+ // them as distinct.
+ // We might want to consider making witness tables part of IR
+ // deduplication (make them HOISTABLE insts), but that is a bigger
+ // change. Another alternative is to make the witness operand of
+ // `IRDifferentialPairTypeBase` be child instead of an operand
+ // so that it is not considered part of the type for deduplication
+ // purposes.
+
+ auto primalType = pairType->getValueType();
+ if (pairTypeCache.TryGetValue(primalType, result))
+ return result;
+ if (!pairType)
{
result = originalPairType;
return result;
}
- auto primalType = pairType->getValueType();
if (as<IRParam>(primalType))
{
result = nullptr;
@@ -301,7 +315,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
if (!diffType)
return result;
result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
- pairTypeCache.Add(originalPairType, result);
+ pairTypeCache.Add(primalType, result);
return result;
}
@@ -1820,4 +1834,16 @@ bool isDerivativeContextVar(IRVar* var)
return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>();
}
+bool isDiffInst(IRInst* inst)
+{
+ if (inst->findDecoration<IRDifferentialInstDecoration>() ||
+ inst->findDecoration<IRMixedDifferentialInstDecoration>())
+ return true;
+
+ if (auto block = as<IRBlock>(inst->getParent()))
+ return isDiffInst(block);
+
+ return false;
+}
+
}