summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorvenkataram-nv <vedavamadath@nvidia.com>2024-09-18 20:42:07 -0700
committerGitHub <noreply@github.com>2024-09-18 20:42:07 -0700
commitb808aa4df50d46eaa569561f7e464c55c1c2d72a (patch)
tree5483a3f9e73a401ff82d66fd1ac3729a9a84a97c
parent3240799c00488858afc7eeac9d1dc479609a1040 (diff)
Report AD checkpoint contexts (#5058)
* Transferring source locations when creating phi instructions * Tracking for simple variables * Deriving source locations for loop counters * Printing checkpoint structure breakdown * More readable output format * Special behavior for loop counters * Writing report to file * Add slangc option to enable checkpoint reports * Display types of checkpointed fields * Message in case there are no checkpointing contexts * Catch source locations for function calls * Source cleanup * Fix compilation warnings * Remove stray dump() * Provide the report through diagnostic notes * Add missing path for sourceLoc during unzip pass * Add tests for reporting intermediates * Include more transfer cases for source locations * Fix ordering in address elimination * Fill in more holes with source location transfer * Remove debugging line * Reverting changes to diagnostic sink * Simplify address elimination using source location RAII contexts * Eliminating manual source loc transfers in forward transcription * Fix local var adaptation to use RAII location setter * Simplify primal hoisting logic for source location transfer * Simplify unzipping with RAII location scopes * Simplify transpose logic * Cleaning up for rev.cpp * Reverting spacing changes * Fix mistake with source loc RAII instantiation * Fix formatting issues
-rw-r--r--include/slang.h1
-rw-r--r--source/slang-record-replay/util/emum-to-string.h1
-rw-r--r--source/slang/slang-compiler.cpp6
-rwxr-xr-xsource/slang/slang-compiler.h1
-rw-r--r--source/slang/slang-diagnostic-defs.h6
-rw-r--r--source/slang/slang-emit.cpp67
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp18
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp42
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp28
-rw-r--r--source/slang/slang-ir-autodiff-rev.h1
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h19
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp18
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h4
-rw-r--r--source/slang/slang-ir-autodiff.cpp1
-rw-r--r--source/slang/slang-ir-clone.cpp1
-rw-r--r--source/slang/slang-ir-eliminate-phis.cpp3
-rw-r--r--source/slang/slang-ir-init-local-var.cpp3
-rw-r--r--source/slang/slang-ir-inst-defs.h3
-rw-r--r--source/slang/slang-ir-insts.h15
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp1
-rw-r--r--source/slang/slang-ir-ssa.cpp1
-rw-r--r--source/slang/slang-ir.cpp2
-rw-r--r--source/slang/slang-options.cpp2
-rw-r--r--tests/autodiff/reverse-checkpoint-1.slang6
-rw-r--r--tests/autodiff/reverse-checkpoint-2.slang2
-rw-r--r--tests/autodiff/reverse-continue-loop.slang6
-rw-r--r--tests/autodiff/reverse-control-flow-1.slang3
-rw-r--r--tests/autodiff/reverse-control-flow-2.slang3
-rw-r--r--tests/autodiff/reverse-control-flow-3.slang11
-rw-r--r--tests/autodiff/reverse-loop-checkpoint-test.slang8
-rw-r--r--tests/autodiff/reverse-loop.slang6
-rw-r--r--tests/autodiff/reverse-nested-calls.slang5
33 files changed, 264 insertions, 33 deletions
diff --git a/include/slang.h b/include/slang.h
index 3024aa884..3bcdcbba8 100644
--- a/include/slang.h
+++ b/include/slang.h
@@ -852,6 +852,7 @@ extern "C"
EmitIr, // bool
ReportDownstreamTime, // bool
ReportPerfBenchmark, // bool
+ ReportCheckpointIntermediates, // bool
SkipSPIRVValidation, // bool
SourceEmbedStyle,
SourceEmbedName,
diff --git a/source/slang-record-replay/util/emum-to-string.h b/source/slang-record-replay/util/emum-to-string.h
index 7226edc04..8c140cf3d 100644
--- a/source/slang-record-replay/util/emum-to-string.h
+++ b/source/slang-record-replay/util/emum-to-string.h
@@ -149,6 +149,7 @@ namespace SlangRecord
CASE(EmitIr);
CASE(ReportDownstreamTime);
CASE(ReportPerfBenchmark);
+ CASE(ReportCheckpointIntermediates);
CASE(SkipSPIRVValidation);
CASE(SourceEmbedStyle);
CASE(SourceEmbedName);
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 541085b4e..c89d94c80 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -2451,12 +2451,16 @@ namespace Slang
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr);
}
+ bool CodeGenContext::shouldReportCheckpointIntermediates()
+ {
+ return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ReportCheckpointIntermediates);
+ }
+
bool CodeGenContext::shouldDumpIntermediates()
{
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates);
}
-
bool CodeGenContext::shouldTrackLiveness()
{
return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness);
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 0c788ae18..4b20d1f76 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -2728,6 +2728,7 @@ namespace Slang
bool shouldValidateIR();
bool shouldDumpIR();
+ bool shouldReportCheckpointIntermediates();
bool shouldTrackLiveness();
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 81170fac3..e0f1e90c5 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -894,6 +894,12 @@ DIAGNOSTIC(58002, Error, unhandledGLSLSSBOType, "Unhandled GLSL Shader Storage B
DIAGNOSTIC(58003, Error, inconsistentPointerAddressSpace, "'$0': use of pointer with inconsistent address space.")
+// Autodiff checkpoint reporting
+DIAGNOSTIC(-1, Note, reportCheckpointIntermediates, "checkpointing context of $1 bytes associated with function: '$0'")
+DIAGNOSTIC(-1, Note, reportCheckpointVariable, "$0 bytes ($1) used to checkpoint the following item:")
+DIAGNOSTIC(-1, Note, reportCheckpointCounter, "$0 bytes ($1) used for a loop counter here:")
+DIAGNOSTIC(-1, Note, reportCheckpointNone, "no checkpoint contexts to report")
+
//
// 8xxxx - Issues specific to a particular library/technology/platform/etc.
//
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index cdd2ca5b6..6e3556064 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -34,6 +34,7 @@
#include "slang-ir-wgsl-legalize.h"
#include "slang-ir-insts.h"
#include "slang-ir-inline.h"
+#include "slang-ir-layout.h"
#include "slang-ir-legalize-array-return-type.h"
#include "slang-ir-legalize-mesh-outputs.h"
#include "slang-ir-legalize-varying-params.h"
@@ -214,6 +215,68 @@ static void dumpIRIfEnabled(
}
}
+static void reportCheckpointIntermediates(CodeGenContext* codeGenContext, DiagnosticSink* sink, IRModule* irModule)
+{
+ // Report checkpointing information
+ CompilerOptionSet& optionSet = codeGenContext->getTargetProgram()->getOptionSet();
+ SourceManager* sourceManager = sink->getSourceManager();
+
+ SourceWriter typeWriter(sourceManager, LineDirectiveMode::None, nullptr);
+
+ CLikeSourceEmitter::Desc description;
+ description.codeGenContext = codeGenContext;
+ description.sourceWriter = &typeWriter;
+
+ CPPSourceEmitter emitter(description);
+
+ int nonEmptyStructs = 0;
+ for (auto inst : irModule->getGlobalInsts())
+ {
+ IRStructType *structType = as<IRStructType>(inst);
+ if (!structType)
+ continue;
+
+ auto checkpointDecoration = structType->findDecoration<IRCheckpointIntermediateDecoration>();
+ if (!checkpointDecoration)
+ continue;
+
+ IRSizeAndAlignment structSize;
+ getNaturalSizeAndAlignment(optionSet, structType, &structSize);
+
+ // Reporting happens before empty structs are optimized out
+ // and we still want to keep the checkpointing decorations,
+ // so we end up needing to check for non-zero-ness
+ if (structSize.size == 0)
+ continue;
+
+ auto func = checkpointDecoration->getSourceFunction();
+ sink->diagnose(structType, Diagnostics::reportCheckpointIntermediates, func, structSize.size);
+ nonEmptyStructs++;
+
+ for (auto field : structType->getFields())
+ {
+ IRType *fieldType = field->getFieldType();
+ IRSizeAndAlignment fieldSize;
+ getNaturalSizeAndAlignment(optionSet, fieldType, &fieldSize);
+ if (fieldSize.size == 0)
+ continue;
+
+ typeWriter.clearContent();
+ emitter.emitType(fieldType);
+
+ sink->diagnose(field->sourceLoc,
+ field->findDecoration<IRLoopCounterDecoration>()
+ ? Diagnostics::reportCheckpointCounter
+ : Diagnostics::reportCheckpointVariable,
+ fieldSize.size,
+ typeWriter.getContent());
+ }
+ }
+
+ if (nonEmptyStructs == 0)
+ sink->diagnose(SourceLoc(), Diagnostics::reportCheckpointNone);
+}
+
struct LinkingAndOptimizationOptions
{
bool shouldLegalizeExistentialAndResourceTypes = true;
@@ -767,6 +830,10 @@ Result linkAndOptimizeIR(
break;
}
+ // Report checkpointing information
+ if (codeGenContext->shouldReportCheckpointIntermediates())
+ reportCheckpointIntermediates(codeGenContext, sink, irModule);
+
if (requiredLoweringPassSet.autodiff)
finalizeAutoDiffPass(targetProgram, irModule);
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
index 8a48936d7..b55f6b93d 100644
--- a/source/slang/slang-ir-addr-inst-elimination.cpp
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -69,30 +69,28 @@ struct AddressInstEliminationContext
}
}
- void transformLoadAddr(IRUse* use)
+ void transformLoadAddr(IRBuilder& builder, IRUse* use)
{
auto addr = use->get();
auto load = as<IRLoad>(use->getUser());
- IRBuilder builder(module);
builder.setInsertBefore(use->getUser());
auto value = getValue(builder, addr);
load->replaceUsesWith(value);
load->removeAndDeallocate();
}
- void transformStoreAddr(IRUse* use)
+ void transformStoreAddr(IRBuilder& builder, IRUse* use)
{
auto addr = use->get();
auto store = as<IRStore>(use->getUser());
- IRBuilder builder(module);
builder.setInsertBefore(use->getUser());
storeValue(builder, addr, store->getVal());
store->removeAndDeallocate();
}
- void transformCallAddr(IRUse* use)
+ void transformCallAddr(IRBuilder& builder, IRUse* use)
{
auto addr = use->get();
auto call = as<IRCall>(use->getUser());
@@ -103,7 +101,6 @@ struct AddressInstEliminationContext
return;
}
- IRBuilder builder(module);
builder.setInsertBefore(call);
auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType());
@@ -155,17 +152,20 @@ struct AddressInstEliminationContext
use = nextUse;
continue;
}
+
+ IRBuilder transformBuilder(module);
+ IRBuilderSourceLocRAII sourceLocationScope(&transformBuilder, use->getUser()->sourceLoc);
switch (use->getUser()->getOp())
{
case kIROp_Load:
- transformLoadAddr(use);
+ transformLoadAddr(transformBuilder, use);
break;
case kIROp_Store:
- transformStoreAddr(use);
+ transformStoreAddr(transformBuilder, use);
break;
case kIROp_Call:
- transformCallAddr(use);
+ transformCallAddr(transformBuilder, use);
break;
case kIROp_GetElementPtr:
case kIROp_FieldAddress:
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 9fe4ec70b..f51178f0f 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -3,8 +3,9 @@
#include "slang-ir-autodiff-region.h"
#include "slang-ir-simplify-cfg.h"
#include "slang-ir-util.h"
-#include "../core/slang-func-ptr.h"
+#include "slang-ir-insts.h"
#include "slang-ir.h"
+#include "../core/slang-func-ptr.h"
namespace Slang
{
@@ -1092,7 +1093,8 @@ IRType* getTypeForLocalStorage(
IRVar* emitIndexedLocalVar(
IRBlock* varBlock,
IRType* baseType,
- const List<IndexTrackingInfo>& defBlockIndices)
+ const List<IndexTrackingInfo>& defBlockIndices,
+ SourceLoc location)
{
// Cannot store pointers. Case should have been handled by now.
SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
@@ -1101,6 +1103,8 @@ IRVar* emitIndexedLocalVar(
SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType));
IRBuilder varBuilder(varBlock->getModule());
+ IRBuilderSourceLocRAII sourceLocationScope(&varBuilder, location);
+
varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst());
IRType* varType = getTypeForLocalStorage(&varBuilder, baseType, defBlockIndices);
@@ -1179,9 +1183,14 @@ IRVar* storeIndexedValue(
IRInst* instToStore,
const List<IndexTrackingInfo>& defBlockIndices)
{
- IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices);
+ IRVar* localVar = emitIndexedLocalVar(defaultVarBlock,
+ instToStore->getDataType(),
+ defBlockIndices,
+ instToStore->sourceLoc);
- IRInst* addr = emitIndexedStoreAddressForVar(builder, localVar, defBlockIndices);
+ IRInst* addr = emitIndexedStoreAddressForVar(builder,
+ localVar,
+ defBlockIndices);
builder->emitStore(addr, instToStore);
@@ -1574,12 +1583,16 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
// region, that means there's no need to allocate a fully indexed var.
//
defBlockIndices = maybeTrimIndices(defBlockIndices, indexedBlockInfo, outOfScopeUses);
-
- IRVar* localVar = storeIndexedValue(
- &builder,
- varBlock,
- builder.emitLoad(varToStore),
- defBlockIndices);
+
+ IRVar* localVar = nullptr;
+ {
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, varToStore->sourceLoc);
+ localVar = storeIndexedValue(
+ &builder,
+ varBlock,
+ builder.emitLoad(varToStore),
+ defBlockIndices);
+ }
for (auto use : outOfScopeUses)
{
@@ -1626,6 +1639,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
}
else
{
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, instToStore->sourceLoc);
+
// Handle the special case of loop counters.
// The only case where there will be a reference of primal loop counter from rev blocks
// is the start of a loop in the reverse code. Since loop counters are not considered a
@@ -1643,6 +1658,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
setInsertAfterOrdinaryInst(&builder, instToStore);
auto localVar = storeIndexedValue(&builder, varBlock, instToStore, defBlockIndices);
+ if (isLoopCounter)
+ builder.addLoopCounterDecoration(localVar);
for (auto use : outOfScopeUses)
{
@@ -1728,6 +1745,8 @@ static IRBlock* getUpdateBlock(IRLoop* loop)
void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalCountParam, IRInst*& diffCountParam)
{
IRBuilder builder(primalLoop);
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, primalLoop->sourceLoc);
+
primalCountParam = nullptr;
// Grab first primal block.
@@ -1899,8 +1918,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
// Legalize the primal inst accesses by introducing local variables / arrays and emitting
// necessary load/store logic.
//
- primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
- return primalsInfo;
+ return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
}
void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func)
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 35a197f29..2fb73c4ac 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -403,8 +403,11 @@ namespace Slang
List<IRType*> primalTypes, propagateTypes;
IRType* primalResultType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getResultType());
+ IRParam *currentParam = origFunc->getFirstParam();
for (UInt i = 0; i < origFuncType->getParamCount(); i++)
{
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, currentParam->sourceLoc);
+
auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i));
auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i));
if (propagateParamType)
@@ -453,6 +456,7 @@ namespace Slang
primalArgs.add(var);
}
primalTypes.add(primalParamType);
+ currentParam = currentParam->getNextParam();
}
// Add dOut argument to propagateArgs.
@@ -588,6 +592,8 @@ namespace Slang
autoDiffSharedContext->transcriberSet.forwardTranscriber);
auto oldCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount();
IRFunc* fwdDiffFunc = as<IRFunc>(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent)));
+ fwdDiffFunc->sourceLoc = primalFunc->sourceLoc;
+
SLANG_ASSERT(fwdDiffFunc);
auto newCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount();
for (auto i = oldCount; i < newCount; i++)
@@ -712,8 +718,10 @@ namespace Slang
}
// Transpose the first block (parameter block)
- auto paramTransposeInfo =
- splitAndTransposeParameterBlock(builder, diffPropagateFunc, isResultDifferentiable);
+ auto paramTransposeInfo = splitAndTransposeParameterBlock(builder,
+ diffPropagateFunc,
+ primalFunc->sourceLoc,
+ isResultDifferentiable);
// The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc
// may be used by write back logic that we are going to insert later.
@@ -815,6 +823,7 @@ namespace Slang
ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock(
IRBuilder* builder,
IRFunc* diffFunc,
+ SourceLoc primalLoc,
bool isResultDifferentiable)
{
// This method splits transposes the all the parameters for both the primal and propagate computation.
@@ -841,6 +850,7 @@ namespace Slang
auto nextBlockBuilder = *builder;
nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst());
+ SourceLoc returnLoc;
IRBlock* firstDiffBlock = nullptr;
for (auto block : diffFunc->getBlocks())
{
@@ -849,6 +859,13 @@ namespace Slang
firstDiffBlock = block;
break;
}
+
+ auto terminator = block->getTerminator();
+ if (as<IRReturn>(terminator))
+ {
+ returnLoc = terminator->sourceLoc;
+ break;
+ }
}
SLANG_RELEASE_ASSERT(firstDiffBlock);
@@ -895,6 +912,8 @@ namespace Slang
// from the primal compuation logic in the future propagate function be replaced to.
for (auto fwdParam : fwdParams)
{
+ IRBuilderSourceLocRAII sourceLocationScope(builder, fwdParam->sourceLoc);
+
// Define the replacement insts that we are going to fill in for each case.
IRInst* diffRefReplacement = nullptr;
IRInst* primalRefReplacement = nullptr;
@@ -1186,6 +1205,7 @@ namespace Slang
SLANG_ASSERT(dOutParamType);
dOutParam = builder->emitParam(dOutParamType);
+ dOutParam->sourceLoc = returnLoc;
builder->addNameHintDecoration(dOutParam, UnownedStringSlice("_s_dOut"));
result.propagateFuncParams.add(dOutParam);
}
@@ -1196,6 +1216,10 @@ namespace Slang
result.primalFuncParams.add(ctxParam);
result.propagateFuncParams.add(ctxParam);
result.dOutParam = dOutParam;
+
+ diffFunc->sourceLoc = primalLoc;
+ ctxParam->sourceLoc = primalLoc;
+
return result;
}
diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h
index 68cb4e0c9..b65701a7a 100644
--- a/source/slang/slang-ir-autodiff-rev.h
+++ b/source/slang/slang-ir-autodiff-rev.h
@@ -105,6 +105,7 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase
ParameterBlockTransposeInfo splitAndTransposeParameterBlock(
IRBuilder* builder,
IRFunc* diffFunc,
+ SourceLoc primalLoc,
bool isResultDifferentiable);
void writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index da69ed8ae..1fa76c730 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -1033,8 +1033,9 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori
if (as<IRModuleInst>(origInst->getParent()) && !as<IRType>(origInst))
return InstPair(origInst, nullptr);
- auto result = transcribeInstImpl(builder, origInst);
+ IRBuilderSourceLocRAII sourceLocationScope(builder, origInst->sourceLoc);
+ auto result = transcribeInstImpl(builder, origInst);
if (result.primal == nullptr && result.differential == nullptr)
{
if (auto origType = as<IRType>(origInst))
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index d42462e1b..1f8c3052e 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -609,6 +609,8 @@ struct DiffTransposePass
auto nextInst = inst->getNextInst();
if (auto varInst = as<IRVar>(inst))
{
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, varInst->sourceLoc);
+
if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst))
{
if (auto ptrPrimalType = as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst)))
@@ -692,7 +694,11 @@ struct DiffTransposePass
SLANG_ASSERT(lastRevBlock->getTerminator() == nullptr);
builder.setInsertInto(lastRevBlock);
- builder.emitReturn();
+
+ {
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, revDiffFunc->sourceLoc);
+ builder.emitReturn();
+ }
// Remove fwd-mode blocks.
for (auto block : workList)
@@ -703,6 +709,8 @@ struct DiffTransposePass
IRInst* extractAccumulatorVarGradient(IRBuilder* builder, IRInst* fwdInst)
{
+ IRBuilderSourceLocRAII sourceLocationScope(builder, fwdInst->sourceLoc);
+
if (auto accVar = getOrCreateAccumulatorVar(fwdInst))
{
auto gradValue = builder->emitLoad(accVar);
@@ -731,6 +739,7 @@ struct DiffTransposePass
return revAccumulatorVarMap[fwdInst];
IRBuilder tempVarBuilder(autodiffContext->moduleInst->getModule());
+ IRBuilderSourceLocRAII sourceLocationSCope(&tempVarBuilder, fwdInst->sourceLoc);
IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(fwdInst->getParent()->getParent())];
@@ -785,6 +794,8 @@ struct DiffTransposePass
for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++)
{
auto arg = branchInst->getArg(ii);
+
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, arg->sourceLoc);
if (isDifferentialInst(arg))
{
// If the arg is a differential, emit a parameter
@@ -885,6 +896,8 @@ struct DiffTransposePass
List<IRInst*> phiParamRevGradInsts;
for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam())
{
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, param->sourceLoc);
+
if (isDifferentialInst(param))
{
// This param might be used outside this block.
@@ -949,6 +962,8 @@ struct DiffTransposePass
if (auto accVar = getOrCreateAccumulatorVar(externInst))
{
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, externInst->sourceLoc);
+
// Accumulate all gradients, including our accumulator variable,
// into one inst.
//
@@ -1050,6 +1065,7 @@ struct DiffTransposePass
// Emit the aggregate of all the gradients here.
// This will form the total derivative for this inst.
+ IRBuilderSourceLocRAII sourceLocationScope(builder, inst->sourceLoc);
auto revValue = emitAggregateValue(builder, primalType, gradients);
auto transposeResult = transposeInst(builder, inst, revValue);
@@ -2738,7 +2754,6 @@ struct DiffTransposePass
gradient.revGradInst,
gradient.fwdGradInst
));
-
}
for (auto pair : bucketedGradients)
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 9b3e3a324..0953c535a 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -75,6 +75,9 @@ struct ExtractPrimalFuncContext
builder.setInsertBefore(destFunc);
IRFuncType* originalFuncType = nullptr;
outIntermediateType = createIntermediateType(destFunc);
+
+ builder.addCheckpointIntermediateDecoration(outIntermediateType, originalFunc);
+ outIntermediateType->sourceLoc = originalFunc->sourceLoc;
GenericChildrenMigrationContext migrationContext;
migrationContext.init(as<IRGeneric>(findOuterGeneric(originalFunc)), as<IRGeneric>(findOuterGeneric(destFunc)), destFunc);
@@ -154,6 +157,7 @@ struct ExtractPrimalFuncContext
IRInst* intermediateOutput)
{
auto field = addIntermediateContextField(inst->getDataType(), intermediateOutput);
+ field->sourceLoc = inst->sourceLoc;
auto key = field->getKey();
if (auto nameHint = inst->findDecoration<IRNameHintDecoration>())
cloneDecoration(nameHint, key);
@@ -219,6 +223,10 @@ struct ExtractPrimalFuncContext
if (inst->hasUses())
{
auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary);
+ field->sourceLoc = inst->sourceLoc;
+ if (inst->findDecoration<IRLoopCounterDecoration>())
+ builder.addLoopCounterDecoration(field);
+
builder.setInsertBefore(inst);
auto fieldAddr = builder.emitFieldAddress(
inst->getFullType(), outIntermediary, field->getKey());
@@ -379,12 +387,16 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
use->set(builder.getVoidValue());
continue;
}
+
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, use->getUser()->sourceLoc);
+
builder.setInsertBefore(use->getUser());
auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType();
auto val = builder.emitFieldExtract(
valType,
intermediateVar,
structKeyDecor->getStructKey());
+
if (use->getUser()->getOp() == kIROp_Load)
{
use->getUser()->replaceUsesWith(val);
@@ -392,8 +404,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
}
else
{
- auto tempVar =
- builder.emitVar(valType);
+ auto tempVar = builder.emitVar(valType);
builder.emitStore(tempVar, val);
use->set(tempVar);
}
@@ -401,7 +412,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
}
else
{
- // Orindary value.
+ // Ordinary value.
// We insert a fieldExtract at each use site instead of before `inst`,
// since at this stage of autodiff pass, `inst` does not necessarily
// dominate all the use sites if `inst` is defined in partial branch
@@ -417,6 +428,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
inst->getFullType(),
intermediateVar,
structKeyDecor->getStructKey());
+ val->sourceLoc = user->sourceLoc;
builder.replaceOperand(iuse, val);
}
}
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 9f18db6e0..6ae5126f9 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -588,7 +588,6 @@ struct DiffUnzipPass
as<IRBlock>(diffMap[targetBlock]),
diffArgs.getCount(),
diffArgs.getBuffer()));
-
}
case kIROp_conditionalBranch:
@@ -710,6 +709,9 @@ struct DiffUnzipPass
void splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst)
{
+ IRBuilderSourceLocRAII primalLocationScope(primalBuilder, inst->sourceLoc);
+ IRBuilderSourceLocRAII diffLocationScope(diffBuilder, inst->sourceLoc);
+
auto instPair = _splitMixedInst(primalBuilder, diffBuilder, inst);
primalMap[inst] = instPair.primal;
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 0979c097c..07a6a76fb 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -1203,6 +1203,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_AutoDiffOriginalValueDecoration:
case kIROp_UserDefinedBackwardDerivativeDecoration:
case kIROp_IntermediateContextFieldDifferentialTypeDecoration:
+ case kIROp_CheckpointIntermediateDecoration:
decor->removeAndDeallocate();
break;
case kIROp_AutoDiffBuiltinDecoration:
diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp
index e2297bcb2..a8b9b548e 100644
--- a/source/slang/slang-ir-clone.cpp
+++ b/source/slang/slang-ir-clone.cpp
@@ -220,6 +220,7 @@ static void _cloneInstDecorationsAndChildren(
auto oldType = oldParam->getFullType();
auto newType = (IRType*)findCloneForOperand(env, oldType);
newParam->setFullType(newType);
+ newParam->sourceLoc = oldParam->sourceLoc;
}
}
diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp
index b17fad6ec..0db2fc765 100644
--- a/source/slang/slang-ir-eliminate-phis.cpp
+++ b/source/slang/slang-ir-eliminate-phis.cpp
@@ -462,6 +462,7 @@ struct PhiEliminationContext
// to the temporary that will replace it.
//
param->transferDecorationsTo(temp);
+ temp->sourceLoc = param->sourceLoc;
}
// The other main auxilliary sxtructure is used to track
@@ -550,6 +551,7 @@ struct PhiEliminationContext
auto user = use->getUser();
m_builder.setInsertBefore(user);
auto newVal = m_builder.emitLoad(temp);
+ newVal->sourceLoc = param->sourceLoc;
m_builder.replaceOperand(use, newVal);
}
@@ -938,6 +940,7 @@ struct PhiEliminationContext
newOperands.getCount(),
newOperands.getArrayView().getBuffer());
oldBranch->transferDecorationsTo(newBranch);
+ newBranch->sourceLoc = oldBranch->sourceLoc;
// TODO: We could consider just modifying `branch` in-place by clearing
// the relevant operands for the phi arguments and setting its operand
diff --git a/source/slang/slang-ir-init-local-var.cpp b/source/slang/slang-ir-init-local-var.cpp
index 34a0e5ff4..fa556bc58 100644
--- a/source/slang/slang-ir-init-local-var.cpp
+++ b/source/slang/slang-ir-init-local-var.cpp
@@ -47,6 +47,9 @@ void initializeLocalVariables(IRModule* module, IRGlobalValueWithCode* func)
breakLabel:;
if (initialized)
continue;
+
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, inst->sourceLoc);
+
builder.setInsertAfter(inst);
builder.emitStore(
inst,
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index b526df3a9..301a9c789 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1056,6 +1056,9 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
/// Hint that the result from a call to the decorated function should be recomputed in backward prop function.
INST(PreferRecomputeDecoration, PreferRecomputeDecoration, 0, 0)
+ /// Hint that a struct is used for reverse mode checkpointing
+ INST(CheckpointIntermediateDecoration, CheckpointIntermediateDecoration, 1, 0)
+
INST_RANGE(CheckpointHintDecoration, PreferCheckpointDecoration, PreferRecomputeDecoration)
/// Marks a function whose return value is never dynamic uniform.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 69f129986..37f242e55 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -947,6 +947,16 @@ struct IRPreferCheckpointDecoration : IRCheckpointHintDecoration
IR_LEAF_ISA(PreferCheckpointDecoration)
};
+struct IRCheckpointIntermediateDecoration : IRCheckpointHintDecoration
+{
+ enum
+ {
+ kOp = kIROp_CheckpointIntermediateDecoration
+ };
+ IR_LEAF_ISA(CheckpointIntermediateDecoration)
+
+ IRInst* getSourceFunction() { return getOperand(0); }
+};
struct IRLoopCounterDecoration : IRDecoration
{
@@ -5152,6 +5162,11 @@ public:
{
addDecoration(inst, kIROp_MemoryQualifierSetDecoration, getIntValue(getIntType(), flags));
}
+
+ void addCheckpointIntermediateDecoration(IRInst* inst, IRGlobalValueWithCode *func)
+ {
+ addDecoration(inst, kIROp_CheckpointIntermediateDecoration, func);
+ }
};
// Helper to establish the source location that will be used
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index 753c930a8..ef0551161 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -526,6 +526,7 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst)
// we will now introduce a breakable region for each iteration.
IRBuilder builder(module);
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, loopInst->sourceLoc);
auto targetBlock = loopInst->getTargetBlock();
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index e44c4079b..506e6a335 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -431,6 +431,7 @@ PhiInfo* addPhi(
RefPtr<PhiInfo> phiInfo = new PhiInfo();
context->phiInfos.add(phi, phiInfo);
+ phi->sourceLoc = var->sourceLoc;
phiInfo->phi = phi;
phiInfo->var = var;
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 9305d1783..6c7691d13 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3512,6 +3512,7 @@ namespace Slang
auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>(
this, kIROp_MakeDifferentialPair, type, 2, args);
addInst(inst);
+ inst->sourceLoc = primal->sourceLoc;
return inst;
}
@@ -3524,6 +3525,7 @@ namespace Slang
auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>(
this, kIROp_MakeDifferentialPairUserCode, type, 2, args);
addInst(inst);
+ inst->sourceLoc = primal->sourceLoc;
return inst;
}
diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp
index c02a00957..b9a12f971 100644
--- a/source/slang/slang-options.cpp
+++ b/source/slang/slang-options.cpp
@@ -339,6 +339,7 @@ void initCommandOptions(CommandOptions& options)
{ OptionKind::InputFilesRemain, "--", nullptr, "Treat the rest of the command line as input files."},
{ OptionKind::ReportDownstreamTime, "-report-downstream-time", nullptr, "Reports the time spent in the downstream compiler." },
{ OptionKind::ReportPerfBenchmark, "-report-perf-benchmark", nullptr, "Reports compiler performance benchmark results." },
+ { OptionKind::ReportCheckpointIntermediates, "-report-checkpoint-intermediates", nullptr, "Reports information about checkpoint contexts used for reverse-mode automatic differentiation." },
{ OptionKind::SkipSPIRVValidation, "-skip-spirv-validation", nullptr, "Skips spirv validation." },
{ OptionKind::SourceEmbedStyle, "-source-embed-style", "-source-embed-style <source-embed-style>",
"If source embedding is enabled, defines the style used. When enabled (with any style other than `none`), "
@@ -1703,6 +1704,7 @@ SlangResult OptionsParser::_parse(
case OptionKind::DumpReproOnError:
case OptionKind::ReportDownstreamTime:
case OptionKind::ReportPerfBenchmark:
+ case OptionKind::ReportCheckpointIntermediates:
case OptionKind::SkipSPIRVValidation:
case OptionKind::DisableSpecialization:
case OptionKind::DisableDynamicDispatch:
diff --git a/tests/autodiff/reverse-checkpoint-1.slang b/tests/autodiff/reverse-checkpoint-1.slang
index 517297013..3d6e9e702 100644
--- a/tests/autodiff/reverse-checkpoint-1.slang
+++ b/tests/autodiff/reverse-checkpoint-1.slang
@@ -2,6 +2,7 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
+//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -16,13 +17,16 @@ float g(float x)
return log(x);
}
+//CHK: note: checkpointing context of 4 bytes associated with function: 'f'
[BackwardDifferentiable]
float f(int p, float x)
{
float y = 1.0;
// Test that phi parameter can be restored.
if (p == 0)
+ //CHK: note: 4 bytes (float) used to checkpoint the following item:
y = g(x);
+
return y * y;
}
@@ -41,3 +45,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
__bwd_diff(f)(0, dpa, 1.0f);
outputBuffer[0] = dpa.d; // Expect: 1
}
+
+//CHK-NOT: note \ No newline at end of file
diff --git a/tests/autodiff/reverse-checkpoint-2.slang b/tests/autodiff/reverse-checkpoint-2.slang
index 8a7262aa4..1dd3f2963 100644
--- a/tests/autodiff/reverse-checkpoint-2.slang
+++ b/tests/autodiff/reverse-checkpoint-2.slang
@@ -41,3 +41,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
__bwd_diff(f)(0, dpa, 1.0f);
outputBuffer[0] = dpa.d; // Expect: 1
}
+
+//CHK-NOT: note \ No newline at end of file
diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang
index 0f9502673..0b6e56f78 100644
--- a/tests/autodiff/reverse-continue-loop.slang
+++ b/tests/autodiff/reverse-continue-loop.slang
@@ -1,6 +1,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -8,11 +9,14 @@ RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef float.Differential dfloat;
+//CHK: note: checkpointing context of 24 bytes associated with function: 'test_loop_with_continue'
[BackwardDifferentiable]
float test_loop_with_continue(float y)
{
+ //CHK: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item:
float t = y;
+ //CHK: note: 4 bytes (int32_t) used for a loop counter here:
for (int i = 0; i < 3; i++)
{
if (t > 4.0)
@@ -41,3 +45,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
outputBuffer[1] = dpa.d; // Expect: 0.0131072
}
}
+
+//CHK-NOT: note \ No newline at end of file
diff --git a/tests/autodiff/reverse-control-flow-1.slang b/tests/autodiff/reverse-control-flow-1.slang
index 7d2f518be..334de4137 100644
--- a/tests/autodiff/reverse-control-flow-1.slang
+++ b/tests/autodiff/reverse-control-flow-1.slang
@@ -1,5 +1,6 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -40,3 +41,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
outputBuffer[1] = dpa.d; // Expect: 1.0
}
}
+
+//CHK: (0): note: no checkpoint contexts to report \ No newline at end of file
diff --git a/tests/autodiff/reverse-control-flow-2.slang b/tests/autodiff/reverse-control-flow-2.slang
index cde707b4d..c3790367c 100644
--- a/tests/autodiff/reverse-control-flow-2.slang
+++ b/tests/autodiff/reverse-control-flow-2.slang
@@ -1,5 +1,6 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -73,3 +74,5 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
outputBuffer[1] = dpx.d;
}
}
+
+//CHK: (0): note: no checkpoint contexts to report \ No newline at end of file
diff --git a/tests/autodiff/reverse-control-flow-3.slang b/tests/autodiff/reverse-control-flow-3.slang
index 01b533279..b4fa68e3a 100644
--- a/tests/autodiff/reverse-control-flow-3.slang
+++ b/tests/autodiff/reverse-control-flow-3.slang
@@ -1,4 +1,5 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
@@ -75,7 +76,8 @@ void d_getParam(uint id, MaterialParam.Differential diff)
outputBuffer[id] += diff.roughness;
}
-
+//CHK-DAG: note: checkpointing context of 8 bytes associated with function: 'updatePathThroughput'
+//CHK-DAG: note: 8 bytes (PathResult_0) used to checkpoint the following item:
[BackwardDifferentiable]
void updatePathThroughput(inout PathResult path, const float weight)
{
@@ -122,9 +124,13 @@ bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, ino
\param[in,out] path The path state.
\return True if a ray was generated, false otherwise.
*/
+
+//CHK-DAG: note: checkpointing context of 16 bytes associated with function: 'generateScatterRay'
[BackwardDifferentiable]
bool generateScatterRay(const BSDFSample bs, const MaterialParam bsdfParams, inout PathState path, inout PathResult pathRes)
{
+ //CHK-DAG: note: 8 bytes (s_bwd_prop_updatePathThroughput_Intermediates_0) used to checkpoint the following item:
+ //CHK-DAG: note: 8 bytes (PathResult_0) used to checkpoint the following item:
updatePathThroughput(pathRes, bs.val);
return true;
}
@@ -215,5 +221,6 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
var dpx = diffPair(pathRes, pathResD);
__bwd_diff(tracePath)(1, dpx); // Expect: 5.0 in outputBuffer[3]
}
-
}
+
+//CHK-NOT: note \ No newline at end of file
diff --git a/tests/autodiff/reverse-loop-checkpoint-test.slang b/tests/autodiff/reverse-loop-checkpoint-test.slang
index fc206e128..68ad823ac 100644
--- a/tests/autodiff/reverse-loop-checkpoint-test.slang
+++ b/tests/autodiff/reverse-loop-checkpoint-test.slang
@@ -1,5 +1,6 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
+//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -44,13 +45,18 @@ float3 infinitesimal(float3 x)
return x - detach(x);
}
+//CHK: note: checkpointing context of 20 bytes associated with function: 'computeLoop'
[BackwardDifferentiable]
[PreferRecompute]
float3 computeLoop(float y)
{
+ //CHK: note: 4 bytes (float) used to checkpoint the following item:
float w = 0;
+
+ //CHK: note: 12 bytes (Vector<float, 3> ) used to checkpoint the following item:
float3 w3 = float3(0, 0, 0);
+ //CHK: note: 4 bytes (int32_t) used for a loop counter here:
for (int i = 0; i < 8; i++)
{
float k = compute(i, y);
@@ -93,3 +99,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
outputBuffer[2] = computeLoop(1.0).x;
}
+
+//CHK-NOT: note \ No newline at end of file
diff --git a/tests/autodiff/reverse-loop.slang b/tests/autodiff/reverse-loop.slang
index a2c826be9..2ba8535be 100644
--- a/tests/autodiff/reverse-loop.slang
+++ b/tests/autodiff/reverse-loop.slang
@@ -1,6 +1,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -8,11 +9,14 @@ RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
typedef float.Differential dfloat;
+//CHK: note: checkpointing context of 24 bytes associated with function: 'test_simple_loop'
[Differentiable]
float test_simple_loop(float y)
{
+ //CHK: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item:
float t = y;
+ //CHK: note: 4 bytes (int32_t) used for a loop counter here:
for (int i = 0; i < 3; i++)
{
t = t * t;
@@ -38,3 +42,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
outputBuffer[1] = dpa.d; // Expect: 0.0131072
}
}
+
+//CHK-NOT: note \ No newline at end of file
diff --git a/tests/autodiff/reverse-nested-calls.slang b/tests/autodiff/reverse-nested-calls.slang
index caf2df6f8..3c1a52c21 100644
--- a/tests/autodiff/reverse-nested-calls.slang
+++ b/tests/autodiff/reverse-nested-calls.slang
@@ -1,6 +1,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+//TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
@@ -15,9 +16,11 @@ float g(float y)
return result * result;
}
+//CHK: note: checkpointing context of 4 bytes associated with function: 'f'
[BackwardDifferentiable]
float f(float x)
{
+ //CHK: note: 4 bytes (float) used to checkpoint the following item:
return 3.0f * g(2.0f * x);
}
@@ -29,3 +32,5 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
__bwd_diff(f)(dpa, 1.0f);
outputBuffer[0] = dpa.d; // Expect: 96.0
}
+
+//CHK-NOT: note \ No newline at end of file