summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-12-21 15:25:38 -0800
committerGitHub <noreply@github.com>2022-12-21 15:25:38 -0800
commit6dbdb74dbdc20783a0429229c21604a3d08d28f8 (patch)
tree910e2dd7b7b296ae5c285dbbb73114b381ef529a /source
parent887842933c0734196729d5525de9835eb48b3855 (diff)
Further unify the autodiff passes. (#2574)
* Further unify the autodiff passes. * Fix clang compilation error. * Rename ForwardDerivativeTranscriber->ForwardDiffTranscriber. * Remove unused fields from Transcriber classes. * More small cleanups. * Cleanup. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1090
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h146
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp567
-rw-r--r--source/slang/slang-ir-autodiff-rev.h88
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp847
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h129
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp33
-rw-r--r--source/slang/slang-ir-autodiff.cpp205
-rw-r--r--source/slang/slang-ir-autodiff.h38
-rw-r--r--source/slang/slang-ir-util.cpp14
-rw-r--r--source/slang/slang-ir-util.h24
11 files changed, 1604 insertions, 1577 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index d1e9f91ec..dbf79b5f8 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -11,123 +11,7 @@
namespace Slang
{
-static IRInst* _unwrapAttributedType(IRInst* type)
-{
- while (auto attrType = as<IRAttributedType>(type))
- type = attrType->getBaseType();
- return type;
-}
-
-DiagnosticSink* ForwardDerivativeTranscriber::getSink()
-{
- SLANG_ASSERT(sink);
- return sink;
-}
-
-void ForwardDerivativeTranscriber::mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
-{
- if (hasDifferentialInst(origInst))
- {
- if (lookupDiffInst(origInst) != diffInst)
- {
- SLANG_UNEXPECTED("Inconsistent differential mappings");
- }
- }
- else
- {
- instMapD.Add(origInst, diffInst);
- }
-}
-
-void ForwardDerivativeTranscriber::mapPrimalInst(IRInst* origInst, IRInst* primalInst)
-{
- if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst)
- {
- getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "inconsistent primal instruction for original");
- }
- else
- {
- cloneEnv.mapOldValToNew[origInst] = primalInst;
- }
-}
-
-IRInst* ForwardDerivativeTranscriber::lookupDiffInst(IRInst* origInst)
-{
- return instMapD[origInst];
-}
-
-IRInst* ForwardDerivativeTranscriber::lookupDiffInst(IRInst* origInst, IRInst* defaultInst)
-{
- return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst;
-}
-
-bool ForwardDerivativeTranscriber::hasDifferentialInst(IRInst* origInst)
-{
- return instMapD.ContainsKey(origInst);
-}
-
-bool ForwardDerivativeTranscriber::shouldUseOriginalAsPrimal(IRInst* origInst)
-{
- if (as<IRGlobalValueWithCode>(origInst))
- return true;
- if (origInst->parent && origInst->parent->getOp() == kIROp_Module)
- return true;
- return false;
-}
-
-IRInst* ForwardDerivativeTranscriber::lookupPrimalInst(IRInst* origInst)
-{
- if (!origInst)
- return nullptr;
- if (shouldUseOriginalAsPrimal(origInst))
- return origInst;
- return cloneEnv.mapOldValToNew[origInst];
-}
-
-IRInst* ForwardDerivativeTranscriber::lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
-{
- if (!origInst)
- return nullptr;
- return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
-}
-
-bool ForwardDerivativeTranscriber::hasPrimalInst(IRInst* origInst)
-{
- if (!origInst)
- return true;
- if (shouldUseOriginalAsPrimal(origInst))
- return true;
- return cloneEnv.mapOldValToNew.ContainsKey(origInst);
-}
-
-IRInst* ForwardDerivativeTranscriber::findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst)
-{
- if (!hasDifferentialInst(origInst))
- {
- transcribe(builder, origInst);
- SLANG_ASSERT(hasDifferentialInst(origInst));
- }
-
- return lookupDiffInst(origInst);
-}
-
-IRInst* ForwardDerivativeTranscriber::findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst)
-{
- if (shouldUseOriginalAsPrimal(origInst))
- return origInst;
-
- if (!hasPrimalInst(origInst))
- {
- transcribe(builder, origInst);
- SLANG_ASSERT(hasPrimalInst(origInst));
- }
-
- return lookupPrimalInst(origInst);
-}
-
-IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+IRFuncType* ForwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
{
List<IRType*> newParameterTypes;
IRType* diffReturnType;
@@ -135,7 +19,7 @@ IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* b
for (UIndex i = 0; i < funcType->getParamCount(); i++)
{
auto origType = funcType->getParamType(i);
- origType = (IRType*) lookupPrimalInst(origType, origType);
+ origType = (IRType*) findOrTranscribePrimalInst(builder, origType);
if (auto diffPairType = tryGetDiffPairType(builder, origType))
newParameterTypes.add(diffPairType);
else
@@ -145,7 +29,7 @@ IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* b
// Transcribe return type to a pair.
// This will be void if the primal return type is non-differentiable.
//
- auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType());
+ auto origResultType = (IRType*)findOrTranscribePrimalInst(builder, funcType->getResultType());
if (auto returnPairType = tryGetDiffPairType(builder, origResultType))
diffReturnType = returnPairType;
else
@@ -154,320 +38,10 @@ IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* b
return builder->getFuncType(newParameterTypes, diffReturnType);
}
-// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
-IRWitnessTable* ForwardDerivativeTranscriber::getDifferentialPairWitness(IRInst* inDiffPairType)
-{
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(inDiffPairType->parent);
- auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
- SLANG_ASSERT(diffPairType);
-
- auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
-
- // Differentiate the pair type to get it's differential (which is itself a pair)
- auto diffDiffPairType = differentiateType(&builder, diffPairType);
-
- // And place it in the synthesized witness table.
- builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
- // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
-
- // Record this in the context for future lookups
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
-
- return table;
-}
-
-IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
-{
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
-}
-
-IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType)
-{
- IRBuilder builder(sharedBuilder);
- if (!primalType->next)
- builder.setInsertInto(primalType->parent);
- else
- builder.setInsertBefore(primalType->next);
-
- IRInst* witness = as<IRWitnessTable>(
- differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
-
- if (!witness)
- {
- if (auto primalPairType = as<IRDifferentialPairType>(primalType))
- {
- witness = getDifferentialPairWitness(primalPairType);
- }
- else if (auto extractExistential = as<IRExtractExistentialType>(primalType))
- {
- differentiateExtractExistentialType(&builder, extractExistential, witness);
- }
- }
-
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
-}
-
-IRType* ForwardDerivativeTranscriber::differentiateType(IRBuilder* builder, IRType* origType)
-{
- IRInst* diffType = nullptr;
- if (!instMapD.TryGetValue(origType, diffType))
- {
- diffType = _differentiateTypeImpl(builder, origType);
- instMapD[origType] = diffType;
- }
- return (IRType*)diffType;
-}
-
-IRType* ForwardDerivativeTranscriber::_differentiateTypeImpl(IRBuilder* builder, IRType* origType)
-{
- if (auto ptrType = as<IRPtrTypeBase>(origType))
- return builder->getPtrType(
- origType->getOp(),
- differentiateType(builder, ptrType->getValueType()));
-
- // If there is an explicit primal version of this type in the local scope, load that
- // otherwise use the original type.
- //
- IRInst* primalType = lookupPrimalInst(origType, origType);
-
- // Special case certain compound types (PtrType, FuncType, etc..)
- // otherwise try to lookup a differential definition for the given type.
- // If one does not exist, then we assume it's not differentiable.
- //
- switch (primalType->getOp())
- {
- case kIROp_Param:
- if (as<IRTypeType>(primalType->getDataType()))
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
- builder,
- (IRType*)primalType));
- else if (as<IRWitnessTableType>(primalType->getDataType()))
- return (IRType*)primalType;
-
- case kIROp_ArrayType:
- {
- auto primalArrayType = as<IRArrayType>(primalType);
- if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
- return builder->getArrayType(
- diffElementType,
- primalArrayType->getElementCount());
- else
- return nullptr;
- }
-
- case kIROp_DifferentialPairType:
- {
- auto primalPairType = as<IRDifferentialPairType>(primalType);
- return getOrCreateDiffPairType(
- pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
- pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
- }
-
- case kIROp_FuncType:
- return differentiateFunctionType(builder, as<IRFuncType>(primalType));
-
- case kIROp_OutType:
- if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
- return builder->getOutType(diffValueType);
- else
- return nullptr;
-
- case kIROp_InOutType:
- if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
- return builder->getInOutType(diffValueType);
- else
- return nullptr;
-
- case kIROp_ExtractExistentialType:
- {
- IRInst* wt = nullptr;
- return differentiateExtractExistentialType(builder, as<IRExtractExistentialType>(primalType), wt);
- }
-
- case kIROp_TupleType:
- {
- auto tupleType = as<IRTupleType>(primalType);
- List<IRType*> diffTypeList;
- // TODO: what if we have type parameters here?
- for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
- diffTypeList.add(
- differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
-
- return builder->getTupleType(diffTypeList);
- }
-
- default:
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
- }
-}
-
- // Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`.
-bool _findDifferentiableInterfaceLookupPathImpl(
- HashSet<IRInst*>& processedTypes,
- IRInterfaceType* idiffType,
- IRInterfaceType* type,
- List<IRInterfaceRequirementEntry*>& currentPath)
-{
- if (processedTypes.Contains(type))
- return false;
- processedTypes.Add(type);
-
- List<IRInterfaceRequirementEntry*> lookupKeyPath;
- for (UInt i = 0; i < type->getOperandCount(); i++)
- {
- auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i));
- if (!entry) continue;
- if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal()))
- {
- currentPath.add(entry);
- if (wt->getConformanceType() == idiffType)
- {
- return true;
- }
- else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType()))
- {
- if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath))
- return true;
- }
- currentPath.removeLast();
- }
- }
- return false;
-}
-
-List<IRInterfaceRequirementEntry*> _findDifferentiableInterfaceLookupPath(
- IRInterfaceType* idiffType,
- IRInterfaceType* type)
-{
- List<IRInterfaceRequirementEntry*> currentPath;
- HashSet<IRInst*> processedTypes;
- _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath);
- return currentPath;
-}
-
-IRType* ForwardDerivativeTranscriber::differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable)
-{
- witnessTable = nullptr;
-
- // Search for IDifferentiable conformance.
- auto interfaceType = as<IRInterfaceType>(_unwrapAttributedType(origType->getOperand(0)->getDataType()));
- if (!interfaceType)
- return nullptr;
- List<IRInterfaceRequirementEntry*> lookupKeyPath = _findDifferentiableInterfaceLookupPath(
- autoDiffSharedContext->differentiableInterfaceType, interfaceType);
-
- if (lookupKeyPath.getCount())
- {
- // `interfaceType` does conform to `IDifferentiable`.
- witnessTable = builder->emitExtractExistentialWitnessTable(origType->getOperand(0));
- for (auto node : lookupKeyPath)
- {
- witnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), witnessTable, node->getRequirementKey());
- }
- auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), witnessTable, autoDiffSharedContext->differentialAssocTypeStructKey);
- return (IRType*)diffType;
- }
- return nullptr;
-}
-
-IRType* ForwardDerivativeTranscriber::tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
-{
- // If this is a PtrType (out, inout, etc..), then create diff pair from
- // value type and re-apply the appropropriate PtrType wrapper.
- //
- if (auto origPtrType = as<IRPtrTypeBase>(primalType))
- {
- if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
- return builder->getPtrType(primalType->getOp(), diffPairValueType);
- else
- return nullptr;
- }
- auto diffType = differentiateType(builder, primalType);
- if (diffType)
- return (IRType*)getOrCreateDiffPairType(primalType);
- return nullptr;
-}
-
-InstPair ForwardDerivativeTranscriber::transcribeParam(IRBuilder* builder, IRParam* origParam)
-{
- auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType());
- // Do not differentiate generic type (and witness table) parameters
- if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
- {
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
- }
-
- // Is this param a phi node or a function parameter?
- auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent());
- bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock());
- if (isFuncParam)
- {
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
- {
- IRInst* diffPairParam = builder->emitParam(diffPairType);
-
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
-
- SLANG_ASSERT(diffPairParam);
-
- if (auto pairType = as<IRDifferentialPairType>(diffPairType))
- {
- return InstPair(
- builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- diffPairParam));
- }
- else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
- {
- auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
-
- return InstPair(
- builder->emitDifferentialPairAddressPrimal(diffPairParam),
- builder->emitDifferentialPairAddressDifferential(
- builder->getPtrType(
- kIROp_PtrType,
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)),
- diffPairParam));
- }
- }
-
- auto primalInst = cloneInst(&cloneEnv, builder, origParam);
- if (auto primalParam = as<IRParam>(primalInst))
- {
- SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
- primalParam->removeFromParent();
- builder->getInsertLoc().getBlock()->addParam(primalParam);
- }
- return InstPair(primalInst, nullptr);
- }
- else
- {
- auto primal = cloneInst(&cloneEnv, builder, origParam);
- IRInst* diff = nullptr;
- if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType))
- {
- diff = builder->emitParam(diffType);
- }
- return InstPair(primal, diff);
- }
-}
-
// Returns "d<var-name>" to use as a name hint for variables and parameters.
// If no primal name is available, returns a blank string.
//
-String ForwardDerivativeTranscriber::getJVPVarName(IRInst* origVar)
+String ForwardDiffTranscriber::getJVPVarName(IRInst* origVar)
{
if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
{
@@ -477,20 +51,7 @@ String ForwardDerivativeTranscriber::getJVPVarName(IRInst* origVar)
return String("");
}
-// Returns "dp<var-name>" to use as a name hint for parameters.
-// If no primal name is available, returns a blank string.
-//
-String ForwardDerivativeTranscriber::makeDiffPairName(IRInst* origVar)
-{
- if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
- {
- return ("dp" + String(namehintDecoration->getName()));
- }
-
- return String("");
-}
-
-InstPair ForwardDerivativeTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
+InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
{
if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
{
@@ -507,7 +68,7 @@ InstPair ForwardDerivativeTranscriber::transcribeVar(IRBuilder* builder, IRVar*
return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr);
}
-InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder, IRInst* origArith)
+InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRInst* origArith)
{
SLANG_ASSERT(origArith->getOperandCount() == 2);
@@ -587,7 +148,7 @@ InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder,
return InstPair(primalArith, nullptr);
}
-InstPair ForwardDerivativeTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic)
+InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic)
{
SLANG_ASSERT(origLogic->getOperandCount() == 2);
@@ -604,7 +165,7 @@ InstPair ForwardDerivativeTranscriber::transcribeBinaryLogic(IRBuilder* builder,
SLANG_UNEXPECTED("Logical operation with non-boolean result");
}
-InstPair ForwardDerivativeTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
+InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
{
auto origPtr = origLoad->getPtr();
auto primalPtr = lookupPrimalInst(origPtr, nullptr);
@@ -637,7 +198,7 @@ InstPair ForwardDerivativeTranscriber::transcribeLoad(IRBuilder* builder, IRLoad
return InstPair(primalLoad, diffLoad);
}
-InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRStore* origStore)
+InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* origStore)
{
IRInst* origStoreLocation = origStore->getPtr();
IRInst* origStoreVal = origStore->getVal();
@@ -679,67 +240,18 @@ InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRSto
return InstPair(primalStore, nullptr);
}
-InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRReturn* origReturn)
-{
- IRInst* origReturnVal = origReturn->getVal();
-
- auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType());
- if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal))
- {
- // If the return value is itself a function, generic or a struct then this
- // is likely to be a generic scope. In this case, we lookup the differential
- // and return that.
- IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
- IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
-
- // Neither of these should be nullptr.
- SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal);
- IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal));
- builder->markInstAsMixedDifferential(diffReturn, nullptr);
-
- return InstPair(diffReturn, diffReturn);
- }
- else if (auto pairType = tryGetDiffPairType(builder, returnDataType))
- {
- IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
- IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
- if(!diffReturnVal)
- diffReturnVal = getDifferentialZeroOfType(builder, returnDataType);
-
- // If the pair type can be formed, this must be non-null.
- SLANG_RELEASE_ASSERT(diffReturnVal);
-
- auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
- builder->markInstAsMixedDifferential(diffPair, pairType);
-
- IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
- builder->markInstAsMixedDifferential(pairReturn, pairType);
-
- return InstPair(pairReturn, pairReturn);
- }
- else
- {
- // If the return type is not differentiable, emit the primal value only.
- IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
-
- IRInst* primalReturn = builder->emitReturn(primalReturnVal);
- return InstPair(primalReturn, nullptr);
-
- }
-}
-
// Since int/float literals are sometimes nested inside an IRConstructor
// instruction, we check to make sure that the nested instr is a constant
// and then return nullptr. Literals do not need to be differentiated.
//
-InstPair ForwardDerivativeTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
+InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
{
IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct);
// Check if the output type can be differentiated. If it cannot be
// differentiated, don't differentiate the inst
//
- auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType());
+ auto primalConstructType = (IRType*)findOrTranscribePrimalInst(builder, origConstruct->getDataType());
if (auto diffConstructType = differentiateType(builder, primalConstructType))
{
UCount operandCount = origConstruct->getOperandCount();
@@ -755,7 +267,7 @@ InstPair ForwardDerivativeTranscriber::transcribeConstruct(IRBuilder* builder, I
else
{
auto operandDataType = origConstruct->getOperand(ii)->getDataType();
- operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType);
+ operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType);
diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
}
}
@@ -778,7 +290,7 @@ InstPair ForwardDerivativeTranscriber::transcribeConstruct(IRBuilder* builder, I
// an appropriate call list based on whichever parameters have differentials
// in the current transcription context.
//
-InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall* origCall)
+InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* origCall)
{
IRInst* origCallee = origCall->getCallee();
@@ -902,7 +414,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall
}
}
-InstPair ForwardDerivativeTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
+InstPair ForwardDiffTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
{
IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle);
@@ -924,7 +436,7 @@ InstPair ForwardDerivativeTranscriber::transcribeSwizzle(IRBuilder* builder, IRS
return InstPair(primalSwizzle, nullptr);
}
-InstPair ForwardDerivativeTranscriber::transcribeByPassthrough(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDiffTranscriber::transcribeByPassthrough(IRBuilder* builder, IRInst* origInst)
{
IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst);
@@ -953,7 +465,7 @@ InstPair ForwardDerivativeTranscriber::transcribeByPassthrough(IRBuilder* builde
diffOperands.getBuffer()));
}
-InstPair ForwardDerivativeTranscriber::transcribeControlFlow(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDiffTranscriber::transcribeControlFlow(IRBuilder* builder, IRInst* origInst)
{
switch(origInst->getOp())
{
@@ -1018,7 +530,7 @@ InstPair ForwardDerivativeTranscriber::transcribeControlFlow(IRBuilder* builder,
return InstPair(nullptr, nullptr);
}
-InstPair ForwardDerivativeTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDiffTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst)
{
switch(origInst->getOp())
{
@@ -1038,7 +550,7 @@ InstPair ForwardDerivativeTranscriber::transcribeConst(IRBuilder* builder, IRIns
return InstPair(nullptr, nullptr);
}
-IRInst* ForwardDerivativeTranscriber::findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
+IRInst* ForwardDiffTranscriber::findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
{
for (UInt i = 0; i < type->getOperandCount(); i++)
{
@@ -1051,7 +563,7 @@ IRInst* ForwardDerivativeTranscriber::findInterfaceRequirement(IRInterfaceType*
return nullptr;
}
-InstPair ForwardDerivativeTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
+InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
{
auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase());
List<IRInst*> primalArgs;
@@ -1120,126 +632,7 @@ InstPair ForwardDerivativeTranscriber::transcribeSpecialize(IRBuilder* builder,
}
}
-InstPair ForwardDerivativeTranscriber::transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst)
-{
- auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable());
- auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey());
- auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType());
- auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey);
-
- auto interfaceType = as<IRInterfaceType>(_unwrapAttributedType(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType()));
- if (!interfaceType)
- {
- return InstPair(primal, nullptr);
- }
- auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
- if (!dict)
- {
- return InstPair(primal, nullptr);
- }
-
- for (auto child : dict->getChildren())
- {
- if (auto item = as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child))
- {
- if (item->getOperand(0) == lookupInst->getRequirementKey())
- {
- auto diffKey = item->getOperand(1);
- if (auto diffType = findInterfaceRequirement(interfaceType, diffKey))
- {
- auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey);
- return InstPair(primal, diff);
- }
- break;
- }
- }
- }
- return InstPair(primal, nullptr);
-}
-
-// In differential computation, the 'default' differential value is always zero.
-// This is a consequence of differential computing being inherently linear. As a
-// result, it's useful to have a method to generate zero literals of any (arithmetic) type.
-// The current implementation requires that types are defined linearly.
-//
-IRInst* ForwardDerivativeTranscriber::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
-{
- if (auto diffType = differentiateType(builder, primalType))
- {
- switch (diffType->getOp())
- {
- case kIROp_DifferentialPairType:
- return builder->emitMakeDifferentialPair(
- diffType,
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
- }
- // Since primalType has a corresponding differential type, we can lookup the
- // definition for zero().
- auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
- if (!zeroMethod)
- {
- // if the differential type itself comes from a witness lookup, we can just lookup the
- // zero method from the same witness table.
- if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType))
- {
- auto wt = lookupInterface->getWitnessTable();
- zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
- }
- }
- SLANG_RELEASE_ASSERT(zeroMethod);
-
- auto emptyArgList = List<IRInst*>();
-
- auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
- builder->markInstAsDifferential(callInst, primalType);
-
- return callInst;
- }
- else
- {
- if (isScalarIntegerType(primalType))
- {
- return builder->getIntValue(primalType, 0);
- }
-
- getSink()->diagnose(primalType->sourceLoc,
- Diagnostics::internalCompilerError,
- "could not generate zero value for given type");
- return nullptr;
- }
-}
-
-InstPair ForwardDerivativeTranscriber::transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
-{
- IRBuilder subBuilder(builder->getSharedBuilder());
- subBuilder.setInsertLoc(builder->getInsertLoc());
-
- IRInst* diffBlock = subBuilder.emitBlock();
-
- // Note: for blocks, we setup the mapping _before_
- // processing the children since we could encounter
- // a lookup while processing the children.
- //
- mapPrimalInst(origBlock, diffBlock);
- mapDifferentialInst(origBlock, diffBlock);
-
- subBuilder.setInsertInto(diffBlock);
-
- // First transcribe every parameter in the block.
- for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
- this->transcribe(&subBuilder, param);
-
- // Then, run through every instruction and use the transcriber to generate the appropriate
- // derivative code.
- //
- for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
- this->transcribe(&subBuilder, child);
-
- return InstPair(diffBlock, diffBlock);
-}
-
-InstPair ForwardDerivativeTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
+InstPair ForwardDiffTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
{
SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst));
@@ -1247,7 +640,7 @@ InstPair ForwardDerivativeTranscriber::transcribeFieldExtract(IRBuilder* builder
auto primalBase = findOrTranscribePrimalInst(builder, origBase);
auto field = originalInst->getOperand(1);
auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>();
- auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType());
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, originalInst->getDataType());
IRInst* primalOperands[] = { primalBase, field };
IRInst* primalFieldExtract = builder->emitIntrinsicInst(
@@ -1278,7 +671,7 @@ InstPair ForwardDerivativeTranscriber::transcribeFieldExtract(IRBuilder* builder
return InstPair(primalFieldExtract, diffFieldExtract);
}
-InstPair ForwardDerivativeTranscriber::transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
+InstPair ForwardDiffTranscriber::transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
{
SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr));
@@ -1286,7 +679,7 @@ InstPair ForwardDerivativeTranscriber::transcribeGetElement(IRBuilder* builder,
auto primalBase = findOrTranscribePrimalInst(builder, origBase);
auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1));
- auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType());
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origGetElementPtr->getDataType());
IRInst* primalOperands[] = {primalBase, primalIndex};
IRInst* primalGetElementPtr = builder->emitIntrinsicInst(
@@ -1313,7 +706,7 @@ InstPair ForwardDerivativeTranscriber::transcribeGetElement(IRBuilder* builder,
return InstPair(primalGetElementPtr, diffGetElementPtr);
}
-InstPair ForwardDerivativeTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
+InstPair ForwardDiffTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
{
// The loop comes with three blocks.. we just need to transcribe each one
// and assemble the new loop instruction.
@@ -1351,7 +744,7 @@ InstPair ForwardDerivativeTranscriber::transcribeLoop(IRBuilder* builder, IRLoop
return InstPair(diffLoop, diffLoop);
}
-InstPair ForwardDerivativeTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
+InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
{
// IfElse Statements come with 4 blocks. We transcribe each block into it's
// linear form, and then wire them up in the same way as the original if-else
@@ -1395,7 +788,7 @@ InstPair ForwardDerivativeTranscriber::transcribeIfElse(IRBuilder* builder, IRIf
return InstPair(diffLoop, diffLoop);
}
-InstPair ForwardDerivativeTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst)
+InstPair ForwardDiffTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst)
{
auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue());
SLANG_ASSERT(primalVal);
@@ -1406,21 +799,16 @@ InstPair ForwardDerivativeTranscriber::transcribeMakeDifferentialPair(IRBuilder*
auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
SLANG_ASSERT(diffDiffVal);
- auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal);
+ auto primalPair = builder->emitMakeDifferentialPair(
+ tryGetDiffPairType(builder, primalVal->getDataType()), primalVal, diffPrimalVal);
auto diffPair = builder->emitMakeDifferentialPair(
- differentiateType(builder, origInst->getDataType()),
+ tryGetDiffPairType(builder, differentiateType(builder, primalVal->getDataType())),
primalDiffVal,
diffDiffVal);
return InstPair(primalPair, diffPair);
}
-InstPair ForwardDerivativeTranscriber::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
-{
- auto primal = cloneInst(&cloneEnv, builder, origInst);
- return InstPair(primal, nullptr);
-}
-
-InstPair ForwardDerivativeTranscriber::transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDiffTranscriber::transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
{
SLANG_ASSERT(
origInst->getOp() == kIROp_DifferentialPairGetDifferential ||
@@ -1444,11 +832,88 @@ InstPair ForwardDerivativeTranscriber::transcribeDifferentialPairGetElement(IRBu
return InstPair(primalResult, diffResult);
}
+InstPair ForwardDiffTranscriber::transcribeSingleOperandInst(IRBuilder* builder, IRInst* origInst)
+{
+ IRInst* origBase = origInst->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType());
+
+ IRInst* primalResult = builder->emitIntrinsicInst(
+ primalType,
+ origInst->getOp(),
+ 1,
+ &primalBase);
+
+ IRInst* diffResult = nullptr;
+
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ {
+ diffResult = builder->emitIntrinsicInst(
+ diffType,
+ origInst->getOp(),
+ 1,
+ &diffBase);
+ }
+ }
+ return InstPair(primalResult, diffResult);
+}
+
+InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, IRInst* origInst)
+{
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType());
+
+ List<IRInst*> primalArgs;
+ for (UInt i = 0; i < origInst->getOperandCount(); i++)
+ {
+ auto primalArg = findOrTranscribePrimalInst(builder, origInst->getOperand(i));
+ primalArgs.add(primalArg);
+ }
+
+ IRInst* primalResult = builder->emitIntrinsicInst(
+ primalType,
+ origInst->getOp(),
+ primalArgs.getCount(),
+ primalArgs.getBuffer());
+
+ IRInst* diffResult = nullptr;
+
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ List<IRInst*> diffArgs;
+ for (UInt i = 0; i < origInst->getOperandCount(); i++)
+ {
+ auto arg = findOrTranscribeDiffInst(builder, origInst->getOperand(i));
+ if (arg)
+ {
+ diffArgs.add(arg);
+ }
+ else if (i == 0)
+ {
+ // If we can't diff the first operand (base), abort now.
+ break;
+ }
+ }
+ if (diffArgs.getCount())
+ {
+ diffResult = builder->emitIntrinsicInst(
+ diffType,
+ origInst->getOp(),
+ diffArgs.getCount(),
+ diffArgs.getBuffer());
+ }
+ }
+ return InstPair(primalResult, diffResult);
+}
+
// Create an empty func to represent the transcribed func of `origFunc`.
-InstPair ForwardDerivativeTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
- IRBuilder builder(inBuilder->getSharedBuilder());
- builder.setInsertBefore(origFunc);
+ if (auto bwdDecor = origFunc->findDecoration<IRForwardDerivativeDecoration>())
+ return InstPair(origFunc, bwdDecor->getForwardDerivativeFunc());
+
+ IRBuilder builder = *inBuilder;
IRFunc* primalFunc = origFunc;
@@ -1482,13 +947,17 @@ InstPair ForwardDerivativeTranscriber::transcribeFuncHeader(IRBuilder* inBuilder
cloneDecoration(dictDecor, diffFunc);
}
- auto result = InstPair(primalFunc, diffFunc);
- followUpFunctionsToTranscribe.add(result);
- return result;
+ FuncBodyTranscriptionTask task;
+ task.type = FuncBodyTranscriptionTaskType::Forward;
+ task.originalFunc = primalFunc;
+ task.resultFunc = diffFunc;
+ autoDiffSharedContext->followUpFunctionsToTranscribe.add(task);
+
+ return InstPair(primalFunc, diffFunc);
}
// Transcribe a function definition.
-InstPair ForwardDerivativeTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
+InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
{
IRBuilder builder(inBuilder->getSharedBuilder());
builder.setInsertInto(diffFunc);
@@ -1502,7 +971,7 @@ InstPair ForwardDerivativeTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFu
}
// Transcribe a generic definition
-InstPair ForwardDerivativeTranscriber::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric)
+InstPair ForwardDiffTranscriber::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric)
{
auto innerVal = findInnerMostGenericReturnVal(origGeneric);
if (auto innerFunc = as<IRFunc>(innerVal))
@@ -1546,69 +1015,7 @@ InstPair ForwardDerivativeTranscriber::transcribeGeneric(IRBuilder* inBuilder, I
return InstPair(primalGeneric, diffGeneric);
}
-IRInst* ForwardDerivativeTranscriber::transcribe(IRBuilder* builder, IRInst* origInst)
-{
- // If a differential intstruction is already mapped for
- // this original inst, return that.
- //
- if (auto diffInst = lookupDiffInst(origInst, nullptr))
- {
- SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check.
- return diffInst;
- }
-
- // Otherwise, dispatch to the appropriate method
- // depending on the op-code.
- //
- instsInProgress.Add(origInst);
- InstPair pair = transcribeInst(builder, origInst);
- instsInProgress.Remove(origInst);
-
- if (auto primalInst = pair.primal)
- {
- mapPrimalInst(origInst, pair.primal);
- mapDifferentialInst(origInst, pair.differential);
- if (pair.differential)
- {
- switch (pair.differential->getOp())
- {
- case kIROp_Func:
- case kIROp_Generic:
- case kIROp_Block:
- // Don't generate again for these.
- // Functions already have their names generated in `transcribeFuncHeader`.
- break;
- default:
- // Generate name hint for the inst.
- if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
- {
- StringBuilder sb;
- sb << "s_diff_" << primalNameHint->getName();
- builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
- }
-
- // Tag the differential inst using a decoration (if it doesn't have one)
- if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() &&
- !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>())
- {
- // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential
- // instead.
- //
- builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType()));
- }
-
- break;
- }
- }
- return pair.differential;
- }
- getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "failed to transcibe instruction");
- return nullptr;
-}
-
-InstPair ForwardDerivativeTranscriber::transcribeInst(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* origInst)
{
// Handle common SSA-style operations
switch (origInst->getOp())
@@ -1695,252 +1102,35 @@ InstPair ForwardDerivativeTranscriber::transcribeInst(IRBuilder* builder, IRInst
case kIROp_DifferentialPairGetPrimal:
case kIROp_DifferentialPairGetDifferential:
return transcribeDifferentialPairGetElement(builder, origInst);
- case kIROp_ExtractExistentialWitnessTable:
- case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialValue:
- case kIROp_WrapExistential:
case kIROp_MakeExistential:
- case kIROp_MakeExistentialWithRTTI:
+ return transcribeSingleOperandInst(builder, origInst);
+ case kIROp_ExtractExistentialType:
+ {
+ IRInst* witnessTable;
+ return InstPair(
+ maybeCloneForPrimalInst(builder, origInst),
+ differentiateExtractExistentialType(
+ builder, as<IRExtractExistentialType>(origInst), witnessTable));
+ }
+ case kIROp_ExtractExistentialWitnessTable:
+ return transcribeExtractExistentialWitnessTable(builder, origInst);
+ case kIROp_WrapExistential:
+ return transcribeWrapExistential(builder, origInst);
+ case kIROp_CreateExistentialObject:
+ // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
+ // so we treat this inst as non differentiable.
+ // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
return trascribeNonDiffInst(builder, origInst);
case kIROp_StructKey:
return InstPair(origInst, nullptr);
- }
- // If none of the cases have been hit, check if the instruction is a
- // type. Only need to explicitly differentiate types if they appear inside a block.
- //
- if (auto origType = as<IRType>(origInst))
- {
- // If this is a generic type, transcibe the parent
- // generic and derive the type from the transcribed generic's
- // return value.
- //
- if (as<IRGeneric>(origType->getParent()->getParent()) &&
- findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType &&
- !instsInProgress.Contains(origType->getParent()->getParent()))
- {
- auto origGenericType = origType->getParent()->getParent();
- auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType);
- auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType));
- return InstPair(
- origGenericType,
- innerDiffGenericType
- );
- }
- else if (as<IRBlock>(origType->getParent()))
- return InstPair(
- cloneInst(&cloneEnv, builder, origType),
- differentiateType(builder, origType));
- else
- return InstPair(
- cloneInst(&cloneEnv, builder, origType),
- nullptr);
- }
-
- // Handle instructions with children
- switch (origInst->getOp())
- {
- case kIROp_Func:
- return transcribeFuncHeader(builder, as<IRFunc>(origInst));
-
- case kIROp_Block:
- return transcribeBlock(builder, as<IRBlock>(origInst));
-
- case kIROp_Generic:
- return transcribeGeneric(builder, as<IRGeneric>(origInst));
+ case kIROp_MakeExistentialWithRTTI:
+ SLANG_UNEXPECTED("MakeExistentialWithRTTI inst is not expected in autodiff pass.");
+ break;
}
- // If we reach this statement, the instruction type is likely unhandled.
- getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::unimplemented,
- "this instruction cannot be differentiated");
-
return InstPair(nullptr, nullptr);
}
-struct ForwardDerivativePass : public InstPassBase
-{
-
- DiagnosticSink* getSink()
- {
- return sink;
- }
-
- bool processModule()
- {
- // TODO(sai): Move this call.
- transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary();
-
- IRBuilder builderStorage(this->autodiffContext->sharedBuilder);
- IRBuilder* builder = &builderStorage;
-
- // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
- // generating derivative code for the referenced function.
- //
- bool modified = processReferencedFunctions(builder);
-
- return modified;
- }
-
- IRInst* lookupJVPReference(IRInst* primalFunction)
- {
- if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
- return jvpDefinition->getForwardDerivativeFunc();
- return nullptr;
- }
-
- // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate),
- // then check that the referenced function is marked correctly for differentiation.
- //
- bool processReferencedFunctions(IRBuilder* builder)
- {
- bool changed = false;
- List<IRInst*> autoDiffWorkList;
- for (;;)
- {
- // Collect all `ForwardDifferentiate` insts from the module.
- autoDiffWorkList.clear();
- processAllInsts([&](IRInst* inst)
- {
- switch (inst->getOp())
- {
- case kIROp_ForwardDifferentiate:
- // Only process now if the operand is a materialized function.
- switch (inst->getOperand(0)->getOp())
- {
- case kIROp_Func:
- case kIROp_Specialize:
- case kIROp_LookupWitness:
- autoDiffWorkList.add(inst);
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- });
-
- if (autoDiffWorkList.getCount() == 0)
- break;
-
- // Process collected `ForwardDifferentiate` insts and replace them with placeholders for
- // differentiated functions.
-
- transcriberStorage.followUpFunctionsToTranscribe.clear();
-
- for (auto differentiateInst : autoDiffWorkList)
- {
- IRInst* baseInst = differentiateInst->getOperand(0);
- if (as<IRForwardDifferentiate>(differentiateInst))
- {
- if (auto existingDiffFunc = lookupJVPReference(baseInst))
- {
- differentiateInst->replaceUsesWith(existingDiffFunc);
- differentiateInst->removeAndDeallocate();
- }
- else
- {
- IRBuilder subBuilder(*builder);
- subBuilder.setInsertBefore(differentiateInst);
- IRInst* diffFunc = transcriberStorage.transcribe(&subBuilder, baseInst);
- SLANG_ASSERT(diffFunc);
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- }
- changed = true;
- }
- }
- // Actually synthesize the derivatives.
- List<InstPair> followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe);
- for (auto task : followUpWorkList)
- {
- auto diffFunc = as<IRFunc>(task.differential);
- SLANG_ASSERT(diffFunc);
- auto primalFunc = as<IRFunc>(task.primal);
- SLANG_ASSERT(primalFunc);
-
- transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
- }
-
- // Transcribing the function body really shouldn't produce more follow up function body work.
- // However it may produce new `ForwardDifferentiate` instructions, which we collect and process
- // in the next iteration.
- SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0);
-
- }
- return changed;
- }
-
- // Checks decorators to see if the function should
- // be differentiated (kIROp_ForwardDifferentiableDecoration)
- //
- bool isMarkedForForwardDifferentiation(IRInst* callable)
- {
- if (auto gen = as<IRGeneric>(callable))
- callable = findGenericReturnVal(gen);
- return callable->findDecoration<IRForwardDifferentiableDecoration>() != nullptr;
- }
-
- IRStringLit* getForwardDerivativeFuncName(IRInst* func)
- {
- IRBuilder builder(&sharedBuilderStorage);
- builder.setInsertBefore(func);
-
- IRStringLit* name = nullptr;
- if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
- {
- name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice());
- }
- else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
- {
- name = builder.getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice());
- }
-
- return name;
- }
-
- ForwardDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) :
- InstPassBase(context->moduleInst->getModule()),
- sink(sink),
- transcriberStorage(context, context->sharedBuilder),
- pairBuilderStorage(context),
- autodiffContext(context)
- {
- transcriberStorage.sink = sink;
- transcriberStorage.autoDiffSharedContext = context;
- transcriberStorage.pairBuilder = &(pairBuilderStorage);
- }
-
-protected:
- // A transcriber object that handles the main job of
- // processing instructions while maintaining state.
- //
- ForwardDerivativeTranscriber transcriberStorage;
-
- // Diagnostic object from the compile request for
- // error messages.
- DiagnosticSink* sink;
-
- // Shared context.
- AutoDiffSharedContext* autodiffContext;
-
- // Builder for dealing with differential pair types.
- DifferentialPairTypeBuilder pairBuilderStorage;
-
-};
-
-// Set up context and call main process method.
-//
-bool processForwardDerivativeCalls(
- AutoDiffSharedContext* autodiffContext,
- DiagnosticSink* sink,
- ForwardDerivativePassOptions const&)
-{
- ForwardDerivativePass fwdPass(autodiffContext, sink);
- bool changed = fwdPass.processModule();
- return changed;
-}
-
}
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 678677625..22ebf9d95 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -1,117 +1,18 @@
// slang-ir-autodiff-fwd.h
#pragma once
-#include "slang-ir.h"
-#include "slang-ir-insts.h"
-#include "slang-compiler.h"
+#include "slang-ir-autodiff-transcriber-base.h"
namespace Slang
{
- template<typename P, typename D>
- struct DiffInstPair
- {
- P primal;
- D differential;
- DiffInstPair() = default;
- DiffInstPair(P primal, D differential) : primal(primal), differential(differential)
- {}
- HashCode getHashCode() const
- {
- Hasher hasher;
- hasher << primal << differential;
- return hasher.getResult();
- }
- bool operator ==(const DiffInstPair& other) const
- {
- return primal == other.primal && differential == other.differential;
- }
- };
-
- typedef DiffInstPair<IRInst*, IRInst*> InstPair;
-
-
-struct ForwardDerivativeTranscriber
+struct ForwardDiffTranscriber : AutoDiffTranscriberBase
{
-
- // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
- // their differential values.
- Dictionary<IRInst*, IRInst*> instMapD;
-
- // Set of insts currently being transcribed. Used to avoid infinite loops.
- HashSet<IRInst*> instsInProgress;
-
- // Cloning environment to hold mapping from old to new copies for the primal
- // instructions.
- IRCloneEnv cloneEnv;
-
- // Diagnostic sink for error messages.
- DiagnosticSink* sink;
-
- // Type conformance information.
- AutoDiffSharedContext* autoDiffSharedContext;
-
- // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
- DifferentialPairTypeBuilder* pairBuilder;
-
- DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
-
- List<InstPair> followUpFunctionsToTranscribe;
-
- SharedIRBuilder* sharedBuilder;
- // Witness table that `DifferentialBottom:IDifferential`.
- IRWitnessTable* differentialBottomWitness = nullptr;
- Dictionary<InstPair, IRInst*> differentialPairTypes;
-
- ForwardDerivativeTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder)
- : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder)
+ ForwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ : AutoDiffTranscriberBase(shared, inSharedBuilder, inSink)
{
-
}
- DiagnosticSink* getSink();
-
- void mapDifferentialInst(IRInst* origInst, IRInst* diffInst);
-
- void mapPrimalInst(IRInst* origInst, IRInst* primalInst);
-
- IRInst* lookupDiffInst(IRInst* origInst);
-
- IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst);
-
- bool hasDifferentialInst(IRInst* origInst);
-
- bool shouldUseOriginalAsPrimal(IRInst* origInst);
-
- IRInst* lookupPrimalInst(IRInst* origInst);
-
- IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst);
-
- bool hasPrimalInst(IRInst* origInst);
-
- IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst);
-
- IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst);
-
- IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType);
-
- // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
- IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType);
-
- IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness);
-
- IRType* getOrCreateDiffPairType(IRInst* primalType);
-
- IRType* differentiateType(IRBuilder* builder, IRType* origType);
-
- IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType);
-
- IRType* differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable);
-
- IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType);
-
- InstPair transcribeParam(IRBuilder* builder, IRParam* origParam);
-
// Returns "d<var-name>" to use as a name hint for variables and parameters.
// If no primal name is available, returns a blank string.
//
@@ -132,8 +33,6 @@ struct ForwardDerivativeTranscriber
InstPair transcribeStore(IRBuilder* builder, IRStore* origStore);
- InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn);
-
// Since int/float literals are sometimes nested inside an IRConstructor
// instruction, we check to make sure that the nested instr is a constant
// and then return nullptr. Literals do not need to be differentiated.
@@ -158,17 +57,6 @@ struct ForwardDerivativeTranscriber
InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);
- InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst);
-
- // In differential computation, the 'default' differential value is always zero.
- // This is a consequence of differential computing being inherently linear. As a
- // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
- // The current implementation requires that types are defined linearly.
- //
- IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType);
-
- InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock);
-
InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst);
InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr);
@@ -179,12 +67,13 @@ struct ForwardDerivativeTranscriber
InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst);
- InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
-
InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst);
- // Create an empty func to represent the transcribed func of `origFunc`.
- InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc);
+ InstPair transcribeSingleOperandInst(IRBuilder* builder, IRInst* origInst);
+
+ InstPair transcribeWrapExistential(IRBuilder* builder, IRInst* origInst);
+
+ virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) override;
// Transcribe a function definition.
InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc);
@@ -192,19 +81,16 @@ struct ForwardDerivativeTranscriber
// Transcribe a generic definition
InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric);
- IRInst* transcribe(IRBuilder* builder, IRInst* origInst);
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
- InstPair transcribeInst(IRBuilder* builder, IRInst* origInst);
-};
+ virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override;
- struct ForwardDerivativePassOptions
+ virtual IROp getDifferentiableMethodDictionaryItemOp() override
{
- // Nothing for now..
- };
+ return kIROp_ForwardDifferentiableMethodRequirementDictionaryItem;
+ }
- bool processForwardDerivativeCalls(
- AutoDiffSharedContext* autodiffContext,
- DiagnosticSink* sink,
- ForwardDerivativePassOptions const& options = ForwardDerivativePassOptions());
+};
}
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 56002231a..cfee49eb1 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -7,83 +7,11 @@
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-autodiff-fwd.h"
-#include "slang-ir-autodiff-propagate.h"
-#include "slang-ir-autodiff-unzip.h"
-#include "slang-ir-autodiff-transpose.h"
namespace Slang
{
-struct BackwardDiffTranscriber
-{
- // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
- // their differential values.
- Dictionary<IRInst*, IRInst*> orginalToTranscribed;
-
- // Set of insts currently being transcribed. Used to avoid infinite loops.
- HashSet<IRInst*> instsInProgress;
-
- // Cloning environment to hold mapping from old to new copies for the primal
- // instructions.
- IRCloneEnv cloneEnv;
-
- // Diagnostic sink for error messages.
- DiagnosticSink* sink;
-
- // Type conformance information.
- AutoDiffSharedContext* autoDiffSharedContext;
-
- // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
- DifferentialPairTypeBuilder* pairBuilder;
-
- DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
-
- List<InstPair> followUpFunctionsToTranscribe;
-
- // Map that stores the upper gradient given an IRInst*
- Dictionary<IRInst*, List<IRInst*>> upperGradients;
- Dictionary<IRInst*, IRInst*> primalToDiffPair;
-
- SharedIRBuilder* sharedBuilder;
- // Witness table that `DifferentialBottom:IDifferential`.
- IRWitnessTable* differentialBottomWitness = nullptr;
- Dictionary<InstPair, IRInst*> differentialPairTypes;
-
- // References to other passes that for reverse-mode transcription.
- ForwardDerivativeTranscriber *fwdDiffTranscriber;
- DiffTransposePass *diffTransposePass;
- DiffPropagationPass *diffPropagationPass;
- DiffUnzipPass *diffUnzipPass;
-
- // Allocate space for the passes.
- ForwardDerivativeTranscriber fwdDiffTranscriberStorage;
- DiffTransposePass diffTransposePassStorage;
- DiffPropagationPass diffPropagationPassStorage;
- DiffUnzipPass diffUnzipPassStorage;
-
-
- BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
- : autoDiffSharedContext(shared)
- , sink(inSink)
- , differentiableTypeConformanceContext(shared)
- , sharedBuilder(inSharedBuilder)
- , fwdDiffTranscriberStorage(shared, inSharedBuilder)
- , diffTransposePassStorage(shared)
- , diffPropagationPassStorage(shared)
- , diffUnzipPassStorage(shared)
- , fwdDiffTranscriber(&fwdDiffTranscriberStorage)
- , diffTransposePass(&diffTransposePassStorage)
- , diffPropagationPass(&diffPropagationPassStorage)
- , diffUnzipPass(&diffUnzipPassStorage)
- { }
-
- DiagnosticSink* getSink()
- {
- SLANG_ASSERT(sink);
- return sink;
- }
-
- IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ IRFuncType* BackwardDiffTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
{
List<IRType*> newParameterTypes;
IRType* diffReturnType;
@@ -123,198 +51,46 @@ struct BackwardDiffTranscriber
return builder->getFuncType(newParameterTypes, diffReturnType);
}
- // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
- IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(inDiffPairType->parent);
- auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
- SLANG_ASSERT(diffPairType);
-
- auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
- auto diffType = differentiateType(&builder, diffPairType->getValueType());
- auto differentialType = builder.getDifferentialPairType(diffType, nullptr);
- builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
- // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
-
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
- return table;
- }
-
- IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
- }
-
- IRType* getOrCreateDiffPairType(IRInst* primalType)
+ InstPair BackwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* origInst)
{
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- auto witness = as<IRWitnessTable>(
- differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
-
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
- }
-
- IRType* differentiateType(IRBuilder* builder, IRType* origType)
- {
- IRInst* diffType = nullptr;
- if (!orginalToTranscribed.TryGetValue(origType, diffType))
- {
- diffType = _differentiateTypeImpl(builder, origType);
- orginalToTranscribed[origType] = diffType;
- }
- return (IRType*)diffType;
- }
-
- IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
- {
- if (auto ptrType = as<IRPtrTypeBase>(origType))
- return builder->getPtrType(
- origType->getOp(),
- differentiateType(builder, ptrType->getValueType()));
-
- // If there is an explicit primal version of this type in the local scope, load that
- // otherwise use the original type.
- //
- IRInst* primalType = origType;
-
- // Special case certain compound types (PtrType, FuncType, etc..)
- // otherwise try to lookup a differential definition for the given type.
- // If one does not exist, then we assume it's not differentiable.
- //
- switch (primalType->getOp())
+ switch (origInst->getOp())
{
case kIROp_Param:
- if (as<IRTypeType>(primalType->getDataType()))
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
- builder,
- (IRType*)primalType));
- else if (as<IRWitnessTableType>(primalType->getDataType()))
- return (IRType*)primalType;
-
- case kIROp_ArrayType:
- {
- auto primalArrayType = as<IRArrayType>(primalType);
- if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
- return builder->getArrayType(
- diffElementType,
- primalArrayType->getElementCount());
- else
- return nullptr;
- }
-
- case kIROp_DifferentialPairType:
- {
- auto primalPairType = as<IRDifferentialPairType>(primalType);
- return getOrCreateDiffPairType(
- pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
- pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
- }
-
- case kIROp_FuncType:
- return differentiateFunctionType(builder, as<IRFuncType>(primalType));
-
- case kIROp_OutType:
- if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
- return builder->getOutType(diffValueType);
- else
- return nullptr;
-
- case kIROp_InOutType:
- if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
- return builder->getInOutType(diffValueType);
- else
- return nullptr;
-
- case kIROp_TupleType:
- {
- auto tupleType = as<IRTupleType>(primalType);
- List<IRType*> diffTypeList;
- // TODO: what if we have type parameters here?
- for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
- diffTypeList.add(
- differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
-
- return builder->getTupleType(diffTypeList);
- }
+ return transcribeParam(builder, as<IRParam>(origInst));
- default:
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
- }
- }
+ case kIROp_Return:
+ return transcribeReturn(builder, as<IRReturn>(origInst));
- IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
- {
- // If this is a PtrType (out, inout, etc..), then create diff pair from
- // value type and re-apply the appropropriate PtrType wrapper.
- //
- if (auto origPtrType = as<IRPtrTypeBase>(primalType))
- {
- if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
- return builder->getPtrType(primalType->getOp(), diffPairValueType);
- else
- return nullptr;
- }
- auto diffType = differentiateType(builder, primalType);
- if (diffType)
- return (IRType*)getOrCreateDiffPairType(primalType);
- return nullptr;
- }
+ case kIROp_LookupWitness:
+ return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
- InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
- {
- auto primalDataType = origParam->getDataType();
- // Do not differentiate generic type (and witness table) parameters
- if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
- {
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
- }
+ case kIROp_Specialize:
+ return transcribeSpecialize(builder, as<IRSpecialize>(origInst));
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
- {
- IRInst* diffPairParam = builder->emitParam(diffPairType);
+ case kIROp_MakeVectorFromScalar:
+ case kIROp_MakeTuple:
+ case kIROp_FloatLit:
+ case kIROp_IntLit:
+ case kIROp_VoidLit:
+ case kIROp_ExtractExistentialWitnessTable:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialValue:
+ case kIROp_WrapExistential:
+ case kIROp_MakeExistential:
+ case kIROp_MakeExistentialWithRTTI:
+ return trascribeNonDiffInst(builder, origInst);
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
-
- SLANG_ASSERT(diffPairParam);
-
- if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
- {
- return InstPair(
- builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- diffPairParam));
- }
- // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
- // its primal and diff parts right now because they would represent a reference
- // to a pair field, which doesn't make sense since pair types are considered mutable.
- // We encode the result as if the param is non-differentiable, and handle it
- // with special care at load/store.
- return InstPair(diffPairParam, nullptr);
+ case kIROp_StructKey:
+ return InstPair(origInst, nullptr);
}
-
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
+ return InstPair(nullptr, nullptr);
}
// Returns "dp<var-name>" to use as a name hint for parameters.
// If no primal name is available, returns a blank string.
//
- String makeDiffPairName(IRInst* origVar)
+ String BackwardDiffTranscriber::makeDiffPairName(IRInst* origVar)
{
if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
{
@@ -330,7 +106,7 @@ struct BackwardDiffTranscriber
// result, it's useful to have a method to generate zero literals of any (arithmetic) type.
// The current implementation requires that types are defined linearly.
//
- IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+ IRInst* BackwardDiffTranscriber::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
{
if (auto diffType = differentiateType(builder, primalType))
{
@@ -364,7 +140,7 @@ struct BackwardDiffTranscriber
}
}
- InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+ InstPair BackwardDiffTranscriber::transposeBlock(IRBuilder* builder, IRBlock* origBlock)
{
IRBuilder subBuilder(builder->getSharedBuilder());
subBuilder.setInsertLoc(builder->getInsertLoc());
@@ -401,10 +177,10 @@ struct BackwardDiffTranscriber
{
sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]);
}
- this->transcribeInstBackward(&subBuilder, child, sumGrad);
+ this->transposeInstBackward(&subBuilder, child, sumGrad);
}
else
- this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst());
+ this->transposeInstBackward(&subBuilder, child, upperGrads->getFirst());
}
subBuilder.emitReturn();
@@ -412,9 +188,20 @@ struct BackwardDiffTranscriber
return InstPair(diffBlock, diffBlock);
}
+ static bool isMarkedForBackwardDifferentiation(IRInst* callable)
+ {
+ return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr;
+ }
+
// Create an empty func to represent the transcribed func of `origFunc`.
- InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
+ if (auto bwdDecor = origFunc->findDecoration<IRBackwardDerivativeDecoration>())
+ return InstPair(origFunc, bwdDecor->getBackwardDerivativeFunc());
+
+ if (!isMarkedForBackwardDifferentiation(origFunc))
+ return InstPair(nullptr, nullptr);
+
IRBuilder builder(inBuilder->getSharedBuilder());
builder.setInsertBefore(origFunc);
@@ -450,13 +237,17 @@ struct BackwardDiffTranscriber
cloneDecoration(dictDecor, diffFunc);
}
- auto result = InstPair(primalFunc, diffFunc);
- followUpFunctionsToTranscribe.add(result);
- return result;
+ FuncBodyTranscriptionTask task;
+ task.originalFunc = primalFunc;
+ task.resultFunc = diffFunc;
+ task.type = FuncBodyTranscriptionTaskType::Backward;
+ autoDiffSharedContext->followUpFunctionsToTranscribe.add(task);
+
+ return InstPair(primalFunc, diffFunc);
}
// Puts parameters into their own block.
- void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func)
+ void BackwardDiffTranscriber::makeParameterBlock(IRBuilder* inBuilder, IRFunc* func)
{
IRBuilder builder(inBuilder->getSharedBuilder());
@@ -491,7 +282,7 @@ struct BackwardDiffTranscriber
builder.emitBranch(firstBlock);
}
- void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType)
+ void BackwardDiffTranscriber::cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType)
{
IRStructType* structType = as<IRStructType>(intermediateType);
if (!structType)
@@ -584,7 +375,7 @@ struct BackwardDiffTranscriber
}
// Transcribe a function definition.
- InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
+ InstPair BackwardDiffTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
{
SLANG_ASSERT(primalFunc);
SLANG_ASSERT(diffFunc);
@@ -592,15 +383,17 @@ struct BackwardDiffTranscriber
// TODO(sai): Fill in documentation.
// Generate a temporary forward derivative function as an intermediate step.
- IRFunc* fwdDiffFunc = as<IRFunc>(fwdDiffTranscriber->transcribeFuncHeader(builder, (IRFunc*)primalFunc).differential);
+ IRBuilder tempBuilder = *builder;
+ tempBuilder.setInsertBefore(diffFunc);
+ IRFunc* fwdDiffFunc = as<IRFunc>(fwdDiffTranscriber->transcribeFuncHeader(&tempBuilder, (IRFunc*)primalFunc).differential);
SLANG_ASSERT(fwdDiffFunc);
// Transcribe the body of the primal function into it's linear (fwd-diff) form.
// TODO(sai): Handle the case when we already have a user-defined fwd-derivative function.
- fwdDiffTranscriber->transcribeFunc(builder, primalFunc, as<IRFunc>(fwdDiffFunc));
+ fwdDiffTranscriber->transcribeFunc(&tempBuilder, primalFunc, as<IRFunc>(fwdDiffFunc));
// Split first block into a paramter block.
- this->makeParameterBlock(builder, as<IRFunc>(fwdDiffFunc));
+ this->makeParameterBlock(&tempBuilder, as<IRFunc>(fwdDiffFunc));
// This steps adds a decoration to instructions that are computing the differential.
// TODO: This is disabled for now because fwd-mode already adds differential decorations
@@ -642,7 +435,7 @@ struct BackwardDiffTranscriber
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(diffFunc, unzippedFwdDiffFunc, intermediateType);
// Transpose the first block (parameter block)
- transcribeParameterBlock(builder, diffFunc);
+ transposeParameterBlock(builder, diffFunc);
builder->setInsertInto(diffFunc);
@@ -663,7 +456,7 @@ struct BackwardDiffTranscriber
return InstPair(primalFunc, diffFunc);
}
- void transcribeParameterBlock(IRBuilder* builder, IRFunc* diffFunc)
+ void BackwardDiffTranscriber::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc)
{
IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock();
@@ -715,7 +508,7 @@ struct BackwardDiffTranscriber
builder->emitParam(dOutParamType);
}
- IRInst* copyParam(IRBuilder* builder, IRParam* origParam)
+ IRInst* BackwardDiffTranscriber::copyParam(IRBuilder* builder, IRParam* origParam)
{
auto primalDataType = origParam->getDataType();
@@ -737,11 +530,10 @@ struct BackwardDiffTranscriber
return diffParam;
}
-
return cloneInst(&cloneEnv, builder, origParam);
}
- InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith)
+ InstPair BackwardDiffTranscriber::copyBinaryArith(IRBuilder* builder, IRInst* origArith)
{
SLANG_ASSERT(origArith->getOperandCount() == 2);
@@ -785,7 +577,7 @@ struct BackwardDiffTranscriber
return InstPair(newInst, nullptr);
}
- IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
+ IRInst* BackwardDiffTranscriber::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
{
SLANG_ASSERT(origArith->getOperandCount() == 2);
@@ -853,7 +645,7 @@ struct BackwardDiffTranscriber
return nullptr;
}
- InstPair copyInst(IRBuilder* builder, IRInst* origInst)
+ InstPair BackwardDiffTranscriber::copyInst(IRBuilder* builder, IRInst* origInst)
{
// Handle common SSA-style operations
switch (origInst->getOp())
@@ -878,7 +670,7 @@ struct BackwardDiffTranscriber
return InstPair(nullptr, nullptr);
}
- IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
+ IRInst* BackwardDiffTranscriber::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
{
IRInOutType* inoutParam = as<IRInOutType>(param->getDataType());
auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType());
@@ -895,19 +687,19 @@ struct BackwardDiffTranscriber
return store;
}
- IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
+ IRInst* BackwardDiffTranscriber::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
{
// Handle common SSA-style operations
switch (origInst->getOp())
{
case kIROp_Param:
- return transcribeParamBackward(builder, as<IRParam>(origInst), grad);
+ return transposeParamBackward(builder, as<IRParam>(origInst), grad);
case kIROp_Add:
case kIROp_Mul:
case kIROp_Sub:
case kIROp_Div:
- return transcribeBinaryArithBackward(builder, origInst, grad);
+ return transposeBinaryArithBackward(builder, origInst, grad);
case kIROp_DifferentialPairGetPrimal:
{
@@ -935,191 +727,72 @@ struct BackwardDiffTranscriber
return nullptr;
}
-
-};
-
-struct ReverseDerivativePass : public InstPassBase
-{
- DiagnosticSink* getSink()
- {
- return sink;
- }
-
- bool processModule()
+ InstPair BackwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
{
+ auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase());
+ List<IRInst*> primalArgs;
+ for (UInt i = 0; i < origSpecialize->getArgCount(); i++)
+ {
+ primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i)));
+ }
+ auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType());
+ auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst(
+ (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer());
- IRBuilder builderStorage(autodiffContext->sharedBuilder);
- IRBuilder* builder = &builderStorage;
+ IRInst* diffBase = nullptr;
+ if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase))
+ {
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
+ {
+ args.add(primalSpecialize->getArg(i));
+ }
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
+ }
- // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
- // generating derivative code for the referenced function.
+ auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
+ // Look for an IRBackwardDerivativeDecoration on the specialize inst.
+ // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
+ // can be specialized, so we put a decoration on the IRSpecialize)
//
- bool modified = processReferencedFunctions(builder);
-
- return modified;
- }
-
- IRInst* lookupJVPReference(IRInst* primalFunction)
- {
- if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
- return jvpDefinition->getForwardDerivativeFunc();
- return nullptr;
- }
-
- // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate),
- // then check that the referenced function is marked correctly for differentiation.
- //
- bool processReferencedFunctions(IRBuilder* builder)
- {
- bool changed = false;
-
- List<IRInst*> autoDiffWorkList;
-
- for (;;)
+ if (auto backDecor = origSpecialize->findDecoration<IRBackwardDerivativeDecoration>())
{
- // Collect all `ForwardDifferentiate` insts from the module.
- autoDiffWorkList.clear();
- processAllInsts([&](IRInst* inst)
- {
- switch (inst->getOp())
- {
- case kIROp_BackwardDifferentiate:
- // Only process now if the operand is a materialized function.
- switch (inst->getOperand(0)->getOp())
- {
- case kIROp_Func:
- case kIROp_Specialize:
- autoDiffWorkList.add(inst);
- break;
- default:
- break;
- }
- break;
- default:
- break;
- }
- });
-
- if (autoDiffWorkList.getCount() == 0)
- break;
-
- // Process collected `ForwardDifferentiate` insts and replace them with placeholders for
- // differentiated functions.
+ auto derivativeFunc = backDecor->getBackwardDerivativeFunc();
- backwardTranscriberStorage.followUpFunctionsToTranscribe.clear();
+ // Make sure this isn't itself a specialize .
+ SLANG_RELEASE_ASSERT(!as<IRSpecialize>(derivativeFunc));
- for (auto differentiateInst : autoDiffWorkList)
- {
- IRInst* baseInst = differentiateInst->getOperand(0);
- if (as<IRBackwardDifferentiate>(differentiateInst))
- {
- if (isMarkedForBackwardDifferentiation(baseInst))
- {
- if (as<IRFunc>(baseInst))
- {
- IRInst* diffFunc =
- backwardTranscriberStorage
- .transcribeFuncHeader(builder, (IRFunc*)baseInst)
- .differential;
- SLANG_ASSERT(diffFunc);
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- changed = true;
- }
- else
- {
- getSink()->diagnose(differentiateInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Unexpected instruction. Expected func or generic");
- }
- }
- }
- }
-
- auto followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe);
- for (auto task : followUpWorkList)
+ return InstPair(primalSpecialize, derivativeFunc);
+ }
+ else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRBackwardDerivativeDecoration>())
+ {
+ diffBase = derivativeDecoration->getBackwardDerivativeFunc();
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
{
- auto diffFunc = as<IRFunc>(task.differential);
- SLANG_ASSERT(diffFunc);
- auto primalFunc = as<IRFunc>(task.primal);
- SLANG_ASSERT(primalFunc);
-
- backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
+ args.add(primalSpecialize->getArg(i));
}
-
- // Transcribing the function body really shouldn't produce more follow up function body work.
- // However it may produce new `ForwardDifferentiate` instructions, which we collect and process
- // in the next iteration.
- SLANG_RELEASE_ASSERT(backwardTranscriberStorage.followUpFunctionsToTranscribe.getCount() == 0);
-
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
}
- return changed;
- }
-
- // Checks decorators to see if the function should
- // be differentiated (kIROp_ForwardDifferentiableDecoration)
- //
- bool isMarkedForBackwardDifferentiation(IRInst* callable)
- {
- return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr;
- }
-
- IRStringLit* getBackwardDerivativeFuncName(IRInst* func)
- {
- IRBuilder builder(&sharedBuilderStorage);
- builder.setInsertBefore(func);
-
- IRStringLit* name = nullptr;
- if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
+ else if (auto diffDecor = genericInnerVal->findDecoration<IRBackwardDifferentiableDecoration>())
{
- name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice());
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
+ {
+ args.add(primalSpecialize->getArg(i));
+ }
+ diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase());
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
}
- else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
+ else
{
- name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice());
+ return InstPair(primalSpecialize, nullptr);
}
-
- return name;
- }
-
- ReverseDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) :
- InstPassBase(context->moduleInst->getModule()),
- sink(sink),
- backwardTranscriberStorage(context, context->sharedBuilder, sink),
- autodiffContext(context),
- pairBuilderStorage(context)
- {
- backwardTranscriberStorage.pairBuilder = &pairBuilderStorage;
- backwardTranscriberStorage.fwdDiffTranscriberStorage.sink = sink;
- backwardTranscriberStorage.fwdDiffTranscriberStorage.autoDiffSharedContext = context;
- backwardTranscriberStorage.fwdDiffTranscriberStorage.pairBuilder = &(pairBuilderStorage);
}
-
-protected:
- // A transcriber object that handles the main job of
- // processing instructions while maintaining state.
- //
- BackwardDiffTranscriber backwardTranscriberStorage;
-
- // Diagnostic object from the compile request for
- // error messages.
- DiagnosticSink* sink;
-
- // Builder for dealing with differential pair types.
- DifferentialPairTypeBuilder pairBuilderStorage;
-
- // Autodiff Shared Context
- AutoDiffSharedContext* autodiffContext;
-};
-
-bool processReverseDerivativeCalls(
- AutoDiffSharedContext* autodiffContext,
- DiagnosticSink* sink,
- IRReverseDerivativePassOptions const&)
-{
- ReverseDerivativePass revPass(autodiffContext, sink);
- bool changed = revPass.processModule();
- return changed;
-}
-
}
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index c3d31e2a9..f9ca6110c 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -7,6 +7,10 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-autodiff-fwd.h"
+#include "slang-ir-autodiff-transcriber-base.h"
+#include "slang-ir-autodiff-propagate.h"
+#include "slang-ir-autodiff-unzip.h"
+#include "slang-ir-autodiff-transpose.h"
namespace Slang
{
@@ -16,10 +20,84 @@ struct IRReverseDerivativePassOptions
// Nothing for now..
};
-bool processReverseDerivativeCalls(
- AutoDiffSharedContext* autodiffContext,
- DiagnosticSink* sink,
- IRReverseDerivativePassOptions const& options = IRReverseDerivativePassOptions());
+struct BackwardDiffTranscriber : AutoDiffTranscriberBase
+{
+ // Map that stores the upper gradient given an IRInst*
+ Dictionary<IRInst*, List<IRInst*>> upperGradients;
+ Dictionary<IRInst*, IRInst*> primalToDiffPair;
+ Dictionary<IRInst*, IRInst*> orginalToTranscribed;
+
+ // References to other passes that for reverse-mode transcription.
+ ForwardDiffTranscriber* fwdDiffTranscriber;
+ DiffTransposePass* diffTransposePass;
+ DiffPropagationPass* diffPropagationPass;
+ DiffUnzipPass* diffUnzipPass;
+
+ // Allocate space for the passes.
+ DiffTransposePass diffTransposePassStorage;
+ DiffPropagationPass diffPropagationPassStorage;
+ DiffUnzipPass diffUnzipPassStorage;
+
+ BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ : AutoDiffTranscriberBase(shared, inSharedBuilder, inSink)
+ , diffTransposePassStorage(shared)
+ , diffPropagationPassStorage(shared)
+ , diffUnzipPassStorage(shared)
+ , diffTransposePass(&diffTransposePassStorage)
+ , diffPropagationPass(&diffPropagationPassStorage)
+ , diffUnzipPass(&diffUnzipPassStorage)
+ { }
+
+ // Returns "dp<var-name>" to use as a name hint for parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String makeDiffPairName(IRInst* origVar);
+
+ // In differential computation, the 'default' differential value is always zero.
+ // This is a consequence of differential computing being inherently linear. As a
+ // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+ // The current implementation requires that types are defined linearly.
+ //
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType);
+
+ InstPair transposeBlock(IRBuilder* builder, IRBlock* origBlock);
+
+ // Puts parameters into their own block.
+ void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func);
+
+ void cleanUpUnusedPrimalIntermediate(IRInst* func, IRInst* primalFunc, IRInst* intermediateType);
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc);
+
+ void transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc);
+ IRInst* copyParam(IRBuilder* builder, IRParam* origParam);
+
+ InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith);
+
+ IRInst* transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad);
+
+ InstPair copyInst(IRBuilder* builder, IRInst* origInst);
+
+ IRInst* transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad);
+
+ IRInst* transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad);
+
+ InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) override;
+
+ virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) override;
+
+ virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override;
+
+ virtual IROp getDifferentiableMethodDictionaryItemOp() override
+ {
+ return kIROp_ForwardDifferentiableMethodRequirementDictionaryItem;
+ }
+
+};
-} \ No newline at end of file
+}
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
new file mode 100644
index 000000000..da7762908
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -0,0 +1,847 @@
+// slang-ir-autodiff-trascriber-base.cpp
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-transcriber-base.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
+
+namespace Slang
+{
+
+DiagnosticSink* AutoDiffTranscriberBase::getSink()
+{
+ SLANG_ASSERT(sink);
+ return sink;
+}
+
+String AutoDiffTranscriberBase::makeDiffPairName(IRInst* origVar)
+{
+ if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
+ {
+ return ("dp" + String(namehintDecoration->getName()));
+ }
+
+ return String("");
+}
+
+void AutoDiffTranscriberBase::mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
+{
+ if (hasDifferentialInst(origInst))
+ {
+ if (lookupDiffInst(origInst) != diffInst)
+ {
+ SLANG_UNEXPECTED("Inconsistent differential mappings");
+ }
+ }
+ else
+ {
+ instMapD.Add(origInst, diffInst);
+ }
+}
+
+void AutoDiffTranscriberBase::mapPrimalInst(IRInst* origInst, IRInst* primalInst)
+{
+ if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst)
+ {
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "inconsistent primal instruction for original");
+ }
+ else
+ {
+ cloneEnv.mapOldValToNew[origInst] = primalInst;
+ }
+}
+
+IRInst* AutoDiffTranscriberBase::lookupDiffInst(IRInst* origInst)
+{
+ return instMapD[origInst];
+}
+
+IRInst* AutoDiffTranscriberBase::lookupDiffInst(IRInst* origInst, IRInst* defaultInst)
+{
+ if (auto lookupResult = instMapD.TryGetValue(origInst))
+ return *lookupResult;
+ return defaultInst;
+}
+
+bool AutoDiffTranscriberBase::hasDifferentialInst(IRInst* origInst)
+{
+ if (!origInst)
+ return false;
+ return instMapD.ContainsKey(origInst);
+}
+
+bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* origInst)
+{
+ if (as<IRGlobalValueWithCode>(origInst))
+ return true;
+ if (origInst->parent && origInst->parent->getOp() == kIROp_Module)
+ return true;
+ return false;
+}
+
+IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* origInst)
+{
+ if (!origInst)
+ return nullptr;
+ if (shouldUseOriginalAsPrimal(origInst))
+ return origInst;
+ return cloneEnv.mapOldValToNew[origInst];
+}
+
+IRInst* AutoDiffTranscriberBase::lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
+{
+ if (!origInst)
+ return nullptr;
+ return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
+}
+
+bool AutoDiffTranscriberBase::hasPrimalInst(IRInst* origInst)
+{
+ if (!origInst)
+ return false;
+ if (shouldUseOriginalAsPrimal(origInst))
+ return true;
+ return cloneEnv.mapOldValToNew.ContainsKey(origInst);
+}
+
+IRInst* AutoDiffTranscriberBase::findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst)
+{
+ if (!hasDifferentialInst(origInst))
+ {
+ transcribe(builder, origInst);
+ SLANG_ASSERT(hasDifferentialInst(origInst));
+ }
+
+ return lookupDiffInst(origInst);
+}
+
+IRInst* AutoDiffTranscriberBase::findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst)
+{
+ if (!origInst)
+ return origInst;
+
+ if (shouldUseOriginalAsPrimal(origInst))
+ return origInst;
+
+ if (!hasPrimalInst(origInst))
+ {
+ transcribe(builder, origInst);
+ SLANG_ASSERT(hasPrimalInst(origInst));
+ }
+
+ return lookupPrimalInst(origInst);
+}
+
+IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst)
+{
+ IRInst* primal = lookupPrimalInst(inst, inst);
+
+ if (primal == inst &&
+ !isChildInstOf(builder->getInsertLoc().getParent(), inst->getParent()))
+ primal = cloneInst(&cloneEnv, builder, inst);
+
+ return primal;
+}
+
+// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+IRWitnessTable* AutoDiffTranscriberBase::getDifferentialPairWitness(IRInst* inDiffPairType)
+{
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(inDiffPairType->parent);
+ auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
+ SLANG_ASSERT(diffPairType);
+
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffDiffPairType = differentiateType(&builder, diffPairType);
+
+ // And place it in the synthesized witness table.
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+
+ // Record this in the context for future lookups
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+
+ return table;
+}
+
+IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+{
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+}
+
+IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType)
+{
+ IRBuilder builder(sharedBuilder);
+ if (!primalType->next)
+ builder.setInsertInto(primalType->parent);
+ else
+ builder.setInsertBefore(primalType->next);
+
+ IRInst* witness = as<IRWitnessTable>(
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+
+ if (!witness)
+ {
+ if (auto primalPairType = as<IRDifferentialPairType>(primalType))
+ {
+ witness = getDifferentialPairWitness(primalPairType);
+ }
+ else if (auto extractExistential = as<IRExtractExistentialType>(primalType))
+ {
+ differentiateExtractExistentialType(&builder, extractExistential, witness);
+ }
+ }
+
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+}
+
+IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType)
+{
+ return (IRType*)transcribe(builder, origType);
+}
+
+IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRType* origType)
+{
+ if (auto ptrType = as<IRPtrTypeBase>(origType))
+ return builder->getPtrType(
+ origType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
+
+ // If there is an explicit primal version of this type in the local scope, load that
+ // otherwise use the original type.
+ //
+ IRInst* primalType = lookupPrimalInst(origType, origType);
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as<IRTypeType>(primalType->getDataType()))
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
+ builder,
+ (IRType*)primalType));
+ else if (as<IRWitnessTableType>(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
+ {
+ auto primalArrayType = as<IRArrayType>(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
+
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as<IRDifferentialPairType>(primalType);
+ return getOrCreateDiffPairType(
+ pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
+ pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_FuncType:
+ return differentiateFunctionType(builder, as<IRFuncType>(primalType));
+
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_ExtractExistentialType:
+ {
+ IRInst* wt = nullptr;
+ return differentiateExtractExistentialType(builder, as<IRExtractExistentialType>(primalType), wt);
+ }
+
+ case kIROp_TupleType:
+ {
+ auto tupleType = as<IRTupleType>(primalType);
+ List<IRType*> diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
+
+ return builder->getTupleType(diffTypeList);
+ }
+
+ default:
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
+ }
+}
+
+// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`.
+static bool _findDifferentiableInterfaceLookupPathImpl(
+ HashSet<IRInst*>& processedTypes,
+ IRInterfaceType* idiffType,
+ IRInterfaceType* type,
+ List<IRInterfaceRequirementEntry*>& currentPath)
+{
+ if (processedTypes.Contains(type))
+ return false;
+ processedTypes.Add(type);
+
+ List<IRInterfaceRequirementEntry*> lookupKeyPath;
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ {
+ auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i));
+ if (!entry) continue;
+ if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal()))
+ {
+ currentPath.add(entry);
+ if (wt->getConformanceType() == idiffType)
+ {
+ return true;
+ }
+ else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType()))
+ {
+ if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath))
+ return true;
+ }
+ currentPath.removeLast();
+ }
+ }
+ return false;
+}
+
+List<IRInterfaceRequirementEntry*> AutoDiffTranscriberBase::findDifferentiableInterfaceLookupPath(
+ IRInterfaceType* idiffType,
+ IRInterfaceType* type)
+{
+ List<IRInterfaceRequirementEntry*> currentPath;
+ HashSet<IRInst*> processedTypes;
+ _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath);
+ return currentPath;
+}
+
+InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst)
+{
+ IRInst* witnessTable = nullptr;
+
+ IRInst* origBase = origInst->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origInst->getDataType());
+
+ IRInst* primalResult = builder->emitIntrinsicInst(
+ primalType,
+ origInst->getOp(),
+ 1,
+ &primalBase);
+
+ // Search for IDifferentiable conformance.
+ auto interfaceType = as<IRInterfaceType>(
+ unwrapAttributedType(cast<IRWitnessTableType>(origInst->getDataType())->getConformanceType()));
+ if (!interfaceType)
+ return InstPair(primalResult, nullptr);
+ List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath(
+ autoDiffSharedContext->differentiableInterfaceType, interfaceType);
+
+ if (lookupKeyPath.getCount())
+ {
+ // `interfaceType` does conform to `IDifferentiable`.
+ witnessTable = primalResult;
+ for (auto node : lookupKeyPath)
+ {
+ witnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), witnessTable, node->getRequirementKey());
+ }
+ return InstPair(primalResult, witnessTable);
+ }
+ return InstPair(primalResult, nullptr);
+}
+
+
+IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& outWitnessTable)
+{
+ outWitnessTable = nullptr;
+
+ // Search for IDifferentiable conformance.
+ auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType()));
+ if (!interfaceType)
+ return nullptr;
+ List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath(
+ autoDiffSharedContext->differentiableInterfaceType, interfaceType);
+
+ if (lookupKeyPath.getCount())
+ {
+ // `interfaceType` does conform to `IDifferentiable`.
+ outWitnessTable = builder->emitExtractExistentialWitnessTable(origType->getOperand(0));
+ for (auto node : lookupKeyPath)
+ {
+ outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey());
+ }
+ auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, autoDiffSharedContext->differentialAssocTypeStructKey);
+ return (IRType*)diffType;
+ }
+ return nullptr;
+}
+
+IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
+{
+ // If this is a PtrType (out, inout, etc..), then create diff pair from
+ // value type and re-apply the appropropriate PtrType wrapper.
+ //
+ if (auto origPtrType = as<IRPtrTypeBase>(primalType))
+ {
+ if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
+ return builder->getPtrType(primalType->getOp(), diffPairValueType);
+ else
+ return nullptr;
+ }
+ auto diffType = differentiateType(builder, primalType);
+ if (diffType)
+ return (IRType*)getOrCreateDiffPairType(primalType);
+ return nullptr;
+}
+
+IRInst* AutoDiffTranscriberBase::findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
+{
+ for (UInt i = 0; i < type->getOperandCount(); i++)
+ {
+ if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i)))
+ {
+ if (req->getRequirementKey() == key)
+ return req->getRequirementVal();
+ }
+ }
+ return nullptr;
+}
+
+InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* origParam)
+{
+ auto primalDataType = findOrTranscribePrimalInst(builder, origParam->getDataType());
+ // Do not differentiate generic type (and witness table) parameters
+ if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
+ {
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ // Is this param a phi node or a function parameter?
+ auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent());
+ bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock());
+ if (isFuncParam)
+ {
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffPairParam);
+
+ if (auto pairType = as<IRDifferentialPairType>(diffPairType))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
+ {
+ auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
+
+ return InstPair(
+ builder->emitDifferentialPairAddressPrimal(diffPairParam),
+ builder->emitDifferentialPairAddressDifferential(
+ builder->getPtrType(
+ kIROp_PtrType,
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)),
+ diffPairParam));
+ }
+ }
+
+ auto primalInst = cloneInst(&cloneEnv, builder, origParam);
+ if (auto primalParam = as<IRParam>(primalInst))
+ {
+ SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock());
+ primalParam->removeFromParent();
+ builder->getInsertLoc().getBlock()->addParam(primalParam);
+ }
+ return InstPair(primalInst, nullptr);
+ }
+ else
+ {
+ auto primal = cloneInst(&cloneEnv, builder, origParam);
+ IRInst* diff = nullptr;
+ if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType))
+ {
+ diff = builder->emitParam(diffType);
+ }
+ return InstPair(primal, diff);
+ }
+}
+
+InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst)
+{
+ auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable());
+ auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey());
+ auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType());
+ auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey);
+
+ auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType()));
+ if (!interfaceType)
+ {
+ return InstPair(primal, nullptr);
+ }
+ auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
+ if (!dict)
+ {
+ return InstPair(primal, nullptr);
+ }
+
+ for (auto child : dict->getChildren())
+ {
+ if (auto item = as<IRDifferentiableMethodRequirementDictionaryItem>(child))
+ {
+ if (item->getOp() == getDifferentiableMethodDictionaryItemOp())
+ {
+ if (item->getOperand(0) == lookupInst->getRequirementKey())
+ {
+ auto diffKey = item->getOperand(1);
+ if (auto diffType = findInterfaceRequirement(interfaceType, diffKey))
+ {
+ auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey);
+ return InstPair(primal, diff);
+ }
+ break;
+ }
+ }
+ }
+ }
+ return InstPair(primal, nullptr);
+}
+
+// In differential computation, the 'default' differential value is always zero.
+// This is a consequence of differential computing being inherently linear. As a
+// result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+// The current implementation requires that types are defined linearly.
+//
+IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+{
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ switch (diffType->getOp())
+ {
+ case kIROp_DifferentialPairType:
+ return builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ }
+ // Since primalType has a corresponding differential type, we can lookup the
+ // definition for zero().
+ auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
+ if (!zeroMethod)
+ {
+ // if the differential type itself comes from a witness lookup, we can just lookup the
+ // zero method from the same witness table.
+ if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType))
+ {
+ auto wt = lookupInterface->getWitnessTable();
+ zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey);
+ }
+ }
+ SLANG_RELEASE_ASSERT(zeroMethod);
+
+ auto emptyArgList = List<IRInst*>();
+
+ auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ builder->markInstAsDifferential(callInst, primalType);
+
+ return callInst;
+ }
+ else
+ {
+ if (isScalarIntegerType(primalType))
+ {
+ return builder->getIntValue(primalType, 0);
+ }
+
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
+ }
+}
+
+InstPair AutoDiffTranscriberBase::transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+{
+ IRBuilder subBuilder(builder->getSharedBuilder());
+ subBuilder.setInsertLoc(builder->getInsertLoc());
+
+ IRInst* diffBlock = subBuilder.emitBlock();
+
+ // Note: for blocks, we setup the mapping _before_
+ // processing the children since we could encounter
+ // a lookup while processing the children.
+ //
+ mapPrimalInst(origBlock, diffBlock);
+ mapDifferentialInst(origBlock, diffBlock);
+
+ subBuilder.setInsertInto(diffBlock);
+
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ this->transcribe(&subBuilder, param);
+
+ // Then, run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ this->transcribe(&subBuilder, child);
+
+ return InstPair(diffBlock, diffBlock);
+}
+
+InstPair AutoDiffTranscriberBase::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
+{
+ auto primal = cloneInst(&cloneEnv, builder, origInst);
+ return InstPair(primal, nullptr);
+}
+
+InstPair AutoDiffTranscriberBase::transcribeReturn(IRBuilder* builder, IRReturn* origReturn)
+{
+ IRInst* origReturnVal = origReturn->getVal();
+
+ auto returnDataType = (IRType*)findOrTranscribePrimalInst(builder, origReturnVal->getDataType());
+ if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal))
+ {
+ // If the return value is itself a function, generic or a struct then this
+ // is likely to be a generic scope. In this case, we lookup the differential
+ // and return that.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+ IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
+
+ // Neither of these should be nullptr.
+ SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal);
+ IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal));
+ builder->markInstAsMixedDifferential(diffReturn, nullptr);
+
+ return InstPair(diffReturn, diffReturn);
+ }
+ else if (auto pairType = tryGetDiffPairType(builder, returnDataType))
+ {
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+ IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
+ if (!diffReturnVal)
+ diffReturnVal = getDifferentialZeroOfType(builder, returnDataType);
+
+ // If the pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffReturnVal);
+
+ auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
+ builder->markInstAsMixedDifferential(diffPair, pairType);
+
+ IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
+ builder->markInstAsMixedDifferential(pairReturn, pairType);
+
+ return InstPair(pairReturn, pairReturn);
+ }
+ else
+ {
+ // If the return type is not differentiable, emit the primal value only.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+
+ IRInst* primalReturn = builder->emitReturn(primalReturnVal);
+ return InstPair(primalReturn, nullptr);
+
+ }
+}
+
+// Transcribe a generic definition
+InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric)
+{
+ auto innerVal = findInnerMostGenericReturnVal(origGeneric);
+ if (auto innerFunc = as<IRFunc>(innerVal))
+ {
+ differentiableTypeConformanceContext.setFunc(innerFunc);
+ }
+ else if (auto funcType = as<IRFuncType>(innerVal))
+ {
+ }
+ else
+ {
+ return InstPair(origGeneric, nullptr);
+ }
+
+ IRGeneric* primalGeneric = origGeneric;
+
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origGeneric);
+
+ auto diffGeneric = builder.emitGeneric();
+
+ // Process type of generic. If the generic is a function, then it's type will also be a
+ // generic and this logic will transcribe that generic first before continuing with the
+ // function itself.
+ //
+ auto primalType = primalGeneric->getFullType();
+
+ IRType* diffType = nullptr;
+ if (primalType)
+ {
+ diffType = (IRType*)findOrTranscribeDiffInst(&builder, primalType);
+ }
+
+ diffGeneric->setFullType(diffType);
+
+ // Transcribe children from origFunc into diffFunc.
+ builder.setInsertInto(diffGeneric);
+ for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribe(&builder, block);
+
+ return InstPair(primalGeneric, diffGeneric);
+}
+
+IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst)
+{
+ // If a differential intstruction is already mapped for
+ // this original inst, return that.
+ //
+ if (auto diffInst = lookupDiffInst(origInst, nullptr))
+ {
+ SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check.
+ return diffInst;
+ }
+
+ // Otherwise, dispatch to the appropriate method
+ // depending on the op-code.
+ //
+ instsInProgress.Add(origInst);
+
+ InstPair pair = transcribeInst(builder, origInst);
+
+ instsInProgress.Remove(origInst);
+
+ if (auto primalInst = pair.primal)
+ {
+ mapPrimalInst(origInst, pair.primal);
+ mapDifferentialInst(origInst, pair.differential);
+ if (pair.differential)
+ {
+ switch (pair.differential->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Generic:
+ case kIROp_Block:
+ // Don't generate again for these.
+ // Functions already have their names generated in `transcribeFuncHeader`.
+ break;
+ default:
+ // Generate name hint for the inst.
+ if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder sb;
+ sb << "s_diff_" << primalNameHint->getName();
+ builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
+ }
+
+ // Tag the differential inst using a decoration (if it doesn't have one)
+ if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() &&
+ !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>())
+ {
+ // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential
+ // instead.
+ //
+ builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType()));
+ }
+
+ break;
+ }
+ }
+ return pair.differential;
+ }
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "failed to transcibe instruction");
+ return nullptr;
+}
+
+InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* origInst)
+{
+ // Handle instructions with children
+ switch (origInst->getOp())
+ {
+ case kIROp_Func:
+ return transcribeFuncHeader(builder, as<IRFunc>(origInst));
+
+ case kIROp_Block:
+ return transcribeBlock(builder, as<IRBlock>(origInst));
+
+ case kIROp_Generic:
+ return transcribeGeneric(builder, as<IRGeneric>(origInst));
+ }
+
+ auto result = transcribeInstImpl(builder, origInst);
+
+ if (result.primal == nullptr && result.differential == nullptr)
+ {
+ if (auto origType = as<IRType>(origInst))
+ {
+ // If this is a generic type, transcibe the parent
+ // generic and derive the type from the transcribed generic's
+ // return value.
+ //
+ if (as<IRGeneric>(origType->getParent()->getParent()) &&
+ findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType &&
+ !instsInProgress.Contains(origType->getParent()->getParent()))
+ {
+ auto origGenericType = origType->getParent()->getParent();
+ auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType);
+ auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType));
+ result = InstPair(
+ origGenericType,
+ innerDiffGenericType
+ );
+ }
+ else
+ {
+ auto diffType = _differentiateTypeImpl(builder, origType);
+ IRInst* primal = maybeCloneForPrimalInst(builder, origType);
+ result = InstPair(primal, diffType);
+ }
+ }
+ }
+
+ if (result.primal == nullptr && result.differential == nullptr)
+ {
+ // If we reach this statement, the instruction type is likely unhandled.
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::unimplemented,
+ "this instruction cannot be differentiated");
+ }
+
+ return result;
+}
+
+}
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
new file mode 100644
index 000000000..8e4b7a901
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -0,0 +1,129 @@
+// slang-ir-autodiff-transcriber-base.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+#include "slang-ir-autodiff.h"
+
+namespace Slang
+{
+
+struct AutoDiffTranscriberBase
+{
+ // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
+ // their differential values.
+ Dictionary<IRInst*, IRInst*> instMapD;
+
+ // Set of insts currently being transcribed. Used to avoid infinite loops.
+ HashSet<IRInst*> instsInProgress;
+
+ // Cloning environment to hold mapping from old to new copies for the primal
+ // instructions.
+ IRCloneEnv cloneEnv;
+
+ // Diagnostic sink for error messages.
+ DiagnosticSink* sink;
+
+ // Type conformance information.
+ AutoDiffSharedContext* autoDiffSharedContext;
+
+ // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
+ DifferentialPairTypeBuilder* pairBuilder;
+
+ DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
+
+ SharedIRBuilder* sharedBuilder;
+
+ AutoDiffTranscriberBase(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ : autoDiffSharedContext(shared)
+ , differentiableTypeConformanceContext(shared)
+ , sharedBuilder(inSharedBuilder)
+ , sink(inSink)
+ {
+
+ }
+
+ DiagnosticSink* getSink();
+
+ // Returns "dp<var-name>" to use as a name hint for parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String makeDiffPairName(IRInst* origVar);
+
+ void mapDifferentialInst(IRInst* origInst, IRInst* diffInst);
+
+ void mapPrimalInst(IRInst* origInst, IRInst* primalInst);
+
+ IRInst* lookupDiffInst(IRInst* origInst);
+
+ IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst);
+
+ bool hasDifferentialInst(IRInst* origInst);
+
+ bool shouldUseOriginalAsPrimal(IRInst* origInst);
+
+ IRInst* lookupPrimalInst(IRInst* origInst);
+
+ IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst);
+
+ bool hasPrimalInst(IRInst* origInst);
+
+ IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst);
+
+ IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst);
+
+ IRInst* maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst);
+
+ List<IRInterfaceRequirementEntry*> findDifferentiableInterfaceLookupPath(
+ IRInterfaceType* idiffType, IRInterfaceType* type);
+
+ InstPair transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst);
+
+ // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+ IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType);
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness);
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType);
+
+ IRType* differentiateType(IRBuilder* builder, IRType* origType);
+
+ IRType* differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable);
+
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType);
+
+ IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key);
+
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType);
+
+ InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
+
+ InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn);
+
+ InstPair transcribeParam(IRBuilder* builder, IRParam* origParam);
+
+ InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst);
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock);
+
+ // Transcribe a generic definition
+ InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric);
+
+ IRInst* transcribe(IRBuilder* builder, IRInst* origInst);
+
+ InstPair transcribeInst(IRBuilder* builder, IRInst* origInst);
+
+ IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType);
+
+ virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) = 0;
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ virtual InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) = 0;
+
+ virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) = 0;
+
+ virtual IROp getDifferentiableMethodDictionaryItemOp() = 0;
+};
+
+}
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 8dfedcb94..546d5a6ec 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -64,35 +64,6 @@ struct ExtractPrimalFuncContext
return intermediateType;
}
- // Specialize `genericToSpecialize` with the generic parameters defined in `userGeneric`.
- // For example:
- // ```
- // int f<T>(T a);
- // ```
- // will be extended into
- // ```
- // struct IntermediateFor_f<T> { T t0; }
- // int f_primal<T>(T a, IntermediateFor_f<T> imm);
- // ```
- // Given a user generic `f_primal<T>` and a used value parameterized on the same set of generic parameters
- // `IntermediateFor_f`, `genericToSpecialize` constructs `IntermediateFor_f<T>` (using the parameter list
- // from user generic).
- //
- IRInst* specializeWithGeneric(
- IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric)
- {
- List<IRInst*> genArgs;
- for (auto param : userGeneric->getFirstBlock()->getParams())
- {
- genArgs.add(param);
- }
- return builder.emitSpecializeInst(
- builder.getTypeKind(),
- genericToSpecialize,
- (UInt)genArgs.getCount(),
- genArgs.getBuffer());
- }
-
IRInst* generatePrimalFuncType(
IRGlobalValueWithCode* destFunc, IRGlobalValueWithCode* fwdFunc, IRInst*& outIntermediateType)
{
@@ -505,8 +476,8 @@ IRGlobalValueWithCode* DiffUnzipPass::extractPrimalFunc(
{
innerFunc = as<IRFunc>(findGenericReturnVal(genFunc));
builder.setInsertBefore(innerFunc);
- specializedIntermediateType = context.specializeWithGeneric(builder, intermediateType, genFunc);
- specializedPrimalFunc = context.specializeWithGeneric(builder, primalFunc, genFunc);
+ specializedIntermediateType = specializeWithGeneric(builder, intermediateType, genFunc);
+ specializedPrimalFunc = specializeWithGeneric(builder, primalFunc, genFunc);
}
SLANG_RELEASE_ASSERT(innerFunc);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 3d42f2922..f0ec1542e 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -5,9 +5,7 @@
namespace Slang
{
-
-// TODO: Put into a nameless namespace.
-IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
+static IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
{
if (auto witnessTable = as<IRWitnessTable>(witness))
{
@@ -41,6 +39,13 @@ bool isNoDiffType(IRType* paramType)
return false;
}
+IRInst* lookupForwardDerivativeReference(IRInst* primalFunction)
+{
+ if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
+ return jvpDefinition->getForwardDerivativeFunc();
+ return nullptr;
+}
+
IRStructField* DifferentialPairTypeBuilder::findField(IRInst* type, IRStructKey* key)
{
if (auto irStructType = as<IRStructType>(type))
@@ -277,7 +282,6 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
return result;
}
-
AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst)
: moduleInst(inModuleInst)
{
@@ -331,8 +335,6 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde
return nullptr;
}
-
-
void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
parentFunc = func;
@@ -385,7 +387,6 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
}
}
-
void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
for (auto inst : parent->getChildren())
@@ -398,6 +399,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_ForwardDerivativeDecoration:
case kIROp_DerivativeMemberDecoration:
case kIROp_DifferentiableTypeDictionaryDecoration:
+ case kIROp_DifferentialInstDecoration:
+ case kIROp_MixedDifferentialInstDecoration:
decor->removeAndDeallocate();
break;
default:
@@ -448,6 +451,187 @@ void stripNoDiffTypeAttribute(IRModule* module)
pass.processModule();
}
+struct AutoDiffPass : public InstPassBase
+{
+ DiagnosticSink* getSink()
+ {
+ return sink;
+ }
+
+ bool processModule()
+ {
+ // TODO(sai): Move this call.
+ forwardTranscriber.differentiableTypeConformanceContext.buildGlobalWitnessDictionary();
+
+ IRBuilder builderStorage(this->autodiffContext->sharedBuilder);
+ IRBuilder* builder = &builderStorage;
+
+ // Process all ForwardDifferentiate and BackwardDifferentiate instructions by
+ // generating derivative code for the referenced function.
+ //
+ bool modified = processReferencedFunctions(builder);
+
+ return modified;
+ }
+
+ // Process all differentiate calls, and recursively generate code for forward and backward
+ // derivative functions.
+ //
+ bool processReferencedFunctions(IRBuilder* builder)
+ {
+ bool hasChanges = false;
+ for (;;)
+ {
+ bool changed = false;
+ List<IRInst*> autoDiffWorkList;
+ // Collect all `ForwardDifferentiate` insts from the module.
+ autoDiffWorkList.clear();
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ // Only process now if the operand is a materialized function.
+ switch (inst->getOperand(0)->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Specialize:
+ case kIROp_LookupWitness:
+ autoDiffWorkList.add(inst);
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ }
+ });
+
+ // Process collected differentiate insts and replace them with placeholders for
+ // differentiated functions.
+
+ for (auto differentiateInst : autoDiffWorkList)
+ {
+ if (auto diffInst = as<IRForwardDifferentiate>(differentiateInst))
+ {
+ IRBuilder subBuilder(*builder);
+ subBuilder.setInsertBefore(differentiateInst);
+ if (auto diffFunc = forwardTranscriber.transcribe(&subBuilder, diffInst->getBaseFn()))
+ {
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ else if (auto backDiffInst = as<IRBackwardDifferentiate>(differentiateInst))
+ {
+ auto baseInst = backDiffInst->getBaseFn();
+ if (auto diffFunc = backwardTranscriber.transcribe(builder, (IRFunc*)baseInst))
+ {
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ }
+
+ // Run transcription logic to generate the body of forward/backward derivatives functions.
+ // While doing so, we may discover new functions to differentiate, so we keep running until
+ // the worklist goes dry.
+ while (autodiffContext->followUpFunctionsToTranscribe.getCount() != 0)
+ {
+ changed = true;
+ auto followUpWorkList = _Move(autodiffContext->followUpFunctionsToTranscribe);
+ for (auto task : followUpWorkList)
+ {
+ auto diffFunc = as<IRFunc>(task.resultFunc);
+ SLANG_ASSERT(diffFunc);
+ auto primalFunc = as<IRFunc>(task.originalFunc);
+ SLANG_ASSERT(primalFunc);
+ switch (task.type)
+ {
+ case FuncBodyTranscriptionTaskType::Forward:
+ forwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
+ break;
+ case FuncBodyTranscriptionTaskType::Backward:
+ backwardTranscriber.transcribeFunc(builder, primalFunc, diffFunc);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ if (!changed)
+ break;
+ hasChanges |= changed;
+ }
+ return hasChanges;
+ }
+
+ IRStringLit* getDerivativeFuncName(IRInst* func, const char* postFix)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(func);
+
+ IRStringLit* name = nullptr;
+ if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
+ {
+ name = builder.getStringValue((String(linkageDecoration->getMangledName()) + postFix).getUnownedSlice());
+ }
+ else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
+ {
+ name = builder.getStringValue((String(namehintDecoration->getName()) + postFix).getUnownedSlice());
+ }
+
+ return name;
+ }
+
+ IRStringLit* getForwardDerivativeFuncName(IRInst* func)
+ {
+ return getDerivativeFuncName(func, "_fwd_diff");
+ }
+
+ IRStringLit* getBackwardDerivativeFuncName(IRInst* func)
+ {
+ return getDerivativeFuncName(func, "_bwd_diff");
+ }
+
+ AutoDiffPass(AutoDiffSharedContext* context, DiagnosticSink* sink) :
+ InstPassBase(context->moduleInst->getModule()),
+ sink(sink),
+ forwardTranscriber(context, context->sharedBuilder, sink),
+ backwardTranscriber(context, context->sharedBuilder, sink),
+ pairBuilderStorage(context),
+ autodiffContext(context)
+ {
+ forwardTranscriber.pairBuilder = &pairBuilderStorage;
+ backwardTranscriber.pairBuilder = &pairBuilderStorage;
+ backwardTranscriber.fwdDiffTranscriber = &forwardTranscriber;
+ }
+
+protected:
+ // A transcriber object that handles the main job of
+ // processing instructions while maintaining state.
+ //
+ ForwardDiffTranscriber forwardTranscriber;
+
+ BackwardDiffTranscriber backwardTranscriber;
+
+ // Diagnostic object from the compile request for
+ // error messages.
+ DiagnosticSink* sink;
+
+ // Shared context.
+ AutoDiffSharedContext* autodiffContext;
+
+ // Builder for dealing with differential pair types.
+ DifferentialPairTypeBuilder pairBuilderStorage;
+
+};
+
bool processAutodiffCalls(
IRModule* module,
DiagnosticSink* sink,
@@ -468,11 +652,9 @@ bool processAutodiffCalls(
autodiffContext.sharedBuilder = &sharedBuilder;
- // Process forward derivative calls.
- modified |= processForwardDerivativeCalls(&autodiffContext, sink);
+ AutoDiffPass pass(&autodiffContext, sink);
- // Process reverse derivative calls.
- modified |= processReverseDerivativeCalls(&autodiffContext, sink);
+ modified |= pass.processModule();
return modified;
}
@@ -505,5 +687,4 @@ bool finalizeAutoDiffPass(IRModule* module)
return false;
}
-
}
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 25cbe16f4..e0508cef7 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -13,6 +13,39 @@
namespace Slang
{
+template<typename P, typename D>
+struct DiffInstPair
+{
+ P primal;
+ D differential;
+ DiffInstPair() = default;
+ DiffInstPair(P primal, D differential) : primal(primal), differential(differential)
+ {}
+ HashCode getHashCode() const
+ {
+ Hasher hasher;
+ hasher << primal << differential;
+ return hasher.getResult();
+ }
+ bool operator ==(const DiffInstPair& other) const
+ {
+ return primal == other.primal && differential == other.differential;
+ }
+};
+
+typedef DiffInstPair<IRInst*, IRInst*> InstPair;
+
+enum class FuncBodyTranscriptionTaskType
+{
+ Forward, Backward, Primal
+};
+
+struct FuncBodyTranscriptionTask
+{
+ FuncBodyTranscriptionTaskType type;
+ IRFunc* originalFunc;
+ IRFunc* resultFunc;
+};
struct AutoDiffSharedContext
{
@@ -58,6 +91,7 @@ struct AutoDiffSharedContext
//
bool isInterfaceAvailable = false;
+ List<FuncBodyTranscriptionTask> followUpFunctionsToTranscribe;
AutoDiffSharedContext(IRModuleInst* inModuleInst);
@@ -195,10 +229,10 @@ struct DifferentialPairTypeBuilder
void stripAutoDiffDecorations(IRModule* module);
-IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey);
-
bool isNoDiffType(IRType* paramType);
+IRInst* lookupForwardDerivativeReference(IRInst* primalFunction);
+
struct IRAutodiffPassOptions
{
// Nothing for now...
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 5c4590abe..81b5d636a 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -129,4 +129,18 @@ IROp getTypeStyle(BaseType op)
}
}
+IRInst* specializeWithGeneric(IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric)
+{
+ List<IRInst*> genArgs;
+ for (auto param : userGeneric->getFirstBlock()->getParams())
+ {
+ genArgs.add(param);
+ }
+ return builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ genericToSpecialize,
+ (UInt)genArgs.getCount(),
+ genArgs.getBuffer());
+}
+
}
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 385d05b28..2087ee4a7 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -44,6 +44,30 @@ inline bool isChildInstOf(IRInst* inst, IRInst* parent)
return false;
}
+ // Specialize `genericToSpecialize` with the generic parameters defined in `userGeneric`.
+ // For example:
+ // ```
+ // int f<T>(T a);
+ // ```
+ // will be extended into
+ // ```
+ // struct IntermediateFor_f<T> { T t0; }
+ // int f_primal<T>(T a, IntermediateFor_f<T> imm);
+ // ```
+ // Given a user generic `f_primal<T>` and a used value parameterized on the same set of generic parameters
+ // `IntermediateFor_f`, `genericToSpecialize` constructs `IntermediateFor_f<T>` (using the parameter list
+ // from user generic).
+ //
+IRInst* specializeWithGeneric(
+ IRBuilder& builder, IRInst* genericToSpecialize, IRGeneric* userGeneric);
+
+
+inline IRInst* unwrapAttributedType(IRInst* type)
+{
+ while (auto attrType = as<IRAttributedType>(type))
+ type = attrType->getBaseType();
+ return type;
+}
}
#endif