summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-transcriber-base.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp39
1 files changed, 37 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 73d9b6ba6..ed122c862 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -72,6 +72,14 @@ bool AutoDiffTranscriberBase::shouldUseOriginalAsPrimal(IRInst* currentParent, I
return true;
if (isChildInstOf(currentParent, origInst->getParent()))
return true;
+
+ // If origInst is defined in the first block of the same function as current inst (e.g. a param),
+ // we can use it as primal.
+ // More generally, we should test if origInst dominates currentParent, but that requires calculating
+ // a dom tree on the fly. Right now just testing if it is first block for parameters seems sufficient.
+ auto parentFunc = getParentFunc(currentParent);
+ if (parentFunc && origInst->parent == parentFunc->getFirstBlock())
+ return true;
return false;
}
@@ -802,6 +810,7 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS
case kIROp_BackwardDerivativePrimalContextDecoration:
case kIROp_BackwardDerivativePrimalDecoration:
case kIROp_BackwardDerivativePropagateDecoration:
+ case kIROp_PrimalSubstituteDecoration:
break;
default:
if (!outInstsToSkip.Contains(use->getUser()))
@@ -876,6 +885,32 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene
return InstPair(primalGeneric, diffGeneric);
}
+IRInst* getActualInstToTranscribe(IRInst* inst)
+{
+ if (auto gen = as<IRGeneric>(inst))
+ {
+ auto retVal = findGenericReturnVal(gen);
+ if (retVal->getOp() != kIROp_Func)
+ return inst;
+ if (auto primalSubst = retVal->findDecoration<IRPrimalSubstituteDecoration>())
+ {
+ auto spec = as<IRSpecialize>(primalSubst->getPrimalSubstituteFunc());
+ SLANG_RELEASE_ASSERT(spec);
+ return spec->getBase();
+ }
+ }
+ else if (auto func = as<IRFunc>(inst))
+ {
+ if (auto primalSubst = func->findDecoration<IRPrimalSubstituteDecoration>())
+ {
+ auto actualFunc = as<IRFunc>(primalSubst->getPrimalSubstituteFunc());
+ SLANG_RELEASE_ASSERT(actualFunc);
+ return actualFunc;
+ }
+ }
+ return inst;
+}
+
IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst)
{
// If a differential instruction is already mapped for
@@ -891,8 +926,8 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
// depending on the op-code.
//
instsInProgress.Add(origInst);
-
- InstPair pair = transcribeInst(builder, origInst);
+ auto actualInstToTranscribe = getActualInstToTranscribe(origInst);
+ InstPair pair = transcribeInst(builder, actualInstToTranscribe);
instsInProgress.Remove(origInst);