From 71efd949fa5276e2464416fcf237f8fd2c486281 Mon Sep 17 00:00:00 2001
From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>
Date: Wed, 15 Mar 2023 22:26:58 -0400
Subject: AD: Primal-Hoisting Rework + Checkpoint Policy Framework (#2702)
---
build/visual-studio/slang/slang.vcxproj | 4 +
build/visual-studio/slang/slang.vcxproj.filters | 12 +
source/slang/slang-ir-autodiff-fwd.cpp | 1 +
source/slang/slang-ir-autodiff-primal-hoist.cpp | 674 +++++++++++++++++++
source/slang/slang-ir-autodiff-primal-hoist.h | 264 ++++++++
source/slang/slang-ir-autodiff-region.cpp | 56 ++
source/slang/slang-ir-autodiff-region.h | 119 ++++
source/slang/slang-ir-autodiff-rev.cpp | 110 +---
source/slang/slang-ir-autodiff-transpose.h | 724 +++++++++++++++++----
source/slang/slang-ir-autodiff-unzip.cpp | 34 +-
source/slang/slang-ir-autodiff-unzip.h | 563 ++--------------
source/slang/slang-ir-autodiff.cpp | 98 ++-
source/slang/slang-ir-autodiff.h | 13 +
source/slang/slang-ir-insts.h | 3 +
source/slang/slang-ir-ssa.cpp | 45 +-
tests/autodiff/reverse-control-flow-1.slang | 42 ++
.../reverse-control-flow-1.slang.expected.txt | 6 +
tests/autodiff/reverse-control-flow.slang | 42 --
.../reverse-control-flow.slang.expected.txt | 6 -
tests/autodiff/reverse-inout-param-1.slang | 47 ++
.../reverse-inout-param-1.slang.expected.txt | 6 +
tests/autodiff/reverse-inout-param.slang | 47 --
.../reverse-inout-param.slang.expected.txt | 6 -
23 files changed, 2068 insertions(+), 854 deletions(-)
create mode 100644 source/slang/slang-ir-autodiff-primal-hoist.cpp
create mode 100644 source/slang/slang-ir-autodiff-primal-hoist.h
create mode 100644 source/slang/slang-ir-autodiff-region.cpp
create mode 100644 source/slang/slang-ir-autodiff-region.h
create mode 100644 tests/autodiff/reverse-control-flow-1.slang
create mode 100644 tests/autodiff/reverse-control-flow-1.slang.expected.txt
delete mode 100644 tests/autodiff/reverse-control-flow.slang
delete mode 100644 tests/autodiff/reverse-control-flow.slang.expected.txt
create mode 100644 tests/autodiff/reverse-inout-param-1.slang
create mode 100644 tests/autodiff/reverse-inout-param-1.slang.expected.txt
delete mode 100644 tests/autodiff/reverse-inout-param.slang
delete mode 100644 tests/autodiff/reverse-inout-param.slang.expected.txt
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index e97d6a2b1..879da8cdf 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -346,7 +346,9 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
+
+
@@ -534,6 +536,8 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
+
+
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 64267db4b..2b223b78d 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -144,9 +144,15 @@
Header Files
+
+ Header Files
+
Header Files
+
+ Header Files
+
Header Files
@@ -704,6 +710,12 @@
Source Files
+
+ Source Files
+
+
+ Source Files
+
Source Files
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 7057a5835..247c3ddde 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -308,6 +308,7 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst*
// differentiated, don't differentiate the inst
//
auto primalConstructType = (IRType*)findOrTranscribePrimalInst(builder, origConstruct->getDataType());
+ // TODO: Need to update this to generate derivatives on a per-key basis
if (auto diffConstructType = differentiateType(builder, primalConstructType))
{
UCount operandCount = origConstruct->getOperandCount();
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
new file mode 100644
index 000000000..793a8ff07
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -0,0 +1,674 @@
+#include "slang-ir-autodiff-primal-hoist.h"
+#include "slang-ir-autodiff-region.h"
+
+namespace Slang
+{
+
+bool containsOperand(IRInst* inst, IRInst* operand)
+{
+ for (UIndex ii = 0; ii < inst->getOperandCount(); ii++)
+ if (inst->getOperand(ii) == operand)
+ return true;
+
+ return false;
+}
+
+RefPtr AutodiffCheckpointPolicyBase::processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* splitInfo)
+{
+ RefPtr checkpointInfo = new CheckpointSetInfo();
+
+ RefPtr domTree = computeDominatorTree(func);
+
+ List workList;
+ HashSet processedUses;
+
+ HashSet usesToReplace;
+
+ auto addPrimalOperandsToWorkList = [&](IRInst* inst)
+ {
+ UIndex opIndex = 0;
+ for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
+ {
+ if (!operand->get()->findDecoration() &&
+ !as(operand->get()) &&
+ !as(operand->get()) &&
+ !(as(operand->get()->getParent())) &&
+ !getBlock(operand->get())->findDecoration())
+ workList.add(operand);
+ }
+
+ // Is the type itself computed within our function?
+ // If so, we'll need to consider that too (this is for existential types, specialize insts, etc)
+ // TODO: We might not really need to query the checkpointing algorithm for these
+ // since they _have_ to be classified as 'recompute'
+ //
+ if (inst->getDataType() && (getParentFunc(inst->getDataType()) == func))
+ {
+ if (!getBlock(inst->getDataType())->findDecoration())
+ workList.add(&inst->typeUse);
+ }
+ };
+
+ // Populate recompute/store/invert sets with insts, by applying the policy
+ // to them.
+ //
+ for (auto block : func->getBlocks())
+ {
+ // Skip parameter block.
+ if (block == func->getFirstBlock())
+ continue;
+
+ if (!block->findDecoration())
+ continue;
+
+ for (auto child : block->getChildren())
+ {
+ // Special case: Ignore the primals used to construct the return pair.
+ if (as(child) &&
+ as(child->firstUse->getUser()))
+ {
+ // quick check
+ SLANG_RELEASE_ASSERT(child->firstUse->nextUse == nullptr);
+ continue;
+ }
+
+ addPrimalOperandsToWorkList(child);
+
+ // We'll be conservative with the decorations we consider as differential uses
+ // of a primal inst, in order to avoid weird behaviour with some decorations
+ //
+ for (auto decoration : child->getDecorations())
+ {
+ if (auto primalCtxDecoration = as(decoration))
+ workList.add(&primalCtxDecoration->primalContextVar);
+ else if (auto loopExitDecoration = as(decoration))
+ workList.add(&loopExitDecoration->exitVal);
+ }
+ }
+
+ addPrimalOperandsToWorkList(block->getTerminator());
+ }
+
+ while (workList.getCount() > 0)
+ {
+ auto use = workList.getLast();
+ workList.removeLast();
+
+ if (processedUses.Contains(use))
+ continue;
+
+ processedUses.Add(use);
+
+ HoistResult result = this->classify(use);
+
+ if (result.mode == HoistResult::Mode::Store)
+ {
+ SLANG_ASSERT(!checkpointInfo->recomputeSet.Contains(result.instToStore));
+ checkpointInfo->storeSet.Add(result.instToStore);
+ }
+ else if (result.mode == HoistResult::Mode::Recompute)
+ {
+ SLANG_ASSERT(!checkpointInfo->storeSet.Contains(result.instToRecompute));
+ checkpointInfo->recomputeSet.Add(result.instToRecompute);
+
+ if (use->getUser()->findDecoration())
+ usesToReplace.Add(use);
+
+ if (auto param = as(result.instToRecompute))
+ {
+ // Add in the branch-args of every predecessor block.
+ auto paramBlock = as(param->getParent());
+ UIndex paramIndex = 0;
+ for (auto _param : paramBlock->getParams())
+ {
+ if (_param == param) break;
+ paramIndex ++;
+ }
+
+ for (auto predecessor : paramBlock->getPredecessors())
+ {
+ // If we hit this, the checkpoint policy is trying to recompute
+ // values across a loop region boundary (we don't currently support this,
+ // and in general this is quite inefficient in both compute & memory)
+ //
+ SLANG_RELEASE_ASSERT(!domTree->dominates(paramBlock, predecessor));
+
+ auto branchInst = as(predecessor->getTerminator());
+ SLANG_ASSERT(branchInst->getOperandCount() > paramIndex);
+
+ workList.add(&branchInst->getOperands()[paramIndex]);
+ }
+ }
+ else
+ {
+ if (auto var = as(result.instToRecompute))
+ {
+ IRUse* storeUse = findUniqueStoredVal(var);
+ if (!storeUse)
+ workList.add(storeUse);
+ }
+ else
+ {
+ addPrimalOperandsToWorkList(result.instToRecompute);
+ }
+ }
+ }
+ else if (result.mode == HoistResult::Mode::Invert)
+ {
+ auto instToInvert = result.inversionInfo.instToInvert;
+
+ SLANG_RELEASE_ASSERT(containsOperand(instToInvert, use->getUser()));
+ SLANG_RELEASE_ASSERT(result.inversionInfo.targetInsts.contains(use->getUser()));
+
+ if (use->getUser()->findDecoration())
+ usesToReplace.Add(use);
+
+ checkpointInfo->invertSet.Add(instToInvert);
+
+ if (checkpointInfo->invInfoMap.ContainsKey(instToInvert))
+ {
+ List currOperands = checkpointInfo->invInfoMap[instToInvert].GetValue().requiredOperands;
+ for (Index ii = 0; ii < result.inversionInfo.requiredOperands.getCount(); ii++)
+ {
+ SLANG_RELEASE_ASSERT(result.inversionInfo.requiredOperands[ii] == currOperands[ii]);
+ }
+ }
+ else
+ checkpointInfo->invInfoMap[instToInvert] = result.inversionInfo;
+ }
+ }
+
+ return applyCheckpointSet(checkpointInfo, func, splitInfo, usesToReplace);
+}
+
+void applyToInst(
+ IRBuilder* builder,
+ CheckpointSetInfo* checkpointInfo,
+ HoistedPrimalsInfo* hoistInfo,
+ IROutOfOrderCloneContext* cloneCtx,
+ IRInst* inst)
+{
+ // Early-out..
+ if (checkpointInfo->storeSet.Contains(inst))
+ {
+ hoistInfo->storeSet.Add(inst);
+ return;
+ }
+
+ bool isInstRecomputed = checkpointInfo->recomputeSet.Contains(inst);
+ if (isInstRecomputed)
+ {
+ if (as(inst))
+ {
+ // Can completely ignore first block parameters
+ if (getBlock(inst) != getBlock(inst)->getParent()->getFirstBlock())
+ {
+ // TODO: We would need to clone in the control-flow for each region (without nested loops)
+ // prior to this, and then hoist this parameter into the within-region block, otherwise
+ // this parameter will not be visible to transposed insts.
+ // This will also include adding an extra case to 'ensurePrimalAvailability': if both insts
+ // are withing the _same_ indexed region, skip the indexed store/load and use a simple var.
+ //
+ SLANG_UNIMPLEMENTED_X("Parameter recompute is not currently supported");
+ }
+ }
+ else
+ {
+ hoistInfo->recomputeSet.Add(cloneCtx->cloneInstOutOfOrder(builder, inst));
+ }
+ }
+
+ bool isInstInverted = checkpointInfo->invertSet.Contains(inst);
+ if (isInstInverted)
+ {
+ InversionInfo info = checkpointInfo->invInfoMap[inst];
+ auto clonedInstToInvert = cloneCtx->cloneInstOutOfOrder(builder, info.instToInvert);
+
+ // Process operand set for the inverse inst.
+ List newOperands;
+ for (auto operand : info.requiredOperands)
+ {
+ if (cloneCtx->cloneEnv.mapOldValToNew.ContainsKey(operand))
+ newOperands.add(cloneCtx->cloneEnv.mapOldValToNew[operand]);
+ else
+ newOperands.add(operand);
+ }
+
+ info.requiredOperands = newOperands;
+
+ hoistInfo->invertInfoMap[clonedInstToInvert] = info;
+ hoistInfo->instsToInvert.Add(clonedInstToInvert);
+ hoistInfo->invertSet.Add(cloneCtx->cloneInstOutOfOrder(builder, inst));
+ }
+}
+
+RefPtr applyCheckpointSet(
+ CheckpointSetInfo* checkpointInfo,
+ IRGlobalValueWithCode* func,
+ BlockSplitInfo* splitInfo,
+ HashSet pendingUses)
+{
+ RefPtr hoistInfo = new HoistedPrimalsInfo();
+
+ RefPtr cloneCtx = new IROutOfOrderCloneContext();
+
+ for (auto use : pendingUses)
+ cloneCtx->pendingUses.Add(use);
+
+ // Populate the clone context with all the primal uses that we may need to replace with
+ // cloned versions. That way any insts we clone into the diff block will automatically replace
+ // their uses.
+ //
+ auto addPrimalUsesToCloneContext = [&](IRInst* inst)
+ {
+ UIndex opIndex = 0;
+ for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
+ {
+ if (!operand->get()->findDecoration())
+ cloneCtx->pendingUses.Add(operand);
+ }
+ };
+
+ // Go back over the insts and move/clone them accoridngly.
+ for (auto block : func->getBlocks())
+ {
+ // Skip parameter block.
+ if (block == func->getFirstBlock())
+ continue;
+
+ if (block->findDecoration())
+ continue;
+
+ auto diffBlock = as(splitInfo->diffBlockMap[block]);
+
+ auto firstDiffInst = as(splitInfo->diffBlockMap[block])->getFirstOrdinaryInst();
+
+ IRBuilder builder(func->getModule());
+
+ UIndex ii = 0;
+ for (auto param : block->getParams())
+ {
+ builder.setInsertBefore(diffBlock->getFirstOrdinaryInst());
+
+ // Apply checkpoint rule to the parameter itself.
+ applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, param);
+
+ // Copy primal branch-arg for predecessor blocks.
+ HashSet predecessorSet;
+ for (auto predecessor : block->getPredecessors())
+ {
+ if (predecessorSet.Contains(predecessor))
+ continue;
+
+ predecessorSet.Add(predecessor);
+
+ auto diffPredecessor = as(splitInfo->diffBlockMap[block]);
+
+ if (checkpointInfo->recomputeSet.Contains(param))
+ addPhiOutputArg(&builder,
+ diffPredecessor,
+ as(predecessor->getTerminator())->getArg(ii));
+
+ if (checkpointInfo->invertSet.Contains(param))
+ addPhiOutputArg(&builder,
+ diffPredecessor,
+ as(predecessor->getTerminator())->getArg(ii));
+ }
+
+ ii++;
+ }
+
+ for (auto child : block->getChildren())
+ {
+ builder.setInsertBefore(firstDiffInst);
+
+ applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child);
+ }
+ }
+
+ return hoistInfo;
+}
+
+IRType* getTypeForLocalStorage(
+ IRBuilder* builder,
+ IRType* storageType,
+ List defBlockIndices)
+{
+ for (auto index : defBlockIndices)
+ {
+ SLANG_ASSERT(index->status == IndexTrackingInfo::CountStatus::Static);
+ SLANG_ASSERT(index->maxIters >= 0);
+
+ storageType = builder->getArrayType(
+ storageType,
+ builder->getIntValue(
+ builder->getUIntType(),
+ index->maxIters + 1));
+ }
+
+ return storageType;
+}
+
+IRVar* emitIndexedLocalVar(
+ IRBlock* varBlock,
+ IRType* baseType,
+ List defBlockIndices)
+{
+ SLANG_RELEASE_ASSERT(!as(baseType));
+
+ IRBuilder varBuilder(varBlock->getModule());
+ varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst());
+
+ IRType* varType = getTypeForLocalStorage(&varBuilder, baseType, defBlockIndices);
+
+ auto var = varBuilder.emitVar(varType);
+ varBuilder.emitStore(var, varBuilder.emitDefaultConstruct(varType));
+
+ return var;
+}
+
+IRInst* emitIndexedStoreAddressForVar(
+ IRBuilder* builder,
+ IRVar* localVar,
+ List defBlockIndices)
+{
+ IRInst* storeAddr = localVar;
+ IRType* currType = as(localVar->getDataType())->getValueType();
+
+ for (auto index : defBlockIndices)
+ {
+ currType = as(currType)->getElementType();
+
+ storeAddr = builder->emitElementAddress(
+ builder->getPtrType(currType),
+ storeAddr,
+ index->primalCountParam);
+ }
+
+ return storeAddr;
+}
+
+
+IRInst* emitIndexedLoadAddressForVar(
+ IRBuilder* builder,
+ IRVar* localVar,
+ List defBlockIndices,
+ List useBlockIndices)
+{
+ IRInst* loadAddr = localVar;
+ IRType* currType = as(localVar->getDataType())->getValueType();
+
+ for (auto index : defBlockIndices)
+ {
+ currType = as(currType)->getElementType();
+ if (useBlockIndices.contains(index))
+ {
+ // If the use-block is under the same region, use the
+ // differential counter variable
+ //
+ auto diffCounterCurrValue = index->diffCountParam;
+
+ loadAddr = builder->emitElementAddress(
+ builder->getPtrType(currType),
+ loadAddr,
+ diffCounterCurrValue);
+ }
+ else
+ {
+ // If the use-block is outside this region, use the
+ // last available value (by indexing with primal counter minus 1)
+ //
+ auto primalCounterCurrValue = builder->emitLoad(index->primalCountLastVar);
+ auto primalCounterLastValue = builder->emitSub(
+ primalCounterCurrValue->getDataType(),
+ primalCounterCurrValue,
+ builder->getIntValue(builder->getIntType(), 1));
+
+ loadAddr = builder->emitElementAddress(
+ builder->getPtrType(currType),
+ loadAddr,
+ primalCounterLastValue);
+ }
+ }
+
+ return loadAddr;
+}
+
+IRVar* storeIndexedValue(
+ IRBuilder* builder,
+ IRBlock* defaultVarBlock,
+ IRInst* instToStore,
+ List defBlockIndices)
+{
+ IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices);
+
+ IRInst* addr = emitIndexedStoreAddressForVar(builder, localVar, defBlockIndices);
+
+ builder->emitStore(addr, instToStore);
+
+ return localVar;
+}
+
+IRInst* loadIndexedValue(
+ IRBuilder* builder,
+ IRVar* localVar,
+ List defBlockIndices,
+ List useBlockIndices)
+{
+ IRInst* addr = emitIndexedLoadAddressForVar(builder, localVar, defBlockIndices, useBlockIndices);
+
+ return builder->emitLoad(addr);
+}
+
+bool areIndicesEqual(
+ List indicesA,
+ List indicesB)
+{
+ if (indicesA.getCount() != indicesB.getCount())
+ return false;
+
+ for (Index ii = 0; ii < indicesA.getCount(); ii++)
+ {
+ if (indicesA[ii] != indicesB[ii])
+ return false;
+ }
+
+ return true;
+}
+
+bool areIndicesSubsetOf(
+ List indicesA,
+ List indicesB)
+{
+ if (indicesA.getCount() > indicesB.getCount())
+ return false;
+
+ for (Index ii = 0; ii < indicesA.getCount(); ii++)
+ {
+ if (indicesA[ii] != indicesB[ii])
+ return false;
+ }
+
+ return true;
+}
+
+
+bool isDifferentialBlock(IRBlock* block)
+{
+ return block->findDecoration();
+}
+
+RefPtr ensurePrimalAvailability(
+ HoistedPrimalsInfo* hoistInfo,
+ IRGlobalValueWithCode* func,
+ Dictionary> indexedBlockInfo)
+{
+ RefPtr domTree = computeDominatorTree(func);
+
+ IRBuilder builder(func->getModule());
+ IRBlock* defaultVarBlock = func->getFirstBlock()->getNextBlock();
+
+ SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock));
+
+ HashSet processedStoreSet;
+
+ // TODO: Also ensure availability of everything in the recompute set (for proper recompute support)
+ for (auto instToStore : hoistInfo->storeSet)
+ {
+ IRBlock* defBlock = nullptr;
+ if (auto ptrInst = as(instToStore->getDataType()))
+ {
+ auto varInst = as(instToStore);
+ auto storeUse = findUniqueStoredVal(varInst);
+
+ defBlock = getBlock(storeUse->getUser());
+ }
+ else
+ defBlock = getBlock(instToStore);
+
+ SLANG_RELEASE_ASSERT(defBlock);
+
+ List outOfScopeUses;
+ for (auto use = instToStore->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ // Only consider uses in differential blocks.
+ // This method is not responsible for other blocks.
+ //
+ IRBlock* userBlock = getBlock(use->getUser());
+ if (userBlock->findDecoration())
+ {
+ if (!domTree->dominates(defBlock, userBlock))
+ {
+ outOfScopeUses.add(use);
+ }
+ else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock]))
+ {
+ outOfScopeUses.add(use);
+ }
+ else if (indexedBlockInfo[defBlock].GetValue().getCount() > 0 &&
+ !isDifferentialBlock(defBlock))
+ {
+ outOfScopeUses.add(use);
+ }
+ else if (as(instToStore->getDataType()) &&
+ !isDifferentialBlock(defBlock))
+ {
+ outOfScopeUses.add(use);
+ }
+ }
+
+ use = nextUse;
+ }
+
+ if (outOfScopeUses.getCount() == 0)
+ {
+ processedStoreSet.Add(instToStore);
+ continue;
+ }
+
+ if (auto ptrInst = as(instToStore->getDataType()))
+ {
+
+ IRVar* varToStore = as(instToStore);
+ SLANG_RELEASE_ASSERT(varToStore);
+
+ auto storeUse = findUniqueStoredVal(varToStore);
+
+ List defBlockIndices = indexedBlockInfo[defBlock];
+
+ bool isIndexedStore = (storeUse && defBlockIndices.getCount() > 0);
+
+ // TODO: There's a slight hackiness here. (Ideally we might just want to emit
+ // additional vars when splitting a call)
+ //
+ if (!isIndexedStore && isDerivativeContextVar(varToStore))
+ {
+ varToStore->insertBefore(defaultVarBlock->getFirstOrdinaryInst());
+ processedStoreSet.Add(varToStore);
+ continue;
+ }
+
+ setInsertAfterOrdinaryInst(&builder, getInstInBlock(storeUse->getUser()));
+
+ IRVar* localVar = storeIndexedValue(
+ &builder,
+ defaultVarBlock,
+ builder.emitLoad(varToStore),
+ defBlockIndices);
+
+ for (auto use : outOfScopeUses)
+ {
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+
+ List useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
+
+ IRInst* loadAddr = emitIndexedLoadAddressForVar(&builder, localVar, defBlockIndices, useBlockIndices);
+ builder.replaceOperand(use, loadAddr);
+ }
+
+ processedStoreSet.Add(localVar);
+ }
+ else
+ {
+ setInsertAfterOrdinaryInst(&builder, instToStore);
+
+ List defBlockIndices = indexedBlockInfo[defBlock];
+ auto localVar = storeIndexedValue(&builder, defaultVarBlock, instToStore, defBlockIndices);
+
+ for (auto use : outOfScopeUses)
+ {
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+
+ List useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
+ builder.replaceOperand(use, loadIndexedValue(&builder, localVar, defBlockIndices, useBlockIndices));
+ }
+
+ processedStoreSet.Add(localVar);
+ }
+ }
+
+ // Replace the old store set with the processed onne one.
+ hoistInfo->storeSet = processedStoreSet;
+
+ return hoistInfo;
+}
+
+void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode*)
+{
+ // Do nothing.. This is an (almost) always-store policy.
+ return;
+}
+
+HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
+{
+ // Store all that we can.. by default, classify will only be called on relevant differential
+ // uses (or on uses in a 'recompute' inst)
+ //
+ if (auto var = as(use->get()))
+ {
+ if (auto spec = as(as(var->getDataType())->getValueType()))
+ {
+ for (UInt i = 0; i < spec->getArgCount(); i++)
+ {
+ if (!canTypeBeStored(spec->getArg(i)->getDataType()))
+ return HoistResult::recompute(use->get());
+ }
+ return HoistResult::store(use->get());
+ }
+ else // if (canTypeBeStored(as(var->getDataType())->getValueType()));
+ {
+ return HoistResult::store(use->get());
+ }
+ }
+ else
+ {
+ if (canTypeBeStored(use->get()->getDataType()))
+ return HoistResult::store(use->get());
+ else
+ return HoistResult::recompute(use->get());
+ }
+}
+
+};
\ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
new file mode 100644
index 000000000..dc85942f6
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -0,0 +1,264 @@
+// slang-ir-autodiff-primal-hoist.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-autodiff.h"
+#include "slang-ir-autodiff-region.h"
+#include "slang-ir-dominators.h"
+
+
+namespace Slang
+{
+ struct IROutOfOrderCloneContext : public RefObject
+ {
+ IRCloneEnv cloneEnv;
+ HashSet pendingUses;
+
+ IRInst* cloneInstOutOfOrder(IRBuilder* builder, IRInst* inst)
+ {
+ IRInst* clonedInst = cloneInst(&cloneEnv, builder, inst);
+
+ UInt operandCount = clonedInst->getOperandCount();
+ for (UInt ii = 0; ii < operandCount; ++ii)
+ {
+ auto oldOperand = inst->getOperand(ii);
+ auto newOperand = clonedInst->getOperand(ii);
+
+ if (oldOperand == newOperand)
+ pendingUses.Add(&clonedInst->getOperands()[ii]);
+ }
+
+ for (auto use = inst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ if (pendingUses.Contains(use))
+ {
+ pendingUses.Remove(use);
+ builder->replaceOperand(use, clonedInst);
+ }
+
+ use = nextUse;
+ }
+
+ return clonedInst;
+ }
+ };
+
+ struct InversionInfo
+ {
+ IRInst* instToInvert;
+ List requiredOperands;
+ List targetInsts;
+
+ InversionInfo(
+ IRInst* instToInvert,
+ List requiredOperands,
+ List targetInsts) :
+ instToInvert(instToInvert),
+ requiredOperands(requiredOperands),
+ targetInsts(targetInsts)
+ { }
+
+ InversionInfo() : instToInvert(nullptr)
+ { }
+
+ InversionInfo applyMap(IRCloneEnv* env)
+ {
+ InversionInfo newInfo;
+ if (env->mapOldValToNew.ContainsKey(instToInvert))
+ newInfo.instToInvert = env->mapOldValToNew[instToInvert];
+
+ for (auto inst : requiredOperands)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newInfo.requiredOperands.add(env->mapOldValToNew[inst]);
+
+ for (auto inst : targetInsts)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newInfo.targetInsts.add(env->mapOldValToNew[inst]);
+
+ return newInfo;
+ }
+ };
+
+ struct HoistedPrimalsInfo : public RefObject
+ {
+ HashSet storeSet;
+ HashSet recomputeSet;
+ HashSet invertSet;
+
+ HashSet instsToInvert;
+
+ Dictionary invertInfoMap;
+
+ RefPtr applyMap(IRCloneEnv* env)
+ {
+ RefPtr newPrimalsInfo = new HoistedPrimalsInfo();
+
+ for (auto inst : this->storeSet)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newPrimalsInfo->storeSet.Add(env->mapOldValToNew[inst]);
+
+ for (auto inst : this->recomputeSet)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newPrimalsInfo->recomputeSet.Add(env->mapOldValToNew[inst]);
+
+ for (auto inst : this->invertSet)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newPrimalsInfo->invertSet.Add(env->mapOldValToNew[inst]);
+
+ for (auto inst : this->instsToInvert)
+ if (env->mapOldValToNew.ContainsKey(inst))
+ newPrimalsInfo->instsToInvert.Add(env->mapOldValToNew[inst]);
+
+ for (auto kvpair : this->invertInfoMap)
+ if (env->mapOldValToNew.ContainsKey(kvpair.Key))
+ newPrimalsInfo->invertInfoMap[env->mapOldValToNew[kvpair.Key]] = kvpair.Value.applyMap(env);
+
+ return newPrimalsInfo;
+ }
+
+ void merge(HoistedPrimalsInfo* info)
+ {
+ for (auto inst : info->storeSet)
+ storeSet.Add(inst);
+
+ for (auto inst : info->recomputeSet)
+ recomputeSet.Add(inst);
+
+ for (auto inst : info->invertSet)
+ invertSet.Add(inst);
+
+ for (auto inst : info->instsToInvert)
+ instsToInvert.Add(inst);
+
+ for (auto kvpair : info->invertInfoMap)
+ invertInfoMap[kvpair.Key] = kvpair.Value;
+ }
+ };
+
+ struct HoistResult
+ {
+ enum Mode
+ {
+ Store,
+ Recompute,
+ Invert,
+
+ None
+ };
+
+ Mode mode;
+
+ IRInst* instToStore = nullptr;
+ IRInst* instToRecompute = nullptr;
+ InversionInfo inversionInfo;
+
+ HoistResult(Mode mode, IRInst* target) :
+ mode(mode)
+ {
+ switch (mode)
+ {
+ case Mode::Store:
+ instToStore = target;
+ break;
+ case Mode::Recompute:
+ instToRecompute = target;
+ break;
+ case Mode::Invert:
+ SLANG_UNEXPECTED("Wrong constructor for HoistResult::Mode::Invert");
+ break;
+ default:
+ SLANG_UNEXPECTED("Unhandled hoist mode");
+ break;
+ }
+ }
+
+ HoistResult(InversionInfo info) :
+ mode(Mode::Invert), inversionInfo(info)
+ { }
+
+ static HoistResult store(IRInst* inst)
+ {
+ return HoistResult(Mode::Store, inst);
+ }
+
+ static HoistResult recompute(IRInst* inst)
+ {
+ return HoistResult(Mode::Recompute, inst);
+ }
+
+ static HoistResult invert(InversionInfo inst)
+ {
+ return HoistResult(inst);
+ }
+ };
+
+
+ // Information on which insts are to be stored, recomputed
+ // and inverted within a single function.
+ // This data structure also holds a map of raw HoistResult
+ // objects to provide more information to later passes.
+ //
+ struct CheckpointSetInfo : public RefObject
+ {
+ HashSet storeSet;
+ HashSet recomputeSet;
+ HashSet invertSet;
+
+ Dictionary invInfoMap;
+ };
+
+ struct BlockSplitInfo : public RefObject
+ {
+ // Maps primal to differential blocks from the unzip step.
+ Dictionary diffBlockMap;
+ };
+
+ class AutodiffCheckpointPolicyBase : public RefObject
+ {
+ public:
+
+ AutodiffCheckpointPolicyBase(IRModule* module) : module(module)
+ { }
+
+ RefPtr processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* info);
+
+ // Do pre-processing on the function (mainly for
+ // 'global' checkpointing methods that consider the entire
+ // function)
+ //
+ virtual void preparePolicy(IRGlobalValueWithCode* func) = 0;
+
+ virtual HoistResult classify(IRUse* diffBlockUse) = 0;
+
+ protected:
+
+ IRModule* module;
+ };
+
+ class DefaultCheckpointPolicy : public AutodiffCheckpointPolicyBase
+ {
+ public:
+
+ DefaultCheckpointPolicy(IRModule* module)
+ : AutodiffCheckpointPolicyBase(module)
+ { }
+
+ virtual void preparePolicy(IRGlobalValueWithCode* func);
+ virtual HoistResult classify(IRUse* use);
+ };
+
+ RefPtr applyCheckpointSet(
+ CheckpointSetInfo* checkpointInfo,
+ IRGlobalValueWithCode* func,
+ BlockSplitInfo* splitInfo,
+ HashSet pendingUses);
+
+ RefPtr ensurePrimalAvailability(
+ HoistedPrimalsInfo* hoistInfo,
+ IRGlobalValueWithCode* func,
+ Dictionary> indexedBlockInfo);
+
+};
\ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-region.cpp b/source/slang/slang-ir-autodiff-region.cpp
new file mode 100644
index 000000000..98b64f179
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-region.cpp
@@ -0,0 +1,56 @@
+// slang-ir-autodiff-region.cpp
+#include "slang-ir-autodiff-region.h"
+
+namespace Slang{
+ RefPtr buildIndexedRegionMap(IRGlobalValueWithCode* func)
+ {
+ RefPtr regionMap = new IndexedRegionMap;
+
+ List workList;
+
+ regionMap->mapBlock(func->getFirstBlock(), nullptr);
+ workList.add(func->getFirstBlock());
+
+ while (workList.getCount() > 0)
+ {
+ auto currentBlock = workList.getLast();
+ workList.removeLast();
+
+ auto terminator = currentBlock->getTerminator();
+ auto currentRegion = regionMap->getRegion(currentBlock);
+
+ switch (terminator->getOp())
+ {
+ case kIROp_loop:
+ {
+ auto loopRegion = regionMap->newRegion(as(terminator), currentRegion);
+ auto condBlock = as(terminator)->getTargetBlock();
+
+ regionMap->mapBlock(condBlock, loopRegion);
+ workList.add(condBlock);
+
+ auto ifElse = as(condBlock->getTerminator());
+ SLANG_RELEASE_ASSERT(ifElse);
+
+ // TODO: this is one of the places we'll need to change if we support loops that
+ // loop on either the true or false side. For now, we assume the loop is on the
+ // true side only.
+ //
+ regionMap->mapBlock(ifElse->getFalseBlock(), currentRegion);
+ workList.add(ifElse->getFalseBlock());
+ }
+ }
+
+ for (auto successor : currentBlock->getSuccessors())
+ {
+ // If already mapped, skip.
+ if (regionMap->hasMapping(successor))
+ continue;
+ regionMap->mapBlock(successor, currentRegion);
+ workList.add(successor);
+ }
+ }
+
+ return regionMap;
+ }
+};
diff --git a/source/slang/slang-ir-autodiff-region.h b/source/slang/slang-ir-autodiff-region.h
new file mode 100644
index 000000000..a4618e257
--- /dev/null
+++ b/source/slang/slang-ir-autodiff-region.h
@@ -0,0 +1,119 @@
+// slang-ir-autodiff-region.h
+#pragma once
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-autodiff.h"
+
+namespace Slang
+{
+struct IndexedRegion : public RefObject
+{
+ IRLoop* loop;
+ IndexedRegion* parent;
+
+ IndexedRegion(IRLoop* loop, IndexedRegion* parent) : loop(loop), parent(parent)
+ { }
+
+ IRBlock* getInitializerBlock() { return as(loop->getParent()); }
+ IRBlock* getConditionBlock()
+ {
+ auto condBlock = as(loop->getTargetBlock());
+ SLANG_RELEASE_ASSERT(as(condBlock->getTerminator()));
+ return condBlock;
+ }
+
+ IRBlock* getBreakBlock() { return loop->getBreakBlock(); }
+
+ IRBlock* getUpdateBlock()
+ {
+ auto initBlock = getInitializerBlock();
+
+ auto condBlock = getConditionBlock();
+
+ IRBlock* lastLoopBlock = nullptr;
+
+ for (auto predecessor : condBlock->getPredecessors())
+ {
+ if (predecessor != initBlock)
+ lastLoopBlock = predecessor;
+ }
+
+ // Should find atleast one predecessor that is _not_ the
+ // init block (that contains the loop info). This
+ // predecessor would be the last block in the loop
+ // before looping back to the condition.
+ //
+ SLANG_RELEASE_ASSERT(lastLoopBlock);
+
+ return lastLoopBlock;
+ }
+};
+
+struct IndexTrackingInfo : public RefObject
+{
+ // After lowering, store references to the count
+ // variables associated with this region
+ //
+ IRInst* primalCountParam = nullptr;
+ IRInst* diffCountParam = nullptr;
+
+ IRVar* primalCountLastVar = nullptr;
+
+ enum CountStatus
+ {
+ Unresolved,
+ Dynamic,
+ Static
+ };
+
+ CountStatus status = CountStatus::Unresolved;
+
+ // Inferred maximum number of iterations.
+ Count maxIters = -1;
+};
+
+struct IndexedRegionMap : public RefObject
+{
+ Dictionary map;
+ List> regions;
+
+ IndexedRegion* newRegion(IRLoop* loop, IndexedRegion* parent)
+ {
+ auto region = new IndexedRegion(loop, parent);
+ regions.add(region);
+
+ return region;
+ }
+
+ void mapBlock(IRBlock* block, IndexedRegion* region)
+ {
+ map.Add(block, region);
+ }
+
+ bool hasMapping(IRBlock* block)
+ {
+ return map.ContainsKey(block);
+ }
+
+ IndexedRegion* getRegion(IRBlock* block)
+ {
+ return map[block];
+ }
+
+ List getAllAncestorRegions(IRBlock* block)
+ {
+ List regionList;
+
+ IndexedRegion* region = getRegion(block);
+ for (; region; region = region->parent)
+ regionList.add(region);
+
+ return regionList;
+ }
+};
+
+RefPtr buildIndexedRegionMap(IRGlobalValueWithCode* func);
+
+
+};
\ No newline at end of file
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 328af4867..157011b7c 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -606,106 +606,6 @@ namespace Slang
return fwdDiffFunc;
}
- void BackwardDiffTranscriberBase::insertVariableForRecomputedPrimalInsts(IRFunc* diffPropFunc)
- {
- RefPtr domTree = computeDominatorTree(diffPropFunc);
- auto firstBlock = diffPropFunc->getFirstBlock();
- if (!firstBlock)
- return;
- Dictionary instVars;
- Dictionary cloneEnvs;
- auto storeInstAsLocalVar = [&](IRInst* inst)
- {
- IRVar* var = nullptr;
- if (instVars.TryGetValue(inst, var))
- return var;
- IRBuilder builder(diffPropFunc);
- builder.setInsertBefore(firstBlock->getFirstOrdinaryInst());
- var = builder.emitVar(inst->getDataType());
- builder.emitStore(var, builder.emitDefaultConstruct(inst->getDataType()));
-
- setInsertAfterOrdinaryInst(&builder, inst);
- builder.emitStore(var, inst);
- instVars[inst] = var;
- return var;
- };
-
- IRBuilder builder(diffPropFunc);
- List workList;
- for (auto block : diffPropFunc->getBlocks())
- {
- if (!block->findDecoration())
- continue;
- cloneEnvs[block] = IRCloneEnv();
- for (auto inst : block->getChildren())
- {
- workList.add(inst);
- }
- }
-
- for (Index i = 0; i < workList.getCount(); i++)
- {
- auto inst = workList[i];
- for (UInt j = 0; j < inst->getOperandCount(); j++)
- {
- auto operand = inst->getOperand(j);
- if (operand->getOp() == kIROp_Block)
- continue;
- auto operandParent = inst->getOperand(j)->getParent();
- if (!operandParent)
- continue;
- if (operandParent->parent != diffPropFunc)
- continue;
- if (domTree->dominates(operandParent, inst->parent))
- continue;
-
- // The def site of the operand does not dominate the use.
- // We need to insert a local variable to store this var.
-
- IRInst* operandReplacement = nullptr;
- if (canTypeBeStored(operand->getDataType()))
- {
- auto var = storeInstAsLocalVar(operand);
- builder.setInsertBefore(inst);
- operandReplacement = builder.emitLoad(var);
- }
- else if (operand->getOp() == kIROp_Var)
- {
- // Var can just be hoisted to first block.
- operand->insertBefore(firstBlock->getFirstOrdinaryInst());
- }
- else
- {
- // For all other insts, we need to copy it to right before this inst.
- // Before actually copying it, check if we have already copied it to
- // any blocks that dominates this block.
- auto dom = as(inst->getParent());
- while (dom)
- {
- auto subCloneEnv = cloneEnvs.TryGetValue(dom);
- if (!subCloneEnv) break;
- if (subCloneEnv->mapOldValToNew.TryGetValue(operand, operandReplacement))
- {
- break;
- }
- dom = domTree->getImmediateDominator(dom);
- }
- // We have not found an existing clone in dominators, so we need to copy it
- // to this block.
- if (!operandReplacement)
- {
- auto subCloneEnv = cloneEnvs.TryGetValue(as(inst->getParent()));
- builder.setInsertBefore(inst);
- operandReplacement = cloneInst(subCloneEnv, &builder, operand);
- workList.add(operandReplacement);
- }
- }
- if (operandReplacement)
- builder.replaceOperand(inst->getOperands() + j, operandReplacement);
- }
- }
- }
-
InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType)
{
SLANG_UNUSED(primalType);
@@ -774,7 +674,7 @@ namespace Slang
// Copy primal insts to the first block of the unzipped function, copy diff insts to the
// second block of the unzipped function.
//
- diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
+ RefPtr primalsInfo = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
IRFunc* unzippedFwdDiffFunc = fwdDiffFunc;
// Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell.
@@ -801,8 +701,8 @@ namespace Slang
// Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
// derivative of the return value.
- DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam };
- diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info);
+ DiffTransposePass::FuncTranspositionInfo transposeInfo = { paramTransposeInfo.dOutParam, primalsInfo };
+ diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, transposeInfo);
eliminateDeadCode(diffPropagateFunc);
@@ -810,7 +710,7 @@ namespace Slang
// with the intermediate results computed from the extracted func.
IRInst* intermediateType = nullptr;
auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc(
- diffPropagateFunc, primalFunc, paramTransposeInfo, intermediateType);
+ diffPropagateFunc, primalFunc, primalsInfo, paramTransposeInfo, intermediateType);
// At this point the unzipped func is just an empty shell
// and we can simply remove it.
@@ -870,7 +770,7 @@ namespace Slang
initializeLocalVariables(builder->getModule(), as(getGenericReturnVal(primalFuncGeneric)));
initializeLocalVariables(builder->getModule(), diffPropagateFunc);
- insertVariableForRecomputedPrimalInsts(diffPropagateFunc);
+ // insertVariableForRecomputedPrimalInsts(diffPropagateFunc);
stripTempDecorations(diffPropagateFunc);
}
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index a92978817..e59f27881 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -8,6 +8,8 @@
#include "slang-ir-autodiff.h"
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-cfg-norm.h"
+#include "slang-ir-autodiff-primal-hoist.h"
+#include "slang-ir-dominators.h"
namespace Slang
{
@@ -78,6 +80,11 @@ struct DiffTransposePass
// of the *output* of the function.
//
IRInst* dOutInst;
+
+ // Information from the unzip pass on how primal insts
+ // are split across the primal and differential blocks.
+ //
+ HoistedPrimalsInfo* hoistedPrimalsInfo;
};
struct PendingBlockTerminatorEntry
@@ -227,15 +234,17 @@ struct DiffTransposePass
IRBlock* revAfterBlock = revBlockMap[currentBlock];
builder.setInsertInto(revCondBlock);
-
- hoistPrimalInst(&builder, ifElse->getCondition());
- builder.emitIfElse(
+ //hoistPrimalInst(&builder, ifElse->getCondition());
+
+ auto newIfElse = builder.emitIfElse(
ifElse->getCondition(),
revTrueEntryBlock,
revFalseEntryBlock,
revAfterBlock);
+ hoistPrimalOperands(&builder, newIfElse);
+
if (!revTrueRegionInfo.isTrivial)
{
builder.setInsertInto(revTrueExitBlock);
@@ -348,14 +357,18 @@ struct DiffTransposePass
// Emit condition into the new cond block.
builder.setInsertInto(revCondBlock);
- hoistPrimalInst(&builder, ifElse->getCondition());
- builder.emitIfElse(
+ // TODO: Need to defer this until after the CFG reversal is complete.
+ //hoistPrimalInst(&builder, ifElse->getCondition());
+
+ auto newIfElse = builder.emitIfElse(
ifElse->getCondition(),
revTrueBlock,
revFalseBlock,
revTrueBlock);
+ hoistPrimalOperands(&builder, newIfElse);
+
// Old false-side starting block becomes end block
// for the new pre-cond region (which could be empty)
//
@@ -364,12 +377,13 @@ struct DiffTransposePass
{
IRBlock* revPreCondEndBlock = revBlockMap[falseBlock];
builder.setInsertInto(revPreCondEndBlock);
- builder.emitLoop(
+ auto revLoop = builder.emitLoop(
revCondBlock,
revBreakBlock,
revLoopEndBlock,
getPhiGrads(falseBlock).getCount(),
getPhiGrads(falseBlock).getBuffer());
+ loop->transferDecorationsTo(revLoop);
auto revLoopStartBlock = revBlockMap[breakBlock];
builder.setInsertInto(revLoopStartBlock);
@@ -383,12 +397,13 @@ struct DiffTransposePass
// Emit loop into rev-version of the break block.
auto revLoopBlock = revBlockMap[breakBlock];
builder.setInsertInto(revLoopBlock);
- builder.emitLoop(
+ auto revLoop = builder.emitLoop(
revPreCondBlock,
revBreakBlock,
revLoopEndBlock,
getPhiGrads(breakBlock).getCount(),
getPhiGrads(breakBlock).getBuffer());
+ loop->transferDecorationsTo(revLoop);
}
currentBlock = breakBlock;
@@ -463,14 +478,16 @@ struct DiffTransposePass
builder.setInsertInto(revSwitchBlock);
- hoistPrimalInst(&builder, switchInst->getCondition());
+ // hoistPrimalInst(&builder, switchInst->getCondition());
- builder.emitSwitch(
+ auto newSwitchInst = builder.emitSwitch(
switchInst->getCondition(),
revBreakBlock,
revDefaultRegionEntry,
reverseSwitchArgs.getCount(),
reverseSwitchArgs.getBuffer());
+
+ hoistPrimalOperands(&builder, newSwitchInst);
currentBlock = breakBlock;
break;
@@ -504,6 +521,13 @@ struct DiffTransposePass
IRFunc* revDiffFunc,
FuncTranspositionInfo transposeInfo)
{
+ // TODO (sai): We really to make this method stateless
+ // (i.e. not store per-func info in 'this')
+ // since it is reused for every reverse-mode call.
+ //
+
+ hoistedPrimalsInfo = transposeInfo.hoistedPrimalsInfo;
+
// Grab all differentiable type information.
diffTypeContext.setFunc(revDiffFunc);
@@ -586,6 +610,9 @@ struct DiffTransposePass
}
}
+ // Make a temporary block to hold inverted insts.
+ tempInvBlock = builder.createBlock();
+
for (auto block : workList)
{
// Set dOutParameter as the transpose gradient for the return inst, if any.
@@ -620,7 +647,6 @@ struct DiffTransposePass
{
if (auto loopInst = as(block->getTerminator()))
{
- lowerLoopExitValues(&builder, loopInst);
invertLoopCondition(&builder, loopInst);
}
}
@@ -655,6 +681,50 @@ struct DiffTransposePass
subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal);
}
+ // TODO: Should move this to before all the transposition, but a lot of the
+ // transposition logic seems to access the parent of blocks to find the func.
+ // Replace those uses.
+ //
+ for (auto block : workList)
+ block->removeFromParent();
+
+ // Mark all primal operands for hoisting.
+ // TODO: Can we just merge this with finishHoistingPrimalInsts?
+ // TODO: Some of this logic is replicated in finishHoistingPrimalInsts. Merge it with the
+ // maybeAddOperandsToWorkList logic there.
+ //
+ for (auto block : workList)
+ {
+ IRBlock* revBlock = revBlockMap[block];
+
+ for (auto child = revBlock->getFirstChild(); child; child = child->getNextInst())
+ {
+ hoistPrimalOperands(&builder, child);
+
+ for (auto decoration = child->getFirstDecoration(); decoration; decoration = decoration->getNextDecoration())
+ {
+ if (auto contextDecoration = as(decoration))
+ hoistPrimalUse(&builder, &contextDecoration->primalContextVar);
+
+ if (auto loopExitDecoration = as(decoration))
+ hoistPrimalUse(&builder, &loopExitDecoration->exitVal);
+ }
+
+ if (auto instType = child->getDataType())
+ if (!as(instType->getParent()))
+ hoistPrimalUse(&builder, &child->typeUse);
+ }
+ }
+
+ finishHoistingPrimals(revDiffFunc);
+
+ for (auto block : workList)
+ {
+ auto revBlock = as(revBlockMap[block]);
+ if (auto revLoop = as(revBlock->getTerminator()))
+ lowerLoopExitValues(&builder, revLoop);
+ }
+
// At this point, the only block left without terminator insts
// should be the last one. Add a void return to complete it.
//
@@ -723,7 +793,24 @@ struct DiffTransposePass
return tempRevVar;
}
+ IRVar* lookupInverseVar(IRInst* inst)
+ {
+ return inverseVarMap[inst];
+ }
+
+ IRVar* getOrCreateInverseVar(IRInst* primalInst, IRGlobalValueWithCode* func)
+ {
+ IRBlock* varBlock = firstRevDiffBlockMap[func];
+ return getOrCreateInverseVar(primalInst, varBlock);
+ }
+
IRVar* getOrCreateInverseVar(IRInst* primalInst)
+ {
+ IRBlock* varBlock = firstRevDiffBlockMap[as(primalInst->getParent()->getParent())];
+ return getOrCreateInverseVar(primalInst, varBlock);
+ }
+
+ IRVar* getOrCreateInverseVar(IRInst* primalInst, IRBlock* varBlock)
{
// No need to store inverse values for constants.
if (as(primalInst))
@@ -734,13 +821,11 @@ struct DiffTransposePass
return inverseVarMap[primalInst];
IRBuilder tempVarBuilder(autodiffContext->moduleInst);
-
- IRBlock* firstDiffBlock = firstRevDiffBlockMap[as(primalInst->getParent()->getParent())];
- if (auto firstInst = firstDiffBlock->getFirstOrdinaryInst())
+ if (auto firstInst = varBlock->getFirstOrdinaryInst())
tempVarBuilder.setInsertBefore(firstInst);
else
- tempVarBuilder.setInsertInto(firstDiffBlock);
+ tempVarBuilder.setInsertInto(varBlock);
auto primalType = primalInst->getDataType();
@@ -766,6 +851,19 @@ struct DiffTransposePass
return false;
}
+ IRParam* getParamAt(IRBlock* block, UIndex ii)
+ {
+ UIndex index = 0;
+ for (auto param : block->getParams())
+ {
+ if (ii == index)
+ return param;
+
+ index ++;
+ }
+ SLANG_UNEXPECTED("ii >= paramCount");
+ }
+
void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock)
{
IRBuilder builder(autodiffContext->moduleInst);
@@ -773,6 +871,10 @@ struct DiffTransposePass
// Insert into our reverse block.
builder.setInsertInto(revBlock);
+ // Create an inverse builder to insert insts into the inv-block.
+ IRBuilder invBuilder(autodiffContext->moduleInst);
+
+
// Check if this block has any 'outputs' (in the form of phi args
// sent to the successor block)
//
@@ -798,25 +900,43 @@ struct DiffTransposePass
revParam,
nullptr));
}
- else if (isPrimalInst(arg))
+ else if (hasInverse(arg))
{
- // If the output arg is a primal, emit a parameter
- // to accept it as an _input_ for the reverse-mode
- //
- auto primalType = arg->getDataType();
- auto primalInvParam = builder.emitParam(primalType);
+ InversionInfo invInfo = this->hoistedPrimalsInfo->invertInfoMap[branchInst];
+ if (invInfo.targetInsts.contains(arg))
+ {
+ SLANG_ASSERT(hasInverse(getParamAt(branchInst->getTargetBlock(), ii)));
+
+ // If the output arg is a primal, emit a parameter
+ // to accept it as an _input_ for the reverse-mode
+ //
+ auto primalType = arg->getDataType();
+ auto primalInvParam = builder.emitParam(primalType);
- setInverse(&builder, arg, primalInvParam);
+ invBuilder.setInsertBefore(branchInst);
+ setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam);
+ }
}
else
{
- SLANG_UNEXPECTED("Encountered inst not marked as primal or differential");
+ if (hasInverse(getParamAt(branchInst->getTargetBlock(), ii)))
+ {
+ auto primalType = arg->getDataType();
+ auto primalInvParam = builder.emitParam(primalType);
+
+ invBuilder.setInsertBefore(branchInst);
+ setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion");
+ }
}
}
}
// Move pointer & reference insts to the top of the reverse-mode block.
- List nonValueInsts;
+ List typeInsts;
for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
{
// If the instruction is a variable allocation (or reverse-gradient pair reference),
@@ -824,17 +944,17 @@ struct DiffTransposePass
// TODO: This is hacky.. Need a more principled way to handle this
// (like primal inst hoisting)
//
- if (as(child) || as(child))
- nonValueInsts.add(child);
+ //if (as(child) || as(child))
+ // nonValueInsts.add(child);
// Slang doesn't support function values. So if we see a func-typed inst
// it's proabably a reference to a function.
//
if (as(child->getDataType()))
- nonValueInsts.add(child);
+ typeInsts.add(child);
}
- for (auto inst : nonValueInsts)
+ for (auto inst : typeInsts)
{
inst->insertAtEnd(revBlock);
}
@@ -846,9 +966,6 @@ struct DiffTransposePass
//
for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst())
{
- if (child->findDecoration())
- continue;
-
if (as(child) || as(child))
continue;
if (as(child))
@@ -856,8 +973,15 @@ struct DiffTransposePass
if (isDifferentialInst(child))
transposeInst(&builder, child);
- else if (isPrimalInst(child))
- invertInst(&builder, child);
+ else if (shouldInstBeInverted(child))
+ {
+ // We'll collect inverse insts in an orphaned block,
+ // so disable IR validation temporarily.
+ //
+ disableIRValidationAtInsert();
+ invertInst(&invBuilder, child);
+ enableIRValidationAtInsert();
+ }
}
// After processing the block's instructions, we 'flush' any remaining gradients
@@ -901,23 +1025,18 @@ struct DiffTransposePass
phiParamRevGradInsts.add(gradInst);
}
else
- {
+ {
phiParamRevGradInsts.add(
emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param)));
}
}
- else if (isPrimalInst(param))
+ else if (hasInverse(param))
{
- if (hasInverse(param))
- phiParamRevGradInsts.add(getInverse(&builder, param));
- else
- {
- SLANG_UNEXPECTED("param is a primal inst but has no registered inverse");
- }
+ phiParamRevGradInsts.add(param);
}
else
- {
- SLANG_UNEXPECTED("param is neither differential nor primal");
+ {
+ SLANG_UNEXPECTED("param is neither differential inst nor marked for inversion");
}
}
@@ -995,8 +1114,15 @@ struct DiffTransposePass
{ }
};
- List invertArithmetic(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput)
+ List invertArithmetic(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo)
{
+ SLANG_RELEASE_ASSERT(invInfo.requiredOperands.getCount() == 1);
+ SLANG_RELEASE_ASSERT(invInfo.targetInsts.getCount() == 1);
+
+ auto invOutput = invInfo.requiredOperands[0];
+
+ auto invTargetInst = invInfo.targetInsts[0];
+
switch (primalInst->getOp())
{
case kIROp_Add:
@@ -1004,7 +1130,7 @@ struct DiffTransposePass
SLANG_RELEASE_ASSERT(as(primalInst->getOperand(1)));
return List(
InvInstPair(
- primalInst->getOperand(0),
+ invTargetInst,
builder->emitSub(
primalInst->getOperand(0)->getDataType(),
invOutput,
@@ -1015,7 +1141,7 @@ struct DiffTransposePass
SLANG_RELEASE_ASSERT(as(primalInst->getOperand(1)));
return List(
InvInstPair(
- primalInst->getOperand(0),
+ invTargetInst,
builder->emitAdd(
primalInst->getOperand(0)->getDataType(),
invOutput,
@@ -1027,24 +1153,38 @@ struct DiffTransposePass
}
}
- void lowerLoopExitValues(IRBuilder* builder, IRLoop* fwdLoop)
+ // NOTE: This is a workaround for the fact that we expect inverses to use
+ // single-use variables. The loop exit value will add a
+ // second store to most inv-variables and mess with the primal hoisting mechanism.
+ // Instead of emitting into the orphaned inverse block, we'll directly emit into
+ // the reverse-mode block since we'll be running this _after_ the primal hoisting
+ // pass.
+ //
+ // This workaround is fine for inverting loop counters, but when we want to
+ // expand to supporting general-purpose adjoints, we would want to use per-region
+ // inverse vars based on 'invInfo' (enforcing single-use vars)
+ //
+ void lowerLoopExitValues(IRBuilder* builder, IRLoop* revLoop)
{
- for (auto decoration : fwdLoop->getDecorations())
+ List processedDecorations;
+ for (auto decoration : revLoop->getDecorations())
{
if (auto loopExitValueDecoration = as(decoration))
{
- IRBlock* revLoopInitBlock = revBlockMap[fwdLoop->getBreakBlock()];
-
- if (auto revLoopInst = revLoopInitBlock->getTerminator())
- builder->setInsertBefore(revLoopInst);
- else
- builder->setInsertInto(revLoopInitBlock);
-
- hoistPrimalInst(builder, loopExitValueDecoration->getLoopExitValInst());
+ builder->setInsertBefore(revLoop);
+ setInverse(
+ builder,
+ nullptr,
+ builder->getFunc(),
+ loopExitValueDecoration->getTargetInst(),
+ loopExitValueDecoration->getLoopExitValInst());
- setInverse(builder, loopExitValueDecoration->getTargetInst(), loopExitValueDecoration->getLoopExitValInst());
+ processedDecorations.add(loopExitValueDecoration);
}
}
+
+ for (auto decoration : processedDecorations)
+ decoration->removeAndDeallocate();
}
void lowerLoopExitValues(IRBuilder* builder, IRBlock* block)
@@ -1094,19 +1234,21 @@ struct DiffTransposePass
kIROp_Neq,
2,
List(
- hoistPrimalInst(builder, loopCounterParam),
- hoistPrimalInst(builder, loopCounterInitVal)).getBuffer());
+ loopCounterParam,
+ loopCounterInitVal).getBuffer());
+
+ hoistPrimalOperands(builder, paramBoundsCheck);
as(revLoopCondBlock->getTerminator())->condition.set(paramBoundsCheck);
}
- List invertInst(IRBuilder* builder, IRInst* primalInst, IRInst* invOutput)
+ List invertInst(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo)
{
switch (primalInst->getOp())
{
case kIROp_Add:
case kIROp_Sub:
- return invertArithmetic(builder, primalInst, invOutput);
+ return invertArithmetic(builder, primalInst, invInfo);
default:
SLANG_UNIMPLEMENTED_X("Unhandled inst type for inversion");
@@ -1115,70 +1257,392 @@ struct DiffTransposePass
bool hasInverse(IRInst* primalInst)
{
- if (getOrCreateInverseVar(primalInst))
- return true;
- else
- return false;
+ return this->hoistedPrimalsInfo->invertSet.Contains(primalInst);
}
- IRInst* getInverse(IRBuilder* builder, IRInst* primalInst)
+ IRInst* loadInverse(IRBuilder* builder, IRInst* primalInst)
{
// Note: There are other possible cases here, although not important
// right now. For example, a value is available to load from the primal block.
//
- if (auto invVar = getOrCreateInverseVar(primalInst))
+
+ if (auto invVar = getOrCreateInverseVar(primalInst, builder->getFunc()))
return builder->emitLoad(invVar);
return nullptr;
}
- void setInverse(IRBuilder* builder, IRInst* inst, IRInst* invInst)
+ IRInst* lookupInstInPrimalBlock(IRInst* invInst)
{
- if (auto invVar = getOrCreateInverseVar(inst))
- builder->emitStore(invVar, invInst);
+ // Lookup the inst in the primal block whose value we can use as an operand
+ // for the inverted inst.
+ //
+ // auto inversionInfo = this->hoistedPrimalsInfo->invertInfoMap[invInst];
+ return invInst;
}
- IRInst* hoistPrimalInst(IRBuilder* revBuilder, IRInst* inst)
+ void setInverse(IRBuilder* builder, IRBlock* defBlock, IRGlobalValueWithCode* func, IRInst* inst, IRInst* invInst)
{
- if (as(inst->getParent()) &&
- isDifferentialInst(as(inst->getParent())))
+ auto instBlock = as(inst->getParent());
+ if (!instBlock)
+ return;
+
+ disableIRValidationAtInsert();
+ if (auto invVar = getOrCreateInverseVar(inst, func))
{
- SLANG_RELEASE_ASSERT(isPrimalInst(inst));
+ auto invStore = builder->emitStore(invVar, invInst);
+ mapStoreToDefBlock[as(invStore)] = defBlock;
}
+ enableIRValidationAtInsert();
+ }
- // Are the operands of this primal inst also available in the reverse-mode context?
- // If not, move/load them.
+ bool shouldInstBeInverted(IRInst* inst)
+ {
+
+ if (this->hoistedPrimalsInfo->instsToInvert.Contains(inst))
+ return true;
+
+ return false;
+ }
+
+ IRInst* hoistPrimalUse(IRBuilder*, IRUse* use)
+ {
+ primalUsesToHoist.add(use);
+ return use->get();
+ }
+
+ bool doesInstRequireHoisting(IRInst* inst)
+ {
+ if (as(inst->getParent()))
+ return false;
+
+ if (as(inst) ||
+ as(inst) ||
+ as(inst))
+ return false;
+
+ if (as(inst))
+ return false;
+
+ if (as(inst))
+ return doesInstRequireHoisting(getInstInBlock(inst));
+
+ // We're looking for primal insts in differential blocks
+ // that have not yet been moved to the 'active' blocks
+ // (i.e in diff blocks that do not have parents)
//
- hoistPrimalOperands(revBuilder, inst);
+ return (!isDifferentialInst(inst) &&
+ (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) &&
+ getBlock(inst)->getParent() == nullptr);
+ }
- if (isPrimalInst(inst) &&
- as(inst->getParent()) &&
- isDifferentialInst(as(inst->getParent())))
+ // Builds a map from inst to a list of uses by primal _inverted_ insts.
+ Dictionary> buildInvOperandMap()
+ {
+ Dictionary> invOperandMap;
+ for (auto kvpair : this->hoistedPrimalsInfo->invertInfoMap)
{
- if (!inst->findDecoration())
+ InversionInfo invInfo = kvpair.Value;
+
+ for (auto operand : invInfo.requiredOperands)
{
- return getInverse(revBuilder, inst);
+ if (!invOperandMap.ContainsKey(operand))
+ invOperandMap[operand] = List();
+
+ for (auto target : invInfo.targetInsts)
+ invOperandMap[operand].GetValue().add(target);
}
- else
+ }
+
+ return invOperandMap;
+ }
+
+ IRBlock* walkToEndOfRegion(IRBlock* block)
+ {
+ IRBlock* currBlock = block;
+
+ bool keepGoing = true;
+ while (keepGoing)
+ {
+ auto terminator = currBlock->getTerminator();
+ switch (terminator->getOp())
{
- auto block = as(inst->getParent());
- SLANG_RELEASE_ASSERT(block);
+ case kIROp_Return:
+ keepGoing = false;
+ break;
- if (block == revBuilder->getBlock())
+ case kIROp_unconditionalBranch:
{
- // Already in block..
- return inst;
- }
+ auto nextBlock = as(terminator)->getTargetBlock();
- // Otherwise, move our inst to the the current builder location.
- inst->removeFromParent();
- revBuilder->addInst(inst);
+ HashSet predecessorSet;
+ for (auto predecessor : nextBlock->getPredecessors())
+ predecessorSet.Add(predecessor);
- return inst;
+ if (predecessorSet.Count() > 1)
+ {
+ keepGoing = false;
+ break;
+ }
+
+ currBlock = nextBlock;
+ break;
+ }
+
+ case kIROp_ifElse:
+ {
+ for (auto predecessor : currBlock->getPredecessors())
+ {
+ if (as(predecessor->getTerminator()))
+ {
+ keepGoing = false;
+ break;
+ }
+ }
+
+ currBlock = as(terminator)->getAfterBlock();
+ break;
+ }
+
+ case kIROp_Switch:
+ currBlock = as(terminator)->getBreakLabel();
+ break;
+
+ case kIROp_loop:
+ currBlock = as(terminator)->getBreakBlock();
+ break;
}
}
- return inst;
+ return currBlock;
+ }
+
+ void finishHoistingPrimals(IRGlobalValueWithCode* func)
+ {
+ List workList;
+
+ Dictionary hoistedInstMap;
+
+ RefPtr domTree = computeDominatorTree(func);
+
+ Dictionary> invOperandMap = buildInvOperandMap();
+
+ auto varBlock = func->getFirstBlock()->getNextBlock();
+
+ // Load up pending insts into workList.
+ for (auto use : primalUsesToHoist)
+ workList.add(use->get());
+
+ primalUsesToHoist.clear();
+
+
+ auto maybeAddPrimalOperandsToWorkList = [&](IRInst* inst)
+ {
+ UIndex opIndex = 0;
+ for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
+ {
+ if (doesInstRequireHoisting(operand->get()) &&
+ !hoistedInstMap.ContainsKey(operand->get()))
+ {
+ workList.add(operand->get());
+ }
+ }
+
+ if (auto instType = inst->getDataType())
+ {
+ if (doesInstRequireHoisting(instType) &&
+ !hoistedInstMap.ContainsKey(instType))
+ workList.add(instType);
+ }
+ };
+
+ auto maybeAddUsersToWorkList = [&](IRInst* inst)
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (doesInstRequireHoisting(use->getUser()))
+ {
+ if (as(inst) && as(use->getUser()))
+ continue;
+
+ // Uses that haven't already been hoisted into reverse-mode
+ // blocks, and are not in the invert-set are pending uses.
+ //
+ if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser()))
+ workList.add(use->getUser());
+ }
+ }
+ };
+
+ auto doesInstHavePendingUsers = [&](IRInst* inst)
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (doesInstRequireHoisting(use->getUser()))
+ {
+ if (as(inst) && as(use->getUser()))
+ continue;
+
+ // Users that haven't already been hoisted into reverse-mode
+ // blocks are pending users.
+ //
+ if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser()))
+ return true;
+ }
+ }
+
+ return false;
+ };
+
+ auto isInstHoisted = [&](IRInst* inst)
+ {
+ return getBlock(inst)->getParent() != nullptr && isDifferentialInst(getBlock(inst));
+ };
+
+ while (workList.getCount() > 0)
+ {
+ // Pop work item
+ auto inst = workList.getLast();
+ workList.removeLast();
+
+ // Already hoisted to reverse-mode block.
+ // replace with mapped inst (in case it's different)
+ // and continue on.. (this should actually never be hit)
+ //
+ if (hoistedInstMap.ContainsKey(inst))
+ continue;
+
+ if (invOperandMap.ContainsKey(inst))
+ {
+ List pendingInvDependencies;
+ for (auto dependency : invOperandMap[inst].GetValue())
+ {
+ if (doesInstRequireHoisting(dependency) &&
+ !hoistedInstMap.ContainsKey(dependency))
+ pendingInvDependencies.add(dependency);
+ }
+
+ if (pendingInvDependencies.getCount() > 0)
+ {
+ workList.add(inst);
+ for (auto dependency : pendingInvDependencies)
+ workList.add(dependency);
+
+ // Skip until all the dependencies have been handled.
+ continue;
+ }
+ }
+
+ // Are the uses of this primal inst already hoisted into the reverse-mode
+ // blocks? We cannot hoist this inst unless the uses are hoisted.
+ //
+ if (doesInstHavePendingUsers(inst))
+ {
+ // Add inst back to work list.
+ workList.add(inst);
+
+ // Then, add all the pending use to the top of
+ // list, ensuring they are processed before we see
+ // inst again.
+ //
+ maybeAddUsersToWorkList(inst);
+
+ continue;
+ }
+
+ // The used inst is marked for inversion, lookup and load
+ // an inverse.
+ //
+ if (this->hoistedPrimalsInfo->invertSet.Contains(inst))
+ {
+ // Replace with inverse.
+ IRBuilder builder(func->getModule());
+
+ for (auto use = inst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ if (!isInstHoisted(use->getUser()))
+ {
+ use = nextUse;
+ continue;
+ }
+
+ // TODO: Hacky workaround to prevent the 'key' being overwritten,
+ // avoid this by adding the decoration on the param instead of the loop
+ //
+ if (auto exitValDecoration = as(use->getUser()))
+ {
+ if (&exitValDecoration->target == use)
+ {
+ use = nextUse;
+ continue;
+ }
+ }
+
+
+ builder.setInsertBefore(getInstInBlock(use->getUser()));
+ use->set(loadInverse(&builder, inst));
+
+ use = nextUse;
+ }
+
+ // If all uses of the invertible inst have been hoisted,
+ // add the inv-var to the worklist.
+ //
+ workList.add(lookupInverseVar(inst));
+ hoistedInstMap[inst] = nullptr;
+
+ continue;
+ }
+
+ // Should not see an inst marked for inversion here.
+ SLANG_RELEASE_ASSERT(!this->hoistedPrimalsInfo->invertSet.Contains(inst));
+
+ List relevantUses;
+
+ IRBlock* defBlock = nullptr;
+ if (auto varToHoist = as(inst))
+ {
+ varToHoist->insertBefore(varBlock->getFirstOrdinaryInst());
+ inst = findUniqueStoredVal(varToHoist)->getUser();
+ SLANG_ASSERT(inst);
+
+ defBlock = getBlock(inst);
+ }
+ else
+ {
+ defBlock = getBlock(inst);
+ }
+
+ if (!doesInstRequireHoisting(inst))
+ continue;
+
+ // Move this inst to after it's diff uses.
+ //
+ {
+
+ IRBlock* currTopBlock = revBlockMap[walkToEndOfRegion(defBlock)];
+
+ SLANG_RELEASE_ASSERT(currTopBlock);
+
+ // More consistency checks
+ SLANG_RELEASE_ASSERT(currTopBlock->getFirstOrdinaryInst() != nullptr);
+ SLANG_RELEASE_ASSERT(currTopBlock->getParent() != nullptr);
+ SLANG_RELEASE_ASSERT(isDifferentialInst(currTopBlock));
+
+ // Insert at top. (disabling validation since the operands of
+ // this inst might not be hoisted to the right place yet)
+ //
+ disableIRValidationAtInsert();
+ inst->insertBefore(currTopBlock->getFirstOrdinaryInst());
+ enableIRValidationAtInsert();
+ }
+
+ // Finish up..
+ hoistedInstMap[inst] = inst;
+ maybeAddPrimalOperandsToWorkList(inst);
+ }
}
void hoistPrimalOperands(IRBuilder* revBuilder, IRInst* fwdInst)
@@ -1192,10 +1656,9 @@ struct DiffTransposePass
// make sure all requried primal insts are moved to the right
// place)
//
- if (isPrimalInst(fwdInst->getOperand(ii)))
+ if (doesInstRequireHoisting(fwdInst->getOperand(ii)))
{
- auto hoistedPrimalInst = hoistPrimalInst(revBuilder, fwdInst->getOperand(ii));
- fwdInst->setOperand(ii, hoistedPrimalInst);
+ hoistPrimalUse(revBuilder, &fwdInst->getOperands()[ii]);
}
}
}
@@ -1203,14 +1666,28 @@ struct DiffTransposePass
void invertInst(IRBuilder* builder, IRInst* primalInst)
{
// Look for an available inverse entry for this primalInst's *output*
- if (hasInverse(primalInst))
+ if (shouldInstBeInverted(primalInst))
{
- auto invOutput = getInverse(builder, primalInst);
+ // This logic is already handled in transposeBlock() so we skip
+ // it here.
+ //
+ if (as(primalInst))
+ return;
+
+ auto invInfo = this->hoistedPrimalsInfo->invertInfoMap[primalInst];
- auto invEntries = invertInst(builder, primalInst, invOutput);
+ IRBuilder invBuilder(builder->getModule());
+ invBuilder.setInsertAfter(primalInst);
+ auto invEntries = invertInst(&invBuilder, primalInst, invInfo);
+
for (auto entry : invEntries)
- setInverse(builder, entry.inst, entry.invInst);
+ setInverse(
+ &invBuilder,
+ getBlock(primalInst),
+ as(entry.inst->getParent()->getParent()),
+ entry.inst,
+ entry.invInst);
}
else
{
@@ -1270,11 +1747,6 @@ struct DiffTransposePass
SLANG_ASSERT(gradients.getCount() == 0);
}
- // Ensure primal operands are replaced with insts accessible in the
- // reverse-mode context.
- //
- hoistPrimalOperands(builder, inst);
-
// Is this inst used in another differential block?
// Emit a function-scope accumulator variable, and include it's value.
// Also, we ignore this if it's a load since those are turned into stores
@@ -1381,7 +1853,9 @@ struct DiffTransposePass
// In order to perform the call, we need a temporary var to store the DiffPair.
auto pairType = as(arg->getDataType())->getValueType();
auto tempVar = builder->emitVar(pairType);
- auto primalVal = builder->emitLoad(hoistPrimalInst(builder, instPair->getPrimal()));
+ auto primalVal = builder->emitLoad(instPair->getPrimal());
+ hoistPrimalOperands(builder, primalVal); // TODO(sai): Do we need to hoist other insts here?
+
auto diffVal = builder->emitLoad(instPair->getDiff());
auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
builder->emitStore(tempVar, pairVal);
@@ -1453,13 +1927,15 @@ struct DiffTransposePass
// If the callee provides a primal implementation that produces continuation context for propagation phase
// we grab it and pass it as argument to the propagation function.
+ //
if (auto primalContextDecor = fwdCall->findDecoration())
- {
- // Ensure availability of the primal context var
- auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar());
- SLANG_RELEASE_ASSERT(primalContextVar);
+ {
+ auto primalContextVar = primalContextDecor->getBackwardDerivativePrimalContextVar();
+
+ auto contextLoad = builder->emitLoad(primalContextVar);
+ hoistPrimalOperands(builder, contextLoad);
- args.add(builder->emitLoad(primalContextVar));
+ args.add(contextLoad);
argTypes.add(as(
primalContextVar->getDataType())
->getValueType());
@@ -1735,6 +2211,7 @@ struct DiffTransposePass
return transposeUpdateElement(builder, fwdInst, revValue);
case kIROp_LoadReverseGradient:
+ case kIROp_ReverseGradientDiffPairRef:
case kIROp_DefaultConstruct:
case kIROp_Specialize:
case kIROp_unconditionalBranch:
@@ -2255,18 +2732,17 @@ struct DiffTransposePass
{
// current type should be a scalar.
SLANG_RELEASE_ASSERT(!as(currentType->getDataType()));
-
- auto targetVectorType = as(targetType);
- List operands;
- for (Index ii = 0; ii < as(targetVectorType->getElementCount())->getValue(); ii++)
- {
- operands.add(inst);
- }
+ return builder->emitMakeVectorFromScalar(targetType, inst);
+ }
- IRInst* newInst = builder->emitMakeVector(targetType, operands.getCount(), operands.getBuffer());
+ case kIROp_MatrixType:
+ {
+ // current type should be a scalar.
+ SLANG_RELEASE_ASSERT(!as(currentType->getDataType()) &&
+ !as(currentType->getDataType()));
- return newInst;
+ return builder->emitMakeMatrixFromScalar(targetType, inst);
}
default:
@@ -2968,6 +3444,10 @@ struct DiffTransposePass
DifferentialPairTypeBuilder pairBuilder;
+ HoistedPrimalsInfo* hoistedPrimalsInfo;
+
+ IRBlock* tempInvBlock;
+
Dictionary> gradientsMap;
Dictionary revAccumulatorVarMap;
@@ -2987,6 +3467,10 @@ struct DiffTransposePass
Dictionary> phiGradsMap;
Dictionary inverseValueMap;
+
+ List primalUsesToHoist;
+
+ Dictionary mapStoreToDefBlock;
};
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 5b59416d4..16862bb19 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -332,7 +332,12 @@ struct ExtractPrimalFuncContext
inst);
}
- IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, HashSet& primalParams, IRInst*& outIntermediateType)
+ IRFunc* turnUnzippedFuncIntoPrimalFunc(
+ IRFunc* unzippedFunc,
+ IRFunc* originalFunc,
+ HoistedPrimalsInfo* primalsInfo,
+ HashSet& primalParams,
+ IRInst*& outIntermediateType)
{
IRBuilder builder(module);
@@ -375,17 +380,9 @@ struct ExtractPrimalFuncContext
// output intermediary struct.
for (auto inst : block->getChildren())
{
- if (shouldStoreInst(inst))
+ if (primalsInfo->storeSet.Contains(inst))
{
- if (as(inst))
- builder.setInsertBefore(block->getFirstOrdinaryInst());
- else
- builder.setInsertAfter(inst);
- storeInst(builder, inst, outIntermediary);
- }
- else if (inst->getOp() == kIROp_Var)
- {
- if (shouldStoreVar(as(inst)))
+ if (as(inst))
{
auto field = addIntermediateContextField(cast(inst->getDataType())->getValueType(), outIntermediary);
builder.setInsertBefore(inst);
@@ -394,7 +391,14 @@ struct ExtractPrimalFuncContext
inst->replaceUsesWith(fieldAddr);
builder.addPrimalValueStructKeyDecoration(inst, field->getKey());
}
-
+ else
+ {
+ if (as(inst))
+ builder.setInsertBefore(block->getFirstOrdinaryInst());
+ else
+ builder.setInsertAfter(inst);
+ storeInst(builder, inst, outIntermediary);
+ }
}
}
}
@@ -459,6 +463,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
IRFunc* DiffUnzipPass::extractPrimalFunc(
IRFunc* func,
IRFunc* originalFunc,
+ HoistedPrimalsInfo* primalsInfo,
ParameterBlockTransposeInfo& paramInfo,
IRInst*& intermediateType)
{
@@ -470,6 +475,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
subEnv.parent = &cloneEnv;
auto clonedFunc = as(cloneInst(&subEnv, &builder, func));
+ auto clonedPrimalsInfo = primalsInfo->applyMap(&subEnv);
+
// Remove [KeepAlive] decorations in clonedFunc.
for (auto block : clonedFunc->getBlocks())
for (auto inst : block->getChildren())
@@ -494,7 +501,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
context.init(autodiffContext->moduleInst->getModule(), autodiffContext->transcriberSet.primalTranscriber);
intermediateType = nullptr;
- auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, newPrimalParams, intermediateType);
+ auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, clonedPrimalsInfo, newPrimalParams, intermediateType);
if (auto nameHint = primalFunc->findDecoration())
{
@@ -580,6 +587,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
{
// The primal calls should be marked as no side effect so they can be DCE'd if possible.
// We can only do so if the intermediate context of the callee is stored.
+ //
if (primalCtx->getBackwardDerivativePrimalContextVar()
->findDecoration())
{
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index f2aa1fd29..8b24b122e 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -9,6 +9,8 @@
#include "slang-ir-autodiff-fwd.h"
#include "slang-ir-autodiff-propagate.h"
#include "slang-ir-autodiff-transcriber-base.h"
+#include "slang-ir-autodiff-region.h"
+#include "slang-ir-autodiff-primal-hoist.h"
#include "slang-ir-validate.h"
#include "slang-ir-ssa.h"
@@ -36,171 +38,10 @@ struct DiffUnzipPass
// might run into an issue here?
IRBlock* firstDiffBlock;
- struct IndexedRegion : public RefObject
- {
- IRLoop* loop;
- IndexedRegion* parent;
-
- IndexedRegion(IRLoop* loop, IndexedRegion* parent) : loop(loop), parent(parent)
- { }
-
- IRBlock* getInitializerBlock() { return as(loop->getParent()); }
- IRBlock* getConditionBlock()
- {
- auto condBlock = as(loop->getTargetBlock());
- SLANG_RELEASE_ASSERT(as(condBlock->getTerminator()));
- return condBlock;
- }
-
- IRBlock* getBreakBlock() { return loop->getBreakBlock(); }
-
- IRBlock* getUpdateBlock()
- {
- auto initBlock = getInitializerBlock();
-
- auto condBlock = getConditionBlock();
-
- IRBlock* lastLoopBlock = nullptr;
-
- for (auto predecessor : condBlock->getPredecessors())
- {
- if (predecessor != initBlock)
- lastLoopBlock = predecessor;
- }
-
- // Should find atleast one predecessor that is _not_ the
- // init block (that contains the loop info). This
- // predecessor would be the last block in the loop
- // before looping back to the condition.
- //
- SLANG_RELEASE_ASSERT(lastLoopBlock);
-
- return lastLoopBlock;
- }
- };
-
-
- struct IndexedRegionMap : public RefObject
- {
- Dictionary map;
- List> regions;
-
- IndexedRegion* newRegion(IRLoop* loop, IndexedRegion* parent)
- {
- auto region = new IndexedRegion(loop, parent);
- regions.add(region);
-
- return region;
- }
-
- void mapBlock(IRBlock* block, IndexedRegion* region)
- {
- map.Add(block, region);
- }
-
- bool hasMapping(IRBlock* block)
- {
- return map.ContainsKey(block);
- }
-
- IndexedRegion* getRegion(IRBlock* block)
- {
- return map[block];
- }
-
- List getAllAncestorRegions(IRBlock* block)
- {
- List regionList;
-
- IndexedRegion* region = getRegion(block);
- for (; region; region = region->parent)
- regionList.add(region);
-
- return regionList;
- }
- };
-
- RefPtr buildIndexedRegionMap(IRGlobalValueWithCode* func)
- {
- RefPtr regionMap = new IndexedRegionMap;
-
- List workList;
-
- regionMap->mapBlock(func->getFirstBlock(), nullptr);
- workList.add(func->getFirstBlock());
-
- while (workList.getCount() > 0)
- {
- auto currentBlock = workList.getLast();
- workList.removeLast();
-
- auto terminator = currentBlock->getTerminator();
- auto currentRegion = regionMap->getRegion(currentBlock);
-
- switch (terminator->getOp())
- {
- case kIROp_loop:
- {
- auto loopRegion = regionMap->newRegion(as(terminator), currentRegion);
- auto condBlock = as(terminator)->getTargetBlock();
-
- regionMap->mapBlock(condBlock, loopRegion);
- workList.add(condBlock);
-
- auto ifElse = as(condBlock->getTerminator());
- SLANG_RELEASE_ASSERT(ifElse);
-
- // TODO: this is one of the places we'll need to change if we support loops that
- // loop on either the true or false side. For now, we assume the loop is on the
- // true side only.
- //
- regionMap->mapBlock(ifElse->getFalseBlock(), currentRegion);
- workList.add(ifElse->getFalseBlock());
- }
- }
-
- for (auto successor : currentBlock->getSuccessors())
- {
- // If already mapped, skip.
- if (regionMap->hasMapping(successor))
- continue;
- regionMap->mapBlock(successor, currentRegion);
- workList.add(successor);
- }
- }
-
- return regionMap;
- }
-
-
RefPtr indexRegionMap;
- struct IndexTrackingInfo : public RefObject
- {
- // After lowering, store references to the count
- // variables associated with this region
- //
- IRInst* primalCountParam = nullptr;
- IRInst* diffCountParam = nullptr;
-
- IRVar* primalCountLastVar = nullptr;
-
- enum CountStatus
- {
- Unresolved,
- Dynamic,
- Static
- };
-
- CountStatus status = CountStatus::Unresolved;
-
- // Inferred maximum number of iterations.
- Count maxIters = -1;
- };
-
Dictionary> indexInfoMap;
-
DiffUnzipPass(
AutoDiffSharedContext* autodiffContext)
: autodiffContext(autodiffContext)
@@ -217,7 +58,7 @@ struct DiffUnzipPass
return diffMap[inst];
}
- void unzipDiffInsts(IRFunc* func)
+ RefPtr unzipDiffInsts(IRFunc* func)
{
diffTypeContext.setFunc(func);
@@ -316,7 +157,12 @@ struct DiffUnzipPass
// Emit counter variables and other supporting
// instructions for all regions.
//
- lowerIndexedRegions();
+ // TODO: Need to have maxIndex in _both_ IndexTrackingInfo & IndexedRegionInfo.
+ // That way, we can do the various passes _before_ lowerIndexedRegions()
+ // TODO: Remove the call to lowerIndexedRegions() once checkpointing works properly.
+ //
+ RefPtr primalsInfo = new HoistedPrimalsInfo();
+ lowerIndexedRegions(primalsInfo);
// Copy regions from fwd-block to their split blocks
// to make it easier to do lookups.
@@ -338,21 +184,39 @@ struct DiffUnzipPass
indexRegionMap->map[as(diffMap[block])] = (IndexedRegion*)indexRegionMap->map[block];
}
}
+
+ // Swap the first block's occurences out for the first primal block.
+ firstBlock->replaceUsesWith(firstPrimalBlock);
- // Process intermediate insts in indexed blocks
- // into array loads/stores.
- //
+ RefPtr splitInfo = new BlockSplitInfo();
+ for (auto block : mixedBlocks)
+ if (primalMap.ContainsKey(block))
+ splitInfo->diffBlockMap[as(primalMap[block])] = as(diffMap[block]);
+
+ Dictionary> indexedBlocksInfo;
for (auto block : mixedBlocks)
{
- if (indexRegionMap->getRegion(block) != nullptr)
- processIndexedFwdBlock(block);
+ indexedBlocksInfo[as(diffMap[block])] = getIndexInfoList(as(diffMap[block]));
+ indexedBlocksInfo[as(primalMap[block])] = getIndexInfoList(as(primalMap[block]));
}
-
- // Swap the first block's occurences out for the first primal block.
- firstBlock->replaceUsesWith(firstPrimalBlock);
for (auto block : mixedBlocks)
block->removeAndDeallocate();
+
+ // Run the three checkpointing passes to hoist/clone primal insts
+ // to the right spots.
+ //
+ {
+ RefPtr chkPolicy = new DefaultCheckpointPolicy(unzippedFunc->getModule());
+ chkPolicy->preparePolicy(func);
+
+ auto chkPrimalsInfo = chkPolicy->processFunc(func, splitInfo);
+ primalsInfo->merge(chkPrimalsInfo);
+
+ primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlocksInfo);
+ }
+
+ return primalsInfo;
}
void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info)
@@ -375,42 +239,6 @@ struct DiffUnzipPass
}
}
- UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
- {
- SLANG_RELEASE_ASSERT(as(block->getTerminator()));
-
- auto branchInst = as(block->getTerminator());
- List phiArgs;
-
- for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++)
- phiArgs.add(branchInst->getArg(ii));
-
- phiArgs.add(arg);
-
- builder->setInsertInto(block);
- switch (branchInst->getOp())
- {
- case kIROp_unconditionalBranch:
- builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer());
- break;
-
- case kIROp_loop:
- builder->emitLoop(
- as(branchInst)->getTargetBlock(),
- as(branchInst)->getBreakBlock(),
- as(branchInst)->getContinueBlock(),
- phiArgs.getCount(),
- phiArgs.getBuffer());
- break;
-
- default:
- break;
- }
-
- branchInst->removeAndDeallocate();
- return phiArgs.getCount() - 1;
- }
-
IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type)
{
builder->setInsertInto(block);
@@ -428,27 +256,7 @@ struct DiffUnzipPass
return addPhiInputParam(builder, block, type);
}
- IRBlock* getBlock(IRInst* inst)
- {
- SLANG_RELEASE_ASSERT(inst);
-
- if (auto block = as(inst))
- return block;
-
- return getBlock(inst->getParent());
- }
-
- IRInst* getInstInBlock(IRInst* inst)
- {
- SLANG_RELEASE_ASSERT(inst);
-
- if (auto block = as(inst->getParent()))
- return inst;
-
- return getInstInBlock(inst->getParent());
- }
-
- void lowerIndexedRegions()
+ void lowerIndexedRegions(HoistedPrimalsInfo* primalsInfo)
{
IRBuilder builder(autodiffContext->moduleInst->getModule());
@@ -464,6 +272,7 @@ struct DiffUnzipPass
// Make variable in the top-most block (so it's visible to diff blocks)
info->primalCountLastVar = builder.emitVar(builder.getIntType());
builder.addNameHintDecoration(info->primalCountLastVar, UnownedStringSlice("_pc_last_var"));
+ primalsInfo->storeSet.Add(info->primalCountLastVar);
{
auto primalCondBlock = as(
@@ -546,6 +355,23 @@ struct DiffUnzipPass
builder.addPrimalValueAccessDecoration(primalCounterLastVal);
builder.addLoopExitPrimalValueDecoration(loopInst, info->diffCountParam, primalCounterLastVal);
+
+ // We'll be manually creating the inversion entries for the counters
+ // TODO: This logic can be moved to the checkpointing alg.
+ //
+ primalsInfo->invertSet.Add(info->diffCountParam);
+ primalsInfo->instsToInvert.Add(incCounterVal);
+ primalsInfo->invertInfoMap[incCounterVal] = InversionInfo(
+ incCounterVal,
+ List(incCounterVal),
+ List(info->diffCountParam));
+
+ primalsInfo->invertSet.Add(incCounterVal);
+ primalsInfo->instsToInvert.Add(diffUpdateBlock->getTerminator());
+ primalsInfo->invertInfoMap[diffUpdateBlock->getTerminator()] = InversionInfo(
+ diffUpdateBlock->getTerminator(),
+ List(diffUpdateBlock->getTerminator()),
+ List(incCounterVal));
}
// Try to infer maximum possible number of iterations.
@@ -576,282 +402,10 @@ struct DiffUnzipPass
return indices;
}
- void processIndexedFwdBlock(IRBlock* fwdBlock)
- {
- // Grab first primal block.
- IRBlock* firstPrimalBlock = as