diff options
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 82 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 261 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 11 |
5 files changed, 80 insertions, 282 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 08b946cdd..5d11b7fb3 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -9,7 +9,7 @@ void applyCheckpointSet( CheckpointSetInfo* checkpointInfo, IRGlobalValueWithCode* func, HoistedPrimalsInfo* hoistInfo, - HashSet<IRUse*> pendingUses, + HashSet<IRUse*>& pendingUses, Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock, IROutOfOrderCloneContext* cloneCtx); @@ -561,7 +561,7 @@ void applyCheckpointSet( CheckpointSetInfo* checkpointInfo, IRGlobalValueWithCode* func, HoistedPrimalsInfo* hoistInfo, - HashSet<IRUse*> pendingUses, + HashSet<IRUse*>& pendingUses, Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock, IROutOfOrderCloneContext* cloneCtx) { @@ -847,11 +847,64 @@ static int getInstRegionNestLevel( return (int)result; } + +/// Legalizes all accesses to primal insts from recompute and diff blocks. +/// RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( HoistedPrimalsInfo* hoistInfo, IRGlobalValueWithCode* func, Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo) { + // In general, after checkpointing, we can have a function like the following: + // ``` + // void func() + // { + // primal: + // for (int i = 0; i < 5; i++) + // { + // float x = g(i); + // use(x); + // } + // recompute: + // ... + // diff: + // for (int i = 5; i >= 0; i--) + // { + // recompute: + // ... + // diff: + // use_diff(x); // def of x is not dominating this location! + // } + // } + // ``` + // This function will legalize the access to x in the dff block by creating + // a proper local variable and insert store/loads, so that the above function + // will be transformed to: + // ``` + // void func() + // { + // primal: + // float x_storage[5]; + // + // for (int i = 0; i < 5; i++) + // { + // float x = g(i); + // x_storage[i] = x; + // use(x); + // } + // recompute: + // ... + // diff: + // for (int i = 5; i >= 0; i--) + // { + // recompute: + // ... + // diff: + // use_diff(x_storage[i]); + // } + // } + // + RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); IRBlock* defaultVarBlock = func->getFirstBlock()->getNextBlock(); @@ -1027,7 +1080,6 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( return hoistInfo; } - void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info) { if (info->status != IndexTrackingInfo::CountStatus::Unresolved) @@ -1042,7 +1094,6 @@ void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info) } } - IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type) { builder->setInsertInto(block); @@ -1175,6 +1226,13 @@ void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalC } } +// Insert iteration counters for all loops to form indexed regions. For loops in +// primal blocks, the counter is incremented from 0. For loops in reverse +// blocks, the counter is decremented from the final value in primal block +// downto 0. Returns a mapping from each block to a list of their enclosing loop +// regions. A loop region records the iteration counter for the corresponding +// loop in the primal block and the reverse block. +// void buildIndexedBlocks( Dictionary<IRBlock*, List<IndexTrackingInfo>>& info, IRGlobalValueWithCode* func) @@ -1218,23 +1276,37 @@ void buildIndexedBlocks( } } +// For each primal inst that is used in reverse blocks, decide if we should recompute or store +// its value, then make them accessible in reverse blocks based the decision. +// RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func) { sortBlocksInFunc(func); + // Insert loop counters and establish loop regions. + // Also makes the reverse loops counting downwards from the final iteration count. + // Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo; buildIndexedBlocks(indexedBlockInfo, func); + // Create recompute blocks for each region following the same control flow structure + // as in primal code. + // RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext(); auto recomputeBlockMap = createPrimalRecomputeBlocks(func, indexedBlockInfo, cloneCtx); sortBlocksInFunc(func); + // Determine the strategy we should use to make a primal inst available. + // If we decide to recompute the inst, emit the recompute inst in the corresponding recompute block. + // RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule()); chkPolicy->preparePolicy(func); - auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap, cloneCtx); + // Legalize the primal inst accesses by introducing local variables / arrays and emitting + // necessary load/store logic. + // primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); return primalsInfo; } diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index ef1bdaf1e..e3575aceb 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -715,8 +715,8 @@ namespace Slang eliminateDeadCode(diffPropagateFunc); - // Extracts the primal computations into its own func, and replace the primal insts - // with the intermediate results computed from the extracted func. + // Extracts the primal computations into its own func, turn all accesses to stored primal insts into + // explicit intermediate data structure reads and writes. IRInst* intermediateType = nullptr; auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( diffPropagateFunc, primalFunc, primalsInfo, paramTransposeInfo, intermediateType); @@ -779,7 +779,7 @@ namespace Slang initializeLocalVariables(builder->getModule(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric))); initializeLocalVariables(builder->getModule(), diffPropagateFunc); - // insertVariableForRecomputedPrimalInsts(diffPropagateFunc); + stripTempDecorations(diffPropagateFunc); sortBlocksInFunc(diffPropagateFunc); diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 86a6f2846..845372ba7 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -114,8 +114,6 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc); - void insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc); - void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc); InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index bcc494fa9..8a734446d 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -514,7 +514,6 @@ struct DiffTransposePass // (i.e. not store per-func info in 'this') // since it is reused for every reverse-mode call. // - primalVarsToHoist.clear(); // Grab all differentiable type information. diffTypeContext.setFunc(revDiffFunc); @@ -663,9 +662,6 @@ struct DiffTransposePass for (auto block : workList) block->removeFromParent(); - finishHoistingPrimals(revDiffFunc); - - // At this point, the only block left without terminator insts // should be the last one. Add a void return to complete it. // @@ -972,259 +968,6 @@ struct DiffTransposePass } - struct InvInstPair - { - IRInst* inst; - IRInst* invInst; - - InvInstPair(IRInst* inst, IRInst* invInst) : - inst(inst), invInst(invInst) - { } - - InvInstPair() : inst(nullptr), invInst(nullptr) - { } - }; - - List<InvInstPair> invertArithmetic(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo) - { - SLANG_RELEASE_ASSERT(invInfo.requiredOperands.getCount() == 1); - SLANG_RELEASE_ASSERT(invInfo.targetInsts.getCount() == 1); - - auto invOutput = invInfo.requiredOperands[0]; - - auto invTargetInst = invInfo.targetInsts[0]; - - switch (primalInst->getOp()) - { - case kIROp_Add: - { - SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1))); - return List<InvInstPair>( - InvInstPair( - invTargetInst, - builder->emitSub( - primalInst->getOperand(0)->getDataType(), - invOutput, - primalInst->getOperand(1)))); - } - case kIROp_Sub: - { - SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1))); - return List<InvInstPair>( - InvInstPair( - invTargetInst, - builder->emitAdd( - primalInst->getOperand(0)->getDataType(), - invOutput, - primalInst->getOperand(1)))); - } - - default: - SLANG_UNEXPECTED("Unhandled arithmetic inst for inversion"); - } - } - - // Go through loop block phi-args, and look for loop counter - // arguments, which for a loop means inserting a check into - // loop condition block. - // This method also adds logic to skip the first iteration. - // (a 'do-while' loop) - // - void invertLoopCondition(IRBuilder* builder, IRLoop* loopInst) - { - auto firstLoopBlock = loopInst->getTargetBlock(); - - IRBlock* revLoopCondBlock = revBlockMap[firstLoopBlock]; - builder->setInsertBefore(revLoopCondBlock->getTerminator()); - - // Add a terminating condition based on the loop counter's initial primal value - - IRParam* loopCounterParam = nullptr; - UIndex loopCounterParamIndex = 0; - for (auto param : firstLoopBlock->getParams()) - { - if (param->findDecoration<IRLoopCounterDecoration>()) - { - // There really not should be two (or more) loop counter params. - SLANG_RELEASE_ASSERT(loopCounterParam == nullptr); - loopCounterParam = param; - } - else - { - loopCounterParamIndex++; - } - } - - // Should see atleast one loop counter parameter on the first loop block. - SLANG_RELEASE_ASSERT(loopCounterParam); - - IRInst* loopCounterInitVal = loopInst->getArg(loopCounterParamIndex); - - auto paramBoundsCheck = builder->emitIntrinsicInst( - builder->getBoolType(), - kIROp_Neq, - 2, - List<IRInst*>( - loopCounterParam, - loopCounterInitVal).getBuffer()); - - as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(paramBoundsCheck); - } - - IRInst* lookupInstInPrimalBlock(IRInst* invInst) - { - // Lookup the inst in the primal block whose value we can use as an operand - // for the inverted inst. - // - // auto inversionInfo = this->hoistedPrimalsInfo->invertInfoMap[invInst]; - return invInst; - } - - bool doesInstRequireHoisting(IRInst* inst) - { - if (as<IRModuleInst>(inst->getParent())) - return false; - - if (as<IRBlock>(inst) || - as<IRGlobalValueWithCode>(inst) || - as<IRConstant>(inst)) - return false; - - if (as<IRTerminatorInst>(inst)) - return false; - - if (as<IRDecoration>(inst)) - return doesInstRequireHoisting(getInstInBlock(inst)); - - // We're looking for primal insts in differential blocks - // that have not yet been moved to the 'active' blocks - // (i.e in diff blocks that do not have parents) - // - return (!isDifferentialInst(inst) && - (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) && - getBlock(inst)->getParent() == nullptr); - } - - IRBlock* walkToEndOfRegion(IRBlock* block) - { - IRBlock* currBlock = block; - - bool keepGoing = true; - while (keepGoing) - { - auto terminator = currBlock->getTerminator(); - switch (terminator->getOp()) - { - case kIROp_Return: - keepGoing = false; - break; - - case kIROp_unconditionalBranch: - { - auto nextBlock = as<IRUnconditionalBranch>(terminator)->getTargetBlock(); - - HashSet<IRBlock*> predecessorSet; - for (auto predecessor : nextBlock->getPredecessors()) - predecessorSet.add(predecessor); - - if (predecessorSet.getCount() > 1) - { - keepGoing = false; - break; - } - - currBlock = nextBlock; - break; - } - - case kIROp_ifElse: - { - for (auto predecessor : currBlock->getPredecessors()) - { - if (as<IRLoop>(predecessor->getTerminator())) - { - keepGoing = false; - break; - } - } - - currBlock = as<IRIfElse>(terminator)->getAfterBlock(); - break; - } - - case kIROp_Switch: - currBlock = as<IRSwitch>(terminator)->getBreakLabel(); - break; - - case kIROp_loop: - currBlock = as<IRLoop>(terminator)->getBreakBlock(); - break; - } - } - - return currBlock; - } - - void finishHoistingPrimals(IRGlobalValueWithCode* func) - { - auto varBlock = func->getFirstBlock()->getNextBlock(); - - for (auto inst : primalVarsToHoist) - { - if (!doesInstRequireHoisting(inst)) - continue; - - List<IRUse*> relevantUses; - - IRBlock* defBlock = nullptr; - if (auto varToHoist = as<IRVar>(inst)) - { - varToHoist->insertBefore(varBlock->getFirstOrdinaryInst()); - auto uniqueStoreUse = findUniqueStoredVal(varToHoist); - if (uniqueStoreUse) - { - inst = uniqueStoreUse->getUser(); - SLANG_ASSERT(inst); - - defBlock = getBlock(inst); - } - else - { - defBlock = getBlock(inst); - } - } - else - { - defBlock = getBlock(inst); - } - - if (!doesInstRequireHoisting(inst)) - continue; - - // Move this inst to after it's diff uses. - // - { - - IRBlock* currTopBlock = revBlockMap[walkToEndOfRegion(defBlock)]; - - SLANG_RELEASE_ASSERT(currTopBlock); - - // More consistency checks - SLANG_RELEASE_ASSERT(currTopBlock->getFirstOrdinaryInst() != nullptr); - SLANG_RELEASE_ASSERT(currTopBlock->getParent() != nullptr); - SLANG_RELEASE_ASSERT(isDifferentialInst(currTopBlock)); - - // Insert at top. (disabling validation since the operands of - // this inst might not be hoisted to the right place yet) - // - disableIRValidationAtInsert(); - inst->insertBefore(currTopBlock->getFirstOrdinaryInst()); - enableIRValidationAtInsert(); - } - } - } - - void transposeInst(IRBuilder* builder, IRInst* inst) { switch (inst->getOp()) @@ -1386,8 +1129,6 @@ struct DiffTransposePass auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType(); auto tempVar = builder->emitVar(pairType); auto primalVal = builder->emitLoad(instPair->getPrimal()); - auto primalVar = instPair->getPrimal(); - primalVarsToHoist.add(primalVar); auto diffVal = builder->emitLoad(instPair->getDiff()); auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); @@ -3002,8 +2743,6 @@ struct DiffTransposePass DifferentialPairTypeBuilder pairBuilder; - List<IRInst*> primalVarsToHoist; - IRBlock* tempInvBlock; Dictionary<IRInst*, List<RevGradient>> gradientsMap; diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index e0723dcdd..34f0f6c9b 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -313,17 +313,6 @@ struct DiffUnzipPass auto primalArg = lookupPrimalInst(arg); auto diffArg = lookupDiffInst(arg); - if (auto primalVar = as<IRVar>(primalArg)) - { - primalArg = diffBuilder->emitVar(as<IRPtrTypeBase>(primalVar->getDataType())->getValueType()); - if (auto storeUse = findUniqueStoredVal(primalVar)) - { - auto storeInst = diffBuilder->emitStore(primalArg, as<IRStore>(storeUse->getUser())->getVal()); - storeInst->insertAfter(storeUse->getUser()); - primalArg->insertBefore(storeInst); - } - } - // If arg is a mixed differential (pair), it should have already been split. SLANG_ASSERT(primalArg); SLANG_ASSERT(diffArg); |
