From 6178cb601368e977c4aa82e0ae25b8eb1e875d84 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 22 Nov 2022 12:36:28 -0500 Subject: Refactor Auto-diff passes (#2526) * Initial refactor * Refactor passes tests * Removed Differential Bottom references from the IR side --- build/visual-studio/slang/slang.vcxproj | 10 +- build/visual-studio/slang/slang.vcxproj.filters | 30 +- source/slang/slang-emit.cpp | 7 +- source/slang/slang-ir-autodiff-fwd.cpp | 1761 ++++++++++++ source/slang/slang-ir-autodiff-fwd.h | 43 + source/slang/slang-ir-autodiff-pairs.cpp | 182 ++ source/slang/slang-ir-autodiff-pairs.h | 21 + source/slang/slang-ir-autodiff-rev.cpp | 832 ++++++ source/slang/slang-ir-autodiff-rev.h | 25 + source/slang/slang-ir-autodiff.cpp | 408 +++ source/slang/slang-ir-autodiff.h | 210 ++ source/slang/slang-ir-check-differentiability.cpp | 2 +- source/slang/slang-ir-diff-jvp.cpp | 3197 --------------------- source/slang/slang-ir-diff-jvp.h | 174 -- source/slang/slang-ir-link.cpp | 2 +- source/slang/slang-lower-to-ir.cpp | 2 +- 16 files changed, 3519 insertions(+), 3387 deletions(-) create mode 100644 source/slang/slang-ir-autodiff-fwd.cpp create mode 100644 source/slang/slang-ir-autodiff-fwd.h create mode 100644 source/slang/slang-ir-autodiff-pairs.cpp create mode 100644 source/slang/slang-ir-autodiff-pairs.h create mode 100644 source/slang/slang-ir-autodiff-rev.cpp create mode 100644 source/slang/slang-ir-autodiff-rev.h create mode 100644 source/slang/slang-ir-autodiff.cpp create mode 100644 source/slang/slang-ir-autodiff.h delete mode 100644 source/slang/slang-ir-diff-jvp.cpp delete mode 100644 source/slang/slang-ir-diff-jvp.h diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index 5244f0efb..fe1922d29 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -333,6 +333,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla + + + + @@ -343,7 +347,6 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla - @@ -505,6 +508,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla + + + + @@ -516,7 +523,6 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla - diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 0a2d1fe3f..6fa42287c 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -132,6 +132,18 @@ Header Files + + Header Files + + + Header Files + + + Header Files + + + Header Files + Header Files @@ -162,9 +174,6 @@ Header Files - - Header Files - Header Files @@ -644,6 +653,18 @@ Source Files + + Source Files + + + Source Files + + + Source Files + + + Source Files + Source Files @@ -677,9 +698,6 @@ Source Files - - Source Files - Source Files diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 69ea29c7a..ca55a68bc 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -11,7 +11,7 @@ #include "slang-ir-cleanup-void.h" #include "slang-ir-dce.h" #include "slang-ir-diff-call.h" -#include "slang-ir-diff-jvp.h" +#include "slang-ir-autodiff.h" #include "slang-ir-dll-export.h" #include "slang-ir-dll-import.h" #include "slang-ir-eliminate-phis.h" @@ -377,10 +377,7 @@ Result linkAndOptimizeIR( dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); - // Process higher-order calles to auto-diff passes. - processDifferentiableFuncs(irModule, sink); - - stripAutoDiffDecorations(irModule); + processAutodiffCalls(irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp new file mode 100644 index 000000000..03e81c5b5 --- /dev/null +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -0,0 +1,1761 @@ +// slang-ir-autodiff-fwd.cpp +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-fwd.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" + +namespace Slang +{ + + +struct JVPTranscriber +{ + + // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent + // their differential values. + Dictionary instMapD; + + // Set of insts currently being transcribed. Used to avoid infinite loops. + HashSet instsInProgress; + + // Cloning environment to hold mapping from old to new copies for the primal + // instructions. + IRCloneEnv cloneEnv; + + // Diagnostic sink for error messages. + DiagnosticSink* sink; + + // Type conformance information. + AutoDiffSharedContext* autoDiffSharedContext; + + // Builder to help with creating and accessing the 'DifferentiablePair' struct + DifferentialPairTypeBuilder* pairBuilder; + + DifferentiableTypeConformanceContext differentiableTypeConformanceContext; + + List followUpFunctionsToTranscribe; + + SharedIRBuilder* sharedBuilder; + // Witness table that `DifferentialBottom:IDifferential`. + IRWitnessTable* differentialBottomWitness = nullptr; + Dictionary differentialPairTypes; + + JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder) + : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder) + { + + } + + DiagnosticSink* getSink() + { + SLANG_ASSERT(sink); + return sink; + } + + void mapDifferentialInst(IRInst* origInst, IRInst* diffInst) + { + if (hasDifferentialInst(origInst)) + { + if (lookupDiffInst(origInst) != diffInst) + { + SLANG_UNEXPECTED("Inconsistent differential mappings"); + } + } + else + { + instMapD.Add(origInst, diffInst); + } + } + + void mapPrimalInst(IRInst* origInst, IRInst* primalInst) + { + if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst) + { + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::internalCompilerError, + "inconsistent primal instruction for original"); + } + else + { + cloneEnv.mapOldValToNew[origInst] = primalInst; + } + } + + IRInst* lookupDiffInst(IRInst* origInst) + { + return instMapD[origInst]; + } + + IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst) + { + return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst; + } + + bool hasDifferentialInst(IRInst* origInst) + { + return instMapD.ContainsKey(origInst); + } + + IRInst* lookupPrimalInst(IRInst* origInst) + { + return cloneEnv.mapOldValToNew[origInst]; + } + + IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) + { + return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; + } + + bool hasPrimalInst(IRInst* origInst) + { + return cloneEnv.mapOldValToNew.ContainsKey(origInst); + } + + IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst) + { + if (!hasDifferentialInst(origInst)) + { + transcribe(builder, origInst); + SLANG_ASSERT(hasDifferentialInst(origInst)); + } + + return lookupDiffInst(origInst); + } + + IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst) + { + if (!hasPrimalInst(origInst)) + { + transcribe(builder, origInst); + SLANG_ASSERT(hasPrimalInst(origInst)); + } + + return lookupPrimalInst(origInst); + } + + IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + { + List newParameterTypes; + IRType* diffReturnType; + + for (UIndex i = 0; i < funcType->getParamCount(); i++) + { + auto origType = funcType->getParamType(i); + origType = (IRType*) lookupPrimalInst(origType, origType); + if (auto diffPairType = tryGetDiffPairType(builder, origType)) + newParameterTypes.add(diffPairType); + else + newParameterTypes.add(origType); + } + + // Transcribe return type to a pair. + // This will be void if the primal return type is non-differentiable. + // + auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType()); + if (auto returnPairType = tryGetDiffPairType(builder, origResultType)) + diffReturnType = returnPairType; + else + diffReturnType = builder->getVoidType(); + + return builder->getFuncType(newParameterTypes, diffReturnType); + } + + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. + IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(inDiffPairType->parent); + auto diffPairType = as(inDiffPairType); + SLANG_ASSERT(diffPairType); + + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = differentiateType(&builder, diffPairType); + + // And place it in the synthesized witness table. + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + + // Record this in the context for future lookups + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + + return table; + } + + IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* getOrCreateDiffPairType(IRInst* primalType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + auto witness = as( + differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); + + if (!witness) + { + if (auto primalPairType = as(primalType)) + { + witness = getDifferentialPairWitness(primalPairType); + } + } + + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* differentiateType(IRBuilder* builder, IRType* origType) + { + IRInst* diffType = nullptr; + if (!instMapD.TryGetValue(origType, diffType)) + { + diffType = _differentiateTypeImpl(builder, origType); + instMapD[origType] = diffType; + } + return (IRType*)diffType; + } + + IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) + { + if (auto ptrType = as(origType)) + return builder->getPtrType( + origType->getOp(), + differentiateType(builder, ptrType->getValueType())); + + // If there is an explicit primal version of this type in the local scope, load that + // otherwise use the original type. + // + IRInst* primalType = lookupPrimalInst(origType, origType); + + // Special case certain compound types (PtrType, FuncType, etc..) + // otherwise try to lookup a differential definition for the given type. + // If one does not exist, then we assume it's not differentiable. + // + switch (primalType->getOp()) + { + case kIROp_Param: + if (as(primalType->getDataType())) + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( + builder, + (IRType*)primalType)); + else if (as(primalType->getDataType())) + return (IRType*)primalType; + + case kIROp_ArrayType: + { + auto primalArrayType = as(primalType); + if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) + return builder->getArrayType( + diffElementType, + primalArrayType->getElementCount()); + else + return nullptr; + } + + case kIROp_DifferentialPairType: + { + auto primalPairType = as(primalType); + return getOrCreateDiffPairType( + pairBuilder->getDiffTypeFromPairType(builder, primalPairType), + pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_FuncType: + return differentiateFunctionType(builder, as(primalType)); + + case kIROp_OutType: + if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType())) + return builder->getOutType(diffValueType); + else + return nullptr; + + case kIROp_InOutType: + if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType())) + return builder->getInOutType(diffValueType); + else + return nullptr; + + case kIROp_TupleType: + { + auto tupleType = as(primalType); + List diffTypeList; + // TODO: what if we have type parameters here? + for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + diffTypeList.add( + differentiateType(builder, (IRType*)tupleType->getOperand(ii))); + + return builder->getTupleType(diffTypeList); + } + + default: + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); + } + } + + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) + { + // If this is a PtrType (out, inout, etc..), then create diff pair from + // value type and re-apply the appropropriate PtrType wrapper. + // + if (auto origPtrType = as(primalType)) + { + if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) + return builder->getPtrType(primalType->getOp(), diffPairValueType); + else + return nullptr; + } + auto diffType = differentiateType(builder, primalType); + if (diffType) + return (IRType*)getOrCreateDiffPairType(primalType); + return nullptr; + } + + InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) + { + auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType()); + // Do not differentiate generic type (and witness table) parameters + if (as(primalDataType) || as(primalDataType)) + { + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + // Is this param a phi node or a function parameter? + auto func = as(origParam->getParent()->getParent()); + bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); + if (isFuncParam) + { + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + IRInst* diffPairParam = builder->emitParam(diffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffPairParam); + + if (auto pairType = as(diffPairType)) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + else if (auto pairPtrType = as(diffPairType)) + { + auto ptrInnerPairType = as(pairPtrType->getValueType()); + + return InstPair( + builder->emitDifferentialPairAddressPrimal(diffPairParam), + builder->emitDifferentialPairAddressDifferential( + builder->getPtrType( + kIROp_PtrType, + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), + diffPairParam)); + } + } + + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + else + { + auto primal = cloneInst(&cloneEnv, builder, origParam); + IRInst* diff = nullptr; + if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) + { + diff = builder->emitParam(diffType); + } + return InstPair(primal, diff); + } + + } + + // Returns "d" to use as a name hint for variables and parameters. + // If no primal name is available, returns a blank string. + // + String getJVPVarName(IRInst* origVar) + { + if (auto namehintDecoration = origVar->findDecoration()) + { + return ("d" + String(namehintDecoration->getName())); + } + + return String(""); + } + + // Returns "dp" to use as a name hint for parameters. + // If no primal name is available, returns a blank string. + // + String makeDiffPairName(IRInst* origVar) + { + if (auto namehintDecoration = origVar->findDecoration()) + { + return ("dp" + String(namehintDecoration->getName())); + } + + return String(""); + } + + InstPair transcribeVar(IRBuilder* builder, IRVar* origVar) + { + if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) + { + IRVar* diffVar = builder->emitVar(diffType); + SLANG_ASSERT(diffVar); + + auto diffNameHint = getJVPVarName(origVar); + if (diffNameHint.getLength() > 0) + builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice()); + + return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar); + } + + return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr); + } + + InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith) + { + SLANG_ASSERT(origArith->getOperandCount() == 2); + + IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith); + + auto origLeft = origArith->getOperand(0); + auto origRight = origArith->getOperand(1); + + auto primalLeft = findOrTranscribePrimalInst(builder, origLeft); + auto primalRight = findOrTranscribePrimalInst(builder, origRight); + + auto diffLeft = findOrTranscribeDiffInst(builder, origLeft); + auto diffRight = findOrTranscribeDiffInst(builder, origRight); + + + if (diffLeft || diffRight) + { + diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType()); + diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); + + auto resultType = primalArith->getDataType(); + switch(origArith->getOp()) + { + case kIROp_Add: + return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight)); + case kIROp_Mul: + return InstPair(primalArith, builder->emitAdd(resultType, + builder->emitMul(resultType, diffLeft, primalRight), + builder->emitMul(resultType, primalLeft, diffRight))); + case kIROp_Sub: + return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight)); + case kIROp_Div: + return InstPair(primalArith, builder->emitDiv(resultType, + builder->emitSub( + resultType, + builder->emitMul(resultType, diffLeft, primalRight), + builder->emitMul(resultType, primalLeft, diffRight)), + builder->emitMul( + primalRight->getDataType(), primalRight, primalRight + ))); + default: + getSink()->diagnose(origArith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); + } + } + + return InstPair(primalArith, nullptr); + } + + + InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic) + { + SLANG_ASSERT(origLogic->getOperandCount() == 2); + + // TODO: Check other boolean cases. + if (as(origLogic->getDataType())) + { + // Boolean operations are not differentiable. For the linearization + // pass, we do not need to do anything but copy them over to the ne + // function. + auto primalLogic = cloneInst(&cloneEnv, builder, origLogic); + return InstPair(primalLogic, nullptr); + } + + SLANG_UNEXPECTED("Logical operation with non-boolean result"); + } + + InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad) + { + auto origPtr = origLoad->getPtr(); + auto primalPtr = lookupPrimalInst(origPtr, nullptr); + auto primalPtrValueType = as(primalPtr->getFullType())->getValueType(); + + if (auto diffPairType = as(primalPtrValueType)) + { + // Special case load from an `out` param, which will not have corresponding `diff` and + // `primal` insts yet. + + auto load = builder->emitLoad(primalPtr); + auto primalElement = builder->emitDifferentialPairGetPrimal(load); + auto diffElement = builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); + return InstPair(primalElement, diffElement); + } + + auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + IRInst* diffLoad = nullptr; + if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) + { + // Default case, we're loading from a known differential inst. + diffLoad = as(builder->emitLoad(diffPtr)); + } + return InstPair(primalLoad, diffLoad); + } + + InstPair transcribeStore(IRBuilder* builder, IRStore* origStore) + { + IRInst* origStoreLocation = origStore->getPtr(); + IRInst* origStoreVal = origStore->getVal(); + auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr); + auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); + auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr); + auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); + + if (!diffStoreLocation) + { + auto primalLocationPtrType = as(primalStoreLocation->getDataType()); + if (auto diffPairType = as(primalLocationPtrType->getValueType())) + { + auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal); + auto store = builder->emitStore(primalStoreLocation, valToStore); + return InstPair(store, nullptr); + } + } + + auto primalStore = cloneInst(&cloneEnv, builder, origStore); + + IRInst* diffStore = nullptr; + + // If the stored value has a differential version, + // emit a store instruction for the differential parameter. + // Otherwise, emit nothing since there's nothing to load. + // + if (diffStoreLocation && diffStoreVal) + { + // Default case, storing the entire type (and not a member) + diffStore = as( + builder->emitStore(diffStoreLocation, diffStoreVal)); + + return InstPair(primalStore, diffStore); + } + + return InstPair(primalStore, nullptr); + } + + InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn) + { + IRInst* origReturnVal = origReturn->getVal(); + + auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType()); + if (as(origReturnVal) || as(origReturnVal) || as(origReturnVal) || as(origReturnVal)) + { + // If the return value is itself a function, generic or a struct then this + // is likely to be a generic scope. In this case, we lookup the differential + // and return that. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); + + // Neither of these should be nullptr. + SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal); + IRReturn* diffReturn = as(builder->emitReturn(diffReturnVal)); + + return InstPair(diffReturn, diffReturn); + } + else if (auto pairType = tryGetDiffPairType(builder, returnDataType)) + { + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); + if(!diffReturnVal) + diffReturnVal = getDifferentialZeroOfType(builder, returnDataType); + + // If the pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffReturnVal); + + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); + IRReturn* pairReturn = as(builder->emitReturn(diffPair)); + return InstPair(pairReturn, pairReturn); + } + else + { + // If the return type is not differentiable, emit the primal value only. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + + IRInst* primalReturn = builder->emitReturn(primalReturnVal); + return InstPair(primalReturn, nullptr); + + } + } + + // Since int/float literals are sometimes nested inside an IRConstructor + // instruction, we check to make sure that the nested instr is a constant + // and then return nullptr. Literals do not need to be differentiated. + // + InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) + { + IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct); + + // Check if the output type can be differentiated. If it cannot be + // differentiated, don't differentiate the inst + // + auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType()); + if (auto diffConstructType = differentiateType(builder, primalConstructType)) + { + UCount operandCount = origConstruct->getOperandCount(); + + List diffOperands; + for (UIndex ii = 0; ii < operandCount; ii++) + { + // If the operand has a differential version, replace the original with + // the differential. Otherwise, use a zero. + // + if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr)) + diffOperands.add(diffInst); + else + { + auto operandDataType = origConstruct->getOperand(ii)->getDataType(); + operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType); + diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); + } + } + + return InstPair( + primalConstruct, + builder->emitIntrinsicInst( + diffConstructType, + origConstruct->getOp(), + operandCount, + diffOperands.getBuffer())); + } + else + { + return InstPair(primalConstruct, nullptr); + } + } + + // Differentiating a call instruction here is primarily about generating + // an appropriate call list based on whichever parameters have differentials + // in the current transcription context. + // + InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) + { + + IRInst* origCallee = origCall->getCallee(); + + if (!origCallee) + { + // Note that this can only happen if the callee is a result + // of a higher-order operation. For now, we assume that we cannot + // differentiate such calls safely. + // TODO(sai): Should probably get checked in the front-end. + // + getSink()->diagnose(origCall->sourceLoc, + Diagnostics::internalCompilerError, + "attempting to differentiate unresolved callee"); + + return InstPair(nullptr, nullptr); + } + + // Since concrete functions are globals, the primal callee is the same + // as the original callee. + // + auto primalCallee = origCallee; + + IRInst* diffCallee = nullptr; + + if (auto derivativeReferenceDecor = primalCallee->findDecoration()) + { + // If the user has already provided an differentiated implementation, use that. + diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc(); + } + else if (primalCallee->findDecoration()) + { + // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass + // to generate the implementation. + diffCallee = builder->emitForwardDifferentiateInst( + differentiateFunctionType(builder, as(primalCallee->getFullType())), + primalCallee); + } + else + { + // The callee is non differentiable, just return primal value with null diff value. + IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); + return InstPair(primalCall, nullptr); + } + + List args; + // Go over the parameter list and create pairs for each input (if required) + for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) + { + auto origArg = origCall->getArg(ii); + auto primalArg = findOrTranscribePrimalInst(builder, origArg); + SLANG_ASSERT(primalArg); + + auto primalType = primalArg->getDataType(); + auto diffArg = findOrTranscribeDiffInst(builder, origArg); + + if (!diffArg) + diffArg = getDifferentialZeroOfType(builder, primalType); + + if (auto pairType = tryGetDiffPairType(builder, primalType)) + { + // If a pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffArg); + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); + args.add(diffPair); + } + else + { + // Add original/primal argument. + args.add(primalArg); + } + } + + IRType* diffReturnType = nullptr; + diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); + + if (!diffReturnType) + { + SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType); + diffReturnType = builder->getVoidType(); + } + + auto callInst = builder->emitCallInst( + diffReturnType, + diffCallee, + args); + + if (diffReturnType->getOp() != kIROp_VoidType) + { + IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst); + auto diffType = differentiateType(builder, origCall->getFullType()); + IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst); + return InstPair(primalResultValue, diffResultValue); + } + else + { + // Return the inst itself if the return value is void. + // This is fine since these values should never actually be used anywhere. + // + return InstPair(callInst, callInst); + } + } + + InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) + { + IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle); + + if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr)) + { + List swizzleIndices; + for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++) + swizzleIndices.add(origSwizzle->getElementIndex(ii)); + + return InstPair( + primalSwizzle, + builder->emitSwizzle( + differentiateType(builder, primalSwizzle->getDataType()), + diffBase, + origSwizzle->getElementCount(), + swizzleIndices.getBuffer())); + } + + return InstPair(primalSwizzle, nullptr); + } + + InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst) + { + IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst); + + UCount operandCount = origInst->getOperandCount(); + + List diffOperands; + for (UIndex ii = 0; ii < operandCount; ii++) + { + // If the operand has a differential version, replace the original with the + // differential. + // Otherwise, abandon the differentiation attempt and assume that origInst + // cannot (or does not need to) be differentiated. + // + if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr)) + diffOperands.add(diffInst); + else + return InstPair(primalInst, nullptr); + } + + return InstPair( + primalInst, + builder->emitIntrinsicInst( + differentiateType(builder, primalInst->getDataType()), + origInst->getOp(), + operandCount, + diffOperands.getBuffer())); + } + + InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst) + { + switch(origInst->getOp()) + { + case kIROp_unconditionalBranch: + case kIROp_loop: + auto origBranch = as(origInst); + + // Grab the differentials for any phi nodes. + List newArgs; + for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) + { + auto origArg = origBranch->getArg(ii); + auto primalArg = lookupPrimalInst(origArg); + newArgs.add(primalArg); + + if (differentiateType(builder, primalArg->getDataType())) + { + auto diffArg = lookupDiffInst(origArg, nullptr); + if (diffArg) + newArgs.add(diffArg); + } + } + + IRInst* diffBranch = nullptr; + if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock())) + { + if (auto origLoop = as(origInst)) + { + auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); + auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); + List operands; + operands.add(breakBlock); + operands.add(continueBlock); + operands.addRange(newArgs); + diffBranch = builder->emitIntrinsicInst( + nullptr, + kIROp_loop, + operands.getCount(), + operands.getBuffer()); + } + else + { + diffBranch = builder->emitBranch( + as(diffBlock), + newArgs.getCount(), + newArgs.getBuffer()); + } + } + + // For now, every block in the original fn must have a corresponding + // block to compute *both* primals and derivatives (i.e linearized block) + SLANG_ASSERT(diffBranch); + + return InstPair(diffBranch, diffBranch); + } + + getSink()->diagnose( + origInst->sourceLoc, + Diagnostics::unimplemented, + "attempting to differentiate unhandled control flow"); + + return InstPair(nullptr, nullptr); + } + + InstPair transcribeConst(IRBuilder* builder, IRInst* origInst) + { + switch(origInst->getOp()) + { + case kIROp_FloatLit: + return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f)); + case kIROp_VoidLit: + return InstPair(origInst, origInst); + case kIROp_IntLit: + return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0)); + } + + getSink()->diagnose( + origInst->sourceLoc, + Diagnostics::unimplemented, + "attempting to differentiate unhandled const type"); + + return InstPair(nullptr, nullptr); + } + + InstPair transcribeSpecialize(IRBuilder*, IRSpecialize* origSpecialize) + { + // In general, we should not see any specialize insts at this stage. + // The exceptions are target intrinsics. + auto genericInnerVal = findInnerMostGenericReturnVal(as(origSpecialize->getBase())); + if (genericInnerVal->findDecoration()) + { + // Look for an IRForwardDerivativeDecoration on the specialize inst. + // (Normally, this would be on the inner IRFunc, but in this case only the JVP func + // can be specialized, so we put a decoration on the IRSpecialize) + // + if (auto jvpFuncDecoration = origSpecialize->findDecoration()) + { + auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc(); + + // Make sure this isn't itself a specialize . + SLANG_RELEASE_ASSERT(!as(jvpFunc)); + + return InstPair(jvpFunc, jvpFunc); + } + } + else + { + getSink()->diagnose(origSpecialize->sourceLoc, + Diagnostics::unexpected, + "should not be attempting to differentiate anything specialized here."); + } + + return InstPair(nullptr, nullptr); + } + + InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup) + { + // This is slightly counter-intuitive, but we don't perform any differentiation + // logic here. We simple clone the original lookup which points to the original function, + // or the cloned version in case we're inside a generic scope. + // The differentiation logic is inserted later when this is used in an IRCall. + // This decision is mostly to maintain a uniform convention of ForwardDifferentiate(Lookup(Table)) + // rather than have Lookup(ForwardDifferentiate(Table)) + // + auto diffLookup = cloneInst(&cloneEnv, builder, origLookup); + return InstPair(diffLookup, diffLookup); + } + + // In differential computation, the 'default' differential value is always zero. + // This is a consequence of differential computing being inherently linear. As a + // result, it's useful to have a method to generate zero literals of any (arithmetic) type. + // The current implementation requires that types are defined linearly. + // + IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) + { + if (auto diffType = differentiateType(builder, primalType)) + { + switch (diffType->getOp()) + { + case kIROp_DifferentialPairType: + return builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as(diffType)->getValueType())); + } + // Since primalType has a corresponding differential type, we can lookup the + // definition for zero(). + auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); + SLANG_ASSERT(zeroMethod); + + auto emptyArgList = List(); + return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + } + else + { + if (isScalarIntegerType(primalType)) + { + return builder->getIntValue(primalType, 0); + } + + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; + } + } + + InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) + { + IRBuilder subBuilder(builder->getSharedBuilder()); + subBuilder.setInsertLoc(builder->getInsertLoc()); + + IRInst* diffBlock = subBuilder.emitBlock(); + + // Note: for blocks, we setup the mapping _before_ + // processing the children since we could encounter + // a lookup while processing the children. + // + mapPrimalInst(origBlock, diffBlock); + mapDifferentialInst(origBlock, diffBlock); + + subBuilder.setInsertInto(diffBlock); + + // First transcribe every parameter in the block. + for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) + this->transcribe(&subBuilder, param); + + // Then, run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + this->transcribe(&subBuilder, child); + + return InstPair(diffBlock, diffBlock); + } + + InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) + { + SLANG_ASSERT(as(originalInst) || as(originalInst)); + + IRInst* origBase = originalInst->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto field = originalInst->getOperand(1); + auto derivativeRefDecor = field->findDecoration(); + auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType()); + + IRInst* primalOperands[] = { primalBase, field }; + IRInst* primalFieldExtract = builder->emitIntrinsicInst( + primalType, + originalInst->getOp(), + 2, + primalOperands); + + if (!derivativeRefDecor) + { + return InstPair(primalFieldExtract, nullptr); + } + + IRInst* diffFieldExtract = nullptr; + + if (auto diffType = differentiateType(builder, primalType)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() }; + diffFieldExtract = builder->emitIntrinsicInst( + diffType, + originalInst->getOp(), + 2, + diffOperands); + } + } + return InstPair(primalFieldExtract, diffFieldExtract); + } + + InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) + { + SLANG_ASSERT(as(origGetElementPtr) || as(origGetElementPtr)); + + IRInst* origBase = origGetElementPtr->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1)); + + auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType()); + + IRInst* primalOperands[] = {primalBase, primalIndex}; + IRInst* primalGetElementPtr = builder->emitIntrinsicInst( + primalType, + origGetElementPtr->getOp(), + 2, + primalOperands); + + IRInst* diffGetElementPtr = nullptr; + + if (auto diffType = differentiateType(builder, primalType)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) + { + IRInst* diffOperands[] = {diffBase, primalIndex}; + diffGetElementPtr = builder->emitIntrinsicInst( + diffType, + origGetElementPtr->getOp(), + 2, + diffOperands); + } + } + + return InstPair(primalGetElementPtr, diffGetElementPtr); + } + + InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop) + { + // The loop comes with three blocks.. we just need to transcribe each one + // and assemble the new loop instruction. + + // Transcribe the target block (this is the 'condition' part of the loop, which + // will branch into the loop body) + auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock()); + + // Transcribe the break block (this is the block after the exiting the loop) + auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); + + // Transcribe the continue block (this is the 'update' part of the loop, which will + // branch into the condition block) + auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); + + + List diffLoopOperands; + diffLoopOperands.add(diffTargetBlock); + diffLoopOperands.add(diffBreakBlock); + diffLoopOperands.add(diffContinueBlock); + + // If there are any other operands, use their primal versions. + for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++) + { + auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii)); + diffLoopOperands.add(primalOperand); + } + + IRInst* diffLoop = builder->emitIntrinsicInst( + nullptr, + kIROp_loop, + diffLoopOperands.getCount(), + diffLoopOperands.getBuffer()); + + return InstPair(diffLoop, diffLoop); + } + + InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) + { + // IfElse Statements come with 4 blocks. We transcribe each block into it's + // linear form, and then wire them up in the same way as the original if-else + + // Transcribe condition block + auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition()); + SLANG_ASSERT(primalConditionBlock); + + // Transcribe 'true' block (condition block branches into this if true) + auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock()); + SLANG_ASSERT(diffTrueBlock); + + // Transcribe 'false' block (condition block branches into this if true) + // TODO (sai): What happens if there's no false block? + auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock()); + SLANG_ASSERT(diffFalseBlock); + + // Transcribe 'after' block (true and false blocks branch into this) + auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock()); + SLANG_ASSERT(diffAfterBlock); + + List diffIfElseArgs; + diffIfElseArgs.add(primalConditionBlock); + diffIfElseArgs.add(diffTrueBlock); + diffIfElseArgs.add(diffFalseBlock); + diffIfElseArgs.add(diffAfterBlock); + + // If there are any other operands, use their primal versions. + for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++) + { + auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii)); + diffIfElseArgs.add(primalOperand); + } + + IRInst* diffLoop = builder->emitIntrinsicInst( + nullptr, + kIROp_ifElse, + diffIfElseArgs.getCount(), + diffIfElseArgs.getBuffer()); + + return InstPair(diffLoop, diffLoop); + } + + InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) + { + auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue()); + SLANG_ASSERT(primalVal); + auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue()); + SLANG_ASSERT(diffPrimalVal); + auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue()); + SLANG_ASSERT(primalDiffVal); + auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); + SLANG_ASSERT(diffDiffVal); + + auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal); + auto diffPair = builder->emitMakeDifferentialPair( + differentiateType(builder, origInst->getDataType()), + primalDiffVal, + diffDiffVal); + return InstPair(primalPair, diffPair); + } + + InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) + { + SLANG_ASSERT( + origInst->getOp() == kIROp_DifferentialPairGetDifferential || + origInst->getOp() == kIROp_DifferentialPairGetPrimal); + + auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0)); + SLANG_ASSERT(primalVal); + + auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0)); + SLANG_ASSERT(diffVal); + + auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal); + + auto diffValPairType = as(diffVal->getDataType()); + IRInst* diffResultType = nullptr; + if (origInst->getOp() == kIROp_DifferentialPairGetDifferential) + diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType); + else + diffResultType = diffValPairType->getValueType(); + auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal); + return InstPair(primalResult, diffResult); + } + + // Create an empty func to represent the transcribed func of `origFunc`. + InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) + { + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origFunc); + + IRFunc* primalFunc = origFunc; + + differentiableTypeConformanceContext.setFunc(origFunc); + + primalFunc = origFunc; + + auto diffFunc = builder.createFunc(); + + SLANG_ASSERT(as(origFunc->getFullType())); + IRType* diffFuncType = this->differentiateFunctionType( + &builder, + as(origFunc->getFullType())); + diffFunc->setFullType(diffFuncType); + + if (auto nameHint = origFunc->findDecoration()) + { + auto originalName = nameHint->getName(); + StringBuilder newNameSb; + newNameSb << "s_fwd_" << originalName; + builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); + } + builder.addForwardDerivativeDecoration(origFunc, diffFunc); + + // Mark the generated derivative function itself as differentiable. + builder.addForwardDifferentiableDecoration(diffFunc); + + // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. + if (auto dictDecor = origFunc->findDecoration()) + { + cloneDecoration(dictDecor, diffFunc); + } + + auto result = InstPair(primalFunc, diffFunc); + followUpFunctionsToTranscribe.add(result); + return result; + } + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) + { + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertInto(diffFunc); + + differentiableTypeConformanceContext.setFunc(primalFunc); + // Transcribe children from origFunc into diffFunc + for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribe(&builder, block); + + return InstPair(primalFunc, diffFunc); + } + + // Transcribe a generic definition + InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) + { + auto innerVal = findInnerMostGenericReturnVal(origGeneric); + if (auto innerFunc = as(innerVal)) + { + differentiableTypeConformanceContext.setFunc(innerFunc); + } + else + { + return InstPair(origGeneric, nullptr); + } + + // For now, we assume there's only one generic layer. So this inst must be top level + bool isTopLevel = (as(origGeneric->getParent()) != nullptr); + SLANG_RELEASE_ASSERT(isTopLevel); + + IRGeneric* primalGeneric = origGeneric; + + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origGeneric); + + auto diffGeneric = builder.emitGeneric(); + + // Process type of generic. If the generic is a function, then it's type will also be a + // generic and this logic will transcribe that generic first before continuing with the + // function itself. + // + auto primalType = primalGeneric->getFullType(); + + IRType* diffType = nullptr; + if (primalType) + { + diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType); + } + + diffGeneric->setFullType(diffType); + + // TODO(sai): Replace naming scheme + // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) + // builder->addNameHintDecoration(diffFunc, jvpName); + + // Transcribe children from origFunc into diffFunc. + builder.setInsertInto(diffGeneric); + for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribe(&builder, block); + + return InstPair(primalGeneric, diffGeneric); + } + + IRInst* transcribe(IRBuilder* builder, IRInst* origInst) + { + // If a differential intstruction is already mapped for + // this original inst, return that. + // + if (auto diffInst = lookupDiffInst(origInst, nullptr)) + { + SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check. + return diffInst; + } + + // Otherwise, dispatch to the appropriate method + // depending on the op-code. + // + instsInProgress.Add(origInst); + InstPair pair = transcribeInst(builder, origInst); + + if (auto primalInst = pair.primal) + { + mapPrimalInst(origInst, pair.primal); + mapDifferentialInst(origInst, pair.differential); + if (pair.differential) + { + switch (pair.differential->getOp()) + { + case kIROp_Func: + case kIROp_Generic: + case kIROp_Block: + // Don't generate again for these. + // Functions already have their names generated in `transcribeFuncHeader`. + break; + default: + // Generate name hint for the inst. + if (auto primalNameHint = primalInst->findDecoration()) + { + StringBuilder sb; + sb << "s_diff_" << primalNameHint->getName(); + builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); + } + break; + } + } + return pair.differential; + } + instsInProgress.Remove(origInst); + + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::internalCompilerError, + "failed to transcibe instruction"); + return nullptr; + } + + InstPair transcribeInst(IRBuilder* builder, IRInst* origInst) + { + // Handle common SSA-style operations + switch (origInst->getOp()) + { + case kIROp_Param: + return transcribeParam(builder, as(origInst)); + + case kIROp_Var: + return transcribeVar(builder, as(origInst)); + + case kIROp_Load: + return transcribeLoad(builder, as(origInst)); + + case kIROp_Store: + return transcribeStore(builder, as(origInst)); + + case kIROp_Return: + return transcribeReturn(builder, as(origInst)); + + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + case kIROp_Div: + return transcribeBinaryArith(builder, origInst); + + case kIROp_Less: + case kIROp_Greater: + case kIROp_And: + case kIROp_Or: + case kIROp_Geq: + case kIROp_Leq: + return transcribeBinaryLogic(builder, origInst); + + case kIROp_Construct: + return transcribeConstruct(builder, origInst); + + case kIROp_Call: + return transcribeCall(builder, as(origInst)); + + case kIROp_swizzle: + return transcribeSwizzle(builder, as(origInst)); + + case kIROp_constructVectorFromScalar: + case kIROp_MakeTuple: + return transcribeByPassthrough(builder, origInst); + + case kIROp_unconditionalBranch: + return transcribeControlFlow(builder, origInst); + + case kIROp_FloatLit: + case kIROp_IntLit: + case kIROp_VoidLit: + return transcribeConst(builder, origInst); + + case kIROp_Specialize: + return transcribeSpecialize(builder, as(origInst)); + + case kIROp_lookup_interface_method: + return transcibeLookupInterfaceMethod(builder, as(origInst)); + + case kIROp_FieldExtract: + case kIROp_FieldAddress: + return transcribeFieldExtract(builder, origInst); + case kIROp_getElement: + case kIROp_getElementPtr: + return transcribeGetElement(builder, origInst); + + case kIROp_loop: + return transcribeLoop(builder, as(origInst)); + + case kIROp_ifElse: + return transcribeIfElse(builder, as(origInst)); + + case kIROp_MakeDifferentialPair: + return transcribeMakeDifferentialPair(builder, as(origInst)); + case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferential: + return transcribeDifferentialPairGetElement(builder, origInst); + } + + // If none of the cases have been hit, check if the instruction is a + // type. Only need to explicitly differentiate types if they appear inside a block. + // + if (auto origType = as(origInst)) + { + // If this is a generic type, transcibe the parent + // generic and derive the type from the transcribed generic's + // return value. + // + if (as(origType->getParent()->getParent()) && + findInnerMostGenericReturnVal(as(origType->getParent()->getParent())) == origType && + !instsInProgress.Contains(origType->getParent()->getParent())) + { + auto origGenericType = origType->getParent()->getParent(); + auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType); + auto innerDiffGenericType = findInnerMostGenericReturnVal(as(diffGenericType)); + return InstPair( + origGenericType, + innerDiffGenericType + ); + } + else if (as(origType->getParent())) + return InstPair( + cloneInst(&cloneEnv, builder, origType), + differentiateType(builder, origType)); + else + return InstPair( + cloneInst(&cloneEnv, builder, origType), + nullptr); + } + + // Handle instructions with children + switch (origInst->getOp()) + { + case kIROp_Func: + return transcribeFuncHeader(builder, as(origInst)); + + case kIROp_Block: + return transcribeBlock(builder, as(origInst)); + + case kIROp_Generic: + return transcribeGeneric(builder, as(origInst)); + } + + + // If we reach this statement, the instruction type is likely unhandled. + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::unimplemented, + "this instruction cannot be differentiated"); + + return InstPair(nullptr, nullptr); + } +}; + +struct ForwardDerivativePass : public InstPassBase +{ + + DiagnosticSink* getSink() + { + return sink; + } + + bool processModule() + { + // TODO(sai): Move this call. + transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary(); + + IRBuilder builderStorage(this->autodiffContext->sharedBuilder); + IRBuilder* builder = &builderStorage; + + // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by + // generating derivative code for the referenced function. + // + bool modified = processReferencedFunctions(builder); + + return modified; + } + + IRInst* lookupJVPReference(IRInst* primalFunction) + { + if (auto jvpDefinition = primalFunction->findDecoration()) + return jvpDefinition->getForwardDerivativeFunc(); + return nullptr; + } + + // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate), + // then check that the referenced function is marked correctly for differentiation. + // + bool processReferencedFunctions(IRBuilder* builder) + { + List autoDiffWorkList; + + for (;;) + { + // Collect all `ForwardDifferentiate` insts from the module. + autoDiffWorkList.clear(); + processAllInsts([&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_ForwardDifferentiate: + // Only process now if the operand is a materialized function. + switch (inst->getOperand(0)->getOp()) + { + case kIROp_Func: + case kIROp_Specialize: + autoDiffWorkList.add(inst); + break; + default: + break; + } + break; + default: + break; + } + }); + + if (autoDiffWorkList.getCount() == 0) + break; + + // Process collected `ForwardDifferentiate` insts and replace them with placeholders for + // differentiated functions. + + transcriberStorage.followUpFunctionsToTranscribe.clear(); + + for (auto differentiateInst : autoDiffWorkList) + { + IRInst* baseInst = differentiateInst->getOperand(0); + if (as(differentiateInst)) + { + if (auto existingDiffFunc = lookupJVPReference(baseInst)) + { + differentiateInst->replaceUsesWith(existingDiffFunc); + differentiateInst->removeAndDeallocate(); + } + else if (isMarkedForForwardDifferentiation(baseInst)) + { + if (as(baseInst) || as(baseInst)) + { + IRInst* diffFunc = transcriberStorage.transcribe(builder, baseInst); + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + } + else + { + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); + } + } + else + { + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Requested differentiation on a function that isn't marked as differentiable."); + } + + } + } + // Actually synthesize the derivatives. + List followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe); + for (auto task : followUpWorkList) + { + auto diffFunc = as(task.differential); + SLANG_ASSERT(diffFunc); + auto primalFunc = as(task.primal); + SLANG_ASSERT(primalFunc); + + transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc); + } + + // Transcribing the function body really shouldn't produce more follow up function body work. + // However it may produce new `ForwardDifferentiate` instructions, which we collect and process + // in the next iteration. + SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0); + + } + return true; + } + + // Checks decorators to see if the function should + // be differentiated (kIROp_ForwardDifferentiableDecoration) + // + bool isMarkedForForwardDifferentiation(IRInst* callable) + { + return callable->findDecoration() != nullptr; + } + + IRStringLit* getForwardDerivativeFuncName(IRInst* func) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(func); + + IRStringLit* name = nullptr; + if (auto linkageDecoration = func->findDecoration()) + { + name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice()); + } + else if (auto namehintDecoration = func->findDecoration()) + { + name = builder.getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice()); + } + + return name; + } + + ForwardDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) : + InstPassBase(context->moduleInst->getModule()), + sink(sink), + transcriberStorage(context, context->sharedBuilder), + pairBuilderStorage(context), + autodiffContext(context) + { + transcriberStorage.sink = sink; + transcriberStorage.autoDiffSharedContext = context; + transcriberStorage.pairBuilder = &(pairBuilderStorage); + } + +protected: + // A transcriber object that handles the main job of + // processing instructions while maintaining state. + // + JVPTranscriber transcriberStorage; + + // Diagnostic object from the compile request for + // error messages. + DiagnosticSink* sink; + + // Shared context. + AutoDiffSharedContext* autodiffContext; + + // Builder for dealing with differential pair types. + DifferentialPairTypeBuilder pairBuilderStorage; + +}; + +// Set up context and call main process method. +// +bool processForwardDerivativeCalls( + AutoDiffSharedContext* autodiffContext, + DiagnosticSink* sink, + ForwardDerivativePassOptions const&) +{ + ForwardDerivativePass fwdPass(autodiffContext, sink); + bool changed = fwdPass.processModule(); + return changed; +} + + +void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) +{ + parentFunc = func; + + auto decor = func->findDecoration(); + SLANG_RELEASE_ASSERT(decor); + + // Build lookup dictionary for type witnesses. + for (auto child = decor->getFirstChild(); child; child = child->next) + { + if (auto item = as(child)) + { + auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); + if (existingItem) + { + *existingItem = item->getWitness(); + } + else + { + differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); + } + } + } +} + + +// Lookup a witness table for the concreteType. One should exist if concreteType +// inherits (successfully) from IDifferentiable. +// + +IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) +{ + IRInst* foundResult = nullptr; + differentiableWitnessDictionary.TryGetValue(type, foundResult); + return foundResult; +} + +IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) +{ + if (auto conformance = lookUpConformanceForType(origType)) + { + return _lookupWitness(builder, conformance, key); + } + return nullptr; +} + +void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() +{ + for (auto globalInst : sharedContext->moduleInst->getChildren()) + { + if (auto pairType = as(globalInst)) + { + differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness()); + } + } +} +} diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h new file mode 100644 index 000000000..6b261ecd0 --- /dev/null +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -0,0 +1,43 @@ +// slang-ir-autodiff-fwd.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +namespace Slang +{ + + template + struct DiffInstPair + { + P primal; + D differential; + DiffInstPair() = default; + DiffInstPair(P primal, D differential) : primal(primal), differential(differential) + {} + HashCode getHashCode() const + { + Hasher hasher; + hasher << primal << differential; + return hasher.getResult(); + } + bool operator ==(const DiffInstPair& other) const + { + return primal == other.primal && differential == other.differential; + } + }; + + typedef DiffInstPair InstPair; + + struct ForwardDerivativePassOptions + { + // Nothing for now.. + }; + + bool processForwardDerivativeCalls( + AutoDiffSharedContext* autodiffContext, + DiagnosticSink* sink, + ForwardDerivativePassOptions const& options = ForwardDerivativePassOptions()); + +} diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp new file mode 100644 index 000000000..1dbb1bd7c --- /dev/null +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -0,0 +1,182 @@ +#include "slang-ir-autodiff-pairs.h" + +namespace Slang +{ + +struct DiffPairLoweringPass : InstPassBase +{ + DiffPairLoweringPass(AutoDiffSharedContext* context) : + InstPassBase(context->moduleInst->getModule()), + pairBuilderStorage(context), + autodiffContext(context) + { + pairBuilder = &pairBuilderStorage; + } + + IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr) + { + builder->setInsertBefore(pairType); + auto loweredPairTypeInfo = pairBuilder->lowerDiffPairType( + builder, + pairType); + if (isTrivial) + *isTrivial = loweredPairTypeInfo.isTrivial; + return loweredPairTypeInfo.loweredType; + } + + IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) + { + + if (auto makePairInst = as(inst)) + { + bool isTrivial = false; + auto pairType = as(makePairInst->getDataType()); + if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial)) + { + builder->setInsertBefore(makePairInst); + IRInst* result = nullptr; + if (isTrivial) + { + result = makePairInst->getPrimalValue(); + } + else + { + IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() }; + result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands); + } + makePairInst->replaceUsesWith(result); + makePairInst->removeAndDeallocate(); + return result; + } + } + + return nullptr; + } + + IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst) + { + if (auto getDiffInst = as(inst)) + { + auto pairType = getDiffInst->getBase()->getDataType(); + if (auto pairPtrType = as(pairType)) + { + pairType = pairPtrType->getValueType(); + } + + if (lowerPairType(builder, pairType, nullptr)) + { + builder->setInsertBefore(getDiffInst); + IRInst* diffFieldExtract = nullptr; + diffFieldExtract = pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase()); + getDiffInst->replaceUsesWith(diffFieldExtract); + getDiffInst->removeAndDeallocate(); + return diffFieldExtract; + } + } + else if (auto getPrimalInst = as(inst)) + { + auto pairType = getPrimalInst->getBase()->getDataType(); + if (auto pairPtrType = as(pairType)) + { + pairType = pairPtrType->getValueType(); + } + + if (lowerPairType(builder, pairType, nullptr)) + { + builder->setInsertBefore(getPrimalInst); + + IRInst* primalFieldExtract = nullptr; + primalFieldExtract = pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + getPrimalInst->replaceUsesWith(primalFieldExtract); + getPrimalInst->removeAndDeallocate(); + return primalFieldExtract; + } + } + + return nullptr; + } + + bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren) + { + bool modified = false; + // Hoist all pair types to global scope when possible. + auto moduleInst = module->getModuleInst(); + processInstsOfType(kIROp_DifferentialPairType, [&](IRInst* originalPairType) + { + if (originalPairType->parent != moduleInst) + { + originalPairType->removeFromParent(); + ShortList operands; + for (UInt i = 0; i < originalPairType->getOperandCount(); i++) + { + operands.add(originalPairType->getOperand(i)); + } + auto newPairType = builder->findOrEmitHoistableInst( + originalPairType->getFullType(), + originalPairType->getOp(), + originalPairType->getOperandCount(), + operands.getArrayView().getBuffer()); + originalPairType->replaceUsesWith(newPairType); + originalPairType->removeAndDeallocate(); + } + }); + + autodiffContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + + processAllInsts([&](IRInst* inst) + { + // Make sure the builder is at the right level. + builder->setInsertInto(instWithChildren); + + switch (inst->getOp()) + { + case kIROp_DifferentialPairGetDifferential: + case kIROp_DifferentialPairGetPrimal: + lowerPairAccess(builder, inst); + modified = true; + break; + + case kIROp_MakeDifferentialPair: + lowerMakePair(builder, inst); + modified = true; + break; + + default: + break; + } + }); + + processInstsOfType(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) + { + if (auto loweredType = lowerPairType(builder, inst)) + { + inst->replaceUsesWith(loweredType); + inst->removeAndDeallocate(); + } + }); + return modified; + } + + bool processModule() + { + IRBuilder builder(autodiffContext->sharedBuilder); + return processInstWithChildren(&builder, module->getModuleInst()); + } + + private: + + AutoDiffSharedContext* autodiffContext; + + DifferentialPairTypeBuilder* pairBuilder; + + DifferentialPairTypeBuilder pairBuilderStorage; + +}; + +bool processPairTypes(AutoDiffSharedContext* context) +{ + DiffPairLoweringPass pairLoweringPass(context); + return pairLoweringPass.processModule(); +} + +} \ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-pairs.h b/source/slang/slang-ir-autodiff-pairs.h new file mode 100644 index 000000000..44321ae9b --- /dev/null +++ b/source/slang/slang-ir-autodiff-pairs.h @@ -0,0 +1,21 @@ +// slang-ir-autodiff-pairs.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" + +#include "slang-ir-autodiff.h" + +namespace Slang +{ + +bool processPairTypes(AutoDiffSharedContext* context); + +} \ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp new file mode 100644 index 000000000..52567e887 --- /dev/null +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -0,0 +1,832 @@ +#include "slang-ir-autodiff-rev.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" + + +namespace Slang +{ +struct BackwardDiffTranscriber +{ + // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent + // their differential values. + Dictionary orginalToTranscribed; + + // Set of insts currently being transcribed. Used to avoid infinite loops. + HashSet instsInProgress; + + // Cloning environment to hold mapping from old to new copies for the primal + // instructions. + IRCloneEnv cloneEnv; + + // Diagnostic sink for error messages. + DiagnosticSink* sink; + + // Type conformance information. + AutoDiffSharedContext* autoDiffSharedContext; + + // Builder to help with creating and accessing the 'DifferentiablePair' struct + DifferentialPairTypeBuilder* pairBuilder; + + DifferentiableTypeConformanceContext differentiableTypeConformanceContext; + + List followUpFunctionsToTranscribe; + + // Map that stores the upper gradient given an IRInst* + Dictionary> upperGradients; + Dictionary primalToDiffPair; + + SharedIRBuilder* sharedBuilder; + // Witness table that `DifferentialBottom:IDifferential`. + IRWitnessTable* differentialBottomWitness = nullptr; + Dictionary differentialPairTypes; + + BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + : autoDiffSharedContext(shared) + , sink(inSink) + , differentiableTypeConformanceContext(shared) + , sharedBuilder(inSharedBuilder) + {} + + DiagnosticSink* getSink() + { + SLANG_ASSERT(sink); + return sink; + } + + IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + { + List newParameterTypes; + IRType* diffReturnType; + + for (UIndex i = 0; i < funcType->getParamCount(); i++) + { + auto origType = funcType->getParamType(i); + if (auto diffPairType = tryGetDiffPairType(builder, origType)) + { + auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); + newParameterTypes.add(inoutDiffPairType); + } + else + newParameterTypes.add(origType); + } + + newParameterTypes.add(funcType->getResultType()); + + diffReturnType = builder->getVoidType(); + + return builder->getFuncType(newParameterTypes, diffReturnType); + } + + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. + IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(inDiffPairType->parent); + auto diffPairType = as(inDiffPairType); + SLANG_ASSERT(diffPairType); + + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + auto diffType = differentiateType(&builder, diffPairType->getValueType()); + auto differentialType = builder.getDifferentialPairType(diffType, nullptr); + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + return table; + } + + IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* getOrCreateDiffPairType(IRInst* primalType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + auto witness = as( + differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); + + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* differentiateType(IRBuilder* builder, IRType* origType) + { + IRInst* diffType = nullptr; + if (!orginalToTranscribed.TryGetValue(origType, diffType)) + { + diffType = _differentiateTypeImpl(builder, origType); + orginalToTranscribed[origType] = diffType; + } + return (IRType*)diffType; + } + + IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) + { + if (auto ptrType = as(origType)) + return builder->getPtrType( + origType->getOp(), + differentiateType(builder, ptrType->getValueType())); + + // If there is an explicit primal version of this type in the local scope, load that + // otherwise use the original type. + // + IRInst* primalType = origType; + + // Special case certain compound types (PtrType, FuncType, etc..) + // otherwise try to lookup a differential definition for the given type. + // If one does not exist, then we assume it's not differentiable. + // + switch (primalType->getOp()) + { + case kIROp_Param: + if (as(primalType->getDataType())) + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( + builder, + (IRType*)primalType)); + else if (as(primalType->getDataType())) + return (IRType*)primalType; + + case kIROp_ArrayType: + { + auto primalArrayType = as(primalType); + if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) + return builder->getArrayType( + diffElementType, + primalArrayType->getElementCount()); + else + return nullptr; + } + + case kIROp_DifferentialPairType: + { + auto primalPairType = as(primalType); + return getOrCreateDiffPairType( + pairBuilder->getDiffTypeFromPairType(builder, primalPairType), + pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_FuncType: + return differentiateFunctionType(builder, as(primalType)); + + case kIROp_OutType: + if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType())) + return builder->getOutType(diffValueType); + else + return nullptr; + + case kIROp_InOutType: + if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType())) + return builder->getInOutType(diffValueType); + else + return nullptr; + + case kIROp_TupleType: + { + auto tupleType = as(primalType); + List diffTypeList; + // TODO: what if we have type parameters here? + for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + diffTypeList.add( + differentiateType(builder, (IRType*)tupleType->getOperand(ii))); + + return builder->getTupleType(diffTypeList); + } + + default: + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); + } + } + + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) + { + // If this is a PtrType (out, inout, etc..), then create diff pair from + // value type and re-apply the appropropriate PtrType wrapper. + // + if (auto origPtrType = as(primalType)) + { + if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) + return builder->getPtrType(primalType->getOp(), diffPairValueType); + else + return nullptr; + } + auto diffType = differentiateType(builder, primalType); + if (diffType) + return (IRType*)getOrCreateDiffPairType(primalType); + return nullptr; + } + + InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) + { + auto primalDataType = origParam->getDataType(); + // Do not differentiate generic type (and witness table) parameters + if (as(primalDataType) || as(primalDataType)) + { + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + IRInst* diffPairParam = builder->emitParam(diffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffPairParam); + + if (auto pairType = as(diffPairParam->getDataType())) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + // If this is an `in/inout DifferentialPair<>` parameter, we can't produce + // its primal and diff parts right now because they would represent a reference + // to a pair field, which doesn't make sense since pair types are considered mutable. + // We encode the result as if the param is non-differentiable, and handle it + // with special care at load/store. + return InstPair(diffPairParam, nullptr); + } + + + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + // Returns "dp" to use as a name hint for parameters. + // If no primal name is available, returns a blank string. + // + String makeDiffPairName(IRInst* origVar) + { + if (auto namehintDecoration = origVar->findDecoration()) + { + return ("dp" + String(namehintDecoration->getName())); + } + + return String(""); + } + + + // In differential computation, the 'default' differential value is always zero. + // This is a consequence of differential computing being inherently linear. As a + // result, it's useful to have a method to generate zero literals of any (arithmetic) type. + // The current implementation requires that types are defined linearly. + // + IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) + { + if (auto diffType = differentiateType(builder, primalType)) + { + switch (diffType->getOp()) + { + case kIROp_DifferentialPairType: + return builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as(diffType)->getValueType())); + } + // Since primalType has a corresponding differential type, we can lookup the + // definition for zero(). + auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); + SLANG_ASSERT(zeroMethod); + + auto emptyArgList = List(); + return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + } + else + { + if (isScalarIntegerType(primalType)) + { + return builder->getIntValue(primalType, 0); + } + + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; + } + } + + InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) + { + IRBuilder subBuilder(builder->getSharedBuilder()); + subBuilder.setInsertLoc(builder->getInsertLoc()); + + IRBlock* diffBlock = subBuilder.emitBlock(); + + subBuilder.setInsertInto(diffBlock); + + // First transcribe every parameter in the block. + for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) + this->copyParam(&subBuilder, param); + + // The extra param for input gradient + auto gradParam = subBuilder.emitParam(as(origBlock->getParent()->getFullType())->getResultType()); + + // Then, run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + this->copyInst(&subBuilder, child); + + auto lastInst = diffBlock->getLastOrdinaryInst(); + List grads = { gradParam }; + upperGradients.Add(lastInst, grads); + for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst()) + { + auto upperGrads = upperGradients.TryGetValue(child); + if (!upperGrads) + continue; + if (upperGrads->getCount() > 1) + { + auto sumGrad = upperGrads->getFirst(); + for (auto i = 1; i < upperGrads->getCount(); i++) + { + sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]); + } + this->transcribeInstBackward(&subBuilder, child, sumGrad); + } + else + this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst()); + } + + subBuilder.emitReturn(); + + return InstPair(diffBlock, diffBlock); + } + + // Create an empty func to represent the transcribed func of `origFunc`. + InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) + { + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origFunc); + + IRFunc* primalFunc = origFunc; + + differentiableTypeConformanceContext.setFunc(origFunc); + + primalFunc = origFunc; + + auto diffFunc = builder.createFunc(); + + SLANG_ASSERT(as(origFunc->getFullType())); + IRType* diffFuncType = this->differentiateFunctionType( + &builder, + as(origFunc->getFullType())); + diffFunc->setFullType(diffFuncType); + + if (auto nameHint = origFunc->findDecoration()) + { + auto originalName = nameHint->getName(); + StringBuilder newNameSb; + newNameSb << "s_bwd_" << originalName; + builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); + } + builder.addBackwardDerivativeDecoration(origFunc, diffFunc); + + // Mark the generated derivative function itself as differentiable. + builder.addBackwardDifferentiableDecoration(diffFunc); + + // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. + if (auto dictDecor = origFunc->findDecoration()) + { + cloneDecoration(dictDecor, diffFunc); + } + + auto result = InstPair(primalFunc, diffFunc); + followUpFunctionsToTranscribe.add(result); + return result; + } + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) + { + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertInto(diffFunc); + + differentiableTypeConformanceContext.setFunc(primalFunc); + // Transcribe children from origFunc into diffFunc + for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribeBlock(&builder, block); + + return InstPair(primalFunc, diffFunc); + } + + IRInst* copyParam(IRBuilder* builder, IRParam* origParam) + { + auto primalDataType = origParam->getDataType(); + + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); + IRInst* diffParam = builder->emitParam(inoutDiffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffParam); + auto paramValue = builder->emitLoad(diffParam); + auto primal = builder->emitDifferentialPairGetPrimal(paramValue); + orginalToTranscribed.Add(origParam, primal); + primalToDiffPair.Add(primal, diffParam); + + return diffParam; + } + + + return cloneInst(&cloneEnv, builder, origParam); + } + + InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith) + { + SLANG_ASSERT(origArith->getOperandCount() == 2); + + auto origLeft = origArith->getOperand(0); + auto origRight = origArith->getOperand(1); + + IRInst* primalLeft; + if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft)) + { + primalLeft = origLeft; + } + IRInst* primalRight; + if (!orginalToTranscribed.TryGetValue(origRight, primalRight)) + { + primalRight = origRight; + } + + auto resultType = origArith->getDataType(); + IRInst* newInst; + switch (origArith->getOp()) + { + case kIROp_Add: + newInst = builder->emitAdd(resultType, primalLeft, primalRight); + break; + case kIROp_Mul: + newInst = builder->emitMul(resultType, primalLeft, primalRight); + break; + case kIROp_Sub: + newInst = builder->emitSub(resultType, primalLeft, primalRight); + break; + case kIROp_Div: + newInst = builder->emitDiv(resultType, primalLeft, primalRight); + break; + default: + newInst = nullptr; + getSink()->diagnose(origArith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); + } + orginalToTranscribed.Add(origArith, newInst); + return InstPair(newInst, nullptr); + } + + IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) + { + SLANG_ASSERT(origArith->getOperandCount() == 2); + + auto lhs = origArith->getOperand(0); + auto rhs = origArith->getOperand(1); + + if (as(lhs->getDataType())) + { + lhs = builder->emitLoad(lhs); + lhs = builder->emitDifferentialPairGetPrimal(lhs); + } + if (as(rhs->getDataType())) + { + rhs = builder->emitLoad(rhs); + rhs = builder->emitDifferentialPairGetPrimal(rhs); + } + + IRInst* leftGrad; + IRInst* rightGrad; + + + switch (origArith->getOp()) + { + case kIROp_Add: + leftGrad = grad; + rightGrad = grad; + break; + case kIROp_Mul: + leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); + rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); + break; + case kIROp_Sub: + leftGrad = grad; + rightGrad = builder->emitNeg(grad->getDataType(), grad); + break; + case kIROp_Div: + leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); + rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad + break; + default: + getSink()->diagnose(origArith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); + } + + lhs = origArith->getOperand(0); + rhs = origArith->getOperand(1); + if (auto leftGrads = upperGradients.TryGetValue(lhs)) + { + leftGrads->add(leftGrad); + } + else + { + upperGradients.Add(lhs, leftGrad); + } + if (auto rightGrads = upperGradients.TryGetValue(rhs)) + { + rightGrads->add(rightGrad); + } + else + { + upperGradients.Add(rhs, rightGrad); + } + + return nullptr; + } + + InstPair copyInst(IRBuilder* builder, IRInst* origInst) + { + // Handle common SSA-style operations + switch (origInst->getOp()) + { + case kIROp_Param: + return transcribeParam(builder, as(origInst)); + + case kIROp_Return: + return InstPair(nullptr, nullptr); + + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + case kIROp_Div: + return copyBinaryArith(builder, origInst); + + default: + // Not yet implemented + SLANG_ASSERT(0); + } + + return InstPair(nullptr, nullptr); + } + + IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) + { + IRInOutType* inoutParam = as(param->getDataType()); + auto pairType = as(inoutParam->getValueType()); + auto paramValue = builder->emitLoad(param); + auto primal = builder->emitDifferentialPairGetPrimal(paramValue); + auto diff = builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + paramValue + ); + auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad); + auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff); + auto store = builder->emitStore(param, updatedParam); + + return store; + } + + IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) + { + // Handle common SSA-style operations + switch (origInst->getOp()) + { + case kIROp_Param: + return transcribeParamBackward(builder, as(origInst), grad); + + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + case kIROp_Div: + return transcribeBinaryArithBackward(builder, origInst, grad); + + case kIROp_DifferentialPairGetPrimal: + { + if (auto param = primalToDiffPair.TryGetValue(origInst)) + { + if (auto leftGrads = upperGradients.TryGetValue(*param)) + { + leftGrads->add(grad); + } + else + { + upperGradients.Add(*param, grad); + } + } + else + SLANG_ASSERT(0); + return nullptr; + } + + default: + // Not yet implemented + SLANG_ASSERT(0); + } + + return nullptr; + } + + +}; + +struct ReverseDerivativePass : public InstPassBase +{ + + DiagnosticSink* getSink() + { + return sink; + } + + bool processModule() + { + + IRBuilder builderStorage(autodiffContext->sharedBuilder); + IRBuilder* builder = &builderStorage; + + // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by + // generating derivative code for the referenced function. + // + bool modified = processReferencedFunctions(builder); + + return modified; + } + + IRInst* lookupJVPReference(IRInst* primalFunction) + { + if (auto jvpDefinition = primalFunction->findDecoration()) + return jvpDefinition->getForwardDerivativeFunc(); + return nullptr; + } + + // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate), + // then check that the referenced function is marked correctly for differentiation. + // + bool processReferencedFunctions(IRBuilder* builder) + { + List autoDiffWorkList; + + for (;;) + { + // Collect all `ForwardDifferentiate` insts from the module. + autoDiffWorkList.clear(); + processAllInsts([&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_BackwardDifferentiate: + // Only process now if the operand is a materialized function. + switch (inst->getOperand(0)->getOp()) + { + case kIROp_Func: + case kIROp_Specialize: + autoDiffWorkList.add(inst); + break; + default: + break; + } + break; + default: + break; + } + }); + + if (autoDiffWorkList.getCount() == 0) + break; + + // Process collected `ForwardDifferentiate` insts and replace them with placeholders for + // differentiated functions. + + backwardTranscriberStorage.followUpFunctionsToTranscribe.clear(); + + for (auto differentiateInst : autoDiffWorkList) + { + IRInst* baseInst = differentiateInst->getOperand(0); + if (as(differentiateInst)) + { + if (isMarkedForBackwardDifferentiation(baseInst)) + { + if (as(baseInst)) + { + IRInst* diffFunc = + backwardTranscriberStorage + .transcribeFuncHeader(builder, (IRFunc*)baseInst) + .differential; + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + } + else + { + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); + } + } + } + } + + auto followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe); + for (auto task : followUpWorkList) + { + auto diffFunc = as(task.differential); + SLANG_ASSERT(diffFunc); + auto primalFunc = as(task.primal); + SLANG_ASSERT(primalFunc); + + backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc); + } + + // Transcribing the function body really shouldn't produce more follow up function body work. + // However it may produce new `ForwardDifferentiate` instructions, which we collect and process + // in the next iteration. + SLANG_RELEASE_ASSERT(backwardTranscriberStorage.followUpFunctionsToTranscribe.getCount() == 0); + + } + return true; + } + + // Checks decorators to see if the function should + // be differentiated (kIROp_ForwardDifferentiableDecoration) + // + bool isMarkedForBackwardDifferentiation(IRInst* callable) + { + return callable->findDecoration() != nullptr; + } + + IRStringLit* getBackwardDerivativeFuncName(IRInst* func) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(func); + + IRStringLit* name = nullptr; + if (auto linkageDecoration = func->findDecoration()) + { + name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice()); + } + else if (auto namehintDecoration = func->findDecoration()) + { + name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice()); + } + + return name; + } + + ReverseDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) : + InstPassBase(context->moduleInst->getModule()), + sink(sink), + backwardTranscriberStorage(context, context->sharedBuilder, sink), + autodiffContext(context), + pairBuilderStorage(context) + { + backwardTranscriberStorage.pairBuilder = &pairBuilderStorage; + } + +protected: + // A transcriber object that handles the main job of + // processing instructions while maintaining state. + // + BackwardDiffTranscriber backwardTranscriberStorage; + + // Diagnostic object from the compile request for + // error messages. + DiagnosticSink* sink; + + // Builder for dealing with differential pair types. + DifferentialPairTypeBuilder pairBuilderStorage; + + // Autodiff Shared Context + AutoDiffSharedContext* autodiffContext; +}; + +bool processReverseDerivativeCalls( + AutoDiffSharedContext* autodiffContext, + DiagnosticSink* sink, + IRReverseDerivativePassOptions const&) +{ + ReverseDerivativePass revPass(autodiffContext, sink); + bool changed = revPass.processModule(); + return changed; +} + +} \ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h new file mode 100644 index 000000000..c3d31e2a9 --- /dev/null +++ b/source/slang/slang-ir-autodiff-rev.h @@ -0,0 +1,25 @@ +// slang-ir-autodiff-rev.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-fwd.h" + +namespace Slang +{ + +struct IRReverseDerivativePassOptions +{ + // Nothing for now.. +}; + +bool processReverseDerivativeCalls( + AutoDiffSharedContext* autodiffContext, + DiagnosticSink* sink, + IRReverseDerivativePassOptions const& options = IRReverseDerivativePassOptions()); + + +} \ No newline at end of file diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp new file mode 100644 index 000000000..313760d85 --- /dev/null +++ b/source/slang/slang-ir-autodiff.cpp @@ -0,0 +1,408 @@ +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-rev.h" +#include "slang-ir-autodiff-fwd.h" +#include "slang-ir-autodiff-pairs.h" + +namespace Slang +{ + +// TODO: Put into a nameless namespace. +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +{ + if (auto witnessTable = as(witness)) + { + for (auto entry : witnessTable->getEntries()) + { + if (entry->getRequirementKey() == requirementKey) + return entry->getSatisfyingVal(); + } + } + else if (auto witnessTableParam = as(witness)) + { + return builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + witnessTableParam, + requirementKey); + } + return nullptr; +} + +IRStructField* DifferentialPairTypeBuilder::findField(IRInst* type, IRStructKey* key) +{ + if (auto irStructType = as(type)) + { + for (auto field : irStructType->getFields()) + { + if (field->getKey() == key) + { + return field; + } + } + } + else if (auto irSpecialize = as(type)) + { + if (auto irGeneric = as(irSpecialize->getBase())) + { + if (auto irGenericStructType = as(findInnerMostGenericReturnVal(irGeneric))) + { + return findField(irGenericStructType, key); + } + } + } + + return nullptr; +} + +IRInst* DifferentialPairTypeBuilder::findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam) +{ + // Get base generic that's being specialized. + auto genericType = as(as(specializeInst)->getBase()); + SLANG_ASSERT(genericType); + + // Find the index of genericParam in the base generic. + int paramIndex = -1; + int currentIndex = 0; + for (auto param : genericType->getParams()) + { + if (param == genericParam) + paramIndex = currentIndex; + currentIndex ++; + } + + SLANG_ASSERT(paramIndex >= 0); + + // Return the corresponding operand in the specialization inst. + return specializeInst->getOperand(1 + paramIndex); +} + +IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) +{ + IRInst* pairType = nullptr; + if (auto basePtrType = as(baseInst->getDataType())) + { + auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType()); + + // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types, + // especially since we probably don't need diff bottom anymore. + // + SLANG_ASSERT(!baseTypeInfo.isTrivial); + + pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType); + } + else + { + auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); + pairType = baseTypeInfo.loweredType; + } + + if (auto basePairStructType = as(pairType)) + { + return as(builder->emitFieldExtract( + findField(basePairStructType, key)->getFieldType(), + baseInst, + key + )); + } + else if (auto ptrType = as(pairType)) + { + if (auto ptrInnerSpecializedType = as(ptrType->getValueType())) + { + auto genericType = findInnerMostGenericReturnVal(as(ptrInnerSpecializedType->getBase())); + if (auto genericBasePairStructType = as(genericType)) + { + return as(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findSpecializationForParam( + ptrInnerSpecializedType, + findField(ptrInnerSpecializedType, key)->getFieldType())), + baseInst, + key + )); + } + } + else if (auto ptrBaseStructType = as(ptrType->getValueType())) + { + return as(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findField(ptrBaseStructType, key)->getFieldType()), + baseInst, + key)); + } + } + else if (auto specializedType = as(pairType)) + { + // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's + // type, emit the specialization type. + // + auto genericType = findInnerMostGenericReturnVal(as(specializedType->getBase())); + if (auto genericBasePairStructType = as(genericType)) + { + return as(builder->emitFieldExtract( + (IRType*)findSpecializationForParam( + specializedType, + findField(genericBasePairStructType, key)->getFieldType()), + baseInst, + key + )); + } + else if (auto genericPtrType = as(genericType)) + { + if (auto genericPairStructType = as(genericPtrType->getValueType())) + { + return as(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findSpecializationForParam( + specializedType, + findField(genericPairStructType, key)->getFieldType())), + baseInst, + key + )); + } + } + } + else + { + SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor"); + } + return nullptr; +} + +IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) +{ + return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); +} + +IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) +{ + return emitFieldAccessor(builder, baseInst, this->globalDiffKey); +} + +IRStructKey* DifferentialPairTypeBuilder::_getOrCreateDiffStructKey() +{ + if (!this->globalDiffKey) + { + IRBuilder builder(sharedContext->sharedBuilder); + // Insert directly at top level (skip any generic scopes etc.) + builder.setInsertInto(sharedContext->moduleInst); + + this->globalDiffKey = builder.createStructKey(); + builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential")); + } + + return this->globalDiffKey; +} + +IRStructKey* DifferentialPairTypeBuilder::_getOrCreatePrimalStructKey() +{ + if (!this->globalPrimalKey) + { + // Insert directly at top level (skip any generic scopes etc.) + IRBuilder builder(sharedContext->sharedBuilder); + builder.setInsertInto(sharedContext->moduleInst); + + this->globalPrimalKey = builder.createStructKey(); + builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal")); + } + + return this->globalPrimalKey; +} + +IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, IRType* diffType) +{ + switch (origBaseType->getOp()) + { + case kIROp_lookup_interface_method: + case kIROp_Specialize: + case kIROp_Param: + return nullptr; + default: + break; + } + + IRBuilder builder(sharedContext->sharedBuilder); + builder.setInsertBefore(diffType); + + auto pairStructType = builder.createStructType(); + builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType); + builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType); + return pairStructType; +} + +IRInst* DifferentialPairTypeBuilder::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); +} + +IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); +} + +DifferentialPairTypeBuilder::LoweredPairTypeInfo DifferentialPairTypeBuilder::lowerDiffPairType( + IRBuilder* builder, IRType* originalPairType) +{ + LoweredPairTypeInfo result = {}; + + if (pairTypeCache.TryGetValue(originalPairType, result)) + return result; + auto pairType = as(originalPairType); + if (!pairType) + { + result.isTrivial = true; + result.loweredType = originalPairType; + return result; + } + auto primalType = pairType->getValueType(); + if (as(primalType)) + { + result.isTrivial = false; + result.loweredType = nullptr; + return result; + } + + auto diffType = getDiffTypeFromPairType(builder, pairType); + if (!diffType) + return result; + result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); + result.isTrivial = false; + pairTypeCache.Add(originalPairType, result); + + return result; +} + + +AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst) + : moduleInst(inModuleInst) +{ + differentiableInterfaceType = as(findDifferentiableInterface()); + if (differentiableInterfaceType) + { + differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); + zeroMethodStructKey = findZeroMethodStructKey(); + addMethodStructKey = findAddMethodStructKey(); + mulMethodStructKey = findMulMethodStructKey(); + + if (differentialAssocTypeStructKey) + isInterfaceAvailable = true; + } +} + +IRInst* AutoDiffSharedContext::findDifferentiableInterface() +{ + if (auto module = as(moduleInst)) + { + for (auto globalInst : module->getGlobalInsts()) + { + // TODO: This seems like a particularly dangerous way to look for an interface. + // See if we can lower IDifferentiable to a separate IR inst. + // + if (globalInst->getOp() == kIROp_InterfaceType && + as(globalInst)->findDecoration()->getName() == "IDifferentiable") + { + return globalInst; + } + } + } + return nullptr; +} + +IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) +{ + if (as(moduleInst) && differentiableInterfaceType) + { + // Assume for now that IDifferentiable has exactly five fields. + SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); + if (auto entry = as(differentiableInterfaceType->getOperand(index))) + return as(entry->getRequirementKey()); + else + { + SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); + } + } + + return nullptr; +} + + +void stripAutoDiffDecorationsFromChildren(IRInst* parent) +{ + for (auto inst : parent->getChildren()) + { + for (auto decor = inst->getFirstDecoration(); decor; ) + { + auto next = decor->getNextDecoration(); + switch (decor->getOp()) + { + case kIROp_ForwardDerivativeDecoration: + case kIROp_DerivativeMemberDecoration: + case kIROp_DifferentiableTypeDictionaryDecoration: + decor->removeAndDeallocate(); + break; + default: + break; + } + decor = next; + } + + if (inst->getFirstChild() != nullptr) + { + stripAutoDiffDecorationsFromChildren(inst); + } + } +} + +void stripAutoDiffDecorations(IRModule* module) +{ + stripAutoDiffDecorationsFromChildren(module->getModuleInst()); +} + +bool processAutodiffCalls( + IRModule* module, + DiagnosticSink* sink, + IRAutodiffPassOptions const&) +{ + // Simplify module to remove dead code. + IRDeadCodeEliminationOptions dceOptions; + dceOptions.keepExportsAlive = true; + dceOptions.keepLayoutsAlive = true; + eliminateDeadCode(module, dceOptions); + + bool modified = false; + + // Create shared context for all auto-diff related passes + AutoDiffSharedContext autodiffContext(module->getModuleInst()); + + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + SharedIRBuilder sharedBuilder; + sharedBuilder.init(module); + sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); + + autodiffContext.sharedBuilder = &sharedBuilder; + + // Process forward derivative calls. + modified |= processForwardDerivativeCalls(&autodiffContext, sink); + + // Process reverse derivative calls. + modified |= processReverseDerivativeCalls(&autodiffContext, sink); + + // Replaces IRDifferentialPairType with an auto-generated struct, + // IRDifferentialPairGetDifferential with 'differential' field access, + // IRDifferentialPairGetPrimal with 'primal' field access, and + // IRMakeDifferentialPair with an IRMakeStruct. + // + modified |= processPairTypes(&autodiffContext); + + // Remove auto-diff related decorations. + stripAutoDiffDecorations(module); + + return modified; +} + + +} \ No newline at end of file diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h new file mode 100644 index 000000000..e470044a4 --- /dev/null +++ b/source/slang/slang-ir-autodiff.h @@ -0,0 +1,210 @@ +// slang-ir-autodiff-fwd.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" + +namespace Slang +{ + +struct AutoDiffSharedContext +{ + IRModuleInst* moduleInst = nullptr; + + SharedIRBuilder* sharedBuilder = nullptr; + + // A reference to the builtin IDifferentiable interface type. + // We use this to look up all the other types (and type exprs) + // that conform to a base type. + // + IRInterfaceType* differentiableInterfaceType = nullptr; + + // The struct key for the 'Differential' associated type + // defined inside IDifferential. We use this to lookup the differential + // type in the conformance table associated with the concrete type. + // + IRStructKey* differentialAssocTypeStructKey = nullptr; + + // The struct key for the witness that `Differential` associated type conforms to + // `IDifferential`. + IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; + + + // The struct key for the 'zero()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of zero() for a given type. + // + IRStructKey* zeroMethodStructKey = nullptr; + + // The struct key for the 'add()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of add() for a given type. + // + IRStructKey* addMethodStructKey = nullptr; + + IRStructKey* mulMethodStructKey = nullptr; + + + // Modules that don't use differentiable types + // won't have the IDifferentiable interface type available. + // Set to false to indicate that we are uninitialized. + // + bool isInterfaceAvailable = false; + + + AutoDiffSharedContext(IRModuleInst* inModuleInst); + +private: + + IRInst* findDifferentiableInterface(); + + IRStructKey* findDifferentialTypeStructKey() + { + return getIDifferentiableStructKeyAtIndex(0); + } + + IRStructKey* findDifferentialTypeWitnessStructKey() + { + return getIDifferentiableStructKeyAtIndex(1); + } + + IRStructKey* findZeroMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(2); + } + + IRStructKey* findAddMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(3); + } + + IRStructKey* findMulMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(4); + } + + IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); +}; + +struct DifferentiableTypeConformanceContext +{ + AutoDiffSharedContext* sharedContext; + + IRGlobalValueWithCode* parentFunc = nullptr; + OrderedDictionary differentiableWitnessDictionary; + + DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) + : sharedContext(shared) + {} + + void setFunc(IRGlobalValueWithCode* func); + + void buildGlobalWitnessDictionary(); + + // Lookup a witness table for the concreteType. One should exist if concreteType + // inherits (successfully) from IDifferentiable. + // + IRInst* lookUpConformanceForType(IRInst* type); + + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); + + // Lookup and return the 'Differential' type declared in the concrete type + // in order to conform to the IDifferentiable interface. + // Note that inside a generic block, this will be a witness table lookup instruction + // that gets resolved during the specialization pass. + // + IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) + { + switch (origType->getOp()) + { + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + case kIROp_VectorType: + return origType; + } + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); + } + + IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + } + + IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); + } + +}; + +struct DifferentialPairTypeBuilder +{ + struct LoweredPairTypeInfo + { + IRInst* loweredType; + bool isTrivial; + }; + + DifferentialPairTypeBuilder() = default; + + DifferentialPairTypeBuilder(AutoDiffSharedContext* sharedContext) : sharedContext(sharedContext) {} + + IRStructField* findField(IRInst* type, IRStructKey* key); + + IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam); + + IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key); + + IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst); + + IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst); + + IRStructKey* _getOrCreateDiffStructKey(); + + IRStructKey* _getOrCreatePrimalStructKey(); + + IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType); + + IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type); + + IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type); + + LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType); + + + Dictionary pairTypeCache; + + IRStructKey* globalPrimalKey = nullptr; + + IRStructKey* globalDiffKey = nullptr; + + IRInst* genericDiffPairType = nullptr; + + List generatedTypeList; + + AutoDiffSharedContext* sharedContext = nullptr; +}; + +void stripAutoDiffDecorations(IRModule* module); + +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); + +struct IRAutodiffPassOptions +{ + // Nothing for now... +}; + +bool processAutodiffCalls( + IRModule* module, + DiagnosticSink* sink, + IRAutodiffPassOptions const& options = IRAutodiffPassOptions()); + +}; \ No newline at end of file diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 44c6324e3..f4f61d7e9 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -1,6 +1,6 @@ #include "slang-ir-check-differentiability.h" -#include "slang-ir-diff-jvp.h" +#include "slang-ir-autodiff.h" #include "slang-ir-inst-pass-base.h" namespace Slang diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp deleted file mode 100644 index c9ca687e4..000000000 --- a/source/slang/slang-ir-diff-jvp.cpp +++ /dev/null @@ -1,3197 +0,0 @@ -// slang-ir-diff-jvp.cpp -#include "slang-ir-diff-jvp.h" - -#include "slang-ir-clone.h" -#include "slang-ir-dce.h" -#include "slang-ir-eliminate-phis.h" -#include "slang-ir-util.h" -#include "slang-ir-inst-pass-base.h" - -// origX, primalX, diffX -// origX -> primalX (cloneEnv) -// origX -> diffX (instMapD) - -namespace Slang -{ - -namespace -{ - -IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) -{ - if (auto witnessTable = as(witness)) - { - for (auto entry : witnessTable->getEntries()) - { - if (entry->getRequirementKey() == requirementKey) - return entry->getSatisfyingVal(); - } - } - else if (auto witnessTableParam = as(witness)) - { - return builder->emitLookupInterfaceMethodInst( - builder->getTypeKind(), - witnessTableParam, - requirementKey); - } - return nullptr; -} - -} - -struct DifferentialPairTypeBuilder -{ - - IRStructField* findField(IRInst* type, IRStructKey* key) - { - if (auto irStructType = as(type)) - { - for (auto field : irStructType->getFields()) - { - if (field->getKey() == key) - { - return field; - } - } - } - else if (auto irSpecialize = as(type)) - { - if (auto irGeneric = as(irSpecialize->getBase())) - { - if (auto irGenericStructType = as(findInnerMostGenericReturnVal(irGeneric))) - { - return findField(irGenericStructType, key); - } - } - } - - return nullptr; - } - - IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam) - { - // Get base generic that's being specialized. - auto genericType = as(as(specializeInst)->getBase()); - SLANG_ASSERT(genericType); - - // Find the index of genericParam in the base generic. - int paramIndex = -1; - int currentIndex = 0; - for (auto param : genericType->getParams()) - { - if (param == genericParam) - paramIndex = currentIndex; - currentIndex ++; - } - - SLANG_ASSERT(paramIndex >= 0); - - // Return the corresponding operand in the specialization inst. - return specializeInst->getOperand(1 + paramIndex); - } - - IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) - { - IRInst* pairType = nullptr; - if (auto basePtrType = as(baseInst->getDataType())) - { - auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType()); - - // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types, - // especially since we probably don't need diff bottom anymore. - // - SLANG_ASSERT(!baseTypeInfo.isTrivial); - - pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType); - } - else - { - auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); - if (baseTypeInfo.isTrivial) - { - if (key == globalPrimalKey) - return baseInst; - else - return builder->getDifferentialBottom(); - } - - pairType = baseTypeInfo.loweredType; - } - - if (auto basePairStructType = as(pairType)) - { - return as(builder->emitFieldExtract( - findField(basePairStructType, key)->getFieldType(), - baseInst, - key - )); - } - else if (auto ptrType = as(pairType)) - { - if (auto ptrInnerSpecializedType = as(ptrType->getValueType())) - { - auto genericType = findInnerMostGenericReturnVal(as(ptrInnerSpecializedType->getBase())); - if (auto genericBasePairStructType = as(genericType)) - { - return as(builder->emitFieldAddress( - builder->getPtrType((IRType*) - findSpecializationForParam( - ptrInnerSpecializedType, - findField(ptrInnerSpecializedType, key)->getFieldType())), - baseInst, - key - )); - } - } - else if (auto ptrBaseStructType = as(ptrType->getValueType())) - { - return as(builder->emitFieldAddress( - builder->getPtrType((IRType*) - findField(ptrBaseStructType, key)->getFieldType()), - baseInst, - key)); - } - } - else if (auto specializedType = as(pairType)) - { - // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's - // type, emit the specialization type. - // - auto genericType = findInnerMostGenericReturnVal(as(specializedType->getBase())); - if (auto genericBasePairStructType = as(genericType)) - { - return as(builder->emitFieldExtract( - (IRType*)findSpecializationForParam( - specializedType, - findField(genericBasePairStructType, key)->getFieldType()), - baseInst, - key - )); - } - else if (auto genericPtrType = as(genericType)) - { - if (auto genericPairStructType = as(genericPtrType->getValueType())) - { - return as(builder->emitFieldAddress( - builder->getPtrType((IRType*) - findSpecializationForParam( - specializedType, - findField(genericPairStructType, key)->getFieldType())), - baseInst, - key - )); - } - } - } - else - { - SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor"); - } - return nullptr; - } - - IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) - { - return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); - } - - IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) - { - return emitFieldAccessor(builder, baseInst, this->globalDiffKey); - } - - IRStructKey* _getOrCreateDiffStructKey() - { - if (!this->globalDiffKey) - { - IRBuilder builder(sharedContext->sharedBuilder); - // Insert directly at top level (skip any generic scopes etc.) - builder.setInsertInto(sharedContext->moduleInst); - - this->globalDiffKey = builder.createStructKey(); - builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential")); - } - - return this->globalDiffKey; - } - - IRStructKey* _getOrCreatePrimalStructKey() - { - if (!this->globalPrimalKey) - { - // Insert directly at top level (skip any generic scopes etc.) - IRBuilder builder(sharedContext->sharedBuilder); - builder.setInsertInto(sharedContext->moduleInst); - - this->globalPrimalKey = builder.createStructKey(); - builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal")); - } - - return this->globalPrimalKey; - } - - IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType) - { - switch (origBaseType->getOp()) - { - case kIROp_lookup_interface_method: - case kIROp_Specialize: - case kIROp_Param: - return nullptr; - default: - break; - } - if (diffType->getOp() != kIROp_DifferentialBottomType) - { - IRBuilder builder(sharedContext->sharedBuilder); - builder.setInsertBefore(diffType); - - auto pairStructType = builder.createStructType(); - builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType); - builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType); - return pairStructType; - } - return origBaseType; - } - - struct LoweredPairTypeInfo - { - IRInst* loweredType; - bool isTrivial; - }; - - IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type) - { - auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); - } - - IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type) - { - auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); - } - - LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType) - { - LoweredPairTypeInfo result = {}; - - if (pairTypeCache.TryGetValue(originalPairType, result)) - return result; - auto pairType = as(originalPairType); - if (!pairType) - { - result.isTrivial = true; - result.loweredType = originalPairType; - return result; - } - auto primalType = pairType->getValueType(); - if (as(primalType)) - { - result.isTrivial = false; - result.loweredType = nullptr; - return result; - } - - auto diffType = getDiffTypeFromPairType(builder, pairType); - if (!diffType) - return result; - result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); - result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType); - pairTypeCache.Add(originalPairType, result); - - return result; - } - - Dictionary pairTypeCache; - - IRStructKey* globalPrimalKey = nullptr; - - IRStructKey* globalDiffKey = nullptr; - - IRInst* genericDiffPairType = nullptr; - - List generatedTypeList; - - AutoDiffSharedContext* sharedContext = nullptr; -}; - -struct JVPTranscriber -{ - - // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent - // their differential values. - Dictionary instMapD; - - // Set of insts currently being transcribed. Used to avoid infinite loops. - HashSet instsInProgress; - - // Cloning environment to hold mapping from old to new copies for the primal - // instructions. - IRCloneEnv cloneEnv; - - // Diagnostic sink for error messages. - DiagnosticSink* sink; - - // Type conformance information. - AutoDiffSharedContext* autoDiffSharedContext; - - // Builder to help with creating and accessing the 'DifferentiablePair' struct - DifferentialPairTypeBuilder* pairBuilder; - - DifferentiableTypeConformanceContext differentiableTypeConformanceContext; - - List followUpFunctionsToTranscribe; - - SharedIRBuilder* sharedBuilder; - // Witness table that `DifferentialBottom:IDifferential`. - IRWitnessTable* differentialBottomWitness = nullptr; - Dictionary differentialPairTypes; - - JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder) - : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder) - { - - } - - DiagnosticSink* getSink() - { - SLANG_ASSERT(sink); - return sink; - } - - void mapDifferentialInst(IRInst* origInst, IRInst* diffInst) - { - if (hasDifferentialInst(origInst)) - { - if (lookupDiffInst(origInst) != diffInst) - { - SLANG_UNEXPECTED("Inconsistent differential mappings"); - } - } - else - { - instMapD.Add(origInst, diffInst); - } - } - - void mapPrimalInst(IRInst* origInst, IRInst* primalInst) - { - if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst) - { - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::internalCompilerError, - "inconsistent primal instruction for original"); - } - else - { - cloneEnv.mapOldValToNew[origInst] = primalInst; - } - } - - IRInst* lookupDiffInst(IRInst* origInst) - { - return instMapD[origInst]; - } - - IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst) - { - return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst; - } - - bool hasDifferentialInst(IRInst* origInst) - { - return instMapD.ContainsKey(origInst); - } - - IRInst* lookupPrimalInst(IRInst* origInst) - { - return cloneEnv.mapOldValToNew[origInst]; - } - - IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) - { - return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; - } - - bool hasPrimalInst(IRInst* origInst) - { - return cloneEnv.mapOldValToNew.ContainsKey(origInst); - } - - IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst) - { - if (!hasDifferentialInst(origInst)) - { - transcribe(builder, origInst); - SLANG_ASSERT(hasDifferentialInst(origInst)); - } - - return lookupDiffInst(origInst); - } - - IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst) - { - if (!hasPrimalInst(origInst)) - { - transcribe(builder, origInst); - SLANG_ASSERT(hasPrimalInst(origInst)); - } - - return lookupPrimalInst(origInst); - } - - IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) - { - List newParameterTypes; - IRType* diffReturnType; - - for (UIndex i = 0; i < funcType->getParamCount(); i++) - { - auto origType = funcType->getParamType(i); - origType = (IRType*) lookupPrimalInst(origType, origType); - if (auto diffPairType = tryGetDiffPairType(builder, origType)) - newParameterTypes.add(diffPairType); - else - newParameterTypes.add(origType); - } - - // Transcribe return type to a pair. - // This will be void if the primal return type is non-differentiable. - // - auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType()); - if (auto returnPairType = tryGetDiffPairType(builder, origResultType)) - diffReturnType = returnPairType; - else - diffReturnType = builder->getVoidType(); - - return builder->getFuncType(newParameterTypes, diffReturnType); - } - - IRWitnessTable* getDifferentialBottomWitness() - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); - auto result = - as(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - SLANG_ASSERT(result); - return result; - } - - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. - IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(inDiffPairType->parent); - auto diffPairType = as(inDiffPairType); - SLANG_ASSERT(diffPairType); - - auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); - - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = differentiateType(&builder, diffPairType); - - // And place it in the synthesized witness table. - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - - // Record this in the context for future lookups - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - - return table; - } - - IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); - } - - IRType* getOrCreateDiffPairType(IRInst* primalType) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - auto witness = as( - differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - - if (!witness) - { - if (auto primalPairType = as(primalType)) - { - witness = getDifferentialPairWitness(primalPairType); - } - else - { - witness = getDifferentialBottomWitness(); - } - } - - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); - } - - IRType* differentiateType(IRBuilder* builder, IRType* origType) - { - IRInst* diffType = nullptr; - if (!instMapD.TryGetValue(origType, diffType)) - { - diffType = _differentiateTypeImpl(builder, origType); - instMapD[origType] = diffType; - } - return (IRType*)diffType; - } - - IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) - { - if (auto ptrType = as(origType)) - return builder->getPtrType( - origType->getOp(), - differentiateType(builder, ptrType->getValueType())); - - // If there is an explicit primal version of this type in the local scope, load that - // otherwise use the original type. - // - IRInst* primalType = lookupPrimalInst(origType, origType); - - // Special case certain compound types (PtrType, FuncType, etc..) - // otherwise try to lookup a differential definition for the given type. - // If one does not exist, then we assume it's not differentiable. - // - switch (primalType->getOp()) - { - case kIROp_Param: - if (as(primalType->getDataType())) - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( - builder, - (IRType*)primalType)); - else if (as(primalType->getDataType())) - return (IRType*)primalType; - - case kIROp_ArrayType: - { - auto primalArrayType = as(primalType); - if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) - return builder->getArrayType( - diffElementType, - primalArrayType->getElementCount()); - else - return nullptr; - } - - case kIROp_DifferentialPairType: - { - auto primalPairType = as(primalType); - return getOrCreateDiffPairType( - pairBuilder->getDiffTypeFromPairType(builder, primalPairType), - pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); - } - - case kIROp_FuncType: - return differentiateFunctionType(builder, as(primalType)); - - case kIROp_OutType: - if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType())) - return builder->getOutType(diffValueType); - else - return nullptr; - - case kIROp_InOutType: - if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType())) - return builder->getInOutType(diffValueType); - else - return nullptr; - - case kIROp_TupleType: - { - auto tupleType = as(primalType); - List diffTypeList; - // TODO: what if we have type parameters here? - for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) - diffTypeList.add( - differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - - return builder->getTupleType(diffTypeList); - } - - default: - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); - } - } - - IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) - { - // If this is a PtrType (out, inout, etc..), then create diff pair from - // value type and re-apply the appropropriate PtrType wrapper. - // - if (auto origPtrType = as(primalType)) - { - if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) - return builder->getPtrType(primalType->getOp(), diffPairValueType); - else - return nullptr; - } - auto diffType = differentiateType(builder, primalType); - if (diffType) - return (IRType*)getOrCreateDiffPairType(primalType); - return nullptr; - } - - InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) - { - auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType()); - // Do not differentiate generic type (and witness table) parameters - if (as(primalDataType) || as(primalDataType)) - { - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } - - // Is this param a phi node or a function parameter? - auto func = as(origParam->getParent()->getParent()); - bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); - if (isFuncParam) - { - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - IRInst* diffPairParam = builder->emitParam(diffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffPairParam); - - if (auto pairType = as(diffPairType)) - { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); - } - else if (auto pairPtrType = as(diffPairType)) - { - auto ptrInnerPairType = as(pairPtrType->getValueType()); - - return InstPair( - builder->emitDifferentialPairAddressPrimal(diffPairParam), - builder->emitDifferentialPairAddressDifferential( - builder->getPtrType( - kIROp_PtrType, - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), - diffPairParam)); - } - } - - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } - else - { - auto primal = cloneInst(&cloneEnv, builder, origParam); - IRInst* diff = nullptr; - if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) - { - diff = builder->emitParam(diffType); - } - return InstPair(primal, diff); - } - - } - - // Returns "d" to use as a name hint for variables and parameters. - // If no primal name is available, returns a blank string. - // - String getJVPVarName(IRInst* origVar) - { - if (auto namehintDecoration = origVar->findDecoration()) - { - return ("d" + String(namehintDecoration->getName())); - } - - return String(""); - } - - // Returns "dp" to use as a name hint for parameters. - // If no primal name is available, returns a blank string. - // - String makeDiffPairName(IRInst* origVar) - { - if (auto namehintDecoration = origVar->findDecoration()) - { - return ("dp" + String(namehintDecoration->getName())); - } - - return String(""); - } - - InstPair transcribeVar(IRBuilder* builder, IRVar* origVar) - { - if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) - { - IRVar* diffVar = builder->emitVar(diffType); - SLANG_ASSERT(diffVar); - - auto diffNameHint = getJVPVarName(origVar); - if (diffNameHint.getLength() > 0) - builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice()); - - return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar); - } - - return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr); - } - - InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); - - IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith); - - auto origLeft = origArith->getOperand(0); - auto origRight = origArith->getOperand(1); - - auto primalLeft = findOrTranscribePrimalInst(builder, origLeft); - auto primalRight = findOrTranscribePrimalInst(builder, origRight); - - auto diffLeft = findOrTranscribeDiffInst(builder, origLeft); - auto diffRight = findOrTranscribeDiffInst(builder, origRight); - - - if (diffLeft || diffRight) - { - diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType()); - diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); - - auto resultType = primalArith->getDataType(); - switch(origArith->getOp()) - { - case kIROp_Add: - return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight)); - case kIROp_Mul: - return InstPair(primalArith, builder->emitAdd(resultType, - builder->emitMul(resultType, diffLeft, primalRight), - builder->emitMul(resultType, primalLeft, diffRight))); - case kIROp_Sub: - return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight)); - case kIROp_Div: - return InstPair(primalArith, builder->emitDiv(resultType, - builder->emitSub( - resultType, - builder->emitMul(resultType, diffLeft, primalRight), - builder->emitMul(resultType, primalLeft, diffRight)), - builder->emitMul( - primalRight->getDataType(), primalRight, primalRight - ))); - default: - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - } - - return InstPair(primalArith, nullptr); - } - - - InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic) - { - SLANG_ASSERT(origLogic->getOperandCount() == 2); - - // TODO: Check other boolean cases. - if (as(origLogic->getDataType())) - { - // Boolean operations are not differentiable. For the linearization - // pass, we do not need to do anything but copy them over to the ne - // function. - auto primalLogic = cloneInst(&cloneEnv, builder, origLogic); - return InstPair(primalLogic, nullptr); - } - - SLANG_UNEXPECTED("Logical operation with non-boolean result"); - } - - InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad) - { - auto origPtr = origLoad->getPtr(); - auto primalPtr = lookupPrimalInst(origPtr, nullptr); - auto primalPtrValueType = as(primalPtr->getFullType())->getValueType(); - - if (auto diffPairType = as(primalPtrValueType)) - { - // Special case load from an `out` param, which will not have corresponding `diff` and - // `primal` insts yet. - - auto load = builder->emitLoad(primalPtr); - auto primalElement = builder->emitDifferentialPairGetPrimal(load); - auto diffElement = builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); - return InstPair(primalElement, diffElement); - } - - auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); - IRInst* diffLoad = nullptr; - if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) - { - // Default case, we're loading from a known differential inst. - diffLoad = as(builder->emitLoad(diffPtr)); - } - return InstPair(primalLoad, diffLoad); - } - - InstPair transcribeStore(IRBuilder* builder, IRStore* origStore) - { - IRInst* origStoreLocation = origStore->getPtr(); - IRInst* origStoreVal = origStore->getVal(); - auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr); - auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); - auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr); - auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); - - if (!diffStoreLocation) - { - auto primalLocationPtrType = as(primalStoreLocation->getDataType()); - if (auto diffPairType = as(primalLocationPtrType->getValueType())) - { - auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal); - auto store = builder->emitStore(primalStoreLocation, valToStore); - return InstPair(store, nullptr); - } - } - - auto primalStore = cloneInst(&cloneEnv, builder, origStore); - - IRInst* diffStore = nullptr; - - // If the stored value has a differential version, - // emit a store instruction for the differential parameter. - // Otherwise, emit nothing since there's nothing to load. - // - if (diffStoreLocation && diffStoreVal) - { - // Default case, storing the entire type (and not a member) - diffStore = as( - builder->emitStore(diffStoreLocation, diffStoreVal)); - - return InstPair(primalStore, diffStore); - } - - return InstPair(primalStore, nullptr); - } - - InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn) - { - IRInst* origReturnVal = origReturn->getVal(); - - auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType()); - if (as(origReturnVal) || as(origReturnVal) || as(origReturnVal) || as(origReturnVal)) - { - // If the return value is itself a function, generic or a struct then this - // is likely to be a generic scope. In this case, we lookup the differential - // and return that. - IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); - - // Neither of these should be nullptr. - SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal); - IRReturn* diffReturn = as(builder->emitReturn(diffReturnVal)); - - return InstPair(diffReturn, diffReturn); - } - else if (auto pairType = tryGetDiffPairType(builder, returnDataType)) - { - IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); - if(!diffReturnVal) - diffReturnVal = getDifferentialZeroOfType(builder, returnDataType); - - // If the pair type can be formed, this must be non-null. - SLANG_RELEASE_ASSERT(diffReturnVal); - - auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); - IRReturn* pairReturn = as(builder->emitReturn(diffPair)); - return InstPair(pairReturn, pairReturn); - } - else - { - // If the return type is not differentiable, emit the primal value only. - IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - - IRInst* primalReturn = builder->emitReturn(primalReturnVal); - return InstPair(primalReturn, nullptr); - - } - } - - // Since int/float literals are sometimes nested inside an IRConstructor - // instruction, we check to make sure that the nested instr is a constant - // and then return nullptr. Literals do not need to be differentiated. - // - InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) - { - IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct); - - // Check if the output type can be differentiated. If it cannot be - // differentiated, don't differentiate the inst - // - auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType()); - if (auto diffConstructType = differentiateType(builder, primalConstructType)) - { - UCount operandCount = origConstruct->getOperandCount(); - - List diffOperands; - for (UIndex ii = 0; ii < operandCount; ii++) - { - // If the operand has a differential version, replace the original with - // the differential. Otherwise, use a zero. - // - if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr)) - diffOperands.add(diffInst); - else - { - auto operandDataType = origConstruct->getOperand(ii)->getDataType(); - operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType); - diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); - } - } - - return InstPair( - primalConstruct, - builder->emitIntrinsicInst( - diffConstructType, - origConstruct->getOp(), - operandCount, - diffOperands.getBuffer())); - } - else - { - return InstPair(primalConstruct, nullptr); - } - } - - // Differentiating a call instruction here is primarily about generating - // an appropriate call list based on whichever parameters have differentials - // in the current transcription context. - // - InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) - { - - IRInst* origCallee = origCall->getCallee(); - - if (!origCallee) - { - // Note that this can only happen if the callee is a result - // of a higher-order operation. For now, we assume that we cannot - // differentiate such calls safely. - // TODO(sai): Should probably get checked in the front-end. - // - getSink()->diagnose(origCall->sourceLoc, - Diagnostics::internalCompilerError, - "attempting to differentiate unresolved callee"); - - return InstPair(nullptr, nullptr); - } - - // Since concrete functions are globals, the primal callee is the same - // as the original callee. - // - auto primalCallee = origCallee; - - IRInst* diffCallee = nullptr; - - if (auto derivativeReferenceDecor = primalCallee->findDecoration()) - { - // If the user has already provided an differentiated implementation, use that. - diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc(); - } - else if (primalCallee->findDecoration()) - { - // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass - // to generate the implementation. - diffCallee = builder->emitForwardDifferentiateInst( - differentiateFunctionType(builder, as(primalCallee->getFullType())), - primalCallee); - } - else - { - // The callee is non differentiable, just return primal value with null diff value. - IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); - return InstPair(primalCall, nullptr); - } - - List args; - // Go over the parameter list and create pairs for each input (if required) - for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) - { - auto origArg = origCall->getArg(ii); - auto primalArg = findOrTranscribePrimalInst(builder, origArg); - SLANG_ASSERT(primalArg); - - auto primalType = primalArg->getDataType(); - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - - if (!diffArg) - diffArg = getDifferentialZeroOfType(builder, primalType); - - if (auto pairType = tryGetDiffPairType(builder, primalType)) - { - // If a pair type can be formed, this must be non-null. - SLANG_RELEASE_ASSERT(diffArg); - auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); - args.add(diffPair); - } - else - { - // Add original/primal argument. - args.add(primalArg); - } - } - - IRType* diffReturnType = nullptr; - diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); - - if (!diffReturnType) - { - SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType); - diffReturnType = builder->getVoidType(); - } - - auto callInst = builder->emitCallInst( - diffReturnType, - diffCallee, - args); - - if (diffReturnType->getOp() != kIROp_VoidType) - { - IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst); - auto diffType = differentiateType(builder, origCall->getFullType()); - IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst); - return InstPair(primalResultValue, diffResultValue); - } - else - { - // Return the inst itself if the return value is void. - // This is fine since these values should never actually be used anywhere. - // - return InstPair(callInst, callInst); - } - } - - InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) - { - IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle); - - if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr)) - { - List swizzleIndices; - for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++) - swizzleIndices.add(origSwizzle->getElementIndex(ii)); - - return InstPair( - primalSwizzle, - builder->emitSwizzle( - differentiateType(builder, primalSwizzle->getDataType()), - diffBase, - origSwizzle->getElementCount(), - swizzleIndices.getBuffer())); - } - - return InstPair(primalSwizzle, nullptr); - } - - InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst) - { - IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst); - - UCount operandCount = origInst->getOperandCount(); - - List diffOperands; - for (UIndex ii = 0; ii < operandCount; ii++) - { - // If the operand has a differential version, replace the original with the - // differential. - // Otherwise, abandon the differentiation attempt and assume that origInst - // cannot (or does not need to) be differentiated. - // - if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr)) - diffOperands.add(diffInst); - else - return InstPair(primalInst, nullptr); - } - - return InstPair( - primalInst, - builder->emitIntrinsicInst( - differentiateType(builder, primalInst->getDataType()), - origInst->getOp(), - operandCount, - diffOperands.getBuffer())); - } - - InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst) - { - switch(origInst->getOp()) - { - case kIROp_unconditionalBranch: - case kIROp_loop: - auto origBranch = as(origInst); - - // Grab the differentials for any phi nodes. - List newArgs; - for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) - { - auto origArg = origBranch->getArg(ii); - auto primalArg = lookupPrimalInst(origArg); - newArgs.add(primalArg); - - if (differentiateType(builder, primalArg->getDataType())) - { - auto diffArg = lookupDiffInst(origArg, nullptr); - if (diffArg) - newArgs.add(diffArg); - } - } - - IRInst* diffBranch = nullptr; - if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock())) - { - if (auto origLoop = as(origInst)) - { - auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); - auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); - List operands; - operands.add(breakBlock); - operands.add(continueBlock); - operands.addRange(newArgs); - diffBranch = builder->emitIntrinsicInst( - nullptr, - kIROp_loop, - operands.getCount(), - operands.getBuffer()); - } - else - { - diffBranch = builder->emitBranch( - as(diffBlock), - newArgs.getCount(), - newArgs.getBuffer()); - } - } - - // For now, every block in the original fn must have a corresponding - // block to compute *both* primals and derivatives (i.e linearized block) - SLANG_ASSERT(diffBranch); - - return InstPair(diffBranch, diffBranch); - } - - getSink()->diagnose( - origInst->sourceLoc, - Diagnostics::unimplemented, - "attempting to differentiate unhandled control flow"); - - return InstPair(nullptr, nullptr); - } - - InstPair transcribeConst(IRBuilder* builder, IRInst* origInst) - { - switch(origInst->getOp()) - { - case kIROp_FloatLit: - return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f)); - case kIROp_VoidLit: - return InstPair(origInst, origInst); - case kIROp_IntLit: - return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0)); - } - - getSink()->diagnose( - origInst->sourceLoc, - Diagnostics::unimplemented, - "attempting to differentiate unhandled const type"); - - return InstPair(nullptr, nullptr); - } - - InstPair transcribeSpecialize(IRBuilder*, IRSpecialize* origSpecialize) - { - // In general, we should not see any specialize insts at this stage. - // The exceptions are target intrinsics. - auto genericInnerVal = findInnerMostGenericReturnVal(as(origSpecialize->getBase())); - if (genericInnerVal->findDecoration()) - { - // Look for an IRForwardDerivativeDecoration on the specialize inst. - // (Normally, this would be on the inner IRFunc, but in this case only the JVP func - // can be specialized, so we put a decoration on the IRSpecialize) - // - if (auto jvpFuncDecoration = origSpecialize->findDecoration()) - { - auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc(); - - // Make sure this isn't itself a specialize . - SLANG_RELEASE_ASSERT(!as(jvpFunc)); - - return InstPair(jvpFunc, jvpFunc); - } - } - else - { - getSink()->diagnose(origSpecialize->sourceLoc, - Diagnostics::unexpected, - "should not be attempting to differentiate anything specialized here."); - } - - return InstPair(nullptr, nullptr); - } - - InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup) - { - // This is slightly counter-intuitive, but we don't perform any differentiation - // logic here. We simple clone the original lookup which points to the original function, - // or the cloned version in case we're inside a generic scope. - // The differentiation logic is inserted later when this is used in an IRCall. - // This decision is mostly to maintain a uniform convention of ForwardDifferentiate(Lookup(Table)) - // rather than have Lookup(ForwardDifferentiate(Table)) - // - auto diffLookup = cloneInst(&cloneEnv, builder, origLookup); - return InstPair(diffLookup, diffLookup); - } - - // In differential computation, the 'default' differential value is always zero. - // This is a consequence of differential computing being inherently linear. As a - // result, it's useful to have a method to generate zero literals of any (arithmetic) type. - // The current implementation requires that types are defined linearly. - // - IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) - { - if (auto diffType = differentiateType(builder, primalType)) - { - switch (diffType->getOp()) - { - case kIROp_DifferentialPairType: - return builder->emitMakeDifferentialPair( - diffType, - getDifferentialZeroOfType(builder, as(diffType)->getValueType()), - getDifferentialZeroOfType(builder, as(diffType)->getValueType())); - } - // Since primalType has a corresponding differential type, we can lookup the - // definition for zero(). - auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); - SLANG_ASSERT(zeroMethod); - - auto emptyArgList = List(); - return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); - } - else - { - if (isScalarIntegerType(primalType)) - { - return builder->getIntValue(primalType, 0); - } - - getSink()->diagnose(primalType->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; - } - } - - InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) - { - IRBuilder subBuilder(builder->getSharedBuilder()); - subBuilder.setInsertLoc(builder->getInsertLoc()); - - IRInst* diffBlock = subBuilder.emitBlock(); - - // Note: for blocks, we setup the mapping _before_ - // processing the children since we could encounter - // a lookup while processing the children. - // - mapPrimalInst(origBlock, diffBlock); - mapDifferentialInst(origBlock, diffBlock); - - subBuilder.setInsertInto(diffBlock); - - // First transcribe every parameter in the block. - for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - this->transcribe(&subBuilder, param); - - // Then, run through every instruction and use the transcriber to generate the appropriate - // derivative code. - // - for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - this->transcribe(&subBuilder, child); - - return InstPair(diffBlock, diffBlock); - } - - InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) - { - SLANG_ASSERT(as(originalInst) || as(originalInst)); - - IRInst* origBase = originalInst->getOperand(0); - auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto field = originalInst->getOperand(1); - auto derivativeRefDecor = field->findDecoration(); - auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType()); - - IRInst* primalOperands[] = { primalBase, field }; - IRInst* primalFieldExtract = builder->emitIntrinsicInst( - primalType, - originalInst->getOp(), - 2, - primalOperands); - - if (!derivativeRefDecor) - { - return InstPair(primalFieldExtract, nullptr); - } - - IRInst* diffFieldExtract = nullptr; - - if (auto diffType = differentiateType(builder, primalType)) - { - if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) - { - IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() }; - diffFieldExtract = builder->emitIntrinsicInst( - diffType, - originalInst->getOp(), - 2, - diffOperands); - } - } - return InstPair(primalFieldExtract, diffFieldExtract); - } - - InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) - { - SLANG_ASSERT(as(origGetElementPtr) || as(origGetElementPtr)); - - IRInst* origBase = origGetElementPtr->getOperand(0); - auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1)); - - auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType()); - - IRInst* primalOperands[] = {primalBase, primalIndex}; - IRInst* primalGetElementPtr = builder->emitIntrinsicInst( - primalType, - origGetElementPtr->getOp(), - 2, - primalOperands); - - IRInst* diffGetElementPtr = nullptr; - - if (auto diffType = differentiateType(builder, primalType)) - { - if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) - { - IRInst* diffOperands[] = {diffBase, primalIndex}; - diffGetElementPtr = builder->emitIntrinsicInst( - diffType, - origGetElementPtr->getOp(), - 2, - diffOperands); - } - } - - return InstPair(primalGetElementPtr, diffGetElementPtr); - } - - InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop) - { - // The loop comes with three blocks.. we just need to transcribe each one - // and assemble the new loop instruction. - - // Transcribe the target block (this is the 'condition' part of the loop, which - // will branch into the loop body) - auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock()); - - // Transcribe the break block (this is the block after the exiting the loop) - auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); - - // Transcribe the continue block (this is the 'update' part of the loop, which will - // branch into the condition block) - auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); - - - List diffLoopOperands; - diffLoopOperands.add(diffTargetBlock); - diffLoopOperands.add(diffBreakBlock); - diffLoopOperands.add(diffContinueBlock); - - // If there are any other operands, use their primal versions. - for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++) - { - auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii)); - diffLoopOperands.add(primalOperand); - } - - IRInst* diffLoop = builder->emitIntrinsicInst( - nullptr, - kIROp_loop, - diffLoopOperands.getCount(), - diffLoopOperands.getBuffer()); - - return InstPair(diffLoop, diffLoop); - } - - InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) - { - // IfElse Statements come with 4 blocks. We transcribe each block into it's - // linear form, and then wire them up in the same way as the original if-else - - // Transcribe condition block - auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition()); - SLANG_ASSERT(primalConditionBlock); - - // Transcribe 'true' block (condition block branches into this if true) - auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock()); - SLANG_ASSERT(diffTrueBlock); - - // Transcribe 'false' block (condition block branches into this if true) - // TODO (sai): What happens if there's no false block? - auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock()); - SLANG_ASSERT(diffFalseBlock); - - // Transcribe 'after' block (true and false blocks branch into this) - auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock()); - SLANG_ASSERT(diffAfterBlock); - - List diffIfElseArgs; - diffIfElseArgs.add(primalConditionBlock); - diffIfElseArgs.add(diffTrueBlock); - diffIfElseArgs.add(diffFalseBlock); - diffIfElseArgs.add(diffAfterBlock); - - // If there are any other operands, use their primal versions. - for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++) - { - auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii)); - diffIfElseArgs.add(primalOperand); - } - - IRInst* diffLoop = builder->emitIntrinsicInst( - nullptr, - kIROp_ifElse, - diffIfElseArgs.getCount(), - diffIfElseArgs.getBuffer()); - - return InstPair(diffLoop, diffLoop); - } - - InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) - { - auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue()); - SLANG_ASSERT(primalVal); - auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue()); - SLANG_ASSERT(diffPrimalVal); - auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue()); - SLANG_ASSERT(primalDiffVal); - auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); - SLANG_ASSERT(diffDiffVal); - - auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal); - auto diffPair = builder->emitMakeDifferentialPair( - differentiateType(builder, origInst->getDataType()), - primalDiffVal, - diffDiffVal); - return InstPair(primalPair, diffPair); - } - - InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) - { - SLANG_ASSERT( - origInst->getOp() == kIROp_DifferentialPairGetDifferential || - origInst->getOp() == kIROp_DifferentialPairGetPrimal); - - auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0)); - SLANG_ASSERT(primalVal); - - auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0)); - SLANG_ASSERT(diffVal); - - auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal); - - auto diffValPairType = as(diffVal->getDataType()); - IRInst* diffResultType = nullptr; - if (origInst->getOp() == kIROp_DifferentialPairGetDifferential) - diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType); - else - diffResultType = diffValPairType->getValueType(); - auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal); - return InstPair(primalResult, diffResult); - } - - // Create an empty func to represent the transcribed func of `origFunc`. - InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) - { - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertBefore(origFunc); - - IRFunc* primalFunc = origFunc; - - differentiableTypeConformanceContext.setFunc(origFunc); - - primalFunc = origFunc; - - auto diffFunc = builder.createFunc(); - - SLANG_ASSERT(as(origFunc->getFullType())); - IRType* diffFuncType = this->differentiateFunctionType( - &builder, - as(origFunc->getFullType())); - diffFunc->setFullType(diffFuncType); - - if (auto nameHint = origFunc->findDecoration()) - { - auto originalName = nameHint->getName(); - StringBuilder newNameSb; - newNameSb << "s_fwd_" << originalName; - builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); - } - builder.addForwardDerivativeDecoration(origFunc, diffFunc); - - // Mark the generated derivative function itself as differentiable. - builder.addForwardDifferentiableDecoration(diffFunc); - - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration()) - { - cloneDecoration(dictDecor, diffFunc); - } - - auto result = InstPair(primalFunc, diffFunc); - followUpFunctionsToTranscribe.add(result); - return result; - } - - // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) - { - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertInto(diffFunc); - - differentiableTypeConformanceContext.setFunc(primalFunc); - // Transcribe children from origFunc into diffFunc - for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribe(&builder, block); - - return InstPair(primalFunc, diffFunc); - } - - // Transcribe a generic definition - InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) - { - auto innerVal = findInnerMostGenericReturnVal(origGeneric); - if (auto innerFunc = as(innerVal)) - { - differentiableTypeConformanceContext.setFunc(innerFunc); - } - else - { - return InstPair(origGeneric, nullptr); - } - - // For now, we assume there's only one generic layer. So this inst must be top level - bool isTopLevel = (as(origGeneric->getParent()) != nullptr); - SLANG_RELEASE_ASSERT(isTopLevel); - - IRGeneric* primalGeneric = origGeneric; - - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertBefore(origGeneric); - - auto diffGeneric = builder.emitGeneric(); - - // Process type of generic. If the generic is a function, then it's type will also be a - // generic and this logic will transcribe that generic first before continuing with the - // function itself. - // - auto primalType = primalGeneric->getFullType(); - - IRType* diffType = nullptr; - if (primalType) - { - diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType); - } - - diffGeneric->setFullType(diffType); - - // TODO(sai): Replace naming scheme - // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) - // builder->addNameHintDecoration(diffFunc, jvpName); - - // Transcribe children from origFunc into diffFunc. - builder.setInsertInto(diffGeneric); - for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribe(&builder, block); - - return InstPair(primalGeneric, diffGeneric); - } - - IRInst* transcribe(IRBuilder* builder, IRInst* origInst) - { - // If a differential intstruction is already mapped for - // this original inst, return that. - // - if (auto diffInst = lookupDiffInst(origInst, nullptr)) - { - SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check. - return diffInst; - } - - // Otherwise, dispatch to the appropriate method - // depending on the op-code. - // - instsInProgress.Add(origInst); - InstPair pair = transcribeInst(builder, origInst); - - if (auto primalInst = pair.primal) - { - mapPrimalInst(origInst, pair.primal); - mapDifferentialInst(origInst, pair.differential); - if (pair.differential) - { - switch (pair.differential->getOp()) - { - case kIROp_Func: - case kIROp_Generic: - case kIROp_Block: - // Don't generate again for these. - // Functions already have their names generated in `transcribeFuncHeader`. - break; - default: - // Generate name hint for the inst. - if (auto primalNameHint = primalInst->findDecoration()) - { - StringBuilder sb; - sb << "s_diff_" << primalNameHint->getName(); - builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); - } - break; - } - } - return pair.differential; - } - instsInProgress.Remove(origInst); - - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::internalCompilerError, - "failed to transcibe instruction"); - return nullptr; - } - - InstPair transcribeInst(IRBuilder* builder, IRInst* origInst) - { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transcribeParam(builder, as(origInst)); - - case kIROp_Var: - return transcribeVar(builder, as(origInst)); - - case kIROp_Load: - return transcribeLoad(builder, as(origInst)); - - case kIROp_Store: - return transcribeStore(builder, as(origInst)); - - case kIROp_Return: - return transcribeReturn(builder, as(origInst)); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return transcribeBinaryArith(builder, origInst); - - case kIROp_Less: - case kIROp_Greater: - case kIROp_And: - case kIROp_Or: - case kIROp_Geq: - case kIROp_Leq: - return transcribeBinaryLogic(builder, origInst); - - case kIROp_Construct: - return transcribeConstruct(builder, origInst); - - case kIROp_Call: - return transcribeCall(builder, as(origInst)); - - case kIROp_swizzle: - return transcribeSwizzle(builder, as(origInst)); - - case kIROp_constructVectorFromScalar: - case kIROp_MakeTuple: - return transcribeByPassthrough(builder, origInst); - - case kIROp_unconditionalBranch: - return transcribeControlFlow(builder, origInst); - - case kIROp_FloatLit: - case kIROp_IntLit: - case kIROp_VoidLit: - return transcribeConst(builder, origInst); - - case kIROp_Specialize: - return transcribeSpecialize(builder, as(origInst)); - - case kIROp_lookup_interface_method: - return transcibeLookupInterfaceMethod(builder, as(origInst)); - - case kIROp_FieldExtract: - case kIROp_FieldAddress: - return transcribeFieldExtract(builder, origInst); - case kIROp_getElement: - case kIROp_getElementPtr: - return transcribeGetElement(builder, origInst); - - case kIROp_loop: - return transcribeLoop(builder, as(origInst)); - - case kIROp_ifElse: - return transcribeIfElse(builder, as(origInst)); - - case kIROp_MakeDifferentialPair: - return transcribeMakeDifferentialPair(builder, as(origInst)); - case kIROp_DifferentialPairGetPrimal: - case kIROp_DifferentialPairGetDifferential: - return transcribeDifferentialPairGetElement(builder, origInst); - } - - // If none of the cases have been hit, check if the instruction is a - // type. Only need to explicitly differentiate types if they appear inside a block. - // - if (auto origType = as(origInst)) - { - // If this is a generic type, transcibe the parent - // generic and derive the type from the transcribed generic's - // return value. - // - if (as(origType->getParent()->getParent()) && - findInnerMostGenericReturnVal(as(origType->getParent()->getParent())) == origType && - !instsInProgress.Contains(origType->getParent()->getParent())) - { - auto origGenericType = origType->getParent()->getParent(); - auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType); - auto innerDiffGenericType = findInnerMostGenericReturnVal(as(diffGenericType)); - return InstPair( - origGenericType, - innerDiffGenericType - ); - } - else if (as(origType->getParent())) - return InstPair( - cloneInst(&cloneEnv, builder, origType), - differentiateType(builder, origType)); - else - return InstPair( - cloneInst(&cloneEnv, builder, origType), - nullptr); - } - - // Handle instructions with children - switch (origInst->getOp()) - { - case kIROp_Func: - return transcribeFuncHeader(builder, as(origInst)); - - case kIROp_Block: - return transcribeBlock(builder, as(origInst)); - - case kIROp_Generic: - return transcribeGeneric(builder, as(origInst)); - } - - - // If we reach this statement, the instruction type is likely unhandled. - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::unimplemented, - "this instruction cannot be differentiated"); - - return InstPair(nullptr, nullptr); - } -}; - - -struct BackwardDiffTranscriber -{ - - // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent - // their differential values. - Dictionary orginalToTranscribed; - - // Set of insts currently being transcribed. Used to avoid infinite loops. - HashSet instsInProgress; - - // Cloning environment to hold mapping from old to new copies for the primal - // instructions. - IRCloneEnv cloneEnv; - - // Diagnostic sink for error messages. - DiagnosticSink* sink; - - // Type conformance information. - AutoDiffSharedContext* autoDiffSharedContext; - - // Builder to help with creating and accessing the 'DifferentiablePair' struct - DifferentialPairTypeBuilder* pairBuilder; - - DifferentiableTypeConformanceContext differentiableTypeConformanceContext; - - List followUpFunctionsToTranscribe; - - // Map that stores the upper gradient given an IRInst* - Dictionary> upperGradients; - Dictionary primalToDiffPair; - - SharedIRBuilder* sharedBuilder; - // Witness table that `DifferentialBottom:IDifferential`. - IRWitnessTable* differentialBottomWitness = nullptr; - Dictionary differentialPairTypes; - - BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) - : autoDiffSharedContext(shared) - , sink(inSink) - , differentiableTypeConformanceContext(shared) - , sharedBuilder(inSharedBuilder) - {} - - DiagnosticSink* getSink() - { - SLANG_ASSERT(sink); - return sink; - } - - IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) - { - List newParameterTypes; - IRType* diffReturnType; - - for (UIndex i = 0; i < funcType->getParamCount(); i++) - { - auto origType = funcType->getParamType(i); - if (auto diffPairType = tryGetDiffPairType(builder, origType)) - { - auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); - newParameterTypes.add(inoutDiffPairType); - } - else - newParameterTypes.add(origType); - } - - newParameterTypes.add(funcType->getResultType()); - - diffReturnType = builder->getVoidType(); - - return builder->getFuncType(newParameterTypes, diffReturnType); - } - - IRWitnessTable* getDifferentialBottomWitness() - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); - auto result = - as(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - SLANG_ASSERT(result); - return result; - } - - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. - IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(inDiffPairType->parent); - auto diffPairType = as(inDiffPairType); - SLANG_ASSERT(diffPairType); - auto result = - as(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - if (result) - return result; - - auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); - auto diffType = differentiateType(&builder, diffPairType->getValueType()); - auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness()); - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - return table; - } - - IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); - } - - IRType* getOrCreateDiffPairType(IRInst* primalType) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - auto witness = as( - differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - if (!witness) - witness = getDifferentialBottomWitness(); - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); - } - - IRType* differentiateType(IRBuilder* builder, IRType* origType) - { - IRInst* diffType = nullptr; - if (!orginalToTranscribed.TryGetValue(origType, diffType)) - { - diffType = _differentiateTypeImpl(builder, origType); - orginalToTranscribed[origType] = diffType; - } - return (IRType*)diffType; - } - - IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) - { - if (auto ptrType = as(origType)) - return builder->getPtrType( - origType->getOp(), - differentiateType(builder, ptrType->getValueType())); - - // If there is an explicit primal version of this type in the local scope, load that - // otherwise use the original type. - // - IRInst* primalType = origType; - - // Special case certain compound types (PtrType, FuncType, etc..) - // otherwise try to lookup a differential definition for the given type. - // If one does not exist, then we assume it's not differentiable. - // - switch (primalType->getOp()) - { - case kIROp_Param: - if (as(primalType->getDataType())) - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( - builder, - (IRType*)primalType)); - else if (as(primalType->getDataType())) - return (IRType*)primalType; - - case kIROp_ArrayType: - { - auto primalArrayType = as(primalType); - if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) - return builder->getArrayType( - diffElementType, - primalArrayType->getElementCount()); - else - return nullptr; - } - - case kIROp_DifferentialPairType: - { - auto primalPairType = as(primalType); - return getOrCreateDiffPairType( - pairBuilder->getDiffTypeFromPairType(builder, primalPairType), - pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); - } - - case kIROp_FuncType: - return differentiateFunctionType(builder, as(primalType)); - - case kIROp_OutType: - if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType())) - return builder->getOutType(diffValueType); - else - return nullptr; - - case kIROp_InOutType: - if (auto diffValueType = differentiateType(builder, as(primalType)->getValueType())) - return builder->getInOutType(diffValueType); - else - return nullptr; - - case kIROp_TupleType: - { - auto tupleType = as(primalType); - List diffTypeList; - // TODO: what if we have type parameters here? - for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) - diffTypeList.add( - differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - - return builder->getTupleType(diffTypeList); - } - - default: - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); - } - } - - IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) - { - // If this is a PtrType (out, inout, etc..), then create diff pair from - // value type and re-apply the appropropriate PtrType wrapper. - // - if (auto origPtrType = as(primalType)) - { - if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) - return builder->getPtrType(primalType->getOp(), diffPairValueType); - else - return nullptr; - } - auto diffType = differentiateType(builder, primalType); - if (diffType) - return (IRType*)getOrCreateDiffPairType(primalType); - return nullptr; - } - - InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) - { - auto primalDataType = origParam->getDataType(); - // Do not differentiate generic type (and witness table) parameters - if (as(primalDataType) || as(primalDataType)) - { - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } - - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - IRInst* diffPairParam = builder->emitParam(diffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffPairParam); - - if (auto pairType = as(diffPairParam->getDataType())) - { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); - } - // If this is an `in/inout DifferentialPair<>` parameter, we can't produce - // its primal and diff parts right now because they would represent a reference - // to a pair field, which doesn't make sense since pair types are considered mutable. - // We encode the result as if the param is non-differentiable, and handle it - // with special care at load/store. - return InstPair(diffPairParam, nullptr); - } - - - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } - - // Returns "dp" to use as a name hint for parameters. - // If no primal name is available, returns a blank string. - // - String makeDiffPairName(IRInst* origVar) - { - if (auto namehintDecoration = origVar->findDecoration()) - { - return ("dp" + String(namehintDecoration->getName())); - } - - return String(""); - } - - - // In differential computation, the 'default' differential value is always zero. - // This is a consequence of differential computing being inherently linear. As a - // result, it's useful to have a method to generate zero literals of any (arithmetic) type. - // The current implementation requires that types are defined linearly. - // - IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) - { - if (auto diffType = differentiateType(builder, primalType)) - { - switch (diffType->getOp()) - { - case kIROp_DifferentialPairType: - return builder->emitMakeDifferentialPair( - diffType, - getDifferentialZeroOfType(builder, as(diffType)->getValueType()), - getDifferentialZeroOfType(builder, as(diffType)->getValueType())); - } - // Since primalType has a corresponding differential type, we can lookup the - // definition for zero(). - auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); - SLANG_ASSERT(zeroMethod); - - auto emptyArgList = List(); - return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); - } - else - { - if (isScalarIntegerType(primalType)) - { - return builder->getIntValue(primalType, 0); - } - - getSink()->diagnose(primalType->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; - } - } - - InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) - { - IRBuilder subBuilder(builder->getSharedBuilder()); - subBuilder.setInsertLoc(builder->getInsertLoc()); - - IRBlock* diffBlock = subBuilder.emitBlock(); - - subBuilder.setInsertInto(diffBlock); - - // First transcribe every parameter in the block. - for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - this->copyParam(&subBuilder, param); - - // The extra param for input gradient - auto gradParam = subBuilder.emitParam(as(origBlock->getParent()->getFullType())->getResultType()); - - // Then, run through every instruction and use the transcriber to generate the appropriate - // derivative code. - // - for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - this->copyInst(&subBuilder, child); - - auto lastInst = diffBlock->getLastOrdinaryInst(); - List grads = { gradParam }; - upperGradients.Add(lastInst, grads); - for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst()) - { - auto upperGrads = upperGradients.TryGetValue(child); - if (!upperGrads) - continue; - if (upperGrads->getCount() > 1) - { - auto sumGrad = upperGrads->getFirst(); - for (auto i = 1; i < upperGrads->getCount(); i++) - { - sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]); - } - this->transcribeInstBackward(&subBuilder, child, sumGrad); - } - else - this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst()); - } - - subBuilder.emitReturn(); - - return InstPair(diffBlock, diffBlock); - } - - // Create an empty func to represent the transcribed func of `origFunc`. - InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) - { - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertBefore(origFunc); - - IRFunc* primalFunc = origFunc; - - differentiableTypeConformanceContext.setFunc(origFunc); - - primalFunc = origFunc; - - auto diffFunc = builder.createFunc(); - - SLANG_ASSERT(as(origFunc->getFullType())); - IRType* diffFuncType = this->differentiateFunctionType( - &builder, - as(origFunc->getFullType())); - diffFunc->setFullType(diffFuncType); - - if (auto nameHint = origFunc->findDecoration()) - { - auto originalName = nameHint->getName(); - StringBuilder newNameSb; - newNameSb << "s_bwd_" << originalName; - builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); - } - builder.addBackwardDerivativeDecoration(origFunc, diffFunc); - - // Mark the generated derivative function itself as differentiable. - builder.addBackwardDifferentiableDecoration(diffFunc); - - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration()) - { - cloneDecoration(dictDecor, diffFunc); - } - - auto result = InstPair(primalFunc, diffFunc); - followUpFunctionsToTranscribe.add(result); - return result; - } - - // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) - { - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertInto(diffFunc); - - differentiableTypeConformanceContext.setFunc(primalFunc); - // Transcribe children from origFunc into diffFunc - for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribeBlock(&builder, block); - - return InstPair(primalFunc, diffFunc); - } - - IRInst* copyParam(IRBuilder* builder, IRParam* origParam) - { - auto primalDataType = origParam->getDataType(); - - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); - IRInst* diffParam = builder->emitParam(inoutDiffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffParam); - auto paramValue = builder->emitLoad(diffParam); - auto primal = builder->emitDifferentialPairGetPrimal(paramValue); - orginalToTranscribed.Add(origParam, primal); - primalToDiffPair.Add(primal, diffParam); - - return diffParam; - } - - - return cloneInst(&cloneEnv, builder, origParam); - } - - InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); - - auto origLeft = origArith->getOperand(0); - auto origRight = origArith->getOperand(1); - - IRInst* primalLeft; - if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft)) - { - primalLeft = origLeft; - } - IRInst* primalRight; - if (!orginalToTranscribed.TryGetValue(origRight, primalRight)) - { - primalRight = origRight; - } - - auto resultType = origArith->getDataType(); - IRInst* newInst; - switch (origArith->getOp()) - { - case kIROp_Add: - newInst = builder->emitAdd(resultType, primalLeft, primalRight); - break; - case kIROp_Mul: - newInst = builder->emitMul(resultType, primalLeft, primalRight); - break; - case kIROp_Sub: - newInst = builder->emitSub(resultType, primalLeft, primalRight); - break; - case kIROp_Div: - newInst = builder->emitDiv(resultType, primalLeft, primalRight); - break; - default: - newInst = nullptr; - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - orginalToTranscribed.Add(origArith, newInst); - return InstPair(newInst, nullptr); - } - - IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); - - auto lhs = origArith->getOperand(0); - auto rhs = origArith->getOperand(1); - - if (as(lhs->getDataType())) - { - lhs = builder->emitLoad(lhs); - lhs = builder->emitDifferentialPairGetPrimal(lhs); - } - if (as(rhs->getDataType())) - { - rhs = builder->emitLoad(rhs); - rhs = builder->emitDifferentialPairGetPrimal(rhs); - } - - IRInst* leftGrad; - IRInst* rightGrad; - - - switch (origArith->getOp()) - { - case kIROp_Add: - leftGrad = grad; - rightGrad = grad; - break; - case kIROp_Mul: - leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); - rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); - break; - case kIROp_Sub: - leftGrad = grad; - rightGrad = builder->emitNeg(grad->getDataType(), grad); - break; - case kIROp_Div: - leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); - rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad - break; - default: - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - - lhs = origArith->getOperand(0); - rhs = origArith->getOperand(1); - if (auto leftGrads = upperGradients.TryGetValue(lhs)) - { - leftGrads->add(leftGrad); - } - else - { - upperGradients.Add(lhs, leftGrad); - } - if (auto rightGrads = upperGradients.TryGetValue(rhs)) - { - rightGrads->add(rightGrad); - } - else - { - upperGradients.Add(rhs, rightGrad); - } - - return nullptr; - } - - InstPair copyInst(IRBuilder* builder, IRInst* origInst) - { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transcribeParam(builder, as(origInst)); - - case kIROp_Return: - return InstPair(nullptr, nullptr); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return copyBinaryArith(builder, origInst); - - default: - // Not yet implemented - SLANG_ASSERT(0); - } - - return InstPair(nullptr, nullptr); - } - - IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) - { - IRInOutType* inoutParam = as(param->getDataType()); - auto pairType = as(inoutParam->getValueType()); - auto paramValue = builder->emitLoad(param); - auto primal = builder->emitDifferentialPairGetPrimal(paramValue); - auto diff = builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - paramValue - ); - auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad); - auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff); - auto store = builder->emitStore(param, updatedParam); - - return store; - } - - IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) - { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transcribeParamBackward(builder, as(origInst), grad); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return transcribeBinaryArithBackward(builder, origInst, grad); - - case kIROp_DifferentialPairGetPrimal: - { - if (auto param = primalToDiffPair.TryGetValue(origInst)) - { - if (auto leftGrads = upperGradients.TryGetValue(*param)) - { - leftGrads->add(grad); - } - else - { - upperGradients.Add(*param, grad); - } - } - else - SLANG_ASSERT(0); - return nullptr; - } - - default: - // Not yet implemented - SLANG_ASSERT(0); - } - - return nullptr; - } -}; - - -struct JVPDerivativeContext : public InstPassBase -{ - - DiagnosticSink* getSink() - { - return sink; - } - - bool processModule() - { - // We start by initializing our shared IR building state, - // since we will re-use that state for any code we - // generate along the way. - // - SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; - sharedBuilder->init(module); - sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); - - // TODO(sai): Move this call. - transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary(); - - IRBuilder builderStorage(sharedBuilderStorage); - IRBuilder* builder = &builderStorage; - - // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by - // generating derivative code for the referenced function. - // - bool modified = processReferencedFunctions(builder); - - // Replaces IRDifferentialPairType with an auto-generated struct, - // IRDifferentialPairGetDifferential with 'differential' field access, - // IRDifferentialPairGetPrimal with 'primal' field access, and - // IRMakeDifferentialPair with an IRMakeStruct. - // - modified |= simplifyDifferentialBottomType(builder); - - // De-duplicate any remaining types. - sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); - - modified |= processPairTypes(builder, module->getModuleInst()); - - modified |= eliminateDifferentialBottomType(builder); - - return modified; - } - - IRInst* lookupJVPReference(IRInst* primalFunction) - { - if (auto jvpDefinition = primalFunction->findDecoration()) - return jvpDefinition->getForwardDerivativeFunc(); - return nullptr; - } - - // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate), - // then check that the referenced function is marked correctly for differentiation. - // - bool processReferencedFunctions(IRBuilder* builder) - { - List autoDiffWorkList; - - for (;;) - { - // Collect all `ForwardDifferentiate` insts from the module. - autoDiffWorkList.clear(); - processAllInsts([&](IRInst* inst) - { - switch (inst->getOp()) - { - case kIROp_ForwardDifferentiate: - case kIROp_BackwardDifferentiate: - // Only process now if the operand is a materialized function. - switch (inst->getOperand(0)->getOp()) - { - case kIROp_Func: - case kIROp_Specialize: - autoDiffWorkList.add(inst); - break; - default: - break; - } - break; - default: - break; - } - }); - - if (autoDiffWorkList.getCount() == 0) - break; - - // Process collected `ForwardDifferentiate` insts and replace them with placeholders for - // differentiated functions. - - transcriberStorage.followUpFunctionsToTranscribe.clear(); - backwardTranscriberStorage.followUpFunctionsToTranscribe.clear(); - - for (auto differentiateInst : autoDiffWorkList) - { - IRInst* baseInst = differentiateInst->getOperand(0); - if (as(differentiateInst)) - { - if (auto existingDiffFunc = lookupJVPReference(baseInst)) - { - differentiateInst->replaceUsesWith(existingDiffFunc); - differentiateInst->removeAndDeallocate(); - } - else if (isMarkedForForwardDifferentiation(baseInst)) - { - if (as(baseInst) || as(baseInst)) - { - IRInst* diffFunc = transcriberStorage.transcribe(builder, baseInst); - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); - } - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Requested differentiation on a function that isn't marked as differentiable."); - } - - } - else if (as(differentiateInst)) - { - if (isMarkedForBackwardDifferentiation(baseInst)) - { - if (as(baseInst)) - { - IRInst* diffFunc = - backwardTranscriberStorage - .transcribeFuncHeader(builder, (IRFunc*)baseInst) - .differential; - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); - } - } - } - } - // Actually synthesize the derivatives. - List followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe); - for (auto task : followUpWorkList) - { - auto diffFunc = as(task.differential); - SLANG_ASSERT(diffFunc); - auto primalFunc = as(task.primal); - SLANG_ASSERT(primalFunc); - - transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc); - } - followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe); - for (auto task : followUpWorkList) - { - auto diffFunc = as(task.differential); - SLANG_ASSERT(diffFunc); - auto primalFunc = as(task.primal); - SLANG_ASSERT(primalFunc); - - backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc); - } - - // Transcribing the function body really shouldn't produce more follow up function body work. - // However it may produce new `ForwardDifferentiate` instructions, which we collect and process - // in the next iteration. - SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0); - - } - return true; - } - - IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr) - { - builder->setInsertBefore(pairType); - auto loweredPairTypeInfo = (&pairBuilderStorage)->lowerDiffPairType( - builder, - pairType); - if (isTrivial) - *isTrivial = loweredPairTypeInfo.isTrivial; - return loweredPairTypeInfo.loweredType; - } - - IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) - { - - if (auto makePairInst = as(inst)) - { - bool isTrivial = false; - auto pairType = as(makePairInst->getDataType()); - if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial)) - { - builder->setInsertBefore(makePairInst); - IRInst* result = nullptr; - if (isTrivial) - { - result = makePairInst->getPrimalValue(); - } - else - { - IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() }; - result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands); - } - makePairInst->replaceUsesWith(result); - makePairInst->removeAndDeallocate(); - return result; - } - } - - return nullptr; - } - - IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst) - { - if (auto getDiffInst = as(inst)) - { - auto pairType = getDiffInst->getBase()->getDataType(); - if (auto pairPtrType = as(pairType)) - { - pairType = pairPtrType->getValueType(); - } - - if (lowerPairType(builder, pairType, nullptr)) - { - builder->setInsertBefore(getDiffInst); - IRInst* diffFieldExtract = nullptr; - diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); - getDiffInst->replaceUsesWith(diffFieldExtract); - getDiffInst->removeAndDeallocate(); - return diffFieldExtract; - } - } - else if (auto getPrimalInst = as(inst)) - { - auto pairType = getPrimalInst->getBase()->getDataType(); - if (auto pairPtrType = as(pairType)) - { - pairType = pairPtrType->getValueType(); - } - - if (lowerPairType(builder, pairType, nullptr)) - { - builder->setInsertBefore(getPrimalInst); - - IRInst* primalFieldExtract = nullptr; - primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); - getPrimalInst->replaceUsesWith(primalFieldExtract); - getPrimalInst->removeAndDeallocate(); - return primalFieldExtract; - } - } - - return nullptr; - } - - bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) - { - bool modified = false; - // Hoist all pair types to global scope when possible. - auto moduleInst = module->getModuleInst(); - processInstsOfType(kIROp_DifferentialPairType, [&](IRInst* originalPairType) - { - if (originalPairType->parent != moduleInst) - { - originalPairType->removeFromParent(); - ShortList operands; - for (UInt i = 0; i < originalPairType->getOperandCount(); i++) - { - operands.add(originalPairType->getOperand(i)); - } - auto newPairType = builder->findOrEmitHoistableInst( - originalPairType->getFullType(), - originalPairType->getOp(), - originalPairType->getOperandCount(), - operands.getArrayView().getBuffer()); - originalPairType->replaceUsesWith(newPairType); - originalPairType->removeAndDeallocate(); - } - }); - - sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); - - processAllInsts([&](IRInst* inst) - { - // Make sure the builder is at the right level. - builder->setInsertInto(instWithChildren); - - switch (inst->getOp()) - { - case kIROp_DifferentialPairGetDifferential: - case kIROp_DifferentialPairGetPrimal: - lowerPairAccess(builder, inst); - modified = true; - break; - - case kIROp_MakeDifferentialPair: - lowerMakePair(builder, inst); - modified = true; - break; - - default: - break; - } - }); - - processInstsOfType(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) - { - if (auto loweredType = lowerPairType(builder, inst)) - { - inst->replaceUsesWith(loweredType); - inst->removeAndDeallocate(); - } - }); - return modified; - } - - bool simplifyDifferentialBottomType(IRBuilder* builder) - { - bool modified = false; - auto diffBottom = builder->getDifferentialBottom(); - - bool changed = true; - List uses; - while (changed) - { - changed = false; - // Replace all insts whose type is `DifferentialBottomType` to `diffBottom`. - processAllInsts([&](IRInst* inst) - { - if (inst->getDataType() && inst->getDataType()->getOp() == kIROp_DifferentialBottomType) - { - if (inst != diffBottom) - { - inst->replaceUsesWith(diffBottom); - inst->removeAndDeallocate(); - modified = true; - } - } - }); - // Go through all uses of diffBottom and run simplification. - processAllInsts([&](IRInst* inst) - { - if (!inst->hasUses()) - return; - - builder->setInsertBefore(inst); - IRInst* valueToReplace = nullptr; - switch (inst->getOp()) - { - case kIROp_Store: - if (as(inst)->getVal() == diffBottom) - { - inst->removeAndDeallocate(); - changed = true; - } - return; - case kIROp_MakeDifferentialPair: - // Our simplification could lead to a situation where - // bottom is used to make a pair that has a non-bottom differential type, - // in this case we should use zero instead. - if (inst->getOperand(1) == diffBottom) - { - // Only apply if we are the second operand. - auto pairType = as(inst->getDataType()); - if (pairBuilderStorage.getDiffTypeFromPairType(builder, pairType)->getOp() != kIROp_DifferentialBottomType) - { - auto zero = transcriberStorage.getDifferentialZeroOfType(builder, pairType->getValueType()); - inst->setOperand(1, zero); - changed = true; - } - } - return; - case kIROp_DifferentialPairGetDifferential: - if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair) - { - valueToReplace = inst->getOperand(0)->getOperand(1); - } - break; - case kIROp_DifferentialPairGetPrimal: - if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair) - { - valueToReplace = inst->getOperand(0)->getOperand(0); - } - break; - case kIROp_Add: - if (inst->getOperand(0) == diffBottom) - { - valueToReplace = inst->getOperand(1); - } - else if (inst->getOperand(1) == diffBottom) - { - valueToReplace = inst->getOperand(0); - } - break; - case kIROp_Sub: - if (inst->getOperand(0) == diffBottom) - { - // If left is bottom, and right is not bottom, then we should return -right. - // However we can't possibly run into that case since both side of - operator - // must be at the same order of differentiation. - valueToReplace = diffBottom; - } - else if (inst->getOperand(1) == diffBottom) - { - valueToReplace = inst->getOperand(0); - } - break; - case kIROp_Mul: - case kIROp_Div: - if (inst->getOperand(0) == diffBottom) - { - valueToReplace = diffBottom; - } - else if (inst->getOperand(1) == diffBottom) - { - valueToReplace = diffBottom; - } - break; - default: - break; - } - if (valueToReplace) - { - inst->replaceUsesWith(valueToReplace); - changed = true; - } - }); - modified |= changed; - } - - return modified; - } - - bool eliminateDifferentialBottomType(IRBuilder* builder) - { - simplifyDifferentialBottomType(builder); - - bool modified = false; - auto diffBottom = builder->getDifferentialBottom(); - auto diffBottomType = diffBottom->getDataType(); - diffBottom->replaceUsesWith(builder->getVoidValue()); - diffBottom->removeAndDeallocate(); - diffBottomType->replaceUsesWith(builder->getVoidType()); - - return modified; - } - - // Checks decorators to see if the function should - // be differentiated (kIROp_ForwardDifferentiableDecoration) - // - bool isMarkedForForwardDifferentiation(IRInst* callable) - { - return callable->findDecoration() != nullptr; - } - - IRStringLit* getForwardDerivativeFuncName(IRInst* func) - { - IRBuilder builder(&sharedBuilderStorage); - builder.setInsertBefore(func); - - IRStringLit* name = nullptr; - if (auto linkageDecoration = func->findDecoration()) - { - name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_fwd_diff").getUnownedSlice()); - } - else if (auto namehintDecoration = func->findDecoration()) - { - name = builder.getStringValue((String(namehintDecoration->getName()) + "_fwd_diff").getUnownedSlice()); - } - - return name; - } - - // Checks decorators to see if the function should - // be differentiated (kIROp_ForwardDifferentiableDecoration) - // - bool isMarkedForBackwardDifferentiation(IRInst* callable) - { - return callable->findDecoration() != nullptr; - } - - IRStringLit* getBackwardDerivativeFuncName(IRInst* func) - { - IRBuilder builder(&sharedBuilderStorage); - builder.setInsertBefore(func); - - IRStringLit* name = nullptr; - if (auto linkageDecoration = func->findDecoration()) - { - name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice()); - } - else if (auto namehintDecoration = func->findDecoration()) - { - name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice()); - } - - return name; - } - - JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : - InstPassBase(module), - sink(sink), - autoDiffSharedContextStorage(module->getModuleInst()), - transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage), - backwardTranscriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage, sink) - { - autoDiffSharedContextStorage.sharedBuilder = &sharedBuilderStorage; - pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage; - transcriberStorage.sink = sink; - transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage); - transcriberStorage.pairBuilder = &(pairBuilderStorage); - backwardTranscriberStorage.pairBuilder = &pairBuilderStorage; - } - -protected: - // A transcriber object that handles the main job of - // processing instructions while maintaining state. - // - JVPTranscriber transcriberStorage; - - BackwardDiffTranscriber backwardTranscriberStorage; - - // Diagnostic object from the compile request for - // error messages. - DiagnosticSink* sink; - - // Context to find and manage the witness tables for types - // implementing `IDifferentiable` - AutoDiffSharedContext autoDiffSharedContextStorage; - - // Builder for dealing with differential pair types. - DifferentialPairTypeBuilder pairBuilderStorage; - -}; - -// Set up context and call main process method. -// -bool processDifferentiableFuncs( - IRModule* module, - DiagnosticSink* sink, - IRJVPDerivativePassOptions const&) -{ - // Simplify module to remove dead code. - IRDeadCodeEliminationOptions options; - options.keepExportsAlive = true; - options.keepLayoutsAlive = true; - eliminateDeadCode(module, options); - - JVPDerivativeContext context(module, sink); - bool changed = context.processModule(); - return changed; -} - -void stripAutoDiffDecorationsFromChildren(IRInst* parent) -{ - for (auto inst : parent->getChildren()) - { - for (auto decor = inst->getFirstDecoration(); decor; ) - { - auto next = decor->getNextDecoration(); - switch (decor->getOp()) - { - case kIROp_ForwardDerivativeDecoration: - case kIROp_DerivativeMemberDecoration: - case kIROp_DifferentiableTypeDictionaryDecoration: - decor->removeAndDeallocate(); - break; - default: - break; - } - decor = next; - } - - if (inst->getFirstChild() != nullptr) - { - stripAutoDiffDecorationsFromChildren(inst); - } - } -} - -void stripAutoDiffDecorations(IRModule* module) -{ - stripAutoDiffDecorationsFromChildren(module->getModuleInst()); -} - -AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst) - : moduleInst(inModuleInst) -{ - differentiableInterfaceType = as(findDifferentiableInterface()); - if (differentiableInterfaceType) - { - differentialAssocTypeStructKey = findDifferentialTypeStructKey(); - differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); - zeroMethodStructKey = findZeroMethodStructKey(); - addMethodStructKey = findAddMethodStructKey(); - mulMethodStructKey = findMulMethodStructKey(); - - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; - } -} - -IRInst* AutoDiffSharedContext::findDifferentiableInterface() -{ - if (auto module = as(moduleInst)) - { - for (auto globalInst : module->getGlobalInsts()) - { - // TODO: This seems like a particularly dangerous way to look for an interface. - // See if we can lower IDifferentiable to a separate IR inst. - // - if (globalInst->getOp() == kIROp_InterfaceType && - as(globalInst)->findDecoration()->getName() == "IDifferentiable") - { - return globalInst; - } - } - } - return nullptr; -} - -IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) -{ - if (as(moduleInst) && differentiableInterfaceType) - { - // Assume for now that IDifferentiable has exactly five fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); - if (auto entry = as(differentiableInterfaceType->getOperand(index))) - return as(entry->getRequirementKey()); - else - { - SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); - } - } - - return nullptr; -} - -void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) -{ - parentFunc = func; - - auto decor = func->findDecoration(); - SLANG_RELEASE_ASSERT(decor); - - // Build lookup dictionary for type witnesses. - for (auto child = decor->getFirstChild(); child; child = child->next) - { - if (auto item = as(child)) - { - auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); - if (existingItem) - { - if (auto witness = as(item->getWitness())) - { - if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType) - continue; - } - *existingItem = item->getWitness(); - } - else - { - differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); - } - } - } -} - - -// Lookup a witness table for the concreteType. One should exist if concreteType -// inherits (successfully) from IDifferentiable. -// - -IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) -{ - IRInst* foundResult = nullptr; - differentiableWitnessDictionary.TryGetValue(type, foundResult); - return foundResult; -} - -IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) -{ - if (auto conformance = lookUpConformanceForType(origType)) - { - return _lookupWitness(builder, conformance, key); - } - return nullptr; -} - -void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() -{ - for (auto globalInst : sharedContext->moduleInst->getChildren()) - { - if (auto pairType = as(globalInst)) - { - differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness()); - } - } -} -} diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h deleted file mode 100644 index 5e2a7f44f..000000000 --- a/source/slang/slang-ir-diff-jvp.h +++ /dev/null @@ -1,174 +0,0 @@ -// slang-ir-diff-jvp.h -#pragma once - -#include "slang-ir.h" -#include "slang-ir-insts.h" -#include "slang-compiler.h" - -namespace Slang -{ - template - struct DiffInstPair - { - P primal; - D differential; - DiffInstPair() = default; - DiffInstPair(P primal, D differential) : primal(primal), differential(differential) - {} - HashCode getHashCode() const - { - Hasher hasher; - hasher << primal << differential; - return hasher.getResult(); - } - bool operator ==(const DiffInstPair& other) const - { - return primal == other.primal && differential == other.differential; - } - }; - - typedef DiffInstPair InstPair; - - struct AutoDiffSharedContext - { - IRModuleInst* moduleInst = nullptr; - - SharedIRBuilder* sharedBuilder = nullptr; - - // A reference to the builtin IDifferentiable interface type. - // We use this to look up all the other types (and type exprs) - // that conform to a base type. - // - IRInterfaceType* differentiableInterfaceType = nullptr; - - // The struct key for the 'Differential' associated type - // defined inside IDifferential. We use this to lookup the differential - // type in the conformance table associated with the concrete type. - // - IRStructKey* differentialAssocTypeStructKey = nullptr; - - // The struct key for the witness that `Differential` associated type conforms to - // `IDifferential`. - IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; - - - // The struct key for the 'zero()' associated type - // defined inside IDifferential. We use this to lookup the - // implementation of zero() for a given type. - // - IRStructKey* zeroMethodStructKey = nullptr; - - // The struct key for the 'add()' associated type - // defined inside IDifferential. We use this to lookup the - // implementation of add() for a given type. - // - IRStructKey* addMethodStructKey = nullptr; - - IRStructKey* mulMethodStructKey = nullptr; - - - // Modules that don't use differentiable types - // won't have the IDifferentiable interface type available. - // Set to false to indicate that we are uninitialized. - // - bool isInterfaceAvailable = false; - - - AutoDiffSharedContext(IRModuleInst* inModuleInst); - - private: - - IRInst* findDifferentiableInterface(); - - IRStructKey* findDifferentialTypeStructKey() - { - return getIDifferentiableStructKeyAtIndex(0); - } - - IRStructKey* findDifferentialTypeWitnessStructKey() - { - return getIDifferentiableStructKeyAtIndex(1); - } - - IRStructKey* findZeroMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(2); - } - - IRStructKey* findAddMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(3); - } - - IRStructKey* findMulMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(4); - } - - IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); - }; - - struct DifferentiableTypeConformanceContext - { - AutoDiffSharedContext* sharedContext; - - IRGlobalValueWithCode* parentFunc = nullptr; - OrderedDictionary differentiableWitnessDictionary; - - DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) - : sharedContext(shared) - {} - - void setFunc(IRGlobalValueWithCode* func); - - void buildGlobalWitnessDictionary(); - - // Lookup a witness table for the concreteType. One should exist if concreteType - // inherits (successfully) from IDifferentiable. - // - IRInst* lookUpConformanceForType(IRInst* type); - - IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); - - // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. - // Note that inside a generic block, this will be a witness table lookup instruction - // that gets resolved during the specialization pass. - // - IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) - { - switch (origType->getOp()) - { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - case kIROp_VectorType: - return origType; - } - return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); - } - - IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); - } - - IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); - } - - }; - - struct IRJVPDerivativePassOptions - { - // Nothing for now.. - }; - - bool processDifferentiableFuncs( - IRModule* module, - DiagnosticSink* sink, - IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions()); - - void stripAutoDiffDecorations(IRModule* module); -} diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 8559103ae..3e85cc40e 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -6,7 +6,7 @@ #include "slang-ir-insts.h" #include "slang-mangle.h" #include "slang-ir-string-hash.h" -#include "slang-ir-diff-jvp.h" +#include "slang-ir-autodiff.h" #include "slang-module-library.h" #include "../compiler-core/slang-artifact.h" diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 61b6fcb76..57267a9ea 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8,7 +8,7 @@ #include "slang-ir-constexpr.h" #include "slang-ir-dce.h" #include "slang-ir-diff-call.h" -#include "slang-ir-diff-jvp.h" +#include "slang-ir-autodiff.h" #include "slang-ir-inline.h" #include "slang-ir-insts.h" #include "slang-ir-check-differentiability.h" -- cgit v1.2.3