diff options
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 451 |
1 files changed, 170 insertions, 281 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 3d02d4fc0..d0bf8f347 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -6,6 +6,7 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" // origX, primalX, diffX // origX -> primalX (cloneEnv) @@ -26,11 +27,9 @@ struct Pair typedef Pair<IRInst*, IRInst*> InstPair; -struct DifferentiableTypeConformanceContext +struct AutoDiffSharedContext { - Dictionary<IRInst*, IRInst*> witnessTableMap; - - IRInst* inst = nullptr; + IRModuleInst* moduleInst = nullptr; // A reference to the builtin IDifferentiable interface type. // We use this to look up all the other types (and type exprs) @@ -62,114 +61,27 @@ struct DifferentiableTypeConformanceContext // bool isInterfaceAvailable = false; - // For handling generic blocks, we use a parent pointer to allow - // looking up types in all relevant scopes. - DifferentiableTypeConformanceContext* parent = nullptr; - DifferentiableTypeConformanceContext(DifferentiableTypeConformanceContext* parent, IRInst* inst) : parent(parent), inst(inst) + AutoDiffSharedContext(IRModuleInst* inModuleInst) + : moduleInst(inModuleInst) { - if (parent) + differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); + if (differentiableInterfaceType) { - differentiableInterfaceType = parent->differentiableInterfaceType; - differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey; - zeroMethodStructKey = parent->zeroMethodStructKey; - addMethodStructKey = parent->addMethodStructKey; - - isInterfaceAvailable = parent->isInterfaceAvailable; - } - else - { - differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); - if (differentiableInterfaceType) - { - differentialAssocTypeStructKey = findDifferentialTypeStructKey(); - zeroMethodStructKey = findZeroMethodStructKey(); - addMethodStructKey = findAddMethodStructKey(); - - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; - } - } - } - - DifferentiableTypeConformanceContext(IRInst* inst) : - DifferentiableTypeConformanceContext(nullptr, inst) - {} + differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + zeroMethodStructKey = findZeroMethodStructKey(); + addMethodStructKey = findAddMethodStructKey(); - // Lookup a witness table for the concreteType. One should exist if concreteType - // inherits (successfully) from IDifferentiable. - // - IRInst* lookUpConformanceForType(IRBuilder* builder, IRInst* type) - { - SLANG_ASSERT(isInterfaceAvailable); - // TODO: Cache the returned value to avoid repeatedly scanning through - // blocks looking for the type entries. - // - if (auto irWitness = builder->findDifferentiableTypeEntry(type, type->getParent())) - { - return irWitness; - } - - return nullptr; - } - - IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) - { - if (auto conformance = lookUpConformanceForType(builder, origType)) - { - if (auto witnessTable = as<IRWitnessTable>(conformance)) - { - for (auto entry : witnessTable->getEntries()) - { - if (entry->getRequirementKey() == key) - return entry->getSatisfyingVal(); - } - } - else if (auto witnessTableParam = as<IRParam>(conformance)) - { - return builder->emitLookupInterfaceMethodInst( - builder->getTypeKind(), - witnessTableParam, - key); - } - } - - return nullptr; - } - - // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. - // Note that inside a generic block, this will be a witness table lookup instruction - // that gets resolved during the specialization pass. - // - IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) - { - switch (origType->getOp()) - { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - case kIROp_VectorType: - return origType; + if (differentialAssocTypeStructKey) + isInterfaceAvailable = true; } - return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey); - } - - IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, zeroMethodStructKey); - } - - IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, addMethodStructKey); } private: IRInst* findDifferentiableInterface() { - if (auto module = as<IRModuleInst>(inst)) + if (auto module = as<IRModuleInst>(moduleInst)) { for (auto globalInst : module->getGlobalInsts()) { @@ -203,7 +115,7 @@ struct DifferentiableTypeConformanceContext IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index) { - if (as<IRModuleInst>(inst) && differentiableInterfaceType) + if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) { // Assume for now that IDifferentiable has exactly four fields. SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4); @@ -217,110 +129,126 @@ struct DifferentiableTypeConformanceContext return nullptr; } +}; - void loadWitnessTablesForInterface(IRInst* interfaceType) +namespace +{ + +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +{ + if (auto witnessTable = as<IRWitnessTable>(witness)) { - - if (auto module = as<IRModuleInst>(inst)) + for (auto entry : witnessTable->getEntries()) { - for (auto globalInst : module->getGlobalInsts()) - { - if (globalInst->getOp() == kIROp_WitnessTable && - cast<IRWitnessTableType>(globalInst->getDataType())->getConformanceType() == - interfaceType) - { - // TODO: Can we have multiple conformances for the same pair of types? - // TODO: Can type instrs be duplicated (i.e. two different float types)? And if they are duplicated, can - // we supply the dictionary with a custom equality rule that uses 'type1->equals(type2)' - witnessTableMap.Add(as<IRWitnessTable>(globalInst)->getConcreteType(), globalInst); - } - } + if (entry->getRequirementKey() == requirementKey) + return entry->getSatisfyingVal(); } - else if (auto generic = as<IRGeneric>(inst)) - { - List<IRParam*> typeParams; + } + else if (auto witnessTableParam = as<IRParam>(witness)) + { + return builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + witnessTableParam, + requirementKey); + } + return nullptr; +} + +} + +struct DifferentiableTypeConformanceContext +{ + AutoDiffSharedContext* sharedContext; + + IRGlobalValueWithCode* parentFunc = nullptr; + Dictionary<IRType*, IRInst*> differentiableWitnessDictionary; + + DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) + : sharedContext(shared) + {} + + void setFunc(IRGlobalValueWithCode* func) + { + parentFunc = func; + + auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); + SLANG_RELEASE_ASSERT(decor); - auto genericParam = generic->getFirstParam(); - while (genericParam) + // Build lookup dictionary for type witnesses. + for (auto child = decor->getFirstChild(); child; child = child->next) + { + if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) { - if (as<IRTypeType>(genericParam->getDataType())) + auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); + if (existingItem) { - typeParams.add(genericParam); + if (auto witness = as<IRWitnessTable>(item->getWitness())) + { + if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType) + continue; + } + *existingItem = item->getWitness(); } else - break; - - genericParam = genericParam->getNextParam(); - } - - Count tableIndex = 0; - while (genericParam) - { - SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType())); - - if (tableIndex >= typeParams.getCount()) - break; - - if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType())) { - // TODO(sai): Heavily flawed way to find the right witness table. - // Rewrite this part - if (witnessTableType->getConformanceType() == differentiableInterfaceType) - witnessTableMap.Add(typeParams[tableIndex], genericParam); + differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); } - else - break; - - tableIndex += 1; - genericParam = genericParam->getNextParam(); } - } - } -}; - -IRInst* findGlobal(IRInst* inst) -{ - if (inst->getParent() != inst->getModule()->getModuleInst()) + // Lookup a witness table for the concreteType. One should exist if concreteType + // inherits (successfully) from IDifferentiable. + // + IRInst* lookUpConformanceForType(IRInst* type) { - return findGlobal(inst->getParent()); + IRInst* foundResult = nullptr; + differentiableWitnessDictionary.TryGetValue(type, foundResult); + return foundResult; } - return inst; -} - -void moveGlobalToBeforeUses(IRBuilder*, IRInst* globalInst) -{ - HashSet<IRInst*> globalsOfUses; - for (auto use = globalInst->firstUse; use; use = use->nextUse) + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) { - globalsOfUses.Add(findGlobal(use->getUser())); + if (auto conformance = lookUpConformanceForType(origType)) + { + return _lookupWitness(builder, conformance, key); + } + return nullptr; } - IRInst* earliestUse = nullptr; - for (auto cursor = globalInst; cursor; cursor = cursor->getPrevInst()) - { - if (globalsOfUses.Contains(cursor)) + // Lookup and return the 'Differential' type declared in the concrete type + // in order to conform to the IDifferentiable interface. + // Note that inside a generic block, this will be a witness table lookup instruction + // that gets resolved during the specialization pass. + // + IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) + { + switch (origType->getOp()) { - earliestUse = cursor; + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + case kIROp_VectorType: + return origType; } + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); } - if (earliestUse) + IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) { - globalInst->insertBefore(earliestUse); + return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + } + + IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); } -} + +}; struct DifferentialPairTypeBuilder { - - DifferentialPairTypeBuilder(DifferentiableTypeConformanceContext* diffConformanceContext) : - diffConformanceContext(diffConformanceContext) - {} IRStructField* findField(IRInst* type, IRStructKey* key) { @@ -454,14 +382,6 @@ struct DifferentialPairTypeBuilder return emitFieldAccessor(builder, baseInst, this->globalDiffKey); } - void relocateNewTypes(IRBuilder* builder) - { - for (auto typeInst : generatedTypeList) - { - moveGlobalToBeforeUses(builder, typeInst); - } - } - IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder) { if (!this->globalDiffKey) @@ -496,27 +416,23 @@ struct DifferentialPairTypeBuilder return this->globalPrimalKey; } - IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) + IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType) { - if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType)) - { - SLANG_ASSERT(!as<IRParam>(origBaseType)); - - auto pairStructType = builder->createStructType(); - builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); - builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*) diffBaseType); + SLANG_ASSERT(!as<IRParam>(origBaseType)); + SLANG_ASSERT(diffType); + auto pairStructType = builder->createStructType(); + builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); + builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType); - return pairStructType; - } - return nullptr; + return pairStructType; } - IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType) + IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType) { if (pairTypeCache.ContainsKey(origBaseType)) return pairTypeCache[origBaseType]; - auto pairType = _createDiffPairType(builder, origBaseType); + auto pairType = _createDiffPairType(builder, origBaseType, diffType); pairTypeCache.Add(origBaseType, pairType); return pairType; @@ -524,8 +440,6 @@ struct DifferentialPairTypeBuilder Dictionary<IRInst*, IRInst*> pairTypeCache; - DifferentiableTypeConformanceContext* diffConformanceContext; - IRStructKey* globalPrimalKey = nullptr; IRStructKey* globalDiffKey = nullptr; @@ -553,11 +467,17 @@ struct JVPTranscriber DiagnosticSink* sink; // Type conformance information. - DifferentiableTypeConformanceContext* diffConformanceContext; + AutoDiffSharedContext* autoDiffSharedContext; // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct DifferentialPairTypeBuilder* pairBuilder; + DifferentiableTypeConformanceContext differentiableTypeConformanceContext; + + JVPTranscriber(AutoDiffSharedContext* shared) + : differentiableTypeConformanceContext(shared) + {} + DiagnosticSink* getSink() { SLANG_ASSERT(sink); @@ -692,7 +612,7 @@ struct JVPTranscriber { case kIROp_Param: if (as<IRTypeType>(primalType->getDataType())) - return (IRType*)(diffConformanceContext->getDifferentialForType( + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( builder, (IRType*)primalType)); else if (as<IRWitnessTableType>(primalType->getDataType())) @@ -737,7 +657,7 @@ struct JVPTranscriber } default: - return (IRType*)(diffConformanceContext->getDifferentialForType(builder, (IRType*)primalType)); + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); } } @@ -753,8 +673,10 @@ struct JVPTranscriber else return nullptr; } - - return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType); + auto diffType = differentiateType(builder, primalType); + if (diffType) + return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType, diffType); + return nullptr; } InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) @@ -1325,7 +1247,7 @@ struct JVPTranscriber { // Since primalType has a corresponding differential type, we can lookup the // definition for zero(). - auto zeroMethod = this->diffConformanceContext->getZeroMethodForType(builder, primalType); + auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); SLANG_ASSERT(zeroMethod); auto emptyArgList = List<IRInst*>(); @@ -1333,6 +1255,11 @@ struct JVPTranscriber } else { + if (isScalarIntegerType(primalType)) + { + return builder->getIntValue(primalType, 0); + } + getSink()->diagnose(primalType->sourceLoc, Diagnostics::internalCompilerError, "could not generate zero value for given type"); @@ -1359,17 +1286,6 @@ struct JVPTranscriber for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) this->transcribe(builder, param); - // Look for the differentiable type dictionary and clone it (and anything else we might need). - // TODO: This logic might have issues if there are additional instructions (say lookup_interface_requirement) - // that are operands. - // TODO: This is currently cloning the global dictionary. Should only clone dictionaries in generic blocks. - if (auto origDict = builder->findDifferentiableTypeDictionary(origBlock)) - { - auto clonedDict = cloneInst(&cloneEnv, builder, origDict); - mapPrimalInst(origDict, clonedDict); - mapDifferentialInst(origDict, clonedDict); - } - // Then, run through every instruction and use the transcriber to generate the appropriate // derivative code. // @@ -1547,6 +1463,8 @@ struct JVPTranscriber { IRFunc* primalFunc = nullptr; + differentiableTypeConformanceContext.setFunc(origFunc); + auto oldLoc = builder->getInsertLoc(); // If this is a top-level function, there is no need to clone it @@ -1602,6 +1520,16 @@ struct JVPTranscriber // Transcribe a generic definition InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric) { + auto innerVal = findInnerMostGenericReturnVal(origGeneric); + if (auto innerFunc = as<IRFunc>(innerVal)) + { + differentiableTypeConformanceContext.setFunc(innerFunc); + } + else + { + return InstPair(origGeneric, nullptr); + } + // For now, we assume there's only one generic layer. So this inst must be top level bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr); SLANG_RELEASE_ASSERT(isTopLevel); @@ -1757,10 +1685,6 @@ struct JVPTranscriber case kIROp_ifElse: return transcribeIfElse(builder, as<IRIfElse>(origInst)); - case kIROp_DifferentiableTypeDictionary: - // Ignore dictionary insts. - return InstPair(nullptr, nullptr); - } // If none of the cases have been hit, check if the instruction is a @@ -1885,11 +1809,8 @@ struct JVPDerivativeContext // IRDifferentialPairGetPrimal with 'primal' field access, and // IRMakeDifferentialPair with an IRMakeStruct. // - modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage)); + modified |= processPairTypes(builder, module->getModuleInst()); - // Temporary fix: Move generated types, if any, to before their use locations. - (&pairBuilderStorage)->relocateNewTypes(builder); - return modified; } @@ -1981,7 +1902,7 @@ struct JVPDerivativeContext return true; } - IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext*) + IRInst* lowerPairType(IRBuilder* builder, IRType* type) { if (auto pairType = as<IRDifferentialPairType>(type)) @@ -1990,13 +1911,18 @@ struct JVPDerivativeContext if (!as<IRType>(pairType->getValueType())) { - // Do not handle non-concrete types. return nullptr; } - + auto witness = pairType->getWitness(); + auto diffType = _lookupWitness(builder, witness, autoDiffSharedContextStorage.differentialAssocTypeStructKey); + if (!diffType) + { + return nullptr; + } auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( builder, - pairType->getValueType()); + pairType->getValueType(), + (IRType*)(diffType)); pairType->replaceUsesWith(diffPairStructType); pairType->removeAndDeallocate(); @@ -2017,12 +1943,12 @@ struct JVPDerivativeContext return nullptr; } - IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext) + IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) { if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) { - if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext)) + if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType())) { builder->setInsertBefore(makePairInst); @@ -2041,11 +1967,11 @@ struct JVPDerivativeContext return nullptr; } - IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext) + IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst) { if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) { - if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext)) + if (lowerPairType(builder, getDiffInst->getBase()->getDataType())) { builder->setInsertBefore(getDiffInst); @@ -2057,7 +1983,7 @@ struct JVPDerivativeContext } else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) { - if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext)) + if (lowerPairType(builder, getPrimalInst->getBase()->getDataType())) { builder->setInsertBefore(getPrimalInst); @@ -2072,16 +1998,10 @@ struct JVPDerivativeContext return nullptr; } - bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren, DifferentiableTypeConformanceContext* diffContext) + bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) { bool modified = false; - // Create a new sub-context to scan witness tables inside workItem - // (mainly relevant if instWithChildren is a generic scope) - // - auto subContext = DifferentiableTypeConformanceContext(diffContext, instWithChildren); - (&pairBuilderStorage)->diffConformanceContext = (&subContext); - for (auto child = instWithChildren->getFirstChild(); child; ) { // Make sure the builder is at the right level. @@ -2092,53 +2012,21 @@ struct JVPDerivativeContext switch (child->getOp()) { case kIROp_DifferentialPairType: - lowerPairType(builder, as<IRType>(child), &subContext); + lowerPairType(builder, as<IRType>(child)); break; case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: - lowerPairAccess(builder, child, &subContext); + lowerPairAccess(builder, child); break; case kIROp_MakeDifferentialPair: - lowerMakePair(builder, child, &subContext); + lowerMakePair(builder, child); break; default: if (child->getFirstChild()) - modified = processPairTypes(builder, child, (&subContext)) | modified; - } - - child = nextChild; - } - - // Reset the context back to the parent. - (&pairBuilderStorage)->diffConformanceContext = diffContext; - - return modified; - } - - bool stripDiffTypeInformation(IRInst* parent) - { - bool modified = false; - - auto child = parent->getFirstChild(); - while (child) - { - auto nextChild = child->getNextInst(); - - switch (child->getOp()) - { - case kIROp_DifferentiableTypeDictionary: - child->removeAndDeallocate(); - child = nextChild; - modified = true; - continue; - } - - if (child->getFirstChild() != nullptr) - { - modified |= stripDiffTypeInformation(child); + modified = processPairTypes(builder, child) | modified; } child = nextChild; @@ -2186,12 +2074,13 @@ struct JVPDerivativeContext } JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : - module(module), sink(sink), - diffConformanceContextStorage(module->getModuleInst()), - pairBuilderStorage(&diffConformanceContextStorage) + module(module), + sink(sink), + autoDiffSharedContextStorage(module->getModuleInst()), + transcriberStorage(&autoDiffSharedContextStorage) { transcriberStorage.sink = sink; - transcriberStorage.diffConformanceContext = &(diffConformanceContextStorage); + transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage); transcriberStorage.pairBuilder = &(pairBuilderStorage); } @@ -2221,7 +2110,7 @@ struct JVPDerivativeContext // Context to find and manage the witness tables for types // implementing `IDifferentiable` - DifferentiableTypeConformanceContext diffConformanceContextStorage; + AutoDiffSharedContext autoDiffSharedContextStorage; // Builder for dealing with differential pair types. DifferentialPairTypeBuilder pairBuilderStorage; @@ -2243,7 +2132,6 @@ bool processForwardDifferentiableFuncs( JVPDerivativeContext context(module, sink); bool changed = context.processModule(); - changed |= context.stripDiffTypeInformation(module->getModuleInst()); return changed; } @@ -2258,6 +2146,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) { case kIROp_ForwardDerivativeDecoration: case kIROp_DerivativeMemberDecoration: + case kIROp_DifferentiableTypeDictionaryDecoration: decor->removeAndDeallocate(); break; default: |
