diff options
26 files changed, 769 insertions, 171 deletions
diff --git a/source/core/slang-dictionary.h b/source/core/slang-dictionary.h index 9c445c3c9..9c753ffcd 100644 --- a/source/core/slang-dictionary.h +++ b/source/core/slang-dictionary.h @@ -9,7 +9,7 @@ #include "slang-math.h" #include "slang-hash.h" -#include <ankerl/unordered_dense.h> +#include "../../external/unordered_dense/include/ankerl/unordered_dense.h" #include <initializer_list> diff --git a/source/core/slang-hash.h b/source/core/slang-hash.h index 3e255f7db..ff0cdc181 100644 --- a/source/core/slang-hash.h +++ b/source/core/slang-hash.h @@ -4,7 +4,7 @@ #include "../../include/slang.h" #include "slang-math.h" -#include <ankerl/unordered_dense.h> +#include "../../external/unordered_dense/include/ankerl/unordered_dense.h" #include <cstring> #include <type_traits> diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index b39d91494..6042ff5cc 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -855,17 +855,20 @@ struct DiffTensorView return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T>(x))); } + [ForceInline] __generic<let N : int> DifferentialPair<T> __load_forward(vector<uint, N> x) { return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.load_forward<T, N>(x))); } + [ForceInline] void __load_backward(uint x, T.Differential dOut) { diff.load_backward<T>(x, reinterpret<T, T.Differential>(dOut)); } + [ForceInline] __generic<let N : int> void __load_backward(vector<uint, N> x, T.Differential dOut) { @@ -894,11 +897,13 @@ struct DiffTensorView diff.store_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d)); } + [ForceInline] void __store_backward(uint x, inout DifferentialPair<T> dpval) { dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.store_backward<T>(x))); } + [ForceInline] __generic<let N : int> void __store_backward(vector<uint, N> x, inout DifferentialPair<T> dpval) { @@ -999,11 +1004,13 @@ struct DiffTensorView return diffPair(primal.load(x), reinterpret<T.Differential, T>(diff.loadOnce_forward<T, N>(x))); } + [ForceInline] void __loadOnce_backward(uint x, T.Differential dOut) { diff.loadOnce_backward<T>(x, reinterpret<T, T.Differential>(dOut)); } + [ForceInline] __generic<let N : int> void __loadOnce_backward(vector<uint, N> x, T.Differential dOut) { @@ -1032,11 +1039,13 @@ struct DiffTensorView diff.storeOnce_forward<T, N>(x, reinterpret<T, T.Differential>(dpval.d)); } + [ForceInline] void __storeOnce_backward(uint x, inout DifferentialPair<T> dpval) { dpval = diffPair(dpval.p, reinterpret<T.Differential, T>(diff.storeOnce_backward<T>(x))); } + [ForceInline] __generic<let N : int> void __storeOnce_backward(vector<uint, N> x, inout DifferentialPair<T> dpval) { diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 2881abe3e..fcc1f95ee 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -359,135 +359,155 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc( while (workList.getCount() > 0) { - auto use = workList.getLast(); - workList.removeLast(); + while (workList.getCount() > 0) + { + auto use = workList.getLast(); + workList.removeLast(); - if (processedUses.contains(use)) - continue; + if (processedUses.contains(use)) + continue; - processedUses.add(use); + processedUses.add(use); - HoistResult result = this->classify(use); + HoistResult result = this->classify(use); - if (result.mode == HoistResult::Mode::Store) - { - SLANG_ASSERT(!checkpointInfo->recomputeSet.contains(result.instToStore)); - checkpointInfo->storeSet.add(result.instToStore); - } - else if (result.mode == HoistResult::Mode::Recompute) - { - SLANG_ASSERT(!checkpointInfo->storeSet.contains(result.instToRecompute)); - checkpointInfo->recomputeSet.add(result.instToRecompute); + if (result.mode == HoistResult::Mode::Store) + { + SLANG_ASSERT(!checkpointInfo->recomputeSet.contains(result.instToStore)); + checkpointInfo->storeSet.add(result.instToStore); + } + else if (result.mode == HoistResult::Mode::Recompute) + { + SLANG_ASSERT(!checkpointInfo->storeSet.contains(result.instToRecompute)); + checkpointInfo->recomputeSet.add(result.instToRecompute); - if (isDifferentialInst(use.user) && use.irUse) - usesToReplace.add(use.irUse); + if (isDifferentialInst(use.user) && use.irUse) + usesToReplace.add(use.irUse); - if (auto param = as<IRParam>(result.instToRecompute)) - { - if (auto inductionInfo = inductionValueInsts.tryGetValue(param)) + if (auto param = as<IRParam>(result.instToRecompute)) { - checkpointInfo->loopInductionInfo.addIfNotExists(param, *inductionInfo); - continue; - } + if (auto inductionInfo = inductionValueInsts.tryGetValue(param)) + { + checkpointInfo->loopInductionInfo.addIfNotExists(param, *inductionInfo); + continue; + } - // Add in the branch-args of every predecessor block. - auto paramBlock = as<IRBlock>(param->getParent()); - UIndex paramIndex = 0; - for (auto _param : paramBlock->getParams()) - { - if (_param == param) break; - paramIndex ++; - } + // Add in the branch-args of every predecessor block. + auto paramBlock = as<IRBlock>(param->getParent()); + UIndex paramIndex = 0; + for (auto _param : paramBlock->getParams()) + { + if (_param == param) break; + paramIndex ++; + } - for (auto predecessor : paramBlock->getPredecessors()) - { - // If we hit this, the checkpoint policy is trying to recompute - // values across a loop region boundary (we don't currently support this, - // and in general this is quite inefficient in both compute & memory) - // - SLANG_RELEASE_ASSERT(!domTree->dominates(paramBlock, predecessor)); + for (auto predecessor : paramBlock->getPredecessors()) + { + // If we hit this, the checkpoint policy is trying to recompute + // values across a loop region boundary (we don't currently support this, + // and in general this is quite inefficient in both compute & memory) + // + SLANG_RELEASE_ASSERT(!domTree->dominates(paramBlock, predecessor)); - auto branchInst = as<IRUnconditionalBranch>(predecessor->getTerminator()); - SLANG_ASSERT(branchInst->getOperandCount() > paramIndex); + auto branchInst = as<IRUnconditionalBranch>(predecessor->getTerminator()); + SLANG_ASSERT(branchInst->getOperandCount() > paramIndex); - workList.add(&branchInst->getArgs()[paramIndex]); + workList.add(&branchInst->getArgs()[paramIndex]); + } } - } - else - { - if (auto var = as<IRVar>(result.instToRecompute)) + else { - for (auto varUse = var->firstUse; varUse; varUse = varUse->nextUse) + if (auto var = as<IRVar>(result.instToRecompute)) { - switch (varUse->getUser()->getOp()) + for (auto varUse = var->firstUse; varUse; varUse = varUse->nextUse) { - case kIROp_Store: - case kIROp_Call: - // When we have a var and a store/call insts that writes to the var, - // we treat as if there is a pseudo-use of the store/call to compute - // the var inst, i.e. the var depends on the store/call, despite - // the IR's def-use chain doesn't reflect this. - workList.add(UseOrPseudoUse(var, varUse->getUser())); - break; + switch (varUse->getUser()->getOp()) + { + case kIROp_Store: + case kIROp_Call: + // When we have a var and a store/call insts that writes to the var, + // we treat as if there is a pseudo-use of the store/call to compute + // the var inst, i.e. the var depends on the store/call, despite + // the IR's def-use chain doesn't reflect this. + workList.add(UseOrPseudoUse(var, varUse->getUser())); + break; + } } } - } - else - { - addPrimalOperandsToWorkList(result.instToRecompute); + else + { + addPrimalOperandsToWorkList(result.instToRecompute); + } } } } - } - // If a var or call is in recomputeSet, move any var/calls associated with the same call to - // recomputeSet. - List<IRInst*> instWorkList; - HashSet<IRInst*> instWorkListSet; - for (auto inst : checkpointInfo->recomputeSet) - { - switch (inst->getOp()) + // If a var or call is in recomputeSet, move any var/calls associated with the same call to + // recomputeSet. + // This is a bit of a 'retro-active' analysis where we go back on processed insts and + // correct them. + // + List<IRInst*> callVarWorkList; + HashSet<IRInst*> callVarWorkListSet; + for (auto inst : checkpointInfo->recomputeSet) { - case kIROp_Call: - case kIROp_Var: - instWorkList.add(inst); - instWorkListSet.add(inst); - break; + switch (inst->getOp()) + { + case kIROp_Call: + case kIROp_Var: + callVarWorkList.add(inst); + callVarWorkListSet.add(inst); + break; + } } - } - for (Index i = 0; i < instWorkList.getCount(); i++) - { - auto inst = instWorkList[i]; - if (auto var = as<IRVar>(inst)) + + for (Index i = 0; i < callVarWorkList.getCount(); i++) { - for (auto use = var->firstUse; use; use = use->nextUse) + auto inst = callVarWorkList[i]; + if (auto var = as<IRVar>(inst)) { - if (auto callUser = as<IRCall>(use->getUser())) + for (auto use = var->firstUse; use; use = use->nextUse) { - checkpointInfo->recomputeSet.add(callUser); - checkpointInfo->storeSet.remove(callUser); - if (instWorkListSet.add(callUser)) - instWorkList.add(callUser); - } - else if (auto storeUser = as<IRStore>(use->getUser())) - { - checkpointInfo->recomputeSet.add(storeUser); - checkpointInfo->storeSet.remove(storeUser); - if (instWorkListSet.add(callUser)) - instWorkList.add(callUser); + if (auto callUser = as<IRCall>(use->getUser())) + { + checkpointInfo->recomputeSet.add(callUser); + checkpointInfo->storeSet.remove(callUser); + if (callVarWorkListSet.add(callUser)) + callVarWorkList.add(callUser); + } + else if (auto storeUser = as<IRStore>(use->getUser())) + { + checkpointInfo->recomputeSet.add(storeUser); + checkpointInfo->storeSet.remove(storeUser); + if (callVarWorkListSet.add(callUser)) + callVarWorkList.add(callUser); + } } } - } - else if (auto call = as<IRCall>(inst)) - { - for (UInt j = 0; j < call->getArgCount(); j++) + else if (auto call = as<IRCall>(inst)) { - if (auto varArg = as<IRVar>(call->getArg(j))) + for (UInt j = 0; j < call->getArgCount(); j++) + { + if (auto varArg = as<IRVar>(call->getArg(j))) + { + checkpointInfo->recomputeSet.add(varArg); + checkpointInfo->storeSet.remove(varArg); + if (callVarWorkListSet.add(varArg)) + callVarWorkList.add(varArg); + } + } + + // This next few lines are a bit of a hack.. ideally we need to add the call to the main worklist + // for processing, so we don't have to repeat the recomputationn actions. + // + auto calleeUse = &call->getOperands()[0]; + if (!as<IRModuleInst>(calleeUse->get()->getParent()) && !processedUses.contains(calleeUse)) + addPrimalOperandsToWorkList(call); + + for (auto use = call->firstUse; use; use = use->nextUse) { - checkpointInfo->recomputeSet.add(varArg); - checkpointInfo->storeSet.remove(varArg); - if (instWorkListSet.add(varArg)) - instWorkList.add(varArg); + if (isDifferentialInst(use->getUser())) + usesToReplace.add(use); } } } @@ -1048,14 +1068,6 @@ void applyCheckpointSet( // primal logic may have already updated the param with a new value, and instead we // want the original value. builder.setInsertBefore(recomputeInsertBeforeInst); - if (auto load = as<IRLoad>(child)) - { - if (load->getPtr()->getOp() == kIROp_Param && - load->getPtr()->getParent() == func->getFirstBlock()) - { - builder.setInsertBefore(getParamPreludeBlock(func)->getTerminator()); - } - } applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, blockIndexInfo, child); } } @@ -1331,15 +1343,17 @@ struct UseChain IRBuilder chainBuilder(builder->getModule()); setInsertAfterOrdinaryInst(&chainBuilder, inst); - + + chain.reverse(); + // Clone the rest of the chain. for (auto& use : chain) { - lastInstInChain = cloneInst(&env, &chainBuilder, use->getUser()); + lastInstInChain = cloneInst(&env, &chainBuilder, use->get()); } // Replace the base use. - builder->replaceOperand(baseUse, lastInstInChain); + builder->replaceOperand(chain.getLast(), lastInstInChain); chain.clear(); } @@ -1347,7 +1361,7 @@ struct UseChain IRInst* getUser() const { SLANG_ASSERT(chain.getCount() > 0); - return chain.getLast()->getUser(); + return chain.getFirst()->getUser(); } }; @@ -1385,11 +1399,14 @@ static List<IndexTrackingInfo> maybeTrimIndices( bool canInstBeStored(IRInst* inst) { - // Cannot store insts whose value is a type or a witness table. + // Cannot store insts whose value is a type or a witness table, or a function. // These insts get lowered to target-specific logic, and cannot be // stored into variables or context structs as normal values. // - if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType())) + if (as<IRTypeType>(inst->getDataType()) || + as<IRWitnessTableType>(inst->getDataType()) || + as<IRTypeKind>(inst->getDataType()) || + as<IRFuncType>(inst->getDataType())) return false; return true; @@ -1499,18 +1516,15 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( SLANG_RELEASE_ASSERT(defBlock); - List<UseChain> outOfScopeUses; - for (auto use = instToStore->firstUse; use;) + // Lambda to check if a use is relevant. + auto isRelevantUse = [&](IRUse* use) { - auto nextUse = use->nextUse; - - // Lambda to check if a use is relevant. - auto isRelevantUse = [&](IRUse* use) + // Only consider uses in differential blocks. + // This method is not responsible for other blocks. + // + IRBlock* userBlock = getBlock(use->getUser()); + if (isRecomputeInst) { - // Only consider uses in differential blocks. - // This method is not responsible for other blocks. - // - IRBlock* userBlock = getBlock(use->getUser()); if (isDifferentialOrRecomputeBlock(userBlock)) { if (!domTree->dominates(defBlock, userBlock)) @@ -1532,16 +1546,37 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( return true; } } - return false; - }; + } + else + { + if (isDifferentialOrRecomputeBlock(userBlock)) + return true; + } + return false; + }; - // Lambda to check if an inst is transparent. We lookup uses 'through' transparent - // insts recursively. - // - auto isPassthroughInst = [&](IRInst* inst) + // Lambda to check if an inst is transparent. We lookup uses 'through' transparent + // insts recursively. + // + auto isPassthroughInst = [&](IRInst* inst) + { + if (!canInstBeStored(inst)) + return true; + + switch (inst->getOp()) { - return !canInstBeStored(inst); - }; + case kIROp_GetSequentialID: + case kIROp_ExtractExistentialValue: + return true; + } + + return false; + }; + + List<UseChain> outOfScopeUses; + for (auto use = instToStore->firstUse; use;) + { + auto nextUse = use->nextUse; List<UseChain> useChains = UseChain::from(use, isRelevantUse, isPassthroughInst); outOfScopeUses.addRange(useChains); @@ -1622,7 +1657,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( if (!isRecomputeInst) processedStoreSet.add(localVar); } - else if (!canInstBeStored(instToStore)) + else if (isPassthroughInst(instToStore)) { // We won't actually process these insts here. Instead we'll // simply make sure that their operands are either already present @@ -1683,12 +1718,52 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( if (!isRecomputeInst) processedStoreSet.add(localVar); } - + + // Put the inst back on the worklist since there's a possibility that we created more uses + // for it in the process. + // + //workList.add(instToStore); seenInstSet.add(instToStore); } }; + // Pull any loop counter in the store set out to another list. + // + Dictionary<UIndex, OrderedHashSet<IRInst*>> loopCounters; + { + List<IRInst*> loopCounterInsts; + for (auto inst : hoistInfo->storeSet) + { + if (inst->findDecoration<IRLoopCounterDecoration>()) + { + auto block = cast<IRBlock>(inst->getParent()); + auto nestDepth = indexedBlockInfo.getValue(block).getCount() - 1; + + if (!loopCounters.containsKey(nestDepth)) + loopCounters[nestDepth] = OrderedHashSet<IRInst*>(); + + loopCounters[nestDepth].add(inst); + loopCounterInsts.add(inst); + } + } + + for (auto inst : loopCounterInsts) + hoistInfo->storeSet.remove(inst); + } + + // First handle all non-loop-counter insts. ensureInstAvailable(hoistInfo->storeSet, false); + + // Then handle the loop counter insts in reverse-order of nest depth + // This ordering is important because loop counters at level N _may_ depend on + // the counters at the previous levels. + // + for (Index ii = (Index)loopCounters.getCount() - 1; ii >= 0; --ii) + { + ensureInstAvailable(loopCounters[(UIndex)ii], false); + } + + // Next handle all recompute insts, from within ensureInstAvailable(hoistInfo->recomputeSet, true); // Replace the old store set with the processed one. @@ -2080,12 +2155,27 @@ static bool shouldStoreInst(IRInst* inst) return false; case kIROp_Call: - // If the callee prefers recompute policy, don't store. + { + // If the callee has a preference, we should follow it. if (getCheckpointPreference(inst->getOperand(0)) == CheckpointPreference::PreferRecompute) { return false; } + else if (getCheckpointPreference(inst->getOperand(0)) == CheckpointPreference::PreferCheckpoint) + { + return true; + } + + // If not, we'll default to recomputing calls that don't have side effects & don't + // load from non-local variables. A previous data-flow pass should have already tagged functions + // with the appropriate decorations. + // + auto callee = getResolvedInstForDecorations(inst->getOperand(0), true); + if (callee->findDecoration<IRReadNoneDecoration>()) + return false; + break; + } default: break; } @@ -2123,8 +2213,8 @@ static bool shouldStoreVar(IRVar* var) // of the var will be the same as the decision for the call. return shouldStoreInst(callUser); } - // Default behavior is to store if we can. - return true; + // Default behavior is to recompute stuff. + return false; } // If the var has never been written to, don't store it. return false; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 169dd31ee..87a5d2281 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -70,6 +70,7 @@ namespace Slang // Don't need to do anything other than add a decoration in the original func to point to the primal func. // The body of the primal func will be generated by propagateTranscriber together with propagate func. addTranscribedFuncDecoration(*builder, primalFunc, diffFunc); + builder->addDecoration(diffFunc, kIROp_IgnoreSideEffectsDecoration); return InstPair(primalFunc, diffFunc); } @@ -116,6 +117,9 @@ namespace Slang builder->emitCallInst(builder->getVoidType(), udfRefFromPropFunc, params); builder->emitReturn(); + // Copy other decorations from the original func to the generated primal func wrapper. + copyOriginalDecorations(udf, diffPropFunc); + // Now create the trivial primal function. auto existingDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>(); if (!existingDecor) @@ -148,6 +152,9 @@ namespace Slang checkpointHint = originalFunc->findDecoration<IRCheckpointHintDecoration>(); if (checkpointHint) cloneCheckpointHint(builder, checkpointHint, cast<IRGlobalValueWithCode>(existingPrimalFunc)); + + // Copy other decorations from the original func to the generated primal func wrapper. + copyOriginalDecorations(udf, existingPrimalFunc); builder->emitBlock(); params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc)); @@ -193,7 +200,7 @@ namespace Slang addTranscribedFuncDecoration(*builder, primalFunc, diffFunc); if (auto udf = primalFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) { - generateTrivialDiffFuncFromUserDefinedDerivative(builder, primalFunc, diffFunc, udf); + generateTrivialDiffFuncFromUserDefinedDerivative(builder, primalFunc, diffFunc, udf); } else { @@ -360,6 +367,7 @@ namespace Slang auto newName = this->getTranscribedFuncName(&builder, origFunc); builder.addNameHintDecoration(diffFunc, newName); } + addTranscribedFuncDecoration(builder, primalFunc, diffFunc); // Transfer checkpoint hint decorations copyCheckpointHints(&builder, origFunc, diffFunc); diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index b65701a7a..428fe088c 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -32,8 +32,8 @@ struct ParameterBlockTransposeInfo // The value with which a primal specific parameter should be replaced in propagate func. OrderedDictionary<IRInst*, IRInst*> mapPrimalSpecificParamToReplacementInPropFunc; + // The insts added that is specific for propagate functions and should be removed - // from the future primal func. List<IRInst*> propagateFuncSpecificPrimalInsts; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 507a2bf92..28715395a 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -57,7 +57,10 @@ struct ExtractPrimalFuncContext return createGenericIntermediateType(as<IRGeneric>(func)); IRBuilder builder(module); builder.setInsertBefore(func); + auto intermediateType = builder.createStructType(); + + builder.addDecoration(intermediateType, kIROp_OptimizableTypeDecoration); if (auto nameHint = func->findDecoration<IRNameHintDecoration>()) { StringBuilder newName; @@ -65,6 +68,7 @@ struct ExtractPrimalFuncContext builder.addNameHintDecoration( intermediateType, UnownedStringSlice(newName.getBuffer())); } + return intermediateType; } @@ -289,6 +293,32 @@ struct ExtractPrimalFuncContext } }; +bool isIntermediateContextType(IRInst* type) +{ + switch (type->getOp()) + { + case kIROp_BackwardDiffIntermediateContextType: + return true; + case kIROp_AttributedType: + return isIntermediateContextType(as<IRAttributedType>(type)->getBaseType()); + case kIROp_Specialize: + return isIntermediateContextType(as<IRSpecialize>(type)->getBase()); + default: + if (as<IRPtrTypeBase>(type)) + return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType()); + return false; + } +} + +void markNonContextParamsAsSideEffectFree(IRBuilder* builder, IRFunc* func) +{ + for (auto param : func->getParams()) + { + if (!isIntermediateContextType(param->getDataType())) + builder->addDecoration(param, kIROp_IgnoreSideEffectsDecoration); + } +} + static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneEnv) { IRInst* newInst = nullptr; @@ -306,6 +336,21 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE } } +IRBlock* getFirstRecomputeBlock(IRFunc* func) +{ + // This logic is a bit fragile. + // We shouldn't necessarily make the + // assumption that the order in the list of blocks is related to the + // control-flow order, but it works with the current system. + // + for (auto block : func->getBlocks()) + { + if (block->findDecoration<IRRecomputeBlockDecoration>()) + return block; + } + return nullptr; +} + IRFunc* DiffUnzipPass::extractPrimalFunc( IRFunc* func, IRFunc* originalFunc, @@ -359,13 +404,18 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( { auto primalName = String("s_primal_ctx_") + UnownedStringSlice(originalNameHint->getName()); builder.addNameHintDecoration(primalFunc, builder.getStringValue(primalName.getUnownedSlice())); + builder.addDecoration(primalFunc, kIROp_IgnoreSideEffectsDecoration); + + markNonContextParamsAsSideEffectFree(&builder, primalFunc); } // Copy PrimalValueStructKey decorations from primal func. copyPrimalValueStructKeyDecorations(func, subEnv); - - auto paramBlock = func->getFirstBlock(); - auto firstBlock = *(paramBlock->getSuccessors().begin()); + + auto firstRecomputeBlock = getFirstRecomputeBlock(func); + SLANG_ASSERT(firstRecomputeBlock); + + auto firstBlock = firstRecomputeBlock; builder.setInsertBefore(firstBlock->getFirstInst()); auto intermediateVar = func->getLastParam(); @@ -447,6 +497,50 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( inst->removeAndDeallocate(); } + auto paramBlock = func->getFirstBlock(); + auto paramPreludeBlock = paramBlock->getNextBlock(); + + // Remove primal blocks from the propagate func & wire the param block directly to the first + // recompute block. + // + { + auto terminator = cast<IRUnconditionalBranch>(paramPreludeBlock->getTerminator()); + builder.replaceOperand(&(terminator->block), firstRecomputeBlock); + } + + // Erase all primal blocks (except for the param & prelude blocks). + // TODO: Lots of ways to clean this up. + // + List<IRBlock*> blocksToRemove; + for (auto block : func->getBlocks()) + { + if (block != paramBlock && + block != paramPreludeBlock && + !block->findDecoration<IRRecomputeBlockDecoration>() && + !block->findDecoration<IRDifferentialInstDecoration>()) + blocksToRemove.add(block); + } + + // Before erasing the blocks, go through and 're-hoist' any hoistable instructions (such as types) + // Any remaining valid instructions should be automatically moved to the recompute blocks. + // The rest can be removed. + // + List<IRInst*> instsToReHoist; + for (auto block : blocksToRemove) + for (auto inst : block->getChildren()) + if (getIROpInfo(inst->getOp()).flags & kIROpFlag_Hoistable) + instsToReHoist.add(inst); + + for (auto inst : instsToReHoist) + { + inst->removeFromParent(); + addHoistableInst(&builder, inst); + } + + for (auto block : blocksToRemove) + block->removeAndDeallocate(); + + return primalFunc; } } // namespace Slang diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index bed6a68e4..975a6e554 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -6,6 +6,7 @@ #include "slang-ir-single-return.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-validate.h" +#include "slang-ir-inline.h" #include "../core/slang-performance-profiler.h" namespace Slang @@ -1511,6 +1512,7 @@ void stripTempDecorations(IRInst* inst) case kIROp_RecomputeBlockDecoration: case kIROp_AutoDiffOriginalValueDecoration: case kIROp_BackwardDerivativePrimalReturnDecoration: + case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_PrimalValueStructKeyDecoration: case kIROp_PrimalElementTypeDecoration: decor->removeAndDeallocate(); @@ -2283,12 +2285,15 @@ struct AutoDiffPass : public InstPassBase } } - // Get rid of block-level decorations that are used to keep track of - // different block types. These don't work well with the IR simplification - // passes since they don't expect decorations in blocks. - // + for (auto diffFunc : autodiffCleanupList) + { + // Get rid of block-level decorations that are used to keep track of + // different block types. These don't work well with the IR simplification + // passes since they don't expect decorations in blocks. + // stripTempDecorations(diffFunc); + } autodiffCleanupList.clear(); @@ -2300,9 +2305,7 @@ struct AutoDiffPass : public InstPassBase break; if (lowerIntermediateContextType(builder)) - { hasChanges = true; - } // We have done transcribing the functions, now it is time to demote all DifferentialPair types // and their operations down to DifferentialPairUserCodeType and *UserCode operations so they @@ -2312,7 +2315,6 @@ struct AutoDiffPass : public InstPassBase hasChanges |= changed; } - return hasChanges; } diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index f414f7266..12cc4ed93 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -294,6 +294,237 @@ bool isFirstBlock(IRInst* inst) return block->getParent()->getFirstBlock() == block; } +bool isPtrUsed(IRInst* ptrInst) +{ + for (auto use = ptrInst->firstUse; use; use = use->nextUse) + { + if (as<IRLoad>(use->getUser())) + return true; + else if (as<IRCall>(use->getUser())) // TODO: narrow this case to 'inout' parameters only. + return true; + else if (as<IRPtrTypeBase>(use->getUser()->getDataType()) && + isPtrUsed(use->getUser())) + return true; + } + + return false; +} + +bool isFieldUsed(IRStructField* fieldInst) +{ + auto structKey = fieldInst->getKey(); + for (auto use = structKey->firstUse; use; use = use->nextUse) + { + if (as<IRPtrTypeBase>(use->getUser()->getDataType()) && + isPtrUsed(use->getUser())) + return true; + + if (as<IRFieldExtract>(use->getUser())) + return true; + } + + // Check fields that have this field as a sub-field. + auto parentType = cast<IRStructType>(fieldInst->getParent()); + + if (as<IRModuleInst>(parentType->getParent())) + { + for (auto use = parentType->firstUse; use; use = use->nextUse) + { + auto useField = as<IRStructField>(use->getUser()); + if (useField && isFieldUsed(useField)) + return true; + } + } + else if (as<IRBlock>(parentType->getParent())) + { + if (auto genericParentType = as<IRGeneric>(parentType->getParent())) + { + List<IRSpecialize*> specInsts; + for (auto use = genericParentType->firstUse; use; use = use->nextUse) + { + if (auto specInst = as<IRSpecialize>(use->getUser())) + specInsts.add(specInst); + } + + for (auto specInst : specInsts) + { + for (auto use = specInst->firstUse; use; use = use->nextUse) + { + auto useField = as<IRStructField>(use->getUser()); + if (useField && isFieldUsed(useField)) + return true; + } + } + } + } + + return false; +} + +bool removeStoresIntoInst(IRInst* ptrInst) +{ + bool changed = false; + + List<IRInst*> storesToRemove; + for (auto use = ptrInst->firstUse; use; use = use->nextUse) + { + // If this is a store, remove it. + if (auto store = as<IRStore>(use->getUser())) + { + if (store->getPtr() == ptrInst) + storesToRemove.add(store); + } + + // If there are any stores into a 'sub-object' of the pointer, + // remove them. + // + + if (auto subAddr = as<IRFieldAddress>(use->getUser())) + changed |= removeStoresIntoInst(subAddr); + + if (auto subIndex = as<IRGetElementPtr>(use->getUser())) + changed |= removeStoresIntoInst(subIndex); + } + + for (auto store : storesToRemove) + { + changed = true; + store->removeAndDeallocate(); + } + + return changed; +} + +bool removeStoresIntoField(IRStructField* field) +{ + return removeStoresIntoInst(field->getKey()); +} + +bool trimMakeStructOperands(IRStructField* field) +{ + // TODO: This can be sped up by considering the full set of fields instead + // of one at a time. + + bool changed = false; + auto structType = cast<IRStructType>(field->getParent()); + + UIndex indexInStruct = 0; + for (auto _field : structType->getFields()) + { + if (field == _field) + break; + indexInStruct++; + } + + List<IRInst*> workList; + for (auto use = structType->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getOp() == kIROp_MakeStruct) + { + workList.add(use->getUser()); + } + } + + IRBuilder builder(field->getModule()); + + for (auto makeStruct : workList) + { + // Make a replacement list of operands. + List<IRInst*> newOperands; + for (UInt index = 0; index < makeStruct->getOperandCount(); ++index) + { + if (index == indexInStruct) + { + // skip.. + changed = true; + continue; + } + else + { + newOperands.add(makeStruct->getOperand(index)); + } + } + + builder.setInsertAfter(makeStruct); + auto newMakeStruct = builder.emitMakeStruct(makeStruct->getFullType(), newOperands); + makeStruct->replaceUsesWith(newMakeStruct); + } + + for (auto makeStruct : workList) + { + makeStruct->removeAndDeallocate(); + } + + return changed; +} + +bool isStructEmpty(IRType* type) +{ + auto structType = as<IRStructType>(type); + if (!structType) + return false; + + UCount nonEmptyFieldCount = 0; + for (auto field : structType->getFields()) + { + if (as<IRVoidType>(field->getFieldType())) + continue; + if (isStructEmpty(field->getFieldType())) + continue; + nonEmptyFieldCount++; + } + + return nonEmptyFieldCount == 0; +} + +bool trimOptimizableType(IRStructType* type) +{ + bool changed = false; + List<IRStructField*> fieldsToRemove; + for (auto field : type->getFields()) + { + // We'll ignore void-type fields, since they're handled differently. + if (as<IRVoidType>(field->getFieldType())) + continue; + + // ... same for empty struct fields. + if(as<IRStructType>(field->getFieldType()) && isStructEmpty(field->getFieldType())) + continue; + + if (!isFieldUsed(field)) + fieldsToRemove.add(field); + } + + for (auto field : fieldsToRemove) + { + changed |= removeStoresIntoField(field); + changed |= trimMakeStructOperands(field); + field->removeFromParent(); + } + + for (auto field : fieldsToRemove) + { + changed = true; + field->removeAndDeallocate(); + } + + return changed; +} + +bool trimOptimizableTypes(IRModule* module) +{ + bool changed = false; + for (auto inst : module->getGlobalInsts()) + { + if (auto type = as<IRStructType>(inst)) + { + if (type->findDecoration<IROptimizableTypeDecoration>()) + changed |= trimOptimizableType(type); + } + } + return changed; +} + bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions options) { // The main source of confusion/complexity here is that diff --git a/source/slang/slang-ir-dce.h b/source/slang/slang-ir-dce.h index 55eed1c92..a0e76ece5 100644 --- a/source/slang/slang-ir-dce.h +++ b/source/slang/slang-ir-dce.h @@ -33,4 +33,6 @@ namespace Slang bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions options); bool isWeakReferenceOperand(IRInst* inst, UInt operandIndex); + + bool trimOptimizableTypes(IRModule* module); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 7fe521486..f4365cf62 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1069,6 +1069,18 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) /// Mark a call as explicitly calling a differentiable function. INST(DifferentiableCallDecoration, differentiableCallDecoration, 0, 0) + /// Mark a type as being eligible for trimming if necessary. If + /// any fields don't have any effective loads from them, they can be + /// removed. + /// + INST(OptimizableTypeDecoration, optimizableTypeDecoration, 0, 0) + + /// Informs the DCE pass to ignore side-effects on this call for + /// the purposes of dead code elimination, even if the call does have + /// side-effects. + /// + INST(IgnoreSideEffectsDecoration, ignoreSideEffectsDecoration, 0, 0) + /// Hint that the result from a call to the decorated function should be stored in backward prop function. INST(PreferCheckpointDecoration, PreferCheckpointDecoration, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index ba96fedf6..bd86f14a4 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1119,6 +1119,32 @@ struct IRDifferentiableCallDecoration : IRDecoration IR_LEAF_ISA(DifferentiableCallDecoration) }; +// Mark a type as being eligible for trimming if necessary. If +// any fields don't have any effective loads from them, they can be +// removed. +// +struct IROptimizableTypeDecoration : IRDecoration +{ + enum + { + kOp = kIROp_OptimizableTypeDecoration + }; + IR_LEAF_ISA(OptimizableTypeDecoration) +}; + +// Informs the DCE pass to ignore side-effects on this call for +// the purposes of dead code elimination, even if the call does have +// side-effects. +// +struct IRIgnoreSideEffectsDecoration : IRDecoration +{ + enum + { + kOp = kIROp_IgnoreSideEffectsDecoration + }; + IR_LEAF_ISA(IgnoreSideEffectsDecoration) +}; + // Treat a call to a non-differentiable function as a differentiable call. struct IRTreatCallAsDifferentiableDecoration : IRDecoration { diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index 82287f58e..9662dc522 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -331,6 +331,12 @@ case kIROp_##TYPE##Type: \ case kIROp_DefaultBufferLayoutType: *outSizeAndAlignment = IRSizeAndAlignment(0, 4); return SLANG_OK; + case kIROp_AttributedType: + { + auto attributedType = cast<IRAttributedType>(type); + SLANG_ASSERT(attributedType->getAttr()->getOp() == kIROp_NoDiffAttr); + return getSizeAndAlignment(optionSet, rules, attributedType->getBaseType(), outSizeAndAlignment); + } default: break; } diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index cd0f67186..a71b2c86c 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -64,6 +64,7 @@ namespace Slang changed |= removeUnusedGenericParam(module); changed |= applySparseConditionalConstantPropagationForGlobalScope(module, sink); changed |= peepholeOptimizeGlobalScope(target, module); + changed |= trimOptimizableTypes(module); for (auto inst : module->getGlobalInsts()) { @@ -74,6 +75,8 @@ namespace Slang int funcIterationCount = 0; while (funcChanged && funcIterationCount < kMaxFuncIterations) { + + eliminateDeadCode(func, options.deadCodeElimOptions); funcChanged = false; funcChanged |= applySparseConditionalConstantPropagation(func, sink); funcChanged |= peepholeOptimize(target, func); @@ -83,6 +86,8 @@ namespace Slang // Note: we disregard the `changed` state from dead code elimination pass since // SCCP pass could be generating temporarily evaluated constant values and never actually use them. // DCE will always remove those nearly generated consts and always returns true here. + // Run eliminate-dead-code twice to ensure optimizations are applied on the dce'd code. + // eliminateDeadCode(func, options.deadCodeElimOptions); if (funcIterationCount == 0) funcChanged |= constructSSA(func); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index f81cde30b..b0eeca4dd 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -971,12 +971,41 @@ void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst) } } +IRInst* tryFindBasePtr(IRInst* inst, IRInst* parentFunc) +{ + // Keep going up the tree until we find a variable. + switch (inst->getOp()) + { + case kIROp_Var: + return getParentFunc(inst) == parentFunc ? inst : nullptr; + case kIROp_Param: + return getParentFunc(inst) == parentFunc ? inst : nullptr; + case kIROp_GetElementPtr: + return tryFindBasePtr(as<IRGetElementPtr>(inst)->getBase(), parentFunc); + case kIROp_FieldAddress: + return tryFindBasePtr(as<IRFieldAddress>(inst)->getBase(), parentFunc); + default: + return nullptr; + } +} + bool areCallArgumentsSideEffectFree(IRCall* call, SideEffectAnalysisOptions options) { // If the function has no side effect and is not writing to any outputs, // we can safely treat the call as a normal inst. + IRFunc* parentFunc = nullptr; - for (UInt i = 0; i < call->getArgCount(); i++) + + IRParam* param = nullptr; + if (auto calleeFunc = getResolvedInstForDecorations(call->getCallee())) + { + if (auto block = calleeFunc->getFirstBlock()) + { + param = block->getFirstParam(); + } + } + + for (UInt i = 0; i < call->getArgCount(); i++, (param = param ? param->getNextParam() : nullptr)) { auto arg = call->getArg(i); if (isValueType(arg->getDataType())) @@ -1074,6 +1103,9 @@ bool areCallArgumentsSideEffectFree(IRCall* call, SideEffectAnalysisOptions opti } else { + if (param && param->findDecoration<IRIgnoreSideEffectsDecoration>()) + continue; + return false; } } @@ -1107,6 +1139,7 @@ bool doesCalleeHaveSideEffect(IRInst* callee) { case kIROp_NoSideEffectDecoration: case kIROp_ReadNoneDecoration: + case kIROp_IgnoreSideEffectsDecoration: return false; } } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 56e6e1676..bd1a212aa 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -8494,7 +8494,7 @@ namespace Slang // common subexpression elimination, etc. // auto call = cast<IRCall>(this); - return !isSideEffectFreeFunctionalCall(call, options); + return !(isSideEffectFreeFunctionalCall(call, options)); } break; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ecc06ebfd..ce7ab9ac6 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9898,6 +9898,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { addVarDecorations(context, irParam, paramDecl); subBuilder->addHighLevelDeclDecoration(irParam, paramDecl); + irParam->sourceLoc = paramDecl->loc; } addParamNameHint(irParam, paramInfo); @@ -9929,6 +9930,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { addVarDecorations(context, irParam, paramDecl); subBuilder->addHighLevelDeclDecoration(irParam, paramDecl); + irParam->sourceLoc = paramDecl->loc; } addParamNameHint(irParam, paramInfo); paramVal = LoweredValInfo::simple(irParam); diff --git a/tests/autodiff/high-order-backward-diff-3.slang b/tests/autodiff/high-order-backward-diff-3.slang index 100a9a1e0..1df6415fe 100644 --- a/tests/autodiff/high-order-backward-diff-3.slang +++ b/tests/autodiff/high-order-backward-diff-3.slang @@ -1,6 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; diff --git a/tests/autodiff/high-order-backward-diff-4.slang b/tests/autodiff/high-order-backward-diff-4.slang index 9ee9aa4c4..e1392c05f 100644 --- a/tests/autodiff/high-order-backward-diff-4.slang +++ b/tests/autodiff/high-order-backward-diff-4.slang @@ -1,6 +1,8 @@ //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type + //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; diff --git a/tests/autodiff/path-tracer/pt-loop.slang b/tests/autodiff/path-tracer/pt-loop.slang index ac8bf763d..85e1825ab 100644 --- a/tests/autodiff/path-tracer/pt-loop.slang +++ b/tests/autodiff/path-tracer/pt-loop.slang @@ -1,7 +1,7 @@ //Tests automatic synthesis of Differential type requirement. -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -Xslang -loop-inversion -//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -Xslang -loop-inversion +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -dx12 +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang index 0b6e56f78..72f112d4c 100644 --- a/tests/autodiff/reverse-continue-loop.slang +++ b/tests/autodiff/reverse-continue-loop.slang @@ -1,7 +1,7 @@ -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type -Xslang -dump-intermediates //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates +//TEST:SIMPLE(filecheck=CHK):-target hlsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; @@ -9,14 +9,14 @@ RWStructuredBuffer<float> outputBuffer; typedef DifferentialPair<float> dpfloat; typedef float.Differential dfloat; -//CHK: note: checkpointing context of 24 bytes associated with function: 'test_loop_with_continue' +//CHK-DAG: note: checkpointing context of 24 bytes associated with function: 'test_loop_with_continue' [BackwardDifferentiable] float test_loop_with_continue(float y) { - //CHK: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item: + //CHK-DAG: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item: float t = y; - //CHK: note: 4 bytes (int32_t) used for a loop counter here: + //CHK-DAG: note: 4 bytes (int32_t) used for a loop counter here: for (int i = 0; i < 3; i++) { if (t > 4.0) diff --git a/tests/autodiff/reverse-control-flow-3.slang b/tests/autodiff/reverse-control-flow-3.slang index b4fa68e3a..c8c09d44f 100644 --- a/tests/autodiff/reverse-control-flow-3.slang +++ b/tests/autodiff/reverse-control-flow-3.slang @@ -1,5 +1,5 @@ //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates +//DISABLE_TEST:SIMPLE(filecheck=CHK):-target hlsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer diff --git a/tests/autodiff/reverse-loop-checkpoint-test.slang b/tests/autodiff/reverse-loop-checkpoint-test.slang index 53a089b21..19316a786 100644 --- a/tests/autodiff/reverse-loop-checkpoint-test.slang +++ b/tests/autodiff/reverse-loop-checkpoint-test.slang @@ -1,7 +1,8 @@ -//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none -//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates -//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu +//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop.slang index 2ba8535be..18b672860 100644 --- a/tests/autodiff/reverse-loop.slang +++ b/tests/autodiff/reverse-loop.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates +//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; diff --git a/tests/autodiff/reverse-nested-calls.slang b/tests/autodiff/reverse-nested-calls.slang index 3c1a52c21..1b59cc75d 100644 --- a/tests/autodiff/reverse-nested-calls.slang +++ b/tests/autodiff/reverse-nested-calls.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type //TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type //TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj -//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates +//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; @@ -16,11 +16,9 @@ float g(float y) return result * result; } -//CHK: note: checkpointing context of 4 bytes associated with function: 'f' [BackwardDifferentiable] float f(float x) { - //CHK: note: 4 bytes (float) used to checkpoint the following item: return 3.0f * g(2.0f * x); } diff --git a/tests/autodiff/test-minimal-context.slang b/tests/autodiff/test-minimal-context.slang new file mode 100644 index 000000000..c2a2b87ed --- /dev/null +++ b/tests/autodiff/test-minimal-context.slang @@ -0,0 +1,76 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none +//DISABLE_TEST:SIMPLE(filecheck=CTX):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDerivative(bwd_load)] +float load(uint idx) +{ + return outputBuffer[idx]; +} + +void bwd_load(uint idx, float dOut) +{ + outputBuffer[idx + 2] += dOut; +} + +[BackwardDerivative(bwd_store)] +void store(uint idx, float a) +{ + outputBuffer[idx] = a; +} + +[ForceInline] +float inner_bwd_store(uint idx) +{ + return outputBuffer[idx + 2]; +} + +[ForceInline] +void bwd_store(uint idx, inout DifferentialPair<float> a) +{ + a = diffPair(a.p, inner_bwd_store(idx)); +} + +[BackwardDerivative(bwd_g)] +float g(float x) +{ + return load(1) * load(1); +} + +void bwd_g(inout DifferentialPair<float> x, float dOut) +{ + float y = load(1); + x = diffPair(x.p + 2 * y, x.d + 2 * y * dOut); + store(0, x.d); +} + +[BackwardDifferentiable] +float f(int p, float x) +{ + float y = g(x); + + store(0, y); + + return 0; +} + +// Check that there are no calls to primal_ctx_f in bwd_f. + +// CHECK: void s_bwd_f_{{[0-9]+}} +// CHECK-NOT: s_primal_ctx_f_{{[0-9]+}} +// CHECK: return + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + dpfloat dpa = dpfloat(2.0, 0.0); + + bwd_diff(f)(0, dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 1 +} +// CTX: note:
\ No newline at end of file |
