diff options
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 39 |
1 files changed, 37 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 73d9b6ba6..ed122c862 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -72,6 +72,14 @@ bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* currentParent, I return true; if (isChildInstOf(currentParent, origInst->getParent())) return true; + + // If origInst is defined in the first block of the same function as current inst (e.g. a param), + // we can use it as primal. + // More generally, we should test if origInst dominates currentParent, but that requires calculating + // a dom tree on the fly. Right now just testing if it is first block for parameters seems sufficient. + auto parentFunc = getParentFunc(currentParent); + if (parentFunc && origInst->parent == parentFunc->getFirstBlock()) + return true; return false; } @@ -802,6 +810,7 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_PrimalSubstituteDecoration: break; default: if (!outInstsToSkip.Contains(use->getUser())) @@ -876,6 +885,32 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene return InstPair(primalGeneric, diffGeneric); } +IRInst* getActualInstToTranscribe(IRInst* inst) +{ + if (auto gen = as<IRGeneric>(inst)) + { + auto retVal = findGenericReturnVal(gen); + if (retVal->getOp() != kIROp_Func) + return inst; + if (auto primalSubst = retVal->findDecoration<IRPrimalSubstituteDecoration>()) + { + auto spec = as<IRSpecialize>(primalSubst->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(spec); + return spec->getBase(); + } + } + else if (auto func = as<IRFunc>(inst)) + { + if (auto primalSubst = func->findDecoration<IRPrimalSubstituteDecoration>()) + { + auto actualFunc = as<IRFunc>(primalSubst->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(actualFunc); + return actualFunc; + } + } + return inst; +} + IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { // If a differential instruction is already mapped for @@ -891,8 +926,8 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst // depending on the op-code. // instsInProgress.Add(origInst); - - InstPair pair = transcribeInst(builder, origInst); + auto actualInstToTranscribe = getActualInstToTranscribe(origInst); + InstPair pair = transcribeInst(builder, actualInstToTranscribe); instsInProgress.Remove(origInst); |
