diff options
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1177 |
1 files changed, 1027 insertions, 150 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 4edd8eabe..7507e2fac 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -230,9 +230,6 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( } else if (auto specializedType = as<IRSpecialize>(pairType)) { - // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's - // type, emit the specialization type. - // auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase())); if (auto genericBasePairStructType = as<IRStructType>(genericType)) { @@ -263,14 +260,142 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor( return nullptr; } -IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) +bool isExistentialOrRuntimeInst(IRInst* inst) { - return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); + if (auto lookup = as<IRLookupWitnessMethod>(inst)) + { + return isExistentialOrRuntimeInst(lookup->getWitnessTable()); + } + + return as<IRExtractExistentialType>(inst) || as<IRExtractExistentialWitnessTable>(inst) || + as<IRMakeExistential>(inst) || as<IRInterfaceType>(inst->getDataType()); } -IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) +bool isRuntimeType(IRType* type) { - return emitFieldAccessor(builder, baseInst, this->globalDiffKey); + if (as<IRExtractExistentialType>(type)) + return true; + + if (auto lookup = as<IRLookupWitnessMethod>(type)) + { + return isExistentialOrRuntimeInst(lookup->getWitnessTable()); + } + + return false; +} + +IRInst* getExistentialBaseWitnessTable(IRBuilder* builder, IRType* type) +{ + if (auto lookupWitnessMethod = as<IRLookupWitnessMethod>(type)) + { + return lookupWitnessMethod->getWitnessTable(); + } + else if (auto extractExistentialType = as<IRExtractExistentialType>(type)) + { + return builder->emitExtractExistentialWitnessTable(extractExistentialType->getOperand(0)); + } + else + { + SLANG_UNEXPECTED("Unexpected existential type"); + } +} + +IRInst* getCacheKey(IRBuilder* builder, IRInst* primalType) +{ + if (auto lookupWitness = as<IRLookupWitnessMethod>(primalType)) + return lookupWitness->getRequirementKey(); + else if (auto extractExistentialType = as<IRExtractExistentialType>(primalType)) + { + auto interfaceType = extractExistentialType->getOperand(0)->getDataType(); + + // We will cache on the interface's this-type, since the interface type itself can be + // deallocated during the lowering process. + // + return builder->getThisType(interfaceType); + } + + return primalType; +} + +IRInst* DifferentialPairTypeBuilder::emitExistentialMakePair( + IRBuilder* builder, + IRInst* pairType, + IRInst* primalInst, + IRInst* diffInst) +{ + auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)pairType); + + auto pairTypeKey = cast<IRLookupWitnessMethod>(pairType)->getRequirementKey(); + auto makePairKey = makePairKeyMap[pairTypeKey]; + + auto makePairMethod = builder->emitLookupInterfaceMethodInst( + makePairFuncTypeMap[makePairKey], + baseWitness, + makePairKey); + + List<IRInst*> args; + args.add(primalInst); + args.add(diffInst); + + auto makePairVal = builder->emitCallInst((IRType*)pairType, makePairMethod, args); + + return makePairVal; +} + +IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess( + IRBuilder* builder, + IRType* loweredPairType, + IRInst* baseInst) +{ + if (isRuntimeType(loweredPairType)) + { + auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)loweredPairType); + + auto pairTypeKey = cast<IRLookupWitnessMethod>(loweredPairType)->getRequirementKey(); + auto getPrimalKey = getPrimalKeyMap[pairTypeKey]; + + auto primalFieldMethod = builder->emitLookupInterfaceMethodInst( + getPrimalFuncTypeMap[getPrimalKey], + baseWitness, + getPrimalKey); + + auto primalFieldVal = + builder->emitCallInst(primalTypeMap[loweredPairType], primalFieldMethod, baseInst); + + return primalFieldVal; + } + else + { + return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); + } +} + +IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess( + IRBuilder* builder, + IRType* loweredPairType, + IRInst* baseInst) +{ + if (isRuntimeType(loweredPairType)) + { + auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)loweredPairType); + + auto pairTypeKey = cast<IRLookupWitnessMethod>(loweredPairType)->getRequirementKey(); + auto getDiffKey = getDiffKeyMap[pairTypeKey]; + + auto diffFieldMethod = builder->emitLookupInterfaceMethodInst( + getDiffFuncTypeMap[getDiffKey], + baseWitness, + getDiffKey); + + auto diffFieldVal = + builder->emitCallInst(diffTypeMap[loweredPairType], diffFieldMethod, baseInst); + + return diffFieldVal; + } + else + { + return emitFieldAccessor(builder, baseInst, this->globalDiffKey); + } } IRStructKey* DifferentialPairTypeBuilder::_getOrCreateDiffStructKey() @@ -307,6 +432,380 @@ IRStructKey* DifferentialPairTypeBuilder::_getOrCreatePrimalStructKey() return this->globalPrimalKey; } +IRInst* DifferentialPairTypeBuilder::getOrCreateCommonDiffPairInterface(IRBuilder* builder) +{ + if (!this->commonDiffPairInterface) + { + this->commonDiffPairInterface = builder->createInterfaceType(0, nullptr); + builder->addNameHintDecoration( + this->commonDiffPairInterface, + UnownedStringSlice("IDiffPair")); + } + + return this->commonDiffPairInterface; +} + +IRInst* DifferentialPairTypeBuilder::_createDiffPairInterfaceRequirement( + IRType* origBaseType, + IRType*) +{ + // We will create an interface requirement for the type's pair & then create implementations in + // all the implementing witness tables. + // + + IRBuilder builder(sharedContext->moduleInst); + + // Find the right interface to put the requirement in. + IRInterfaceType* interfaceType = nullptr; + + // Find the effective type to put in the requirement entry + // for the base type + // + IRType* requirementBaseType = nullptr; + + // Requirement key (only used for associated types) + // + IRInst* requirementKey = nullptr; + + // Add a name hint to the key. + StringBuilder nameBuilderReqKey; + nameBuilderReqKey << "DiffPair_Req_"; + + if (auto lookup = as<IRLookupWitnessMethod>(origBaseType)) + { + interfaceType = + cast<IRInterfaceType>(cast<IRWitnessTableType>(lookup->getWitnessTable()->getDataType()) + ->getConformanceType()); + + requirementBaseType = + cast<IRType>(findInterfaceRequirement(interfaceType, lookup->getRequirementKey())); + + requirementKey = lookup->getRequirementKey(); + + if (auto nameHint = lookup->getRequirementKey()->findDecoration<IRNameHintDecoration>()) + { + nameBuilderReqKey << nameHint->getName(); + } + else + { + nameBuilderReqKey << "unknown_assoc_type"; + } + } + else if (auto extractType = as<IRExtractExistentialType>(origBaseType)) + { + auto existentialType = extractType->getOperand(0); + interfaceType = cast<IRInterfaceType>(existentialType->getDataType()); + requirementBaseType = builder.getThisType(interfaceType); + + requirementKey = nullptr; + + if (auto nameHint = interfaceType->findDecoration<IRNameHintDecoration>()) + { + nameBuilderReqKey << nameHint->getName(); + } + else + { + nameBuilderReqKey << "unknown_interface_type"; + } + } + else + { + SLANG_UNEXPECTED("Unexpected type for differential pair interface requirement"); + } + + auto diffPairInterfaceType = + cast<IRInterfaceType>(getOrCreateCommonDiffPairInterface(&builder)); + + // Add 4 requirements to the interface: + // the associated pair type, getPrimal, getDiff & makePair + // + builder.setInsertInto(interfaceType); + IRStructKey* diffPairRequirementKey = builder.createStructKey(); + IRStructKey* getPrimalRequirementKey = builder.createStructKey(); + IRStructKey* getDiffRequirementKey = builder.createStructKey(); + IRStructKey* makePairRequirementKey = builder.createStructKey(); + + makePairKeyMap[diffPairRequirementKey] = makePairRequirementKey; + getPrimalKeyMap[diffPairRequirementKey] = getPrimalRequirementKey; + getDiffKeyMap[diffPairRequirementKey] = getDiffRequirementKey; + + List<IRInst*> entries; + + // Add all the old requirements to the new interface. + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + entries.add(interfaceType->getOperand(i)); + + // + // Create the new requirement entries. + // + + { + // Create & insert the requirement key. + List<IRInterfaceType*> constraintTypes; + constraintTypes.add(diffPairInterfaceType); + auto entry = builder.createInterfaceRequirementEntry( + diffPairRequirementKey, + builder.getAssociatedType(constraintTypes.getArrayView())); + + builder.addNameHintDecoration(diffPairRequirementKey, nameBuilderReqKey.getUnownedSlice()); + entries.add(entry); + } + + { + // Create & insert the getPrimal requirement. + + List<IRType*> paramTypes; + List<IRInterfaceType*> paramConstraintTypes; + paramConstraintTypes.add(diffPairInterfaceType); + paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView())); + + auto entryFuncType = builder.getFuncType(paramTypes, requirementBaseType); + auto entry = + builder.createInterfaceRequirementEntry(getPrimalRequirementKey, entryFuncType); + + getPrimalFuncTypeMap[getPrimalRequirementKey] = entryFuncType; + + StringBuilder entryNameBuilder; + entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_getPrimal"; + builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice()); + + entries.add(entry); + } + + { + // Create & insert the getDiff requirement. + + List<IRType*> paramTypes; + List<IRInterfaceType*> paramConstraintTypes; + paramConstraintTypes.add(diffPairInterfaceType); + paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView())); + + List<IRInterfaceType*> resultConstraintTypes; + resultConstraintTypes.add(sharedContext->differentiableInterfaceType); + auto resultType = builder.getAssociatedType(resultConstraintTypes.getArrayView()); + + auto entryFuncType = builder.getFuncType(paramTypes, resultType); + auto entry = builder.createInterfaceRequirementEntry(getDiffRequirementKey, entryFuncType); + + getDiffFuncTypeMap[getDiffRequirementKey] = entryFuncType; + + StringBuilder entryNameBuilder; + entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_getDiff"; + builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice()); + + entries.add(entry); + } + + { + // Create & insert the makePair requirement. + + List<IRType*> paramTypes; + paramTypes.add(requirementBaseType); + + List<IRInterfaceType*> paramConstraintTypes; + paramConstraintTypes.add(sharedContext->differentiableInterfaceType); + paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView())); + + List<IRInterfaceType*> resultConstraintTypes; + resultConstraintTypes.add(diffPairInterfaceType); + auto entryFuncType = builder.getFuncType( + paramTypes, + builder.getAssociatedType(resultConstraintTypes.getArrayView())); + auto entry = builder.createInterfaceRequirementEntry(makePairRequirementKey, entryFuncType); + + makePairFuncTypeMap[makePairRequirementKey] = entryFuncType; + + StringBuilder entryNameBuilder; + entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_makePair"; + builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice()); + + entries.add(entry); + } + + { + // Create the new interface type. + + auto newInterfaceType = + builder.createInterfaceType(entries.getCount(), entries.getBuffer()); + + // Transfer decorations from the old interface to the new one. + interfaceType->transferDecorationsTo(newInterfaceType); + interfaceType->replaceUsesWith(newInterfaceType); + + // Replace the interface maps in the caches. + if (this->pairTypeCache.containsKey(interfaceType)) + this->pairTypeCache[newInterfaceType] = this->pairTypeCache[interfaceType]; + + if (this->existentialPairTypeCache.containsKey(interfaceType)) + this->existentialPairTypeCache[newInterfaceType] = + this->existentialPairTypeCache[interfaceType]; + + interfaceType->removeAndDeallocate(); + interfaceType = newInterfaceType; + } + + // + // Implement the requirements in all the witness tables. + // + + // Collect all witness tables of the given interfaceType. + List<IRWitnessTable*> concreteWitnessTables; + auto witnessTableType = builder.getWitnessTableType(interfaceType); + for (auto use = witnessTableType->firstUse; use; use = use->nextUse) + { + if (auto witnessTable = as<IRWitnessTable>(use->getUser())) + { + if (use->getUser()->getFullType() == witnessTableType) + concreteWitnessTables.add(witnessTable); + } + } + + DifferentiableTypeConformanceContext ctx(sharedContext); + ctx.buildGlobalWitnessDictionary(); + + for (auto concreteWitnessTable : concreteWitnessTables) + { + IRType* concretePrimalType = nullptr; + + // What requirement are we trying to satisfy? + if (as<IRThisType>(requirementBaseType)) + { + // For this types, we should lower the concrete type of the witness table itself. + concretePrimalType = concreteWitnessTable->getConcreteType(); + } + else if (as<IRAssociatedType>(requirementBaseType)) + { + // For associated types, look it up in the witness table. + concretePrimalType = + (IRType*)findWitnessTableEntry(concreteWitnessTable, requirementKey); + } + else + { + // We shouldn't see any other case here. + SLANG_UNEXPECTED("Unexpected requirement base type"); + } + + // Create the pair type. + auto witness = ctx.tryGetDifferentiableWitness( + &builder, + concretePrimalType, + DiffConformanceKind::Value); + + // Really should not see a case where the original interface is differentiable, but + // we can't find the witness table. + // + SLANG_ASSERT(witness); + + auto concretePairType = builder.getDifferentialPairType( + concretePrimalType, + witness); // TODO: Need to handle the other conformance kinds + auto concreteDiffType = + (IRType*)_getDiffTypeFromPairType(sharedContext, &builder, concretePairType); + + auto loweredStructType = (IRType*)lowerDiffPairType(&builder, concretePairType); + + // Create an (empty) witness table for loweredStuctType : IDiffPair_... + // This is just so that there is a bound on the any-value-size for each group of pair types. + // + auto witnessTable = builder.createWitnessTable(diffPairInterfaceType, loweredStructType); + builder.addKeepAliveDecoration(witnessTable); + + builder.setInsertInto(concreteWitnessTable); + + // Create the associated type entry. + { + builder.createWitnessTableEntry( + concreteWitnessTable, + diffPairRequirementKey, + loweredStructType); + } + + // Create the getPrimal method. + { + auto primalMethod = builder.createFunc(); + + StringBuilder nameBuilder; + getTypeNameHint(nameBuilder, loweredStructType); + nameBuilder << "_getPrimal"; + builder.addNameHintDecoration(primalMethod, nameBuilder.getUnownedSlice()); + + primalMethod->setFullType(builder.getFuncType( + List<IRType*>({(IRType*)loweredStructType}), + concretePrimalType)); + + builder.setInsertInto(primalMethod); + auto block = builder.emitBlock(); + builder.setInsertInto(block); + auto param = builder.emitParam((IRType*)loweredStructType); + builder.emitReturn( + builder.emitFieldExtract(concretePrimalType, param, _getOrCreatePrimalStructKey())); + + builder.setInsertInto(concreteWitnessTable); + builder.createWitnessTableEntry( + concreteWitnessTable, + getPrimalRequirementKey, + primalMethod); + } + + // Create the getDiff method. + { + auto diffMethod = builder.createFunc(); + + StringBuilder nameBuilder; + getTypeNameHint(nameBuilder, loweredStructType); + nameBuilder << "_getDiff"; + builder.addNameHintDecoration(diffMethod, nameBuilder.getUnownedSlice()); + + diffMethod->setFullType( + builder.getFuncType(List<IRType*>({(IRType*)loweredStructType}), concreteDiffType)); + + builder.setInsertInto(diffMethod); + auto block = builder.emitBlock(); + builder.setInsertInto(block); + auto param = builder.emitParam((IRType*)loweredStructType); + builder.emitReturn( + builder.emitFieldExtract(concreteDiffType, param, _getOrCreateDiffStructKey())); + + builder.setInsertInto(concreteWitnessTable); + builder.createWitnessTableEntry( + concreteWitnessTable, + getDiffRequirementKey, + diffMethod); + } + + // Create the makePair method. + { + auto makePairMethod = builder.createFunc(); + + StringBuilder nameBuilder; + getTypeNameHint(nameBuilder, loweredStructType); + nameBuilder << "_makePair"; + builder.addNameHintDecoration(makePairMethod, nameBuilder.getUnownedSlice()); + + makePairMethod->setFullType(builder.getFuncType( + List<IRType*>({concretePrimalType, concreteDiffType}), + (IRType*)loweredStructType)); + + builder.setInsertInto(makePairMethod); + auto block = builder.emitBlock(); + builder.setInsertInto(block); + auto param1 = builder.emitParam(concretePrimalType); + auto param2 = builder.emitParam(concreteDiffType); + List<IRInst*> args = {param1, param2}; + auto pair = builder.emitMakeStruct((IRType*)loweredStructType, args); + builder.emitReturn(pair); + + builder.setInsertInto(concreteWitnessTable); + builder.createWitnessTableEntry( + concreteWitnessTable, + makePairRequirementKey, + makePairMethod); + } + } + + return diffPairRequirementKey; +} + IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, IRType* diffType) { switch (origBaseType->getOp()) @@ -333,6 +832,7 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I return pairStructType; } + IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(IRBuilder* builder, IRType* originalPairType) { IRInst* result = nullptr; @@ -352,26 +852,119 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(IRBuilder* builder, IRTyp // purposes. auto primalType = pairType->getValueType(); - if (pairTypeCache.tryGetValue(primalType, result)) - return result; - if (!pairType) + + if (isRuntimeType(primalType)) { - result = originalPairType; + // Existential case. + auto cacheKey = getCacheKey(builder, primalType); + auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType); + + IRInst* pairReqKey = nullptr; + if (!existentialPairTypeCache.tryGetValue(cacheKey, pairReqKey)) + { + pairReqKey = _createDiffPairInterfaceRequirement(primalType, (IRType*)diffType); + existentialPairTypeCache.add(cacheKey, pairReqKey); + } + + auto baseWitnessTable = getExistentialBaseWitnessTable(builder, primalType); + result = builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + baseWitnessTable, + pairReqKey); + + primalTypeMap[result] = primalType; + diffTypeMap[result] = (IRType*)diffType; + return result; } - if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalType)) + else if (auto typePack = as<IRTypePack>(primalType)) { - result = nullptr; - return result; + // Lower DiffPair(TypePack(a_0, a_1, ...), MakeWitnessPack(w_0, w_1, ...)) as + // TypePack(DiffPair(a_0, w_0), DiffPair(a_1, w_1), ...) + // + auto cacheKey = primalType; + if (pairTypeCache.tryGetValue(cacheKey, result)) + return result; + + auto packWitness = pairType->getWitness(); + + // Right now we only support concrete witness tables for type packs. + auto concretePackWitness = as<IRWitnessTable>(packWitness); + SLANG_ASSERT(concretePackWitness); + + // Get diff type pack. + IRTypePack* diffTypePack = nullptr; + + if (concretePackWitness->getConformanceType() == + this->sharedContext->differentiableInterfaceType) + diffTypePack = as<IRTypePack>(findWitnessTableEntry( + concretePackWitness, + this->sharedContext->differentialAssocTypeStructKey)); + else if ( + concretePackWitness->getConformanceType() == + this->sharedContext->differentiablePtrInterfaceType) + diffTypePack = as<IRTypePack>(findWitnessTableEntry( + concretePackWitness, + this->sharedContext->differentialAssocRefTypeStructKey)); + else + SLANG_UNEXPECTED("Unexpected witness table"); + + SLANG_ASSERT(diffTypePack); + + List<IRType*> args; + for (UInt i = 0; i < typePack->getOperandCount(); i++) + { + auto type = (IRType*)typePack->getOperand(i); + auto diffType = (IRType*)typePack->getOperand(i); + + if (pairTypeCache.tryGetValue(type, result)) + { + args.add((IRType*)result); + continue; + } + + // Lower the diff pair type. + auto loweredPairType = (IRType*)_createDiffPairType(type, diffType); + + pairTypeCache.add(type, loweredPairType); + args.add(loweredPairType); + } + + auto loweredTypePack = builder->getTypePack(args.getCount(), args.getBuffer()); + // TODO: Unify the cache between the three cases. + pairTypeCache.add(cacheKey, loweredTypePack); + + return loweredTypePack; } + else + { + auto cacheKey = primalType; + if (pairTypeCache.tryGetValue(primalType, result)) + return result; - auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType); - if (!diffType) - return result; - result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); - pairTypeCache.add(primalType, result); + if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalType)) + { + result = nullptr; + return result; + } + + if (as<IRThisType>(primalType) || as<IRAssociatedType>(primalType)) + { + List<IRInterfaceType*> constraintTypes; + constraintTypes.add(this->commonDiffPairInterface); + return builder->getAssociatedType(constraintTypes.getArrayView()); + } + + auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType); + if (!diffType) + return result; + + // Concrete case. + result = _createDiffPairType(primalType, (IRType*)diffType); + pairTypeCache.add(cacheKey, result); - return result; + return result; + } } IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst) @@ -550,6 +1143,13 @@ IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWit auto innerWitnessTableType = cast<IRWitnessTableType>(operand); return cast<IRInterfaceType>(innerWitnessTableType->getConformanceType()); } + else if (auto genericWitness = as<IRGeneric>(witness)) + { + // This is a generic witness table. + auto innerWitness = getGenericReturnVal(genericWitness); + SLANG_ASSERT(as<IRWitnessTableType>(innerWitness->getDataType())); + return getConformanceTypeFromWitness(innerWitness); + } else { SLANG_UNEXPECTED("Unexpected witness type"); @@ -558,81 +1158,134 @@ IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWit return diffInterfaceType; } +List<IRDifferentiableTypeAnnotation*> DifferentiableTypeConformanceContext::getAnnotations( + IRGlobalValueWithCode* code) +{ + // Scan function for all IRDifferentiableTypeAnnotation insts. + List<IRDifferentiableTypeAnnotation*> annotations; + for (auto block : code->getBlocks()) + { + for (auto child : block->getChildren()) + { + if (auto annotation = as<IRDifferentiableTypeAnnotation>(child)) + { + annotations.add(annotation); + } + } + } + + return annotations; +} + +List<IRDifferentiableTypeAnnotation*> DifferentiableTypeConformanceContext::getAnnotations( + IRModuleInst* module) +{ + // Scan module for all IRDifferentiableTypeAnnotation insts. + List<IRDifferentiableTypeAnnotation*> annotations; + for (auto globalInst : module->getGlobalInsts()) + { + if (auto annotation = as<IRDifferentiableTypeAnnotation>(globalInst)) + { + annotations.add(annotation); + } + } + + return annotations; +} + void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { parentFunc = func; + List<IRDifferentiableTypeAnnotation*> annotations = getAnnotations(func); - auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); - SLANG_RELEASE_ASSERT(decor); - - // Build lookup dictionary for type witnesses. - for (auto child = decor->getFirstChild(); child; child = child->next) + // Go up the parents of func & add the annotations of any IRGeneric or IRModule parent: + IRInst* parent = func; + while (parent) { - if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) + if (auto upperFunc = as<IRGlobalValueWithCode>(parent)) { - IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness()); + // TODO: Cache this. + auto parentAnnotations = getAnnotations(upperFunc); + annotations.addRange(parentAnnotations); + } + else if (auto module = as<IRModuleInst>(parent)) + { + // TODO: Cache this. + auto parentAnnotations = getAnnotations(module); + annotations.addRange(parentAnnotations); + } + parent = parent->getParent(); + } - SLANG_ASSERT( - diffInterfaceType == sharedContext->differentiableInterfaceType || - diffInterfaceType == sharedContext->differentiablePtrInterfaceType); + for (auto item : annotations) + { + IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness()); - auto existingItem = - differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType()); - if (existingItem) - { - *existingItem = item->getWitness(); - } - else - { - auto witness = item->getWitness(); + SLANG_ASSERT( + diffInterfaceType == sharedContext->differentiableInterfaceType || + diffInterfaceType == sharedContext->differentiablePtrInterfaceType); - // Also register the type's differential type with the same witness. - auto concreteType = item->getConcreteType(); - IRBuilder subBuilder(item->getConcreteType()); - if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType)) + auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getBaseType()); + if (existingItem) + { + *existingItem = item->getWitness(); + } + else + { + auto witness = item->getWitness(); + + // Also register the type's differential type with the same witness. + auto concreteType = item->getBaseType(); + IRBuilder subBuilder(item->getBaseType()); + if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType)) + { + // For tuple types with concrete element types, + // register the differential type for each element, but don't register for the + // tuple/typepack itself. + if (auto witnessPack = as<IRMakeWitnessPack>(witness)) { - // For tuple types with concrete element types, - // register the differential type for each element, but don't register for the - // tuple/typepack itself. - if (auto witnessPack = as<IRMakeWitnessPack>(witness)) + + for (UInt i = 0; i < concreteType->getOperandCount(); i++) { + auto element = concreteType->getOperand(i); + auto elementWitness = witnessPack->getOperand(i); - for (UInt i = 0; i < concreteType->getOperandCount(); i++) - { - auto element = concreteType->getOperand(i); - auto elementWitness = witnessPack->getOperand(i); - - if (diffInterfaceType == sharedContext->differentiableInterfaceType) - addTypeToDictionary((IRType*)element, elementWitness); - else if ( - diffInterfaceType == sharedContext->differentiablePtrInterfaceType) - addTypeToDictionary((IRType*)element, elementWitness); - } - return; + if (diffInterfaceType == sharedContext->differentiableInterfaceType) + addTypeToDictionary((IRType*)element, elementWitness); + else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType) + addTypeToDictionary((IRType*)element, elementWitness); } + return; } + } - addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness()); + addTypeToDictionary((IRType*)item->getBaseType(), item->getWitness()); - if (!as<IRInterfaceType>(item->getConcreteType())) - { - addTypeToDictionary( - (IRType*)_lookupWitness( - &subBuilder, - item->getWitness(), - sharedContext->differentialAssocTypeStructKey, - subBuilder.getTypeKind()), - item->getWitness()); - } + // TODO: Is this really needed? + if (!as<IRInterfaceType>(item->getBaseType()) && + !as<IRAssociatedType>(item->getBaseType())) + { + addTypeToDictionary( + (IRType*)_lookupWitness( + &subBuilder, + item->getWitness(), + sharedContext->differentialAssocTypeStructKey, + subBuilder.getTypeKind()), + item->getWitness()); + } - if (auto diffPairType = as<IRDifferentialPairTypeBase>(item->getConcreteType())) - { - // For differential pair types, register the differential type as well. - IRBuilder builder(diffPairType); - builder.setInsertAfter(diffPairType->getWitness()); + // TODO: Is this really needed? + if (auto diffPairType = as<IRDifferentialPairTypeBase>(item->getBaseType())) + { + // For differential pair types, register the differential type as well. + IRBuilder builder(diffPairType); + builder.setInsertAfter(diffPairType->getWitness()); - // TODO(sai): lot of this logic is duplicated. need to refactor. + // TODO(sai): lot of this logic is duplicated. need to refactor. + if (!as<IRInterfaceType>(diffPairType->getValueType()) && + !as<IRAssociatedType>(diffPairType->getValueType())) + { auto diffType = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? _lookupWitness( @@ -665,12 +1318,28 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) } } +IRWitnessTable* findGlobalWitness(IRInterfaceType* interface, IRInst* type) +{ + for (auto use = type->firstUse; use; use = use->nextUse) + { + if (auto witnessTable = as<IRWitnessTable>(use->getUser())) + { + if (witnessTable->getConcreteType() == type && + witnessTable->getConformanceType() == interface) + return witnessTable; + } + } + + return nullptr; +} + IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType( IRInst* type, DiffConformanceKind kind) { IRInst* foundResult = nullptr; differentiableTypeWitnessDictionary.tryGetValue(type, foundResult); + if (!foundResult) return nullptr; @@ -791,8 +1460,8 @@ IRInst* DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface return nullptr; } -// Given an interface type, return the lookup path from a witness table of `type` to a witness table -// of `supType`. +// Given an interface type, return the lookup path from a witness table of `type` to a witness +// table of `supType`. static bool _findInterfaceLookupPathImpl( HashSet<IRInst*>& processedTypes, IRInterfaceType* supType, @@ -967,6 +1636,11 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { addTypeToDictionary(pairType->getValueType(), pairType->getWitness()); } + + if (auto annotation = as<IRDifferentiableTypeAnnotation>(globalInst)) + { + addTypeToDictionary((IRType*)annotation->getBaseType(), annotation->getWitness()); + } } } @@ -1071,6 +1745,20 @@ IRType* DifferentiableTypeConformanceContext::differentiateType( } } +IRType* getAssociatedTypeForKey(IRInst* key) +{ + for (auto use = key->firstUse; use; use = use->nextUse) + { + if (auto interfaceReq = as<IRInterfaceRequirementEntry>(key)) + { + if (auto assocType = as<IRAssociatedType>(interfaceReq->getRequirementVal())) + return assocType; + } + } + + return nullptr; +} + IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness( IRBuilder* builder, IRInst* primalType, @@ -1118,8 +1806,9 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness( } else if (auto lookup = as<IRLookupWitnessMethod>(primalType)) { - // For types that are lookups from a table, we can simply lookup the witness from the same - // table + // Trivial cases: For types that are lookups from a table, we can simply lookup the + // witness from the same table + // if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey) { witness = builder->emitLookupInterfaceMethodInst( @@ -1203,8 +1892,8 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness( auto p0 = b.emitParam(diffDiffPairType); auto p1 = b.emitParam(diffDiffPairType); - // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value - // type == diff type. + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that + // value type == diff type. auto innerAdd = _lookupWitness( &b, innerWitness, @@ -1325,8 +2014,8 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( auto p0 = b.emitParam(diffArrayType); auto p1 = b.emitParam(diffArrayType); - // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value - // type == diff type. + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that + // value type == diff type. auto innerAdd = _lookupWitness( &b, innerWitness, @@ -1566,6 +2255,143 @@ IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness return nullptr; } +IRInst* DifferentiableTypeConformanceContext::emitDAddOfDiffInstType( + IRBuilder* builder, + IRType* primalType, + IRInst* op1, + IRInst* op2) +{ + if (auto arrayType = as<IRArrayType>(primalType)) + { + // TODO: This case should really not be necessary anymore + auto diffElementType = + (IRType*)this->getDifferentialForType(builder, arrayType->getElementType()); + SLANG_RELEASE_ASSERT(diffElementType); + auto arraySize = arrayType->getElementCount(); + + if (auto constArraySize = as<IRIntLit>(arraySize)) + { + List<IRInst*> args; + for (IRIntegerValue i = 0; i < constArraySize->getValue(); i++) + { + auto index = builder->getIntValue(builder->getIntType(), i); + auto op1Val = builder->emitElementExtract(diffElementType, op1, index); + auto op2Val = builder->emitElementExtract(diffElementType, op2, index); + args.add( + emitDAddOfDiffInstType(builder, arrayType->getElementType(), op1Val, op2Val)); + } + auto diffArrayType = + builder->getArrayType(diffElementType, arrayType->getElementCount()); + return builder->emitMakeArray(diffArrayType, (UInt)args.getCount(), args.getBuffer()); + } + else + { + // TODO: insert a runtime loop here. + SLANG_UNIMPLEMENTED_X("dadd of dynamic array."); + } + } + else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType)) + { + // TODO: This case should really not be necessary anymore + auto diffType = (IRType*)this->getDiffTypeFromPairType(builder, diffPairUserType); + auto diffWitness = this->getDiffTypeWitnessFromPairType(builder, diffPairUserType); + + auto primal1 = builder->emitDifferentialPairGetPrimalUserCode(op1); + auto primal2 = builder->emitDifferentialPairGetPrimalUserCode(op2); + auto primal = + emitDAddOfDiffInstType(builder, diffPairUserType->getValueType(), primal1, primal2); + + auto diff1 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op1); + auto diff2 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op2); + auto diff = emitDAddOfDiffInstType(builder, diffType, diff1, diff2); + + auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); + return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff); + } + else if (as<IRInterfaceType>(primalType)) + { + // If our type is existential, we need to handle the case where + // one or both of our operands are null-type. + // + return emitDAddForExistentialType(builder, primalType, op1, op2); + } + else if (as<IRAssociatedType>(primalType)) + { + // Should not happen. associated type does not have any additional info, we can't + // lookup the necessary methods. + // + SLANG_UNEXPECTED("unexpected associated type during transposition"); + } + + auto addMethod = this->getAddMethodForType(builder, primalType); + + // Should exist. + SLANG_ASSERT(addMethod); + + return builder->emitCallInst( + (IRType*)this->getDifferentialForType(builder, primalType), + addMethod, + List<IRInst*>(op1, op2)); +} + +IRInst* DifferentiableTypeConformanceContext::emitDAddForExistentialType( + IRBuilder* builder, + IRType* primalType, + IRInst* op1, + IRInst* op2) +{ + return builder->emitCallInst( + (IRType*)this->getDifferentialForType(builder, primalType), + this->getOrCreateExistentialDAddMethod(), + List<IRInst*>({op1, op2})); +} + +IRInst* DifferentiableTypeConformanceContext::emitDZeroOfDiffInstType( + IRBuilder* builder, + IRType* primalType) +{ + if (auto arrayType = as<IRArrayType>(primalType)) + { + // TODO: This case should really not be necessary anymore + auto diffElementType = + (IRType*)this->getDifferentialForType(builder, arrayType->getElementType()); + SLANG_RELEASE_ASSERT(diffElementType); + auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount()); + auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType()); + return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero); + } + else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType)) + { + // TODO: This case should really not be necessary anymore. + auto primalZero = emitDZeroOfDiffInstType(builder, diffPairUserType->getValueType()); + auto diffZero = primalZero; + auto diffType = primalZero->getFullType(); + auto diffWitness = this->getDiffTypeWitnessFromPairType(builder, diffPairUserType); + auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); + return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primalZero, diffZero); + } + else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType)) + { + // Pack a null value into an existential type. + auto existentialZero = builder->emitMakeExistential( + this->sharedContext->differentiableInterfaceType, + this->emitNullDifferential(builder), + this->sharedContext->nullDifferentialWitness); + + return existentialZero; + } + + auto zeroMethod = this->getZeroMethodForType(builder, primalType); + + // Should exist. + SLANG_ASSERT(zeroMethod); + + return builder->emitCallInst( + (IRType*)this->getDifferentialForType(builder, primalType), + zeroMethod, + List<IRInst*>()); +} + void copyCheckpointHints( IRBuilder* builder, IRGlobalValueWithCode* oldInst, @@ -1883,6 +2709,7 @@ struct AutoDiffPass : public InstPassBase { bool result = false; OrderedHashSet<IRInst*> loweredIntermediateTypes; + Dictionary<IRInst*, IRGlobalValueWithCode*> typeToBwdFuncMap; // Replace all `BackwardDiffIntermediateContextType` insts with the struct type // that we generated during backward diff pass. @@ -1906,6 +2733,38 @@ struct AutoDiffPass : public InstPassBase if (type) { loweredIntermediateTypes.add(type); + + auto func = differentiateInst->getFunc(); + + if (auto spec = as<IRSpecialize>(func)) + func = spec->getBase(); + + if (auto generic = as<IRGeneric>(func)) + { + func = + cast<IRGlobalValueWithCode>(findGenericReturnVal(generic)); + + auto bwdFuncDecor = func->findDecoration< + IRBackwardDerivativePropagateDecoration>(); + + typeToBwdFuncMap.add( + type, + cast<IRGlobalValueWithCode>( + as<IRSpecialize>( + bwdFuncDecor->getBackwardDerivativePropagateFunc()) + ->getBase())); + } + else + { + auto bwdFuncDecor = func->findDecoration< + IRBackwardDerivativePropagateDecoration>(); + + typeToBwdFuncMap.add( + type, + cast<IRGlobalValueWithCode>( + bwdFuncDecor->getBackwardDerivativePropagateFunc())); + } + inst->replaceUsesWith(type); inst->removeAndDeallocate(); changed = true; @@ -1922,7 +2781,9 @@ struct AutoDiffPass : public InstPassBase } // Now we generate the differential type for the intermediate context type // to allow higher order differentiation. - generateDifferentialImplementationForContextType(loweredIntermediateTypes); + generateDifferentialImplementationForContextType( + loweredIntermediateTypes, + typeToBwdFuncMap); return result; } @@ -1977,22 +2838,13 @@ struct AutoDiffPass : public InstPassBase IRInst* addMethod = nullptr; }; - // Register the differential type for an intermediate context type to the derivative functions - // that uses the type. + // Register the differential type for an intermediate context type to the derivative + // functions that uses the type. void registerDiffContextType( IRBuilder& builder, - IRDifferentiableTypeDictionaryDecoration* diffDecor, OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes, IRInst* origType) { - HashSet<IRInst*> registeredType; - for (auto entry : diffDecor->getChildren()) - { - if (auto e = as<IRDifferentiableTypeDictionaryItem>(entry)) - { - registeredType.add(e->getOperand(0)); - } - } // Use a work list to recursively walk through all sub fields of the struct type. List<IRInst*> wlist; wlist.add(origType); @@ -2002,10 +2854,13 @@ struct AutoDiffPass : public InstPassBase IntermediateContextTypeDifferentialInfo diffInfo; if (!diffTypes.tryGetValue(t, diffInfo)) continue; - if (registeredType.add(t)) - builder.addDifferentiableTypeEntry(diffDecor, t, diffInfo.diffWitness); - else - continue; + + IRInst* args[] = {t, diffInfo.diffWitness}; + builder.emitIntrinsicInst( + builder.getVoidType(), + kIROp_DifferentiableTypeAnnotation, + 2, + args); if (auto structType = as<IRStructType>(getResolvedInstForDecorations(t))) { @@ -2017,7 +2872,9 @@ struct AutoDiffPass : public InstPassBase } } - void generateDifferentialImplementationForContextType(OrderedHashSet<IRInst*>& contextTypes) + void generateDifferentialImplementationForContextType( + OrderedHashSet<IRInst*>& contextTypes, + Dictionary<IRInst*, IRGlobalValueWithCode*> typeToBwdFuncMap) { // First we are going to topology sort all intermediate context types. OrderedHashSet<IRInst*> sortedContextTypes; @@ -2043,6 +2900,10 @@ struct AutoDiffPass : public InstPassBase IRBuilder builder(module); for (auto t : sortedContextTypes) { + auto func = typeToBwdFuncMap[t]; + DifferentiableTypeConformanceContext ctx(this->autodiffContext); + ctx.setFunc(func); + if (t->getOp() == kIROp_Generic || t->getOp() == kIROp_StructType) { // For generics/struct types, we will generate a new generic/struct type @@ -2050,7 +2911,7 @@ struct AutoDiffPass : public InstPassBase SLANG_RELEASE_ASSERT(t->getParent() && t->getParent()->getOp() == kIROp_Module); builder.setInsertBefore(t); - auto diffInfo = fillDifferentialTypeImplementation(diffTypes, t); + auto diffInfo = fillDifferentialTypeImplementation(&ctx, diffTypes, t); diffTypes[t] = diffInfo; } else if (auto specialize = as<IRSpecialize>(t)) @@ -2085,30 +2946,29 @@ struct AutoDiffPass : public InstPassBase // function without a intermediate-type via an interface. SLANG_RELEASE_ASSERT(diffTypes.containsKey(t)); } - } - // Register the differential types into the conformance dictionaries of the functions that - // uses them. - for (auto t : diffTypes) - { + if (!diffTypes.containsKey(t)) + continue; + + // If we created a new differential type, we need to place into the contexts of all + // functions that use it. + // HashSet<IRFunc*> registeredFuncs; - for (auto use = t.key->firstUse; use; use = use->nextUse) + for (auto use = t->firstUse; use; use = use->nextUse) { auto parentFunc = getParentFunc(use->getUser()); if (!parentFunc) continue; if (!registeredFuncs.add(parentFunc)) continue; - if (auto dictDecor = - parentFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) - { - registerDiffContextType(builder, dictDecor, diffTypes, t.key); - } + + registerDiffContextType(builder, diffTypes, t); } } } IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementationForStruct( + DifferentiableTypeConformanceContext* ctx, OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes, IRStructType* originalType, IRStructType* diffType) @@ -2122,6 +2982,7 @@ struct AutoDiffPass : public InstPassBase // Generate the fields for all differentiable members of the original struct type. struct FieldInfo { + IRType* primalType; IRStructField* field; IRInst* witness; }; @@ -2130,30 +2991,30 @@ struct AutoDiffPass : public InstPassBase for (auto field : originalType->getFields()) { IRInst* diffFieldWitness = nullptr; - if (auto diffDecor = - field->findDecoration<IRIntermediateContextFieldDifferentialTypeDecoration>()) - { - diffFieldWitness = diffDecor->getDifferentialWitness(); - } - else + + diffFieldWitness = ctx->tryGetDifferentiableWitness( + &builder, + field->getFieldType(), + DiffConformanceKind::Value); + + if (!diffFieldWitness) { IntermediateContextTypeDifferentialInfo diffFieldTypeInfo; diffTypes.tryGetValue(field->getFieldType(), diffFieldTypeInfo); diffFieldWitness = diffFieldTypeInfo.diffWitness; } + if (diffFieldWitness) { FieldInfo info; IRBuilder keyBuilder = builder; keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType)); auto diffKey = keyBuilder.createStructKey(); - auto diffFieldType = _lookupWitness( - &keyBuilder, - diffFieldWitness, - autodiffContext->differentialAssocTypeStructKey, - builder.getTypeKind()); + auto diffFieldType = ctx->getDifferentialForType(&builder, field->getFieldType()); + info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType); info.witness = diffFieldWitness; + info.primalType = field->getFieldType(); builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey); builder.addDecoration(diffKey, kIROp_DerivativeMemberDecoration, diffKey); diffFields.add(info); @@ -2172,16 +3033,10 @@ struct AutoDiffPass : public InstPassBase builder.setInsertInto(zeroMethod); builder.emitBlock(); List<IRInst*> fieldVals; + for (auto info : diffFields) { - auto innerZeroMethod = _lookupWitness( - &builder, - info.witness, - autodiffContext->zeroMethodStructKey, - autodiffContext->zeroMethodType); - IRInst* val = - builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr); - fieldVals.add(val); + fieldVals.add(ctx->emitDZeroOfDiffInstType(&builder, info.primalType)); } builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals)); } @@ -2203,20 +3058,15 @@ struct AutoDiffPass : public InstPassBase List<IRInst*> fieldVals; for (auto info : diffFields) { - auto innerAddMethod = _lookupWitness( - &builder, - info.witness, - autodiffContext->addMethodStructKey, - autodiffContext->addMethodType); IRInst* args[2] = { builder .emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()), builder .emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()), }; - IRInst* val = - builder.emitCallInst(info.field->getFieldType(), innerAddMethod, 2, args); - fieldVals.add(val); + + fieldVals.add( + ctx->emitDAddOfDiffInstType(&builder, info.primalType, args[0], args[1])); } builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals)); } @@ -2265,6 +3115,7 @@ struct AutoDiffPass : public InstPassBase } IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementation( + DifferentiableTypeConformanceContext* ctx, OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes, IRInst* originalType) { @@ -2274,6 +3125,7 @@ struct AutoDiffPass : public InstPassBase builder.setInsertBefore(originalType); auto diffType = builder.createStructType(); return fillDifferentialTypeImplementationForStruct( + ctx, diffTypes, as<IRStructType>(originalType), as<IRStructType>(diffType)); @@ -2286,7 +3138,7 @@ struct AutoDiffPass : public InstPassBase auto structType = as<IRStructType>(findGenericReturnVal(genType)); SLANG_RELEASE_ASSERT(structType); - auto innerResult = fillDifferentialTypeImplementation(diffTypes, structType); + auto innerResult = fillDifferentialTypeImplementation(ctx, diffTypes, structType); IRBuilder builder(originalType); builder.setInsertBefore(originalType); @@ -2421,7 +3273,8 @@ struct AutoDiffPass : public InstPassBase { bool changed = false; List<IRInst*> autoDiffWorkList; - // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call graph. + // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call + // graph. processAllReachableInsts( [&](IRInst* inst) { @@ -2438,6 +3291,7 @@ struct AutoDiffPass : public InstPassBase case kIROp_Func: case kIROp_Specialize: case kIROp_LookupWitness: + case kIROp_Generic: if (auto innerFunc = as<IRFunc>(getResolvedInstForDecorations(inst->getOperand(0)))) { @@ -2519,8 +3373,8 @@ struct AutoDiffPass : public InstPassBase } // Run transcription logic to generate the body of forward/backward derivatives - // functions. While doing so, we may discover new functions to differentiate, so we keep - // running until the worklist goes dry. + // functions. While doing so, we may discover new functions to differentiate, so we + // keep running until the worklist goes dry. List<IRFunc*> autodiffCleanupList; while (autodiffContext->followUpFunctionsToTranscribe.getCount() != 0) { @@ -2582,10 +3436,10 @@ struct AutoDiffPass : public InstPassBase hasChanges = true; // We have done transcribing the functions, now it is time to demote all - // DifferentialPair types and their operations down to DifferentialPairUserCodeType and - // *UserCode operations so they can be treated just like normal types with no special - // semantics in future processing, and won't be confused with the semantics of a - // DifferentialPair type during future autodiff code gen. + // DifferentialPair types and their operations down to DifferentialPairUserCodeType + // and *UserCode operations so they can be treated just like normal types with no + // special semantics in future processing, and won't be confused with the semantics + // of a DifferentialPair type during future autodiff code gen. rewriteDifferentialPairToUserCode(module); hasChanges |= changed; @@ -2693,8 +3547,8 @@ void checkAutodiffPatterns(TargetProgram* target, IRModule* module, DiagnosticSi if (func->sourceLoc.isValid() && // Don't diagnose for synthesized functions func->findDecoration<IRPreferRecomputeDecoration>()) { - // If we don't have any side-effect behavior, we should warn (note: read-none is a - // stronger guarantee than no-side-effect) + // If we don't have any side-effect behavior, we should warn (note: read-none is + // a stronger guarantee than no-side-effect) // if (func->findDecoration<IRNoSideEffectDecoration>() || func->findDecoration<IRReadNoneDecoration>()) @@ -2759,6 +3613,27 @@ void removeDetachInsts(IRModule* module) pass.processModule(); } + +struct RemoveTypeAnnotationInstsPass : InstPassBase +{ + RemoveTypeAnnotationInstsPass(IRModule* module) + : InstPassBase(module) + { + } + void processModule() + { + processInstsOfType<IRDifferentiableTypeAnnotation>( + kIROp_DifferentiableTypeAnnotation, + [&](IRDifferentiableTypeAnnotation* annotation) { annotation->removeAndDeallocate(); }); + } +}; + +void removeTypeAnnotations(IRModule* module) +{ + RemoveTypeAnnotationInstsPass pass(module); + pass.processModule(); +} + struct LowerNullCheckPass : InstPassBase { LowerNullCheckPass(IRModule* module, AutoDiffSharedContext* context) @@ -2841,6 +3716,8 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module) removeDetachInsts(module); + removeTypeAnnotations(module); + lowerNullCheckInsts(module, &autodiffContext); stripNoDiffTypeAttribute(module); |
