summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-transcriber-base.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp38
1 files changed, 29 insertions, 9 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 31a3072c0..10a734d65 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -901,6 +901,15 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
{
mapPrimalInst(origInst, pair.primal);
mapDifferentialInst(origInst, pair.differential);
+
+
+ if (pair.primal != pair.differential &&
+ !pair.primal->findDecoration<IRAutodiffInstDecoration>() &&
+ !as<IRConstant>(pair.primal))
+ {
+ builder->markInstAsPrimal(pair.primal);
+ }
+
if (pair.differential)
{
switch (pair.differential->getOp())
@@ -920,16 +929,27 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
}
- // Tag the differential inst using a decoration (if it doesn't have one)
- if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() &&
- !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>() &&
- !as<IRConstant>(pair.differential))
+ // Automatically tag the primal and differential results
+ // if they haven't already been handled by the
+ // code.
+ //
+ if (pair.primal != pair.differential)
+ {
+ if (!pair.differential->findDecoration<IRAutodiffInstDecoration>()
+ && !as<IRConstant>(pair.differential))
+ {
+ auto primalType = as<IRType>(pair.primal->getDataType());
+ builder->markInstAsDifferential(pair.differential, primalType);
+ }
+ }
+ else
{
- // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential
- // instead.
- //
- auto primalType = as<IRType>(pair.primal->getDataType());
- builder->markInstAsDifferential(pair.differential, primalType);
+ if (!pair.primal->findDecoration<IRAutodiffInstDecoration>()
+ && !as<IRConstant>(pair.differential))
+ {
+ auto mixedType = as<IRType>(pair.primal->getDataType());
+ builder->markInstAsMixedDifferential(pair.primal, mixedType);
+ }
}
break;