diff options
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 94 |
1 files changed, 50 insertions, 44 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 3135f300d..574db2036 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -681,17 +681,6 @@ struct JVPTranscriber return builder->getFuncType(newParameterTypes, diffReturnType); } - IRWitnessTable* getDifferentialBottomWitness() - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); - auto result = - as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - SLANG_ASSERT(result); - return result; - } - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) { @@ -699,20 +688,23 @@ struct JVPTranscriber builder.setInsertInto(inDiffPairType->parent); auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); SLANG_ASSERT(diffPairType); - auto result = - as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - if (result) - return result; - - auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); auto diffType = differentiateType(&builder, diffPairType->getValueType()); - auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness()); - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - return table; + IRInst* tableInst = nullptr; + if (!differentiableTypeConformanceContext.differentiableWitnessDictionary.TryGetValue(diffPairType, tableInst)) + { + IRWitnessTable* table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + // The witness that `diffType` + auto differentialType = builder.getDifferentialPairType( + diffType, + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffType] + .GetValue()); + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + tableInst = table; + } + return as<IRWitnessTable>(tableInst); } IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) @@ -730,8 +722,10 @@ struct JVPTranscriber builder.setInsertInto(primalType->parent); auto witness = as<IRWitnessTable>( differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - if (!witness) - witness = getDifferentialBottomWitness(); + if (!witness && as<IRDifferentialPairType>(primalType)) + { + witness = getDifferentialPairWitness(primalType); + } return builder.getDifferentialPairType( (IRType*)primalType, witness); @@ -2205,29 +2199,41 @@ struct JVPDerivativeContext : public InstPassBase bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) { bool modified = false; - // Hoist all pair types to global scope when possible. + // Hoist and deduplicate all pair types to global scope when possible. + // This avoids emitting different struct types for equivalent pair types. auto moduleInst = module->getModuleInst(); - processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType) - { - if (originalPairType->parent != moduleInst) + Dictionary<IRInst*, IRInst*> diffPairTypes; + for (;;) + { + bool changed = false; + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* originalPairType) { - originalPairType->removeFromParent(); - ShortList<IRInst*> operands; - for (UInt i = 0; i < originalPairType->getOperandCount(); i++) + IRInst* finalType = nullptr; + if (diffPairTypes.TryGetValue(originalPairType->getValueType(), finalType)) { - operands.add(originalPairType->getOperand(i)); + if (finalType != originalPairType) + { + originalPairType->replaceUsesWith(finalType); + originalPairType->removeAndDeallocate(); + changed = true; + return; + } } - auto newPairType = builder->findOrEmitHoistableInst( - originalPairType->getFullType(), - originalPairType->getOp(), - originalPairType->getOperandCount(), - operands.getArrayView().getBuffer()); - originalPairType->replaceUsesWith(newPairType); - originalPairType->removeAndDeallocate(); - } - }); - - sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + diffPairTypes[originalPairType->getValueType()] = originalPairType; + if (originalPairType->parent != moduleInst) + { + if (originalPairType->getValueType()->getParent() != originalPairType->getParent()) + { + originalPairType->insertAfter(originalPairType->getValueType()); + changed = true; + return; + } + } + }); + if (!changed) + break; + } processAllInsts([&](IRInst* inst) { |
