diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-01-10 03:16:24 +0530 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-09 13:46:24 -0800 |
| commit | 87f00a36a123e36b415eeea82e02a8366cc5b881 (patch) | |
| tree | 719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-pairs.cpp | |
| parent | 6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff) | |
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch
* More fixes
* remove intermediate dumps
* Update slang-ast-type.h
* More fixes + add a workaround for existential no-diff
* Update reverse-control-flow-3.slang
* remove dumps
* remove more dumps
* Delete working-reverse-control-flow-3.hlsl
* Cleanup comments + unused variables
* More comment cleanup
* Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)`
* Fix array of issues in Falcor tests.
* Update slang-ir-autodiff-pairs.cpp
* More fixes for Falcor image tests
* Small fixups.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-pairs.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-pairs.cpp | 95 |
1 files changed, 76 insertions, 19 deletions
diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index a49a2f762..c732263f0 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -13,7 +13,6 @@ struct DiffPairLoweringPass : InstPassBase IRInst* lowerPairType(IRBuilder* builder, IRType* pairType) { - builder->setInsertBefore(pairType); auto loweredPairType = pairBuilder->lowerDiffPairType(builder, pairType); return loweredPairType; } @@ -22,26 +21,81 @@ struct DiffPairLoweringPass : InstPassBase { if (auto makePairInst = as<IRMakeDifferentialPairBase>(inst)) { - bool isTrivial = false; auto pairType = as<IRDifferentialPairTypeBase>(makePairInst->getDataType()); - if (auto loweredPairType = lowerPairType(builder, pairType)) + builder->setInsertBefore(makePairInst); + if (auto loweredPairType = (IRType*)lowerPairType(builder, pairType)) { - builder->setInsertBefore(makePairInst); - IRInst* result = nullptr; - if (isTrivial) + if (isRuntimeType(pairType->getValueType())) { - result = makePairInst->getPrimalValue(); + auto result = pairBuilder->emitExistentialMakePair( + builder, + loweredPairType, + makePairInst->getPrimalValue(), + makePairInst->getDifferentialValue()); + + makePairInst->replaceUsesWith(result); + makePairInst->removeAndDeallocate(); + return result; + } + else if (auto typePack = as<IRTypePack>(pairType->getValueType())) + { + // TODO: Do we need to flatten the packs here? + + // If the type is a type pack, then the value must be in + // MakePair(MakeValuePack(p_0, p_1, ...), MakeValuePack(d_0, d_1, ...)) form + // Convert it to MakeValuePack(MakePair(p_0, d_0), MakePair(p_1, d_1), ...) + // and lower each MakePair. + // + + // Primal pack + auto primalValue = as<IRMakeValuePack>(makePairInst->getPrimalValue()); + SLANG_ASSERT(primalValue); + + // Differential pack + auto diffValue = as<IRMakeValuePack>(makePairInst->getDifferentialValue()); + SLANG_ASSERT(diffValue); + + // Expect the lowered pair type to be a type pack of pair types. + SLANG_ASSERT(as<IRTypePack>(loweredPairType)); + + List<IRInst*> newValues; + for (UInt i = 0; i < typePack->getOperandCount(); i++) + { + auto primalElement = primalValue->getOperand(i); + auto diffElement = diffValue->getOperand(i); + + auto loweredElementPairType = (IRType*)loweredPairType->getOperand(i); + + IRInst* operands[] = {primalElement, diffElement}; + + auto loweredMakePair = + builder->emitMakeStruct((IRType*)loweredElementPairType, 2, operands); + + newValues.add(loweredMakePair); + } + + auto newPack = builder->emitMakeValuePack( + loweredPairType, + newValues.getCount(), + newValues.getBuffer()); + + makePairInst->replaceUsesWith(newPack); + makePairInst->removeAndDeallocate(); + return newPack; } else { + IRInst* result = nullptr; + IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue()}; result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands); + + makePairInst->replaceUsesWith(result); + makePairInst->removeAndDeallocate(); + return result; } - makePairInst->replaceUsesWith(result); - makePairInst->removeAndDeallocate(); - return result; } } @@ -58,12 +112,14 @@ struct DiffPairLoweringPass : InstPassBase pairType = pairPtrType->getValueType(); } - if (lowerPairType(builder, pairType)) + builder->setInsertBefore(getDiffInst); + if (auto loweredType = lowerPairType(builder, pairType)) { - builder->setInsertBefore(getDiffInst); IRInst* diffFieldExtract = nullptr; - diffFieldExtract = - pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase()); + diffFieldExtract = pairBuilder->emitDiffFieldAccess( + builder, + (IRType*)loweredType, + getDiffInst->getBase()); getDiffInst->replaceUsesWith(diffFieldExtract); getDiffInst->removeAndDeallocate(); return diffFieldExtract; @@ -77,13 +133,14 @@ struct DiffPairLoweringPass : InstPassBase pairType = pairPtrType->getValueType(); } - if (lowerPairType(builder, pairType)) + builder->setInsertBefore(getPrimalInst); + if (auto loweredType = lowerPairType(builder, pairType)) { - builder->setInsertBefore(getPrimalInst); - IRInst* primalFieldExtract = nullptr; - primalFieldExtract = - pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + primalFieldExtract = pairBuilder->emitPrimalFieldAccess( + builder, + (IRType*)loweredType, + getPrimalInst->getBase()); getPrimalInst->replaceUsesWith(primalFieldExtract); getPrimalInst->removeAndDeallocate(); return primalFieldExtract; |
