summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-autodiff-pairs.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-13 10:57:28 -0700
committerGitHub <noreply@github.com>2023-03-13 10:57:28 -0700
commita911ca6e06ce41e403b80fe6054162393491c8ac (patch)
tree6c8d56a3060b1887e7fd3126fe54a1241160eddd /source/slang/slang-ir-autodiff-pairs.cpp
parent3fea56ef77a33273bf5af6f432163b30c0a0e1dc (diff)
Support high order diff pattern: `bwd_diff(fwd_diff(f))`. (#2695)
* Support high order diff pattern: `bwd_diff(fwd_diff(f))`. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-pairs.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-pairs.cpp123
1 files changed, 114 insertions, 9 deletions
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
index 9d761764c..7b16c0213 100644
--- a/source/slang/slang-ir-autodiff-pairs.cpp
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -24,10 +24,10 @@ struct DiffPairLoweringPass : InstPassBase
IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
{
- if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
+ if (auto makePairInst = as<IRMakeDifferentialPairBase>(inst))
{
bool isTrivial = false;
- auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType());
+ auto pairType = as<IRDifferentialPairTypeBase>(makePairInst->getDataType());
if (auto loweredPairType = lowerPairType(builder, pairType))
{
builder->setInsertBefore(makePairInst);
@@ -52,7 +52,7 @@ struct DiffPairLoweringPass : InstPassBase
IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst)
{
- if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
+ if (auto getDiffInst = as<IRDifferentialPairGetDifferentialBase>(inst))
{
auto pairType = getDiffInst->getBase()->getDataType();
if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
@@ -70,7 +70,7 @@ struct DiffPairLoweringPass : InstPassBase
return diffFieldExtract;
}
}
- else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
+ else if (auto getPrimalInst = as<IRDifferentialPairGetPrimalBase>(inst))
{
auto pairType = getPrimalInst->getBase()->getDataType();
if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
@@ -106,10 +106,12 @@ struct DiffPairLoweringPass : InstPassBase
{
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
+ case kIROp_DifferentialPairGetDifferentialUserCode:
+ case kIROp_DifferentialPairGetPrimalUserCode:
lowerPairAccess(builder, inst);
break;
- case kIROp_MakeDifferentialPair:
+ case kIROp_MakeDifferentialPairUserCode:
lowerMakePair(builder, inst);
break;
@@ -119,12 +121,15 @@ struct DiffPairLoweringPass : InstPassBase
});
OrderedDictionary<IRInst*, IRInst*> pendingReplacements;
- processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
+ processAllInsts([&](IRInst* inst)
{
- if (auto loweredType = lowerPairType(builder, inst))
+ if (auto pairType = as<IRDifferentialPairTypeBase>(inst))
{
- pendingReplacements.Add(inst, loweredType);
- modified = true;
+ if (auto loweredType = lowerPairType(builder, pairType))
+ {
+ pendingReplacements.Add(pairType, loweredType);
+ modified = true;
+ }
}
});
for (auto replacement : pendingReplacements)
@@ -158,4 +163,104 @@ bool processPairTypes(AutoDiffSharedContext* context)
return pairLoweringPass.processModule();
}
+struct DifferentialPairUserCodeTranscribePass : public InstPassBase
+{
+ DifferentialPairUserCodeTranscribePass(IRModule* module)
+ :InstPassBase(module)
+ {}
+
+ IRInst* rewritePairType(IRBuilder* builder, IRType* pairType)
+ {
+ builder->setInsertBefore(pairType);
+ auto originalPairType = as<IRDifferentialPairType>(pairType);
+ return builder->getDifferentialPairUserCodeType(originalPairType->getValueType(), originalPairType->getWitness());
+ }
+
+ IRInst* rewriteMakePair(IRBuilder* builder, IRMakeDifferentialPair* inst)
+ {
+ auto pairType = as<IRDifferentialPairType>(inst->getFullType());
+ builder->setInsertBefore(inst);
+ auto newInst = builder->emitMakeDifferentialPairUserCode(
+ (IRType*)pairType, inst->getPrimalValue(), inst->getDifferentialValue());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ return newInst;
+ }
+
+ IRInst* rewritePairAccess(IRBuilder* builder, IRInst* inst)
+ {
+ if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
+ {
+ builder->setInsertBefore(inst);
+
+ auto newInst = builder->emitDifferentialPairGetDifferentialUserCode(
+ (IRType*)inst->getFullType(), getDiffInst->getBase());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ }
+ else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
+ {
+ builder->setInsertBefore(inst);
+ auto newInst = builder->emitDifferentialPairGetPrimalUserCode(getPrimalInst->getBase());
+ inst->replaceUsesWith(newInst);
+ inst->removeAndDeallocate();
+ }
+ return inst;
+ }
+
+ bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren)
+ {
+ SLANG_UNUSED(instWithChildren);
+
+ bool modified = false;
+
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_DifferentialPairGetDifferential:
+ case kIROp_DifferentialPairGetPrimal:
+ rewritePairAccess(builder, inst);
+ break;
+
+ case kIROp_MakeDifferentialPair:
+ rewriteMakePair(builder, as<IRMakeDifferentialPair>(inst));
+ break;
+
+ default:
+ break;
+ }
+ });
+
+ OrderedDictionary<IRInst*, IRInst*> pendingReplacements;
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
+ {
+ if (auto loweredType = rewritePairType(builder, inst))
+ {
+ pendingReplacements.Add(inst, loweredType);
+ modified = true;
+ }
+ });
+ for (auto replacement : pendingReplacements)
+ {
+ replacement.Key->replaceUsesWith(replacement.Value);
+ replacement.Key->removeAndDeallocate();
+ }
+
+ return modified;
+ }
+
+ bool processModule()
+ {
+ IRBuilder builder(module);
+ return processInstWithChildren(&builder, module->getModuleInst());
+ }
+};
+
+void rewriteDifferentialPairToUserCode(IRModule* module)
+{
+ DifferentialPairUserCodeTranscribePass pairRewritePass(module);
+ pairRewritePass.processModule();
+}
+
}