From df02f3f50f977112ca1fbb148cd48ee41d560f41 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 9 Feb 2023 17:40:20 -0500 Subject: Reverse-mode Loop Support (#2635) * Full loop support now working. MaxItersAttr in progress * Lookup table updates? * Fixed the max iters decoration * Minox fixes & remove superfluous code * fixup warnings * Revert "Lookup table updates?" This reverts commit 7d9b0793fb5239f31d1155776e846dcf1892d8d9. * Update 07-autodiff.md * Change maxiters to MaxIters * Added asserts * Update 07-autodiff.md --- .../slang/slang-ir-autodiff-transcriber-base.cpp | 38 +++++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp') diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 31a3072c0..10a734d65 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -901,6 +901,15 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); + + + if (pair.primal != pair.differential && + !pair.primal->findDecoration() && + !as(pair.primal)) + { + builder->markInstAsPrimal(pair.primal); + } + if (pair.differential) { switch (pair.differential->getOp()) @@ -920,16 +929,27 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); } - // Tag the differential inst using a decoration (if it doesn't have one) - if (!pair.differential->findDecoration() && - !pair.differential->findDecoration() && - !as(pair.differential)) + // Automatically tag the primal and differential results + // if they haven't already been handled by the + // code. + // + if (pair.primal != pair.differential) + { + if (!pair.differential->findDecoration() + && !as(pair.differential)) + { + auto primalType = as(pair.primal->getDataType()); + builder->markInstAsDifferential(pair.differential, primalType); + } + } + else { - // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential - // instead. - // - auto primalType = as(pair.primal->getDataType()); - builder->markInstAsDifferential(pair.differential, primalType); + if (!pair.primal->findDecoration() + && !as(pair.differential)) + { + auto mixedType = as(pair.primal->getDataType()); + builder->markInstAsMixedDifferential(pair.primal, mixedType); + } } break; -- cgit v1.2.3