diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-03-15 22:26:58 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-15 19:26:58 -0700 |
| commit | 71efd949fa5276e2464416fcf237f8fd2c486281 (patch) | |
| tree | a5b24cd077f2ecc3f74d4dd4671c8260eb6e9b67 | |
| parent | 38e62199cc75ce34608491c8dd299eb330bde518 (diff) | |
AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702)
19 files changed, 1967 insertions, 753 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index e97d6a2b1..879da8cdf 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -346,7 +346,9 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-cfg-norm.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-fwd.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-primal-hoist.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-propagate.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-region.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-transcriber-base.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-transpose.h" />
@@ -534,6 +536,8 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-cfg-norm.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-fwd.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-primal-hoist.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-region.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-transcriber-base.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-unzip.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 64267db4b..2b223b78d 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -144,9 +144,15 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-pairs.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-primal-hoist.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-propagate.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-region.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-autodiff-rev.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -704,6 +710,12 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-pairs.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-primal-hoist.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-region.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-autodiff-rev.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 7057a5835..247c3ddde 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -308,6 +308,7 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* // differentiated, don't differentiate the inst // auto primalConstructType = (IRType*)findOrTranscribePrimalInst(builder, origConstruct->getDataType()); + // TODO: Need to update this to generate derivatives on a per-key basis if (auto diffConstructType = differentiateType(builder, primalConstructType)) { UCount operandCount = origConstruct->getOperandCount(); diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp new file mode 100644 index 000000000..793a8ff07 --- /dev/null +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -0,0 +1,674 @@ +#include "slang-ir-autodiff-primal-hoist.h" +#include "slang-ir-autodiff-region.h" + +namespace Slang +{ + +bool containsOperand(IRInst* inst, IRInst* operand) +{ + for (UIndex ii = 0; ii < inst->getOperandCount(); ii++) + if (inst->getOperand(ii) == operand) + return true; + + return false; +} + +RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* splitInfo) +{ + RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo(); + + RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); + + List<IRUse*> workList; + HashSet<IRUse*> processedUses; + + HashSet<IRUse*> usesToReplace; + + auto addPrimalOperandsToWorkList = [&](IRInst* inst) + { + UIndex opIndex = 0; + for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++) + { + if (!operand->get()->findDecoration<IRDifferentialInstDecoration>() && + !as<IRFunc>(operand->get()) && + !as<IRBlock>(operand->get()) && + !(as<IRModuleInst>(operand->get()->getParent())) && + !getBlock(operand->get())->findDecoration<IRDifferentialInstDecoration>()) + workList.add(operand); + } + + // Is the type itself computed within our function? + // If so, we'll need to consider that too (this is for existential types, specialize insts, etc) + // TODO: We might not really need to query the checkpointing algorithm for these + // since they _have_ to be classified as 'recompute' + // + if (inst->getDataType() && (getParentFunc(inst->getDataType()) == func)) + { + if (!getBlock(inst->getDataType())->findDecoration<IRDifferentialInstDecoration>()) + workList.add(&inst->typeUse); + } + }; + + // Populate recompute/store/invert sets with insts, by applying the policy + // to them. + // + for (auto block : func->getBlocks()) + { + // Skip parameter block. + if (block == func->getFirstBlock()) + continue; + + if (!block->findDecoration<IRDifferentialInstDecoration>()) + continue; + + for (auto child : block->getChildren()) + { + // Special case: Ignore the primals used to construct the return pair. + if (as<IRMakeDifferentialPair>(child) && + as<IRReturn>(child->firstUse->getUser())) + { + // quick check + SLANG_RELEASE_ASSERT(child->firstUse->nextUse == nullptr); + continue; + } + + addPrimalOperandsToWorkList(child); + + // We'll be conservative with the decorations we consider as differential uses + // of a primal inst, in order to avoid weird behaviour with some decorations + // + for (auto decoration : child->getDecorations()) + { + if (auto primalCtxDecoration = as<IRBackwardDerivativePrimalContextDecoration>(decoration)) + workList.add(&primalCtxDecoration->primalContextVar); + else if (auto loopExitDecoration = as<IRLoopExitPrimalValueDecoration>(decoration)) + workList.add(&loopExitDecoration->exitVal); + } + } + + addPrimalOperandsToWorkList(block->getTerminator()); + } + + while (workList.getCount() > 0) + { + auto use = workList.getLast(); + workList.removeLast(); + + if (processedUses.Contains(use)) + continue; + + processedUses.Add(use); + + HoistResult result = this->classify(use); + + if (result.mode == HoistResult::Mode::Store) + { + SLANG_ASSERT(!checkpointInfo->recomputeSet.Contains(result.instToStore)); + checkpointInfo->storeSet.Add(result.instToStore); + } + else if (result.mode == HoistResult::Mode::Recompute) + { + SLANG_ASSERT(!checkpointInfo->storeSet.Contains(result.instToRecompute)); + checkpointInfo->recomputeSet.Add(result.instToRecompute); + + if (use->getUser()->findDecoration<IRDifferentialInstDecoration>()) + usesToReplace.Add(use); + + if (auto param = as<IRParam>(result.instToRecompute)) + { + // Add in the branch-args of every predecessor block. + auto paramBlock = as<IRBlock>(param->getParent()); + UIndex paramIndex = 0; + for (auto _param : paramBlock->getParams()) + { + if (_param == param) break; + paramIndex ++; + } + + for (auto predecessor : paramBlock->getPredecessors()) + { + // If we hit this, the checkpoint policy is trying to recompute + // values across a loop region boundary (we don't currently support this, + // and in general this is quite inefficient in both compute & memory) + // + SLANG_RELEASE_ASSERT(!domTree->dominates(paramBlock, predecessor)); + + auto branchInst = as<IRUnconditionalBranch>(predecessor->getTerminator()); + SLANG_ASSERT(branchInst->getOperandCount() > paramIndex); + + workList.add(&branchInst->getOperands()[paramIndex]); + } + } + else + { + if (auto var = as<IRVar>(result.instToRecompute)) + { + IRUse* storeUse = findUniqueStoredVal(var); + if (!storeUse) + workList.add(storeUse); + } + else + { + addPrimalOperandsToWorkList(result.instToRecompute); + } + } + } + else if (result.mode == HoistResult::Mode::Invert) + { + auto instToInvert = result.inversionInfo.instToInvert; + + SLANG_RELEASE_ASSERT(containsOperand(instToInvert, use->getUser())); + SLANG_RELEASE_ASSERT(result.inversionInfo.targetInsts.contains(use->getUser())); + + if (use->getUser()->findDecoration<IRDifferentialInstDecoration>()) + usesToReplace.Add(use); + + checkpointInfo->invertSet.Add(instToInvert); + + if (checkpointInfo->invInfoMap.ContainsKey(instToInvert)) + { + List<IRInst*> currOperands = checkpointInfo->invInfoMap[instToInvert].GetValue().requiredOperands; + for (Index ii = 0; ii < result.inversionInfo.requiredOperands.getCount(); ii++) + { + SLANG_RELEASE_ASSERT(result.inversionInfo.requiredOperands[ii] == currOperands[ii]); + } + } + else + checkpointInfo->invInfoMap[instToInvert] = result.inversionInfo; + } + } + + return applyCheckpointSet(checkpointInfo, func, splitInfo, usesToReplace); +} + +void applyToInst( + IRBuilder* builder, + CheckpointSetInfo* checkpointInfo, + HoistedPrimalsInfo* hoistInfo, + IROutOfOrderCloneContext* cloneCtx, + IRInst* inst) +{ + // Early-out.. + if (checkpointInfo->storeSet.Contains(inst)) + { + hoistInfo->storeSet.Add(inst); + return; + } + + bool isInstRecomputed = checkpointInfo->recomputeSet.Contains(inst); + if (isInstRecomputed) + { + if (as<IRParam>(inst)) + { + // Can completely ignore first block parameters + if (getBlock(inst) != getBlock(inst)->getParent()->getFirstBlock()) + { + // TODO: We would need to clone in the control-flow for each region (without nested loops) + // prior to this, and then hoist this parameter into the within-region block, otherwise + // this parameter will not be visible to transposed insts. + // This will also include adding an extra case to 'ensurePrimalAvailability': if both insts + // are withing the _same_ indexed region, skip the indexed store/load and use a simple var. + // + SLANG_UNIMPLEMENTED_X("Parameter recompute is not currently supported"); + } + } + else + { + hoistInfo->recomputeSet.Add(cloneCtx->cloneInstOutOfOrder(builder, inst)); + } + } + + bool isInstInverted = checkpointInfo->invertSet.Contains(inst); + if (isInstInverted) + { + InversionInfo info = checkpointInfo->invInfoMap[inst]; + auto clonedInstToInvert = cloneCtx->cloneInstOutOfOrder(builder, info.instToInvert); + + // Process operand set for the inverse inst. + List<IRInst*> newOperands; + for (auto operand : info.requiredOperands) + { + if (cloneCtx->cloneEnv.mapOldValToNew.ContainsKey(operand)) + newOperands.add(cloneCtx->cloneEnv.mapOldValToNew[operand]); + else + newOperands.add(operand); + } + + info.requiredOperands = newOperands; + + hoistInfo->invertInfoMap[clonedInstToInvert] = info; + hoistInfo->instsToInvert.Add(clonedInstToInvert); + hoistInfo->invertSet.Add(cloneCtx->cloneInstOutOfOrder(builder, inst)); + } +} + +RefPtr<HoistedPrimalsInfo> applyCheckpointSet( + CheckpointSetInfo* checkpointInfo, + IRGlobalValueWithCode* func, + BlockSplitInfo* splitInfo, + HashSet<IRUse*> pendingUses) +{ + RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo(); + + RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext(); + + for (auto use : pendingUses) + cloneCtx->pendingUses.Add(use); + + // Populate the clone context with all the primal uses that we may need to replace with + // cloned versions. That way any insts we clone into the diff block will automatically replace + // their uses. + // + auto addPrimalUsesToCloneContext = [&](IRInst* inst) + { + UIndex opIndex = 0; + for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++) + { + if (!operand->get()->findDecoration<IRDifferentialInstDecoration>()) + cloneCtx->pendingUses.Add(operand); + } + }; + + // Go back over the insts and move/clone them accoridngly. + for (auto block : func->getBlocks()) + { + // Skip parameter block. + if (block == func->getFirstBlock()) + continue; + + if (block->findDecoration<IRDifferentialInstDecoration>()) + continue; + + auto diffBlock = as<IRBlock>(splitInfo->diffBlockMap[block]); + + auto firstDiffInst = as<IRBlock>(splitInfo->diffBlockMap[block])->getFirstOrdinaryInst(); + + IRBuilder builder(func->getModule()); + + UIndex ii = 0; + for (auto param : block->getParams()) + { + builder.setInsertBefore(diffBlock->getFirstOrdinaryInst()); + + // Apply checkpoint rule to the parameter itself. + applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, param); + + // Copy primal branch-arg for predecessor blocks. + HashSet<IRBlock*> predecessorSet; + for (auto predecessor : block->getPredecessors()) + { + if (predecessorSet.Contains(predecessor)) + continue; + + predecessorSet.Add(predecessor); + + auto diffPredecessor = as<IRBlock>(splitInfo->diffBlockMap[block]); + + if (checkpointInfo->recomputeSet.Contains(param)) + addPhiOutputArg(&builder, + diffPredecessor, + as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii)); + + if (checkpointInfo->invertSet.Contains(param)) + addPhiOutputArg(&builder, + diffPredecessor, + as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii)); + } + + ii++; + } + + for (auto child : block->getChildren()) + { + builder.setInsertBefore(firstDiffInst); + + applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child); + } + } + + return hoistInfo; +} + +IRType* getTypeForLocalStorage( + IRBuilder* builder, + IRType* storageType, + List<IndexTrackingInfo*> defBlockIndices) +{ + for (auto index : defBlockIndices) + { + SLANG_ASSERT(index->status == IndexTrackingInfo::CountStatus::Static); + SLANG_ASSERT(index->maxIters >= 0); + + storageType = builder->getArrayType( + storageType, + builder->getIntValue( + builder->getUIntType(), + index->maxIters + 1)); + } + + return storageType; +} + +IRVar* emitIndexedLocalVar( + IRBlock* varBlock, + IRType* baseType, + List<IndexTrackingInfo*> defBlockIndices) +{ + SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType)); + + IRBuilder varBuilder(varBlock->getModule()); + varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst()); + + IRType* varType = getTypeForLocalStorage(&varBuilder, baseType, defBlockIndices); + + auto var = varBuilder.emitVar(varType); + varBuilder.emitStore(var, varBuilder.emitDefaultConstruct(varType)); + + return var; +} + +IRInst* emitIndexedStoreAddressForVar( + IRBuilder* builder, + IRVar* localVar, + List<IndexTrackingInfo*> defBlockIndices) +{ + IRInst* storeAddr = localVar; + IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType(); + + for (auto index : defBlockIndices) + { + currType = as<IRArrayType>(currType)->getElementType(); + + storeAddr = builder->emitElementAddress( + builder->getPtrType(currType), + storeAddr, + index->primalCountParam); + } + + return storeAddr; +} + + +IRInst* emitIndexedLoadAddressForVar( + IRBuilder* builder, + IRVar* localVar, + List<IndexTrackingInfo*> defBlockIndices, + List<IndexTrackingInfo*> useBlockIndices) +{ + IRInst* loadAddr = localVar; + IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType(); + + for (auto index : defBlockIndices) + { + currType = as<IRArrayType>(currType)->getElementType(); + if (useBlockIndices.contains(index)) + { + // If the use-block is under the same region, use the + // differential counter variable + // + auto diffCounterCurrValue = index->diffCountParam; + + loadAddr = builder->emitElementAddress( + builder->getPtrType(currType), + loadAddr, + diffCounterCurrValue); + } + else + { + // If the use-block is outside this region, use the + // last available value (by indexing with primal counter minus 1) + // + auto primalCounterCurrValue = builder->emitLoad(index->primalCountLastVar); + auto primalCounterLastValue = builder->emitSub( + primalCounterCurrValue->getDataType(), + primalCounterCurrValue, + builder->getIntValue(builder->getIntType(), 1)); + + loadAddr = builder->emitElementAddress( + builder->getPtrType(currType), + loadAddr, + primalCounterLastValue); + } + } + + return loadAddr; +} + +IRVar* storeIndexedValue( + IRBuilder* builder, + IRBlock* defaultVarBlock, + IRInst* instToStore, + List<IndexTrackingInfo*> defBlockIndices) +{ + IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices); + + IRInst* addr = emitIndexedStoreAddressForVar(builder, localVar, defBlockIndices); + + builder->emitStore(addr, instToStore); + + return localVar; +} + +IRInst* loadIndexedValue( + IRBuilder* builder, + IRVar* localVar, + List<IndexTrackingInfo*> defBlockIndices, + List<IndexTrackingInfo*> useBlockIndices) +{ + IRInst* addr = emitIndexedLoadAddressForVar(builder, localVar, defBlockIndices, useBlockIndices); + + return builder->emitLoad(addr); +} + +bool areIndicesEqual( + List<IndexTrackingInfo*> indicesA, + List<IndexTrackingInfo*> indicesB) +{ + if (indicesA.getCount() != indicesB.getCount()) + return false; + + for (Index ii = 0; ii < indicesA.getCount(); ii++) + { + if (indicesA[ii] != indicesB[ii]) + return false; + } + + return true; +} + +bool areIndicesSubsetOf( + List<IndexTrackingInfo*> indicesA, + List<IndexTrackingInfo*> indicesB) +{ + if (indicesA.getCount() > indicesB.getCount()) + return false; + + for (Index ii = 0; ii < indicesA.getCount(); ii++) + { + if (indicesA[ii] != indicesB[ii]) + return false; + } + + return true; +} + + +bool isDifferentialBlock(IRBlock* block) +{ + return block->findDecoration<IRDifferentialInstDecoration>(); +} + +RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( + HoistedPrimalsInfo* hoistInfo, + IRGlobalValueWithCode* func, + Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo) +{ + RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); + + IRBuilder builder(func->getModule()); + IRBlock* defaultVarBlock = func->getFirstBlock()->getNextBlock(); + + SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock)); + + HashSet<IRInst*> processedStoreSet; + + // TODO: Also ensure availability of everything in the recompute set (for proper recompute support) + for (auto instToStore : hoistInfo->storeSet) + { + IRBlock* defBlock = nullptr; + if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) + { + auto varInst = as<IRVar>(instToStore); + auto storeUse = findUniqueStoredVal(varInst); + + defBlock = getBlock(storeUse->getUser()); + } + else + defBlock = getBlock(instToStore); + + SLANG_RELEASE_ASSERT(defBlock); + + List<IRUse*> outOfScopeUses; + for (auto use = instToStore->firstUse; use;) + { + auto nextUse = use->nextUse; + + // Only consider uses in differential blocks. + // This method is not responsible for other blocks. + // + IRBlock* userBlock = getBlock(use->getUser()); + if (userBlock->findDecoration<IRDifferentialInstDecoration>()) + { + if (!domTree->dominates(defBlock, userBlock)) + { + outOfScopeUses.add(use); + } + else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock])) + { + outOfScopeUses.add(use); + } + else if (indexedBlockInfo[defBlock].GetValue().getCount() > 0 && + !isDifferentialBlock(defBlock)) + { + outOfScopeUses.add(use); + } + else if (as<IRPtrTypeBase>(instToStore->getDataType()) && + !isDifferentialBlock(defBlock)) + { + outOfScopeUses.add(use); + } + } + + use = nextUse; + } + + if (outOfScopeUses.getCount() == 0) + { + processedStoreSet.Add(instToStore); + continue; + } + + if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) + { + + IRVar* varToStore = as<IRVar>(instToStore); + SLANG_RELEASE_ASSERT(varToStore); + + auto storeUse = findUniqueStoredVal(varToStore); + + List<IndexTrackingInfo*> defBlockIndices = indexedBlockInfo[defBlock]; + + bool isIndexedStore = (storeUse && defBlockIndices.getCount() > 0); + + // TODO: There's a slight hackiness here. (Ideally we might just want to emit + // additional vars when splitting a call) + // + if (!isIndexedStore && isDerivativeContextVar(varToStore)) + { + varToStore->insertBefore(defaultVarBlock->getFirstOrdinaryInst()); + processedStoreSet.Add(varToStore); + continue; + } + + setInsertAfterOrdinaryInst(&builder, getInstInBlock(storeUse->getUser())); + + IRVar* localVar = storeIndexedValue( + &builder, + defaultVarBlock, + builder.emitLoad(varToStore), + defBlockIndices); + + for (auto use : outOfScopeUses) + { + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); + + List<IndexTrackingInfo*> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())]; + + IRInst* loadAddr = emitIndexedLoadAddressForVar(&builder, localVar, defBlockIndices, useBlockIndices); + builder.replaceOperand(use, loadAddr); + } + + processedStoreSet.Add(localVar); + } + else + { + setInsertAfterOrdinaryInst(&builder, instToStore); + + List<IndexTrackingInfo*> defBlockIndices = indexedBlockInfo[defBlock]; + auto localVar = storeIndexedValue(&builder, defaultVarBlock, instToStore, defBlockIndices); + + for (auto use : outOfScopeUses) + { + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); + + List<IndexTrackingInfo*> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())]; + builder.replaceOperand(use, loadIndexedValue(&builder, localVar, defBlockIndices, useBlockIndices)); + } + + processedStoreSet.Add(localVar); + } + } + + // Replace the old store set with the processed onne one. + hoistInfo->storeSet = processedStoreSet; + + return hoistInfo; +} + +void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode*) +{ + // Do nothing.. This is an (almost) always-store policy. + return; +} + +HoistResult DefaultCheckpointPolicy::classify(IRUse* use) +{ + // Store all that we can.. by default, classify will only be called on relevant differential + // uses (or on uses in a 'recompute' inst) + // + if (auto var = as<IRVar>(use->get())) + { + if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType())) + { + for (UInt i = 0; i < spec->getArgCount(); i++) + { + if (!canTypeBeStored(spec->getArg(i)->getDataType())) + return HoistResult::recompute(use->get()); + } + return HoistResult::store(use->get()); + } + else // if (canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType())); + { + return HoistResult::store(use->get()); + } + } + else + { + if (canTypeBeStored(use->get()->getDataType())) + return HoistResult::store(use->get()); + else + return HoistResult::recompute(use->get()); + } +} + +};
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h new file mode 100644 index 000000000..dc85942f6 --- /dev/null +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -0,0 +1,264 @@ +// slang-ir-autodiff-primal-hoist.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-autodiff.h" +#include "slang-ir-autodiff-region.h" +#include "slang-ir-dominators.h" + + +namespace Slang +{ + struct IROutOfOrderCloneContext : public RefObject + { + IRCloneEnv cloneEnv; + HashSet<IRUse*> pendingUses; + + IRInst* cloneInstOutOfOrder(IRBuilder* builder, IRInst* inst) + { + IRInst* clonedInst = cloneInst(&cloneEnv, builder, inst); + + UInt operandCount = clonedInst->getOperandCount(); + for (UInt ii = 0; ii < operandCount; ++ii) + { + auto oldOperand = inst->getOperand(ii); + auto newOperand = clonedInst->getOperand(ii); + + if (oldOperand == newOperand) + pendingUses.Add(&clonedInst->getOperands()[ii]); + } + + for (auto use = inst->firstUse; use;) + { + auto nextUse = use->nextUse; + + if (pendingUses.Contains(use)) + { + pendingUses.Remove(use); + builder->replaceOperand(use, clonedInst); + } + + use = nextUse; + } + + return clonedInst; + } + }; + + struct InversionInfo + { + IRInst* instToInvert; + List<IRInst*> requiredOperands; + List<IRInst*> targetInsts; + + InversionInfo( + IRInst* instToInvert, + List<IRInst*> requiredOperands, + List<IRInst*> targetInsts) : + instToInvert(instToInvert), + requiredOperands(requiredOperands), + targetInsts(targetInsts) + { } + + InversionInfo() : instToInvert(nullptr) + { } + + InversionInfo applyMap(IRCloneEnv* env) + { + InversionInfo newInfo; + if (env->mapOldValToNew.ContainsKey(instToInvert)) + newInfo.instToInvert = env->mapOldValToNew[instToInvert]; + + for (auto inst : requiredOperands) + if (env->mapOldValToNew.ContainsKey(inst)) + newInfo.requiredOperands.add(env->mapOldValToNew[inst]); + + for (auto inst : targetInsts) + if (env->mapOldValToNew.ContainsKey(inst)) + newInfo.targetInsts.add(env->mapOldValToNew[inst]); + + return newInfo; + } + }; + + struct HoistedPrimalsInfo : public RefObject + { + HashSet<IRInst*> storeSet; + HashSet<IRInst*> recomputeSet; + HashSet<IRInst*> invertSet; + + HashSet<IRInst*> instsToInvert; + + Dictionary<IRInst*, InversionInfo> invertInfoMap; + + RefPtr<HoistedPrimalsInfo> applyMap(IRCloneEnv* env) + { + RefPtr<HoistedPrimalsInfo> newPrimalsInfo = new HoistedPrimalsInfo(); + + for (auto inst : this->storeSet) + if (env->mapOldValToNew.ContainsKey(inst)) + newPrimalsInfo->storeSet.Add(env->mapOldValToNew[inst]); + + for (auto inst : this->recomputeSet) + if (env->mapOldValToNew.ContainsKey(inst)) + newPrimalsInfo->recomputeSet.Add(env->mapOldValToNew[inst]); + + for (auto inst : this->invertSet) + if (env->mapOldValToNew.ContainsKey(inst)) + newPrimalsInfo->invertSet.Add(env->mapOldValToNew[inst]); + + for (auto inst : this->instsToInvert) + if (env->mapOldValToNew.ContainsKey(inst)) + newPrimalsInfo->instsToInvert.Add(env->mapOldValToNew[inst]); + + for (auto kvpair : this->invertInfoMap) + if (env->mapOldValToNew.ContainsKey(kvpair.Key)) + newPrimalsInfo->invertInfoMap[env->mapOldValToNew[kvpair.Key]] = kvpair.Value.applyMap(env); + + return newPrimalsInfo; + } + + void merge(HoistedPrimalsInfo* info) + { + for (auto inst : info->storeSet) + storeSet.Add(inst); + + for (auto inst : info->recomputeSet) + recomputeSet.Add(inst); + + for (auto inst : info->invertSet) + invertSet.Add(inst); + + for (auto inst : info->instsToInvert) + instsToInvert.Add(inst); + + for (auto kvpair : info->invertInfoMap) + invertInfoMap[kvpair.Key] = kvpair.Value; + } + }; + + struct HoistResult + { + enum Mode + { + Store, + Recompute, + Invert, + + None + }; + + Mode mode; + + IRInst* instToStore = nullptr; + IRInst* instToRecompute = nullptr; + InversionInfo inversionInfo; + + HoistResult(Mode mode, IRInst* target) : + mode(mode) + { + switch (mode) + { + case Mode::Store: + instToStore = target; + break; + case Mode::Recompute: + instToRecompute = target; + break; + case Mode::Invert: + SLANG_UNEXPECTED("Wrong constructor for HoistResult::Mode::Invert"); + break; + default: + SLANG_UNEXPECTED("Unhandled hoist mode"); + break; + } + } + + HoistResult(InversionInfo info) : + mode(Mode::Invert), inversionInfo(info) + { } + + static HoistResult store(IRInst* inst) + { + return HoistResult(Mode::Store, inst); + } + + static HoistResult recompute(IRInst* inst) + { + return HoistResult(Mode::Recompute, inst); + } + + static HoistResult invert(InversionInfo inst) + { + return HoistResult(inst); + } + }; + + + // Information on which insts are to be stored, recomputed + // and inverted within a single function. + // This data structure also holds a map of raw HoistResult + // objects to provide more information to later passes. + // + struct CheckpointSetInfo : public RefObject + { + HashSet<IRInst*> storeSet; + HashSet<IRInst*> recomputeSet; + HashSet<IRInst*> invertSet; + + Dictionary<IRInst*, InversionInfo> invInfoMap; + }; + + struct BlockSplitInfo : public RefObject + { + // Maps primal to differential blocks from the unzip step. + Dictionary<IRBlock*, IRBlock*> diffBlockMap; + }; + + class AutodiffCheckpointPolicyBase : public RefObject + { + public: + + AutodiffCheckpointPolicyBase(IRModule* module) : module(module) + { } + + RefPtr<HoistedPrimalsInfo> processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* info); + + // Do pre-processing on the function (mainly for + // 'global' checkpointing methods that consider the entire + // function) + // + virtual void preparePolicy(IRGlobalValueWithCode* func) = 0; + + virtual HoistResult classify(IRUse* diffBlockUse) = 0; + + protected: + + IRModule* module; + }; + + class DefaultCheckpointPolicy : public AutodiffCheckpointPolicyBase + { + public: + + DefaultCheckpointPolicy(IRModule* module) + : AutodiffCheckpointPolicyBase(module) + { } + + virtual void preparePolicy(IRGlobalValueWithCode* func); + virtual HoistResult classify(IRUse* use); + }; + + RefPtr<HoistedPrimalsInfo> applyCheckpointSet( + CheckpointSetInfo* checkpointInfo, + IRGlobalValueWithCode* func, + BlockSplitInfo* splitInfo, + HashSet<IRUse*> pendingUses); + + RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( + HoistedPrimalsInfo* hoistInfo, + IRGlobalValueWithCode* func, + Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo); + +};
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-region.cpp b/source/slang/slang-ir-autodiff-region.cpp new file mode 100644 index 000000000..98b64f179 --- /dev/null +++ b/source/slang/slang-ir-autodiff-region.cpp @@ -0,0 +1,56 @@ +// slang-ir-autodiff-region.cpp +#include "slang-ir-autodiff-region.h" + +namespace Slang{ + RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func) + { + RefPtr<IndexedRegionMap> regionMap = new IndexedRegionMap; + + List<IRBlock*> workList; + + regionMap->mapBlock(func->getFirstBlock(), nullptr); + workList.add(func->getFirstBlock()); + + while (workList.getCount() > 0) + { + auto currentBlock = workList.getLast(); + workList.removeLast(); + + auto terminator = currentBlock->getTerminator(); + auto currentRegion = regionMap->getRegion(currentBlock); + + switch (terminator->getOp()) + { + case kIROp_loop: + { + auto loopRegion = regionMap->newRegion(as<IRLoop>(terminator), currentRegion); + auto condBlock = as<IRLoop>(terminator)->getTargetBlock(); + + regionMap->mapBlock(condBlock, loopRegion); + workList.add(condBlock); + + auto ifElse = as<IRIfElse>(condBlock->getTerminator()); + SLANG_RELEASE_ASSERT(ifElse); + + // TODO: this is one of the places we'll need to change if we support loops that + // loop on either the true or false side. For now, we assume the loop is on the + // true side only. + // + regionMap->mapBlock(ifElse->getFalseBlock(), currentRegion); + workList.add(ifElse->getFalseBlock()); + } + } + + for (auto successor : currentBlock->getSuccessors()) + { + // If already mapped, skip. + if (regionMap->hasMapping(successor)) + continue; + regionMap->mapBlock(successor, currentRegion); + workList.add(successor); + } + } + + return regionMap; + } +}; diff --git a/source/slang/slang-ir-autodiff-region.h b/source/slang/slang-ir-autodiff-region.h new file mode 100644 index 000000000..a4618e257 --- /dev/null +++ b/source/slang/slang-ir-autodiff-region.h @@ -0,0 +1,119 @@ +// slang-ir-autodiff-region.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-autodiff.h" + +namespace Slang +{ +struct IndexedRegion : public RefObject +{ + IRLoop* loop; + IndexedRegion* parent; + + IndexedRegion(IRLoop* loop, IndexedRegion* parent) : loop(loop), parent(parent) + { } + + IRBlock* getInitializerBlock() { return as<IRBlock>(loop->getParent()); } + IRBlock* getConditionBlock() + { + auto condBlock = as<IRBlock>(loop->getTargetBlock()); + SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator())); + return condBlock; + } + + IRBlock* getBreakBlock() { return loop->getBreakBlock(); } + + IRBlock* getUpdateBlock() + { + auto initBlock = getInitializerBlock(); + + auto condBlock = getConditionBlock(); + + IRBlock* lastLoopBlock = nullptr; + + for (auto predecessor : condBlock->getPredecessors()) + { + if (predecessor != initBlock) + lastLoopBlock = predecessor; + } + + // Should find atleast one predecessor that is _not_ the + // init block (that contains the loop info). This + // predecessor would be the last block in the loop + // before looping back to the condition. + // + SLANG_RELEASE_ASSERT(lastLoopBlock); + + return lastLoopBlock; + } +}; + +struct IndexTrackingInfo : public RefObject +{ + // After lowering, store references to the count + // variables associated with this region + // + IRInst* primalCountParam = nullptr; + IRInst* diffCountParam = nullptr; + + IRVar* primalCountLastVar = nullptr; + + enum CountStatus + { + Unresolved, + Dynamic, + Static + }; + + CountStatus status = CountStatus::Unresolved; + + // Inferred maximum number of iterations. + Count maxIters = -1; +}; + +struct IndexedRegionMap : public RefObject +{ + Dictionary<IRBlock*, IndexedRegion*> map; + List<RefPtr<IndexedRegion>> regions; + + IndexedRegion* newRegion(IRLoop* loop, IndexedRegion* parent) + { + auto region = new IndexedRegion(loop, parent); + regions.add(region); + + return region; + } + + void mapBlock(IRBlock* block, IndexedRegion* region) + { + map.Add(block, region); + } + + bool hasMapping(IRBlock* block) + { + return map.ContainsKey(block); + } + + IndexedRegion* getRegion(IRBlock* block) + { + return map[block]; + } + + List<IndexedRegion*> getAllAncestorRegions(IRBlock* block) + { + List<IndexedRegion*> regionList; + + IndexedRegion* region = getRegion(block); + for (; region; region = region->parent) + regionList.add(region); + + return regionList; + } +}; + +RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func); + + +};
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 328af4867..157011b7c 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -606,106 +606,6 @@ namespace Slang return fwdDiffFunc; } - void BackwardDiffTranscriberBase::insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc) - { - RefPtr<IRDominatorTree> domTree = computeDominatorTree(diffPropFunc); - auto firstBlock = diffPropFunc->getFirstBlock(); - if (!firstBlock) - return; - Dictionary<IRInst*, IRVar*> instVars; - Dictionary<IRBlock*, IRCloneEnv> cloneEnvs; - auto storeInstAsLocalVar = [&](IRInst* inst) - { - IRVar* var = nullptr; - if (instVars.TryGetValue(inst, var)) - return var; - IRBuilder builder(diffPropFunc); - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - var = builder.emitVar(inst->getDataType()); - builder.emitStore(var, builder.emitDefaultConstruct(inst->getDataType())); - - setInsertAfterOrdinaryInst(&builder, inst); - builder.emitStore(var, inst); - instVars[inst] = var; - return var; - }; - - IRBuilder builder(diffPropFunc); - List<IRInst*> workList; - for (auto block : diffPropFunc->getBlocks()) - { - if (!block->findDecoration<IRDifferentialInstDecoration>()) - continue; - cloneEnvs[block] = IRCloneEnv(); - for (auto inst : block->getChildren()) - { - workList.add(inst); - } - } - - for (Index i = 0; i < workList.getCount(); i++) - { - auto inst = workList[i]; - for (UInt j = 0; j < inst->getOperandCount(); j++) - { - auto operand = inst->getOperand(j); - if (operand->getOp() == kIROp_Block) - continue; - auto operandParent = inst->getOperand(j)->getParent(); - if (!operandParent) - continue; - if (operandParent->parent != diffPropFunc) - continue; - if (domTree->dominates(operandParent, inst->parent)) - continue; - - // The def site of the operand does not dominate the use. - // We need to insert a local variable to store this var. - - IRInst* operandReplacement = nullptr; - if (canTypeBeStored(operand->getDataType())) - { - auto var = storeInstAsLocalVar(operand); - builder.setInsertBefore(inst); - operandReplacement = builder.emitLoad(var); - } - else if (operand->getOp() == kIROp_Var) - { - // Var can just be hoisted to first block. - operand->insertBefore(firstBlock->getFirstOrdinaryInst()); - } - else - { - // For all other insts, we need to copy it to right before this inst. - // Before actually copying it, check if we have already copied it to - // any blocks that dominates this block. - auto dom = as<IRBlock>(inst->getParent()); - while (dom) - { - auto subCloneEnv = cloneEnvs.TryGetValue(dom); - if (!subCloneEnv) break; - if (subCloneEnv->mapOldValToNew.TryGetValue(operand, operandReplacement)) - { - break; - } - dom = domTree->getImmediateDominator(dom); - } - // We have not found an existing clone in dominators, so we need to copy it - // to this block. - if (!operandReplacement) - { - auto subCloneEnv = cloneEnvs.TryGetValue(as<IRBlock>(inst->getParent())); - builder.setInsertBefore(inst); - operandReplacement = cloneInst(subCloneEnv, &builder, operand); - workList.add(operandReplacement); - } - } - if (operandReplacement) - builder.replaceOperand(inst->getOperands() + j, operandReplacement); - } - } - } - InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) { SLANG_UNUSED(primalType); @@ -774,7 +674,7 @@ namespace Slang // Copy primal insts to the first block of the unzipped function, copy diff insts to the // second block of the unzipped function. // - diffUnzipPass->unzipDiffInsts(fwdDiffFunc); + RefPtr<HoistedPrimalsInfo> primalsInfo = diffUnzipPass->unzipDiffInsts(fwdDiffFunc); IRFunc* unzippedFwdDiffFunc = fwdDiffFunc; // Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell. @@ -801,8 +701,8 @@ namespace Slang // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the // derivative of the return value. - DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam }; - diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); + DiffTransposePass::FuncTranspositionInfo transposeInfo = { paramTransposeInfo.dOutParam, primalsInfo }; + diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, transposeInfo); eliminateDeadCode(diffPropagateFunc); @@ -810,7 +710,7 @@ namespace Slang // with the intermediate results computed from the extracted func. IRInst* intermediateType = nullptr; auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( - diffPropagateFunc, primalFunc, paramTransposeInfo, intermediateType); + diffPropagateFunc, primalFunc, primalsInfo, paramTransposeInfo, intermediateType); // At this point the unzipped func is just an empty shell // and we can simply remove it. @@ -870,7 +770,7 @@ namespace Slang initializeLocalVariables(builder->getModule(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric))); initializeLocalVariables(builder->getModule(), diffPropagateFunc); - insertVariableForRecomputedPrimalInsts(diffPropagateFunc); + // insertVariableForRecomputedPrimalInsts(diffPropagateFunc); stripTempDecorations(diffPropagateFunc); } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index a92978817..e59f27881 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -8,6 +8,8 @@ #include "slang-ir-autodiff.h" #include "slang-ir-autodiff-fwd.h" #include "slang-ir-autodiff-cfg-norm.h" +#include "slang-ir-autodiff-primal-hoist.h" +#include "slang-ir-dominators.h" namespace Slang { @@ -78,6 +80,11 @@ struct DiffTransposePass // of the *output* of the function. // IRInst* dOutInst; + + // Information from the unzip pass on how primal insts + // are split across the primal and differential blocks. + // + HoistedPrimalsInfo* hoistedPrimalsInfo; }; struct PendingBlockTerminatorEntry @@ -227,15 +234,17 @@ struct DiffTransposePass IRBlock* revAfterBlock = revBlockMap[currentBlock]; builder.setInsertInto(revCondBlock); - - hoistPrimalInst(&builder, ifElse->getCondition()); - builder.emitIfElse( + //hoistPrimalInst(&builder, ifElse->getCondition()); + + auto newIfElse = builder.emitIfElse( ifElse->getCondition(), revTrueEntryBlock, revFalseEntryBlock, revAfterBlock); + hoistPrimalOperands(&builder, newIfElse); + if (!revTrueRegionInfo.isTrivial) { builder.setInsertInto(revTrueExitBlock); @@ -348,14 +357,18 @@ struct DiffTransposePass // Emit condition into the new cond block. builder.setInsertInto(revCondBlock); - hoistPrimalInst(&builder, ifElse->getCondition()); - builder.emitIfElse( + // TODO: Need to defer this until after the CFG reversal is complete. + //hoistPrimalInst(&builder, ifElse->getCondition()); + + auto newIfElse = builder.emitIfElse( ifElse->getCondition(), revTrueBlock, revFalseBlock, revTrueBlock); + hoistPrimalOperands(&builder, newIfElse); + // Old false-side starting block becomes end block // for the new pre-cond region (which could be empty) // @@ -364,12 +377,13 @@ struct DiffTransposePass { IRBlock* revPreCondEndBlock = revBlockMap[falseBlock]; builder.setInsertInto(revPreCondEndBlock); - builder.emitLoop( + auto revLoop = builder.emitLoop( revCondBlock, revBreakBlock, revLoopEndBlock, getPhiGrads(falseBlock).getCount(), getPhiGrads(falseBlock).getBuffer()); + loop->transferDecorationsTo(revLoop); auto revLoopStartBlock = revBlockMap[breakBlock]; builder.setInsertInto(revLoopStartBlock); @@ -383,12 +397,13 @@ struct DiffTransposePass // Emit loop into rev-version of the break block. auto revLoopBlock = revBlockMap[breakBlock]; builder.setInsertInto(revLoopBlock); - builder.emitLoop( + auto revLoop = builder.emitLoop( revPreCondBlock, revBreakBlock, revLoopEndBlock, getPhiGrads(breakBlock).getCount(), getPhiGrads(breakBlock).getBuffer()); + loop->transferDecorationsTo(revLoop); } currentBlock = breakBlock; @@ -463,14 +478,16 @@ struct DiffTransposePass builder.setInsertInto(revSwitchBlock); - hoistPrimalInst(&builder, switchInst->getCondition()); + // hoistPrimalInst(&builder, switchInst->getCondition()); - builder.emitSwitch( + auto newSwitchInst = builder.emitSwitch( switchInst->getCondition(), revBreakBlock, revDefaultRegionEntry, reverseSwitchArgs.getCount(), reverseSwitchArgs.getBuffer()); + + hoistPrimalOperands(&builder, newSwitchInst); currentBlock = breakBlock; break; @@ -504,6 +521,13 @@ struct DiffTransposePass IRFunc* revDiffFunc, FuncTranspositionInfo transposeInfo) { + // TODO (sai): We really to make this method stateless + // (i.e. not store per-func info in 'this') + // since it is reused for every reverse-mode call. + // + + hoistedPrimalsInfo = transposeInfo.hoistedPrimalsInfo; + // Grab all differentiable type information. diffTypeContext.setFunc(revDiffFunc); @@ -586,6 +610,9 @@ struct DiffTransposePass } } + // Make a temporary block to hold inverted insts. + tempInvBlock = builder.createBlock(); + for (auto block : workList) { // Set dOutParameter as the transpose gradient for the return inst, if any. @@ -620,7 +647,6 @@ struct DiffTransposePass { if (auto loopInst = as<IRLoop>(block->getTerminator())) { - lowerLoopExitValues(&builder, loopInst); invertLoopCondition(&builder, loopInst); } } @@ -655,6 +681,50 @@ struct DiffTransposePass subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal); } + // TODO: Should move this to before all the transposition, but a lot of the + // transposition logic seems to access the parent of blocks to find the func. + // Replace those uses. + // + for (auto block : workList) + block->removeFromParent(); + + // Mark all primal operands for hoisting. + // TODO: Can we just merge this with finishHoistingPrimalInsts? + // TODO: Some of this logic is replicated in finishHoistingPrimalInsts. Merge it with the + // maybeAddOperandsToWorkList logic there. + // + for (auto block : workList) + { + IRBlock* revBlock = revBlockMap[block]; + + for (auto child = revBlock->getFirstChild(); child; child = child->getNextInst()) + { + hoistPrimalOperands(&builder, child); + + for (auto decoration = child->getFirstDecoration(); decoration; decoration = decoration->getNextDecoration()) + { + if (auto contextDecoration = as<IRBackwardDerivativePrimalContextDecoration>(decoration)) + hoistPrimalUse(&builder, &contextDecoration->primalContextVar); + + if (auto loopExitDecoration = as<IRLoopExitPrimalValueDecoration>(decoration)) + hoistPrimalUse(&builder, &loopExitDecoration->exitVal); + } + + if (auto instType = child->getDataType()) + if (!as<IRModuleInst>(instType->getParent())) + hoistPrimalUse(&builder, &child->typeUse); + } + } + + finishHoistingPrimals(revDiffFunc); + + for (auto block : workList) + { + auto revBlock = as<IRBlock>(revBlockMap[block]); + if (auto revLoop = as<IRLoop>(revBlock->getTerminator())) + lowerLoopExitValues(&builder, revLoop); + } + // At this point, the only block left without terminator insts // should be the last one. Add a void return to complete it. // @@ -723,8 +793,25 @@ struct DiffTransposePass return tempRevVar; } + IRVar* lookupInverseVar(IRInst* inst) + { + return inverseVarMap[inst]; + } + + IRVar* getOrCreateInverseVar(IRInst* primalInst, IRGlobalValueWithCode* func) + { + IRBlock* varBlock = firstRevDiffBlockMap[func]; + return getOrCreateInverseVar(primalInst, varBlock); + } + IRVar* getOrCreateInverseVar(IRInst* primalInst) { + IRBlock* varBlock = firstRevDiffBlockMap[as<IRFunc>(primalInst->getParent()->getParent())]; + return getOrCreateInverseVar(primalInst, varBlock); + } + + IRVar* getOrCreateInverseVar(IRInst* primalInst, IRBlock* varBlock) + { // No need to store inverse values for constants. if (as<IRConstant>(primalInst)) return nullptr; @@ -734,13 +821,11 @@ struct DiffTransposePass return inverseVarMap[primalInst]; IRBuilder tempVarBuilder(autodiffContext->moduleInst); - - IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(primalInst->getParent()->getParent())]; - if (auto firstInst = firstDiffBlock->getFirstOrdinaryInst()) + if (auto firstInst = varBlock->getFirstOrdinaryInst()) tempVarBuilder.setInsertBefore(firstInst); else - tempVarBuilder.setInsertInto(firstDiffBlock); + tempVarBuilder.setInsertInto(varBlock); auto primalType = primalInst->getDataType(); @@ -766,6 +851,19 @@ struct DiffTransposePass return false; } + IRParam* getParamAt(IRBlock* block, UIndex ii) + { + UIndex index = 0; + for (auto param : block->getParams()) + { + if (ii == index) + return param; + + index ++; + } + SLANG_UNEXPECTED("ii >= paramCount"); + } + void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock) { IRBuilder builder(autodiffContext->moduleInst); @@ -773,6 +871,10 @@ struct DiffTransposePass // Insert into our reverse block. builder.setInsertInto(revBlock); + // Create an inverse builder to insert insts into the inv-block. + IRBuilder invBuilder(autodiffContext->moduleInst); + + // Check if this block has any 'outputs' (in the form of phi args // sent to the successor block) // @@ -798,25 +900,43 @@ struct DiffTransposePass revParam, nullptr)); } - else if (isPrimalInst(arg)) + else if (hasInverse(arg)) { - // If the output arg is a primal, emit a parameter - // to accept it as an _input_ for the reverse-mode - // - auto primalType = arg->getDataType(); - auto primalInvParam = builder.emitParam(primalType); + InversionInfo invInfo = this->hoistedPrimalsInfo->invertInfoMap[branchInst]; + if (invInfo.targetInsts.contains(arg)) + { + SLANG_ASSERT(hasInverse(getParamAt(branchInst->getTargetBlock(), ii))); + + // If the output arg is a primal, emit a parameter + // to accept it as an _input_ for the reverse-mode + // + auto primalType = arg->getDataType(); + auto primalInvParam = builder.emitParam(primalType); - setInverse(&builder, arg, primalInvParam); + invBuilder.setInsertBefore(branchInst); + setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam); + } } else { - SLANG_UNEXPECTED("Encountered inst not marked as primal or differential"); + if (hasInverse(getParamAt(branchInst->getTargetBlock(), ii))) + { + auto primalType = arg->getDataType(); + auto primalInvParam = builder.emitParam(primalType); + + invBuilder.setInsertBefore(branchInst); + setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam); + } + else + { + SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion"); + } } } } // Move pointer & reference insts to the top of the reverse-mode block. - List<IRInst*> nonValueInsts; + List<IRInst*> typeInsts; for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) { // If the instruction is a variable allocation (or reverse-gradient pair reference), @@ -824,17 +944,17 @@ struct DiffTransposePass // TODO: This is hacky.. Need a more principled way to handle this // (like primal inst hoisting) // - if (as<IRVar>(child) || as<IRReverseGradientDiffPairRef>(child)) - nonValueInsts.add(child); + //if (as<IRVar>(child) || as<IRReverseGradientDiffPairRef>(child)) + // nonValueInsts.add(child); // Slang doesn't support function values. So if we see a func-typed inst // it's proabably a reference to a function. // if (as<IRFuncType>(child->getDataType())) - nonValueInsts.add(child); + typeInsts.add(child); } - for (auto inst : nonValueInsts) + for (auto inst : typeInsts) { inst->insertAtEnd(revBlock); } @@ -846,9 +966,6 @@ struct DiffTransposePass // for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst()) { - if (child->findDecoration<IRPrimalValueAccessDecoration>()) - continue; - if (as<IRDecoration>(child) || as<IRParam>(child)) continue; if (as<IRType>(child)) @@ -856,8 +973,15 @@ struct DiffTransposePass if (isDifferentialInst(child)) transposeInst(&builder, child); - else if (isPrimalInst(child)) - invertInst(&builder, child); + else if (shouldInstBeInverted(child)) + { + // We'll collect inverse insts in an orphaned block, + // so disable IR validation temporarily. + // + disableIRValidationAtInsert(); + invertInst(&invBuilder, child); + enableIRValidationAtInsert(); + } } // After processing the block's instructions, we 'flush' any remaining gradients @@ -901,23 +1025,18 @@ struct DiffTransposePass phiParamRevGradInsts.add(gradInst); } else - { + { phiParamRevGradInsts.add( emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param))); } } - else if (isPrimalInst(param)) + else if (hasInverse(param)) { - if (hasInverse(param)) - phiParamRevGradInsts.add(getInverse(&builder, param)); - else - { - SLANG_UNEXPECTED("param is a primal inst but has no registered inverse"); - } + phiParamRevGradInsts.add(param); } else - { - SLANG_UNEXPECTED("param is neither differential nor primal"); + { + SLANG_UNEXPECTED("param is neither differential inst nor marked for inversion"); } } @@ -995,8 +1114,15 @@ struct DiffTransposePass { } }; - List<InvInstPair> invertArithmetic(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput) + List<InvInstPair> invertArithmetic(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo) { + SLANG_RELEASE_ASSERT(invInfo.requiredOperands.getCount() == 1); + SLANG_RELEASE_ASSERT(invInfo.targetInsts.getCount() == 1); + + auto invOutput = invInfo.requiredOperands[0]; + + auto invTargetInst = invInfo.targetInsts[0]; + switch (primalInst->getOp()) { case kIROp_Add: @@ -1004,7 +1130,7 @@ struct DiffTransposePass SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1))); return List<InvInstPair>( InvInstPair( - primalInst->getOperand(0), + invTargetInst, builder->emitSub( primalInst->getOperand(0)->getDataType(), invOutput, @@ -1015,7 +1141,7 @@ struct DiffTransposePass SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1))); return List<InvInstPair>( InvInstPair( - primalInst->getOperand(0), + invTargetInst, builder->emitAdd( primalInst->getOperand(0)->getDataType(), invOutput, @@ -1027,24 +1153,38 @@ struct DiffTransposePass } } - void lowerLoopExitValues(IRBuilder* builder, IRLoop* fwdLoop) + // NOTE: This is a workaround for the fact that we expect inverses to use + // single-use variables. The loop exit value will add a + // second store to most inv-variables and mess with the primal hoisting mechanism. + // Instead of emitting into the orphaned inverse block, we'll directly emit into + // the reverse-mode block since we'll be running this _after_ the primal hoisting + // pass. + // + // This workaround is fine for inverting loop counters, but when we want to + // expand to supporting general-purpose adjoints, we would want to use per-region + // inverse vars based on 'invInfo' (enforcing single-use vars) + // + void lowerLoopExitValues(IRBuilder* builder, IRLoop* revLoop) { - for (auto decoration : fwdLoop->getDecorations()) + List<IRDecoration*> processedDecorations; + for (auto decoration : revLoop->getDecorations()) { if (auto loopExitValueDecoration = as<IRLoopExitPrimalValueDecoration>(decoration)) { - IRBlock* revLoopInitBlock = revBlockMap[fwdLoop->getBreakBlock()]; - - if (auto revLoopInst = revLoopInitBlock->getTerminator()) - builder->setInsertBefore(revLoopInst); - else - builder->setInsertInto(revLoopInitBlock); - - hoistPrimalInst(builder, loopExitValueDecoration->getLoopExitValInst()); + builder->setInsertBefore(revLoop); + setInverse( + builder, + nullptr, + builder->getFunc(), + loopExitValueDecoration->getTargetInst(), + loopExitValueDecoration->getLoopExitValInst()); - setInverse(builder, loopExitValueDecoration->getTargetInst(), loopExitValueDecoration->getLoopExitValInst()); + processedDecorations.add(loopExitValueDecoration); } } + + for (auto decoration : processedDecorations) + decoration->removeAndDeallocate(); } void lowerLoopExitValues(IRBuilder* builder, IRBlock* block) @@ -1094,19 +1234,21 @@ struct DiffTransposePass kIROp_Neq, 2, List<IRInst*>( - hoistPrimalInst(builder, loopCounterParam), - hoistPrimalInst(builder, loopCounterInitVal)).getBuffer()); + loopCounterParam, + loopCounterInitVal).getBuffer()); + + hoistPrimalOperands(builder, paramBoundsCheck); as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(paramBoundsCheck); } - List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput) + List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo) { switch (primalInst->getOp()) { case kIROp_Add: case kIROp_Sub: - return invertArithmetic(builder, primalInst, invOutput); + return invertArithmetic(builder, primalInst, invInfo); default: SLANG_UNIMPLEMENTED_X("Unhandled inst type for inversion"); @@ -1115,70 +1257,392 @@ struct DiffTransposePass bool hasInverse(IRInst* primalInst) { - if (getOrCreateInverseVar(primalInst)) - return true; - else - return false; + return this->hoistedPrimalsInfo->invertSet.Contains(primalInst); } - IRInst* getInverse(IRBuilder* builder, IRInst* primalInst) + IRInst* loadInverse(IRBuilder* builder, IRInst* primalInst) { // Note: There are other possible cases here, although not important // right now. For example, a value is available to load from the primal block. // - if (auto invVar = getOrCreateInverseVar(primalInst)) + + if (auto invVar = getOrCreateInverseVar(primalInst, builder->getFunc())) return builder->emitLoad(invVar); return nullptr; } - void setInverse(IRBuilder* builder, IRInst* inst, IRInst* invInst) + IRInst* lookupInstInPrimalBlock(IRInst* invInst) { - if (auto invVar = getOrCreateInverseVar(inst)) - builder->emitStore(invVar, invInst); + // Lookup the inst in the primal block whose value we can use as an operand + // for the inverted inst. + // + // auto inversionInfo = this->hoistedPrimalsInfo->invertInfoMap[invInst]; + return invInst; } - IRInst* hoistPrimalInst(IRBuilder* revBuilder, IRInst* inst) + void setInverse(IRBuilder* builder, IRBlock* defBlock, IRGlobalValueWithCode* func, IRInst* inst, IRInst* invInst) { - if (as<IRBlock>(inst->getParent()) && - isDifferentialInst(as<IRBlock>(inst->getParent()))) + auto instBlock = as<IRBlock>(inst->getParent()); + if (!instBlock) + return; + + disableIRValidationAtInsert(); + if (auto invVar = getOrCreateInverseVar(inst, func)) { - SLANG_RELEASE_ASSERT(isPrimalInst(inst)); + auto invStore = builder->emitStore(invVar, invInst); + mapStoreToDefBlock[as<IRStore>(invStore)] = defBlock; } + enableIRValidationAtInsert(); + } - // Are the operands of this primal inst also available in the reverse-mode context? - // If not, move/load them. + bool shouldInstBeInverted(IRInst* inst) + { + + if (this->hoistedPrimalsInfo->instsToInvert.Contains(inst)) + return true; + + return false; + } + + IRInst* hoistPrimalUse(IRBuilder*, IRUse* use) + { + primalUsesToHoist.add(use); + return use->get(); + } + + bool doesInstRequireHoisting(IRInst* inst) + { + if (as<IRModuleInst>(inst->getParent())) + return false; + + if (as<IRBlock>(inst) || + as<IRGlobalValueWithCode>(inst) || + as<IRConstant>(inst)) + return false; + + if (as<IRTerminatorInst>(inst)) + return false; + + if (as<IRDecoration>(inst)) + return doesInstRequireHoisting(getInstInBlock(inst)); + + // We're looking for primal insts in differential blocks + // that have not yet been moved to the 'active' blocks + // (i.e in diff blocks that do not have parents) // - hoistPrimalOperands(revBuilder, inst); + return (!isDifferentialInst(inst) && + (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) && + getBlock(inst)->getParent() == nullptr); + } - if (isPrimalInst(inst) && - as<IRBlock>(inst->getParent()) && - isDifferentialInst(as<IRBlock>(inst->getParent()))) + // Builds a map from inst to a list of uses by primal _inverted_ insts. + Dictionary<IRInst*, List<IRInst*>> buildInvOperandMap() + { + Dictionary<IRInst*, List<IRInst*>> invOperandMap; + for (auto kvpair : this->hoistedPrimalsInfo->invertInfoMap) { - if (!inst->findDecoration<IRPrimalValueAccessDecoration>()) + InversionInfo invInfo = kvpair.Value; + + for (auto operand : invInfo.requiredOperands) { - return getInverse(revBuilder, inst); + if (!invOperandMap.ContainsKey(operand)) + invOperandMap[operand] = List<IRInst*>(); + + for (auto target : invInfo.targetInsts) + invOperandMap[operand].GetValue().add(target); } - else + } + + return invOperandMap; + } + + IRBlock* walkToEndOfRegion(IRBlock* block) + { + IRBlock* currBlock = block; + + bool keepGoing = true; + while (keepGoing) + { + auto terminator = currBlock->getTerminator(); + switch (terminator->getOp()) { - auto block = as<IRBlock>(inst->getParent()); - SLANG_RELEASE_ASSERT(block); + case kIROp_Return: + keepGoing = false; + break; - if (block == revBuilder->getBlock()) + case kIROp_unconditionalBranch: { - // Already in block.. - return inst; - } + auto nextBlock = as<IRUnconditionalBranch>(terminator)->getTargetBlock(); - // Otherwise, move our inst to the the current builder location. - inst->removeFromParent(); - revBuilder->addInst(inst); + HashSet<IRBlock*> predecessorSet; + for (auto predecessor : nextBlock->getPredecessors()) + predecessorSet.Add(predecessor); - return inst; + if (predecessorSet.Count() > 1) + { + keepGoing = false; + break; + } + + currBlock = nextBlock; + break; + } + + case kIROp_ifElse: + { + for (auto predecessor : currBlock->getPredecessors()) + { + if (as<IRLoop>(predecessor->getTerminator())) + { + keepGoing = false; + break; + } + } + + currBlock = as<IRIfElse>(terminator)->getAfterBlock(); + break; + } + + case kIROp_Switch: + currBlock = as<IRSwitch>(terminator)->getBreakLabel(); + break; + + case kIROp_loop: + currBlock = as<IRLoop>(terminator)->getBreakBlock(); + break; } } - return inst; + return currBlock; + } + + void finishHoistingPrimals(IRGlobalValueWithCode* func) + { + List<IRInst*> workList; + + Dictionary<IRInst*, IRInst*> hoistedInstMap; + + RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); + + Dictionary<IRInst*, List<IRInst*>> invOperandMap = buildInvOperandMap(); + + auto varBlock = func->getFirstBlock()->getNextBlock(); + + // Load up pending insts into workList. + for (auto use : primalUsesToHoist) + workList.add(use->get()); + + primalUsesToHoist.clear(); + + + auto maybeAddPrimalOperandsToWorkList = [&](IRInst* inst) + { + UIndex opIndex = 0; + for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++) + { + if (doesInstRequireHoisting(operand->get()) && + !hoistedInstMap.ContainsKey(operand->get())) + { + workList.add(operand->get()); + } + } + + if (auto instType = inst->getDataType()) + { + if (doesInstRequireHoisting(instType) && + !hoistedInstMap.ContainsKey(instType)) + workList.add(instType); + } + }; + + auto maybeAddUsersToWorkList = [&](IRInst* inst) + { + for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (doesInstRequireHoisting(use->getUser())) + { + if (as<IRVar>(inst) && as<IRStore>(use->getUser())) + continue; + + // Uses that haven't already been hoisted into reverse-mode + // blocks, and are not in the invert-set are pending uses. + // + if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser())) + workList.add(use->getUser()); + } + } + }; + + auto doesInstHavePendingUsers = [&](IRInst* inst) + { + for (auto use = inst->firstUse; use; use = use->nextUse) + { + if (doesInstRequireHoisting(use->getUser())) + { + if (as<IRVar>(inst) && as<IRStore>(use->getUser())) + continue; + + // Users that haven't already been hoisted into reverse-mode + // blocks are pending users. + // + if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser())) + return true; + } + } + + return false; + }; + + auto isInstHoisted = [&](IRInst* inst) + { + return getBlock(inst)->getParent() != nullptr && isDifferentialInst(getBlock(inst)); + }; + + while (workList.getCount() > 0) + { + // Pop work item + auto inst = workList.getLast(); + workList.removeLast(); + + // Already hoisted to reverse-mode block. + // replace with mapped inst (in case it's different) + // and continue on.. (this should actually never be hit) + // + if (hoistedInstMap.ContainsKey(inst)) + continue; + + if (invOperandMap.ContainsKey(inst)) + { + List<IRInst*> pendingInvDependencies; + for (auto dependency : invOperandMap[inst].GetValue()) + { + if (doesInstRequireHoisting(dependency) && + !hoistedInstMap.ContainsKey(dependency)) + pendingInvDependencies.add(dependency); + } + + if (pendingInvDependencies.getCount() > 0) + { + workList.add(inst); + for (auto dependency : pendingInvDependencies) + workList.add(dependency); + + // Skip until all the dependencies have been handled. + continue; + } + } + + // Are the uses of this primal inst already hoisted into the reverse-mode + // blocks? We cannot hoist this inst unless the uses are hoisted. + // + if (doesInstHavePendingUsers(inst)) + { + // Add inst back to work list. + workList.add(inst); + + // Then, add all the pending use to the top of + // list, ensuring they are processed before we see + // inst again. + // + maybeAddUsersToWorkList(inst); + + continue; + } + + // The used inst is marked for inversion, lookup and load + // an inverse. + // + if (this->hoistedPrimalsInfo->invertSet.Contains(inst)) + { + // Replace with inverse. + IRBuilder builder(func->getModule()); + + for (auto use = inst->firstUse; use;) + { + auto nextUse = use->nextUse; + + if (!isInstHoisted(use->getUser())) + { + use = nextUse; + continue; + } + + // TODO: Hacky workaround to prevent the 'key' being overwritten, + // avoid this by adding the decoration on the param instead of the loop + // + if (auto exitValDecoration = as<IRLoopExitPrimalValueDecoration>(use->getUser())) + { + if (&exitValDecoration->target == use) + { + use = nextUse; + continue; + } + } + + + builder.setInsertBefore(getInstInBlock(use->getUser())); + use->set(loadInverse(&builder, inst)); + + use = nextUse; + } + + // If all uses of the invertible inst have been hoisted, + // add the inv-var to the worklist. + // + workList.add(lookupInverseVar(inst)); + hoistedInstMap[inst] = nullptr; + + continue; + } + + // Should not see an inst marked for inversion here. + SLANG_RELEASE_ASSERT(!this->hoistedPrimalsInfo->invertSet.Contains(inst)); + + List<IRUse*> relevantUses; + + IRBlock* defBlock = nullptr; + if (auto varToHoist = as<IRVar>(inst)) + { + varToHoist->insertBefore(varBlock->getFirstOrdinaryInst()); + inst = findUniqueStoredVal(varToHoist)->getUser(); + SLANG_ASSERT(inst); + + defBlock = getBlock(inst); + } + else + { + defBlock = getBlock(inst); + } + + if (!doesInstRequireHoisting(inst)) + continue; + + // Move this inst to after it's diff uses. + // + { + + IRBlock* currTopBlock = revBlockMap[walkToEndOfRegion(defBlock)]; + + SLANG_RELEASE_ASSERT(currTopBlock); + + // More consistency checks + SLANG_RELEASE_ASSERT(currTopBlock->getFirstOrdinaryInst() != nullptr); + SLANG_RELEASE_ASSERT(currTopBlock->getParent() != nullptr); + SLANG_RELEASE_ASSERT(isDifferentialInst(currTopBlock)); + + // Insert at top. (disabling validation since the operands of + // this inst might not be hoisted to the right place yet) + // + disableIRValidationAtInsert(); + inst->insertBefore(currTopBlock->getFirstOrdinaryInst()); + enableIRValidationAtInsert(); + } + + // Finish up.. + hoistedInstMap[inst] = inst; + maybeAddPrimalOperandsToWorkList(inst); + } } void hoistPrimalOperands(IRBuilder* revBuilder, IRInst* fwdInst) @@ -1192,10 +1656,9 @@ struct DiffTransposePass // make sure all requried primal insts are moved to the right // place) // - if (isPrimalInst(fwdInst->getOperand(ii))) + if (doesInstRequireHoisting(fwdInst->getOperand(ii))) { - auto hoistedPrimalInst = hoistPrimalInst(revBuilder, fwdInst->getOperand(ii)); - fwdInst->setOperand(ii, hoistedPrimalInst); + hoistPrimalUse(revBuilder, &fwdInst->getOperands()[ii]); } } } @@ -1203,14 +1666,28 @@ struct DiffTransposePass void invertInst(IRBuilder* builder, IRInst* primalInst) { // Look for an available inverse entry for this primalInst's *output* - if (hasInverse(primalInst)) + if (shouldInstBeInverted(primalInst)) { - auto invOutput = getInverse(builder, primalInst); + // This logic is already handled in transposeBlock() so we skip + // it here. + // + if (as<IRTerminatorInst>(primalInst)) + return; + + auto invInfo = this->hoistedPrimalsInfo->invertInfoMap[primalInst]; - auto invEntries = invertInst(builder, primalInst, invOutput); + IRBuilder invBuilder(builder->getModule()); + invBuilder.setInsertAfter(primalInst); + auto invEntries = invertInst(&invBuilder, primalInst, invInfo); + for (auto entry : invEntries) - setInverse(builder, entry.inst, entry.invInst); + setInverse( + &invBuilder, + getBlock(primalInst), + as<IRGlobalValueWithCode>(entry.inst->getParent()->getParent()), + entry.inst, + entry.invInst); } else { @@ -1270,11 +1747,6 @@ struct DiffTransposePass SLANG_ASSERT(gradients.getCount() == 0); } - // Ensure primal operands are replaced with insts accessible in the - // reverse-mode context. - // - hoistPrimalOperands(builder, inst); - // Is this inst used in another differential block? // Emit a function-scope accumulator variable, and include it's value. // Also, we ignore this if it's a load since those are turned into stores @@ -1381,7 +1853,9 @@ struct DiffTransposePass // In order to perform the call, we need a temporary var to store the DiffPair. auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType(); auto tempVar = builder->emitVar(pairType); - auto primalVal = builder->emitLoad(hoistPrimalInst(builder, instPair->getPrimal())); + auto primalVal = builder->emitLoad(instPair->getPrimal()); + hoistPrimalOperands(builder, primalVal); // TODO(sai): Do we need to hoist other insts here? + auto diffVal = builder->emitLoad(instPair->getDiff()); auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); builder->emitStore(tempVar, pairVal); @@ -1453,13 +1927,15 @@ struct DiffTransposePass // If the callee provides a primal implementation that produces continuation context for propagation phase // we grab it and pass it as argument to the propagation function. + // if (auto primalContextDecor = fwdCall->findDecoration<IRBackwardDerivativePrimalContextDecoration>()) - { - // Ensure availability of the primal context var - auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar()); - SLANG_RELEASE_ASSERT(primalContextVar); + { + auto primalContextVar = primalContextDecor->getBackwardDerivativePrimalContextVar(); + + auto contextLoad = builder->emitLoad(primalContextVar); + hoistPrimalOperands(builder, contextLoad); - args.add(builder->emitLoad(primalContextVar)); + args.add(contextLoad); argTypes.add(as<IRPtrTypeBase>( primalContextVar->getDataType()) ->getValueType()); @@ -1735,6 +2211,7 @@ struct DiffTransposePass return transposeUpdateElement(builder, fwdInst, revValue); case kIROp_LoadReverseGradient: + case kIROp_ReverseGradientDiffPairRef: case kIROp_DefaultConstruct: case kIROp_Specialize: case kIROp_unconditionalBranch: @@ -2255,18 +2732,17 @@ struct DiffTransposePass { // current type should be a scalar. SLANG_RELEASE_ASSERT(!as<IRVectorType>(currentType->getDataType())); - - auto targetVectorType = as<IRVectorType>(targetType); - List<IRInst*> operands; - for (Index ii = 0; ii < as<IRIntLit>(targetVectorType->getElementCount())->getValue(); ii++) - { - operands.add(inst); - } + return builder->emitMakeVectorFromScalar(targetType, inst); + } - IRInst* newInst = builder->emitMakeVector(targetType, operands.getCount(), operands.getBuffer()); + case kIROp_MatrixType: + { + // current type should be a scalar. + SLANG_RELEASE_ASSERT(!as<IRVectorType>(currentType->getDataType()) && + !as<IRMatrixType>(currentType->getDataType())); - return newInst; + return builder->emitMakeMatrixFromScalar(targetType, inst); } default: @@ -2968,6 +3444,10 @@ struct DiffTransposePass DifferentialPairTypeBuilder pairBuilder; + HoistedPrimalsInfo* hoistedPrimalsInfo; + + IRBlock* tempInvBlock; + Dictionary<IRInst*, List<RevGradient>> gradientsMap; Dictionary<IRInst*, IRVar*> revAccumulatorVarMap; @@ -2987,6 +3467,10 @@ struct DiffTransposePass Dictionary<IRBlock*, List<IRInst*>> phiGradsMap; Dictionary<IRInst*, IRInst*> inverseValueMap; + + List<IRUse*> primalUsesToHoist; + + Dictionary<IRStore*, IRBlock*> mapStoreToDefBlock; }; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 5b59416d4..16862bb19 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -332,7 +332,12 @@ struct ExtractPrimalFuncContext inst); } - IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, HashSet<IRInst*>& primalParams, IRInst*& outIntermediateType) + IRFunc* turnUnzippedFuncIntoPrimalFunc( + IRFunc* unzippedFunc, + IRFunc* originalFunc, + HoistedPrimalsInfo* primalsInfo, + HashSet<IRInst*>& primalParams, + IRInst*& outIntermediateType) { IRBuilder builder(module); @@ -375,17 +380,9 @@ struct ExtractPrimalFuncContext // output intermediary struct. for (auto inst : block->getChildren()) { - if (shouldStoreInst(inst)) + if (primalsInfo->storeSet.Contains(inst)) { - if (as<IRParam>(inst)) - builder.setInsertBefore(block->getFirstOrdinaryInst()); - else - builder.setInsertAfter(inst); - storeInst(builder, inst, outIntermediary); - } - else if (inst->getOp() == kIROp_Var) - { - if (shouldStoreVar(as<IRVar>(inst))) + if (as<IRVar>(inst)) { auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary); builder.setInsertBefore(inst); @@ -394,7 +391,14 @@ struct ExtractPrimalFuncContext inst->replaceUsesWith(fieldAddr); builder.addPrimalValueStructKeyDecoration(inst, field->getKey()); } - + else + { + if (as<IRParam>(inst)) + builder.setInsertBefore(block->getFirstOrdinaryInst()); + else + builder.setInsertAfter(inst); + storeInst(builder, inst, outIntermediary); + } } } } @@ -459,6 +463,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE IRFunc* DiffUnzipPass::extractPrimalFunc( IRFunc* func, IRFunc* originalFunc, + HoistedPrimalsInfo* primalsInfo, ParameterBlockTransposeInfo& paramInfo, IRInst*& intermediateType) { @@ -470,6 +475,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( subEnv.parent = &cloneEnv; auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func)); + auto clonedPrimalsInfo = primalsInfo->applyMap(&subEnv); + // Remove [KeepAlive] decorations in clonedFunc. for (auto block : clonedFunc->getBlocks()) for (auto inst : block->getChildren()) @@ -494,7 +501,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( context.init(autodiffContext->moduleInst->getModule(), autodiffContext->transcriberSet.primalTranscriber); intermediateType = nullptr; - auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, newPrimalParams, intermediateType); + auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, clonedPrimalsInfo, newPrimalParams, intermediateType); if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>()) { @@ -580,6 +587,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( { // The primal calls should be marked as no side effect so they can be DCE'd if possible. // We can only do so if the intermediate context of the callee is stored. + // if (primalCtx->getBackwardDerivativePrimalContextVar() ->findDecoration<IRPrimalValueStructKeyDecoration>()) { diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index f2aa1fd29..8b24b122e 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -9,6 +9,8 @@ #include "slang-ir-autodiff-fwd.h" #include "slang-ir-autodiff-propagate.h" #include "slang-ir-autodiff-transcriber-base.h" +#include "slang-ir-autodiff-region.h" +#include "slang-ir-autodiff-primal-hoist.h" #include "slang-ir-validate.h" #include "slang-ir-ssa.h" @@ -36,171 +38,10 @@ struct DiffUnzipPass // might run into an issue here? IRBlock* firstDiffBlock; - struct IndexedRegion : public RefObject - { - IRLoop* loop; - IndexedRegion* parent; - - IndexedRegion(IRLoop* loop, IndexedRegion* parent) : loop(loop), parent(parent) - { } - - IRBlock* getInitializerBlock() { return as<IRBlock>(loop->getParent()); } - IRBlock* getConditionBlock() - { - auto condBlock = as<IRBlock>(loop->getTargetBlock()); - SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator())); - return condBlock; - } - - IRBlock* getBreakBlock() { return loop->getBreakBlock(); } - - IRBlock* getUpdateBlock() - { - auto initBlock = getInitializerBlock(); - - auto condBlock = getConditionBlock(); - - IRBlock* lastLoopBlock = nullptr; - - for (auto predecessor : condBlock->getPredecessors()) - { - if (predecessor != initBlock) - lastLoopBlock = predecessor; - } - - // Should find atleast one predecessor that is _not_ the - // init block (that contains the loop info). This - // predecessor would be the last block in the loop - // before looping back to the condition. - // - SLANG_RELEASE_ASSERT(lastLoopBlock); - - return lastLoopBlock; - } - }; - - - struct IndexedRegionMap : public RefObject - { - Dictionary<IRBlock*, IndexedRegion*> map; - List<RefPtr<IndexedRegion>> regions; - - IndexedRegion* newRegion(IRLoop* loop, IndexedRegion* parent) - { - auto region = new IndexedRegion(loop, parent); - regions.add(region); - - return region; - } - - void mapBlock(IRBlock* block, IndexedRegion* region) - { - map.Add(block, region); - } - - bool hasMapping(IRBlock* block) - { - return map.ContainsKey(block); - } - - IndexedRegion* getRegion(IRBlock* block) - { - return map[block]; - } - - List<IndexedRegion*> getAllAncestorRegions(IRBlock* block) - { - List<IndexedRegion*> regionList; - - IndexedRegion* region = getRegion(block); - for (; region; region = region->parent) - regionList.add(region); - - return regionList; - } - }; - - RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func) - { - RefPtr<IndexedRegionMap> regionMap = new IndexedRegionMap; - - List<IRBlock*> workList; - - regionMap->mapBlock(func->getFirstBlock(), nullptr); - workList.add(func->getFirstBlock()); - - while (workList.getCount() > 0) - { - auto currentBlock = workList.getLast(); - workList.removeLast(); - - auto terminator = currentBlock->getTerminator(); - auto currentRegion = regionMap->getRegion(currentBlock); - - switch (terminator->getOp()) - { - case kIROp_loop: - { - auto loopRegion = regionMap->newRegion(as<IRLoop>(terminator), currentRegion); - auto condBlock = as<IRLoop>(terminator)->getTargetBlock(); - - regionMap->mapBlock(condBlock, loopRegion); - workList.add(condBlock); - - auto ifElse = as<IRIfElse>(condBlock->getTerminator()); - SLANG_RELEASE_ASSERT(ifElse); - - // TODO: this is one of the places we'll need to change if we support loops that - // loop on either the true or false side. For now, we assume the loop is on the - // true side only. - // - regionMap->mapBlock(ifElse->getFalseBlock(), currentRegion); - workList.add(ifElse->getFalseBlock()); - } - } - - for (auto successor : currentBlock->getSuccessors()) - { - // If already mapped, skip. - if (regionMap->hasMapping(successor)) - continue; - regionMap->mapBlock(successor, currentRegion); - workList.add(successor); - } - } - - return regionMap; - } - - RefPtr<IndexedRegionMap> indexRegionMap; - struct IndexTrackingInfo : public RefObject - { - // After lowering, store references to the count - // variables associated with this region - // - IRInst* primalCountParam = nullptr; - IRInst* diffCountParam = nullptr; - - IRVar* primalCountLastVar = nullptr; - - enum CountStatus - { - Unresolved, - Dynamic, - Static - }; - - CountStatus status = CountStatus::Unresolved; - - // Inferred maximum number of iterations. - Count maxIters = -1; - }; - Dictionary<IndexedRegion*, RefPtr<IndexTrackingInfo>> indexInfoMap; - DiffUnzipPass( AutoDiffSharedContext* autodiffContext) : autodiffContext(autodiffContext) @@ -217,7 +58,7 @@ struct DiffUnzipPass return diffMap[inst]; } - void unzipDiffInsts(IRFunc* func) + RefPtr<HoistedPrimalsInfo> unzipDiffInsts(IRFunc* func) { diffTypeContext.setFunc(func); @@ -316,7 +157,12 @@ struct DiffUnzipPass // Emit counter variables and other supporting // instructions for all regions. // - lowerIndexedRegions(); + // TODO: Need to have maxIndex in _both_ IndexTrackingInfo & IndexedRegionInfo. + // That way, we can do the various passes _before_ lowerIndexedRegions() + // TODO: Remove the call to lowerIndexedRegions() once checkpointing works properly. + // + RefPtr<HoistedPrimalsInfo> primalsInfo = new HoistedPrimalsInfo(); + lowerIndexedRegions(primalsInfo); // Copy regions from fwd-block to their split blocks // to make it easier to do lookups. @@ -338,21 +184,39 @@ struct DiffUnzipPass indexRegionMap->map[as<IRBlock>(diffMap[block])] = (IndexedRegion*)indexRegionMap->map[block]; } } + + // Swap the first block's occurences out for the first primal block. + firstBlock->replaceUsesWith(firstPrimalBlock); - // Process intermediate insts in indexed blocks - // into array loads/stores. - // + RefPtr<BlockSplitInfo> splitInfo = new BlockSplitInfo(); + for (auto block : mixedBlocks) + if (primalMap.ContainsKey(block)) + splitInfo->diffBlockMap[as<IRBlock>(primalMap[block])] = as<IRBlock>(diffMap[block]); + + Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlocksInfo; for (auto block : mixedBlocks) { - if (indexRegionMap->getRegion(block) != nullptr) - processIndexedFwdBlock(block); + indexedBlocksInfo[as<IRBlock>(diffMap[block])] = getIndexInfoList(as<IRBlock>(diffMap[block])); + indexedBlocksInfo[as<IRBlock>(primalMap[block])] = getIndexInfoList(as<IRBlock>(primalMap[block])); } - - // Swap the first block's occurences out for the first primal block. - firstBlock->replaceUsesWith(firstPrimalBlock); for (auto block : mixedBlocks) block->removeAndDeallocate(); + + // Run the three checkpointing passes to hoist/clone primal insts + // to the right spots. + // + { + RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(unzippedFunc->getModule()); + chkPolicy->preparePolicy(func); + + auto chkPrimalsInfo = chkPolicy->processFunc(func, splitInfo); + primalsInfo->merge(chkPrimalsInfo); + + primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlocksInfo); + } + + return primalsInfo; } void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info) @@ -375,42 +239,6 @@ struct DiffUnzipPass } } - UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) - { - SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator())); - - auto branchInst = as<IRUnconditionalBranch>(block->getTerminator()); - List<IRInst*> phiArgs; - - for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++) - phiArgs.add(branchInst->getArg(ii)); - - phiArgs.add(arg); - - builder->setInsertInto(block); - switch (branchInst->getOp()) - { - case kIROp_unconditionalBranch: - builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer()); - break; - - case kIROp_loop: - builder->emitLoop( - as<IRLoop>(branchInst)->getTargetBlock(), - as<IRLoop>(branchInst)->getBreakBlock(), - as<IRLoop>(branchInst)->getContinueBlock(), - phiArgs.getCount(), - phiArgs.getBuffer()); - break; - - default: - break; - } - - branchInst->removeAndDeallocate(); - return phiArgs.getCount() - 1; - } - IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type) { builder->setInsertInto(block); @@ -428,27 +256,7 @@ struct DiffUnzipPass return addPhiInputParam(builder, block, type); } - IRBlock* getBlock(IRInst* inst) - { - SLANG_RELEASE_ASSERT(inst); - - if (auto block = as<IRBlock>(inst)) - return block; - - return getBlock(inst->getParent()); - } - - IRInst* getInstInBlock(IRInst* inst) - { - SLANG_RELEASE_ASSERT(inst); - - if (auto block = as<IRBlock>(inst->getParent())) - return inst; - - return getInstInBlock(inst->getParent()); - } - - void lowerIndexedRegions() + void lowerIndexedRegions(HoistedPrimalsInfo* primalsInfo) { IRBuilder builder(autodiffContext->moduleInst->getModule()); @@ -464,6 +272,7 @@ struct DiffUnzipPass // Make variable in the top-most block (so it's visible to diff blocks) info->primalCountLastVar = builder.emitVar(builder.getIntType()); builder.addNameHintDecoration(info->primalCountLastVar, UnownedStringSlice("_pc_last_var")); + primalsInfo->storeSet.Add(info->primalCountLastVar); { auto primalCondBlock = as<IRUnconditionalBranch>( @@ -546,6 +355,23 @@ struct DiffUnzipPass builder.addPrimalValueAccessDecoration(primalCounterLastVal); builder.addLoopExitPrimalValueDecoration(loopInst, info->diffCountParam, primalCounterLastVal); + + // We'll be manually creating the inversion entries for the counters + // TODO: This logic can be moved to the checkpointing alg. + // + primalsInfo->invertSet.Add(info->diffCountParam); + primalsInfo->instsToInvert.Add(incCounterVal); + primalsInfo->invertInfoMap[incCounterVal] = InversionInfo( + incCounterVal, + List<IRInst*>(incCounterVal), + List<IRInst*>(info->diffCountParam)); + + primalsInfo->invertSet.Add(incCounterVal); + primalsInfo->instsToInvert.Add(diffUpdateBlock->getTerminator()); + primalsInfo->invertInfoMap[diffUpdateBlock->getTerminator()] = InversionInfo( + diffUpdateBlock->getTerminator(), + List<IRInst*>(diffUpdateBlock->getTerminator()), + List<IRInst*>(incCounterVal)); } // Try to infer maximum possible number of iterations. @@ -576,282 +402,10 @@ struct DiffUnzipPass return indices; } - void processIndexedFwdBlock(IRBlock* fwdBlock) - { - // Grab first primal block. - IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[fwdBlock->getParent()->getFirstBlock()->getNextBlock()]); - - // Scan through instructions and identify those that are used - // outside the local block. - // - IRBlock* primalBlock = as<IRBlock>(primalMap[fwdBlock]); - - List<IRInst*> primalInsts; - for (auto child = primalBlock->getFirstChild(); child; child = child->getNextInst()) - { - // TODO: This might be a decent place to enforce that each load has a single - // corresponding store (i.e. that everything is SSAd properly)? - - // We're only interested in insts that generate values. - if (child->getDataType() == nullptr || - as<IRVoidType>(child->getDataType()) || - as<IRFuncType>(child->getDataType()) || - as<IRTypeKind>(child->getDataType())) - continue; - - primalInsts.add(child); - } - - IRBuilder builder(autodiffContext->moduleInst->getModule()); - - for (auto inst : primalInsts) - { - // 1. Check if we need to store inst (is it used in a differential block?) - - bool shouldStore = false; - for (auto use = inst->firstUse; use; use = use->nextUse) - { - IRBlock* useBlock = getBlock(use->getUser()); - - if (isDifferentialInst(useBlock)) - { - shouldStore = true; - break; - } - } - - if (!shouldStore) continue; - - // 2. If we're dealing with a var, we need to locate the value that - // we actually need to store. We assume everything is SSA form - // so there must be a single IRStore on this var. - // - IRInst* valueToStore = nullptr; - IRBlock* valueBlock = nullptr; - IRType* valueType = nullptr; - - bool isPtrType = false; - bool isIntermediateContext = false; - - if (auto ptrValueType = as<IRPtrTypeBase>(inst->getDataType())) - { - isPtrType = true; - - // Find value to store - for (auto use = inst->firstUse; use; use = use->nextUse) - { - if (auto storeInst = as<IRStore>(use->getUser())) - { - // Should not see more than one IRStore - SLANG_RELEASE_ASSERT(!valueToStore); - valueToStore = storeInst->getVal(); - - // Is this the right block to use to determine if the - // store can have multiple values based on the index? - // - valueBlock = as<IRBlock>(storeInst->getParent()); - } - } - - if (as<IRBackwardDiffIntermediateContextType>(ptrValueType->getValueType())) - { - isIntermediateContext = true; - - // TODO: This should be the parent block of the `call` associated - // with this context type. The var itself _could_ be in a different place. - // - valueBlock = as<IRBlock>(inst->getParent()); - } - - valueType = ptrValueType->getValueType(); - } - else - { - isPtrType = false; - valueToStore = inst; - valueBlock = as<IRBlock>(inst->getParent()); - valueType = inst->getDataType(); - } - - // What do we do for primal vars that are used in the diff block - // but do not have an IRStore on them? This can happen for 'out' - // primal variables. - // - if (!valueToStore && !isIntermediateContext) - { - // For now, we can ignore them since they are used as inputs - // to 'out' parameters. If their value is every actually used, - // we will see an IRLoad which will be hoisted accordingly. - // - continue; - } - - // Build list of indices that the value's block is affected by. - List<IndexTrackingInfo*> indices = getIndexInfoList(valueBlock); - - // 3. Emit an array to top-level to allocate space. - - builder.setInsertBefore(firstPrimalBlock->getTerminator()); - - IRType* storageType = valueType; - - for (auto index : indices) - { - SLANG_ASSERT(index->status == IndexTrackingInfo::CountStatus::Static); - SLANG_ASSERT(index->maxIters >= 0); - - storageType = builder.getArrayType( - storageType, - builder.getIntValue( - builder.getUIntType(), - index->maxIters + 1)); - } - - // Reverse the list since the indices need to be - // emitted in reverse order. - // - indices.reverse(); - - auto storageVar = builder.emitVar(storageType); - if (isIntermediateContext) - builder.addBackwardDerivativePrimalContextDecoration( - storageVar, - storageVar); - - // 4. Store current value into the array and replace uses with a load. - // If an index is missing, use the 'last' value of the primal index. - - { - if (!isIntermediateContext) - setInsertAfterOrdinaryInst(&builder, valueToStore); - else - setInsertAfterOrdinaryInst(&builder, inst); - - IRInst* storeAddr = storageVar; - IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); - - for (auto index : indices) - { - currType = as<IRArrayType>(currType)->getElementType(); - - storeAddr = builder.emitElementAddress( - builder.getPtrType(currType), - storeAddr, - index->primalCountParam); - } - - if (!isIntermediateContext) - builder.emitStore(storeAddr, valueToStore); - else - { - List<IRUse*> primalUses; - for (auto use = inst->firstUse; use; use = use->nextUse) - { - if (!isDifferentialInst(getBlock(use->getUser()))) - primalUses.add(use); - } - - for (auto use : primalUses) - use->set(storeAddr); - } - } - - - // 5. Replace uses in differential blocks with loads from the array. - List<IRInst*> instsToTag; - { - List<IRUse*> diffUses; - for (auto use = inst->firstUse; use; use = use->nextUse) - { - if (as<IRDecoration>(use->getUser())) - { - if (!as<IRLoopExitPrimalValueDecoration>(use->getUser()) && - !as<IRBackwardDerivativePrimalContextDecoration>(use->getUser())) - continue; - } - - IRBlock* useBlock = getBlock(use->getUser()); - if (useBlock && isDifferentialInst(useBlock)) - diffUses.add(use); - } - - for (auto use : diffUses) - { - IRBlock* useBlock = getBlock(use->getUser()); - setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); - - IRInst* loadAddr = storageVar; - IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); - - // Enumerate use block regions. - // TODO: Probably a good idea to do this ahead of time for - // all blocks. - // - List<IndexTrackingInfo*> useBlockIndices = getIndexInfoList(useBlock); - - for (auto index : indices) - { - currType = as<IRArrayType>(currType)->getElementType(); - if (useBlockIndices.contains(index)) - { - // If the use-block is under the same region, use the - // differential counter variable - // - auto diffCounterCurrValue = index->diffCountParam; - - loadAddr = builder.emitElementAddress( - builder.getPtrType(currType), - loadAddr, - diffCounterCurrValue); - } - else - { - // If the use-block is outside this region, use the - // last available value (by indexing with primal counter minus 1) - // - auto primalCounterCurrValue = builder.emitLoad(index->primalCountLastVar); - auto primalCounterLastValue = builder.emitSub( - primalCounterCurrValue->getDataType(), - primalCounterCurrValue, - builder.getIntValue(builder.getIntType(), 1)); - - instsToTag.add(primalCounterCurrValue); - instsToTag.add(primalCounterLastValue); - - loadAddr = builder.emitElementAddress( - builder.getPtrType(currType), - loadAddr, - primalCounterLastValue); - } - - instsToTag.add(loadAddr); - } - - if (!isPtrType) - { - auto loadedValue = builder.emitLoad(loadAddr); - instsToTag.add(loadedValue); - - use->set(loadedValue); - } - else - { - use->set(loadAddr); - } - } - } - - for (auto instToTag : instsToTag) - { - builder.addPrimalValueAccessDecoration(instToTag); - builder.markInstAsPrimal(instToTag); - } - } - } - IRFunc* extractPrimalFunc( IRFunc* func, IRFunc* originalFunc, + HoistedPrimalsInfo* primalsInfo, ParameterBlockTransposeInfo& paramInfo, IRInst*& intermediateType); @@ -973,6 +527,13 @@ struct DiffUnzipPass auto primalArg = lookupPrimalInst(arg); auto diffArg = lookupDiffInst(arg); + if (auto primalVar = as<IRVar>(primalArg)) + { + primalArg = diffBuilder->emitVar(as<IRPtrTypeBase>(primalVar->getDataType())->getValueType()); + if (auto storeUse = findUniqueStoredVal(primalVar)) + diffBuilder->emitStore(primalArg, as<IRStore>(storeUse->getUser())->getVal()); + } + // If arg is a mixed differential (pair), it should have already been split. SLANG_ASSERT(primalArg); SLANG_ASSERT(diffArg); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 517b9e3ea..a3a7e4b77 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -595,6 +595,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_FloatType: case kIROp_VectorType: case kIROp_MatrixType: + case kIROp_AttributedType: return true; default: return false; @@ -904,7 +905,7 @@ struct AutoDiffPass : public InstPassBase else { IntermediateContextTypeDifferentialInfo diffFieldTypeInfo; - diffTypes.TryGetValue(field->getDataType(), diffFieldTypeInfo); + diffTypes.TryGetValue(field->getFieldType(), diffFieldTypeInfo); diffFieldWitness = diffFieldTypeInfo.diffWitness; } if (diffFieldWitness) @@ -1429,4 +1430,99 @@ bool finalizeAutoDiffPass(IRModule* module) return false; } +IRBlock* getBlock(IRInst* inst) +{ + SLANG_RELEASE_ASSERT(inst); + + if (auto block = as<IRBlock>(inst)) + return block; + + return getBlock(inst->getParent()); +} + +IRInst* getInstInBlock(IRInst* inst) +{ + SLANG_RELEASE_ASSERT(inst); + + if (auto block = as<IRBlock>(inst->getParent())) + return inst; + + return getInstInBlock(inst->getParent()); +} + +UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) +{ + SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator())); + + auto branchInst = as<IRUnconditionalBranch>(block->getTerminator()); + List<IRInst*> phiArgs; + + for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++) + phiArgs.add(branchInst->getArg(ii)); + + phiArgs.add(arg); + + builder->setInsertInto(block); + switch (branchInst->getOp()) + { + case kIROp_unconditionalBranch: + builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer()); + break; + + case kIROp_loop: + builder->emitLoop( + as<IRLoop>(branchInst)->getTargetBlock(), + as<IRLoop>(branchInst)->getBreakBlock(), + as<IRLoop>(branchInst)->getContinueBlock(), + phiArgs.getCount(), + phiArgs.getBuffer()); + break; + + default: + SLANG_UNEXPECTED("Unexpected branch-type for phi replacement"); + } + + branchInst->removeAndDeallocate(); + return phiArgs.getCount() - 1; +} + +IRUse* findUniqueStoredVal(IRVar* var) +{ + if (isDerivativeContextVar(var)) + { + IRUse* primalCallUse = nullptr; + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto callInst = as<IRCall>(use->getUser())) + { + // Should not see more than one IRCall. If we do + // we'll need to pick the primal call. + // + SLANG_RELEASE_ASSERT(!primalCallUse); + primalCallUse = use; + } + } + return primalCallUse; + } + else + { + IRUse* storeUse = nullptr; + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto storeInst = as<IRStore>(use->getUser())) + { + // Should not see more than one IRStore + SLANG_RELEASE_ASSERT(!storeUse); + storeUse = use; + } + } + return storeUse; + } +} + +bool isDerivativeContextVar(IRVar* var) +{ + return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>(); +} + } diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index e7a841323..d49babc52 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -298,6 +298,8 @@ bool processAutodiffCalls( bool finalizeAutoDiffPass(IRModule* module); +// Utility methods + void stripDerivativeDecorations(IRInst* inst); bool isBackwardDifferentiableFunc(IRInst* func); @@ -322,4 +324,15 @@ inline bool isRelevantDifferentialPair(IRType* type) return false; } +IRBlock* getBlock(IRInst* inst); + +IRInst* getInstInBlock(IRInst* inst); + +UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg); + +IRUse* findUniqueStoredVal(IRVar* var); + +bool isDerivativeContextVar(IRVar* var); + + }; diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 0f5c36dcb..9c4c1f4e2 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -650,6 +650,7 @@ struct IRBackwardDerivativePrimalContextDecoration : IRDecoration }; IR_LEAF_ISA(BackwardDerivativePrimalContextDecoration) + IRUse primalContextVar; IRInst* getBackwardDerivativePrimalContextVar() { return getOperand(0); } }; @@ -703,6 +704,8 @@ struct IRLoopExitPrimalValueDecoration : IRDecoration }; IR_LEAF_ISA(LoopExitPrimalValueDecoration) + IRUse target; + IRUse exitVal; IRInst* getTargetInst() { return getOperand(0); } IRInst* getLoopExitValInst() { return getOperand(1); } }; diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index d8246edae..9b50b9c30 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -802,6 +802,34 @@ IRInst* readVar( return readVarRec(context, blockInfo, var); } +void collectInstsToRemove( + ConstructSSAContext* context, + IRBlock* block) +{ + IRInst* next = nullptr; + for (auto ii = block->getFirstInst(); ii; ii = next) + { + next = ii->getNextInst(); + + switch (ii->getOp()) + { + default: + // Ordinary instruction -> leave as-is + break; + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + auto ptrArg = ii->getOperand(0); + if (auto var = asPromotableVarAccessChain(context, ptrArg)) + { + context->instsToRemove.add(ii); + } + } + break; + } + } +} + void processBlock( ConstructSSAContext* context, IRBlock* block, @@ -877,19 +905,6 @@ void processBlock( } } break; - - case kIROp_GetElementPtr: - case kIROp_FieldAddress: - { - auto ptrArg = ii->getOperand(0); - if (auto var = asPromotableVarAccessChain(context, ptrArg)) - { - context->instsToRemove.add(ii); - } - } - break; - - } } @@ -1078,6 +1093,10 @@ bool constructSSA(ConstructSSAContext* context) context->blockInfos.Add(bb, blockInfo); } + + for(auto bb : globalVal->getBlocks()) + collectInstsToRemove(context, bb); + for(auto bb : globalVal->getBlocks()) { auto blockInfo = * context->blockInfos.TryGetValue(bb); diff --git a/tests/autodiff/reverse-control-flow.slang b/tests/autodiff/reverse-control-flow-1.slang index 7d2f518be..7d2f518be 100644 --- a/tests/autodiff/reverse-control-flow.slang +++ b/tests/autodiff/reverse-control-flow-1.slang diff --git a/tests/autodiff/reverse-control-flow.slang.expected.txt b/tests/autodiff/reverse-control-flow-1.slang.expected.txt index 86aa47f11..86aa47f11 100644 --- a/tests/autodiff/reverse-control-flow.slang.expected.txt +++ b/tests/autodiff/reverse-control-flow-1.slang.expected.txt diff --git a/tests/autodiff/reverse-inout-param.slang b/tests/autodiff/reverse-inout-param-1.slang index de0d8f7ed..de0d8f7ed 100644 --- a/tests/autodiff/reverse-inout-param.slang +++ b/tests/autodiff/reverse-inout-param-1.slang diff --git a/tests/autodiff/reverse-inout-param.slang.expected.txt b/tests/autodiff/reverse-inout-param-1.slang.expected.txt index 2df174e2f..2df174e2f 100644 --- a/tests/autodiff/reverse-inout-param.slang.expected.txt +++ b/tests/autodiff/reverse-inout-param-1.slang.expected.txt |
