summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-autodiff-rev.cpp
diff options
context:
space:
mode:
authorvenkataram-nv <vedavamadath@nvidia.com>2024-09-18 20:42:07 -0700
committerGitHub <noreply@github.com>2024-09-18 20:42:07 -0700
commitb808aa4df50d46eaa569561f7e464c55c1c2d72a (patch)
tree5483a3f9e73a401ff82d66fd1ac3729a9a84a97c /source/slang/slang-ir-autodiff-rev.cpp
parent3240799c00488858afc7eeac9d1dc479609a1040 (diff)
Report AD checkpoint contexts (#5058)
* Transferring source locations when creating phi instructions * Tracking for simple variables * Deriving source locations for loop counters * Printing checkpoint structure breakdown * More readable output format * Special behavior for loop counters * Writing report to file * Add slangc option to enable checkpoint reports * Display types of checkpointed fields * Message in case there are no checkpointing contexts * Catch source locations for function calls * Source cleanup * Fix compilation warnings * Remove stray dump() * Provide the report through diagnostic notes * Add missing path for sourceLoc during unzip pass * Add tests for reporting intermediates * Include more transfer cases for source locations * Fix ordering in address elimination * Fill in more holes with source location transfer * Remove debugging line * Reverting changes to diagnostic sink * Simplify address elimination using source location RAII contexts * Eliminating manual source loc transfers in forward transcription * Fix local var adaptation to use RAII location setter * Simplify primal hoisting logic for source location transfer * Simplify unzipping with RAII location scopes * Simplify transpose logic * Cleaning up for rev.cpp * Reverting spacing changes * Fix mistake with source loc RAII instantiation * Fix formatting issues
Diffstat (limited to 'source/slang/slang-ir-autodiff-rev.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp28
1 files changed, 26 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 35a197f29..2fb73c4ac 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -403,8 +403,11 @@ namespace Slang
List<IRType*> primalTypes, propagateTypes;
IRType* primalResultType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getResultType());
+ IRParam *currentParam = origFunc->getFirstParam();
for (UInt i = 0; i < origFuncType->getParamCount(); i++)
{
+ IRBuilderSourceLocRAII sourceLocationScope(&builder, currentParam->sourceLoc);
+
auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i));
auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i));
if (propagateParamType)
@@ -453,6 +456,7 @@ namespace Slang
primalArgs.add(var);
}
primalTypes.add(primalParamType);
+ currentParam = currentParam->getNextParam();
}
// Add dOut argument to propagateArgs.
@@ -588,6 +592,8 @@ namespace Slang
autoDiffSharedContext->transcriberSet.forwardTranscriber);
auto oldCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount();
IRFunc* fwdDiffFunc = as<IRFunc>(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent)));
+ fwdDiffFunc->sourceLoc = primalFunc->sourceLoc;
+
SLANG_ASSERT(fwdDiffFunc);
auto newCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount();
for (auto i = oldCount; i < newCount; i++)
@@ -712,8 +718,10 @@ namespace Slang
}
// Transpose the first block (parameter block)
- auto paramTransposeInfo =
- splitAndTransposeParameterBlock(builder, diffPropagateFunc, isResultDifferentiable);
+ auto paramTransposeInfo = splitAndTransposeParameterBlock(builder,
+ diffPropagateFunc,
+ primalFunc->sourceLoc,
+ isResultDifferentiable);
// The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc
// may be used by write back logic that we are going to insert later.
@@ -815,6 +823,7 @@ namespace Slang
ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock(
IRBuilder* builder,
IRFunc* diffFunc,
+ SourceLoc primalLoc,
bool isResultDifferentiable)
{
// This method splits transposes the all the parameters for both the primal and propagate computation.
@@ -841,6 +850,7 @@ namespace Slang
auto nextBlockBuilder = *builder;
nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst());
+ SourceLoc returnLoc;
IRBlock* firstDiffBlock = nullptr;
for (auto block : diffFunc->getBlocks())
{
@@ -849,6 +859,13 @@ namespace Slang
firstDiffBlock = block;
break;
}
+
+ auto terminator = block->getTerminator();
+ if (as<IRReturn>(terminator))
+ {
+ returnLoc = terminator->sourceLoc;
+ break;
+ }
}
SLANG_RELEASE_ASSERT(firstDiffBlock);
@@ -895,6 +912,8 @@ namespace Slang
// from the primal compuation logic in the future propagate function be replaced to.
for (auto fwdParam : fwdParams)
{
+ IRBuilderSourceLocRAII sourceLocationScope(builder, fwdParam->sourceLoc);
+
// Define the replacement insts that we are going to fill in for each case.
IRInst* diffRefReplacement = nullptr;
IRInst* primalRefReplacement = nullptr;
@@ -1186,6 +1205,7 @@ namespace Slang
SLANG_ASSERT(dOutParamType);
dOutParam = builder->emitParam(dOutParamType);
+ dOutParam->sourceLoc = returnLoc;
builder->addNameHintDecoration(dOutParam, UnownedStringSlice("_s_dOut"));
result.propagateFuncParams.add(dOutParam);
}
@@ -1196,6 +1216,10 @@ namespace Slang
result.primalFuncParams.add(ctxParam);
result.propagateFuncParams.add(ctxParam);
result.dOutParam = dOutParam;
+
+ diffFunc->sourceLoc = primalLoc;
+ ctxParam->sourceLoc = primalLoc;
+
return result;
}