summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp94
1 files changed, 50 insertions, 44 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 3135f300d..574db2036 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -681,17 +681,6 @@ struct JVPTranscriber
return builder->getFuncType(newParameterTypes, diffReturnType);
}
- IRWitnessTable* getDifferentialBottomWitness()
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
- auto result =
- as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
- builder.getDifferentialBottomType()));
- SLANG_ASSERT(result);
- return result;
- }
-
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
{
@@ -699,20 +688,23 @@ struct JVPTranscriber
builder.setInsertInto(inDiffPairType->parent);
auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
SLANG_ASSERT(diffPairType);
- auto result =
- as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
- builder.getDifferentialBottomType()));
- if (result)
- return result;
-
- auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
auto diffType = differentiateType(&builder, diffPairType->getValueType());
- auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness());
- builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
- // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
- return table;
+ IRInst* tableInst = nullptr;
+ if (!differentiableTypeConformanceContext.differentiableWitnessDictionary.TryGetValue(diffPairType, tableInst))
+ {
+ IRWitnessTable* table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+ // The witness that `diffType`
+ auto differentialType = builder.getDifferentialPairType(
+ diffType,
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffType]
+ .GetValue());
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ tableInst = table;
+ }
+ return as<IRWitnessTable>(tableInst);
}
IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
@@ -730,8 +722,10 @@ struct JVPTranscriber
builder.setInsertInto(primalType->parent);
auto witness = as<IRWitnessTable>(
differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
- if (!witness)
- witness = getDifferentialBottomWitness();
+ if (!witness && as<IRDifferentialPairType>(primalType))
+ {
+ witness = getDifferentialPairWitness(primalType);
+ }
return builder.getDifferentialPairType(
(IRType*)primalType,
witness);
@@ -2205,29 +2199,41 @@ struct JVPDerivativeContext : public InstPassBase
bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren)
{
bool modified = false;
- // Hoist all pair types to global scope when possible.
+ // Hoist and deduplicate all pair types to global scope when possible.
+ // This avoids emitting different struct types for equivalent pair types.
auto moduleInst = module->getModuleInst();
- processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType)
- {
- if (originalPairType->parent != moduleInst)
+ Dictionary<IRInst*, IRInst*> diffPairTypes;
+ for (;;)
+ {
+ bool changed = false;
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* originalPairType)
{
- originalPairType->removeFromParent();
- ShortList<IRInst*> operands;
- for (UInt i = 0; i < originalPairType->getOperandCount(); i++)
+ IRInst* finalType = nullptr;
+ if (diffPairTypes.TryGetValue(originalPairType->getValueType(), finalType))
{
- operands.add(originalPairType->getOperand(i));
+ if (finalType != originalPairType)
+ {
+ originalPairType->replaceUsesWith(finalType);
+ originalPairType->removeAndDeallocate();
+ changed = true;
+ return;
+ }
}
- auto newPairType = builder->findOrEmitHoistableInst(
- originalPairType->getFullType(),
- originalPairType->getOp(),
- originalPairType->getOperandCount(),
- operands.getArrayView().getBuffer());
- originalPairType->replaceUsesWith(newPairType);
- originalPairType->removeAndDeallocate();
- }
- });
-
- sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+ diffPairTypes[originalPairType->getValueType()] = originalPairType;
+ if (originalPairType->parent != moduleInst)
+ {
+ if (originalPairType->getValueType()->getParent() != originalPairType->getParent())
+ {
+ originalPairType->insertAfter(originalPairType->getValueType());
+ changed = true;
+ return;
+ }
+ }
+ });
+ if (!changed)
+ break;
+ }
processAllInsts([&](IRInst* inst)
{