summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-11-29 20:01:41 -0500
committerGitHub <noreply@github.com>2022-11-29 17:01:41 -0800
commitf5581786a1891cedb165adb1afe71fe34f26e030 (patch)
tree86da2f1acbaec920ac0c38349897b293b405c021
parentaf7f40063dfed1c651d33b93956c7623a7d2c050 (diff)
Refactored reverse-mode implementation to use 4 separate passes. (#2531)
* Added partial implementation for reverse-mode * Fixing several compile and runtime errors. * Fixed several issues with reverse-mode passes. * Fixed more issues. Basic reverse-mode tests passing Co-authored-by: Edward Liu <shiqiu1105@gmail.com>
-rw-r--r--build/visual-studio/slang/slang.vcxproj3
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters9
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2550
-rw-r--r--source/slang/slang-ir-autodiff-fwd.h165
-rw-r--r--source/slang/slang-ir-autodiff-propagate.h102
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp196
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h420
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h110
-rw-r--r--source/slang/slang-ir-autodiff.cpp54
-rw-r--r--source/slang/slang-ir-inst-defs.h7
-rw-r--r--source/slang/slang-ir-insts.h14
11 files changed, 2294 insertions, 1336 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index fe1922d29..ff8a09599 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -335,7 +335,10 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-augment-make-existential.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-propagate.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-transpose.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-unzip.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-bind-existentials.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-byte-address-legalize.h" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 6fa42287c..3c29fd21b 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -138,9 +138,18 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-propagate.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-transpose.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-unzip.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff.h">
<Filter>Header Files</Filter>
</ClInclude>
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index c93522565..0ad9ce87c 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -12,167 +12,129 @@ namespace Slang
{
-struct JVPTranscriber
+DiagnosticSink* ForwardDerivativeTranscriber::getSink()
{
+ SLANG_ASSERT(sink);
+ return sink;
+}
- // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
- // their differential values.
- Dictionary<IRInst*, IRInst*> instMapD;
-
- // Set of insts currently being transcribed. Used to avoid infinite loops.
- HashSet<IRInst*> instsInProgress;
-
- // Cloning environment to hold mapping from old to new copies for the primal
- // instructions.
- IRCloneEnv cloneEnv;
-
- // Diagnostic sink for error messages.
- DiagnosticSink* sink;
-
- // Type conformance information.
- AutoDiffSharedContext* autoDiffSharedContext;
-
- // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
- DifferentialPairTypeBuilder* pairBuilder;
-
- DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
-
- List<InstPair> followUpFunctionsToTranscribe;
-
- SharedIRBuilder* sharedBuilder;
- // Witness table that `DifferentialBottom:IDifferential`.
- IRWitnessTable* differentialBottomWitness = nullptr;
- Dictionary<InstPair, IRInst*> differentialPairTypes;
-
- JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder)
- : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder)
- {
-
- }
-
- DiagnosticSink* getSink()
- {
- SLANG_ASSERT(sink);
- return sink;
- }
-
- void mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
+void ForwardDerivativeTranscriber::mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
+{
+ if (hasDifferentialInst(origInst))
{
- if (hasDifferentialInst(origInst))
+ if (lookupDiffInst(origInst) != diffInst)
{
- if (lookupDiffInst(origInst) != diffInst)
- {
- SLANG_UNEXPECTED("Inconsistent differential mappings");
- }
- }
- else
- {
- instMapD.Add(origInst, diffInst);
+ SLANG_UNEXPECTED("Inconsistent differential mappings");
}
}
-
- void mapPrimalInst(IRInst* origInst, IRInst* primalInst)
+ else
{
- if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst)
- {
- getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "inconsistent primal instruction for original");
- }
- else
- {
- cloneEnv.mapOldValToNew[origInst] = primalInst;
- }
+ instMapD.Add(origInst, diffInst);
}
+}
- IRInst* lookupDiffInst(IRInst* origInst)
+void ForwardDerivativeTranscriber::mapPrimalInst(IRInst* origInst, IRInst* primalInst)
+{
+ if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst)
{
- return instMapD[origInst];
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "inconsistent primal instruction for original");
}
-
- IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst)
+ else
{
- return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst;
+ cloneEnv.mapOldValToNew[origInst] = primalInst;
}
+}
- bool hasDifferentialInst(IRInst* origInst)
- {
- return instMapD.ContainsKey(origInst);
- }
+IRInst* ForwardDerivativeTranscriber::lookupDiffInst(IRInst* origInst)
+{
+ return instMapD[origInst];
+}
- bool shouldUseOriginalAsPrimal(IRInst* origInst)
- {
- if (as<IRGlobalValueWithCode>(origInst))
- return true;
- if (origInst->parent && origInst->parent->getOp() == kIROp_Module)
- return true;
- return false;
- }
+IRInst* ForwardDerivativeTranscriber::lookupDiffInst(IRInst* origInst, IRInst* defaultInst)
+{
+ return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst;
+}
- IRInst* lookupPrimalInst(IRInst* origInst)
- {
- if (!origInst)
- return nullptr;
- if (shouldUseOriginalAsPrimal(origInst))
- return origInst;
- return cloneEnv.mapOldValToNew[origInst];
- }
+bool ForwardDerivativeTranscriber::hasDifferentialInst(IRInst* origInst)
+{
+ return instMapD.ContainsKey(origInst);
+}
- IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
- {
- if (!origInst)
- return nullptr;
- return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
- }
+bool ForwardDerivativeTranscriber::shouldUseOriginalAsPrimal(IRInst* origInst)
+{
+ if (as<IRGlobalValueWithCode>(origInst))
+ return true;
+ if (origInst->parent && origInst->parent->getOp() == kIROp_Module)
+ return true;
+ return false;
+}
+
+IRInst* ForwardDerivativeTranscriber::lookupPrimalInst(IRInst* origInst)
+{
+ if (!origInst)
+ return nullptr;
+ if (shouldUseOriginalAsPrimal(origInst))
+ return origInst;
+ return cloneEnv.mapOldValToNew[origInst];
+}
+
+IRInst* ForwardDerivativeTranscriber::lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
+{
+ if (!origInst)
+ return nullptr;
+ return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
+}
- bool hasPrimalInst(IRInst* origInst)
+bool ForwardDerivativeTranscriber::hasPrimalInst(IRInst* origInst)
+{
+ if (!origInst)
+ return true;
+ if (shouldUseOriginalAsPrimal(origInst))
+ return true;
+ return cloneEnv.mapOldValToNew.ContainsKey(origInst);
+}
+
+IRInst* ForwardDerivativeTranscriber::findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst)
+{
+ if (!hasDifferentialInst(origInst))
{
- if (!origInst)
- return true;
- if (shouldUseOriginalAsPrimal(origInst))
- return true;
- return cloneEnv.mapOldValToNew.ContainsKey(origInst);
+ transcribe(builder, origInst);
+ SLANG_ASSERT(hasDifferentialInst(origInst));
}
- IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst)
- {
- if (!hasDifferentialInst(origInst))
- {
- transcribe(builder, origInst);
- SLANG_ASSERT(hasDifferentialInst(origInst));
- }
+ return lookupDiffInst(origInst);
+}
- return lookupDiffInst(origInst);
- }
+IRInst* ForwardDerivativeTranscriber::findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst)
+{
+ if (shouldUseOriginalAsPrimal(origInst))
+ return origInst;
- IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst)
+ if (!hasPrimalInst(origInst))
{
- if (shouldUseOriginalAsPrimal(origInst))
- return origInst;
+ transcribe(builder, origInst);
+ SLANG_ASSERT(hasPrimalInst(origInst));
+ }
- if (!hasPrimalInst(origInst))
- {
- transcribe(builder, origInst);
- SLANG_ASSERT(hasPrimalInst(origInst));
- }
+ return lookupPrimalInst(origInst);
+}
- return lookupPrimalInst(origInst);
- }
+IRFuncType* ForwardDerivativeTranscriber::differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+{
+ List<IRType*> newParameterTypes;
+ IRType* diffReturnType;
- IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ for (UIndex i = 0; i < funcType->getParamCount(); i++)
{
- List<IRType*> newParameterTypes;
- IRType* diffReturnType;
-
- for (UIndex i = 0; i < funcType->getParamCount(); i++)
- {
- auto origType = funcType->getParamType(i);
- origType = (IRType*) lookupPrimalInst(origType, origType);
- if (auto diffPairType = tryGetDiffPairType(builder, origType))
- newParameterTypes.add(diffPairType);
- else
- newParameterTypes.add(origType);
- }
+ auto origType = funcType->getParamType(i);
+ origType = (IRType*) lookupPrimalInst(origType, origType);
+ if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ newParameterTypes.add(diffPairType);
+ else
+ newParameterTypes.add(origType);
+ }
// Transcribe return type to a pair.
// This will be void if the primal return type is non-differentiable.
@@ -183,562 +145,562 @@ struct JVPTranscriber
else
diffReturnType = origResultType;
- return builder->getFuncType(newParameterTypes, diffReturnType);
- }
+ return builder->getFuncType(newParameterTypes, diffReturnType);
+}
- // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
- IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(inDiffPairType->parent);
- auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
- SLANG_ASSERT(diffPairType);
+// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+IRWitnessTable* ForwardDerivativeTranscriber::getDifferentialPairWitness(IRInst* inDiffPairType)
+{
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(inDiffPairType->parent);
+ auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
+ SLANG_ASSERT(diffPairType);
- auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
- // Differentiate the pair type to get it's differential (which is itself a pair)
- auto diffDiffPairType = differentiateType(&builder, diffPairType);
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffDiffPairType = differentiateType(&builder, diffPairType);
- // And place it in the synthesized witness table.
- builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
- // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+ // And place it in the synthesized witness table.
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
- // Record this in the context for future lookups
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ // Record this in the context for future lookups
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
- return table;
- }
+ return table;
+}
- IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
- }
+IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+{
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+}
- IRType* getOrCreateDiffPairType(IRInst* primalType)
- {
- IRBuilder builder(sharedBuilder);
- builder.setInsertInto(primalType->parent);
- auto witness = as<IRWitnessTable>(
- differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType)
+{
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ auto witness = as<IRWitnessTable>(
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
- if (!witness)
+ if (!witness)
+ {
+ if (auto primalPairType = as<IRDifferentialPairType>(primalType))
{
- if (auto primalPairType = as<IRDifferentialPairType>(primalType))
- {
- witness = getDifferentialPairWitness(primalPairType);
- }
+ witness = getDifferentialPairWitness(primalPairType);
}
-
- return builder.getDifferentialPairType(
- (IRType*)primalType,
- witness);
}
- IRType* differentiateType(IRBuilder* builder, IRType* origType)
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+}
+
+IRType* ForwardDerivativeTranscriber::differentiateType(IRBuilder* builder, IRType* origType)
+{
+ IRInst* diffType = nullptr;
+ if (!instMapD.TryGetValue(origType, diffType))
{
- IRInst* diffType = nullptr;
- if (!instMapD.TryGetValue(origType, diffType))
- {
- diffType = _differentiateTypeImpl(builder, origType);
- instMapD[origType] = diffType;
- }
- return (IRType*)diffType;
+ diffType = _differentiateTypeImpl(builder, origType);
+ instMapD[origType] = diffType;
}
+ return (IRType*)diffType;
+}
- IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
- {
- if (auto ptrType = as<IRPtrTypeBase>(origType))
- return builder->getPtrType(
- origType->getOp(),
- differentiateType(builder, ptrType->getValueType()));
+IRType* ForwardDerivativeTranscriber::_differentiateTypeImpl(IRBuilder* builder, IRType* origType)
+{
+ if (auto ptrType = as<IRPtrTypeBase>(origType))
+ return builder->getPtrType(
+ origType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
- // If there is an explicit primal version of this type in the local scope, load that
- // otherwise use the original type.
- //
- IRInst* primalType = lookupPrimalInst(origType, origType);
-
- // Special case certain compound types (PtrType, FuncType, etc..)
- // otherwise try to lookup a differential definition for the given type.
- // If one does not exist, then we assume it's not differentiable.
- //
- switch (primalType->getOp())
+ // If there is an explicit primal version of this type in the local scope, load that
+ // otherwise use the original type.
+ //
+ IRInst* primalType = lookupPrimalInst(origType, origType);
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as<IRTypeType>(primalType->getDataType()))
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
+ builder,
+ (IRType*)primalType));
+ else if (as<IRWitnessTableType>(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
{
- case kIROp_Param:
- if (as<IRTypeType>(primalType->getDataType()))
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
- builder,
- (IRType*)primalType));
- else if (as<IRWitnessTableType>(primalType->getDataType()))
- return (IRType*)primalType;
-
- case kIROp_ArrayType:
- {
- auto primalArrayType = as<IRArrayType>(primalType);
- if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
- return builder->getArrayType(
- diffElementType,
- primalArrayType->getElementCount());
- else
- return nullptr;
- }
+ auto primalArrayType = as<IRArrayType>(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
- case kIROp_DifferentialPairType:
- {
- auto primalPairType = as<IRDifferentialPairType>(primalType);
- return getOrCreateDiffPairType(
- pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
- pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
- }
-
- case kIROp_FuncType:
- return differentiateFunctionType(builder, as<IRFuncType>(primalType));
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as<IRDifferentialPairType>(primalType);
+ return getOrCreateDiffPairType(
+ pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
+ pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_FuncType:
+ return differentiateFunctionType(builder, as<IRFuncType>(primalType));
- case kIROp_OutType:
- if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
- return builder->getOutType(diffValueType);
- else
- return nullptr;
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
- case kIROp_InOutType:
- if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
- return builder->getInOutType(diffValueType);
- else
- return nullptr;
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
+ return nullptr;
- case kIROp_TupleType:
- {
- auto tupleType = as<IRTupleType>(primalType);
- List<IRType*> diffTypeList;
- // TODO: what if we have type parameters here?
- for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
- diffTypeList.add(
- differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
-
- return builder->getTupleType(diffTypeList);
- }
+ case kIROp_TupleType:
+ {
+ auto tupleType = as<IRTupleType>(primalType);
+ List<IRType*> diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
- default:
- return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
+ return builder->getTupleType(diffTypeList);
}
+
+ default:
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
}
+}
- IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
- {
- // If this is a PtrType (out, inout, etc..), then create diff pair from
- // value type and re-apply the appropropriate PtrType wrapper.
- //
- if (auto origPtrType = as<IRPtrTypeBase>(primalType))
- {
- if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
- return builder->getPtrType(primalType->getOp(), diffPairValueType);
- else
- return nullptr;
- }
- auto diffType = differentiateType(builder, primalType);
- if (diffType)
- return (IRType*)getOrCreateDiffPairType(primalType);
- return nullptr;
+IRType* ForwardDerivativeTranscriber::tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
+{
+ // If this is a PtrType (out, inout, etc..), then create diff pair from
+ // value type and re-apply the appropropriate PtrType wrapper.
+ //
+ if (auto origPtrType = as<IRPtrTypeBase>(primalType))
+ {
+ if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
+ return builder->getPtrType(primalType->getOp(), diffPairValueType);
+ else
+ return nullptr;
}
+ auto diffType = differentiateType(builder, primalType);
+ if (diffType)
+ return (IRType*)getOrCreateDiffPairType(primalType);
+ return nullptr;
+}
- InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
+InstPair ForwardDerivativeTranscriber::transcribeParam(IRBuilder* builder, IRParam* origParam)
+{
+ auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType());
+ // Do not differentiate generic type (and witness table) parameters
+ if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
{
- auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType());
- // Do not differentiate generic type (and witness table) parameters
- if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
- {
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
- }
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
- // Is this param a phi node or a function parameter?
- auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent());
- bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock());
- if (isFuncParam)
+ // Is this param a phi node or a function parameter?
+ auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent());
+ bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock());
+ if (isFuncParam)
+ {
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
{
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
- {
- IRInst* diffPairParam = builder->emitParam(diffPairType);
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
- SLANG_ASSERT(diffPairParam);
-
- if (auto pairType = as<IRDifferentialPairType>(diffPairType))
- {
- return InstPair(
- builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- diffPairParam));
- }
- else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
- {
- auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
-
- return InstPair(
- builder->emitDifferentialPairAddressPrimal(diffPairParam),
- builder->emitDifferentialPairAddressDifferential(
- builder->getPtrType(
- kIROp_PtrType,
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)),
- diffPairParam));
- }
+ SLANG_ASSERT(diffPairParam);
+
+ if (auto pairType = as<IRDifferentialPairType>(diffPairType))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
}
-
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
- }
- else
- {
- auto primal = cloneInst(&cloneEnv, builder, origParam);
- IRInst* diff = nullptr;
- if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType))
+ else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
{
- diff = builder->emitParam(diffType);
+ auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
+
+ return InstPair(
+ builder->emitDifferentialPairAddressPrimal(diffPairParam),
+ builder->emitDifferentialPairAddressDifferential(
+ builder->getPtrType(
+ kIROp_PtrType,
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)),
+ diffPairParam));
}
- return InstPair(primal, diff);
}
-
- }
- // Returns "d<var-name>" to use as a name hint for variables and parameters.
- // If no primal name is available, returns a blank string.
- //
- String getJVPVarName(IRInst* origVar)
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+ else
{
- if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
+ auto primal = cloneInst(&cloneEnv, builder, origParam);
+ IRInst* diff = nullptr;
+ if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType))
{
- return ("d" + String(namehintDecoration->getName()));
+ diff = builder->emitParam(diffType);
}
-
- return String("");
+ return InstPair(primal, diff);
}
+
+}
- // Returns "dp<var-name>" to use as a name hint for parameters.
- // If no primal name is available, returns a blank string.
- //
- String makeDiffPairName(IRInst* origVar)
+// Returns "d<var-name>" to use as a name hint for variables and parameters.
+// If no primal name is available, returns a blank string.
+//
+String ForwardDerivativeTranscriber::getJVPVarName(IRInst* origVar)
+{
+ if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
{
- if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
- {
- return ("dp" + String(namehintDecoration->getName()));
- }
-
- return String("");
+ return ("d" + String(namehintDecoration->getName()));
}
- InstPair transcribeVar(IRBuilder* builder, IRVar* origVar)
+ return String("");
+}
+
+// Returns "dp<var-name>" to use as a name hint for parameters.
+// If no primal name is available, returns a blank string.
+//
+String ForwardDerivativeTranscriber::makeDiffPairName(IRInst* origVar)
+{
+ if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
{
- if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
- {
- IRVar* diffVar = builder->emitVar(diffType);
- SLANG_ASSERT(diffVar);
+ return ("dp" + String(namehintDecoration->getName()));
+ }
- auto diffNameHint = getJVPVarName(origVar);
- if (diffNameHint.getLength() > 0)
- builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice());
+ return String("");
+}
- return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar);
- }
+InstPair ForwardDerivativeTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar)
+{
+ if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
+ {
+ IRVar* diffVar = builder->emitVar(diffType);
+ SLANG_ASSERT(diffVar);
- return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr);
+ auto diffNameHint = getJVPVarName(origVar);
+ if (diffNameHint.getLength() > 0)
+ builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice());
+
+ return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar);
}
- InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith)
- {
- SLANG_ASSERT(origArith->getOperandCount() == 2);
+ return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr);
+}
- IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith);
-
- auto origLeft = origArith->getOperand(0);
- auto origRight = origArith->getOperand(1);
+InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder, IRInst* origArith)
+{
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
- auto primalLeft = findOrTranscribePrimalInst(builder, origLeft);
- auto primalRight = findOrTranscribePrimalInst(builder, origRight);
+ IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith);
- auto diffLeft = findOrTranscribeDiffInst(builder, origLeft);
- auto diffRight = findOrTranscribeDiffInst(builder, origRight);
+ auto origLeft = origArith->getOperand(0);
+ auto origRight = origArith->getOperand(1);
+ auto primalLeft = findOrTranscribePrimalInst(builder, origLeft);
+ auto primalRight = findOrTranscribePrimalInst(builder, origRight);
- if (diffLeft || diffRight)
- {
- diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType());
- diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
+ auto diffLeft = findOrTranscribeDiffInst(builder, origLeft);
+ auto diffRight = findOrTranscribeDiffInst(builder, origRight);
- auto resultType = primalArith->getDataType();
- switch(origArith->getOp())
- {
- case kIROp_Add:
- return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight));
- case kIROp_Mul:
- return InstPair(primalArith, builder->emitAdd(resultType,
- builder->emitMul(resultType, diffLeft, primalRight),
- builder->emitMul(resultType, primalLeft, diffRight)));
- case kIROp_Sub:
- return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight));
- case kIROp_Div:
- return InstPair(primalArith, builder->emitDiv(resultType,
- builder->emitSub(
- resultType,
- builder->emitMul(resultType, diffLeft, primalRight),
- builder->emitMul(resultType, primalLeft, diffRight)),
- builder->emitMul(
- primalRight->getDataType(), primalRight, primalRight
- )));
- default:
- getSink()->diagnose(origArith->sourceLoc,
- Diagnostics::unimplemented,
- "this arithmetic instruction cannot be differentiated");
- }
- }
-
- return InstPair(primalArith, nullptr);
- }
-
- InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic)
+ if (diffLeft || diffRight)
{
- SLANG_ASSERT(origLogic->getOperandCount() == 2);
+ diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType());
+ diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
- // TODO: Check other boolean cases.
- if (as<IRBoolType>(origLogic->getDataType()))
+ auto resultType = primalArith->getDataType();
+ switch(origArith->getOp())
{
- // Boolean operations are not differentiable. For the linearization
- // pass, we do not need to do anything but copy them over to the ne
- // function.
- auto primalLogic = cloneInst(&cloneEnv, builder, origLogic);
- return InstPair(primalLogic, nullptr);
+ case kIROp_Add:
+ return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight));
+ case kIROp_Mul:
+ return InstPair(primalArith, builder->emitAdd(resultType,
+ builder->emitMul(resultType, diffLeft, primalRight),
+ builder->emitMul(resultType, primalLeft, diffRight)));
+ case kIROp_Sub:
+ return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight));
+ case kIROp_Div:
+ return InstPair(primalArith, builder->emitDiv(resultType,
+ builder->emitSub(
+ resultType,
+ builder->emitMul(resultType, diffLeft, primalRight),
+ builder->emitMul(resultType, primalLeft, diffRight)),
+ builder->emitMul(
+ primalRight->getDataType(), primalRight, primalRight
+ )));
+ default:
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
}
-
- SLANG_UNEXPECTED("Logical operation with non-boolean result");
}
- InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
+ return InstPair(primalArith, nullptr);
+}
+
+
+InstPair ForwardDerivativeTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic)
+{
+ SLANG_ASSERT(origLogic->getOperandCount() == 2);
+
+ // TODO: Check other boolean cases.
+ if (as<IRBoolType>(origLogic->getDataType()))
{
- auto origPtr = origLoad->getPtr();
- auto primalPtr = lookupPrimalInst(origPtr, nullptr);
- auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType();
+ // Boolean operations are not differentiable. For the linearization
+ // pass, we do not need to do anything but copy them over to the ne
+ // function.
+ auto primalLogic = cloneInst(&cloneEnv, builder, origLogic);
+ return InstPair(primalLogic, nullptr);
+ }
+
+ SLANG_UNEXPECTED("Logical operation with non-boolean result");
+}
- if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType))
- {
- // Special case load from an `out` param, which will not have corresponding `diff` and
- // `primal` insts yet.
-
- auto load = builder->emitLoad(primalPtr);
- auto primalElement = builder->emitDifferentialPairGetPrimal(load);
- auto diffElement = builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load);
- return InstPair(primalElement, diffElement);
- }
+InstPair ForwardDerivativeTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
+{
+ auto origPtr = origLoad->getPtr();
+ auto primalPtr = lookupPrimalInst(origPtr, nullptr);
+ auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType();
- auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
- IRInst* diffLoad = nullptr;
- if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
- {
- // Default case, we're loading from a known differential inst.
- diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
- }
- return InstPair(primalLoad, diffLoad);
+ if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType))
+ {
+ // Special case load from an `out` param, which will not have corresponding `diff` and
+ // `primal` insts yet.
+
+ auto load = builder->emitLoad(primalPtr);
+ auto primalElement = builder->emitDifferentialPairGetPrimal(load);
+ auto diffElement = builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load);
+ return InstPair(primalElement, diffElement);
}
- InstPair transcribeStore(IRBuilder* builder, IRStore* origStore)
+ auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+ IRInst* diffLoad = nullptr;
+ if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
{
- IRInst* origStoreLocation = origStore->getPtr();
- IRInst* origStoreVal = origStore->getVal();
- auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr);
- auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
- auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr);
- auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);
+ // Default case, we're loading from a known differential inst.
+ diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
+ }
+ return InstPair(primalLoad, diffLoad);
+}
- if (!diffStoreLocation)
+InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRStore* origStore)
+{
+ IRInst* origStoreLocation = origStore->getPtr();
+ IRInst* origStoreVal = origStore->getVal();
+ auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr);
+ auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
+ auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr);
+ auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);
+
+ if (!diffStoreLocation)
+ {
+ auto primalLocationPtrType = as<IRPtrTypeBase>(primalStoreLocation->getDataType());
+ if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType()))
{
- auto primalLocationPtrType = as<IRPtrTypeBase>(primalStoreLocation->getDataType());
- if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType()))
- {
- auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal);
- auto store = builder->emitStore(primalStoreLocation, valToStore);
- return InstPair(store, nullptr);
- }
+ auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal);
+ auto store = builder->emitStore(primalStoreLocation, valToStore);
+ return InstPair(store, nullptr);
}
+ }
- auto primalStore = cloneInst(&cloneEnv, builder, origStore);
-
- IRInst* diffStore = nullptr;
+ auto primalStore = cloneInst(&cloneEnv, builder, origStore);
- // If the stored value has a differential version,
- // emit a store instruction for the differential parameter.
- // Otherwise, emit nothing since there's nothing to load.
- //
- if (diffStoreLocation && diffStoreVal)
- {
- // Default case, storing the entire type (and not a member)
- diffStore = as<IRStore>(
- builder->emitStore(diffStoreLocation, diffStoreVal));
-
- return InstPair(primalStore, diffStore);
- }
+ IRInst* diffStore = nullptr;
- return InstPair(primalStore, nullptr);
+ // If the stored value has a differential version,
+ // emit a store instruction for the differential parameter.
+ // Otherwise, emit nothing since there's nothing to load.
+ //
+ if (diffStoreLocation && diffStoreVal)
+ {
+ // Default case, storing the entire type (and not a member)
+ diffStore = as<IRStore>(
+ builder->emitStore(diffStoreLocation, diffStoreVal));
+
+ return InstPair(primalStore, diffStore);
}
- InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn)
+ return InstPair(primalStore, nullptr);
+}
+
+InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRReturn* origReturn)
+{
+ IRInst* origReturnVal = origReturn->getVal();
+
+ auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType());
+ if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal))
{
- IRInst* origReturnVal = origReturn->getVal();
+ // If the return value is itself a function, generic or a struct then this
+ // is likely to be a generic scope. In this case, we lookup the differential
+ // and return that.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+ IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
+
+ // Neither of these should be nullptr.
+ SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal);
+ IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal));
- auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType());
- if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal))
- {
- // If the return value is itself a function, generic or a struct then this
- // is likely to be a generic scope. In this case, we lookup the differential
- // and return that.
- IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
- IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
-
- // Neither of these should be nullptr.
- SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal);
- IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal));
-
- return InstPair(diffReturn, diffReturn);
- }
- else if (auto pairType = tryGetDiffPairType(builder, returnDataType))
- {
- IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
- IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
- if(!diffReturnVal)
- diffReturnVal = getDifferentialZeroOfType(builder, returnDataType);
-
- // If the pair type can be formed, this must be non-null.
- SLANG_RELEASE_ASSERT(diffReturnVal);
-
- auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
- IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
- return InstPair(pairReturn, pairReturn);
- }
- else
- {
- // If the return type is not differentiable, emit the primal value only.
- IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+ return InstPair(diffReturn, diffReturn);
+ }
+ else if (auto pairType = tryGetDiffPairType(builder, returnDataType))
+ {
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+ IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
+ if(!diffReturnVal)
+ diffReturnVal = getDifferentialZeroOfType(builder, returnDataType);
- IRInst* primalReturn = builder->emitReturn(primalReturnVal);
- return InstPair(primalReturn, nullptr);
-
- }
+ // If the pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffReturnVal);
+
+ auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
+ IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
+ return InstPair(pairReturn, pairReturn);
}
+ else
+ {
+ // If the return type is not differentiable, emit the primal value only.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
- // Since int/float literals are sometimes nested inside an IRConstructor
- // instruction, we check to make sure that the nested instr is a constant
- // and then return nullptr. Literals do not need to be differentiated.
- //
- InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
- {
- IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct);
+ IRInst* primalReturn = builder->emitReturn(primalReturnVal);
+ return InstPair(primalReturn, nullptr);
- // Check if the output type can be differentiated. If it cannot be
- // differentiated, don't differentiate the inst
- //
- auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType());
- if (auto diffConstructType = differentiateType(builder, primalConstructType))
- {
- UCount operandCount = origConstruct->getOperandCount();
-
- List<IRInst*> diffOperands;
- for (UIndex ii = 0; ii < operandCount; ii++)
- {
- // If the operand has a differential version, replace the original with
- // the differential. Otherwise, use a zero.
- //
- if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr))
- diffOperands.add(diffInst);
- else
- {
- auto operandDataType = origConstruct->getOperand(ii)->getDataType();
- operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType);
- diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
- }
- }
-
- return InstPair(
- primalConstruct,
- builder->emitIntrinsicInst(
- diffConstructType,
- origConstruct->getOp(),
- operandCount,
- diffOperands.getBuffer()));
- }
- else
- {
- return InstPair(primalConstruct, nullptr);
- }
}
+}
- // Differentiating a call instruction here is primarily about generating
- // an appropriate call list based on whichever parameters have differentials
- // in the current transcription context.
+// Since int/float literals are sometimes nested inside an IRConstructor
+// instruction, we check to make sure that the nested instr is a constant
+// and then return nullptr. Literals do not need to be differentiated.
+//
+InstPair ForwardDerivativeTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
+{
+ IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct);
+
+ // Check if the output type can be differentiated. If it cannot be
+ // differentiated, don't differentiate the inst
//
- InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
- {
-
- IRInst* origCallee = origCall->getCallee();
+ auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType());
+ if (auto diffConstructType = differentiateType(builder, primalConstructType))
+ {
+ UCount operandCount = origConstruct->getOperandCount();
- if (!origCallee)
+ List<IRInst*> diffOperands;
+ for (UIndex ii = 0; ii < operandCount; ii++)
{
- // Note that this can only happen if the callee is a result
- // of a higher-order operation. For now, we assume that we cannot
- // differentiate such calls safely.
- // TODO(sai): Should probably get checked in the front-end.
- //
- getSink()->diagnose(origCall->sourceLoc,
- Diagnostics::internalCompilerError,
- "attempting to differentiate unresolved callee");
-
- return InstPair(nullptr, nullptr);
+ // If the operand has a differential version, replace the original with
+ // the differential. Otherwise, use a zero.
+ //
+ if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr))
+ diffOperands.add(diffInst);
+ else
+ {
+ auto operandDataType = origConstruct->getOperand(ii)->getDataType();
+ operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType);
+ diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
+ }
}
+
+ return InstPair(
+ primalConstruct,
+ builder->emitIntrinsicInst(
+ diffConstructType,
+ origConstruct->getOp(),
+ operandCount,
+ diffOperands.getBuffer()));
+ }
+ else
+ {
+ return InstPair(primalConstruct, nullptr);
+ }
+}
- // Since concrete functions are globals, the primal callee is the same
- // as the original callee.
+// Differentiating a call instruction here is primarily about generating
+// an appropriate call list based on whichever parameters have differentials
+// in the current transcription context.
+//
+InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall* origCall)
+{
+
+ IRInst* origCallee = origCall->getCallee();
+
+ if (!origCallee)
+ {
+ // Note that this can only happen if the callee is a result
+ // of a higher-order operation. For now, we assume that we cannot
+ // differentiate such calls safely.
+ // TODO(sai): Should probably get checked in the front-end.
//
- auto primalCallee = origCallee;
+ getSink()->diagnose(origCall->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "attempting to differentiate unresolved callee");
+
+ return InstPair(nullptr, nullptr);
+ }
- IRInst* diffCallee = nullptr;
+ // Since concrete functions are globals, the primal callee is the same
+ // as the original callee.
+ //
+ auto primalCallee = origCallee;
- if (instMapD.TryGetValue(origCallee, diffCallee))
- {
- }
- else if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>())
- {
- // If the user has already provided an differentiated implementation, use that.
- diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc();
- }
- else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>())
- {
- // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
- // to generate the implementation.
- diffCallee = builder->emitForwardDifferentiateInst(
- differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
- primalCallee);
- }
- else
- {
- // The callee is non differentiable, just return primal value with null diff value.
- IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
- return InstPair(primalCall, nullptr);
- }
+ IRInst* diffCallee = nullptr;
- List<IRInst*> args;
- // Go over the parameter list and create pairs for each input (if required)
- for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
- {
- auto origArg = origCall->getArg(ii);
- auto primalArg = findOrTranscribePrimalInst(builder, origArg);
- SLANG_ASSERT(primalArg);
+ if (instMapD.TryGetValue(origCallee, diffCallee))
+ {
+ }
+ else if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>())
+ {
+ // If the user has already provided an differentiated implementation, use that.
+ diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc();
+ }
+ else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>())
+ {
+ // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
+ // to generate the implementation.
+ diffCallee = builder->emitForwardDifferentiateInst(
+ differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
+ primalCallee);
+ }
+ else
+ {
+ // The callee is non differentiable, just return primal value with null diff value.
+ IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
+ return InstPair(primalCall, nullptr);
+ }
+
+ List<IRInst*> args;
+ // Go over the parameter list and create pairs for each input (if required)
+ for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
+ {
+ auto origArg = origCall->getArg(ii);
+ auto primalArg = findOrTranscribePrimalInst(builder, origArg);
+ SLANG_ASSERT(primalArg);
auto primalType = primalArg->getDataType();
if (auto pairType = tryGetDiffPairType(builder, primalType))
@@ -762,218 +724,218 @@ struct JVPTranscriber
IRType* diffReturnType = nullptr;
diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
- if (!diffReturnType)
- {
- SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType);
- diffReturnType = builder->getVoidType();
- }
+ if (!diffReturnType)
+ {
+ SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType);
+ diffReturnType = builder->getVoidType();
+ }
- auto callInst = builder->emitCallInst(
- diffReturnType,
- diffCallee,
- args);
+ auto callInst = builder->emitCallInst(
+ diffReturnType,
+ diffCallee,
+ args);
- if (diffReturnType->getOp() != kIROp_VoidType)
- {
- IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst);
- auto diffType = differentiateType(builder, origCall->getFullType());
- IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst);
- return InstPair(primalResultValue, diffResultValue);
- }
- else
- {
- // Return the inst itself if the return value is void.
- // This is fine since these values should never actually be used anywhere.
- //
- return InstPair(callInst, callInst);
- }
+ if (diffReturnType->getOp() != kIROp_VoidType)
+ {
+ IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst);
+ auto diffType = differentiateType(builder, origCall->getFullType());
+ IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst);
+ return InstPair(primalResultValue, diffResultValue);
}
-
- InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
+ else
{
- IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle);
+ // Return the inst itself if the return value is void.
+ // This is fine since these values should never actually be used anywhere.
+ //
+ return InstPair(callInst, callInst);
+ }
+}
- if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
- {
- List<IRInst*> swizzleIndices;
- for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
- swizzleIndices.add(origSwizzle->getElementIndex(ii));
-
- return InstPair(
- primalSwizzle,
- builder->emitSwizzle(
- differentiateType(builder, primalSwizzle->getDataType()),
- diffBase,
- origSwizzle->getElementCount(),
- swizzleIndices.getBuffer()));
- }
+InstPair ForwardDerivativeTranscriber::transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
+{
+ IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle);
- return InstPair(primalSwizzle, nullptr);
+ if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
+ {
+ List<IRInst*> swizzleIndices;
+ for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
+ swizzleIndices.add(origSwizzle->getElementIndex(ii));
+
+ return InstPair(
+ primalSwizzle,
+ builder->emitSwizzle(
+ differentiateType(builder, primalSwizzle->getDataType()),
+ diffBase,
+ origSwizzle->getElementCount(),
+ swizzleIndices.getBuffer()));
}
- InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst)
- {
- IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst);
+ return InstPair(primalSwizzle, nullptr);
+}
- UCount operandCount = origInst->getOperandCount();
+InstPair ForwardDerivativeTranscriber::transcribeByPassthrough(IRBuilder* builder, IRInst* origInst)
+{
+ IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst);
- List<IRInst*> diffOperands;
- for (UIndex ii = 0; ii < operandCount; ii++)
- {
- // If the operand has a differential version, replace the original with the
- // differential.
- // Otherwise, abandon the differentiation attempt and assume that origInst
- // cannot (or does not need to) be differentiated.
- //
- if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr))
- diffOperands.add(diffInst);
- else
- return InstPair(primalInst, nullptr);
- }
-
- return InstPair(
- primalInst,
- builder->emitIntrinsicInst(
- differentiateType(builder, primalInst->getDataType()),
- origInst->getOp(),
- operandCount,
- diffOperands.getBuffer()));
+ UCount operandCount = origInst->getOperandCount();
+
+ List<IRInst*> diffOperands;
+ for (UIndex ii = 0; ii < operandCount; ii++)
+ {
+ // If the operand has a differential version, replace the original with the
+ // differential.
+ // Otherwise, abandon the differentiation attempt and assume that origInst
+ // cannot (or does not need to) be differentiated.
+ //
+ if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr))
+ diffOperands.add(diffInst);
+ else
+ return InstPair(primalInst, nullptr);
}
+
+ return InstPair(
+ primalInst,
+ builder->emitIntrinsicInst(
+ differentiateType(builder, primalInst->getDataType()),
+ origInst->getOp(),
+ operandCount,
+ diffOperands.getBuffer()));
+}
- InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDerivativeTranscriber::transcribeControlFlow(IRBuilder* builder, IRInst* origInst)
+{
+ switch(origInst->getOp())
{
- switch(origInst->getOp())
- {
- case kIROp_unconditionalBranch:
- case kIROp_loop:
- auto origBranch = as<IRUnconditionalBranch>(origInst);
+ case kIROp_unconditionalBranch:
+ case kIROp_loop:
+ auto origBranch = as<IRUnconditionalBranch>(origInst);
- // Grab the differentials for any phi nodes.
- List<IRInst*> newArgs;
- for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
- {
- auto origArg = origBranch->getArg(ii);
- auto primalArg = lookupPrimalInst(origArg);
- newArgs.add(primalArg);
+ // Grab the differentials for any phi nodes.
+ List<IRInst*> newArgs;
+ for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
+ {
+ auto origArg = origBranch->getArg(ii);
+ auto primalArg = lookupPrimalInst(origArg);
+ newArgs.add(primalArg);
- if (differentiateType(builder, primalArg->getDataType()))
- {
- auto diffArg = lookupDiffInst(origArg, nullptr);
- if (diffArg)
- newArgs.add(diffArg);
- }
+ if (differentiateType(builder, primalArg->getDataType()))
+ {
+ auto diffArg = lookupDiffInst(origArg, nullptr);
+ if (diffArg)
+ newArgs.add(diffArg);
}
+ }
- IRInst* diffBranch = nullptr;
- if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock()))
+ IRInst* diffBranch = nullptr;
+ if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock()))
+ {
+ if (auto origLoop = as<IRLoop>(origInst))
{
- if (auto origLoop = as<IRLoop>(origInst))
- {
- auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
- auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
- List<IRInst*> operands;
- operands.add(breakBlock);
- operands.add(continueBlock);
- operands.addRange(newArgs);
- diffBranch = builder->emitIntrinsicInst(
- nullptr,
- kIROp_loop,
- operands.getCount(),
- operands.getBuffer());
- }
- else
- {
- diffBranch = builder->emitBranch(
- as<IRBlock>(diffBlock),
- newArgs.getCount(),
- newArgs.getBuffer());
- }
+ auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
+ auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
+ List<IRInst*> operands;
+ operands.add(breakBlock);
+ operands.add(continueBlock);
+ operands.addRange(newArgs);
+ diffBranch = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_loop,
+ operands.getCount(),
+ operands.getBuffer());
}
+ else
+ {
+ diffBranch = builder->emitBranch(
+ as<IRBlock>(diffBlock),
+ newArgs.getCount(),
+ newArgs.getBuffer());
+ }
+ }
- // For now, every block in the original fn must have a corresponding
- // block to compute *both* primals and derivatives (i.e linearized block)
- SLANG_ASSERT(diffBranch);
+ // For now, every block in the original fn must have a corresponding
+ // block to compute *both* primals and derivatives (i.e linearized block)
+ SLANG_ASSERT(diffBranch);
- return InstPair(diffBranch, diffBranch);
- }
+ return InstPair(diffBranch, diffBranch);
+ }
- getSink()->diagnose(
- origInst->sourceLoc,
- Diagnostics::unimplemented,
- "attempting to differentiate unhandled control flow");
+ getSink()->diagnose(
+ origInst->sourceLoc,
+ Diagnostics::unimplemented,
+ "attempting to differentiate unhandled control flow");
- return InstPair(nullptr, nullptr);
- }
+ return InstPair(nullptr, nullptr);
+}
- InstPair transcribeConst(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDerivativeTranscriber::transcribeConst(IRBuilder* builder, IRInst* origInst)
+{
+ switch(origInst->getOp())
{
- switch(origInst->getOp())
- {
- case kIROp_FloatLit:
- return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f));
- case kIROp_VoidLit:
- return InstPair(origInst, origInst);
- case kIROp_IntLit:
- return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0));
- }
+ case kIROp_FloatLit:
+ return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f));
+ case kIROp_VoidLit:
+ return InstPair(origInst, origInst);
+ case kIROp_IntLit:
+ return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0));
+ }
- getSink()->diagnose(
- origInst->sourceLoc,
- Diagnostics::unimplemented,
- "attempting to differentiate unhandled const type");
+ getSink()->diagnose(
+ origInst->sourceLoc,
+ Diagnostics::unimplemented,
+ "attempting to differentiate unhandled const type");
- return InstPair(nullptr, nullptr);
- }
+ return InstPair(nullptr, nullptr);
+}
- IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
+IRInst* ForwardDerivativeTranscriber::findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
+{
+ for (UInt i = 0; i < type->getOperandCount(); i++)
{
- for (UInt i = 0; i < type->getOperandCount(); i++)
+ if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i)))
{
- if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i)))
- {
- if (req->getRequirementKey() == key)
- return req->getRequirementVal();
- }
+ if (req->getRequirementKey() == key)
+ return req->getRequirementVal();
}
- return nullptr;
}
+ return nullptr;
+}
- InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
+InstPair ForwardDerivativeTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
+{
+ auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase());
+ List<IRInst*> primalArgs;
+ for (UInt i = 0; i < origSpecialize->getArgCount(); i++)
{
- auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase());
- List<IRInst*> primalArgs;
- for (UInt i = 0; i < origSpecialize->getArgCount(); i++)
- {
- primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i)));
- }
- auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType());
- auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst(
- (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer());
+ primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i)));
+ }
+ auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType());
+ auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst(
+ (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer());
- IRInst* diffBase = nullptr;
- if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase))
+ IRInst* diffBase = nullptr;
+ if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase))
+ {
+ List<IRInst*> args;
+ for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
{
- List<IRInst*> args;
- for (UInt i = 0; i < primalSpecialize->getArgCount(); i++)
- {
- args.add(primalSpecialize->getArg(i));
- }
- auto diffSpecialize = builder->emitSpecializeInst(
- builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
- return InstPair(primalSpecialize, diffSpecialize);
+ args.add(primalSpecialize->getArg(i));
}
+ auto diffSpecialize = builder->emitSpecializeInst(
+ builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer());
+ return InstPair(primalSpecialize, diffSpecialize);
+ }
- auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
- // Look for an IRForwardDerivativeDecoration on the specialize inst.
- // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
- // can be specialized, so we put a decoration on the IRSpecialize)
- //
- if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
- {
- auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc();
+ auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
+ // Look for an IRForwardDerivativeDecoration on the specialize inst.
+ // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
+ // can be specialized, so we put a decoration on the IRSpecialize)
+ //
+ if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>())
+ {
+ auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc();
- // Make sure this isn't itself a specialize .
- SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
+ // Make sure this isn't itself a specialize .
+ SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
return InstPair(primalSpecialize, jvpFunc);
}
@@ -1007,611 +969,610 @@ struct JVPTranscriber
}
}
- InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst)
- {
- auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable());
- auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey());
- auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType());
- auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey);
+InstPair ForwardDerivativeTranscriber::transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst)
+{
+ auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable());
+ auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey());
+ auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType());
+ auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey);
- auto interfaceType = as<IRInterfaceType>(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType());
- if (!interfaceType)
- {
- return InstPair(primal, nullptr);
- }
- auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
- if (!dict)
- {
- return InstPair(primal, nullptr);
- }
+ auto interfaceType = as<IRInterfaceType>(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType());
+ if (!interfaceType)
+ {
+ return InstPair(primal, nullptr);
+ }
+ auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>();
+ if (!dict)
+ {
+ return InstPair(primal, nullptr);
+ }
- for (auto child : dict->getChildren())
+ for (auto child : dict->getChildren())
+ {
+ if (auto item = as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child))
{
- if (auto item = as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child))
+ if (item->getOperand(0) == lookupInst->getRequirementKey())
{
- if (item->getOperand(0) == lookupInst->getRequirementKey())
+ auto diffKey = item->getOperand(1);
+ if (auto diffType = findInterfaceRequirement(interfaceType, diffKey))
{
- auto diffKey = item->getOperand(1);
- if (auto diffType = findInterfaceRequirement(interfaceType, diffKey))
- {
- auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey);
- return InstPair(primal, diff);
- }
- break;
+ auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey);
+ return InstPair(primal, diff);
}
+ break;
}
}
- return InstPair(primal, nullptr);
}
+ return InstPair(primal, nullptr);
+}
- // In differential computation, the 'default' differential value is always zero.
- // This is a consequence of differential computing being inherently linear. As a
- // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
- // The current implementation requires that types are defined linearly.
- //
- IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+// In differential computation, the 'default' differential value is always zero.
+// This is a consequence of differential computing being inherently linear. As a
+// result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+// The current implementation requires that types are defined linearly.
+//
+IRInst* ForwardDerivativeTranscriber::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+{
+ if (auto diffType = differentiateType(builder, primalType))
{
- if (auto diffType = differentiateType(builder, primalType))
+ switch (diffType->getOp())
{
- switch (diffType->getOp())
- {
- case kIROp_DifferentialPairType:
- return builder->emitMakeDifferentialPair(
- diffType,
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
- getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
- }
- // Since primalType has a corresponding differential type, we can lookup the
- // definition for zero().
- auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
- SLANG_ASSERT(zeroMethod);
-
- auto emptyArgList = List<IRInst*>();
- return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
- }
- else
+ case kIROp_DifferentialPairType:
+ return builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ }
+ // Since primalType has a corresponding differential type, we can lookup the
+ // definition for zero().
+ auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
+ SLANG_ASSERT(zeroMethod);
+
+ auto emptyArgList = List<IRInst*>();
+ return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ }
+ else
+ {
+ if (isScalarIntegerType(primalType))
{
- if (isScalarIntegerType(primalType))
- {
- return builder->getIntValue(primalType, 0);
- }
-
- getSink()->diagnose(primalType->sourceLoc,
- Diagnostics::internalCompilerError,
- "could not generate zero value for given type");
- return nullptr;
+ return builder->getIntValue(primalType, 0);
}
+
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
}
+}
- InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
- {
- IRBuilder subBuilder(builder->getSharedBuilder());
- subBuilder.setInsertLoc(builder->getInsertLoc());
-
- IRInst* diffBlock = subBuilder.emitBlock();
-
- // Note: for blocks, we setup the mapping _before_
- // processing the children since we could encounter
- // a lookup while processing the children.
- //
- mapPrimalInst(origBlock, diffBlock);
- mapDifferentialInst(origBlock, diffBlock);
+InstPair ForwardDerivativeTranscriber::transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+{
+ IRBuilder subBuilder(builder->getSharedBuilder());
+ subBuilder.setInsertLoc(builder->getInsertLoc());
+
+ IRInst* diffBlock = subBuilder.emitBlock();
+
+ // Note: for blocks, we setup the mapping _before_
+ // processing the children since we could encounter
+ // a lookup while processing the children.
+ //
+ mapPrimalInst(origBlock, diffBlock);
+ mapDifferentialInst(origBlock, diffBlock);
- subBuilder.setInsertInto(diffBlock);
+ subBuilder.setInsertInto(diffBlock);
- // First transcribe every parameter in the block.
- for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
- this->transcribe(&subBuilder, param);
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ this->transcribe(&subBuilder, param);
- // Then, run through every instruction and use the transcriber to generate the appropriate
- // derivative code.
- //
- for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
- this->transcribe(&subBuilder, child);
+ // Then, run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ this->transcribe(&subBuilder, child);
- return InstPair(diffBlock, diffBlock);
- }
+ return InstPair(diffBlock, diffBlock);
+}
- InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
+InstPair ForwardDerivativeTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
+{
+ SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst));
+
+ IRInst* origBase = originalInst->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto field = originalInst->getOperand(1);
+ auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>();
+ auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType());
+
+ IRInst* primalOperands[] = { primalBase, field };
+ IRInst* primalFieldExtract = builder->emitIntrinsicInst(
+ primalType,
+ originalInst->getOp(),
+ 2,
+ primalOperands);
+
+ if (!derivativeRefDecor)
{
- SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst));
-
- IRInst* origBase = originalInst->getOperand(0);
- auto primalBase = findOrTranscribePrimalInst(builder, origBase);
- auto field = originalInst->getOperand(1);
- auto derivativeRefDecor = field->findDecoration<IRDerivativeMemberDecoration>();
- auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType());
+ return InstPair(primalFieldExtract, nullptr);
+ }
- IRInst* primalOperands[] = { primalBase, field };
- IRInst* primalFieldExtract = builder->emitIntrinsicInst(
- primalType,
- originalInst->getOp(),
- 2,
- primalOperands);
+ IRInst* diffFieldExtract = nullptr;
- if (!derivativeRefDecor)
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
{
- return InstPair(primalFieldExtract, nullptr);
+ IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() };
+ diffFieldExtract = builder->emitIntrinsicInst(
+ diffType,
+ originalInst->getOp(),
+ 2,
+ diffOperands);
}
-
- IRInst* diffFieldExtract = nullptr;
-
- if (auto diffType = differentiateType(builder, primalType))
- {
- if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
- {
- IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() };
- diffFieldExtract = builder->emitIntrinsicInst(
- diffType,
- originalInst->getOp(),
- 2,
- diffOperands);
- }
- }
- return InstPair(primalFieldExtract, diffFieldExtract);
}
+ return InstPair(primalFieldExtract, diffFieldExtract);
+}
- InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
- {
- SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr));
+InstPair ForwardDerivativeTranscriber::transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
+{
+ SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr));
- IRInst* origBase = origGetElementPtr->getOperand(0);
- auto primalBase = findOrTranscribePrimalInst(builder, origBase);
- auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1));
+ IRInst* origBase = origGetElementPtr->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1));
- auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType());
+ auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType());
- IRInst* primalOperands[] = {primalBase, primalIndex};
- IRInst* primalGetElementPtr = builder->emitIntrinsicInst(
- primalType,
- origGetElementPtr->getOp(),
- 2,
- primalOperands);
+ IRInst* primalOperands[] = {primalBase, primalIndex};
+ IRInst* primalGetElementPtr = builder->emitIntrinsicInst(
+ primalType,
+ origGetElementPtr->getOp(),
+ 2,
+ primalOperands);
- IRInst* diffGetElementPtr = nullptr;
+ IRInst* diffGetElementPtr = nullptr;
- if (auto diffType = differentiateType(builder, primalType))
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
{
- if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
- {
- IRInst* diffOperands[] = {diffBase, primalIndex};
- diffGetElementPtr = builder->emitIntrinsicInst(
- diffType,
- origGetElementPtr->getOp(),
- 2,
- diffOperands);
- }
+ IRInst* diffOperands[] = {diffBase, primalIndex};
+ diffGetElementPtr = builder->emitIntrinsicInst(
+ diffType,
+ origGetElementPtr->getOp(),
+ 2,
+ diffOperands);
}
-
- return InstPair(primalGetElementPtr, diffGetElementPtr);
}
- InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
- {
- // The loop comes with three blocks.. we just need to transcribe each one
- // and assemble the new loop instruction.
-
- // Transcribe the target block (this is the 'condition' part of the loop, which
- // will branch into the loop body)
- auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock());
-
- // Transcribe the break block (this is the block after the exiting the loop)
- auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
-
- // Transcribe the continue block (this is the 'update' part of the loop, which will
- // branch into the condition block)
- auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
+ return InstPair(primalGetElementPtr, diffGetElementPtr);
+}
-
- List<IRInst*> diffLoopOperands;
- diffLoopOperands.add(diffTargetBlock);
- diffLoopOperands.add(diffBreakBlock);
- diffLoopOperands.add(diffContinueBlock);
+InstPair ForwardDerivativeTranscriber::transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
+{
+ // The loop comes with three blocks.. we just need to transcribe each one
+ // and assemble the new loop instruction.
+
+ // Transcribe the target block (this is the 'condition' part of the loop, which
+ // will branch into the loop body)
+ auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock());
- // If there are any other operands, use their primal versions.
- for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++)
- {
- auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii));
- diffLoopOperands.add(primalOperand);
- }
+ // Transcribe the break block (this is the block after the exiting the loop)
+ auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
- IRInst* diffLoop = builder->emitIntrinsicInst(
- nullptr,
- kIROp_loop,
- diffLoopOperands.getCount(),
- diffLoopOperands.getBuffer());
+ // Transcribe the continue block (this is the 'update' part of the loop, which will
+ // branch into the condition block)
+ auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
- return InstPair(diffLoop, diffLoop);
- }
+
+ List<IRInst*> diffLoopOperands;
+ diffLoopOperands.add(diffTargetBlock);
+ diffLoopOperands.add(diffBreakBlock);
+ diffLoopOperands.add(diffContinueBlock);
- InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
+ // If there are any other operands, use their primal versions.
+ for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++)
{
- // IfElse Statements come with 4 blocks. We transcribe each block into it's
- // linear form, and then wire them up in the same way as the original if-else
-
- // Transcribe condition block
- auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition());
- SLANG_ASSERT(primalConditionBlock);
-
- // Transcribe 'true' block (condition block branches into this if true)
- auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock());
- SLANG_ASSERT(diffTrueBlock);
-
- // Transcribe 'false' block (condition block branches into this if true)
- // TODO (sai): What happens if there's no false block?
- auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock());
- SLANG_ASSERT(diffFalseBlock);
-
- // Transcribe 'after' block (true and false blocks branch into this)
- auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock());
- SLANG_ASSERT(diffAfterBlock);
-
- List<IRInst*> diffIfElseArgs;
- diffIfElseArgs.add(primalConditionBlock);
- diffIfElseArgs.add(diffTrueBlock);
- diffIfElseArgs.add(diffFalseBlock);
- diffIfElseArgs.add(diffAfterBlock);
-
- // If there are any other operands, use their primal versions.
- for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++)
- {
- auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii));
- diffIfElseArgs.add(primalOperand);
- }
-
- IRInst* diffLoop = builder->emitIntrinsicInst(
- nullptr,
- kIROp_ifElse,
- diffIfElseArgs.getCount(),
- diffIfElseArgs.getBuffer());
-
- return InstPair(diffLoop, diffLoop);
+ auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii));
+ diffLoopOperands.add(primalOperand);
}
- InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst)
- {
- auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue());
- SLANG_ASSERT(primalVal);
- auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue());
- SLANG_ASSERT(diffPrimalVal);
- auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue());
- SLANG_ASSERT(primalDiffVal);
- auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
- SLANG_ASSERT(diffDiffVal);
+ IRInst* diffLoop = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_loop,
+ diffLoopOperands.getCount(),
+ diffLoopOperands.getBuffer());
- auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal);
- auto diffPair = builder->emitMakeDifferentialPair(
- differentiateType(builder, origInst->getDataType()),
- primalDiffVal,
- diffDiffVal);
- return InstPair(primalPair, diffPair);
- }
+ return InstPair(diffLoop, diffLoop);
+}
- InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDerivativeTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
+{
+ // IfElse Statements come with 4 blocks. We transcribe each block into it's
+ // linear form, and then wire them up in the same way as the original if-else
+
+ // Transcribe condition block
+ auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition());
+ SLANG_ASSERT(primalConditionBlock);
+
+ // Transcribe 'true' block (condition block branches into this if true)
+ auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock());
+ SLANG_ASSERT(diffTrueBlock);
+
+ // Transcribe 'false' block (condition block branches into this if true)
+ // TODO (sai): What happens if there's no false block?
+ auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock());
+ SLANG_ASSERT(diffFalseBlock);
+
+ // Transcribe 'after' block (true and false blocks branch into this)
+ auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock());
+ SLANG_ASSERT(diffAfterBlock);
+
+ List<IRInst*> diffIfElseArgs;
+ diffIfElseArgs.add(primalConditionBlock);
+ diffIfElseArgs.add(diffTrueBlock);
+ diffIfElseArgs.add(diffFalseBlock);
+ diffIfElseArgs.add(diffAfterBlock);
+
+ // If there are any other operands, use their primal versions.
+ for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++)
{
- auto primal = cloneInst(&cloneEnv, builder, origInst);
- return InstPair(primal, nullptr);
+ auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii));
+ diffIfElseArgs.add(primalOperand);
}
- InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
- {
- SLANG_ASSERT(
- origInst->getOp() == kIROp_DifferentialPairGetDifferential ||
- origInst->getOp() == kIROp_DifferentialPairGetPrimal);
+ IRInst* diffLoop = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_ifElse,
+ diffIfElseArgs.getCount(),
+ diffIfElseArgs.getBuffer());
- auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0));
- SLANG_ASSERT(primalVal);
-
- auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0));
- SLANG_ASSERT(diffVal);
+ return InstPair(diffLoop, diffLoop);
+}
- auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal);
+InstPair ForwardDerivativeTranscriber::transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst)
+{
+ auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue());
+ SLANG_ASSERT(primalVal);
+ auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue());
+ SLANG_ASSERT(diffPrimalVal);
+ auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue());
+ SLANG_ASSERT(primalDiffVal);
+ auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
+ SLANG_ASSERT(diffDiffVal);
+
+ auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal);
+ auto diffPair = builder->emitMakeDifferentialPair(
+ differentiateType(builder, origInst->getDataType()),
+ primalDiffVal,
+ diffDiffVal);
+ return InstPair(primalPair, diffPair);
+}
- auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType());
- IRInst* diffResultType = nullptr;
- if (origInst->getOp() == kIROp_DifferentialPairGetDifferential)
- diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType);
- else
- diffResultType = diffValPairType->getValueType();
- auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal);
- return InstPair(primalResult, diffResult);
- }
+InstPair ForwardDerivativeTranscriber::trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst)
+{
+ auto primal = cloneInst(&cloneEnv, builder, origInst);
+ return InstPair(primal, nullptr);
+}
- // Create an empty func to represent the transcribed func of `origFunc`.
- InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
- {
- IRBuilder builder(inBuilder->getSharedBuilder());
- builder.setInsertBefore(origFunc);
+InstPair ForwardDerivativeTranscriber::transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
+{
+ SLANG_ASSERT(
+ origInst->getOp() == kIROp_DifferentialPairGetDifferential ||
+ origInst->getOp() == kIROp_DifferentialPairGetPrimal);
+
+ auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0));
+ SLANG_ASSERT(primalVal);
+
+ auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0));
+ SLANG_ASSERT(diffVal);
+
+ auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal);
+
+ auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType());
+ IRInst* diffResultType = nullptr;
+ if (origInst->getOp() == kIROp_DifferentialPairGetDifferential)
+ diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType);
+ else
+ diffResultType = diffValPairType->getValueType();
+ auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal);
+ return InstPair(primalResult, diffResult);
+}
- IRFunc* primalFunc = origFunc;
+// Create an empty func to represent the transcribed func of `origFunc`.
+InstPair ForwardDerivativeTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+{
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origFunc);
- differentiableTypeConformanceContext.setFunc(origFunc);
+ IRFunc* primalFunc = origFunc;
- primalFunc = origFunc;
+ differentiableTypeConformanceContext.setFunc(origFunc);
- auto diffFunc = builder.createFunc();
+ primalFunc = origFunc;
- SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
- IRType* diffFuncType = this->differentiateFunctionType(
- &builder,
- as<IRFuncType>(origFunc->getFullType()));
- diffFunc->setFullType(diffFuncType);
+ auto diffFunc = builder.createFunc();
- if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
- {
- auto originalName = nameHint->getName();
- StringBuilder newNameSb;
- newNameSb << "s_fwd_" << originalName;
- builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
- }
- builder.addForwardDerivativeDecoration(origFunc, diffFunc);
+ SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
+ IRType* diffFuncType = this->differentiateFunctionType(
+ &builder,
+ as<IRFuncType>(origFunc->getFullType()));
+ diffFunc->setFullType(diffFuncType);
- // Mark the generated derivative function itself as differentiable.
- builder.addForwardDifferentiableDecoration(diffFunc);
+ if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
+ {
+ auto originalName = nameHint->getName();
+ StringBuilder newNameSb;
+ newNameSb << "s_fwd_" << originalName;
+ builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ }
+ builder.addForwardDerivativeDecoration(origFunc, diffFunc);
- // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
- if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- cloneDecoration(dictDecor, diffFunc);
- }
+ // Mark the generated derivative function itself as differentiable.
+ builder.addForwardDifferentiableDecoration(diffFunc);
- auto result = InstPair(primalFunc, diffFunc);
- followUpFunctionsToTranscribe.add(result);
- return result;
+ // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
+ if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ {
+ cloneDecoration(dictDecor, diffFunc);
}
- // Transcribe a function definition.
- InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
- {
- IRBuilder builder(inBuilder->getSharedBuilder());
- builder.setInsertInto(diffFunc);
+ auto result = InstPair(primalFunc, diffFunc);
+ followUpFunctionsToTranscribe.add(result);
+ return result;
+}
- differentiableTypeConformanceContext.setFunc(primalFunc);
- // Transcribe children from origFunc into diffFunc
- for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
- this->transcribe(&builder, block);
+// Transcribe a function definition.
+InstPair ForwardDerivativeTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
+{
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertInto(diffFunc);
- return InstPair(primalFunc, diffFunc);
- }
+ differentiableTypeConformanceContext.setFunc(primalFunc);
+ // Transcribe children from origFunc into diffFunc
+ for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribe(&builder, block);
+
+ return InstPair(primalFunc, diffFunc);
+}
- // Transcribe a generic definition
- InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric)
+// Transcribe a generic definition
+InstPair ForwardDerivativeTranscriber::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric)
+{
+ auto innerVal = findInnerMostGenericReturnVal(origGeneric);
+ if (auto innerFunc = as<IRFunc>(innerVal))
{
- auto innerVal = findInnerMostGenericReturnVal(origGeneric);
- if (auto innerFunc = as<IRFunc>(innerVal))
- {
- differentiableTypeConformanceContext.setFunc(innerFunc);
- }
- else if (auto funcType = as<IRFuncType>(innerVal))
- {
- }
- else
- {
- return InstPair(origGeneric, nullptr);
- }
+ differentiableTypeConformanceContext.setFunc(innerFunc);
+ }
+ else if (auto funcType = as<IRFuncType>(innerVal))
+ {
+ }
+ else
+ {
+ return InstPair(origGeneric, nullptr);
+ }
- IRGeneric* primalGeneric = origGeneric;
+ IRGeneric* primalGeneric = origGeneric;
- IRBuilder builder(inBuilder->getSharedBuilder());
- builder.setInsertBefore(origGeneric);
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origGeneric);
- auto diffGeneric = builder.emitGeneric();
+ auto diffGeneric = builder.emitGeneric();
- // Process type of generic. If the generic is a function, then it's type will also be a
- // generic and this logic will transcribe that generic first before continuing with the
- // function itself.
- //
- auto primalType = primalGeneric->getFullType();
+ // Process type of generic. If the generic is a function, then it's type will also be a
+ // generic and this logic will transcribe that generic first before continuing with the
+ // function itself.
+ //
+ auto primalType = primalGeneric->getFullType();
- IRType* diffType = nullptr;
- if (primalType)
- {
- diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType);
- }
+ IRType* diffType = nullptr;
+ if (primalType)
+ {
+ diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType);
+ }
- diffGeneric->setFullType(diffType);
+ diffGeneric->setFullType(diffType);
// Transcribe children from origFunc into diffFunc.
builder.setInsertInto(diffGeneric);
for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
this->transcribe(&builder, block);
- return InstPair(primalGeneric, diffGeneric);
- }
+ return InstPair(primalGeneric, diffGeneric);
+}
- IRInst* transcribe(IRBuilder* builder, IRInst* origInst)
+IRInst* ForwardDerivativeTranscriber::transcribe(IRBuilder* builder, IRInst* origInst)
+{
+ // If a differential intstruction is already mapped for
+ // this original inst, return that.
+ //
+ if (auto diffInst = lookupDiffInst(origInst, nullptr))
{
- // If a differential intstruction is already mapped for
- // this original inst, return that.
- //
- if (auto diffInst = lookupDiffInst(origInst, nullptr))
- {
- SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check.
- return diffInst;
- }
+ SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check.
+ return diffInst;
+ }
- // Otherwise, dispatch to the appropriate method
- // depending on the op-code.
- //
- instsInProgress.Add(origInst);
- InstPair pair = transcribeInst(builder, origInst);
- instsInProgress.Remove(origInst);
+ // Otherwise, dispatch to the appropriate method
+ // depending on the op-code.
+ //
+ instsInProgress.Add(origInst);
+ InstPair pair = transcribeInst(builder, origInst);
+ instsInProgress.Remove(origInst);
- if (auto primalInst = pair.primal)
+ if (auto primalInst = pair.primal)
+ {
+ mapPrimalInst(origInst, pair.primal);
+ mapDifferentialInst(origInst, pair.differential);
+ if (pair.differential)
{
- mapPrimalInst(origInst, pair.primal);
- mapDifferentialInst(origInst, pair.differential);
- if (pair.differential)
+ switch (pair.differential->getOp())
{
- switch (pair.differential->getOp())
+ case kIROp_Func:
+ case kIROp_Generic:
+ case kIROp_Block:
+ // Don't generate again for these.
+ // Functions already have their names generated in `transcribeFuncHeader`.
+ break;
+ default:
+ // Generate name hint for the inst.
+ if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
{
- case kIROp_Func:
- case kIROp_Generic:
- case kIROp_Block:
- // Don't generate again for these.
- // Functions already have their names generated in `transcribeFuncHeader`.
- break;
- default:
- // Generate name hint for the inst.
- if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
- {
- StringBuilder sb;
- sb << "s_diff_" << primalNameHint->getName();
- builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
- }
- break;
+ StringBuilder sb;
+ sb << "s_diff_" << primalNameHint->getName();
+ builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
}
+ break;
}
- return pair.differential;
}
- getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "failed to transcibe instruction");
- return nullptr;
+ return pair.differential;
}
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "failed to transcibe instruction");
+ return nullptr;
+}
- InstPair transcribeInst(IRBuilder* builder, IRInst* origInst)
+InstPair ForwardDerivativeTranscriber::transcribeInst(IRBuilder* builder, IRInst* origInst)
+{
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
{
- // Handle common SSA-style operations
- switch (origInst->getOp())
- {
- case kIROp_Param:
- return transcribeParam(builder, as<IRParam>(origInst));
-
- case kIROp_Var:
- return transcribeVar(builder, as<IRVar>(origInst));
-
- case kIROp_Load:
- return transcribeLoad(builder, as<IRLoad>(origInst));
-
- case kIROp_Store:
- return transcribeStore(builder, as<IRStore>(origInst));
+ case kIROp_Param:
+ return transcribeParam(builder, as<IRParam>(origInst));
- case kIROp_Return:
- return transcribeReturn(builder, as<IRReturn>(origInst));
-
- case kIROp_Add:
- case kIROp_Mul:
- case kIROp_Sub:
- case kIROp_Div:
- return transcribeBinaryArith(builder, origInst);
-
- case kIROp_Less:
- case kIROp_Greater:
- case kIROp_And:
- case kIROp_Or:
- case kIROp_Geq:
- case kIROp_Leq:
- return transcribeBinaryLogic(builder, origInst);
-
- case kIROp_Construct:
- return transcribeConstruct(builder, origInst);
-
- case kIROp_lookup_interface_method:
- return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
-
- case kIROp_Call:
- return transcribeCall(builder, as<IRCall>(origInst));
-
- case kIROp_swizzle:
- return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
-
- case kIROp_constructVectorFromScalar:
- case kIROp_MakeTuple:
- return transcribeByPassthrough(builder, origInst);
+ case kIROp_Var:
+ return transcribeVar(builder, as<IRVar>(origInst));
- case kIROp_unconditionalBranch:
- return transcribeControlFlow(builder, origInst);
+ case kIROp_Load:
+ return transcribeLoad(builder, as<IRLoad>(origInst));
- case kIROp_FloatLit:
- case kIROp_IntLit:
- case kIROp_VoidLit:
- return transcribeConst(builder, origInst);
+ case kIROp_Store:
+ return transcribeStore(builder, as<IRStore>(origInst));
- case kIROp_Specialize:
- return transcribeSpecialize(builder, as<IRSpecialize>(origInst));
+ case kIROp_Return:
+ return transcribeReturn(builder, as<IRReturn>(origInst));
- case kIROp_FieldExtract:
- case kIROp_FieldAddress:
- return transcribeFieldExtract(builder, origInst);
- case kIROp_getElement:
- case kIROp_getElementPtr:
- return transcribeGetElement(builder, origInst);
-
- case kIROp_loop:
- return transcribeLoop(builder, as<IRLoop>(origInst));
-
- case kIROp_ifElse:
- return transcribeIfElse(builder, as<IRIfElse>(origInst));
-
- case kIROp_MakeDifferentialPair:
- return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst));
- case kIROp_DifferentialPairGetPrimal:
- case kIROp_DifferentialPairGetDifferential:
- return transcribeDifferentialPairGetElement(builder, origInst);
- case kIROp_ExtractExistentialWitnessTable:
- case kIROp_ExtractExistentialType:
- case kIROp_ExtractExistentialValue:
- case kIROp_WrapExistential:
- case kIROp_MakeExistential:
- case kIROp_MakeExistentialWithRTTI:
- return trascribeNonDiffInst(builder, origInst);
- case kIROp_StructKey:
- return InstPair(origInst, nullptr);
- }
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return transcribeBinaryArith(builder, origInst);
- // If none of the cases have been hit, check if the instruction is a
- // type. Only need to explicitly differentiate types if they appear inside a block.
- //
- if (auto origType = as<IRType>(origInst))
- {
- // If this is a generic type, transcibe the parent
- // generic and derive the type from the transcribed generic's
- // return value.
- //
- if (as<IRGeneric>(origType->getParent()->getParent()) &&
- findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType &&
- !instsInProgress.Contains(origType->getParent()->getParent()))
- {
- auto origGenericType = origType->getParent()->getParent();
- auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType);
- auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType));
- return InstPair(
- origGenericType,
- innerDiffGenericType
- );
- }
- else if (as<IRBlock>(origType->getParent()))
- return InstPair(
- cloneInst(&cloneEnv, builder, origType),
- differentiateType(builder, origType));
- else
- return InstPair(
- cloneInst(&cloneEnv, builder, origType),
- nullptr);
- }
+ case kIROp_Less:
+ case kIROp_Greater:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ return transcribeBinaryLogic(builder, origInst);
+
+ case kIROp_Construct:
+ return transcribeConstruct(builder, origInst);
+
+ case kIROp_lookup_interface_method:
+ return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
+
+ case kIROp_Call:
+ return transcribeCall(builder, as<IRCall>(origInst));
+
+ case kIROp_swizzle:
+ return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
+
+ case kIROp_constructVectorFromScalar:
+ case kIROp_MakeTuple:
+ return transcribeByPassthrough(builder, origInst);
+
+ case kIROp_unconditionalBranch:
+ return transcribeControlFlow(builder, origInst);
+
+ case kIROp_FloatLit:
+ case kIROp_IntLit:
+ case kIROp_VoidLit:
+ return transcribeConst(builder, origInst);
+
+ case kIROp_Specialize:
+ return transcribeSpecialize(builder, as<IRSpecialize>(origInst));
+
+ case kIROp_FieldExtract:
+ case kIROp_FieldAddress:
+ return transcribeFieldExtract(builder, origInst);
+ case kIROp_getElement:
+ case kIROp_getElementPtr:
+ return transcribeGetElement(builder, origInst);
+
+ case kIROp_loop:
+ return transcribeLoop(builder, as<IRLoop>(origInst));
+
+ case kIROp_ifElse:
+ return transcribeIfElse(builder, as<IRIfElse>(origInst));
+
+ case kIROp_MakeDifferentialPair:
+ return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst));
+ case kIROp_DifferentialPairGetPrimal:
+ case kIROp_DifferentialPairGetDifferential:
+ return transcribeDifferentialPairGetElement(builder, origInst);
+ case kIROp_ExtractExistentialWitnessTable:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialValue:
+ case kIROp_WrapExistential:
+ case kIROp_MakeExistential:
+ case kIROp_MakeExistentialWithRTTI:
+ return trascribeNonDiffInst(builder, origInst);
+ case kIROp_StructKey:
+ return InstPair(origInst, nullptr);
+ }
- // Handle instructions with children
- switch (origInst->getOp())
+ // If none of the cases have been hit, check if the instruction is a
+ // type. Only need to explicitly differentiate types if they appear inside a block.
+ //
+ if (auto origType = as<IRType>(origInst))
+ {
+ // If this is a generic type, transcibe the parent
+ // generic and derive the type from the transcribed generic's
+ // return value.
+ //
+ if (as<IRGeneric>(origType->getParent()->getParent()) &&
+ findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType &&
+ !instsInProgress.Contains(origType->getParent()->getParent()))
{
- case kIROp_Func:
- return transcribeFuncHeader(builder, as<IRFunc>(origInst));
-
- case kIROp_Block:
- return transcribeBlock(builder, as<IRBlock>(origInst));
-
- case kIROp_Generic:
- return transcribeGeneric(builder, as<IRGeneric>(origInst));
+ auto origGenericType = origType->getParent()->getParent();
+ auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType);
+ auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType));
+ return InstPair(
+ origGenericType,
+ innerDiffGenericType
+ );
}
+ else if (as<IRBlock>(origType->getParent()))
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origType),
+ differentiateType(builder, origType));
+ else
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origType),
+ nullptr);
+ }
- // If we reach this statement, the instruction type is likely unhandled.
- getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::unimplemented,
- "this instruction cannot be differentiated");
+ // Handle instructions with children
+ switch (origInst->getOp())
+ {
+ case kIROp_Func:
+ return transcribeFuncHeader(builder, as<IRFunc>(origInst));
- return InstPair(nullptr, nullptr);
+ case kIROp_Block:
+ return transcribeBlock(builder, as<IRBlock>(origInst));
+
+ case kIROp_Generic:
+ return transcribeGeneric(builder, as<IRGeneric>(origInst));
}
-};
+
+ // If we reach this statement, the instruction type is likely unhandled.
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::unimplemented,
+ "this instruction cannot be differentiated");
+
+ return InstPair(nullptr, nullptr);
+}
struct ForwardDerivativePass : public InstPassBase
{
@@ -1771,7 +1732,7 @@ protected:
// A transcriber object that handles the main job of
// processing instructions while maintaining state.
//
- JVPTranscriber transcriberStorage;
+ ForwardDerivativeTranscriber transcriberStorage;
// Diagnostic object from the compile request for
// error messages.
@@ -1797,61 +1758,4 @@ bool processForwardDerivativeCalls(
return changed;
}
-
-void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
-{
- parentFunc = func;
-
- auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
- SLANG_RELEASE_ASSERT(decor);
-
- // Build lookup dictionary for type witnesses.
- for (auto child = decor->getFirstChild(); child; child = child->next)
- {
- if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
- {
- auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
- if (existingItem)
- {
- *existingItem = item->getWitness();
- }
- else
- {
- differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
- }
- }
- }
-}
-
-
-// Lookup a witness table for the concreteType. One should exist if concreteType
-// inherits (successfully) from IDifferentiable.
-//
-
-IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type)
-{
- IRInst* foundResult = nullptr;
- differentiableWitnessDictionary.TryGetValue(type, foundResult);
- return foundResult;
-}
-
-IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
-{
- if (auto conformance = lookUpConformanceForType(origType))
- {
- return _lookupWitness(builder, conformance, key);
- }
- return nullptr;
-}
-
-void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
-{
- for (auto globalInst : sharedContext->moduleInst->getChildren())
- {
- if (auto pairType = as<IRDifferentialPairType>(globalInst))
- {
- differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness());
- }
- }
-}
}
diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h
index 6b261ecd0..ab5d753d6 100644
--- a/source/slang/slang-ir-autodiff-fwd.h
+++ b/source/slang/slang-ir-autodiff-fwd.h
@@ -30,6 +30,171 @@ namespace Slang
typedef DiffInstPair<IRInst*, IRInst*> InstPair;
+
+struct ForwardDerivativeTranscriber
+{
+
+ // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
+ // their differential values.
+ Dictionary<IRInst*, IRInst*> instMapD;
+
+ // Set of insts currently being transcribed. Used to avoid infinite loops.
+ HashSet<IRInst*> instsInProgress;
+
+ // Cloning environment to hold mapping from old to new copies for the primal
+ // instructions.
+ IRCloneEnv cloneEnv;
+
+ // Diagnostic sink for error messages.
+ DiagnosticSink* sink;
+
+ // Type conformance information.
+ AutoDiffSharedContext* autoDiffSharedContext;
+
+ // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
+ DifferentialPairTypeBuilder* pairBuilder;
+
+ DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
+
+ List<InstPair> followUpFunctionsToTranscribe;
+
+ SharedIRBuilder* sharedBuilder;
+ // Witness table that `DifferentialBottom:IDifferential`.
+ IRWitnessTable* differentialBottomWitness = nullptr;
+ Dictionary<InstPair, IRInst*> differentialPairTypes;
+
+ ForwardDerivativeTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder)
+ : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder)
+ {
+
+ }
+
+ DiagnosticSink* getSink();
+
+ void mapDifferentialInst(IRInst* origInst, IRInst* diffInst);
+
+ void mapPrimalInst(IRInst* origInst, IRInst* primalInst);
+
+ IRInst* lookupDiffInst(IRInst* origInst);
+
+ IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst);
+
+ bool hasDifferentialInst(IRInst* origInst);
+
+ bool shouldUseOriginalAsPrimal(IRInst* origInst);
+
+ IRInst* lookupPrimalInst(IRInst* origInst);
+
+ IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst);
+
+ bool hasPrimalInst(IRInst* origInst);
+
+ IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst);
+
+ IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst);
+
+ IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType);
+
+ // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+ IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType);
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness);
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType);
+
+ IRType* differentiateType(IRBuilder* builder, IRType* origType);
+
+ IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType);
+
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType);
+
+ InstPair transcribeParam(IRBuilder* builder, IRParam* origParam);
+
+ // Returns "d<var-name>" to use as a name hint for variables and parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String getJVPVarName(IRInst* origVar);
+
+ // Returns "dp<var-name>" to use as a name hint for parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String makeDiffPairName(IRInst* origVar);
+
+ InstPair transcribeVar(IRBuilder* builder, IRVar* origVar);
+
+ InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith);
+
+ InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic);
+
+ InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad);
+
+ InstPair transcribeStore(IRBuilder* builder, IRStore* origStore);
+
+ InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn);
+
+ // Since int/float literals are sometimes nested inside an IRConstructor
+ // instruction, we check to make sure that the nested instr is a constant
+ // and then return nullptr. Literals do not need to be differentiated.
+ //
+ InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct);
+
+ // Differentiating a call instruction here is primarily about generating
+ // an appropriate call list based on whichever parameters have differentials
+ // in the current transcription context.
+ //
+ InstPair transcribeCall(IRBuilder* builder, IRCall* origCall);
+
+ InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle);
+
+ InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst);
+
+ InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst);
+
+ InstPair transcribeConst(IRBuilder* builder, IRInst* origInst);
+
+ IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key);
+
+ InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize);
+
+ InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst);
+
+ // In differential computation, the 'default' differential value is always zero.
+ // This is a consequence of differential computing being inherently linear. As a
+ // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+ // The current implementation requires that types are defined linearly.
+ //
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType);
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock);
+
+ InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst);
+
+ InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr);
+
+ InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop);
+
+ InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse);
+
+ InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst);
+
+ InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst);
+
+ InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst);
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc);
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc);
+
+ // Transcribe a generic definition
+ InstPair transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric);
+
+ IRInst* transcribe(IRBuilder* builder, IRInst* origInst);
+
+ InstPair transcribeInst(IRBuilder* builder, IRInst* origInst);
+};
+
struct ForwardDerivativePassOptions
{
// Nothing for now..
diff --git a/source/slang/slang-ir-autodiff-propagate.h b/source/slang/slang-ir-autodiff-propagate.h
new file mode 100644
index 000000000..9518ccb34
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-propagate.h
@@ -0,0 +1,102 @@
+// slang-ir-autodiff-propagate.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+#include "slang-ir-autodiff.h"
+
+namespace Slang
+{
+
+bool isDifferentialInst(IRInst* inst)
+{
+ return inst->findDecoration<IRDifferentialInstDecoration>();
+}
+
+struct DiffPropagationPass : InstPassBase
+{
+ AutoDiffSharedContext* autodiffContext;
+
+ DiffPropagationPass(AutoDiffSharedContext* autodiffContext) :
+ autodiffContext(autodiffContext),
+ InstPassBase(autodiffContext->moduleInst->getModule())
+ {
+
+ }
+
+
+ bool shouldInstBeMarkedDifferential(IRInst* inst)
+ {
+ for (UIndex ii = 0; ii < inst->getOperandCount(); ii ++)
+ {
+ if (isDifferentialInst(inst->getOperand(ii)))
+ {
+ return true;
+ }
+ }
+
+ return false;
+ }
+
+ void addPendingUsersToWorkList(IRInst* inst)
+ {
+ auto use = inst->firstUse;
+ while (use)
+ {
+ if (!isDifferentialInst(use->getUser()))
+ {
+ addToWorkList(use->getUser());
+ }
+ use = use->nextUse;
+ }
+ }
+
+ // Propagate IRDifferentialInstDecoration for all children of instWithChildren.
+ void propagateDiffInstDecoration(IRBuilder* builder, IRInst* instWithChildren)
+ {
+ List<IRInst*> initialList;
+ // Mark 'GetDifferential' insts as differential.
+ processChildInstsOfType<IRDifferentialPairGetDifferential>(
+ kIROp_DifferentialPairGetDifferential,
+ instWithChildren,
+ [&](IRDifferentialPairGetDifferential* getDifferentialInst)
+ {
+ builder->markInstAsDifferential(getDifferentialInst);
+ initialList.add(getDifferentialInst);
+ });
+
+
+ workList.clear();
+ workListSet.Clear();
+
+ // Add the marked insts to the work list.
+ for (auto inst : initialList)
+ {
+ // Look for insts marked as differential.
+ if (isDifferentialInst(inst))
+ addPendingUsersToWorkList(inst);
+ }
+
+ // Propagate to all users..
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = pop();
+
+ // Skip if this is already a differential inst.
+ if (isDifferentialInst(inst))
+ {
+ continue;
+ }
+
+ if (shouldInstBeMarkedDifferential(inst))
+ {
+ builder->markInstAsDifferential(inst);
+ addPendingUsersToWorkList(inst);
+ }
+ }
+ }
+};
+
+} \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 52567e887..522c995b0 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -6,6 +6,11 @@
#include "slang-ir-util.h"
#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-autodiff-fwd.h"
+#include "slang-ir-autodiff-propagate.h"
+#include "slang-ir-autodiff-unzip.h"
+#include "slang-ir-autodiff-transpose.h"
+
namespace Slang
{
@@ -44,12 +49,33 @@ struct BackwardDiffTranscriber
IRWitnessTable* differentialBottomWitness = nullptr;
Dictionary<InstPair, IRInst*> differentialPairTypes;
+ // References to other passes that for reverse-mode transcription.
+ ForwardDerivativeTranscriber *fwdDiffTranscriber;
+ DiffTransposePass *diffTransposePass;
+ DiffPropagationPass *diffPropagationPass;
+ DiffUnzipPass *diffUnzipPass;
+
+ // Allocate space for the passes.
+ ForwardDerivativeTranscriber fwdDiffTranscriberStorage;
+ DiffTransposePass diffTransposePassStorage;
+ DiffPropagationPass diffPropagationPassStorage;
+ DiffUnzipPass diffUnzipPassStorage;
+
+
BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
: autoDiffSharedContext(shared)
, sink(inSink)
, differentiableTypeConformanceContext(shared)
, sharedBuilder(inSharedBuilder)
- {}
+ , fwdDiffTranscriberStorage(shared, inSharedBuilder)
+ , diffTransposePassStorage(shared)
+ , diffPropagationPassStorage(shared)
+ , diffUnzipPassStorage(shared)
+ , fwdDiffTranscriber(&fwdDiffTranscriberStorage)
+ , diffTransposePass(&diffTransposePassStorage)
+ , diffPropagationPass(&diffPropagationPassStorage)
+ , diffUnzipPass(&diffUnzipPassStorage)
+ { }
DiagnosticSink* getSink()
{
@@ -413,20 +439,166 @@ struct BackwardDiffTranscriber
return result;
}
- // Transcribe a function definition.
- InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
+ // Puts parameters into their own block.
+ void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func)
{
IRBuilder builder(inBuilder->getSharedBuilder());
- builder.setInsertInto(diffFunc);
- differentiableTypeConformanceContext.setFunc(primalFunc);
- // Transcribe children from origFunc into diffFunc
- for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
- this->transcribeBlock(&builder, block);
+ auto firstBlock = func->getFirstBlock();
+ IRParam* param = func->getFirstParam();
+
+ builder.setInsertBefore(firstBlock);
+
+ // Note: It looks like emitBlock() doesn't use the current
+ // builder position, so we're going to manually move the new block
+ // to before the existing block.
+ auto paramBlock = builder.emitBlock();
+ paramBlock->insertBefore(firstBlock);
+ builder.setInsertInto(paramBlock);
+
+ while(param)
+ {
+ IRParam* nextParam = param->getNextParam();
+
+ // Copy inst into the new parameter block.
+ auto clonedParam = cloneInst(&cloneEnv, &builder, param);
+ param->replaceUsesWith(clonedParam);
+ param->removeAndDeallocate();
+
+ param = nextParam;
+ }
+
+ // Replace this block as the first block.
+ firstBlock->replaceUsesWith(paramBlock);
+
+ // Add terminator inst.
+ builder.emitBranch(firstBlock);
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ SLANG_ASSERT(primalFunc);
+ SLANG_ASSERT(diffFunc);
+ // Reverse-mode transcription uses 4 separate steps:
+ // TODO(sai): Fill in documentation.
+
+ // Generate a temporary forward derivative function as an intermediate step.
+ IRFunc* fwdDiffFunc = as<IRFunc>(fwdDiffTranscriber->transcribeFuncHeader(builder, (IRFunc*)primalFunc).differential);
+ SLANG_ASSERT(fwdDiffFunc);
+
+ // Transcribe the body of the primal function into it's linear (fwd-diff) form.
+ // TODO(sai): Handle the case when we already have a user-defined fwd-derivative function.
+ fwdDiffTranscriber->transcribeFunc(builder, primalFunc, as<IRFunc>(fwdDiffFunc));
+
+ // Split first block into a paramter block.
+ this->makeParameterBlock(builder, as<IRFunc>(fwdDiffFunc));
+
+ // This steps adds a decoration to instructions that are computing the differential.
+ diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc);
+
+ // Copy primal insts to the first block of the unzipped function, copy diff insts to the
+ // second block of the unzipped function.
+ //
+ IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
+
+ // Clone the primal blocks from unzippedFwdDiffFunc
+ // to the reverse-mode function.
+ // TODO: This is the spot where we can make a decision to split
+ // the primal and differential into two different funcitons
+ // instead of two blocks in the same function.
+ //
+ // Special care needs to be taken for the first block since it holds the parameters
+
+ // Clone all blocks into a temporary diff func.
+ // We're using a temporary sice we don't want to clone decorations,
+ // only blocks, and right now there's no provision in slang-ir-clone.h
+ // for that.
+ //
+ builder->setInsertInto(diffFunc->getParent());
+ auto tempDiffFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, unzippedFwdDiffFunc));
+
+ // Move blocks to the diffFunc shell.
+ {
+ List<IRBlock*> workList;
+ for (auto block = tempDiffFunc->getFirstBlock(); block; block = block->getNextBlock())
+ workList.add(block);
+
+ for (auto block : workList)
+ block->insertAtEnd(diffFunc);
+ }
+
+ // Transpose the first block (parameter block)
+ transcribeParameterBlock(builder, diffFunc);
+
+ builder->setInsertInto(diffFunc);
+
+ auto dOutParameter = diffFunc->getLastParam();
+
+ // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
+ DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr};
+ diffTransposePass->transposeDiffBlocksInFunc(diffFunc, info);
+
+ // Clean up by deallocating intermediate steps.
+ tempDiffFunc->removeAndDeallocate();
+ unzippedFwdDiffFunc->removeAndDeallocate();
+ fwdDiffFunc->removeAndDeallocate();
return InstPair(primalFunc, diffFunc);
}
+ void transcribeParameterBlock(IRBuilder* builder, IRFunc* diffFunc)
+ {
+ IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock();
+
+ // Find the 'next' block using the terminator inst of the parameter block.
+ auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator());
+ auto nextBlock = fwdParamBlockBranch->getTargetBlock();
+
+ builder->setInsertInto(fwdDiffParameterBlock);
+
+ // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<>
+ for (auto child = fwdDiffParameterBlock->getFirstParam(); child;)
+ {
+ IRParam* nextChild = child->getNextParam();
+
+ auto fwdParam = as<IRParam>(child);
+ SLANG_ASSERT(fwdParam);
+
+ // TODO: Handle ptr<pair> types.
+ if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()))
+ {
+ // Create inout version.
+ auto inoutDiffPairType = builder->getInOutType(diffPairType);
+ auto newParam = builder->emitParam(inoutDiffPairType);
+
+ // Map the _load_ of the new parameter as the clone of the old one.
+ auto newParamLoad = builder->emitLoad(newParam);
+ newParamLoad->insertAtStart(nextBlock); // Move to first block _after_ the parameter block.
+ fwdParam->replaceUsesWith(newParamLoad);
+ fwdParam->removeAndDeallocate();
+ }
+ else
+ {
+ // Default case (parameter has nothing to do with differentiation)
+ // Do nothing.
+ }
+
+ child = nextChild;
+ }
+
+ auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount();
+
+ // 2. Add a parameter for 'derivative of the output' (d_out).
+ // The type is the last parameter type of the function.
+ //
+ auto dOutParamType = as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1);
+
+ SLANG_ASSERT(dOutParamType);
+
+ builder->emitParam(dOutParamType);
+ }
+
IRInst* copyParam(IRBuilder* builder, IRParam* origParam)
{
auto primalDataType = origParam->getDataType();
@@ -652,7 +824,6 @@ struct BackwardDiffTranscriber
struct ReverseDerivativePass : public InstPassBase
{
-
DiagnosticSink* getSink()
{
return sink;
@@ -664,7 +835,7 @@ struct ReverseDerivativePass : public InstPassBase
IRBuilder builderStorage(autodiffContext->sharedBuilder);
IRBuilder* builder = &builderStorage;
- // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
+ // Process all ForwardDifferentiate instructions (kIROp_ForwardDifferentiate), by
// generating derivative code for the referenced function.
//
bool modified = processReferencedFunctions(builder);
@@ -800,6 +971,9 @@ struct ReverseDerivativePass : public InstPassBase
pairBuilderStorage(context)
{
backwardTranscriberStorage.pairBuilder = &pairBuilderStorage;
+ backwardTranscriberStorage.fwdDiffTranscriberStorage.sink = sink;
+ backwardTranscriberStorage.fwdDiffTranscriberStorage.autoDiffSharedContext = context;
+ backwardTranscriberStorage.fwdDiffTranscriberStorage.pairBuilder = &(pairBuilderStorage);
}
protected:
@@ -829,4 +1003,4 @@ bool processReverseDerivativeCalls(
return changed;
}
-} \ No newline at end of file
+}
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
new file mode 100644
index 000000000..659131820
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -0,0 +1,420 @@
+// slang-ir-autodiff-transpose.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-fwd.h"
+
+namespace Slang
+{
+
+struct DiffTransposePass
+{
+ AutoDiffSharedContext* autodiffContext;
+
+ DifferentialPairTypeBuilder pairBuilder;
+
+ Dictionary<IRInst*, List<IRInst*>> assignmentsMap;
+
+ Dictionary<IRInst*, IRInst*>* primalsMap;
+
+ DiffTransposePass(AutoDiffSharedContext* autodiffContext) :
+ autodiffContext(autodiffContext), pairBuilder(autodiffContext)
+ { }
+
+ struct RevAssignment
+ {
+ IRInst* lvalue;
+ IRInst* rvalue;
+
+ RevAssignment(IRInst* lvalue, IRInst* rvalue) : lvalue(lvalue), rvalue(rvalue)
+ { }
+ RevAssignment() : lvalue(nullptr), rvalue(nullptr)
+ { }
+ };
+
+ struct TranspositionResult
+ {
+ // Holds a set of pairs of
+ // (original-inst, inst-to-accumulate-for-orig-inst)
+ List<RevAssignment> revPairs;
+
+ TranspositionResult()
+ { }
+
+ TranspositionResult(List<RevAssignment> revPairs) : revPairs(revPairs)
+ { }
+ };
+
+ struct FuncTranspositionInfo
+ {
+ // Inst that represents the reverse-mode derivative
+ // of the *output* of the function.
+ //
+ IRInst* dOutInst;
+
+ // Mapping between *primal* insts in the forward-mode function, and the
+ // reverse-mode function
+ //
+ Dictionary<IRInst*, IRInst*>* primalsMap;
+ };
+
+ void transposeDiffBlocksInFunc(
+ IRFunc* revDiffFunc,
+ // TODO: Maybe there's a more elegant way to pass this information.
+ FuncTranspositionInfo transposeInfo)
+ {
+
+ // Traverse all instructions/blocks in reverse (starting from the terminator inst)
+ // look for insts/blocks marked with IRDifferentialInstDecoration,
+ // and transpose them in the revDiffFunc.
+ //
+ IRBuilder builder;
+ builder.init(autodiffContext->sharedBuilder);
+
+ // Insert after the last block.
+ builder.setInsertInto(revDiffFunc);
+
+ List<IRBlock*> workList;
+
+ // Build initial list of blocks to process by checking if they're differential blocks.
+ for (IRBlock* block = revDiffFunc->getFirstBlock(); block; block = block->getNextBlock())
+ {
+ if (!isDifferentialInst(block))
+ {
+ // Skip blocks that aren't computing differentials.
+ // At this stage we should have 'unzipped' the function
+ // into blocks that either entirely deal with primal insts,
+ // or entirely with differential insts.
+ continue;
+ }
+ workList.add(block);
+ }
+
+ // TODO: We *might* need a step here that 'sorts' the work list in reverse order starting with 'leaf'
+ // differential blocks, and following the branches backwards.
+ // The alternative is to make phi nodes and treat all intermediaries & their gradients as arguments.
+
+ for (auto block : workList)
+ {
+ // Set dOutParameter as the transpose gradient for the return inst, if any.
+ if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ {
+ this->addRevAssignmentForFwdInst(returnInst, transposeInfo.dOutInst);
+ }
+
+ IRBlock* revBlock = builder.emitBlock();
+ this->transposeBlock(block, revBlock);
+
+ // TODO: This should only really be used for the transition from
+ // the 'last' primal block(s) to the first differential block.
+ // Transitions from differential blocks to
+ block->replaceUsesWith(revBlock);
+ block->removeAndDeallocate();
+ }
+ }
+
+ void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock)
+ {
+ IRBuilder builder;
+ builder.init(autodiffContext->sharedBuilder);
+
+ // Insert after the last block.
+ builder.setInsertInto(revBlock);
+
+ // Note the 'reverse' traversal here.
+ for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst())
+ {
+ if (as<IRDecoration>(child))
+ continue;
+
+ transposeInst(&builder, child);
+ }
+
+ // After processing the block's instructions, we 'flush' any remaining gradients
+ // in the assignments map.
+ // For now, these are only function parameter gradients (or of the form IRLoad(IRParam))
+ // TODO: We should be flushing *all* gradients accumulated in this block to some
+ // function scope variable, since control flow can affect what blocks contribute to
+ // for a specific inst.
+ //
+ for (auto pair : assignmentsMap)
+ {
+ if (auto param = as<IRLoad>(pair.Key))
+ accumulateGradientsForLoad(&builder, param);
+ }
+
+ // Emit a terminator inst.
+ // TODO: need a be a lot smarter here. For now, we assume a single differential
+ // block, so it should end in a return statement.
+ if (as<IRReturn>(fwdBlock->getTerminator()))
+ {
+ // Emit a void return.
+ builder.emitReturn();
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unhandled block terminator");
+ }
+ }
+
+ void transposeInst(IRBuilder* builder, IRInst* inst)
+ {
+ // Look for assignment entry for this inst.
+ IRInst* revValue = builder->getFloatValue(builder->getType(kIROp_FloatType), 0.0);
+ if (hasRevAssignments(inst))
+ {
+ // Emit the aggregate of all the assignments here. This will form the derivative
+ revValue = emitAggregateValue(builder, popRevAssignments(inst));
+ }
+
+ auto transposeResult = transposeInst(builder, inst, revValue);
+
+ // Add the new results to the assignments map.
+ for (auto pair : transposeResult.revPairs)
+ {
+ addRevAssignmentForFwdInst(pair.lvalue, pair.rvalue);
+ }
+ }
+
+ TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
+ {
+ // Dispatch logic.
+ switch(fwdInst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ return transposeArithmetic(builder, fwdInst, revValue);
+
+ case kIROp_Return:
+ return transposeReturn(builder, as<IRReturn>(fwdInst), revValue);
+
+ case kIROp_MakeDifferentialPair:
+ return transposeMakePair(builder, as<IRMakeDifferentialPair>(fwdInst), revValue);
+
+ case kIROp_DifferentialPairGetDifferential:
+ return transposeGetDifferential(builder, as<IRDifferentialPairGetDifferential>(fwdInst), revValue);
+
+ default:
+ SLANG_ASSERT_FAILURE("Unhandled instruction");
+ }
+ }
+
+ TranspositionResult transposeMakePair(IRBuilder*, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue)
+ {
+ // (P = (A, dA)) -> (dA += dP)
+ return TranspositionResult(
+ List<RevAssignment>(
+ RevAssignment(
+ fwdMakePair->getDifferentialValue(),
+ revValue)));
+ }
+
+ TranspositionResult transposeGetDifferential(IRBuilder*, IRDifferentialPairGetDifferential* fwdGetDiff, IRInst* revValue)
+ {
+ // (A = GetDiff(P)) -> (dP.d += dA)
+ return TranspositionResult(
+ List<RevAssignment>(
+ RevAssignment(
+ fwdGetDiff->getBase(),
+ revValue)));
+ }
+
+ // Gather all reverse-mode gradients for parameters, and store to the differential
+ //
+ void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad)
+ {
+ auto revParam = revLoad->getPtr();
+
+ // Don't currently handle loads from non-param insts.
+ SLANG_ASSERT(as<IRParam>(revParam));
+
+ // Assert that param type is of the form IRPtrTypeBase<IRDifferentialPairType<T>>
+ SLANG_ASSERT(as<IRPtrTypeBase>(revParam->getDataType()));
+ SLANG_ASSERT(as<IRPtrTypeBase>(revParam->getDataType())->getValueType()->getOp() == kIROp_DifferentialPairType);
+
+ auto paramPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(revParam->getDataType())->getValueType());
+ auto diffType = (IRType*) pairBuilder.getDiffTypeFromPairType(builder, paramPairType);
+
+ // Gather gradients.
+ auto gradients = popRevAssignments(revLoad);
+ if (gradients.getCount() == 0)
+ {
+ // Ignore.
+ return;
+ }
+ else
+ {
+ // Re-emit a load to get the _current_ value of revParam.
+ auto revCurrLoad = builder->emitLoad(revParam);
+
+ // Grab the current gradient value.
+ auto revCurrGrad = builder->emitDifferentialPairGetDifferential(diffType, revCurrLoad);
+
+ // Add the current value to the aggregation list.
+ gradients.add(revCurrGrad);
+
+ // Get the _total_ value.
+ auto aggregateGradient = emitAggregateValue(builder, gradients);
+
+ // Grab the current primal value.
+ auto revCurrPrimal = builder->emitDifferentialPairGetPrimal(revCurrLoad);
+
+ // Make the pair with the new gradient.
+ auto newDiffPair = builder->emitMakeDifferentialPair(paramPairType, revCurrPrimal, aggregateGradient);
+
+ // Store this back into the parameter.
+ builder->emitStore(revParam, newDiffPair);
+ }
+ }
+
+ TranspositionResult transposeReturn(IRBuilder*, IRReturn* fwdReturn, IRInst* revValue)
+ {
+
+ if (as<IRDifferentialPairType>(fwdReturn->getVal()->getDataType()))
+ {
+ // If the type is a differential pair, we add the reverse-value for the *pair*
+ // itself. TODO: Signal this through flags in the 'RevAssignment' struct.
+ // (return (A, dA)) -> (dA += dOut)
+ return TranspositionResult(
+ List<RevAssignment>(
+ RevAssignment(
+ fwdReturn->getVal(),
+ revValue)));
+ }
+ else
+ {
+ // (return A) -> (empty)
+ return TranspositionResult();
+ }
+ }
+
+ TranspositionResult transposeArithmetic(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
+ {
+ IRType* floatType = builder->getType(kIROp_FloatType);
+ switch(fwdInst->getOp())
+ {
+ case kIROp_Add:
+ {
+ // (Out = dA + dB) -> [(dA += dOut), (dB += dOut)]
+ return TranspositionResult(
+ List<RevAssignment>(
+ RevAssignment(
+ fwdInst->getOperand(0),
+ revValue),
+ RevAssignment(
+ fwdInst->getOperand(1),
+ revValue)));
+ }
+ case kIROp_Sub:
+ {
+ // (Out = dA - dB) -> [(dA += dOut), (dB -= dOut)]
+ return TranspositionResult(
+ List<RevAssignment>(
+ RevAssignment(
+ fwdInst->getOperand(0),
+ revValue),
+ RevAssignment(
+ fwdInst->getOperand(1),
+ builder->emitNeg(
+ revValue->getDataType(), revValue))));
+ }
+ case kIROp_Mul:
+ {
+ if (isDifferentialInst(fwdInst->getOperand(0)))
+ {
+ // (Out = dA * B) -> (dA += B * dOut)
+ return TranspositionResult(
+ List<RevAssignment>(
+ RevAssignment(
+ fwdInst->getOperand(0),
+ builder->emitMul(floatType, fwdInst->getOperand(1), revValue))));
+ }
+ else if (isDifferentialInst(fwdInst->getOperand(1)))
+ {
+ // (Out = A * dB) -> (dB += A * dOut)
+ return TranspositionResult(
+ List<RevAssignment>(
+ RevAssignment(
+ fwdInst->getOperand(1),
+ builder->emitMul(floatType, fwdInst->getOperand(0), revValue))));
+ }
+ else
+ {
+ SLANG_ASSERT_FAILURE("Neither operand of a mul instruction is a differential inst");
+ }
+ }
+
+ default:
+ SLANG_ASSERT_FAILURE("Unhandled arithmetic");
+ }
+ }
+
+ IRInst* emitAggregateValue(IRBuilder* builder, List<IRInst*> values)
+ {
+ // We're handling the case where the types are all float,
+ // so we can use a bunch of kIROp_Add insts to add them up.
+ // If this is an arbitrary type T, we need to lookup and
+ // call T.dadd()
+
+ IRInst* initialValue = builder->getFloatValue(builder->getType(kIROp_FloatType), 0.0);
+ if (values.getCount() == 0)
+ {
+ // If there's not values to add up, emit a 0 value.
+ return initialValue;
+ }
+ else if (values.getCount() == 1)
+ {
+ // If there's only one value to add up, just return it in order
+ // to avoid a stack of 0 + 0 + 0 + ...
+ return values[0];
+ }
+
+ // If there's more than one value, aggregate them by adding them up.
+
+ SLANG_ASSERT(values[0]->getDataType()->getOp() == kIROp_FloatType);
+
+ IRInst* currentValue = initialValue;
+ for (auto value : values)
+ {
+ currentValue = builder->emitAdd(
+ builder->getType(kIROp_FloatType), currentValue, value);
+ }
+
+ return currentValue;
+ }
+
+ void addRevAssignmentForFwdInst(IRInst* fwdInst, IRInst* assignment)
+ {
+ if (!hasRevAssignments(fwdInst))
+ {
+ assignmentsMap[fwdInst] = List<IRInst*>();
+ }
+
+ assignmentsMap[fwdInst].GetValue().add(assignment);
+ }
+
+ List<IRInst*> getRevAssignments(IRInst* fwdInst)
+ {
+ return assignmentsMap[fwdInst];
+ }
+
+ List<IRInst*> popRevAssignments(IRInst* fwdInst)
+ {
+ List<IRInst*> val = assignmentsMap[fwdInst].GetValue();
+ assignmentsMap.Remove(fwdInst);
+ return val;
+ }
+
+ bool hasRevAssignments(IRInst* fwdInst)
+ {
+ return assignmentsMap.ContainsKey(fwdInst);
+ }
+};
+
+
+}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
new file mode 100644
index 000000000..344a930f2
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -0,0 +1,110 @@
+// slang-ir-autodiff-unzip.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-compiler.h"
+
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-propagate.h"
+
+namespace Slang
+{
+
+struct DiffUnzipPass
+{
+ AutoDiffSharedContext* autodiffContext;
+
+ IRCloneEnv cloneEnv;
+
+ DiffUnzipPass(AutoDiffSharedContext* autodiffContext) :
+ autodiffContext(autodiffContext)
+ { }
+
+ IRFunc* unzipDiffInsts(IRFunc* func)
+ {
+ IRBuilder builderStorage;
+ builderStorage.init(autodiffContext->sharedBuilder);
+
+ IRBuilder* builder = &builderStorage;
+
+ // Clone the entire function.
+ // TODO: Maybe don't clone? The reverse-mode process seems to clone several times.
+ // TODO: Looks like we get a copy of the decorations?
+ IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&cloneEnv, builder, func));
+
+ builder->setInsertInto(unzippedFunc);
+
+ // Work *only* with two-block functions for now.
+ SLANG_ASSERT(unzippedFunc->getFirstBlock() != nullptr);
+ SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock() != nullptr);
+ SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock()->getNextBlock() == nullptr);
+
+ // Ignore the first block (this is reserved for parameters), start
+ // at the second block. (For now, we work with only a single block of insts)
+ // TODO: expand to handle multi-block functions later.
+
+ IRBlock* mainBlock = unzippedFunc->getFirstBlock()->getNextBlock();
+
+ // Emit new blocks for split vesions of mainblock.
+ IRBlock* primalBlock = builder->emitBlock();
+ IRBlock* diffBlock = builder->emitBlock();
+
+ // Mark the differential block as a differential inst.
+ builder->markInstAsDifferential(diffBlock);
+
+ // Split the main block into two. This method should also emit
+ // a branch statement from primalBlock to diffBlock.
+ // TODO: extend this code to split multiple blocks
+ //
+ splitBlock(mainBlock, primalBlock, diffBlock);
+
+ // Replace occurences of mainBlock with primalBlock
+ mainBlock->replaceUsesWith(primalBlock);
+ mainBlock->removeAndDeallocate();
+
+ return unzippedFunc;
+ }
+
+ void splitBlock(IRBlock* mainBlock, IRBlock* primalBlock, IRBlock* diffBlock)
+ {
+ // Make two builders for primal and differential blocks.
+ IRBuilder primalBuilder;
+ primalBuilder.init(autodiffContext->sharedBuilder);
+ primalBuilder.setInsertInto(primalBlock);
+
+ IRBuilder diffBuilder;
+ diffBuilder.init(autodiffContext->sharedBuilder);
+ diffBuilder.setInsertInto(diffBlock);
+
+ for (auto child = mainBlock->getFirstChild(); child;)
+ {
+ IRInst* nextChild = child->getNextInst();
+
+ if (isDifferentialInst(child) || as<IRTerminatorInst>(child))
+ {
+ auto newInst = cloneInst(&cloneEnv, &diffBuilder, child);
+ child->replaceUsesWith(newInst);
+ child->removeAndDeallocate();
+ }
+ else
+ {
+ auto newInst = cloneInst(&cloneEnv, &primalBuilder, child);
+ child->replaceUsesWith(newInst);
+ child->removeAndDeallocate();
+ }
+
+ child = nextChild;
+ }
+
+ // Nothing should be left in the original block.
+ SLANG_ASSERT(mainBlock->getFirstChild() == nullptr);
+
+ // Branch from primal to differential block.
+ // Functionally, the new blocks should produce the same output as the
+ // old block.
+ primalBuilder.emitBranch(diffBlock);
+ }
+};
+
+}
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 313760d85..b0dbf62fa 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -327,6 +327,60 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde
}
+
+void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func)
+{
+ parentFunc = func;
+
+ auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
+ SLANG_RELEASE_ASSERT(decor);
+
+ // Build lookup dictionary for type witnesses.
+ for (auto child = decor->getFirstChild(); child; child = child->next)
+ {
+ if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
+ {
+ auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
+ if (existingItem)
+ {
+ *existingItem = item->getWitness();
+ }
+ else
+ {
+ differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
+ }
+ }
+ }
+}
+
+IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type)
+{
+ IRInst* foundResult = nullptr;
+ differentiableWitnessDictionary.TryGetValue(type, foundResult);
+ return foundResult;
+}
+
+IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
+{
+ if (auto conformance = lookUpConformanceForType(origType))
+ {
+ return _lookupWitness(builder, conformance, key);
+ }
+ return nullptr;
+}
+
+void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
+{
+ for (auto globalInst : sharedContext->moduleInst->getChildren())
+ {
+ if (auto pairType = as<IRDifferentialPairType>(globalInst))
+ {
+ differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness());
+ }
+ }
+}
+
+
void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
for (auto inst : parent->getChildren())
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 3ca9f1d41..4aca291f9 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -727,10 +727,13 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// generated derivative function.
INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0)
- /// Used by the auto-diff pass to hold a reference to the
- /// generated derivative function.
+ /// Decorated function is marked for the reverse-mode differentiation pass.
INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0)
+ /// Used by the auto-diff pass to mark insts that compute
+ /// a differential value.
+ INST(DifferentialInstDecoration, diffInstDecoration, 0, 0)
+
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 7cf8b3032..a1249aff9 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -584,6 +584,15 @@ struct IRBackwardDerivativeDecoration : IRDecoration
IRInst* getBackwardDerivativeFunc() { return getOperand(0); }
};
+struct IRDifferentialInstDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_DifferentialInstDecoration
+ };
+ IR_LEAF_ISA(DifferentialInstDecoration)
+};
+
struct IRBackwardDifferentiableDecoration : IRDecoration
{
enum
@@ -3335,6 +3344,11 @@ public:
addDecoration(value, kIROp_BackwardDerivativeDecoration, jvpFn);
}
+ void markInstAsDifferential(IRInst* value)
+ {
+ addDecoration(value, kIROp_DifferentialInstDecoration);
+ }
+
void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)
{
addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1);