summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp211
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h18
-rw-r--r--source/slang/slang-ir-util.cpp94
-rw-r--r--source/slang/slang-ir.cpp7
-rw-r--r--tests/autodiff/reverse-continue-loop.slang2
-rw-r--r--tests/autodiff/reverse-loop-immediate-return.slang59
7 files changed, 336 insertions, 57 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index b5ac784ce..c2403e53b 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -281,6 +281,142 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
return recomputeBlockMap;
}
+// Checks if list A is a subset of list B by comparing their primal count parameters.
+//
+// Parameters:
+// indicesA - First list of IndexTrackingInfo to compare
+// indicesB - Second list of IndexTrackingInfo to compare
+//
+// Returns:
+// true if all indices in indicesA are present in indicesB, false otherwise
+//
+bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB)
+{
+ if (indicesA.getCount() > indicesB.getCount())
+ return false;
+
+ for (Index ii = 0; ii < indicesA.getCount(); ii++)
+ {
+ if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam)
+ return false;
+ }
+
+ return true;
+}
+
+bool canInstBeStored(IRInst* inst)
+{
+ // Cannot store insts whose value is a type or a witness table, or a function.
+ // These insts get lowered to target-specific logic, and cannot be
+ // stored into variables or context structs as normal values.
+ //
+ if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) ||
+ as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) ||
+ !inst->getDataType())
+ return false;
+
+ return true;
+}
+
+// This is a helper that converts insts in a loop condition block into two if necessary,
+// then replaces all uses 'outside' the loop region with the new insts. This is because
+// insts in loop condition blocks can be used in two distinct regions (the loop body, and
+// after the loop).
+//
+// We'll use CheckpointObject for the splitting, which is allowed on any value-typed inst.
+//
+void splitLoopConditionBlockInsts(
+ IRGlobalValueWithCode* func,
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo)
+{
+ // RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
+
+ // Collect primal loop condition blocks, and map differential blocks to their primal blocks.
+ List<IRBlock*> loopConditionBlocks;
+ Dictionary<IRBlock*, IRBlock*> diffBlockMap;
+ for (auto block : func->getBlocks())
+ {
+ if (auto loop = as<IRLoop>(block->getTerminator()))
+ {
+ auto loopConditionBlock = getLoopConditionBlock(loop);
+ if (isDifferentialBlock(loopConditionBlock))
+ {
+ auto diffDecor = loopConditionBlock->findDecoration<IRDifferentialInstDecoration>();
+ diffBlockMap[cast<IRBlock>(diffDecor->getPrimalInst())] = loopConditionBlock;
+ }
+ else
+ loopConditionBlocks.add(loopConditionBlock);
+ }
+ }
+
+ // For each loop condition block, split the insts that are used in both the loop body and
+ // after the loop.
+ // Use the dominator tree to find uses of insts outside the loop body
+ //
+ // Essentially we want to split the uses dominated by the true block and the false block of the
+ // condition.
+ //
+ IRBuilder builder(func->getModule());
+
+
+ List<IRUse*> loopUses;
+ List<IRUse*> afterLoopUses;
+
+ for (auto condBlock : loopConditionBlocks)
+ {
+ // For each inst in the primal condition block, check if it has uses inside the loop body
+ // as well as outside of it. (Use the indexedBlockInfo to perform the teets)
+ //
+ for (auto inst = condBlock->getFirstInst(); inst; inst = inst->getNextInst())
+ {
+ // Skip terminators and insts that can't be stored
+ if (as<IRTerminatorInst>(inst) || !canInstBeStored(inst))
+ continue;
+ // Shouldn't see any vars.
+ SLANG_ASSERT(!as<IRVar>(inst));
+
+ // Get the indices for the condition block
+ auto& condBlockIndices = indexedBlockInfo[condBlock];
+
+ loopUses.clear();
+ afterLoopUses.clear();
+
+ // Check all uses of this inst
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto userBlock = getBlock(use->getUser());
+ auto& userBlockIndices = indexedBlockInfo[userBlock];
+
+ // If all of the condBlock's indices are a subset of the userBlock's indices,
+ // then the userBlock is inside the loop.
+ //
+ bool isInLoop = areIndicesSubsetOf(condBlockIndices, userBlockIndices);
+
+ if (isInLoop)
+ loopUses.add(use);
+ else
+ afterLoopUses.add(use);
+ }
+
+ // If inst has uses both inside and after the loop, create a copy for after-loop uses
+ if (loopUses.getCount() > 0 && afterLoopUses.getCount() > 0)
+ {
+ setInsertAfterOrdinaryInst(&builder, inst);
+ auto copy = builder.emitCheckpointObject(inst);
+
+ // Copy source location so that checkpoint reporting is accurate
+ copy->sourceLoc = inst->sourceLoc;
+
+ // Replace after-loop uses with the copy
+ for (auto use : afterLoopUses)
+ {
+ builder.replaceOperand(use, copy);
+ }
+ }
+ }
+ }
+}
+
RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
IRGlobalValueWithCode* func,
Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock,
@@ -1297,20 +1433,6 @@ bool areIndicesEqual(
return true;
}
-bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB)
-{
- if (indicesA.getCount() > indicesB.getCount())
- return false;
-
- for (Index ii = 0; ii < indicesA.getCount(); ii++)
- {
- if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam)
- return false;
- }
-
- return true;
-}
-
static int getInstRegionNestLevel(
Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo,
IRBlock* defBlock,
@@ -1510,21 +1632,6 @@ static List<IndexTrackingInfo> maybeTrimIndices(
return result;
}
-bool canInstBeStored(IRInst* inst)
-{
- // Cannot store insts whose value is a type or a witness table, or a function.
- // These insts get lowered to target-specific logic, and cannot be
- // stored into variables or context structs as normal values.
- //
- if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) ||
- as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) ||
- !inst->getDataType())
- return false;
-
- return true;
-}
-
-
/// Legalizes all accesses to primal insts from recompute and diff blocks.
///
RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
@@ -2104,6 +2211,39 @@ void buildIndexedBlocks(
}
}
+// This function simply turns all CheckpointObject insts into a 'no-op'.
+// i.e. simply replaces all uses of CheckpointObject with the original value.
+//
+// This operation is 'correct' because if CheckpointObject's operand is visible
+// in a block, then it is visible in all dominated blocks.
+//
+void lowerCheckpointObjectInsts(IRGlobalValueWithCode* func)
+{
+ // For each block in the function
+ for (auto block : func->getBlocks())
+ {
+ // For each instruction in the block
+ for (auto inst = block->getFirstInst(); inst;)
+ {
+ // Get next inst before potentially removing current one
+ auto nextInst = inst->getNextInst();
+
+ // Check if this is a CheckpointObject instruction
+ if (auto copyInst = as<IRCheckpointObject>(inst))
+ {
+ // Replace all uses of the copy with the original value
+ auto originalVal = copyInst->getVal();
+ copyInst->replaceUsesWith(originalVal);
+
+ // Remove the now unused copy instruction
+ inst->removeAndDeallocate();
+ }
+
+ inst = nextInst;
+ }
+ }
+}
+
// For each primal inst that is used in reverse blocks, decide if we should recompute or store
// its value, then make them accessible in reverse blocks based the decision.
//
@@ -2117,6 +2257,9 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo;
buildIndexedBlocks(indexedBlockInfo, func);
+ // Split loop condition insts into two if necessary.
+ splitLoopConditionBlockInsts(func, indexedBlockInfo);
+
// Create recompute blocks for each region following the same control flow structure
// as in primal code.
//
@@ -2136,7 +2279,12 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
// Legalize the primal inst accesses by introducing local variables / arrays and emitting
// necessary load/store logic.
//
- return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
+ auto hoistedPrimalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
+
+ // Lower CheckpointObject insts to a no-op.
+ lowerCheckpointObjectInsts(func);
+
+ return hoistedPrimalsInfo;
}
void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func)
@@ -2312,6 +2460,9 @@ static bool shouldStoreInst(IRInst* inst)
break;
}
+ case kIROp_CheckpointObject:
+ // Special inst for when a value must be stored.
+ return true;
default:
break;
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 5a1966d00..9ffaeeeb9 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -716,6 +716,8 @@ INST(BitNot, bitnot, 1, 0)
INST(Select, select, 3, 0)
+INST(CheckpointObject, checkpointObj, 1, 0)
+
INST(GetStringHash, getStringHash, 1, 0)
INST(WaveGetActiveMask, waveGetActiveMask, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index d64820aa6..7c975cfcd 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2664,6 +2664,22 @@ struct IRDiscard : IRTerminatorInst
{
};
+// Used for representing a distinct copy of an object.
+// This will get lowered into a no-op in the backend,
+// but is useful for IR transformations that need to consider
+// different uses of an inst separately.
+//
+// For example, when we hoist primal insts out of a loop,
+// we need to make distinct copies of the inst for its uses
+// within the loop body and outside of it.
+//
+struct IRCheckpointObject : IRInst
+{
+ IR_LEAF_ISA(CheckpointObject);
+
+ IRInst* getVal() { return getOperand(0); }
+};
+
// Signals that this point in the code should be unreachable.
// We can/should emit a dataflow error if we can ever determine
// that a block ending in one of these can actually be
@@ -4408,6 +4424,8 @@ public:
IRInst* emitDiscard();
+ IRInst* emitCheckpointObject(IRInst* value);
+
IRInst* emitUnreachable();
IRInst* emitMissingReturn();
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index bf5b25d9c..39c1c5bb1 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -2078,9 +2078,36 @@ Int getSpecializationConstantId(IRGlobalParam* param)
return offset->getOffset();
}
+IRBlock* getLoopHeaderForConditionBlock(IRBlock* block)
+{
+ // Go through uses and check if any of them are a loop condition block.
+ for (auto use = block->firstUse; use; use = use->nextUse)
+ {
+ if (auto loop = as<IRLoop>(use->getUser()))
+ {
+ if (loop->getTargetBlock() == block)
+ return cast<IRBlock>(loop->getParent());
+ }
+ }
+ return nullptr;
+}
+
void legalizeDefUse(IRGlobalValueWithCode* func)
{
auto dom = computeDominatorTree(func);
+
+ // Make a map of loop condition blocks to their loop header.
+ // We need this because we'll be treating loop condition blocks as
+ // special cases (they are the special blocks since they "dominate" themselves,
+ // in the dominator tree sense)
+ //
+ Dictionary<IRBlock*, IRBlock*> loopHeaderBlockMap;
+ for (auto block : func->getBlocks())
+ {
+ if (auto header = getLoopHeaderForConditionBlock(block))
+ loopHeaderBlockMap.add(block, header);
+ }
+
for (auto block : func->getBlocks())
{
for (auto inst : block->getModifiableChildren())
@@ -2099,16 +2126,22 @@ void legalizeDefUse(IRGlobalValueWithCode* func)
}
SLANG_ASSERT(commonDominator);
- if (commonDominator == block)
+ // If commonDominator is 'block' and if the inst is not a Var in
+ // a loop condition block, we can skip the legalization.
+ //
+ if (commonDominator == block &&
+ !(as<IRVar>(inst) && loopHeaderBlockMap.containsKey(block)))
continue;
- // If the common dominator is not `block`, it means we have detected
- // uses that is no longer dominated by the current definition, and need
- // to be fixed.
-
- // Normally, we can simply move the definition to the common dominator.
+ // Normally, if the common dominator is not `block`, we can simply move the definition
+ // to the common dominator.
// An exception is when the common dominator is the target block of a
- // loop. Note that after normalization, loops are in the form of:
+ // loop.
+ // Another exception is when a var in the loop condition block is accessed both inside
+ // and outside the loop. It is technically visible, but effects on the 'var' are not
+ // visible outside the loop, so we'll need to hoist it out of the loop.
+ //
+ // Note that after normalization, loops are in the form of:
// ```
// loop { if (condition) block; else break; }
// ```
@@ -2117,38 +2150,47 @@ void legalizeDefUse(IRGlobalValueWithCode* func)
// In this case, we should insert a var/move the inst before the loop
// instead of before the `if`. This situation can occur in the IR if
// the original code is lowered from a `do-while` loop.
- for (auto use = commonDominator->firstUse; use; use = use->nextUse)
+ //
+ bool shouldInitializeVar = false;
+ if (loopHeaderBlockMap.containsKey(commonDominator))
{
- if (auto loopUser = as<IRLoop>(use->getUser()))
+ bool shouldMoveToHeader = false;
+
+ // Check that the break-block dominates any of the uses are past the break
+ // block
+ for (auto _use = inst->firstUse; _use; _use = _use->nextUse)
{
- if (loopUser->getTargetBlock() == commonDominator)
+ if (dom->dominates(
+ as<IRLoop>(loopHeaderBlockMap[commonDominator]->getTerminator())
+ ->getBreakBlock(),
+ _use->getUser()->getParent()))
{
- bool shouldMoveToHeader = false;
- // Check that the break-block dominates any of the uses are past the break
- // block
- for (auto _use = inst->firstUse; _use; _use = _use->nextUse)
- {
- if (dom->dominates(
- loopUser->getBreakBlock(),
- _use->getUser()->getParent()))
- {
- shouldMoveToHeader = true;
- break;
- }
- }
-
- if (shouldMoveToHeader)
- commonDominator = as<IRBlock>(loopUser->getParent());
+ shouldMoveToHeader = true;
break;
}
}
+ if (shouldMoveToHeader)
+ {
+ commonDominator = loopHeaderBlockMap[commonDominator];
+ shouldInitializeVar = true;
+ }
}
+
// Now we can legalize uses based on the type of `inst`.
if (auto var = as<IRVar>(inst))
{
// If inst is an var, this is easy, we just move it to the
// common dominator.
var->insertBefore(commonDominator->getTerminator());
+ if (shouldInitializeVar)
+ {
+ IRBuilder builder(func);
+ builder.setInsertAfter(var);
+ builder.emitStore(
+ var,
+ builder.emitDefaultConstruct(
+ as<IRPtrTypeBase>(var->getDataType())->getValueType()));
+ }
}
else
{
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index fb274c4a0..3a7ace37d 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5664,6 +5664,13 @@ IRInst* IRBuilder::emitDiscard()
return inst;
}
+IRInst* IRBuilder::emitCheckpointObject(IRInst* value)
+{
+ auto inst =
+ createInst<IRCheckpointObject>(this, kIROp_CheckpointObject, value->getFullType(), value);
+ addInst(inst);
+ return inst;
+}
IRInst* IRBuilder::emitBranch(IRBlock* pBlock)
{
diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang
index 2dfad0a61..51f17b611 100644
--- a/tests/autodiff/reverse-continue-loop.slang
+++ b/tests/autodiff/reverse-continue-loop.slang
@@ -16,7 +16,7 @@ float test_loop_with_continue(float y)
//CHK-DAG: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item:
float t = y;
- //CHK-DAG: note: 4 bytes (int32_t) used for a loop counter here:
+ //CHK-DAG: note: 4 bytes (int32_t) used to checkpoint the following item:
for (int i = 0; i < 3; i++)
{
if (t > 4.0)
diff --git a/tests/autodiff/reverse-loop-immediate-return.slang b/tests/autodiff/reverse-loop-immediate-return.slang
new file mode 100644
index 000000000..121836115
--- /dev/null
+++ b/tests/autodiff/reverse-loop-immediate-return.slang
@@ -0,0 +1,59 @@
+
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+
+[BackwardDerivative(set_bwd)]
+void set(uint idx, float x)
+{
+ outputBuffer[idx] = x;
+}
+
+void set_bwd(uint idx, inout DifferentialPair<float> x)
+{
+ // For debugging, we'll set the derivative to 1.0
+ x = DifferentialPair<float>(x.p, 1.0f);
+}
+
+[Differentiable]
+void run(
+ uint idx,
+ float x)
+{
+ if (idx >= 1) return;
+
+ if (idx == 0)
+ { }
+
+ for (int i = 0; i < 1; i++)
+ {
+ if (idx > 0)
+ {
+ return;
+ }
+
+ if (idx == 0)
+ {
+ x = x * 2.0f;
+ }
+ }
+
+ if (idx == 0)
+ { }
+
+ set(idx, x);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // bwd_diff
+ DifferentialPair<float> dpa = DifferentialPair<float>(1.0, 0.0);
+ bwd_diff(run)(dispatchThreadID.x, dpa);
+ outputBuffer[dispatchThreadID.x] = dpa.d;
+
+ // CHECK: type: float
+ // CHECK: 2.0
+}