summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp793
1 files changed, 527 insertions, 266 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 07a6a76fb..94a605a68 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -25,7 +25,7 @@ bool isBackwardDifferentiableFunc(IRInst* func)
return false;
}
-IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey, IRType* resultType = nullptr)
{
if (auto witnessTable = as<IRWitnessTable>(witness))
{
@@ -53,15 +53,16 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
}
else
{
+ SLANG_ASSERT(resultType);
return builder->emitLookupInterfaceMethodInst(
- builder->getTypeKind(),
+ resultType,
witness,
requirementKey);
}
return nullptr;
}
-static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
+static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
auto witness = type->getWitness();
SLANG_RELEASE_ASSERT(witness);
@@ -70,16 +71,48 @@ static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRB
if (as<IRInterfaceType>(type->getValueType()) || as<IRAssociatedType>(type->getValueType()))
{
// The differential type is the IDifferentiable interface type.
- return sharedContext->differentiableInterfaceType;
+ if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type))
+ return sharedContext->differentiableInterfaceType;
+ else if (as<IRDifferentialPtrPairType>(type))
+ return sharedContext->differentiablePtrInterfaceType;
+ else
+ SLANG_UNEXPECTED("Unexpected differential pair type");
}
- return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey);
+ if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type))
+ return _lookupWitness(
+ builder,
+ witness,
+ sharedContext->differentialAssocTypeStructKey,
+ builder->getTypeKind());
+ else if (as<IRDifferentialPtrPairType>(type))
+ return _lookupWitness(
+ builder,
+ witness,
+ sharedContext->differentialAssocRefTypeStructKey,
+ builder->getTypeKind());
+ else
+ SLANG_UNEXPECTED("Unexpected differential pair type");
}
static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
+
+ if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type))
+ return _lookupWitness(
+ builder,
+ witnessTable,
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ sharedContext->differentialAssocTypeWitnessTableType);
+ else if (as<IRDifferentialPtrPairType>(type))
+ return _lookupWitness(
+ builder,
+ witnessTable,
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ sharedContext->differentialAssocRefTypeWitnessTableType);
+ else
+ SLANG_UNEXPECTED("Unexpected differential pair type");
}
bool isNoDiffType(IRType* paramType)
@@ -320,6 +353,24 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
return result;
}
+IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst)
+{
+ for (auto inst : moduleInst->getGlobalInsts())
+ {
+ if (auto interfaceType = as<IRInterfaceType>(inst))
+ {
+ if (auto decor = interfaceType->findDecoration<IRNameHintDecoration>())
+ {
+ if (decor->getName() == "IDifferentiablePtrType")
+ {
+ return interfaceType;
+ }
+ }
+ }
+ }
+ return nullptr;
+}
+
AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst* inModuleInst)
: moduleInst(inModuleInst), targetProgram(target)
{
@@ -328,14 +379,27 @@ AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst
{
differentialAssocTypeStructKey = findDifferentialTypeStructKey();
differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
+ differentialAssocTypeWitnessTableType = findDifferentialTypeWitnessTableType();
zeroMethodStructKey = findZeroMethodStructKey();
+ zeroMethodType = cast<IRFuncType>(getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementVal());
addMethodStructKey = findAddMethodStructKey();
+ addMethodType = cast<IRFuncType>(getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementVal());
mulMethodStructKey = findMulMethodStructKey();
nullDifferentialStructType = findNullDifferentialStructType();
nullDifferentialWitness = findNullDifferentialWitness();
- if (differentialAssocTypeStructKey)
- isInterfaceAvailable = true;
+ isInterfaceAvailable = true;
+ }
+
+ differentiablePtrInterfaceType = as<IRInterfaceType>(findDifferentiableRefInterface(inModuleInst));
+
+ if (differentiablePtrInterfaceType)
+ {
+ differentialAssocRefTypeStructKey = findDifferentialPtrTypeStructKey();
+ differentialAssocRefTypeWitnessStructKey = findDifferentialPtrTypeWitnessStructKey();
+ differentialAssocRefTypeWitnessTableType = findDifferentialPtrTypeWitnessTableType();
+
+ isPtrInterfaceAvailable = true;
}
}
@@ -404,14 +468,14 @@ IRInst* AutoDiffSharedContext::findNullDifferentialWitness()
}
-IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index)
+IRInterfaceRequirementEntry* AutoDiffSharedContext::getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index)
{
- if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
+ if (as<IRModuleInst>(moduleInst) && interface)
{
// Assume for now that IDifferentiable has exactly five fields.
- SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
- if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
- return as<IRStructKey>(entry->getRequirementKey());
+ // SLANG_ASSERT(interface->getOperandCount() == 5);
+ if (auto entry = as<IRInterfaceRequirementEntry>(interface->getOperand(index)))
+ return entry;
else
{
SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
@@ -421,6 +485,50 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde
return nullptr;
}
+// Extracts conformance interface from a witness inst while accounting for some
+// quirks in the type system around interfaces that conform to other interfaces.
+//
+IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWitness(IRInst* witness)
+{
+ IRInterfaceType* diffInterfaceType = nullptr;
+ if (auto witnessTableType = as<IRWitnessTableType>(witness->getDataType()))
+ {
+ diffInterfaceType = cast<IRInterfaceType>(witnessTableType->getConformanceType());
+ }
+ else if (auto structKey = as<IRStructKey>(witness))
+ {
+ // We currently assume that a struct key is used uniquely for a single interface-requirement-entry.
+ // Find that entry
+ for (IRUse* use = structKey->firstUse; use; use = use->nextUse)
+ {
+ if (auto entry = as<IRInterfaceRequirementEntry>(use->getUser()))
+ {
+ auto innerWitnessTableType = cast<IRWitnessTableType>(entry->getRequirementVal());
+ diffInterfaceType = cast<IRInterfaceType>(innerWitnessTableType->getConformanceType());
+ break;
+ }
+ }
+ }
+ else if (auto interfaceRequirementEntry = as<IRInterfaceRequirementEntry>(witness))
+ {
+ auto innerWitnessTableType = cast<IRWitnessTableType>(interfaceRequirementEntry->getRequirementVal());
+ diffInterfaceType = cast<IRInterfaceType>(innerWitnessTableType->getConformanceType());
+ }
+ else if (auto tupleType = as<IRTupleType>(witness->getDataType()))
+ {
+ SLANG_ASSERT(tupleType->getOperandCount() >= 1);
+ auto operand = tupleType->getOperand(0);
+ auto innerWitnessTableType = cast<IRWitnessTableType>(operand);
+ return cast<IRInterfaceType>(innerWitnessTableType->getConformanceType());
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unexpected witness type");
+ }
+
+ return diffInterfaceType;
+}
+
void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
parentFunc = func;
@@ -434,7 +542,13 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
{
- auto existingItem = differentiableWitnessDictionary.tryGetValue(item->getConcreteType());
+ IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness());
+
+ SLANG_ASSERT(
+ diffInterfaceType == sharedContext->differentiableInterfaceType
+ || diffInterfaceType == sharedContext->differentiablePtrInterfaceType);
+
+ auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType());
if (existingItem)
{
*existingItem = item->getWitness();
@@ -458,20 +572,26 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
auto element = concreteType->getOperand(i);
auto elementWitness = witnessPack->getOperand(i);
- differentiableWitnessDictionary.addIfNotExists(
- (IRType*)element,
- _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey));
+
+ if (diffInterfaceType == sharedContext->differentiableInterfaceType)
+ addTypeToDictionary(
+ (IRType*)element,
+ elementWitness);
+ else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType)
+ addTypeToDictionary(
+ (IRType*)element,
+ elementWitness);
}
return;
}
}
- differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness());
+ addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness());
if (!as<IRInterfaceType>(item->getConcreteType()))
{
- differentiableWitnessDictionary.addIfNotExists(
- (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey),
+ addTypeToDictionary(
+ (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey, subBuilder.getTypeKind()),
item->getWitness());
}
@@ -480,29 +600,55 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
// For differential pair types, register the differential type as well.
IRBuilder builder(diffPairType);
builder.setInsertAfter(diffPairType->getWitness());
- auto diffType = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey);
- auto diffWitness = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeWitnessStructKey);
- if (diffType && diffWitness)
- {
- differentiableWitnessDictionary.addIfNotExists((IRType*)diffType, diffWitness);
- }
+
+ // TODO(sai): lot of this logic is duplicated. need to refactor.
+ auto diffType = (diffInterfaceType == sharedContext->differentiableInterfaceType) ?
+ _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey, builder.getTypeKind()) :
+ _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocRefTypeStructKey, builder.getTypeKind());
+ auto diffWitness = (diffInterfaceType == sharedContext->differentiableInterfaceType) ?
+ _lookupWitness(
+ &builder,
+ diffPairType->getWitness(),
+ sharedContext->differentialAssocTypeWitnessStructKey,
+ sharedContext->differentialAssocTypeWitnessTableType) :
+ _lookupWitness(
+ &builder,
+ diffPairType->getWitness(),
+ sharedContext->differentialAssocRefTypeWitnessStructKey,
+ sharedContext->differentialAssocRefTypeWitnessTableType);
+
+ addTypeToDictionary((IRType*)diffType, diffWitness);
}
}
}
}
}
-IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type)
+IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type, DiffConformanceKind kind)
{
IRInst* foundResult = nullptr;
- differentiableWitnessDictionary.tryGetValue(type, foundResult);
- return foundResult;
+ differentiableTypeWitnessDictionary.tryGetValue(type, foundResult);
+ if (!foundResult)
+ return nullptr;
+
+ if (kind == DiffConformanceKind::Any)
+ return foundResult;
+
+ if (auto baseType = getConformanceTypeFromWitness(foundResult))
+ {
+ if (baseType == sharedContext->differentiableInterfaceType && kind == DiffConformanceKind::Value)
+ return foundResult;
+ else if (baseType == sharedContext->differentiablePtrInterfaceType && kind == DiffConformanceKind::Ptr)
+ return foundResult;
+ }
+
+ return nullptr;
}
-IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
+IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType)
{
- if (auto conformance = tryGetDifferentiableWitness(builder, origType))
- return _lookupWitness(builder, conformance, key);
+ if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any))
+ return _lookupWitness(builder, conformance, key, resultType);
return nullptr;
}
@@ -514,7 +660,7 @@ IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairTyp
IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
- return _getDiffTypeFromPairType(sharedContext, builder, type);
+ return this->differentiateType(builder, type->getValueType());
}
IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
@@ -525,20 +671,34 @@ IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRB
IRInst* DifferentiableTypeConformanceContext::getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey);
+ return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType);
}
IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
{
auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey);
+ return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey, sharedContext->addMethodType);
+}
+
+void DifferentiableTypeConformanceContext::addTypeToDictionary(IRType* type, IRInst* witness)
+{
+ auto conformanceType = getConformanceTypeFromWitness(witness);
+
+ if (!sharedContext->isInterfaceAvailable && !sharedContext->isPtrInterfaceAvailable)
+ return;
+
+ SLANG_ASSERT(
+ conformanceType == sharedContext->differentiableInterfaceType ||
+ conformanceType == sharedContext->differentiablePtrInterfaceType);
+
+ differentiableTypeWitnessDictionary.addIfNotExists(type, witness);
}
IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterfaceType(IRBuilder *builder, IRInterfaceType *interfaceType, IRWitnessTable *witnessTable)
{
SLANG_RELEASE_ASSERT(interfaceType);
- List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath(
+ List<IRInterfaceRequirementEntry*> lookupKeyPath = findInterfaceLookupPath(
sharedContext->differentiableInterfaceType, interfaceType);
IRInst* differentialTypeWitness = witnessTable;
@@ -549,6 +709,7 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface
{
differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey());
// Lookup insts are always primal values.
+
builder->markInstAsPrimal(differentialTypeWitness);
}
return differentialTypeWitness;
@@ -557,10 +718,10 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface
return nullptr;
}
-// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`.
-static bool _findDifferentiableInterfaceLookupPathImpl(
+// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `supType`.
+static bool _findInterfaceLookupPathImpl(
HashSet<IRInst*>& processedTypes,
- IRInterfaceType* idiffType,
+ IRInterfaceType* supType,
IRInterfaceType* type,
List<IRInterfaceRequirementEntry*>& currentPath)
{
@@ -576,13 +737,13 @@ static bool _findDifferentiableInterfaceLookupPathImpl(
if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal()))
{
currentPath.add(entry);
- if (wt->getConformanceType() == idiffType)
+ if (wt->getConformanceType() == supType)
{
return true;
}
else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType()))
{
- if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath))
+ if (_findInterfaceLookupPathImpl(processedTypes, supType, subInterfaceType, currentPath))
return true;
}
currentPath.removeLast();
@@ -591,11 +752,11 @@ static bool _findDifferentiableInterfaceLookupPathImpl(
return false;
}
-List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findDifferentiableInterfaceLookupPath(IRInterfaceType *idiffType, IRInterfaceType *type)
+List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findInterfaceLookupPath(IRInterfaceType *supType, IRInterfaceType *type)
{
List<IRInterfaceRequirementEntry*> currentPath;
HashSet<IRInst*> processedTypes;
- _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath);
+ _findInterfaceLookupPathImpl(processedTypes, supType, type, currentPath);
return currentPath;
}
@@ -722,7 +883,7 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst))
{
- differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness());
+ addTypeToDictionary(pairType->getValueType(), pairType->getWitness());
}
}
}
@@ -762,9 +923,8 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build
case kIROp_DifferentialPairType:
{
auto primalPairType = as<IRDifferentialPairType>(primalType);
- return getOrCreateDiffPairType(
- builder,
- getDiffTypeFromPairType(builder, primalPairType),
+ return builder->getDifferentialPairType(
+ (IRType*)getDiffTypeFromPairType(builder, primalPairType),
getDiffTypeWitnessFromPairType(builder, primalPairType));
}
@@ -776,6 +936,14 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build
getDiffTypeWitnessFromPairType(builder, primalPairType));
}
+ case kIROp_DifferentialPtrPairType:
+ {
+ auto primalPairType = as<IRDifferentialPtrPairType>(primalType);
+ return builder->getDifferentialPtrPairType(
+ (IRType*)getDiffTypeFromPairType(builder, primalPairType),
+ getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
case kIROp_FuncType:
{
SLANG_UNIMPLEMENTED_X("Impl");
@@ -817,12 +985,12 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build
}
}
-IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType)
+IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType, DiffConformanceKind kind)
{
if (isNoDiffType((IRType*)primalType))
return nullptr;
-
- IRInst* witness = lookUpConformanceForType((IRType*)primalType);
+
+ IRInst* witness = lookUpConformanceForType((IRType*)primalType, kind);
if (witness)
{
SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(primalType));
@@ -834,31 +1002,60 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil
witness = nullptr;
}
- if (!witness)
+ if (witness)
+ return witness;
+
+ // If a witness is not already mapped, build one if possible.
+ SLANG_RELEASE_ASSERT(primalType);
+ if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType))
{
- SLANG_RELEASE_ASSERT(primalType);
- if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType))
- {
- witness = getOrCreateDifferentiablePairWitness(builder, primalPairType);
- }
- else if (auto arrayType = as<IRArrayType>(primalType))
- {
- witness = getArrayWitness(builder, arrayType);
- }
- else if (auto extractExistential = as<IRExtractExistentialType>(primalType))
- {
- witness = getExtractExistensialTypeWitness(builder, extractExistential);
- }
- else if (auto typePack = as<IRTypePack>(primalType))
+ witness = buildDifferentiablePairWitness(builder, primalPairType, kind);
+ }
+ else if (auto arrayType = as<IRArrayType>(primalType))
+ {
+ witness = buildArrayWitness(builder, arrayType, kind);
+ }
+ else if (auto extractExistential = as<IRExtractExistentialType>(primalType))
+ {
+ witness = buildExtractExistensialTypeWitness(builder, extractExistential, kind);
+ }
+ else if (auto typePack = as<IRTypePack>(primalType))
+ {
+ witness = buildTupleWitness(builder, typePack, kind);
+ }
+ else if (auto tupleType = as<IRTupleType>(primalType))
+ {
+ witness = buildTupleWitness(builder, tupleType, kind);
+ }
+ else if (auto lookup = as<IRLookupWitnessMethod>(primalType))
+ {
+ // For types that are lookups from a table, we can simply lookup the witness from the same table
+ if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey)
{
- witness = getTupleWitness(builder, typePack);
+ witness = builder->emitLookupInterfaceMethodInst(
+ lookup->getWitnessTable()->getDataType(),
+ lookup->getWitnessTable(),
+ sharedContext->differentialAssocTypeWitnessStructKey);
}
- else if (auto tupleType = as<IRTupleType>(primalType))
+
+ if (lookup->getRequirementKey() == sharedContext->differentialAssocRefTypeStructKey)
{
- witness = getTupleWitness(builder, tupleType);
+ witness = builder->emitLookupInterfaceMethodInst(
+ lookup->getWitnessTable()->getDataType(),
+ lookup->getWitnessTable(),
+ sharedContext->differentialAssocRefTypeWitnessStructKey);
}
}
- return witness;
+
+ // If we created a witness, register it.
+ if (witness)
+ {
+ addTypeToDictionary((IRType*)primalType, witness);
+ return witness;
+ }
+
+ // Failed. Type is either non-differentiable, or unhandled.
+ return nullptr;
}
IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness)
@@ -868,77 +1065,97 @@ IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder*
witness);
}
-IRInst* DifferentiableTypeConformanceContext::getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType)
+IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness(
+ IRBuilder* builder,
+ IRDifferentialPairTypeBase* pairType,
+ DiffConformanceKind target)
{
- // Differentiate the pair type to get it's differential (which is itself a pair)
- auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType);
-
- auto addMethod = builder->createFunc();
- auto zeroMethod = builder->createFunc();
-
- auto table = builder->createWitnessTable(this->sharedContext->differentiableInterfaceType, (IRType*)pairType);
-
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType);
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
-
- bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false;
-
- // Fill in differential method implementations.
- auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType();
- auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness();
-
- {
- // Add method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType };
- addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType));
- b.emitBlock();
- auto p0 = b.emitParam(diffDiffPairType);
- auto p1 = b.emitParam(diffDiffPairType);
-
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
- auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
- IRInst* argsPrimal[2] = {
- isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0),
- isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) };
- auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal);
- IRInst* argsDiff[2] = {
- isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0),
- isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)};
- auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff);
- auto retVal =
- isUserCodeType
- ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart)
- : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart);
- b.emitReturn(retVal);
- }
- {
- // Zero method.
- IRBuilder b = *builder;
- b.setInsertInto(zeroMethod);
- zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType));
- b.emitBlock();
- auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
- auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
- auto retVal =
- isUserCodeType
- ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal)
- : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal);
- b.emitReturn(retVal);
+ IRWitnessTable* table = nullptr;
+ if (target == DiffConformanceKind::Value)
+ {
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType);
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
+ table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)pairType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+
+ bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false;
+
+ // Fill in differential method implementations.
+ auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType();
+ auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness();
+
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffDiffPairType);
+ auto p1 = b.emitParam(diffDiffPairType);
+
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType);
+ IRInst* argsPrimal[2] = {
+ isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0),
+ isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) };
+ auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal);
+ IRInst* argsDiff[2] = {
+ isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0),
+ isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)};
+ auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff);
+ auto retVal =
+ isUserCodeType
+ ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart)
+ : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart);
+ b.emitReturn(retVal);
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType));
+ b.emitBlock();
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType);
+ auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
+ auto retVal =
+ isUserCodeType
+ ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal)
+ : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal);
+ b.emitReturn(retVal);
+ }
+ }
+ else if (target == DiffConformanceKind::Ptr)
+ {
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType);
+
+ table = builder->createWitnessTable(
+ sharedContext->differentiablePtrInterfaceType,
+ (IRType*)pairType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffDiffPairType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table);
}
-
- // Record this in the context for future lookups
- differentiableWitnessDictionary[(IRType*)pairType] = table;
return table;
}
-IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder, IRArrayType* arrayType)
+IRInst* DifferentiableTypeConformanceContext::buildArrayWitness(
+ IRBuilder* builder,
+ IRArrayType* arrayType,
+ DiffConformanceKind target)
{
// Differentiate the pair type to get it's differential (which is itself a pair)
auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)arrayType);
@@ -946,70 +1163,89 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder
if (!diffArrayType)
return nullptr;
- auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType());
+ IRWitnessTable* table = nullptr;
+ if (target == DiffConformanceKind::Value)
+ {
+ SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType));
+ auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType(), DiffConformanceKind::Value);
- auto addMethod = builder->createFunc();
- auto zeroMethod = builder->createFunc();
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
- auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType);
+ table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType);
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType);
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
- auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
+ auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType();
- // Fill in differential method implementations.
+ // Fill in differential method implementations.
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffArrayType, diffArrayType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffArrayType);
+ auto p1 = b.emitParam(diffArrayType);
+
+ // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType);
+ auto resultVar = b.emitVar(diffArrayType);
+ IRBlock* loopBodyBlock = nullptr;
+ IRBlock* loopBreakBlock = nullptr;
+ auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock);
+ b.setInsertBefore(loopBodyBlock->getTerminator());
+
+ IRInst* args[2] = {
+ b.emitElementExtract(p0, loopCounter),
+ b.emitElementExtract(p1, loopCounter) };
+ auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args);
+ auto addr = b.emitElementAddress(resultVar, loopCounter);
+ b.emitStore(addr, elementResult);
+ b.setInsertInto(loopBreakBlock);
+ b.emitReturn(b.emitLoad(resultVar));
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(zeroMethod);
+ zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType));
+ b.emitBlock();
+
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType);
+ auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
+ auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal);
+ b.emitReturn(retVal);
+ }
+ }
+ else if (target == DiffConformanceKind::Ptr)
{
- // Add method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- IRType* paramTypes[2] = { diffArrayType, diffArrayType };
- addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType));
- b.emitBlock();
- auto p0 = b.emitParam(diffArrayType);
- auto p1 = b.emitParam(diffArrayType);
+ SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType));
- // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type.
- auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
- auto resultVar = b.emitVar(diffArrayType);
- IRBlock* loopBodyBlock = nullptr;
- IRBlock* loopBreakBlock = nullptr;
- auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock);
- b.setInsertBefore(loopBodyBlock->getTerminator());
+ table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)arrayType);
- IRInst* args[2] = {
- b.emitElementExtract(p0, loopCounter),
- b.emitElementExtract(p1, loopCounter) };
- auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args);
- auto addr = b.emitElementAddress(resultVar, loopCounter);
- b.emitStore(addr, elementResult);
- b.setInsertInto(loopBreakBlock);
- b.emitReturn(b.emitLoad(resultVar));
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffArrayType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table);
}
+ else
{
- // Zero method.
- IRBuilder b = *builder;
- b.setInsertInto(zeroMethod);
- zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType));
- b.emitBlock();
-
- auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
- auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr);
- auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal);
- b.emitReturn(retVal);
+ SLANG_UNEXPECTED("Invalid conformance kind for synthesis");
}
- // Record this in the context for future lookups
- differentiableWitnessDictionary[(IRType*)arrayType] = table;
-
return table;
}
-IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType)
+IRInst* DifferentiableTypeConformanceContext::buildTupleWitness(
+ IRBuilder* builder,
+ IRInst* inTupleType,
+ DiffConformanceKind target)
{
// Differentiate the pair type to get it's differential (which is itself a pair)
auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType);
@@ -1017,100 +1253,116 @@ IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder
if (!diffTupleType)
return nullptr;
- auto addMethod = builder->createFunc();
- auto zeroMethod = builder->createFunc();
-
- auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType);
-
- // And place it in the synthesized witness table.
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType);
- builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
- builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
- builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
-
- // Fill in differential method implementations.
- {
- // Add method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- IRType* paramTypes[2] = { diffTupleType, diffTupleType };
- addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType));
- b.emitBlock();
- auto p0 = b.emitParam(diffTupleType);
- auto p1 = b.emitParam(diffTupleType);
- List<IRInst*> results;
- for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
- {
- auto elementType = inTupleType->getOperand(i);
- auto diffElementType = (IRType*)diffTupleType->getOperand(i);
- auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType);
- IRInst* elementResult = nullptr;
- if (!innerWitness)
+ IRWitnessTable* table = nullptr;
+ if (target == DiffConformanceKind::Value)
+ {
+ SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType));
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
+ table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+
+ // Fill in differential method implementations.
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffTupleType, diffTupleType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffTupleType);
+ auto p1 = b.emitParam(diffTupleType);
+ List<IRInst*> results;
+ for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
{
- elementResult = b.getVoidValue();
+ auto elementType = inTupleType->getOperand(i);
+ auto diffElementType = (IRType*)diffTupleType->getOperand(i);
+ auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value);
+ IRInst* elementResult = nullptr;
+ if (!innerWitness)
+ {
+ elementResult = b.getVoidValue();
+ }
+ else
+ {
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType);
+ auto iVal = b.getIntValue(b.getIntType(), i);
+ IRInst* args[2] = {
+ b.emitGetTupleElement(diffElementType, p0, iVal),
+ b.emitGetTupleElement(diffElementType, p1, iVal) };
+ elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args);
+ }
+ results.add(elementResult);
}
+ IRInst* resultVal = nullptr;
+ if (diffTupleType->getOp() == kIROp_TupleType)
+ resultVal = b.emitMakeTuple(diffTupleType, results);
else
- {
- auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
- auto iVal = b.getIntValue(b.getIntType(), i);
- IRInst* args[2] = {
- b.emitGetTupleElement(diffElementType, p0, iVal),
- b.emitGetTupleElement(diffElementType, p1, iVal) };
- elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args);
- }
- results.add(elementResult);
+ resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
+ b.emitReturn(resultVal);
}
- IRInst* resultVal = nullptr;
- if (diffTupleType->getOp() == kIROp_TupleType)
- resultVal = b.emitMakeTuple(diffTupleType, results);
- else
- resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
- b.emitReturn(resultVal);
- }
- {
- // Zero method.
- IRBuilder b = *builder;
- b.setInsertInto(addMethod);
- b.addBackwardDifferentiableDecoration(addMethod);
- addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
- b.emitBlock();
- List<IRInst*> results;
- for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
- {
- auto elementType = inTupleType->getOperand(i);
- auto diffElementType = (IRType*)diffTupleType->getOperand(i);
- auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType);
- IRInst* elementResult = nullptr;
- if (!innerWitness)
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
+ b.emitBlock();
+ List<IRInst*> results;
+ for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
{
- elementResult = b.getVoidValue();
+ auto elementType = inTupleType->getOperand(i);
+ auto diffElementType = (IRType*)diffTupleType->getOperand(i);
+ auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value);
+ IRInst* elementResult = nullptr;
+ if (!innerWitness)
+ {
+ elementResult = b.getVoidValue();
+ }
+ else
+ {
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType);
+ elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr);
+ }
+ results.add(elementResult);
}
+ IRInst* resultVal = nullptr;
+ if (diffTupleType->getOp() == kIROp_TupleType)
+ resultVal = b.emitMakeTuple(diffTupleType, results);
else
- {
- auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
- elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr);
- }
- results.add(elementResult);
+ resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
+ b.emitReturn(resultVal);
}
- IRInst* resultVal = nullptr;
- if (diffTupleType->getOp() == kIROp_TupleType)
- resultVal = b.emitMakeTuple(diffTupleType, results);
- else
- resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
- b.emitReturn(resultVal);
}
+ else if (target == DiffConformanceKind::Ptr)
+ {
+ SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType));
- // Record this in the context for future lookups
- differentiableWitnessDictionary[(IRType*)inTupleType] = table;
+ table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)inTupleType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffTupleType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table);
+ }
return table;
}
-IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(
+IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness(
IRBuilder* builder,
- IRExtractExistentialType* extractExistentialType)
+ IRExtractExistentialType* extractExistentialType,
+ DiffConformanceKind target)
{
+ SLANG_UNUSED(target); // logic is the same for both value and ptr
+
// Check that the type's base is differentiable
if (differentiateType(builder, extractExistentialType->getOperand(0)->getDataType()))
{
@@ -1310,12 +1562,13 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst*
if (context.isDifferentiableType((IRType*)typeInst))
return true;
+
// Look for equivalent types.
- for (auto type : context.differentiableWitnessDictionary)
+ for (auto type : context.differentiableTypeWitnessDictionary)
{
if (isTypeEqual(type.key, (IRType*)typeInst))
{
- context.differentiableWitnessDictionary[(IRType*)typeInst] = type.value;
+ context.differentiableTypeWitnessDictionary[(IRType*)typeInst] = type.value;
return true;
}
}
@@ -1672,7 +1925,7 @@ struct AutoDiffPass : public InstPassBase
IRBuilder keyBuilder = builder;
keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType));
auto diffKey = keyBuilder.createStructKey();
- auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey);
+ auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey, builder.getTypeKind());
info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType);
info.witness = diffFieldWitness;
builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey);
@@ -1695,7 +1948,11 @@ struct AutoDiffPass : public InstPassBase
List<IRInst*> fieldVals;
for (auto info : diffFields)
{
- auto innerZeroMethod = _lookupWitness(&builder, info.witness, autodiffContext->zeroMethodStructKey);
+ auto innerZeroMethod = _lookupWitness(
+ &builder,
+ info.witness,
+ autodiffContext->zeroMethodStructKey,
+ autodiffContext->zeroMethodType);
IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr);
fieldVals.add(val);
}
@@ -1719,7 +1976,11 @@ struct AutoDiffPass : public InstPassBase
List<IRInst*> fieldVals;
for (auto info : diffFields)
{
- auto innerAddMethod = _lookupWitness(&builder, info.witness, autodiffContext->addMethodStructKey);
+ auto innerAddMethod = _lookupWitness(
+ &builder,
+ info.witness,
+ autodiffContext->addMethodStructKey,
+ autodiffContext->addMethodType);
IRInst* args[2] = {
builder.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()),
builder.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()),