From 87f00a36a123e36b415eeea82e02a8366cc5b881 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:16:24 +0530 Subject: [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866) * Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-rev.cpp | 49 ++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 17 deletions(-) (limited to 'source/slang/slang-ir-autodiff-rev.cpp') diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 36093518a..5ac4016d7 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -203,13 +203,23 @@ IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType( IRInst* func, IRFuncType* funcType) { - IRType* intermediateType = - builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + IRType* intermediateType = nullptr; if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent())) { + intermediateType = + builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as(outerGeneric)); } + else if (as(func)) + { + intermediateType = nullptr; + } + else + { + intermediateType = + builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + } return differentiateFunctionTypeImpl(builder, funcType, intermediateType); } @@ -382,14 +392,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl( IRFunc* primalFunc = origFunc; maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); - - // The original func may not have a type dictionary if it is not originally marked as - // differentiable, in this case we would have already pulled the necessary types from - // the user-provided derivative function, so we are still fine. - if (origFunc->findDecoration()) - { - differentiableTypeConformanceContext.setFunc(origFunc); - } + differentiableTypeConformanceContext.setFunc(origFunc); auto diffFunc = builder.createFunc(); @@ -414,12 +417,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl( // Mark the generated derivative function itself as differentiable. builder.addBackwardDifferentiableDecoration(diffFunc); - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration()) - { - builder.setInsertBefore(diffFunc->getFirstDecorationOrChild()); - cloneInst(&cloneEnv, &builder, dictDecor); - } + copyOriginalDecorations(origFunc, diffFunc); builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast); return InstPair(primalFunc, diffFunc); @@ -446,7 +444,24 @@ void BackwardDiffTranscriberBase::addTranscribedFuncDecoration( InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { - auto result = transcribeFuncHeaderImpl(inBuilder, origFunc); + InstPair result; + + // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the + // insert location unchanges). If we're transcribing it as a declaration, we should + // insert into the module. + // + auto origOuterGen = as(findOuterGeneric(origFunc)); + if (!origOuterGen || !(findInnerMostGenericReturnVal(origOuterGen) == origFunc)) + { + // Dealing with a declaration.. insert into module scope. + IRBuilder subBuilder = *inBuilder; + subBuilder.setInsertInto(inBuilder->getModule()); + result = transcribeFuncHeaderImpl(&subBuilder, origFunc); + } + else + { + result = transcribeFuncHeaderImpl(inBuilder, origFunc); + } FuncBodyTranscriptionTask task; task.originalFunc = as(result.primal); -- cgit v1.2.3