diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-11-29 20:01:41 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-29 17:01:41 -0800 |
| commit | f5581786a1891cedb165adb1afe71fe34f26e030 (patch) | |
| tree | 86da2f1acbaec920ac0c38349897b293b405c021 | |
| parent | af7f40063dfed1c651d33b93956c7623a7d2c050 (diff) | |
Refactored reverse-mode implementation to use 4 separate passes. (#2531)
* Added partial implementation for reverse-mode
* Fixing several compile and runtime errors.
* Fixed several issues with reverse-mode passes.
* Fixed more issues. Basic reverse-mode tests passing
Co-authored-by: Edward Liu <shiqiu1105@gmail.com>
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj | 3 | ||||
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj.filters | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 2550 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 165 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-propagate.h | 102 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 196 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 420 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 110 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 54 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 14 |
11 files changed, 2294 insertions, 1336 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index fe1922d29..ff8a09599 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -335,7 +335,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-ir-augment-make-existential.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-propagate.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-transpose.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-unzip.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-bind-existentials.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-byte-address-legalize.h" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 6fa42287c..3c29fd21b 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -138,9 +138,18 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-propagate.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-transpose.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-unzip.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff.h">
<Filter>Header Files</Filter>
</ClInclude>
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index c93522565..0ad9ce87c 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -12,167 +12,129 @@ namespace Slang { -struct JVPTranscriber +DiagnosticSink* ForwardDerivativeTranscriber::getSink() { + SLANG_ASSERT(sink); + return sink; +} - // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent - // their differential values. - Dictionary<IRInst*, IRInst*> instMapD; - - // Set of insts currently being transcribed. Used to avoid infinite loops. - HashSet<IRInst*> instsInProgress; - - // Cloning environment to hold mapping from old to new copies for the primal - // instructions. - IRCloneEnv cloneEnv; - - // Diagnostic sink for error messages. - DiagnosticSink* sink; - - // Type conformance information. - AutoDiffSharedContext* autoDiffSharedContext; - - // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct - DifferentialPairTypeBuilder* pairBuilder; - - DifferentiableTypeConformanceContext differentiableTypeConformanceContext; - - List<InstPair> followUpFunctionsToTranscribe; - - SharedIRBuilder* sharedBuilder; - // Witness table that `DifferentialBottom:IDifferential`. - IRWitnessTable* differentialBottomWitness = nullptr; - Dictionary<InstPair, IRInst*> differentialPairTypes; - - JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder) - : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder) - { - - } - - DiagnosticSink* getSink() - { - SLANG_ASSERT(sink); - return sink; - } - - void mapDifferentialInst(IRInst* origInst, IRInst* diffInst) +void ForwardDerivativeTranscriber::mapDifferentialInst(IRInst* origInst, IRInst* diffInst) +{ + if (hasDifferentialInst(origInst)) { - if (hasDifferentialInst(origInst)) + if (lookupDiffInst(origInst) != diffInst) { - if (lookupDiffInst(origInst) != diffInst) - { - SLANG_UNEXPECTED("Inconsistent differential mappings"); - } - } - else - { - instMapD.Add(origInst, diffInst); + SLANG_UNEXPECTED("Inconsistent differential mappings"); } } - - void mapPrimalInst(IRInst* origInst, IRInst* primalInst) + else { - if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst) - { - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::internalCompilerError, - "inconsistent primal instruction for original"); - } - else - { - cloneEnv.mapOldValToNew[origInst] = primalInst; - } + instMapD.Add(origInst, diffInst); } +} - IRInst* lookupDiffInst(IRInst* origInst) +void ForwardDerivativeTranscriber::mapPrimalInst(IRInst* origInst, IRInst* primalInst) +{ + if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst) { - return instMapD[origInst]; + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::internalCompilerError, + "inconsistent primal instruction for original"); } - - IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst) + else { - return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst; + cloneEnv.mapOldValToNew[origInst] = primalInst; } +} - bool hasDifferentialInst(IRInst* origInst) - { - return instMapD.ContainsKey(origInst); - } +IRInst* ForwardDerivativeTranscriber::lookupDiffInst(IRInst* origInst) +{ + return instMapD[origInst]; +} - bool shouldUseOriginalAsPrimal(IRInst* origInst) - { - if (as<IRGlobalValueWithCode>(origInst)) - return true; - if (origInst->parent && origInst->parent->getOp() == kIROp_Module) - return true; - return false; - } +IRInst* ForwardDerivativeTranscriber::lookupDiffInst(IRInst* origInst, IRInst* defaultInst) +{ + return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst; +} - IRInst* lookupPrimalInst(IRInst* origInst) - { - if (!origInst) - return nullptr; - if (shouldUseOriginalAsPrimal(origInst)) - return origInst; - return cloneEnv.mapOldValToNew[origInst]; - } +bool ForwardDerivativeTranscriber::hasDifferentialInst(IRInst* origInst) +{ + return instMapD.ContainsKey(origInst); +} - IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) - { - if (!origInst) - return nullptr; - return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; - } +bool ForwardDerivativeTranscriber::shouldUseOriginalAsPrimal(IRInst* origInst) +{ + if (as<IRGlobalValueWithCode>(origInst)) + return true; + if (origInst->parent && origInst->parent->getOp() == kIROp_Module) + return true; + return false; +} + +IRInst* ForwardDerivativeTranscriber::lookupPrimalInst(IRInst* origInst) +{ + if (!origInst) + return nullptr; + if (shouldUseOriginalAsPrimal(origInst)) + return origInst; + return cloneEnv.mapOldValToNew[origInst]; +} + +IRInst* ForwardDerivativeTranscriber::lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) +{ + if (!origInst) + return nullptr; + return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; +} - bool hasPrimalInst(IRInst* origInst) +bool ForwardDerivativeTranscriber::hasPrimalInst(IRInst* origInst) +{ + if (!origInst) + return true; + if (shouldUseOriginalAsPrimal(origInst)) + return true; + return cloneEnv.mapOldValToNew.ContainsKey(origInst); +} + +IRInst* ForwardDerivativeTranscriber::findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst) +{ + if (!hasDifferentialInst(origInst)) { - if (!origInst) - return true; - if (shouldUseOriginalAsPrimal(origInst)) - return true; - return cloneEnv.mapOldValToNew.ContainsKey(origInst); + transcribe(builder, origInst); + SLANG_ASSERT(hasDifferentialInst(origInst)); } - IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst) - { - if (!hasDifferentialInst(origInst)) - { - transcribe(builder, origInst); - SLANG_ASSERT(hasDifferentialInst(origInst)); - } + return lookupDiffInst(origInst); +} - return lookupDiffInst(origInst); - } +IRInst* ForwardDerivativeTranscriber::findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst) +{ + if (shouldUseOriginalAsPrimal(origInst)) + return origInst; - IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst) + if (!hasPrimalInst(origInst)) { - if (shouldUseOriginalAsPrimal(origInst)) - return origInst; + transcribe(builder, origInst); + SLANG_ASSERT(hasPrimalInst(origInst)); + } - if (!hasPrimalInst(origInst)) - { - transcribe(builder, origInst); - SLANG_ASSERT(hasPrimalInst(origInst)); - } + return lookupPrimalInst(origInst); +} - return lookupPrimalInst(origInst); - } +IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) +{ + List<IRType*> newParameterTypes; + IRType* diffReturnType; - IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + for (UIndex i = 0; i < funcType->getParamCount(); i++) { - List<IRType*> newParameterTypes; - IRType* diffReturnType; - - for (UIndex i = 0; i < funcType->getParamCount(); i++) - { - auto origType = funcType->getParamType(i); - origType = (IRType*) lookupPrimalInst(origType, origType); - if (auto diffPairType = tryGetDiffPairType(builder, origType)) - newParameterTypes.add(diffPairType); - else - newParameterTypes.add(origType); - } + auto origType = funcType->getParamType(i); + origType = (IRType*) lookupPrimalInst(origType, origType); + if (auto diffPairType = tryGetDiffPairType(builder, origType)) + newParameterTypes.add(diffPairType); + else + newParameterTypes.add(origType); + } // Transcribe return type to a pair. // This will be void if the primal return type is non-differentiable. @@ -183,562 +145,562 @@ struct JVPTranscriber else diffReturnType = origResultType; - return builder->getFuncType(newParameterTypes, diffReturnType); - } + return builder->getFuncType(newParameterTypes, diffReturnType); +} - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. - IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(inDiffPairType->parent); - auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); - SLANG_ASSERT(diffPairType); +// Get or construct `:IDifferentiable` conformance for a DifferentiablePair. +IRWitnessTable* ForwardDerivativeTranscriber::getDifferentialPairWitness(IRInst* inDiffPairType) +{ + IRBuilder builder(sharedBuilder); + builder.setInsertInto(inDiffPairType->parent); + auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); + SLANG_ASSERT(diffPairType); - auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = differentiateType(&builder, diffPairType); + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = differentiateType(&builder, diffPairType); - // And place it in the synthesized witness table. - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + // And place it in the synthesized witness table. + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - // Record this in the context for future lookups - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + // Record this in the context for future lookups + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - return table; - } + return table; +} - IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); - } +IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) +{ + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); +} - IRType* getOrCreateDiffPairType(IRInst* primalType) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - auto witness = as<IRWitnessTable>( - differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); +IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType) +{ + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + auto witness = as<IRWitnessTable>( + differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - if (!witness) + if (!witness) + { + if (auto primalPairType = as<IRDifferentialPairType>(primalType)) { - if (auto primalPairType = as<IRDifferentialPairType>(primalType)) - { - witness = getDifferentialPairWitness(primalPairType); - } + witness = getDifferentialPairWitness(primalPairType); } - - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); } - IRType* differentiateType(IRBuilder* builder, IRType* origType) + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); +} + +IRType* ForwardDerivativeTranscriber::differentiateType(IRBuilder* builder, IRType* origType) +{ + IRInst* diffType = nullptr; + if (!instMapD.TryGetValue(origType, diffType)) { - IRInst* diffType = nullptr; - if (!instMapD.TryGetValue(origType, diffType)) - { - diffType = _differentiateTypeImpl(builder, origType); - instMapD[origType] = diffType; - } - return (IRType*)diffType; + diffType = _differentiateTypeImpl(builder, origType); + instMapD[origType] = diffType; } + return (IRType*)diffType; +} - IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) - { - if (auto ptrType = as<IRPtrTypeBase>(origType)) - return builder->getPtrType( - origType->getOp(), - differentiateType(builder, ptrType->getValueType())); +IRType* ForwardDerivativeTranscriber::_differentiateTypeImpl(IRBuilder* builder, IRType* origType) +{ + if (auto ptrType = as<IRPtrTypeBase>(origType)) + return builder->getPtrType( + origType->getOp(), + differentiateType(builder, ptrType->getValueType())); - // If there is an explicit primal version of this type in the local scope, load that - // otherwise use the original type. - // - IRInst* primalType = lookupPrimalInst(origType, origType); - - // Special case certain compound types (PtrType, FuncType, etc..) - // otherwise try to lookup a differential definition for the given type. - // If one does not exist, then we assume it's not differentiable. - // - switch (primalType->getOp()) + // If there is an explicit primal version of this type in the local scope, load that + // otherwise use the original type. + // + IRInst* primalType = lookupPrimalInst(origType, origType); + + // Special case certain compound types (PtrType, FuncType, etc..) + // otherwise try to lookup a differential definition for the given type. + // If one does not exist, then we assume it's not differentiable. + // + switch (primalType->getOp()) + { + case kIROp_Param: + if (as<IRTypeType>(primalType->getDataType())) + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( + builder, + (IRType*)primalType)); + else if (as<IRWitnessTableType>(primalType->getDataType())) + return (IRType*)primalType; + + case kIROp_ArrayType: { - case kIROp_Param: - if (as<IRTypeType>(primalType->getDataType())) - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( - builder, - (IRType*)primalType)); - else if (as<IRWitnessTableType>(primalType->getDataType())) - return (IRType*)primalType; - - case kIROp_ArrayType: - { - auto primalArrayType = as<IRArrayType>(primalType); - if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) - return builder->getArrayType( - diffElementType, - primalArrayType->getElementCount()); - else - return nullptr; - } + auto primalArrayType = as<IRArrayType>(primalType); + if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) + return builder->getArrayType( + diffElementType, + primalArrayType->getElementCount()); + else + return nullptr; + } - case kIROp_DifferentialPairType: - { - auto primalPairType = as<IRDifferentialPairType>(primalType); - return getOrCreateDiffPairType( - pairBuilder->getDiffTypeFromPairType(builder, primalPairType), - pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); - } - - case kIROp_FuncType: - return differentiateFunctionType(builder, as<IRFuncType>(primalType)); + case kIROp_DifferentialPairType: + { + auto primalPairType = as<IRDifferentialPairType>(primalType); + return getOrCreateDiffPairType( + pairBuilder->getDiffTypeFromPairType(builder, primalPairType), + pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_FuncType: + return differentiateFunctionType(builder, as<IRFuncType>(primalType)); - case kIROp_OutType: - if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) - return builder->getOutType(diffValueType); - else - return nullptr; + case kIROp_OutType: + if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) + return builder->getOutType(diffValueType); + else + return nullptr; - case kIROp_InOutType: - if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) - return builder->getInOutType(diffValueType); - else - return nullptr; + case kIROp_InOutType: + if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) + return builder->getInOutType(diffValueType); + else + return nullptr; - case kIROp_TupleType: - { - auto tupleType = as<IRTupleType>(primalType); - List<IRType*> diffTypeList; - // TODO: what if we have type parameters here? - for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) - diffTypeList.add( - differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - - return builder->getTupleType(diffTypeList); - } + case kIROp_TupleType: + { + auto tupleType = as<IRTupleType>(primalType); + List<IRType*> diffTypeList; + // TODO: what if we have type parameters here? + for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + diffTypeList.add( + differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - default: - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); + return builder->getTupleType(diffTypeList); } + + default: + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); } +} - IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) - { - // If this is a PtrType (out, inout, etc..), then create diff pair from - // value type and re-apply the appropropriate PtrType wrapper. - // - if (auto origPtrType = as<IRPtrTypeBase>(primalType)) - { - if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) - return builder->getPtrType(primalType->getOp(), diffPairValueType); - else - return nullptr; - } - auto diffType = differentiateType(builder, primalType); - if (diffType) - return (IRType*)getOrCreateDiffPairType(primalType); - return nullptr; +IRType* ForwardDerivativeTranscriber::tryGetDiffPairType(IRBuilder* builder, IRType* primalType) +{ + // If this is a PtrType (out, inout, etc..), then create diff pair from + // value type and re-apply the appropropriate PtrType wrapper. + // + if (auto origPtrType = as<IRPtrTypeBase>(primalType)) + { + if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) + return builder->getPtrType(primalType->getOp(), diffPairValueType); + else + return nullptr; } + auto diffType = differentiateType(builder, primalType); + if (diffType) + return (IRType*)getOrCreateDiffPairType(primalType); + return nullptr; +} - InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) +InstPair ForwardDerivativeTranscriber::transcribeParam(IRBuilder* builder, IRParam* origParam) +{ + auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType()); + // Do not differentiate generic type (and witness table) parameters + if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) { - auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType()); - // Do not differentiate generic type (and witness table) parameters - if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) - { - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } - // Is this param a phi node or a function parameter? - auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent()); - bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); - if (isFuncParam) + // Is this param a phi node or a function parameter? + auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent()); + bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); + if (isFuncParam) + { + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) { - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - IRInst* diffPairParam = builder->emitParam(diffPairType); + IRInst* diffPairParam = builder->emitParam(diffPairType); - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - SLANG_ASSERT(diffPairParam); - - if (auto pairType = as<IRDifferentialPairType>(diffPairType)) - { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); - } - else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) - { - auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType()); - - return InstPair( - builder->emitDifferentialPairAddressPrimal(diffPairParam), - builder->emitDifferentialPairAddressDifferential( - builder->getPtrType( - kIROp_PtrType, - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), - diffPairParam)); - } + SLANG_ASSERT(diffPairParam); + + if (auto pairType = as<IRDifferentialPairType>(diffPairType)) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); } - - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } - else - { - auto primal = cloneInst(&cloneEnv, builder, origParam); - IRInst* diff = nullptr; - if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) + else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) { - diff = builder->emitParam(diffType); + auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType()); + + return InstPair( + builder->emitDifferentialPairAddressPrimal(diffPairParam), + builder->emitDifferentialPairAddressDifferential( + builder->getPtrType( + kIROp_PtrType, + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), + diffPairParam)); } - return InstPair(primal, diff); } - - } - // Returns "d<var-name>" to use as a name hint for variables and parameters. - // If no primal name is available, returns a blank string. - // - String getJVPVarName(IRInst* origVar) + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + else { - if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) + auto primal = cloneInst(&cloneEnv, builder, origParam); + IRInst* diff = nullptr; + if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) { - return ("d" + String(namehintDecoration->getName())); + diff = builder->emitParam(diffType); } - - return String(""); + return InstPair(primal, diff); } + +} - // Returns "dp<var-name>" to use as a name hint for parameters. - // If no primal name is available, returns a blank string. - // - String makeDiffPairName(IRInst* origVar) +// Returns "d<var-name>" to use as a name hint for variables and parameters. +// If no primal name is available, returns a blank string. +// +String ForwardDerivativeTranscriber::getJVPVarName(IRInst* origVar) +{ + if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) { - if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) - { - return ("dp" + String(namehintDecoration->getName())); - } - - return String(""); + return ("d" + String(namehintDecoration->getName())); } - InstPair transcribeVar(IRBuilder* builder, IRVar* origVar) + return String(""); +} + +// Returns "dp<var-name>" to use as a name hint for parameters. +// If no primal name is available, returns a blank string. +// +String ForwardDerivativeTranscriber::makeDiffPairName(IRInst* origVar) +{ + if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) { - if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) - { - IRVar* diffVar = builder->emitVar(diffType); - SLANG_ASSERT(diffVar); + return ("dp" + String(namehintDecoration->getName())); + } - auto diffNameHint = getJVPVarName(origVar); - if (diffNameHint.getLength() > 0) - builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice()); + return String(""); +} - return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar); - } +InstPair ForwardDerivativeTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar) +{ + if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) + { + IRVar* diffVar = builder->emitVar(diffType); + SLANG_ASSERT(diffVar); - return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr); + auto diffNameHint = getJVPVarName(origVar); + if (diffNameHint.getLength() > 0) + builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice()); + + return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar); } - InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); + return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr); +} - IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith); - - auto origLeft = origArith->getOperand(0); - auto origRight = origArith->getOperand(1); +InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder, IRInst* origArith) +{ + SLANG_ASSERT(origArith->getOperandCount() == 2); - auto primalLeft = findOrTranscribePrimalInst(builder, origLeft); - auto primalRight = findOrTranscribePrimalInst(builder, origRight); + IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith); - auto diffLeft = findOrTranscribeDiffInst(builder, origLeft); - auto diffRight = findOrTranscribeDiffInst(builder, origRight); + auto origLeft = origArith->getOperand(0); + auto origRight = origArith->getOperand(1); + auto primalLeft = findOrTranscribePrimalInst(builder, origLeft); + auto primalRight = findOrTranscribePrimalInst(builder, origRight); - if (diffLeft || diffRight) - { - diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType()); - diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); + auto diffLeft = findOrTranscribeDiffInst(builder, origLeft); + auto diffRight = findOrTranscribeDiffInst(builder, origRight); - auto resultType = primalArith->getDataType(); - switch(origArith->getOp()) - { - case kIROp_Add: - return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight)); - case kIROp_Mul: - return InstPair(primalArith, builder->emitAdd(resultType, - builder->emitMul(resultType, diffLeft, primalRight), - builder->emitMul(resultType, primalLeft, diffRight))); - case kIROp_Sub: - return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight)); - case kIROp_Div: - return InstPair(primalArith, builder->emitDiv(resultType, - builder->emitSub( - resultType, - builder->emitMul(resultType, diffLeft, primalRight), - builder->emitMul(resultType, primalLeft, diffRight)), - builder->emitMul( - primalRight->getDataType(), primalRight, primalRight - ))); - default: - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - } - - return InstPair(primalArith, nullptr); - } - - InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic) + if (diffLeft || diffRight) { - SLANG_ASSERT(origLogic->getOperandCount() == 2); + diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType()); + diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); - // TODO: Check other boolean cases. - if (as<IRBoolType>(origLogic->getDataType())) + auto resultType = primalArith->getDataType(); + switch(origArith->getOp()) { - // Boolean operations are not differentiable. For the linearization - // pass, we do not need to do anything but copy them over to the ne - // function. - auto primalLogic = cloneInst(&cloneEnv, builder, origLogic); - return InstPair(primalLogic, nullptr); + case kIROp_Add: + return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight)); + case kIROp_Mul: + return InstPair(primalArith, builder->emitAdd(resultType, + builder->emitMul(resultType, diffLeft, primalRight), + builder->emitMul(resultType, primalLeft, diffRight))); + case kIROp_Sub: + return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight)); + case kIROp_Div: + return InstPair(primalArith, builder->emitDiv(resultType, + builder->emitSub( + resultType, + builder->emitMul(resultType, diffLeft, primalRight), + builder->emitMul(resultType, primalLeft, diffRight)), + builder->emitMul( + primalRight->getDataType(), primalRight, primalRight + ))); + default: + getSink()->diagnose(origArith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); } - - SLANG_UNEXPECTED("Logical operation with non-boolean result"); } - InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad) + return InstPair(primalArith, nullptr); +} + + +InstPair ForwardDerivativeTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic) +{ + SLANG_ASSERT(origLogic->getOperandCount() == 2); + + // TODO: Check other boolean cases. + if (as<IRBoolType>(origLogic->getDataType())) { - auto origPtr = origLoad->getPtr(); - auto primalPtr = lookupPrimalInst(origPtr, nullptr); - auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType(); + // Boolean operations are not differentiable. For the linearization + // pass, we do not need to do anything but copy them over to the ne + // function. + auto primalLogic = cloneInst(&cloneEnv, builder, origLogic); + return InstPair(primalLogic, nullptr); + } + + SLANG_UNEXPECTED("Logical operation with non-boolean result"); +} - if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType)) - { - // Special case load from an `out` param, which will not have corresponding `diff` and - // `primal` insts yet. - - auto load = builder->emitLoad(primalPtr); - auto primalElement = builder->emitDifferentialPairGetPrimal(load); - auto diffElement = builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); - return InstPair(primalElement, diffElement); - } +InstPair ForwardDerivativeTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad) +{ + auto origPtr = origLoad->getPtr(); + auto primalPtr = lookupPrimalInst(origPtr, nullptr); + auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType(); - auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); - IRInst* diffLoad = nullptr; - if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) - { - // Default case, we're loading from a known differential inst. - diffLoad = as<IRLoad>(builder->emitLoad(diffPtr)); - } - return InstPair(primalLoad, diffLoad); + if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType)) + { + // Special case load from an `out` param, which will not have corresponding `diff` and + // `primal` insts yet. + + auto load = builder->emitLoad(primalPtr); + auto primalElement = builder->emitDifferentialPairGetPrimal(load); + auto diffElement = builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); + return InstPair(primalElement, diffElement); } - InstPair transcribeStore(IRBuilder* builder, IRStore* origStore) + auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + IRInst* diffLoad = nullptr; + if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) { - IRInst* origStoreLocation = origStore->getPtr(); - IRInst* origStoreVal = origStore->getVal(); - auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr); - auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); - auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr); - auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); + // Default case, we're loading from a known differential inst. + diffLoad = as<IRLoad>(builder->emitLoad(diffPtr)); + } + return InstPair(primalLoad, diffLoad); +} - if (!diffStoreLocation) +InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRStore* origStore) +{ + IRInst* origStoreLocation = origStore->getPtr(); + IRInst* origStoreVal = origStore->getVal(); + auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr); + auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); + auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr); + auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); + + if (!diffStoreLocation) + { + auto primalLocationPtrType = as<IRPtrTypeBase>(primalStoreLocation->getDataType()); + if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType())) { - auto primalLocationPtrType = as<IRPtrTypeBase>(primalStoreLocation->getDataType()); - if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType())) - { - auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal); - auto store = builder->emitStore(primalStoreLocation, valToStore); - return InstPair(store, nullptr); - } + auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal); + auto store = builder->emitStore(primalStoreLocation, valToStore); + return InstPair(store, nullptr); } + } - auto primalStore = cloneInst(&cloneEnv, builder, origStore); - - IRInst* diffStore = nullptr; + auto primalStore = cloneInst(&cloneEnv, builder, origStore); - // If the stored value has a differential version, - // emit a store instruction for the differential parameter. - // Otherwise, emit nothing since there's nothing to load. - // - if (diffStoreLocation && diffStoreVal) - { - // Default case, storing the entire type (and not a member) - diffStore = as<IRStore>( - builder->emitStore(diffStoreLocation, diffStoreVal)); - - return InstPair(primalStore, diffStore); - } + IRInst* diffStore = nullptr; - return InstPair(primalStore, nullptr); + // If the stored value has a differential version, + // emit a store instruction for the differential parameter. + // Otherwise, emit nothing since there's nothing to load. + // + if (diffStoreLocation && diffStoreVal) + { + // Default case, storing the entire type (and not a member) + diffStore = as<IRStore>( + builder->emitStore(diffStoreLocation, diffStoreVal)); + + return InstPair(primalStore, diffStore); } - InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn) + return InstPair(primalStore, nullptr); +} + +InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRReturn* origReturn) +{ + IRInst* origReturnVal = origReturn->getVal(); + + auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType()); + if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal)) { - IRInst* origReturnVal = origReturn->getVal(); + // If the return value is itself a function, generic or a struct then this + // is likely to be a generic scope. In this case, we lookup the differential + // and return that. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); + + // Neither of these should be nullptr. + SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal); + IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal)); - auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType()); - if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal)) - { - // If the return value is itself a function, generic or a struct then this - // is likely to be a generic scope. In this case, we lookup the differential - // and return that. - IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); - - // Neither of these should be nullptr. - SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal); - IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal)); - - return InstPair(diffReturn, diffReturn); - } - else if (auto pairType = tryGetDiffPairType(builder, returnDataType)) - { - IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); - if(!diffReturnVal) - diffReturnVal = getDifferentialZeroOfType(builder, returnDataType); - - // If the pair type can be formed, this must be non-null. - SLANG_RELEASE_ASSERT(diffReturnVal); - - auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); - IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair)); - return InstPair(pairReturn, pairReturn); - } - else - { - // If the return type is not differentiable, emit the primal value only. - IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + return InstPair(diffReturn, diffReturn); + } + else if (auto pairType = tryGetDiffPairType(builder, returnDataType)) + { + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); + if(!diffReturnVal) + diffReturnVal = getDifferentialZeroOfType(builder, returnDataType); - IRInst* primalReturn = builder->emitReturn(primalReturnVal); - return InstPair(primalReturn, nullptr); - - } + // If the pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffReturnVal); + + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); + IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair)); + return InstPair(pairReturn, pairReturn); } + else + { + // If the return type is not differentiable, emit the primal value only. + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); - // Since int/float literals are sometimes nested inside an IRConstructor - // instruction, we check to make sure that the nested instr is a constant - // and then return nullptr. Literals do not need to be differentiated. - // - InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) - { - IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct); + IRInst* primalReturn = builder->emitReturn(primalReturnVal); + return InstPair(primalReturn, nullptr); - // Check if the output type can be differentiated. If it cannot be - // differentiated, don't differentiate the inst - // - auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType()); - if (auto diffConstructType = differentiateType(builder, primalConstructType)) - { - UCount operandCount = origConstruct->getOperandCount(); - - List<IRInst*> diffOperands; - for (UIndex ii = 0; ii < operandCount; ii++) - { - // If the operand has a differential version, replace the original with - // the differential. Otherwise, use a zero. - // - if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr)) - diffOperands.add(diffInst); - else - { - auto operandDataType = origConstruct->getOperand(ii)->getDataType(); - operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType); - diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); - } - } - - return InstPair( - primalConstruct, - builder->emitIntrinsicInst( - diffConstructType, - origConstruct->getOp(), - operandCount, - diffOperands.getBuffer())); - } - else - { - return InstPair(primalConstruct, nullptr); - } } +} - // Differentiating a call instruction here is primarily about generating - // an appropriate call list based on whichever parameters have differentials - // in the current transcription context. +// Since int/float literals are sometimes nested inside an IRConstructor +// instruction, we check to make sure that the nested instr is a constant +// and then return nullptr. Literals do not need to be differentiated. +// +InstPair ForwardDerivativeTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) +{ + IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct); + + // Check if the output type can be differentiated. If it cannot be + // differentiated, don't differentiate the inst // - InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) - { - - IRInst* origCallee = origCall->getCallee(); + auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType()); + if (auto diffConstructType = differentiateType(builder, primalConstructType)) + { + UCount operandCount = origConstruct->getOperandCount(); - if (!origCallee) + List<IRInst*> diffOperands; + for (UIndex ii = 0; ii < operandCount; ii++) { - // Note that this can only happen if the callee is a result - // of a higher-order operation. For now, we assume that we cannot - // differentiate such calls safely. - // TODO(sai): Should probably get checked in the front-end. - // - getSink()->diagnose(origCall->sourceLoc, - Diagnostics::internalCompilerError, - "attempting to differentiate unresolved callee"); - - return InstPair(nullptr, nullptr); + // If the operand has a differential version, replace the original with + // the differential. Otherwise, use a zero. + // + if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr)) + diffOperands.add(diffInst); + else + { + auto operandDataType = origConstruct->getOperand(ii)->getDataType(); + operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType); + diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); + } } + + return InstPair( + primalConstruct, + builder->emitIntrinsicInst( + diffConstructType, + origConstruct->getOp(), + operandCount, + diffOperands.getBuffer())); + } + else + { + return InstPair(primalConstruct, nullptr); + } +} - // Since concrete functions are globals, the primal callee is the same - // as the original callee. +// Differentiating a call instruction here is primarily about generating +// an appropriate call list based on whichever parameters have differentials +// in the current transcription context. +// +InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall* origCall) +{ + + IRInst* origCallee = origCall->getCallee(); + + if (!origCallee) + { + // Note that this can only happen if the callee is a result + // of a higher-order operation. For now, we assume that we cannot + // differentiate such calls safely. + // TODO(sai): Should probably get checked in the front-end. // - auto primalCallee = origCallee; + getSink()->diagnose(origCall->sourceLoc, + Diagnostics::internalCompilerError, + "attempting to differentiate unresolved callee"); + + return InstPair(nullptr, nullptr); + } - IRInst* diffCallee = nullptr; + // Since concrete functions are globals, the primal callee is the same + // as the original callee. + // + auto primalCallee = origCallee; - if (instMapD.TryGetValue(origCallee, diffCallee)) - { - } - else if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>()) - { - // If the user has already provided an differentiated implementation, use that. - diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc(); - } - else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>()) - { - // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass - // to generate the implementation. - diffCallee = builder->emitForwardDifferentiateInst( - differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), - primalCallee); - } - else - { - // The callee is non differentiable, just return primal value with null diff value. - IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); - return InstPair(primalCall, nullptr); - } + IRInst* diffCallee = nullptr; - List<IRInst*> args; - // Go over the parameter list and create pairs for each input (if required) - for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) - { - auto origArg = origCall->getArg(ii); - auto primalArg = findOrTranscribePrimalInst(builder, origArg); - SLANG_ASSERT(primalArg); + if (instMapD.TryGetValue(origCallee, diffCallee)) + { + } + else if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>()) + { + // If the user has already provided an differentiated implementation, use that. + diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc(); + } + else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>()) + { + // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass + // to generate the implementation. + diffCallee = builder->emitForwardDifferentiateInst( + differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), + primalCallee); + } + else + { + // The callee is non differentiable, just return primal value with null diff value. + IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); + return InstPair(primalCall, nullptr); + } + + List<IRInst*> args; + // Go over the parameter list and create pairs for each input (if required) + for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) + { + auto origArg = origCall->getArg(ii); + auto primalArg = findOrTranscribePrimalInst(builder, origArg); + SLANG_ASSERT(primalArg); auto primalType = primalArg->getDataType(); if (auto pairType = tryGetDiffPairType(builder, primalType)) @@ -762,218 +724,218 @@ struct JVPTranscriber IRType* diffReturnType = nullptr; diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); - if (!diffReturnType) - { - SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType); - diffReturnType = builder->getVoidType(); - } + if (!diffReturnType) + { + SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType); + diffReturnType = builder->getVoidType(); + } - auto callInst = builder->emitCallInst( - diffReturnType, - diffCallee, - args); + auto callInst = builder->emitCallInst( + diffReturnType, + diffCallee, + args); - if (diffReturnType->getOp() != kIROp_VoidType) - { - IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst); - auto diffType = differentiateType(builder, origCall->getFullType()); - IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst); - return InstPair(primalResultValue, diffResultValue); - } - else - { - // Return the inst itself if the return value is void. - // This is fine since these values should never actually be used anywhere. - // - return InstPair(callInst, callInst); - } + if (diffReturnType->getOp() != kIROp_VoidType) + { + IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst); + auto diffType = differentiateType(builder, origCall->getFullType()); + IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst); + return InstPair(primalResultValue, diffResultValue); } - - InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) + else { - IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle); + // Return the inst itself if the return value is void. + // This is fine since these values should never actually be used anywhere. + // + return InstPair(callInst, callInst); + } +} - if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr)) - { - List<IRInst*> swizzleIndices; - for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++) - swizzleIndices.add(origSwizzle->getElementIndex(ii)); - - return InstPair( - primalSwizzle, - builder->emitSwizzle( - differentiateType(builder, primalSwizzle->getDataType()), - diffBase, - origSwizzle->getElementCount(), - swizzleIndices.getBuffer())); - } +InstPair ForwardDerivativeTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) +{ + IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle); - return InstPair(primalSwizzle, nullptr); + if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr)) + { + List<IRInst*> swizzleIndices; + for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++) + swizzleIndices.add(origSwizzle->getElementIndex(ii)); + + return InstPair( + primalSwizzle, + builder->emitSwizzle( + differentiateType(builder, primalSwizzle->getDataType()), + diffBase, + origSwizzle->getElementCount(), + swizzleIndices.getBuffer())); } - InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst) - { - IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst); + return InstPair(primalSwizzle, nullptr); +} - UCount operandCount = origInst->getOperandCount(); +InstPair ForwardDerivativeTranscriber::transcribeByPassthrough(IRBuilder* builder, IRInst* origInst) +{ + IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst); - List<IRInst*> diffOperands; - for (UIndex ii = 0; ii < operandCount; ii++) - { - // If the operand has a differential version, replace the original with the - // differential. - // Otherwise, abandon the differentiation attempt and assume that origInst - // cannot (or does not need to) be differentiated. - // - if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr)) - diffOperands.add(diffInst); - else - return InstPair(primalInst, nullptr); - } - - return InstPair( - primalInst, - builder->emitIntrinsicInst( - differentiateType(builder, primalInst->getDataType()), - origInst->getOp(), - operandCount, - diffOperands.getBuffer())); + UCount operandCount = origInst->getOperandCount(); + + List<IRInst*> diffOperands; + for (UIndex ii = 0; ii < operandCount; ii++) + { + // If the operand has a differential version, replace the original with the + // differential. + // Otherwise, abandon the differentiation attempt and assume that origInst + // cannot (or does not need to) be differentiated. + // + if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr)) + diffOperands.add(diffInst); + else + return InstPair(primalInst, nullptr); } + + return InstPair( + primalInst, + builder->emitIntrinsicInst( + differentiateType(builder, primalInst->getDataType()), + origInst->getOp(), + operandCount, + diffOperands.getBuffer())); +} - InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDerivativeTranscriber::transcribeControlFlow(IRBuilder* builder, IRInst* origInst) +{ + switch(origInst->getOp()) { - switch(origInst->getOp()) - { - case kIROp_unconditionalBranch: - case kIROp_loop: - auto origBranch = as<IRUnconditionalBranch>(origInst); + case kIROp_unconditionalBranch: + case kIROp_loop: + auto origBranch = as<IRUnconditionalBranch>(origInst); - // Grab the differentials for any phi nodes. - List<IRInst*> newArgs; - for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) - { - auto origArg = origBranch->getArg(ii); - auto primalArg = lookupPrimalInst(origArg); - newArgs.add(primalArg); + // Grab the differentials for any phi nodes. + List<IRInst*> newArgs; + for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) + { + auto origArg = origBranch->getArg(ii); + auto primalArg = lookupPrimalInst(origArg); + newArgs.add(primalArg); - if (differentiateType(builder, primalArg->getDataType())) - { - auto diffArg = lookupDiffInst(origArg, nullptr); - if (diffArg) - newArgs.add(diffArg); - } + if (differentiateType(builder, primalArg->getDataType())) + { + auto diffArg = lookupDiffInst(origArg, nullptr); + if (diffArg) + newArgs.add(diffArg); } + } - IRInst* diffBranch = nullptr; - if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock())) + IRInst* diffBranch = nullptr; + if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock())) + { + if (auto origLoop = as<IRLoop>(origInst)) { - if (auto origLoop = as<IRLoop>(origInst)) - { - auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); - auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); - List<IRInst*> operands; - operands.add(breakBlock); - operands.add(continueBlock); - operands.addRange(newArgs); - diffBranch = builder->emitIntrinsicInst( - nullptr, - kIROp_loop, - operands.getCount(), - operands.getBuffer()); - } - else - { - diffBranch = builder->emitBranch( - as<IRBlock>(diffBlock), - newArgs.getCount(), - newArgs.getBuffer()); - } + auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); + auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); + List<IRInst*> operands; + operands.add(breakBlock); + operands.add(continueBlock); + operands.addRange(newArgs); + diffBranch = builder->emitIntrinsicInst( + nullptr, + kIROp_loop, + operands.getCount(), + operands.getBuffer()); } + else + { + diffBranch = builder->emitBranch( + as<IRBlock>(diffBlock), + newArgs.getCount(), + newArgs.getBuffer()); + } + } - // For now, every block in the original fn must have a corresponding - // block to compute *both* primals and derivatives (i.e linearized block) - SLANG_ASSERT(diffBranch); + // For now, every block in the original fn must have a corresponding + // block to compute *both* primals and derivatives (i.e linearized block) + SLANG_ASSERT(diffBranch); - return InstPair(diffBranch, diffBranch); - } + return InstPair(diffBranch, diffBranch); + } - getSink()->diagnose( - origInst->sourceLoc, - Diagnostics::unimplemented, - "attempting to differentiate unhandled control flow"); + getSink()->diagnose( + origInst->sourceLoc, + Diagnostics::unimplemented, + "attempting to differentiate unhandled control flow"); - return InstPair(nullptr, nullptr); - } + return InstPair(nullptr, nullptr); +} - InstPair transcribeConst(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDerivativeTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst) +{ + switch(origInst->getOp()) { - switch(origInst->getOp()) - { - case kIROp_FloatLit: - return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f)); - case kIROp_VoidLit: - return InstPair(origInst, origInst); - case kIROp_IntLit: - return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0)); - } + case kIROp_FloatLit: + return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f)); + case kIROp_VoidLit: + return InstPair(origInst, origInst); + case kIROp_IntLit: + return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0)); + } - getSink()->diagnose( - origInst->sourceLoc, - Diagnostics::unimplemented, - "attempting to differentiate unhandled const type"); + getSink()->diagnose( + origInst->sourceLoc, + Diagnostics::unimplemented, + "attempting to differentiate unhandled const type"); - return InstPair(nullptr, nullptr); - } + return InstPair(nullptr, nullptr); +} - IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key) +IRInst* ForwardDerivativeTranscriber::findInterfaceRequirement(IRInterfaceType* type, IRInst* key) +{ + for (UInt i = 0; i < type->getOperandCount(); i++) { - for (UInt i = 0; i < type->getOperandCount(); i++) + if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i))) { - if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i))) - { - if (req->getRequirementKey() == key) - return req->getRequirementVal(); - } + if (req->getRequirementKey() == key) + return req->getRequirementVal(); } - return nullptr; } + return nullptr; +} - InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) +InstPair ForwardDerivativeTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) +{ + auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase()); + List<IRInst*> primalArgs; + for (UInt i = 0; i < origSpecialize->getArgCount(); i++) { - auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase()); - List<IRInst*> primalArgs; - for (UInt i = 0; i < origSpecialize->getArgCount(); i++) - { - primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i))); - } - auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType()); - auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst( - (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer()); + primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i))); + } + auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType()); + auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst( + (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer()); - IRInst* diffBase = nullptr; - if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase)) + IRInst* diffBase = nullptr; + if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase)) + { + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) { - List<IRInst*> args; - for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) - { - args.add(primalSpecialize->getArg(i)); - } - auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); - return InstPair(primalSpecialize, diffSpecialize); + args.add(primalSpecialize->getArg(i)); } + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); + } - auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); - // Look for an IRForwardDerivativeDecoration on the specialize inst. - // (Normally, this would be on the inner IRFunc, but in this case only the JVP func - // can be specialized, so we put a decoration on the IRSpecialize) - // - if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>()) - { - auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc(); + auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); + // Look for an IRForwardDerivativeDecoration on the specialize inst. + // (Normally, this would be on the inner IRFunc, but in this case only the JVP func + // can be specialized, so we put a decoration on the IRSpecialize) + // + if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>()) + { + auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc(); - // Make sure this isn't itself a specialize . - SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); + // Make sure this isn't itself a specialize . + SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); return InstPair(primalSpecialize, jvpFunc); } @@ -1007,611 +969,610 @@ struct JVPTranscriber } } - InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst) - { - auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable()); - auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey()); - auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType()); - auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey); +InstPair ForwardDerivativeTranscriber::transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst) +{ + auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable()); + auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey()); + auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType()); + auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey); - auto interfaceType = as<IRInterfaceType>(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType()); - if (!interfaceType) - { - return InstPair(primal, nullptr); - } - auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); - if (!dict) - { - return InstPair(primal, nullptr); - } + auto interfaceType = as<IRInterfaceType>(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType()); + if (!interfaceType) + { + return InstPair(primal, nullptr); + } + auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); + if (!dict) + { + return InstPair(primal, nullptr); + } - for (auto child : dict->getChildren()) + for (auto child : dict->getChildren()) + { + if (auto item = as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child)) { - if (auto item = as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child)) + if (item->getOperand(0) == lookupInst->getRequirementKey()) { - if (item->getOperand(0) == lookupInst->getRequirementKey()) + auto diffKey = item->getOperand(1); + if (auto diffType = findInterfaceRequirement(interfaceType, diffKey)) { - auto diffKey = item->getOperand(1); - if (auto diffType = findInterfaceRequirement(interfaceType, diffKey)) - { - auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey); - return InstPair(primal, diff); - } - break; + auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey); + return InstPair(primal, diff); } + break; } } - return InstPair(primal, nullptr); } + return InstPair(primal, nullptr); +} - // In differential computation, the 'default' differential value is always zero. - // This is a consequence of differential computing being inherently linear. As a - // result, it's useful to have a method to generate zero literals of any (arithmetic) type. - // The current implementation requires that types are defined linearly. - // - IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) +// In differential computation, the 'default' differential value is always zero. +// This is a consequence of differential computing being inherently linear. As a +// result, it's useful to have a method to generate zero literals of any (arithmetic) type. +// The current implementation requires that types are defined linearly. +// +IRInst* ForwardDerivativeTranscriber::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) +{ + if (auto diffType = differentiateType(builder, primalType)) { - if (auto diffType = differentiateType(builder, primalType)) + switch (diffType->getOp()) { - switch (diffType->getOp()) - { - case kIROp_DifferentialPairType: - return builder->emitMakeDifferentialPair( - diffType, - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); - } - // Since primalType has a corresponding differential type, we can lookup the - // definition for zero(). - auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); - SLANG_ASSERT(zeroMethod); - - auto emptyArgList = List<IRInst*>(); - return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); - } - else + case kIROp_DifferentialPairType: + return builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); + } + // Since primalType has a corresponding differential type, we can lookup the + // definition for zero(). + auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); + SLANG_ASSERT(zeroMethod); + + auto emptyArgList = List<IRInst*>(); + return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + } + else + { + if (isScalarIntegerType(primalType)) { - if (isScalarIntegerType(primalType)) - { - return builder->getIntValue(primalType, 0); - } - - getSink()->diagnose(primalType->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; + return builder->getIntValue(primalType, 0); } + + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; } +} - InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) - { - IRBuilder subBuilder(builder->getSharedBuilder()); - subBuilder.setInsertLoc(builder->getInsertLoc()); - - IRInst* diffBlock = subBuilder.emitBlock(); - - // Note: for blocks, we setup the mapping _before_ - // processing the children since we could encounter - // a lookup while processing the children. - // - mapPrimalInst(origBlock, diffBlock); - mapDifferentialInst(origBlock, diffBlock); +InstPair ForwardDerivativeTranscriber::transcribeBlock(IRBuilder* builder, IRBlock* origBlock) +{ + IRBuilder subBuilder(builder->getSharedBuilder()); + subBuilder.setInsertLoc(builder->getInsertLoc()); + + IRInst* diffBlock = subBuilder.emitBlock(); + + // Note: for blocks, we setup the mapping _before_ + // processing the children since we could encounter + // a lookup while processing the children. + // + mapPrimalInst(origBlock, diffBlock); + mapDifferentialInst(origBlock, diffBlock); - subBuilder.setInsertInto(diffBlock); + subBuilder.setInsertInto(diffBlock); - // First transcribe every parameter in the block. - for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - this->transcribe(&subBuilder, param); + // First transcribe every parameter in the block. + for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) + this->transcribe(&subBuilder, param); - // Then, run through every instruction and use the transcriber to generate the appropriate - // derivative code. - // - for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - this->transcribe(&subBuilder, child); + // Then, run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + this->transcribe(&subBuilder, child); - return InstPair(diffBlock, diffBlock); - } + return InstPair(diffBlock, diffBlock); +} - InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) +InstPair ForwardDerivativeTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) +{ + SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst)); + + IRInst* origBase = originalInst->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto field = originalInst->getOperand(1); + auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>(); + auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType()); + + IRInst* primalOperands[] = { primalBase, field }; + IRInst* primalFieldExtract = builder->emitIntrinsicInst( + primalType, + originalInst->getOp(), + 2, + primalOperands); + + if (!derivativeRefDecor) { - SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst)); - - IRInst* origBase = originalInst->getOperand(0); - auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto field = originalInst->getOperand(1); - auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>(); - auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType()); + return InstPair(primalFieldExtract, nullptr); + } - IRInst* primalOperands[] = { primalBase, field }; - IRInst* primalFieldExtract = builder->emitIntrinsicInst( - primalType, - originalInst->getOp(), - 2, - primalOperands); + IRInst* diffFieldExtract = nullptr; - if (!derivativeRefDecor) + if (auto diffType = differentiateType(builder, primalType)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) { - return InstPair(primalFieldExtract, nullptr); + IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() }; + diffFieldExtract = builder->emitIntrinsicInst( + diffType, + originalInst->getOp(), + 2, + diffOperands); } - - IRInst* diffFieldExtract = nullptr; - - if (auto diffType = differentiateType(builder, primalType)) - { - if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) - { - IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() }; - diffFieldExtract = builder->emitIntrinsicInst( - diffType, - originalInst->getOp(), - 2, - diffOperands); - } - } - return InstPair(primalFieldExtract, diffFieldExtract); } + return InstPair(primalFieldExtract, diffFieldExtract); +} - InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) - { - SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr)); +InstPair ForwardDerivativeTranscriber::transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) +{ + SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr)); - IRInst* origBase = origGetElementPtr->getOperand(0); - auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1)); + IRInst* origBase = origGetElementPtr->getOperand(0); + auto primalBase = findOrTranscribePrimalInst(builder, origBase); + auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1)); - auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType()); + auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType()); - IRInst* primalOperands[] = {primalBase, primalIndex}; - IRInst* primalGetElementPtr = builder->emitIntrinsicInst( - primalType, - origGetElementPtr->getOp(), - 2, - primalOperands); + IRInst* primalOperands[] = {primalBase, primalIndex}; + IRInst* primalGetElementPtr = builder->emitIntrinsicInst( + primalType, + origGetElementPtr->getOp(), + 2, + primalOperands); - IRInst* diffGetElementPtr = nullptr; + IRInst* diffGetElementPtr = nullptr; - if (auto diffType = differentiateType(builder, primalType)) + if (auto diffType = differentiateType(builder, primalType)) + { + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) { - if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) - { - IRInst* diffOperands[] = {diffBase, primalIndex}; - diffGetElementPtr = builder->emitIntrinsicInst( - diffType, - origGetElementPtr->getOp(), - 2, - diffOperands); - } + IRInst* diffOperands[] = {diffBase, primalIndex}; + diffGetElementPtr = builder->emitIntrinsicInst( + diffType, + origGetElementPtr->getOp(), + 2, + diffOperands); } - - return InstPair(primalGetElementPtr, diffGetElementPtr); } - InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop) - { - // The loop comes with three blocks.. we just need to transcribe each one - // and assemble the new loop instruction. - - // Transcribe the target block (this is the 'condition' part of the loop, which - // will branch into the loop body) - auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock()); - - // Transcribe the break block (this is the block after the exiting the loop) - auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); - - // Transcribe the continue block (this is the 'update' part of the loop, which will - // branch into the condition block) - auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); + return InstPair(primalGetElementPtr, diffGetElementPtr); +} - - List<IRInst*> diffLoopOperands; - diffLoopOperands.add(diffTargetBlock); - diffLoopOperands.add(diffBreakBlock); - diffLoopOperands.add(diffContinueBlock); +InstPair ForwardDerivativeTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop) +{ + // The loop comes with three blocks.. we just need to transcribe each one + // and assemble the new loop instruction. + + // Transcribe the target block (this is the 'condition' part of the loop, which + // will branch into the loop body) + auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock()); - // If there are any other operands, use their primal versions. - for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++) - { - auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii)); - diffLoopOperands.add(primalOperand); - } + // Transcribe the break block (this is the block after the exiting the loop) + auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); - IRInst* diffLoop = builder->emitIntrinsicInst( - nullptr, - kIROp_loop, - diffLoopOperands.getCount(), - diffLoopOperands.getBuffer()); + // Transcribe the continue block (this is the 'update' part of the loop, which will + // branch into the condition block) + auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); - return InstPair(diffLoop, diffLoop); - } + + List<IRInst*> diffLoopOperands; + diffLoopOperands.add(diffTargetBlock); + diffLoopOperands.add(diffBreakBlock); + diffLoopOperands.add(diffContinueBlock); - InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) + // If there are any other operands, use their primal versions. + for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++) { - // IfElse Statements come with 4 blocks. We transcribe each block into it's - // linear form, and then wire them up in the same way as the original if-else - - // Transcribe condition block - auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition()); - SLANG_ASSERT(primalConditionBlock); - - // Transcribe 'true' block (condition block branches into this if true) - auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock()); - SLANG_ASSERT(diffTrueBlock); - - // Transcribe 'false' block (condition block branches into this if true) - // TODO (sai): What happens if there's no false block? - auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock()); - SLANG_ASSERT(diffFalseBlock); - - // Transcribe 'after' block (true and false blocks branch into this) - auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock()); - SLANG_ASSERT(diffAfterBlock); - - List<IRInst*> diffIfElseArgs; - diffIfElseArgs.add(primalConditionBlock); - diffIfElseArgs.add(diffTrueBlock); - diffIfElseArgs.add(diffFalseBlock); - diffIfElseArgs.add(diffAfterBlock); - - // If there are any other operands, use their primal versions. - for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++) - { - auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii)); - diffIfElseArgs.add(primalOperand); - } - - IRInst* diffLoop = builder->emitIntrinsicInst( - nullptr, - kIROp_ifElse, - diffIfElseArgs.getCount(), - diffIfElseArgs.getBuffer()); - - return InstPair(diffLoop, diffLoop); + auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii)); + diffLoopOperands.add(primalOperand); } - InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) - { - auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue()); - SLANG_ASSERT(primalVal); - auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue()); - SLANG_ASSERT(diffPrimalVal); - auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue()); - SLANG_ASSERT(primalDiffVal); - auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); - SLANG_ASSERT(diffDiffVal); + IRInst* diffLoop = builder->emitIntrinsicInst( + nullptr, + kIROp_loop, + diffLoopOperands.getCount(), + diffLoopOperands.getBuffer()); - auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal); - auto diffPair = builder->emitMakeDifferentialPair( - differentiateType(builder, origInst->getDataType()), - primalDiffVal, - diffDiffVal); - return InstPair(primalPair, diffPair); - } + return InstPair(diffLoop, diffLoop); +} - InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDerivativeTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse) +{ + // IfElse Statements come with 4 blocks. We transcribe each block into it's + // linear form, and then wire them up in the same way as the original if-else + + // Transcribe condition block + auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition()); + SLANG_ASSERT(primalConditionBlock); + + // Transcribe 'true' block (condition block branches into this if true) + auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock()); + SLANG_ASSERT(diffTrueBlock); + + // Transcribe 'false' block (condition block branches into this if true) + // TODO (sai): What happens if there's no false block? + auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock()); + SLANG_ASSERT(diffFalseBlock); + + // Transcribe 'after' block (true and false blocks branch into this) + auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock()); + SLANG_ASSERT(diffAfterBlock); + + List<IRInst*> diffIfElseArgs; + diffIfElseArgs.add(primalConditionBlock); + diffIfElseArgs.add(diffTrueBlock); + diffIfElseArgs.add(diffFalseBlock); + diffIfElseArgs.add(diffAfterBlock); + + // If there are any other operands, use their primal versions. + for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++) { - auto primal = cloneInst(&cloneEnv, builder, origInst); - return InstPair(primal, nullptr); + auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii)); + diffIfElseArgs.add(primalOperand); } - InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) - { - SLANG_ASSERT( - origInst->getOp() == kIROp_DifferentialPairGetDifferential || - origInst->getOp() == kIROp_DifferentialPairGetPrimal); + IRInst* diffLoop = builder->emitIntrinsicInst( + nullptr, + kIROp_ifElse, + diffIfElseArgs.getCount(), + diffIfElseArgs.getBuffer()); - auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0)); - SLANG_ASSERT(primalVal); - - auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0)); - SLANG_ASSERT(diffVal); + return InstPair(diffLoop, diffLoop); +} - auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal); +InstPair ForwardDerivativeTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) +{ + auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue()); + SLANG_ASSERT(primalVal); + auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue()); + SLANG_ASSERT(diffPrimalVal); + auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue()); + SLANG_ASSERT(primalDiffVal); + auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); + SLANG_ASSERT(diffDiffVal); + + auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal); + auto diffPair = builder->emitMakeDifferentialPair( + differentiateType(builder, origInst->getDataType()), + primalDiffVal, + diffDiffVal); + return InstPair(primalPair, diffPair); +} - auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType()); - IRInst* diffResultType = nullptr; - if (origInst->getOp() == kIROp_DifferentialPairGetDifferential) - diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType); - else - diffResultType = diffValPairType->getValueType(); - auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal); - return InstPair(primalResult, diffResult); - } +InstPair ForwardDerivativeTranscriber::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst) +{ + auto primal = cloneInst(&cloneEnv, builder, origInst); + return InstPair(primal, nullptr); +} - // Create an empty func to represent the transcribed func of `origFunc`. - InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) - { - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertBefore(origFunc); +InstPair ForwardDerivativeTranscriber::transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) +{ + SLANG_ASSERT( + origInst->getOp() == kIROp_DifferentialPairGetDifferential || + origInst->getOp() == kIROp_DifferentialPairGetPrimal); + + auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0)); + SLANG_ASSERT(primalVal); + + auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0)); + SLANG_ASSERT(diffVal); + + auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal); + + auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType()); + IRInst* diffResultType = nullptr; + if (origInst->getOp() == kIROp_DifferentialPairGetDifferential) + diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType); + else + diffResultType = diffValPairType->getValueType(); + auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal); + return InstPair(primalResult, diffResult); +} - IRFunc* primalFunc = origFunc; +// Create an empty func to represent the transcribed func of `origFunc`. +InstPair ForwardDerivativeTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) +{ + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origFunc); - differentiableTypeConformanceContext.setFunc(origFunc); + IRFunc* primalFunc = origFunc; - primalFunc = origFunc; + differentiableTypeConformanceContext.setFunc(origFunc); - auto diffFunc = builder.createFunc(); + primalFunc = origFunc; - SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); - IRType* diffFuncType = this->differentiateFunctionType( - &builder, - as<IRFuncType>(origFunc->getFullType())); - diffFunc->setFullType(diffFuncType); + auto diffFunc = builder.createFunc(); - if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>()) - { - auto originalName = nameHint->getName(); - StringBuilder newNameSb; - newNameSb << "s_fwd_" << originalName; - builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); - } - builder.addForwardDerivativeDecoration(origFunc, diffFunc); + SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); + IRType* diffFuncType = this->differentiateFunctionType( + &builder, + as<IRFuncType>(origFunc->getFullType())); + diffFunc->setFullType(diffFuncType); - // Mark the generated derivative function itself as differentiable. - builder.addForwardDifferentiableDecoration(diffFunc); + if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>()) + { + auto originalName = nameHint->getName(); + StringBuilder newNameSb; + newNameSb << "s_fwd_" << originalName; + builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); + } + builder.addForwardDerivativeDecoration(origFunc, diffFunc); - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) - { - cloneDecoration(dictDecor, diffFunc); - } + // Mark the generated derivative function itself as differentiable. + builder.addForwardDifferentiableDecoration(diffFunc); - auto result = InstPair(primalFunc, diffFunc); - followUpFunctionsToTranscribe.add(result); - return result; + // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. + if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) + { + cloneDecoration(dictDecor, diffFunc); } - // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) - { - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertInto(diffFunc); + auto result = InstPair(primalFunc, diffFunc); + followUpFunctionsToTranscribe.add(result); + return result; +} - differentiableTypeConformanceContext.setFunc(primalFunc); - // Transcribe children from origFunc into diffFunc - for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribe(&builder, block); +// Transcribe a function definition. +InstPair ForwardDerivativeTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) +{ + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertInto(diffFunc); - return InstPair(primalFunc, diffFunc); - } + differentiableTypeConformanceContext.setFunc(primalFunc); + // Transcribe children from origFunc into diffFunc + for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribe(&builder, block); + + return InstPair(primalFunc, diffFunc); +} - // Transcribe a generic definition - InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) +// Transcribe a generic definition +InstPair ForwardDerivativeTranscriber::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) +{ + auto innerVal = findInnerMostGenericReturnVal(origGeneric); + if (auto innerFunc = as<IRFunc>(innerVal)) { - auto innerVal = findInnerMostGenericReturnVal(origGeneric); - if (auto innerFunc = as<IRFunc>(innerVal)) - { - differentiableTypeConformanceContext.setFunc(innerFunc); - } - else if (auto funcType = as<IRFuncType>(innerVal)) - { - } - else - { - return InstPair(origGeneric, nullptr); - } + differentiableTypeConformanceContext.setFunc(innerFunc); + } + else if (auto funcType = as<IRFuncType>(innerVal)) + { + } + else + { + return InstPair(origGeneric, nullptr); + } - IRGeneric* primalGeneric = origGeneric; + IRGeneric* primalGeneric = origGeneric; - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertBefore(origGeneric); + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origGeneric); - auto diffGeneric = builder.emitGeneric(); + auto diffGeneric = builder.emitGeneric(); - // Process type of generic. If the generic is a function, then it's type will also be a - // generic and this logic will transcribe that generic first before continuing with the - // function itself. - // - auto primalType = primalGeneric->getFullType(); + // Process type of generic. If the generic is a function, then it's type will also be a + // generic and this logic will transcribe that generic first before continuing with the + // function itself. + // + auto primalType = primalGeneric->getFullType(); - IRType* diffType = nullptr; - if (primalType) - { - diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType); - } + IRType* diffType = nullptr; + if (primalType) + { + diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType); + } - diffGeneric->setFullType(diffType); + diffGeneric->setFullType(diffType); // Transcribe children from origFunc into diffFunc. builder.setInsertInto(diffGeneric); for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) this->transcribe(&builder, block); - return InstPair(primalGeneric, diffGeneric); - } + return InstPair(primalGeneric, diffGeneric); +} - IRInst* transcribe(IRBuilder* builder, IRInst* origInst) +IRInst* ForwardDerivativeTranscriber::transcribe(IRBuilder* builder, IRInst* origInst) +{ + // If a differential intstruction is already mapped for + // this original inst, return that. + // + if (auto diffInst = lookupDiffInst(origInst, nullptr)) { - // If a differential intstruction is already mapped for - // this original inst, return that. - // - if (auto diffInst = lookupDiffInst(origInst, nullptr)) - { - SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check. - return diffInst; - } + SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check. + return diffInst; + } - // Otherwise, dispatch to the appropriate method - // depending on the op-code. - // - instsInProgress.Add(origInst); - InstPair pair = transcribeInst(builder, origInst); - instsInProgress.Remove(origInst); + // Otherwise, dispatch to the appropriate method + // depending on the op-code. + // + instsInProgress.Add(origInst); + InstPair pair = transcribeInst(builder, origInst); + instsInProgress.Remove(origInst); - if (auto primalInst = pair.primal) + if (auto primalInst = pair.primal) + { + mapPrimalInst(origInst, pair.primal); + mapDifferentialInst(origInst, pair.differential); + if (pair.differential) { - mapPrimalInst(origInst, pair.primal); - mapDifferentialInst(origInst, pair.differential); - if (pair.differential) + switch (pair.differential->getOp()) { - switch (pair.differential->getOp()) + case kIROp_Func: + case kIROp_Generic: + case kIROp_Block: + // Don't generate again for these. + // Functions already have their names generated in `transcribeFuncHeader`. + break; + default: + // Generate name hint for the inst. + if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>()) { - case kIROp_Func: - case kIROp_Generic: - case kIROp_Block: - // Don't generate again for these. - // Functions already have their names generated in `transcribeFuncHeader`. - break; - default: - // Generate name hint for the inst. - if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>()) - { - StringBuilder sb; - sb << "s_diff_" << primalNameHint->getName(); - builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); - } - break; + StringBuilder sb; + sb << "s_diff_" << primalNameHint->getName(); + builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); } + break; } - return pair.differential; } - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::internalCompilerError, - "failed to transcibe instruction"); - return nullptr; + return pair.differential; } + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::internalCompilerError, + "failed to transcibe instruction"); + return nullptr; +} - InstPair transcribeInst(IRBuilder* builder, IRInst* origInst) +InstPair ForwardDerivativeTranscriber::transcribeInst(IRBuilder* builder, IRInst* origInst) +{ + // Handle common SSA-style operations + switch (origInst->getOp()) { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transcribeParam(builder, as<IRParam>(origInst)); - - case kIROp_Var: - return transcribeVar(builder, as<IRVar>(origInst)); - - case kIROp_Load: - return transcribeLoad(builder, as<IRLoad>(origInst)); - - case kIROp_Store: - return transcribeStore(builder, as<IRStore>(origInst)); + case kIROp_Param: + return transcribeParam(builder, as<IRParam>(origInst)); - case kIROp_Return: - return transcribeReturn(builder, as<IRReturn>(origInst)); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return transcribeBinaryArith(builder, origInst); - - case kIROp_Less: - case kIROp_Greater: - case kIROp_And: - case kIROp_Or: - case kIROp_Geq: - case kIROp_Leq: - return transcribeBinaryLogic(builder, origInst); - - case kIROp_Construct: - return transcribeConstruct(builder, origInst); - - case kIROp_lookup_interface_method: - return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); - - case kIROp_Call: - return transcribeCall(builder, as<IRCall>(origInst)); - - case kIROp_swizzle: - return transcribeSwizzle(builder, as<IRSwizzle>(origInst)); - - case kIROp_constructVectorFromScalar: - case kIROp_MakeTuple: - return transcribeByPassthrough(builder, origInst); + case kIROp_Var: + return transcribeVar(builder, as<IRVar>(origInst)); - case kIROp_unconditionalBranch: - return transcribeControlFlow(builder, origInst); + case kIROp_Load: + return transcribeLoad(builder, as<IRLoad>(origInst)); - case kIROp_FloatLit: - case kIROp_IntLit: - case kIROp_VoidLit: - return transcribeConst(builder, origInst); + case kIROp_Store: + return transcribeStore(builder, as<IRStore>(origInst)); - case kIROp_Specialize: - return transcribeSpecialize(builder, as<IRSpecialize>(origInst)); + case kIROp_Return: + return transcribeReturn(builder, as<IRReturn>(origInst)); - case kIROp_FieldExtract: - case kIROp_FieldAddress: - return transcribeFieldExtract(builder, origInst); - case kIROp_getElement: - case kIROp_getElementPtr: - return transcribeGetElement(builder, origInst); - - case kIROp_loop: - return transcribeLoop(builder, as<IRLoop>(origInst)); - - case kIROp_ifElse: - return transcribeIfElse(builder, as<IRIfElse>(origInst)); - - case kIROp_MakeDifferentialPair: - return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst)); - case kIROp_DifferentialPairGetPrimal: - case kIROp_DifferentialPairGetDifferential: - return transcribeDifferentialPairGetElement(builder, origInst); - case kIROp_ExtractExistentialWitnessTable: - case kIROp_ExtractExistentialType: - case kIROp_ExtractExistentialValue: - case kIROp_WrapExistential: - case kIROp_MakeExistential: - case kIROp_MakeExistentialWithRTTI: - return trascribeNonDiffInst(builder, origInst); - case kIROp_StructKey: - return InstPair(origInst, nullptr); - } + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + case kIROp_Div: + return transcribeBinaryArith(builder, origInst); - // If none of the cases have been hit, check if the instruction is a - // type. Only need to explicitly differentiate types if they appear inside a block. - // - if (auto origType = as<IRType>(origInst)) - { - // If this is a generic type, transcibe the parent - // generic and derive the type from the transcribed generic's - // return value. - // - if (as<IRGeneric>(origType->getParent()->getParent()) && - findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType && - !instsInProgress.Contains(origType->getParent()->getParent())) - { - auto origGenericType = origType->getParent()->getParent(); - auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType); - auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType)); - return InstPair( - origGenericType, - innerDiffGenericType - ); - } - else if (as<IRBlock>(origType->getParent())) - return InstPair( - cloneInst(&cloneEnv, builder, origType), - differentiateType(builder, origType)); - else - return InstPair( - cloneInst(&cloneEnv, builder, origType), - nullptr); - } + case kIROp_Less: + case kIROp_Greater: + case kIROp_And: + case kIROp_Or: + case kIROp_Geq: + case kIROp_Leq: + return transcribeBinaryLogic(builder, origInst); + + case kIROp_Construct: + return transcribeConstruct(builder, origInst); + + case kIROp_lookup_interface_method: + return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); + + case kIROp_Call: + return transcribeCall(builder, as<IRCall>(origInst)); + + case kIROp_swizzle: + return transcribeSwizzle(builder, as<IRSwizzle>(origInst)); + + case kIROp_constructVectorFromScalar: + case kIROp_MakeTuple: + return transcribeByPassthrough(builder, origInst); + + case kIROp_unconditionalBranch: + return transcribeControlFlow(builder, origInst); + + case kIROp_FloatLit: + case kIROp_IntLit: + case kIROp_VoidLit: + return transcribeConst(builder, origInst); + + case kIROp_Specialize: + return transcribeSpecialize(builder, as<IRSpecialize>(origInst)); + + case kIROp_FieldExtract: + case kIROp_FieldAddress: + return transcribeFieldExtract(builder, origInst); + case kIROp_getElement: + case kIROp_getElementPtr: + return transcribeGetElement(builder, origInst); + + case kIROp_loop: + return transcribeLoop(builder, as<IRLoop>(origInst)); + + case kIROp_ifElse: + return transcribeIfElse(builder, as<IRIfElse>(origInst)); + + case kIROp_MakeDifferentialPair: + return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst)); + case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferential: + return transcribeDifferentialPairGetElement(builder, origInst); + case kIROp_ExtractExistentialWitnessTable: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_WrapExistential: + case kIROp_MakeExistential: + case kIROp_MakeExistentialWithRTTI: + return trascribeNonDiffInst(builder, origInst); + case kIROp_StructKey: + return InstPair(origInst, nullptr); + } - // Handle instructions with children - switch (origInst->getOp()) + // If none of the cases have been hit, check if the instruction is a + // type. Only need to explicitly differentiate types if they appear inside a block. + // + if (auto origType = as<IRType>(origInst)) + { + // If this is a generic type, transcibe the parent + // generic and derive the type from the transcribed generic's + // return value. + // + if (as<IRGeneric>(origType->getParent()->getParent()) && + findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType && + !instsInProgress.Contains(origType->getParent()->getParent())) { - case kIROp_Func: - return transcribeFuncHeader(builder, as<IRFunc>(origInst)); - - case kIROp_Block: - return transcribeBlock(builder, as<IRBlock>(origInst)); - - case kIROp_Generic: - return transcribeGeneric(builder, as<IRGeneric>(origInst)); + auto origGenericType = origType->getParent()->getParent(); + auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType); + auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType)); + return InstPair( + origGenericType, + innerDiffGenericType + ); } + else if (as<IRBlock>(origType->getParent())) + return InstPair( + cloneInst(&cloneEnv, builder, origType), + differentiateType(builder, origType)); + else + return InstPair( + cloneInst(&cloneEnv, builder, origType), + nullptr); + } - // If we reach this statement, the instruction type is likely unhandled. - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::unimplemented, - "this instruction cannot be differentiated"); + // Handle instructions with children + switch (origInst->getOp()) + { + case kIROp_Func: + return transcribeFuncHeader(builder, as<IRFunc>(origInst)); - return InstPair(nullptr, nullptr); + case kIROp_Block: + return transcribeBlock(builder, as<IRBlock>(origInst)); + + case kIROp_Generic: + return transcribeGeneric(builder, as<IRGeneric>(origInst)); } -}; + + // If we reach this statement, the instruction type is likely unhandled. + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::unimplemented, + "this instruction cannot be differentiated"); + + return InstPair(nullptr, nullptr); +} struct ForwardDerivativePass : public InstPassBase { @@ -1771,7 +1732,7 @@ protected: // A transcriber object that handles the main job of // processing instructions while maintaining state. // - JVPTranscriber transcriberStorage; + ForwardDerivativeTranscriber transcriberStorage; // Diagnostic object from the compile request for // error messages. @@ -1797,61 +1758,4 @@ bool processForwardDerivativeCalls( return changed; } - -void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) -{ - parentFunc = func; - - auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); - SLANG_RELEASE_ASSERT(decor); - - // Build lookup dictionary for type witnesses. - for (auto child = decor->getFirstChild(); child; child = child->next) - { - if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) - { - auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); - if (existingItem) - { - *existingItem = item->getWitness(); - } - else - { - differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); - } - } - } -} - - -// Lookup a witness table for the concreteType. One should exist if concreteType -// inherits (successfully) from IDifferentiable. -// - -IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) -{ - IRInst* foundResult = nullptr; - differentiableWitnessDictionary.TryGetValue(type, foundResult); - return foundResult; -} - -IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) -{ - if (auto conformance = lookUpConformanceForType(origType)) - { - return _lookupWitness(builder, conformance, key); - } - return nullptr; -} - -void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() -{ - for (auto globalInst : sharedContext->moduleInst->getChildren()) - { - if (auto pairType = as<IRDifferentialPairType>(globalInst)) - { - differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness()); - } - } -} } diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 6b261ecd0..ab5d753d6 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -30,6 +30,171 @@ namespace Slang typedef DiffInstPair<IRInst*, IRInst*> InstPair; + +struct ForwardDerivativeTranscriber +{ + + // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent + // their differential values. + Dictionary<IRInst*, IRInst*> instMapD; + + // Set of insts currently being transcribed. Used to avoid infinite loops. + HashSet<IRInst*> instsInProgress; + + // Cloning environment to hold mapping from old to new copies for the primal + // instructions. + IRCloneEnv cloneEnv; + + // Diagnostic sink for error messages. + DiagnosticSink* sink; + + // Type conformance information. + AutoDiffSharedContext* autoDiffSharedContext; + + // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct + DifferentialPairTypeBuilder* pairBuilder; + + DifferentiableTypeConformanceContext differentiableTypeConformanceContext; + + List<InstPair> followUpFunctionsToTranscribe; + + SharedIRBuilder* sharedBuilder; + // Witness table that `DifferentialBottom:IDifferential`. + IRWitnessTable* differentialBottomWitness = nullptr; + Dictionary<InstPair, IRInst*> differentialPairTypes; + + ForwardDerivativeTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder) + : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder) + { + + } + + DiagnosticSink* getSink(); + + void mapDifferentialInst(IRInst* origInst, IRInst* diffInst); + + void mapPrimalInst(IRInst* origInst, IRInst* primalInst); + + IRInst* lookupDiffInst(IRInst* origInst); + + IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst); + + bool hasDifferentialInst(IRInst* origInst); + + bool shouldUseOriginalAsPrimal(IRInst* origInst); + + IRInst* lookupPrimalInst(IRInst* origInst); + + IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst); + + bool hasPrimalInst(IRInst* origInst); + + IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst); + + IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst); + + IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType); + + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. + IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType); + + IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness); + + IRType* getOrCreateDiffPairType(IRInst* primalType); + + IRType* differentiateType(IRBuilder* builder, IRType* origType); + + IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType); + + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType); + + InstPair transcribeParam(IRBuilder* builder, IRParam* origParam); + + // Returns "d<var-name>" to use as a name hint for variables and parameters. + // If no primal name is available, returns a blank string. + // + String getJVPVarName(IRInst* origVar); + + // Returns "dp<var-name>" to use as a name hint for parameters. + // If no primal name is available, returns a blank string. + // + String makeDiffPairName(IRInst* origVar); + + InstPair transcribeVar(IRBuilder* builder, IRVar* origVar); + + InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith); + + InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic); + + InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad); + + InstPair transcribeStore(IRBuilder* builder, IRStore* origStore); + + InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn); + + // Since int/float literals are sometimes nested inside an IRConstructor + // instruction, we check to make sure that the nested instr is a constant + // and then return nullptr. Literals do not need to be differentiated. + // + InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct); + + // Differentiating a call instruction here is primarily about generating + // an appropriate call list based on whichever parameters have differentials + // in the current transcription context. + // + InstPair transcribeCall(IRBuilder* builder, IRCall* origCall); + + InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle); + + InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst); + + InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst); + + InstPair transcribeConst(IRBuilder* builder, IRInst* origInst); + + IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key); + + InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize); + + InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst); + + // In differential computation, the 'default' differential value is always zero. + // This is a consequence of differential computing being inherently linear. As a + // result, it's useful to have a method to generate zero literals of any (arithmetic) type. + // The current implementation requires that types are defined linearly. + // + IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType); + + InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock); + + InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst); + + InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr); + + InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop); + + InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse); + + InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst); + + InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst); + + InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst); + + // Create an empty func to represent the transcribed func of `origFunc`. + InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc); + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc); + + // Transcribe a generic definition + InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric); + + IRInst* transcribe(IRBuilder* builder, IRInst* origInst); + + InstPair transcribeInst(IRBuilder* builder, IRInst* origInst); +}; + struct ForwardDerivativePassOptions { // Nothing for now.. diff --git a/source/slang/slang-ir-autodiff-propagate.h b/source/slang/slang-ir-autodiff-propagate.h new file mode 100644 index 000000000..9518ccb34 --- /dev/null +++ b/source/slang/slang-ir-autodiff-propagate.h @@ -0,0 +1,102 @@ +// slang-ir-autodiff-propagate.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +#include "slang-ir-autodiff.h" + +namespace Slang +{ + +bool isDifferentialInst(IRInst* inst) +{ + return inst->findDecoration<IRDifferentialInstDecoration>(); +} + +struct DiffPropagationPass : InstPassBase +{ + AutoDiffSharedContext* autodiffContext; + + DiffPropagationPass(AutoDiffSharedContext* autodiffContext) : + autodiffContext(autodiffContext), + InstPassBase(autodiffContext->moduleInst->getModule()) + { + + } + + + bool shouldInstBeMarkedDifferential(IRInst* inst) + { + for (UIndex ii = 0; ii < inst->getOperandCount(); ii ++) + { + if (isDifferentialInst(inst->getOperand(ii))) + { + return true; + } + } + + return false; + } + + void addPendingUsersToWorkList(IRInst* inst) + { + auto use = inst->firstUse; + while (use) + { + if (!isDifferentialInst(use->getUser())) + { + addToWorkList(use->getUser()); + } + use = use->nextUse; + } + } + + // Propagate IRDifferentialInstDecoration for all children of instWithChildren. + void propagateDiffInstDecoration(IRBuilder* builder, IRInst* instWithChildren) + { + List<IRInst*> initialList; + // Mark 'GetDifferential' insts as differential. + processChildInstsOfType<IRDifferentialPairGetDifferential>( + kIROp_DifferentialPairGetDifferential, + instWithChildren, + [&](IRDifferentialPairGetDifferential* getDifferentialInst) + { + builder->markInstAsDifferential(getDifferentialInst); + initialList.add(getDifferentialInst); + }); + + + workList.clear(); + workListSet.Clear(); + + // Add the marked insts to the work list. + for (auto inst : initialList) + { + // Look for insts marked as differential. + if (isDifferentialInst(inst)) + addPendingUsersToWorkList(inst); + } + + // Propagate to all users.. + while (workList.getCount() != 0) + { + IRInst* inst = pop(); + + // Skip if this is already a differential inst. + if (isDifferentialInst(inst)) + { + continue; + } + + if (shouldInstBeMarkedDifferential(inst)) + { + builder->markInstAsDifferential(inst); + addPendingUsersToWorkList(inst); + } + } + } +}; + +}
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 52567e887..522c995b0 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -6,6 +6,11 @@ #include "slang-ir-util.h" #include "slang-ir-inst-pass-base.h" +#include "slang-ir-autodiff-fwd.h" +#include "slang-ir-autodiff-propagate.h" +#include "slang-ir-autodiff-unzip.h" +#include "slang-ir-autodiff-transpose.h" + namespace Slang { @@ -44,12 +49,33 @@ struct BackwardDiffTranscriber IRWitnessTable* differentialBottomWitness = nullptr; Dictionary<InstPair, IRInst*> differentialPairTypes; + // References to other passes that for reverse-mode transcription. + ForwardDerivativeTranscriber *fwdDiffTranscriber; + DiffTransposePass *diffTransposePass; + DiffPropagationPass *diffPropagationPass; + DiffUnzipPass *diffUnzipPass; + + // Allocate space for the passes. + ForwardDerivativeTranscriber fwdDiffTranscriberStorage; + DiffTransposePass diffTransposePassStorage; + DiffPropagationPass diffPropagationPassStorage; + DiffUnzipPass diffUnzipPassStorage; + + BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) : autoDiffSharedContext(shared) , sink(inSink) , differentiableTypeConformanceContext(shared) , sharedBuilder(inSharedBuilder) - {} + , fwdDiffTranscriberStorage(shared, inSharedBuilder) + , diffTransposePassStorage(shared) + , diffPropagationPassStorage(shared) + , diffUnzipPassStorage(shared) + , fwdDiffTranscriber(&fwdDiffTranscriberStorage) + , diffTransposePass(&diffTransposePassStorage) + , diffPropagationPass(&diffPropagationPassStorage) + , diffUnzipPass(&diffUnzipPassStorage) + { } DiagnosticSink* getSink() { @@ -413,20 +439,166 @@ struct BackwardDiffTranscriber return result; } - // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) + // Puts parameters into their own block. + void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func) { IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertInto(diffFunc); - differentiableTypeConformanceContext.setFunc(primalFunc); - // Transcribe children from origFunc into diffFunc - for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribeBlock(&builder, block); + auto firstBlock = func->getFirstBlock(); + IRParam* param = func->getFirstParam(); + + builder.setInsertBefore(firstBlock); + + // Note: It looks like emitBlock() doesn't use the current + // builder position, so we're going to manually move the new block + // to before the existing block. + auto paramBlock = builder.emitBlock(); + paramBlock->insertBefore(firstBlock); + builder.setInsertInto(paramBlock); + + while(param) + { + IRParam* nextParam = param->getNextParam(); + + // Copy inst into the new parameter block. + auto clonedParam = cloneInst(&cloneEnv, &builder, param); + param->replaceUsesWith(clonedParam); + param->removeAndDeallocate(); + + param = nextParam; + } + + // Replace this block as the first block. + firstBlock->replaceUsesWith(paramBlock); + + // Add terminator inst. + builder.emitBranch(firstBlock); + } + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) + { + SLANG_ASSERT(primalFunc); + SLANG_ASSERT(diffFunc); + // Reverse-mode transcription uses 4 separate steps: + // TODO(sai): Fill in documentation. + + // Generate a temporary forward derivative function as an intermediate step. + IRFunc* fwdDiffFunc = as<IRFunc>(fwdDiffTranscriber->transcribeFuncHeader(builder, (IRFunc*)primalFunc).differential); + SLANG_ASSERT(fwdDiffFunc); + + // Transcribe the body of the primal function into it's linear (fwd-diff) form. + // TODO(sai): Handle the case when we already have a user-defined fwd-derivative function. + fwdDiffTranscriber->transcribeFunc(builder, primalFunc, as<IRFunc>(fwdDiffFunc)); + + // Split first block into a paramter block. + this->makeParameterBlock(builder, as<IRFunc>(fwdDiffFunc)); + + // This steps adds a decoration to instructions that are computing the differential. + diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc); + + // Copy primal insts to the first block of the unzipped function, copy diff insts to the + // second block of the unzipped function. + // + IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc); + + // Clone the primal blocks from unzippedFwdDiffFunc + // to the reverse-mode function. + // TODO: This is the spot where we can make a decision to split + // the primal and differential into two different funcitons + // instead of two blocks in the same function. + // + // Special care needs to be taken for the first block since it holds the parameters + + // Clone all blocks into a temporary diff func. + // We're using a temporary sice we don't want to clone decorations, + // only blocks, and right now there's no provision in slang-ir-clone.h + // for that. + // + builder->setInsertInto(diffFunc->getParent()); + auto tempDiffFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, unzippedFwdDiffFunc)); + + // Move blocks to the diffFunc shell. + { + List<IRBlock*> workList; + for (auto block = tempDiffFunc->getFirstBlock(); block; block = block->getNextBlock()) + workList.add(block); + + for (auto block : workList) + block->insertAtEnd(diffFunc); + } + + // Transpose the first block (parameter block) + transcribeParameterBlock(builder, diffFunc); + + builder->setInsertInto(diffFunc); + + auto dOutParameter = diffFunc->getLastParam(); + + // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the + DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr}; + diffTransposePass->transposeDiffBlocksInFunc(diffFunc, info); + + // Clean up by deallocating intermediate steps. + tempDiffFunc->removeAndDeallocate(); + unzippedFwdDiffFunc->removeAndDeallocate(); + fwdDiffFunc->removeAndDeallocate(); return InstPair(primalFunc, diffFunc); } + void transcribeParameterBlock(IRBuilder* builder, IRFunc* diffFunc) + { + IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock(); + + // Find the 'next' block using the terminator inst of the parameter block. + auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator()); + auto nextBlock = fwdParamBlockBranch->getTargetBlock(); + + builder->setInsertInto(fwdDiffParameterBlock); + + // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> + for (auto child = fwdDiffParameterBlock->getFirstParam(); child;) + { + IRParam* nextChild = child->getNextParam(); + + auto fwdParam = as<IRParam>(child); + SLANG_ASSERT(fwdParam); + + // TODO: Handle ptr<pair> types. + if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType())) + { + // Create inout version. + auto inoutDiffPairType = builder->getInOutType(diffPairType); + auto newParam = builder->emitParam(inoutDiffPairType); + + // Map the _load_ of the new parameter as the clone of the old one. + auto newParamLoad = builder->emitLoad(newParam); + newParamLoad->insertAtStart(nextBlock); // Move to first block _after_ the parameter block. + fwdParam->replaceUsesWith(newParamLoad); + fwdParam->removeAndDeallocate(); + } + else + { + // Default case (parameter has nothing to do with differentiation) + // Do nothing. + } + + child = nextChild; + } + + auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); + + // 2. Add a parameter for 'derivative of the output' (d_out). + // The type is the last parameter type of the function. + // + auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1); + + SLANG_ASSERT(dOutParamType); + + builder->emitParam(dOutParamType); + } + IRInst* copyParam(IRBuilder* builder, IRParam* origParam) { auto primalDataType = origParam->getDataType(); @@ -652,7 +824,6 @@ struct BackwardDiffTranscriber struct ReverseDerivativePass : public InstPassBase { - DiagnosticSink* getSink() { return sink; @@ -664,7 +835,7 @@ struct ReverseDerivativePass : public InstPassBase IRBuilder builderStorage(autodiffContext->sharedBuilder); IRBuilder* builder = &builderStorage; - // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by + // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by // generating derivative code for the referenced function. // bool modified = processReferencedFunctions(builder); @@ -800,6 +971,9 @@ struct ReverseDerivativePass : public InstPassBase pairBuilderStorage(context) { backwardTranscriberStorage.pairBuilder = &pairBuilderStorage; + backwardTranscriberStorage.fwdDiffTranscriberStorage.sink = sink; + backwardTranscriberStorage.fwdDiffTranscriberStorage.autoDiffSharedContext = context; + backwardTranscriberStorage.fwdDiffTranscriberStorage.pairBuilder = &(pairBuilderStorage); } protected: @@ -829,4 +1003,4 @@ bool processReverseDerivativeCalls( return changed; } -}
\ No newline at end of file +} diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h new file mode 100644 index 000000000..659131820 --- /dev/null +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -0,0 +1,420 @@ +// slang-ir-autodiff-transpose.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-fwd.h" + +namespace Slang +{ + +struct DiffTransposePass +{ + AutoDiffSharedContext* autodiffContext; + + DifferentialPairTypeBuilder pairBuilder; + + Dictionary<IRInst*, List<IRInst*>> assignmentsMap; + + Dictionary<IRInst*, IRInst*>* primalsMap; + + DiffTransposePass(AutoDiffSharedContext* autodiffContext) : + autodiffContext(autodiffContext), pairBuilder(autodiffContext) + { } + + struct RevAssignment + { + IRInst* lvalue; + IRInst* rvalue; + + RevAssignment(IRInst* lvalue, IRInst* rvalue) : lvalue(lvalue), rvalue(rvalue) + { } + RevAssignment() : lvalue(nullptr), rvalue(nullptr) + { } + }; + + struct TranspositionResult + { + // Holds a set of pairs of + // (original-inst, inst-to-accumulate-for-orig-inst) + List<RevAssignment> revPairs; + + TranspositionResult() + { } + + TranspositionResult(List<RevAssignment> revPairs) : revPairs(revPairs) + { } + }; + + struct FuncTranspositionInfo + { + // Inst that represents the reverse-mode derivative + // of the *output* of the function. + // + IRInst* dOutInst; + + // Mapping between *primal* insts in the forward-mode function, and the + // reverse-mode function + // + Dictionary<IRInst*, IRInst*>* primalsMap; + }; + + void transposeDiffBlocksInFunc( + IRFunc* revDiffFunc, + // TODO: Maybe there's a more elegant way to pass this information. + FuncTranspositionInfo transposeInfo) + { + + // Traverse all instructions/blocks in reverse (starting from the terminator inst) + // look for insts/blocks marked with IRDifferentialInstDecoration, + // and transpose them in the revDiffFunc. + // + IRBuilder builder; + builder.init(autodiffContext->sharedBuilder); + + // Insert after the last block. + builder.setInsertInto(revDiffFunc); + + List<IRBlock*> workList; + + // Build initial list of blocks to process by checking if they're differential blocks. + for (IRBlock* block = revDiffFunc->getFirstBlock(); block; block = block->getNextBlock()) + { + if (!isDifferentialInst(block)) + { + // Skip blocks that aren't computing differentials. + // At this stage we should have 'unzipped' the function + // into blocks that either entirely deal with primal insts, + // or entirely with differential insts. + continue; + } + workList.add(block); + } + + // TODO: We *might* need a step here that 'sorts' the work list in reverse order starting with 'leaf' + // differential blocks, and following the branches backwards. + // The alternative is to make phi nodes and treat all intermediaries & their gradients as arguments. + + for (auto block : workList) + { + // Set dOutParameter as the transpose gradient for the return inst, if any. + if (auto returnInst = as<IRReturn>(block->getTerminator())) + { + this->addRevAssignmentForFwdInst(returnInst, transposeInfo.dOutInst); + } + + IRBlock* revBlock = builder.emitBlock(); + this->transposeBlock(block, revBlock); + + // TODO: This should only really be used for the transition from + // the 'last' primal block(s) to the first differential block. + // Transitions from differential blocks to + block->replaceUsesWith(revBlock); + block->removeAndDeallocate(); + } + } + + void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock) + { + IRBuilder builder; + builder.init(autodiffContext->sharedBuilder); + + // Insert after the last block. + builder.setInsertInto(revBlock); + + // Note the 'reverse' traversal here. + for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst()) + { + if (as<IRDecoration>(child)) + continue; + + transposeInst(&builder, child); + } + + // After processing the block's instructions, we 'flush' any remaining gradients + // in the assignments map. + // For now, these are only function parameter gradients (or of the form IRLoad(IRParam)) + // TODO: We should be flushing *all* gradients accumulated in this block to some + // function scope variable, since control flow can affect what blocks contribute to + // for a specific inst. + // + for (auto pair : assignmentsMap) + { + if (auto param = as<IRLoad>(pair.Key)) + accumulateGradientsForLoad(&builder, param); + } + + // Emit a terminator inst. + // TODO: need a be a lot smarter here. For now, we assume a single differential + // block, so it should end in a return statement. + if (as<IRReturn>(fwdBlock->getTerminator())) + { + // Emit a void return. + builder.emitReturn(); + } + else + { + SLANG_UNEXPECTED("Unhandled block terminator"); + } + } + + void transposeInst(IRBuilder* builder, IRInst* inst) + { + // Look for assignment entry for this inst. + IRInst* revValue = builder->getFloatValue(builder->getType(kIROp_FloatType), 0.0); + if (hasRevAssignments(inst)) + { + // Emit the aggregate of all the assignments here. This will form the derivative + revValue = emitAggregateValue(builder, popRevAssignments(inst)); + } + + auto transposeResult = transposeInst(builder, inst, revValue); + + // Add the new results to the assignments map. + for (auto pair : transposeResult.revPairs) + { + addRevAssignmentForFwdInst(pair.lvalue, pair.rvalue); + } + } + + TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) + { + // Dispatch logic. + switch(fwdInst->getOp()) + { + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + return transposeArithmetic(builder, fwdInst, revValue); + + case kIROp_Return: + return transposeReturn(builder, as<IRReturn>(fwdInst), revValue); + + case kIROp_MakeDifferentialPair: + return transposeMakePair(builder, as<IRMakeDifferentialPair>(fwdInst), revValue); + + case kIROp_DifferentialPairGetDifferential: + return transposeGetDifferential(builder, as<IRDifferentialPairGetDifferential>(fwdInst), revValue); + + default: + SLANG_ASSERT_FAILURE("Unhandled instruction"); + } + } + + TranspositionResult transposeMakePair(IRBuilder*, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue) + { + // (P = (A, dA)) -> (dA += dP) + return TranspositionResult( + List<RevAssignment>( + RevAssignment( + fwdMakePair->getDifferentialValue(), + revValue))); + } + + TranspositionResult transposeGetDifferential(IRBuilder*, IRDifferentialPairGetDifferential* fwdGetDiff, IRInst* revValue) + { + // (A = GetDiff(P)) -> (dP.d += dA) + return TranspositionResult( + List<RevAssignment>( + RevAssignment( + fwdGetDiff->getBase(), + revValue))); + } + + // Gather all reverse-mode gradients for parameters, and store to the differential + // + void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad) + { + auto revParam = revLoad->getPtr(); + + // Don't currently handle loads from non-param insts. + SLANG_ASSERT(as<IRParam>(revParam)); + + // Assert that param type is of the form IRPtrTypeBase<IRDifferentialPairType<T>> + SLANG_ASSERT(as<IRPtrTypeBase>(revParam->getDataType())); + SLANG_ASSERT(as<IRPtrTypeBase>(revParam->getDataType())->getValueType()->getOp() == kIROp_DifferentialPairType); + + auto paramPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(revParam->getDataType())->getValueType()); + auto diffType = (IRType*) pairBuilder.getDiffTypeFromPairType(builder, paramPairType); + + // Gather gradients. + auto gradients = popRevAssignments(revLoad); + if (gradients.getCount() == 0) + { + // Ignore. + return; + } + else + { + // Re-emit a load to get the _current_ value of revParam. + auto revCurrLoad = builder->emitLoad(revParam); + + // Grab the current gradient value. + auto revCurrGrad = builder->emitDifferentialPairGetDifferential(diffType, revCurrLoad); + + // Add the current value to the aggregation list. + gradients.add(revCurrGrad); + + // Get the _total_ value. + auto aggregateGradient = emitAggregateValue(builder, gradients); + + // Grab the current primal value. + auto revCurrPrimal = builder->emitDifferentialPairGetPrimal(revCurrLoad); + + // Make the pair with the new gradient. + auto newDiffPair = builder->emitMakeDifferentialPair(paramPairType, revCurrPrimal, aggregateGradient); + + // Store this back into the parameter. + builder->emitStore(revParam, newDiffPair); + } + } + + TranspositionResult transposeReturn(IRBuilder*, IRReturn* fwdReturn, IRInst* revValue) + { + + if (as<IRDifferentialPairType>(fwdReturn->getVal()->getDataType())) + { + // If the type is a differential pair, we add the reverse-value for the *pair* + // itself. TODO: Signal this through flags in the 'RevAssignment' struct. + // (return (A, dA)) -> (dA += dOut) + return TranspositionResult( + List<RevAssignment>( + RevAssignment( + fwdReturn->getVal(), + revValue))); + } + else + { + // (return A) -> (empty) + return TranspositionResult(); + } + } + + TranspositionResult transposeArithmetic(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) + { + IRType* floatType = builder->getType(kIROp_FloatType); + switch(fwdInst->getOp()) + { + case kIROp_Add: + { + // (Out = dA + dB) -> [(dA += dOut), (dB += dOut)] + return TranspositionResult( + List<RevAssignment>( + RevAssignment( + fwdInst->getOperand(0), + revValue), + RevAssignment( + fwdInst->getOperand(1), + revValue))); + } + case kIROp_Sub: + { + // (Out = dA - dB) -> [(dA += dOut), (dB -= dOut)] + return TranspositionResult( + List<RevAssignment>( + RevAssignment( + fwdInst->getOperand(0), + revValue), + RevAssignment( + fwdInst->getOperand(1), + builder->emitNeg( + revValue->getDataType(), revValue)))); + } + case kIROp_Mul: + { + if (isDifferentialInst(fwdInst->getOperand(0))) + { + // (Out = dA * B) -> (dA += B * dOut) + return TranspositionResult( + List<RevAssignment>( + RevAssignment( + fwdInst->getOperand(0), + builder->emitMul(floatType, fwdInst->getOperand(1), revValue)))); + } + else if (isDifferentialInst(fwdInst->getOperand(1))) + { + // (Out = A * dB) -> (dB += A * dOut) + return TranspositionResult( + List<RevAssignment>( + RevAssignment( + fwdInst->getOperand(1), + builder->emitMul(floatType, fwdInst->getOperand(0), revValue)))); + } + else + { + SLANG_ASSERT_FAILURE("Neither operand of a mul instruction is a differential inst"); + } + } + + default: + SLANG_ASSERT_FAILURE("Unhandled arithmetic"); + } + } + + IRInst* emitAggregateValue(IRBuilder* builder, List<IRInst*> values) + { + // We're handling the case where the types are all float, + // so we can use a bunch of kIROp_Add insts to add them up. + // If this is an arbitrary type T, we need to lookup and + // call T.dadd() + + IRInst* initialValue = builder->getFloatValue(builder->getType(kIROp_FloatType), 0.0); + if (values.getCount() == 0) + { + // If there's not values to add up, emit a 0 value. + return initialValue; + } + else if (values.getCount() == 1) + { + // If there's only one value to add up, just return it in order + // to avoid a stack of 0 + 0 + 0 + ... + return values[0]; + } + + // If there's more than one value, aggregate them by adding them up. + + SLANG_ASSERT(values[0]->getDataType()->getOp() == kIROp_FloatType); + + IRInst* currentValue = initialValue; + for (auto value : values) + { + currentValue = builder->emitAdd( + builder->getType(kIROp_FloatType), currentValue, value); + } + + return currentValue; + } + + void addRevAssignmentForFwdInst(IRInst* fwdInst, IRInst* assignment) + { + if (!hasRevAssignments(fwdInst)) + { + assignmentsMap[fwdInst] = List<IRInst*>(); + } + + assignmentsMap[fwdInst].GetValue().add(assignment); + } + + List<IRInst*> getRevAssignments(IRInst* fwdInst) + { + return assignmentsMap[fwdInst]; + } + + List<IRInst*> popRevAssignments(IRInst* fwdInst) + { + List<IRInst*> val = assignmentsMap[fwdInst].GetValue(); + assignmentsMap.Remove(fwdInst); + return val; + } + + bool hasRevAssignments(IRInst* fwdInst) + { + return assignmentsMap.ContainsKey(fwdInst); + } +}; + + +} diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h new file mode 100644 index 000000000..344a930f2 --- /dev/null +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -0,0 +1,110 @@ +// slang-ir-autodiff-unzip.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-propagate.h" + +namespace Slang +{ + +struct DiffUnzipPass +{ + AutoDiffSharedContext* autodiffContext; + + IRCloneEnv cloneEnv; + + DiffUnzipPass(AutoDiffSharedContext* autodiffContext) : + autodiffContext(autodiffContext) + { } + + IRFunc* unzipDiffInsts(IRFunc* func) + { + IRBuilder builderStorage; + builderStorage.init(autodiffContext->sharedBuilder); + + IRBuilder* builder = &builderStorage; + + // Clone the entire function. + // TODO: Maybe don't clone? The reverse-mode process seems to clone several times. + // TODO: Looks like we get a copy of the decorations? + IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, func)); + + builder->setInsertInto(unzippedFunc); + + // Work *only* with two-block functions for now. + SLANG_ASSERT(unzippedFunc->getFirstBlock() != nullptr); + SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock() != nullptr); + SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock()->getNextBlock() == nullptr); + + // Ignore the first block (this is reserved for parameters), start + // at the second block. (For now, we work with only a single block of insts) + // TODO: expand to handle multi-block functions later. + + IRBlock* mainBlock = unzippedFunc->getFirstBlock()->getNextBlock(); + + // Emit new blocks for split vesions of mainblock. + IRBlock* primalBlock = builder->emitBlock(); + IRBlock* diffBlock = builder->emitBlock(); + + // Mark the differential block as a differential inst. + builder->markInstAsDifferential(diffBlock); + + // Split the main block into two. This method should also emit + // a branch statement from primalBlock to diffBlock. + // TODO: extend this code to split multiple blocks + // + splitBlock(mainBlock, primalBlock, diffBlock); + + // Replace occurences of mainBlock with primalBlock + mainBlock->replaceUsesWith(primalBlock); + mainBlock->removeAndDeallocate(); + + return unzippedFunc; + } + + void splitBlock(IRBlock* mainBlock, IRBlock* primalBlock, IRBlock* diffBlock) + { + // Make two builders for primal and differential blocks. + IRBuilder primalBuilder; + primalBuilder.init(autodiffContext->sharedBuilder); + primalBuilder.setInsertInto(primalBlock); + + IRBuilder diffBuilder; + diffBuilder.init(autodiffContext->sharedBuilder); + diffBuilder.setInsertInto(diffBlock); + + for (auto child = mainBlock->getFirstChild(); child;) + { + IRInst* nextChild = child->getNextInst(); + + if (isDifferentialInst(child) || as<IRTerminatorInst>(child)) + { + auto newInst = cloneInst(&cloneEnv, &diffBuilder, child); + child->replaceUsesWith(newInst); + child->removeAndDeallocate(); + } + else + { + auto newInst = cloneInst(&cloneEnv, &primalBuilder, child); + child->replaceUsesWith(newInst); + child->removeAndDeallocate(); + } + + child = nextChild; + } + + // Nothing should be left in the original block. + SLANG_ASSERT(mainBlock->getFirstChild() == nullptr); + + // Branch from primal to differential block. + // Functionally, the new blocks should produce the same output as the + // old block. + primalBuilder.emitBranch(diffBlock); + } +}; + +} diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 313760d85..b0dbf62fa 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -327,6 +327,60 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde } + +void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) +{ + parentFunc = func; + + auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>(); + SLANG_RELEASE_ASSERT(decor); + + // Build lookup dictionary for type witnesses. + for (auto child = decor->getFirstChild(); child; child = child->next) + { + if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) + { + auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); + if (existingItem) + { + *existingItem = item->getWitness(); + } + else + { + differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); + } + } + } +} + +IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) +{ + IRInst* foundResult = nullptr; + differentiableWitnessDictionary.TryGetValue(type, foundResult); + return foundResult; +} + +IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) +{ + if (auto conformance = lookUpConformanceForType(origType)) + { + return _lookupWitness(builder, conformance, key); + } + return nullptr; +} + +void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() +{ + for (auto globalInst : sharedContext->moduleInst->getChildren()) + { + if (auto pairType = as<IRDifferentialPairType>(globalInst)) + { + differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness()); + } + } +} + + void stripAutoDiffDecorationsFromChildren(IRInst* parent) { for (auto inst : parent->getChildren()) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 3ca9f1d41..4aca291f9 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -727,10 +727,13 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// generated derivative function. INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0) - /// Used by the auto-diff pass to hold a reference to the - /// generated derivative function. + /// Decorated function is marked for the reverse-mode differentiation pass. INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0) + /// Used by the auto-diff pass to mark insts that compute + /// a differential value. + INST(DifferentialInstDecoration, diffInstDecoration, 0, 0) + /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 7cf8b3032..a1249aff9 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -584,6 +584,15 @@ struct IRBackwardDerivativeDecoration : IRDecoration IRInst* getBackwardDerivativeFunc() { return getOperand(0); } }; +struct IRDifferentialInstDecoration : IRDecoration +{ + enum + { + kOp = kIROp_DifferentialInstDecoration + }; + IR_LEAF_ISA(DifferentialInstDecoration) +}; + struct IRBackwardDifferentiableDecoration : IRDecoration { enum @@ -3335,6 +3344,11 @@ public: addDecoration(value, kIROp_BackwardDerivativeDecoration, jvpFn); } + void markInstAsDifferential(IRInst* value) + { + addDecoration(value, kIROp_DifferentialInstDecoration); + } + void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) { addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1); |
