summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/core/slang-dictionary.h2
-rw-r--r--source/core/slang-hash.h2
-rw-r--r--source/slang/diff.meta.slang9
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp362
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp10
-rw-r--r--source/slang/slang-ir-autodiff-rev.h2
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp100
-rw-r--r--source/slang/slang-ir-autodiff.cpp16
-rw-r--r--source/slang/slang-ir-dce.cpp231
-rw-r--r--source/slang/slang-ir-dce.h2
-rw-r--r--source/slang/slang-ir-inst-defs.h12
-rw-r--r--source/slang/slang-ir-insts.h26
-rw-r--r--source/slang/slang-ir-layout.cpp6
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp5
-rw-r--r--source/slang/slang-ir-util.cpp35
-rw-r--r--source/slang/slang-ir.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
-rw-r--r--tests/autodiff/high-order-backward-diff-3.slang5
-rw-r--r--tests/autodiff/high-order-backward-diff-4.slang6
-rw-r--r--tests/autodiff/path-tracer/pt-loop.slang4
-rw-r--r--tests/autodiff/reverse-continue-loop.slang10
-rw-r--r--tests/autodiff/reverse-control-flow-3.slang2
-rw-r--r--tests/autodiff/reverse-loop-checkpoint-test.slang7
-rw-r--r--tests/autodiff/reverse-loop.slang2
-rw-r--r--tests/autodiff/reverse-nested-calls.slang4
-rw-r--r--tests/autodiff/test-minimal-context.slang76
26 files changed, 769 insertions, 171 deletions
diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h
index 9c445c3c9..9c753ffcd 100644
--- a/source/core/slang-dictionary.h
+++ b/source/core/slang-dictionary.h
@@ -9,7 +9,7 @@
#include "slang-math.h"
#include "slang-hash.h"
-#include <ankerl/unordered_dense.h>
+#include "../../external/unordered_dense/include/ankerl/unordered_dense.h"
#include <initializer_list>
diff --git a/source/core/slang-hash.h b/source/core/slang-hash.h
index 3e255f7db..ff0cdc181 100644
--- a/source/core/slang-hash.h
+++ b/source/core/slang-hash.h
@@ -4,7 +4,7 @@
#include "../../include/slang.h"
#include "slang-math.h"
-#include <ankerl/unordered_dense.h>
+#include "../../external/unordered_dense/include/ankerl/unordered_dense.h"
#include <cstring>
#include <type_traits>
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index b39d91494..6042ff5cc 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -855,17 +855,20 @@ struct DiffTensorView
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x)));
}
+ [ForceInline]
__generic<let N : int>
DifferentialPair<T> __load_forward(vector<uint, N> x)
{
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T, N>(x)));
}
+ [ForceInline]
void __load_backward(uint x, T.Differential dOut)
{
diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut));
}
+ [ForceInline]
__generic<let N : int>
void __load_backward(vector<uint, N> x, T.Differential dOut)
{
@@ -894,11 +897,13 @@ struct DiffTensorView
diff.store_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d));
}
+ [ForceInline]
void __store_backward(uint x, inout DifferentialPair<T> dpval)
{
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x)));
}
+ [ForceInline]
__generic<let N : int>
void __store_backward(vector<uint, N> x, inout DifferentialPair<T> dpval)
{
@@ -999,11 +1004,13 @@ struct DiffTensorView
return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.loadOnce_forward<T, N>(x)));
}
+ [ForceInline]
void __loadOnce_backward(uint x, T.Differential dOut)
{
diff.loadOnce_backward<T>(x, reinterpret<T, T.Differential>(dOut));
}
+ [ForceInline]
__generic<let N : int>
void __loadOnce_backward(vector<uint, N> x, T.Differential dOut)
{
@@ -1032,11 +1039,13 @@ struct DiffTensorView
diff.storeOnce_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d));
}
+ [ForceInline]
void __storeOnce_backward(uint x, inout DifferentialPair<T> dpval)
{
dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.storeOnce_backward<T>(x)));
}
+ [ForceInline]
__generic<let N : int>
void __storeOnce_backward(vector<uint, N> x, inout DifferentialPair<T> dpval)
{
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 2881abe3e..fcc1f95ee 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -359,135 +359,155 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
while (workList.getCount() > 0)
{
- auto use = workList.getLast();
- workList.removeLast();
+ while (workList.getCount() > 0)
+ {
+ auto use = workList.getLast();
+ workList.removeLast();
- if (processedUses.contains(use))
- continue;
+ if (processedUses.contains(use))
+ continue;
- processedUses.add(use);
+ processedUses.add(use);
- HoistResult result = this->classify(use);
+ HoistResult result = this->classify(use);
- if (result.mode == HoistResult::Mode::Store)
- {
- SLANG_ASSERT(!checkpointInfo->recomputeSet.contains(result.instToStore));
- checkpointInfo->storeSet.add(result.instToStore);
- }
- else if (result.mode == HoistResult::Mode::Recompute)
- {
- SLANG_ASSERT(!checkpointInfo->storeSet.contains(result.instToRecompute));
- checkpointInfo->recomputeSet.add(result.instToRecompute);
+ if (result.mode == HoistResult::Mode::Store)
+ {
+ SLANG_ASSERT(!checkpointInfo->recomputeSet.contains(result.instToStore));
+ checkpointInfo->storeSet.add(result.instToStore);
+ }
+ else if (result.mode == HoistResult::Mode::Recompute)
+ {
+ SLANG_ASSERT(!checkpointInfo->storeSet.contains(result.instToRecompute));
+ checkpointInfo->recomputeSet.add(result.instToRecompute);
- if (isDifferentialInst(use.user) && use.irUse)
- usesToReplace.add(use.irUse);
+ if (isDifferentialInst(use.user) && use.irUse)
+ usesToReplace.add(use.irUse);
- if (auto param = as<IRParam>(result.instToRecompute))
- {
- if (auto inductionInfo = inductionValueInsts.tryGetValue(param))
+ if (auto param = as<IRParam>(result.instToRecompute))
{
- checkpointInfo->loopInductionInfo.addIfNotExists(param, *inductionInfo);
- continue;
- }
+ if (auto inductionInfo = inductionValueInsts.tryGetValue(param))
+ {
+ checkpointInfo->loopInductionInfo.addIfNotExists(param, *inductionInfo);
+ continue;
+ }
- // Add in the branch-args of every predecessor block.
- auto paramBlock = as<IRBlock>(param->getParent());
- UIndex paramIndex = 0;
- for (auto _param : paramBlock->getParams())
- {
- if (_param == param) break;
- paramIndex ++;
- }
+ // Add in the branch-args of every predecessor block.
+ auto paramBlock = as<IRBlock>(param->getParent());
+ UIndex paramIndex = 0;
+ for (auto _param : paramBlock->getParams())
+ {
+ if (_param == param) break;
+ paramIndex ++;
+ }
- for (auto predecessor : paramBlock->getPredecessors())
- {
- // If we hit this, the checkpoint policy is trying to recompute
- // values across a loop region boundary (we don't currently support this,
- // and in general this is quite inefficient in both compute & memory)
- //
- SLANG_RELEASE_ASSERT(!domTree->dominates(paramBlock, predecessor));
+ for (auto predecessor : paramBlock->getPredecessors())
+ {
+ // If we hit this, the checkpoint policy is trying to recompute
+ // values across a loop region boundary (we don't currently support this,
+ // and in general this is quite inefficient in both compute & memory)
+ //
+ SLANG_RELEASE_ASSERT(!domTree->dominates(paramBlock, predecessor));
- auto branchInst = as<IRUnconditionalBranch>(predecessor->getTerminator());
- SLANG_ASSERT(branchInst->getOperandCount() > paramIndex);
+ auto branchInst = as<IRUnconditionalBranch>(predecessor->getTerminator());
+ SLANG_ASSERT(branchInst->getOperandCount() > paramIndex);
- workList.add(&branchInst->getArgs()[paramIndex]);
+ workList.add(&branchInst->getArgs()[paramIndex]);
+ }
}
- }
- else
- {
- if (auto var = as<IRVar>(result.instToRecompute))
+ else
{
- for (auto varUse = var->firstUse; varUse; varUse = varUse->nextUse)
+ if (auto var = as<IRVar>(result.instToRecompute))
{
- switch (varUse->getUser()->getOp())
+ for (auto varUse = var->firstUse; varUse; varUse = varUse->nextUse)
{
- case kIROp_Store:
- case kIROp_Call:
- // When we have a var and a store/call insts that writes to the var,
- // we treat as if there is a pseudo-use of the store/call to compute
- // the var inst, i.e. the var depends on the store/call, despite
- // the IR's def-use chain doesn't reflect this.
- workList.add(UseOrPseudoUse(var, varUse->getUser()));
- break;
+ switch (varUse->getUser()->getOp())
+ {
+ case kIROp_Store:
+ case kIROp_Call:
+ // When we have a var and a store/call insts that writes to the var,
+ // we treat as if there is a pseudo-use of the store/call to compute
+ // the var inst, i.e. the var depends on the store/call, despite
+ // the IR's def-use chain doesn't reflect this.
+ workList.add(UseOrPseudoUse(var, varUse->getUser()));
+ break;
+ }
}
}
- }
- else
- {
- addPrimalOperandsToWorkList(result.instToRecompute);
+ else
+ {
+ addPrimalOperandsToWorkList(result.instToRecompute);
+ }
}
}
}
- }
- // If a var or call is in recomputeSet, move any var/calls associated with the same call to
- // recomputeSet.
- List<IRInst*> instWorkList;
- HashSet<IRInst*> instWorkListSet;
- for (auto inst : checkpointInfo->recomputeSet)
- {
- switch (inst->getOp())
+ // If a var or call is in recomputeSet, move any var/calls associated with the same call to
+ // recomputeSet.
+ // This is a bit of a 'retro-active' analysis where we go back on processed insts and
+ // correct them.
+ //
+ List<IRInst*> callVarWorkList;
+ HashSet<IRInst*> callVarWorkListSet;
+ for (auto inst : checkpointInfo->recomputeSet)
{
- case kIROp_Call:
- case kIROp_Var:
- instWorkList.add(inst);
- instWorkListSet.add(inst);
- break;
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ case kIROp_Var:
+ callVarWorkList.add(inst);
+ callVarWorkListSet.add(inst);
+ break;
+ }
}
- }
- for (Index i = 0; i < instWorkList.getCount(); i++)
- {
- auto inst = instWorkList[i];
- if (auto var = as<IRVar>(inst))
+
+ for (Index i = 0; i < callVarWorkList.getCount(); i++)
{
- for (auto use = var->firstUse; use; use = use->nextUse)
+ auto inst = callVarWorkList[i];
+ if (auto var = as<IRVar>(inst))
{
- if (auto callUser = as<IRCall>(use->getUser()))
+ for (auto use = var->firstUse; use; use = use->nextUse)
{
- checkpointInfo->recomputeSet.add(callUser);
- checkpointInfo->storeSet.remove(callUser);
- if (instWorkListSet.add(callUser))
- instWorkList.add(callUser);
- }
- else if (auto storeUser = as<IRStore>(use->getUser()))
- {
- checkpointInfo->recomputeSet.add(storeUser);
- checkpointInfo->storeSet.remove(storeUser);
- if (instWorkListSet.add(callUser))
- instWorkList.add(callUser);
+ if (auto callUser = as<IRCall>(use->getUser()))
+ {
+ checkpointInfo->recomputeSet.add(callUser);
+ checkpointInfo->storeSet.remove(callUser);
+ if (callVarWorkListSet.add(callUser))
+ callVarWorkList.add(callUser);
+ }
+ else if (auto storeUser = as<IRStore>(use->getUser()))
+ {
+ checkpointInfo->recomputeSet.add(storeUser);
+ checkpointInfo->storeSet.remove(storeUser);
+ if (callVarWorkListSet.add(callUser))
+ callVarWorkList.add(callUser);
+ }
}
}
- }
- else if (auto call = as<IRCall>(inst))
- {
- for (UInt j = 0; j < call->getArgCount(); j++)
+ else if (auto call = as<IRCall>(inst))
{
- if (auto varArg = as<IRVar>(call->getArg(j)))
+ for (UInt j = 0; j < call->getArgCount(); j++)
+ {
+ if (auto varArg = as<IRVar>(call->getArg(j)))
+ {
+ checkpointInfo->recomputeSet.add(varArg);
+ checkpointInfo->storeSet.remove(varArg);
+ if (callVarWorkListSet.add(varArg))
+ callVarWorkList.add(varArg);
+ }
+ }
+
+ // This next few lines are a bit of a hack.. ideally we need to add the call to the main worklist
+ // for processing, so we don't have to repeat the recomputationn actions.
+ //
+ auto calleeUse = &call->getOperands()[0];
+ if (!as<IRModuleInst>(calleeUse->get()->getParent()) && !processedUses.contains(calleeUse))
+ addPrimalOperandsToWorkList(call);
+
+ for (auto use = call->firstUse; use; use = use->nextUse)
{
- checkpointInfo->recomputeSet.add(varArg);
- checkpointInfo->storeSet.remove(varArg);
- if (instWorkListSet.add(varArg))
- instWorkList.add(varArg);
+ if (isDifferentialInst(use->getUser()))
+ usesToReplace.add(use);
}
}
}
@@ -1048,14 +1068,6 @@ void applyCheckpointSet(
// primal logic may have already updated the param with a new value, and instead we
// want the original value.
builder.setInsertBefore(recomputeInsertBeforeInst);
- if (auto load = as<IRLoad>(child))
- {
- if (load->getPtr()->getOp() == kIROp_Param &&
- load->getPtr()->getParent() == func->getFirstBlock())
- {
- builder.setInsertBefore(getParamPreludeBlock(func)->getTerminator());
- }
- }
applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, blockIndexInfo, child);
}
}
@@ -1331,15 +1343,17 @@ struct UseChain
IRBuilder chainBuilder(builder->getModule());
setInsertAfterOrdinaryInst(&chainBuilder, inst);
-
+
+ chain.reverse();
+
// Clone the rest of the chain.
for (auto& use : chain)
{
- lastInstInChain = cloneInst(&env, &chainBuilder, use->getUser());
+ lastInstInChain = cloneInst(&env, &chainBuilder, use->get());
}
// Replace the base use.
- builder->replaceOperand(baseUse, lastInstInChain);
+ builder->replaceOperand(chain.getLast(), lastInstInChain);
chain.clear();
}
@@ -1347,7 +1361,7 @@ struct UseChain
IRInst* getUser() const
{
SLANG_ASSERT(chain.getCount() > 0);
- return chain.getLast()->getUser();
+ return chain.getFirst()->getUser();
}
};
@@ -1385,11 +1399,14 @@ static List<IndexTrackingInfo> maybeTrimIndices(
bool canInstBeStored(IRInst* inst)
{
- // Cannot store insts whose value is a type or a witness table.
+ // Cannot store insts whose value is a type or a witness table, or a function.
// These insts get lowered to target-specific logic, and cannot be
// stored into variables or context structs as normal values.
//
- if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()))
+ if (as<IRTypeType>(inst->getDataType()) ||
+ as<IRWitnessTableType>(inst->getDataType()) ||
+ as<IRTypeKind>(inst->getDataType()) ||
+ as<IRFuncType>(inst->getDataType()))
return false;
return true;
@@ -1499,18 +1516,15 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
SLANG_RELEASE_ASSERT(defBlock);
- List<UseChain> outOfScopeUses;
- for (auto use = instToStore->firstUse; use;)
+ // Lambda to check if a use is relevant.
+ auto isRelevantUse = [&](IRUse* use)
{
- auto nextUse = use->nextUse;
-
- // Lambda to check if a use is relevant.
- auto isRelevantUse = [&](IRUse* use)
+ // Only consider uses in differential blocks.
+ // This method is not responsible for other blocks.
+ //
+ IRBlock* userBlock = getBlock(use->getUser());
+ if (isRecomputeInst)
{
- // Only consider uses in differential blocks.
- // This method is not responsible for other blocks.
- //
- IRBlock* userBlock = getBlock(use->getUser());
if (isDifferentialOrRecomputeBlock(userBlock))
{
if (!domTree->dominates(defBlock, userBlock))
@@ -1532,16 +1546,37 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
return true;
}
}
- return false;
- };
+ }
+ else
+ {
+ if (isDifferentialOrRecomputeBlock(userBlock))
+ return true;
+ }
+ return false;
+ };
- // Lambda to check if an inst is transparent. We lookup uses 'through' transparent
- // insts recursively.
- //
- auto isPassthroughInst = [&](IRInst* inst)
+ // Lambda to check if an inst is transparent. We lookup uses 'through' transparent
+ // insts recursively.
+ //
+ auto isPassthroughInst = [&](IRInst* inst)
+ {
+ if (!canInstBeStored(inst))
+ return true;
+
+ switch (inst->getOp())
{
- return !canInstBeStored(inst);
- };
+ case kIROp_GetSequentialID:
+ case kIROp_ExtractExistentialValue:
+ return true;
+ }
+
+ return false;
+ };
+
+ List<UseChain> outOfScopeUses;
+ for (auto use = instToStore->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
List<UseChain> useChains = UseChain::from(use, isRelevantUse, isPassthroughInst);
outOfScopeUses.addRange(useChains);
@@ -1622,7 +1657,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
if (!isRecomputeInst)
processedStoreSet.add(localVar);
}
- else if (!canInstBeStored(instToStore))
+ else if (isPassthroughInst(instToStore))
{
// We won't actually process these insts here. Instead we'll
// simply make sure that their operands are either already present
@@ -1683,12 +1718,52 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
if (!isRecomputeInst)
processedStoreSet.add(localVar);
}
-
+
+ // Put the inst back on the worklist since there's a possibility that we created more uses
+ // for it in the process.
+ //
+ //workList.add(instToStore);
seenInstSet.add(instToStore);
}
};
+ // Pull any loop counter in the store set out to another list.
+ //
+ Dictionary<UIndex, OrderedHashSet<IRInst*>> loopCounters;
+ {
+ List<IRInst*> loopCounterInsts;
+ for (auto inst : hoistInfo->storeSet)
+ {
+ if (inst->findDecoration<IRLoopCounterDecoration>())
+ {
+ auto block = cast<IRBlock>(inst->getParent());
+ auto nestDepth = indexedBlockInfo.getValue(block).getCount() - 1;
+
+ if (!loopCounters.containsKey(nestDepth))
+ loopCounters[nestDepth] = OrderedHashSet<IRInst*>();
+
+ loopCounters[nestDepth].add(inst);
+ loopCounterInsts.add(inst);
+ }
+ }
+
+ for (auto inst : loopCounterInsts)
+ hoistInfo->storeSet.remove(inst);
+ }
+
+ // First handle all non-loop-counter insts.
ensureInstAvailable(hoistInfo->storeSet, false);
+
+ // Then handle the loop counter insts in reverse-order of nest depth
+ // This ordering is important because loop counters at level N _may_ depend on
+ // the counters at the previous levels.
+ //
+ for (Index ii = (Index)loopCounters.getCount() - 1; ii >= 0; --ii)
+ {
+ ensureInstAvailable(loopCounters[(UIndex)ii], false);
+ }
+
+ // Next handle all recompute insts, from within
ensureInstAvailable(hoistInfo->recomputeSet, true);
// Replace the old store set with the processed one.
@@ -2080,12 +2155,27 @@ static bool shouldStoreInst(IRInst* inst)
return false;
case kIROp_Call:
- // If the callee prefers recompute policy, don't store.
+ {
+ // If the callee has a preference, we should follow it.
if (getCheckpointPreference(inst->getOperand(0)) == CheckpointPreference::PreferRecompute)
{
return false;
}
+ else if (getCheckpointPreference(inst->getOperand(0)) == CheckpointPreference::PreferCheckpoint)
+ {
+ return true;
+ }
+
+ // If not, we'll default to recomputing calls that don't have side effects & don't
+ // load from non-local variables. A previous data-flow pass should have already tagged functions
+ // with the appropriate decorations.
+ //
+ auto callee = getResolvedInstForDecorations(inst->getOperand(0), true);
+ if (callee->findDecoration<IRReadNoneDecoration>())
+ return false;
+
break;
+ }
default:
break;
}
@@ -2123,8 +2213,8 @@ static bool shouldStoreVar(IRVar* var)
// of the var will be the same as the decision for the call.
return shouldStoreInst(callUser);
}
- // Default behavior is to store if we can.
- return true;
+ // Default behavior is to recompute stuff.
+ return false;
}
// If the var has never been written to, don't store it.
return false;
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 169dd31ee..87a5d2281 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -70,6 +70,7 @@ namespace Slang
// Don't need to do anything other than add a decoration in the original func to point to the primal func.
// The body of the primal func will be generated by propagateTranscriber together with propagate func.
addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
+ builder->addDecoration(diffFunc, kIROp_IgnoreSideEffectsDecoration);
return InstPair(primalFunc, diffFunc);
}
@@ -116,6 +117,9 @@ namespace Slang
builder->emitCallInst(builder->getVoidType(), udfRefFromPropFunc, params);
builder->emitReturn();
+ // Copy other decorations from the original func to the generated primal func wrapper.
+ copyOriginalDecorations(udf, diffPropFunc);
+
// Now create the trivial primal function.
auto existingDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>();
if (!existingDecor)
@@ -148,6 +152,9 @@ namespace Slang
checkpointHint = originalFunc->findDecoration<IRCheckpointHintDecoration>();
if (checkpointHint)
cloneCheckpointHint(builder, checkpointHint, cast<IRGlobalValueWithCode>(existingPrimalFunc));
+
+ // Copy other decorations from the original func to the generated primal func wrapper.
+ copyOriginalDecorations(udf, existingPrimalFunc);
builder->emitBlock();
params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc));
@@ -193,7 +200,7 @@ namespace Slang
addTranscribedFuncDecoration(*builder, primalFunc, diffFunc);
if (auto udf = primalFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>())
{
- generateTrivialDiffFuncFromUserDefinedDerivative(builder, primalFunc, diffFunc, udf);
+ generateTrivialDiffFuncFromUserDefinedDerivative(builder, primalFunc, diffFunc, udf);
}
else
{
@@ -360,6 +367,7 @@ namespace Slang
auto newName = this->getTranscribedFuncName(&builder, origFunc);
builder.addNameHintDecoration(diffFunc, newName);
}
+ addTranscribedFuncDecoration(builder, primalFunc, diffFunc);
// Transfer checkpoint hint decorations
copyCheckpointHints(&builder, origFunc, diffFunc);
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index b65701a7a..428fe088c 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -32,8 +32,8 @@ struct ParameterBlockTransposeInfo
// The value with which a primal specific parameter should be replaced in propagate func.
OrderedDictionary<IRInst*, IRInst*> mapPrimalSpecificParamToReplacementInPropFunc;
+
// The insts added that is specific for propagate functions and should be removed
-
// from the future primal func.
List<IRInst*> propagateFuncSpecificPrimalInsts;
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 507a2bf92..28715395a 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -57,7 +57,10 @@ struct ExtractPrimalFuncContext
return createGenericIntermediateType(as<IRGeneric>(func));
IRBuilder builder(module);
builder.setInsertBefore(func);
+
auto intermediateType = builder.createStructType();
+
+ builder.addDecoration(intermediateType, kIROp_OptimizableTypeDecoration);
if (auto nameHint = func->findDecoration<IRNameHintDecoration>())
{
StringBuilder newName;
@@ -65,6 +68,7 @@ struct ExtractPrimalFuncContext
builder.addNameHintDecoration(
intermediateType, UnownedStringSlice(newName.getBuffer()));
}
+
return intermediateType;
}
@@ -289,6 +293,32 @@ struct ExtractPrimalFuncContext
}
};
+bool isIntermediateContextType(IRInst* type)
+{
+ switch (type->getOp())
+ {
+ case kIROp_BackwardDiffIntermediateContextType:
+ return true;
+ case kIROp_AttributedType:
+ return isIntermediateContextType(as<IRAttributedType>(type)->getBaseType());
+ case kIROp_Specialize:
+ return isIntermediateContextType(as<IRSpecialize>(type)->getBase());
+ default:
+ if (as<IRPtrTypeBase>(type))
+ return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType());
+ return false;
+ }
+}
+
+void markNonContextParamsAsSideEffectFree(IRBuilder* builder, IRFunc* func)
+{
+ for (auto param : func->getParams())
+ {
+ if (!isIntermediateContextType(param->getDataType()))
+ builder->addDecoration(param, kIROp_IgnoreSideEffectsDecoration);
+ }
+}
+
static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneEnv)
{
IRInst* newInst = nullptr;
@@ -306,6 +336,21 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE
}
}
+IRBlock* getFirstRecomputeBlock(IRFunc* func)
+{
+ // This logic is a bit fragile.
+ // We shouldn't necessarily make the
+ // assumption that the order in the list of blocks is related to the
+ // control-flow order, but it works with the current system.
+ //
+ for (auto block : func->getBlocks())
+ {
+ if (block->findDecoration<IRRecomputeBlockDecoration>())
+ return block;
+ }
+ return nullptr;
+}
+
IRFunc* DiffUnzipPass::extractPrimalFunc(
IRFunc* func,
IRFunc* originalFunc,
@@ -359,13 +404,18 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
{
auto primalName = String("s_primal_ctx_") + UnownedStringSlice(originalNameHint->getName());
builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice()));
+ builder.addDecoration(primalFunc, kIROp_IgnoreSideEffectsDecoration);
+
+ markNonContextParamsAsSideEffectFree(&builder, primalFunc);
}
// Copy PrimalValueStructKey decorations from primal func.
copyPrimalValueStructKeyDecorations(func, subEnv);
-
- auto paramBlock = func->getFirstBlock();
- auto firstBlock = *(paramBlock->getSuccessors().begin());
+
+ auto firstRecomputeBlock = getFirstRecomputeBlock(func);
+ SLANG_ASSERT(firstRecomputeBlock);
+
+ auto firstBlock = firstRecomputeBlock;
builder.setInsertBefore(firstBlock->getFirstInst());
auto intermediateVar = func->getLastParam();
@@ -447,6 +497,50 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
inst->removeAndDeallocate();
}
+ auto paramBlock = func->getFirstBlock();
+ auto paramPreludeBlock = paramBlock->getNextBlock();
+
+ // Remove primal blocks from the propagate func & wire the param block directly to the first
+ // recompute block.
+ //
+ {
+ auto terminator = cast<IRUnconditionalBranch>(paramPreludeBlock->getTerminator());
+ builder.replaceOperand(&(terminator->block), firstRecomputeBlock);
+ }
+
+ // Erase all primal blocks (except for the param & prelude blocks).
+ // TODO: Lots of ways to clean this up.
+ //
+ List<IRBlock*> blocksToRemove;
+ for (auto block : func->getBlocks())
+ {
+ if (block != paramBlock &&
+ block != paramPreludeBlock &&
+ !block->findDecoration<IRRecomputeBlockDecoration>() &&
+ !block->findDecoration<IRDifferentialInstDecoration>())
+ blocksToRemove.add(block);
+ }
+
+ // Before erasing the blocks, go through and 're-hoist' any hoistable instructions (such as types)
+ // Any remaining valid instructions should be automatically moved to the recompute blocks.
+ // The rest can be removed.
+ //
+ List<IRInst*> instsToReHoist;
+ for (auto block : blocksToRemove)
+ for (auto inst : block->getChildren())
+ if (getIROpInfo(inst->getOp()).flags & kIROpFlag_Hoistable)
+ instsToReHoist.add(inst);
+
+ for (auto inst : instsToReHoist)
+ {
+ inst->removeFromParent();
+ addHoistableInst(&builder, inst);
+ }
+
+ for (auto block : blocksToRemove)
+ block->removeAndDeallocate();
+
+
return primalFunc;
}
} // namespace Slang
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index bed6a68e4..975a6e554 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -6,6 +6,7 @@
#include "slang-ir-single-return.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-validate.h"
+#include "slang-ir-inline.h"
#include "../core/slang-performance-profiler.h"
namespace Slang
@@ -1511,6 +1512,7 @@ void stripTempDecorations(IRInst* inst)
case kIROp_RecomputeBlockDecoration:
case kIROp_AutoDiffOriginalValueDecoration:
case kIROp_BackwardDerivativePrimalReturnDecoration:
+ case kIROp_BackwardDerivativePrimalContextDecoration:
case kIROp_PrimalValueStructKeyDecoration:
case kIROp_PrimalElementTypeDecoration:
decor->removeAndDeallocate();
@@ -2283,12 +2285,15 @@ struct AutoDiffPass : public InstPassBase
}
}
- // Get rid of block-level decorations that are used to keep track of
- // different block types. These don't work well with the IR simplification
- // passes since they don't expect decorations in blocks.
- //
+
for (auto diffFunc : autodiffCleanupList)
+ {
+ // Get rid of block-level decorations that are used to keep track of
+ // different block types. These don't work well with the IR simplification
+ // passes since they don't expect decorations in blocks.
+ //
stripTempDecorations(diffFunc);
+ }
autodiffCleanupList.clear();
@@ -2300,9 +2305,7 @@ struct AutoDiffPass : public InstPassBase
break;
if (lowerIntermediateContextType(builder))
- {
hasChanges = true;
- }
// We have done transcribing the functions, now it is time to demote all DifferentialPair types
// and their operations down to DifferentialPairUserCodeType and *UserCode operations so they
@@ -2312,7 +2315,6 @@ struct AutoDiffPass : public InstPassBase
hasChanges |= changed;
}
-
return hasChanges;
}
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index f414f7266..12cc4ed93 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -294,6 +294,237 @@ bool isFirstBlock(IRInst* inst)
return block->getParent()->getFirstBlock() == block;
}
+bool isPtrUsed(IRInst* ptrInst)
+{
+ for (auto use = ptrInst->firstUse; use; use = use->nextUse)
+ {
+ if (as<IRLoad>(use->getUser()))
+ return true;
+ else if (as<IRCall>(use->getUser())) // TODO: narrow this case to 'inout' parameters only.
+ return true;
+ else if (as<IRPtrTypeBase>(use->getUser()->getDataType()) &&
+ isPtrUsed(use->getUser()))
+ return true;
+ }
+
+ return false;
+}
+
+bool isFieldUsed(IRStructField* fieldInst)
+{
+ auto structKey = fieldInst->getKey();
+ for (auto use = structKey->firstUse; use; use = use->nextUse)
+ {
+ if (as<IRPtrTypeBase>(use->getUser()->getDataType()) &&
+ isPtrUsed(use->getUser()))
+ return true;
+
+ if (as<IRFieldExtract>(use->getUser()))
+ return true;
+ }
+
+ // Check fields that have this field as a sub-field.
+ auto parentType = cast<IRStructType>(fieldInst->getParent());
+
+ if (as<IRModuleInst>(parentType->getParent()))
+ {
+ for (auto use = parentType->firstUse; use; use = use->nextUse)
+ {
+ auto useField = as<IRStructField>(use->getUser());
+ if (useField && isFieldUsed(useField))
+ return true;
+ }
+ }
+ else if (as<IRBlock>(parentType->getParent()))
+ {
+ if (auto genericParentType = as<IRGeneric>(parentType->getParent()))
+ {
+ List<IRSpecialize*> specInsts;
+ for (auto use = genericParentType->firstUse; use; use = use->nextUse)
+ {
+ if (auto specInst = as<IRSpecialize>(use->getUser()))
+ specInsts.add(specInst);
+ }
+
+ for (auto specInst : specInsts)
+ {
+ for (auto use = specInst->firstUse; use; use = use->nextUse)
+ {
+ auto useField = as<IRStructField>(use->getUser());
+ if (useField && isFieldUsed(useField))
+ return true;
+ }
+ }
+ }
+ }
+
+ return false;
+}
+
+bool removeStoresIntoInst(IRInst* ptrInst)
+{
+ bool changed = false;
+
+ List<IRInst*> storesToRemove;
+ for (auto use = ptrInst->firstUse; use; use = use->nextUse)
+ {
+ // If this is a store, remove it.
+ if (auto store = as<IRStore>(use->getUser()))
+ {
+ if (store->getPtr() == ptrInst)
+ storesToRemove.add(store);
+ }
+
+ // If there are any stores into a 'sub-object' of the pointer,
+ // remove them.
+ //
+
+ if (auto subAddr = as<IRFieldAddress>(use->getUser()))
+ changed |= removeStoresIntoInst(subAddr);
+
+ if (auto subIndex = as<IRGetElementPtr>(use->getUser()))
+ changed |= removeStoresIntoInst(subIndex);
+ }
+
+ for (auto store : storesToRemove)
+ {
+ changed = true;
+ store->removeAndDeallocate();
+ }
+
+ return changed;
+}
+
+bool removeStoresIntoField(IRStructField* field)
+{
+ return removeStoresIntoInst(field->getKey());
+}
+
+bool trimMakeStructOperands(IRStructField* field)
+{
+ // TODO: This can be sped up by considering the full set of fields instead
+ // of one at a time.
+
+ bool changed = false;
+ auto structType = cast<IRStructType>(field->getParent());
+
+ UIndex indexInStruct = 0;
+ for (auto _field : structType->getFields())
+ {
+ if (field == _field)
+ break;
+ indexInStruct++;
+ }
+
+ List<IRInst*> workList;
+ for (auto use = structType->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser()->getOp() == kIROp_MakeStruct)
+ {
+ workList.add(use->getUser());
+ }
+ }
+
+ IRBuilder builder(field->getModule());
+
+ for (auto makeStruct : workList)
+ {
+ // Make a replacement list of operands.
+ List<IRInst*> newOperands;
+ for (UInt index = 0; index < makeStruct->getOperandCount(); ++index)
+ {
+ if (index == indexInStruct)
+ {
+ // skip..
+ changed = true;
+ continue;
+ }
+ else
+ {
+ newOperands.add(makeStruct->getOperand(index));
+ }
+ }
+
+ builder.setInsertAfter(makeStruct);
+ auto newMakeStruct = builder.emitMakeStruct(makeStruct->getFullType(), newOperands);
+ makeStruct->replaceUsesWith(newMakeStruct);
+ }
+
+ for (auto makeStruct : workList)
+ {
+ makeStruct->removeAndDeallocate();
+ }
+
+ return changed;
+}
+
+bool isStructEmpty(IRType* type)
+{
+ auto structType = as<IRStructType>(type);
+ if (!structType)
+ return false;
+
+ UCount nonEmptyFieldCount = 0;
+ for (auto field : structType->getFields())
+ {
+ if (as<IRVoidType>(field->getFieldType()))
+ continue;
+ if (isStructEmpty(field->getFieldType()))
+ continue;
+ nonEmptyFieldCount++;
+ }
+
+ return nonEmptyFieldCount == 0;
+}
+
+bool trimOptimizableType(IRStructType* type)
+{
+ bool changed = false;
+ List<IRStructField*> fieldsToRemove;
+ for (auto field : type->getFields())
+ {
+ // We'll ignore void-type fields, since they're handled differently.
+ if (as<IRVoidType>(field->getFieldType()))
+ continue;
+
+ // ... same for empty struct fields.
+ if(as<IRStructType>(field->getFieldType()) && isStructEmpty(field->getFieldType()))
+ continue;
+
+ if (!isFieldUsed(field))
+ fieldsToRemove.add(field);
+ }
+
+ for (auto field : fieldsToRemove)
+ {
+ changed |= removeStoresIntoField(field);
+ changed |= trimMakeStructOperands(field);
+ field->removeFromParent();
+ }
+
+ for (auto field : fieldsToRemove)
+ {
+ changed = true;
+ field->removeAndDeallocate();
+ }
+
+ return changed;
+}
+
+bool trimOptimizableTypes(IRModule* module)
+{
+ bool changed = false;
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto type = as<IRStructType>(inst))
+ {
+ if (type->findDecoration<IROptimizableTypeDecoration>())
+ changed |= trimOptimizableType(type);
+ }
+ }
+ return changed;
+}
+
bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions options)
{
// The main source of confusion/complexity here is that
diff --git a/source/slang/slang-ir-dce.h b/source/slang/slang-ir-dce.h
index 55eed1c92..a0e76ece5 100644
--- a/source/slang/slang-ir-dce.h
+++ b/source/slang/slang-ir-dce.h
@@ -33,4 +33,6 @@ namespace Slang
bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions options);
bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex);
+
+ bool trimOptimizableTypes(IRModule* module);
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 7fe521486..f4365cf62 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1069,6 +1069,18 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
/// Mark a call as explicitly calling a differentiable function.
INST(DifferentiableCallDecoration, differentiableCallDecoration, 0, 0)
+ /// Mark a type as being eligible for trimming if necessary. If
+ /// any fields don't have any effective loads from them, they can be
+ /// removed.
+ ///
+ INST(OptimizableTypeDecoration, optimizableTypeDecoration, 0, 0)
+
+ /// Informs the DCE pass to ignore side-effects on this call for
+ /// the purposes of dead code elimination, even if the call does have
+ /// side-effects.
+ ///
+ INST(IgnoreSideEffectsDecoration, ignoreSideEffectsDecoration, 0, 0)
+
/// Hint that the result from a call to the decorated function should be stored in backward prop function.
INST(PreferCheckpointDecoration, PreferCheckpointDecoration, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index ba96fedf6..bd86f14a4 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1119,6 +1119,32 @@ struct IRDifferentiableCallDecoration : IRDecoration
IR_LEAF_ISA(DifferentiableCallDecoration)
};
+// Mark a type as being eligible for trimming if necessary. If
+// any fields don't have any effective loads from them, they can be
+// removed.
+//
+struct IROptimizableTypeDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_OptimizableTypeDecoration
+ };
+ IR_LEAF_ISA(OptimizableTypeDecoration)
+};
+
+// Informs the DCE pass to ignore side-effects on this call for
+// the purposes of dead code elimination, even if the call does have
+// side-effects.
+//
+struct IRIgnoreSideEffectsDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_IgnoreSideEffectsDecoration
+ };
+ IR_LEAF_ISA(IgnoreSideEffectsDecoration)
+};
+
// Treat a call to a non-differentiable function as a differentiable call.
struct IRTreatCallAsDifferentiableDecoration : IRDecoration
{
diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp
index 82287f58e..9662dc522 100644
--- a/source/slang/slang-ir-layout.cpp
+++ b/source/slang/slang-ir-layout.cpp
@@ -331,6 +331,12 @@ case kIROp_##TYPE##Type: \
case kIROp_DefaultBufferLayoutType:
*outSizeAndAlignment = IRSizeAndAlignment(0, 4);
return SLANG_OK;
+ case kIROp_AttributedType:
+ {
+ auto attributedType = cast<IRAttributedType>(type);
+ SLANG_ASSERT(attributedType->getAttr()->getOp() == kIROp_NoDiffAttr);
+ return getSizeAndAlignment(optionSet, rules, attributedType->getBaseType(), outSizeAndAlignment);
+ }
default:
break;
}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index cd0f67186..a71b2c86c 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -64,6 +64,7 @@ namespace Slang
changed |= removeUnusedGenericParam(module);
changed |= applySparseConditionalConstantPropagationForGlobalScope(module, sink);
changed |= peepholeOptimizeGlobalScope(target, module);
+ changed |= trimOptimizableTypes(module);
for (auto inst : module->getGlobalInsts())
{
@@ -74,6 +75,8 @@ namespace Slang
int funcIterationCount = 0;
while (funcChanged && funcIterationCount < kMaxFuncIterations)
{
+
+ eliminateDeadCode(func, options.deadCodeElimOptions);
funcChanged = false;
funcChanged |= applySparseConditionalConstantPropagation(func, sink);
funcChanged |= peepholeOptimize(target, func);
@@ -83,6 +86,8 @@ namespace Slang
// Note: we disregard the `changed` state from dead code elimination pass since
// SCCP pass could be generating temporarily evaluated constant values and never actually use them.
// DCE will always remove those nearly generated consts and always returns true here.
+ // Run eliminate-dead-code twice to ensure optimizations are applied on the dce'd code.
+ //
eliminateDeadCode(func, options.deadCodeElimOptions);
if (funcIterationCount == 0)
funcChanged |= constructSSA(func);
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index f81cde30b..b0eeca4dd 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -971,12 +971,41 @@ void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst)
}
}
+IRInst* tryFindBasePtr(IRInst* inst, IRInst* parentFunc)
+{
+ // Keep going up the tree until we find a variable.
+ switch (inst->getOp())
+ {
+ case kIROp_Var:
+ return getParentFunc(inst) == parentFunc ? inst : nullptr;
+ case kIROp_Param:
+ return getParentFunc(inst) == parentFunc ? inst : nullptr;
+ case kIROp_GetElementPtr:
+ return tryFindBasePtr(as<IRGetElementPtr>(inst)->getBase(), parentFunc);
+ case kIROp_FieldAddress:
+ return tryFindBasePtr(as<IRFieldAddress>(inst)->getBase(), parentFunc);
+ default:
+ return nullptr;
+ }
+}
+
bool areCallArgumentsSideEffectFree(IRCall* call, SideEffectAnalysisOptions options)
{
// If the function has no side effect and is not writing to any outputs,
// we can safely treat the call as a normal inst.
+
IRFunc* parentFunc = nullptr;
- for (UInt i = 0; i < call->getArgCount(); i++)
+
+ IRParam* param = nullptr;
+ if (auto calleeFunc = getResolvedInstForDecorations(call->getCallee()))
+ {
+ if (auto block = calleeFunc->getFirstBlock())
+ {
+ param = block->getFirstParam();
+ }
+ }
+
+ for (UInt i = 0; i < call->getArgCount(); i++, (param = param ? param->getNextParam() : nullptr))
{
auto arg = call->getArg(i);
if (isValueType(arg->getDataType()))
@@ -1074,6 +1103,9 @@ bool areCallArgumentsSideEffectFree(IRCall* call, SideEffectAnalysisOptions opti
}
else
{
+ if (param && param->findDecoration<IRIgnoreSideEffectsDecoration>())
+ continue;
+
return false;
}
}
@@ -1107,6 +1139,7 @@ bool doesCalleeHaveSideEffect(IRInst* callee)
{
case kIROp_NoSideEffectDecoration:
case kIROp_ReadNoneDecoration:
+ case kIROp_IgnoreSideEffectsDecoration:
return false;
}
}
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 56e6e1676..bd1a212aa 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -8494,7 +8494,7 @@ namespace Slang
// common subexpression elimination, etc.
//
auto call = cast<IRCall>(this);
- return !isSideEffectFreeFunctionalCall(call, options);
+ return !(isSideEffectFreeFunctionalCall(call, options));
}
break;
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ecc06ebfd..ce7ab9ac6 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9898,6 +9898,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
addVarDecorations(context, irParam, paramDecl);
subBuilder->addHighLevelDeclDecoration(irParam, paramDecl);
+ irParam->sourceLoc = paramDecl->loc;
}
addParamNameHint(irParam, paramInfo);
@@ -9929,6 +9930,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
addVarDecorations(context, irParam, paramDecl);
subBuilder->addHighLevelDeclDecoration(irParam, paramDecl);
+ irParam->sourceLoc = paramDecl->loc;
}
addParamNameHint(irParam, paramInfo);
paramVal = LoweredValInfo::simple(irParam);
diff --git a/tests/autodiff/high-order-backward-diff-3.slang b/tests/autodiff/high-order-backward-diff-3.slang
index 100a9a1e0..1df6415fe 100644
--- a/tests/autodiff/high-order-backward-diff-3.slang
+++ b/tests/autodiff/high-order-backward-diff-3.slang
@@ -1,6 +1,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
-//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
diff --git a/tests/autodiff/high-order-backward-diff-4.slang b/tests/autodiff/high-order-backward-diff-4.slang
index 9ee9aa4c4..e1392c05f 100644
--- a/tests/autodiff/high-order-backward-diff-4.slang
+++ b/tests/autodiff/high-order-backward-diff-4.slang
@@ -1,6 +1,8 @@
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
-//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type
+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
diff --git a/tests/autodiff/path-tracer/pt-loop.slang b/tests/autodiff/path-tracer/pt-loop.slang
index ac8bf763d..85e1825ab 100644
--- a/tests/autodiff/path-tracer/pt-loop.slang
+++ b/tests/autodiff/path-tracer/pt-loop.slang
@@ -1,7 +1,7 @@
//Tests automatic synthesis of Differential type requirement.
-//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -Xslang -loop-inversion
-//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -Xslang -loop-inversion
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -dx12
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang
index 0b6e56f78..72f112d4c 100644
--- a/tests/autodiff/reverse-continue-loop.slang
+++ b/tests/autodiff/reverse-continue-loop.slang
@@ -1,7 +1,7 @@
-//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -Xslang -dump-intermediates
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
-//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
+//TEST:SIMPLE(filecheck=CHK):-target hlsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -9,14 +9,14 @@ RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef float.Differential dfloat;
-//CHK: note: checkpointing context of 24 bytes associated with function: 'test_loop_with_continue'
+//CHK-DAG: note: checkpointing context of 24 bytes associated with function: 'test_loop_with_continue'
[BackwardDifferentiable]
float test_loop_with_continue(float y)
{
- //CHK: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item:
+ //CHK-DAG: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item:
float t = y;
- //CHK: note: 4 bytes (int32_t) used for a loop counter here:
+ //CHK-DAG: note: 4 bytes (int32_t) used for a loop counter here:
for (int i = 0; i < 3; i++)
{
if (t > 4.0)
diff --git a/tests/autodiff/reverse-control-flow-3.slang b/tests/autodiff/reverse-control-flow-3.slang
index b4fa68e3a..c8c09d44f 100644
--- a/tests/autodiff/reverse-control-flow-3.slang
+++ b/tests/autodiff/reverse-control-flow-3.slang
@@ -1,5 +1,5 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
-//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
+//DISABLE_TEST:SIMPLE(filecheck=CHK):-target hlsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
diff --git a/tests/autodiff/reverse-loop-checkpoint-test.slang b/tests/autodiff/reverse-loop-checkpoint-test.slang
index 53a089b21..19316a786 100644
--- a/tests/autodiff/reverse-loop-checkpoint-test.slang
+++ b/tests/autodiff/reverse-loop-checkpoint-test.slang
@@ -1,7 +1,8 @@
-//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
-//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
-//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
+//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop.slang
index 2ba8535be..18b672860 100644
--- a/tests/autodiff/reverse-loop.slang
+++ b/tests/autodiff/reverse-loop.slang
@@ -1,7 +1,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
-//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
+//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
diff --git a/tests/autodiff/reverse-nested-calls.slang b/tests/autodiff/reverse-nested-calls.slang
index 3c1a52c21..1b59cc75d 100644
--- a/tests/autodiff/reverse-nested-calls.slang
+++ b/tests/autodiff/reverse-nested-calls.slang
@@ -1,7 +1,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
-//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
+//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -16,11 +16,9 @@ float g(float y)
return result * result;
}
-//CHK: note: checkpointing context of 4 bytes associated with function: 'f'
[BackwardDifferentiable]
float f(float x)
{
- //CHK: note: 4 bytes (float) used to checkpoint the following item:
return 3.0f * g(2.0f * x);
}
diff --git a/tests/autodiff/test-minimal-context.slang b/tests/autodiff/test-minimal-context.slang
new file mode 100644
index 000000000..c2a2b87ed
--- /dev/null
+++ b/tests/autodiff/test-minimal-context.slang
@@ -0,0 +1,76 @@
+//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
+//DISABLE_TEST:SIMPLE(filecheck=CTX):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[BackwardDerivative(bwd_load)]
+float load(uint idx)
+{
+ return outputBuffer[idx];
+}
+
+void bwd_load(uint idx, float dOut)
+{
+ outputBuffer[idx + 2] += dOut;
+}
+
+[BackwardDerivative(bwd_store)]
+void store(uint idx, float a)
+{
+ outputBuffer[idx] = a;
+}
+
+[ForceInline]
+float inner_bwd_store(uint idx)
+{
+ return outputBuffer[idx + 2];
+}
+
+[ForceInline]
+void bwd_store(uint idx, inout DifferentialPair<float> a)
+{
+ a = diffPair(a.p, inner_bwd_store(idx));
+}
+
+[BackwardDerivative(bwd_g)]
+float g(float x)
+{
+ return load(1) * load(1);
+}
+
+void bwd_g(inout DifferentialPair<float> x, float dOut)
+{
+ float y = load(1);
+ x = diffPair(x.p + 2 * y, x.d + 2 * y * dOut);
+ store(0, x.d);
+}
+
+[BackwardDifferentiable]
+float f(int p, float x)
+{
+ float y = g(x);
+
+ store(0, y);
+
+ return 0;
+}
+
+// Check that there are no calls to primal_ctx_f in bwd_f.
+
+// CHECK: void s_bwd_f_{{[0-9]+}}
+// CHECK-NOT: s_primal_ctx_f_{{[0-9]+}}
+// CHECK: return
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ dpfloat dpa = dpfloat(2.0, 0.0);
+
+ bwd_diff(f)(0, dpa, 1.0f);
+ outputBuffer[0] = dpa.d; // Expect: 1
+}
+// CTX: note: \ No newline at end of file