summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-03-15 22:26:58 -0400
committerGitHub <noreply@github.com>2023-03-15 19:26:58 -0700
commit71efd949fa5276e2464416fcf237f8fd2c486281 (patch)
treea5b24cd077f2ecc3f74d4dd4671c8260eb6e9b67 /source
parent38e62199cc75ce34608491c8dd299eb330bde518 (diff)
AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp674
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h264
-rw-r--r--source/slang/slang-ir-autodiff-region.cpp56
-rw-r--r--source/slang/slang-ir-autodiff-region.h119
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp110
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h724
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp34
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h563
-rw-r--r--source/slang/slang-ir-autodiff.cpp98
-rw-r--r--source/slang/slang-ir-autodiff.h13
-rw-r--r--source/slang/slang-ir-insts.h3
-rw-r--r--source/slang/slang-ir-ssa.cpp45
13 files changed, 1951 insertions, 753 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 7057a5835..247c3ddde 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -308,6 +308,7 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst*
// differentiated, don't differentiate the inst
//
auto primalConstructType = (IRType*)findOrTranscribePrimalInst(builder, origConstruct->getDataType());
+ // TODO: Need to update this to generate derivatives on a per-key basis
if (auto diffConstructType = differentiateType(builder, primalConstructType))
{
UCount operandCount = origConstruct->getOperandCount();
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
new file mode 100644
index 000000000..793a8ff07
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -0,0 +1,674 @@
+#include "slang-ir-autodiff-primal-hoist.h"
+#include "slang-ir-autodiff-region.h"
+
+namespace Slang
+{
+
+bool containsOperand(IRInst* inst, IRInst* operand)
+{
+ for (UIndex ii = 0; ii < inst->getOperandCount(); ii++)
+ if (inst->getOperand(ii) == operand)
+ return true;
+
+ return false;
+}
+
+RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* splitInfo)
+{
+ RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo();
+
+ RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
+
+ List<IRUse*> workList;
+ HashSet<IRUse*> processedUses;
+
+ HashSet<IRUse*> usesToReplace;
+
+ auto addPrimalOperandsToWorkList = [&](IRInst* inst)
+ {
+ UIndex opIndex = 0;
+ for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
+ {
+ if (!operand->get()->findDecoration<IRDifferentialInstDecoration>() &&
+ !as<IRFunc>(operand->get()) &&
+ !as<IRBlock>(operand->get()) &&
+ !(as<IRModuleInst>(operand->get()->getParent())) &&
+ !getBlock(operand->get())->findDecoration<IRDifferentialInstDecoration>())
+ workList.add(operand);
+ }
+
+ // Is the type itself computed within our function?
+ // If so, we'll need to consider that too (this is for existential types, specialize insts, etc)
+ // TODO: We might not really need to query the checkpointing algorithm for these
+ // since they _have_ to be classified as 'recompute'
+ //
+ if (inst->getDataType() && (getParentFunc(inst->getDataType()) == func))
+ {
+ if (!getBlock(inst->getDataType())->findDecoration<IRDifferentialInstDecoration>())
+ workList.add(&inst->typeUse);
+ }
+ };
+
+ // Populate recompute/store/invert sets with insts, by applying the policy
+ // to them.
+ //
+ for (auto block : func->getBlocks())
+ {
+ // Skip parameter block.
+ if (block == func->getFirstBlock())
+ continue;
+
+ if (!block->findDecoration<IRDifferentialInstDecoration>())
+ continue;
+
+ for (auto child : block->getChildren())
+ {
+ // Special case: Ignore the primals used to construct the return pair.
+ if (as<IRMakeDifferentialPair>(child) &&
+ as<IRReturn>(child->firstUse->getUser()))
+ {
+ // quick check
+ SLANG_RELEASE_ASSERT(child->firstUse->nextUse == nullptr);
+ continue;
+ }
+
+ addPrimalOperandsToWorkList(child);
+
+ // We'll be conservative with the decorations we consider as differential uses
+ // of a primal inst, in order to avoid weird behaviour with some decorations
+ //
+ for (auto decoration : child->getDecorations())
+ {
+ if (auto primalCtxDecoration = as<IRBackwardDerivativePrimalContextDecoration>(decoration))
+ workList.add(&primalCtxDecoration->primalContextVar);
+ else if (auto loopExitDecoration = as<IRLoopExitPrimalValueDecoration>(decoration))
+ workList.add(&loopExitDecoration->exitVal);
+ }
+ }
+
+ addPrimalOperandsToWorkList(block->getTerminator());
+ }
+
+ while (workList.getCount() > 0)
+ {
+ auto use = workList.getLast();
+ workList.removeLast();
+
+ if (processedUses.Contains(use))
+ continue;
+
+ processedUses.Add(use);
+
+ HoistResult result = this->classify(use);
+
+ if (result.mode == HoistResult::Mode::Store)
+ {
+ SLANG_ASSERT(!checkpointInfo->recomputeSet.Contains(result.instToStore));
+ checkpointInfo->storeSet.Add(result.instToStore);
+ }
+ else if (result.mode == HoistResult::Mode::Recompute)
+ {
+ SLANG_ASSERT(!checkpointInfo->storeSet.Contains(result.instToRecompute));
+ checkpointInfo->recomputeSet.Add(result.instToRecompute);
+
+ if (use->getUser()->findDecoration<IRDifferentialInstDecoration>())
+ usesToReplace.Add(use);
+
+ if (auto param = as<IRParam>(result.instToRecompute))
+ {
+ // Add in the branch-args of every predecessor block.
+ auto paramBlock = as<IRBlock>(param->getParent());
+ UIndex paramIndex = 0;
+ for (auto _param : paramBlock->getParams())
+ {
+ if (_param == param) break;
+ paramIndex ++;
+ }
+
+ for (auto predecessor : paramBlock->getPredecessors())
+ {
+ // If we hit this, the checkpoint policy is trying to recompute
+ // values across a loop region boundary (we don't currently support this,
+ // and in general this is quite inefficient in both compute & memory)
+ //
+ SLANG_RELEASE_ASSERT(!domTree->dominates(paramBlock, predecessor));
+
+ auto branchInst = as<IRUnconditionalBranch>(predecessor->getTerminator());
+ SLANG_ASSERT(branchInst->getOperandCount() > paramIndex);
+
+ workList.add(&branchInst->getOperands()[paramIndex]);
+ }
+ }
+ else
+ {
+ if (auto var = as<IRVar>(result.instToRecompute))
+ {
+ IRUse* storeUse = findUniqueStoredVal(var);
+ if (!storeUse)
+ workList.add(storeUse);
+ }
+ else
+ {
+ addPrimalOperandsToWorkList(result.instToRecompute);
+ }
+ }
+ }
+ else if (result.mode == HoistResult::Mode::Invert)
+ {
+ auto instToInvert = result.inversionInfo.instToInvert;
+
+ SLANG_RELEASE_ASSERT(containsOperand(instToInvert, use->getUser()));
+ SLANG_RELEASE_ASSERT(result.inversionInfo.targetInsts.contains(use->getUser()));
+
+ if (use->getUser()->findDecoration<IRDifferentialInstDecoration>())
+ usesToReplace.Add(use);
+
+ checkpointInfo->invertSet.Add(instToInvert);
+
+ if (checkpointInfo->invInfoMap.ContainsKey(instToInvert))
+ {
+ List<IRInst*> currOperands = checkpointInfo->invInfoMap[instToInvert].GetValue().requiredOperands;
+ for (Index ii = 0; ii < result.inversionInfo.requiredOperands.getCount(); ii++)
+ {
+ SLANG_RELEASE_ASSERT(result.inversionInfo.requiredOperands[ii] == currOperands[ii]);
+ }
+ }
+ else
+ checkpointInfo->invInfoMap[instToInvert] = result.inversionInfo;
+ }
+ }
+
+ return applyCheckpointSet(checkpointInfo, func, splitInfo, usesToReplace);
+}
+
+void applyToInst(
+ IRBuilder* builder,
+ CheckpointSetInfo* checkpointInfo,
+ HoistedPrimalsInfo* hoistInfo,
+ IROutOfOrderCloneContext* cloneCtx,
+ IRInst* inst)
+{
+ // Early-out..
+ if (checkpointInfo->storeSet.Contains(inst))
+ {
+ hoistInfo->storeSet.Add(inst);
+ return;
+ }
+
+ bool isInstRecomputed = checkpointInfo->recomputeSet.Contains(inst);
+ if (isInstRecomputed)
+ {
+ if (as<IRParam>(inst))
+ {
+ // Can completely ignore first block parameters
+ if (getBlock(inst) != getBlock(inst)->getParent()->getFirstBlock())
+ {
+ // TODO: We would need to clone in the control-flow for each region (without nested loops)
+ // prior to this, and then hoist this parameter into the within-region block, otherwise
+ // this parameter will not be visible to transposed insts.
+ // This will also include adding an extra case to 'ensurePrimalAvailability': if both insts
+ // are withing the _same_ indexed region, skip the indexed store/load and use a simple var.
+ //
+ SLANG_UNIMPLEMENTED_X("Parameter recompute is not currently supported");
+ }
+ }
+ else
+ {
+ hoistInfo->recomputeSet.Add(cloneCtx->cloneInstOutOfOrder(builder, inst));
+ }
+ }
+
+ bool isInstInverted = checkpointInfo->invertSet.Contains(inst);
+ if (isInstInverted)
+ {
+ InversionInfo info = checkpointInfo->invInfoMap[inst];
+ auto clonedInstToInvert = cloneCtx->cloneInstOutOfOrder(builder, info.instToInvert);
+
+ // Process operand set for the inverse inst.
+ List<IRInst*> newOperands;
+ for (auto operand : info.requiredOperands)
+ {
+ if (cloneCtx->cloneEnv.mapOldValToNew.ContainsKey(operand))
+ newOperands.add(cloneCtx->cloneEnv.mapOldValToNew[operand]);
+ else
+ newOperands.add(operand);
+ }
+
+ info.requiredOperands = newOperands;
+
+ hoistInfo->invertInfoMap[clonedInstToInvert] = info;
+ hoistInfo->instsToInvert.Add(clonedInstToInvert);
+ hoistInfo->invertSet.Add(cloneCtx->cloneInstOutOfOrder(builder, inst));
+ }
+}
+
+RefPtr<HoistedPrimalsInfo> applyCheckpointSet(
+ CheckpointSetInfo* checkpointInfo,
+ IRGlobalValueWithCode* func,
+ BlockSplitInfo* splitInfo,
+ HashSet<IRUse*> pendingUses)
+{
+ RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo();
+
+ RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext();
+
+ for (auto use : pendingUses)
+ cloneCtx->pendingUses.Add(use);
+
+ // Populate the clone context with all the primal uses that we may need to replace with
+ // cloned versions. That way any insts we clone into the diff block will automatically replace
+ // their uses.
+ //
+ auto addPrimalUsesToCloneContext = [&](IRInst* inst)
+ {
+ UIndex opIndex = 0;
+ for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
+ {
+ if (!operand->get()->findDecoration<IRDifferentialInstDecoration>())
+ cloneCtx->pendingUses.Add(operand);
+ }
+ };
+
+ // Go back over the insts and move/clone them accoridngly.
+ for (auto block : func->getBlocks())
+ {
+ // Skip parameter block.
+ if (block == func->getFirstBlock())
+ continue;
+
+ if (block->findDecoration<IRDifferentialInstDecoration>())
+ continue;
+
+ auto diffBlock = as<IRBlock>(splitInfo->diffBlockMap[block]);
+
+ auto firstDiffInst = as<IRBlock>(splitInfo->diffBlockMap[block])->getFirstOrdinaryInst();
+
+ IRBuilder builder(func->getModule());
+
+ UIndex ii = 0;
+ for (auto param : block->getParams())
+ {
+ builder.setInsertBefore(diffBlock->getFirstOrdinaryInst());
+
+ // Apply checkpoint rule to the parameter itself.
+ applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, param);
+
+ // Copy primal branch-arg for predecessor blocks.
+ HashSet<IRBlock*> predecessorSet;
+ for (auto predecessor : block->getPredecessors())
+ {
+ if (predecessorSet.Contains(predecessor))
+ continue;
+
+ predecessorSet.Add(predecessor);
+
+ auto diffPredecessor = as<IRBlock>(splitInfo->diffBlockMap[block]);
+
+ if (checkpointInfo->recomputeSet.Contains(param))
+ addPhiOutputArg(&builder,
+ diffPredecessor,
+ as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii));
+
+ if (checkpointInfo->invertSet.Contains(param))
+ addPhiOutputArg(&builder,
+ diffPredecessor,
+ as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii));
+ }
+
+ ii++;
+ }
+
+ for (auto child : block->getChildren())
+ {
+ builder.setInsertBefore(firstDiffInst);
+
+ applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child);
+ }
+ }
+
+ return hoistInfo;
+}
+
+IRType* getTypeForLocalStorage(
+ IRBuilder* builder,
+ IRType* storageType,
+ List<IndexTrackingInfo*> defBlockIndices)
+{
+ for (auto index : defBlockIndices)
+ {
+ SLANG_ASSERT(index->status == IndexTrackingInfo::CountStatus::Static);
+ SLANG_ASSERT(index->maxIters >= 0);
+
+ storageType = builder->getArrayType(
+ storageType,
+ builder->getIntValue(
+ builder->getUIntType(),
+ index->maxIters + 1));
+ }
+
+ return storageType;
+}
+
+IRVar* emitIndexedLocalVar(
+ IRBlock* varBlock,
+ IRType* baseType,
+ List<IndexTrackingInfo*> defBlockIndices)
+{
+ SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
+
+ IRBuilder varBuilder(varBlock->getModule());
+ varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst());
+
+ IRType* varType = getTypeForLocalStorage(&varBuilder, baseType, defBlockIndices);
+
+ auto var = varBuilder.emitVar(varType);
+ varBuilder.emitStore(var, varBuilder.emitDefaultConstruct(varType));
+
+ return var;
+}
+
+IRInst* emitIndexedStoreAddressForVar(
+ IRBuilder* builder,
+ IRVar* localVar,
+ List<IndexTrackingInfo*> defBlockIndices)
+{
+ IRInst* storeAddr = localVar;
+ IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType();
+
+ for (auto index : defBlockIndices)
+ {
+ currType = as<IRArrayType>(currType)->getElementType();
+
+ storeAddr = builder->emitElementAddress(
+ builder->getPtrType(currType),
+ storeAddr,
+ index->primalCountParam);
+ }
+
+ return storeAddr;
+}
+
+
+IRInst* emitIndexedLoadAddressForVar(
+ IRBuilder* builder,
+ IRVar* localVar,
+ List<IndexTrackingInfo*> defBlockIndices,
+ List<IndexTrackingInfo*> useBlockIndices)
+{
+ IRInst* loadAddr = localVar;
+ IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType();
+
+ for (auto index : defBlockIndices)
+ {
+ currType = as<IRArrayType>(currType)->getElementType();
+ if (useBlockIndices.contains(index))
+ {
+ // If the use-block is under the same region, use the
+ // differential counter variable
+ //
+ auto diffCounterCurrValue = index->diffCountParam;
+
+ loadAddr = builder->emitElementAddress(
+ builder->getPtrType(currType),
+ loadAddr,
+ diffCounterCurrValue);
+ }
+ else
+ {
+ // If the use-block is outside this region, use the
+ // last available value (by indexing with primal counter minus 1)
+ //
+ auto primalCounterCurrValue = builder->emitLoad(index->primalCountLastVar);
+ auto primalCounterLastValue = builder->emitSub(
+ primalCounterCurrValue->getDataType(),
+ primalCounterCurrValue,
+ builder->getIntValue(builder->getIntType(), 1));
+
+ loadAddr = builder->emitElementAddress(
+ builder->getPtrType(currType),
+ loadAddr,
+ primalCounterLastValue);
+ }
+ }
+
+ return loadAddr;
+}
+
+IRVar* storeIndexedValue(
+ IRBuilder* builder,
+ IRBlock* defaultVarBlock,
+ IRInst* instToStore,
+ List<IndexTrackingInfo*> defBlockIndices)
+{
+ IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices);
+
+ IRInst* addr = emitIndexedStoreAddressForVar(builder, localVar, defBlockIndices);
+
+ builder->emitStore(addr, instToStore);
+
+ return localVar;
+}
+
+IRInst* loadIndexedValue(
+ IRBuilder* builder,
+ IRVar* localVar,
+ List<IndexTrackingInfo*> defBlockIndices,
+ List<IndexTrackingInfo*> useBlockIndices)
+{
+ IRInst* addr = emitIndexedLoadAddressForVar(builder, localVar, defBlockIndices, useBlockIndices);
+
+ return builder->emitLoad(addr);
+}
+
+bool areIndicesEqual(
+ List<IndexTrackingInfo*> indicesA,
+ List<IndexTrackingInfo*> indicesB)
+{
+ if (indicesA.getCount() != indicesB.getCount())
+ return false;
+
+ for (Index ii = 0; ii < indicesA.getCount(); ii++)
+ {
+ if (indicesA[ii] != indicesB[ii])
+ return false;
+ }
+
+ return true;
+}
+
+bool areIndicesSubsetOf(
+ List<IndexTrackingInfo*> indicesA,
+ List<IndexTrackingInfo*> indicesB)
+{
+ if (indicesA.getCount() > indicesB.getCount())
+ return false;
+
+ for (Index ii = 0; ii < indicesA.getCount(); ii++)
+ {
+ if (indicesA[ii] != indicesB[ii])
+ return false;
+ }
+
+ return true;
+}
+
+
+bool isDifferentialBlock(IRBlock* block)
+{
+ return block->findDecoration<IRDifferentialInstDecoration>();
+}
+
+RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
+ HoistedPrimalsInfo* hoistInfo,
+ IRGlobalValueWithCode* func,
+ Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo)
+{
+ RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
+
+ IRBuilder builder(func->getModule());
+ IRBlock* defaultVarBlock = func->getFirstBlock()->getNextBlock();
+
+ SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock));
+
+ HashSet<IRInst*> processedStoreSet;
+
+ // TODO: Also ensure availability of everything in the recompute set (for proper recompute support)
+ for (auto instToStore : hoistInfo->storeSet)
+ {
+ IRBlock* defBlock = nullptr;
+ if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType()))
+ {
+ auto varInst = as<IRVar>(instToStore);
+ auto storeUse = findUniqueStoredVal(varInst);
+
+ defBlock = getBlock(storeUse->getUser());
+ }
+ else
+ defBlock = getBlock(instToStore);
+
+ SLANG_RELEASE_ASSERT(defBlock);
+
+ List<IRUse*> outOfScopeUses;
+ for (auto use = instToStore->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ // Only consider uses in differential blocks.
+ // This method is not responsible for other blocks.
+ //
+ IRBlock* userBlock = getBlock(use->getUser());
+ if (userBlock->findDecoration<IRDifferentialInstDecoration>())
+ {
+ if (!domTree->dominates(defBlock, userBlock))
+ {
+ outOfScopeUses.add(use);
+ }
+ else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock]))
+ {
+ outOfScopeUses.add(use);
+ }
+ else if (indexedBlockInfo[defBlock].GetValue().getCount() > 0 &&
+ !isDifferentialBlock(defBlock))
+ {
+ outOfScopeUses.add(use);
+ }
+ else if (as<IRPtrTypeBase>(instToStore->getDataType()) &&
+ !isDifferentialBlock(defBlock))
+ {
+ outOfScopeUses.add(use);
+ }
+ }
+
+ use = nextUse;
+ }
+
+ if (outOfScopeUses.getCount() == 0)
+ {
+ processedStoreSet.Add(instToStore);
+ continue;
+ }
+
+ if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType()))
+ {
+
+ IRVar* varToStore = as<IRVar>(instToStore);
+ SLANG_RELEASE_ASSERT(varToStore);
+
+ auto storeUse = findUniqueStoredVal(varToStore);
+
+ List<IndexTrackingInfo*> defBlockIndices = indexedBlockInfo[defBlock];
+
+ bool isIndexedStore = (storeUse && defBlockIndices.getCount() > 0);
+
+ // TODO: There's a slight hackiness here. (Ideally we might just want to emit
+ // additional vars when splitting a call)
+ //
+ if (!isIndexedStore && isDerivativeContextVar(varToStore))
+ {
+ varToStore->insertBefore(defaultVarBlock->getFirstOrdinaryInst());
+ processedStoreSet.Add(varToStore);
+ continue;
+ }
+
+ setInsertAfterOrdinaryInst(&builder, getInstInBlock(storeUse->getUser()));
+
+ IRVar* localVar = storeIndexedValue(
+ &builder,
+ defaultVarBlock,
+ builder.emitLoad(varToStore),
+ defBlockIndices);
+
+ for (auto use : outOfScopeUses)
+ {
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+
+ List<IndexTrackingInfo*> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
+
+ IRInst* loadAddr = emitIndexedLoadAddressForVar(&builder, localVar, defBlockIndices, useBlockIndices);
+ builder.replaceOperand(use, loadAddr);
+ }
+
+ processedStoreSet.Add(localVar);
+ }
+ else
+ {
+ setInsertAfterOrdinaryInst(&builder, instToStore);
+
+ List<IndexTrackingInfo*> defBlockIndices = indexedBlockInfo[defBlock];
+ auto localVar = storeIndexedValue(&builder, defaultVarBlock, instToStore, defBlockIndices);
+
+ for (auto use : outOfScopeUses)
+ {
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+
+ List<IndexTrackingInfo*> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
+ builder.replaceOperand(use, loadIndexedValue(&builder, localVar, defBlockIndices, useBlockIndices));
+ }
+
+ processedStoreSet.Add(localVar);
+ }
+ }
+
+ // Replace the old store set with the processed onne one.
+ hoistInfo->storeSet = processedStoreSet;
+
+ return hoistInfo;
+}
+
+void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode*)
+{
+ // Do nothing.. This is an (almost) always-store policy.
+ return;
+}
+
+HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
+{
+ // Store all that we can.. by default, classify will only be called on relevant differential
+ // uses (or on uses in a 'recompute' inst)
+ //
+ if (auto var = as<IRVar>(use->get()))
+ {
+ if (auto spec = as<IRSpecialize>(as<IRPtrTypeBase>(var->getDataType())->getValueType()))
+ {
+ for (UInt i = 0; i < spec->getArgCount(); i++)
+ {
+ if (!canTypeBeStored(spec->getArg(i)->getDataType()))
+ return HoistResult::recompute(use->get());
+ }
+ return HoistResult::store(use->get());
+ }
+ else // if (canTypeBeStored(as<IRPtrTypeBase>(var->getDataType())->getValueType()));
+ {
+ return HoistResult::store(use->get());
+ }
+ }
+ else
+ {
+ if (canTypeBeStored(use->get()->getDataType()))
+ return HoistResult::store(use->get());
+ else
+ return HoistResult::recompute(use->get());
+ }
+}
+
+}; \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
new file mode 100644
index 000000000..dc85942f6
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -0,0 +1,264 @@
+// slang-ir-autodiff-primal-hoist.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-region.h"
+#include "slang-ir-dominators.h"
+
+
+namespace Slang
+{
+ struct IROutOfOrderCloneContext : public RefObject
+ {
+ IRCloneEnv cloneEnv;
+ HashSet<IRUse*> pendingUses;
+
+ IRInst* cloneInstOutOfOrder(IRBuilder* builder, IRInst* inst)
+ {
+ IRInst* clonedInst = cloneInst(&cloneEnv, builder, inst);
+
+ UInt operandCount = clonedInst->getOperandCount();
+ for (UInt ii = 0; ii < operandCount; ++ii)
+ {
+ auto oldOperand = inst->getOperand(ii);
+ auto newOperand = clonedInst->getOperand(ii);
+
+ if (oldOperand == newOperand)
+ pendingUses.Add(&clonedInst->getOperands()[ii]);
+ }
+
+ for (auto use = inst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ if (pendingUses.Contains(use))
+ {
+ pendingUses.Remove(use);
+ builder->replaceOperand(use, clonedInst);
+ }
+
+ use = nextUse;
+ }
+
+ return clonedInst;
+ }
+ };
+
+ struct InversionInfo
+ {
+ IRInst* instToInvert;
+ List<IRInst*> requiredOperands;
+ List<IRInst*> targetInsts;
+
+ InversionInfo(
+ IRInst* instToInvert,
+ List<IRInst*> requiredOperands,
+ List<IRInst*> targetInsts) :
+ instToInvert(instToInvert),
+ requiredOperands(requiredOperands),
+ targetInsts(targetInsts)
+ { }
+
+ InversionInfo() : instToInvert(nullptr)
+ { }
+
+ InversionInfo applyMap(IRCloneEnv* env)
+ {
+ InversionInfo newInfo;
+ if (env->mapOldValToNew.ContainsKey(instToInvert))
+ newInfo.instToInvert = env->mapOldValToNew[instToInvert];
+
+ for (auto inst : requiredOperands)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newInfo.requiredOperands.add(env->mapOldValToNew[inst]);
+
+ for (auto inst : targetInsts)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newInfo.targetInsts.add(env->mapOldValToNew[inst]);
+
+ return newInfo;
+ }
+ };
+
+ struct HoistedPrimalsInfo : public RefObject
+ {
+ HashSet<IRInst*> storeSet;
+ HashSet<IRInst*> recomputeSet;
+ HashSet<IRInst*> invertSet;
+
+ HashSet<IRInst*> instsToInvert;
+
+ Dictionary<IRInst*, InversionInfo> invertInfoMap;
+
+ RefPtr<HoistedPrimalsInfo> applyMap(IRCloneEnv* env)
+ {
+ RefPtr<HoistedPrimalsInfo> newPrimalsInfo = new HoistedPrimalsInfo();
+
+ for (auto inst : this->storeSet)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newPrimalsInfo->storeSet.Add(env->mapOldValToNew[inst]);
+
+ for (auto inst : this->recomputeSet)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newPrimalsInfo->recomputeSet.Add(env->mapOldValToNew[inst]);
+
+ for (auto inst : this->invertSet)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newPrimalsInfo->invertSet.Add(env->mapOldValToNew[inst]);
+
+ for (auto inst : this->instsToInvert)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newPrimalsInfo->instsToInvert.Add(env->mapOldValToNew[inst]);
+
+ for (auto kvpair : this->invertInfoMap)
+ if (env->mapOldValToNew.ContainsKey(kvpair.Key))
+ newPrimalsInfo->invertInfoMap[env->mapOldValToNew[kvpair.Key]] = kvpair.Value.applyMap(env);
+
+ return newPrimalsInfo;
+ }
+
+ void merge(HoistedPrimalsInfo* info)
+ {
+ for (auto inst : info->storeSet)
+ storeSet.Add(inst);
+
+ for (auto inst : info->recomputeSet)
+ recomputeSet.Add(inst);
+
+ for (auto inst : info->invertSet)
+ invertSet.Add(inst);
+
+ for (auto inst : info->instsToInvert)
+ instsToInvert.Add(inst);
+
+ for (auto kvpair : info->invertInfoMap)
+ invertInfoMap[kvpair.Key] = kvpair.Value;
+ }
+ };
+
+ struct HoistResult
+ {
+ enum Mode
+ {
+ Store,
+ Recompute,
+ Invert,
+
+ None
+ };
+
+ Mode mode;
+
+ IRInst* instToStore = nullptr;
+ IRInst* instToRecompute = nullptr;
+ InversionInfo inversionInfo;
+
+ HoistResult(Mode mode, IRInst* target) :
+ mode(mode)
+ {
+ switch (mode)
+ {
+ case Mode::Store:
+ instToStore = target;
+ break;
+ case Mode::Recompute:
+ instToRecompute = target;
+ break;
+ case Mode::Invert:
+ SLANG_UNEXPECTED("Wrong constructor for HoistResult::Mode::Invert");
+ break;
+ default:
+ SLANG_UNEXPECTED("Unhandled hoist mode");
+ break;
+ }
+ }
+
+ HoistResult(InversionInfo info) :
+ mode(Mode::Invert), inversionInfo(info)
+ { }
+
+ static HoistResult store(IRInst* inst)
+ {
+ return HoistResult(Mode::Store, inst);
+ }
+
+ static HoistResult recompute(IRInst* inst)
+ {
+ return HoistResult(Mode::Recompute, inst);
+ }
+
+ static HoistResult invert(InversionInfo inst)
+ {
+ return HoistResult(inst);
+ }
+ };
+
+
+ // Information on which insts are to be stored, recomputed
+ // and inverted within a single function.
+ // This data structure also holds a map of raw HoistResult
+ // objects to provide more information to later passes.
+ //
+ struct CheckpointSetInfo : public RefObject
+ {
+ HashSet<IRInst*> storeSet;
+ HashSet<IRInst*> recomputeSet;
+ HashSet<IRInst*> invertSet;
+
+ Dictionary<IRInst*, InversionInfo> invInfoMap;
+ };
+
+ struct BlockSplitInfo : public RefObject
+ {
+ // Maps primal to differential blocks from the unzip step.
+ Dictionary<IRBlock*, IRBlock*> diffBlockMap;
+ };
+
+ class AutodiffCheckpointPolicyBase : public RefObject
+ {
+ public:
+
+ AutodiffCheckpointPolicyBase(IRModule* module) : module(module)
+ { }
+
+ RefPtr<HoistedPrimalsInfo> processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* info);
+
+ // Do pre-processing on the function (mainly for
+ // 'global' checkpointing methods that consider the entire
+ // function)
+ //
+ virtual void preparePolicy(IRGlobalValueWithCode* func) = 0;
+
+ virtual HoistResult classify(IRUse* diffBlockUse) = 0;
+
+ protected:
+
+ IRModule* module;
+ };
+
+ class DefaultCheckpointPolicy : public AutodiffCheckpointPolicyBase
+ {
+ public:
+
+ DefaultCheckpointPolicy(IRModule* module)
+ : AutodiffCheckpointPolicyBase(module)
+ { }
+
+ virtual void preparePolicy(IRGlobalValueWithCode* func);
+ virtual HoistResult classify(IRUse* use);
+ };
+
+ RefPtr<HoistedPrimalsInfo> applyCheckpointSet(
+ CheckpointSetInfo* checkpointInfo,
+ IRGlobalValueWithCode* func,
+ BlockSplitInfo* splitInfo,
+ HashSet<IRUse*> pendingUses);
+
+ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
+ HoistedPrimalsInfo* hoistInfo,
+ IRGlobalValueWithCode* func,
+ Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo);
+
+}; \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-region.cpp b/source/slang/slang-ir-autodiff-region.cpp
new file mode 100644
index 000000000..98b64f179
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-region.cpp
@@ -0,0 +1,56 @@
+// slang-ir-autodiff-region.cpp
+#include "slang-ir-autodiff-region.h"
+
+namespace Slang{
+ RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func)
+ {
+ RefPtr<IndexedRegionMap> regionMap = new IndexedRegionMap;
+
+ List<IRBlock*> workList;
+
+ regionMap->mapBlock(func->getFirstBlock(), nullptr);
+ workList.add(func->getFirstBlock());
+
+ while (workList.getCount() > 0)
+ {
+ auto currentBlock = workList.getLast();
+ workList.removeLast();
+
+ auto terminator = currentBlock->getTerminator();
+ auto currentRegion = regionMap->getRegion(currentBlock);
+
+ switch (terminator->getOp())
+ {
+ case kIROp_loop:
+ {
+ auto loopRegion = regionMap->newRegion(as<IRLoop>(terminator), currentRegion);
+ auto condBlock = as<IRLoop>(terminator)->getTargetBlock();
+
+ regionMap->mapBlock(condBlock, loopRegion);
+ workList.add(condBlock);
+
+ auto ifElse = as<IRIfElse>(condBlock->getTerminator());
+ SLANG_RELEASE_ASSERT(ifElse);
+
+ // TODO: this is one of the places we'll need to change if we support loops that
+ // loop on either the true or false side. For now, we assume the loop is on the
+ // true side only.
+ //
+ regionMap->mapBlock(ifElse->getFalseBlock(), currentRegion);
+ workList.add(ifElse->getFalseBlock());
+ }
+ }
+
+ for (auto successor : currentBlock->getSuccessors())
+ {
+ // If already mapped, skip.
+ if (regionMap->hasMapping(successor))
+ continue;
+ regionMap->mapBlock(successor, currentRegion);
+ workList.add(successor);
+ }
+ }
+
+ return regionMap;
+ }
+};
diff --git a/source/slang/slang-ir-autodiff-region.h b/source/slang/slang-ir-autodiff-region.h
new file mode 100644
index 000000000..a4618e257
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-region.h
@@ -0,0 +1,119 @@
+// slang-ir-autodiff-region.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-autodiff.h"
+
+namespace Slang
+{
+struct IndexedRegion : public RefObject
+{
+ IRLoop* loop;
+ IndexedRegion* parent;
+
+ IndexedRegion(IRLoop* loop, IndexedRegion* parent) : loop(loop), parent(parent)
+ { }
+
+ IRBlock* getInitializerBlock() { return as<IRBlock>(loop->getParent()); }
+ IRBlock* getConditionBlock()
+ {
+ auto condBlock = as<IRBlock>(loop->getTargetBlock());
+ SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator()));
+ return condBlock;
+ }
+
+ IRBlock* getBreakBlock() { return loop->getBreakBlock(); }
+
+ IRBlock* getUpdateBlock()
+ {
+ auto initBlock = getInitializerBlock();
+
+ auto condBlock = getConditionBlock();
+
+ IRBlock* lastLoopBlock = nullptr;
+
+ for (auto predecessor : condBlock->getPredecessors())
+ {
+ if (predecessor != initBlock)
+ lastLoopBlock = predecessor;
+ }
+
+ // Should find atleast one predecessor that is _not_ the
+ // init block (that contains the loop info). This
+ // predecessor would be the last block in the loop
+ // before looping back to the condition.
+ //
+ SLANG_RELEASE_ASSERT(lastLoopBlock);
+
+ return lastLoopBlock;
+ }
+};
+
+struct IndexTrackingInfo : public RefObject
+{
+ // After lowering, store references to the count
+ // variables associated with this region
+ //
+ IRInst* primalCountParam = nullptr;
+ IRInst* diffCountParam = nullptr;
+
+ IRVar* primalCountLastVar = nullptr;
+
+ enum CountStatus
+ {
+ Unresolved,
+ Dynamic,
+ Static
+ };
+
+ CountStatus status = CountStatus::Unresolved;
+
+ // Inferred maximum number of iterations.
+ Count maxIters = -1;
+};
+
+struct IndexedRegionMap : public RefObject
+{
+ Dictionary<IRBlock*, IndexedRegion*> map;
+ List<RefPtr<IndexedRegion>> regions;
+
+ IndexedRegion* newRegion(IRLoop* loop, IndexedRegion* parent)
+ {
+ auto region = new IndexedRegion(loop, parent);
+ regions.add(region);
+
+ return region;
+ }
+
+ void mapBlock(IRBlock* block, IndexedRegion* region)
+ {
+ map.Add(block, region);
+ }
+
+ bool hasMapping(IRBlock* block)
+ {
+ return map.ContainsKey(block);
+ }
+
+ IndexedRegion* getRegion(IRBlock* block)
+ {
+ return map[block];
+ }
+
+ List<IndexedRegion*> getAllAncestorRegions(IRBlock* block)
+ {
+ List<IndexedRegion*> regionList;
+
+ IndexedRegion* region = getRegion(block);
+ for (; region; region = region->parent)
+ regionList.add(region);
+
+ return regionList;
+ }
+};
+
+RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func);
+
+
+}; \ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 328af4867..157011b7c 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -606,106 +606,6 @@ namespace Slang
return fwdDiffFunc;
}
- void BackwardDiffTranscriberBase::insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc)
- {
- RefPtr<IRDominatorTree> domTree = computeDominatorTree(diffPropFunc);
- auto firstBlock = diffPropFunc->getFirstBlock();
- if (!firstBlock)
- return;
- Dictionary<IRInst*, IRVar*> instVars;
- Dictionary<IRBlock*, IRCloneEnv> cloneEnvs;
- auto storeInstAsLocalVar = [&](IRInst* inst)
- {
- IRVar* var = nullptr;
- if (instVars.TryGetValue(inst, var))
- return var;
- IRBuilder builder(diffPropFunc);
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
- var = builder.emitVar(inst->getDataType());
- builder.emitStore(var, builder.emitDefaultConstruct(inst->getDataType()));
-
- setInsertAfterOrdinaryInst(&builder, inst);
- builder.emitStore(var, inst);
- instVars[inst] = var;
- return var;
- };
-
- IRBuilder builder(diffPropFunc);
- List<IRInst*> workList;
- for (auto block : diffPropFunc->getBlocks())
- {
- if (!block->findDecoration<IRDifferentialInstDecoration>())
- continue;
- cloneEnvs[block] = IRCloneEnv();
- for (auto inst : block->getChildren())
- {
- workList.add(inst);
- }
- }
-
- for (Index i = 0; i < workList.getCount(); i++)
- {
- auto inst = workList[i];
- for (UInt j = 0; j < inst->getOperandCount(); j++)
- {
- auto operand = inst->getOperand(j);
- if (operand->getOp() == kIROp_Block)
- continue;
- auto operandParent = inst->getOperand(j)->getParent();
- if (!operandParent)
- continue;
- if (operandParent->parent != diffPropFunc)
- continue;
- if (domTree->dominates(operandParent, inst->parent))
- continue;
-
- // The def site of the operand does not dominate the use.
- // We need to insert a local variable to store this var.
-
- IRInst* operandReplacement = nullptr;
- if (canTypeBeStored(operand->getDataType()))
- {
- auto var = storeInstAsLocalVar(operand);
- builder.setInsertBefore(inst);
- operandReplacement = builder.emitLoad(var);
- }
- else if (operand->getOp() == kIROp_Var)
- {
- // Var can just be hoisted to first block.
- operand->insertBefore(firstBlock->getFirstOrdinaryInst());
- }
- else
- {
- // For all other insts, we need to copy it to right before this inst.
- // Before actually copying it, check if we have already copied it to
- // any blocks that dominates this block.
- auto dom = as<IRBlock>(inst->getParent());
- while (dom)
- {
- auto subCloneEnv = cloneEnvs.TryGetValue(dom);
- if (!subCloneEnv) break;
- if (subCloneEnv->mapOldValToNew.TryGetValue(operand, operandReplacement))
- {
- break;
- }
- dom = domTree->getImmediateDominator(dom);
- }
- // We have not found an existing clone in dominators, so we need to copy it
- // to this block.
- if (!operandReplacement)
- {
- auto subCloneEnv = cloneEnvs.TryGetValue(as<IRBlock>(inst->getParent()));
- builder.setInsertBefore(inst);
- operandReplacement = cloneInst(subCloneEnv, &builder, operand);
- workList.add(operandReplacement);
- }
- }
- if (operandReplacement)
- builder.replaceOperand(inst->getOperands() + j, operandReplacement);
- }
- }
- }
-
InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType)
{
SLANG_UNUSED(primalType);
@@ -774,7 +674,7 @@ namespace Slang
// Copy primal insts to the first block of the unzipped function, copy diff insts to the
// second block of the unzipped function.
//
- diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
+ RefPtr<HoistedPrimalsInfo> primalsInfo = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
IRFunc* unzippedFwdDiffFunc = fwdDiffFunc;
// Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell.
@@ -801,8 +701,8 @@ namespace Slang
// Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
// derivative of the return value.
- DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam };
- diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
+ DiffTransposePass::FuncTranspositionInfo transposeInfo = { paramTransposeInfo.dOutParam, primalsInfo };
+ diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, transposeInfo);
eliminateDeadCode(diffPropagateFunc);
@@ -810,7 +710,7 @@ namespace Slang
// with the intermediate results computed from the extracted func.
IRInst* intermediateType = nullptr;
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
- diffPropagateFunc, primalFunc, paramTransposeInfo, intermediateType);
+ diffPropagateFunc, primalFunc, primalsInfo, paramTransposeInfo, intermediateType);
// At this point the unzipped func is just an empty shell
// and we can simply remove it.
@@ -870,7 +770,7 @@ namespace Slang
initializeLocalVariables(builder->getModule(), as<IRGlobalValueWithCode>(getGenericReturnVal(primalFuncGeneric)));
initializeLocalVariables(builder->getModule(), diffPropagateFunc);
- insertVariableForRecomputedPrimalInsts(diffPropagateFunc);
+ // insertVariableForRecomputedPrimalInsts(diffPropagateFunc);
stripTempDecorations(diffPropagateFunc);
}
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index a92978817..e59f27881 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -8,6 +8,8 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-cfg-norm.h"
+#include "slang-ir-autodiff-primal-hoist.h"
+#include "slang-ir-dominators.h"
namespace Slang
{
@@ -78,6 +80,11 @@ struct DiffTransposePass
// of the *output* of the function.
//
IRInst* dOutInst;
+
+ // Information from the unzip pass on how primal insts
+ // are split across the primal and differential blocks.
+ //
+ HoistedPrimalsInfo* hoistedPrimalsInfo;
};
struct PendingBlockTerminatorEntry
@@ -227,15 +234,17 @@ struct DiffTransposePass
IRBlock* revAfterBlock = revBlockMap[currentBlock];
builder.setInsertInto(revCondBlock);
-
- hoistPrimalInst(&builder, ifElse->getCondition());
- builder.emitIfElse(
+ //hoistPrimalInst(&builder, ifElse->getCondition());
+
+ auto newIfElse = builder.emitIfElse(
ifElse->getCondition(),
revTrueEntryBlock,
revFalseEntryBlock,
revAfterBlock);
+ hoistPrimalOperands(&builder, newIfElse);
+
if (!revTrueRegionInfo.isTrivial)
{
builder.setInsertInto(revTrueExitBlock);
@@ -348,14 +357,18 @@ struct DiffTransposePass
// Emit condition into the new cond block.
builder.setInsertInto(revCondBlock);
- hoistPrimalInst(&builder, ifElse->getCondition());
- builder.emitIfElse(
+ // TODO: Need to defer this until after the CFG reversal is complete.
+ //hoistPrimalInst(&builder, ifElse->getCondition());
+
+ auto newIfElse = builder.emitIfElse(
ifElse->getCondition(),
revTrueBlock,
revFalseBlock,
revTrueBlock);
+ hoistPrimalOperands(&builder, newIfElse);
+
// Old false-side starting block becomes end block
// for the new pre-cond region (which could be empty)
//
@@ -364,12 +377,13 @@ struct DiffTransposePass
{
IRBlock* revPreCondEndBlock = revBlockMap[falseBlock];
builder.setInsertInto(revPreCondEndBlock);
- builder.emitLoop(
+ auto revLoop = builder.emitLoop(
revCondBlock,
revBreakBlock,
revLoopEndBlock,
getPhiGrads(falseBlock).getCount(),
getPhiGrads(falseBlock).getBuffer());
+ loop->transferDecorationsTo(revLoop);
auto revLoopStartBlock = revBlockMap[breakBlock];
builder.setInsertInto(revLoopStartBlock);
@@ -383,12 +397,13 @@ struct DiffTransposePass
// Emit loop into rev-version of the break block.
auto revLoopBlock = revBlockMap[breakBlock];
builder.setInsertInto(revLoopBlock);
- builder.emitLoop(
+ auto revLoop = builder.emitLoop(
revPreCondBlock,
revBreakBlock,
revLoopEndBlock,
getPhiGrads(breakBlock).getCount(),
getPhiGrads(breakBlock).getBuffer());
+ loop->transferDecorationsTo(revLoop);
}
currentBlock = breakBlock;
@@ -463,14 +478,16 @@ struct DiffTransposePass
builder.setInsertInto(revSwitchBlock);
- hoistPrimalInst(&builder, switchInst->getCondition());
+ // hoistPrimalInst(&builder, switchInst->getCondition());
- builder.emitSwitch(
+ auto newSwitchInst = builder.emitSwitch(
switchInst->getCondition(),
revBreakBlock,
revDefaultRegionEntry,
reverseSwitchArgs.getCount(),
reverseSwitchArgs.getBuffer());
+
+ hoistPrimalOperands(&builder, newSwitchInst);
currentBlock = breakBlock;
break;
@@ -504,6 +521,13 @@ struct DiffTransposePass
IRFunc* revDiffFunc,
FuncTranspositionInfo transposeInfo)
{
+ // TODO (sai): We really to make this method stateless
+ // (i.e. not store per-func info in 'this')
+ // since it is reused for every reverse-mode call.
+ //
+
+ hoistedPrimalsInfo = transposeInfo.hoistedPrimalsInfo;
+
// Grab all differentiable type information.
diffTypeContext.setFunc(revDiffFunc);
@@ -586,6 +610,9 @@ struct DiffTransposePass
}
}
+ // Make a temporary block to hold inverted insts.
+ tempInvBlock = builder.createBlock();
+
for (auto block : workList)
{
// Set dOutParameter as the transpose gradient for the return inst, if any.
@@ -620,7 +647,6 @@ struct DiffTransposePass
{
if (auto loopInst = as<IRLoop>(block->getTerminator()))
{
- lowerLoopExitValues(&builder, loopInst);
invertLoopCondition(&builder, loopInst);
}
}
@@ -655,6 +681,50 @@ struct DiffTransposePass
subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal);
}
+ // TODO: Should move this to before all the transposition, but a lot of the
+ // transposition logic seems to access the parent of blocks to find the func.
+ // Replace those uses.
+ //
+ for (auto block : workList)
+ block->removeFromParent();
+
+ // Mark all primal operands for hoisting.
+ // TODO: Can we just merge this with finishHoistingPrimalInsts?
+ // TODO: Some of this logic is replicated in finishHoistingPrimalInsts. Merge it with the
+ // maybeAddOperandsToWorkList logic there.
+ //
+ for (auto block : workList)
+ {
+ IRBlock* revBlock = revBlockMap[block];
+
+ for (auto child = revBlock->getFirstChild(); child; child = child->getNextInst())
+ {
+ hoistPrimalOperands(&builder, child);
+
+ for (auto decoration = child->getFirstDecoration(); decoration; decoration = decoration->getNextDecoration())
+ {
+ if (auto contextDecoration = as<IRBackwardDerivativePrimalContextDecoration>(decoration))
+ hoistPrimalUse(&builder, &contextDecoration->primalContextVar);
+
+ if (auto loopExitDecoration = as<IRLoopExitPrimalValueDecoration>(decoration))
+ hoistPrimalUse(&builder, &loopExitDecoration->exitVal);
+ }
+
+ if (auto instType = child->getDataType())
+ if (!as<IRModuleInst>(instType->getParent()))
+ hoistPrimalUse(&builder, &child->typeUse);
+ }
+ }
+
+ finishHoistingPrimals(revDiffFunc);
+
+ for (auto block : workList)
+ {
+ auto revBlock = as<IRBlock>(revBlockMap[block]);
+ if (auto revLoop = as<IRLoop>(revBlock->getTerminator()))
+ lowerLoopExitValues(&builder, revLoop);
+ }
+
// At this point, the only block left without terminator insts
// should be the last one. Add a void return to complete it.
//
@@ -723,8 +793,25 @@ struct DiffTransposePass
return tempRevVar;
}
+ IRVar* lookupInverseVar(IRInst* inst)
+ {
+ return inverseVarMap[inst];
+ }
+
+ IRVar* getOrCreateInverseVar(IRInst* primalInst, IRGlobalValueWithCode* func)
+ {
+ IRBlock* varBlock = firstRevDiffBlockMap[func];
+ return getOrCreateInverseVar(primalInst, varBlock);
+ }
+
IRVar* getOrCreateInverseVar(IRInst* primalInst)
{
+ IRBlock* varBlock = firstRevDiffBlockMap[as<IRFunc>(primalInst->getParent()->getParent())];
+ return getOrCreateInverseVar(primalInst, varBlock);
+ }
+
+ IRVar* getOrCreateInverseVar(IRInst* primalInst, IRBlock* varBlock)
+ {
// No need to store inverse values for constants.
if (as<IRConstant>(primalInst))
return nullptr;
@@ -734,13 +821,11 @@ struct DiffTransposePass
return inverseVarMap[primalInst];
IRBuilder tempVarBuilder(autodiffContext->moduleInst);
-
- IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(primalInst->getParent()->getParent())];
- if (auto firstInst = firstDiffBlock->getFirstOrdinaryInst())
+ if (auto firstInst = varBlock->getFirstOrdinaryInst())
tempVarBuilder.setInsertBefore(firstInst);
else
- tempVarBuilder.setInsertInto(firstDiffBlock);
+ tempVarBuilder.setInsertInto(varBlock);
auto primalType = primalInst->getDataType();
@@ -766,6 +851,19 @@ struct DiffTransposePass
return false;
}
+ IRParam* getParamAt(IRBlock* block, UIndex ii)
+ {
+ UIndex index = 0;
+ for (auto param : block->getParams())
+ {
+ if (ii == index)
+ return param;
+
+ index ++;
+ }
+ SLANG_UNEXPECTED("ii >= paramCount");
+ }
+
void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock)
{
IRBuilder builder(autodiffContext->moduleInst);
@@ -773,6 +871,10 @@ struct DiffTransposePass
// Insert into our reverse block.
builder.setInsertInto(revBlock);
+ // Create an inverse builder to insert insts into the inv-block.
+ IRBuilder invBuilder(autodiffContext->moduleInst);
+
+
// Check if this block has any 'outputs' (in the form of phi args
// sent to the successor block)
//
@@ -798,25 +900,43 @@ struct DiffTransposePass
revParam,
nullptr));
}
- else if (isPrimalInst(arg))
+ else if (hasInverse(arg))
{
- // If the output arg is a primal, emit a parameter
- // to accept it as an _input_ for the reverse-mode
- //
- auto primalType = arg->getDataType();
- auto primalInvParam = builder.emitParam(primalType);
+ InversionInfo invInfo = this->hoistedPrimalsInfo->invertInfoMap[branchInst];
+ if (invInfo.targetInsts.contains(arg))
+ {
+ SLANG_ASSERT(hasInverse(getParamAt(branchInst->getTargetBlock(), ii)));
+
+ // If the output arg is a primal, emit a parameter
+ // to accept it as an _input_ for the reverse-mode
+ //
+ auto primalType = arg->getDataType();
+ auto primalInvParam = builder.emitParam(primalType);
- setInverse(&builder, arg, primalInvParam);
+ invBuilder.setInsertBefore(branchInst);
+ setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam);
+ }
}
else
{
- SLANG_UNEXPECTED("Encountered inst not marked as primal or differential");
+ if (hasInverse(getParamAt(branchInst->getTargetBlock(), ii)))
+ {
+ auto primalType = arg->getDataType();
+ auto primalInvParam = builder.emitParam(primalType);
+
+ invBuilder.setInsertBefore(branchInst);
+ setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion");
+ }
}
}
}
// Move pointer & reference insts to the top of the reverse-mode block.
- List<IRInst*> nonValueInsts;
+ List<IRInst*> typeInsts;
for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
{
// If the instruction is a variable allocation (or reverse-gradient pair reference),
@@ -824,17 +944,17 @@ struct DiffTransposePass
// TODO: This is hacky.. Need a more principled way to handle this
// (like primal inst hoisting)
//
- if (as<IRVar>(child) || as<IRReverseGradientDiffPairRef>(child))
- nonValueInsts.add(child);
+ //if (as<IRVar>(child) || as<IRReverseGradientDiffPairRef>(child))
+ // nonValueInsts.add(child);
// Slang doesn't support function values. So if we see a func-typed inst
// it's proabably a reference to a function.
//
if (as<IRFuncType>(child->getDataType()))
- nonValueInsts.add(child);
+ typeInsts.add(child);
}
- for (auto inst : nonValueInsts)
+ for (auto inst : typeInsts)
{
inst->insertAtEnd(revBlock);
}
@@ -846,9 +966,6 @@ struct DiffTransposePass
//
for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst())
{
- if (child->findDecoration<IRPrimalValueAccessDecoration>())
- continue;
-
if (as<IRDecoration>(child) || as<IRParam>(child))
continue;
if (as<IRType>(child))
@@ -856,8 +973,15 @@ struct DiffTransposePass
if (isDifferentialInst(child))
transposeInst(&builder, child);
- else if (isPrimalInst(child))
- invertInst(&builder, child);
+ else if (shouldInstBeInverted(child))
+ {
+ // We'll collect inverse insts in an orphaned block,
+ // so disable IR validation temporarily.
+ //
+ disableIRValidationAtInsert();
+ invertInst(&invBuilder, child);
+ enableIRValidationAtInsert();
+ }
}
// After processing the block's instructions, we 'flush' any remaining gradients
@@ -901,23 +1025,18 @@ struct DiffTransposePass
phiParamRevGradInsts.add(gradInst);
}
else
- {
+ {
phiParamRevGradInsts.add(
emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param)));
}
}
- else if (isPrimalInst(param))
+ else if (hasInverse(param))
{
- if (hasInverse(param))
- phiParamRevGradInsts.add(getInverse(&builder, param));
- else
- {
- SLANG_UNEXPECTED("param is a primal inst but has no registered inverse");
- }
+ phiParamRevGradInsts.add(param);
}
else
- {
- SLANG_UNEXPECTED("param is neither differential nor primal");
+ {
+ SLANG_UNEXPECTED("param is neither differential inst nor marked for inversion");
}
}
@@ -995,8 +1114,15 @@ struct DiffTransposePass
{ }
};
- List<InvInstPair> invertArithmetic(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput)
+ List<InvInstPair> invertArithmetic(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo)
{
+ SLANG_RELEASE_ASSERT(invInfo.requiredOperands.getCount() == 1);
+ SLANG_RELEASE_ASSERT(invInfo.targetInsts.getCount() == 1);
+
+ auto invOutput = invInfo.requiredOperands[0];
+
+ auto invTargetInst = invInfo.targetInsts[0];
+
switch (primalInst->getOp())
{
case kIROp_Add:
@@ -1004,7 +1130,7 @@ struct DiffTransposePass
SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1)));
return List<InvInstPair>(
InvInstPair(
- primalInst->getOperand(0),
+ invTargetInst,
builder->emitSub(
primalInst->getOperand(0)->getDataType(),
invOutput,
@@ -1015,7 +1141,7 @@ struct DiffTransposePass
SLANG_RELEASE_ASSERT(as<IRConstant>(primalInst->getOperand(1)));
return List<InvInstPair>(
InvInstPair(
- primalInst->getOperand(0),
+ invTargetInst,
builder->emitAdd(
primalInst->getOperand(0)->getDataType(),
invOutput,
@@ -1027,24 +1153,38 @@ struct DiffTransposePass
}
}
- void lowerLoopExitValues(IRBuilder* builder, IRLoop* fwdLoop)
+ // NOTE: This is a workaround for the fact that we expect inverses to use
+ // single-use variables. The loop exit value will add a
+ // second store to most inv-variables and mess with the primal hoisting mechanism.
+ // Instead of emitting into the orphaned inverse block, we'll directly emit into
+ // the reverse-mode block since we'll be running this _after_ the primal hoisting
+ // pass.
+ //
+ // This workaround is fine for inverting loop counters, but when we want to
+ // expand to supporting general-purpose adjoints, we would want to use per-region
+ // inverse vars based on 'invInfo' (enforcing single-use vars)
+ //
+ void lowerLoopExitValues(IRBuilder* builder, IRLoop* revLoop)
{
- for (auto decoration : fwdLoop->getDecorations())
+ List<IRDecoration*> processedDecorations;
+ for (auto decoration : revLoop->getDecorations())
{
if (auto loopExitValueDecoration = as<IRLoopExitPrimalValueDecoration>(decoration))
{
- IRBlock* revLoopInitBlock = revBlockMap[fwdLoop->getBreakBlock()];
-
- if (auto revLoopInst = revLoopInitBlock->getTerminator())
- builder->setInsertBefore(revLoopInst);
- else
- builder->setInsertInto(revLoopInitBlock);
-
- hoistPrimalInst(builder, loopExitValueDecoration->getLoopExitValInst());
+ builder->setInsertBefore(revLoop);
+ setInverse(
+ builder,
+ nullptr,
+ builder->getFunc(),
+ loopExitValueDecoration->getTargetInst(),
+ loopExitValueDecoration->getLoopExitValInst());
- setInverse(builder, loopExitValueDecoration->getTargetInst(), loopExitValueDecoration->getLoopExitValInst());
+ processedDecorations.add(loopExitValueDecoration);
}
}
+
+ for (auto decoration : processedDecorations)
+ decoration->removeAndDeallocate();
}
void lowerLoopExitValues(IRBuilder* builder, IRBlock* block)
@@ -1094,19 +1234,21 @@ struct DiffTransposePass
kIROp_Neq,
2,
List<IRInst*>(
- hoistPrimalInst(builder, loopCounterParam),
- hoistPrimalInst(builder, loopCounterInitVal)).getBuffer());
+ loopCounterParam,
+ loopCounterInitVal).getBuffer());
+
+ hoistPrimalOperands(builder, paramBoundsCheck);
as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(paramBoundsCheck);
}
- List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput)
+ List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo)
{
switch (primalInst->getOp())
{
case kIROp_Add:
case kIROp_Sub:
- return invertArithmetic(builder, primalInst, invOutput);
+ return invertArithmetic(builder, primalInst, invInfo);
default:
SLANG_UNIMPLEMENTED_X("Unhandled inst type for inversion");
@@ -1115,70 +1257,392 @@ struct DiffTransposePass
bool hasInverse(IRInst* primalInst)
{
- if (getOrCreateInverseVar(primalInst))
- return true;
- else
- return false;
+ return this->hoistedPrimalsInfo->invertSet.Contains(primalInst);
}
- IRInst* getInverse(IRBuilder* builder, IRInst* primalInst)
+ IRInst* loadInverse(IRBuilder* builder, IRInst* primalInst)
{
// Note: There are other possible cases here, although not important
// right now. For example, a value is available to load from the primal block.
//
- if (auto invVar = getOrCreateInverseVar(primalInst))
+
+ if (auto invVar = getOrCreateInverseVar(primalInst, builder->getFunc()))
return builder->emitLoad(invVar);
return nullptr;
}
- void setInverse(IRBuilder* builder, IRInst* inst, IRInst* invInst)
+ IRInst* lookupInstInPrimalBlock(IRInst* invInst)
{
- if (auto invVar = getOrCreateInverseVar(inst))
- builder->emitStore(invVar, invInst);
+ // Lookup the inst in the primal block whose value we can use as an operand
+ // for the inverted inst.
+ //
+ // auto inversionInfo = this->hoistedPrimalsInfo->invertInfoMap[invInst];
+ return invInst;
}
- IRInst* hoistPrimalInst(IRBuilder* revBuilder, IRInst* inst)
+ void setInverse(IRBuilder* builder, IRBlock* defBlock, IRGlobalValueWithCode* func, IRInst* inst, IRInst* invInst)
{
- if (as<IRBlock>(inst->getParent()) &&
- isDifferentialInst(as<IRBlock>(inst->getParent())))
+ auto instBlock = as<IRBlock>(inst->getParent());
+ if (!instBlock)
+ return;
+
+ disableIRValidationAtInsert();
+ if (auto invVar = getOrCreateInverseVar(inst, func))
{
- SLANG_RELEASE_ASSERT(isPrimalInst(inst));
+ auto invStore = builder->emitStore(invVar, invInst);
+ mapStoreToDefBlock[as<IRStore>(invStore)] = defBlock;
}
+ enableIRValidationAtInsert();
+ }
- // Are the operands of this primal inst also available in the reverse-mode context?
- // If not, move/load them.
+ bool shouldInstBeInverted(IRInst* inst)
+ {
+
+ if (this->hoistedPrimalsInfo->instsToInvert.Contains(inst))
+ return true;
+
+ return false;
+ }
+
+ IRInst* hoistPrimalUse(IRBuilder*, IRUse* use)
+ {
+ primalUsesToHoist.add(use);
+ return use->get();
+ }
+
+ bool doesInstRequireHoisting(IRInst* inst)
+ {
+ if (as<IRModuleInst>(inst->getParent()))
+ return false;
+
+ if (as<IRBlock>(inst) ||
+ as<IRGlobalValueWithCode>(inst) ||
+ as<IRConstant>(inst))
+ return false;
+
+ if (as<IRTerminatorInst>(inst))
+ return false;
+
+ if (as<IRDecoration>(inst))
+ return doesInstRequireHoisting(getInstInBlock(inst));
+
+ // We're looking for primal insts in differential blocks
+ // that have not yet been moved to the 'active' blocks
+ // (i.e in diff blocks that do not have parents)
//
- hoistPrimalOperands(revBuilder, inst);
+ return (!isDifferentialInst(inst) &&
+ (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) &&
+ getBlock(inst)->getParent() == nullptr);
+ }
- if (isPrimalInst(inst) &&
- as<IRBlock>(inst->getParent()) &&
- isDifferentialInst(as<IRBlock>(inst->getParent())))
+ // Builds a map from inst to a list of uses by primal _inverted_ insts.
+ Dictionary<IRInst*, List<IRInst*>> buildInvOperandMap()
+ {
+ Dictionary<IRInst*, List<IRInst*>> invOperandMap;
+ for (auto kvpair : this->hoistedPrimalsInfo->invertInfoMap)
{
- if (!inst->findDecoration<IRPrimalValueAccessDecoration>())
+ InversionInfo invInfo = kvpair.Value;
+
+ for (auto operand : invInfo.requiredOperands)
{
- return getInverse(revBuilder, inst);
+ if (!invOperandMap.ContainsKey(operand))
+ invOperandMap[operand] = List<IRInst*>();
+
+ for (auto target : invInfo.targetInsts)
+ invOperandMap[operand].GetValue().add(target);
}
- else
+ }
+
+ return invOperandMap;
+ }
+
+ IRBlock* walkToEndOfRegion(IRBlock* block)
+ {
+ IRBlock* currBlock = block;
+
+ bool keepGoing = true;
+ while (keepGoing)
+ {
+ auto terminator = currBlock->getTerminator();
+ switch (terminator->getOp())
{
- auto block = as<IRBlock>(inst->getParent());
- SLANG_RELEASE_ASSERT(block);
+ case kIROp_Return:
+ keepGoing = false;
+ break;
- if (block == revBuilder->getBlock())
+ case kIROp_unconditionalBranch:
{
- // Already in block..
- return inst;
- }
+ auto nextBlock = as<IRUnconditionalBranch>(terminator)->getTargetBlock();
- // Otherwise, move our inst to the the current builder location.
- inst->removeFromParent();
- revBuilder->addInst(inst);
+ HashSet<IRBlock*> predecessorSet;
+ for (auto predecessor : nextBlock->getPredecessors())
+ predecessorSet.Add(predecessor);
- return inst;
+ if (predecessorSet.Count() > 1)
+ {
+ keepGoing = false;
+ break;
+ }
+
+ currBlock = nextBlock;
+ break;
+ }
+
+ case kIROp_ifElse:
+ {
+ for (auto predecessor : currBlock->getPredecessors())
+ {
+ if (as<IRLoop>(predecessor->getTerminator()))
+ {
+ keepGoing = false;
+ break;
+ }
+ }
+
+ currBlock = as<IRIfElse>(terminator)->getAfterBlock();
+ break;
+ }
+
+ case kIROp_Switch:
+ currBlock = as<IRSwitch>(terminator)->getBreakLabel();
+ break;
+
+ case kIROp_loop:
+ currBlock = as<IRLoop>(terminator)->getBreakBlock();
+ break;
}
}
- return inst;
+ return currBlock;
+ }
+
+ void finishHoistingPrimals(IRGlobalValueWithCode* func)
+ {
+ List<IRInst*> workList;
+
+ Dictionary<IRInst*, IRInst*> hoistedInstMap;
+
+ RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
+
+ Dictionary<IRInst*, List<IRInst*>> invOperandMap = buildInvOperandMap();
+
+ auto varBlock = func->getFirstBlock()->getNextBlock();
+
+ // Load up pending insts into workList.
+ for (auto use : primalUsesToHoist)
+ workList.add(use->get());
+
+ primalUsesToHoist.clear();
+
+
+ auto maybeAddPrimalOperandsToWorkList = [&](IRInst* inst)
+ {
+ UIndex opIndex = 0;
+ for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
+ {
+ if (doesInstRequireHoisting(operand->get()) &&
+ !hoistedInstMap.ContainsKey(operand->get()))
+ {
+ workList.add(operand->get());
+ }
+ }
+
+ if (auto instType = inst->getDataType())
+ {
+ if (doesInstRequireHoisting(instType) &&
+ !hoistedInstMap.ContainsKey(instType))
+ workList.add(instType);
+ }
+ };
+
+ auto maybeAddUsersToWorkList = [&](IRInst* inst)
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (doesInstRequireHoisting(use->getUser()))
+ {
+ if (as<IRVar>(inst) && as<IRStore>(use->getUser()))
+ continue;
+
+ // Uses that haven't already been hoisted into reverse-mode
+ // blocks, and are not in the invert-set are pending uses.
+ //
+ if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser()))
+ workList.add(use->getUser());
+ }
+ }
+ };
+
+ auto doesInstHavePendingUsers = [&](IRInst* inst)
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (doesInstRequireHoisting(use->getUser()))
+ {
+ if (as<IRVar>(inst) && as<IRStore>(use->getUser()))
+ continue;
+
+ // Users that haven't already been hoisted into reverse-mode
+ // blocks are pending users.
+ //
+ if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser()))
+ return true;
+ }
+ }
+
+ return false;
+ };
+
+ auto isInstHoisted = [&](IRInst* inst)
+ {
+ return getBlock(inst)->getParent() != nullptr && isDifferentialInst(getBlock(inst));
+ };
+
+ while (workList.getCount() > 0)
+ {
+ // Pop work item
+ auto inst = workList.getLast();
+ workList.removeLast();
+
+ // Already hoisted to reverse-mode block.
+ // replace with mapped inst (in case it's different)
+ // and continue on.. (this should actually never be hit)
+ //
+ if (hoistedInstMap.ContainsKey(inst))
+ continue;
+
+ if (invOperandMap.ContainsKey(inst))
+ {
+ List<IRInst*> pendingInvDependencies;
+ for (auto dependency : invOperandMap[inst].GetValue())
+ {
+ if (doesInstRequireHoisting(dependency) &&
+ !hoistedInstMap.ContainsKey(dependency))
+ pendingInvDependencies.add(dependency);
+ }
+
+ if (pendingInvDependencies.getCount() > 0)
+ {
+ workList.add(inst);
+ for (auto dependency : pendingInvDependencies)
+ workList.add(dependency);
+
+ // Skip until all the dependencies have been handled.
+ continue;
+ }
+ }
+
+ // Are the uses of this primal inst already hoisted into the reverse-mode
+ // blocks? We cannot hoist this inst unless the uses are hoisted.
+ //
+ if (doesInstHavePendingUsers(inst))
+ {
+ // Add inst back to work list.
+ workList.add(inst);
+
+ // Then, add all the pending use to the top of
+ // list, ensuring they are processed before we see
+ // inst again.
+ //
+ maybeAddUsersToWorkList(inst);
+
+ continue;
+ }
+
+ // The used inst is marked for inversion, lookup and load
+ // an inverse.
+ //
+ if (this->hoistedPrimalsInfo->invertSet.Contains(inst))
+ {
+ // Replace with inverse.
+ IRBuilder builder(func->getModule());
+
+ for (auto use = inst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ if (!isInstHoisted(use->getUser()))
+ {
+ use = nextUse;
+ continue;
+ }
+
+ // TODO: Hacky workaround to prevent the 'key' being overwritten,
+ // avoid this by adding the decoration on the param instead of the loop
+ //
+ if (auto exitValDecoration = as<IRLoopExitPrimalValueDecoration>(use->getUser()))
+ {
+ if (&exitValDecoration->target == use)
+ {
+ use = nextUse;
+ continue;
+ }
+ }
+
+
+ builder.setInsertBefore(getInstInBlock(use->getUser()));
+ use->set(loadInverse(&builder, inst));
+
+ use = nextUse;
+ }
+
+ // If all uses of the invertible inst have been hoisted,
+ // add the inv-var to the worklist.
+ //
+ workList.add(lookupInverseVar(inst));
+ hoistedInstMap[inst] = nullptr;
+
+ continue;
+ }
+
+ // Should not see an inst marked for inversion here.
+ SLANG_RELEASE_ASSERT(!this->hoistedPrimalsInfo->invertSet.Contains(inst));
+
+ List<IRUse*> relevantUses;
+
+ IRBlock* defBlock = nullptr;
+ if (auto varToHoist = as<IRVar>(inst))
+ {
+ varToHoist->insertBefore(varBlock->getFirstOrdinaryInst());
+ inst = findUniqueStoredVal(varToHoist)->getUser();
+ SLANG_ASSERT(inst);
+
+ defBlock = getBlock(inst);
+ }
+ else
+ {
+ defBlock = getBlock(inst);
+ }
+
+ if (!doesInstRequireHoisting(inst))
+ continue;
+
+ // Move this inst to after it's diff uses.
+ //
+ {
+
+ IRBlock* currTopBlock = revBlockMap[walkToEndOfRegion(defBlock)];
+
+ SLANG_RELEASE_ASSERT(currTopBlock);
+
+ // More consistency checks
+ SLANG_RELEASE_ASSERT(currTopBlock->getFirstOrdinaryInst() != nullptr);
+ SLANG_RELEASE_ASSERT(currTopBlock->getParent() != nullptr);
+ SLANG_RELEASE_ASSERT(isDifferentialInst(currTopBlock));
+
+ // Insert at top. (disabling validation since the operands of
+ // this inst might not be hoisted to the right place yet)
+ //
+ disableIRValidationAtInsert();
+ inst->insertBefore(currTopBlock->getFirstOrdinaryInst());
+ enableIRValidationAtInsert();
+ }
+
+ // Finish up..
+ hoistedInstMap[inst] = inst;
+ maybeAddPrimalOperandsToWorkList(inst);
+ }
}
void hoistPrimalOperands(IRBuilder* revBuilder, IRInst* fwdInst)
@@ -1192,10 +1656,9 @@ struct DiffTransposePass
// make sure all requried primal insts are moved to the right
// place)
//
- if (isPrimalInst(fwdInst->getOperand(ii)))
+ if (doesInstRequireHoisting(fwdInst->getOperand(ii)))
{
- auto hoistedPrimalInst = hoistPrimalInst(revBuilder, fwdInst->getOperand(ii));
- fwdInst->setOperand(ii, hoistedPrimalInst);
+ hoistPrimalUse(revBuilder, &fwdInst->getOperands()[ii]);
}
}
}
@@ -1203,14 +1666,28 @@ struct DiffTransposePass
void invertInst(IRBuilder* builder, IRInst* primalInst)
{
// Look for an available inverse entry for this primalInst's *output*
- if (hasInverse(primalInst))
+ if (shouldInstBeInverted(primalInst))
{
- auto invOutput = getInverse(builder, primalInst);
+ // This logic is already handled in transposeBlock() so we skip
+ // it here.
+ //
+ if (as<IRTerminatorInst>(primalInst))
+ return;
+
+ auto invInfo = this->hoistedPrimalsInfo->invertInfoMap[primalInst];
- auto invEntries = invertInst(builder, primalInst, invOutput);
+ IRBuilder invBuilder(builder->getModule());
+ invBuilder.setInsertAfter(primalInst);
+ auto invEntries = invertInst(&invBuilder, primalInst, invInfo);
+
for (auto entry : invEntries)
- setInverse(builder, entry.inst, entry.invInst);
+ setInverse(
+ &invBuilder,
+ getBlock(primalInst),
+ as<IRGlobalValueWithCode>(entry.inst->getParent()->getParent()),
+ entry.inst,
+ entry.invInst);
}
else
{
@@ -1270,11 +1747,6 @@ struct DiffTransposePass
SLANG_ASSERT(gradients.getCount() == 0);
}
- // Ensure primal operands are replaced with insts accessible in the
- // reverse-mode context.
- //
- hoistPrimalOperands(builder, inst);
-
// Is this inst used in another differential block?
// Emit a function-scope accumulator variable, and include it's value.
// Also, we ignore this if it's a load since those are turned into stores
@@ -1381,7 +1853,9 @@ struct DiffTransposePass
// In order to perform the call, we need a temporary var to store the DiffPair.
auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType();
auto tempVar = builder->emitVar(pairType);
- auto primalVal = builder->emitLoad(hoistPrimalInst(builder, instPair->getPrimal()));
+ auto primalVal = builder->emitLoad(instPair->getPrimal());
+ hoistPrimalOperands(builder, primalVal); // TODO(sai): Do we need to hoist other insts here?
+
auto diffVal = builder->emitLoad(instPair->getDiff());
auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
builder->emitStore(tempVar, pairVal);
@@ -1453,13 +1927,15 @@ struct DiffTransposePass
// If the callee provides a primal implementation that produces continuation context for propagation phase
// we grab it and pass it as argument to the propagation function.
+ //
if (auto primalContextDecor = fwdCall->findDecoration<IRBackwardDerivativePrimalContextDecoration>())
- {
- // Ensure availability of the primal context var
- auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar());
- SLANG_RELEASE_ASSERT(primalContextVar);
+ {
+ auto primalContextVar = primalContextDecor->getBackwardDerivativePrimalContextVar();
+
+ auto contextLoad = builder->emitLoad(primalContextVar);
+ hoistPrimalOperands(builder, contextLoad);
- args.add(builder->emitLoad(primalContextVar));
+ args.add(contextLoad);
argTypes.add(as<IRPtrTypeBase>(
primalContextVar->getDataType())
->getValueType());
@@ -1735,6 +2211,7 @@ struct DiffTransposePass
return transposeUpdateElement(builder, fwdInst, revValue);
case kIROp_LoadReverseGradient:
+ case kIROp_ReverseGradientDiffPairRef:
case kIROp_DefaultConstruct:
case kIROp_Specialize:
case kIROp_unconditionalBranch:
@@ -2255,18 +2732,17 @@ struct DiffTransposePass
{
// current type should be a scalar.
SLANG_RELEASE_ASSERT(!as<IRVectorType>(currentType->getDataType()));
-
- auto targetVectorType = as<IRVectorType>(targetType);
- List<IRInst*> operands;
- for (Index ii = 0; ii < as<IRIntLit>(targetVectorType->getElementCount())->getValue(); ii++)
- {
- operands.add(inst);
- }
+ return builder->emitMakeVectorFromScalar(targetType, inst);
+ }
- IRInst* newInst = builder->emitMakeVector(targetType, operands.getCount(), operands.getBuffer());
+ case kIROp_MatrixType:
+ {
+ // current type should be a scalar.
+ SLANG_RELEASE_ASSERT(!as<IRVectorType>(currentType->getDataType()) &&
+ !as<IRMatrixType>(currentType->getDataType()));
- return newInst;
+ return builder->emitMakeMatrixFromScalar(targetType, inst);
}
default:
@@ -2968,6 +3444,10 @@ struct DiffTransposePass
DifferentialPairTypeBuilder pairBuilder;
+ HoistedPrimalsInfo* hoistedPrimalsInfo;
+
+ IRBlock* tempInvBlock;
+
Dictionary<IRInst*, List<RevGradient>> gradientsMap;
Dictionary<IRInst*, IRVar*> revAccumulatorVarMap;
@@ -2987,6 +3467,10 @@ struct DiffTransposePass
Dictionary<IRBlock*, List<IRInst*>> phiGradsMap;
Dictionary<IRInst*, IRInst*> inverseValueMap;
+
+ List<IRUse*> primalUsesToHoist;
+
+ Dictionary<IRStore*, IRBlock*> mapStoreToDefBlock;
};
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 5b59416d4..16862bb19 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -332,7 +332,12 @@ struct ExtractPrimalFuncContext
inst);
}
- IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, HashSet<IRInst*>& primalParams, IRInst*& outIntermediateType)
+ IRFunc* turnUnzippedFuncIntoPrimalFunc(
+ IRFunc* unzippedFunc,
+ IRFunc* originalFunc,
+ HoistedPrimalsInfo* primalsInfo,
+ HashSet<IRInst*>& primalParams,
+ IRInst*& outIntermediateType)
{
IRBuilder builder(module);
@@ -375,17 +380,9 @@ struct ExtractPrimalFuncContext
// output intermediary struct.
for (auto inst : block->getChildren())
{
- if (shouldStoreInst(inst))
+ if (primalsInfo->storeSet.Contains(inst))
{
- if (as<IRParam>(inst))
- builder.setInsertBefore(block->getFirstOrdinaryInst());
- else
- builder.setInsertAfter(inst);
- storeInst(builder, inst, outIntermediary);
- }
- else if (inst->getOp() == kIROp_Var)
- {
- if (shouldStoreVar(as<IRVar>(inst)))
+ if (as<IRVar>(inst))
{
auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary);
builder.setInsertBefore(inst);
@@ -394,7 +391,14 @@ struct ExtractPrimalFuncContext
inst->replaceUsesWith(fieldAddr);
builder.addPrimalValueStructKeyDecoration(inst, field->getKey());
}
-
+ else
+ {
+ if (as<IRParam>(inst))
+ builder.setInsertBefore(block->getFirstOrdinaryInst());
+ else
+ builder.setInsertAfter(inst);
+ storeInst(builder, inst, outIntermediary);
+ }
}
}
}
@@ -459,6 +463,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
IRFunc* DiffUnzipPass::extractPrimalFunc(
IRFunc* func,
IRFunc* originalFunc,
+ HoistedPrimalsInfo* primalsInfo,
ParameterBlockTransposeInfo& paramInfo,
IRInst*& intermediateType)
{
@@ -470,6 +475,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
subEnv.parent = &cloneEnv;
auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func));
+ auto clonedPrimalsInfo = primalsInfo->applyMap(&subEnv);
+
// Remove [KeepAlive] decorations in clonedFunc.
for (auto block : clonedFunc->getBlocks())
for (auto inst : block->getChildren())
@@ -494,7 +501,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
context.init(autodiffContext->moduleInst->getModule(), autodiffContext->transcriberSet.primalTranscriber);
intermediateType = nullptr;
- auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, newPrimalParams, intermediateType);
+ auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, clonedPrimalsInfo, newPrimalParams, intermediateType);
if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>())
{
@@ -580,6 +587,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
{
// The primal calls should be marked as no side effect so they can be DCE'd if possible.
// We can only do so if the intermediate context of the callee is stored.
+ //
if (primalCtx->getBackwardDerivativePrimalContextVar()
->findDecoration<IRPrimalValueStructKeyDecoration>())
{
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index f2aa1fd29..8b24b122e 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -9,6 +9,8 @@
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-propagate.h"
#include "slang-ir-autodiff-transcriber-base.h"
+#include "slang-ir-autodiff-region.h"
+#include "slang-ir-autodiff-primal-hoist.h"
#include "slang-ir-validate.h"
#include "slang-ir-ssa.h"
@@ -36,171 +38,10 @@ struct DiffUnzipPass
// might run into an issue here?
IRBlock* firstDiffBlock;
- struct IndexedRegion : public RefObject
- {
- IRLoop* loop;
- IndexedRegion* parent;
-
- IndexedRegion(IRLoop* loop, IndexedRegion* parent) : loop(loop), parent(parent)
- { }
-
- IRBlock* getInitializerBlock() { return as<IRBlock>(loop->getParent()); }
- IRBlock* getConditionBlock()
- {
- auto condBlock = as<IRBlock>(loop->getTargetBlock());
- SLANG_RELEASE_ASSERT(as<IRIfElse>(condBlock->getTerminator()));
- return condBlock;
- }
-
- IRBlock* getBreakBlock() { return loop->getBreakBlock(); }
-
- IRBlock* getUpdateBlock()
- {
- auto initBlock = getInitializerBlock();
-
- auto condBlock = getConditionBlock();
-
- IRBlock* lastLoopBlock = nullptr;
-
- for (auto predecessor : condBlock->getPredecessors())
- {
- if (predecessor != initBlock)
- lastLoopBlock = predecessor;
- }
-
- // Should find atleast one predecessor that is _not_ the
- // init block (that contains the loop info). This
- // predecessor would be the last block in the loop
- // before looping back to the condition.
- //
- SLANG_RELEASE_ASSERT(lastLoopBlock);
-
- return lastLoopBlock;
- }
- };
-
-
- struct IndexedRegionMap : public RefObject
- {
- Dictionary<IRBlock*, IndexedRegion*> map;
- List<RefPtr<IndexedRegion>> regions;
-
- IndexedRegion* newRegion(IRLoop* loop, IndexedRegion* parent)
- {
- auto region = new IndexedRegion(loop, parent);
- regions.add(region);
-
- return region;
- }
-
- void mapBlock(IRBlock* block, IndexedRegion* region)
- {
- map.Add(block, region);
- }
-
- bool hasMapping(IRBlock* block)
- {
- return map.ContainsKey(block);
- }
-
- IndexedRegion* getRegion(IRBlock* block)
- {
- return map[block];
- }
-
- List<IndexedRegion*> getAllAncestorRegions(IRBlock* block)
- {
- List<IndexedRegion*> regionList;
-
- IndexedRegion* region = getRegion(block);
- for (; region; region = region->parent)
- regionList.add(region);
-
- return regionList;
- }
- };
-
- RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func)
- {
- RefPtr<IndexedRegionMap> regionMap = new IndexedRegionMap;
-
- List<IRBlock*> workList;
-
- regionMap->mapBlock(func->getFirstBlock(), nullptr);
- workList.add(func->getFirstBlock());
-
- while (workList.getCount() > 0)
- {
- auto currentBlock = workList.getLast();
- workList.removeLast();
-
- auto terminator = currentBlock->getTerminator();
- auto currentRegion = regionMap->getRegion(currentBlock);
-
- switch (terminator->getOp())
- {
- case kIROp_loop:
- {
- auto loopRegion = regionMap->newRegion(as<IRLoop>(terminator), currentRegion);
- auto condBlock = as<IRLoop>(terminator)->getTargetBlock();
-
- regionMap->mapBlock(condBlock, loopRegion);
- workList.add(condBlock);
-
- auto ifElse = as<IRIfElse>(condBlock->getTerminator());
- SLANG_RELEASE_ASSERT(ifElse);
-
- // TODO: this is one of the places we'll need to change if we support loops that
- // loop on either the true or false side. For now, we assume the loop is on the
- // true side only.
- //
- regionMap->mapBlock(ifElse->getFalseBlock(), currentRegion);
- workList.add(ifElse->getFalseBlock());
- }
- }
-
- for (auto successor : currentBlock->getSuccessors())
- {
- // If already mapped, skip.
- if (regionMap->hasMapping(successor))
- continue;
- regionMap->mapBlock(successor, currentRegion);
- workList.add(successor);
- }
- }
-
- return regionMap;
- }
-
-
RefPtr<IndexedRegionMap> indexRegionMap;
- struct IndexTrackingInfo : public RefObject
- {
- // After lowering, store references to the count
- // variables associated with this region
- //
- IRInst* primalCountParam = nullptr;
- IRInst* diffCountParam = nullptr;
-
- IRVar* primalCountLastVar = nullptr;
-
- enum CountStatus
- {
- Unresolved,
- Dynamic,
- Static
- };
-
- CountStatus status = CountStatus::Unresolved;
-
- // Inferred maximum number of iterations.
- Count maxIters = -1;
- };
-
Dictionary<IndexedRegion*, RefPtr<IndexTrackingInfo>> indexInfoMap;
-
DiffUnzipPass(
AutoDiffSharedContext* autodiffContext)
: autodiffContext(autodiffContext)
@@ -217,7 +58,7 @@ struct DiffUnzipPass
return diffMap[inst];
}
- void unzipDiffInsts(IRFunc* func)
+ RefPtr<HoistedPrimalsInfo> unzipDiffInsts(IRFunc* func)
{
diffTypeContext.setFunc(func);
@@ -316,7 +157,12 @@ struct DiffUnzipPass
// Emit counter variables and other supporting
// instructions for all regions.
//
- lowerIndexedRegions();
+ // TODO: Need to have maxIndex in _both_ IndexTrackingInfo & IndexedRegionInfo.
+ // That way, we can do the various passes _before_ lowerIndexedRegions()
+ // TODO: Remove the call to lowerIndexedRegions() once checkpointing works properly.
+ //
+ RefPtr<HoistedPrimalsInfo> primalsInfo = new HoistedPrimalsInfo();
+ lowerIndexedRegions(primalsInfo);
// Copy regions from fwd-block to their split blocks
// to make it easier to do lookups.
@@ -338,21 +184,39 @@ struct DiffUnzipPass
indexRegionMap->map[as<IRBlock>(diffMap[block])] = (IndexedRegion*)indexRegionMap->map[block];
}
}
+
+ // Swap the first block's occurences out for the first primal block.
+ firstBlock->replaceUsesWith(firstPrimalBlock);
- // Process intermediate insts in indexed blocks
- // into array loads/stores.
- //
+ RefPtr<BlockSplitInfo> splitInfo = new BlockSplitInfo();
+ for (auto block : mixedBlocks)
+ if (primalMap.ContainsKey(block))
+ splitInfo->diffBlockMap[as<IRBlock>(primalMap[block])] = as<IRBlock>(diffMap[block]);
+
+ Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlocksInfo;
for (auto block : mixedBlocks)
{
- if (indexRegionMap->getRegion(block) != nullptr)
- processIndexedFwdBlock(block);
+ indexedBlocksInfo[as<IRBlock>(diffMap[block])] = getIndexInfoList(as<IRBlock>(diffMap[block]));
+ indexedBlocksInfo[as<IRBlock>(primalMap[block])] = getIndexInfoList(as<IRBlock>(primalMap[block]));
}
-
- // Swap the first block's occurences out for the first primal block.
- firstBlock->replaceUsesWith(firstPrimalBlock);
for (auto block : mixedBlocks)
block->removeAndDeallocate();
+
+ // Run the three checkpointing passes to hoist/clone primal insts
+ // to the right spots.
+ //
+ {
+ RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(unzippedFunc->getModule());
+ chkPolicy->preparePolicy(func);
+
+ auto chkPrimalsInfo = chkPolicy->processFunc(func, splitInfo);
+ primalsInfo->merge(chkPrimalsInfo);
+
+ primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlocksInfo);
+ }
+
+ return primalsInfo;
}
void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info)
@@ -375,42 +239,6 @@ struct DiffUnzipPass
}
}
- UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
- {
- SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator()));
-
- auto branchInst = as<IRUnconditionalBranch>(block->getTerminator());
- List<IRInst*> phiArgs;
-
- for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++)
- phiArgs.add(branchInst->getArg(ii));
-
- phiArgs.add(arg);
-
- builder->setInsertInto(block);
- switch (branchInst->getOp())
- {
- case kIROp_unconditionalBranch:
- builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer());
- break;
-
- case kIROp_loop:
- builder->emitLoop(
- as<IRLoop>(branchInst)->getTargetBlock(),
- as<IRLoop>(branchInst)->getBreakBlock(),
- as<IRLoop>(branchInst)->getContinueBlock(),
- phiArgs.getCount(),
- phiArgs.getBuffer());
- break;
-
- default:
- break;
- }
-
- branchInst->removeAndDeallocate();
- return phiArgs.getCount() - 1;
- }
-
IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type)
{
builder->setInsertInto(block);
@@ -428,27 +256,7 @@ struct DiffUnzipPass
return addPhiInputParam(builder, block, type);
}
- IRBlock* getBlock(IRInst* inst)
- {
- SLANG_RELEASE_ASSERT(inst);
-
- if (auto block = as<IRBlock>(inst))
- return block;
-
- return getBlock(inst->getParent());
- }
-
- IRInst* getInstInBlock(IRInst* inst)
- {
- SLANG_RELEASE_ASSERT(inst);
-
- if (auto block = as<IRBlock>(inst->getParent()))
- return inst;
-
- return getInstInBlock(inst->getParent());
- }
-
- void lowerIndexedRegions()
+ void lowerIndexedRegions(HoistedPrimalsInfo* primalsInfo)
{
IRBuilder builder(autodiffContext->moduleInst->getModule());
@@ -464,6 +272,7 @@ struct DiffUnzipPass
// Make variable in the top-most block (so it's visible to diff blocks)
info->primalCountLastVar = builder.emitVar(builder.getIntType());
builder.addNameHintDecoration(info->primalCountLastVar, UnownedStringSlice("_pc_last_var"));
+ primalsInfo->storeSet.Add(info->primalCountLastVar);
{
auto primalCondBlock = as<IRUnconditionalBranch>(
@@ -546,6 +355,23 @@ struct DiffUnzipPass
builder.addPrimalValueAccessDecoration(primalCounterLastVal);
builder.addLoopExitPrimalValueDecoration(loopInst, info->diffCountParam, primalCounterLastVal);
+
+ // We'll be manually creating the inversion entries for the counters
+ // TODO: This logic can be moved to the checkpointing alg.
+ //
+ primalsInfo->invertSet.Add(info->diffCountParam);
+ primalsInfo->instsToInvert.Add(incCounterVal);
+ primalsInfo->invertInfoMap[incCounterVal] = InversionInfo(
+ incCounterVal,
+ List<IRInst*>(incCounterVal),
+ List<IRInst*>(info->diffCountParam));
+
+ primalsInfo->invertSet.Add(incCounterVal);
+ primalsInfo->instsToInvert.Add(diffUpdateBlock->getTerminator());
+ primalsInfo->invertInfoMap[diffUpdateBlock->getTerminator()] = InversionInfo(
+ diffUpdateBlock->getTerminator(),
+ List<IRInst*>(diffUpdateBlock->getTerminator()),
+ List<IRInst*>(incCounterVal));
}
// Try to infer maximum possible number of iterations.
@@ -576,282 +402,10 @@ struct DiffUnzipPass
return indices;
}
- void processIndexedFwdBlock(IRBlock* fwdBlock)
- {
- // Grab first primal block.
- IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[fwdBlock->getParent()->getFirstBlock()->getNextBlock()]);
-
- // Scan through instructions and identify those that are used
- // outside the local block.
- //
- IRBlock* primalBlock = as<IRBlock>(primalMap[fwdBlock]);
-
- List<IRInst*> primalInsts;
- for (auto child = primalBlock->getFirstChild(); child; child = child->getNextInst())
- {
- // TODO: This might be a decent place to enforce that each load has a single
- // corresponding store (i.e. that everything is SSAd properly)?
-
- // We're only interested in insts that generate values.
- if (child->getDataType() == nullptr ||
- as<IRVoidType>(child->getDataType()) ||
- as<IRFuncType>(child->getDataType()) ||
- as<IRTypeKind>(child->getDataType()))
- continue;
-
- primalInsts.add(child);
- }
-
- IRBuilder builder(autodiffContext->moduleInst->getModule());
-
- for (auto inst : primalInsts)
- {
- // 1. Check if we need to store inst (is it used in a differential block?)
-
- bool shouldStore = false;
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- IRBlock* useBlock = getBlock(use->getUser());
-
- if (isDifferentialInst(useBlock))
- {
- shouldStore = true;
- break;
- }
- }
-
- if (!shouldStore) continue;
-
- // 2. If we're dealing with a var, we need to locate the value that
- // we actually need to store. We assume everything is SSA form
- // so there must be a single IRStore on this var.
- //
- IRInst* valueToStore = nullptr;
- IRBlock* valueBlock = nullptr;
- IRType* valueType = nullptr;
-
- bool isPtrType = false;
- bool isIntermediateContext = false;
-
- if (auto ptrValueType = as<IRPtrTypeBase>(inst->getDataType()))
- {
- isPtrType = true;
-
- // Find value to store
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- if (auto storeInst = as<IRStore>(use->getUser()))
- {
- // Should not see more than one IRStore
- SLANG_RELEASE_ASSERT(!valueToStore);
- valueToStore = storeInst->getVal();
-
- // Is this the right block to use to determine if the
- // store can have multiple values based on the index?
- //
- valueBlock = as<IRBlock>(storeInst->getParent());
- }
- }
-
- if (as<IRBackwardDiffIntermediateContextType>(ptrValueType->getValueType()))
- {
- isIntermediateContext = true;
-
- // TODO: This should be the parent block of the `call` associated
- // with this context type. The var itself _could_ be in a different place.
- //
- valueBlock = as<IRBlock>(inst->getParent());
- }
-
- valueType = ptrValueType->getValueType();
- }
- else
- {
- isPtrType = false;
- valueToStore = inst;
- valueBlock = as<IRBlock>(inst->getParent());
- valueType = inst->getDataType();
- }
-
- // What do we do for primal vars that are used in the diff block
- // but do not have an IRStore on them? This can happen for 'out'
- // primal variables.
- //
- if (!valueToStore && !isIntermediateContext)
- {
- // For now, we can ignore them since they are used as inputs
- // to 'out' parameters. If their value is every actually used,
- // we will see an IRLoad which will be hoisted accordingly.
- //
- continue;
- }
-
- // Build list of indices that the value's block is affected by.
- List<IndexTrackingInfo*> indices = getIndexInfoList(valueBlock);
-
- // 3. Emit an array to top-level to allocate space.
-
- builder.setInsertBefore(firstPrimalBlock->getTerminator());
-
- IRType* storageType = valueType;
-
- for (auto index : indices)
- {
- SLANG_ASSERT(index->status == IndexTrackingInfo::CountStatus::Static);
- SLANG_ASSERT(index->maxIters >= 0);
-
- storageType = builder.getArrayType(
- storageType,
- builder.getIntValue(
- builder.getUIntType(),
- index->maxIters + 1));
- }
-
- // Reverse the list since the indices need to be
- // emitted in reverse order.
- //
- indices.reverse();
-
- auto storageVar = builder.emitVar(storageType);
- if (isIntermediateContext)
- builder.addBackwardDerivativePrimalContextDecoration(
- storageVar,
- storageVar);
-
- // 4. Store current value into the array and replace uses with a load.
- // If an index is missing, use the 'last' value of the primal index.
-
- {
- if (!isIntermediateContext)
- setInsertAfterOrdinaryInst(&builder, valueToStore);
- else
- setInsertAfterOrdinaryInst(&builder, inst);
-
- IRInst* storeAddr = storageVar;
- IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType();
-
- for (auto index : indices)
- {
- currType = as<IRArrayType>(currType)->getElementType();
-
- storeAddr = builder.emitElementAddress(
- builder.getPtrType(currType),
- storeAddr,
- index->primalCountParam);
- }
-
- if (!isIntermediateContext)
- builder.emitStore(storeAddr, valueToStore);
- else
- {
- List<IRUse*> primalUses;
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- if (!isDifferentialInst(getBlock(use->getUser())))
- primalUses.add(use);
- }
-
- for (auto use : primalUses)
- use->set(storeAddr);
- }
- }
-
-
- // 5. Replace uses in differential blocks with loads from the array.
- List<IRInst*> instsToTag;
- {
- List<IRUse*> diffUses;
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- if (as<IRDecoration>(use->getUser()))
- {
- if (!as<IRLoopExitPrimalValueDecoration>(use->getUser()) &&
- !as<IRBackwardDerivativePrimalContextDecoration>(use->getUser()))
- continue;
- }
-
- IRBlock* useBlock = getBlock(use->getUser());
- if (useBlock && isDifferentialInst(useBlock))
- diffUses.add(use);
- }
-
- for (auto use : diffUses)
- {
- IRBlock* useBlock = getBlock(use->getUser());
- setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
-
- IRInst* loadAddr = storageVar;
- IRType* currType = as<IRPtrTypeBase>(storageVar->getDataType())->getValueType();
-
- // Enumerate use block regions.
- // TODO: Probably a good idea to do this ahead of time for
- // all blocks.
- //
- List<IndexTrackingInfo*> useBlockIndices = getIndexInfoList(useBlock);
-
- for (auto index : indices)
- {
- currType = as<IRArrayType>(currType)->getElementType();
- if (useBlockIndices.contains(index))
- {
- // If the use-block is under the same region, use the
- // differential counter variable
- //
- auto diffCounterCurrValue = index->diffCountParam;
-
- loadAddr = builder.emitElementAddress(
- builder.getPtrType(currType),
- loadAddr,
- diffCounterCurrValue);
- }
- else
- {
- // If the use-block is outside this region, use the
- // last available value (by indexing with primal counter minus 1)
- //
- auto primalCounterCurrValue = builder.emitLoad(index->primalCountLastVar);
- auto primalCounterLastValue = builder.emitSub(
- primalCounterCurrValue->getDataType(),
- primalCounterCurrValue,
- builder.getIntValue(builder.getIntType(), 1));
-
- instsToTag.add(primalCounterCurrValue);
- instsToTag.add(primalCounterLastValue);
-
- loadAddr = builder.emitElementAddress(
- builder.getPtrType(currType),
- loadAddr,
- primalCounterLastValue);
- }
-
- instsToTag.add(loadAddr);
- }
-
- if (!isPtrType)
- {
- auto loadedValue = builder.emitLoad(loadAddr);
- instsToTag.add(loadedValue);
-
- use->set(loadedValue);
- }
- else
- {
- use->set(loadAddr);
- }
- }
- }
-
- for (auto instToTag : instsToTag)
- {
- builder.addPrimalValueAccessDecoration(instToTag);
- builder.markInstAsPrimal(instToTag);
- }
- }
- }
-
IRFunc* extractPrimalFunc(
IRFunc* func,
IRFunc* originalFunc,
+ HoistedPrimalsInfo* primalsInfo,
ParameterBlockTransposeInfo& paramInfo,
IRInst*& intermediateType);
@@ -973,6 +527,13 @@ struct DiffUnzipPass
auto primalArg = lookupPrimalInst(arg);
auto diffArg = lookupDiffInst(arg);
+ if (auto primalVar = as<IRVar>(primalArg))
+ {
+ primalArg = diffBuilder->emitVar(as<IRPtrTypeBase>(primalVar->getDataType())->getValueType());
+ if (auto storeUse = findUniqueStoredVal(primalVar))
+ diffBuilder->emitStore(primalArg, as<IRStore>(storeUse->getUser())->getVal());
+ }
+
// If arg is a mixed differential (pair), it should have already been split.
SLANG_ASSERT(primalArg);
SLANG_ASSERT(diffArg);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 517b9e3ea..a3a7e4b77 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -595,6 +595,7 @@ bool canTypeBeStored(IRInst* type)
case kIROp_FloatType:
case kIROp_VectorType:
case kIROp_MatrixType:
+ case kIROp_AttributedType:
return true;
default:
return false;
@@ -904,7 +905,7 @@ struct AutoDiffPass : public InstPassBase
else
{
IntermediateContextTypeDifferentialInfo diffFieldTypeInfo;
- diffTypes.TryGetValue(field->getDataType(), diffFieldTypeInfo);
+ diffTypes.TryGetValue(field->getFieldType(), diffFieldTypeInfo);
diffFieldWitness = diffFieldTypeInfo.diffWitness;
}
if (diffFieldWitness)
@@ -1429,4 +1430,99 @@ bool finalizeAutoDiffPass(IRModule* module)
return false;
}
+IRBlock* getBlock(IRInst* inst)
+{
+ SLANG_RELEASE_ASSERT(inst);
+
+ if (auto block = as<IRBlock>(inst))
+ return block;
+
+ return getBlock(inst->getParent());
+}
+
+IRInst* getInstInBlock(IRInst* inst)
+{
+ SLANG_RELEASE_ASSERT(inst);
+
+ if (auto block = as<IRBlock>(inst->getParent()))
+ return inst;
+
+ return getInstInBlock(inst->getParent());
+}
+
+UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
+{
+ SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator()));
+
+ auto branchInst = as<IRUnconditionalBranch>(block->getTerminator());
+ List<IRInst*> phiArgs;
+
+ for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++)
+ phiArgs.add(branchInst->getArg(ii));
+
+ phiArgs.add(arg);
+
+ builder->setInsertInto(block);
+ switch (branchInst->getOp())
+ {
+ case kIROp_unconditionalBranch:
+ builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer());
+ break;
+
+ case kIROp_loop:
+ builder->emitLoop(
+ as<IRLoop>(branchInst)->getTargetBlock(),
+ as<IRLoop>(branchInst)->getBreakBlock(),
+ as<IRLoop>(branchInst)->getContinueBlock(),
+ phiArgs.getCount(),
+ phiArgs.getBuffer());
+ break;
+
+ default:
+ SLANG_UNEXPECTED("Unexpected branch-type for phi replacement");
+ }
+
+ branchInst->removeAndDeallocate();
+ return phiArgs.getCount() - 1;
+}
+
+IRUse* findUniqueStoredVal(IRVar* var)
+{
+ if (isDerivativeContextVar(var))
+ {
+ IRUse* primalCallUse = nullptr;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ {
+ if (auto callInst = as<IRCall>(use->getUser()))
+ {
+ // Should not see more than one IRCall. If we do
+ // we'll need to pick the primal call.
+ //
+ SLANG_RELEASE_ASSERT(!primalCallUse);
+ primalCallUse = use;
+ }
+ }
+ return primalCallUse;
+ }
+ else
+ {
+ IRUse* storeUse = nullptr;
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ {
+ if (auto storeInst = as<IRStore>(use->getUser()))
+ {
+ // Should not see more than one IRStore
+ SLANG_RELEASE_ASSERT(!storeUse);
+ storeUse = use;
+ }
+ }
+ return storeUse;
+ }
+}
+
+bool isDerivativeContextVar(IRVar* var)
+{
+ return var->findDecoration<IRBackwardDerivativePrimalContextDecoration>();
+}
+
}
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index e7a841323..d49babc52 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -298,6 +298,8 @@ bool processAutodiffCalls(
bool finalizeAutoDiffPass(IRModule* module);
+// Utility methods
+
void stripDerivativeDecorations(IRInst* inst);
bool isBackwardDifferentiableFunc(IRInst* func);
@@ -322,4 +324,15 @@ inline bool isRelevantDifferentialPair(IRType* type)
return false;
}
+IRBlock* getBlock(IRInst* inst);
+
+IRInst* getInstInBlock(IRInst* inst);
+
+UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg);
+
+IRUse* findUniqueStoredVal(IRVar* var);
+
+bool isDerivativeContextVar(IRVar* var);
+
+
};
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 0f5c36dcb..9c4c1f4e2 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -650,6 +650,7 @@ struct IRBackwardDerivativePrimalContextDecoration : IRDecoration
};
IR_LEAF_ISA(BackwardDerivativePrimalContextDecoration)
+ IRUse primalContextVar;
IRInst* getBackwardDerivativePrimalContextVar() { return getOperand(0); }
};
@@ -703,6 +704,8 @@ struct IRLoopExitPrimalValueDecoration : IRDecoration
};
IR_LEAF_ISA(LoopExitPrimalValueDecoration)
+ IRUse target;
+ IRUse exitVal;
IRInst* getTargetInst() { return getOperand(0); }
IRInst* getLoopExitValInst() { return getOperand(1); }
};
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index d8246edae..9b50b9c30 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -802,6 +802,34 @@ IRInst* readVar(
return readVarRec(context, blockInfo, var);
}
+void collectInstsToRemove(
+ ConstructSSAContext* context,
+ IRBlock* block)
+{
+ IRInst* next = nullptr;
+ for (auto ii = block->getFirstInst(); ii; ii = next)
+ {
+ next = ii->getNextInst();
+
+ switch (ii->getOp())
+ {
+ default:
+ // Ordinary instruction -> leave as-is
+ break;
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ {
+ auto ptrArg = ii->getOperand(0);
+ if (auto var = asPromotableVarAccessChain(context, ptrArg))
+ {
+ context->instsToRemove.add(ii);
+ }
+ }
+ break;
+ }
+ }
+}
+
void processBlock(
ConstructSSAContext* context,
IRBlock* block,
@@ -877,19 +905,6 @@ void processBlock(
}
}
break;
-
- case kIROp_GetElementPtr:
- case kIROp_FieldAddress:
- {
- auto ptrArg = ii->getOperand(0);
- if (auto var = asPromotableVarAccessChain(context, ptrArg))
- {
- context->instsToRemove.add(ii);
- }
- }
- break;
-
-
}
}
@@ -1078,6 +1093,10 @@ bool constructSSA(ConstructSSAContext* context)
context->blockInfos.Add(bb, blockInfo);
}
+
+ for(auto bb : globalVal->getBlocks())
+ collectInstsToRemove(context, bb);
+
for(auto bb : globalVal->getBlocks())
{
auto blockInfo = * context->blockInfos.TryGetValue(bb);