summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-01-10 03:16:24 +0530
committerGitHub <noreply@github.com>2025-01-09 13:46:24 -0800
commit87f00a36a123e36b415eeea82e02a8366cc5b881 (patch)
tree719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff.cpp
parent6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff)
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp1177
1 files changed, 1027 insertions, 150 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 4edd8eabe..7507e2fac 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -230,9 +230,6 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
}
else if (auto specializedType = as<IRSpecialize>(pairType))
{
- // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
- // type, emit the specialization type.
- //
auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase()));
if (auto genericBasePairStructType = as<IRStructType>(genericType))
{
@@ -263,14 +260,142 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
return nullptr;
}
-IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
+bool isExistentialOrRuntimeInst(IRInst* inst)
{
- return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
+ if (auto lookup = as<IRLookupWitnessMethod>(inst))
+ {
+ return isExistentialOrRuntimeInst(lookup->getWitnessTable());
+ }
+
+ return as<IRExtractExistentialType>(inst) || as<IRExtractExistentialWitnessTable>(inst) ||
+ as<IRMakeExistential>(inst) || as<IRInterfaceType>(inst->getDataType());
}
-IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
+bool isRuntimeType(IRType* type)
{
- return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
+ if (as<IRExtractExistentialType>(type))
+ return true;
+
+ if (auto lookup = as<IRLookupWitnessMethod>(type))
+ {
+ return isExistentialOrRuntimeInst(lookup->getWitnessTable());
+ }
+
+ return false;
+}
+
+IRInst* getExistentialBaseWitnessTable(IRBuilder* builder, IRType* type)
+{
+ if (auto lookupWitnessMethod = as<IRLookupWitnessMethod>(type))
+ {
+ return lookupWitnessMethod->getWitnessTable();
+ }
+ else if (auto extractExistentialType = as<IRExtractExistentialType>(type))
+ {
+ return builder->emitExtractExistentialWitnessTable(extractExistentialType->getOperand(0));
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unexpected existential type");
+ }
+}
+
+IRInst* getCacheKey(IRBuilder* builder, IRInst* primalType)
+{
+ if (auto lookupWitness = as<IRLookupWitnessMethod>(primalType))
+ return lookupWitness->getRequirementKey();
+ else if (auto extractExistentialType = as<IRExtractExistentialType>(primalType))
+ {
+ auto interfaceType = extractExistentialType->getOperand(0)->getDataType();
+
+ // We will cache on the interface's this-type, since the interface type itself can be
+ // deallocated during the lowering process.
+ //
+ return builder->getThisType(interfaceType);
+ }
+
+ return primalType;
+}
+
+IRInst* DifferentialPairTypeBuilder::emitExistentialMakePair(
+ IRBuilder* builder,
+ IRInst* pairType,
+ IRInst* primalInst,
+ IRInst* diffInst)
+{
+ auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)pairType);
+
+ auto pairTypeKey = cast<IRLookupWitnessMethod>(pairType)->getRequirementKey();
+ auto makePairKey = makePairKeyMap[pairTypeKey];
+
+ auto makePairMethod = builder->emitLookupInterfaceMethodInst(
+ makePairFuncTypeMap[makePairKey],
+ baseWitness,
+ makePairKey);
+
+ List<IRInst*> args;
+ args.add(primalInst);
+ args.add(diffInst);
+
+ auto makePairVal = builder->emitCallInst((IRType*)pairType, makePairMethod, args);
+
+ return makePairVal;
+}
+
+IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(
+ IRBuilder* builder,
+ IRType* loweredPairType,
+ IRInst* baseInst)
+{
+ if (isRuntimeType(loweredPairType))
+ {
+ auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)loweredPairType);
+
+ auto pairTypeKey = cast<IRLookupWitnessMethod>(loweredPairType)->getRequirementKey();
+ auto getPrimalKey = getPrimalKeyMap[pairTypeKey];
+
+ auto primalFieldMethod = builder->emitLookupInterfaceMethodInst(
+ getPrimalFuncTypeMap[getPrimalKey],
+ baseWitness,
+ getPrimalKey);
+
+ auto primalFieldVal =
+ builder->emitCallInst(primalTypeMap[loweredPairType], primalFieldMethod, baseInst);
+
+ return primalFieldVal;
+ }
+ else
+ {
+ return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
+ }
+}
+
+IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(
+ IRBuilder* builder,
+ IRType* loweredPairType,
+ IRInst* baseInst)
+{
+ if (isRuntimeType(loweredPairType))
+ {
+ auto baseWitness = getExistentialBaseWitnessTable(builder, (IRType*)loweredPairType);
+
+ auto pairTypeKey = cast<IRLookupWitnessMethod>(loweredPairType)->getRequirementKey();
+ auto getDiffKey = getDiffKeyMap[pairTypeKey];
+
+ auto diffFieldMethod = builder->emitLookupInterfaceMethodInst(
+ getDiffFuncTypeMap[getDiffKey],
+ baseWitness,
+ getDiffKey);
+
+ auto diffFieldVal =
+ builder->emitCallInst(diffTypeMap[loweredPairType], diffFieldMethod, baseInst);
+
+ return diffFieldVal;
+ }
+ else
+ {
+ return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
+ }
}
IRStructKey* DifferentialPairTypeBuilder::_getOrCreateDiffStructKey()
@@ -307,6 +432,380 @@ IRStructKey* DifferentialPairTypeBuilder::_getOrCreatePrimalStructKey()
return this->globalPrimalKey;
}
+IRInst* DifferentialPairTypeBuilder::getOrCreateCommonDiffPairInterface(IRBuilder* builder)
+{
+ if (!this->commonDiffPairInterface)
+ {
+ this->commonDiffPairInterface = builder->createInterfaceType(0, nullptr);
+ builder->addNameHintDecoration(
+ this->commonDiffPairInterface,
+ UnownedStringSlice("IDiffPair"));
+ }
+
+ return this->commonDiffPairInterface;
+}
+
+IRInst* DifferentialPairTypeBuilder::_createDiffPairInterfaceRequirement(
+ IRType* origBaseType,
+ IRType*)
+{
+ // We will create an interface requirement for the type's pair & then create implementations in
+ // all the implementing witness tables.
+ //
+
+ IRBuilder builder(sharedContext->moduleInst);
+
+ // Find the right interface to put the requirement in.
+ IRInterfaceType* interfaceType = nullptr;
+
+ // Find the effective type to put in the requirement entry
+ // for the base type
+ //
+ IRType* requirementBaseType = nullptr;
+
+ // Requirement key (only used for associated types)
+ //
+ IRInst* requirementKey = nullptr;
+
+ // Add a name hint to the key.
+ StringBuilder nameBuilderReqKey;
+ nameBuilderReqKey << "DiffPair_Req_";
+
+ if (auto lookup = as<IRLookupWitnessMethod>(origBaseType))
+ {
+ interfaceType =
+ cast<IRInterfaceType>(cast<IRWitnessTableType>(lookup->getWitnessTable()->getDataType())
+ ->getConformanceType());
+
+ requirementBaseType =
+ cast<IRType>(findInterfaceRequirement(interfaceType, lookup->getRequirementKey()));
+
+ requirementKey = lookup->getRequirementKey();
+
+ if (auto nameHint = lookup->getRequirementKey()->findDecoration<IRNameHintDecoration>())
+ {
+ nameBuilderReqKey << nameHint->getName();
+ }
+ else
+ {
+ nameBuilderReqKey << "unknown_assoc_type";
+ }
+ }
+ else if (auto extractType = as<IRExtractExistentialType>(origBaseType))
+ {
+ auto existentialType = extractType->getOperand(0);
+ interfaceType = cast<IRInterfaceType>(existentialType->getDataType());
+ requirementBaseType = builder.getThisType(interfaceType);
+
+ requirementKey = nullptr;
+
+ if (auto nameHint = interfaceType->findDecoration<IRNameHintDecoration>())
+ {
+ nameBuilderReqKey << nameHint->getName();
+ }
+ else
+ {
+ nameBuilderReqKey << "unknown_interface_type";
+ }
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unexpected type for differential pair interface requirement");
+ }
+
+ auto diffPairInterfaceType =
+ cast<IRInterfaceType>(getOrCreateCommonDiffPairInterface(&builder));
+
+ // Add 4 requirements to the interface:
+ // the associated pair type, getPrimal, getDiff & makePair
+ //
+ builder.setInsertInto(interfaceType);
+ IRStructKey* diffPairRequirementKey = builder.createStructKey();
+ IRStructKey* getPrimalRequirementKey = builder.createStructKey();
+ IRStructKey* getDiffRequirementKey = builder.createStructKey();
+ IRStructKey* makePairRequirementKey = builder.createStructKey();
+
+ makePairKeyMap[diffPairRequirementKey] = makePairRequirementKey;
+ getPrimalKeyMap[diffPairRequirementKey] = getPrimalRequirementKey;
+ getDiffKeyMap[diffPairRequirementKey] = getDiffRequirementKey;
+
+ List<IRInst*> entries;
+
+ // Add all the old requirements to the new interface.
+ for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
+ entries.add(interfaceType->getOperand(i));
+
+ //
+ // Create the new requirement entries.
+ //
+
+ {
+ // Create & insert the requirement key.
+ List<IRInterfaceType*> constraintTypes;
+ constraintTypes.add(diffPairInterfaceType);
+ auto entry = builder.createInterfaceRequirementEntry(
+ diffPairRequirementKey,
+ builder.getAssociatedType(constraintTypes.getArrayView()));
+
+ builder.addNameHintDecoration(diffPairRequirementKey, nameBuilderReqKey.getUnownedSlice());
+ entries.add(entry);
+ }
+
+ {
+ // Create & insert the getPrimal requirement.
+
+ List<IRType*> paramTypes;
+ List<IRInterfaceType*> paramConstraintTypes;
+ paramConstraintTypes.add(diffPairInterfaceType);
+ paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView()));
+
+ auto entryFuncType = builder.getFuncType(paramTypes, requirementBaseType);
+ auto entry =
+ builder.createInterfaceRequirementEntry(getPrimalRequirementKey, entryFuncType);
+
+ getPrimalFuncTypeMap[getPrimalRequirementKey] = entryFuncType;
+
+ StringBuilder entryNameBuilder;
+ entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_getPrimal";
+ builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice());
+
+ entries.add(entry);
+ }
+
+ {
+ // Create & insert the getDiff requirement.
+
+ List<IRType*> paramTypes;
+ List<IRInterfaceType*> paramConstraintTypes;
+ paramConstraintTypes.add(diffPairInterfaceType);
+ paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView()));
+
+ List<IRInterfaceType*> resultConstraintTypes;
+ resultConstraintTypes.add(sharedContext->differentiableInterfaceType);
+ auto resultType = builder.getAssociatedType(resultConstraintTypes.getArrayView());
+
+ auto entryFuncType = builder.getFuncType(paramTypes, resultType);
+ auto entry = builder.createInterfaceRequirementEntry(getDiffRequirementKey, entryFuncType);
+
+ getDiffFuncTypeMap[getDiffRequirementKey] = entryFuncType;
+
+ StringBuilder entryNameBuilder;
+ entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_getDiff";
+ builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice());
+
+ entries.add(entry);
+ }
+
+ {
+ // Create & insert the makePair requirement.
+
+ List<IRType*> paramTypes;
+ paramTypes.add(requirementBaseType);
+
+ List<IRInterfaceType*> paramConstraintTypes;
+ paramConstraintTypes.add(sharedContext->differentiableInterfaceType);
+ paramTypes.add(builder.getAssociatedType(paramConstraintTypes.getArrayView()));
+
+ List<IRInterfaceType*> resultConstraintTypes;
+ resultConstraintTypes.add(diffPairInterfaceType);
+ auto entryFuncType = builder.getFuncType(
+ paramTypes,
+ builder.getAssociatedType(resultConstraintTypes.getArrayView()));
+ auto entry = builder.createInterfaceRequirementEntry(makePairRequirementKey, entryFuncType);
+
+ makePairFuncTypeMap[makePairRequirementKey] = entryFuncType;
+
+ StringBuilder entryNameBuilder;
+ entryNameBuilder << nameBuilderReqKey.getUnownedSlice() << "_makePair";
+ builder.addNameHintDecoration(entry, entryNameBuilder.getUnownedSlice());
+
+ entries.add(entry);
+ }
+
+ {
+ // Create the new interface type.
+
+ auto newInterfaceType =
+ builder.createInterfaceType(entries.getCount(), entries.getBuffer());
+
+ // Transfer decorations from the old interface to the new one.
+ interfaceType->transferDecorationsTo(newInterfaceType);
+ interfaceType->replaceUsesWith(newInterfaceType);
+
+ // Replace the interface maps in the caches.
+ if (this->pairTypeCache.containsKey(interfaceType))
+ this->pairTypeCache[newInterfaceType] = this->pairTypeCache[interfaceType];
+
+ if (this->existentialPairTypeCache.containsKey(interfaceType))
+ this->existentialPairTypeCache[newInterfaceType] =
+ this->existentialPairTypeCache[interfaceType];
+
+ interfaceType->removeAndDeallocate();
+ interfaceType = newInterfaceType;
+ }
+
+ //
+ // Implement the requirements in all the witness tables.
+ //
+
+ // Collect all witness tables of the given interfaceType.
+ List<IRWitnessTable*> concreteWitnessTables;
+ auto witnessTableType = builder.getWitnessTableType(interfaceType);
+ for (auto use = witnessTableType->firstUse; use; use = use->nextUse)
+ {
+ if (auto witnessTable = as<IRWitnessTable>(use->getUser()))
+ {
+ if (use->getUser()->getFullType() == witnessTableType)
+ concreteWitnessTables.add(witnessTable);
+ }
+ }
+
+ DifferentiableTypeConformanceContext ctx(sharedContext);
+ ctx.buildGlobalWitnessDictionary();
+
+ for (auto concreteWitnessTable : concreteWitnessTables)
+ {
+ IRType* concretePrimalType = nullptr;
+
+ // What requirement are we trying to satisfy?
+ if (as<IRThisType>(requirementBaseType))
+ {
+ // For this types, we should lower the concrete type of the witness table itself.
+ concretePrimalType = concreteWitnessTable->getConcreteType();
+ }
+ else if (as<IRAssociatedType>(requirementBaseType))
+ {
+ // For associated types, look it up in the witness table.
+ concretePrimalType =
+ (IRType*)findWitnessTableEntry(concreteWitnessTable, requirementKey);
+ }
+ else
+ {
+ // We shouldn't see any other case here.
+ SLANG_UNEXPECTED("Unexpected requirement base type");
+ }
+
+ // Create the pair type.
+ auto witness = ctx.tryGetDifferentiableWitness(
+ &builder,
+ concretePrimalType,
+ DiffConformanceKind::Value);
+
+ // Really should not see a case where the original interface is differentiable, but
+ // we can't find the witness table.
+ //
+ SLANG_ASSERT(witness);
+
+ auto concretePairType = builder.getDifferentialPairType(
+ concretePrimalType,
+ witness); // TODO: Need to handle the other conformance kinds
+ auto concreteDiffType =
+ (IRType*)_getDiffTypeFromPairType(sharedContext, &builder, concretePairType);
+
+ auto loweredStructType = (IRType*)lowerDiffPairType(&builder, concretePairType);
+
+ // Create an (empty) witness table for loweredStuctType : IDiffPair_...
+ // This is just so that there is a bound on the any-value-size for each group of pair types.
+ //
+ auto witnessTable = builder.createWitnessTable(diffPairInterfaceType, loweredStructType);
+ builder.addKeepAliveDecoration(witnessTable);
+
+ builder.setInsertInto(concreteWitnessTable);
+
+ // Create the associated type entry.
+ {
+ builder.createWitnessTableEntry(
+ concreteWitnessTable,
+ diffPairRequirementKey,
+ loweredStructType);
+ }
+
+ // Create the getPrimal method.
+ {
+ auto primalMethod = builder.createFunc();
+
+ StringBuilder nameBuilder;
+ getTypeNameHint(nameBuilder, loweredStructType);
+ nameBuilder << "_getPrimal";
+ builder.addNameHintDecoration(primalMethod, nameBuilder.getUnownedSlice());
+
+ primalMethod->setFullType(builder.getFuncType(
+ List<IRType*>({(IRType*)loweredStructType}),
+ concretePrimalType));
+
+ builder.setInsertInto(primalMethod);
+ auto block = builder.emitBlock();
+ builder.setInsertInto(block);
+ auto param = builder.emitParam((IRType*)loweredStructType);
+ builder.emitReturn(
+ builder.emitFieldExtract(concretePrimalType, param, _getOrCreatePrimalStructKey()));
+
+ builder.setInsertInto(concreteWitnessTable);
+ builder.createWitnessTableEntry(
+ concreteWitnessTable,
+ getPrimalRequirementKey,
+ primalMethod);
+ }
+
+ // Create the getDiff method.
+ {
+ auto diffMethod = builder.createFunc();
+
+ StringBuilder nameBuilder;
+ getTypeNameHint(nameBuilder, loweredStructType);
+ nameBuilder << "_getDiff";
+ builder.addNameHintDecoration(diffMethod, nameBuilder.getUnownedSlice());
+
+ diffMethod->setFullType(
+ builder.getFuncType(List<IRType*>({(IRType*)loweredStructType}), concreteDiffType));
+
+ builder.setInsertInto(diffMethod);
+ auto block = builder.emitBlock();
+ builder.setInsertInto(block);
+ auto param = builder.emitParam((IRType*)loweredStructType);
+ builder.emitReturn(
+ builder.emitFieldExtract(concreteDiffType, param, _getOrCreateDiffStructKey()));
+
+ builder.setInsertInto(concreteWitnessTable);
+ builder.createWitnessTableEntry(
+ concreteWitnessTable,
+ getDiffRequirementKey,
+ diffMethod);
+ }
+
+ // Create the makePair method.
+ {
+ auto makePairMethod = builder.createFunc();
+
+ StringBuilder nameBuilder;
+ getTypeNameHint(nameBuilder, loweredStructType);
+ nameBuilder << "_makePair";
+ builder.addNameHintDecoration(makePairMethod, nameBuilder.getUnownedSlice());
+
+ makePairMethod->setFullType(builder.getFuncType(
+ List<IRType*>({concretePrimalType, concreteDiffType}),
+ (IRType*)loweredStructType));
+
+ builder.setInsertInto(makePairMethod);
+ auto block = builder.emitBlock();
+ builder.setInsertInto(block);
+ auto param1 = builder.emitParam(concretePrimalType);
+ auto param2 = builder.emitParam(concreteDiffType);
+ List<IRInst*> args = {param1, param2};
+ auto pair = builder.emitMakeStruct((IRType*)loweredStructType, args);
+ builder.emitReturn(pair);
+
+ builder.setInsertInto(concreteWitnessTable);
+ builder.createWitnessTableEntry(
+ concreteWitnessTable,
+ makePairRequirementKey,
+ makePairMethod);
+ }
+ }
+
+ return diffPairRequirementKey;
+}
+
IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, IRType* diffType)
{
switch (origBaseType->getOp())
@@ -333,6 +832,7 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I
return pairStructType;
}
+
IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(IRBuilder* builder, IRType* originalPairType)
{
IRInst* result = nullptr;
@@ -352,26 +852,119 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(IRBuilder* builder, IRTyp
// purposes.
auto primalType = pairType->getValueType();
- if (pairTypeCache.tryGetValue(primalType, result))
- return result;
- if (!pairType)
+
+ if (isRuntimeType(primalType))
{
- result = originalPairType;
+ // Existential case.
+ auto cacheKey = getCacheKey(builder, primalType);
+ auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType);
+
+ IRInst* pairReqKey = nullptr;
+ if (!existentialPairTypeCache.tryGetValue(cacheKey, pairReqKey))
+ {
+ pairReqKey = _createDiffPairInterfaceRequirement(primalType, (IRType*)diffType);
+ existentialPairTypeCache.add(cacheKey, pairReqKey);
+ }
+
+ auto baseWitnessTable = getExistentialBaseWitnessTable(builder, primalType);
+ result = builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ baseWitnessTable,
+ pairReqKey);
+
+ primalTypeMap[result] = primalType;
+ diffTypeMap[result] = (IRType*)diffType;
+
return result;
}
- if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalType))
+ else if (auto typePack = as<IRTypePack>(primalType))
{
- result = nullptr;
- return result;
+ // Lower DiffPair(TypePack(a_0, a_1, ...), MakeWitnessPack(w_0, w_1, ...)) as
+ // TypePack(DiffPair(a_0, w_0), DiffPair(a_1, w_1), ...)
+ //
+ auto cacheKey = primalType;
+ if (pairTypeCache.tryGetValue(cacheKey, result))
+ return result;
+
+ auto packWitness = pairType->getWitness();
+
+ // Right now we only support concrete witness tables for type packs.
+ auto concretePackWitness = as<IRWitnessTable>(packWitness);
+ SLANG_ASSERT(concretePackWitness);
+
+ // Get diff type pack.
+ IRTypePack* diffTypePack = nullptr;
+
+ if (concretePackWitness->getConformanceType() ==
+ this->sharedContext->differentiableInterfaceType)
+ diffTypePack = as<IRTypePack>(findWitnessTableEntry(
+ concretePackWitness,
+ this->sharedContext->differentialAssocTypeStructKey));
+ else if (
+ concretePackWitness->getConformanceType() ==
+ this->sharedContext->differentiablePtrInterfaceType)
+ diffTypePack = as<IRTypePack>(findWitnessTableEntry(
+ concretePackWitness,
+ this->sharedContext->differentialAssocRefTypeStructKey));
+ else
+ SLANG_UNEXPECTED("Unexpected witness table");
+
+ SLANG_ASSERT(diffTypePack);
+
+ List<IRType*> args;
+ for (UInt i = 0; i < typePack->getOperandCount(); i++)
+ {
+ auto type = (IRType*)typePack->getOperand(i);
+ auto diffType = (IRType*)typePack->getOperand(i);
+
+ if (pairTypeCache.tryGetValue(type, result))
+ {
+ args.add((IRType*)result);
+ continue;
+ }
+
+ // Lower the diff pair type.
+ auto loweredPairType = (IRType*)_createDiffPairType(type, diffType);
+
+ pairTypeCache.add(type, loweredPairType);
+ args.add(loweredPairType);
+ }
+
+ auto loweredTypePack = builder->getTypePack(args.getCount(), args.getBuffer());
+ // TODO: Unify the cache between the three cases.
+ pairTypeCache.add(cacheKey, loweredTypePack);
+
+ return loweredTypePack;
}
+ else
+ {
+ auto cacheKey = primalType;
+ if (pairTypeCache.tryGetValue(primalType, result))
+ return result;
- auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType);
- if (!diffType)
- return result;
- result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
- pairTypeCache.add(primalType, result);
+ if (as<IRParam, IRDynamicCastBehavior::NoUnwrap>(primalType))
+ {
+ result = nullptr;
+ return result;
+ }
+
+ if (as<IRThisType>(primalType) || as<IRAssociatedType>(primalType))
+ {
+ List<IRInterfaceType*> constraintTypes;
+ constraintTypes.add(this->commonDiffPairInterface);
+ return builder->getAssociatedType(constraintTypes.getArrayView());
+ }
+
+ auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType);
+ if (!diffType)
+ return result;
+
+ // Concrete case.
+ result = _createDiffPairType(primalType, (IRType*)diffType);
+ pairTypeCache.add(cacheKey, result);
- return result;
+ return result;
+ }
}
IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst)
@@ -550,6 +1143,13 @@ IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWit
auto innerWitnessTableType = cast<IRWitnessTableType>(operand);
return cast<IRInterfaceType>(innerWitnessTableType->getConformanceType());
}
+ else if (auto genericWitness = as<IRGeneric>(witness))
+ {
+ // This is a generic witness table.
+ auto innerWitness = getGenericReturnVal(genericWitness);
+ SLANG_ASSERT(as<IRWitnessTableType>(innerWitness->getDataType()));
+ return getConformanceTypeFromWitness(innerWitness);
+ }
else
{
SLANG_UNEXPECTED("Unexpected witness type");
@@ -558,81 +1158,134 @@ IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWit
return diffInterfaceType;
}
+List<IRDifferentiableTypeAnnotation*> DifferentiableTypeConformanceContext::getAnnotations(
+ IRGlobalValueWithCode* code)
+{
+ // Scan function for all IRDifferentiableTypeAnnotation insts.
+ List<IRDifferentiableTypeAnnotation*> annotations;
+ for (auto block : code->getBlocks())
+ {
+ for (auto child : block->getChildren())
+ {
+ if (auto annotation = as<IRDifferentiableTypeAnnotation>(child))
+ {
+ annotations.add(annotation);
+ }
+ }
+ }
+
+ return annotations;
+}
+
+List<IRDifferentiableTypeAnnotation*> DifferentiableTypeConformanceContext::getAnnotations(
+ IRModuleInst* module)
+{
+ // Scan module for all IRDifferentiableTypeAnnotation insts.
+ List<IRDifferentiableTypeAnnotation*> annotations;
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (auto annotation = as<IRDifferentiableTypeAnnotation>(globalInst))
+ {
+ annotations.add(annotation);
+ }
+ }
+
+ return annotations;
+}
+
void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
parentFunc = func;
+ List<IRDifferentiableTypeAnnotation*> annotations = getAnnotations(func);
- auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- SLANG_RELEASE_ASSERT(decor);
-
- // Build lookup dictionary for type witnesses.
- for (auto child = decor->getFirstChild(); child; child = child->next)
+ // Go up the parents of func & add the annotations of any IRGeneric or IRModule parent:
+ IRInst* parent = func;
+ while (parent)
{
- if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
+ if (auto upperFunc = as<IRGlobalValueWithCode>(parent))
{
- IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness());
+ // TODO: Cache this.
+ auto parentAnnotations = getAnnotations(upperFunc);
+ annotations.addRange(parentAnnotations);
+ }
+ else if (auto module = as<IRModuleInst>(parent))
+ {
+ // TODO: Cache this.
+ auto parentAnnotations = getAnnotations(module);
+ annotations.addRange(parentAnnotations);
+ }
+ parent = parent->getParent();
+ }
- SLANG_ASSERT(
- diffInterfaceType == sharedContext->differentiableInterfaceType ||
- diffInterfaceType == sharedContext->differentiablePtrInterfaceType);
+ for (auto item : annotations)
+ {
+ IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness());
- auto existingItem =
- differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType());
- if (existingItem)
- {
- *existingItem = item->getWitness();
- }
- else
- {
- auto witness = item->getWitness();
+ SLANG_ASSERT(
+ diffInterfaceType == sharedContext->differentiableInterfaceType ||
+ diffInterfaceType == sharedContext->differentiablePtrInterfaceType);
- // Also register the type's differential type with the same witness.
- auto concreteType = item->getConcreteType();
- IRBuilder subBuilder(item->getConcreteType());
- if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType))
+ auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getBaseType());
+ if (existingItem)
+ {
+ *existingItem = item->getWitness();
+ }
+ else
+ {
+ auto witness = item->getWitness();
+
+ // Also register the type's differential type with the same witness.
+ auto concreteType = item->getBaseType();
+ IRBuilder subBuilder(item->getBaseType());
+ if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType))
+ {
+ // For tuple types with concrete element types,
+ // register the differential type for each element, but don't register for the
+ // tuple/typepack itself.
+ if (auto witnessPack = as<IRMakeWitnessPack>(witness))
{
- // For tuple types with concrete element types,
- // register the differential type for each element, but don't register for the
- // tuple/typepack itself.
- if (auto witnessPack = as<IRMakeWitnessPack>(witness))
+
+ for (UInt i = 0; i < concreteType->getOperandCount(); i++)
{
+ auto element = concreteType->getOperand(i);
+ auto elementWitness = witnessPack->getOperand(i);
- for (UInt i = 0; i < concreteType->getOperandCount(); i++)
- {
- auto element = concreteType->getOperand(i);
- auto elementWitness = witnessPack->getOperand(i);
-
- if (diffInterfaceType == sharedContext->differentiableInterfaceType)
- addTypeToDictionary((IRType*)element, elementWitness);
- else if (
- diffInterfaceType == sharedContext->differentiablePtrInterfaceType)
- addTypeToDictionary((IRType*)element, elementWitness);
- }
- return;
+ if (diffInterfaceType == sharedContext->differentiableInterfaceType)
+ addTypeToDictionary((IRType*)element, elementWitness);
+ else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType)
+ addTypeToDictionary((IRType*)element, elementWitness);
}
+ return;
}
+ }
- addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness());
+ addTypeToDictionary((IRType*)item->getBaseType(), item->getWitness());
- if (!as<IRInterfaceType>(item->getConcreteType()))
- {
- addTypeToDictionary(
- (IRType*)_lookupWitness(
- &subBuilder,
- item->getWitness(),
- sharedContext->differentialAssocTypeStructKey,
- subBuilder.getTypeKind()),
- item->getWitness());
- }
+ // TODO: Is this really needed?
+ if (!as<IRInterfaceType>(item->getBaseType()) &&
+ !as<IRAssociatedType>(item->getBaseType()))
+ {
+ addTypeToDictionary(
+ (IRType*)_lookupWitness(
+ &subBuilder,
+ item->getWitness(),
+ sharedContext->differentialAssocTypeStructKey,
+ subBuilder.getTypeKind()),
+ item->getWitness());
+ }
- if (auto diffPairType = as<IRDifferentialPairTypeBase>(item->getConcreteType()))
- {
- // For differential pair types, register the differential type as well.
- IRBuilder builder(diffPairType);
- builder.setInsertAfter(diffPairType->getWitness());
+ // TODO: Is this really needed?
+ if (auto diffPairType = as<IRDifferentialPairTypeBase>(item->getBaseType()))
+ {
+ // For differential pair types, register the differential type as well.
+ IRBuilder builder(diffPairType);
+ builder.setInsertAfter(diffPairType->getWitness());
- // TODO(sai): lot of this logic is duplicated. need to refactor.
+ // TODO(sai): lot of this logic is duplicated. need to refactor.
+ if (!as<IRInterfaceType>(diffPairType->getValueType()) &&
+ !as<IRAssociatedType>(diffPairType->getValueType()))
+ {
auto diffType =
(diffInterfaceType == sharedContext->differentiableInterfaceType)
? _lookupWitness(
@@ -665,12 +1318,28 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
}
}
+IRWitnessTable* findGlobalWitness(IRInterfaceType* interface, IRInst* type)
+{
+ for (auto use = type->firstUse; use; use = use->nextUse)
+ {
+ if (auto witnessTable = as<IRWitnessTable>(use->getUser()))
+ {
+ if (witnessTable->getConcreteType() == type &&
+ witnessTable->getConformanceType() == interface)
+ return witnessTable;
+ }
+ }
+
+ return nullptr;
+}
+
IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(
IRInst* type,
DiffConformanceKind kind)
{
IRInst* foundResult = nullptr;
differentiableTypeWitnessDictionary.tryGetValue(type, foundResult);
+
if (!foundResult)
return nullptr;
@@ -791,8 +1460,8 @@ IRInst* DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface
return nullptr;
}
-// Given an interface type, return the lookup path from a witness table of `type` to a witness table
-// of `supType`.
+// Given an interface type, return the lookup path from a witness table of `type` to a witness
+// table of `supType`.
static bool _findInterfaceLookupPathImpl(
HashSet<IRInst*>& processedTypes,
IRInterfaceType* supType,
@@ -967,6 +1636,11 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
addTypeToDictionary(pairType->getValueType(), pairType->getWitness());
}
+
+ if (auto annotation = as<IRDifferentiableTypeAnnotation>(globalInst))
+ {
+ addTypeToDictionary((IRType*)annotation->getBaseType(), annotation->getWitness());
+ }
}
}
@@ -1071,6 +1745,20 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(
}
}
+IRType* getAssociatedTypeForKey(IRInst* key)
+{
+ for (auto use = key->firstUse; use; use = use->nextUse)
+ {
+ if (auto interfaceReq = as<IRInterfaceRequirementEntry>(key))
+ {
+ if (auto assocType = as<IRAssociatedType>(interfaceReq->getRequirementVal()))
+ return assocType;
+ }
+ }
+
+ return nullptr;
+}
+
IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(
IRBuilder* builder,
IRInst* primalType,
@@ -1118,8 +1806,9 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(
}
else if (auto lookup = as<IRLookupWitnessMethod>(primalType))
{
- // For types that are lookups from a table, we can simply lookup the witness from the same
- // table
+ // Trivial cases: For types that are lookups from a table, we can simply lookup the
+ // witness from the same table
+ //
if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey)
{
witness = builder->emitLookupInterfaceMethodInst(
@@ -1203,8 +1892,8 @@ IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
auto p0 = b.emitParam(diffDiffPairType);
auto p1 = b.emitParam(diffDiffPairType);
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value
- // type == diff type.
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that
+ // value type == diff type.
auto innerAdd = _lookupWitness(
&b,
innerWitness,
@@ -1325,8 +2014,8 @@ IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
auto p0 = b.emitParam(diffArrayType);
auto p1 = b.emitParam(diffArrayType);
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value
- // type == diff type.
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that
+ // value type == diff type.
auto innerAdd = _lookupWitness(
&b,
innerWitness,
@@ -1566,6 +2255,143 @@ IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness
return nullptr;
}
+IRInst* DifferentiableTypeConformanceContext::emitDAddOfDiffInstType(
+ IRBuilder* builder,
+ IRType* primalType,
+ IRInst* op1,
+ IRInst* op2)
+{
+ if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ // TODO: This case should really not be necessary anymore
+ auto diffElementType =
+ (IRType*)this->getDifferentialForType(builder, arrayType->getElementType());
+ SLANG_RELEASE_ASSERT(diffElementType);
+ auto arraySize = arrayType->getElementCount();
+
+ if (auto constArraySize = as<IRIntLit>(arraySize))
+ {
+ List<IRInst*> args;
+ for (IRIntegerValue i = 0; i < constArraySize->getValue(); i++)
+ {
+ auto index = builder->getIntValue(builder->getIntType(), i);
+ auto op1Val = builder->emitElementExtract(diffElementType, op1, index);
+ auto op2Val = builder->emitElementExtract(diffElementType, op2, index);
+ args.add(
+ emitDAddOfDiffInstType(builder, arrayType->getElementType(), op1Val, op2Val));
+ }
+ auto diffArrayType =
+ builder->getArrayType(diffElementType, arrayType->getElementCount());
+ return builder->emitMakeArray(diffArrayType, (UInt)args.getCount(), args.getBuffer());
+ }
+ else
+ {
+ // TODO: insert a runtime loop here.
+ SLANG_UNIMPLEMENTED_X("dadd of dynamic array.");
+ }
+ }
+ else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
+ {
+ // TODO: This case should really not be necessary anymore
+ auto diffType = (IRType*)this->getDiffTypeFromPairType(builder, diffPairUserType);
+ auto diffWitness = this->getDiffTypeWitnessFromPairType(builder, diffPairUserType);
+
+ auto primal1 = builder->emitDifferentialPairGetPrimalUserCode(op1);
+ auto primal2 = builder->emitDifferentialPairGetPrimalUserCode(op2);
+ auto primal =
+ emitDAddOfDiffInstType(builder, diffPairUserType->getValueType(), primal1, primal2);
+
+ auto diff1 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op1);
+ auto diff2 = builder->emitDifferentialPairGetDifferentialUserCode(diffType, op2);
+ auto diff = emitDAddOfDiffInstType(builder, diffType, diff1, diff2);
+
+ auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
+ return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff);
+ }
+ else if (as<IRInterfaceType>(primalType))
+ {
+ // If our type is existential, we need to handle the case where
+ // one or both of our operands are null-type.
+ //
+ return emitDAddForExistentialType(builder, primalType, op1, op2);
+ }
+ else if (as<IRAssociatedType>(primalType))
+ {
+ // Should not happen. associated type does not have any additional info, we can't
+ // lookup the necessary methods.
+ //
+ SLANG_UNEXPECTED("unexpected associated type during transposition");
+ }
+
+ auto addMethod = this->getAddMethodForType(builder, primalType);
+
+ // Should exist.
+ SLANG_ASSERT(addMethod);
+
+ return builder->emitCallInst(
+ (IRType*)this->getDifferentialForType(builder, primalType),
+ addMethod,
+ List<IRInst*>(op1, op2));
+}
+
+IRInst* DifferentiableTypeConformanceContext::emitDAddForExistentialType(
+ IRBuilder* builder,
+ IRType* primalType,
+ IRInst* op1,
+ IRInst* op2)
+{
+ return builder->emitCallInst(
+ (IRType*)this->getDifferentialForType(builder, primalType),
+ this->getOrCreateExistentialDAddMethod(),
+ List<IRInst*>({op1, op2}));
+}
+
+IRInst* DifferentiableTypeConformanceContext::emitDZeroOfDiffInstType(
+ IRBuilder* builder,
+ IRType* primalType)
+{
+ if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ // TODO: This case should really not be necessary anymore
+ auto diffElementType =
+ (IRType*)this->getDifferentialForType(builder, arrayType->getElementType());
+ SLANG_RELEASE_ASSERT(diffElementType);
+ auto diffArrayType = builder->getArrayType(diffElementType, arrayType->getElementCount());
+ auto diffElementZero = emitDZeroOfDiffInstType(builder, arrayType->getElementType());
+ return builder->emitMakeArrayFromElement(diffArrayType, diffElementZero);
+ }
+ else if (auto diffPairUserType = as<IRDifferentialPairUserCodeType>(primalType))
+ {
+ // TODO: This case should really not be necessary anymore.
+ auto primalZero = emitDZeroOfDiffInstType(builder, diffPairUserType->getValueType());
+ auto diffZero = primalZero;
+ auto diffType = primalZero->getFullType();
+ auto diffWitness = this->getDiffTypeWitnessFromPairType(builder, diffPairUserType);
+ auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness);
+ return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primalZero, diffZero);
+ }
+ else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType))
+ {
+ // Pack a null value into an existential type.
+ auto existentialZero = builder->emitMakeExistential(
+ this->sharedContext->differentiableInterfaceType,
+ this->emitNullDifferential(builder),
+ this->sharedContext->nullDifferentialWitness);
+
+ return existentialZero;
+ }
+
+ auto zeroMethod = this->getZeroMethodForType(builder, primalType);
+
+ // Should exist.
+ SLANG_ASSERT(zeroMethod);
+
+ return builder->emitCallInst(
+ (IRType*)this->getDifferentialForType(builder, primalType),
+ zeroMethod,
+ List<IRInst*>());
+}
+
void copyCheckpointHints(
IRBuilder* builder,
IRGlobalValueWithCode* oldInst,
@@ -1883,6 +2709,7 @@ struct AutoDiffPass : public InstPassBase
{
bool result = false;
OrderedHashSet<IRInst*> loweredIntermediateTypes;
+ Dictionary<IRInst*, IRGlobalValueWithCode*> typeToBwdFuncMap;
// Replace all `BackwardDiffIntermediateContextType` insts with the struct type
// that we generated during backward diff pass.
@@ -1906,6 +2733,38 @@ struct AutoDiffPass : public InstPassBase
if (type)
{
loweredIntermediateTypes.add(type);
+
+ auto func = differentiateInst->getFunc();
+
+ if (auto spec = as<IRSpecialize>(func))
+ func = spec->getBase();
+
+ if (auto generic = as<IRGeneric>(func))
+ {
+ func =
+ cast<IRGlobalValueWithCode>(findGenericReturnVal(generic));
+
+ auto bwdFuncDecor = func->findDecoration<
+ IRBackwardDerivativePropagateDecoration>();
+
+ typeToBwdFuncMap.add(
+ type,
+ cast<IRGlobalValueWithCode>(
+ as<IRSpecialize>(
+ bwdFuncDecor->getBackwardDerivativePropagateFunc())
+ ->getBase()));
+ }
+ else
+ {
+ auto bwdFuncDecor = func->findDecoration<
+ IRBackwardDerivativePropagateDecoration>();
+
+ typeToBwdFuncMap.add(
+ type,
+ cast<IRGlobalValueWithCode>(
+ bwdFuncDecor->getBackwardDerivativePropagateFunc()));
+ }
+
inst->replaceUsesWith(type);
inst->removeAndDeallocate();
changed = true;
@@ -1922,7 +2781,9 @@ struct AutoDiffPass : public InstPassBase
}
// Now we generate the differential type for the intermediate context type
// to allow higher order differentiation.
- generateDifferentialImplementationForContextType(loweredIntermediateTypes);
+ generateDifferentialImplementationForContextType(
+ loweredIntermediateTypes,
+ typeToBwdFuncMap);
return result;
}
@@ -1977,22 +2838,13 @@ struct AutoDiffPass : public InstPassBase
IRInst* addMethod = nullptr;
};
- // Register the differential type for an intermediate context type to the derivative functions
- // that uses the type.
+ // Register the differential type for an intermediate context type to the derivative
+ // functions that uses the type.
void registerDiffContextType(
IRBuilder& builder,
- IRDifferentiableTypeDictionaryDecoration* diffDecor,
OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes,
IRInst* origType)
{
- HashSet<IRInst*> registeredType;
- for (auto entry : diffDecor->getChildren())
- {
- if (auto e = as<IRDifferentiableTypeDictionaryItem>(entry))
- {
- registeredType.add(e->getOperand(0));
- }
- }
// Use a work list to recursively walk through all sub fields of the struct type.
List<IRInst*> wlist;
wlist.add(origType);
@@ -2002,10 +2854,13 @@ struct AutoDiffPass : public InstPassBase
IntermediateContextTypeDifferentialInfo diffInfo;
if (!diffTypes.tryGetValue(t, diffInfo))
continue;
- if (registeredType.add(t))
- builder.addDifferentiableTypeEntry(diffDecor, t, diffInfo.diffWitness);
- else
- continue;
+
+ IRInst* args[] = {t, diffInfo.diffWitness};
+ builder.emitIntrinsicInst(
+ builder.getVoidType(),
+ kIROp_DifferentiableTypeAnnotation,
+ 2,
+ args);
if (auto structType = as<IRStructType>(getResolvedInstForDecorations(t)))
{
@@ -2017,7 +2872,9 @@ struct AutoDiffPass : public InstPassBase
}
}
- void generateDifferentialImplementationForContextType(OrderedHashSet<IRInst*>& contextTypes)
+ void generateDifferentialImplementationForContextType(
+ OrderedHashSet<IRInst*>& contextTypes,
+ Dictionary<IRInst*, IRGlobalValueWithCode*> typeToBwdFuncMap)
{
// First we are going to topology sort all intermediate context types.
OrderedHashSet<IRInst*> sortedContextTypes;
@@ -2043,6 +2900,10 @@ struct AutoDiffPass : public InstPassBase
IRBuilder builder(module);
for (auto t : sortedContextTypes)
{
+ auto func = typeToBwdFuncMap[t];
+ DifferentiableTypeConformanceContext ctx(this->autodiffContext);
+ ctx.setFunc(func);
+
if (t->getOp() == kIROp_Generic || t->getOp() == kIROp_StructType)
{
// For generics/struct types, we will generate a new generic/struct type
@@ -2050,7 +2911,7 @@ struct AutoDiffPass : public InstPassBase
SLANG_RELEASE_ASSERT(t->getParent() && t->getParent()->getOp() == kIROp_Module);
builder.setInsertBefore(t);
- auto diffInfo = fillDifferentialTypeImplementation(diffTypes, t);
+ auto diffInfo = fillDifferentialTypeImplementation(&ctx, diffTypes, t);
diffTypes[t] = diffInfo;
}
else if (auto specialize = as<IRSpecialize>(t))
@@ -2085,30 +2946,29 @@ struct AutoDiffPass : public InstPassBase
// function without a intermediate-type via an interface.
SLANG_RELEASE_ASSERT(diffTypes.containsKey(t));
}
- }
- // Register the differential types into the conformance dictionaries of the functions that
- // uses them.
- for (auto t : diffTypes)
- {
+ if (!diffTypes.containsKey(t))
+ continue;
+
+ // If we created a new differential type, we need to place into the contexts of all
+ // functions that use it.
+ //
HashSet<IRFunc*> registeredFuncs;
- for (auto use = t.key->firstUse; use; use = use->nextUse)
+ for (auto use = t->firstUse; use; use = use->nextUse)
{
auto parentFunc = getParentFunc(use->getUser());
if (!parentFunc)
continue;
if (!registeredFuncs.add(parentFunc))
continue;
- if (auto dictDecor =
- parentFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- registerDiffContextType(builder, dictDecor, diffTypes, t.key);
- }
+
+ registerDiffContextType(builder, diffTypes, t);
}
}
}
IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementationForStruct(
+ DifferentiableTypeConformanceContext* ctx,
OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes,
IRStructType* originalType,
IRStructType* diffType)
@@ -2122,6 +2982,7 @@ struct AutoDiffPass : public InstPassBase
// Generate the fields for all differentiable members of the original struct type.
struct FieldInfo
{
+ IRType* primalType;
IRStructField* field;
IRInst* witness;
};
@@ -2130,30 +2991,30 @@ struct AutoDiffPass : public InstPassBase
for (auto field : originalType->getFields())
{
IRInst* diffFieldWitness = nullptr;
- if (auto diffDecor =
- field->findDecoration<IRIntermediateContextFieldDifferentialTypeDecoration>())
- {
- diffFieldWitness = diffDecor->getDifferentialWitness();
- }
- else
+
+ diffFieldWitness = ctx->tryGetDifferentiableWitness(
+ &builder,
+ field->getFieldType(),
+ DiffConformanceKind::Value);
+
+ if (!diffFieldWitness)
{
IntermediateContextTypeDifferentialInfo diffFieldTypeInfo;
diffTypes.tryGetValue(field->getFieldType(), diffFieldTypeInfo);
diffFieldWitness = diffFieldTypeInfo.diffWitness;
}
+
if (diffFieldWitness)
{
FieldInfo info;
IRBuilder keyBuilder = builder;
keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType));
auto diffKey = keyBuilder.createStructKey();
- auto diffFieldType = _lookupWitness(
- &keyBuilder,
- diffFieldWitness,
- autodiffContext->differentialAssocTypeStructKey,
- builder.getTypeKind());
+ auto diffFieldType = ctx->getDifferentialForType(&builder, field->getFieldType());
+
info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType);
info.witness = diffFieldWitness;
+ info.primalType = field->getFieldType();
builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey);
builder.addDecoration(diffKey, kIROp_DerivativeMemberDecoration, diffKey);
diffFields.add(info);
@@ -2172,16 +3033,10 @@ struct AutoDiffPass : public InstPassBase
builder.setInsertInto(zeroMethod);
builder.emitBlock();
List<IRInst*> fieldVals;
+
for (auto info : diffFields)
{
- auto innerZeroMethod = _lookupWitness(
- &builder,
- info.witness,
- autodiffContext->zeroMethodStructKey,
- autodiffContext->zeroMethodType);
- IRInst* val =
- builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr);
- fieldVals.add(val);
+ fieldVals.add(ctx->emitDZeroOfDiffInstType(&builder, info.primalType));
}
builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals));
}
@@ -2203,20 +3058,15 @@ struct AutoDiffPass : public InstPassBase
List<IRInst*> fieldVals;
for (auto info : diffFields)
{
- auto innerAddMethod = _lookupWitness(
- &builder,
- info.witness,
- autodiffContext->addMethodStructKey,
- autodiffContext->addMethodType);
IRInst* args[2] = {
builder
.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()),
builder
.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()),
};
- IRInst* val =
- builder.emitCallInst(info.field->getFieldType(), innerAddMethod, 2, args);
- fieldVals.add(val);
+
+ fieldVals.add(
+ ctx->emitDAddOfDiffInstType(&builder, info.primalType, args[0], args[1]));
}
builder.emitReturn(builder.emitMakeStruct(diffType, fieldVals));
}
@@ -2265,6 +3115,7 @@ struct AutoDiffPass : public InstPassBase
}
IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementation(
+ DifferentiableTypeConformanceContext* ctx,
OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes,
IRInst* originalType)
{
@@ -2274,6 +3125,7 @@ struct AutoDiffPass : public InstPassBase
builder.setInsertBefore(originalType);
auto diffType = builder.createStructType();
return fillDifferentialTypeImplementationForStruct(
+ ctx,
diffTypes,
as<IRStructType>(originalType),
as<IRStructType>(diffType));
@@ -2286,7 +3138,7 @@ struct AutoDiffPass : public InstPassBase
auto structType = as<IRStructType>(findGenericReturnVal(genType));
SLANG_RELEASE_ASSERT(structType);
- auto innerResult = fillDifferentialTypeImplementation(diffTypes, structType);
+ auto innerResult = fillDifferentialTypeImplementation(ctx, diffTypes, structType);
IRBuilder builder(originalType);
builder.setInsertBefore(originalType);
@@ -2421,7 +3273,8 @@ struct AutoDiffPass : public InstPassBase
{
bool changed = false;
List<IRInst*> autoDiffWorkList;
- // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call graph.
+ // Collect all `ForwardDifferentiate`/`BackwardDifferentiate` insts from the call
+ // graph.
processAllReachableInsts(
[&](IRInst* inst)
{
@@ -2438,6 +3291,7 @@ struct AutoDiffPass : public InstPassBase
case kIROp_Func:
case kIROp_Specialize:
case kIROp_LookupWitness:
+ case kIROp_Generic:
if (auto innerFunc =
as<IRFunc>(getResolvedInstForDecorations(inst->getOperand(0))))
{
@@ -2519,8 +3373,8 @@ struct AutoDiffPass : public InstPassBase
}
// Run transcription logic to generate the body of forward/backward derivatives
- // functions. While doing so, we may discover new functions to differentiate, so we keep
- // running until the worklist goes dry.
+ // functions. While doing so, we may discover new functions to differentiate, so we
+ // keep running until the worklist goes dry.
List<IRFunc*> autodiffCleanupList;
while (autodiffContext->followUpFunctionsToTranscribe.getCount() != 0)
{
@@ -2582,10 +3436,10 @@ struct AutoDiffPass : public InstPassBase
hasChanges = true;
// We have done transcribing the functions, now it is time to demote all
- // DifferentialPair types and their operations down to DifferentialPairUserCodeType and
- // *UserCode operations so they can be treated just like normal types with no special
- // semantics in future processing, and won't be confused with the semantics of a
- // DifferentialPair type during future autodiff code gen.
+ // DifferentialPair types and their operations down to DifferentialPairUserCodeType
+ // and *UserCode operations so they can be treated just like normal types with no
+ // special semantics in future processing, and won't be confused with the semantics
+ // of a DifferentialPair type during future autodiff code gen.
rewriteDifferentialPairToUserCode(module);
hasChanges |= changed;
@@ -2693,8 +3547,8 @@ void checkAutodiffPatterns(TargetProgram* target, IRModule* module, DiagnosticSi
if (func->sourceLoc.isValid() && // Don't diagnose for synthesized functions
func->findDecoration<IRPreferRecomputeDecoration>())
{
- // If we don't have any side-effect behavior, we should warn (note: read-none is a
- // stronger guarantee than no-side-effect)
+ // If we don't have any side-effect behavior, we should warn (note: read-none is
+ // a stronger guarantee than no-side-effect)
//
if (func->findDecoration<IRNoSideEffectDecoration>() ||
func->findDecoration<IRReadNoneDecoration>())
@@ -2759,6 +3613,27 @@ void removeDetachInsts(IRModule* module)
pass.processModule();
}
+
+struct RemoveTypeAnnotationInstsPass : InstPassBase
+{
+ RemoveTypeAnnotationInstsPass(IRModule* module)
+ : InstPassBase(module)
+ {
+ }
+ void processModule()
+ {
+ processInstsOfType<IRDifferentiableTypeAnnotation>(
+ kIROp_DifferentiableTypeAnnotation,
+ [&](IRDifferentiableTypeAnnotation* annotation) { annotation->removeAndDeallocate(); });
+ }
+};
+
+void removeTypeAnnotations(IRModule* module)
+{
+ RemoveTypeAnnotationInstsPass pass(module);
+ pass.processModule();
+}
+
struct LowerNullCheckPass : InstPassBase
{
LowerNullCheckPass(IRModule* module, AutoDiffSharedContext* context)
@@ -2841,6 +3716,8 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module)
removeDetachInsts(module);
+ removeTypeAnnotations(module);
+
lowerNullCheckInsts(module, &autodiffContext);
stripNoDiffTypeAttribute(module);