summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp835
1 files changed, 589 insertions, 246 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index d0bf8f347..8a4fe23d0 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -7,6 +7,7 @@
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
// origX, primalX, diffX
// origX -> primalX (cloneEnv)
@@ -20,9 +21,19 @@ struct Pair
{
P primal;
D differential;
-
+ Pair() = default;
Pair(P primal, D differential) : primal(primal), differential(differential)
{}
+ HashCode getHashCode() const
+ {
+ Hasher hasher;
+ hasher << primal << differential;
+ return hasher.getResult();
+ }
+ bool operator ==(const Pair& other) const
+ {
+ return primal == other.primal && differential == other.differential;
+ }
};
typedef Pair<IRInst*, IRInst*> InstPair;
@@ -43,6 +54,11 @@ struct AutoDiffSharedContext
//
IRStructKey* differentialAssocTypeStructKey = nullptr;
+ // The struct key for the witness that `Differential` associated type conforms to
+ // `IDifferential`.
+ IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;
+
+
// The struct key for the 'zero()' associated type
// defined inside IDifferential. We use this to lookup the
// implementation of zero() for a given type.
@@ -54,6 +70,9 @@ struct AutoDiffSharedContext
// implementation of add() for a given type.
//
IRStructKey* addMethodStructKey = nullptr;
+
+ IRStructKey* mulMethodStructKey = nullptr;
+
// Modules that don't use differentiable types
// won't have the IDifferentiable interface type available.
@@ -69,8 +88,10 @@ struct AutoDiffSharedContext
if (differentiableInterfaceType)
{
differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
zeroMethodStructKey = findZeroMethodStructKey();
addMethodStructKey = findAddMethodStructKey();
+ mulMethodStructKey = findMulMethodStructKey();
if (differentialAssocTypeStructKey)
isInterfaceAvailable = true;
@@ -103,22 +124,32 @@ struct AutoDiffSharedContext
return getIDifferentiableStructKeyAtIndex(0);
}
- IRStructKey* findZeroMethodStructKey()
+ IRStructKey* findDifferentialTypeWitnessStructKey()
{
return getIDifferentiableStructKeyAtIndex(1);
}
- IRStructKey* findAddMethodStructKey()
+ IRStructKey* findZeroMethodStructKey()
{
return getIDifferentiableStructKeyAtIndex(2);
}
+ IRStructKey* findAddMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(3);
+ }
+
+ IRStructKey* findMulMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(4);
+ }
+
IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index)
{
if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
{
- // Assume for now that IDifferentiable has exactly four fields.
- SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4);
+ // Assume for now that IDifferentiable has exactly five fields.
+ SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
return as<IRStructKey>(entry->getRequirementKey());
else
@@ -300,7 +331,16 @@ struct DifferentialPairTypeBuilder
IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key)
{
- if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
+ auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType());
+ if (baseTypeInfo.isTrivial)
+ {
+ if (key == globalPrimalKey)
+ return baseInst;
+ else
+ return builder->getDifferentialBottom();
+ }
+
+ if (auto basePairStructType = as<IRStructType>(baseTypeInfo.loweredType))
{
return as<IRFieldExtract>(builder->emitFieldExtract(
findField(basePairStructType, key)->getFieldType(),
@@ -308,7 +348,7 @@ struct DifferentialPairTypeBuilder
key
));
}
- else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
+ else if (auto ptrType = as<IRPtrTypeBase>(baseTypeInfo.loweredType))
{
if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
{
@@ -334,7 +374,7 @@ struct DifferentialPairTypeBuilder
key));
}
}
- else if (auto specializedType = as<IRSpecialize>(baseInst->getDataType()))
+ else if (auto specializedType = as<IRSpecialize>(baseTypeInfo.loweredType))
{
// TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
// type, emit the specialization type.
@@ -420,25 +460,64 @@ struct DifferentialPairTypeBuilder
{
SLANG_ASSERT(!as<IRParam>(origBaseType));
SLANG_ASSERT(diffType);
- auto pairStructType = builder->createStructType();
- builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
- builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType);
+ if (diffType->getOp() != kIROp_DifferentialBottomType)
+ {
+ auto pairStructType = builder->createStructType();
+ builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
+ builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType);
+ return pairStructType;
+ }
+ return origBaseType;
+ }
- return pairStructType;
+ struct LoweredPairTypeInfo
+ {
+ IRInst* loweredType;
+ bool isTrivial;
+ };
+
+ IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
+ {
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
}
- IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType)
+ IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
{
- if (pairTypeCache.ContainsKey(origBaseType))
- return pairTypeCache[origBaseType];
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
+ }
- auto pairType = _createDiffPairType(builder, origBaseType, diffType);
- pairTypeCache.Add(origBaseType, pairType);
+ LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType)
+ {
+ LoweredPairTypeInfo result = {};
+
+ if (pairTypeCache.TryGetValue(originalPairType, result))
+ return result;
+ auto pairType = as<IRDifferentialPairType>(originalPairType);
+ if (!pairType)
+ {
+ result.isTrivial = true;
+ result.loweredType = originalPairType;
+ return result;
+ }
+ auto primalType = pairType->getValueType();
+ if (as<IRParam>(primalType))
+ {
+ result.isTrivial = false;
+ result.loweredType = nullptr;
+ return result;
+ }
+
+ auto diffType = getDiffTypeFromPairType(builder, pairType);
+ result.loweredType = _createDiffPairType(builder, pairType->getValueType(), (IRType*)diffType);
+ result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType);
+ pairTypeCache.Add(originalPairType, result);
- return pairType;
+ return result;
}
- Dictionary<IRInst*, IRInst*> pairTypeCache;
+ Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache;
IRStructKey* globalPrimalKey = nullptr;
@@ -447,6 +526,8 @@ struct DifferentialPairTypeBuilder
IRInst* genericDiffPairType = nullptr;
List<IRInst*> generatedTypeList;
+
+ AutoDiffSharedContext* sharedContext = nullptr;
};
struct JVPTranscriber
@@ -474,8 +555,15 @@ struct JVPTranscriber
DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
- JVPTranscriber(AutoDiffSharedContext* shared)
- : differentiableTypeConformanceContext(shared)
+ List<InstPair> followUpFunctionsToTranscribe;
+
+ SharedIRBuilder* sharedBuilder;
+ // Witness table that `DifferentialBottom:IDifferential`.
+ IRWitnessTable* differentialBottomWitness = nullptr;
+ Dictionary<InstPair, IRInst*> differentialPairTypes;
+
+ JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder)
+ : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder)
{}
DiagnosticSink* getSink()
@@ -592,8 +680,75 @@ struct JVPTranscriber
return builder->getFuncType(newParameterTypes, diffReturnType);
}
+ IRWitnessTable* getDifferentialBottomWitness()
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
+ auto result =
+ as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
+ builder.getDifferentialBottomType()));
+ SLANG_ASSERT(result);
+ return result;
+ }
+
+ // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+ IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(inDiffPairType->parent);
+ auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
+ SLANG_ASSERT(diffPairType);
+ auto result =
+ as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
+ builder.getDifferentialBottomType()));
+ if (result)
+ return result;
+
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+ auto diffType = differentiateType(&builder, diffPairType->getValueType());
+ auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness());
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ return table;
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ auto witness = as<IRWitnessTable>(
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+ if (!witness)
+ witness = getDifferentialBottomWitness();
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
IRType* differentiateType(IRBuilder* builder, IRType* origType)
{
+ IRInst* diffType = nullptr;
+ if (!instMapD.TryGetValue(origType, diffType))
+ {
+ diffType = _differentiateTypeImpl(builder, origType);
+ instMapD[origType] = diffType;
+ }
+ return (IRType*)diffType;
+ }
+
+ IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
+ {
if (auto ptrType = as<IRPtrTypeBase>(origType))
return builder->getPtrType(
origType->getOp(),
@@ -628,6 +783,14 @@ struct JVPTranscriber
else
return nullptr;
}
+
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as<IRDifferentialPairType>(primalType);
+ return getOrCreateDiffPairType(
+ pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
+ pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
case kIROp_FuncType:
return differentiateFunctionType(builder, as<IRFuncType>(primalType));
@@ -660,7 +823,7 @@ struct JVPTranscriber
return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
}
}
-
+
IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
{
// If this is a PtrType (out, inout, etc..), then create diff pair from
@@ -675,7 +838,7 @@ struct JVPTranscriber
}
auto diffType = differentiateType(builder, primalType);
if (diffType)
- return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType, diffType);
+ return (IRType*)getOrCreateDiffPairType(primalType);
return nullptr;
}
@@ -692,7 +855,7 @@ struct JVPTranscriber
if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
{
- IRParam* diffPairParam = builder->emitParam(diffPairType);
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
auto diffPairVarName = makeDiffPairName(origParam);
if (diffPairVarName.getLength() > 0)
@@ -700,9 +863,20 @@ struct JVPTranscriber
SLANG_ASSERT(diffPairParam);
- return InstPair(
- pairBuilder->emitPrimalFieldAccess(builder, diffPairParam),
- pairBuilder->emitDiffFieldAccess(builder, diffPairParam));
+ if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
+ // its primal and diff parts right now because they would represent a reference
+ // to a pair field, which doesn't make sense since pair types are considered mutable.
+ // We encode the result as if the param is non-differentiable, and handle it
+ // with special care at load/store.
+ return InstPair(diffPairParam, nullptr);
}
@@ -826,30 +1000,52 @@ struct JVPTranscriber
InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
{
auto origPtr = origLoad->getPtr();
-
- auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+ auto primalPtr = lookupPrimalInst(origPtr, nullptr);
+ auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType();
- IRInst* diffLoad = nullptr;
+ if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType))
+ {
+ // Special case load from an `out` param, which will not have corresponding `diff` and
+ // `primal` insts yet.
+ auto load = builder->emitLoad(primalPtr);
+ auto primalElement = builder->emitDifferentialPairGetPrimal(load);
+ auto diffElement = builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load);
+ return InstPair(primalElement, diffElement);
+ }
+ auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+ IRInst* diffLoad = nullptr;
if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
{
// Default case, we're loading from a known differential inst.
diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
- return InstPair(primalLoad, diffLoad);
- }
- return InstPair(primalLoad, nullptr);
+ }
+ return InstPair(primalLoad, diffLoad);
}
InstPair transcribeStore(IRBuilder* builder, IRStore* origStore)
{
IRInst* origStoreLocation = origStore->getPtr();
IRInst* origStoreVal = origStore->getVal();
-
- auto primalStore = cloneInst(&cloneEnv, builder, origStore);
-
+ auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr);
auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
+ auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr);
auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);
+ if (!diffStoreLocation)
+ {
+ auto primalLocationPtrType = as<IRPtrTypeBase>(primalStoreLocation->getDataType());
+ if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType()))
+ {
+ auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal);
+ auto store = builder->emitStore(primalStoreLocation, valToStore);
+ return InstPair(store, nullptr);
+ }
+ }
+
+ auto primalStore = cloneInst(&cloneEnv, builder, origStore);
+
IRInst* diffStore = nullptr;
// If the stored value has a differential version,
@@ -1052,8 +1248,9 @@ struct JVPTranscriber
if (diffReturnType->getOp() != kIROp_VoidType)
{
- IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst);
- IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst);
+ IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst);
+ auto diffType = differentiateType(builder, origCall->getFullType());
+ IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst);
return InstPair(primalResultValue, diffResultValue);
}
else
@@ -1174,14 +1371,16 @@ struct JVPTranscriber
return InstPair(nullptr, nullptr);
}
- InstPair transcribeConst(IRBuilder*, IRInst* origInst)
+ InstPair transcribeConst(IRBuilder* builder, IRInst* origInst)
{
switch(origInst->getOp())
{
case kIROp_FloatLit:
+ return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f));
case kIROp_VoidLit:
+ return InstPair(origInst, origInst);
case kIROp_IntLit:
- return InstPair(origInst, nullptr);
+ return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0));
}
getSink()->diagnose(
@@ -1245,6 +1444,14 @@ struct JVPTranscriber
{
if (auto diffType = differentiateType(builder, primalType))
{
+ switch (diffType->getOp())
+ {
+ case kIROp_DifferentialPairType:
+ return builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ }
// Since primalType has a corresponding differential type, we can lookup the
// definition for zero().
auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
@@ -1458,40 +1665,63 @@ struct JVPTranscriber
return InstPair(diffLoop, diffLoop);
}
- // Transcribe a function definition.
- InstPair transcribeFunc(IRBuilder* builder, IRFunc* origFunc)
+ InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst)
{
- IRFunc* primalFunc = nullptr;
+ auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue());
+ SLANG_ASSERT(primalVal);
+ auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue());
+ SLANG_ASSERT(diffPrimalVal);
+ auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue());
+ SLANG_ASSERT(primalDiffVal);
+ auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
+ SLANG_ASSERT(diffDiffVal);
- differentiableTypeConformanceContext.setFunc(origFunc);
+ auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal);
+ auto diffPair = builder->emitMakeDifferentialPair(
+ differentiateType(builder, origInst->getDataType()),
+ primalDiffVal,
+ diffDiffVal);
+ return InstPair(primalPair, diffPair);
+ }
- auto oldLoc = builder->getInsertLoc();
+ InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
+ {
+ SLANG_ASSERT(
+ origInst->getOp() == kIROp_DifferentialPairGetDifferential ||
+ origInst->getOp() == kIROp_DifferentialPairGetPrimal);
- // If this is a top-level function, there is no need to clone it
- // since it is visible in all the scopes.
- // Otherwise, we need to clone it in case of generic scopes.
- //
- // TODO(sai): Is this the correct thing to do? Can a function cloned inside a
- // generic scope but is not the return value of that generic, be used within
- // that scope? Or do we have to call out to the original generic specialized with
- // the current generic params?
- //
- bool isTopLevelFunc = (as<IRModuleInst>(origFunc->parent) != nullptr);
- if (isTopLevelFunc)
- {
- builder->setInsertBefore(origFunc);
- primalFunc = origFunc;
- }
+ auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0));
+ SLANG_ASSERT(primalVal);
+
+ auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0));
+ SLANG_ASSERT(diffVal);
+
+ auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal);
+
+ auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType());
+ IRInst* diffResultType = nullptr;
+ if (origInst->getOp() == kIROp_DifferentialPairGetDifferential)
+ diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType);
else
- {
- // TODO(sai): this might never be called, and it might never make sense
- // to call it either. Potentially remove this.
- primalFunc = as<IRFunc>(
- cloneInst(&cloneEnv, builder, origFunc));
- }
+ diffResultType = diffValPairType->getValueType();
+ auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal);
+ return InstPair(primalResult, diffResult);
+ }
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ InstPair transcribeFuncHeader(IRBuilder* builder, IRFunc* origFunc)
+ {
+ auto oldLoc = builder->getInsertLoc();
+
+ IRFunc* primalFunc = origFunc;
+
+ differentiableTypeConformanceContext.setFunc(origFunc);
+
+ builder->setInsertBefore(origFunc);
+ primalFunc = origFunc;
auto diffFunc = builder->createFunc();
-
+
SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
IRType* diffFuncType = this->differentiateFunctionType(
builder,
@@ -1505,10 +1735,33 @@ struct JVPTranscriber
newNameSb << "s_jvp_" << originalName;
builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
}
-
+ builder->addForwardDerivativeDecoration(origFunc, diffFunc);
+
+ // Mark the generated derivative function itself as differentiable.
+ builder->addForwardDifferentiableDecoration(diffFunc);
+
+ // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
+ if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ {
+ cloneDecoration(dictDecor, diffFunc);
+ }
+
+ // Reset builder position
+ builder->setInsertLoc(oldLoc);
+ auto result = InstPair(primalFunc, diffFunc);
+ followUpFunctionsToTranscribe.add(result);
+ return result;
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ auto oldLoc = builder->getInsertLoc();
+
+ differentiableTypeConformanceContext.setFunc(primalFunc);
// Transcribe children from origFunc into diffFunc
builder->setInsertInto(diffFunc);
- for (auto block = origFunc->getFirstBlock(); block; block = block->getNextBlock())
+ for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
this->transcribe(builder, block);
// Reset builder position
@@ -1685,6 +1938,11 @@ struct JVPTranscriber
case kIROp_ifElse:
return transcribeIfElse(builder, as<IRIfElse>(origInst));
+ case kIROp_MakeDifferentialPair:
+ return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst));
+ case kIROp_DifferentialPairGetPrimal:
+ case kIROp_DifferentialPairGetDifferential:
+ return transcribeDifferentialPairGetElement(builder, origInst);
}
// If none of the cases have been hit, check if the instruction is a
@@ -1722,7 +1980,7 @@ struct JVPTranscriber
switch (origInst->getOp())
{
case kIROp_Func:
- return transcribeFunc(builder, as<IRFunc>(origInst));
+ return transcribeFuncHeader(builder, as<IRFunc>(origInst));
case kIROp_Block:
return transcribeBlock(builder, as<IRBlock>(origInst));
@@ -1741,45 +1999,7 @@ struct JVPTranscriber
}
};
-struct IRWorkQueue
-{
- // Work list to hold the active set of insts whose children
- // need to be looked at.
- //
- List<IRInst*> workList;
- HashSet<IRInst*> workListSet;
-
- void push(IRInst* inst)
- {
- if(!inst) return;
- if(workListSet.Contains(inst)) return;
-
- workList.add(inst);
- workListSet.Add(inst);
- }
-
- IRInst* pop()
- {
- if (workList.getCount() != 0)
- {
- IRInst* topItem = workList.getFirst();
- // TODO(Sai): Repeatedly calling removeAt() can be really slow.
- // Consider a specialized data structure or using removeLast()
- //
- workList.removeAt(0);
- workListSet.Remove(topItem);
- return topItem;
- }
- return nullptr;
- }
-
- IRInst* peek()
- {
- return workList.getFirst();
- }
-};
-
-struct JVPDerivativeContext
+struct JVPDerivativeContext : public InstPassBase
{
DiagnosticSink* getSink()
@@ -1795,6 +2015,7 @@ struct JVPDerivativeContext
//
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
+ sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
@@ -1809,8 +2030,12 @@ struct JVPDerivativeContext
// IRDifferentialPairGetPrimal with 'primal' field access, and
// IRMakeDifferentialPair with an IRMakeStruct.
//
+ modified |= simplifyDifferentialBottomType(builder);
+
modified |= processPairTypes(builder, module->getModuleInst());
-
+
+ modified |= eliminateDifferentialBottomType(builder);
+
return modified;
}
@@ -1826,121 +2051,92 @@ struct JVPDerivativeContext
//
bool processReferencedFunctions(IRBuilder* builder)
{
- IRWorkQueue* workQueue = &(workQueueStorage);
+ List<IRForwardDifferentiate*> autoDiffWorkList;
- // Put the top-level inst into the queue.
- workQueue->push(module->getModuleInst());
-
- // Keep processing items until the queue is complete.
- while (IRInst* workItem = workQueue->pop())
- {
- for(auto child = workItem->getFirstChild(); child; child = child->getNextInst())
+ for (;;)
+ {
+ // Collect all `ForwardDifferentiate` insts from the module.
+ autoDiffWorkList.clear();
+ processInstsOfType<IRForwardDifferentiate>(kIROp_ForwardDifferentiate, [&](IRForwardDifferentiate* fwdDiffInst)
{
- // Either the child instruction has more children (func/block etc..)
- // and we add it to the work list for further processing, or
- // it's an ordinary inst in which case we check if it's a ForwardDifferentiate
- // instruction.
- //
- if (child->getFirstChild() != nullptr)
- workQueue->push(child);
-
- if (auto jvpDiffInst = as<IRForwardDifferentiate>(child))
- {
- auto baseInst = jvpDiffInst->getBaseFn();
+ autoDiffWorkList.add(fwdDiffInst);
+ });
- IRGlobalValueWithCode* baseFunction = nullptr;
+ if (autoDiffWorkList.getCount() == 0)
+ break;
- if (auto specializeInst = as<IRSpecialize>(baseInst))
- {
- // Certain specialize insts come with a derivative
- // reference attached. Skip such instructions.
- //
- if (lookupJVPReference(specializeInst)) continue;
- }
- else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst))
+ // Process collected `ForwardDifferentiate` insts and replace them with placeholders for
+ // differentiated functions.
+ transcriberStorage.followUpFunctionsToTranscribe.clear();
+
+ for (auto fwdDiffInst : autoDiffWorkList)
+ {
+ auto baseInst = fwdDiffInst->getBaseFn();
+ if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst))
+ {
+ if (auto existingDiffFunc = lookupJVPReference(baseFunction))
{
- baseFunction = globalValWithCode;
+ fwdDiffInst->replaceUsesWith(existingDiffFunc);
+ fwdDiffInst->removeAndDeallocate();
}
-
- SLANG_ASSERT(baseFunction);
-
- // If the JVP Reference already exists, no need to
- // differentiate again.
- //
- if (lookupJVPReference(baseFunction)) continue;
-
- if (isMarkedForForwardDifferentiation(baseFunction))
+ else if (isMarkedForForwardDifferentiation(baseFunction))
{
if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
{
- IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction);
+ IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction);
SLANG_ASSERT(diffFunc);
- builder->addForwardDerivativeDecoration(baseFunction, diffFunc);
- workQueue->push(diffFunc);
- }
+ fwdDiffInst->replaceUsesWith(diffFunc);
+ fwdDiffInst->removeAndDeallocate();
+ }
else
{
// TODO(Sai): This would probably be better with a more specific
// error code.
- getSink()->diagnose(jvpDiffInst->sourceLoc,
+ getSink()->diagnose(fwdDiffInst->sourceLoc,
Diagnostics::internalCompilerError,
"Unexpected instruction. Expected func or generic");
}
}
- else
+ else
{
// TODO(Sai): This would probably be better with a more specific
// error code.
- getSink()->diagnose(jvpDiffInst->sourceLoc,
+ getSink()->diagnose(fwdDiffInst->sourceLoc,
Diagnostics::internalCompilerError,
"Cannot differentiate functions not marked for differentiation");
}
}
}
- }
-
- return true;
- }
-
- IRInst* lowerPairType(IRBuilder* builder, IRType* type)
- {
-
- if (auto pairType = as<IRDifferentialPairType>(type))
- {
- builder->setInsertBefore(pairType);
-
- if (!as<IRType>(pairType->getValueType()))
+ // Actually synthesize the derivatives.
+ List<InstPair> followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe);
+ for (auto task : followUpWorkList)
{
- return nullptr;
- }
- auto witness = pairType->getWitness();
- auto diffType = _lookupWitness(builder, witness, autoDiffSharedContextStorage.differentialAssocTypeStructKey);
- if (!diffType)
- {
- return nullptr;
+ auto diffFunc = as<IRFunc>(task.differential);
+ SLANG_ASSERT(diffFunc);
+ auto primalFunc = as<IRFunc>(task.primal);
+ SLANG_ASSERT(primalFunc);
+
+ transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
}
- auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
- builder,
- pairType->getValueType(),
- (IRType*)(diffType));
- pairType->replaceUsesWith(diffPairStructType);
- pairType->removeAndDeallocate();
+ // Transcribing the function body really shouldn't produce more follow up function body work.
+ // However it may produce new `ForwardDifferentiate` instructions, which we collect and process
+ // in the next iteration.
+ SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0);
- return diffPairStructType;
- }
- else if (auto loweredStructType = as<IRStructType>(type))
- {
- // Already lowered to struct.
- return loweredStructType;
- }
- else if (auto specializedStructType = as<IRSpecialize>(type))
- {
- // Already lowered to specialized struct.
- return specializedStructType;
}
-
- return nullptr;
+ return true;
+ }
+
+ IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr)
+ {
+ builder->setInsertBefore(pairType);
+ auto loweredPairTypeInfo = (&pairBuilderStorage)->lowerDiffPairType(
+ builder,
+ pairType);
+ if (isTrivial)
+ *isTrivial = loweredPairTypeInfo.isTrivial;
+ return loweredPairTypeInfo.loweredType;
}
IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
@@ -1948,19 +2144,24 @@ struct JVPDerivativeContext
if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
{
- if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType()))
+ bool isTrivial = false;
+ auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType());
+ if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial))
{
builder->setInsertBefore(makePairInst);
-
- List<IRInst*> operands;
- operands.add(makePairInst->getPrimalValue());
- operands.add(makePairInst->getDifferentialValue());
-
- auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands);
- makePairInst->replaceUsesWith(makeStructInst);
+ IRInst* result = nullptr;
+ if (isTrivial)
+ {
+ result = makePairInst->getPrimalValue();
+ }
+ else
+ {
+ IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() };
+ result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands);
+ }
+ makePairInst->replaceUsesWith(result);
makePairInst->removeAndDeallocate();
-
- return makeStructInst;
+ return result;
}
}
@@ -1971,11 +2172,11 @@ struct JVPDerivativeContext
{
if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
{
- if (lowerPairType(builder, getDiffInst->getBase()->getDataType()))
+ if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), nullptr))
{
builder->setInsertBefore(getDiffInst);
-
- auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
+ IRInst* diffFieldExtract = nullptr;
+ diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
getDiffInst->replaceUsesWith(diffFieldExtract);
getDiffInst->removeAndDeallocate();
return diffFieldExtract;
@@ -1983,14 +2184,14 @@ struct JVPDerivativeContext
}
else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
{
- if (lowerPairType(builder, getPrimalInst->getBase()->getDataType()))
+ if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), nullptr))
{
builder->setInsertBefore(getPrimalInst);
- auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
+ IRInst* primalFieldExtract = nullptr;
+ primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
getPrimalInst->replaceUsesWith(primalFieldExtract);
getPrimalInst->removeAndDeallocate();
-
return primalFieldExtract;
}
}
@@ -2001,40 +2202,195 @@ struct JVPDerivativeContext
bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren)
{
bool modified = false;
+ // Hoist all pair types to global scope when possible.
+ auto moduleInst = module->getModuleInst();
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType)
+ {
+ if (originalPairType->parent != moduleInst)
+ {
+ originalPairType->removeFromParent();
+ ShortList<IRInst*> operands;
+ for (UInt i = 0; i < originalPairType->getOperandCount(); i++)
+ {
+ operands.add(originalPairType->getOperand(i));
+ }
+ auto newPairType = builder->findOrEmitHoistableInst(
+ originalPairType->getFullType(),
+ originalPairType->getOp(),
+ originalPairType->getOperandCount(),
+ operands.getArrayView().getBuffer());
+ originalPairType->replaceUsesWith(newPairType);
+ originalPairType->removeAndDeallocate();
+ }
+ });
- for (auto child = instWithChildren->getFirstChild(); child; )
- {
- // Make sure the builder is at the right level.
- builder->setInsertInto(instWithChildren);
-
- auto nextChild = child->getNextInst();
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
- switch (child->getOp())
+ processAllInsts([&](IRInst* inst)
{
- case kIROp_DifferentialPairType:
- lowerPairType(builder, as<IRType>(child));
- break;
-
+ // Make sure the builder is at the right level.
+ builder->setInsertInto(instWithChildren);
+
+ switch (inst->getOp())
+ {
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
- lowerPairAccess(builder, child);
+ lowerPairAccess(builder, inst);
+ modified = true;
break;
-
+
case kIROp_MakeDifferentialPair:
- lowerMakePair(builder, child);
+ lowerMakePair(builder, inst);
+ modified = true;
break;
-
+
default:
- if (child->getFirstChild())
- modified = processPairTypes(builder, child) | modified;
- }
+ break;
+ }
+ });
- child = nextChild;
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
+ {
+ if (auto loweredType = lowerPairType(builder, inst))
+ {
+ inst->replaceUsesWith(loweredType);
+ inst->removeAndDeallocate();
+ }
+ });
+ return modified;
+ }
+
+ bool simplifyDifferentialBottomType(IRBuilder* builder)
+ {
+ bool modified = false;
+ auto diffBottom = builder->getDifferentialBottom();
+
+ bool changed = true;
+ List<IRUse*> uses;
+ while (changed)
+ {
+ changed = false;
+ // Replace all insts whose type is `DifferentialBottomType` to `diffBottom`.
+ processAllInsts([&](IRInst* inst)
+ {
+ if (inst->getDataType() && inst->getDataType()->getOp() == kIROp_DifferentialBottomType)
+ {
+ if (inst != diffBottom)
+ {
+ inst->replaceUsesWith(diffBottom);
+ inst->removeAndDeallocate();
+ modified = true;
+ }
+ }
+ });
+ // Go through all uses of diffBottom and run simplification.
+ processAllInsts([&](IRInst* inst)
+ {
+ if (!inst->hasUses())
+ return;
+
+ builder->setInsertBefore(inst);
+ IRInst* valueToReplace = nullptr;
+ switch (inst->getOp())
+ {
+ case kIROp_Store:
+ if (as<IRStore>(inst)->getVal() == diffBottom)
+ {
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ return;
+ case kIROp_MakeDifferentialPair:
+ // Our simplification could lead to a situation where
+ // bottom is used to make a pair that has a non-bottom differential type,
+ // in this case we should use zero instead.
+ if (inst->getOperand(1) == diffBottom)
+ {
+ // Only apply if we are the second operand.
+ auto pairType = as<IRDifferentialPairType>(inst->getDataType());
+ if (pairBuilderStorage.getDiffTypeFromPairType(builder, pairType)->getOp() != kIROp_DifferentialBottomType)
+ {
+ auto zero = transcriberStorage.getDifferentialZeroOfType(builder, pairType->getValueType());
+ inst->setOperand(1, zero);
+ changed = true;
+ }
+ }
+ return;
+ case kIROp_DifferentialPairGetDifferential:
+ if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair)
+ {
+ valueToReplace = inst->getOperand(0)->getOperand(1);
+ }
+ break;
+ case kIROp_DifferentialPairGetPrimal:
+ if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair)
+ {
+ valueToReplace = inst->getOperand(0)->getOperand(0);
+ }
+ break;
+ case kIROp_Add:
+ if (inst->getOperand(0) == diffBottom)
+ {
+ valueToReplace = inst->getOperand(1);
+ }
+ else if (inst->getOperand(1) == diffBottom)
+ {
+ valueToReplace = inst->getOperand(0);
+ }
+ break;
+ case kIROp_Sub:
+ if (inst->getOperand(0) == diffBottom)
+ {
+ // If left is bottom, and right is not bottom, then we should return -right.
+ // However we can't possibly run into that case since both side of - operator
+ // must be at the same order of differentiation.
+ valueToReplace = diffBottom;
+ }
+ else if (inst->getOperand(1) == diffBottom)
+ {
+ valueToReplace = inst->getOperand(0);
+ }
+ break;
+ case kIROp_Mul:
+ case kIROp_Div:
+ if (inst->getOperand(0) == diffBottom)
+ {
+ valueToReplace = diffBottom;
+ }
+ else if (inst->getOperand(1) == diffBottom)
+ {
+ valueToReplace = diffBottom;
+ }
+ break;
+ default:
+ break;
+ }
+ if (valueToReplace)
+ {
+ inst->replaceUsesWith(valueToReplace);
+ changed = true;
+ }
+ });
+ modified |= changed;
}
return modified;
}
+ bool eliminateDifferentialBottomType(IRBuilder* builder)
+ {
+ simplifyDifferentialBottomType(builder);
+
+ bool modified = false;
+ auto diffBottom = builder->getDifferentialBottom();
+ auto diffBottomType = diffBottom->getDataType();
+ diffBottom->replaceUsesWith(builder->getVoidValue());
+ diffBottom->removeAndDeallocate();
+ diffBottomType->replaceUsesWith(builder->getVoidType());
+
+ return modified;
+ }
+
// Checks decorators to see if the function should
// be differentiated (kIROp_ForwardDifferentiableDecoration)
//
@@ -2074,27 +2430,18 @@ struct JVPDerivativeContext
}
JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
- module(module),
+ InstPassBase(module),
sink(sink),
autoDiffSharedContextStorage(module->getModuleInst()),
- transcriberStorage(&autoDiffSharedContextStorage)
+ transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage)
{
+ pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage;
transcriberStorage.sink = sink;
transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage);
transcriberStorage.pairBuilder = &(pairBuilderStorage);
}
- protected:
-
- // This type passes over the module and generates
- // forward-mode derivative versions of functions
- // that are explicitly marked for it.
- //
- IRModule* module;
-
- // Shared builder state for our derivative passes.
- SharedIRBuilder sharedBuilderStorage;
-
+protected:
// A transcriber object that handles the main job of
// processing instructions while maintaining state.
//
@@ -2104,10 +2451,6 @@ struct JVPDerivativeContext
// error messages.
DiagnosticSink* sink;
- // Work queue to hold a stream of instructions that need
- // to be checked for references to derivative functions.
- IRWorkQueue workQueueStorage;
-
// Context to find and manage the witness tables for types
// implementing `IDifferentiable`
AutoDiffSharedContext autoDiffSharedContextStorage;