diff options
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 38 |
1 files changed, 29 insertions, 9 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 31a3072c0..10a734d65 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -901,6 +901,15 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); + + + if (pair.primal != pair.differential && + !pair.primal->findDecoration<IRAutodiffInstDecoration>() && + !as<IRConstant>(pair.primal)) + { + builder->markInstAsPrimal(pair.primal); + } + if (pair.differential) { switch (pair.differential->getOp()) @@ -920,16 +929,27 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); } - // Tag the differential inst using a decoration (if it doesn't have one) - if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() && - !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>() && - !as<IRConstant>(pair.differential)) + // Automatically tag the primal and differential results + // if they haven't already been handled by the + // code. + // + if (pair.primal != pair.differential) + { + if (!pair.differential->findDecoration<IRAutodiffInstDecoration>() + && !as<IRConstant>(pair.differential)) + { + auto primalType = as<IRType>(pair.primal->getDataType()); + builder->markInstAsDifferential(pair.differential, primalType); + } + } + else { - // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential - // instead. - // - auto primalType = as<IRType>(pair.primal->getDataType()); - builder->markInstAsDifferential(pair.differential, primalType); + if (!pair.primal->findDecoration<IRAutodiffInstDecoration>() + && !as<IRConstant>(pair.differential)) + { + auto mixedType = as<IRType>(pair.primal->getDataType()); + builder->markInstAsMixedDifferential(pair.primal, mixedType); + } } break; |
