From 6178cb601368e977c4aa82e0ae25b8eb1e875d84 Mon Sep 17 00:00:00 2001
From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>
Date: Tue, 22 Nov 2022 12:36:28 -0500
Subject: Refactor Auto-diff passes (#2526)
* Initial refactor
* Refactor passes tests
* Removed Differential Bottom references from the IR side
---
build/visual-studio/slang/slang.vcxproj | 10 +-
build/visual-studio/slang/slang.vcxproj.filters | 30 +-
source/slang/slang-emit.cpp | 7 +-
source/slang/slang-ir-autodiff-fwd.cpp | 1761 ++++++++++++
source/slang/slang-ir-autodiff-fwd.h | 43 +
source/slang/slang-ir-autodiff-pairs.cpp | 182 ++
source/slang/slang-ir-autodiff-pairs.h | 21 +
source/slang/slang-ir-autodiff-rev.cpp | 832 ++++++
source/slang/slang-ir-autodiff-rev.h | 25 +
source/slang/slang-ir-autodiff.cpp | 408 +++
source/slang/slang-ir-autodiff.h | 210 ++
source/slang/slang-ir-check-differentiability.cpp | 2 +-
source/slang/slang-ir-diff-jvp.cpp | 3197 ---------------------
source/slang/slang-ir-diff-jvp.h | 174 --
source/slang/slang-ir-link.cpp | 2 +-
source/slang/slang-lower-to-ir.cpp | 2 +-
16 files changed, 3519 insertions(+), 3387 deletions(-)
create mode 100644 source/slang/slang-ir-autodiff-fwd.cpp
create mode 100644 source/slang/slang-ir-autodiff-fwd.h
create mode 100644 source/slang/slang-ir-autodiff-pairs.cpp
create mode 100644 source/slang/slang-ir-autodiff-pairs.h
create mode 100644 source/slang/slang-ir-autodiff-rev.cpp
create mode 100644 source/slang/slang-ir-autodiff-rev.h
create mode 100644 source/slang/slang-ir-autodiff.cpp
create mode 100644 source/slang/slang-ir-autodiff.h
delete mode 100644 source/slang/slang-ir-diff-jvp.cpp
delete mode 100644 source/slang/slang-ir-diff-jvp.h
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index 5244f0efb..fe1922d29 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -333,6 +333,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
+
+
+
+
@@ -343,7 +347,6 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
-
@@ -505,6 +508,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
+
+
+
+
@@ -516,7 +523,6 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
-
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 0a2d1fe3f..6fa42287c 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -132,6 +132,18 @@
Header Files
+
+ Header Files
+
+
+ Header Files
+
+
+ Header Files
+
+
+ Header Files
+
Header Files
@@ -162,9 +174,6 @@
Header Files
-
- Header Files
-
Header Files
@@ -644,6 +653,18 @@
Source Files
+
+ Source Files
+
+
+ Source Files
+
+
+ Source Files
+
+
+ Source Files
+
Source Files
@@ -677,9 +698,6 @@
Source Files
-
- Source Files
-
Source Files
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 69ea29c7a..ca55a68bc 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -11,7 +11,7 @@
#include "slang-ir-cleanup-void.h"
#include "slang-ir-dce.h"
#include "slang-ir-diff-call.h"
-#include "slang-ir-diff-jvp.h"
+#include "slang-ir-autodiff.h"
#include "slang-ir-dll-export.h"
#include "slang-ir-dll-import.h"
#include "slang-ir-eliminate-phis.h"
@@ -377,10 +377,7 @@ Result linkAndOptimizeIR(
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
- // Process higher-order calles to auto-diff passes.
- processDifferentiableFuncs(irModule, sink);
-
- stripAutoDiffDecorations(irModule);
+ processAutodiffCalls(irModule, sink);
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
new file mode 100644
index 000000000..03e81c5b5
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -0,0 +1,1761 @@
+// slang-ir-autodiff-fwd.cpp
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-fwd.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
+
+namespace Slang
+{
+
+
+struct JVPTranscriber
+{
+
+ // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
+ // their differential values.
+ Dictionary instMapD;
+
+ // Set of insts currently being transcribed. Used to avoid infinite loops.
+ HashSet instsInProgress;
+
+ // Cloning environment to hold mapping from old to new copies for the primal
+ // instructions.
+ IRCloneEnv cloneEnv;
+
+ // Diagnostic sink for error messages.
+ DiagnosticSink* sink;
+
+ // Type conformance information.
+ AutoDiffSharedContext* autoDiffSharedContext;
+
+ // Builder to help with creating and accessing the 'DifferentiablePair' struct
+ DifferentialPairTypeBuilder* pairBuilder;
+
+ DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
+
+ List followUpFunctionsToTranscribe;
+
+ SharedIRBuilder* sharedBuilder;
+ // Witness table that `DifferentialBottom:IDifferential`.
+ IRWitnessTable* differentialBottomWitness = nullptr;
+ Dictionary differentialPairTypes;
+
+ JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder)
+ : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder)
+ {
+
+ }
+
+ DiagnosticSink* getSink()
+ {
+ SLANG_ASSERT(sink);
+ return sink;
+ }
+
+ void mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
+ {
+ if (hasDifferentialInst(origInst))
+ {
+ if (lookupDiffInst(origInst) != diffInst)
+ {
+ SLANG_UNEXPECTED("Inconsistent differential mappings");
+ }
+ }
+ else
+ {
+ instMapD.Add(origInst, diffInst);
+ }
+ }
+
+ void mapPrimalInst(IRInst* origInst, IRInst* primalInst)
+ {
+ if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst)
+ {
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "inconsistent primal instruction for original");
+ }
+ else
+ {
+ cloneEnv.mapOldValToNew[origInst] = primalInst;
+ }
+ }
+
+ IRInst* lookupDiffInst(IRInst* origInst)
+ {
+ return instMapD[origInst];
+ }
+
+ IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst)
+ {
+ return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst;
+ }
+
+ bool hasDifferentialInst(IRInst* origInst)
+ {
+ return instMapD.ContainsKey(origInst);
+ }
+
+ IRInst* lookupPrimalInst(IRInst* origInst)
+ {
+ return cloneEnv.mapOldValToNew[origInst];
+ }
+
+ IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
+ {
+ return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
+ }
+
+ bool hasPrimalInst(IRInst* origInst)
+ {
+ return cloneEnv.mapOldValToNew.ContainsKey(origInst);
+ }
+
+ IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst)
+ {
+ if (!hasDifferentialInst(origInst))
+ {
+ transcribe(builder, origInst);
+ SLANG_ASSERT(hasDifferentialInst(origInst));
+ }
+
+ return lookupDiffInst(origInst);
+ }
+
+ IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst)
+ {
+ if (!hasPrimalInst(origInst))
+ {
+ transcribe(builder, origInst);
+ SLANG_ASSERT(hasPrimalInst(origInst));
+ }
+
+ return lookupPrimalInst(origInst);
+ }
+
+ IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ {
+ List newParameterTypes;
+ IRType* diffReturnType;
+
+ for (UIndex i = 0; i < funcType->getParamCount(); i++)
+ {
+ auto origType = funcType->getParamType(i);
+ origType = (IRType*) lookupPrimalInst(origType, origType);
+ if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ newParameterTypes.add(diffPairType);
+ else
+ newParameterTypes.add(origType);
+ }
+
+ // Transcribe return type to a pair.
+ // This will be void if the primal return type is non-differentiable.
+ //
+ auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType());
+ if (auto returnPairType = tryGetDiffPairType(builder, origResultType))
+ diffReturnType = returnPairType;
+ else
+ diffReturnType = builder->getVoidType();
+
+ return builder->getFuncType(newParameterTypes, diffReturnType);
+ }
+
+ // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+ IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(inDiffPairType->parent);
+ auto diffPairType = as(inDiffPairType);
+ SLANG_ASSERT(diffPairType);
+
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffDiffPairType = differentiateType(&builder, diffPairType);
+
+ // And place it in the synthesized witness table.
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+
+ // Record this in the context for future lookups
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+
+ return table;
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ auto witness = as(
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+
+ if (!witness)
+ {
+ if (auto primalPairType = as(primalType))
+ {
+ witness = getDifferentialPairWitness(primalPairType);
+ }
+ }
+
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* differentiateType(IRBuilder* builder, IRType* origType)
+ {
+ IRInst* diffType = nullptr;
+ if (!instMapD.TryGetValue(origType, diffType))
+ {
+ diffType = _differentiateTypeImpl(builder, origType);
+ instMapD[origType] = diffType;
+ }
+ return (IRType*)diffType;
+ }
+
+ IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
+ {
+ if (auto ptrType = as(origType))
+ return builder->getPtrType(
+ origType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
+
+ // If there is an explicit primal version of this type in the local scope, load that
+ // otherwise use the original type.
+ //
+ IRInst* primalType = lookupPrimalInst(origType, origType);
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as(primalType->getDataType()))
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
+ builder,
+ (IRType*)primalType));
+ else if (as(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
+ {
+ auto primalArrayType = as(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
+
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as(primalType);
+ return getOrCreateDiffPairType(
+ pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
+ pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_FuncType:
+ return differentiateFunctionType(builder, as(primalType));
+
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_TupleType:
+ {
+ auto tupleType = as(primalType);
+ List diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
+
+ return builder->getTupleType(diffTypeList);
+ }
+
+ default:
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
+ }
+ }
+
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
+ {
+ // If this is a PtrType (out, inout, etc..), then create diff pair from
+ // value type and re-apply the appropropriate PtrType wrapper.
+ //
+ if (auto origPtrType = as(primalType))
+ {
+ if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
+ return builder->getPtrType(primalType->getOp(), diffPairValueType);
+ else
+ return nullptr;
+ }
+ auto diffType = differentiateType(builder, primalType);
+ if (diffType)
+ return (IRType*)getOrCreateDiffPairType(primalType);
+ return nullptr;
+ }
+
+ InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
+ {
+ auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType());
+ // Do not differentiate generic type (and witness table) parameters
+ if (as(primalDataType) || as(primalDataType))
+ {
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ // Is this param a phi node or a function parameter?
+ auto func = as(origParam->getParent()->getParent());
+ bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock());
+ if (isFuncParam)
+ {
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffPairParam);
+
+ if (auto pairType = as(diffPairType))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ else if (auto pairPtrType = as(diffPairType))
+ {
+ auto ptrInnerPairType = as(pairPtrType->getValueType());
+
+ return InstPair(
+ builder->emitDifferentialPairAddressPrimal(diffPairParam),
+ builder->emitDifferentialPairAddressDifferential(
+ builder->getPtrType(
+ kIROp_PtrType,
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)),
+ diffPairParam));
+ }
+ }
+
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+ else
+ {
+ auto primal = cloneInst(&cloneEnv, builder, origParam);
+ IRInst* diff = nullptr;
+ if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType))
+ {
+ diff = builder->emitParam(diffType);
+ }
+ return InstPair(primal, diff);
+ }
+
+ }
+
+ // Returns "d" to use as a name hint for variables and parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String getJVPVarName(IRInst* origVar)
+ {
+ if (auto namehintDecoration = origVar->findDecoration())
+ {
+ return ("d" + String(namehintDecoration->getName()));
+ }
+
+ return String("");
+ }
+
+ // Returns "dp" to use as a name hint for parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String makeDiffPairName(IRInst* origVar)
+ {
+ if (auto namehintDecoration = origVar->findDecoration())
+ {
+ return ("dp" + String(namehintDecoration->getName()));
+ }
+
+ return String("");
+ }
+
+ InstPair transcribeVar(IRBuilder* builder, IRVar* origVar)
+ {
+ if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
+ {
+ IRVar* diffVar = builder->emitVar(diffType);
+ SLANG_ASSERT(diffVar);
+
+ auto diffNameHint = getJVPVarName(origVar);
+ if (diffNameHint.getLength() > 0)
+ builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice());
+
+ return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar);
+ }
+
+ return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr);
+ }
+
+ InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith)
+ {
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith);
+
+ auto origLeft = origArith->getOperand(0);
+ auto origRight = origArith->getOperand(1);
+
+ auto primalLeft = findOrTranscribePrimalInst(builder, origLeft);
+ auto primalRight = findOrTranscribePrimalInst(builder, origRight);
+
+ auto diffLeft = findOrTranscribeDiffInst(builder, origLeft);
+ auto diffRight = findOrTranscribeDiffInst(builder, origRight);
+
+
+ if (diffLeft || diffRight)
+ {
+ diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType());
+ diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
+
+ auto resultType = primalArith->getDataType();
+ switch(origArith->getOp())
+ {
+ case kIROp_Add:
+ return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight));
+ case kIROp_Mul:
+ return InstPair(primalArith, builder->emitAdd(resultType,
+ builder->emitMul(resultType, diffLeft, primalRight),
+ builder->emitMul(resultType, primalLeft, diffRight)));
+ case kIROp_Sub:
+ return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight));
+ case kIROp_Div:
+ return InstPair(primalArith, builder->emitDiv(resultType,
+ builder->emitSub(
+ resultType,
+ builder->emitMul(resultType, diffLeft, primalRight),
+ builder->emitMul(resultType, primalLeft, diffRight)),
+ builder->emitMul(
+ primalRight->getDataType(), primalRight, primalRight
+ )));
+ default:
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
+ }
+ }
+
+ return InstPair(primalArith, nullptr);
+ }
+
+
+ InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic)
+ {
+ SLANG_ASSERT(origLogic->getOperandCount() == 2);
+
+ // TODO: Check other boolean cases.
+ if (as(origLogic->getDataType()))
+ {
+ // Boolean operations are not differentiable. For the linearization
+ // pass, we do not need to do anything but copy them over to the ne
+ // function.
+ auto primalLogic = cloneInst(&cloneEnv, builder, origLogic);
+ return InstPair(primalLogic, nullptr);
+ }
+
+ SLANG_UNEXPECTED("Logical operation with non-boolean result");
+ }
+
+ InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
+ {
+ auto origPtr = origLoad->getPtr();
+ auto primalPtr = lookupPrimalInst(origPtr, nullptr);
+ auto primalPtrValueType = as(primalPtr->getFullType())->getValueType();
+
+ if (auto diffPairType = as(primalPtrValueType))
+ {
+ // Special case load from an `out` param, which will not have corresponding `diff` and
+ // `primal` insts yet.
+
+ auto load = builder->emitLoad(primalPtr);
+ auto primalElement = builder->emitDifferentialPairGetPrimal(load);
+ auto diffElement = builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load);
+ return InstPair(primalElement, diffElement);
+ }
+
+ auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+ IRInst* diffLoad = nullptr;
+ if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
+ {
+ // Default case, we're loading from a known differential inst.
+ diffLoad = as(builder->emitLoad(diffPtr));
+ }
+ return InstPair(primalLoad, diffLoad);
+ }
+
+ InstPair transcribeStore(IRBuilder* builder, IRStore* origStore)
+ {
+ IRInst* origStoreLocation = origStore->getPtr();
+ IRInst* origStoreVal = origStore->getVal();
+ auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr);
+ auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
+ auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr);
+ auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);
+
+ if (!diffStoreLocation)
+ {
+ auto primalLocationPtrType = as(primalStoreLocation->getDataType());
+ if (auto diffPairType = as(primalLocationPtrType->getValueType()))
+ {
+ auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal);
+ auto store = builder->emitStore(primalStoreLocation, valToStore);
+ return InstPair(store, nullptr);
+ }
+ }
+
+ auto primalStore = cloneInst(&cloneEnv, builder, origStore);
+
+ IRInst* diffStore = nullptr;
+
+ // If the stored value has a differential version,
+ // emit a store instruction for the differential parameter.
+ // Otherwise, emit nothing since there's nothing to load.
+ //
+ if (diffStoreLocation && diffStoreVal)
+ {
+ // Default case, storing the entire type (and not a member)
+ diffStore = as(
+ builder->emitStore(diffStoreLocation, diffStoreVal));
+
+ return InstPair(primalStore, diffStore);
+ }
+
+ return InstPair(primalStore, nullptr);
+ }
+
+ InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn)
+ {
+ IRInst* origReturnVal = origReturn->getVal();
+
+ auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType());
+ if (as(origReturnVal) || as(origReturnVal) || as(origReturnVal) || as(origReturnVal))
+ {
+ // If the return value is itself a function, generic or a struct then this
+ // is likely to be a generic scope. In this case, we lookup the differential
+ // and return that.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+ IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
+
+ // Neither of these should be nullptr.
+ SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal);
+ IRReturn* diffReturn = as(builder->emitReturn(diffReturnVal));
+
+ return InstPair(diffReturn, diffReturn);
+ }
+ else if (auto pairType = tryGetDiffPairType(builder, returnDataType))
+ {
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+ IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
+ if(!diffReturnVal)
+ diffReturnVal = getDifferentialZeroOfType(builder, returnDataType);
+
+ // If the pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffReturnVal);
+
+ auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
+ IRReturn* pairReturn = as(builder->emitReturn(diffPair));
+ return InstPair(pairReturn, pairReturn);
+ }
+ else
+ {
+ // If the return type is not differentiable, emit the primal value only.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+
+ IRInst* primalReturn = builder->emitReturn(primalReturnVal);
+ return InstPair(primalReturn, nullptr);
+
+ }
+ }
+
+ // Since int/float literals are sometimes nested inside an IRConstructor
+ // instruction, we check to make sure that the nested instr is a constant
+ // and then return nullptr. Literals do not need to be differentiated.
+ //
+ InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
+ {
+ IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct);
+
+ // Check if the output type can be differentiated. If it cannot be
+ // differentiated, don't differentiate the inst
+ //
+ auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType());
+ if (auto diffConstructType = differentiateType(builder, primalConstructType))
+ {
+ UCount operandCount = origConstruct->getOperandCount();
+
+ List diffOperands;
+ for (UIndex ii = 0; ii < operandCount; ii++)
+ {
+ // If the operand has a differential version, replace the original with
+ // the differential. Otherwise, use a zero.
+ //
+ if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr))
+ diffOperands.add(diffInst);
+ else
+ {
+ auto operandDataType = origConstruct->getOperand(ii)->getDataType();
+ operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType);
+ diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
+ }
+ }
+
+ return InstPair(
+ primalConstruct,
+ builder->emitIntrinsicInst(
+ diffConstructType,
+ origConstruct->getOp(),
+ operandCount,
+ diffOperands.getBuffer()));
+ }
+ else
+ {
+ return InstPair(primalConstruct, nullptr);
+ }
+ }
+
+ // Differentiating a call instruction here is primarily about generating
+ // an appropriate call list based on whichever parameters have differentials
+ // in the current transcription context.
+ //
+ InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
+ {
+
+ IRInst* origCallee = origCall->getCallee();
+
+ if (!origCallee)
+ {
+ // Note that this can only happen if the callee is a result
+ // of a higher-order operation. For now, we assume that we cannot
+ // differentiate such calls safely.
+ // TODO(sai): Should probably get checked in the front-end.
+ //
+ getSink()->diagnose(origCall->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "attempting to differentiate unresolved callee");
+
+ return InstPair(nullptr, nullptr);
+ }
+
+ // Since concrete functions are globals, the primal callee is the same
+ // as the original callee.
+ //
+ auto primalCallee = origCallee;
+
+ IRInst* diffCallee = nullptr;
+
+ if (auto derivativeReferenceDecor = primalCallee->findDecoration())
+ {
+ // If the user has already provided an differentiated implementation, use that.
+ diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc();
+ }
+ else if (primalCallee->findDecoration())
+ {
+ // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
+ // to generate the implementation.
+ diffCallee = builder->emitForwardDifferentiateInst(
+ differentiateFunctionType(builder, as(primalCallee->getFullType())),
+ primalCallee);
+ }
+ else
+ {
+ // The callee is non differentiable, just return primal value with null diff value.
+ IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
+ return InstPair(primalCall, nullptr);
+ }
+
+ List args;
+ // Go over the parameter list and create pairs for each input (if required)
+ for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
+ {
+ auto origArg = origCall->getArg(ii);
+ auto primalArg = findOrTranscribePrimalInst(builder, origArg);
+ SLANG_ASSERT(primalArg);
+
+ auto primalType = primalArg->getDataType();
+ auto diffArg = findOrTranscribeDiffInst(builder, origArg);
+
+ if (!diffArg)
+ diffArg = getDifferentialZeroOfType(builder, primalType);
+
+ if (auto pairType = tryGetDiffPairType(builder, primalType))
+ {
+ // If a pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffArg);
+ auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
+ args.add(diffPair);
+ }
+ else
+ {
+ // Add original/primal argument.
+ args.add(primalArg);
+ }
+ }
+
+ IRType* diffReturnType = nullptr;
+ diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
+
+ if (!diffReturnType)
+ {
+ SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType);
+ diffReturnType = builder->getVoidType();
+ }
+
+ auto callInst = builder->emitCallInst(
+ diffReturnType,
+ diffCallee,
+ args);
+
+ if (diffReturnType->getOp() != kIROp_VoidType)
+ {
+ IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst);
+ auto diffType = differentiateType(builder, origCall->getFullType());
+ IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst);
+ return InstPair(primalResultValue, diffResultValue);
+ }
+ else
+ {
+ // Return the inst itself if the return value is void.
+ // This is fine since these values should never actually be used anywhere.
+ //
+ return InstPair(callInst, callInst);
+ }
+ }
+
+ InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
+ {
+ IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle);
+
+ if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
+ {
+ List swizzleIndices;
+ for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
+ swizzleIndices.add(origSwizzle->getElementIndex(ii));
+
+ return InstPair(
+ primalSwizzle,
+ builder->emitSwizzle(
+ differentiateType(builder, primalSwizzle->getDataType()),
+ diffBase,
+ origSwizzle->getElementCount(),
+ swizzleIndices.getBuffer()));
+ }
+
+ return InstPair(primalSwizzle, nullptr);
+ }
+
+ InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst)
+ {
+ IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst);
+
+ UCount operandCount = origInst->getOperandCount();
+
+ List diffOperands;
+ for (UIndex ii = 0; ii < operandCount; ii++)
+ {
+ // If the operand has a differential version, replace the original with the
+ // differential.
+ // Otherwise, abandon the differentiation attempt and assume that origInst
+ // cannot (or does not need to) be differentiated.
+ //
+ if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr))
+ diffOperands.add(diffInst);
+ else
+ return InstPair(primalInst, nullptr);
+ }
+
+ return InstPair(
+ primalInst,
+ builder->emitIntrinsicInst(
+ differentiateType(builder, primalInst->getDataType()),
+ origInst->getOp(),
+ operandCount,
+ diffOperands.getBuffer()));
+ }
+
+ InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst)
+ {
+ switch(origInst->getOp())
+ {
+ case kIROp_unconditionalBranch:
+ case kIROp_loop:
+ auto origBranch = as(origInst);
+
+ // Grab the differentials for any phi nodes.
+ List newArgs;
+ for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
+ {
+ auto origArg = origBranch->getArg(ii);
+ auto primalArg = lookupPrimalInst(origArg);
+ newArgs.add(primalArg);
+
+ if (differentiateType(builder, primalArg->getDataType()))
+ {
+ auto diffArg = lookupDiffInst(origArg, nullptr);
+ if (diffArg)
+ newArgs.add(diffArg);
+ }
+ }
+
+ IRInst* diffBranch = nullptr;
+ if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock()))
+ {
+ if (auto origLoop = as(origInst))
+ {
+ auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
+ auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
+ List operands;
+ operands.add(breakBlock);
+ operands.add(continueBlock);
+ operands.addRange(newArgs);
+ diffBranch = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_loop,
+ operands.getCount(),
+ operands.getBuffer());
+ }
+ else
+ {
+ diffBranch = builder->emitBranch(
+ as(diffBlock),
+ newArgs.getCount(),
+ newArgs.getBuffer());
+ }
+ }
+
+ // For now, every block in the original fn must have a corresponding
+ // block to compute *both* primals and derivatives (i.e linearized block)
+ SLANG_ASSERT(diffBranch);
+
+ return InstPair(diffBranch, diffBranch);
+ }
+
+ getSink()->diagnose(
+ origInst->sourceLoc,
+ Diagnostics::unimplemented,
+ "attempting to differentiate unhandled control flow");
+
+ return InstPair(nullptr, nullptr);
+ }
+
+ InstPair transcribeConst(IRBuilder* builder, IRInst* origInst)
+ {
+ switch(origInst->getOp())
+ {
+ case kIROp_FloatLit:
+ return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f));
+ case kIROp_VoidLit:
+ return InstPair(origInst, origInst);
+ case kIROp_IntLit:
+ return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0));
+ }
+
+ getSink()->diagnose(
+ origInst->sourceLoc,
+ Diagnostics::unimplemented,
+ "attempting to differentiate unhandled const type");
+
+ return InstPair(nullptr, nullptr);
+ }
+
+ InstPair transcribeSpecialize(IRBuilder*, IRSpecialize* origSpecialize)
+ {
+ // In general, we should not see any specialize insts at this stage.
+ // The exceptions are target intrinsics.
+ auto genericInnerVal = findInnerMostGenericReturnVal(as(origSpecialize->getBase()));
+ if (genericInnerVal->findDecoration())
+ {
+ // Look for an IRForwardDerivativeDecoration on the specialize inst.
+ // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
+ // can be specialized, so we put a decoration on the IRSpecialize)
+ //
+ if (auto jvpFuncDecoration = origSpecialize->findDecoration())
+ {
+ auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc();
+
+ // Make sure this isn't itself a specialize .
+ SLANG_RELEASE_ASSERT(!as(jvpFunc));
+
+ return InstPair(jvpFunc, jvpFunc);
+ }
+ }
+ else
+ {
+ getSink()->diagnose(origSpecialize->sourceLoc,
+ Diagnostics::unexpected,
+ "should not be attempting to differentiate anything specialized here.");
+ }
+
+ return InstPair(nullptr, nullptr);
+ }
+
+ InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup)
+ {
+ // This is slightly counter-intuitive, but we don't perform any differentiation
+ // logic here. We simple clone the original lookup which points to the original function,
+ // or the cloned version in case we're inside a generic scope.
+ // The differentiation logic is inserted later when this is used in an IRCall.
+ // This decision is mostly to maintain a uniform convention of ForwardDifferentiate(Lookup(Table))
+ // rather than have Lookup(ForwardDifferentiate(Table))
+ //
+ auto diffLookup = cloneInst(&cloneEnv, builder, origLookup);
+ return InstPair(diffLookup, diffLookup);
+ }
+
+ // In differential computation, the 'default' differential value is always zero.
+ // This is a consequence of differential computing being inherently linear. As a
+ // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+ // The current implementation requires that types are defined linearly.
+ //
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+ {
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ switch (diffType->getOp())
+ {
+ case kIROp_DifferentialPairType:
+ return builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as(diffType)->getValueType()));
+ }
+ // Since primalType has a corresponding differential type, we can lookup the
+ // definition for zero().
+ auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
+ SLANG_ASSERT(zeroMethod);
+
+ auto emptyArgList = List();
+ return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ }
+ else
+ {
+ if (isScalarIntegerType(primalType))
+ {
+ return builder->getIntValue(primalType, 0);
+ }
+
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
+ }
+ }
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+ {
+ IRBuilder subBuilder(builder->getSharedBuilder());
+ subBuilder.setInsertLoc(builder->getInsertLoc());
+
+ IRInst* diffBlock = subBuilder.emitBlock();
+
+ // Note: for blocks, we setup the mapping _before_
+ // processing the children since we could encounter
+ // a lookup while processing the children.
+ //
+ mapPrimalInst(origBlock, diffBlock);
+ mapDifferentialInst(origBlock, diffBlock);
+
+ subBuilder.setInsertInto(diffBlock);
+
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ this->transcribe(&subBuilder, param);
+
+ // Then, run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ this->transcribe(&subBuilder, child);
+
+ return InstPair(diffBlock, diffBlock);
+ }
+
+ InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
+ {
+ SLANG_ASSERT(as(originalInst) || as(originalInst));
+
+ IRInst* origBase = originalInst->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto field = originalInst->getOperand(1);
+ auto derivativeRefDecor = field->findDecoration();
+ auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType());
+
+ IRInst* primalOperands[] = { primalBase, field };
+ IRInst* primalFieldExtract = builder->emitIntrinsicInst(
+ primalType,
+ originalInst->getOp(),
+ 2,
+ primalOperands);
+
+ if (!derivativeRefDecor)
+ {
+ return InstPair(primalFieldExtract, nullptr);
+ }
+
+ IRInst* diffFieldExtract = nullptr;
+
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ {
+ IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() };
+ diffFieldExtract = builder->emitIntrinsicInst(
+ diffType,
+ originalInst->getOp(),
+ 2,
+ diffOperands);
+ }
+ }
+ return InstPair(primalFieldExtract, diffFieldExtract);
+ }
+
+ InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
+ {
+ SLANG_ASSERT(as(origGetElementPtr) || as(origGetElementPtr));
+
+ IRInst* origBase = origGetElementPtr->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1));
+
+ auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType());
+
+ IRInst* primalOperands[] = {primalBase, primalIndex};
+ IRInst* primalGetElementPtr = builder->emitIntrinsicInst(
+ primalType,
+ origGetElementPtr->getOp(),
+ 2,
+ primalOperands);
+
+ IRInst* diffGetElementPtr = nullptr;
+
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ {
+ IRInst* diffOperands[] = {diffBase, primalIndex};
+ diffGetElementPtr = builder->emitIntrinsicInst(
+ diffType,
+ origGetElementPtr->getOp(),
+ 2,
+ diffOperands);
+ }
+ }
+
+ return InstPair(primalGetElementPtr, diffGetElementPtr);
+ }
+
+ InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
+ {
+ // The loop comes with three blocks.. we just need to transcribe each one
+ // and assemble the new loop instruction.
+
+ // Transcribe the target block (this is the 'condition' part of the loop, which
+ // will branch into the loop body)
+ auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock());
+
+ // Transcribe the break block (this is the block after the exiting the loop)
+ auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
+
+ // Transcribe the continue block (this is the 'update' part of the loop, which will
+ // branch into the condition block)
+ auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
+
+
+ List diffLoopOperands;
+ diffLoopOperands.add(diffTargetBlock);
+ diffLoopOperands.add(diffBreakBlock);
+ diffLoopOperands.add(diffContinueBlock);
+
+ // If there are any other operands, use their primal versions.
+ for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++)
+ {
+ auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii));
+ diffLoopOperands.add(primalOperand);
+ }
+
+ IRInst* diffLoop = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_loop,
+ diffLoopOperands.getCount(),
+ diffLoopOperands.getBuffer());
+
+ return InstPair(diffLoop, diffLoop);
+ }
+
+ InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
+ {
+ // IfElse Statements come with 4 blocks. We transcribe each block into it's
+ // linear form, and then wire them up in the same way as the original if-else
+
+ // Transcribe condition block
+ auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition());
+ SLANG_ASSERT(primalConditionBlock);
+
+ // Transcribe 'true' block (condition block branches into this if true)
+ auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock());
+ SLANG_ASSERT(diffTrueBlock);
+
+ // Transcribe 'false' block (condition block branches into this if true)
+ // TODO (sai): What happens if there's no false block?
+ auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock());
+ SLANG_ASSERT(diffFalseBlock);
+
+ // Transcribe 'after' block (true and false blocks branch into this)
+ auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock());
+ SLANG_ASSERT(diffAfterBlock);
+
+ List diffIfElseArgs;
+ diffIfElseArgs.add(primalConditionBlock);
+ diffIfElseArgs.add(diffTrueBlock);
+ diffIfElseArgs.add(diffFalseBlock);
+ diffIfElseArgs.add(diffAfterBlock);
+
+ // If there are any other operands, use their primal versions.
+ for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++)
+ {
+ auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii));
+ diffIfElseArgs.add(primalOperand);
+ }
+
+ IRInst* diffLoop = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_ifElse,
+ diffIfElseArgs.getCount(),
+ diffIfElseArgs.getBuffer());
+
+ return InstPair(diffLoop, diffLoop);
+ }
+
+ InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst)
+ {
+ auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue());
+ SLANG_ASSERT(primalVal);
+ auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue());
+ SLANG_ASSERT(diffPrimalVal);
+ auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue());
+ SLANG_ASSERT(primalDiffVal);
+ auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
+ SLANG_ASSERT(diffDiffVal);
+
+ auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal);
+ auto diffPair = builder->emitMakeDifferentialPair(
+ differentiateType(builder, origInst->getDataType()),
+ primalDiffVal,
+ diffDiffVal);
+ return InstPair(primalPair, diffPair);
+ }
+
+ InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
+ {
+ SLANG_ASSERT(
+ origInst->getOp() == kIROp_DifferentialPairGetDifferential ||
+ origInst->getOp() == kIROp_DifferentialPairGetPrimal);
+
+ auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0));
+ SLANG_ASSERT(primalVal);
+
+ auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0));
+ SLANG_ASSERT(diffVal);
+
+ auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal);
+
+ auto diffValPairType = as(diffVal->getDataType());
+ IRInst* diffResultType = nullptr;
+ if (origInst->getOp() == kIROp_DifferentialPairGetDifferential)
+ diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType);
+ else
+ diffResultType = diffValPairType->getValueType();
+ auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal);
+ return InstPair(primalResult, diffResult);
+ }
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origFunc);
+
+ IRFunc* primalFunc = origFunc;
+
+ differentiableTypeConformanceContext.setFunc(origFunc);
+
+ primalFunc = origFunc;
+
+ auto diffFunc = builder.createFunc();
+
+ SLANG_ASSERT(as(origFunc->getFullType()));
+ IRType* diffFuncType = this->differentiateFunctionType(
+ &builder,
+ as(origFunc->getFullType()));
+ diffFunc->setFullType(diffFuncType);
+
+ if (auto nameHint = origFunc->findDecoration())
+ {
+ auto originalName = nameHint->getName();
+ StringBuilder newNameSb;
+ newNameSb << "s_fwd_" << originalName;
+ builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ }
+ builder.addForwardDerivativeDecoration(origFunc, diffFunc);
+
+ // Mark the generated derivative function itself as differentiable.
+ builder.addForwardDifferentiableDecoration(diffFunc);
+
+ // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
+ if (auto dictDecor = origFunc->findDecoration())
+ {
+ cloneDecoration(dictDecor, diffFunc);
+ }
+
+ auto result = InstPair(primalFunc, diffFunc);
+ followUpFunctionsToTranscribe.add(result);
+ return result;
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertInto(diffFunc);
+
+ differentiableTypeConformanceContext.setFunc(primalFunc);
+ // Transcribe children from origFunc into diffFunc
+ for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribe(&builder, block);
+
+ return InstPair(primalFunc, diffFunc);
+ }
+
+ // Transcribe a generic definition
+ InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric)
+ {
+ auto innerVal = findInnerMostGenericReturnVal(origGeneric);
+ if (auto innerFunc = as(innerVal))
+ {
+ differentiableTypeConformanceContext.setFunc(innerFunc);
+ }
+ else
+ {
+ return InstPair(origGeneric, nullptr);
+ }
+
+ // For now, we assume there's only one generic layer. So this inst must be top level
+ bool isTopLevel = (as(origGeneric->getParent()) != nullptr);
+ SLANG_RELEASE_ASSERT(isTopLevel);
+
+ IRGeneric* primalGeneric = origGeneric;
+
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origGeneric);
+
+ auto diffGeneric = builder.emitGeneric();
+
+ // Process type of generic. If the generic is a function, then it's type will also be a
+ // generic and this logic will transcribe that generic first before continuing with the
+ // function itself.
+ //
+ auto primalType = primalGeneric->getFullType();
+
+ IRType* diffType = nullptr;
+ if (primalType)
+ {
+ diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType);
+ }
+
+ diffGeneric->setFullType(diffType);
+
+ // TODO(sai): Replace naming scheme
+ // if (auto jvpName = this->getJVPFuncName(builder, primalFn))
+ // builder->addNameHintDecoration(diffFunc, jvpName);
+
+ // Transcribe children from origFunc into diffFunc.
+ builder.setInsertInto(diffGeneric);
+ for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribe(&builder, block);
+
+ return InstPair(primalGeneric, diffGeneric);
+ }
+
+ IRInst* transcribe(IRBuilder* builder, IRInst* origInst)
+ {
+ // If a differential intstruction is already mapped for
+ // this original inst, return that.
+ //
+ if (auto diffInst = lookupDiffInst(origInst, nullptr))
+ {
+ SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check.
+ return diffInst;
+ }
+
+ // Otherwise, dispatch to the appropriate method
+ // depending on the op-code.
+ //
+ instsInProgress.Add(origInst);
+ InstPair pair = transcribeInst(builder, origInst);
+
+ if (auto primalInst = pair.primal)
+ {
+ mapPrimalInst(origInst, pair.primal);
+ mapDifferentialInst(origInst, pair.differential);
+ if (pair.differential)
+ {
+ switch (pair.differential->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Generic:
+ case kIROp_Block:
+ // Don't generate again for these.
+ // Functions already have their names generated in `transcribeFuncHeader`.
+ break;
+ default:
+ // Generate name hint for the inst.
+ if (auto primalNameHint = primalInst->findDecoration())
+ {
+ StringBuilder sb;
+ sb << "s_diff_" << primalNameHint->getName();
+ builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
+ }
+ break;
+ }
+ }
+ return pair.differential;
+ }
+ instsInProgress.Remove(origInst);
+
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "failed to transcibe instruction");
+ return nullptr;
+ }
+
+ InstPair transcribeInst(IRBuilder* builder, IRInst* origInst)
+ {
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
+ {
+ case kIROp_Param:
+ return transcribeParam(builder, as(origInst));
+
+ case kIROp_Var:
+ return transcribeVar(builder, as(origInst));
+
+ case kIROp_Load:
+ return transcribeLoad(builder, as(origInst));
+
+ case kIROp_Store:
+ return transcribeStore(builder, as(origInst));
+
+ case kIROp_Return:
+ return transcribeReturn(builder, as(origInst));
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return transcribeBinaryArith(builder, origInst);
+
+ case kIROp_Less:
+ case kIROp_Greater:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ return transcribeBinaryLogic(builder, origInst);
+
+ case kIROp_Construct:
+ return transcribeConstruct(builder, origInst);
+
+ case kIROp_Call:
+ return transcribeCall(builder, as(origInst));
+
+ case kIROp_swizzle:
+ return transcribeSwizzle(builder, as(origInst));
+
+ case kIROp_constructVectorFromScalar:
+ case kIROp_MakeTuple:
+ return transcribeByPassthrough(builder, origInst);
+
+ case kIROp_unconditionalBranch:
+ return transcribeControlFlow(builder, origInst);
+
+ case kIROp_FloatLit:
+ case kIROp_IntLit:
+ case kIROp_VoidLit:
+ return transcribeConst(builder, origInst);
+
+ case kIROp_Specialize:
+ return transcribeSpecialize(builder, as(origInst));
+
+ case kIROp_lookup_interface_method:
+ return transcibeLookupInterfaceMethod(builder, as(origInst));
+
+ case kIROp_FieldExtract:
+ case kIROp_FieldAddress:
+ return transcribeFieldExtract(builder, origInst);
+ case kIROp_getElement:
+ case kIROp_getElementPtr:
+ return transcribeGetElement(builder, origInst);
+
+ case kIROp_loop:
+ return transcribeLoop(builder, as(origInst));
+
+ case kIROp_ifElse:
+ return transcribeIfElse(builder, as(origInst));
+
+ case kIROp_MakeDifferentialPair:
+ return transcribeMakeDifferentialPair(builder, as(origInst));
+ case kIROp_DifferentialPairGetPrimal:
+ case kIROp_DifferentialPairGetDifferential:
+ return transcribeDifferentialPairGetElement(builder, origInst);
+ }
+
+ // If none of the cases have been hit, check if the instruction is a
+ // type. Only need to explicitly differentiate types if they appear inside a block.
+ //
+ if (auto origType = as(origInst))
+ {
+ // If this is a generic type, transcibe the parent
+ // generic and derive the type from the transcribed generic's
+ // return value.
+ //
+ if (as(origType->getParent()->getParent()) &&
+ findInnerMostGenericReturnVal(as(origType->getParent()->getParent())) == origType &&
+ !instsInProgress.Contains(origType->getParent()->getParent()))
+ {
+ auto origGenericType = origType->getParent()->getParent();
+ auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType);
+ auto innerDiffGenericType = findInnerMostGenericReturnVal(as(diffGenericType));
+ return InstPair(
+ origGenericType,
+ innerDiffGenericType
+ );
+ }
+ else if (as(origType->getParent()))
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origType),
+ differentiateType(builder, origType));
+ else
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origType),
+ nullptr);
+ }
+
+ // Handle instructions with children
+ switch (origInst->getOp())
+ {
+ case kIROp_Func:
+ return transcribeFuncHeader(builder, as(origInst));
+
+ case kIROp_Block:
+ return transcribeBlock(builder, as(origInst));
+
+ case kIROp_Generic:
+ return transcribeGeneric(builder, as(origInst));
+ }
+
+
+ // If we reach this statement, the instruction type is likely unhandled.
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::unimplemented,
+ "this instruction cannot be differentiated");
+
+ return InstPair(nullptr, nullptr);
+ }
+};
+
+struct ForwardDerivativePass : public InstPassBase
+{
+
+ DiagnosticSink* getSink()
+ {
+ return sink;
+ }
+
+ bool processModule()
+ {
+ // TODO(sai): Move this call.
+ transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary();
+
+ IRBuilder builderStorage(this->autodiffContext->sharedBuilder);
+ IRBuilder* builder = &builderStorage;
+
+ // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
+ // generating derivative code for the referenced function.
+ //
+ bool modified = processReferencedFunctions(builder);
+
+ return modified;
+ }
+
+ IRInst* lookupJVPReference(IRInst* primalFunction)
+ {
+ if (auto jvpDefinition = primalFunction->findDecoration())
+ return jvpDefinition->getForwardDerivativeFunc();
+ return nullptr;
+ }
+
+ // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate),
+ // then check that the referenced function is marked correctly for differentiation.
+ //
+ bool processReferencedFunctions(IRBuilder* builder)
+ {
+ List autoDiffWorkList;
+
+ for (;;)
+ {
+ // Collect all `ForwardDifferentiate` insts from the module.
+ autoDiffWorkList.clear();
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ // Only process now if the operand is a materialized function.
+ switch (inst->getOperand(0)->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Specialize:
+ autoDiffWorkList.add(inst);
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ }
+ });
+
+ if (autoDiffWorkList.getCount() == 0)
+ break;
+
+ // Process collected `ForwardDifferentiate` insts and replace them with placeholders for
+ // differentiated functions.
+
+ transcriberStorage.followUpFunctionsToTranscribe.clear();
+
+ for (auto differentiateInst : autoDiffWorkList)
+ {
+ IRInst* baseInst = differentiateInst->getOperand(0);
+ if (as(differentiateInst))
+ {
+ if (auto existingDiffFunc = lookupJVPReference(baseInst))
+ {
+ differentiateInst->replaceUsesWith(existingDiffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else if (isMarkedForForwardDifferentiation(baseInst))
+ {
+ if (as(baseInst) || as(baseInst))
+ {
+ IRInst* diffFunc = transcriberStorage.transcribe(builder, baseInst);
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else
+ {
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
+ }
+ }
+ else
+ {
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Requested differentiation on a function that isn't marked as differentiable.");
+ }
+
+ }
+ }
+ // Actually synthesize the derivatives.
+ List followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe);
+ for (auto task : followUpWorkList)
+ {
+ auto diffFunc = as(task.differential);
+ SLANG_ASSERT(diffFunc);
+ auto primalFunc = as(task.primal);
+ SLANG_ASSERT(primalFunc);
+
+ transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
+ }
+
+ // Transcribing the function body really shouldn't produce more follow up function body work.
+ // However it may produce new `ForwardDifferentiate` instructions, which we collect and process
+ // in the next iteration.
+ SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0);
+
+ }
+ return true;
+ }
+
+ // Checks decorators to see if the function should
+ // be differentiated (kIROp_ForwardDifferentiableDecoration)
+ //
+ bool isMarkedForForwardDifferentiation(IRInst* callable)
+ {
+ return callable->findDecoration() != nullptr;
+ }
+
+ IRStringLit* getForwardDerivativeFuncName(IRInst* func)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(func);
+
+ IRStringLit* name = nullptr;
+ if (auto linkageDecoration = func->findDecoration())
+ {
+ name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice());
+ }
+ else if (auto namehintDecoration = func->findDecoration())
+ {
+ name = builder.getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice());
+ }
+
+ return name;
+ }
+
+ ForwardDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) :
+ InstPassBase(context->moduleInst->getModule()),
+ sink(sink),
+ transcriberStorage(context, context->sharedBuilder),
+ pairBuilderStorage(context),
+ autodiffContext(context)
+ {
+ transcriberStorage.sink = sink;
+ transcriberStorage.autoDiffSharedContext = context;
+ transcriberStorage.pairBuilder = &(pairBuilderStorage);
+ }
+
+protected:
+ // A transcriber object that handles the main job of
+ // processing instructions while maintaining state.
+ //
+ JVPTranscriber transcriberStorage;
+
+ // Diagnostic object from the compile request for
+ // error messages.
+ DiagnosticSink* sink;
+
+ // Shared context.
+ AutoDiffSharedContext* autodiffContext;
+
+ // Builder for dealing with differential pair types.
+ DifferentialPairTypeBuilder pairBuilderStorage;
+
+};
+
+// Set up context and call main process method.
+//
+bool processForwardDerivativeCalls(
+ AutoDiffSharedContext* autodiffContext,
+ DiagnosticSink* sink,
+ ForwardDerivativePassOptions const&)
+{
+ ForwardDerivativePass fwdPass(autodiffContext, sink);
+ bool changed = fwdPass.processModule();
+ return changed;
+}
+
+
+void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
+{
+ parentFunc = func;
+
+ auto decor = func->findDecoration();
+ SLANG_RELEASE_ASSERT(decor);
+
+ // Build lookup dictionary for type witnesses.
+ for (auto child = decor->getFirstChild(); child; child = child->next)
+ {
+ if (auto item = as(child))
+ {
+ auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
+ if (existingItem)
+ {
+ *existingItem = item->getWitness();
+ }
+ else
+ {
+ differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
+ }
+ }
+ }
+}
+
+
+// Lookup a witness table for the concreteType. One should exist if concreteType
+// inherits (successfully) from IDifferentiable.
+//
+
+IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type)
+{
+ IRInst* foundResult = nullptr;
+ differentiableWitnessDictionary.TryGetValue(type, foundResult);
+ return foundResult;
+}
+
+IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
+{
+ if (auto conformance = lookUpConformanceForType(origType))
+ {
+ return _lookupWitness(builder, conformance, key);
+ }
+ return nullptr;
+}
+
+void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
+{
+ for (auto globalInst : sharedContext->moduleInst->getChildren())
+ {
+ if (auto pairType = as(globalInst))
+ {
+ differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness());
+ }
+ }
+}
+}
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
new file mode 100644
index 000000000..6b261ecd0
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -0,0 +1,43 @@
+// slang-ir-autodiff-fwd.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+namespace Slang
+{
+
+ template
+ struct DiffInstPair
+ {
+ P primal;
+ D differential;
+ DiffInstPair() = default;
+ DiffInstPair(P primal, D differential) : primal(primal), differential(differential)
+ {}
+ HashCode getHashCode() const
+ {
+ Hasher hasher;
+ hasher << primal << differential;
+ return hasher.getResult();
+ }
+ bool operator ==(const DiffInstPair& other) const
+ {
+ return primal == other.primal && differential == other.differential;
+ }
+ };
+
+ typedef DiffInstPair InstPair;
+
+ struct ForwardDerivativePassOptions
+ {
+ // Nothing for now..
+ };
+
+ bool processForwardDerivativeCalls(
+ AutoDiffSharedContext* autodiffContext,
+ DiagnosticSink* sink,
+ ForwardDerivativePassOptions const& options = ForwardDerivativePassOptions());
+
+}
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
new file mode 100644
index 000000000..1dbb1bd7c
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -0,0 +1,182 @@
+#include "slang-ir-autodiff-pairs.h"
+
+namespace Slang
+{
+
+struct DiffPairLoweringPass : InstPassBase
+{
+ DiffPairLoweringPass(AutoDiffSharedContext* context) :
+ InstPassBase(context->moduleInst->getModule()),
+ pairBuilderStorage(context),
+ autodiffContext(context)
+ {
+ pairBuilder = &pairBuilderStorage;
+ }
+
+ IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr)
+ {
+ builder->setInsertBefore(pairType);
+ auto loweredPairTypeInfo = pairBuilder->lowerDiffPairType(
+ builder,
+ pairType);
+ if (isTrivial)
+ *isTrivial = loweredPairTypeInfo.isTrivial;
+ return loweredPairTypeInfo.loweredType;
+ }
+
+ IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
+ {
+
+ if (auto makePairInst = as(inst))
+ {
+ bool isTrivial = false;
+ auto pairType = as(makePairInst->getDataType());
+ if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial))
+ {
+ builder->setInsertBefore(makePairInst);
+ IRInst* result = nullptr;
+ if (isTrivial)
+ {
+ result = makePairInst->getPrimalValue();
+ }
+ else
+ {
+ IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() };
+ result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands);
+ }
+ makePairInst->replaceUsesWith(result);
+ makePairInst->removeAndDeallocate();
+ return result;
+ }
+ }
+
+ return nullptr;
+ }
+
+ IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst)
+ {
+ if (auto getDiffInst = as(inst))
+ {
+ auto pairType = getDiffInst->getBase()->getDataType();
+ if (auto pairPtrType = as(pairType))
+ {
+ pairType = pairPtrType->getValueType();
+ }
+
+ if (lowerPairType(builder, pairType, nullptr))
+ {
+ builder->setInsertBefore(getDiffInst);
+ IRInst* diffFieldExtract = nullptr;
+ diffFieldExtract = pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase());
+ getDiffInst->replaceUsesWith(diffFieldExtract);
+ getDiffInst->removeAndDeallocate();
+ return diffFieldExtract;
+ }
+ }
+ else if (auto getPrimalInst = as(inst))
+ {
+ auto pairType = getPrimalInst->getBase()->getDataType();
+ if (auto pairPtrType = as(pairType))
+ {
+ pairType = pairPtrType->getValueType();
+ }
+
+ if (lowerPairType(builder, pairType, nullptr))
+ {
+ builder->setInsertBefore(getPrimalInst);
+
+ IRInst* primalFieldExtract = nullptr;
+ primalFieldExtract = pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
+ getPrimalInst->replaceUsesWith(primalFieldExtract);
+ getPrimalInst->removeAndDeallocate();
+ return primalFieldExtract;
+ }
+ }
+
+ return nullptr;
+ }
+
+ bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren)
+ {
+ bool modified = false;
+ // Hoist all pair types to global scope when possible.
+ auto moduleInst = module->getModuleInst();
+ processInstsOfType(kIROp_DifferentialPairType, [&](IRInst* originalPairType)
+ {
+ if (originalPairType->parent != moduleInst)
+ {
+ originalPairType->removeFromParent();
+ ShortList operands;
+ for (UInt i = 0; i < originalPairType->getOperandCount(); i++)
+ {
+ operands.add(originalPairType->getOperand(i));
+ }
+ auto newPairType = builder->findOrEmitHoistableInst(
+ originalPairType->getFullType(),
+ originalPairType->getOp(),
+ originalPairType->getOperandCount(),
+ operands.getArrayView().getBuffer());
+ originalPairType->replaceUsesWith(newPairType);
+ originalPairType->removeAndDeallocate();
+ }
+ });
+
+ autodiffContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
+
+ processAllInsts([&](IRInst* inst)
+ {
+ // Make sure the builder is at the right level.
+ builder->setInsertInto(instWithChildren);
+
+ switch (inst->getOp())
+ {
+ case kIROp_DifferentialPairGetDifferential:
+ case kIROp_DifferentialPairGetPrimal:
+ lowerPairAccess(builder, inst);
+ modified = true;
+ break;
+
+ case kIROp_MakeDifferentialPair:
+ lowerMakePair(builder, inst);
+ modified = true;
+ break;
+
+ default:
+ break;
+ }
+ });
+
+ processInstsOfType(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
+ {
+ if (auto loweredType = lowerPairType(builder, inst))
+ {
+ inst->replaceUsesWith(loweredType);
+ inst->removeAndDeallocate();
+ }
+ });
+ return modified;
+ }
+
+ bool processModule()
+ {
+ IRBuilder builder(autodiffContext->sharedBuilder);
+ return processInstWithChildren(&builder, module->getModuleInst());
+ }
+
+ private:
+
+ AutoDiffSharedContext* autodiffContext;
+
+ DifferentialPairTypeBuilder* pairBuilder;
+
+ DifferentialPairTypeBuilder pairBuilderStorage;
+
+};
+
+bool processPairTypes(AutoDiffSharedContext* context)
+{
+ DiffPairLoweringPass pairLoweringPass(context);
+ return pairLoweringPass.processModule();
+}
+
+}
\ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-pairs.h b/source/slang/slang-ir-autodiff-pairs.h
new file mode 100644
index 000000000..44321ae9b
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-pairs.h
@@ -0,0 +1,21 @@
+// slang-ir-autodiff-pairs.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
+
+#include "slang-ir-autodiff.h"
+
+namespace Slang
+{
+
+bool processPairTypes(AutoDiffSharedContext* context);
+
+}
\ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
new file mode 100644
index 000000000..52567e887
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -0,0 +1,832 @@
+#include "slang-ir-autodiff-rev.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
+
+
+namespace Slang
+{
+struct BackwardDiffTranscriber
+{
+ // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
+ // their differential values.
+ Dictionary orginalToTranscribed;
+
+ // Set of insts currently being transcribed. Used to avoid infinite loops.
+ HashSet instsInProgress;
+
+ // Cloning environment to hold mapping from old to new copies for the primal
+ // instructions.
+ IRCloneEnv cloneEnv;
+
+ // Diagnostic sink for error messages.
+ DiagnosticSink* sink;
+
+ // Type conformance information.
+ AutoDiffSharedContext* autoDiffSharedContext;
+
+ // Builder to help with creating and accessing the 'DifferentiablePair' struct
+ DifferentialPairTypeBuilder* pairBuilder;
+
+ DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
+
+ List followUpFunctionsToTranscribe;
+
+ // Map that stores the upper gradient given an IRInst*
+ Dictionary> upperGradients;
+ Dictionary primalToDiffPair;
+
+ SharedIRBuilder* sharedBuilder;
+ // Witness table that `DifferentialBottom:IDifferential`.
+ IRWitnessTable* differentialBottomWitness = nullptr;
+ Dictionary differentialPairTypes;
+
+ BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ : autoDiffSharedContext(shared)
+ , sink(inSink)
+ , differentiableTypeConformanceContext(shared)
+ , sharedBuilder(inSharedBuilder)
+ {}
+
+ DiagnosticSink* getSink()
+ {
+ SLANG_ASSERT(sink);
+ return sink;
+ }
+
+ IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ {
+ List newParameterTypes;
+ IRType* diffReturnType;
+
+ for (UIndex i = 0; i < funcType->getParamCount(); i++)
+ {
+ auto origType = funcType->getParamType(i);
+ if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ {
+ auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
+ newParameterTypes.add(inoutDiffPairType);
+ }
+ else
+ newParameterTypes.add(origType);
+ }
+
+ newParameterTypes.add(funcType->getResultType());
+
+ diffReturnType = builder->getVoidType();
+
+ return builder->getFuncType(newParameterTypes, diffReturnType);
+ }
+
+ // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+ IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(inDiffPairType->parent);
+ auto diffPairType = as(inDiffPairType);
+ SLANG_ASSERT(diffPairType);
+
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+ auto diffType = differentiateType(&builder, diffPairType->getValueType());
+ auto differentialType = builder.getDifferentialPairType(diffType, nullptr);
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ return table;
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ auto witness = as(
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* differentiateType(IRBuilder* builder, IRType* origType)
+ {
+ IRInst* diffType = nullptr;
+ if (!orginalToTranscribed.TryGetValue(origType, diffType))
+ {
+ diffType = _differentiateTypeImpl(builder, origType);
+ orginalToTranscribed[origType] = diffType;
+ }
+ return (IRType*)diffType;
+ }
+
+ IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
+ {
+ if (auto ptrType = as(origType))
+ return builder->getPtrType(
+ origType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
+
+ // If there is an explicit primal version of this type in the local scope, load that
+ // otherwise use the original type.
+ //
+ IRInst* primalType = origType;
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as(primalType->getDataType()))
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
+ builder,
+ (IRType*)primalType));
+ else if (as(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
+ {
+ auto primalArrayType = as(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
+
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as(primalType);
+ return getOrCreateDiffPairType(
+ pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
+ pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_FuncType:
+ return differentiateFunctionType(builder, as(primalType));
+
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_TupleType:
+ {
+ auto tupleType = as(primalType);
+ List diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
+
+ return builder->getTupleType(diffTypeList);
+ }
+
+ default:
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
+ }
+ }
+
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
+ {
+ // If this is a PtrType (out, inout, etc..), then create diff pair from
+ // value type and re-apply the appropropriate PtrType wrapper.
+ //
+ if (auto origPtrType = as(primalType))
+ {
+ if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
+ return builder->getPtrType(primalType->getOp(), diffPairValueType);
+ else
+ return nullptr;
+ }
+ auto diffType = differentiateType(builder, primalType);
+ if (diffType)
+ return (IRType*)getOrCreateDiffPairType(primalType);
+ return nullptr;
+ }
+
+ InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
+ {
+ auto primalDataType = origParam->getDataType();
+ // Do not differentiate generic type (and witness table) parameters
+ if (as(primalDataType) || as(primalDataType))
+ {
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffPairParam);
+
+ if (auto pairType = as(diffPairParam->getDataType()))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
+ // its primal and diff parts right now because they would represent a reference
+ // to a pair field, which doesn't make sense since pair types are considered mutable.
+ // We encode the result as if the param is non-differentiable, and handle it
+ // with special care at load/store.
+ return InstPair(diffPairParam, nullptr);
+ }
+
+
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ // Returns "dp" to use as a name hint for parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String makeDiffPairName(IRInst* origVar)
+ {
+ if (auto namehintDecoration = origVar->findDecoration())
+ {
+ return ("dp" + String(namehintDecoration->getName()));
+ }
+
+ return String("");
+ }
+
+
+ // In differential computation, the 'default' differential value is always zero.
+ // This is a consequence of differential computing being inherently linear. As a
+ // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+ // The current implementation requires that types are defined linearly.
+ //
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+ {
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ switch (diffType->getOp())
+ {
+ case kIROp_DifferentialPairType:
+ return builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as(diffType)->getValueType()));
+ }
+ // Since primalType has a corresponding differential type, we can lookup the
+ // definition for zero().
+ auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
+ SLANG_ASSERT(zeroMethod);
+
+ auto emptyArgList = List();
+ return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ }
+ else
+ {
+ if (isScalarIntegerType(primalType))
+ {
+ return builder->getIntValue(primalType, 0);
+ }
+
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
+ }
+ }
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+ {
+ IRBuilder subBuilder(builder->getSharedBuilder());
+ subBuilder.setInsertLoc(builder->getInsertLoc());
+
+ IRBlock* diffBlock = subBuilder.emitBlock();
+
+ subBuilder.setInsertInto(diffBlock);
+
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ this->copyParam(&subBuilder, param);
+
+ // The extra param for input gradient
+ auto gradParam = subBuilder.emitParam(as(origBlock->getParent()->getFullType())->getResultType());
+
+ // Then, run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ this->copyInst(&subBuilder, child);
+
+ auto lastInst = diffBlock->getLastOrdinaryInst();
+ List grads = { gradParam };
+ upperGradients.Add(lastInst, grads);
+ for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst())
+ {
+ auto upperGrads = upperGradients.TryGetValue(child);
+ if (!upperGrads)
+ continue;
+ if (upperGrads->getCount() > 1)
+ {
+ auto sumGrad = upperGrads->getFirst();
+ for (auto i = 1; i < upperGrads->getCount(); i++)
+ {
+ sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]);
+ }
+ this->transcribeInstBackward(&subBuilder, child, sumGrad);
+ }
+ else
+ this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst());
+ }
+
+ subBuilder.emitReturn();
+
+ return InstPair(diffBlock, diffBlock);
+ }
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origFunc);
+
+ IRFunc* primalFunc = origFunc;
+
+ differentiableTypeConformanceContext.setFunc(origFunc);
+
+ primalFunc = origFunc;
+
+ auto diffFunc = builder.createFunc();
+
+ SLANG_ASSERT(as(origFunc->getFullType()));
+ IRType* diffFuncType = this->differentiateFunctionType(
+ &builder,
+ as(origFunc->getFullType()));
+ diffFunc->setFullType(diffFuncType);
+
+ if (auto nameHint = origFunc->findDecoration())
+ {
+ auto originalName = nameHint->getName();
+ StringBuilder newNameSb;
+ newNameSb << "s_bwd_" << originalName;
+ builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ }
+ builder.addBackwardDerivativeDecoration(origFunc, diffFunc);
+
+ // Mark the generated derivative function itself as differentiable.
+ builder.addBackwardDifferentiableDecoration(diffFunc);
+
+ // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
+ if (auto dictDecor = origFunc->findDecoration())
+ {
+ cloneDecoration(dictDecor, diffFunc);
+ }
+
+ auto result = InstPair(primalFunc, diffFunc);
+ followUpFunctionsToTranscribe.add(result);
+ return result;
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertInto(diffFunc);
+
+ differentiableTypeConformanceContext.setFunc(primalFunc);
+ // Transcribe children from origFunc into diffFunc
+ for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribeBlock(&builder, block);
+
+ return InstPair(primalFunc, diffFunc);
+ }
+
+ IRInst* copyParam(IRBuilder* builder, IRParam* origParam)
+ {
+ auto primalDataType = origParam->getDataType();
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
+ IRInst* diffParam = builder->emitParam(inoutDiffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffParam);
+ auto paramValue = builder->emitLoad(diffParam);
+ auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
+ orginalToTranscribed.Add(origParam, primal);
+ primalToDiffPair.Add(primal, diffParam);
+
+ return diffParam;
+ }
+
+
+ return cloneInst(&cloneEnv, builder, origParam);
+ }
+
+ InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith)
+ {
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ auto origLeft = origArith->getOperand(0);
+ auto origRight = origArith->getOperand(1);
+
+ IRInst* primalLeft;
+ if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft))
+ {
+ primalLeft = origLeft;
+ }
+ IRInst* primalRight;
+ if (!orginalToTranscribed.TryGetValue(origRight, primalRight))
+ {
+ primalRight = origRight;
+ }
+
+ auto resultType = origArith->getDataType();
+ IRInst* newInst;
+ switch (origArith->getOp())
+ {
+ case kIROp_Add:
+ newInst = builder->emitAdd(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Mul:
+ newInst = builder->emitMul(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Sub:
+ newInst = builder->emitSub(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Div:
+ newInst = builder->emitDiv(resultType, primalLeft, primalRight);
+ break;
+ default:
+ newInst = nullptr;
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
+ }
+ orginalToTranscribed.Add(origArith, newInst);
+ return InstPair(newInst, nullptr);
+ }
+
+ IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
+ {
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ auto lhs = origArith->getOperand(0);
+ auto rhs = origArith->getOperand(1);
+
+ if (as(lhs->getDataType()))
+ {
+ lhs = builder->emitLoad(lhs);
+ lhs = builder->emitDifferentialPairGetPrimal(lhs);
+ }
+ if (as(rhs->getDataType()))
+ {
+ rhs = builder->emitLoad(rhs);
+ rhs = builder->emitDifferentialPairGetPrimal(rhs);
+ }
+
+ IRInst* leftGrad;
+ IRInst* rightGrad;
+
+
+ switch (origArith->getOp())
+ {
+ case kIROp_Add:
+ leftGrad = grad;
+ rightGrad = grad;
+ break;
+ case kIROp_Mul:
+ leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
+ rightGrad = builder->emitMul(grad->getDataType(), lhs, grad);
+ break;
+ case kIROp_Sub:
+ leftGrad = grad;
+ rightGrad = builder->emitNeg(grad->getDataType(), grad);
+ break;
+ case kIROp_Div:
+ leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
+ rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad
+ break;
+ default:
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
+ }
+
+ lhs = origArith->getOperand(0);
+ rhs = origArith->getOperand(1);
+ if (auto leftGrads = upperGradients.TryGetValue(lhs))
+ {
+ leftGrads->add(leftGrad);
+ }
+ else
+ {
+ upperGradients.Add(lhs, leftGrad);
+ }
+ if (auto rightGrads = upperGradients.TryGetValue(rhs))
+ {
+ rightGrads->add(rightGrad);
+ }
+ else
+ {
+ upperGradients.Add(rhs, rightGrad);
+ }
+
+ return nullptr;
+ }
+
+ InstPair copyInst(IRBuilder* builder, IRInst* origInst)
+ {
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
+ {
+ case kIROp_Param:
+ return transcribeParam(builder, as(origInst));
+
+ case kIROp_Return:
+ return InstPair(nullptr, nullptr);
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return copyBinaryArith(builder, origInst);
+
+ default:
+ // Not yet implemented
+ SLANG_ASSERT(0);
+ }
+
+ return InstPair(nullptr, nullptr);
+ }
+
+ IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
+ {
+ IRInOutType* inoutParam = as(param->getDataType());
+ auto pairType = as(inoutParam->getValueType());
+ auto paramValue = builder->emitLoad(param);
+ auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
+ auto diff = builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ paramValue
+ );
+ auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad);
+ auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff);
+ auto store = builder->emitStore(param, updatedParam);
+
+ return store;
+ }
+
+ IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
+ {
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
+ {
+ case kIROp_Param:
+ return transcribeParamBackward(builder, as(origInst), grad);
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return transcribeBinaryArithBackward(builder, origInst, grad);
+
+ case kIROp_DifferentialPairGetPrimal:
+ {
+ if (auto param = primalToDiffPair.TryGetValue(origInst))
+ {
+ if (auto leftGrads = upperGradients.TryGetValue(*param))
+ {
+ leftGrads->add(grad);
+ }
+ else
+ {
+ upperGradients.Add(*param, grad);
+ }
+ }
+ else
+ SLANG_ASSERT(0);
+ return nullptr;
+ }
+
+ default:
+ // Not yet implemented
+ SLANG_ASSERT(0);
+ }
+
+ return nullptr;
+ }
+
+
+};
+
+struct ReverseDerivativePass : public InstPassBase
+{
+
+ DiagnosticSink* getSink()
+ {
+ return sink;
+ }
+
+ bool processModule()
+ {
+
+ IRBuilder builderStorage(autodiffContext->sharedBuilder);
+ IRBuilder* builder = &builderStorage;
+
+ // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
+ // generating derivative code for the referenced function.
+ //
+ bool modified = processReferencedFunctions(builder);
+
+ return modified;
+ }
+
+ IRInst* lookupJVPReference(IRInst* primalFunction)
+ {
+ if (auto jvpDefinition = primalFunction->findDecoration())
+ return jvpDefinition->getForwardDerivativeFunc();
+ return nullptr;
+ }
+
+ // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate),
+ // then check that the referenced function is marked correctly for differentiation.
+ //
+ bool processReferencedFunctions(IRBuilder* builder)
+ {
+ List autoDiffWorkList;
+
+ for (;;)
+ {
+ // Collect all `ForwardDifferentiate` insts from the module.
+ autoDiffWorkList.clear();
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_BackwardDifferentiate:
+ // Only process now if the operand is a materialized function.
+ switch (inst->getOperand(0)->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Specialize:
+ autoDiffWorkList.add(inst);
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ }
+ });
+
+ if (autoDiffWorkList.getCount() == 0)
+ break;
+
+ // Process collected `ForwardDifferentiate` insts and replace them with placeholders for
+ // differentiated functions.
+
+ backwardTranscriberStorage.followUpFunctionsToTranscribe.clear();
+
+ for (auto differentiateInst : autoDiffWorkList)
+ {
+ IRInst* baseInst = differentiateInst->getOperand(0);
+ if (as(differentiateInst))
+ {
+ if (isMarkedForBackwardDifferentiation(baseInst))
+ {
+ if (as(baseInst))
+ {
+ IRInst* diffFunc =
+ backwardTranscriberStorage
+ .transcribeFuncHeader(builder, (IRFunc*)baseInst)
+ .differential;
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else
+ {
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
+ }
+ }
+ }
+ }
+
+ auto followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe);
+ for (auto task : followUpWorkList)
+ {
+ auto diffFunc = as(task.differential);
+ SLANG_ASSERT(diffFunc);
+ auto primalFunc = as(task.primal);
+ SLANG_ASSERT(primalFunc);
+
+ backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
+ }
+
+ // Transcribing the function body really shouldn't produce more follow up function body work.
+ // However it may produce new `ForwardDifferentiate` instructions, which we collect and process
+ // in the next iteration.
+ SLANG_RELEASE_ASSERT(backwardTranscriberStorage.followUpFunctionsToTranscribe.getCount() == 0);
+
+ }
+ return true;
+ }
+
+ // Checks decorators to see if the function should
+ // be differentiated (kIROp_ForwardDifferentiableDecoration)
+ //
+ bool isMarkedForBackwardDifferentiation(IRInst* callable)
+ {
+ return callable->findDecoration() != nullptr;
+ }
+
+ IRStringLit* getBackwardDerivativeFuncName(IRInst* func)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(func);
+
+ IRStringLit* name = nullptr;
+ if (auto linkageDecoration = func->findDecoration())
+ {
+ name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice());
+ }
+ else if (auto namehintDecoration = func->findDecoration())
+ {
+ name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice());
+ }
+
+ return name;
+ }
+
+ ReverseDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) :
+ InstPassBase(context->moduleInst->getModule()),
+ sink(sink),
+ backwardTranscriberStorage(context, context->sharedBuilder, sink),
+ autodiffContext(context),
+ pairBuilderStorage(context)
+ {
+ backwardTranscriberStorage.pairBuilder = &pairBuilderStorage;
+ }
+
+protected:
+ // A transcriber object that handles the main job of
+ // processing instructions while maintaining state.
+ //
+ BackwardDiffTranscriber backwardTranscriberStorage;
+
+ // Diagnostic object from the compile request for
+ // error messages.
+ DiagnosticSink* sink;
+
+ // Builder for dealing with differential pair types.
+ DifferentialPairTypeBuilder pairBuilderStorage;
+
+ // Autodiff Shared Context
+ AutoDiffSharedContext* autodiffContext;
+};
+
+bool processReverseDerivativeCalls(
+ AutoDiffSharedContext* autodiffContext,
+ DiagnosticSink* sink,
+ IRReverseDerivativePassOptions const&)
+{
+ ReverseDerivativePass revPass(autodiffContext, sink);
+ bool changed = revPass.processModule();
+ return changed;
+}
+
+}
\ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
new file mode 100644
index 000000000..c3d31e2a9
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -0,0 +1,25 @@
+// slang-ir-autodiff-rev.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-fwd.h"
+
+namespace Slang
+{
+
+struct IRReverseDerivativePassOptions
+{
+ // Nothing for now..
+};
+
+bool processReverseDerivativeCalls(
+ AutoDiffSharedContext* autodiffContext,
+ DiagnosticSink* sink,
+ IRReverseDerivativePassOptions const& options = IRReverseDerivativePassOptions());
+
+
+}
\ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
new file mode 100644
index 000000000..313760d85
--- /dev/null
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -0,0 +1,408 @@
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-rev.h"
+#include "slang-ir-autodiff-fwd.h"
+#include "slang-ir-autodiff-pairs.h"
+
+namespace Slang
+{
+
+// TODO: Put into a nameless namespace.
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
+{
+ if (auto witnessTable = as(witness))
+ {
+ for (auto entry : witnessTable->getEntries())
+ {
+ if (entry->getRequirementKey() == requirementKey)
+ return entry->getSatisfyingVal();
+ }
+ }
+ else if (auto witnessTableParam = as(witness))
+ {
+ return builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ witnessTableParam,
+ requirementKey);
+ }
+ return nullptr;
+}
+
+IRStructField* DifferentialPairTypeBuilder::findField(IRInst* type, IRStructKey* key)
+{
+ if (auto irStructType = as(type))
+ {
+ for (auto field : irStructType->getFields())
+ {
+ if (field->getKey() == key)
+ {
+ return field;
+ }
+ }
+ }
+ else if (auto irSpecialize = as(type))
+ {
+ if (auto irGeneric = as(irSpecialize->getBase()))
+ {
+ if (auto irGenericStructType = as(findInnerMostGenericReturnVal(irGeneric)))
+ {
+ return findField(irGenericStructType, key);
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+IRInst* DifferentialPairTypeBuilder::findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam)
+{
+ // Get base generic that's being specialized.
+ auto genericType = as(as(specializeInst)->getBase());
+ SLANG_ASSERT(genericType);
+
+ // Find the index of genericParam in the base generic.
+ int paramIndex = -1;
+ int currentIndex = 0;
+ for (auto param : genericType->getParams())
+ {
+ if (param == genericParam)
+ paramIndex = currentIndex;
+ currentIndex ++;
+ }
+
+ SLANG_ASSERT(paramIndex >= 0);
+
+ // Return the corresponding operand in the specialization inst.
+ return specializeInst->getOperand(1 + paramIndex);
+}
+
+IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key)
+{
+ IRInst* pairType = nullptr;
+ if (auto basePtrType = as(baseInst->getDataType()))
+ {
+ auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType());
+
+ // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types,
+ // especially since we probably don't need diff bottom anymore.
+ //
+ SLANG_ASSERT(!baseTypeInfo.isTrivial);
+
+ pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType);
+ }
+ else
+ {
+ auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType());
+ pairType = baseTypeInfo.loweredType;
+ }
+
+ if (auto basePairStructType = as(pairType))
+ {
+ return as(builder->emitFieldExtract(
+ findField(basePairStructType, key)->getFieldType(),
+ baseInst,
+ key
+ ));
+ }
+ else if (auto ptrType = as(pairType))
+ {
+ if (auto ptrInnerSpecializedType = as(ptrType->getValueType()))
+ {
+ auto genericType = findInnerMostGenericReturnVal(as(ptrInnerSpecializedType->getBase()));
+ if (auto genericBasePairStructType = as(genericType))
+ {
+ return as(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findSpecializationForParam(
+ ptrInnerSpecializedType,
+ findField(ptrInnerSpecializedType, key)->getFieldType())),
+ baseInst,
+ key
+ ));
+ }
+ }
+ else if (auto ptrBaseStructType = as(ptrType->getValueType()))
+ {
+ return as(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findField(ptrBaseStructType, key)->getFieldType()),
+ baseInst,
+ key));
+ }
+ }
+ else if (auto specializedType = as(pairType))
+ {
+ // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
+ // type, emit the specialization type.
+ //
+ auto genericType = findInnerMostGenericReturnVal(as(specializedType->getBase()));
+ if (auto genericBasePairStructType = as(genericType))
+ {
+ return as(builder->emitFieldExtract(
+ (IRType*)findSpecializationForParam(
+ specializedType,
+ findField(genericBasePairStructType, key)->getFieldType()),
+ baseInst,
+ key
+ ));
+ }
+ else if (auto genericPtrType = as(genericType))
+ {
+ if (auto genericPairStructType = as(genericPtrType->getValueType()))
+ {
+ return as(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findSpecializationForParam(
+ specializedType,
+ findField(genericPairStructType, key)->getFieldType())),
+ baseInst,
+ key
+ ));
+ }
+ }
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor");
+ }
+ return nullptr;
+}
+
+IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
+{
+ return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
+}
+
+IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
+{
+ return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
+}
+
+IRStructKey* DifferentialPairTypeBuilder::_getOrCreateDiffStructKey()
+{
+ if (!this->globalDiffKey)
+ {
+ IRBuilder builder(sharedContext->sharedBuilder);
+ // Insert directly at top level (skip any generic scopes etc.)
+ builder.setInsertInto(sharedContext->moduleInst);
+
+ this->globalDiffKey = builder.createStructKey();
+ builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential"));
+ }
+
+ return this->globalDiffKey;
+}
+
+IRStructKey* DifferentialPairTypeBuilder::_getOrCreatePrimalStructKey()
+{
+ if (!this->globalPrimalKey)
+ {
+ // Insert directly at top level (skip any generic scopes etc.)
+ IRBuilder builder(sharedContext->sharedBuilder);
+ builder.setInsertInto(sharedContext->moduleInst);
+
+ this->globalPrimalKey = builder.createStructKey();
+ builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal"));
+ }
+
+ return this->globalPrimalKey;
+}
+
+IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, IRType* diffType)
+{
+ switch (origBaseType->getOp())
+ {
+ case kIROp_lookup_interface_method:
+ case kIROp_Specialize:
+ case kIROp_Param:
+ return nullptr;
+ default:
+ break;
+ }
+
+ IRBuilder builder(sharedContext->sharedBuilder);
+ builder.setInsertBefore(diffType);
+
+ auto pairStructType = builder.createStructType();
+ builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType);
+ builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType);
+ return pairStructType;
+}
+
+IRInst* DifferentialPairTypeBuilder::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
+}
+
+IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
+}
+
+DifferentialPairTypeBuilder::LoweredPairTypeInfo DifferentialPairTypeBuilder::lowerDiffPairType(
+ IRBuilder* builder, IRType* originalPairType)
+{
+ LoweredPairTypeInfo result = {};
+
+ if (pairTypeCache.TryGetValue(originalPairType, result))
+ return result;
+ auto pairType = as(originalPairType);
+ if (!pairType)
+ {
+ result.isTrivial = true;
+ result.loweredType = originalPairType;
+ return result;
+ }
+ auto primalType = pairType->getValueType();
+ if (as(primalType))
+ {
+ result.isTrivial = false;
+ result.loweredType = nullptr;
+ return result;
+ }
+
+ auto diffType = getDiffTypeFromPairType(builder, pairType);
+ if (!diffType)
+ return result;
+ result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
+ result.isTrivial = false;
+ pairTypeCache.Add(originalPairType, result);
+
+ return result;
+}
+
+
+AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst)
+ : moduleInst(inModuleInst)
+{
+ differentiableInterfaceType = as(findDifferentiableInterface());
+ if (differentiableInterfaceType)
+ {
+ differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
+ zeroMethodStructKey = findZeroMethodStructKey();
+ addMethodStructKey = findAddMethodStructKey();
+ mulMethodStructKey = findMulMethodStructKey();
+
+ if (differentialAssocTypeStructKey)
+ isInterfaceAvailable = true;
+ }
+}
+
+IRInst* AutoDiffSharedContext::findDifferentiableInterface()
+{
+ if (auto module = as(moduleInst))
+ {
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ // TODO: This seems like a particularly dangerous way to look for an interface.
+ // See if we can lower IDifferentiable to a separate IR inst.
+ //
+ if (globalInst->getOp() == kIROp_InterfaceType &&
+ as(globalInst)->findDecoration()->getName() == "IDifferentiable")
+ {
+ return globalInst;
+ }
+ }
+ }
+ return nullptr;
+}
+
+IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index)
+{
+ if (as(moduleInst) && differentiableInterfaceType)
+ {
+ // Assume for now that IDifferentiable has exactly five fields.
+ SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
+ if (auto entry = as(differentiableInterfaceType->getOperand(index)))
+ return as(entry->getRequirementKey());
+ else
+ {
+ SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
+ }
+ }
+
+ return nullptr;
+}
+
+
+void stripAutoDiffDecorationsFromChildren(IRInst* parent)
+{
+ for (auto inst : parent->getChildren())
+ {
+ for (auto decor = inst->getFirstDecoration(); decor; )
+ {
+ auto next = decor->getNextDecoration();
+ switch (decor->getOp())
+ {
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_DerivativeMemberDecoration:
+ case kIROp_DifferentiableTypeDictionaryDecoration:
+ decor->removeAndDeallocate();
+ break;
+ default:
+ break;
+ }
+ decor = next;
+ }
+
+ if (inst->getFirstChild() != nullptr)
+ {
+ stripAutoDiffDecorationsFromChildren(inst);
+ }
+ }
+}
+
+void stripAutoDiffDecorations(IRModule* module)
+{
+ stripAutoDiffDecorationsFromChildren(module->getModuleInst());
+}
+
+bool processAutodiffCalls(
+ IRModule* module,
+ DiagnosticSink* sink,
+ IRAutodiffPassOptions const&)
+{
+ // Simplify module to remove dead code.
+ IRDeadCodeEliminationOptions dceOptions;
+ dceOptions.keepExportsAlive = true;
+ dceOptions.keepLayoutsAlive = true;
+ eliminateDeadCode(module, dceOptions);
+
+ bool modified = false;
+
+ // Create shared context for all auto-diff related passes
+ AutoDiffSharedContext autodiffContext(module->getModuleInst());
+
+ // We start by initializing our shared IR building state,
+ // since we will re-use that state for any code we
+ // generate along the way.
+ //
+ SharedIRBuilder sharedBuilder;
+ sharedBuilder.init(module);
+ sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+
+ autodiffContext.sharedBuilder = &sharedBuilder;
+
+ // Process forward derivative calls.
+ modified |= processForwardDerivativeCalls(&autodiffContext, sink);
+
+ // Process reverse derivative calls.
+ modified |= processReverseDerivativeCalls(&autodiffContext, sink);
+
+ // Replaces IRDifferentialPairType with an auto-generated struct,
+ // IRDifferentialPairGetDifferential with 'differential' field access,
+ // IRDifferentialPairGetPrimal with 'primal' field access, and
+ // IRMakeDifferentialPair with an IRMakeStruct.
+ //
+ modified |= processPairTypes(&autodiffContext);
+
+ // Remove auto-diff related decorations.
+ stripAutoDiffDecorations(module);
+
+ return modified;
+}
+
+
+}
\ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
new file mode 100644
index 000000000..e470044a4
--- /dev/null
+++ b/source/slang/slang-ir-autodiff.h
@@ -0,0 +1,210 @@
+// slang-ir-autodiff-fwd.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
+
+namespace Slang
+{
+
+struct AutoDiffSharedContext
+{
+ IRModuleInst* moduleInst = nullptr;
+
+ SharedIRBuilder* sharedBuilder = nullptr;
+
+ // A reference to the builtin IDifferentiable interface type.
+ // We use this to look up all the other types (and type exprs)
+ // that conform to a base type.
+ //
+ IRInterfaceType* differentiableInterfaceType = nullptr;
+
+ // The struct key for the 'Differential' associated type
+ // defined inside IDifferential. We use this to lookup the differential
+ // type in the conformance table associated with the concrete type.
+ //
+ IRStructKey* differentialAssocTypeStructKey = nullptr;
+
+ // The struct key for the witness that `Differential` associated type conforms to
+ // `IDifferential`.
+ IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;
+
+
+ // The struct key for the 'zero()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of zero() for a given type.
+ //
+ IRStructKey* zeroMethodStructKey = nullptr;
+
+ // The struct key for the 'add()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of add() for a given type.
+ //
+ IRStructKey* addMethodStructKey = nullptr;
+
+ IRStructKey* mulMethodStructKey = nullptr;
+
+
+ // Modules that don't use differentiable types
+ // won't have the IDifferentiable interface type available.
+ // Set to false to indicate that we are uninitialized.
+ //
+ bool isInterfaceAvailable = false;
+
+
+ AutoDiffSharedContext(IRModuleInst* inModuleInst);
+
+private:
+
+ IRInst* findDifferentiableInterface();
+
+ IRStructKey* findDifferentialTypeStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(0);
+ }
+
+ IRStructKey* findDifferentialTypeWitnessStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(1);
+ }
+
+ IRStructKey* findZeroMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(2);
+ }
+
+ IRStructKey* findAddMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(3);
+ }
+
+ IRStructKey* findMulMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(4);
+ }
+
+ IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index);
+};
+
+struct DifferentiableTypeConformanceContext
+{
+ AutoDiffSharedContext* sharedContext;
+
+ IRGlobalValueWithCode* parentFunc = nullptr;
+ OrderedDictionary differentiableWitnessDictionary;
+
+ DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
+ : sharedContext(shared)
+ {}
+
+ void setFunc(IRGlobalValueWithCode* func);
+
+ void buildGlobalWitnessDictionary();
+
+ // Lookup a witness table for the concreteType. One should exist if concreteType
+ // inherits (successfully) from IDifferentiable.
+ //
+ IRInst* lookUpConformanceForType(IRInst* type);
+
+ IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key);
+
+ // Lookup and return the 'Differential' type declared in the concrete type
+ // in order to conform to the IDifferentiable interface.
+ // Note that inside a generic block, this will be a witness table lookup instruction
+ // that gets resolved during the specialization pass.
+ //
+ IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
+ {
+ switch (origType->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ case kIROp_VectorType:
+ return origType;
+ }
+ return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
+ }
+
+ IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
+ }
+
+ IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
+ }
+
+};
+
+struct DifferentialPairTypeBuilder
+{
+ struct LoweredPairTypeInfo
+ {
+ IRInst* loweredType;
+ bool isTrivial;
+ };
+
+ DifferentialPairTypeBuilder() = default;
+
+ DifferentialPairTypeBuilder(AutoDiffSharedContext* sharedContext) : sharedContext(sharedContext) {}
+
+ IRStructField* findField(IRInst* type, IRStructKey* key);
+
+ IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam);
+
+ IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key);
+
+ IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst);
+
+ IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst);
+
+ IRStructKey* _getOrCreateDiffStructKey();
+
+ IRStructKey* _getOrCreatePrimalStructKey();
+
+ IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType);
+
+ IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type);
+
+ IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type);
+
+ LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType);
+
+
+ Dictionary pairTypeCache;
+
+ IRStructKey* globalPrimalKey = nullptr;
+
+ IRStructKey* globalDiffKey = nullptr;
+
+ IRInst* genericDiffPairType = nullptr;
+
+ List generatedTypeList;
+
+ AutoDiffSharedContext* sharedContext = nullptr;
+};
+
+void stripAutoDiffDecorations(IRModule* module);
+
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey);
+
+struct IRAutodiffPassOptions
+{
+ // Nothing for now...
+};
+
+bool processAutodiffCalls(
+ IRModule* module,
+ DiagnosticSink* sink,
+ IRAutodiffPassOptions const& options = IRAutodiffPassOptions());
+
+};
\ No newline at end of file
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 44c6324e3..f4f61d7e9 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -1,6 +1,6 @@
#include "slang-ir-check-differentiability.h"
-#include "slang-ir-diff-jvp.h"
+#include "slang-ir-autodiff.h"
#include "slang-ir-inst-pass-base.h"
namespace Slang
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
deleted file mode 100644
index c9ca687e4..000000000
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ /dev/null
@@ -1,3197 +0,0 @@
-// slang-ir-diff-jvp.cpp
-#include "slang-ir-diff-jvp.h"
-
-#include "slang-ir-clone.h"
-#include "slang-ir-dce.h"
-#include "slang-ir-eliminate-phis.h"
-#include "slang-ir-util.h"
-#include "slang-ir-inst-pass-base.h"
-
-// origX, primalX, diffX
-// origX -> primalX (cloneEnv)
-// origX -> diffX (instMapD)
-
-namespace Slang
-{
-
-namespace
-{
-
-IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
-{
- if (auto witnessTable = as(witness))
- {
- for (auto entry : witnessTable->getEntries())
- {
- if (entry->getRequirementKey() == requirementKey)
- return entry->getSatisfyingVal();
- }
- }
- else if (auto witnessTableParam = as(witness))
- {
- return builder->emitLookupInterfaceMethodInst(
- builder->getTypeKind(),
- witnessTableParam,
- requirementKey);
- }
- return nullptr;
-}
-
-}
-
-struct DifferentialPairTypeBuilder
-{
-
- IRStructField* findField(IRInst* type, IRStructKey* key)
- {
- if (auto irStructType = as(type))
- {
- for (auto field : irStructType->getFields())
- {
- if (field->getKey() == key)
- {
- return field;
- }
- }
- }
- else if (auto irSpecialize = as(type))
- {
- if (auto irGeneric = as(irSpecialize->getBase()))
- {
- if (auto irGenericStructType = as(findInnerMostGenericReturnVal(irGeneric)))
- {
- return findField(irGenericStructType, key);
- }
- }
- }
-
- return nullptr;
- }
-
- IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam)
- {
- // Get base generic that's being specialized.
- auto genericType = as(as(specializeInst)->getBase());
- SLANG_ASSERT(genericType);
-
- // Find the index of genericParam in the base generic.
- int paramIndex = -1;
- int currentIndex = 0;
- for (auto param : genericType->getParams())
- {
- if (param == genericParam)
- paramIndex = currentIndex;
- currentIndex ++;
- }
-
- SLANG_ASSERT(paramIndex >= 0);
-
- // Return the corresponding operand in the specialization inst.
- return specializeInst->getOperand(1 + paramIndex);
- }
-
- IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key)
- {
- IRInst* pairType = nullptr;
- if (auto basePtrType = as(baseInst->getDataType()))
- {
- auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType());
-
- // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types,
- // especially since we probably don't need diff bottom anymore.
- //
- SLANG_ASSERT(!baseTypeInfo.isTrivial);
-
- pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType);
- }
- else
- {
- auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType());
- if (baseTypeInfo.isTrivial)
- {
- if (key == globalPrimalKey)
- return baseInst;
- else
- return builder->getDifferentialBottom();
- }
-
- pairType = baseTypeInfo.loweredType;
- }
-
- if (auto basePairStructType = as(pairType))
- {
- return as(builder->emitFieldExtract(
- findField(basePairStructType, key)->getFieldType(),
- baseInst,
- key
- ));
- }
- else if (auto ptrType = as(pairType))
- {
- if (auto ptrInnerSpecializedType = as(ptrType->getValueType()))
- {
- auto genericType = findInnerMostGenericReturnVal(as(ptrInnerSpecializedType->getBase()));
- if (auto genericBasePairStructType = as(genericType))
- {
- return as(builder->emitFieldAddress(
- builder->getPtrType((IRType*)
- findSpecializationForParam(
- ptrInnerSpecializedType,
- findField(ptrInnerSpecializedType, key)->getFieldType())),
- baseInst,
- key
- ));
- }
- }
- else if (auto ptrBaseStructType = as(ptrType->getValueType()))
- {
- return as(builder->emitFieldAddress(
- builder->getPtrType((IRType*)
- findField(ptrBaseStructType, key)->getFieldType()),
- baseInst,
- key));
- }
- }
- else if (auto specializedType = as(pairType))
- {
- // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
- // type, emit the specialization type.
- //
- auto genericType = findInnerMostGenericReturnVal(as(specializedType->getBase()));
- if (auto genericBasePairStructType = as(genericType))
- {
- return as(builder->emitFieldExtract(
- (IRType*)findSpecializationForParam(
- specializedType,
- findField(genericBasePairStructType, key)->getFieldType()),
- baseInst,
- key
- ));
- }
- else if (auto genericPtrType = as(genericType))
- {
- if (auto genericPairStructType = as(genericPtrType->getValueType()))
- {
- return as(builder->emitFieldAddress(
- builder->getPtrType((IRType*)
- findSpecializationForParam(
- specializedType,
- findField(genericPairStructType, key)->getFieldType())),
- baseInst,
- key
- ));
- }
- }
- }
- else
- {
- SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor");
- }
- return nullptr;
- }
-
- IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
- {
- return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
- }
-
- IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
- {
- return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
- }
-
- IRStructKey* _getOrCreateDiffStructKey()
- {
- if (!this->globalDiffKey)
- {
- IRBuilder builder(sharedContext->sharedBuilder);
- // Insert directly at top level (skip any generic scopes etc.)
- builder.setInsertInto(sharedContext->moduleInst);
-
- this->globalDiffKey = builder.createStructKey();
- builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential"));
- }
-
- return this->globalDiffKey;
- }
-
- IRStructKey* _getOrCreatePrimalStructKey()
- {
- if (!this->globalPrimalKey)
- {
- // Insert directly at top level (skip any generic scopes etc.)
- IRBuilder builder(sharedContext->sharedBuilder);
- builder.setInsertInto(sharedContext->moduleInst);
-
- this->globalPrimalKey = builder.createStructKey();
- builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal"));
- }
-
- return this->globalPrimalKey;
- }
-
- IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType)
- {
- switch (origBaseType->getOp())
- {
- case kIROp_lookup_interface_method:
- case kIROp_Specialize:
- case kIROp_Param:
- return nullptr;
- default:
- break;
- }
- if (diffType->getOp() != kIROp_DifferentialBottomType)
- {
- IRBuilder builder(sharedContext->sharedBuilder);
- builder.setInsertBefore(diffType);
-
- auto pairStructType = builder.createStructType();
- builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType);
- builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType);
- return pairStructType;
- }
- return origBaseType;
- }
-
- struct LoweredPairTypeInfo
- {
- IRInst* loweredType;
- bool isTrivial;
- };
-
- IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
- {
- auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
- }
-
- IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
- {
- auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
- }
-
- LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType)
- {
- LoweredPairTypeInfo result = {};
-
- if (pairTypeCache.TryGetValue(originalPairType, result))
- return result;
- auto pairType = as(originalPairType);
- if (!pairType)
- {
- result.isTrivial = true;
- result.loweredType = originalPairType;
- return result;
- }
- auto primalType = pairType->getValueType();
- if (as(primalType))
- {
- result.isTrivial = false;
- result.loweredType = nullptr;
- return result;
- }
-
- auto diffType = getDiffTypeFromPairType(builder, pairType);
- if (!diffType)
- return result;
- result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
- result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType);
- pairTypeCache.Add(originalPairType, result);
-
- return result;
- }
-
- Dictionary pairTypeCache;
-
- IRStructKey* globalPrimalKey = nullptr;
-
- IRStructKey* globalDiffKey = nullptr;
-
- IRInst* genericDiffPairType = nullptr;
-
- List generatedTypeList;
-
- AutoDiffSharedContext* sharedContext = nullptr;
-};
-
-struct JVPTranscriber
-{
-
- // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
- // their differential values.
- Dictionary instMapD;
-
- // Set of insts currently being transcribed. Used to avoid infinite loops.
- HashSet instsInProgress;
-
- // Cloning environment to hold mapping from old to new copies for the primal
- // instructions.
- IRCloneEnv cloneEnv;
-
- // Diagnostic sink for error messages.
- DiagnosticSink* sink;
-
- // Type conformance information.
- AutoDiffSharedContext* autoDiffSharedContext;
-
- // Builder to help with creating and accessing the 'DifferentiablePair' struct
- DifferentialPairTypeBuilder* pairBuilder;
-
- DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
-
- List followUpFunctionsToTranscribe;
-
- SharedIRBuilder* sharedBuilder;
- // Witness table that `DifferentialBottom:IDifferential`.
- IRWitnessTable* differentialBottomWitness = nullptr;
- Dictionary differentialPairTypes;
-
- JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder)
- : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder)
- {
-
- }
-
- DiagnosticSink* getSink()
- {
- SLANG_ASSERT(sink);
- return sink;
- }
-
- void mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
- {
- if (hasDifferentialInst(origInst))
- {
- if (lookupDiffInst(origInst) != diffInst)
- {
- SLANG_UNEXPECTED("Inconsistent differential mappings");
- }
- }
- else
- {
- instMapD.Add(origInst, diffInst);
- }
- }
-
- void mapPrimalInst(IRInst* origInst, IRInst* primalInst)
- {
- if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst)
- {
- getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "inconsistent primal instruction for original");
- }
- else
- {
- cloneEnv.mapOldValToNew[origInst] = primalInst;
- }
- }
-
- IRInst* lookupDiffInst(IRInst* origInst)
- {
- return instMapD[origInst];
- }
-
- IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst)
- {
- return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst;
- }
-
- bool hasDifferentialInst(IRInst* origInst)
- {
- return instMapD.ContainsKey(origInst);
- }
-
- IRInst* lookupPrimalInst(IRInst* origInst)
- {
- return cloneEnv.mapOldValToNew[origInst];
- }
-
- IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
- {
- return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
- }
-
- bool hasPrimalInst(IRInst* origInst)
- {
- return cloneEnv.mapOldValToNew.ContainsKey(origInst);
- }
-
- IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst)
- {
- if (!hasDifferentialInst(origInst))
- {
- transcribe(builder, origInst);
- SLANG_ASSERT(hasDifferentialInst(origInst));
- }
-
- return lookupDiffInst(origInst);
- }
-
- IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst)
- {
- if (!hasPrimalInst(origInst))
- {
- transcribe(builder, origInst);
- SLANG_ASSERT(hasPrimalInst(origInst));
- }
-
- return lookupPrimalInst(origInst);
- }
-
- IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
- {
- List newParameterTypes;
- IRType* diffReturnType;
-
- for (UIndex i = 0; i < funcType->getParamCount(); i++)
- {
- auto origType = funcType->getParamType(i);
- origType = (IRType*) lookupPrimalInst(origType, origType);
- if (auto diffPairType = tryGetDiffPairType(builder, origType))
- newParameterTypes.add(diffPairType);
- else
- newParameterTypes.add(origType);
- }
-
- // Transcribe return type to a pair.
- // This will be void if the primal return type is non-differentiable.
- //
- auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType());
- if (auto returnPairType = tryGetDiffPairType(builder, origResultType))
- diffReturnType = returnPairType;
- else
- diffReturnType = builder->getVoidType();
-
- return builder->getFuncType(newParameterTypes, diffReturnType);
- }
-
- IRWitnessTable* getDifferentialBottomWitness()
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
- auto result =
- as(differentiableTypeConformanceContext.lookUpConformanceForType(
- builder.getDifferentialBottomType()));
- SLANG_ASSERT(result);
- return result;
- }
-
- // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
- IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(inDiffPairType->parent);
- auto diffPairType = as(inDiffPairType);
- SLANG_ASSERT(diffPairType);
-
- auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
-
- // Differentiate the pair type to get it's differential (which is itself a pair)
- auto diffDiffPairType = differentiateType(&builder, diffPairType);
-
- // And place it in the synthesized witness table.
- builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
- // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
-
- // Record this in the context for future lookups
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
-
- return table;
- }
-
- IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
- }
-
- IRType* getOrCreateDiffPairType(IRInst* primalType)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- auto witness = as(
- differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
-
- if (!witness)
- {
- if (auto primalPairType = as(primalType))
- {
- witness = getDifferentialPairWitness(primalPairType);
- }
- else
- {
- witness = getDifferentialBottomWitness();
- }
- }
-
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
- }
-
- IRType* differentiateType(IRBuilder* builder, IRType* origType)
- {
- IRInst* diffType = nullptr;
- if (!instMapD.TryGetValue(origType, diffType))
- {
- diffType = _differentiateTypeImpl(builder, origType);
- instMapD[origType] = diffType;
- }
- return (IRType*)diffType;
- }
-
- IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
- {
- if (auto ptrType = as(origType))
- return builder->getPtrType(
- origType->getOp(),
- differentiateType(builder, ptrType->getValueType()));
-
- // If there is an explicit primal version of this type in the local scope, load that
- // otherwise use the original type.
- //
- IRInst* primalType = lookupPrimalInst(origType, origType);
-
- // Special case certain compound types (PtrType, FuncType, etc..)
- // otherwise try to lookup a differential definition for the given type.
- // If one does not exist, then we assume it's not differentiable.
- //
- switch (primalType->getOp())
- {
- case kIROp_Param:
- if (as(primalType->getDataType()))
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
- builder,
- (IRType*)primalType));
- else if (as(primalType->getDataType()))
- return (IRType*)primalType;
-
- case kIROp_ArrayType:
- {
- auto primalArrayType = as(primalType);
- if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
- return builder->getArrayType(
- diffElementType,
- primalArrayType->getElementCount());
- else
- return nullptr;
- }
-
- case kIROp_DifferentialPairType:
- {
- auto primalPairType = as(primalType);
- return getOrCreateDiffPairType(
- pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
- pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
- }
-
- case kIROp_FuncType:
- return differentiateFunctionType(builder, as(primalType));
-
- case kIROp_OutType:
- if (auto diffValueType = differentiateType(builder, as