summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-23 21:45:59 -0700
committerGitHub <noreply@github.com>2024-08-23 21:45:59 -0700
commitb2ca2d5a4efeae807d3c3f48f60235e47413b559 (patch)
tree643d2bab5776e5f8f7cfa722975af9e826d77c9d /source/slang/slang-ir-autodiff.cpp
parente4088cd602bd4d5a72fea67a787b1319acfc044d (diff)
Make variadic generics work with interfaces and forward autodiff. (#4905)
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp156
1 files changed, 150 insertions, 6 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 6b275179c..b7c2037e5 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -44,6 +44,13 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
return entry->getRequirementVal();
}
}
+ else if (as<IRMakeWitnessPack>(witness))
+ {
+ // We are looking up a witness from a type pack.
+ // This is only allowed if we are looking up a differential type.
+ // We should turn this into an actual witness table for the type pack/tuple type.
+ SLANG_UNEXPECTED("looking up from a witness pack is invalid and should have been lowered.");
+ }
else
{
return builder->emitLookupInterfaceMethodInst(
@@ -434,10 +441,33 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
}
else
{
- differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness());
+ auto witness = item->getWitness();
// Also register the type's differential type with the same witness.
+ auto concreteType = item->getConcreteType();
IRBuilder subBuilder(item->getConcreteType());
+ if (as<IRTypePack>(concreteType) || as<IRTupleType>(concreteType))
+ {
+ // For tuple types, register the differential type for each element, but don't register for the
+ // tuple/typepack itself.
+ auto witnessPack = as<IRMakeWitnessPack>(witness);
+ SLANG_ASSERT(witnessPack);
+
+ for (UInt i = 0; i < concreteType->getOperandCount(); i++)
+ {
+ auto element = concreteType->getOperand(i);
+ auto elementWitness = witnessPack->getOperand(i);
+ differentiableWitnessDictionary.addIfNotExists(
+ (IRType*)element,
+ _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey));
+ }
+ return;
+ }
+ else
+ {
+ differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness());
+ }
+
if (!as<IRInterfaceType>(item->getConcreteType()))
{
differentiableWitnessDictionary.addIfNotExists(
@@ -768,16 +798,18 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build
SLANG_UNIMPLEMENTED_X("Impl");
}
+ case kIROp_TypePack:
case kIROp_TupleType:
{
- auto tupleType = as<IRTupleType>(primalType);
List<IRType*> diffTypeList;
// TODO: what if we have type parameters here?
- for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ for (UIndex ii = 0; ii < primalType->getOperandCount(); ii++)
diffTypeList.add(
- differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
-
- return builder->getTupleType(diffTypeList);
+ differentiateType(builder, (IRType*)primalType->getOperand(ii)));
+ if (primalType->getOp() == kIROp_TupleType)
+ return builder->getTupleType(diffTypeList);
+ else
+ return builder->getTypePack((UInt)diffTypeList.getCount(), diffTypeList.getBuffer());
}
default:
@@ -795,6 +827,12 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil
{
SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(primalType));
}
+ if (as<IRMakeWitnessPack>(witness))
+ {
+ // If registered witness is a witness pack for a type pack,
+ // we should reconstruct the true witness table.
+ witness = nullptr;
+ }
if (!witness)
{
@@ -811,6 +849,14 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil
{
witness = getExtractExistensialTypeWitness(builder, extractExistential);
}
+ else if (auto typePack = as<IRTypePack>(primalType))
+ {
+ witness = getTupleWitness(builder, typePack);
+ }
+ else if (auto tupleType = as<IRTupleType>(primalType))
+ {
+ witness = getTupleWitness(builder, tupleType);
+ }
}
return witness;
}
@@ -963,6 +1009,104 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder
return table;
}
+IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType)
+{
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType);
+
+ if (!diffTupleType)
+ return nullptr;
+
+ auto addMethod = builder->createFunc();
+ auto zeroMethod = builder->createFunc();
+
+ auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType);
+
+ // And place it in the synthesized witness table.
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType);
+ builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table);
+ builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod);
+ builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod);
+
+ // Fill in differential method implementations.
+ {
+ // Add method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ IRType* paramTypes[2] = { diffTupleType, diffTupleType };
+ addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType));
+ b.emitBlock();
+ auto p0 = b.emitParam(diffTupleType);
+ auto p1 = b.emitParam(diffTupleType);
+ List<IRInst*> results;
+ for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
+ {
+ auto elementType = inTupleType->getOperand(i);
+ auto diffElementType = (IRType*)diffTupleType->getOperand(i);
+ auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType);
+ IRInst* elementResult = nullptr;
+ if (!innerWitness)
+ {
+ elementResult = b.getVoidValue();
+ }
+ else
+ {
+ auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey);
+ auto iVal = b.getIntValue(b.getIntType(), i);
+ IRInst* args[2] = {
+ b.emitGetTupleElement(diffElementType, p0, iVal),
+ b.emitGetTupleElement(diffElementType, p1, iVal) };
+ elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args);
+ }
+ results.add(elementResult);
+ }
+ IRInst* resultVal = nullptr;
+ if (diffTupleType->getOp() == kIROp_TupleType)
+ resultVal = b.emitMakeTuple(diffTupleType, results);
+ else
+ resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
+ b.emitReturn(resultVal);
+ }
+ {
+ // Zero method.
+ IRBuilder b = *builder;
+ b.setInsertInto(addMethod);
+ b.addBackwardDifferentiableDecoration(addMethod);
+ addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType));
+ b.emitBlock();
+ List<IRInst*> results;
+ for (UInt i = 0; i < inTupleType->getOperandCount(); i++)
+ {
+ auto elementType = inTupleType->getOperand(i);
+ auto diffElementType = (IRType*)diffTupleType->getOperand(i);
+ auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType);
+ IRInst* elementResult = nullptr;
+ if (!innerWitness)
+ {
+ elementResult = b.getVoidValue();
+ }
+ else
+ {
+ auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey);
+ elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr);
+ }
+ results.add(elementResult);
+ }
+ IRInst* resultVal = nullptr;
+ if (diffTupleType->getOp() == kIROp_TupleType)
+ resultVal = b.emitMakeTuple(diffTupleType, results);
+ else
+ resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer());
+ b.emitReturn(resultVal);
+ }
+
+ // Record this in the context for future lookups
+ differentiableWitnessDictionary[(IRType*)inTupleType] = table;
+
+ return table;
+}
+
IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(
IRBuilder* builder,
IRExtractExistentialType* extractExistentialType)