diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-13 10:57:28 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-13 10:57:28 -0700 |
| commit | a911ca6e06ce41e403b80fe6054162393491c8ac (patch) | |
| tree | 6c8d56a3060b1887e7fd3126fe54a1241160eddd /source/slang/slang-ir-autodiff-pairs.cpp | |
| parent | 3fea56ef77a33273bf5af6f432163b30c0a0e1dc (diff) | |
Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695)
* Support high order diff pattern: `bwd_diff(fwd_diff(f))`.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-pairs.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-pairs.cpp | 123 |
1 files changed, 114 insertions, 9 deletions
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index 9d761764c..7b16c0213 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -24,10 +24,10 @@ struct DiffPairLoweringPass : InstPassBase IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) { - if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) + if (auto makePairInst = as<IRMakeDifferentialPairBase>(inst)) { bool isTrivial = false; - auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType()); + auto pairType = as<IRDifferentialPairTypeBase>(makePairInst->getDataType()); if (auto loweredPairType = lowerPairType(builder, pairType)) { builder->setInsertBefore(makePairInst); @@ -52,7 +52,7 @@ struct DiffPairLoweringPass : InstPassBase IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst) { - if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) + if (auto getDiffInst = as<IRDifferentialPairGetDifferentialBase>(inst)) { auto pairType = getDiffInst->getBase()->getDataType(); if (auto pairPtrType = as<IRPtrTypeBase>(pairType)) @@ -70,7 +70,7 @@ struct DiffPairLoweringPass : InstPassBase return diffFieldExtract; } } - else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) + else if (auto getPrimalInst = as<IRDifferentialPairGetPrimalBase>(inst)) { auto pairType = getPrimalInst->getBase()->getDataType(); if (auto pairPtrType = as<IRPtrTypeBase>(pairType)) @@ -106,10 +106,12 @@ struct DiffPairLoweringPass : InstPassBase { case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferentialUserCode: + case kIROp_DifferentialPairGetPrimalUserCode: lowerPairAccess(builder, inst); break; - case kIROp_MakeDifferentialPair: + case kIROp_MakeDifferentialPairUserCode: lowerMakePair(builder, inst); break; @@ -119,12 +121,15 @@ struct DiffPairLoweringPass : InstPassBase }); OrderedDictionary<IRInst*, IRInst*> pendingReplacements; - processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) + processAllInsts([&](IRInst* inst) { - if (auto loweredType = lowerPairType(builder, inst)) + if (auto pairType = as<IRDifferentialPairTypeBase>(inst)) { - pendingReplacements.Add(inst, loweredType); - modified = true; + if (auto loweredType = lowerPairType(builder, pairType)) + { + pendingReplacements.Add(pairType, loweredType); + modified = true; + } } }); for (auto replacement : pendingReplacements) @@ -158,4 +163,104 @@ bool processPairTypes(AutoDiffSharedContext* context) return pairLoweringPass.processModule(); } +struct DifferentialPairUserCodeTranscribePass : public InstPassBase +{ + DifferentialPairUserCodeTranscribePass(IRModule* module) + :InstPassBase(module) + {} + + IRInst* rewritePairType(IRBuilder* builder, IRType* pairType) + { + builder->setInsertBefore(pairType); + auto originalPairType = as<IRDifferentialPairType>(pairType); + return builder->getDifferentialPairUserCodeType(originalPairType->getValueType(), originalPairType->getWitness()); + } + + IRInst* rewriteMakePair(IRBuilder* builder, IRMakeDifferentialPair* inst) + { + auto pairType = as<IRDifferentialPairType>(inst->getFullType()); + builder->setInsertBefore(inst); + auto newInst = builder->emitMakeDifferentialPairUserCode( + (IRType*)pairType, inst->getPrimalValue(), inst->getDifferentialValue()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + return newInst; + } + + IRInst* rewritePairAccess(IRBuilder* builder, IRInst* inst) + { + if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) + { + builder->setInsertBefore(inst); + + auto newInst = builder->emitDifferentialPairGetDifferentialUserCode( + (IRType*)inst->getFullType(), getDiffInst->getBase()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + } + else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) + { + builder->setInsertBefore(inst); + auto newInst = builder->emitDifferentialPairGetPrimalUserCode(getPrimalInst->getBase()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + } + return inst; + } + + bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren) + { + SLANG_UNUSED(instWithChildren); + + bool modified = false; + + processAllInsts([&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_DifferentialPairGetDifferential: + case kIROp_DifferentialPairGetPrimal: + rewritePairAccess(builder, inst); + break; + + case kIROp_MakeDifferentialPair: + rewriteMakePair(builder, as<IRMakeDifferentialPair>(inst)); + break; + + default: + break; + } + }); + + OrderedDictionary<IRInst*, IRInst*> pendingReplacements; + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) + { + if (auto loweredType = rewritePairType(builder, inst)) + { + pendingReplacements.Add(inst, loweredType); + modified = true; + } + }); + for (auto replacement : pendingReplacements) + { + replacement.Key->replaceUsesWith(replacement.Value); + replacement.Key->removeAndDeallocate(); + } + + return modified; + } + + bool processModule() + { + IRBuilder builder(module); + return processInstWithChildren(&builder, module->getModuleInst()); + } +}; + +void rewriteDifferentialPairToUserCode(IRModule* module) +{ + DifferentialPairUserCodeTranscribePass pairRewritePass(module); + pairRewritePass.processModule(); +} + } |
