diff options
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.cpp | 155 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-primal-hoist.h | 26 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 19 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-dce.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 2 | ||||
| -rw-r--r-- | tests/autodiff/reverse-checkpoint-1.slang | 42 | ||||
| -rw-r--r-- | tests/autodiff/reverse-checkpoint-1.slang.expected.txt | 5 |
12 files changed, 244 insertions, 107 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 6a9b504a6..1bc3caaba 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1,5 +1,6 @@ #include "slang-ir-autodiff-primal-hoist.h" #include "slang-ir-autodiff-region.h" +#include "slang-ir-simplify-cfg.h" namespace Slang { @@ -9,7 +10,8 @@ void applyCheckpointSet( IRGlobalValueWithCode* func, HoistedPrimalsInfo* hoistInfo, HashSet<IRUse*> pendingUses, - Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock); + Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock, + IROutOfOrderCloneContext* cloneCtx); bool containsOperand(IRInst* inst, IRInst* operand) { @@ -68,7 +70,8 @@ static IRBlock* tryGetSubRegionEndBlock(IRInst* terminator) static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks( IRGlobalValueWithCode* func, - Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo) + Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo, + IROutOfOrderCloneContext* cloneCtx) { IRBlock* firstDiffBlock = nullptr; for (auto block : func->getBlocks()) @@ -136,7 +139,6 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks( WorkItem firstWorkItem = { func->getFirstBlock(), firstRecomputeBlock, firstRecomputeBlock, firstDiffBlock }; workList.add(firstWorkItem); - IRCloneEnv recomputeCloneEnv; recomputeBlockMap[func->getFirstBlock()] = firstRecomputeBlock; for (Index i = 0; i < workList.getCount(); i++) @@ -216,7 +218,7 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks( { case kIROp_Switch: case kIROp_ifElse: - newTerminator = cloneInst(&recomputeCloneEnv, &builder, primalBlock->getTerminator()); + newTerminator = cloneCtx->cloneInstOutOfOrder(&builder, primalBlock->getTerminator()); break; case kIROp_unconditionalBranch: newTerminator = builder.emitBranch(as<IRUnconditionalBranch>(terminator)->getTargetBlock()); @@ -271,7 +273,8 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks( RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( IRGlobalValueWithCode* func, - Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock) + Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock, + IROutOfOrderCloneContext* cloneCtx) { RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo(); @@ -483,7 +486,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( } RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo(); - applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock); + applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock, cloneCtx); return hoistInfo; } @@ -501,11 +504,6 @@ void applyToInst( return; } - if (hoistInfo->ignoreSet.Contains(inst)) - { - return; - } - bool isInstRecomputed = checkpointInfo->recomputeSet.Contains(inst); if (isInstRecomputed) { @@ -522,11 +520,10 @@ void applyToInst( // SLANG_UNIMPLEMENTED_X("Parameter recompute is not currently supported"); } + return; } - else - { - hoistInfo->recomputeSet.Add(cloneCtx->cloneInstOutOfOrder(builder, inst)); - } + auto recomputeInst = cloneCtx->cloneInstOutOfOrder(builder, inst); + hoistInfo->recomputeSet.Add(recomputeInst); } bool isInstInverted = checkpointInfo->invertSet.Contains(inst); @@ -553,17 +550,22 @@ void applyToInst( } } +static IRBlock* getParamPreludeBlock(IRGlobalValueWithCode* func) +{ + return func->getFirstBlock()->getNextBlock(); +} + void applyCheckpointSet( CheckpointSetInfo* checkpointInfo, IRGlobalValueWithCode* func, HoistedPrimalsInfo* hoistInfo, HashSet<IRUse*> pendingUses, - Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock) + Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock, + IROutOfOrderCloneContext* cloneCtx) { // Reconstruct diff block map. Dictionary<IRBlock*, IRBlock*> diffBlockMap = reconstructDiffBlockMap(func); - RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext(); for (auto use : pendingUses) cloneCtx->pendingUses.Add(use); @@ -583,10 +585,11 @@ void applyCheckpointSet( }; // Go back over the insts and move/clone them accoridngly. + auto paramPreludeBlock = getParamPreludeBlock(func); for (auto block : func->getBlocks()) { - // Skip parameter block. - if (block == func->getFirstBlock()) + // Skip parameter block and the param prelude block. + if (block == func->getFirstBlock() || block == paramPreludeBlock) continue; if (isDifferentialBlock(block)) @@ -646,7 +649,22 @@ void applyCheckpointSet( for (auto child : block->getChildren()) { + // Determine the insertion point for the recomputeInst. + // Normally we insert recomputeInst into the block's corresponding recomputeBlock. + // The exception is a load(inoutParam), in which case we insert the recomputed load + // at the right beginning of the function to correctly receive the initial parameter + // value. We can't just insert the load at recomputeBlock because at that point the + // primal logic may have already updated the param with a new value, and instead we + // want the original value. builder.setInsertBefore(recomputeInsertBeforeInst); + if (auto load = as<IRLoad>(child)) + { + if (load->getPtr()->getOp() == kIROp_Param && + load->getPtr()->getParent() == func->getFirstBlock()) + { + builder.setInsertBefore(getParamPreludeBlock(func)->getTerminator()); + } + } applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child); } } @@ -833,28 +851,33 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo) { RefPtr<IRDominatorTree> domTree = computeDominatorTree(func); + + IRBlock* defaultVarBlock = func->getFirstBlock()->getNextBlock(); IRBuilder builder(func->getModule()); - IRBlock* defaultVarBlock = func->getFirstBlock()->getNextBlock(); - SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock)); + IRBlock* defaultRecomptueVarBlock = nullptr; + for (auto block : func->getBlocks()) + if (isDifferentialOrRecomputeBlock(block)) + { + defaultRecomptueVarBlock = block; + break; + } + SLANG_RELEASE_ASSERT(defaultRecomptueVarBlock); OrderedHashSet<IRInst*> processedStoreSet; - auto ensureInstAvailable = [&](OrderedHashSet<IRInst*>& instSet) + auto ensureInstAvailable = [&](OrderedHashSet<IRInst*>& instSet, bool isRecomputeInst) { + SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock)); + for (auto instToStore : instSet) { - if (!instSet.Contains(instToStore)) - continue; - - if (hoistInfo->ignoreSet.Contains(instToStore)) - continue; IRBlock* defBlock = nullptr; if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) { auto varInst = as<IRVar>(instToStore); - auto storeUse = findUniqueStoredVal(varInst); + auto storeUse = findEarliestUniqueWriteUse(varInst); defBlock = getBlock(storeUse->getUser()); } @@ -899,19 +922,28 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( if (outOfScopeUses.getCount() == 0) { - processedStoreSet.Add(instToStore); + if (!isRecomputeInst) + processedStoreSet.Add(instToStore); continue; } + auto defBlockIndices = indexedBlockInfo[defBlock].GetValue(); + IRBlock* varBlock = defaultVarBlock; + if (isRecomputeInst) + { + varBlock = defaultRecomptueVarBlock; + if (defBlockIndices.getCount()) + { + varBlock = as<IRBlock>(defBlockIndices[0].diffCountParam->getParent()); + defBlockIndices.clear(); + } + } if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) { - IRVar* varToStore = as<IRVar>(instToStore); SLANG_RELEASE_ASSERT(varToStore); - auto storeUse = findUniqueStoredVal(varToStore); - - List<IndexTrackingInfo>& defBlockIndices = indexedBlockInfo[defBlock]; + auto storeUse = findLatestUniqueWriteUse(varToStore); bool isIndexedStore = (storeUse && defBlockIndices.getCount() > 0); @@ -921,7 +953,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( if (!isIndexedStore && isDerivativeContextVar(varToStore)) { varToStore->insertBefore(defaultVarBlock->getFirstOrdinaryInst()); - processedStoreSet.Add(varToStore); + if (!isRecomputeInst) + processedStoreSet.Add(varToStore); continue; } @@ -929,7 +962,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( IRVar* localVar = storeIndexedValue( &builder, - defaultVarBlock, + varBlock, builder.emitLoad(varToStore), defBlockIndices); @@ -942,8 +975,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( IRInst* loadAddr = emitIndexedLoadAddressForVar(&builder, localVar, defBlockIndices, useBlockIndices); builder.replaceOperand(use, loadAddr); } - - processedStoreSet.Add(localVar); + if (!isRecomputeInst) + processedStoreSet.Add(localVar); } else { @@ -951,7 +984,6 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( // The only case where there will be a reference of primal loop counter from rev blocks // is the start of a loop in the reverse code. Since loop counters are not considered a // part of their loop region, so we remove the first index info. - List<IndexTrackingInfo> defBlockIndices = indexedBlockInfo[defBlock]; bool isLoopCounter = (instToStore->findDecoration<IRLoopCounterDecoration>() != nullptr); if (isLoopCounter) { @@ -959,7 +991,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( } setInsertAfterOrdinaryInst(&builder, instToStore); - auto localVar = storeIndexedValue(&builder, defaultVarBlock, instToStore, defBlockIndices); + auto localVar = storeIndexedValue(&builder, varBlock, instToStore, defBlockIndices); for (auto use : outOfScopeUses) { @@ -974,14 +1006,15 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); builder.replaceOperand(use, loadIndexedValue(&builder, localVar, defBlockIndices, useBlockIndices)); } - - processedStoreSet.Add(localVar); + if (!isRecomputeInst) + processedStoreSet.Add(localVar); } } }; - ensureInstAvailable(hoistInfo->storeSet); - + ensureInstAvailable(hoistInfo->storeSet, false); + ensureInstAvailable(hoistInfo->recomputeSet, true); + // Replace the old store set with the processed one. hoistInfo->storeSet = processedStoreSet; @@ -1179,27 +1212,23 @@ void buildIndexedBlocks( } } -RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy( - IRGlobalValueWithCode* func, const List<IRInst*>& instsToIgnore) +RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func) { sortBlocksInFunc(func); Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo; buildIndexedBlocks(indexedBlockInfo, func); - auto recomputeBlockMap = createPrimalRecomputeBlocks(func, indexedBlockInfo); + RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext(); + auto recomputeBlockMap = createPrimalRecomputeBlocks(func, indexedBlockInfo, cloneCtx); sortBlocksInFunc(func); RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule()); chkPolicy->preparePolicy(func); - auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap); + auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap, cloneCtx); - for (auto propagateFuncSpecificInst : instsToIgnore) - { - primalsInfo->ignoreSet.add(propagateFuncSpecificInst); - } primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); return primalsInfo; } @@ -1343,7 +1372,6 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_GetSequentialID: case kIROp_Specialize: case kIROp_LookupWitness: -#if 0 case kIROp_Add: case kIROp_Sub: case kIROp_Mul: @@ -1364,7 +1392,6 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_BitXor: case kIROp_Lsh: case kIROp_Rsh: -#endif return false; case kIROp_GetElement: case kIROp_FieldExtract: @@ -1387,17 +1414,29 @@ static bool shouldStoreInst(IRInst* inst) if (as<IRType>(inst)) return false; - // Only store if the inst has differential inst user. - bool hasDiffUser = doesInstHaveDiffUse(inst); - if (!hasDiffUser) - return false; - return true; } bool canRecompute(IRDominatorTree* domTree, IRUse* use) { SLANG_UNUSED(domTree); + if (auto load = as<IRLoad>(use->get())) + { + // Generally, we cannot recompute a load(ptr), since ptr may be modified + // afterwards. The exceptions are a load of an inout param, since the + // propagation function never actually writes to the primal part of the + // inout param, and we can always just read the original param. + + auto ptr = load->getPtr(); + if (ptr->getOp() == kIROp_Param) + { + if (auto block = as<IRBlock>(ptr->getParent())) + { + return (block == block->getParent()->getFirstBlock()); + } + } + return false; + } auto param = as<IRParam>(use->get()); if (!param) return true; diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h index 3b3fb82b1..6e861bc5b 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.h +++ b/source/slang/slang-ir-autodiff-primal-hoist.h @@ -14,10 +14,8 @@ namespace Slang IRCloneEnv cloneEnv; HashSet<IRUse*> pendingUses; - IRInst* cloneInstOutOfOrder(IRBuilder* builder, IRInst* inst) + void registerClonedInst(IRBuilder* builder, IRInst* inst, IRInst* clonedInst) { - IRInst* clonedInst = cloneInst(&cloneEnv, builder, inst); - UInt operandCount = clonedInst->getOperandCount(); for (UInt ii = 0; ii < operandCount; ++ii) { @@ -31,16 +29,21 @@ namespace Slang for (auto use = inst->firstUse; use;) { auto nextUse = use->nextUse; - + if (pendingUses.Contains(use)) { pendingUses.Remove(use); builder->replaceOperand(use, clonedInst); } - + use = nextUse; } + } + IRInst* cloneInstOutOfOrder(IRBuilder* builder, IRInst* inst) + { + IRInst* clonedInst = cloneInst(&cloneEnv, builder, inst); + registerClonedInst(builder, inst, clonedInst); return clonedInst; } }; @@ -86,7 +89,6 @@ namespace Slang OrderedHashSet<IRInst*> storeSet; OrderedHashSet<IRInst*> recomputeSet; OrderedHashSet<IRInst*> invertSet; - OrderedHashSet<IRInst*> ignoreSet; OrderedHashSet<IRInst*> instsToInvert; Dictionary<IRInst*, InversionInfo> invertInfoMap; @@ -129,9 +131,6 @@ namespace Slang for (auto inst : info->invertSet) invertSet.Add(inst); - for (auto inst : info->ignoreSet) - ignoreSet.add(inst); - for (auto inst : info->instsToInvert) instsToInvert.Add(inst); @@ -261,7 +260,8 @@ namespace Slang RefPtr<HoistedPrimalsInfo> processFunc( IRGlobalValueWithCode* func, - Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock); + Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock, + IROutOfOrderCloneContext* cloneCtx); // Do pre-processing on the function (mainly for // 'global' checkpointing methods that consider the entire @@ -290,9 +290,5 @@ namespace Slang RefPtr<IRDominatorTree> domTree; }; - RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy( - IRGlobalValueWithCode* func, - const List<IRInst*>& instsToIgnore); - - + RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func); }; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 979eb6343..d7abf1d40 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -711,12 +711,9 @@ namespace Slang // Apply checkpointing policy to legalize cross-scope uses of primal values // using either recompute or store strategies. - auto primalsInfo = applyCheckpointPolicy( - diffPropagateFunc, paramTransposeInfo.propagateFuncSpecificPrimalInsts); - + auto primalsInfo = applyCheckpointPolicy(diffPropagateFunc); eliminateDeadCode(diffPropagateFunc); - // Extracts the primal computations into its own func, and replace the primal insts // with the intermediate results computed from the extracted func. @@ -810,10 +807,13 @@ namespace Slang // Find the 'next' block using the terminator inst of the parameter block. auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator()); - auto nextBlock = fwdParamBlockBranch->getTargetBlock(); + // We create a new block after parameter block to hold insts that translates from transposed parameters + // into something that the rest of the function can use. + IRBuilder::insertBlockAlongEdge(diffFunc->getModule(), IREdge(&fwdParamBlockBranch->block)); + auto paramPreludeBlock = fwdParamBlockBranch->getTargetBlock(); auto nextBlockBuilder = *builder; - nextBlockBuilder.setInsertBefore(nextBlock->getFirstOrdinaryInst()); + nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst()); IRBlock* firstDiffBlock = nullptr; for (auto block : diffFunc->getBlocks()) diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 44e981404..a864a74b2 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -216,12 +216,15 @@ struct ExtractPrimalFuncContext { if (as<IRVar>(inst)) { - auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary); - builder.setInsertBefore(inst); - auto fieldAddr = builder.emitFieldAddress( - inst->getFullType(), outIntermediary, field->getKey()); - inst->replaceUsesWith(fieldAddr); - builder.addPrimalValueStructKeyDecoration(inst, field->getKey()); + if (inst->hasUses()) + { + auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary); + builder.setInsertBefore(inst); + auto fieldAddr = builder.emitFieldAddress( + inst->getFullType(), outIntermediary, field->getKey()); + inst->replaceUsesWith(fieldAddr); + builder.addPrimalValueStructKeyDecoration(inst, field->getKey()); + } } else { @@ -359,7 +362,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( List<IRInst*> instsToRemove; for (auto block : func->getBlocks()) { - for (auto inst : block->getOrdinaryInsts()) + for (auto inst : block->getChildren()) { if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>()) { @@ -420,6 +423,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( for (auto inst : instsToRemove) { + if (as<IRParam>(inst)) + removePhiArgs(inst); inst->removeAndDeallocate(); } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 65f45ece8..532e63b42 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -316,8 +316,12 @@ struct DiffUnzipPass if (auto primalVar = as<IRVar>(primalArg)) { primalArg = diffBuilder->emitVar(as<IRPtrTypeBase>(primalVar->getDataType())->getValueType()); - if (auto storeUse = findUniqueStoredVal(primalVar)) - diffBuilder->emitStore(primalArg, as<IRStore>(storeUse->getUser())->getVal()); + if (auto storeUse = findUniqueStoredVal(primalVar)) + { + auto storeInst = diffBuilder->emitStore(primalArg, as<IRStore>(storeUse->getUser())->getVal()); + storeInst->insertAfter(storeUse->getUser()); + primalArg->insertBefore(storeInst); + } } // If arg is a mixed differential (pair), it should have already been split. diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index a8af148d9..656b0e11b 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1868,6 +1868,48 @@ IRUse* findUniqueStoredVal(IRVar* var) } } +// Given a local var that is supposed to have a unique write, find the last inst +// that writes to it. Note: if var is intended for an inout argument, it will +// have exactly one store that sets its initial value and one call that writes +// the final value to it, this method will return the call inst for this case. +IRUse* findLatestUniqueWriteUse(IRVar* var) +{ + IRUse* storeUse = nullptr; + // If no unique store found, try to look for a call. + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto callInst = as<IRCall>(use->getUser())) + { + SLANG_RELEASE_ASSERT(!storeUse); + storeUse = use; + } + } + return findUniqueStoredVal(var); +} + +// Given a local var that is supposed to have a unique write, find the last inst +// that writes to it. Note: if var is intended for an inout argument, it will +// have exactly one store that sets its initial value and one call that writes +// the final value to it, this method will return the store inst for this case. +IRUse* findEarliestUniqueWriteUse(IRVar* var) +{ + IRUse* storeUse = findUniqueStoredVal(var); + if (storeUse) + return storeUse; + + // If no unique store found, try to look for a call. + for (auto use = var->firstUse; use; use = use->nextUse) + { + if (auto callInst = as<IRCall>(use->getUser())) + { + SLANG_RELEASE_ASSERT(!storeUse); + storeUse = use; + } + } + return storeUse; +} + + bool isDerivativeContextVar(IRVar* var) { return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>(); diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index d7d6119d4..52cf346b3 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -343,6 +343,8 @@ IRInst* getInstInBlock(IRInst* inst); UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst*& inoutTerminatorInst, IRInst* arg); IRUse* findUniqueStoredVal(IRVar* var); +IRUse* findLatestUniqueWriteUse(IRVar* var); +IRUse* findEarliestUniqueWriteUse(IRVar* var); bool isDerivativeContextVar(IRVar* var); diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 364abe68c..1b0ecf521 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -223,25 +223,6 @@ struct DeadCodeEliminationContext return processInst(module->getModuleInst()); } - void removePhiArgs(IRInst* phiParam) - { - auto block = cast<IRBlock>(phiParam->getParent()); - UInt paramIndex = 0; - for (auto p = block->getFirstParam(); p; p = p->getNextParam()) - { - if (p == phiParam) - break; - paramIndex++; - } - for (auto predBlock : block->getPredecessors()) - { - auto termInst = as<IRUnconditionalBranch>(predBlock->getTerminator()); - SLANG_ASSERT(paramIndex < termInst->getArgCount()); - termInst->removeArgument(paramIndex); - } - phiRemoved = true; - } - bool eliminateDeadInstsRec(IRInst* inst) { bool changed = false; @@ -266,6 +247,7 @@ struct DeadCodeEliminationContext { // For Phi parameters, we need to update all branch arguments. removePhiArgs(inst); + phiRemoved = true; } inst->removeAndDeallocate(); changed = true; diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 03b74b36a..9348dfe8a 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -741,6 +741,24 @@ void moveParams(IRBlock* dest, IRBlock* src) } } +void removePhiArgs(IRInst* phiParam) +{ + auto block = cast<IRBlock>(phiParam->getParent()); + UInt paramIndex = 0; + for (auto p = block->getFirstParam(); p; p = p->getNextParam()) + { + if (p == phiParam) + break; + paramIndex++; + } + for (auto predBlock : block->getPredecessors()) + { + auto termInst = as<IRUnconditionalBranch>(predBlock->getTerminator()); + SLANG_ASSERT(paramIndex < termInst->getArgCount()); + termInst->removeArgument(paramIndex); + } +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index e7d182604..9405771b1 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -200,6 +200,8 @@ IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key); IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key); void moveParams(IRBlock* dest, IRBlock* src); + +void removePhiArgs(IRInst* phiParam); } #endif diff --git a/tests/autodiff/reverse-checkpoint-1.slang b/tests/autodiff/reverse-checkpoint-1.slang new file mode 100644 index 000000000..e503fb50f --- /dev/null +++ b/tests/autodiff/reverse-checkpoint-1.slang @@ -0,0 +1,42 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float g(float x) +{ + return log(x); +} + +[BackwardDifferentiable] +float f(int p, float x) +{ + float y = 1.0; + // Test that phi parameter can be restored. + if (p == 0) + y = g(x); + return y * y; +} + +// Check that there are no calls to primal_g in bwd_f. + +// CHECK: void s_bwd_f_{{[0-9]+}} +// CHECK-NOT: {{[_a-zA-Z0-9]+}} = s_bwd_primal_g_{{[0-9]+}} +// CHECK: return + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + dpfloat dpa = dpfloat(2.0, 0.0); + + __bwd_diff(f)(0, dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 1 +} diff --git a/tests/autodiff/reverse-checkpoint-1.slang.expected.txt b/tests/autodiff/reverse-checkpoint-1.slang.expected.txt new file mode 100644 index 000000000..1ea0454e4 --- /dev/null +++ b/tests/autodiff/reverse-checkpoint-1.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +0.693147 +0.000000 +0.000000 +0.000000 |
