From 6e4eae1050ab9282b460a33a013652c387c1e585 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 23 Mar 2023 17:16:32 -0700 Subject: Hack handling of primal insts that has a function type. (#2728) * Update diff-bwd material test * Minor update * Hack handling of primal insts that has a function type. --------- Co-authored-by: winmad Co-authored-by: Yong He --- .../slang/slang-ir-autodiff-transcriber-base.cpp | 7 ++++-- source/slang/slang-ir-autodiff-transpose.h | 27 ++++++++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 2bc67e561..552ac762c 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1116,9 +1116,12 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst } else { - if (!pair.primal->findDecoration() - && !as(pair.differential)) + if (!pair.primal->findDecoration()) { + if (as(pair.differential)) + break; + if (as(pair.differential)) + break; auto mixedType = (IRType*)(pair.primal->getDataType()); builder->markInstAsMixedDifferential(pair.primal, mixedType); } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 653589933..bbdb01290 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -950,8 +950,25 @@ struct DiffTransposePass // Slang doesn't support function values. So if we see a func-typed inst // it's proabably a reference to a function. // - if (as(child->getDataType())) + switch (child->getOp()) + { + /* + TODO: need a better way to move specialize, lookupwitness, extractExistentialType/Value/Witness + insts to a proper location that dominates all their use sites. Create copies of these insts + when necessary. + case kIROp_Specialize: + case kIROp_LookupWitness: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: + */ + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: typeInsts.add(child); + break; + } } for (auto inst : typeInsts) @@ -959,7 +976,6 @@ struct DiffTransposePass inst->insertAtEnd(revBlock); } - // Then, go backwards through the regular instructions, and transpose them into the new // rev block. // Note the 'reverse' traversal here. @@ -2221,6 +2237,10 @@ struct DiffTransposePass case kIROp_ifElse: case kIROp_loop: case kIROp_Switch: + case kIROp_LookupWitness: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: { // Ignore. transposeBlock() should take care of adding the // appropriate branch instruction. @@ -3474,6 +3494,9 @@ struct DiffTransposePass List primalUsesToHoist; Dictionary mapStoreToDefBlock; + + IRCloneEnv typeInstCloneEnv = {}; + }; -- cgit v1.2.3