diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 27 |
2 files changed, 30 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 2bc67e561..552ac762c 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1116,9 +1116,12 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst } else { - if (!pair.primal->findDecoration<IRAutodiffInstDecoration>() - && !as<IRConstant>(pair.differential)) + if (!pair.primal->findDecoration<IRAutodiffInstDecoration>()) { + if (as<IRConstant>(pair.differential)) + break; + if (as<IRType>(pair.differential)) + break; auto mixedType = (IRType*)(pair.primal->getDataType()); builder->markInstAsMixedDifferential(pair.primal, mixedType); } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 653589933..bbdb01290 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -950,8 +950,25 @@ struct DiffTransposePass // Slang doesn't support function values. So if we see a func-typed inst // it's proabably a reference to a function. // - if (as<IRFuncType>(child->getDataType())) + switch (child->getOp()) + { + /* + TODO: need a better way to move specialize, lookupwitness, extractExistentialType/Value/Witness + insts to a proper location that dominates all their use sites. Create copies of these insts + when necessary. + case kIROp_Specialize: + case kIROp_LookupWitness: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: + */ + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: typeInsts.add(child); + break; + } } for (auto inst : typeInsts) @@ -959,7 +976,6 @@ struct DiffTransposePass inst->insertAtEnd(revBlock); } - // Then, go backwards through the regular instructions, and transpose them into the new // rev block. // Note the 'reverse' traversal here. @@ -2221,6 +2237,10 @@ struct DiffTransposePass case kIROp_ifElse: case kIROp_loop: case kIROp_Switch: + case kIROp_LookupWitness: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_ExtractExistentialWitnessTable: { // Ignore. transposeBlock() should take care of adding the // appropriate branch instruction. @@ -3474,6 +3494,9 @@ struct DiffTransposePass List<IRUse*> primalUsesToHoist; Dictionary<IRStore*, IRBlock*> mapStoreToDefBlock; + + IRCloneEnv typeInstCloneEnv = {}; + }; |
