diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-09 17:40:20 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-09 17:40:20 -0500 |
| commit | df02f3f50f977112ca1fbb148cd48ee41d560f41 (patch) | |
| tree | 7732e8fec9f33aff9666b3710c7adb899788c4be /source/slang/slang-ir-autodiff-transcriber-base.cpp | |
| parent | d911e1bed9572664b1d0554feb3c7d1a2a880518 (diff) | |
Reverse-mode Loop Support (#2635)
* Full loop support now working. MaxItersAttr in progress
* Lookup table updates?
* Fixed the max iters decoration
* Minox fixes & remove superfluous code
* fixup warnings
* Revert "Lookup table updates?"
This reverts commit 7d9b0793fb5239f31d1155776e846dcf1892d8d9.
* Update 07-autodiff.md
* Change maxiters to MaxIters
* Added asserts
* Update 07-autodiff.md
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 38 |
1 files changed, 29 insertions, 9 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 31a3072c0..10a734d65 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -901,6 +901,15 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); + + + if (pair.primal != pair.differential && + !pair.primal->findDecoration<IRAutodiffInstDecoration>() && + !as<IRConstant>(pair.primal)) + { + builder->markInstAsPrimal(pair.primal); + } + if (pair.differential) { switch (pair.differential->getOp()) @@ -920,16 +929,27 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); } - // Tag the differential inst using a decoration (if it doesn't have one) - if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() && - !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>() && - !as<IRConstant>(pair.differential)) + // Automatically tag the primal and differential results + // if they haven't already been handled by the + // code. + // + if (pair.primal != pair.differential) + { + if (!pair.differential->findDecoration<IRAutodiffInstDecoration>() + && !as<IRConstant>(pair.differential)) + { + auto primalType = as<IRType>(pair.primal->getDataType()); + builder->markInstAsDifferential(pair.differential, primalType); + } + } + else { - // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential - // instead. - // - auto primalType = as<IRType>(pair.primal->getDataType()); - builder->markInstAsDifferential(pair.differential, primalType); + if (!pair.primal->findDecoration<IRAutodiffInstDecoration>() + && !as<IRConstant>(pair.differential)) + { + auto mixedType = as<IRType>(pair.primal->getDataType()); + builder->markInstAsMixedDifferential(pair.primal, mixedType); + } } break; |
