diff options
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 793 |
1 files changed, 527 insertions, 266 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 07a6a76fb..94a605a68 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -25,7 +25,7 @@ bool isBackwardDifferentiableFunc(IRInst* func) return false; } -IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey, IRType* resultType = nullptr) { if (auto witnessTable = as<IRWitnessTable>(witness)) { @@ -53,15 +53,16 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK } else { + SLANG_ASSERT(resultType); return builder->emitLookupInterfaceMethodInst( - builder->getTypeKind(), + resultType, witness, requirementKey); } return nullptr; } -static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) +static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witness = type->getWitness(); SLANG_RELEASE_ASSERT(witness); @@ -70,16 +71,48 @@ static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRB if (as<IRInterfaceType>(type->getValueType()) || as<IRAssociatedType>(type->getValueType())) { // The differential type is the IDifferentiable interface type. - return sharedContext->differentiableInterfaceType; + if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type)) + return sharedContext->differentiableInterfaceType; + else if (as<IRDifferentialPtrPairType>(type)) + return sharedContext->differentiablePtrInterfaceType; + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } - return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey); + if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type)) + return _lookupWitness( + builder, + witness, + sharedContext->differentialAssocTypeStructKey, + builder->getTypeKind()); + else if (as<IRDifferentialPtrPairType>(type)) + return _lookupWitness( + builder, + witness, + sharedContext->differentialAssocRefTypeStructKey, + builder->getTypeKind()); + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); + + if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type)) + return _lookupWitness( + builder, + witnessTable, + sharedContext->differentialAssocTypeWitnessStructKey, + sharedContext->differentialAssocTypeWitnessTableType); + else if (as<IRDifferentialPtrPairType>(type)) + return _lookupWitness( + builder, + witnessTable, + sharedContext->differentialAssocRefTypeWitnessStructKey, + sharedContext->differentialAssocRefTypeWitnessTableType); + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } bool isNoDiffType(IRType* paramType) @@ -320,6 +353,24 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( return result; } +IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst) +{ + for (auto inst : moduleInst->getGlobalInsts()) + { + if (auto interfaceType = as<IRInterfaceType>(inst)) + { + if (auto decor = interfaceType->findDecoration<IRNameHintDecoration>()) + { + if (decor->getName() == "IDifferentiablePtrType") + { + return interfaceType; + } + } + } + } + return nullptr; +} + AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst* inModuleInst) : moduleInst(inModuleInst), targetProgram(target) { @@ -328,14 +379,27 @@ AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst { differentialAssocTypeStructKey = findDifferentialTypeStructKey(); differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); + differentialAssocTypeWitnessTableType = findDifferentialTypeWitnessTableType(); zeroMethodStructKey = findZeroMethodStructKey(); + zeroMethodType = cast<IRFuncType>(getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementVal()); addMethodStructKey = findAddMethodStructKey(); + addMethodType = cast<IRFuncType>(getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementVal()); mulMethodStructKey = findMulMethodStructKey(); nullDifferentialStructType = findNullDifferentialStructType(); nullDifferentialWitness = findNullDifferentialWitness(); - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; + isInterfaceAvailable = true; + } + + differentiablePtrInterfaceType = as<IRInterfaceType>(findDifferentiableRefInterface(inModuleInst)); + + if (differentiablePtrInterfaceType) + { + differentialAssocRefTypeStructKey = findDifferentialPtrTypeStructKey(); + differentialAssocRefTypeWitnessStructKey = findDifferentialPtrTypeWitnessStructKey(); + differentialAssocRefTypeWitnessTableType = findDifferentialPtrTypeWitnessTableType(); + + isPtrInterfaceAvailable = true; } } @@ -404,14 +468,14 @@ IRInst* AutoDiffSharedContext::findNullDifferentialWitness() } -IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) +IRInterfaceRequirementEntry* AutoDiffSharedContext::getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index) { - if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) + if (as<IRModuleInst>(moduleInst) && interface) { // Assume for now that IDifferentiable has exactly five fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); - if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) - return as<IRStructKey>(entry->getRequirementKey()); + // SLANG_ASSERT(interface->getOperandCount() == 5); + if (auto entry = as<IRInterfaceRequirementEntry>(interface->getOperand(index))) + return entry; else { SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); @@ -421,6 +485,50 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde return nullptr; } +// Extracts conformance interface from a witness inst while accounting for some +// quirks in the type system around interfaces that conform to other interfaces. +// +IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWitness(IRInst* witness) +{ + IRInterfaceType* diffInterfaceType = nullptr; + if (auto witnessTableType = as<IRWitnessTableType>(witness->getDataType())) + { + diffInterfaceType = cast<IRInterfaceType>(witnessTableType->getConformanceType()); + } + else if (auto structKey = as<IRStructKey>(witness)) + { + // We currently assume that a struct key is used uniquely for a single interface-requirement-entry. + // Find that entry + for (IRUse* use = structKey->firstUse; use; use = use->nextUse) + { + if (auto entry = as<IRInterfaceRequirementEntry>(use->getUser())) + { + auto innerWitnessTableType = cast<IRWitnessTableType>(entry->getRequirementVal()); + diffInterfaceType = cast<IRInterfaceType>(innerWitnessTableType->getConformanceType()); + break; + } + } + } + else if (auto interfaceRequirementEntry = as<IRInterfaceRequirementEntry>(witness)) + { + auto innerWitnessTableType = cast<IRWitnessTableType>(interfaceRequirementEntry->getRequirementVal()); + diffInterfaceType = cast<IRInterfaceType>(innerWitnessTableType->getConformanceType()); + } + else if (auto tupleType = as<IRTupleType>(witness->getDataType())) + { + SLANG_ASSERT(tupleType->getOperandCount() >= 1); + auto operand = tupleType->getOperand(0); + auto innerWitnessTableType = cast<IRWitnessTableType>(operand); + return cast<IRInterfaceType>(innerWitnessTableType->getConformanceType()); + } + else + { + SLANG_UNEXPECTED("Unexpected witness type"); + } + + return diffInterfaceType; +} + void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { parentFunc = func; @@ -434,7 +542,13 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) { - auto existingItem = differentiableWitnessDictionary.tryGetValue(item->getConcreteType()); + IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness()); + + SLANG_ASSERT( + diffInterfaceType == sharedContext->differentiableInterfaceType + || diffInterfaceType == sharedContext->differentiablePtrInterfaceType); + + auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType()); if (existingItem) { *existingItem = item->getWitness(); @@ -458,20 +572,26 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { auto element = concreteType->getOperand(i); auto elementWitness = witnessPack->getOperand(i); - differentiableWitnessDictionary.addIfNotExists( - (IRType*)element, - _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey)); + + if (diffInterfaceType == sharedContext->differentiableInterfaceType) + addTypeToDictionary( + (IRType*)element, + elementWitness); + else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType) + addTypeToDictionary( + (IRType*)element, + elementWitness); } return; } } - differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); + addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness()); if (!as<IRInterfaceType>(item->getConcreteType())) { - differentiableWitnessDictionary.addIfNotExists( - (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey), + addTypeToDictionary( + (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey, subBuilder.getTypeKind()), item->getWitness()); } @@ -480,29 +600,55 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) // For differential pair types, register the differential type as well. IRBuilder builder(diffPairType); builder.setInsertAfter(diffPairType->getWitness()); - auto diffType = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey); - auto diffWitness = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeWitnessStructKey); - if (diffType && diffWitness) - { - differentiableWitnessDictionary.addIfNotExists((IRType*)diffType, diffWitness); - } + + // TODO(sai): lot of this logic is duplicated. need to refactor. + auto diffType = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? + _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey, builder.getTypeKind()) : + _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocRefTypeStructKey, builder.getTypeKind()); + auto diffWitness = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? + _lookupWitness( + &builder, + diffPairType->getWitness(), + sharedContext->differentialAssocTypeWitnessStructKey, + sharedContext->differentialAssocTypeWitnessTableType) : + _lookupWitness( + &builder, + diffPairType->getWitness(), + sharedContext->differentialAssocRefTypeWitnessStructKey, + sharedContext->differentialAssocRefTypeWitnessTableType); + + addTypeToDictionary((IRType*)diffType, diffWitness); } } } } } -IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) +IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type, DiffConformanceKind kind) { IRInst* foundResult = nullptr; - differentiableWitnessDictionary.tryGetValue(type, foundResult); - return foundResult; + differentiableTypeWitnessDictionary.tryGetValue(type, foundResult); + if (!foundResult) + return nullptr; + + if (kind == DiffConformanceKind::Any) + return foundResult; + + if (auto baseType = getConformanceTypeFromWitness(foundResult)) + { + if (baseType == sharedContext->differentiableInterfaceType && kind == DiffConformanceKind::Value) + return foundResult; + else if (baseType == sharedContext->differentiablePtrInterfaceType && kind == DiffConformanceKind::Ptr) + return foundResult; + } + + return nullptr; } -IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) +IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType) { - if (auto conformance = tryGetDifferentiableWitness(builder, origType)) - return _lookupWitness(builder, conformance, key); + if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any)) + return _lookupWitness(builder, conformance, key, resultType); return nullptr; } @@ -514,7 +660,7 @@ IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairTyp IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { - return _getDiffTypeFromPairType(sharedContext, builder, type); + return this->differentiateType(builder, type->getValueType()); } IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) @@ -525,20 +671,34 @@ IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRB IRInst* DifferentiableTypeConformanceContext::getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey); + return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); } IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey); + return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey, sharedContext->addMethodType); +} + +void DifferentiableTypeConformanceContext::addTypeToDictionary(IRType* type, IRInst* witness) +{ + auto conformanceType = getConformanceTypeFromWitness(witness); + + if (!sharedContext->isInterfaceAvailable && !sharedContext->isPtrInterfaceAvailable) + return; + + SLANG_ASSERT( + conformanceType == sharedContext->differentiableInterfaceType || + conformanceType == sharedContext->differentiablePtrInterfaceType); + + differentiableTypeWitnessDictionary.addIfNotExists(type, witness); } IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterfaceType(IRBuilder *builder, IRInterfaceType *interfaceType, IRWitnessTable *witnessTable) { SLANG_RELEASE_ASSERT(interfaceType); - List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( + List<IRInterfaceRequirementEntry*> lookupKeyPath = findInterfaceLookupPath( sharedContext->differentiableInterfaceType, interfaceType); IRInst* differentialTypeWitness = witnessTable; @@ -549,6 +709,7 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface { differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey()); // Lookup insts are always primal values. + builder->markInstAsPrimal(differentialTypeWitness); } return differentialTypeWitness; @@ -557,10 +718,10 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface return nullptr; } -// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`. -static bool _findDifferentiableInterfaceLookupPathImpl( +// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `supType`. +static bool _findInterfaceLookupPathImpl( HashSet<IRInst*>& processedTypes, - IRInterfaceType* idiffType, + IRInterfaceType* supType, IRInterfaceType* type, List<IRInterfaceRequirementEntry*>& currentPath) { @@ -576,13 +737,13 @@ static bool _findDifferentiableInterfaceLookupPathImpl( if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal())) { currentPath.add(entry); - if (wt->getConformanceType() == idiffType) + if (wt->getConformanceType() == supType) { return true; } else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType())) { - if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath)) + if (_findInterfaceLookupPathImpl(processedTypes, supType, subInterfaceType, currentPath)) return true; } currentPath.removeLast(); @@ -591,11 +752,11 @@ static bool _findDifferentiableInterfaceLookupPathImpl( return false; } -List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findDifferentiableInterfaceLookupPath(IRInterfaceType *idiffType, IRInterfaceType *type) +List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findInterfaceLookupPath(IRInterfaceType *supType, IRInterfaceType *type) { List<IRInterfaceRequirementEntry*> currentPath; HashSet<IRInst*> processedTypes; - _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath); + _findInterfaceLookupPathImpl(processedTypes, supType, type, currentPath); return currentPath; } @@ -722,7 +883,7 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst)) { - differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness()); + addTypeToDictionary(pairType->getValueType(), pairType->getWitness()); } } } @@ -762,9 +923,8 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build case kIROp_DifferentialPairType: { auto primalPairType = as<IRDifferentialPairType>(primalType); - return getOrCreateDiffPairType( - builder, - getDiffTypeFromPairType(builder, primalPairType), + return builder->getDifferentialPairType( + (IRType*)getDiffTypeFromPairType(builder, primalPairType), getDiffTypeWitnessFromPairType(builder, primalPairType)); } @@ -776,6 +936,14 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build getDiffTypeWitnessFromPairType(builder, primalPairType)); } + case kIROp_DifferentialPtrPairType: + { + auto primalPairType = as<IRDifferentialPtrPairType>(primalType); + return builder->getDifferentialPtrPairType( + (IRType*)getDiffTypeFromPairType(builder, primalPairType), + getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + case kIROp_FuncType: { SLANG_UNIMPLEMENTED_X("Impl"); @@ -817,12 +985,12 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build } } -IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType) +IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType, DiffConformanceKind kind) { if (isNoDiffType((IRType*)primalType)) return nullptr; - - IRInst* witness = lookUpConformanceForType((IRType*)primalType); + + IRInst* witness = lookUpConformanceForType((IRType*)primalType, kind); if (witness) { SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(primalType)); @@ -834,31 +1002,60 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil witness = nullptr; } - if (!witness) + if (witness) + return witness; + + // If a witness is not already mapped, build one if possible. + SLANG_RELEASE_ASSERT(primalType); + if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType)) { - SLANG_RELEASE_ASSERT(primalType); - if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType)) - { - witness = getOrCreateDifferentiablePairWitness(builder, primalPairType); - } - else if (auto arrayType = as<IRArrayType>(primalType)) - { - witness = getArrayWitness(builder, arrayType); - } - else if (auto extractExistential = as<IRExtractExistentialType>(primalType)) - { - witness = getExtractExistensialTypeWitness(builder, extractExistential); - } - else if (auto typePack = as<IRTypePack>(primalType)) + witness = buildDifferentiablePairWitness(builder, primalPairType, kind); + } + else if (auto arrayType = as<IRArrayType>(primalType)) + { + witness = buildArrayWitness(builder, arrayType, kind); + } + else if (auto extractExistential = as<IRExtractExistentialType>(primalType)) + { + witness = buildExtractExistensialTypeWitness(builder, extractExistential, kind); + } + else if (auto typePack = as<IRTypePack>(primalType)) + { + witness = buildTupleWitness(builder, typePack, kind); + } + else if (auto tupleType = as<IRTupleType>(primalType)) + { + witness = buildTupleWitness(builder, tupleType, kind); + } + else if (auto lookup = as<IRLookupWitnessMethod>(primalType)) + { + // For types that are lookups from a table, we can simply lookup the witness from the same table + if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey) { - witness = getTupleWitness(builder, typePack); + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + sharedContext->differentialAssocTypeWitnessStructKey); } - else if (auto tupleType = as<IRTupleType>(primalType)) + + if (lookup->getRequirementKey() == sharedContext->differentialAssocRefTypeStructKey) { - witness = getTupleWitness(builder, tupleType); + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + sharedContext->differentialAssocRefTypeWitnessStructKey); } } - return witness; + + // If we created a witness, register it. + if (witness) + { + addTypeToDictionary((IRType*)primalType, witness); + return witness; + } + + // Failed. Type is either non-differentiable, or unhandled. + return nullptr; } IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) @@ -868,77 +1065,97 @@ IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* witness); } -IRInst* DifferentiableTypeConformanceContext::getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType) +IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness( + IRBuilder* builder, + IRDifferentialPairTypeBase* pairType, + DiffConformanceKind target) { - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); - - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(this->sharedContext->differentiableInterfaceType, (IRType*)pairType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - - bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false; - - // Fill in differential method implementations. - auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType(); - auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness(); - - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); - b.emitBlock(); - auto p0 = b.emitParam(diffDiffPairType); - auto p1 = b.emitParam(diffDiffPairType); - - // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - IRInst* argsPrimal[2] = { - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; - auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); - IRInst* argsDiff[2] = { - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; - auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) - : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); - b.emitReturn(retVal); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); - b.emitBlock(); - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) - : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); - b.emitReturn(retVal); + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)pairType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false; + + // Fill in differential method implementations. + auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType(); + auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness(); + + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); + b.emitBlock(); + auto p0 = b.emitParam(diffDiffPairType); + auto p1 = b.emitParam(diffDiffPairType); + + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + IRInst* argsPrimal[2] = { + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; + auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); + IRInst* argsDiff[2] = { + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; + auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) + : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); + b.emitReturn(retVal); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); + b.emitBlock(); + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) + : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); + b.emitReturn(retVal); + } + } + else if (target == DiffConformanceKind::Ptr) + { + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); + + table = builder->createWitnessTable( + sharedContext->differentiablePtrInterfaceType, + (IRType*)pairType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); } - - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)pairType] = table; return table; } -IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder, IRArrayType* arrayType) +IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( + IRBuilder* builder, + IRArrayType* arrayType, + DiffConformanceKind target) { // Differentiate the pair type to get it's differential (which is itself a pair) auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)arrayType); @@ -946,70 +1163,89 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder if (!diffArrayType) return nullptr; - auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType()); + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType)); + auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType(), DiffConformanceKind::Value); - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); - auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType); + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType(); + auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType(); - // Fill in differential method implementations. + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffArrayType, diffArrayType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); + b.emitBlock(); + auto p0 = b.emitParam(diffArrayType); + auto p1 = b.emitParam(diffArrayType); + + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + auto resultVar = b.emitVar(diffArrayType); + IRBlock* loopBodyBlock = nullptr; + IRBlock* loopBreakBlock = nullptr; + auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); + b.setInsertBefore(loopBodyBlock->getTerminator()); + + IRInst* args[2] = { + b.emitElementExtract(p0, loopCounter), + b.emitElementExtract(p1, loopCounter) }; + auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); + auto addr = b.emitElementAddress(resultVar, loopCounter); + b.emitStore(addr, elementResult); + b.setInsertInto(loopBreakBlock); + b.emitReturn(b.emitLoad(resultVar)); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); + b.emitBlock(); + + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); + b.emitReturn(retVal); + } + } + else if (target == DiffConformanceKind::Ptr) { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffArrayType, diffArrayType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); - b.emitBlock(); - auto p0 = b.emitParam(diffArrayType); - auto p1 = b.emitParam(diffArrayType); + SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType)); - // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - auto resultVar = b.emitVar(diffArrayType); - IRBlock* loopBodyBlock = nullptr; - IRBlock* loopBreakBlock = nullptr; - auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); - b.setInsertBefore(loopBodyBlock->getTerminator()); + table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)arrayType); - IRInst* args[2] = { - b.emitElementExtract(p0, loopCounter), - b.emitElementExtract(p1, loopCounter) }; - auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); - auto addr = b.emitElementAddress(resultVar, loopCounter); - b.emitStore(addr, elementResult); - b.setInsertInto(loopBreakBlock); - b.emitReturn(b.emitLoad(resultVar)); + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); } + else { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); - b.emitBlock(); - - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); - b.emitReturn(retVal); + SLANG_UNEXPECTED("Invalid conformance kind for synthesis"); } - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)arrayType] = table; - return table; } -IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType) +IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( + IRBuilder* builder, + IRInst* inTupleType, + DiffConformanceKind target) { // Differentiate the pair type to get it's differential (which is itself a pair) auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType); @@ -1017,100 +1253,116 @@ IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder if (!diffTupleType) return nullptr; - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - - // Fill in differential method implementations. - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffTupleType, diffTupleType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); - b.emitBlock(); - auto p0 = b.emitParam(diffTupleType); - auto p1 = b.emitParam(diffTupleType); - List<IRInst*> results; - for (UInt i = 0; i < inTupleType->getOperandCount(); i++) - { - auto elementType = inTupleType->getOperand(i); - auto diffElementType = (IRType*)diffTupleType->getOperand(i); - auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); - IRInst* elementResult = nullptr; - if (!innerWitness) + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType)); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffTupleType, diffTupleType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); + b.emitBlock(); + auto p0 = b.emitParam(diffTupleType); + auto p1 = b.emitParam(diffTupleType); + List<IRInst*> results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) { - elementResult = b.getVoidValue(); + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + auto iVal = b.getIntValue(b.getIntType(), i); + IRInst* args[2] = { + b.emitGetTupleElement(diffElementType, p0, iVal), + b.emitGetTupleElement(diffElementType, p1, iVal) }; + elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); + } + results.add(elementResult); } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); else - { - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - auto iVal = b.getIntValue(b.getIntType(), i); - IRInst* args[2] = { - b.emitGetTupleElement(diffElementType, p0, iVal), - b.emitGetTupleElement(diffElementType, p1, iVal) }; - elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); - } - results.add(elementResult); + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); } - IRInst* resultVal = nullptr; - if (diffTupleType->getOp() == kIROp_TupleType) - resultVal = b.emitMakeTuple(diffTupleType, results); - else - resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); - b.emitReturn(resultVal); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); - b.emitBlock(); - List<IRInst*> results; - for (UInt i = 0; i < inTupleType->getOperandCount(); i++) - { - auto elementType = inTupleType->getOperand(i); - auto diffElementType = (IRType*)diffTupleType->getOperand(i); - auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); - IRInst* elementResult = nullptr; - if (!innerWitness) + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); + b.emitBlock(); + List<IRInst*> results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) { - elementResult = b.getVoidValue(); + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); + } + results.add(elementResult); } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); else - { - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); - } - results.add(elementResult); + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); } - IRInst* resultVal = nullptr; - if (diffTupleType->getOp() == kIROp_TupleType) - resultVal = b.emitMakeTuple(diffTupleType, results); - else - resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); - b.emitReturn(resultVal); } + else if (target == DiffConformanceKind::Ptr) + { + SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType)); - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)inTupleType] = table; + table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)inTupleType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); + } return table; } -IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness( +IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness( IRBuilder* builder, - IRExtractExistentialType* extractExistentialType) + IRExtractExistentialType* extractExistentialType, + DiffConformanceKind target) { + SLANG_UNUSED(target); // logic is the same for both value and ptr + // Check that the type's base is differentiable if (differentiateType(builder, extractExistentialType->getOperand(0)->getDataType())) { @@ -1310,12 +1562,13 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* if (context.isDifferentiableType((IRType*)typeInst)) return true; + // Look for equivalent types. - for (auto type : context.differentiableWitnessDictionary) + for (auto type : context.differentiableTypeWitnessDictionary) { if (isTypeEqual(type.key, (IRType*)typeInst)) { - context.differentiableWitnessDictionary[(IRType*)typeInst] = type.value; + context.differentiableTypeWitnessDictionary[(IRType*)typeInst] = type.value; return true; } } @@ -1672,7 +1925,7 @@ struct AutoDiffPass : public InstPassBase IRBuilder keyBuilder = builder; keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType)); auto diffKey = keyBuilder.createStructKey(); - auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey); + auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey, builder.getTypeKind()); info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType); info.witness = diffFieldWitness; builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey); @@ -1695,7 +1948,11 @@ struct AutoDiffPass : public InstPassBase List<IRInst*> fieldVals; for (auto info : diffFields) { - auto innerZeroMethod = _lookupWitness(&builder, info.witness, autodiffContext->zeroMethodStructKey); + auto innerZeroMethod = _lookupWitness( + &builder, + info.witness, + autodiffContext->zeroMethodStructKey, + autodiffContext->zeroMethodType); IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr); fieldVals.add(val); } @@ -1719,7 +1976,11 @@ struct AutoDiffPass : public InstPassBase List<IRInst*> fieldVals; for (auto info : diffFields) { - auto innerAddMethod = _lookupWitness(&builder, info.witness, autodiffContext->addMethodStructKey); + auto innerAddMethod = _lookupWitness( + &builder, + info.witness, + autodiffContext->addMethodStructKey, + autodiffContext->addMethodType); IRInst* args[2] = { builder.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()), builder.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()), |
