summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-pairs.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-01-10 03:16:24 +0530
committerGitHub <noreply@github.com>2025-01-09 13:46:24 -0800
commit87f00a36a123e36b415eeea82e02a8366cc5b881 (patch)
tree719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-pairs.cpp
parent6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff)
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-pairs.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-pairs.cpp95
1 files changed, 76 insertions, 19 deletions
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp
index a49a2f762..c732263f0 100644
--- a/source/slang/slang-ir-autodiff-pairs.cpp
+++ b/source/slang/slang-ir-autodiff-pairs.cpp
@@ -13,7 +13,6 @@ struct DiffPairLoweringPass : InstPassBase
IRInst* lowerPairType(IRBuilder* builder, IRType* pairType)
{
- builder->setInsertBefore(pairType);
auto loweredPairType = pairBuilder->lowerDiffPairType(builder, pairType);
return loweredPairType;
}
@@ -22,26 +21,81 @@ struct DiffPairLoweringPass : InstPassBase
{
if (auto makePairInst = as<IRMakeDifferentialPairBase>(inst))
{
- bool isTrivial = false;
auto pairType = as<IRDifferentialPairTypeBase>(makePairInst->getDataType());
- if (auto loweredPairType = lowerPairType(builder, pairType))
+ builder->setInsertBefore(makePairInst);
+ if (auto loweredPairType = (IRType*)lowerPairType(builder, pairType))
{
- builder->setInsertBefore(makePairInst);
- IRInst* result = nullptr;
- if (isTrivial)
+ if (isRuntimeType(pairType->getValueType()))
{
- result = makePairInst->getPrimalValue();
+ auto result = pairBuilder->emitExistentialMakePair(
+ builder,
+ loweredPairType,
+ makePairInst->getPrimalValue(),
+ makePairInst->getDifferentialValue());
+
+ makePairInst->replaceUsesWith(result);
+ makePairInst->removeAndDeallocate();
+ return result;
+ }
+ else if (auto typePack = as<IRTypePack>(pairType->getValueType()))
+ {
+ // TODO: Do we need to flatten the packs here?
+
+ // If the type is a type pack, then the value must be in
+ // MakePair(MakeValuePack(p_0, p_1, ...), MakeValuePack(d_0, d_1, ...)) form
+ // Convert it to MakeValuePack(MakePair(p_0, d_0), MakePair(p_1, d_1), ...)
+ // and lower each MakePair.
+ //
+
+ // Primal pack
+ auto primalValue = as<IRMakeValuePack>(makePairInst->getPrimalValue());
+ SLANG_ASSERT(primalValue);
+
+ // Differential pack
+ auto diffValue = as<IRMakeValuePack>(makePairInst->getDifferentialValue());
+ SLANG_ASSERT(diffValue);
+
+ // Expect the lowered pair type to be a type pack of pair types.
+ SLANG_ASSERT(as<IRTypePack>(loweredPairType));
+
+ List<IRInst*> newValues;
+ for (UInt i = 0; i < typePack->getOperandCount(); i++)
+ {
+ auto primalElement = primalValue->getOperand(i);
+ auto diffElement = diffValue->getOperand(i);
+
+ auto loweredElementPairType = (IRType*)loweredPairType->getOperand(i);
+
+ IRInst* operands[] = {primalElement, diffElement};
+
+ auto loweredMakePair =
+ builder->emitMakeStruct((IRType*)loweredElementPairType, 2, operands);
+
+ newValues.add(loweredMakePair);
+ }
+
+ auto newPack = builder->emitMakeValuePack(
+ loweredPairType,
+ newValues.getCount(),
+ newValues.getBuffer());
+
+ makePairInst->replaceUsesWith(newPack);
+ makePairInst->removeAndDeallocate();
+ return newPack;
}
else
{
+ IRInst* result = nullptr;
+
IRInst* operands[2] = {
makePairInst->getPrimalValue(),
makePairInst->getDifferentialValue()};
result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands);
+
+ makePairInst->replaceUsesWith(result);
+ makePairInst->removeAndDeallocate();
+ return result;
}
- makePairInst->replaceUsesWith(result);
- makePairInst->removeAndDeallocate();
- return result;
}
}
@@ -58,12 +112,14 @@ struct DiffPairLoweringPass : InstPassBase
pairType = pairPtrType->getValueType();
}
- if (lowerPairType(builder, pairType))
+ builder->setInsertBefore(getDiffInst);
+ if (auto loweredType = lowerPairType(builder, pairType))
{
- builder->setInsertBefore(getDiffInst);
IRInst* diffFieldExtract = nullptr;
- diffFieldExtract =
- pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase());
+ diffFieldExtract = pairBuilder->emitDiffFieldAccess(
+ builder,
+ (IRType*)loweredType,
+ getDiffInst->getBase());
getDiffInst->replaceUsesWith(diffFieldExtract);
getDiffInst->removeAndDeallocate();
return diffFieldExtract;
@@ -77,13 +133,14 @@ struct DiffPairLoweringPass : InstPassBase
pairType = pairPtrType->getValueType();
}
- if (lowerPairType(builder, pairType))
+ builder->setInsertBefore(getPrimalInst);
+ if (auto loweredType = lowerPairType(builder, pairType))
{
- builder->setInsertBefore(getPrimalInst);
-
IRInst* primalFieldExtract = nullptr;
- primalFieldExtract =
- pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
+ primalFieldExtract = pairBuilder->emitPrimalFieldAccess(
+ builder,
+ (IRType*)loweredType,
+ getPrimalInst->getBase());
getPrimalInst->replaceUsesWith(primalFieldExtract);
getPrimalInst->removeAndDeallocate();
return primalFieldExtract;