diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-17 12:03:59 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-17 09:03:59 -0800 |
| commit | f253d15a3b2681dfa40491451fcb3f21f1dbe412 (patch) | |
| tree | 589298ff23ea2b2eb89615694f2c06613f1199a1 /source | |
| parent | 245466d89cfe54b78da486f06d470bc6daaf4625 (diff) | |
Proper reverse-mode loop handling with splitting + inversion steps (#2656)
* Halfway to loop inversion
* More progress towards proper loop inversion
* More progress towards inverse insts. Only thing left is adding `counter>=0` at the right place
* More fixes for inversion step.
* Lots more fixes, added primal inst 'hoisting' mechanism as the central method that ensures primal values are placed in the right spot
* Loop inversion is now functional
* Cleaned up commented code
* rename diffCounterVar -> diffCounterParam
* minor update
* removed some comments and commented code
* Switch `IRBuilder(sharedIRBuilder)` to `IRBuilder(moduleInst)`
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-propagate.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 441 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 328 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa.cpp | 5 |
7 files changed, 636 insertions, 179 deletions
diff --git a/source/slang/slang-ir-autodiff-propagate.h b/source/slang/slang-ir-autodiff-propagate.h index 4edf20142..8f912ba61 100644 --- a/source/slang/slang-ir-autodiff-propagate.h +++ b/source/slang/slang-ir-autodiff-propagate.h @@ -15,6 +15,11 @@ inline bool isDifferentialInst(IRInst* inst) return inst->findDecoration<IRDifferentialInstDecoration>(); } +inline bool isPrimalInst(IRInst* inst) +{ + return inst->findDecoration<IRPrimalInstDecoration>() || (as<IRConstant>(inst) != nullptr); +} + inline bool isMixedDifferentialInst(IRInst* inst) { return inst->findDecoration<IRMixedDifferentialInstDecoration>(); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 8aca31642..b74416b76 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -227,6 +227,9 @@ struct DiffTransposePass IRBlock* revAfterBlock = revBlockMap[currentBlock]; builder.setInsertInto(revCondBlock); + + hoistPrimalInst(&builder, ifElse->getCondition()); + builder.emitIfElse( ifElse->getCondition(), revTrueEntryBlock, @@ -357,6 +360,8 @@ struct DiffTransposePass // Emit condition into the new cond block. builder.setInsertInto(revCondBlock); + hoistPrimalInst(&builder, ifElse->getCondition()); + builder.emitIfElse( ifElse->getCondition(), revTrueBlock, @@ -442,7 +447,11 @@ struct DiffTransposePass } auto revSwitchBlock = revBlockMap[breakBlock]; + builder.setInsertInto(revSwitchBlock); + + hoistPrimalInst(&builder, switchInst->getCondition()); + builder.emitSwitch( switchInst->getCondition(), revBreakBlock, @@ -588,6 +597,21 @@ struct DiffTransposePass auto firstFwdDiffBlock = branchInst->getTargetBlock(); reverseCFGRegion(firstFwdDiffBlock, List<IRBlock*>()); + // Lower any loop-exit-value decorations into initializations for loop intermediate vals, + // and convert loop initial values into terminating conditions. + // + // TODO: We need a way to confirm that all required vars have an initial value + // (is there a built-in dataflow tool for this?) + // + for (auto block : workList) + { + if (auto loopInst = as<IRLoop>(block->getTerminator())) + { + lowerLoopExitValues(&builder, loopInst); + invertLoopCondition(&builder, loopInst); + } + } + // Link the last differential fwd-mode block (which will be the first // rev-mode block) as the successor to the last primal block. // We assume that the original function is in single-return form @@ -686,6 +710,36 @@ struct DiffTransposePass return tempRevVar; } + IRVar* getOrCreateInverseVar(IRInst* primalInst) + { + // No need to store inverse values for constants. + if (as<IRConstant>(primalInst)) + return nullptr; + + // Check if we have a var already. + if (inverseVarMap.ContainsKey(primalInst)) + return inverseVarMap[primalInst]; + + IRBuilder tempVarBuilder(autodiffContext->moduleInst); + + IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(primalInst->getParent()->getParent())]; + + if (auto firstInst = firstDiffBlock->getFirstOrdinaryInst()) + tempVarBuilder.setInsertBefore(firstInst); + else + tempVarBuilder.setInsertInto(firstDiffBlock); + + auto primalType = primalInst->getDataType(); + + // Emit a var in the top-level differential block to hold the inverse, + // and initialize it. + auto tempInvVar = tempVarBuilder.emitVar(primalType); + + inverseVarMap[primalInst] = tempInvVar; + + return tempInvVar; + } + bool isInstUsedOutsideParentBlock(IRInst* inst) { auto currBlock = inst->getParent(); @@ -707,7 +761,7 @@ struct DiffTransposePass builder.setInsertInto(revBlock); // Check if this block has any 'outputs' (in the form of phi args - // sent to the successor bvock) + // sent to the successor block) // if (auto branchInst = as<IRUnconditionalBranch>(fwdBlock->getTerminator())) { @@ -716,51 +770,48 @@ struct DiffTransposePass auto arg = branchInst->getArg(ii); if (isDifferentialInst(arg)) { + // If the arg is a differential, emit a parameter + // to accept it's reverse-mode differential as an input + // + auto diffType = arg->getDataType(); auto revParam = builder.emitParam(diffType); addRevGradientForFwdInst( - arg, + arg, RevGradient( RevGradient::Flavor::Simple, arg, revParam, nullptr)); } - } - } - - // Some special instructions simply need to be copied over. - // These do not deal with differentials. - // TODO: This will not work if there are any differential - // insts that rely on loop counter vars having a specific - // value. - // The solution is to have primal insts appearing in - // differential blocks be in their own special blocks that are - // ignored entirely, rather than dealing with them one inst - // at a time. - // - for (IRInst* child = fwdBlock->getFirstChild(); child;) - { - auto nextChild = child->getNextInst(); + else if (isPrimalInst(arg)) + { + // If the output arg is a primal, emit a parameter + // to accept it as an _input_ for the reverse-mode + // + auto primalType = arg->getDataType(); + auto primalInvParam = builder.emitParam(primalType); - if (child->findDecoration<IRLoopCounterDecoration>()) - { - // Loop counter insts should not have any gradients. - SLANG_ASSERT(!hasRevGradients(child)); - child->insertAtEnd(revBlock); + setInverse(&builder, arg, primalInvParam); + } + else + { + SLANG_UNEXPECTED("Encountered inst not marked as primal or differential"); + } } - - child = nextChild; } // Move pointer & reference insts to the top of the reverse-mode block. List<IRInst*> nonValueInsts; for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) { - // If the instruction is pointer typed, it's not actually computing a value. + // If the instruction is a variable allocation (or reverse-gradient pair reference), + // move to top. + // TODO: This is hacky.. Need a more principled way to handle this + // (like primal inst hoisting) // - if (as<IRPtrTypeBase>(child->getDataType())) + if (as<IRVar>(child) || as<IRReverseGradientDiffPairRef>(child)) nonValueInsts.add(child); // Slang doesn't support function values. So if we see a func-typed inst @@ -782,11 +833,16 @@ struct DiffTransposePass // for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst()) { + if (child->findDecoration<IRPrimalValueAccessDecoration>()) + continue; + if (as<IRDecoration>(child) || as<IRParam>(child)) continue; - - transposeInst(&builder, child); + if (isDifferentialInst(child)) + transposeInst(&builder, child); + else if (isPrimalInst(child)) + invertInst(&builder, child); } // After processing the block's instructions, we 'flush' any remaining gradients @@ -806,32 +862,47 @@ struct DiffTransposePass List<IRInst*> phiParamRevGradInsts; for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam()) { - // This param might be used outside this block. - // If so, add/get an accumulator. - // - if (isInstUsedOutsideParentBlock(param)) + if (isDifferentialInst(param)) { - auto accGradient = extractAccumulatorVarGradient(&builder, param); - addRevGradientForFwdInst( - param, - RevGradient(param, accGradient, nullptr)); + // This param might be used outside this block. + // If so, add/get an accumulator. + // + if (isInstUsedOutsideParentBlock(param)) + { + auto accGradient = extractAccumulatorVarGradient(&builder, param); + addRevGradientForFwdInst( + param, + RevGradient(param, accGradient, nullptr)); + } + if (hasRevGradients(param)) + { + auto gradients = popRevGradients(param); + + auto gradInst = emitAggregateValue( + &builder, + tryGetPrimalTypeFromDiffInst(param), + gradients); + + phiParamRevGradInsts.add(gradInst); + } + else + { + phiParamRevGradInsts.add( + emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param))); + } } - - if (hasRevGradients(param)) + else if (isPrimalInst(param)) { - auto gradients = popRevGradients(param); - - auto gradInst = emitAggregateValue( - &builder, - tryGetPrimalTypeFromDiffInst(param), - gradients); - - phiParamRevGradInsts.add(gradInst); + if (hasInverse(param)) + phiParamRevGradInsts.add(getInverse(&builder, param)); + else + { + SLANG_UNEXPECTED("param is a primal inst but has no registered inverse"); + } } else { - phiParamRevGradInsts.add( - emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param))); + SLANG_UNEXPECTED("param is neither differential nor primal"); } } @@ -896,6 +967,266 @@ struct DiffTransposePass } + struct InvInstPair + { + IRInst* inst; + IRInst* invInst; + + InvInstPair(IRInst* inst, IRInst* invInst) : + inst(inst), invInst(invInst) + { } + + InvInstPair() : inst(nullptr), invInst(nullptr) + { } + }; + + List<InvInstPair> invertArithmetic(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput) + { + switch (primalInst->getOp()) + { + case kIROp_Add: + { + SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1))); + return List<InvInstPair>( + InvInstPair( + primalInst->getOperand(0), + builder->emitSub( + primalInst->getOperand(0)->getDataType(), + invOutput, + primalInst->getOperand(1)))); + } + case kIROp_Sub: + { + SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1))); + return List<InvInstPair>( + InvInstPair( + primalInst->getOperand(0), + builder->emitAdd( + primalInst->getOperand(0)->getDataType(), + invOutput, + primalInst->getOperand(1)))); + } + + default: + SLANG_UNEXPECTED("Unhandled arithmetic inst for inversion"); + } + } + + void lowerLoopExitValues(IRBuilder* builder, IRLoop* fwdLoop) + { + for (auto decoration : fwdLoop->getDecorations()) + { + if (auto loopExitValueDecoration = as<IRLoopExitPrimalValueDecoration>(decoration)) + { + IRBlock* revLoopInitBlock = revBlockMap[fwdLoop->getBreakBlock()]; + + if (auto revLoopInst = revLoopInitBlock->getTerminator()) + builder->setInsertBefore(revLoopInst); + else + builder->setInsertInto(revLoopInitBlock); + + hoistPrimalInst(builder, loopExitValueDecoration->getLoopExitValInst()); + + setInverse(builder, loopExitValueDecoration->getTargetInst(), loopExitValueDecoration->getLoopExitValInst()); + } + } + } + + void lowerLoopExitValues(IRBuilder* builder, IRBlock* block) + { + if (auto loopInst = as<IRLoop>(block->getTerminator())) + lowerLoopExitValues(builder, loopInst); + } + + // Go through loop block phi-args, and look for loop counter + // arguments, which for a loop means inserting a check into + // loop condition block. + // This method also adds logic to skip the first iteration. + // (a 'do-while' loop) + // + void invertLoopCondition(IRBuilder* builder, IRLoop* loopInst) + { + auto firstLoopBlock = loopInst->getTargetBlock(); + + IRBlock* revLoopCondBlock = revBlockMap[firstLoopBlock]; + builder->setInsertBefore(revLoopCondBlock->getTerminator()); + + auto loopBaseCondition = as<IRIfElse>(revLoopCondBlock->getTerminator())->getCondition(); + + // Convert the loop from a 'for' into a 'do-while' by skipping the first check + + IRBlock* revLoopStartBlock = revBlockMap[as<IRBlock>(loopInst->getBreakBlock())]; + builder->setInsertBefore(revLoopStartBlock->getTerminator()); + + auto firstLoopCheckSkipVar = builder->emitVar(builder->getBoolType()); + builder->emitStore(firstLoopCheckSkipVar, builder->getBoolValue(true)); + + builder->setInsertBefore(revLoopCondBlock->getTerminator()); + auto firstLoopCheckSkipVal = builder->emitLoad(firstLoopCheckSkipVar); + + builder->emitStore(firstLoopCheckSkipVar, builder->getBoolValue(false)); + + loopBaseCondition = builder->emitIntrinsicInst( + builder->getBoolType(), + kIROp_Or, + 2, + List<IRInst*>(firstLoopCheckSkipVal, loopBaseCondition).getBuffer()); + + // Add a terminating condition based on the loop counter's initial primal value + + IRParam* loopCounterParam = nullptr; + UIndex loopCounterParamIndex = 0; + for (auto param : firstLoopBlock->getParams()) + { + if (param->findDecoration<IRLoopCounterDecoration>()) + { + // There really should be two (or more) loop counter params. + SLANG_RELEASE_ASSERT(loopCounterParam == nullptr); + loopCounterParam = param; + } + else + { + loopCounterParamIndex++; + } + } + + // Should see atleast one loop counter parameter on the first loop block. + SLANG_RELEASE_ASSERT(loopCounterParam); + + IRInst* loopCounterInitVal = loopInst->getArg(loopCounterParamIndex); + + auto paramBoundsCheck = builder->emitIntrinsicInst( + builder->getBoolType(), + kIROp_Neq, + 2, + List<IRInst*>( + hoistPrimalInst(builder, loopCounterParam), + hoistPrimalInst(builder, loopCounterInitVal)).getBuffer()); + + loopBaseCondition = builder->emitIntrinsicInst( + builder->getBoolType(), + kIROp_And, + 2, + List<IRInst*>(paramBoundsCheck, loopBaseCondition).getBuffer()); + + + as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(loopBaseCondition); + } + + List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput) + { + switch (primalInst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + return invertArithmetic(builder, primalInst, invOutput); + + default: + SLANG_UNIMPLEMENTED_X("Unhandled inst type for inversion"); + } + } + + bool hasInverse(IRInst* primalInst) + { + if (getOrCreateInverseVar(primalInst)) + return true; + else + return false; + } + + IRInst* getInverse(IRBuilder* builder, IRInst* primalInst) + { + // Note: There are other possible cases here, although not important + // right now. For example, a value is available to load from the primal block. + // + if (auto invVar = getOrCreateInverseVar(primalInst)) + return builder->emitLoad(invVar); + + return nullptr; + } + + void setInverse(IRBuilder* builder, IRInst* inst, IRInst* invInst) + { + if (auto invVar = getOrCreateInverseVar(inst)) + builder->emitStore(invVar, invInst); + } + + IRInst* hoistPrimalInst(IRBuilder* revBuilder, IRInst* inst) + { + SLANG_RELEASE_ASSERT(isPrimalInst(inst)); + + // Are the operands of this primal inst also available in the reverse-mode context? + // If not, move/load them. + // + hoistPrimalOperands(revBuilder, inst); + + if (isPrimalInst(inst) && + as<IRBlock>(inst->getParent()) && + isDifferentialInst(as<IRBlock>(inst->getParent()))) + { + if (!inst->findDecoration<IRPrimalValueAccessDecoration>()) + { + return getInverse(revBuilder, inst); + } + else + { + auto block = as<IRBlock>(inst->getParent()); + SLANG_RELEASE_ASSERT(block); + + if (block == revBuilder->getBlock()) + { + // Already in block.. + return inst; + } + + // Otherwise, move our inst to the the current builder location. + inst->removeFromParent(); + revBuilder->addInst(inst); + + return inst; + } + } + + return inst; + } + + void hoistPrimalOperands(IRBuilder* revBuilder, IRInst* fwdInst) + { + for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++) + { + // For now we'll only hoist primal operands that are + // generated in differential blocks. + // Eventually, we also want this method to move primal access + // insts to the reverse-mode blocks (i.e. this method will + // make sure all requried primal insts are moved to the right + // place) + // + if (isPrimalInst(fwdInst->getOperand(ii))) + { + auto hoistedPrimalInst = hoistPrimalInst(revBuilder, fwdInst->getOperand(ii)); + fwdInst->setOperand(ii, hoistedPrimalInst); + } + } + } + + void invertInst(IRBuilder* builder, IRInst* primalInst) + { + // Look for an available inverse entry for this primalInst's *output* + if (hasInverse(primalInst)) + { + auto invOutput = getInverse(builder, primalInst); + + auto invEntries = invertInst(builder, primalInst, invOutput); + + for (auto entry : invEntries) + setInverse(builder, entry.inst, entry.invInst); + } + else + { + SLANG_UNEXPECTED("Could not find value for the output of inst. Unable to invert"); + } + } + void transposeInst(IRBuilder* builder, IRInst* inst) { switch (inst->getOp()) @@ -930,7 +1261,7 @@ struct DiffTransposePass if (auto pairType = as<IRDifferentialPairType>(loadInst->getDataType())) { primalType = pairType->getValueType(); - } + } } } @@ -948,6 +1279,11 @@ struct DiffTransposePass SLANG_ASSERT(gradients.getCount() == 0); } + // Ensure primal operands are replaced with insts accessible in the + // reverse-mode context. + // + hoistPrimalOperands(builder, inst); + // Is this inst used in another differential block? // Emit a function-scope accumulator variable, and include it's value. // Also, we ignore this if it's a load since those are turned into stores @@ -2457,6 +2793,8 @@ struct DiffTransposePass Dictionary<IRInst*, IRVar*> revAccumulatorVarMap; + Dictionary<IRInst*, IRVar*> inverseVarMap; + List<IRInst*> usedPtrs; Dictionary<IRBlock*, IRBlock*> revBlockMap; @@ -2468,9 +2806,8 @@ struct DiffTransposePass List<PendingBlockTerminatorEntry> pendingBlocks; Dictionary<IRBlock*, List<IRInst*>> phiGradsMap; - - Dictionary<IRBlock*, IRBlock*> initializerBlockMap; + Dictionary<IRInst*, IRInst*> inverseValueMap; }; diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 68326fd54..c3af52d8a 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -10,6 +10,7 @@ #include "slang-ir-autodiff-propagate.h" #include "slang-ir-autodiff-transcriber-base.h" #include "slang-ir-validate.h" +#include "slang-ir-ssa.h" namespace Slang { @@ -55,8 +56,10 @@ struct DiffUnzipPass // After lowering, store references to the count // variables associated with this region // - IRVar* primalCountVar = nullptr; - IRVar* diffCountVar = nullptr; + IRInst* primalCountParam = nullptr; + IRInst* diffCountParam = nullptr; + + IRVar* primalCountLastVar = nullptr; enum CountStatus { @@ -76,8 +79,8 @@ struct DiffUnzipPass firstBlock(nullptr), breakBlock(nullptr), continueBlock(nullptr), - primalCountVar(nullptr), - diffCountVar(nullptr), + primalCountParam(nullptr), + diffCountParam(nullptr), status(CountStatus::Unresolved), maxIters(-1) { } @@ -93,8 +96,8 @@ struct DiffUnzipPass firstBlock(firstBlock), breakBlock(breakBlock), continueBlock(continueBlock), - primalCountVar(nullptr), - diffCountVar(nullptr), + primalCountParam(nullptr), + diffCountParam(nullptr), status(CountStatus::Unresolved), maxIters(-1) { } @@ -254,20 +257,15 @@ struct DiffUnzipPass // for (auto block : mixedBlocks) { - auto primalBlock = primalMap[block]; - if (isBlockIndexed(block)) - { processIndexedFwdBlock(block); - } } - + // Swap the first block's occurences out for the first primal block. firstBlock->replaceUsesWith(firstPrimalBlock); cleanupIndexRegionInfo(); - // Remove old blocks. for (auto block : mixedBlocks) block->removeAndDeallocate(); } @@ -319,17 +317,78 @@ struct DiffUnzipPass } } - // Make a primal value *available* to the differential block. - // This can get quite involved, and we're going to rely on - // constructSSA to do most of the heavy-lifting & optimization - // For now, we'll simply create a variable in the top-most - // primal block, then load it in the last primal block - // - //void hoistValue(IRInst* primalInst) - //{ - // IRBlock* terminalPrimalBlock = getTerminalPrimalBlock(); - // IRBlock* firstPrimalBlock = getFirstPrimalBlock(); - //} + UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg) + { + SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator())); + + auto branchInst = as<IRUnconditionalBranch>(block->getTerminator()); + List<IRInst*> phiArgs; + + for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++) + phiArgs.add(branchInst->getArg(ii)); + + phiArgs.add(arg); + + builder->setInsertInto(block); + switch (branchInst->getOp()) + { + case kIROp_unconditionalBranch: + builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer()); + break; + + case kIROp_loop: + builder->emitLoop( + as<IRLoop>(branchInst)->getTargetBlock(), + as<IRLoop>(branchInst)->getBreakBlock(), + as<IRLoop>(branchInst)->getContinueBlock(), + phiArgs.getCount(), + phiArgs.getBuffer()); + break; + + default: + break; + } + + branchInst->removeAndDeallocate(); + return phiArgs.getCount() - 1; + } + + IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type) + { + builder->setInsertInto(block); + return builder->emitParam(type); + } + + IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type, UIndex index) + { + List<IRParam*> params; + for (auto param : block->getParams()) + params.add(param); + + SLANG_RELEASE_ASSERT(index == (UCount)params.getCount()); + + return addPhiInputParam(builder, block, type); + } + + IRBlock* getBlock(IRInst* inst) + { + SLANG_RELEASE_ASSERT(inst); + + if (auto block = as<IRBlock>(inst)) + return block; + + return getBlock(inst->getParent()); + } + + IRInst* getInstInBlock(IRInst* inst) + { + SLANG_RELEASE_ASSERT(inst); + + if (auto block = as<IRBlock>(inst->getParent())) + return inst; + + return getInstInBlock(inst->getParent()); + } void lowerIndexedRegions() { @@ -337,114 +396,131 @@ struct DiffUnzipPass for (auto region : indexRegions) { - - //IRBlock* initializerBlock = getInitializerBlock(region); - IRBlock* breakBlock = region->breakBlock; - // Grab first primal block. IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]); - - // Make variable in the top-most block (so it's visible to diff blocks) builder.setInsertBefore(firstPrimalBlock->getTerminator()); - region->primalCountVar = builder.emitVar(builder.getIntType()); - builder.emitStore( - region->primalCountVar, - builder.getIntValue(builder.getIntType(), 0)); - - // NOTE: This is a hacky shortcut we're taking here. - // Technically the unzip pass should not affect the - // correctness (it must still compute the proper fwd-mode derivative) - // However, we're currently making the loop counter go backwards to - // make it easier on the transposition pass, so the output from - // the unzip pass is neither fwd-mode or rev-mode until the transposition - // step is complete. - // - // TODO: Ideally this needs to be replaced with a small inversion step - // within the transposition pass. - // - // Emit the diff counter into the diff *break* block ( - // which we're praying turns into the reverse initializer block) - // initialized to the final value of the primal counter. - // - builder.setInsertBefore(as<IRBlock>(diffMap[breakBlock])->getTerminator()); - //auto primalCounterValue = builder.emitLoad(region->primalCountVar); - auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar); - auto primalCounterLastValue = builder.emitSub( - primalCounterCurrValue->getDataType(), - primalCounterCurrValue, - builder.getIntValue(builder.getIntType(), 1)); - - region->diffCountVar = builder.emitVar(builder.getIntType()); - auto diffCountInit = builder.emitStore(region->diffCountVar, primalCounterLastValue); - - builder.addLoopCounterDecoration(diffCountInit); - builder.addLoopCounterDecoration(region->diffCountVar); - builder.addLoopCounterDecoration(primalCounterCurrValue); - builder.addLoopCounterDecoration(primalCounterLastValue); - - IRBlock* updateBlock = getUpdateBlock(region); + + // Make variable in the top-most block (so it's visible to diff blocks) + region->primalCountLastVar = builder.emitVar(builder.getIntType()); { - // TODO: Figure out if the counter update needs to go before or after - // the rest of the update block. - // - builder.setInsertBefore(as<IRBlock>(primalMap[updateBlock])->getTerminator()); + IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]); + + auto primalCondBlock = as<IRUnconditionalBranch>( + primalInitBlock->getTerminator())->getTargetBlock(); + builder.setInsertBefore(primalCondBlock->getTerminator()); + + auto phiCounterArgLoopEntryIndex = addPhiOutputArg( + &builder, + primalInitBlock, + builder.getIntValue(builder.getIntType(), 0)); + + region->primalCountParam = addPhiInputParam( + &builder, + primalCondBlock, + builder.getIntType(), + phiCounterArgLoopEntryIndex); + builder.addLoopCounterDecoration(region->primalCountParam); + builder.markInstAsPrimal(region->primalCountParam); + + IRBlock* primalUpdateBlock = as<IRBlock>(primalMap[getUpdateBlock(region)]); + builder.setInsertBefore(primalUpdateBlock->getTerminator()); - auto counterVal = builder.emitLoad(region->primalCountVar); auto incCounterVal = builder.emitAdd( builder.getIntType(), - counterVal, + region->primalCountParam, builder.getIntValue(builder.getIntType(), 1)); + builder.markInstAsPrimal(incCounterVal); + + auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, primalUpdateBlock, incCounterVal); - auto incStore = builder.emitStore(region->primalCountVar, incCounterVal); + SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex); - builder.addLoopCounterDecoration(counterVal); - builder.addLoopCounterDecoration(incCounterVal); - builder.addLoopCounterDecoration(incStore); + IRBlock* primalBreakBlock = as<IRBlock>(primalMap[region->breakBlock]); + builder.setInsertBefore(primalBreakBlock->getTerminator()); + + builder.emitStore(region->primalCountLastVar, region->primalCountParam); } { - IRBlock* firstLoopBlock = getFirstLoopBodyBlock(region); - auto diffFirstLoopBlock = as<IRBlock>(diffMap[firstLoopBlock]); + IRBlock* diffInitBlock = as<IRBlock>(diffMap[region->initBlock]); + + auto diffCondBlock = as<IRUnconditionalBranch>( + diffInitBlock->getTerminator())->getTargetBlock(); + builder.setInsertBefore(diffCondBlock->getTerminator()); - builder.setInsertBefore(diffFirstLoopBlock->getTerminator()); + auto phiCounterArgLoopEntryIndex = addPhiOutputArg( + &builder, + diffInitBlock, + builder.getIntValue(builder.getIntType(), 0)); + + region->diffCountParam = addPhiInputParam( + &builder, + diffCondBlock, + builder.getIntType(), + phiCounterArgLoopEntryIndex); + builder.addLoopCounterDecoration(region->diffCountParam); + builder.markInstAsPrimal(region->diffCountParam); + + IRBlock* diffUpdateBlock = as<IRBlock>(diffMap[getUpdateBlock(region)]); + builder.setInsertBefore(diffUpdateBlock->getTerminator()); - auto counterVal = builder.emitLoad(region->diffCountVar); - auto decCounterVal = builder.emitSub( + auto incCounterVal = builder.emitAdd( builder.getIntType(), - counterVal, + region->diffCountParam, builder.getIntValue(builder.getIntType(), 1)); + builder.markInstAsPrimal(incCounterVal); - auto decStore = builder.emitStore(region->diffCountVar, decCounterVal); + auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, diffUpdateBlock, incCounterVal); - // Mark insts as loop counter insts to avoid removing them. - // - builder.addLoopCounterDecoration(counterVal); - builder.addLoopCounterDecoration(decCounterVal); - builder.addLoopCounterDecoration(decStore); + SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex); - // TODO: - // This is another hack here to avoid the counter from going negative - // (since they are not valid indices) - // - IRBlock* diffCondBlock = as<IRBlock>(diffMap[region->firstBlock]); + auto loopInst = as<IRLoop>(diffInitBlock->getTerminator()); - builder.setInsertBefore(diffCondBlock->getTerminator()); - IRInst* diffCounterVal = builder.emitLoad(region->diffCountVar); - IRInst* diffCounterCmp = builder.emitIntrinsicInst( - builder.getBoolType(), - kIROp_Geq, - 2, - List<IRInst*>( - diffCounterVal, - builder.getIntValue(builder.getIntType(), 0)).getBuffer()); - - as<IRIfElse>(diffCondBlock->getTerminator())->condition.set(diffCounterCmp); + builder.setInsertBefore(loopInst); + + auto primalCounterLastVal = builder.emitLoad(region->primalCountLastVar); + builder.markInstAsPrimal(primalCounterLastVal); + builder.addPrimalValueAccessDecoration(primalCounterLastVal); - builder.addLoopCounterDecoration(diffCounterVal); - builder.addLoopCounterDecoration(diffCounterCmp); + builder.addLoopExitPrimalValueDecoration(loopInst, region->diffCountParam, primalCounterLastVal); } + } + } + void tagNewParams(IRBuilder* builder, IRFunc* func) + { + for (auto block : func->getBlocks()) + { + for (auto param = block->getFirstParam(); param; param = param->getNextParam()) + if (!param->findDecoration<IRAutodiffInstDecoration>()) + builder->markInstAsPrimal(param); + } + } + + void setInsertBeforeOrdinaryInst(IRBuilder* builder, IRInst* inst) + { + if (as<IRParam>(inst)) + { + SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); + builder->setInsertBefore(as<IRBlock>(inst->getParent())->getFirstOrdinaryInst()); + } + else + { + builder->setInsertBefore(inst); + } + } + + void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst) + { + if (as<IRParam>(inst)) + { + SLANG_RELEASE_ASSERT(as<IRBlock>(inst->getParent())); + builder->setInsertBefore(as<IRBlock>(inst->getParent())->getFirstOrdinaryInst()); + } + else + { + builder->setInsertAfter(inst); } } @@ -520,10 +596,15 @@ struct DiffUnzipPass auto storageVar = builder.emitVar(arrayType); + // TODO(sai) STOPPED HERE: For some reason, we still have a direct param access + // when trying to cover up the access to last value of loop counter. + // Maybe we need a different way to access this? (use a var) + // Special case? + // 3. Store current value into the array and replace uses with a load. // TODO: If an index is missing, use the 'last' value of the primal index. { - builder.setInsertAfter(inst); + setInsertAfterOrdinaryInst(&builder, inst); IRInst* storeAddr = storageVar; IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); @@ -535,12 +616,12 @@ struct DiffUnzipPass storeAddr = builder.emitElementAddress( builder.getPtrType(currType), storeAddr, - builder.emitLoad(region->primalCountVar)); + region->primalCountParam); } builder.emitStore(storeAddr, inst); } - + // 4. Replace uses in differential blocks with loads from the array. List<IRInst*> instsToTag; { @@ -548,17 +629,20 @@ struct DiffUnzipPass for (auto use = inst->firstUse; use; use = use->nextUse) { if (as<IRDecoration>(use->getUser())) - continue; + { + if (!as<IRLoopExitPrimalValueDecoration>(use->getUser())) + continue; + } - IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent()); + IRBlock* useBlock = getBlock(use->getUser()); if (useBlock && isDifferentialInst(useBlock)) diffUses.add(use); } for (auto use : diffUses) { - IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent()); - builder.setInsertBefore(use->getUser()); + IRBlock* useBlock = getBlock(use->getUser()); + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); IRInst* loadAddr = storageVar; IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType(); @@ -583,8 +667,7 @@ struct DiffUnzipPass // If the use-block is under the same region, use the // differential counter variable // - auto diffCounterCurrValue = builder.emitLoad(region->diffCountVar); - instsToTag.add(diffCounterCurrValue); + auto diffCounterCurrValue = region->diffCountParam; loadAddr = builder.emitElementAddress( builder.getPtrType(currType), @@ -596,7 +679,7 @@ struct DiffUnzipPass // If the use-block is outside this region, use the // last available value (by indexing with primal counter minus 1) // - auto primalCounterCurrValue = builder.emitLoad(region->primalCountVar); + auto primalCounterCurrValue = builder.emitLoad(region->primalCountLastVar); auto primalCounterLastValue = builder.emitSub( primalCounterCurrValue->getDataType(), primalCounterCurrValue, @@ -621,11 +704,11 @@ struct DiffUnzipPass } } - // TODO: Loop-counter is not really the right decoration.. - // replace with primal-inst when it's ready. - // for (auto instToTag : instsToTag) - builder.addLoopCounterDecoration(instToTag); + { + builder.addPrimalValueAccessDecoration(instToTag); + builder.markInstAsPrimal(instToTag); + } } } @@ -1306,11 +1389,6 @@ struct DiffUnzipPass // Nothing should be left in the original block. SLANG_ASSERT(block->getFirstChild() == block->getTerminator()); - - // Branch from primal to differential block. - // Functionally, the new blocks should produce the same output as the - // old block. - // primalBuilder.emitBranch(diffBlock); } }; diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 1232cf50d..97cdb644e 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -463,6 +463,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_PrimalInstDecoration: case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: + case kIROp_PrimalValueAccessDecoration: case kIROp_BackwardDerivativeDecoration: case kIROp_BackwardDerivativeIntermediateTypeDecoration: case kIROp_BackwardDerivativePropagateDecoration: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 35877d680..f2107aa62 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -598,6 +598,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(LayoutDecoration, layout, 1, 0) INST(LoopControlDecoration, loopControl, 1, 0) INST(LoopMaxItersDecoration, loopMaxIters, 1, 0) + INST(LoopExitPrimalValueDecoration, loopExitPrimalValue, 2, 0) INST(IntrinsicOpDecoration, intrinsicOp, 1, 0) /* TargetSpecificDecoration */ INST(TargetDecoration, target, 1, 0) @@ -769,6 +770,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0) INST(LoopCounterDecoration, loopCounterDecoration, 0, 0) + INST(PrimalValueAccessDecoration, primalValueAccessDecoration, 0, 0) /* Auto-diff inst decorations */ /// Used by the auto-diff pass to mark insts that compute diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 0eef9cb43..fe20f17f5 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -683,6 +683,18 @@ struct IRLoopCounterDecoration : IRDecoration IR_LEAF_ISA(LoopCounterDecoration) }; +struct IRLoopExitPrimalValueDecoration : IRDecoration +{ + enum + { + kOp = kIROp_LoopExitPrimalValueDecoration + }; + IR_LEAF_ISA(LoopExitPrimalValueDecoration) + + IRInst* getTargetInst() { return getOperand(0); } + IRInst* getLoopExitValInst() { return getOperand(1); } +}; + struct IRAutodiffInstDecoration : IRDecoration { IR_PARENT_ISA(AutodiffInstDecoration) @@ -712,7 +724,6 @@ struct IRPrimalInstDecoration : IRAutodiffInstDecoration IR_LEAF_ISA(PrimalInstDecoration) }; - struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration { enum @@ -726,6 +737,16 @@ struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration IRType* getPairType() { return as<IRType>(getOperand(0)); } }; +struct IRPrimalValueAccessDecoration : IRAutodiffInstDecoration +{ + enum + { + kOp = kIROp_PrimalValueAccessDecoration + }; + + IR_LEAF_ISA(PrimalValueAccessDecoration) +}; + struct IRPrimalValueStructKeyDecoration : IRDecoration { enum @@ -3613,6 +3634,16 @@ public: addDecoration(value, kIROp_LoopCounterDecoration); } + void addLoopExitPrimalValueDecoration(IRInst* value, IRInst* primalInst, IRInst* exitValue) + { + addDecoration(value, kIROp_LoopExitPrimalValueDecoration, primalInst, exitValue); + } + + void addPrimalValueAccessDecoration(IRInst* value) + { + addDecoration(value, kIROp_PrimalValueAccessDecoration); + } + void markInstAsPrimal(IRInst* value) { addDecoration(value, kIROp_PrimalInstDecoration); diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index d8246edae..20a8d7d13 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -1056,7 +1056,10 @@ bool constructSSA(ConstructSSAContext* context) // Figure out what variables we can promote to // SSA temporaries. - identifyPromotableVars(context); + if (!(context->promotableVars.getCount() > 0)) + { + identifyPromotableVars(context); + } // If none of the variables are promote-able, // then we can exit without making any changes |
