From 86fc50c5092fbccf6072dcf7bbdfafb8915f02c8 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 8 Mar 2023 21:52:34 -0800 Subject: Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. (#2691) * Add support for `[PrimalSubstitute]` and `[PrimalSubstituteOf]`. * Fix * Fix. * Cleanup. --------- Co-authored-by: Yong He --- .../slang/slang-ir-autodiff-transcriber-base.cpp | 39 ++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp') diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 73d9b6ba6..ed122c862 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -72,6 +72,14 @@ bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* currentParent, I return true; if (isChildInstOf(currentParent, origInst->getParent())) return true; + + // If origInst is defined in the first block of the same function as current inst (e.g. a param), + // we can use it as primal. + // More generally, we should test if origInst dominates currentParent, but that requires calculating + // a dom tree on the fly. Right now just testing if it is first block for parameters seems sufficient. + auto parentFunc = getParentFunc(currentParent); + if (parentFunc && origInst->parent == parentFunc->getFirstBlock()) + return true; return false; } @@ -802,6 +810,7 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePropagateDecoration: + case kIROp_PrimalSubstituteDecoration: break; default: if (!outInstsToSkip.Contains(use->getUser())) @@ -876,6 +885,32 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene return InstPair(primalGeneric, diffGeneric); } +IRInst* getActualInstToTranscribe(IRInst* inst) +{ + if (auto gen = as(inst)) + { + auto retVal = findGenericReturnVal(gen); + if (retVal->getOp() != kIROp_Func) + return inst; + if (auto primalSubst = retVal->findDecoration()) + { + auto spec = as(primalSubst->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(spec); + return spec->getBase(); + } + } + else if (auto func = as(inst)) + { + if (auto primalSubst = func->findDecoration()) + { + auto actualFunc = as(primalSubst->getPrimalSubstituteFunc()); + SLANG_RELEASE_ASSERT(actualFunc); + return actualFunc; + } + } + return inst; +} + IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst) { // If a differential instruction is already mapped for @@ -891,8 +926,8 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst // depending on the op-code. // instsInProgress.Add(origInst); - - InstPair pair = transcribeInst(builder, origInst); + auto actualInstToTranscribe = getActualInstToTranscribe(origInst); + InstPair pair = transcribeInst(builder, actualInstToTranscribe); instsInProgress.Remove(origInst); -- cgit v1.2.3