diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-11-22 12:36:28 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-22 09:36:28 -0800 |
| commit | 6178cb601368e977c4aa82e0ae25b8eb1e875d84 (patch) | |
| tree | c7dc6df96c5bb4d0f4fd598ae40158e06a082fd1 | |
| parent | d9b014cba6803dbfcd37ed8ac3e7560a5191e3cf (diff) | |
Refactor Auto-diff passes (#2526)
* Initial refactor
* Refactor passes tests
* Removed Differential Bottom references from the IR side
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj | 10 | ||||
| -rw-r--r-- | build/visual-studio/slang/slang.vcxproj.filters | 30 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp (renamed from source/slang/slang-ir-diff-jvp.cpp) | 1474 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 43 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-pairs.cpp | 182 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-pairs.h | 21 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 832 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.h | 25 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 408 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 210 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.h | 174 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 2 |
15 files changed, 1777 insertions, 1645 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index 5244f0efb..fe1922d29 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -333,6 +333,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-intrinsic-expand.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-any-value-marshalling.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-augment-make-existential.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-bind-existentials.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-byte-address-legalize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-check-differentiability.h" />
@@ -343,7 +347,6 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-ir-constexpr.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dce.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-diff-call.h" />
- <ClInclude Include="..\..\..\source\slang\slang-ir-diff-jvp.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-export.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-import.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dominators.h" />
@@ -505,6 +508,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-intrinsic-expand.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-any-value-marshalling.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-bind-existentials.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-byte-address-legalize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-check-differentiability.cpp" />
@@ -516,7 +523,6 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-ir-dce.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-deduplicate.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-diff-call.cpp" />
- <ClCompile Include="..\..\..\source\slang\slang-ir-diff-jvp.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-export.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-import.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dominators.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 0a2d1fe3f..6fa42287c 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -132,6 +132,18 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-augment-make-existential.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-bind-existentials.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -162,9 +174,6 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-diff-call.h">
<Filter>Header Files</Filter>
</ClInclude>
- <ClInclude Include="..\..\..\source\slang\slang-ir-diff-jvp.h">
- <Filter>Header Files</Filter>
- </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-export.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -644,6 +653,18 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-bind-existentials.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -677,9 +698,6 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-diff-call.cpp">
<Filter>Source Files</Filter>
</ClCompile>
- <ClCompile Include="..\..\..\source\slang\slang-ir-diff-jvp.cpp">
- <Filter>Source Files</Filter>
- </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-export.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 69ea29c7a..ca55a68bc 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -11,7 +11,7 @@ #include "slang-ir-cleanup-void.h" #include "slang-ir-dce.h" #include "slang-ir-diff-call.h" -#include "slang-ir-diff-jvp.h" +#include "slang-ir-autodiff.h" #include "slang-ir-dll-export.h" #include "slang-ir-dll-import.h" #include "slang-ir-eliminate-phis.h" @@ -377,10 +377,7 @@ Result linkAndOptimizeIR( dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); - // Process higher-order calles to auto-diff passes. - processDifferentiableFuncs(irModule, sink); - - stripAutoDiffDecorations(irModule); + processAutodiffCalls(irModule, sink); dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index c9ca687e4..03e81c5b5 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1,5 +1,6 @@ -// slang-ir-diff-jvp.cpp -#include "slang-ir-diff-jvp.h" +// slang-ir-autodiff-fwd.cpp +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-fwd.h" #include "slang-ir-clone.h" #include "slang-ir-dce.h" @@ -7,314 +8,9 @@ #include "slang-ir-util.h" #include "slang-ir-inst-pass-base.h" -// origX, primalX, diffX -// origX -> primalX (cloneEnv) -// origX -> diffX (instMapD) - namespace Slang { -namespace -{ - -IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) -{ - if (auto witnessTable = as<IRWitnessTable>(witness)) - { - for (auto entry : witnessTable->getEntries()) - { - if (entry->getRequirementKey() == requirementKey) - return entry->getSatisfyingVal(); - } - } - else if (auto witnessTableParam = as<IRParam>(witness)) - { - return builder->emitLookupInterfaceMethodInst( - builder->getTypeKind(), - witnessTableParam, - requirementKey); - } - return nullptr; -} - -} - -struct DifferentialPairTypeBuilder -{ - - IRStructField* findField(IRInst* type, IRStructKey* key) - { - if (auto irStructType = as<IRStructType>(type)) - { - for (auto field : irStructType->getFields()) - { - if (field->getKey() == key) - { - return field; - } - } - } - else if (auto irSpecialize = as<IRSpecialize>(type)) - { - if (auto irGeneric = as<IRGeneric>(irSpecialize->getBase())) - { - if (auto irGenericStructType = as<IRStructType>(findInnerMostGenericReturnVal(irGeneric))) - { - return findField(irGenericStructType, key); - } - } - } - - return nullptr; - } - - IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam) - { - // Get base generic that's being specialized. - auto genericType = as<IRGeneric>(as<IRSpecialize>(specializeInst)->getBase()); - SLANG_ASSERT(genericType); - - // Find the index of genericParam in the base generic. - int paramIndex = -1; - int currentIndex = 0; - for (auto param : genericType->getParams()) - { - if (param == genericParam) - paramIndex = currentIndex; - currentIndex ++; - } - - SLANG_ASSERT(paramIndex >= 0); - - // Return the corresponding operand in the specialization inst. - return specializeInst->getOperand(1 + paramIndex); - } - - IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) - { - IRInst* pairType = nullptr; - if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType())) - { - auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType()); - - // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types, - // especially since we probably don't need diff bottom anymore. - // - SLANG_ASSERT(!baseTypeInfo.isTrivial); - - pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType); - } - else - { - auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); - if (baseTypeInfo.isTrivial) - { - if (key == globalPrimalKey) - return baseInst; - else - return builder->getDifferentialBottom(); - } - - pairType = baseTypeInfo.loweredType; - } - - if (auto basePairStructType = as<IRStructType>(pairType)) - { - return as<IRFieldExtract>(builder->emitFieldExtract( - findField(basePairStructType, key)->getFieldType(), - baseInst, - key - )); - } - else if (auto ptrType = as<IRPtrTypeBase>(pairType)) - { - if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType())) - { - auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(ptrInnerSpecializedType->getBase())); - if (auto genericBasePairStructType = as<IRStructType>(genericType)) - { - return as<IRFieldAddress>(builder->emitFieldAddress( - builder->getPtrType((IRType*) - findSpecializationForParam( - ptrInnerSpecializedType, - findField(ptrInnerSpecializedType, key)->getFieldType())), - baseInst, - key - )); - } - } - else if (auto ptrBaseStructType = as<IRStructType>(ptrType->getValueType())) - { - return as<IRFieldAddress>(builder->emitFieldAddress( - builder->getPtrType((IRType*) - findField(ptrBaseStructType, key)->getFieldType()), - baseInst, - key)); - } - } - else if (auto specializedType = as<IRSpecialize>(pairType)) - { - // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's - // type, emit the specialization type. - // - auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase())); - if (auto genericBasePairStructType = as<IRStructType>(genericType)) - { - return as<IRFieldExtract>(builder->emitFieldExtract( - (IRType*)findSpecializationForParam( - specializedType, - findField(genericBasePairStructType, key)->getFieldType()), - baseInst, - key - )); - } - else if (auto genericPtrType = as<IRPtrTypeBase>(genericType)) - { - if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType())) - { - return as<IRFieldAddress>(builder->emitFieldAddress( - builder->getPtrType((IRType*) - findSpecializationForParam( - specializedType, - findField(genericPairStructType, key)->getFieldType())), - baseInst, - key - )); - } - } - } - else - { - SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor"); - } - return nullptr; - } - - IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) - { - return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); - } - - IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) - { - return emitFieldAccessor(builder, baseInst, this->globalDiffKey); - } - - IRStructKey* _getOrCreateDiffStructKey() - { - if (!this->globalDiffKey) - { - IRBuilder builder(sharedContext->sharedBuilder); - // Insert directly at top level (skip any generic scopes etc.) - builder.setInsertInto(sharedContext->moduleInst); - - this->globalDiffKey = builder.createStructKey(); - builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential")); - } - - return this->globalDiffKey; - } - - IRStructKey* _getOrCreatePrimalStructKey() - { - if (!this->globalPrimalKey) - { - // Insert directly at top level (skip any generic scopes etc.) - IRBuilder builder(sharedContext->sharedBuilder); - builder.setInsertInto(sharedContext->moduleInst); - - this->globalPrimalKey = builder.createStructKey(); - builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal")); - } - - return this->globalPrimalKey; - } - - IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType) - { - switch (origBaseType->getOp()) - { - case kIROp_lookup_interface_method: - case kIROp_Specialize: - case kIROp_Param: - return nullptr; - default: - break; - } - if (diffType->getOp() != kIROp_DifferentialBottomType) - { - IRBuilder builder(sharedContext->sharedBuilder); - builder.setInsertBefore(diffType); - - auto pairStructType = builder.createStructType(); - builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType); - builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType); - return pairStructType; - } - return origBaseType; - } - - struct LoweredPairTypeInfo - { - IRInst* loweredType; - bool isTrivial; - }; - - IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type) - { - auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); - } - - IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type) - { - auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); - } - - LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType) - { - LoweredPairTypeInfo result = {}; - - if (pairTypeCache.TryGetValue(originalPairType, result)) - return result; - auto pairType = as<IRDifferentialPairType>(originalPairType); - if (!pairType) - { - result.isTrivial = true; - result.loweredType = originalPairType; - return result; - } - auto primalType = pairType->getValueType(); - if (as<IRParam>(primalType)) - { - result.isTrivial = false; - result.loweredType = nullptr; - return result; - } - - auto diffType = getDiffTypeFromPairType(builder, pairType); - if (!diffType) - return result; - result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); - result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType); - pairTypeCache.Add(originalPairType, result); - - return result; - } - - Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache; - - IRStructKey* globalPrimalKey = nullptr; - - IRStructKey* globalDiffKey = nullptr; - - IRInst* genericDiffPairType = nullptr; - - List<IRInst*> generatedTypeList; - - AutoDiffSharedContext* sharedContext = nullptr; -}; struct JVPTranscriber { @@ -468,17 +164,6 @@ struct JVPTranscriber return builder->getFuncType(newParameterTypes, diffReturnType); } - IRWitnessTable* getDifferentialBottomWitness() - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); - auto result = - as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - SLANG_ASSERT(result); - return result; - } - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) { @@ -524,10 +209,6 @@ struct JVPTranscriber { witness = getDifferentialPairWitness(primalPairType); } - else - { - witness = getDifferentialBottomWitness(); - } } return builder.getDifferentialPairType( @@ -1825,666 +1506,7 @@ struct JVPTranscriber } }; - -struct BackwardDiffTranscriber -{ - - // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent - // their differential values. - Dictionary<IRInst*, IRInst*> orginalToTranscribed; - - // Set of insts currently being transcribed. Used to avoid infinite loops. - HashSet<IRInst*> instsInProgress; - - // Cloning environment to hold mapping from old to new copies for the primal - // instructions. - IRCloneEnv cloneEnv; - - // Diagnostic sink for error messages. - DiagnosticSink* sink; - - // Type conformance information. - AutoDiffSharedContext* autoDiffSharedContext; - - // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct - DifferentialPairTypeBuilder* pairBuilder; - - DifferentiableTypeConformanceContext differentiableTypeConformanceContext; - - List<InstPair> followUpFunctionsToTranscribe; - - // Map that stores the upper gradient given an IRInst* - Dictionary<IRInst*, List<IRInst*>> upperGradients; - Dictionary<IRInst*, IRInst*> primalToDiffPair; - - SharedIRBuilder* sharedBuilder; - // Witness table that `DifferentialBottom:IDifferential`. - IRWitnessTable* differentialBottomWitness = nullptr; - Dictionary<InstPair, IRInst*> differentialPairTypes; - - BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) - : autoDiffSharedContext(shared) - , sink(inSink) - , differentiableTypeConformanceContext(shared) - , sharedBuilder(inSharedBuilder) - {} - - DiagnosticSink* getSink() - { - SLANG_ASSERT(sink); - return sink; - } - - IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) - { - List<IRType*> newParameterTypes; - IRType* diffReturnType; - - for (UIndex i = 0; i < funcType->getParamCount(); i++) - { - auto origType = funcType->getParamType(i); - if (auto diffPairType = tryGetDiffPairType(builder, origType)) - { - auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); - newParameterTypes.add(inoutDiffPairType); - } - else - newParameterTypes.add(origType); - } - - newParameterTypes.add(funcType->getResultType()); - - diffReturnType = builder->getVoidType(); - - return builder->getFuncType(newParameterTypes, diffReturnType); - } - - IRWitnessTable* getDifferentialBottomWitness() - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); - auto result = - as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - SLANG_ASSERT(result); - return result; - } - - // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. - IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(inDiffPairType->parent); - auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); - SLANG_ASSERT(diffPairType); - auto result = - as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( - builder.getDifferentialBottomType())); - if (result) - return result; - - auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); - auto diffType = differentiateType(&builder, diffPairType->getValueType()); - auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness()); - builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); - // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. - - differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; - return table; - } - - IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); - } - - IRType* getOrCreateDiffPairType(IRInst* primalType) - { - IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - auto witness = as<IRWitnessTable>( - differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); - if (!witness) - witness = getDifferentialBottomWitness(); - return builder.getDifferentialPairType( - (IRType*)primalType, - witness); - } - - IRType* differentiateType(IRBuilder* builder, IRType* origType) - { - IRInst* diffType = nullptr; - if (!orginalToTranscribed.TryGetValue(origType, diffType)) - { - diffType = _differentiateTypeImpl(builder, origType); - orginalToTranscribed[origType] = diffType; - } - return (IRType*)diffType; - } - - IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) - { - if (auto ptrType = as<IRPtrTypeBase>(origType)) - return builder->getPtrType( - origType->getOp(), - differentiateType(builder, ptrType->getValueType())); - - // If there is an explicit primal version of this type in the local scope, load that - // otherwise use the original type. - // - IRInst* primalType = origType; - - // Special case certain compound types (PtrType, FuncType, etc..) - // otherwise try to lookup a differential definition for the given type. - // If one does not exist, then we assume it's not differentiable. - // - switch (primalType->getOp()) - { - case kIROp_Param: - if (as<IRTypeType>(primalType->getDataType())) - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( - builder, - (IRType*)primalType)); - else if (as<IRWitnessTableType>(primalType->getDataType())) - return (IRType*)primalType; - - case kIROp_ArrayType: - { - auto primalArrayType = as<IRArrayType>(primalType); - if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) - return builder->getArrayType( - diffElementType, - primalArrayType->getElementCount()); - else - return nullptr; - } - - case kIROp_DifferentialPairType: - { - auto primalPairType = as<IRDifferentialPairType>(primalType); - return getOrCreateDiffPairType( - pairBuilder->getDiffTypeFromPairType(builder, primalPairType), - pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); - } - - case kIROp_FuncType: - return differentiateFunctionType(builder, as<IRFuncType>(primalType)); - - case kIROp_OutType: - if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) - return builder->getOutType(diffValueType); - else - return nullptr; - - case kIROp_InOutType: - if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) - return builder->getInOutType(diffValueType); - else - return nullptr; - - case kIROp_TupleType: - { - auto tupleType = as<IRTupleType>(primalType); - List<IRType*> diffTypeList; - // TODO: what if we have type parameters here? - for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) - diffTypeList.add( - differentiateType(builder, (IRType*)tupleType->getOperand(ii))); - - return builder->getTupleType(diffTypeList); - } - - default: - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); - } - } - - IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) - { - // If this is a PtrType (out, inout, etc..), then create diff pair from - // value type and re-apply the appropropriate PtrType wrapper. - // - if (auto origPtrType = as<IRPtrTypeBase>(primalType)) - { - if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) - return builder->getPtrType(primalType->getOp(), diffPairValueType); - else - return nullptr; - } - auto diffType = differentiateType(builder, primalType); - if (diffType) - return (IRType*)getOrCreateDiffPairType(primalType); - return nullptr; - } - - InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) - { - auto primalDataType = origParam->getDataType(); - // Do not differentiate generic type (and witness table) parameters - if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) - { - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } - - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - IRInst* diffPairParam = builder->emitParam(diffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffPairParam); - - if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) - { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); - } - // If this is an `in/inout DifferentialPair<>` parameter, we can't produce - // its primal and diff parts right now because they would represent a reference - // to a pair field, which doesn't make sense since pair types are considered mutable. - // We encode the result as if the param is non-differentiable, and handle it - // with special care at load/store. - return InstPair(diffPairParam, nullptr); - } - - - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); - } - - // Returns "dp<var-name>" to use as a name hint for parameters. - // If no primal name is available, returns a blank string. - // - String makeDiffPairName(IRInst* origVar) - { - if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) - { - return ("dp" + String(namehintDecoration->getName())); - } - - return String(""); - } - - - // In differential computation, the 'default' differential value is always zero. - // This is a consequence of differential computing being inherently linear. As a - // result, it's useful to have a method to generate zero literals of any (arithmetic) type. - // The current implementation requires that types are defined linearly. - // - IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) - { - if (auto diffType = differentiateType(builder, primalType)) - { - switch (diffType->getOp()) - { - case kIROp_DifferentialPairType: - return builder->emitMakeDifferentialPair( - diffType, - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), - getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); - } - // Since primalType has a corresponding differential type, we can lookup the - // definition for zero(). - auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); - SLANG_ASSERT(zeroMethod); - - auto emptyArgList = List<IRInst*>(); - return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); - } - else - { - if (isScalarIntegerType(primalType)) - { - return builder->getIntValue(primalType, 0); - } - - getSink()->diagnose(primalType->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; - } - } - - InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) - { - IRBuilder subBuilder(builder->getSharedBuilder()); - subBuilder.setInsertLoc(builder->getInsertLoc()); - - IRBlock* diffBlock = subBuilder.emitBlock(); - - subBuilder.setInsertInto(diffBlock); - - // First transcribe every parameter in the block. - for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - this->copyParam(&subBuilder, param); - - // The extra param for input gradient - auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType()); - - // Then, run through every instruction and use the transcriber to generate the appropriate - // derivative code. - // - for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - this->copyInst(&subBuilder, child); - - auto lastInst = diffBlock->getLastOrdinaryInst(); - List<IRInst*> grads = { gradParam }; - upperGradients.Add(lastInst, grads); - for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst()) - { - auto upperGrads = upperGradients.TryGetValue(child); - if (!upperGrads) - continue; - if (upperGrads->getCount() > 1) - { - auto sumGrad = upperGrads->getFirst(); - for (auto i = 1; i < upperGrads->getCount(); i++) - { - sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]); - } - this->transcribeInstBackward(&subBuilder, child, sumGrad); - } - else - this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst()); - } - - subBuilder.emitReturn(); - - return InstPair(diffBlock, diffBlock); - } - - // Create an empty func to represent the transcribed func of `origFunc`. - InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) - { - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertBefore(origFunc); - - IRFunc* primalFunc = origFunc; - - differentiableTypeConformanceContext.setFunc(origFunc); - - primalFunc = origFunc; - - auto diffFunc = builder.createFunc(); - - SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); - IRType* diffFuncType = this->differentiateFunctionType( - &builder, - as<IRFuncType>(origFunc->getFullType())); - diffFunc->setFullType(diffFuncType); - - if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>()) - { - auto originalName = nameHint->getName(); - StringBuilder newNameSb; - newNameSb << "s_bwd_" << originalName; - builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); - } - builder.addBackwardDerivativeDecoration(origFunc, diffFunc); - - // Mark the generated derivative function itself as differentiable. - builder.addBackwardDifferentiableDecoration(diffFunc); - - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) - { - cloneDecoration(dictDecor, diffFunc); - } - - auto result = InstPair(primalFunc, diffFunc); - followUpFunctionsToTranscribe.add(result); - return result; - } - - // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) - { - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertInto(diffFunc); - - differentiableTypeConformanceContext.setFunc(primalFunc); - // Transcribe children from origFunc into diffFunc - for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribeBlock(&builder, block); - - return InstPair(primalFunc, diffFunc); - } - - IRInst* copyParam(IRBuilder* builder, IRParam* origParam) - { - auto primalDataType = origParam->getDataType(); - - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); - IRInst* diffParam = builder->emitParam(inoutDiffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffParam); - auto paramValue = builder->emitLoad(diffParam); - auto primal = builder->emitDifferentialPairGetPrimal(paramValue); - orginalToTranscribed.Add(origParam, primal); - primalToDiffPair.Add(primal, diffParam); - - return diffParam; - } - - - return cloneInst(&cloneEnv, builder, origParam); - } - - InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); - - auto origLeft = origArith->getOperand(0); - auto origRight = origArith->getOperand(1); - - IRInst* primalLeft; - if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft)) - { - primalLeft = origLeft; - } - IRInst* primalRight; - if (!orginalToTranscribed.TryGetValue(origRight, primalRight)) - { - primalRight = origRight; - } - - auto resultType = origArith->getDataType(); - IRInst* newInst; - switch (origArith->getOp()) - { - case kIROp_Add: - newInst = builder->emitAdd(resultType, primalLeft, primalRight); - break; - case kIROp_Mul: - newInst = builder->emitMul(resultType, primalLeft, primalRight); - break; - case kIROp_Sub: - newInst = builder->emitSub(resultType, primalLeft, primalRight); - break; - case kIROp_Div: - newInst = builder->emitDiv(resultType, primalLeft, primalRight); - break; - default: - newInst = nullptr; - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - orginalToTranscribed.Add(origArith, newInst); - return InstPair(newInst, nullptr); - } - - IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); - - auto lhs = origArith->getOperand(0); - auto rhs = origArith->getOperand(1); - - if (as<IRInOutType>(lhs->getDataType())) - { - lhs = builder->emitLoad(lhs); - lhs = builder->emitDifferentialPairGetPrimal(lhs); - } - if (as<IRInOutType>(rhs->getDataType())) - { - rhs = builder->emitLoad(rhs); - rhs = builder->emitDifferentialPairGetPrimal(rhs); - } - - IRInst* leftGrad; - IRInst* rightGrad; - - - switch (origArith->getOp()) - { - case kIROp_Add: - leftGrad = grad; - rightGrad = grad; - break; - case kIROp_Mul: - leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); - rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); - break; - case kIROp_Sub: - leftGrad = grad; - rightGrad = builder->emitNeg(grad->getDataType(), grad); - break; - case kIROp_Div: - leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); - rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad - break; - default: - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - - lhs = origArith->getOperand(0); - rhs = origArith->getOperand(1); - if (auto leftGrads = upperGradients.TryGetValue(lhs)) - { - leftGrads->add(leftGrad); - } - else - { - upperGradients.Add(lhs, leftGrad); - } - if (auto rightGrads = upperGradients.TryGetValue(rhs)) - { - rightGrads->add(rightGrad); - } - else - { - upperGradients.Add(rhs, rightGrad); - } - - return nullptr; - } - - InstPair copyInst(IRBuilder* builder, IRInst* origInst) - { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transcribeParam(builder, as<IRParam>(origInst)); - - case kIROp_Return: - return InstPair(nullptr, nullptr); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return copyBinaryArith(builder, origInst); - - default: - // Not yet implemented - SLANG_ASSERT(0); - } - - return InstPair(nullptr, nullptr); - } - - IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) - { - IRInOutType* inoutParam = as<IRInOutType>(param->getDataType()); - auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType()); - auto paramValue = builder->emitLoad(param); - auto primal = builder->emitDifferentialPairGetPrimal(paramValue); - auto diff = builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - paramValue - ); - auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad); - auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff); - auto store = builder->emitStore(param, updatedParam); - - return store; - } - - IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) - { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transcribeParamBackward(builder, as<IRParam>(origInst), grad); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return transcribeBinaryArithBackward(builder, origInst, grad); - - case kIROp_DifferentialPairGetPrimal: - { - if (auto param = primalToDiffPair.TryGetValue(origInst)) - { - if (auto leftGrads = upperGradients.TryGetValue(*param)) - { - leftGrads->add(grad); - } - else - { - upperGradients.Add(*param, grad); - } - } - else - SLANG_ASSERT(0); - return nullptr; - } - - default: - // Not yet implemented - SLANG_ASSERT(0); - } - - return nullptr; - } -}; - - -struct JVPDerivativeContext : public InstPassBase +struct ForwardDerivativePass : public InstPassBase { DiagnosticSink* getSink() @@ -2494,18 +1516,10 @@ struct JVPDerivativeContext : public InstPassBase bool processModule() { - // We start by initializing our shared IR building state, - // since we will re-use that state for any code we - // generate along the way. - // - SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; - sharedBuilder->init(module); - sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); - // TODO(sai): Move this call. transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary(); - IRBuilder builderStorage(sharedBuilderStorage); + IRBuilder builderStorage(this->autodiffContext->sharedBuilder); IRBuilder* builder = &builderStorage; // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by @@ -2513,20 +1527,6 @@ struct JVPDerivativeContext : public InstPassBase // bool modified = processReferencedFunctions(builder); - // Replaces IRDifferentialPairType with an auto-generated struct, - // IRDifferentialPairGetDifferential with 'differential' field access, - // IRDifferentialPairGetPrimal with 'primal' field access, and - // IRMakeDifferentialPair with an IRMakeStruct. - // - modified |= simplifyDifferentialBottomType(builder); - - // De-duplicate any remaining types. - sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); - - modified |= processPairTypes(builder, module->getModuleInst()); - - modified |= eliminateDifferentialBottomType(builder); - return modified; } @@ -2553,7 +1553,6 @@ struct JVPDerivativeContext : public InstPassBase switch (inst->getOp()) { case kIROp_ForwardDifferentiate: - case kIROp_BackwardDifferentiate: // Only process now if the operand is a materialized function. switch (inst->getOperand(0)->getOp()) { @@ -2577,7 +1576,6 @@ struct JVPDerivativeContext : public InstPassBase // differentiated functions. transcriberStorage.followUpFunctionsToTranscribe.clear(); - backwardTranscriberStorage.followUpFunctionsToTranscribe.clear(); for (auto differentiateInst : autoDiffWorkList) { @@ -2613,28 +1611,6 @@ struct JVPDerivativeContext : public InstPassBase } } - else if (as<IRBackwardDifferentiate>(differentiateInst)) - { - if (isMarkedForBackwardDifferentiation(baseInst)) - { - if (as<IRFunc>(baseInst)) - { - IRInst* diffFunc = - backwardTranscriberStorage - .transcribeFuncHeader(builder, (IRFunc*)baseInst) - .differential; - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); - } - } - } } // Actually synthesize the derivatives. List<InstPair> followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe); @@ -2647,16 +1623,6 @@ struct JVPDerivativeContext : public InstPassBase transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc); } - followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe); - for (auto task : followUpWorkList) - { - auto diffFunc = as<IRFunc>(task.differential); - SLANG_ASSERT(diffFunc); - auto primalFunc = as<IRFunc>(task.primal); - SLANG_ASSERT(primalFunc); - - backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc); - } // Transcribing the function body really shouldn't produce more follow up function body work. // However it may produce new `ForwardDifferentiate` instructions, which we collect and process @@ -2667,281 +1633,6 @@ struct JVPDerivativeContext : public InstPassBase return true; } - IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr) - { - builder->setInsertBefore(pairType); - auto loweredPairTypeInfo = (&pairBuilderStorage)->lowerDiffPairType( - builder, - pairType); - if (isTrivial) - *isTrivial = loweredPairTypeInfo.isTrivial; - return loweredPairTypeInfo.loweredType; - } - - IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) - { - - if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) - { - bool isTrivial = false; - auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType()); - if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial)) - { - builder->setInsertBefore(makePairInst); - IRInst* result = nullptr; - if (isTrivial) - { - result = makePairInst->getPrimalValue(); - } - else - { - IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() }; - result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands); - } - makePairInst->replaceUsesWith(result); - makePairInst->removeAndDeallocate(); - return result; - } - } - - return nullptr; - } - - IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst) - { - if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) - { - auto pairType = getDiffInst->getBase()->getDataType(); - if (auto pairPtrType = as<IRPtrTypeBase>(pairType)) - { - pairType = pairPtrType->getValueType(); - } - - if (lowerPairType(builder, pairType, nullptr)) - { - builder->setInsertBefore(getDiffInst); - IRInst* diffFieldExtract = nullptr; - diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); - getDiffInst->replaceUsesWith(diffFieldExtract); - getDiffInst->removeAndDeallocate(); - return diffFieldExtract; - } - } - else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) - { - auto pairType = getPrimalInst->getBase()->getDataType(); - if (auto pairPtrType = as<IRPtrTypeBase>(pairType)) - { - pairType = pairPtrType->getValueType(); - } - - if (lowerPairType(builder, pairType, nullptr)) - { - builder->setInsertBefore(getPrimalInst); - - IRInst* primalFieldExtract = nullptr; - primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); - getPrimalInst->replaceUsesWith(primalFieldExtract); - getPrimalInst->removeAndDeallocate(); - return primalFieldExtract; - } - } - - return nullptr; - } - - bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) - { - bool modified = false; - // Hoist all pair types to global scope when possible. - auto moduleInst = module->getModuleInst(); - processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType) - { - if (originalPairType->parent != moduleInst) - { - originalPairType->removeFromParent(); - ShortList<IRInst*> operands; - for (UInt i = 0; i < originalPairType->getOperandCount(); i++) - { - operands.add(originalPairType->getOperand(i)); - } - auto newPairType = builder->findOrEmitHoistableInst( - originalPairType->getFullType(), - originalPairType->getOp(), - originalPairType->getOperandCount(), - operands.getArrayView().getBuffer()); - originalPairType->replaceUsesWith(newPairType); - originalPairType->removeAndDeallocate(); - } - }); - - sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); - - processAllInsts([&](IRInst* inst) - { - // Make sure the builder is at the right level. - builder->setInsertInto(instWithChildren); - - switch (inst->getOp()) - { - case kIROp_DifferentialPairGetDifferential: - case kIROp_DifferentialPairGetPrimal: - lowerPairAccess(builder, inst); - modified = true; - break; - - case kIROp_MakeDifferentialPair: - lowerMakePair(builder, inst); - modified = true; - break; - - default: - break; - } - }); - - processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) - { - if (auto loweredType = lowerPairType(builder, inst)) - { - inst->replaceUsesWith(loweredType); - inst->removeAndDeallocate(); - } - }); - return modified; - } - - bool simplifyDifferentialBottomType(IRBuilder* builder) - { - bool modified = false; - auto diffBottom = builder->getDifferentialBottom(); - - bool changed = true; - List<IRUse*> uses; - while (changed) - { - changed = false; - // Replace all insts whose type is `DifferentialBottomType` to `diffBottom`. - processAllInsts([&](IRInst* inst) - { - if (inst->getDataType() && inst->getDataType()->getOp() == kIROp_DifferentialBottomType) - { - if (inst != diffBottom) - { - inst->replaceUsesWith(diffBottom); - inst->removeAndDeallocate(); - modified = true; - } - } - }); - // Go through all uses of diffBottom and run simplification. - processAllInsts([&](IRInst* inst) - { - if (!inst->hasUses()) - return; - - builder->setInsertBefore(inst); - IRInst* valueToReplace = nullptr; - switch (inst->getOp()) - { - case kIROp_Store: - if (as<IRStore>(inst)->getVal() == diffBottom) - { - inst->removeAndDeallocate(); - changed = true; - } - return; - case kIROp_MakeDifferentialPair: - // Our simplification could lead to a situation where - // bottom is used to make a pair that has a non-bottom differential type, - // in this case we should use zero instead. - if (inst->getOperand(1) == diffBottom) - { - // Only apply if we are the second operand. - auto pairType = as<IRDifferentialPairType>(inst->getDataType()); - if (pairBuilderStorage.getDiffTypeFromPairType(builder, pairType)->getOp() != kIROp_DifferentialBottomType) - { - auto zero = transcriberStorage.getDifferentialZeroOfType(builder, pairType->getValueType()); - inst->setOperand(1, zero); - changed = true; - } - } - return; - case kIROp_DifferentialPairGetDifferential: - if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair) - { - valueToReplace = inst->getOperand(0)->getOperand(1); - } - break; - case kIROp_DifferentialPairGetPrimal: - if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair) - { - valueToReplace = inst->getOperand(0)->getOperand(0); - } - break; - case kIROp_Add: - if (inst->getOperand(0) == diffBottom) - { - valueToReplace = inst->getOperand(1); - } - else if (inst->getOperand(1) == diffBottom) - { - valueToReplace = inst->getOperand(0); - } - break; - case kIROp_Sub: - if (inst->getOperand(0) == diffBottom) - { - // If left is bottom, and right is not bottom, then we should return -right. - // However we can't possibly run into that case since both side of - operator - // must be at the same order of differentiation. - valueToReplace = diffBottom; - } - else if (inst->getOperand(1) == diffBottom) - { - valueToReplace = inst->getOperand(0); - } - break; - case kIROp_Mul: - case kIROp_Div: - if (inst->getOperand(0) == diffBottom) - { - valueToReplace = diffBottom; - } - else if (inst->getOperand(1) == diffBottom) - { - valueToReplace = diffBottom; - } - break; - default: - break; - } - if (valueToReplace) - { - inst->replaceUsesWith(valueToReplace); - changed = true; - } - }); - modified |= changed; - } - - return modified; - } - - bool eliminateDifferentialBottomType(IRBuilder* builder) - { - simplifyDifferentialBottomType(builder); - - bool modified = false; - auto diffBottom = builder->getDifferentialBottom(); - auto diffBottomType = diffBottom->getDataType(); - diffBottom->replaceUsesWith(builder->getVoidValue()); - diffBottom->removeAndDeallocate(); - diffBottomType->replaceUsesWith(builder->getVoidType()); - - return modified; - } - // Checks decorators to see if the function should // be differentiated (kIROp_ForwardDifferentiableDecoration) // @@ -2968,45 +1659,16 @@ struct JVPDerivativeContext : public InstPassBase return name; } - // Checks decorators to see if the function should - // be differentiated (kIROp_ForwardDifferentiableDecoration) - // - bool isMarkedForBackwardDifferentiation(IRInst* callable) - { - return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr; - } - - IRStringLit* getBackwardDerivativeFuncName(IRInst* func) - { - IRBuilder builder(&sharedBuilderStorage); - builder.setInsertBefore(func); - - IRStringLit* name = nullptr; - if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) - { - name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice()); - } - else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) - { - name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice()); - } - - return name; - } - - JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : - InstPassBase(module), + ForwardDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) : + InstPassBase(context->moduleInst->getModule()), sink(sink), - autoDiffSharedContextStorage(module->getModuleInst()), - transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage), - backwardTranscriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage, sink) + transcriberStorage(context, context->sharedBuilder), + pairBuilderStorage(context), + autodiffContext(context) { - autoDiffSharedContextStorage.sharedBuilder = &sharedBuilderStorage; - pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage; transcriberStorage.sink = sink; - transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage); + transcriberStorage.autoDiffSharedContext = context; transcriberStorage.pairBuilder = &(pairBuilderStorage); - backwardTranscriberStorage.pairBuilder = &pairBuilderStorage; } protected: @@ -3015,15 +1677,12 @@ protected: // JVPTranscriber transcriberStorage; - BackwardDiffTranscriber backwardTranscriberStorage; - // Diagnostic object from the compile request for // error messages. - DiagnosticSink* sink; + DiagnosticSink* sink; - // Context to find and manage the witness tables for types - // implementing `IDifferentiable` - AutoDiffSharedContext autoDiffSharedContextStorage; + // Shared context. + AutoDiffSharedContext* autodiffContext; // Builder for dealing with differential pair types. DifferentialPairTypeBuilder pairBuilderStorage; @@ -3032,106 +1691,16 @@ protected: // Set up context and call main process method. // -bool processDifferentiableFuncs( - IRModule* module, +bool processForwardDerivativeCalls( + AutoDiffSharedContext* autodiffContext, DiagnosticSink* sink, - IRJVPDerivativePassOptions const&) + ForwardDerivativePassOptions const&) { - // Simplify module to remove dead code. - IRDeadCodeEliminationOptions options; - options.keepExportsAlive = true; - options.keepLayoutsAlive = true; - eliminateDeadCode(module, options); - - JVPDerivativeContext context(module, sink); - bool changed = context.processModule(); + ForwardDerivativePass fwdPass(autodiffContext, sink); + bool changed = fwdPass.processModule(); return changed; } -void stripAutoDiffDecorationsFromChildren(IRInst* parent) -{ - for (auto inst : parent->getChildren()) - { - for (auto decor = inst->getFirstDecoration(); decor; ) - { - auto next = decor->getNextDecoration(); - switch (decor->getOp()) - { - case kIROp_ForwardDerivativeDecoration: - case kIROp_DerivativeMemberDecoration: - case kIROp_DifferentiableTypeDictionaryDecoration: - decor->removeAndDeallocate(); - break; - default: - break; - } - decor = next; - } - - if (inst->getFirstChild() != nullptr) - { - stripAutoDiffDecorationsFromChildren(inst); - } - } -} - -void stripAutoDiffDecorations(IRModule* module) -{ - stripAutoDiffDecorationsFromChildren(module->getModuleInst()); -} - -AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst) - : moduleInst(inModuleInst) -{ - differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); - if (differentiableInterfaceType) - { - differentialAssocTypeStructKey = findDifferentialTypeStructKey(); - differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); - zeroMethodStructKey = findZeroMethodStructKey(); - addMethodStructKey = findAddMethodStructKey(); - mulMethodStructKey = findMulMethodStructKey(); - - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; - } -} - -IRInst* AutoDiffSharedContext::findDifferentiableInterface() -{ - if (auto module = as<IRModuleInst>(moduleInst)) - { - for (auto globalInst : module->getGlobalInsts()) - { - // TODO: This seems like a particularly dangerous way to look for an interface. - // See if we can lower IDifferentiable to a separate IR inst. - // - if (globalInst->getOp() == kIROp_InterfaceType && - as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable") - { - return globalInst; - } - } - } - return nullptr; -} - -IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) -{ - if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) - { - // Assume for now that IDifferentiable has exactly five fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); - if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) - return as<IRStructKey>(entry->getRequirementKey()); - else - { - SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); - } - } - - return nullptr; -} void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { @@ -3148,11 +1717,6 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType()); if (existingItem) { - if (auto witness = as<IRWitnessTable>(item->getWitness())) - { - if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType) - continue; - } *existingItem = item->getWitness(); } else diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h new file mode 100644 index 000000000..6b261ecd0 --- /dev/null +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -0,0 +1,43 @@ +// slang-ir-autodiff-fwd.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +namespace Slang +{ + + template<typename P, typename D> + struct DiffInstPair + { + P primal; + D differential; + DiffInstPair() = default; + DiffInstPair(P primal, D differential) : primal(primal), differential(differential) + {} + HashCode getHashCode() const + { + Hasher hasher; + hasher << primal << differential; + return hasher.getResult(); + } + bool operator ==(const DiffInstPair& other) const + { + return primal == other.primal && differential == other.differential; + } + }; + + typedef DiffInstPair<IRInst*, IRInst*> InstPair; + + struct ForwardDerivativePassOptions + { + // Nothing for now.. + }; + + bool processForwardDerivativeCalls( + AutoDiffSharedContext* autodiffContext, + DiagnosticSink* sink, + ForwardDerivativePassOptions const& options = ForwardDerivativePassOptions()); + +} diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp new file mode 100644 index 000000000..1dbb1bd7c --- /dev/null +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -0,0 +1,182 @@ +#include "slang-ir-autodiff-pairs.h" + +namespace Slang +{ + +struct DiffPairLoweringPass : InstPassBase +{ + DiffPairLoweringPass(AutoDiffSharedContext* context) : + InstPassBase(context->moduleInst->getModule()), + pairBuilderStorage(context), + autodiffContext(context) + { + pairBuilder = &pairBuilderStorage; + } + + IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr) + { + builder->setInsertBefore(pairType); + auto loweredPairTypeInfo = pairBuilder->lowerDiffPairType( + builder, + pairType); + if (isTrivial) + *isTrivial = loweredPairTypeInfo.isTrivial; + return loweredPairTypeInfo.loweredType; + } + + IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) + { + + if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) + { + bool isTrivial = false; + auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType()); + if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial)) + { + builder->setInsertBefore(makePairInst); + IRInst* result = nullptr; + if (isTrivial) + { + result = makePairInst->getPrimalValue(); + } + else + { + IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() }; + result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands); + } + makePairInst->replaceUsesWith(result); + makePairInst->removeAndDeallocate(); + return result; + } + } + + return nullptr; + } + + IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst) + { + if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) + { + auto pairType = getDiffInst->getBase()->getDataType(); + if (auto pairPtrType = as<IRPtrTypeBase>(pairType)) + { + pairType = pairPtrType->getValueType(); + } + + if (lowerPairType(builder, pairType, nullptr)) + { + builder->setInsertBefore(getDiffInst); + IRInst* diffFieldExtract = nullptr; + diffFieldExtract = pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase()); + getDiffInst->replaceUsesWith(diffFieldExtract); + getDiffInst->removeAndDeallocate(); + return diffFieldExtract; + } + } + else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) + { + auto pairType = getPrimalInst->getBase()->getDataType(); + if (auto pairPtrType = as<IRPtrTypeBase>(pairType)) + { + pairType = pairPtrType->getValueType(); + } + + if (lowerPairType(builder, pairType, nullptr)) + { + builder->setInsertBefore(getPrimalInst); + + IRInst* primalFieldExtract = nullptr; + primalFieldExtract = pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + getPrimalInst->replaceUsesWith(primalFieldExtract); + getPrimalInst->removeAndDeallocate(); + return primalFieldExtract; + } + } + + return nullptr; + } + + bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren) + { + bool modified = false; + // Hoist all pair types to global scope when possible. + auto moduleInst = module->getModuleInst(); + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType) + { + if (originalPairType->parent != moduleInst) + { + originalPairType->removeFromParent(); + ShortList<IRInst*> operands; + for (UInt i = 0; i < originalPairType->getOperandCount(); i++) + { + operands.add(originalPairType->getOperand(i)); + } + auto newPairType = builder->findOrEmitHoistableInst( + originalPairType->getFullType(), + originalPairType->getOp(), + originalPairType->getOperandCount(), + operands.getArrayView().getBuffer()); + originalPairType->replaceUsesWith(newPairType); + originalPairType->removeAndDeallocate(); + } + }); + + autodiffContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + + processAllInsts([&](IRInst* inst) + { + // Make sure the builder is at the right level. + builder->setInsertInto(instWithChildren); + + switch (inst->getOp()) + { + case kIROp_DifferentialPairGetDifferential: + case kIROp_DifferentialPairGetPrimal: + lowerPairAccess(builder, inst); + modified = true; + break; + + case kIROp_MakeDifferentialPair: + lowerMakePair(builder, inst); + modified = true; + break; + + default: + break; + } + }); + + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) + { + if (auto loweredType = lowerPairType(builder, inst)) + { + inst->replaceUsesWith(loweredType); + inst->removeAndDeallocate(); + } + }); + return modified; + } + + bool processModule() + { + IRBuilder builder(autodiffContext->sharedBuilder); + return processInstWithChildren(&builder, module->getModuleInst()); + } + + private: + + AutoDiffSharedContext* autodiffContext; + + DifferentialPairTypeBuilder* pairBuilder; + + DifferentialPairTypeBuilder pairBuilderStorage; + +}; + +bool processPairTypes(AutoDiffSharedContext* context) +{ + DiffPairLoweringPass pairLoweringPass(context); + return pairLoweringPass.processModule(); +} + +}
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-pairs.h b/source/slang/slang-ir-autodiff-pairs.h new file mode 100644 index 000000000..44321ae9b --- /dev/null +++ b/source/slang/slang-ir-autodiff-pairs.h @@ -0,0 +1,21 @@ +// slang-ir-autodiff-pairs.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" + +#include "slang-ir-autodiff.h" + +namespace Slang +{ + +bool processPairTypes(AutoDiffSharedContext* context); + +}
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp new file mode 100644 index 000000000..52567e887 --- /dev/null +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -0,0 +1,832 @@ +#include "slang-ir-autodiff-rev.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" + + +namespace Slang +{ +struct BackwardDiffTranscriber +{ + // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent + // their differential values. + Dictionary<IRInst*, IRInst*> orginalToTranscribed; + + // Set of insts currently being transcribed. Used to avoid infinite loops. + HashSet<IRInst*> instsInProgress; + + // Cloning environment to hold mapping from old to new copies for the primal + // instructions. + IRCloneEnv cloneEnv; + + // Diagnostic sink for error messages. + DiagnosticSink* sink; + + // Type conformance information. + AutoDiffSharedContext* autoDiffSharedContext; + + // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct + DifferentialPairTypeBuilder* pairBuilder; + + DifferentiableTypeConformanceContext differentiableTypeConformanceContext; + + List<InstPair> followUpFunctionsToTranscribe; + + // Map that stores the upper gradient given an IRInst* + Dictionary<IRInst*, List<IRInst*>> upperGradients; + Dictionary<IRInst*, IRInst*> primalToDiffPair; + + SharedIRBuilder* sharedBuilder; + // Witness table that `DifferentialBottom:IDifferential`. + IRWitnessTable* differentialBottomWitness = nullptr; + Dictionary<InstPair, IRInst*> differentialPairTypes; + + BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + : autoDiffSharedContext(shared) + , sink(inSink) + , differentiableTypeConformanceContext(shared) + , sharedBuilder(inSharedBuilder) + {} + + DiagnosticSink* getSink() + { + SLANG_ASSERT(sink); + return sink; + } + + IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + { + List<IRType*> newParameterTypes; + IRType* diffReturnType; + + for (UIndex i = 0; i < funcType->getParamCount(); i++) + { + auto origType = funcType->getParamType(i); + if (auto diffPairType = tryGetDiffPairType(builder, origType)) + { + auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); + newParameterTypes.add(inoutDiffPairType); + } + else + newParameterTypes.add(origType); + } + + newParameterTypes.add(funcType->getResultType()); + + diffReturnType = builder->getVoidType(); + + return builder->getFuncType(newParameterTypes, diffReturnType); + } + + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. + IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(inDiffPairType->parent); + auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); + SLANG_ASSERT(diffPairType); + + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + auto diffType = differentiateType(&builder, diffPairType->getValueType()); + auto differentialType = builder.getDifferentialPairType(diffType, nullptr); + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + return table; + } + + IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* getOrCreateDiffPairType(IRInst* primalType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + auto witness = as<IRWitnessTable>( + differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); + + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* differentiateType(IRBuilder* builder, IRType* origType) + { + IRInst* diffType = nullptr; + if (!orginalToTranscribed.TryGetValue(origType, diffType)) + { + diffType = _differentiateTypeImpl(builder, origType); + orginalToTranscribed[origType] = diffType; + } + return (IRType*)diffType; + } + + IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) + { + if (auto ptrType = as<IRPtrTypeBase>(origType)) + return builder->getPtrType( + origType->getOp(), + differentiateType(builder, ptrType->getValueType())); + + // If there is an explicit primal version of this type in the local scope, load that + // otherwise use the original type. + // + IRInst* primalType = origType; + + // Special case certain compound types (PtrType, FuncType, etc..) + // otherwise try to lookup a differential definition for the given type. + // If one does not exist, then we assume it's not differentiable. + // + switch (primalType->getOp()) + { + case kIROp_Param: + if (as<IRTypeType>(primalType->getDataType())) + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( + builder, + (IRType*)primalType)); + else if (as<IRWitnessTableType>(primalType->getDataType())) + return (IRType*)primalType; + + case kIROp_ArrayType: + { + auto primalArrayType = as<IRArrayType>(primalType); + if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) + return builder->getArrayType( + diffElementType, + primalArrayType->getElementCount()); + else + return nullptr; + } + + case kIROp_DifferentialPairType: + { + auto primalPairType = as<IRDifferentialPairType>(primalType); + return getOrCreateDiffPairType( + pairBuilder->getDiffTypeFromPairType(builder, primalPairType), + pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_FuncType: + return differentiateFunctionType(builder, as<IRFuncType>(primalType)); + + case kIROp_OutType: + if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) + return builder->getOutType(diffValueType); + else + return nullptr; + + case kIROp_InOutType: + if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) + return builder->getInOutType(diffValueType); + else + return nullptr; + + case kIROp_TupleType: + { + auto tupleType = as<IRTupleType>(primalType); + List<IRType*> diffTypeList; + // TODO: what if we have type parameters here? + for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + diffTypeList.add( + differentiateType(builder, (IRType*)tupleType->getOperand(ii))); + + return builder->getTupleType(diffTypeList); + } + + default: + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); + } + } + + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) + { + // If this is a PtrType (out, inout, etc..), then create diff pair from + // value type and re-apply the appropropriate PtrType wrapper. + // + if (auto origPtrType = as<IRPtrTypeBase>(primalType)) + { + if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) + return builder->getPtrType(primalType->getOp(), diffPairValueType); + else + return nullptr; + } + auto diffType = differentiateType(builder, primalType); + if (diffType) + return (IRType*)getOrCreateDiffPairType(primalType); + return nullptr; + } + + InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) + { + auto primalDataType = origParam->getDataType(); + // Do not differentiate generic type (and witness table) parameters + if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) + { + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + IRInst* diffPairParam = builder->emitParam(diffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffPairParam); + + if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + // If this is an `in/inout DifferentialPair<>` parameter, we can't produce + // its primal and diff parts right now because they would represent a reference + // to a pair field, which doesn't make sense since pair types are considered mutable. + // We encode the result as if the param is non-differentiable, and handle it + // with special care at load/store. + return InstPair(diffPairParam, nullptr); + } + + + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + // Returns "dp<var-name>" to use as a name hint for parameters. + // If no primal name is available, returns a blank string. + // + String makeDiffPairName(IRInst* origVar) + { + if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) + { + return ("dp" + String(namehintDecoration->getName())); + } + + return String(""); + } + + + // In differential computation, the 'default' differential value is always zero. + // This is a consequence of differential computing being inherently linear. As a + // result, it's useful to have a method to generate zero literals of any (arithmetic) type. + // The current implementation requires that types are defined linearly. + // + IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) + { + if (auto diffType = differentiateType(builder, primalType)) + { + switch (diffType->getOp()) + { + case kIROp_DifferentialPairType: + return builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); + } + // Since primalType has a corresponding differential type, we can lookup the + // definition for zero(). + auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); + SLANG_ASSERT(zeroMethod); + + auto emptyArgList = List<IRInst*>(); + return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + } + else + { + if (isScalarIntegerType(primalType)) + { + return builder->getIntValue(primalType, 0); + } + + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; + } + } + + InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) + { + IRBuilder subBuilder(builder->getSharedBuilder()); + subBuilder.setInsertLoc(builder->getInsertLoc()); + + IRBlock* diffBlock = subBuilder.emitBlock(); + + subBuilder.setInsertInto(diffBlock); + + // First transcribe every parameter in the block. + for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) + this->copyParam(&subBuilder, param); + + // The extra param for input gradient + auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType()); + + // Then, run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + this->copyInst(&subBuilder, child); + + auto lastInst = diffBlock->getLastOrdinaryInst(); + List<IRInst*> grads = { gradParam }; + upperGradients.Add(lastInst, grads); + for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst()) + { + auto upperGrads = upperGradients.TryGetValue(child); + if (!upperGrads) + continue; + if (upperGrads->getCount() > 1) + { + auto sumGrad = upperGrads->getFirst(); + for (auto i = 1; i < upperGrads->getCount(); i++) + { + sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]); + } + this->transcribeInstBackward(&subBuilder, child, sumGrad); + } + else + this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst()); + } + + subBuilder.emitReturn(); + + return InstPair(diffBlock, diffBlock); + } + + // Create an empty func to represent the transcribed func of `origFunc`. + InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) + { + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origFunc); + + IRFunc* primalFunc = origFunc; + + differentiableTypeConformanceContext.setFunc(origFunc); + + primalFunc = origFunc; + + auto diffFunc = builder.createFunc(); + + SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); + IRType* diffFuncType = this->differentiateFunctionType( + &builder, + as<IRFuncType>(origFunc->getFullType())); + diffFunc->setFullType(diffFuncType); + + if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>()) + { + auto originalName = nameHint->getName(); + StringBuilder newNameSb; + newNameSb << "s_bwd_" << originalName; + builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); + } + builder.addBackwardDerivativeDecoration(origFunc, diffFunc); + + // Mark the generated derivative function itself as differentiable. + builder.addBackwardDifferentiableDecoration(diffFunc); + + // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. + if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) + { + cloneDecoration(dictDecor, diffFunc); + } + + auto result = InstPair(primalFunc, diffFunc); + followUpFunctionsToTranscribe.add(result); + return result; + } + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) + { + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertInto(diffFunc); + + differentiableTypeConformanceContext.setFunc(primalFunc); + // Transcribe children from origFunc into diffFunc + for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribeBlock(&builder, block); + + return InstPair(primalFunc, diffFunc); + } + + IRInst* copyParam(IRBuilder* builder, IRParam* origParam) + { + auto primalDataType = origParam->getDataType(); + + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); + IRInst* diffParam = builder->emitParam(inoutDiffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffParam); + auto paramValue = builder->emitLoad(diffParam); + auto primal = builder->emitDifferentialPairGetPrimal(paramValue); + orginalToTranscribed.Add(origParam, primal); + primalToDiffPair.Add(primal, diffParam); + + return diffParam; + } + + + return cloneInst(&cloneEnv, builder, origParam); + } + + InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith) + { + SLANG_ASSERT(origArith->getOperandCount() == 2); + + auto origLeft = origArith->getOperand(0); + auto origRight = origArith->getOperand(1); + + IRInst* primalLeft; + if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft)) + { + primalLeft = origLeft; + } + IRInst* primalRight; + if (!orginalToTranscribed.TryGetValue(origRight, primalRight)) + { + primalRight = origRight; + } + + auto resultType = origArith->getDataType(); + IRInst* newInst; + switch (origArith->getOp()) + { + case kIROp_Add: + newInst = builder->emitAdd(resultType, primalLeft, primalRight); + break; + case kIROp_Mul: + newInst = builder->emitMul(resultType, primalLeft, primalRight); + break; + case kIROp_Sub: + newInst = builder->emitSub(resultType, primalLeft, primalRight); + break; + case kIROp_Div: + newInst = builder->emitDiv(resultType, primalLeft, primalRight); + break; + default: + newInst = nullptr; + getSink()->diagnose(origArith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); + } + orginalToTranscribed.Add(origArith, newInst); + return InstPair(newInst, nullptr); + } + + IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) + { + SLANG_ASSERT(origArith->getOperandCount() == 2); + + auto lhs = origArith->getOperand(0); + auto rhs = origArith->getOperand(1); + + if (as<IRInOutType>(lhs->getDataType())) + { + lhs = builder->emitLoad(lhs); + lhs = builder->emitDifferentialPairGetPrimal(lhs); + } + if (as<IRInOutType>(rhs->getDataType())) + { + rhs = builder->emitLoad(rhs); + rhs = builder->emitDifferentialPairGetPrimal(rhs); + } + + IRInst* leftGrad; + IRInst* rightGrad; + + + switch (origArith->getOp()) + { + case kIROp_Add: + leftGrad = grad; + rightGrad = grad; + break; + case kIROp_Mul: + leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); + rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); + break; + case kIROp_Sub: + leftGrad = grad; + rightGrad = builder->emitNeg(grad->getDataType(), grad); + break; + case kIROp_Div: + leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); + rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad + break; + default: + getSink()->diagnose(origArith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); + } + + lhs = origArith->getOperand(0); + rhs = origArith->getOperand(1); + if (auto leftGrads = upperGradients.TryGetValue(lhs)) + { + leftGrads->add(leftGrad); + } + else + { + upperGradients.Add(lhs, leftGrad); + } + if (auto rightGrads = upperGradients.TryGetValue(rhs)) + { + rightGrads->add(rightGrad); + } + else + { + upperGradients.Add(rhs, rightGrad); + } + + return nullptr; + } + + InstPair copyInst(IRBuilder* builder, IRInst* origInst) + { + // Handle common SSA-style operations + switch (origInst->getOp()) + { + case kIROp_Param: + return transcribeParam(builder, as<IRParam>(origInst)); + + case kIROp_Return: + return InstPair(nullptr, nullptr); + + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + case kIROp_Div: + return copyBinaryArith(builder, origInst); + + default: + // Not yet implemented + SLANG_ASSERT(0); + } + + return InstPair(nullptr, nullptr); + } + + IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) + { + IRInOutType* inoutParam = as<IRInOutType>(param->getDataType()); + auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType()); + auto paramValue = builder->emitLoad(param); + auto primal = builder->emitDifferentialPairGetPrimal(paramValue); + auto diff = builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + paramValue + ); + auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad); + auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff); + auto store = builder->emitStore(param, updatedParam); + + return store; + } + + IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) + { + // Handle common SSA-style operations + switch (origInst->getOp()) + { + case kIROp_Param: + return transcribeParamBackward(builder, as<IRParam>(origInst), grad); + + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + case kIROp_Div: + return transcribeBinaryArithBackward(builder, origInst, grad); + + case kIROp_DifferentialPairGetPrimal: + { + if (auto param = primalToDiffPair.TryGetValue(origInst)) + { + if (auto leftGrads = upperGradients.TryGetValue(*param)) + { + leftGrads->add(grad); + } + else + { + upperGradients.Add(*param, grad); + } + } + else + SLANG_ASSERT(0); + return nullptr; + } + + default: + // Not yet implemented + SLANG_ASSERT(0); + } + + return nullptr; + } + + +}; + +struct ReverseDerivativePass : public InstPassBase +{ + + DiagnosticSink* getSink() + { + return sink; + } + + bool processModule() + { + + IRBuilder builderStorage(autodiffContext->sharedBuilder); + IRBuilder* builder = &builderStorage; + + // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by + // generating derivative code for the referenced function. + // + bool modified = processReferencedFunctions(builder); + + return modified; + } + + IRInst* lookupJVPReference(IRInst* primalFunction) + { + if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) + return jvpDefinition->getForwardDerivativeFunc(); + return nullptr; + } + + // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate), + // then check that the referenced function is marked correctly for differentiation. + // + bool processReferencedFunctions(IRBuilder* builder) + { + List<IRInst*> autoDiffWorkList; + + for (;;) + { + // Collect all `ForwardDifferentiate` insts from the module. + autoDiffWorkList.clear(); + processAllInsts([&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_BackwardDifferentiate: + // Only process now if the operand is a materialized function. + switch (inst->getOperand(0)->getOp()) + { + case kIROp_Func: + case kIROp_Specialize: + autoDiffWorkList.add(inst); + break; + default: + break; + } + break; + default: + break; + } + }); + + if (autoDiffWorkList.getCount() == 0) + break; + + // Process collected `ForwardDifferentiate` insts and replace them with placeholders for + // differentiated functions. + + backwardTranscriberStorage.followUpFunctionsToTranscribe.clear(); + + for (auto differentiateInst : autoDiffWorkList) + { + IRInst* baseInst = differentiateInst->getOperand(0); + if (as<IRBackwardDifferentiate>(differentiateInst)) + { + if (isMarkedForBackwardDifferentiation(baseInst)) + { + if (as<IRFunc>(baseInst)) + { + IRInst* diffFunc = + backwardTranscriberStorage + .transcribeFuncHeader(builder, (IRFunc*)baseInst) + .differential; + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + } + else + { + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); + } + } + } + } + + auto followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe); + for (auto task : followUpWorkList) + { + auto diffFunc = as<IRFunc>(task.differential); + SLANG_ASSERT(diffFunc); + auto primalFunc = as<IRFunc>(task.primal); + SLANG_ASSERT(primalFunc); + + backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc); + } + + // Transcribing the function body really shouldn't produce more follow up function body work. + // However it may produce new `ForwardDifferentiate` instructions, which we collect and process + // in the next iteration. + SLANG_RELEASE_ASSERT(backwardTranscriberStorage.followUpFunctionsToTranscribe.getCount() == 0); + + } + return true; + } + + // Checks decorators to see if the function should + // be differentiated (kIROp_ForwardDifferentiableDecoration) + // + bool isMarkedForBackwardDifferentiation(IRInst* callable) + { + return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr; + } + + IRStringLit* getBackwardDerivativeFuncName(IRInst* func) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(func); + + IRStringLit* name = nullptr; + if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) + { + name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice()); + } + else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) + { + name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice()); + } + + return name; + } + + ReverseDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) : + InstPassBase(context->moduleInst->getModule()), + sink(sink), + backwardTranscriberStorage(context, context->sharedBuilder, sink), + autodiffContext(context), + pairBuilderStorage(context) + { + backwardTranscriberStorage.pairBuilder = &pairBuilderStorage; + } + +protected: + // A transcriber object that handles the main job of + // processing instructions while maintaining state. + // + BackwardDiffTranscriber backwardTranscriberStorage; + + // Diagnostic object from the compile request for + // error messages. + DiagnosticSink* sink; + + // Builder for dealing with differential pair types. + DifferentialPairTypeBuilder pairBuilderStorage; + + // Autodiff Shared Context + AutoDiffSharedContext* autodiffContext; +}; + +bool processReverseDerivativeCalls( + AutoDiffSharedContext* autodiffContext, + DiagnosticSink* sink, + IRReverseDerivativePassOptions const&) +{ + ReverseDerivativePass revPass(autodiffContext, sink); + bool changed = revPass.processModule(); + return changed; +} + +}
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h new file mode 100644 index 000000000..c3d31e2a9 --- /dev/null +++ b/source/slang/slang-ir-autodiff-rev.h @@ -0,0 +1,25 @@ +// slang-ir-autodiff-rev.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-fwd.h" + +namespace Slang +{ + +struct IRReverseDerivativePassOptions +{ + // Nothing for now.. +}; + +bool processReverseDerivativeCalls( + AutoDiffSharedContext* autodiffContext, + DiagnosticSink* sink, + IRReverseDerivativePassOptions const& options = IRReverseDerivativePassOptions()); + + +}
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp new file mode 100644 index 000000000..313760d85 --- /dev/null +++ b/source/slang/slang-ir-autodiff.cpp @@ -0,0 +1,408 @@ +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-rev.h" +#include "slang-ir-autodiff-fwd.h" +#include "slang-ir-autodiff-pairs.h" + +namespace Slang +{ + +// TODO: Put into a nameless namespace. +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +{ + if (auto witnessTable = as<IRWitnessTable>(witness)) + { + for (auto entry : witnessTable->getEntries()) + { + if (entry->getRequirementKey() == requirementKey) + return entry->getSatisfyingVal(); + } + } + else if (auto witnessTableParam = as<IRParam>(witness)) + { + return builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + witnessTableParam, + requirementKey); + } + return nullptr; +} + +IRStructField* DifferentialPairTypeBuilder::findField(IRInst* type, IRStructKey* key) +{ + if (auto irStructType = as<IRStructType>(type)) + { + for (auto field : irStructType->getFields()) + { + if (field->getKey() == key) + { + return field; + } + } + } + else if (auto irSpecialize = as<IRSpecialize>(type)) + { + if (auto irGeneric = as<IRGeneric>(irSpecialize->getBase())) + { + if (auto irGenericStructType = as<IRStructType>(findInnerMostGenericReturnVal(irGeneric))) + { + return findField(irGenericStructType, key); + } + } + } + + return nullptr; +} + +IRInst* DifferentialPairTypeBuilder::findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam) +{ + // Get base generic that's being specialized. + auto genericType = as<IRGeneric>(as<IRSpecialize>(specializeInst)->getBase()); + SLANG_ASSERT(genericType); + + // Find the index of genericParam in the base generic. + int paramIndex = -1; + int currentIndex = 0; + for (auto param : genericType->getParams()) + { + if (param == genericParam) + paramIndex = currentIndex; + currentIndex ++; + } + + SLANG_ASSERT(paramIndex >= 0); + + // Return the corresponding operand in the specialization inst. + return specializeInst->getOperand(1 + paramIndex); +} + +IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) +{ + IRInst* pairType = nullptr; + if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType())) + { + auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType()); + + // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types, + // especially since we probably don't need diff bottom anymore. + // + SLANG_ASSERT(!baseTypeInfo.isTrivial); + + pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType); + } + else + { + auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); + pairType = baseTypeInfo.loweredType; + } + + if (auto basePairStructType = as<IRStructType>(pairType)) + { + return as<IRFieldExtract>(builder->emitFieldExtract( + findField(basePairStructType, key)->getFieldType(), + baseInst, + key + )); + } + else if (auto ptrType = as<IRPtrTypeBase>(pairType)) + { + if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType())) + { + auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(ptrInnerSpecializedType->getBase())); + if (auto genericBasePairStructType = as<IRStructType>(genericType)) + { + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findSpecializationForParam( + ptrInnerSpecializedType, + findField(ptrInnerSpecializedType, key)->getFieldType())), + baseInst, + key + )); + } + } + else if (auto ptrBaseStructType = as<IRStructType>(ptrType->getValueType())) + { + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findField(ptrBaseStructType, key)->getFieldType()), + baseInst, + key)); + } + } + else if (auto specializedType = as<IRSpecialize>(pairType)) + { + // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's + // type, emit the specialization type. + // + auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase())); + if (auto genericBasePairStructType = as<IRStructType>(genericType)) + { + return as<IRFieldExtract>(builder->emitFieldExtract( + (IRType*)findSpecializationForParam( + specializedType, + findField(genericBasePairStructType, key)->getFieldType()), + baseInst, + key + )); + } + else if (auto genericPtrType = as<IRPtrTypeBase>(genericType)) + { + if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType())) + { + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType((IRType*) + findSpecializationForParam( + specializedType, + findField(genericPairStructType, key)->getFieldType())), + baseInst, + key + )); + } + } + } + else + { + SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor"); + } + return nullptr; +} + +IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) +{ + return emitFieldAccessor(builder, baseInst, this->globalPrimalKey); +} + +IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) +{ + return emitFieldAccessor(builder, baseInst, this->globalDiffKey); +} + +IRStructKey* DifferentialPairTypeBuilder::_getOrCreateDiffStructKey() +{ + if (!this->globalDiffKey) + { + IRBuilder builder(sharedContext->sharedBuilder); + // Insert directly at top level (skip any generic scopes etc.) + builder.setInsertInto(sharedContext->moduleInst); + + this->globalDiffKey = builder.createStructKey(); + builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential")); + } + + return this->globalDiffKey; +} + +IRStructKey* DifferentialPairTypeBuilder::_getOrCreatePrimalStructKey() +{ + if (!this->globalPrimalKey) + { + // Insert directly at top level (skip any generic scopes etc.) + IRBuilder builder(sharedContext->sharedBuilder); + builder.setInsertInto(sharedContext->moduleInst); + + this->globalPrimalKey = builder.createStructKey(); + builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal")); + } + + return this->globalPrimalKey; +} + +IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, IRType* diffType) +{ + switch (origBaseType->getOp()) + { + case kIROp_lookup_interface_method: + case kIROp_Specialize: + case kIROp_Param: + return nullptr; + default: + break; + } + + IRBuilder builder(sharedContext->sharedBuilder); + builder.setInsertBefore(diffType); + + auto pairStructType = builder.createStructType(); + builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType); + builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType); + return pairStructType; +} + +IRInst* DifferentialPairTypeBuilder::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); +} + +IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); +} + +DifferentialPairTypeBuilder::LoweredPairTypeInfo DifferentialPairTypeBuilder::lowerDiffPairType( + IRBuilder* builder, IRType* originalPairType) +{ + LoweredPairTypeInfo result = {}; + + if (pairTypeCache.TryGetValue(originalPairType, result)) + return result; + auto pairType = as<IRDifferentialPairType>(originalPairType); + if (!pairType) + { + result.isTrivial = true; + result.loweredType = originalPairType; + return result; + } + auto primalType = pairType->getValueType(); + if (as<IRParam>(primalType)) + { + result.isTrivial = false; + result.loweredType = nullptr; + return result; + } + + auto diffType = getDiffTypeFromPairType(builder, pairType); + if (!diffType) + return result; + result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); + result.isTrivial = false; + pairTypeCache.Add(originalPairType, result); + + return result; +} + + +AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst) + : moduleInst(inModuleInst) +{ + differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); + if (differentiableInterfaceType) + { + differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); + zeroMethodStructKey = findZeroMethodStructKey(); + addMethodStructKey = findAddMethodStructKey(); + mulMethodStructKey = findMulMethodStructKey(); + + if (differentialAssocTypeStructKey) + isInterfaceAvailable = true; + } +} + +IRInst* AutoDiffSharedContext::findDifferentiableInterface() +{ + if (auto module = as<IRModuleInst>(moduleInst)) + { + for (auto globalInst : module->getGlobalInsts()) + { + // TODO: This seems like a particularly dangerous way to look for an interface. + // See if we can lower IDifferentiable to a separate IR inst. + // + if (globalInst->getOp() == kIROp_InterfaceType && + as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable") + { + return globalInst; + } + } + } + return nullptr; +} + +IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) +{ + if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) + { + // Assume for now that IDifferentiable has exactly five fields. + SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); + if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) + return as<IRStructKey>(entry->getRequirementKey()); + else + { + SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); + } + } + + return nullptr; +} + + +void stripAutoDiffDecorationsFromChildren(IRInst* parent) +{ + for (auto inst : parent->getChildren()) + { + for (auto decor = inst->getFirstDecoration(); decor; ) + { + auto next = decor->getNextDecoration(); + switch (decor->getOp()) + { + case kIROp_ForwardDerivativeDecoration: + case kIROp_DerivativeMemberDecoration: + case kIROp_DifferentiableTypeDictionaryDecoration: + decor->removeAndDeallocate(); + break; + default: + break; + } + decor = next; + } + + if (inst->getFirstChild() != nullptr) + { + stripAutoDiffDecorationsFromChildren(inst); + } + } +} + +void stripAutoDiffDecorations(IRModule* module) +{ + stripAutoDiffDecorationsFromChildren(module->getModuleInst()); +} + +bool processAutodiffCalls( + IRModule* module, + DiagnosticSink* sink, + IRAutodiffPassOptions const&) +{ + // Simplify module to remove dead code. + IRDeadCodeEliminationOptions dceOptions; + dceOptions.keepExportsAlive = true; + dceOptions.keepLayoutsAlive = true; + eliminateDeadCode(module, dceOptions); + + bool modified = false; + + // Create shared context for all auto-diff related passes + AutoDiffSharedContext autodiffContext(module->getModuleInst()); + + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + SharedIRBuilder sharedBuilder; + sharedBuilder.init(module); + sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); + + autodiffContext.sharedBuilder = &sharedBuilder; + + // Process forward derivative calls. + modified |= processForwardDerivativeCalls(&autodiffContext, sink); + + // Process reverse derivative calls. + modified |= processReverseDerivativeCalls(&autodiffContext, sink); + + // Replaces IRDifferentialPairType with an auto-generated struct, + // IRDifferentialPairGetDifferential with 'differential' field access, + // IRDifferentialPairGetPrimal with 'primal' field access, and + // IRMakeDifferentialPair with an IRMakeStruct. + // + modified |= processPairTypes(&autodiffContext); + + // Remove auto-diff related decorations. + stripAutoDiffDecorations(module); + + return modified; +} + + +}
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h new file mode 100644 index 000000000..e470044a4 --- /dev/null +++ b/source/slang/slang-ir-autodiff.h @@ -0,0 +1,210 @@ +// slang-ir-autodiff-fwd.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" + +namespace Slang +{ + +struct AutoDiffSharedContext +{ + IRModuleInst* moduleInst = nullptr; + + SharedIRBuilder* sharedBuilder = nullptr; + + // A reference to the builtin IDifferentiable interface type. + // We use this to look up all the other types (and type exprs) + // that conform to a base type. + // + IRInterfaceType* differentiableInterfaceType = nullptr; + + // The struct key for the 'Differential' associated type + // defined inside IDifferential. We use this to lookup the differential + // type in the conformance table associated with the concrete type. + // + IRStructKey* differentialAssocTypeStructKey = nullptr; + + // The struct key for the witness that `Differential` associated type conforms to + // `IDifferential`. + IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; + + + // The struct key for the 'zero()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of zero() for a given type. + // + IRStructKey* zeroMethodStructKey = nullptr; + + // The struct key for the 'add()' associated type + // defined inside IDifferential. We use this to lookup the + // implementation of add() for a given type. + // + IRStructKey* addMethodStructKey = nullptr; + + IRStructKey* mulMethodStructKey = nullptr; + + + // Modules that don't use differentiable types + // won't have the IDifferentiable interface type available. + // Set to false to indicate that we are uninitialized. + // + bool isInterfaceAvailable = false; + + + AutoDiffSharedContext(IRModuleInst* inModuleInst); + +private: + + IRInst* findDifferentiableInterface(); + + IRStructKey* findDifferentialTypeStructKey() + { + return getIDifferentiableStructKeyAtIndex(0); + } + + IRStructKey* findDifferentialTypeWitnessStructKey() + { + return getIDifferentiableStructKeyAtIndex(1); + } + + IRStructKey* findZeroMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(2); + } + + IRStructKey* findAddMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(3); + } + + IRStructKey* findMulMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(4); + } + + IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); +}; + +struct DifferentiableTypeConformanceContext +{ + AutoDiffSharedContext* sharedContext; + + IRGlobalValueWithCode* parentFunc = nullptr; + OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary; + + DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) + : sharedContext(shared) + {} + + void setFunc(IRGlobalValueWithCode* func); + + void buildGlobalWitnessDictionary(); + + // Lookup a witness table for the concreteType. One should exist if concreteType + // inherits (successfully) from IDifferentiable. + // + IRInst* lookUpConformanceForType(IRInst* type); + + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); + + // Lookup and return the 'Differential' type declared in the concrete type + // in order to conform to the IDifferentiable interface. + // Note that inside a generic block, this will be a witness table lookup instruction + // that gets resolved during the specialization pass. + // + IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) + { + switch (origType->getOp()) + { + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + case kIROp_VectorType: + return origType; + } + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); + } + + IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + } + + IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) + { + return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); + } + +}; + +struct DifferentialPairTypeBuilder +{ + struct LoweredPairTypeInfo + { + IRInst* loweredType; + bool isTrivial; + }; + + DifferentialPairTypeBuilder() = default; + + DifferentialPairTypeBuilder(AutoDiffSharedContext* sharedContext) : sharedContext(sharedContext) {} + + IRStructField* findField(IRInst* type, IRStructKey* key); + + IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam); + + IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key); + + IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst); + + IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst); + + IRStructKey* _getOrCreateDiffStructKey(); + + IRStructKey* _getOrCreatePrimalStructKey(); + + IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType); + + IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type); + + IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type); + + LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType); + + + Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache; + + IRStructKey* globalPrimalKey = nullptr; + + IRStructKey* globalDiffKey = nullptr; + + IRInst* genericDiffPairType = nullptr; + + List<IRInst*> generatedTypeList; + + AutoDiffSharedContext* sharedContext = nullptr; +}; + +void stripAutoDiffDecorations(IRModule* module); + +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); + +struct IRAutodiffPassOptions +{ + // Nothing for now... +}; + +bool processAutodiffCalls( + IRModule* module, + DiagnosticSink* sink, + IRAutodiffPassOptions const& options = IRAutodiffPassOptions()); + +};
\ No newline at end of file diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 44c6324e3..f4f61d7e9 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -1,6 +1,6 @@ #include "slang-ir-check-differentiability.h" -#include "slang-ir-diff-jvp.h" +#include "slang-ir-autodiff.h" #include "slang-ir-inst-pass-base.h" namespace Slang diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h deleted file mode 100644 index 5e2a7f44f..000000000 --- a/source/slang/slang-ir-diff-jvp.h +++ /dev/null @@ -1,174 +0,0 @@ -// slang-ir-diff-jvp.h -#pragma once - -#include "slang-ir.h" -#include "slang-ir-insts.h" -#include "slang-compiler.h" - -namespace Slang -{ - template<typename P, typename D> - struct DiffInstPair - { - P primal; - D differential; - DiffInstPair() = default; - DiffInstPair(P primal, D differential) : primal(primal), differential(differential) - {} - HashCode getHashCode() const - { - Hasher hasher; - hasher << primal << differential; - return hasher.getResult(); - } - bool operator ==(const DiffInstPair& other) const - { - return primal == other.primal && differential == other.differential; - } - }; - - typedef DiffInstPair<IRInst*, IRInst*> InstPair; - - struct AutoDiffSharedContext - { - IRModuleInst* moduleInst = nullptr; - - SharedIRBuilder* sharedBuilder = nullptr; - - // A reference to the builtin IDifferentiable interface type. - // We use this to look up all the other types (and type exprs) - // that conform to a base type. - // - IRInterfaceType* differentiableInterfaceType = nullptr; - - // The struct key for the 'Differential' associated type - // defined inside IDifferential. We use this to lookup the differential - // type in the conformance table associated with the concrete type. - // - IRStructKey* differentialAssocTypeStructKey = nullptr; - - // The struct key for the witness that `Differential` associated type conforms to - // `IDifferential`. - IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; - - - // The struct key for the 'zero()' associated type - // defined inside IDifferential. We use this to lookup the - // implementation of zero() for a given type. - // - IRStructKey* zeroMethodStructKey = nullptr; - - // The struct key for the 'add()' associated type - // defined inside IDifferential. We use this to lookup the - // implementation of add() for a given type. - // - IRStructKey* addMethodStructKey = nullptr; - - IRStructKey* mulMethodStructKey = nullptr; - - - // Modules that don't use differentiable types - // won't have the IDifferentiable interface type available. - // Set to false to indicate that we are uninitialized. - // - bool isInterfaceAvailable = false; - - - AutoDiffSharedContext(IRModuleInst* inModuleInst); - - private: - - IRInst* findDifferentiableInterface(); - - IRStructKey* findDifferentialTypeStructKey() - { - return getIDifferentiableStructKeyAtIndex(0); - } - - IRStructKey* findDifferentialTypeWitnessStructKey() - { - return getIDifferentiableStructKeyAtIndex(1); - } - - IRStructKey* findZeroMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(2); - } - - IRStructKey* findAddMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(3); - } - - IRStructKey* findMulMethodStructKey() - { - return getIDifferentiableStructKeyAtIndex(4); - } - - IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); - }; - - struct DifferentiableTypeConformanceContext - { - AutoDiffSharedContext* sharedContext; - - IRGlobalValueWithCode* parentFunc = nullptr; - OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary; - - DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) - : sharedContext(shared) - {} - - void setFunc(IRGlobalValueWithCode* func); - - void buildGlobalWitnessDictionary(); - - // Lookup a witness table for the concreteType. One should exist if concreteType - // inherits (successfully) from IDifferentiable. - // - IRInst* lookUpConformanceForType(IRInst* type); - - IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); - - // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. - // Note that inside a generic block, this will be a witness table lookup instruction - // that gets resolved during the specialization pass. - // - IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) - { - switch (origType->getOp()) - { - case kIROp_FloatType: - case kIROp_HalfType: - case kIROp_DoubleType: - case kIROp_VectorType: - return origType; - } - return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); - } - - IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); - } - - IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) - { - return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); - } - - }; - - struct IRJVPDerivativePassOptions - { - // Nothing for now.. - }; - - bool processDifferentiableFuncs( - IRModule* module, - DiagnosticSink* sink, - IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions()); - - void stripAutoDiffDecorations(IRModule* module); -} diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 8559103ae..3e85cc40e 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -6,7 +6,7 @@ #include "slang-ir-insts.h" #include "slang-mangle.h" #include "slang-ir-string-hash.h" -#include "slang-ir-diff-jvp.h" +#include "slang-ir-autodiff.h" #include "slang-module-library.h" #include "../compiler-core/slang-artifact.h" diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 61b6fcb76..57267a9ea 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8,7 +8,7 @@ #include "slang-ir-constexpr.h" #include "slang-ir-dce.h" #include "slang-ir-diff-call.h" -#include "slang-ir-diff-jvp.h" +#include "slang-ir-autodiff.h" #include "slang-ir-inline.h" #include "slang-ir-insts.h" #include "slang-ir-check-differentiability.h" |
