summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-01-10 03:16:24 +0530
committerGitHub <noreply@github.com>2025-01-09 13:46:24 -0800
commit87f00a36a123e36b415eeea82e02a8366cc5b881 (patch)
tree719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-rev.cpp
parent6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff)
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp49
1 files changed, 32 insertions, 17 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 36093518a..5ac4016d7 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -203,13 +203,23 @@ IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType(
IRInst* func,
IRFuncType* funcType)
{
- IRType* intermediateType =
- builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
+ IRType* intermediateType = nullptr;
if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent()))
{
intermediateType =
+ builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
+ intermediateType =
(IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric));
}
+ else if (as<IRLookupWitnessMethod>(func))
+ {
+ intermediateType = nullptr;
+ }
+ else
+ {
+ intermediateType =
+ builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func));
+ }
return differentiateFunctionTypeImpl(builder, funcType, intermediateType);
}
@@ -382,14 +392,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(
IRFunc* primalFunc = origFunc;
maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc);
-
- // The original func may not have a type dictionary if it is not originally marked as
- // differentiable, in this case we would have already pulled the necessary types from
- // the user-provided derivative function, so we are still fine.
- if (origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- differentiableTypeConformanceContext.setFunc(origFunc);
- }
+ differentiableTypeConformanceContext.setFunc(origFunc);
auto diffFunc = builder.createFunc();
@@ -414,12 +417,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(
// Mark the generated derivative function itself as differentiable.
builder.addBackwardDifferentiableDecoration(diffFunc);
- // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
- if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
- {
- builder.setInsertBefore(diffFunc->getFirstDecorationOrChild());
- cloneInst(&cloneEnv, &builder, dictDecor);
- }
+
copyOriginalDecorations(origFunc, diffFunc);
builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);
return InstPair(primalFunc, diffFunc);
@@ -446,7 +444,24 @@ void BackwardDiffTranscriberBase::addTranscribedFuncDecoration(
InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
{
- auto result = transcribeFuncHeaderImpl(inBuilder, origFunc);
+ InstPair result;
+
+ // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the
+ // insert location unchanges). If we're transcribing it as a declaration, we should
+ // insert into the module.
+ //
+ auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc));
+ if (!origOuterGen || !(findInnerMostGenericReturnVal(origOuterGen) == origFunc))
+ {
+ // Dealing with a declaration.. insert into module scope.
+ IRBuilder subBuilder = *inBuilder;
+ subBuilder.setInsertInto(inBuilder->getModule());
+ result = transcribeFuncHeaderImpl(&subBuilder, origFunc);
+ }
+ else
+ {
+ result = transcribeFuncHeaderImpl(inBuilder, origFunc);
+ }
FuncBodyTranscriptionTask task;
task.originalFunc = as<IRFunc>(result.primal);