diff options
| author | venkataram-nv <vedavamadath@nvidia.com> | 2024-09-18 20:42:07 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-18 20:42:07 -0700 |
| commit | b808aa4df50d46eaa569561f7e464c55c1c2d72a (patch) | |
| tree | 5483a3f9e73a401ff82d66fd1ac3729a9a84a97c /source | |
| parent | 3240799c00488858afc7eeac9d1dc479609a1040 (diff) | |
Report AD checkpoint contexts (#5058)
* Transferring source locations when creating phi instructions
* Tracking for simple variables
* Deriving source locations for loop counters
* Printing checkpoint structure breakdown
* More readable output format
* Special behavior for loop counters
* Writing report to file
* Add slangc option to enable checkpoint reports
* Display types of checkpointed fields
* Message in case there are no checkpointing contexts
* Catch source locations for function calls
* Source cleanup
* Fix compilation warnings
* Remove stray dump()
* Provide the report through diagnostic notes
* Add missing path for sourceLoc during unzip pass
* Add tests for reporting intermediates
* Include more transfer cases for source locations
* Fix ordering in address elimination
* Fill in more holes with source location transfer
* Remove debugging line
* Reverting changes to diagnostic sink
* Simplify address elimination using source location RAII contexts
* Eliminating manual source loc transfers in forward transcription
* Fix local var adaptation to use RAII location setter
* Simplify primal hoisting logic for source location transfer
* Simplify unzipping with RAII location scopes
* Simplify transpose logic
* Cleaning up for rev.cpp
* Reverting spacing changes
* Fix mistake with source loc RAII instantiation
* Fix formatting issues
Diffstat (limited to 'source')
23 files changed, 215 insertions, 31 deletions
diff --git a/source/slang-record-replay/util/emum-to-string.h b/source/slang-record-replay/util/emum-to-string.h index 7226edc04..8c140cf3d 100644 --- a/source/slang-record-replay/util/emum-to-string.h +++ b/source/slang-record-replay/util/emum-to-string.h @@ -149,6 +149,7 @@ namespace SlangRecord CASE(EmitIr); CASE(ReportDownstreamTime); CASE(ReportPerfBenchmark); + CASE(ReportCheckpointIntermediates); CASE(SkipSPIRVValidation); CASE(SourceEmbedStyle); CASE(SourceEmbedName); diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 541085b4e..c89d94c80 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2451,12 +2451,16 @@ namespace Slang return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIr); } + bool CodeGenContext::shouldReportCheckpointIntermediates() + { + return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::ReportCheckpointIntermediates); + } + bool CodeGenContext::shouldDumpIntermediates() { return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::DumpIntermediates); } - bool CodeGenContext::shouldTrackLiveness() { return getTargetProgram()->getOptionSet().getBoolOption(CompilerOptionName::TrackLiveness); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 0c788ae18..4b20d1f76 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2728,6 +2728,7 @@ namespace Slang bool shouldValidateIR(); bool shouldDumpIR(); + bool shouldReportCheckpointIntermediates(); bool shouldTrackLiveness(); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 81170fac3..e0f1e90c5 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -894,6 +894,12 @@ DIAGNOSTIC(58002, Error, unhandledGLSLSSBOType, "Unhandled GLSL Shader Storage B DIAGNOSTIC(58003, Error, inconsistentPointerAddressSpace, "'$0': use of pointer with inconsistent address space.") +// Autodiff checkpoint reporting +DIAGNOSTIC(-1, Note, reportCheckpointIntermediates, "checkpointing context of $1 bytes associated with function: '$0'") +DIAGNOSTIC(-1, Note, reportCheckpointVariable, "$0 bytes ($1) used to checkpoint the following item:") +DIAGNOSTIC(-1, Note, reportCheckpointCounter, "$0 bytes ($1) used for a loop counter here:") +DIAGNOSTIC(-1, Note, reportCheckpointNone, "no checkpoint contexts to report") + // // 8xxxx - Issues specific to a particular library/technology/platform/etc. // diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index cdd2ca5b6..6e3556064 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -34,6 +34,7 @@ #include "slang-ir-wgsl-legalize.h" #include "slang-ir-insts.h" #include "slang-ir-inline.h" +#include "slang-ir-layout.h" #include "slang-ir-legalize-array-return-type.h" #include "slang-ir-legalize-mesh-outputs.h" #include "slang-ir-legalize-varying-params.h" @@ -214,6 +215,68 @@ static void dumpIRIfEnabled( } } +static void reportCheckpointIntermediates(CodeGenContext* codeGenContext, DiagnosticSink* sink, IRModule* irModule) +{ + // Report checkpointing information + CompilerOptionSet& optionSet = codeGenContext->getTargetProgram()->getOptionSet(); + SourceManager* sourceManager = sink->getSourceManager(); + + SourceWriter typeWriter(sourceManager, LineDirectiveMode::None, nullptr); + + CLikeSourceEmitter::Desc description; + description.codeGenContext = codeGenContext; + description.sourceWriter = &typeWriter; + + CPPSourceEmitter emitter(description); + + int nonEmptyStructs = 0; + for (auto inst : irModule->getGlobalInsts()) + { + IRStructType *structType = as<IRStructType>(inst); + if (!structType) + continue; + + auto checkpointDecoration = structType->findDecoration<IRCheckpointIntermediateDecoration>(); + if (!checkpointDecoration) + continue; + + IRSizeAndAlignment structSize; + getNaturalSizeAndAlignment(optionSet, structType, &structSize); + + // Reporting happens before empty structs are optimized out + // and we still want to keep the checkpointing decorations, + // so we end up needing to check for non-zero-ness + if (structSize.size == 0) + continue; + + auto func = checkpointDecoration->getSourceFunction(); + sink->diagnose(structType, Diagnostics::reportCheckpointIntermediates, func, structSize.size); + nonEmptyStructs++; + + for (auto field : structType->getFields()) + { + IRType *fieldType = field->getFieldType(); + IRSizeAndAlignment fieldSize; + getNaturalSizeAndAlignment(optionSet, fieldType, &fieldSize); + if (fieldSize.size == 0) + continue; + + typeWriter.clearContent(); + emitter.emitType(fieldType); + + sink->diagnose(field->sourceLoc, + field->findDecoration<IRLoopCounterDecoration>() + ? Diagnostics::reportCheckpointCounter + : Diagnostics::reportCheckpointVariable, + fieldSize.size, + typeWriter.getContent()); + } + } + + if (nonEmptyStructs == 0) + sink->diagnose(SourceLoc(), Diagnostics::reportCheckpointNone); +} + struct LinkingAndOptimizationOptions { bool shouldLegalizeExistentialAndResourceTypes = true; @@ -767,6 +830,10 @@ Result linkAndOptimizeIR( break; } + // Report checkpointing information + if (codeGenContext->shouldReportCheckpointIntermediates()) + reportCheckpointIntermediates(codeGenContext, sink, irModule); + if (requiredLoweringPassSet.autodiff) finalizeAutoDiffPass(targetProgram, irModule); diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 8a48936d7..b55f6b93d 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -69,30 +69,28 @@ struct AddressInstEliminationContext } } - void transformLoadAddr(IRUse* use) + void transformLoadAddr(IRBuilder& builder, IRUse* use) { auto addr = use->get(); auto load = as<IRLoad>(use->getUser()); - IRBuilder builder(module); builder.setInsertBefore(use->getUser()); auto value = getValue(builder, addr); load->replaceUsesWith(value); load->removeAndDeallocate(); } - void transformStoreAddr(IRUse* use) + void transformStoreAddr(IRBuilder& builder, IRUse* use) { auto addr = use->get(); auto store = as<IRStore>(use->getUser()); - IRBuilder builder(module); builder.setInsertBefore(use->getUser()); storeValue(builder, addr, store->getVal()); store->removeAndDeallocate(); } - void transformCallAddr(IRUse* use) + void transformCallAddr(IRBuilder& builder, IRUse* use) { auto addr = use->get(); auto call = as<IRCall>(use->getUser()); @@ -103,7 +101,6 @@ struct AddressInstEliminationContext return; } - IRBuilder builder(module); builder.setInsertBefore(call); auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType()); @@ -155,17 +152,20 @@ struct AddressInstEliminationContext use = nextUse; continue; } + + IRBuilder transformBuilder(module); + IRBuilderSourceLocRAII sourceLocationScope(&transformBuilder, use->getUser()->sourceLoc); switch (use->getUser()->getOp()) { case kIROp_Load: - transformLoadAddr(use); + transformLoadAddr(transformBuilder, use); break; case kIROp_Store: - transformStoreAddr(use); + transformStoreAddr(transformBuilder, use); break; case kIROp_Call: - transformCallAddr(use); + transformCallAddr(transformBuilder, use); break; case kIROp_GetElementPtr: case kIROp_FieldAddress: diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 9fe4ec70b..f51178f0f 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -3,8 +3,9 @@ #include "slang-ir-autodiff-region.h" #include "slang-ir-simplify-cfg.h" #include "slang-ir-util.h" -#include "../core/slang-func-ptr.h" +#include "slang-ir-insts.h" #include "slang-ir.h" +#include "../core/slang-func-ptr.h" namespace Slang { @@ -1092,7 +1093,8 @@ IRType* getTypeForLocalStorage( IRVar* emitIndexedLocalVar( IRBlock* varBlock, IRType* baseType, - const List<IndexTrackingInfo>& defBlockIndices) + const List<IndexTrackingInfo>& defBlockIndices, + SourceLoc location) { // Cannot store pointers. Case should have been handled by now. SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType)); @@ -1101,6 +1103,8 @@ IRVar* emitIndexedLocalVar( SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType)); IRBuilder varBuilder(varBlock->getModule()); + IRBuilderSourceLocRAII sourceLocationScope(&varBuilder, location); + varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst()); IRType* varType = getTypeForLocalStorage(&varBuilder, baseType, defBlockIndices); @@ -1179,9 +1183,14 @@ IRVar* storeIndexedValue( IRInst* instToStore, const List<IndexTrackingInfo>& defBlockIndices) { - IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices); + IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, + instToStore->getDataType(), + defBlockIndices, + instToStore->sourceLoc); - IRInst* addr = emitIndexedStoreAddressForVar(builder, localVar, defBlockIndices); + IRInst* addr = emitIndexedStoreAddressForVar(builder, + localVar, + defBlockIndices); builder->emitStore(addr, instToStore); @@ -1574,12 +1583,16 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( // region, that means there's no need to allocate a fully indexed var. // defBlockIndices = maybeTrimIndices(defBlockIndices, indexedBlockInfo, outOfScopeUses); - - IRVar* localVar = storeIndexedValue( - &builder, - varBlock, - builder.emitLoad(varToStore), - defBlockIndices); + + IRVar* localVar = nullptr; + { + IRBuilderSourceLocRAII sourceLocationScope(&builder, varToStore->sourceLoc); + localVar = storeIndexedValue( + &builder, + varBlock, + builder.emitLoad(varToStore), + defBlockIndices); + } for (auto use : outOfScopeUses) { @@ -1626,6 +1639,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( } else { + IRBuilderSourceLocRAII sourceLocationScope(&builder, instToStore->sourceLoc); + // Handle the special case of loop counters. // The only case where there will be a reference of primal loop counter from rev blocks // is the start of a loop in the reverse code. Since loop counters are not considered a @@ -1643,6 +1658,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( setInsertAfterOrdinaryInst(&builder, instToStore); auto localVar = storeIndexedValue(&builder, varBlock, instToStore, defBlockIndices); + if (isLoopCounter) + builder.addLoopCounterDecoration(localVar); for (auto use : outOfScopeUses) { @@ -1728,6 +1745,8 @@ static IRBlock* getUpdateBlock(IRLoop* loop) void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalCountParam, IRInst*& diffCountParam) { IRBuilder builder(primalLoop); + IRBuilderSourceLocRAII sourceLocationScope(&builder, primalLoop->sourceLoc); + primalCountParam = nullptr; // Grab first primal block. @@ -1899,8 +1918,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func) // Legalize the primal inst accesses by introducing local variables / arrays and emitting // necessary load/store logic. // - primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); - return primalsInfo; + return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo); } void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func) diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 35a197f29..2fb73c4ac 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -403,8 +403,11 @@ namespace Slang List<IRType*> primalTypes, propagateTypes; IRType* primalResultType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getResultType()); + IRParam *currentParam = origFunc->getFirstParam(); for (UInt i = 0; i < origFuncType->getParamCount(); i++) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, currentParam->sourceLoc); + auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i)); auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i)); if (propagateParamType) @@ -453,6 +456,7 @@ namespace Slang primalArgs.add(var); } primalTypes.add(primalParamType); + currentParam = currentParam->getNextParam(); } // Add dOut argument to propagateArgs. @@ -588,6 +592,8 @@ namespace Slang autoDiffSharedContext->transcriberSet.forwardTranscriber); auto oldCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); IRFunc* fwdDiffFunc = as<IRFunc>(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent))); + fwdDiffFunc->sourceLoc = primalFunc->sourceLoc; + SLANG_ASSERT(fwdDiffFunc); auto newCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); for (auto i = oldCount; i < newCount; i++) @@ -712,8 +718,10 @@ namespace Slang } // Transpose the first block (parameter block) - auto paramTransposeInfo = - splitAndTransposeParameterBlock(builder, diffPropagateFunc, isResultDifferentiable); + auto paramTransposeInfo = splitAndTransposeParameterBlock(builder, + diffPropagateFunc, + primalFunc->sourceLoc, + isResultDifferentiable); // The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc // may be used by write back logic that we are going to insert later. @@ -815,6 +823,7 @@ namespace Slang ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock( IRBuilder* builder, IRFunc* diffFunc, + SourceLoc primalLoc, bool isResultDifferentiable) { // This method splits transposes the all the parameters for both the primal and propagate computation. @@ -841,6 +850,7 @@ namespace Slang auto nextBlockBuilder = *builder; nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst()); + SourceLoc returnLoc; IRBlock* firstDiffBlock = nullptr; for (auto block : diffFunc->getBlocks()) { @@ -849,6 +859,13 @@ namespace Slang firstDiffBlock = block; break; } + + auto terminator = block->getTerminator(); + if (as<IRReturn>(terminator)) + { + returnLoc = terminator->sourceLoc; + break; + } } SLANG_RELEASE_ASSERT(firstDiffBlock); @@ -895,6 +912,8 @@ namespace Slang // from the primal compuation logic in the future propagate function be replaced to. for (auto fwdParam : fwdParams) { + IRBuilderSourceLocRAII sourceLocationScope(builder, fwdParam->sourceLoc); + // Define the replacement insts that we are going to fill in for each case. IRInst* diffRefReplacement = nullptr; IRInst* primalRefReplacement = nullptr; @@ -1186,6 +1205,7 @@ namespace Slang SLANG_ASSERT(dOutParamType); dOutParam = builder->emitParam(dOutParamType); + dOutParam->sourceLoc = returnLoc; builder->addNameHintDecoration(dOutParam, UnownedStringSlice("_s_dOut")); result.propagateFuncParams.add(dOutParam); } @@ -1196,6 +1216,10 @@ namespace Slang result.primalFuncParams.add(ctxParam); result.propagateFuncParams.add(ctxParam); result.dOutParam = dOutParam; + + diffFunc->sourceLoc = primalLoc; + ctxParam->sourceLoc = primalLoc; + return result; } diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 68cb4e0c9..b65701a7a 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -105,6 +105,7 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase ParameterBlockTransposeInfo splitAndTransposeParameterBlock( IRBuilder* builder, IRFunc* diffFunc, + SourceLoc primalLoc, bool isResultDifferentiable); void writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index da69ed8ae..1fa76c730 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -1033,8 +1033,9 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori if (as<IRModuleInst>(origInst->getParent()) && !as<IRType>(origInst)) return InstPair(origInst, nullptr); - auto result = transcribeInstImpl(builder, origInst); + IRBuilderSourceLocRAII sourceLocationScope(builder, origInst->sourceLoc); + auto result = transcribeInstImpl(builder, origInst); if (result.primal == nullptr && result.differential == nullptr) { if (auto origType = as<IRType>(origInst)) diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index d42462e1b..1f8c3052e 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -609,6 +609,8 @@ struct DiffTransposePass auto nextInst = inst->getNextInst(); if (auto varInst = as<IRVar>(inst)) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, varInst->sourceLoc); + if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst)) { if (auto ptrPrimalType = as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst))) @@ -692,7 +694,11 @@ struct DiffTransposePass SLANG_ASSERT(lastRevBlock->getTerminator() == nullptr); builder.setInsertInto(lastRevBlock); - builder.emitReturn(); + + { + IRBuilderSourceLocRAII sourceLocationScope(&builder, revDiffFunc->sourceLoc); + builder.emitReturn(); + } // Remove fwd-mode blocks. for (auto block : workList) @@ -703,6 +709,8 @@ struct DiffTransposePass IRInst* extractAccumulatorVarGradient(IRBuilder* builder, IRInst* fwdInst) { + IRBuilderSourceLocRAII sourceLocationScope(builder, fwdInst->sourceLoc); + if (auto accVar = getOrCreateAccumulatorVar(fwdInst)) { auto gradValue = builder->emitLoad(accVar); @@ -731,6 +739,7 @@ struct DiffTransposePass return revAccumulatorVarMap[fwdInst]; IRBuilder tempVarBuilder(autodiffContext->moduleInst->getModule()); + IRBuilderSourceLocRAII sourceLocationSCope(&tempVarBuilder, fwdInst->sourceLoc); IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(fwdInst->getParent()->getParent())]; @@ -785,6 +794,8 @@ struct DiffTransposePass for (UIndex ii = 0; ii < branchInst->getArgCount(); ii++) { auto arg = branchInst->getArg(ii); + + IRBuilderSourceLocRAII sourceLocationScope(&builder, arg->sourceLoc); if (isDifferentialInst(arg)) { // If the arg is a differential, emit a parameter @@ -885,6 +896,8 @@ struct DiffTransposePass List<IRInst*> phiParamRevGradInsts; for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam()) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, param->sourceLoc); + if (isDifferentialInst(param)) { // This param might be used outside this block. @@ -949,6 +962,8 @@ struct DiffTransposePass if (auto accVar = getOrCreateAccumulatorVar(externInst)) { + IRBuilderSourceLocRAII sourceLocationScope(&builder, externInst->sourceLoc); + // Accumulate all gradients, including our accumulator variable, // into one inst. // @@ -1050,6 +1065,7 @@ struct DiffTransposePass // Emit the aggregate of all the gradients here. // This will form the total derivative for this inst. + IRBuilderSourceLocRAII sourceLocationScope(builder, inst->sourceLoc); auto revValue = emitAggregateValue(builder, primalType, gradients); auto transposeResult = transposeInst(builder, inst, revValue); @@ -2738,7 +2754,6 @@ struct DiffTransposePass gradient.revGradInst, gradient.fwdGradInst )); - } for (auto pair : bucketedGradients) diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 9b3e3a324..0953c535a 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -75,6 +75,9 @@ struct ExtractPrimalFuncContext builder.setInsertBefore(destFunc); IRFuncType* originalFuncType = nullptr; outIntermediateType = createIntermediateType(destFunc); + + builder.addCheckpointIntermediateDecoration(outIntermediateType, originalFunc); + outIntermediateType->sourceLoc = originalFunc->sourceLoc; GenericChildrenMigrationContext migrationContext; migrationContext.init(as<IRGeneric>(findOuterGeneric(originalFunc)), as<IRGeneric>(findOuterGeneric(destFunc)), destFunc); @@ -154,6 +157,7 @@ struct ExtractPrimalFuncContext IRInst* intermediateOutput) { auto field = addIntermediateContextField(inst->getDataType(), intermediateOutput); + field->sourceLoc = inst->sourceLoc; auto key = field->getKey(); if (auto nameHint = inst->findDecoration<IRNameHintDecoration>()) cloneDecoration(nameHint, key); @@ -219,6 +223,10 @@ struct ExtractPrimalFuncContext if (inst->hasUses()) { auto field = addIntermediateContextField(cast<IRPtrTypeBase>(inst->getDataType())->getValueType(), outIntermediary); + field->sourceLoc = inst->sourceLoc; + if (inst->findDecoration<IRLoopCounterDecoration>()) + builder.addLoopCounterDecoration(field); + builder.setInsertBefore(inst); auto fieldAddr = builder.emitFieldAddress( inst->getFullType(), outIntermediary, field->getKey()); @@ -379,12 +387,16 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( use->set(builder.getVoidValue()); continue; } + + IRBuilderSourceLocRAII sourceLocationScope(&builder, use->getUser()->sourceLoc); + builder.setInsertBefore(use->getUser()); auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType(); auto val = builder.emitFieldExtract( valType, intermediateVar, structKeyDecor->getStructKey()); + if (use->getUser()->getOp() == kIROp_Load) { use->getUser()->replaceUsesWith(val); @@ -392,8 +404,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } else { - auto tempVar = - builder.emitVar(valType); + auto tempVar = builder.emitVar(valType); builder.emitStore(tempVar, val); use->set(tempVar); } @@ -401,7 +412,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( } else { - // Orindary value. + // Ordinary value. // We insert a fieldExtract at each use site instead of before `inst`, // since at this stage of autodiff pass, `inst` does not necessarily // dominate all the use sites if `inst` is defined in partial branch @@ -417,6 +428,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( inst->getFullType(), intermediateVar, structKeyDecor->getStructKey()); + val->sourceLoc = user->sourceLoc; builder.replaceOperand(iuse, val); } } diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 9f18db6e0..6ae5126f9 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -588,7 +588,6 @@ struct DiffUnzipPass as<IRBlock>(diffMap[targetBlock]), diffArgs.getCount(), diffArgs.getBuffer())); - } case kIROp_conditionalBranch: @@ -710,6 +709,9 @@ struct DiffUnzipPass void splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst) { + IRBuilderSourceLocRAII primalLocationScope(primalBuilder, inst->sourceLoc); + IRBuilderSourceLocRAII diffLocationScope(diffBuilder, inst->sourceLoc); + auto instPair = _splitMixedInst(primalBuilder, diffBuilder, inst); primalMap[inst] = instPair.primal; diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 0979c097c..07a6a76fb 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -1203,6 +1203,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_AutoDiffOriginalValueDecoration: case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_IntermediateContextFieldDifferentialTypeDecoration: + case kIROp_CheckpointIntermediateDecoration: decor->removeAndDeallocate(); break; case kIROp_AutoDiffBuiltinDecoration: diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index e2297bcb2..a8b9b548e 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -220,6 +220,7 @@ static void _cloneInstDecorationsAndChildren( auto oldType = oldParam->getFullType(); auto newType = (IRType*)findCloneForOperand(env, oldType); newParam->setFullType(newType); + newParam->sourceLoc = oldParam->sourceLoc; } } diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp index b17fad6ec..0db2fc765 100644 --- a/source/slang/slang-ir-eliminate-phis.cpp +++ b/source/slang/slang-ir-eliminate-phis.cpp @@ -462,6 +462,7 @@ struct PhiEliminationContext // to the temporary that will replace it. // param->transferDecorationsTo(temp); + temp->sourceLoc = param->sourceLoc; } // The other main auxilliary sxtructure is used to track @@ -550,6 +551,7 @@ struct PhiEliminationContext auto user = use->getUser(); m_builder.setInsertBefore(user); auto newVal = m_builder.emitLoad(temp); + newVal->sourceLoc = param->sourceLoc; m_builder.replaceOperand(use, newVal); } @@ -938,6 +940,7 @@ struct PhiEliminationContext newOperands.getCount(), newOperands.getArrayView().getBuffer()); oldBranch->transferDecorationsTo(newBranch); + newBranch->sourceLoc = oldBranch->sourceLoc; // TODO: We could consider just modifying `branch` in-place by clearing // the relevant operands for the phi arguments and setting its operand diff --git a/source/slang/slang-ir-init-local-var.cpp b/source/slang/slang-ir-init-local-var.cpp index 34a0e5ff4..fa556bc58 100644 --- a/source/slang/slang-ir-init-local-var.cpp +++ b/source/slang/slang-ir-init-local-var.cpp @@ -47,6 +47,9 @@ void initializeLocalVariables(IRModule* module, IRGlobalValueWithCode* func) breakLabel:; if (initialized) continue; + + IRBuilderSourceLocRAII sourceLocationScope(&builder, inst->sourceLoc); + builder.setInsertAfter(inst); builder.emitStore( inst, diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index b526df3a9..301a9c789 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1056,6 +1056,9 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) /// Hint that the result from a call to the decorated function should be recomputed in backward prop function. INST(PreferRecomputeDecoration, PreferRecomputeDecoration, 0, 0) + /// Hint that a struct is used for reverse mode checkpointing + INST(CheckpointIntermediateDecoration, CheckpointIntermediateDecoration, 1, 0) + INST_RANGE(CheckpointHintDecoration, PreferCheckpointDecoration, PreferRecomputeDecoration) /// Marks a function whose return value is never dynamic uniform. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 69f129986..37f242e55 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -947,6 +947,16 @@ struct IRPreferCheckpointDecoration : IRCheckpointHintDecoration IR_LEAF_ISA(PreferCheckpointDecoration) }; +struct IRCheckpointIntermediateDecoration : IRCheckpointHintDecoration +{ + enum + { + kOp = kIROp_CheckpointIntermediateDecoration + }; + IR_LEAF_ISA(CheckpointIntermediateDecoration) + + IRInst* getSourceFunction() { return getOperand(0); } +}; struct IRLoopCounterDecoration : IRDecoration { @@ -5152,6 +5162,11 @@ public: { addDecoration(inst, kIROp_MemoryQualifierSetDecoration, getIntValue(getIntType(), flags)); } + + void addCheckpointIntermediateDecoration(IRInst* inst, IRGlobalValueWithCode *func) + { + addDecoration(inst, kIROp_CheckpointIntermediateDecoration, func); + } }; // Helper to establish the source location that will be used diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index 753c930a8..ef0551161 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -526,6 +526,7 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst) // we will now introduce a breakable region for each iteration. IRBuilder builder(module); + IRBuilderSourceLocRAII sourceLocationScope(&builder, loopInst->sourceLoc); auto targetBlock = loopInst->getTargetBlock(); diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index e44c4079b..506e6a335 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -431,6 +431,7 @@ PhiInfo* addPhi( RefPtr<PhiInfo> phiInfo = new PhiInfo(); context->phiInfos.add(phi, phiInfo); + phi->sourceLoc = var->sourceLoc; phiInfo->phi = phi; phiInfo->var = var; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9305d1783..6c7691d13 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3512,6 +3512,7 @@ namespace Slang auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>( this, kIROp_MakeDifferentialPair, type, 2, args); addInst(inst); + inst->sourceLoc = primal->sourceLoc; return inst; } @@ -3524,6 +3525,7 @@ namespace Slang auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>( this, kIROp_MakeDifferentialPairUserCode, type, 2, args); addInst(inst); + inst->sourceLoc = primal->sourceLoc; return inst; } diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index c02a00957..b9a12f971 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -339,6 +339,7 @@ void initCommandOptions(CommandOptions& options) { OptionKind::InputFilesRemain, "--", nullptr, "Treat the rest of the command line as input files."}, { OptionKind::ReportDownstreamTime, "-report-downstream-time", nullptr, "Reports the time spent in the downstream compiler." }, { OptionKind::ReportPerfBenchmark, "-report-perf-benchmark", nullptr, "Reports compiler performance benchmark results." }, + { OptionKind::ReportCheckpointIntermediates, "-report-checkpoint-intermediates", nullptr, "Reports information about checkpoint contexts used for reverse-mode automatic differentiation." }, { OptionKind::SkipSPIRVValidation, "-skip-spirv-validation", nullptr, "Skips spirv validation." }, { OptionKind::SourceEmbedStyle, "-source-embed-style", "-source-embed-style <source-embed-style>", "If source embedding is enabled, defines the style used. When enabled (with any style other than `none`), " @@ -1703,6 +1704,7 @@ SlangResult OptionsParser::_parse( case OptionKind::DumpReproOnError: case OptionKind::ReportDownstreamTime: case OptionKind::ReportPerfBenchmark: + case OptionKind::ReportCheckpointIntermediates: case OptionKind::SkipSPIRVValidation: case OptionKind::DisableSpecialization: case OptionKind::DisableDynamicDispatch: |
