diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-23 17:16:32 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-23 17:16:32 -0700 |
| commit | 6e4eae1050ab9282b460a33a013652c387c1e585 (patch) | |
| tree | 98510cd3d5cedc5b6d878d9e7d173c1a346ae3fe /source | |
| parent | 50e7d9797d9bf4b98a056d5df128c24dde6e78bd (diff) | |
Hack handling of primal insts that has a function type. (#2728)
* Update diff-bwd material test
* Minor update
* Hack handling of primal insts that has a function type.
---------
Co-authored-by: winmad <winmad.wlf@gmail.com>
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 27 |
2 files changed, 30 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 2bc67e561..552ac762c 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1116,9 +1116,12 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst } else { - if (!pair.primal->findDecoration<IRAutodiffInstDecoration>() - && !as<IRConstant>(pair.differential)) + if (!pair.primal->findDecoration<IRAutodiffInstDecoration>()) { + if (as<IRConstant>(pair.differential)) + break; + if (as<IRType>(pair.differential)) + break; auto mixedType = (IRType*)(pair.primal->getDataType()); builder->markInstAsMixedDifferential(pair.primal, mixedType); } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 653589933..bbdb01290 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -950,8 +950,25 @@ struct DiffTransposePass // Slang doesn't support function values. So if we see a func-typed inst // it's proabably a reference to a function. // - if (as<IRFuncType>(child->getDataType())) + switch (child->getOp()) + { + /* + TODO: need a better way to move specialize, lookupwitness, extractExistentialType/Value/Witness + insts to a proper location that dominates all their use sites. Create copies of these insts + when necessary. + case kIROp_Specialize: + case kIROp_LookupWitness: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: + */ + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: typeInsts.add(child); + break; + } } for (auto inst : typeInsts) @@ -959,7 +976,6 @@ struct DiffTransposePass inst->insertAtEnd(revBlock); } - // Then, go backwards through the regular instructions, and transpose them into the new // rev block. // Note the 'reverse' traversal here. @@ -2221,6 +2237,10 @@ struct DiffTransposePass case kIROp_ifElse: case kIROp_loop: case kIROp_Switch: + case kIROp_LookupWitness: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: { // Ignore. transposeBlock() should take care of adding the // appropriate branch instruction. @@ -3474,6 +3494,9 @@ struct DiffTransposePass List<IRUse*> primalUsesToHoist; Dictionary<IRStore*, IRBlock*> mapStoreToDefBlock; + + IRCloneEnv typeInstCloneEnv = {}; + }; |
