From 87f00a36a123e36b415eeea82e02a8366cc5b881 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:16:24 +0530 Subject: [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866) * Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-pairs.cpp | 95 +++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 19 deletions(-) (limited to 'source/slang/slang-ir-autodiff-pairs.cpp') diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index a49a2f762..c732263f0 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -13,7 +13,6 @@ struct DiffPairLoweringPass : InstPassBase IRInst* lowerPairType(IRBuilder* builder, IRType* pairType) { - builder->setInsertBefore(pairType); auto loweredPairType = pairBuilder->lowerDiffPairType(builder, pairType); return loweredPairType; } @@ -22,26 +21,81 @@ struct DiffPairLoweringPass : InstPassBase { if (auto makePairInst = as(inst)) { - bool isTrivial = false; auto pairType = as(makePairInst->getDataType()); - if (auto loweredPairType = lowerPairType(builder, pairType)) + builder->setInsertBefore(makePairInst); + if (auto loweredPairType = (IRType*)lowerPairType(builder, pairType)) { - builder->setInsertBefore(makePairInst); - IRInst* result = nullptr; - if (isTrivial) + if (isRuntimeType(pairType->getValueType())) { - result = makePairInst->getPrimalValue(); + auto result = pairBuilder->emitExistentialMakePair( + builder, + loweredPairType, + makePairInst->getPrimalValue(), + makePairInst->getDifferentialValue()); + + makePairInst->replaceUsesWith(result); + makePairInst->removeAndDeallocate(); + return result; + } + else if (auto typePack = as(pairType->getValueType())) + { + // TODO: Do we need to flatten the packs here? + + // If the type is a type pack, then the value must be in + // MakePair(MakeValuePack(p_0, p_1, ...), MakeValuePack(d_0, d_1, ...)) form + // Convert it to MakeValuePack(MakePair(p_0, d_0), MakePair(p_1, d_1), ...) + // and lower each MakePair. + // + + // Primal pack + auto primalValue = as(makePairInst->getPrimalValue()); + SLANG_ASSERT(primalValue); + + // Differential pack + auto diffValue = as(makePairInst->getDifferentialValue()); + SLANG_ASSERT(diffValue); + + // Expect the lowered pair type to be a type pack of pair types. + SLANG_ASSERT(as(loweredPairType)); + + List newValues; + for (UInt i = 0; i < typePack->getOperandCount(); i++) + { + auto primalElement = primalValue->getOperand(i); + auto diffElement = diffValue->getOperand(i); + + auto loweredElementPairType = (IRType*)loweredPairType->getOperand(i); + + IRInst* operands[] = {primalElement, diffElement}; + + auto loweredMakePair = + builder->emitMakeStruct((IRType*)loweredElementPairType, 2, operands); + + newValues.add(loweredMakePair); + } + + auto newPack = builder->emitMakeValuePack( + loweredPairType, + newValues.getCount(), + newValues.getBuffer()); + + makePairInst->replaceUsesWith(newPack); + makePairInst->removeAndDeallocate(); + return newPack; } else { + IRInst* result = nullptr; + IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue()}; result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands); + + makePairInst->replaceUsesWith(result); + makePairInst->removeAndDeallocate(); + return result; } - makePairInst->replaceUsesWith(result); - makePairInst->removeAndDeallocate(); - return result; } } @@ -58,12 +112,14 @@ struct DiffPairLoweringPass : InstPassBase pairType = pairPtrType->getValueType(); } - if (lowerPairType(builder, pairType)) + builder->setInsertBefore(getDiffInst); + if (auto loweredType = lowerPairType(builder, pairType)) { - builder->setInsertBefore(getDiffInst); IRInst* diffFieldExtract = nullptr; - diffFieldExtract = - pairBuilder->emitDiffFieldAccess(builder, getDiffInst->getBase()); + diffFieldExtract = pairBuilder->emitDiffFieldAccess( + builder, + (IRType*)loweredType, + getDiffInst->getBase()); getDiffInst->replaceUsesWith(diffFieldExtract); getDiffInst->removeAndDeallocate(); return diffFieldExtract; @@ -77,13 +133,14 @@ struct DiffPairLoweringPass : InstPassBase pairType = pairPtrType->getValueType(); } - if (lowerPairType(builder, pairType)) + builder->setInsertBefore(getPrimalInst); + if (auto loweredType = lowerPairType(builder, pairType)) { - builder->setInsertBefore(getPrimalInst); - IRInst* primalFieldExtract = nullptr; - primalFieldExtract = - pairBuilder->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + primalFieldExtract = pairBuilder->emitPrimalFieldAccess( + builder, + (IRType*)loweredType, + getPrimalInst->getBase()); getPrimalInst->replaceUsesWith(primalFieldExtract); getPrimalInst->removeAndDeallocate(); return primalFieldExtract; -- cgit v1.2.3