summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp451
1 files changed, 170 insertions, 281 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 3d02d4fc0..d0bf8f347 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -6,6 +6,7 @@
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
// origX, primalX, diffX
// origX -> primalX (cloneEnv)
@@ -26,11 +27,9 @@ struct Pair
typedef Pair<IRInst*, IRInst*> InstPair;
-struct DifferentiableTypeConformanceContext
+struct AutoDiffSharedContext
{
- Dictionary<IRInst*, IRInst*> witnessTableMap;
-
- IRInst* inst = nullptr;
+ IRModuleInst* moduleInst = nullptr;
// A reference to the builtin IDifferentiable interface type.
// We use this to look up all the other types (and type exprs)
@@ -62,114 +61,27 @@ struct DifferentiableTypeConformanceContext
//
bool isInterfaceAvailable = false;
- // For handling generic blocks, we use a parent pointer to allow
- // looking up types in all relevant scopes.
- DifferentiableTypeConformanceContext* parent = nullptr;
- DifferentiableTypeConformanceContext(DifferentiableTypeConformanceContext* parent, IRInst* inst) : parent(parent), inst(inst)
+ AutoDiffSharedContext(IRModuleInst* inModuleInst)
+ : moduleInst(inModuleInst)
{
- if (parent)
+ differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
+ if (differentiableInterfaceType)
{
- differentiableInterfaceType = parent->differentiableInterfaceType;
- differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey;
- zeroMethodStructKey = parent->zeroMethodStructKey;
- addMethodStructKey = parent->addMethodStructKey;
-
- isInterfaceAvailable = parent->isInterfaceAvailable;
- }
- else
- {
- differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
- if (differentiableInterfaceType)
- {
- differentialAssocTypeStructKey = findDifferentialTypeStructKey();
- zeroMethodStructKey = findZeroMethodStructKey();
- addMethodStructKey = findAddMethodStructKey();
-
- if (differentialAssocTypeStructKey)
- isInterfaceAvailable = true;
- }
- }
- }
-
- DifferentiableTypeConformanceContext(IRInst* inst) :
- DifferentiableTypeConformanceContext(nullptr, inst)
- {}
+ differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ zeroMethodStructKey = findZeroMethodStructKey();
+ addMethodStructKey = findAddMethodStructKey();
- // Lookup a witness table for the concreteType. One should exist if concreteType
- // inherits (successfully) from IDifferentiable.
- //
- IRInst* lookUpConformanceForType(IRBuilder* builder, IRInst* type)
- {
- SLANG_ASSERT(isInterfaceAvailable);
- // TODO: Cache the returned value to avoid repeatedly scanning through
- // blocks looking for the type entries.
- //
- if (auto irWitness = builder->findDifferentiableTypeEntry(type, type->getParent()))
- {
- return irWitness;
- }
-
- return nullptr;
- }
-
- IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
- {
- if (auto conformance = lookUpConformanceForType(builder, origType))
- {
- if (auto witnessTable = as<IRWitnessTable>(conformance))
- {
- for (auto entry : witnessTable->getEntries())
- {
- if (entry->getRequirementKey() == key)
- return entry->getSatisfyingVal();
- }
- }
- else if (auto witnessTableParam = as<IRParam>(conformance))
- {
- return builder->emitLookupInterfaceMethodInst(
- builder->getTypeKind(),
- witnessTableParam,
- key);
- }
- }
-
- return nullptr;
- }
-
- // Lookup and return the 'Differential' type declared in the concrete type
- // in order to conform to the IDifferentiable interface.
- // Note that inside a generic block, this will be a witness table lookup instruction
- // that gets resolved during the specialization pass.
- //
- IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
- {
- switch (origType->getOp())
- {
- case kIROp_FloatType:
- case kIROp_HalfType:
- case kIROp_DoubleType:
- case kIROp_VectorType:
- return origType;
+ if (differentialAssocTypeStructKey)
+ isInterfaceAvailable = true;
}
- return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey);
- }
-
- IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, zeroMethodStructKey);
- }
-
- IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, addMethodStructKey);
}
private:
IRInst* findDifferentiableInterface()
{
- if (auto module = as<IRModuleInst>(inst))
+ if (auto module = as<IRModuleInst>(moduleInst))
{
for (auto globalInst : module->getGlobalInsts())
{
@@ -203,7 +115,7 @@ struct DifferentiableTypeConformanceContext
IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index)
{
- if (as<IRModuleInst>(inst) && differentiableInterfaceType)
+ if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
{
// Assume for now that IDifferentiable has exactly four fields.
SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4);
@@ -217,110 +129,126 @@ struct DifferentiableTypeConformanceContext
return nullptr;
}
+};
- void loadWitnessTablesForInterface(IRInst* interfaceType)
+namespace
+{
+
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
+{
+ if (auto witnessTable = as<IRWitnessTable>(witness))
{
-
- if (auto module = as<IRModuleInst>(inst))
+ for (auto entry : witnessTable->getEntries())
{
- for (auto globalInst : module->getGlobalInsts())
- {
- if (globalInst->getOp() == kIROp_WitnessTable &&
- cast<IRWitnessTableType>(globalInst->getDataType())->getConformanceType() ==
- interfaceType)
- {
- // TODO: Can we have multiple conformances for the same pair of types?
- // TODO: Can type instrs be duplicated (i.e. two different float types)? And if they are duplicated, can
- // we supply the dictionary with a custom equality rule that uses 'type1->equals(type2)'
- witnessTableMap.Add(as<IRWitnessTable>(globalInst)->getConcreteType(), globalInst);
- }
- }
+ if (entry->getRequirementKey() == requirementKey)
+ return entry->getSatisfyingVal();
}
- else if (auto generic = as<IRGeneric>(inst))
- {
- List<IRParam*> typeParams;
+ }
+ else if (auto witnessTableParam = as<IRParam>(witness))
+ {
+ return builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ witnessTableParam,
+ requirementKey);
+ }
+ return nullptr;
+}
+
+}
+
+struct DifferentiableTypeConformanceContext
+{
+ AutoDiffSharedContext* sharedContext;
+
+ IRGlobalValueWithCode* parentFunc = nullptr;
+ Dictionary<IRType*, IRInst*> differentiableWitnessDictionary;
+
+ DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
+ : sharedContext(shared)
+ {}
+
+ void setFunc(IRGlobalValueWithCode* func)
+ {
+ parentFunc = func;
+
+ auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
+ SLANG_RELEASE_ASSERT(decor);
- auto genericParam = generic->getFirstParam();
- while (genericParam)
+ // Build lookup dictionary for type witnesses.
+ for (auto child = decor->getFirstChild(); child; child = child->next)
+ {
+ if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
{
- if (as<IRTypeType>(genericParam->getDataType()))
+ auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
+ if (existingItem)
{
- typeParams.add(genericParam);
+ if (auto witness = as<IRWitnessTable>(item->getWitness()))
+ {
+ if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType)
+ continue;
+ }
+ *existingItem = item->getWitness();
}
else
- break;
-
- genericParam = genericParam->getNextParam();
- }
-
- Count tableIndex = 0;
- while (genericParam)
- {
- SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType()));
-
- if (tableIndex >= typeParams.getCount())
- break;
-
- if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType()))
{
- // TODO(sai): Heavily flawed way to find the right witness table.
- // Rewrite this part
- if (witnessTableType->getConformanceType() == differentiableInterfaceType)
- witnessTableMap.Add(typeParams[tableIndex], genericParam);
+ differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
}
- else
- break;
-
- tableIndex += 1;
- genericParam = genericParam->getNextParam();
}
-
}
-
}
-};
-
-IRInst* findGlobal(IRInst* inst)
-{
- if (inst->getParent() != inst->getModule()->getModuleInst())
+ // Lookup a witness table for the concreteType. One should exist if concreteType
+ // inherits (successfully) from IDifferentiable.
+ //
+ IRInst* lookUpConformanceForType(IRInst* type)
{
- return findGlobal(inst->getParent());
+ IRInst* foundResult = nullptr;
+ differentiableWitnessDictionary.TryGetValue(type, foundResult);
+ return foundResult;
}
- return inst;
-}
-
-void moveGlobalToBeforeUses(IRBuilder*, IRInst* globalInst)
-{
- HashSet<IRInst*> globalsOfUses;
- for (auto use = globalInst->firstUse; use; use = use->nextUse)
+ IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
{
- globalsOfUses.Add(findGlobal(use->getUser()));
+ if (auto conformance = lookUpConformanceForType(origType))
+ {
+ return _lookupWitness(builder, conformance, key);
+ }
+ return nullptr;
}
- IRInst* earliestUse = nullptr;
- for (auto cursor = globalInst; cursor; cursor = cursor->getPrevInst())
- {
- if (globalsOfUses.Contains(cursor))
+ // Lookup and return the 'Differential' type declared in the concrete type
+ // in order to conform to the IDifferentiable interface.
+ // Note that inside a generic block, this will be a witness table lookup instruction
+ // that gets resolved during the specialization pass.
+ //
+ IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
+ {
+ switch (origType->getOp())
{
- earliestUse = cursor;
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ case kIROp_VectorType:
+ return origType;
}
+ return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
}
- if (earliestUse)
+ IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
{
- globalInst->insertBefore(earliestUse);
+ return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
+ }
+
+ IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
}
-}
+
+};
struct DifferentialPairTypeBuilder
{
-
- DifferentialPairTypeBuilder(DifferentiableTypeConformanceContext* diffConformanceContext) :
- diffConformanceContext(diffConformanceContext)
- {}
IRStructField* findField(IRInst* type, IRStructKey* key)
{
@@ -454,14 +382,6 @@ struct DifferentialPairTypeBuilder
return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
}
- void relocateNewTypes(IRBuilder* builder)
- {
- for (auto typeInst : generatedTypeList)
- {
- moveGlobalToBeforeUses(builder, typeInst);
- }
- }
-
IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder)
{
if (!this->globalDiffKey)
@@ -496,27 +416,23 @@ struct DifferentialPairTypeBuilder
return this->globalPrimalKey;
}
- IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType)
{
- if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType))
- {
- SLANG_ASSERT(!as<IRParam>(origBaseType));
-
- auto pairStructType = builder->createStructType();
- builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
- builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*) diffBaseType);
+ SLANG_ASSERT(!as<IRParam>(origBaseType));
+ SLANG_ASSERT(diffType);
+ auto pairStructType = builder->createStructType();
+ builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
+ builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType);
- return pairStructType;
- }
- return nullptr;
+ return pairStructType;
}
- IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType)
{
if (pairTypeCache.ContainsKey(origBaseType))
return pairTypeCache[origBaseType];
- auto pairType = _createDiffPairType(builder, origBaseType);
+ auto pairType = _createDiffPairType(builder, origBaseType, diffType);
pairTypeCache.Add(origBaseType, pairType);
return pairType;
@@ -524,8 +440,6 @@ struct DifferentialPairTypeBuilder
Dictionary<IRInst*, IRInst*> pairTypeCache;
- DifferentiableTypeConformanceContext* diffConformanceContext;
-
IRStructKey* globalPrimalKey = nullptr;
IRStructKey* globalDiffKey = nullptr;
@@ -553,11 +467,17 @@ struct JVPTranscriber
DiagnosticSink* sink;
// Type conformance information.
- DifferentiableTypeConformanceContext* diffConformanceContext;
+ AutoDiffSharedContext* autoDiffSharedContext;
// Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
DifferentialPairTypeBuilder* pairBuilder;
+ DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
+
+ JVPTranscriber(AutoDiffSharedContext* shared)
+ : differentiableTypeConformanceContext(shared)
+ {}
+
DiagnosticSink* getSink()
{
SLANG_ASSERT(sink);
@@ -692,7 +612,7 @@ struct JVPTranscriber
{
case kIROp_Param:
if (as<IRTypeType>(primalType->getDataType()))
- return (IRType*)(diffConformanceContext->getDifferentialForType(
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
builder,
(IRType*)primalType));
else if (as<IRWitnessTableType>(primalType->getDataType()))
@@ -737,7 +657,7 @@ struct JVPTranscriber
}
default:
- return (IRType*)(diffConformanceContext->getDifferentialForType(builder, (IRType*)primalType));
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
}
}
@@ -753,8 +673,10 @@ struct JVPTranscriber
else
return nullptr;
}
-
- return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType);
+ auto diffType = differentiateType(builder, primalType);
+ if (diffType)
+ return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType, diffType);
+ return nullptr;
}
InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
@@ -1325,7 +1247,7 @@ struct JVPTranscriber
{
// Since primalType has a corresponding differential type, we can lookup the
// definition for zero().
- auto zeroMethod = this->diffConformanceContext->getZeroMethodForType(builder, primalType);
+ auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
SLANG_ASSERT(zeroMethod);
auto emptyArgList = List<IRInst*>();
@@ -1333,6 +1255,11 @@ struct JVPTranscriber
}
else
{
+ if (isScalarIntegerType(primalType))
+ {
+ return builder->getIntValue(primalType, 0);
+ }
+
getSink()->diagnose(primalType->sourceLoc,
Diagnostics::internalCompilerError,
"could not generate zero value for given type");
@@ -1359,17 +1286,6 @@ struct JVPTranscriber
for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
this->transcribe(builder, param);
- // Look for the differentiable type dictionary and clone it (and anything else we might need).
- // TODO: This logic might have issues if there are additional instructions (say lookup_interface_requirement)
- // that are operands.
- // TODO: This is currently cloning the global dictionary. Should only clone dictionaries in generic blocks.
- if (auto origDict = builder->findDifferentiableTypeDictionary(origBlock))
- {
- auto clonedDict = cloneInst(&cloneEnv, builder, origDict);
- mapPrimalInst(origDict, clonedDict);
- mapDifferentialInst(origDict, clonedDict);
- }
-
// Then, run through every instruction and use the transcriber to generate the appropriate
// derivative code.
//
@@ -1547,6 +1463,8 @@ struct JVPTranscriber
{
IRFunc* primalFunc = nullptr;
+ differentiableTypeConformanceContext.setFunc(origFunc);
+
auto oldLoc = builder->getInsertLoc();
// If this is a top-level function, there is no need to clone it
@@ -1602,6 +1520,16 @@ struct JVPTranscriber
// Transcribe a generic definition
InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric)
{
+ auto innerVal = findInnerMostGenericReturnVal(origGeneric);
+ if (auto innerFunc = as<IRFunc>(innerVal))
+ {
+ differentiableTypeConformanceContext.setFunc(innerFunc);
+ }
+ else
+ {
+ return InstPair(origGeneric, nullptr);
+ }
+
// For now, we assume there's only one generic layer. So this inst must be top level
bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr);
SLANG_RELEASE_ASSERT(isTopLevel);
@@ -1757,10 +1685,6 @@ struct JVPTranscriber
case kIROp_ifElse:
return transcribeIfElse(builder, as<IRIfElse>(origInst));
- case kIROp_DifferentiableTypeDictionary:
- // Ignore dictionary insts.
- return InstPair(nullptr, nullptr);
-
}
// If none of the cases have been hit, check if the instruction is a
@@ -1885,11 +1809,8 @@ struct JVPDerivativeContext
// IRDifferentialPairGetPrimal with 'primal' field access, and
// IRMakeDifferentialPair with an IRMakeStruct.
//
- modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage));
+ modified |= processPairTypes(builder, module->getModuleInst());
- // Temporary fix: Move generated types, if any, to before their use locations.
- (&pairBuilderStorage)->relocateNewTypes(builder);
-
return modified;
}
@@ -1981,7 +1902,7 @@ struct JVPDerivativeContext
return true;
}
- IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext*)
+ IRInst* lowerPairType(IRBuilder* builder, IRType* type)
{
if (auto pairType = as<IRDifferentialPairType>(type))
@@ -1990,13 +1911,18 @@ struct JVPDerivativeContext
if (!as<IRType>(pairType->getValueType()))
{
- // Do not handle non-concrete types.
return nullptr;
}
-
+ auto witness = pairType->getWitness();
+ auto diffType = _lookupWitness(builder, witness, autoDiffSharedContextStorage.differentialAssocTypeStructKey);
+ if (!diffType)
+ {
+ return nullptr;
+ }
auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
builder,
- pairType->getValueType());
+ pairType->getValueType(),
+ (IRType*)(diffType));
pairType->replaceUsesWith(diffPairStructType);
pairType->removeAndDeallocate();
@@ -2017,12 +1943,12 @@ struct JVPDerivativeContext
return nullptr;
}
- IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
+ IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
{
if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
{
- if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext))
+ if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType()))
{
builder->setInsertBefore(makePairInst);
@@ -2041,11 +1967,11 @@ struct JVPDerivativeContext
return nullptr;
}
- IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
+ IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst)
{
if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
{
- if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext))
+ if (lowerPairType(builder, getDiffInst->getBase()->getDataType()))
{
builder->setInsertBefore(getDiffInst);
@@ -2057,7 +1983,7 @@ struct JVPDerivativeContext
}
else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
{
- if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext))
+ if (lowerPairType(builder, getPrimalInst->getBase()->getDataType()))
{
builder->setInsertBefore(getPrimalInst);
@@ -2072,16 +1998,10 @@ struct JVPDerivativeContext
return nullptr;
}
- bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren, DifferentiableTypeConformanceContext* diffContext)
+ bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren)
{
bool modified = false;
- // Create a new sub-context to scan witness tables inside workItem
- // (mainly relevant if instWithChildren is a generic scope)
- //
- auto subContext = DifferentiableTypeConformanceContext(diffContext, instWithChildren);
- (&pairBuilderStorage)->diffConformanceContext = (&subContext);
-
for (auto child = instWithChildren->getFirstChild(); child; )
{
// Make sure the builder is at the right level.
@@ -2092,53 +2012,21 @@ struct JVPDerivativeContext
switch (child->getOp())
{
case kIROp_DifferentialPairType:
- lowerPairType(builder, as<IRType>(child), &subContext);
+ lowerPairType(builder, as<IRType>(child));
break;
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
- lowerPairAccess(builder, child, &subContext);
+ lowerPairAccess(builder, child);
break;
case kIROp_MakeDifferentialPair:
- lowerMakePair(builder, child, &subContext);
+ lowerMakePair(builder, child);
break;
default:
if (child->getFirstChild())
- modified = processPairTypes(builder, child, (&subContext)) | modified;
- }
-
- child = nextChild;
- }
-
- // Reset the context back to the parent.
- (&pairBuilderStorage)->diffConformanceContext = diffContext;
-
- return modified;
- }
-
- bool stripDiffTypeInformation(IRInst* parent)
- {
- bool modified = false;
-
- auto child = parent->getFirstChild();
- while (child)
- {
- auto nextChild = child->getNextInst();
-
- switch (child->getOp())
- {
- case kIROp_DifferentiableTypeDictionary:
- child->removeAndDeallocate();
- child = nextChild;
- modified = true;
- continue;
- }
-
- if (child->getFirstChild() != nullptr)
- {
- modified |= stripDiffTypeInformation(child);
+ modified = processPairTypes(builder, child) | modified;
}
child = nextChild;
@@ -2186,12 +2074,13 @@ struct JVPDerivativeContext
}
JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
- module(module), sink(sink),
- diffConformanceContextStorage(module->getModuleInst()),
- pairBuilderStorage(&diffConformanceContextStorage)
+ module(module),
+ sink(sink),
+ autoDiffSharedContextStorage(module->getModuleInst()),
+ transcriberStorage(&autoDiffSharedContextStorage)
{
transcriberStorage.sink = sink;
- transcriberStorage.diffConformanceContext = &(diffConformanceContextStorage);
+ transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage);
transcriberStorage.pairBuilder = &(pairBuilderStorage);
}
@@ -2221,7 +2110,7 @@ struct JVPDerivativeContext
// Context to find and manage the witness tables for types
// implementing `IDifferentiable`
- DifferentiableTypeConformanceContext diffConformanceContextStorage;
+ AutoDiffSharedContext autoDiffSharedContextStorage;
// Builder for dealing with differential pair types.
DifferentialPairTypeBuilder pairBuilderStorage;
@@ -2243,7 +2132,6 @@ bool processForwardDifferentiableFuncs(
JVPDerivativeContext context(module, sink);
bool changed = context.processModule();
- changed |= context.stripDiffTypeInformation(module->getModuleInst());
return changed;
}
@@ -2258,6 +2146,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
case kIROp_ForwardDerivativeDecoration:
case kIROp_DerivativeMemberDecoration:
+ case kIROp_DifferentiableTypeDictionaryDecoration:
decor->removeAndDeallocate();
break;
default: