diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-01-10 03:16:24 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-09 13:46:24 -0800 |
| commit | 87f00a36a123e36b415eeea82e02a8366cc5b881 (patch) | |
| tree | 719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-rev.cpp | |
| parent | 6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff) | |
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch
* More fixes
* remove intermediate dumps
* Update slang-ast-type.h
* More fixes + add a workaround for existential no-diff
* Update reverse-control-flow-3.slang
* remove dumps
* remove more dumps
* Delete working-reverse-control-flow-3.hlsl
* Cleanup comments + unused variables
* More comment cleanup
* Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)`
* Fix array of issues in Falcor tests.
* Update slang-ir-autodiff-pairs.cpp
* More fixes for Falcor image tests
* Small fixups.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 49 |
1 files changed, 32 insertions, 17 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 36093518a..5ac4016d7 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -203,13 +203,23 @@ IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType( IRInst* func, IRFuncType* funcType) { - IRType* intermediateType = - builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + IRType* intermediateType = nullptr; if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent())) { intermediateType = + builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric)); } + else if (as<IRLookupWitnessMethod>(func)) + { + intermediateType = nullptr; + } + else + { + intermediateType = + builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + } return differentiateFunctionTypeImpl(builder, funcType, intermediateType); } @@ -382,14 +392,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl( IRFunc* primalFunc = origFunc; maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); - - // The original func may not have a type dictionary if it is not originally marked as - // differentiable, in this case we would have already pulled the necessary types from - // the user-provided derivative function, so we are still fine. - if (origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) - { - differentiableTypeConformanceContext.setFunc(origFunc); - } + differentiableTypeConformanceContext.setFunc(origFunc); auto diffFunc = builder.createFunc(); @@ -414,12 +417,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl( // Mark the generated derivative function itself as differentiable. builder.addBackwardDifferentiableDecoration(diffFunc); - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) - { - builder.setInsertBefore(diffFunc->getFirstDecorationOrChild()); - cloneInst(&cloneEnv, &builder, dictDecor); - } + copyOriginalDecorations(origFunc, diffFunc); builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast); return InstPair(primalFunc, diffFunc); @@ -446,7 +444,24 @@ void BackwardDiffTranscriberBase::addTranscribedFuncDecoration( InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { - auto result = transcribeFuncHeaderImpl(inBuilder, origFunc); + InstPair result; + + // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the + // insert location unchanges). If we're transcribing it as a declaration, we should + // insert into the module. + // + auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc)); + if (!origOuterGen || !(findInnerMostGenericReturnVal(origOuterGen) == origFunc)) + { + // Dealing with a declaration.. insert into module scope. + IRBuilder subBuilder = *inBuilder; + subBuilder.setInsertInto(inBuilder->getModule()); + result = transcribeFuncHeaderImpl(&subBuilder, origFunc); + } + else + { + result = transcribeFuncHeaderImpl(inBuilder, origFunc); + } FuncBodyTranscriptionTask task; task.originalFunc = as<IRFunc>(result.primal); |
