summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-11-22 12:36:28 -0500
committerGitHub <noreply@github.com>2022-11-22 09:36:28 -0800
commit6178cb601368e977c4aa82e0ae25b8eb1e875d84 (patch)
treec7dc6df96c5bb4d0f4fd598ae40158e06a082fd1
parentd9b014cba6803dbfcd37ed8ac3e7560a5191e3cf (diff)
Refactor Auto-diff passes (#2526)
* Initial refactor * Refactor passes tests * Removed Differential Bottom references from the IR side
-rw-r--r--build/visual-studio/slang/slang.vcxproj10
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters30
-rw-r--r--source/slang/slang-emit.cpp7
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp (renamed from source/slang/slang-ir-diff-jvp.cpp)1474
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h43
-rw-r--r--source/slang/slang-ir-autodiff-pairs.cpp182
-rw-r--r--source/slang/slang-ir-autodiff-pairs.h21
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp832
-rw-r--r--source/slang/slang-ir-autodiff-rev.h25
-rw-r--r--source/slang/slang-ir-autodiff.cpp408
-rw-r--r--source/slang/slang-ir-autodiff.h210
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp2
-rw-r--r--source/slang/slang-ir-diff-jvp.h174
-rw-r--r--source/slang/slang-ir-link.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
15 files changed, 1777 insertions, 1645 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index 5244f0efb..fe1922d29 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -333,6 +333,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-intrinsic-expand.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-any-value-marshalling.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-augment-make-existential.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-bind-existentials.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-byte-address-legalize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-check-differentiability.h" />
@@ -343,7 +347,6 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-constexpr.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dce.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-diff-call.h" />
- <ClInclude Include="..\..\..\source\slang\slang-ir-diff-jvp.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-export.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-import.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dominators.h" />
@@ -505,6 +508,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-intrinsic-expand.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-any-value-marshalling.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-bind-existentials.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-byte-address-legalize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-check-differentiability.cpp" />
@@ -516,7 +523,6 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-ir-dce.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-deduplicate.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-diff-call.cpp" />
- <ClCompile Include="..\..\..\source\slang\slang-ir-diff-jvp.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-export.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-import.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dominators.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 0a2d1fe3f..6fa42287c 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -132,6 +132,18 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-augment-make-existential.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-bind-existentials.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -162,9 +174,6 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-diff-call.h">
<Filter>Header Files</Filter>
</ClInclude>
- <ClInclude Include="..\..\..\source\slang\slang-ir-diff-jvp.h">
- <Filter>Header Files</Filter>
- </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-export.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -644,6 +653,18 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-augment-make-existential.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-bind-existentials.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -677,9 +698,6 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-diff-call.cpp">
<Filter>Source Files</Filter>
</ClCompile>
- <ClCompile Include="..\..\..\source\slang\slang-ir-diff-jvp.cpp">
- <Filter>Source Files</Filter>
- </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-export.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 69ea29c7a..ca55a68bc 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -11,7 +11,7 @@
#include "slang-ir-cleanup-void.h"
#include "slang-ir-dce.h"
#include "slang-ir-diff-call.h"
-#include "slang-ir-diff-jvp.h"
+#include "slang-ir-autodiff.h"
#include "slang-ir-dll-export.h"
#include "slang-ir-dll-import.h"
#include "slang-ir-eliminate-phis.h"
@@ -377,10 +377,7 @@ Result linkAndOptimizeIR(
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
- // Process higher-order calles to auto-diff passes.
- processDifferentiableFuncs(irModule, sink);
-
- stripAutoDiffDecorations(irModule);
+ processAutodiffCalls(irModule, sink);
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index c9ca687e4..03e81c5b5 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1,5 +1,6 @@
-// slang-ir-diff-jvp.cpp
-#include "slang-ir-diff-jvp.h"
+// slang-ir-autodiff-fwd.cpp
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
@@ -7,314 +8,9 @@
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
-// origX, primalX, diffX
-// origX -> primalX (cloneEnv)
-// origX -> diffX (instMapD)
-
namespace Slang
{
-namespace
-{
-
-IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
-{
- if (auto witnessTable = as<IRWitnessTable>(witness))
- {
- for (auto entry : witnessTable->getEntries())
- {
- if (entry->getRequirementKey() == requirementKey)
- return entry->getSatisfyingVal();
- }
- }
- else if (auto witnessTableParam = as<IRParam>(witness))
- {
- return builder->emitLookupInterfaceMethodInst(
- builder->getTypeKind(),
- witnessTableParam,
- requirementKey);
- }
- return nullptr;
-}
-
-}
-
-struct DifferentialPairTypeBuilder
-{
-
- IRStructField* findField(IRInst* type, IRStructKey* key)
- {
- if (auto irStructType = as<IRStructType>(type))
- {
- for (auto field : irStructType->getFields())
- {
- if (field->getKey() == key)
- {
- return field;
- }
- }
- }
- else if (auto irSpecialize = as<IRSpecialize>(type))
- {
- if (auto irGeneric = as<IRGeneric>(irSpecialize->getBase()))
- {
- if (auto irGenericStructType = as<IRStructType>(findInnerMostGenericReturnVal(irGeneric)))
- {
- return findField(irGenericStructType, key);
- }
- }
- }
-
- return nullptr;
- }
-
- IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam)
- {
- // Get base generic that's being specialized.
- auto genericType = as<IRGeneric>(as<IRSpecialize>(specializeInst)->getBase());
- SLANG_ASSERT(genericType);
-
- // Find the index of genericParam in the base generic.
- int paramIndex = -1;
- int currentIndex = 0;
- for (auto param : genericType->getParams())
- {
- if (param == genericParam)
- paramIndex = currentIndex;
- currentIndex ++;
- }
-
- SLANG_ASSERT(paramIndex >= 0);
-
- // Return the corresponding operand in the specialization inst.
- return specializeInst->getOperand(1 + paramIndex);
- }
-
- IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key)
- {
- IRInst* pairType = nullptr;
- if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType()))
- {
- auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType());
-
- // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types,
- // especially since we probably don't need diff bottom anymore.
- //
- SLANG_ASSERT(!baseTypeInfo.isTrivial);
-
- pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType);
- }
- else
- {
- auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType());
- if (baseTypeInfo.isTrivial)
- {
- if (key == globalPrimalKey)
- return baseInst;
- else
- return builder->getDifferentialBottom();
- }
-
- pairType = baseTypeInfo.loweredType;
- }
-
- if (auto basePairStructType = as<IRStructType>(pairType))
- {
- return as<IRFieldExtract>(builder->emitFieldExtract(
- findField(basePairStructType, key)->getFieldType(),
- baseInst,
- key
- ));
- }
- else if (auto ptrType = as<IRPtrTypeBase>(pairType))
- {
- if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
- {
- auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(ptrInnerSpecializedType->getBase()));
- if (auto genericBasePairStructType = as<IRStructType>(genericType))
- {
- return as<IRFieldAddress>(builder->emitFieldAddress(
- builder->getPtrType((IRType*)
- findSpecializationForParam(
- ptrInnerSpecializedType,
- findField(ptrInnerSpecializedType, key)->getFieldType())),
- baseInst,
- key
- ));
- }
- }
- else if (auto ptrBaseStructType = as<IRStructType>(ptrType->getValueType()))
- {
- return as<IRFieldAddress>(builder->emitFieldAddress(
- builder->getPtrType((IRType*)
- findField(ptrBaseStructType, key)->getFieldType()),
- baseInst,
- key));
- }
- }
- else if (auto specializedType = as<IRSpecialize>(pairType))
- {
- // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
- // type, emit the specialization type.
- //
- auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase()));
- if (auto genericBasePairStructType = as<IRStructType>(genericType))
- {
- return as<IRFieldExtract>(builder->emitFieldExtract(
- (IRType*)findSpecializationForParam(
- specializedType,
- findField(genericBasePairStructType, key)->getFieldType()),
- baseInst,
- key
- ));
- }
- else if (auto genericPtrType = as<IRPtrTypeBase>(genericType))
- {
- if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType()))
- {
- return as<IRFieldAddress>(builder->emitFieldAddress(
- builder->getPtrType((IRType*)
- findSpecializationForParam(
- specializedType,
- findField(genericPairStructType, key)->getFieldType())),
- baseInst,
- key
- ));
- }
- }
- }
- else
- {
- SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor");
- }
- return nullptr;
- }
-
- IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
- {
- return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
- }
-
- IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
- {
- return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
- }
-
- IRStructKey* _getOrCreateDiffStructKey()
- {
- if (!this->globalDiffKey)
- {
- IRBuilder builder(sharedContext->sharedBuilder);
- // Insert directly at top level (skip any generic scopes etc.)
- builder.setInsertInto(sharedContext->moduleInst);
-
- this->globalDiffKey = builder.createStructKey();
- builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential"));
- }
-
- return this->globalDiffKey;
- }
-
- IRStructKey* _getOrCreatePrimalStructKey()
- {
- if (!this->globalPrimalKey)
- {
- // Insert directly at top level (skip any generic scopes etc.)
- IRBuilder builder(sharedContext->sharedBuilder);
- builder.setInsertInto(sharedContext->moduleInst);
-
- this->globalPrimalKey = builder.createStructKey();
- builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal"));
- }
-
- return this->globalPrimalKey;
- }
-
- IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType)
- {
- switch (origBaseType->getOp())
- {
- case kIROp_lookup_interface_method:
- case kIROp_Specialize:
- case kIROp_Param:
- return nullptr;
- default:
- break;
- }
- if (diffType->getOp() != kIROp_DifferentialBottomType)
- {
- IRBuilder builder(sharedContext->sharedBuilder);
- builder.setInsertBefore(diffType);
-
- auto pairStructType = builder.createStructType();
- builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType);
- builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType);
- return pairStructType;
- }
- return origBaseType;
- }
-
- struct LoweredPairTypeInfo
- {
- IRInst* loweredType;
- bool isTrivial;
- };
-
- IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
- {
- auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
- }
-
- IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
- {
- auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
- }
-
- LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType)
- {
- LoweredPairTypeInfo result = {};
-
- if (pairTypeCache.TryGetValue(originalPairType, result))
- return result;
- auto pairType = as<IRDifferentialPairType>(originalPairType);
- if (!pairType)
- {
- result.isTrivial = true;
- result.loweredType = originalPairType;
- return result;
- }
- auto primalType = pairType->getValueType();
- if (as<IRParam>(primalType))
- {
- result.isTrivial = false;
- result.loweredType = nullptr;
- return result;
- }
-
- auto diffType = getDiffTypeFromPairType(builder, pairType);
- if (!diffType)
- return result;
- result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
- result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType);
- pairTypeCache.Add(originalPairType, result);
-
- return result;
- }
-
- Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache;
-
- IRStructKey* globalPrimalKey = nullptr;
-
- IRStructKey* globalDiffKey = nullptr;
-
- IRInst* genericDiffPairType = nullptr;
-
- List<IRInst*> generatedTypeList;
-
- AutoDiffSharedContext* sharedContext = nullptr;
-};
struct JVPTranscriber
{
@@ -468,17 +164,6 @@ struct JVPTranscriber
return builder->getFuncType(newParameterTypes, diffReturnType);
}
- IRWitnessTable* getDifferentialBottomWitness()
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
- auto result =
- as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
- builder.getDifferentialBottomType()));
- SLANG_ASSERT(result);
- return result;
- }
-
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
{
@@ -524,10 +209,6 @@ struct JVPTranscriber
{
witness = getDifferentialPairWitness(primalPairType);
}
- else
- {
- witness = getDifferentialBottomWitness();
- }
}
return builder.getDifferentialPairType(
@@ -1825,666 +1506,7 @@ struct JVPTranscriber
}
};
-
-struct BackwardDiffTranscriber
-{
-
- // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
- // their differential values.
- Dictionary<IRInst*, IRInst*> orginalToTranscribed;
-
- // Set of insts currently being transcribed. Used to avoid infinite loops.
- HashSet<IRInst*> instsInProgress;
-
- // Cloning environment to hold mapping from old to new copies for the primal
- // instructions.
- IRCloneEnv cloneEnv;
-
- // Diagnostic sink for error messages.
- DiagnosticSink* sink;
-
- // Type conformance information.
- AutoDiffSharedContext* autoDiffSharedContext;
-
- // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
- DifferentialPairTypeBuilder* pairBuilder;
-
- DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
-
- List<InstPair> followUpFunctionsToTranscribe;
-
- // Map that stores the upper gradient given an IRInst*
- Dictionary<IRInst*, List<IRInst*>> upperGradients;
- Dictionary<IRInst*, IRInst*> primalToDiffPair;
-
- SharedIRBuilder* sharedBuilder;
- // Witness table that `DifferentialBottom:IDifferential`.
- IRWitnessTable* differentialBottomWitness = nullptr;
- Dictionary<InstPair, IRInst*> differentialPairTypes;
-
- BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
- : autoDiffSharedContext(shared)
- , sink(inSink)
- , differentiableTypeConformanceContext(shared)
- , sharedBuilder(inSharedBuilder)
- {}
-
- DiagnosticSink* getSink()
- {
- SLANG_ASSERT(sink);
- return sink;
- }
-
- IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
- {
- List<IRType*> newParameterTypes;
- IRType* diffReturnType;
-
- for (UIndex i = 0; i < funcType->getParamCount(); i++)
- {
- auto origType = funcType->getParamType(i);
- if (auto diffPairType = tryGetDiffPairType(builder, origType))
- {
- auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
- newParameterTypes.add(inoutDiffPairType);
- }
- else
- newParameterTypes.add(origType);
- }
-
- newParameterTypes.add(funcType->getResultType());
-
- diffReturnType = builder->getVoidType();
-
- return builder->getFuncType(newParameterTypes, diffReturnType);
- }
-
- IRWitnessTable* getDifferentialBottomWitness()
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
- auto result =
- as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
- builder.getDifferentialBottomType()));
- SLANG_ASSERT(result);
- return result;
- }
-
- // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
- IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(inDiffPairType->parent);
- auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
- SLANG_ASSERT(diffPairType);
- auto result =
- as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
- builder.getDifferentialBottomType()));
- if (result)
- return result;
-
- auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
- auto diffType = differentiateType(&builder, diffPairType->getValueType());
- auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness());
- builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
- // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
-
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
- return table;
- }
-
- IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
- }
-
- IRType* getOrCreateDiffPairType(IRInst* primalType)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- auto witness = as<IRWitnessTable>(
- differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
- if (!witness)
- witness = getDifferentialBottomWitness();
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
- }
-
- IRType* differentiateType(IRBuilder* builder, IRType* origType)
- {
- IRInst* diffType = nullptr;
- if (!orginalToTranscribed.TryGetValue(origType, diffType))
- {
- diffType = _differentiateTypeImpl(builder, origType);
- orginalToTranscribed[origType] = diffType;
- }
- return (IRType*)diffType;
- }
-
- IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
- {
- if (auto ptrType = as<IRPtrTypeBase>(origType))
- return builder->getPtrType(
- origType->getOp(),
- differentiateType(builder, ptrType->getValueType()));
-
- // If there is an explicit primal version of this type in the local scope, load that
- // otherwise use the original type.
- //
- IRInst* primalType = origType;
-
- // Special case certain compound types (PtrType, FuncType, etc..)
- // otherwise try to lookup a differential definition for the given type.
- // If one does not exist, then we assume it's not differentiable.
- //
- switch (primalType->getOp())
- {
- case kIROp_Param:
- if (as<IRTypeType>(primalType->getDataType()))
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
- builder,
- (IRType*)primalType));
- else if (as<IRWitnessTableType>(primalType->getDataType()))
- return (IRType*)primalType;
-
- case kIROp_ArrayType:
- {
- auto primalArrayType = as<IRArrayType>(primalType);
- if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
- return builder->getArrayType(
- diffElementType,
- primalArrayType->getElementCount());
- else
- return nullptr;
- }
-
- case kIROp_DifferentialPairType:
- {
- auto primalPairType = as<IRDifferentialPairType>(primalType);
- return getOrCreateDiffPairType(
- pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
- pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
- }
-
- case kIROp_FuncType:
- return differentiateFunctionType(builder, as<IRFuncType>(primalType));
-
- case kIROp_OutType:
- if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
- return builder->getOutType(diffValueType);
- else
- return nullptr;
-
- case kIROp_InOutType:
- if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
- return builder->getInOutType(diffValueType);
- else
- return nullptr;
-
- case kIROp_TupleType:
- {
- auto tupleType = as<IRTupleType>(primalType);
- List<IRType*> diffTypeList;
- // TODO: what if we have type parameters here?
- for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
- diffTypeList.add(
- differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
-
- return builder->getTupleType(diffTypeList);
- }
-
- default:
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
- }
- }
-
- IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
- {
- // If this is a PtrType (out, inout, etc..), then create diff pair from
- // value type and re-apply the appropropriate PtrType wrapper.
- //
- if (auto origPtrType = as<IRPtrTypeBase>(primalType))
- {
- if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
- return builder->getPtrType(primalType->getOp(), diffPairValueType);
- else
- return nullptr;
- }
- auto diffType = differentiateType(builder, primalType);
- if (diffType)
- return (IRType*)getOrCreateDiffPairType(primalType);
- return nullptr;
- }
-
- InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
- {
- auto primalDataType = origParam->getDataType();
- // Do not differentiate generic type (and witness table) parameters
- if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
- {
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
- }
-
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
- {
- IRInst* diffPairParam = builder->emitParam(diffPairType);
-
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
-
- SLANG_ASSERT(diffPairParam);
-
- if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
- {
- return InstPair(
- builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- diffPairParam));
- }
- // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
- // its primal and diff parts right now because they would represent a reference
- // to a pair field, which doesn't make sense since pair types are considered mutable.
- // We encode the result as if the param is non-differentiable, and handle it
- // with special care at load/store.
- return InstPair(diffPairParam, nullptr);
- }
-
-
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
- }
-
- // Returns "dp<var-name>" to use as a name hint for parameters.
- // If no primal name is available, returns a blank string.
- //
- String makeDiffPairName(IRInst* origVar)
- {
- if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
- {
- return ("dp" + String(namehintDecoration->getName()));
- }
-
- return String("");
- }
-
-
- // In differential computation, the 'default' differential value is always zero.
- // This is a consequence of differential computing being inherently linear. As a
- // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
- // The current implementation requires that types are defined linearly.
- //
- IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
- {
- if (auto diffType = differentiateType(builder, primalType))
- {
- switch (diffType->getOp())
- {
- case kIROp_DifferentialPairType:
- return builder->emitMakeDifferentialPair(
- diffType,
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
- }
- // Since primalType has a corresponding differential type, we can lookup the
- // definition for zero().
- auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
- SLANG_ASSERT(zeroMethod);
-
- auto emptyArgList = List<IRInst*>();
- return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
- }
- else
- {
- if (isScalarIntegerType(primalType))
- {
- return builder->getIntValue(primalType, 0);
- }
-
- getSink()->diagnose(primalType->sourceLoc,
- Diagnostics::internalCompilerError,
- "could not generate zero value for given type");
- return nullptr;
- }
- }
-
- InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
- {
- IRBuilder subBuilder(builder->getSharedBuilder());
- subBuilder.setInsertLoc(builder->getInsertLoc());
-
- IRBlock* diffBlock = subBuilder.emitBlock();
-
- subBuilder.setInsertInto(diffBlock);
-
- // First transcribe every parameter in the block.
- for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
- this->copyParam(&subBuilder, param);
-
- // The extra param for input gradient
- auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType());
-
- // Then, run through every instruction and use the transcriber to generate the appropriate
- // derivative code.
- //
- for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
- this->copyInst(&subBuilder, child);
-
- auto lastInst = diffBlock->getLastOrdinaryInst();
- List<IRInst*> grads = { gradParam };
- upperGradients.Add(lastInst, grads);
- for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst())
- {
- auto upperGrads = upperGradients.TryGetValue(child);
- if (!upperGrads)
- continue;
- if (upperGrads->getCount() > 1)
- {
- auto sumGrad = upperGrads->getFirst();
- for (auto i = 1; i < upperGrads->getCount(); i++)
- {
- sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]);
- }
- this->transcribeInstBackward(&subBuilder, child, sumGrad);
- }
- else
- this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst());
- }
-
- subBuilder.emitReturn();
-
- return InstPair(diffBlock, diffBlock);
- }
-
- // Create an empty func to represent the transcribed func of `origFunc`.
- InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
- {
- IRBuilder builder(inBuilder->getSharedBuilder());
- builder.setInsertBefore(origFunc);
-
- IRFunc* primalFunc = origFunc;
-
- differentiableTypeConformanceContext.setFunc(origFunc);
-
- primalFunc = origFunc;
-
- auto diffFunc = builder.createFunc();
-
- SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
- IRType* diffFuncType = this->differentiateFunctionType(
- &builder,
- as<IRFuncType>(origFunc->getFullType()));
- diffFunc->setFullType(diffFuncType);
-
- if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
- {
- auto originalName = nameHint->getName();
- StringBuilder newNameSb;
- newNameSb << "s_bwd_" << originalName;
- builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
- }
- builder.addBackwardDerivativeDecoration(origFunc, diffFunc);
-
- // Mark the generated derivative function itself as differentiable.
- builder.addBackwardDifferentiableDecoration(diffFunc);
-
- // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
- if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- cloneDecoration(dictDecor, diffFunc);
- }
-
- auto result = InstPair(primalFunc, diffFunc);
- followUpFunctionsToTranscribe.add(result);
- return result;
- }
-
- // Transcribe a function definition.
- InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
- {
- IRBuilder builder(inBuilder->getSharedBuilder());
- builder.setInsertInto(diffFunc);
-
- differentiableTypeConformanceContext.setFunc(primalFunc);
- // Transcribe children from origFunc into diffFunc
- for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
- this->transcribeBlock(&builder, block);
-
- return InstPair(primalFunc, diffFunc);
- }
-
- IRInst* copyParam(IRBuilder* builder, IRParam* origParam)
- {
- auto primalDataType = origParam->getDataType();
-
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
- {
- auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
- IRInst* diffParam = builder->emitParam(inoutDiffPairType);
-
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice());
-
- SLANG_ASSERT(diffParam);
- auto paramValue = builder->emitLoad(diffParam);
- auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
- orginalToTranscribed.Add(origParam, primal);
- primalToDiffPair.Add(primal, diffParam);
-
- return diffParam;
- }
-
-
- return cloneInst(&cloneEnv, builder, origParam);
- }
-
- InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith)
- {
- SLANG_ASSERT(origArith->getOperandCount() == 2);
-
- auto origLeft = origArith->getOperand(0);
- auto origRight = origArith->getOperand(1);
-
- IRInst* primalLeft;
- if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft))
- {
- primalLeft = origLeft;
- }
- IRInst* primalRight;
- if (!orginalToTranscribed.TryGetValue(origRight, primalRight))
- {
- primalRight = origRight;
- }
-
- auto resultType = origArith->getDataType();
- IRInst* newInst;
- switch (origArith->getOp())
- {
- case kIROp_Add:
- newInst = builder->emitAdd(resultType, primalLeft, primalRight);
- break;
- case kIROp_Mul:
- newInst = builder->emitMul(resultType, primalLeft, primalRight);
- break;
- case kIROp_Sub:
- newInst = builder->emitSub(resultType, primalLeft, primalRight);
- break;
- case kIROp_Div:
- newInst = builder->emitDiv(resultType, primalLeft, primalRight);
- break;
- default:
- newInst = nullptr;
- getSink()->diagnose(origArith->sourceLoc,
- Diagnostics::unimplemented,
- "this arithmetic instruction cannot be differentiated");
- }
- orginalToTranscribed.Add(origArith, newInst);
- return InstPair(newInst, nullptr);
- }
-
- IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
- {
- SLANG_ASSERT(origArith->getOperandCount() == 2);
-
- auto lhs = origArith->getOperand(0);
- auto rhs = origArith->getOperand(1);
-
- if (as<IRInOutType>(lhs->getDataType()))
- {
- lhs = builder->emitLoad(lhs);
- lhs = builder->emitDifferentialPairGetPrimal(lhs);
- }
- if (as<IRInOutType>(rhs->getDataType()))
- {
- rhs = builder->emitLoad(rhs);
- rhs = builder->emitDifferentialPairGetPrimal(rhs);
- }
-
- IRInst* leftGrad;
- IRInst* rightGrad;
-
-
- switch (origArith->getOp())
- {
- case kIROp_Add:
- leftGrad = grad;
- rightGrad = grad;
- break;
- case kIROp_Mul:
- leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
- rightGrad = builder->emitMul(grad->getDataType(), lhs, grad);
- break;
- case kIROp_Sub:
- leftGrad = grad;
- rightGrad = builder->emitNeg(grad->getDataType(), grad);
- break;
- case kIROp_Div:
- leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
- rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad
- break;
- default:
- getSink()->diagnose(origArith->sourceLoc,
- Diagnostics::unimplemented,
- "this arithmetic instruction cannot be differentiated");
- }
-
- lhs = origArith->getOperand(0);
- rhs = origArith->getOperand(1);
- if (auto leftGrads = upperGradients.TryGetValue(lhs))
- {
- leftGrads->add(leftGrad);
- }
- else
- {
- upperGradients.Add(lhs, leftGrad);
- }
- if (auto rightGrads = upperGradients.TryGetValue(rhs))
- {
- rightGrads->add(rightGrad);
- }
- else
- {
- upperGradients.Add(rhs, rightGrad);
- }
-
- return nullptr;
- }
-
- InstPair copyInst(IRBuilder* builder, IRInst* origInst)
- {
- // Handle common SSA-style operations
- switch (origInst->getOp())
- {
- case kIROp_Param:
- return transcribeParam(builder, as<IRParam>(origInst));
-
- case kIROp_Return:
- return InstPair(nullptr, nullptr);
-
- case kIROp_Add:
- case kIROp_Mul:
- case kIROp_Sub:
- case kIROp_Div:
- return copyBinaryArith(builder, origInst);
-
- default:
- // Not yet implemented
- SLANG_ASSERT(0);
- }
-
- return InstPair(nullptr, nullptr);
- }
-
- IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
- {
- IRInOutType* inoutParam = as<IRInOutType>(param->getDataType());
- auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType());
- auto paramValue = builder->emitLoad(param);
- auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
- auto diff = builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- paramValue
- );
- auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad);
- auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff);
- auto store = builder->emitStore(param, updatedParam);
-
- return store;
- }
-
- IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
- {
- // Handle common SSA-style operations
- switch (origInst->getOp())
- {
- case kIROp_Param:
- return transcribeParamBackward(builder, as<IRParam>(origInst), grad);
-
- case kIROp_Add:
- case kIROp_Mul:
- case kIROp_Sub:
- case kIROp_Div:
- return transcribeBinaryArithBackward(builder, origInst, grad);
-
- case kIROp_DifferentialPairGetPrimal:
- {
- if (auto param = primalToDiffPair.TryGetValue(origInst))
- {
- if (auto leftGrads = upperGradients.TryGetValue(*param))
- {
- leftGrads->add(grad);
- }
- else
- {
- upperGradients.Add(*param, grad);
- }
- }
- else
- SLANG_ASSERT(0);
- return nullptr;
- }
-
- default:
- // Not yet implemented
- SLANG_ASSERT(0);
- }
-
- return nullptr;
- }
-};
-
-
-struct JVPDerivativeContext : public InstPassBase
+struct ForwardDerivativePass : public InstPassBase
{
DiagnosticSink* getSink()
@@ -2494,18 +1516,10 @@ struct JVPDerivativeContext : public InstPassBase
bool processModule()
{
- // We start by initializing our shared IR building state,
- // since we will re-use that state for any code we
- // generate along the way.
- //
- SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
- sharedBuilder->init(module);
- sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
-
// TODO(sai): Move this call.
transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary();
- IRBuilder builderStorage(sharedBuilderStorage);
+ IRBuilder builderStorage(this->autodiffContext->sharedBuilder);
IRBuilder* builder = &builderStorage;
// Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
@@ -2513,20 +1527,6 @@ struct JVPDerivativeContext : public InstPassBase
//
bool modified = processReferencedFunctions(builder);
- // Replaces IRDifferentialPairType with an auto-generated struct,
- // IRDifferentialPairGetDifferential with 'differential' field access,
- // IRDifferentialPairGetPrimal with 'primal' field access, and
- // IRMakeDifferentialPair with an IRMakeStruct.
- //
- modified |= simplifyDifferentialBottomType(builder);
-
- // De-duplicate any remaining types.
- sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
-
- modified |= processPairTypes(builder, module->getModuleInst());
-
- modified |= eliminateDifferentialBottomType(builder);
-
return modified;
}
@@ -2553,7 +1553,6 @@ struct JVPDerivativeContext : public InstPassBase
switch (inst->getOp())
{
case kIROp_ForwardDifferentiate:
- case kIROp_BackwardDifferentiate:
// Only process now if the operand is a materialized function.
switch (inst->getOperand(0)->getOp())
{
@@ -2577,7 +1576,6 @@ struct JVPDerivativeContext : public InstPassBase
// differentiated functions.
transcriberStorage.followUpFunctionsToTranscribe.clear();
- backwardTranscriberStorage.followUpFunctionsToTranscribe.clear();
for (auto differentiateInst : autoDiffWorkList)
{
@@ -2613,28 +1611,6 @@ struct JVPDerivativeContext : public InstPassBase
}
}
- else if (as<IRBackwardDifferentiate>(differentiateInst))
- {
- if (isMarkedForBackwardDifferentiation(baseInst))
- {
- if (as<IRFunc>(baseInst))
- {
- IRInst* diffFunc =
- backwardTranscriberStorage
- .transcribeFuncHeader(builder, (IRFunc*)baseInst)
- .differential;
- SLANG_ASSERT(diffFunc);
- differentiateInst->replaceUsesWith(diffFunc);
- differentiateInst->removeAndDeallocate();
- }
- else
- {
- getSink()->diagnose(differentiateInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Unexpected instruction. Expected func or generic");
- }
- }
- }
}
// Actually synthesize the derivatives.
List<InstPair> followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe);
@@ -2647,16 +1623,6 @@ struct JVPDerivativeContext : public InstPassBase
transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
}
- followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe);
- for (auto task : followUpWorkList)
- {
- auto diffFunc = as<IRFunc>(task.differential);
- SLANG_ASSERT(diffFunc);
- auto primalFunc = as<IRFunc>(task.primal);
- SLANG_ASSERT(primalFunc);
-
- backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
- }
// Transcribing the function body really shouldn't produce more follow up function body work.
// However it may produce new `ForwardDifferentiate` instructions, which we collect and process
@@ -2667,281 +1633,6 @@ struct JVPDerivativeContext : public InstPassBase
return true;
}
- IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr)
- {
- builder->setInsertBefore(pairType);
- auto loweredPairTypeInfo = (&pairBuilderStorage)->lowerDiffPairType(
- builder,
- pairType);
- if (isTrivial)
- *isTrivial = loweredPairTypeInfo.isTrivial;
- return loweredPairTypeInfo.loweredType;
- }
-
- IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
- {
-
- if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
- {
- bool isTrivial = false;
- auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType());
- if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial))
- {
- builder->setInsertBefore(makePairInst);
- IRInst* result = nullptr;
- if (isTrivial)
- {
- result = makePairInst->getPrimalValue();
- }
- else
- {
- IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() };
- result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands);
- }
- makePairInst->replaceUsesWith(result);
- makePairInst->removeAndDeallocate();
- return result;
- }
- }
-
- return nullptr;
- }
-
- IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst)
- {
- if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
- {
- auto pairType = getDiffInst->getBase()->getDataType();
- if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
- {
- pairType = pairPtrType->getValueType();
- }
-
- if (lowerPairType(builder, pairType, nullptr))
- {
- builder->setInsertBefore(getDiffInst);
- IRInst* diffFieldExtract = nullptr;
- diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
- getDiffInst->replaceUsesWith(diffFieldExtract);
- getDiffInst->removeAndDeallocate();
- return diffFieldExtract;
- }
- }
- else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
- {
- auto pairType = getPrimalInst->getBase()->getDataType();
- if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
- {
- pairType = pairPtrType->getValueType();
- }
-
- if (lowerPairType(builder, pairType, nullptr))
- {
- builder->setInsertBefore(getPrimalInst);
-
- IRInst* primalFieldExtract = nullptr;
- primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
- getPrimalInst->replaceUsesWith(primalFieldExtract);
- getPrimalInst->removeAndDeallocate();
- return primalFieldExtract;
- }
- }
-
- return nullptr;
- }
-
- bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren)
- {
- bool modified = false;
- // Hoist all pair types to global scope when possible.
- auto moduleInst = module->getModuleInst();
- processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType)
- {
- if (originalPairType->parent != moduleInst)
- {
- originalPairType->removeFromParent();
- ShortList<IRInst*> operands;
- for (UInt i = 0; i < originalPairType->getOperandCount(); i++)
- {
- operands.add(originalPairType->getOperand(i));
- }
- auto newPairType = builder->findOrEmitHoistableInst(
- originalPairType->getFullType(),
- originalPairType->getOp(),
- originalPairType->getOperandCount(),
- operands.getArrayView().getBuffer());
- originalPairType->replaceUsesWith(newPairType);
- originalPairType->removeAndDeallocate();
- }
- });
-
- sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
-
- processAllInsts([&](IRInst* inst)
- {
- // Make sure the builder is at the right level.
- builder->setInsertInto(instWithChildren);
-
- switch (inst->getOp())
- {
- case kIROp_DifferentialPairGetDifferential:
- case kIROp_DifferentialPairGetPrimal:
- lowerPairAccess(builder, inst);
- modified = true;
- break;
-
- case kIROp_MakeDifferentialPair:
- lowerMakePair(builder, inst);
- modified = true;
- break;
-
- default:
- break;
- }
- });
-
- processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
- {
- if (auto loweredType = lowerPairType(builder, inst))
- {
- inst->replaceUsesWith(loweredType);
- inst->removeAndDeallocate();
- }
- });
- return modified;
- }
-
- bool simplifyDifferentialBottomType(IRBuilder* builder)
- {
- bool modified = false;
- auto diffBottom = builder->getDifferentialBottom();
-
- bool changed = true;
- List<IRUse*> uses;
- while (changed)
- {
- changed = false;
- // Replace all insts whose type is `DifferentialBottomType` to `diffBottom`.
- processAllInsts([&](IRInst* inst)
- {
- if (inst->getDataType() && inst->getDataType()->getOp() == kIROp_DifferentialBottomType)
- {
- if (inst != diffBottom)
- {
- inst->replaceUsesWith(diffBottom);
- inst->removeAndDeallocate();
- modified = true;
- }
- }
- });
- // Go through all uses of diffBottom and run simplification.
- processAllInsts([&](IRInst* inst)
- {
- if (!inst->hasUses())
- return;
-
- builder->setInsertBefore(inst);
- IRInst* valueToReplace = nullptr;
- switch (inst->getOp())
- {
- case kIROp_Store:
- if (as<IRStore>(inst)->getVal() == diffBottom)
- {
- inst->removeAndDeallocate();
- changed = true;
- }
- return;
- case kIROp_MakeDifferentialPair:
- // Our simplification could lead to a situation where
- // bottom is used to make a pair that has a non-bottom differential type,
- // in this case we should use zero instead.
- if (inst->getOperand(1) == diffBottom)
- {
- // Only apply if we are the second operand.
- auto pairType = as<IRDifferentialPairType>(inst->getDataType());
- if (pairBuilderStorage.getDiffTypeFromPairType(builder, pairType)->getOp() != kIROp_DifferentialBottomType)
- {
- auto zero = transcriberStorage.getDifferentialZeroOfType(builder, pairType->getValueType());
- inst->setOperand(1, zero);
- changed = true;
- }
- }
- return;
- case kIROp_DifferentialPairGetDifferential:
- if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair)
- {
- valueToReplace = inst->getOperand(0)->getOperand(1);
- }
- break;
- case kIROp_DifferentialPairGetPrimal:
- if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair)
- {
- valueToReplace = inst->getOperand(0)->getOperand(0);
- }
- break;
- case kIROp_Add:
- if (inst->getOperand(0) == diffBottom)
- {
- valueToReplace = inst->getOperand(1);
- }
- else if (inst->getOperand(1) == diffBottom)
- {
- valueToReplace = inst->getOperand(0);
- }
- break;
- case kIROp_Sub:
- if (inst->getOperand(0) == diffBottom)
- {
- // If left is bottom, and right is not bottom, then we should return -right.
- // However we can't possibly run into that case since both side of - operator
- // must be at the same order of differentiation.
- valueToReplace = diffBottom;
- }
- else if (inst->getOperand(1) == diffBottom)
- {
- valueToReplace = inst->getOperand(0);
- }
- break;
- case kIROp_Mul:
- case kIROp_Div:
- if (inst->getOperand(0) == diffBottom)
- {
- valueToReplace = diffBottom;
- }
- else if (inst->getOperand(1) == diffBottom)
- {
- valueToReplace = diffBottom;
- }
- break;
- default:
- break;
- }
- if (valueToReplace)
- {
- inst->replaceUsesWith(valueToReplace);
- changed = true;
- }
- });
- modified |= changed;
- }
-
- return modified;
- }
-
- bool eliminateDifferentialBottomType(IRBuilder* builder)
- {
- simplifyDifferentialBottomType(builder);
-
- bool modified = false;
- auto diffBottom = builder->getDifferentialBottom();
- auto diffBottomType = diffBottom->getDataType();
- diffBottom->replaceUsesWith(builder->getVoidValue());
- diffBottom->removeAndDeallocate();
- diffBottomType->replaceUsesWith(builder->getVoidType());
-
- return modified;
- }
-
// Checks decorators to see if the function should
// be differentiated (kIROp_ForwardDifferentiableDecoration)
//
@@ -2968,45 +1659,16 @@ struct JVPDerivativeContext : public InstPassBase
return name;
}
- // Checks decorators to see if the function should
- // be differentiated (kIROp_ForwardDifferentiableDecoration)
- //
- bool isMarkedForBackwardDifferentiation(IRInst* callable)
- {
- return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr;
- }
-
- IRStringLit* getBackwardDerivativeFuncName(IRInst* func)
- {
- IRBuilder builder(&sharedBuilderStorage);
- builder.setInsertBefore(func);
-
- IRStringLit* name = nullptr;
- if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
- {
- name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice());
- }
- else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
- {
- name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice());
- }
-
- return name;
- }
-
- JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
- InstPassBase(module),
+ ForwardDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) :
+ InstPassBase(context->moduleInst->getModule()),
sink(sink),
- autoDiffSharedContextStorage(module->getModuleInst()),
- transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage),
- backwardTranscriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage, sink)
+ transcriberStorage(context, context->sharedBuilder),
+ pairBuilderStorage(context),
+ autodiffContext(context)
{
- autoDiffSharedContextStorage.sharedBuilder = &sharedBuilderStorage;
- pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage;
transcriberStorage.sink = sink;
- transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage);
+ transcriberStorage.autoDiffSharedContext = context;
transcriberStorage.pairBuilder = &(pairBuilderStorage);
- backwardTranscriberStorage.pairBuilder = &pairBuilderStorage;
}
protected:
@@ -3015,15 +1677,12 @@ protected:
//
JVPTranscriber transcriberStorage;
- BackwardDiffTranscriber backwardTranscriberStorage;
-
// Diagnostic object from the compile request for
// error messages.
- DiagnosticSink* sink;
+ DiagnosticSink* sink;
- // Context to find and manage the witness tables for types
- // implementing `IDifferentiable`
- AutoDiffSharedContext autoDiffSharedContextStorage;
+ // Shared context.
+ AutoDiffSharedContext* autodiffContext;
// Builder for dealing with differential pair types.
DifferentialPairTypeBuilder pairBuilderStorage;
@@ -3032,106 +1691,16 @@ protected:
// Set up context and call main process method.
//
-bool processDifferentiableFuncs(
- IRModule* module,
+bool processForwardDerivativeCalls(
+ AutoDiffSharedContext* autodiffContext,
DiagnosticSink* sink,
- IRJVPDerivativePassOptions const&)
+ ForwardDerivativePassOptions const&)
{
- // Simplify module to remove dead code.
- IRDeadCodeEliminationOptions options;
- options.keepExportsAlive = true;
- options.keepLayoutsAlive = true;
- eliminateDeadCode(module, options);
-
- JVPDerivativeContext context(module, sink);
- bool changed = context.processModule();
+ ForwardDerivativePass fwdPass(autodiffContext, sink);
+ bool changed = fwdPass.processModule();
return changed;
}
-void stripAutoDiffDecorationsFromChildren(IRInst* parent)
-{
- for (auto inst : parent->getChildren())
- {
- for (auto decor = inst->getFirstDecoration(); decor; )
- {
- auto next = decor->getNextDecoration();
- switch (decor->getOp())
- {
- case kIROp_ForwardDerivativeDecoration:
- case kIROp_DerivativeMemberDecoration:
- case kIROp_DifferentiableTypeDictionaryDecoration:
- decor->removeAndDeallocate();
- break;
- default:
- break;
- }
- decor = next;
- }
-
- if (inst->getFirstChild() != nullptr)
- {
- stripAutoDiffDecorationsFromChildren(inst);
- }
- }
-}
-
-void stripAutoDiffDecorations(IRModule* module)
-{
- stripAutoDiffDecorationsFromChildren(module->getModuleInst());
-}
-
-AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst)
- : moduleInst(inModuleInst)
-{
- differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
- if (differentiableInterfaceType)
- {
- differentialAssocTypeStructKey = findDifferentialTypeStructKey();
- differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
- zeroMethodStructKey = findZeroMethodStructKey();
- addMethodStructKey = findAddMethodStructKey();
- mulMethodStructKey = findMulMethodStructKey();
-
- if (differentialAssocTypeStructKey)
- isInterfaceAvailable = true;
- }
-}
-
-IRInst* AutoDiffSharedContext::findDifferentiableInterface()
-{
- if (auto module = as<IRModuleInst>(moduleInst))
- {
- for (auto globalInst : module->getGlobalInsts())
- {
- // TODO: This seems like a particularly dangerous way to look for an interface.
- // See if we can lower IDifferentiable to a separate IR inst.
- //
- if (globalInst->getOp() == kIROp_InterfaceType &&
- as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable")
- {
- return globalInst;
- }
- }
- }
- return nullptr;
-}
-
-IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index)
-{
- if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
- {
- // Assume for now that IDifferentiable has exactly five fields.
- SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
- if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
- return as<IRStructKey>(entry->getRequirementKey());
- else
- {
- SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
- }
- }
-
- return nullptr;
-}
void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
{
@@ -3148,11 +1717,6 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
if (existingItem)
{
- if (auto witness = as<IRWitnessTable>(item->getWitness()))
- {
- if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType)
- continue;
- }
*existingItem = item->getWitness();
}
else
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
new file mode 100644
index 000000000..6b261ecd0
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -0,0 +1,43 @@
+// slang-ir-autodiff-fwd.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+namespace Slang
+{
+
+ template<typename P, typename D>
+ struct DiffInstPair
+ {
+ P primal;
+ D differential;
+ DiffInstPair() = default;
+ DiffInstPair(P primal, D differential) : primal(primal), differential(differential)
+ {}
+ HashCode getHashCode() const
+ {
+ Hasher hasher;
+ hasher << primal << differential;
+ return hasher.getResult();
+ }
+ bool operator ==(const DiffInstPair& other) const
+ {
+ return primal == other.primal && differential == other.differential;
+ }
+ };
+
+ typedef DiffInstPair<IRInst*, IRInst*> InstPair;
+
+ struct ForwardDerivativePassOptions
+ {
+ // Nothing for now..
+ };
+
+ bool processForwardDerivativeCalls(
+ AutoDiffSharedContext* autodiffContext,
+ DiagnosticSink* sink,
+ ForwardDerivativePassOptions const& options = ForwardDerivativePassOptions());
+
+}
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
new file mode 100644
index 000000000..1dbb1bd7c
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -0,0 +1,182 @@
+#include "slang-ir-autodiff-pairs.h"
+
+namespace Slang
+{
+
+struct DiffPairLoweringPass : InstPassBase
+{
+ DiffPairLoweringPass(AutoDiffSharedContext* context) :
+ InstPassBase(context->moduleInst->getModule()),
+ pairBuilderStorage(context),
+ autodiffContext(context)
+ {
+ pairBuilder = &pairBuilderStorage;
+ }
+
+ IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr)
+ {
+ builder->setInsertBefore(pairType);
+ auto loweredPairTypeInfo = pairBuilder->lowerDiffPairType(
+ builder,
+ pairType);
+ if (isTrivial)
+ *isTrivial = loweredPairTypeInfo.isTrivial;
+ return loweredPairTypeInfo.loweredType;
+ }
+
+ IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
+ {
+
+ if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
+ {
+ bool isTrivial = false;
+ auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType());
+ if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial))
+ {
+ builder->setInsertBefore(makePairInst);
+ IRInst* result = nullptr;
+ if (isTrivial)
+ {
+ result = makePairInst->getPrimalValue();
+ }
+ else
+ {
+ IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() };
+ result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands);
+ }
+ makePairInst->replaceUsesWith(result);
+ makePairInst->removeAndDeallocate();
+ return result;
+ }
+ }
+
+ return nullptr;
+ }
+
+ IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst)
+ {
+ if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
+ {
+ auto pairType = getDiffInst->getBase()->getDataType();
+ if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
+ {
+ pairType = pairPtrType->getValueType();
+ }
+
+ if (lowerPairType(builder, pairType, nullptr))
+ {
+ builder->setInsertBefore(getDiffInst);
+ IRInst* diffFieldExtract = nullptr;
+ diffFieldExtract = pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase());
+ getDiffInst->replaceUsesWith(diffFieldExtract);
+ getDiffInst->removeAndDeallocate();
+ return diffFieldExtract;
+ }
+ }
+ else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
+ {
+ auto pairType = getPrimalInst->getBase()->getDataType();
+ if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
+ {
+ pairType = pairPtrType->getValueType();
+ }
+
+ if (lowerPairType(builder, pairType, nullptr))
+ {
+ builder->setInsertBefore(getPrimalInst);
+
+ IRInst* primalFieldExtract = nullptr;
+ primalFieldExtract = pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
+ getPrimalInst->replaceUsesWith(primalFieldExtract);
+ getPrimalInst->removeAndDeallocate();
+ return primalFieldExtract;
+ }
+ }
+
+ return nullptr;
+ }
+
+ bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren)
+ {
+ bool modified = false;
+ // Hoist all pair types to global scope when possible.
+ auto moduleInst = module->getModuleInst();
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType)
+ {
+ if (originalPairType->parent != moduleInst)
+ {
+ originalPairType->removeFromParent();
+ ShortList<IRInst*> operands;
+ for (UInt i = 0; i < originalPairType->getOperandCount(); i++)
+ {
+ operands.add(originalPairType->getOperand(i));
+ }
+ auto newPairType = builder->findOrEmitHoistableInst(
+ originalPairType->getFullType(),
+ originalPairType->getOp(),
+ originalPairType->getOperandCount(),
+ operands.getArrayView().getBuffer());
+ originalPairType->replaceUsesWith(newPairType);
+ originalPairType->removeAndDeallocate();
+ }
+ });
+
+ autodiffContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
+
+ processAllInsts([&](IRInst* inst)
+ {
+ // Make sure the builder is at the right level.
+ builder->setInsertInto(instWithChildren);
+
+ switch (inst->getOp())
+ {
+ case kIROp_DifferentialPairGetDifferential:
+ case kIROp_DifferentialPairGetPrimal:
+ lowerPairAccess(builder, inst);
+ modified = true;
+ break;
+
+ case kIROp_MakeDifferentialPair:
+ lowerMakePair(builder, inst);
+ modified = true;
+ break;
+
+ default:
+ break;
+ }
+ });
+
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
+ {
+ if (auto loweredType = lowerPairType(builder, inst))
+ {
+ inst->replaceUsesWith(loweredType);
+ inst->removeAndDeallocate();
+ }
+ });
+ return modified;
+ }
+
+ bool processModule()
+ {
+ IRBuilder builder(autodiffContext->sharedBuilder);
+ return processInstWithChildren(&builder, module->getModuleInst());
+ }
+
+ private:
+
+ AutoDiffSharedContext* autodiffContext;
+
+ DifferentialPairTypeBuilder* pairBuilder;
+
+ DifferentialPairTypeBuilder pairBuilderStorage;
+
+};
+
+bool processPairTypes(AutoDiffSharedContext* context)
+{
+ DiffPairLoweringPass pairLoweringPass(context);
+ return pairLoweringPass.processModule();
+}
+
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-pairs.h b/source/slang/slang-ir-autodiff-pairs.h
new file mode 100644
index 000000000..44321ae9b
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-pairs.h
@@ -0,0 +1,21 @@
+// slang-ir-autodiff-pairs.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
+
+#include "slang-ir-autodiff.h"
+
+namespace Slang
+{
+
+bool processPairTypes(AutoDiffSharedContext* context);
+
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
new file mode 100644
index 000000000..52567e887
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -0,0 +1,832 @@
+#include "slang-ir-autodiff-rev.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
+
+
+namespace Slang
+{
+struct BackwardDiffTranscriber
+{
+ // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
+ // their differential values.
+ Dictionary<IRInst*, IRInst*> orginalToTranscribed;
+
+ // Set of insts currently being transcribed. Used to avoid infinite loops.
+ HashSet<IRInst*> instsInProgress;
+
+ // Cloning environment to hold mapping from old to new copies for the primal
+ // instructions.
+ IRCloneEnv cloneEnv;
+
+ // Diagnostic sink for error messages.
+ DiagnosticSink* sink;
+
+ // Type conformance information.
+ AutoDiffSharedContext* autoDiffSharedContext;
+
+ // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
+ DifferentialPairTypeBuilder* pairBuilder;
+
+ DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
+
+ List<InstPair> followUpFunctionsToTranscribe;
+
+ // Map that stores the upper gradient given an IRInst*
+ Dictionary<IRInst*, List<IRInst*>> upperGradients;
+ Dictionary<IRInst*, IRInst*> primalToDiffPair;
+
+ SharedIRBuilder* sharedBuilder;
+ // Witness table that `DifferentialBottom:IDifferential`.
+ IRWitnessTable* differentialBottomWitness = nullptr;
+ Dictionary<InstPair, IRInst*> differentialPairTypes;
+
+ BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ : autoDiffSharedContext(shared)
+ , sink(inSink)
+ , differentiableTypeConformanceContext(shared)
+ , sharedBuilder(inSharedBuilder)
+ {}
+
+ DiagnosticSink* getSink()
+ {
+ SLANG_ASSERT(sink);
+ return sink;
+ }
+
+ IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ {
+ List<IRType*> newParameterTypes;
+ IRType* diffReturnType;
+
+ for (UIndex i = 0; i < funcType->getParamCount(); i++)
+ {
+ auto origType = funcType->getParamType(i);
+ if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ {
+ auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
+ newParameterTypes.add(inoutDiffPairType);
+ }
+ else
+ newParameterTypes.add(origType);
+ }
+
+ newParameterTypes.add(funcType->getResultType());
+
+ diffReturnType = builder->getVoidType();
+
+ return builder->getFuncType(newParameterTypes, diffReturnType);
+ }
+
+ // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+ IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(inDiffPairType->parent);
+ auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
+ SLANG_ASSERT(diffPairType);
+
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+ auto diffType = differentiateType(&builder, diffPairType->getValueType());
+ auto differentialType = builder.getDifferentialPairType(diffType, nullptr);
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ return table;
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ auto witness = as<IRWitnessTable>(
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* differentiateType(IRBuilder* builder, IRType* origType)
+ {
+ IRInst* diffType = nullptr;
+ if (!orginalToTranscribed.TryGetValue(origType, diffType))
+ {
+ diffType = _differentiateTypeImpl(builder, origType);
+ orginalToTranscribed[origType] = diffType;
+ }
+ return (IRType*)diffType;
+ }
+
+ IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(origType))
+ return builder->getPtrType(
+ origType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
+
+ // If there is an explicit primal version of this type in the local scope, load that
+ // otherwise use the original type.
+ //
+ IRInst* primalType = origType;
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as<IRTypeType>(primalType->getDataType()))
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
+ builder,
+ (IRType*)primalType));
+ else if (as<IRWitnessTableType>(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
+ {
+ auto primalArrayType = as<IRArrayType>(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
+
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as<IRDifferentialPairType>(primalType);
+ return getOrCreateDiffPairType(
+ pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
+ pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_FuncType:
+ return differentiateFunctionType(builder, as<IRFuncType>(primalType));
+
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_TupleType:
+ {
+ auto tupleType = as<IRTupleType>(primalType);
+ List<IRType*> diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
+
+ return builder->getTupleType(diffTypeList);
+ }
+
+ default:
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
+ }
+ }
+
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
+ {
+ // If this is a PtrType (out, inout, etc..), then create diff pair from
+ // value type and re-apply the appropropriate PtrType wrapper.
+ //
+ if (auto origPtrType = as<IRPtrTypeBase>(primalType))
+ {
+ if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
+ return builder->getPtrType(primalType->getOp(), diffPairValueType);
+ else
+ return nullptr;
+ }
+ auto diffType = differentiateType(builder, primalType);
+ if (diffType)
+ return (IRType*)getOrCreateDiffPairType(primalType);
+ return nullptr;
+ }
+
+ InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
+ {
+ auto primalDataType = origParam->getDataType();
+ // Do not differentiate generic type (and witness table) parameters
+ if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
+ {
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffPairParam);
+
+ if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
+ // its primal and diff parts right now because they would represent a reference
+ // to a pair field, which doesn't make sense since pair types are considered mutable.
+ // We encode the result as if the param is non-differentiable, and handle it
+ // with special care at load/store.
+ return InstPair(diffPairParam, nullptr);
+ }
+
+
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ // Returns "dp<var-name>" to use as a name hint for parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String makeDiffPairName(IRInst* origVar)
+ {
+ if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
+ {
+ return ("dp" + String(namehintDecoration->getName()));
+ }
+
+ return String("");
+ }
+
+
+ // In differential computation, the 'default' differential value is always zero.
+ // This is a consequence of differential computing being inherently linear. As a
+ // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+ // The current implementation requires that types are defined linearly.
+ //
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+ {
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ switch (diffType->getOp())
+ {
+ case kIROp_DifferentialPairType:
+ return builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ }
+ // Since primalType has a corresponding differential type, we can lookup the
+ // definition for zero().
+ auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
+ SLANG_ASSERT(zeroMethod);
+
+ auto emptyArgList = List<IRInst*>();
+ return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ }
+ else
+ {
+ if (isScalarIntegerType(primalType))
+ {
+ return builder->getIntValue(primalType, 0);
+ }
+
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
+ }
+ }
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+ {
+ IRBuilder subBuilder(builder->getSharedBuilder());
+ subBuilder.setInsertLoc(builder->getInsertLoc());
+
+ IRBlock* diffBlock = subBuilder.emitBlock();
+
+ subBuilder.setInsertInto(diffBlock);
+
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ this->copyParam(&subBuilder, param);
+
+ // The extra param for input gradient
+ auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType());
+
+ // Then, run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ this->copyInst(&subBuilder, child);
+
+ auto lastInst = diffBlock->getLastOrdinaryInst();
+ List<IRInst*> grads = { gradParam };
+ upperGradients.Add(lastInst, grads);
+ for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst())
+ {
+ auto upperGrads = upperGradients.TryGetValue(child);
+ if (!upperGrads)
+ continue;
+ if (upperGrads->getCount() > 1)
+ {
+ auto sumGrad = upperGrads->getFirst();
+ for (auto i = 1; i < upperGrads->getCount(); i++)
+ {
+ sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]);
+ }
+ this->transcribeInstBackward(&subBuilder, child, sumGrad);
+ }
+ else
+ this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst());
+ }
+
+ subBuilder.emitReturn();
+
+ return InstPair(diffBlock, diffBlock);
+ }
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origFunc);
+
+ IRFunc* primalFunc = origFunc;
+
+ differentiableTypeConformanceContext.setFunc(origFunc);
+
+ primalFunc = origFunc;
+
+ auto diffFunc = builder.createFunc();
+
+ SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
+ IRType* diffFuncType = this->differentiateFunctionType(
+ &builder,
+ as<IRFuncType>(origFunc->getFullType()));
+ diffFunc->setFullType(diffFuncType);
+
+ if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
+ {
+ auto originalName = nameHint->getName();
+ StringBuilder newNameSb;
+ newNameSb << "s_bwd_" << originalName;
+ builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ }
+ builder.addBackwardDerivativeDecoration(origFunc, diffFunc);
+
+ // Mark the generated derivative function itself as differentiable.
+ builder.addBackwardDifferentiableDecoration(diffFunc);
+
+ // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
+ if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ {
+ cloneDecoration(dictDecor, diffFunc);
+ }
+
+ auto result = InstPair(primalFunc, diffFunc);
+ followUpFunctionsToTranscribe.add(result);
+ return result;
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertInto(diffFunc);
+
+ differentiableTypeConformanceContext.setFunc(primalFunc);
+ // Transcribe children from origFunc into diffFunc
+ for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribeBlock(&builder, block);
+
+ return InstPair(primalFunc, diffFunc);
+ }
+
+ IRInst* copyParam(IRBuilder* builder, IRParam* origParam)
+ {
+ auto primalDataType = origParam->getDataType();
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
+ IRInst* diffParam = builder->emitParam(inoutDiffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffParam);
+ auto paramValue = builder->emitLoad(diffParam);
+ auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
+ orginalToTranscribed.Add(origParam, primal);
+ primalToDiffPair.Add(primal, diffParam);
+
+ return diffParam;
+ }
+
+
+ return cloneInst(&cloneEnv, builder, origParam);
+ }
+
+ InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith)
+ {
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ auto origLeft = origArith->getOperand(0);
+ auto origRight = origArith->getOperand(1);
+
+ IRInst* primalLeft;
+ if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft))
+ {
+ primalLeft = origLeft;
+ }
+ IRInst* primalRight;
+ if (!orginalToTranscribed.TryGetValue(origRight, primalRight))
+ {
+ primalRight = origRight;
+ }
+
+ auto resultType = origArith->getDataType();
+ IRInst* newInst;
+ switch (origArith->getOp())
+ {
+ case kIROp_Add:
+ newInst = builder->emitAdd(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Mul:
+ newInst = builder->emitMul(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Sub:
+ newInst = builder->emitSub(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Div:
+ newInst = builder->emitDiv(resultType, primalLeft, primalRight);
+ break;
+ default:
+ newInst = nullptr;
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
+ }
+ orginalToTranscribed.Add(origArith, newInst);
+ return InstPair(newInst, nullptr);
+ }
+
+ IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
+ {
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ auto lhs = origArith->getOperand(0);
+ auto rhs = origArith->getOperand(1);
+
+ if (as<IRInOutType>(lhs->getDataType()))
+ {
+ lhs = builder->emitLoad(lhs);
+ lhs = builder->emitDifferentialPairGetPrimal(lhs);
+ }
+ if (as<IRInOutType>(rhs->getDataType()))
+ {
+ rhs = builder->emitLoad(rhs);
+ rhs = builder->emitDifferentialPairGetPrimal(rhs);
+ }
+
+ IRInst* leftGrad;
+ IRInst* rightGrad;
+
+
+ switch (origArith->getOp())
+ {
+ case kIROp_Add:
+ leftGrad = grad;
+ rightGrad = grad;
+ break;
+ case kIROp_Mul:
+ leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
+ rightGrad = builder->emitMul(grad->getDataType(), lhs, grad);
+ break;
+ case kIROp_Sub:
+ leftGrad = grad;
+ rightGrad = builder->emitNeg(grad->getDataType(), grad);
+ break;
+ case kIROp_Div:
+ leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
+ rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad
+ break;
+ default:
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
+ }
+
+ lhs = origArith->getOperand(0);
+ rhs = origArith->getOperand(1);
+ if (auto leftGrads = upperGradients.TryGetValue(lhs))
+ {
+ leftGrads->add(leftGrad);
+ }
+ else
+ {
+ upperGradients.Add(lhs, leftGrad);
+ }
+ if (auto rightGrads = upperGradients.TryGetValue(rhs))
+ {
+ rightGrads->add(rightGrad);
+ }
+ else
+ {
+ upperGradients.Add(rhs, rightGrad);
+ }
+
+ return nullptr;
+ }
+
+ InstPair copyInst(IRBuilder* builder, IRInst* origInst)
+ {
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
+ {
+ case kIROp_Param:
+ return transcribeParam(builder, as<IRParam>(origInst));
+
+ case kIROp_Return:
+ return InstPair(nullptr, nullptr);
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return copyBinaryArith(builder, origInst);
+
+ default:
+ // Not yet implemented
+ SLANG_ASSERT(0);
+ }
+
+ return InstPair(nullptr, nullptr);
+ }
+
+ IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
+ {
+ IRInOutType* inoutParam = as<IRInOutType>(param->getDataType());
+ auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType());
+ auto paramValue = builder->emitLoad(param);
+ auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
+ auto diff = builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ paramValue
+ );
+ auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad);
+ auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff);
+ auto store = builder->emitStore(param, updatedParam);
+
+ return store;
+ }
+
+ IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
+ {
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
+ {
+ case kIROp_Param:
+ return transcribeParamBackward(builder, as<IRParam>(origInst), grad);
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return transcribeBinaryArithBackward(builder, origInst, grad);
+
+ case kIROp_DifferentialPairGetPrimal:
+ {
+ if (auto param = primalToDiffPair.TryGetValue(origInst))
+ {
+ if (auto leftGrads = upperGradients.TryGetValue(*param))
+ {
+ leftGrads->add(grad);
+ }
+ else
+ {
+ upperGradients.Add(*param, grad);
+ }
+ }
+ else
+ SLANG_ASSERT(0);
+ return nullptr;
+ }
+
+ default:
+ // Not yet implemented
+ SLANG_ASSERT(0);
+ }
+
+ return nullptr;
+ }
+
+
+};
+
+struct ReverseDerivativePass : public InstPassBase
+{
+
+ DiagnosticSink* getSink()
+ {
+ return sink;
+ }
+
+ bool processModule()
+ {
+
+ IRBuilder builderStorage(autodiffContext->sharedBuilder);
+ IRBuilder* builder = &builderStorage;
+
+ // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
+ // generating derivative code for the referenced function.
+ //
+ bool modified = processReferencedFunctions(builder);
+
+ return modified;
+ }
+
+ IRInst* lookupJVPReference(IRInst* primalFunction)
+ {
+ if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
+ return jvpDefinition->getForwardDerivativeFunc();
+ return nullptr;
+ }
+
+ // Recursively process instructions looking for JVP calls (kIROp_ForwardDifferentiate),
+ // then check that the referenced function is marked correctly for differentiation.
+ //
+ bool processReferencedFunctions(IRBuilder* builder)
+ {
+ List<IRInst*> autoDiffWorkList;
+
+ for (;;)
+ {
+ // Collect all `ForwardDifferentiate` insts from the module.
+ autoDiffWorkList.clear();
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_BackwardDifferentiate:
+ // Only process now if the operand is a materialized function.
+ switch (inst->getOperand(0)->getOp())
+ {
+ case kIROp_Func:
+ case kIROp_Specialize:
+ autoDiffWorkList.add(inst);
+ break;
+ default:
+ break;
+ }
+ break;
+ default:
+ break;
+ }
+ });
+
+ if (autoDiffWorkList.getCount() == 0)
+ break;
+
+ // Process collected `ForwardDifferentiate` insts and replace them with placeholders for
+ // differentiated functions.
+
+ backwardTranscriberStorage.followUpFunctionsToTranscribe.clear();
+
+ for (auto differentiateInst : autoDiffWorkList)
+ {
+ IRInst* baseInst = differentiateInst->getOperand(0);
+ if (as<IRBackwardDifferentiate>(differentiateInst))
+ {
+ if (isMarkedForBackwardDifferentiation(baseInst))
+ {
+ if (as<IRFunc>(baseInst))
+ {
+ IRInst* diffFunc =
+ backwardTranscriberStorage
+ .transcribeFuncHeader(builder, (IRFunc*)baseInst)
+ .differential;
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else
+ {
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
+ }
+ }
+ }
+ }
+
+ auto followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe);
+ for (auto task : followUpWorkList)
+ {
+ auto diffFunc = as<IRFunc>(task.differential);
+ SLANG_ASSERT(diffFunc);
+ auto primalFunc = as<IRFunc>(task.primal);
+ SLANG_ASSERT(primalFunc);
+
+ backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
+ }
+
+ // Transcribing the function body really shouldn't produce more follow up function body work.
+ // However it may produce new `ForwardDifferentiate` instructions, which we collect and process
+ // in the next iteration.
+ SLANG_RELEASE_ASSERT(backwardTranscriberStorage.followUpFunctionsToTranscribe.getCount() == 0);
+
+ }
+ return true;
+ }
+
+ // Checks decorators to see if the function should
+ // be differentiated (kIROp_ForwardDifferentiableDecoration)
+ //
+ bool isMarkedForBackwardDifferentiation(IRInst* callable)
+ {
+ return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr;
+ }
+
+ IRStringLit* getBackwardDerivativeFuncName(IRInst* func)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(func);
+
+ IRStringLit* name = nullptr;
+ if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
+ {
+ name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice());
+ }
+ else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
+ {
+ name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice());
+ }
+
+ return name;
+ }
+
+ ReverseDerivativePass(AutoDiffSharedContext* context, DiagnosticSink* sink) :
+ InstPassBase(context->moduleInst->getModule()),
+ sink(sink),
+ backwardTranscriberStorage(context, context->sharedBuilder, sink),
+ autodiffContext(context),
+ pairBuilderStorage(context)
+ {
+ backwardTranscriberStorage.pairBuilder = &pairBuilderStorage;
+ }
+
+protected:
+ // A transcriber object that handles the main job of
+ // processing instructions while maintaining state.
+ //
+ BackwardDiffTranscriber backwardTranscriberStorage;
+
+ // Diagnostic object from the compile request for
+ // error messages.
+ DiagnosticSink* sink;
+
+ // Builder for dealing with differential pair types.
+ DifferentialPairTypeBuilder pairBuilderStorage;
+
+ // Autodiff Shared Context
+ AutoDiffSharedContext* autodiffContext;
+};
+
+bool processReverseDerivativeCalls(
+ AutoDiffSharedContext* autodiffContext,
+ DiagnosticSink* sink,
+ IRReverseDerivativePassOptions const&)
+{
+ ReverseDerivativePass revPass(autodiffContext, sink);
+ bool changed = revPass.processModule();
+ return changed;
+}
+
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
new file mode 100644
index 000000000..c3d31e2a9
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -0,0 +1,25 @@
+// slang-ir-autodiff-rev.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-fwd.h"
+
+namespace Slang
+{
+
+struct IRReverseDerivativePassOptions
+{
+ // Nothing for now..
+};
+
+bool processReverseDerivativeCalls(
+ AutoDiffSharedContext* autodiffContext,
+ DiagnosticSink* sink,
+ IRReverseDerivativePassOptions const& options = IRReverseDerivativePassOptions());
+
+
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
new file mode 100644
index 000000000..313760d85
--- /dev/null
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -0,0 +1,408 @@
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-rev.h"
+#include "slang-ir-autodiff-fwd.h"
+#include "slang-ir-autodiff-pairs.h"
+
+namespace Slang
+{
+
+// TODO: Put into a nameless namespace.
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
+{
+ if (auto witnessTable = as<IRWitnessTable>(witness))
+ {
+ for (auto entry : witnessTable->getEntries())
+ {
+ if (entry->getRequirementKey() == requirementKey)
+ return entry->getSatisfyingVal();
+ }
+ }
+ else if (auto witnessTableParam = as<IRParam>(witness))
+ {
+ return builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ witnessTableParam,
+ requirementKey);
+ }
+ return nullptr;
+}
+
+IRStructField* DifferentialPairTypeBuilder::findField(IRInst* type, IRStructKey* key)
+{
+ if (auto irStructType = as<IRStructType>(type))
+ {
+ for (auto field : irStructType->getFields())
+ {
+ if (field->getKey() == key)
+ {
+ return field;
+ }
+ }
+ }
+ else if (auto irSpecialize = as<IRSpecialize>(type))
+ {
+ if (auto irGeneric = as<IRGeneric>(irSpecialize->getBase()))
+ {
+ if (auto irGenericStructType = as<IRStructType>(findInnerMostGenericReturnVal(irGeneric)))
+ {
+ return findField(irGenericStructType, key);
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+IRInst* DifferentialPairTypeBuilder::findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam)
+{
+ // Get base generic that's being specialized.
+ auto genericType = as<IRGeneric>(as<IRSpecialize>(specializeInst)->getBase());
+ SLANG_ASSERT(genericType);
+
+ // Find the index of genericParam in the base generic.
+ int paramIndex = -1;
+ int currentIndex = 0;
+ for (auto param : genericType->getParams())
+ {
+ if (param == genericParam)
+ paramIndex = currentIndex;
+ currentIndex ++;
+ }
+
+ SLANG_ASSERT(paramIndex >= 0);
+
+ // Return the corresponding operand in the specialization inst.
+ return specializeInst->getOperand(1 + paramIndex);
+}
+
+IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key)
+{
+ IRInst* pairType = nullptr;
+ if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType()))
+ {
+ auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType());
+
+ // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types,
+ // especially since we probably don't need diff bottom anymore.
+ //
+ SLANG_ASSERT(!baseTypeInfo.isTrivial);
+
+ pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType);
+ }
+ else
+ {
+ auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType());
+ pairType = baseTypeInfo.loweredType;
+ }
+
+ if (auto basePairStructType = as<IRStructType>(pairType))
+ {
+ return as<IRFieldExtract>(builder->emitFieldExtract(
+ findField(basePairStructType, key)->getFieldType(),
+ baseInst,
+ key
+ ));
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(pairType))
+ {
+ if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
+ {
+ auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(ptrInnerSpecializedType->getBase()));
+ if (auto genericBasePairStructType = as<IRStructType>(genericType))
+ {
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findSpecializationForParam(
+ ptrInnerSpecializedType,
+ findField(ptrInnerSpecializedType, key)->getFieldType())),
+ baseInst,
+ key
+ ));
+ }
+ }
+ else if (auto ptrBaseStructType = as<IRStructType>(ptrType->getValueType()))
+ {
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findField(ptrBaseStructType, key)->getFieldType()),
+ baseInst,
+ key));
+ }
+ }
+ else if (auto specializedType = as<IRSpecialize>(pairType))
+ {
+ // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
+ // type, emit the specialization type.
+ //
+ auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase()));
+ if (auto genericBasePairStructType = as<IRStructType>(genericType))
+ {
+ return as<IRFieldExtract>(builder->emitFieldExtract(
+ (IRType*)findSpecializationForParam(
+ specializedType,
+ findField(genericBasePairStructType, key)->getFieldType()),
+ baseInst,
+ key
+ ));
+ }
+ else if (auto genericPtrType = as<IRPtrTypeBase>(genericType))
+ {
+ if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType()))
+ {
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findSpecializationForParam(
+ specializedType,
+ findField(genericPairStructType, key)->getFieldType())),
+ baseInst,
+ key
+ ));
+ }
+ }
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor");
+ }
+ return nullptr;
+}
+
+IRInst* DifferentialPairTypeBuilder::emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
+{
+ return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
+}
+
+IRInst* DifferentialPairTypeBuilder::emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
+{
+ return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
+}
+
+IRStructKey* DifferentialPairTypeBuilder::_getOrCreateDiffStructKey()
+{
+ if (!this->globalDiffKey)
+ {
+ IRBuilder builder(sharedContext->sharedBuilder);
+ // Insert directly at top level (skip any generic scopes etc.)
+ builder.setInsertInto(sharedContext->moduleInst);
+
+ this->globalDiffKey = builder.createStructKey();
+ builder.addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential"));
+ }
+
+ return this->globalDiffKey;
+}
+
+IRStructKey* DifferentialPairTypeBuilder::_getOrCreatePrimalStructKey()
+{
+ if (!this->globalPrimalKey)
+ {
+ // Insert directly at top level (skip any generic scopes etc.)
+ IRBuilder builder(sharedContext->sharedBuilder);
+ builder.setInsertInto(sharedContext->moduleInst);
+
+ this->globalPrimalKey = builder.createStructKey();
+ builder.addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal"));
+ }
+
+ return this->globalPrimalKey;
+}
+
+IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, IRType* diffType)
+{
+ switch (origBaseType->getOp())
+ {
+ case kIROp_lookup_interface_method:
+ case kIROp_Specialize:
+ case kIROp_Param:
+ return nullptr;
+ default:
+ break;
+ }
+
+ IRBuilder builder(sharedContext->sharedBuilder);
+ builder.setInsertBefore(diffType);
+
+ auto pairStructType = builder.createStructType();
+ builder.createStructField(pairStructType, _getOrCreatePrimalStructKey(), origBaseType);
+ builder.createStructField(pairStructType, _getOrCreateDiffStructKey(), (IRType*)diffType);
+ return pairStructType;
+}
+
+IRInst* DifferentialPairTypeBuilder::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
+}
+
+IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
+}
+
+DifferentialPairTypeBuilder::LoweredPairTypeInfo DifferentialPairTypeBuilder::lowerDiffPairType(
+ IRBuilder* builder, IRType* originalPairType)
+{
+ LoweredPairTypeInfo result = {};
+
+ if (pairTypeCache.TryGetValue(originalPairType, result))
+ return result;
+ auto pairType = as<IRDifferentialPairType>(originalPairType);
+ if (!pairType)
+ {
+ result.isTrivial = true;
+ result.loweredType = originalPairType;
+ return result;
+ }
+ auto primalType = pairType->getValueType();
+ if (as<IRParam>(primalType))
+ {
+ result.isTrivial = false;
+ result.loweredType = nullptr;
+ return result;
+ }
+
+ auto diffType = getDiffTypeFromPairType(builder, pairType);
+ if (!diffType)
+ return result;
+ result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
+ result.isTrivial = false;
+ pairTypeCache.Add(originalPairType, result);
+
+ return result;
+}
+
+
+AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst)
+ : moduleInst(inModuleInst)
+{
+ differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
+ if (differentiableInterfaceType)
+ {
+ differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
+ zeroMethodStructKey = findZeroMethodStructKey();
+ addMethodStructKey = findAddMethodStructKey();
+ mulMethodStructKey = findMulMethodStructKey();
+
+ if (differentialAssocTypeStructKey)
+ isInterfaceAvailable = true;
+ }
+}
+
+IRInst* AutoDiffSharedContext::findDifferentiableInterface()
+{
+ if (auto module = as<IRModuleInst>(moduleInst))
+ {
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ // TODO: This seems like a particularly dangerous way to look for an interface.
+ // See if we can lower IDifferentiable to a separate IR inst.
+ //
+ if (globalInst->getOp() == kIROp_InterfaceType &&
+ as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable")
+ {
+ return globalInst;
+ }
+ }
+ }
+ return nullptr;
+}
+
+IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index)
+{
+ if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
+ {
+ // Assume for now that IDifferentiable has exactly five fields.
+ SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
+ if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
+ return as<IRStructKey>(entry->getRequirementKey());
+ else
+ {
+ SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
+ }
+ }
+
+ return nullptr;
+}
+
+
+void stripAutoDiffDecorationsFromChildren(IRInst* parent)
+{
+ for (auto inst : parent->getChildren())
+ {
+ for (auto decor = inst->getFirstDecoration(); decor; )
+ {
+ auto next = decor->getNextDecoration();
+ switch (decor->getOp())
+ {
+ case kIROp_ForwardDerivativeDecoration:
+ case kIROp_DerivativeMemberDecoration:
+ case kIROp_DifferentiableTypeDictionaryDecoration:
+ decor->removeAndDeallocate();
+ break;
+ default:
+ break;
+ }
+ decor = next;
+ }
+
+ if (inst->getFirstChild() != nullptr)
+ {
+ stripAutoDiffDecorationsFromChildren(inst);
+ }
+ }
+}
+
+void stripAutoDiffDecorations(IRModule* module)
+{
+ stripAutoDiffDecorationsFromChildren(module->getModuleInst());
+}
+
+bool processAutodiffCalls(
+ IRModule* module,
+ DiagnosticSink* sink,
+ IRAutodiffPassOptions const&)
+{
+ // Simplify module to remove dead code.
+ IRDeadCodeEliminationOptions dceOptions;
+ dceOptions.keepExportsAlive = true;
+ dceOptions.keepLayoutsAlive = true;
+ eliminateDeadCode(module, dceOptions);
+
+ bool modified = false;
+
+ // Create shared context for all auto-diff related passes
+ AutoDiffSharedContext autodiffContext(module->getModuleInst());
+
+ // We start by initializing our shared IR building state,
+ // since we will re-use that state for any code we
+ // generate along the way.
+ //
+ SharedIRBuilder sharedBuilder;
+ sharedBuilder.init(module);
+ sharedBuilder.deduplicateAndRebuildGlobalNumberingMap();
+
+ autodiffContext.sharedBuilder = &sharedBuilder;
+
+ // Process forward derivative calls.
+ modified |= processForwardDerivativeCalls(&autodiffContext, sink);
+
+ // Process reverse derivative calls.
+ modified |= processReverseDerivativeCalls(&autodiffContext, sink);
+
+ // Replaces IRDifferentialPairType with an auto-generated struct,
+ // IRDifferentialPairGetDifferential with 'differential' field access,
+ // IRDifferentialPairGetPrimal with 'primal' field access, and
+ // IRMakeDifferentialPair with an IRMakeStruct.
+ //
+ modified |= processPairTypes(&autodiffContext);
+
+ // Remove auto-diff related decorations.
+ stripAutoDiffDecorations(module);
+
+ return modified;
+}
+
+
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
new file mode 100644
index 000000000..e470044a4
--- /dev/null
+++ b/source/slang/slang-ir-autodiff.h
@@ -0,0 +1,210 @@
+// slang-ir-autodiff-fwd.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dce.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
+
+namespace Slang
+{
+
+struct AutoDiffSharedContext
+{
+ IRModuleInst* moduleInst = nullptr;
+
+ SharedIRBuilder* sharedBuilder = nullptr;
+
+ // A reference to the builtin IDifferentiable interface type.
+ // We use this to look up all the other types (and type exprs)
+ // that conform to a base type.
+ //
+ IRInterfaceType* differentiableInterfaceType = nullptr;
+
+ // The struct key for the 'Differential' associated type
+ // defined inside IDifferential. We use this to lookup the differential
+ // type in the conformance table associated with the concrete type.
+ //
+ IRStructKey* differentialAssocTypeStructKey = nullptr;
+
+ // The struct key for the witness that `Differential` associated type conforms to
+ // `IDifferential`.
+ IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;
+
+
+ // The struct key for the 'zero()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of zero() for a given type.
+ //
+ IRStructKey* zeroMethodStructKey = nullptr;
+
+ // The struct key for the 'add()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of add() for a given type.
+ //
+ IRStructKey* addMethodStructKey = nullptr;
+
+ IRStructKey* mulMethodStructKey = nullptr;
+
+
+ // Modules that don't use differentiable types
+ // won't have the IDifferentiable interface type available.
+ // Set to false to indicate that we are uninitialized.
+ //
+ bool isInterfaceAvailable = false;
+
+
+ AutoDiffSharedContext(IRModuleInst* inModuleInst);
+
+private:
+
+ IRInst* findDifferentiableInterface();
+
+ IRStructKey* findDifferentialTypeStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(0);
+ }
+
+ IRStructKey* findDifferentialTypeWitnessStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(1);
+ }
+
+ IRStructKey* findZeroMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(2);
+ }
+
+ IRStructKey* findAddMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(3);
+ }
+
+ IRStructKey* findMulMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(4);
+ }
+
+ IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index);
+};
+
+struct DifferentiableTypeConformanceContext
+{
+ AutoDiffSharedContext* sharedContext;
+
+ IRGlobalValueWithCode* parentFunc = nullptr;
+ OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary;
+
+ DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
+ : sharedContext(shared)
+ {}
+
+ void setFunc(IRGlobalValueWithCode* func);
+
+ void buildGlobalWitnessDictionary();
+
+ // Lookup a witness table for the concreteType. One should exist if concreteType
+ // inherits (successfully) from IDifferentiable.
+ //
+ IRInst* lookUpConformanceForType(IRInst* type);
+
+ IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key);
+
+ // Lookup and return the 'Differential' type declared in the concrete type
+ // in order to conform to the IDifferentiable interface.
+ // Note that inside a generic block, this will be a witness table lookup instruction
+ // that gets resolved during the specialization pass.
+ //
+ IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
+ {
+ switch (origType->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ case kIROp_VectorType:
+ return origType;
+ }
+ return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
+ }
+
+ IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
+ }
+
+ IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
+ }
+
+};
+
+struct DifferentialPairTypeBuilder
+{
+ struct LoweredPairTypeInfo
+ {
+ IRInst* loweredType;
+ bool isTrivial;
+ };
+
+ DifferentialPairTypeBuilder() = default;
+
+ DifferentialPairTypeBuilder(AutoDiffSharedContext* sharedContext) : sharedContext(sharedContext) {}
+
+ IRStructField* findField(IRInst* type, IRStructKey* key);
+
+ IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam);
+
+ IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key);
+
+ IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst);
+
+ IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst);
+
+ IRStructKey* _getOrCreateDiffStructKey();
+
+ IRStructKey* _getOrCreatePrimalStructKey();
+
+ IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType);
+
+ IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type);
+
+ IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type);
+
+ LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType);
+
+
+ Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache;
+
+ IRStructKey* globalPrimalKey = nullptr;
+
+ IRStructKey* globalDiffKey = nullptr;
+
+ IRInst* genericDiffPairType = nullptr;
+
+ List<IRInst*> generatedTypeList;
+
+ AutoDiffSharedContext* sharedContext = nullptr;
+};
+
+void stripAutoDiffDecorations(IRModule* module);
+
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey);
+
+struct IRAutodiffPassOptions
+{
+ // Nothing for now...
+};
+
+bool processAutodiffCalls(
+ IRModule* module,
+ DiagnosticSink* sink,
+ IRAutodiffPassOptions const& options = IRAutodiffPassOptions());
+
+}; \ No newline at end of file
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 44c6324e3..f4f61d7e9 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -1,6 +1,6 @@
#include "slang-ir-check-differentiability.h"
-#include "slang-ir-diff-jvp.h"
+#include "slang-ir-autodiff.h"
#include "slang-ir-inst-pass-base.h"
namespace Slang
diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h
deleted file mode 100644
index 5e2a7f44f..000000000
--- a/source/slang/slang-ir-diff-jvp.h
+++ /dev/null
@@ -1,174 +0,0 @@
-// slang-ir-diff-jvp.h
-#pragma once
-
-#include "slang-ir.h"
-#include "slang-ir-insts.h"
-#include "slang-compiler.h"
-
-namespace Slang
-{
- template<typename P, typename D>
- struct DiffInstPair
- {
- P primal;
- D differential;
- DiffInstPair() = default;
- DiffInstPair(P primal, D differential) : primal(primal), differential(differential)
- {}
- HashCode getHashCode() const
- {
- Hasher hasher;
- hasher << primal << differential;
- return hasher.getResult();
- }
- bool operator ==(const DiffInstPair& other) const
- {
- return primal == other.primal && differential == other.differential;
- }
- };
-
- typedef DiffInstPair<IRInst*, IRInst*> InstPair;
-
- struct AutoDiffSharedContext
- {
- IRModuleInst* moduleInst = nullptr;
-
- SharedIRBuilder* sharedBuilder = nullptr;
-
- // A reference to the builtin IDifferentiable interface type.
- // We use this to look up all the other types (and type exprs)
- // that conform to a base type.
- //
- IRInterfaceType* differentiableInterfaceType = nullptr;
-
- // The struct key for the 'Differential' associated type
- // defined inside IDifferential. We use this to lookup the differential
- // type in the conformance table associated with the concrete type.
- //
- IRStructKey* differentialAssocTypeStructKey = nullptr;
-
- // The struct key for the witness that `Differential` associated type conforms to
- // `IDifferential`.
- IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;
-
-
- // The struct key for the 'zero()' associated type
- // defined inside IDifferential. We use this to lookup the
- // implementation of zero() for a given type.
- //
- IRStructKey* zeroMethodStructKey = nullptr;
-
- // The struct key for the 'add()' associated type
- // defined inside IDifferential. We use this to lookup the
- // implementation of add() for a given type.
- //
- IRStructKey* addMethodStructKey = nullptr;
-
- IRStructKey* mulMethodStructKey = nullptr;
-
-
- // Modules that don't use differentiable types
- // won't have the IDifferentiable interface type available.
- // Set to false to indicate that we are uninitialized.
- //
- bool isInterfaceAvailable = false;
-
-
- AutoDiffSharedContext(IRModuleInst* inModuleInst);
-
- private:
-
- IRInst* findDifferentiableInterface();
-
- IRStructKey* findDifferentialTypeStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(0);
- }
-
- IRStructKey* findDifferentialTypeWitnessStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(1);
- }
-
- IRStructKey* findZeroMethodStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(2);
- }
-
- IRStructKey* findAddMethodStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(3);
- }
-
- IRStructKey* findMulMethodStructKey()
- {
- return getIDifferentiableStructKeyAtIndex(4);
- }
-
- IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index);
- };
-
- struct DifferentiableTypeConformanceContext
- {
- AutoDiffSharedContext* sharedContext;
-
- IRGlobalValueWithCode* parentFunc = nullptr;
- OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary;
-
- DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
- : sharedContext(shared)
- {}
-
- void setFunc(IRGlobalValueWithCode* func);
-
- void buildGlobalWitnessDictionary();
-
- // Lookup a witness table for the concreteType. One should exist if concreteType
- // inherits (successfully) from IDifferentiable.
- //
- IRInst* lookUpConformanceForType(IRInst* type);
-
- IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key);
-
- // Lookup and return the 'Differential' type declared in the concrete type
- // in order to conform to the IDifferentiable interface.
- // Note that inside a generic block, this will be a witness table lookup instruction
- // that gets resolved during the specialization pass.
- //
- IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
- {
- switch (origType->getOp())
- {
- case kIROp_FloatType:
- case kIROp_HalfType:
- case kIROp_DoubleType:
- case kIROp_VectorType:
- return origType;
- }
- return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
- }
-
- IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
- }
-
- IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
- }
-
- };
-
- struct IRJVPDerivativePassOptions
- {
- // Nothing for now..
- };
-
- bool processDifferentiableFuncs(
- IRModule* module,
- DiagnosticSink* sink,
- IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions());
-
- void stripAutoDiffDecorations(IRModule* module);
-}
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 8559103ae..3e85cc40e 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -6,7 +6,7 @@
#include "slang-ir-insts.h"
#include "slang-mangle.h"
#include "slang-ir-string-hash.h"
-#include "slang-ir-diff-jvp.h"
+#include "slang-ir-autodiff.h"
#include "slang-module-library.h"
#include "../compiler-core/slang-artifact.h"
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 61b6fcb76..57267a9ea 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8,7 +8,7 @@
#include "slang-ir-constexpr.h"
#include "slang-ir-dce.h"
#include "slang-ir-diff-call.h"
-#include "slang-ir-diff-jvp.h"
+#include "slang-ir-autodiff.h"
#include "slang-ir-inline.h"
#include "slang-ir-insts.h"
#include "slang-ir-check-differentiability.h"