diff options
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 49 |
1 files changed, 32 insertions, 17 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 36093518a..5ac4016d7 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -203,13 +203,23 @@ IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType( IRInst* func, IRFuncType* funcType) { - IRType* intermediateType = - builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + IRType* intermediateType = nullptr; if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent())) { intermediateType = + builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric)); } + else if (as<IRLookupWitnessMethod>(func)) + { + intermediateType = nullptr; + } + else + { + intermediateType = + builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + } return differentiateFunctionTypeImpl(builder, funcType, intermediateType); } @@ -382,14 +392,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl( IRFunc* primalFunc = origFunc; maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); - - // The original func may not have a type dictionary if it is not originally marked as - // differentiable, in this case we would have already pulled the necessary types from - // the user-provided derivative function, so we are still fine. - if (origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) - { - differentiableTypeConformanceContext.setFunc(origFunc); - } + differentiableTypeConformanceContext.setFunc(origFunc); auto diffFunc = builder.createFunc(); @@ -414,12 +417,7 @@ InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl( // Mark the generated derivative function itself as differentiable. builder.addBackwardDifferentiableDecoration(diffFunc); - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. - if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) - { - builder.setInsertBefore(diffFunc->getFirstDecorationOrChild()); - cloneInst(&cloneEnv, &builder, dictDecor); - } + copyOriginalDecorations(origFunc, diffFunc); builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast); return InstPair(primalFunc, diffFunc); @@ -446,7 +444,24 @@ void BackwardDiffTranscriberBase::addTranscribedFuncDecoration( InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { - auto result = transcribeFuncHeaderImpl(inBuilder, origFunc); + InstPair result; + + // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the + // insert location unchanges). If we're transcribing it as a declaration, we should + // insert into the module. + // + auto origOuterGen = as<IRGeneric>(findOuterGeneric(origFunc)); + if (!origOuterGen || !(findInnerMostGenericReturnVal(origOuterGen) == origFunc)) + { + // Dealing with a declaration.. insert into module scope. + IRBuilder subBuilder = *inBuilder; + subBuilder.setInsertInto(inBuilder->getModule()); + result = transcribeFuncHeaderImpl(&subBuilder, origFunc); + } + else + { + result = transcribeFuncHeaderImpl(inBuilder, origFunc); + } FuncBodyTranscriptionTask task; task.originalFunc = as<IRFunc>(result.primal); |
