From 29c318bfe5c66350a67467e3b6ef08120f00fb7e Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 20 Sep 2023 20:54:10 -0400 Subject: Move force inlining step to before `processAutodiffCalls` (and run in loop) (#3217) * Move auto-diff force inlining step to before `processAutodiffCalls` * Fix `replaceUsesWith` to handle existing inst defined after current use. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-emit.cpp | 2 + source/slang/slang-ir-autodiff-fwd.cpp | 2 - source/slang/slang-ir-autodiff-primal-hoist.cpp | 6 +++ source/slang/slang-ir-autodiff-rev.cpp | 2 - source/slang/slang-ir-inline.cpp | 6 +++ source/slang/slang-ir-inline.h | 3 ++ source/slang/slang-ir.cpp | 55 +++++++++++++++++++++++++ 7 files changed, 72 insertions(+), 4 deletions(-) (limited to 'source') diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 98c9c9803..3a0ee36be 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -413,6 +413,8 @@ Result linkAndOptimizeIR( // since we may be missing out cases prevented by the functions that we just specialzied. performMandatoryEarlyInlining(irModule); + performPreAutoDiffForceInlining(irModule); + // Unroll loops. if (codeGenContext->getSink()->getErrorCount() == 0) { diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index b82818b99..c906f93eb 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1690,8 +1690,6 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) { insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func); removeLinkageDecorations(func); - - performPreAutoDiffForceInlining(func); initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func); diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index ebf7a9484..61baa7dd7 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -969,7 +969,10 @@ void applyCheckpointSet( bool isInverted = checkpointInfo->invertSet.contains(param); bool loopInductionInfo = checkpointInfo->loopInductionInfo.tryGetValue(param); if (!isRecomputed && !isInverted) + { + ii++; continue; + } if (!loopInductionInfo) { @@ -982,7 +985,10 @@ void applyCheckpointSet( applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, blockIndexInfo, param); if (loopInductionInfo) + { + ii++; continue; + } // Copy primal branch-arg for predecessor blocks. HashSet predecessorSet; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 335b6572e..7d5659425 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -533,8 +533,6 @@ namespace Slang { removeLinkageDecorations(func); - performPreAutoDiffForceInlining(func); - DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext); diffTypeContext.setFunc(func); diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 06b63db52..1d308c507 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -872,6 +872,12 @@ bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func) return pass.considerAllCallSitesRec(func); } +bool performPreAutoDiffForceInlining(IRModule* module) +{ + PreAutoDiffForceInliningPass pass(module); + return pass.considerAllCallSitesRec(module->getModuleInst()); +} + // Defined in slang-ir-specialize-resource.cpp bool isResourceType(IRType* type); bool isIllegalGLSLParameterType(IRType* type); diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h index fe050b7b9..539bb26c0 100644 --- a/source/slang/slang-ir-inline.h +++ b/source/slang/slang-ir-inline.h @@ -25,6 +25,9 @@ namespace Slang /// Perform force inlining of functions that does not have custom derivatives. bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func); + /// Perform force inlining of all functions in a module that does not have custom derivatives. + bool performPreAutoDiffForceInlining(IRModule* module); + /// Inline calls to functions that returns a resource/sampler via either return value or output parameter. void performGLSLResourceReturnFunctionInlining(IRModule* module); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index a54bc1f2e..b4a8e6f42 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -7177,6 +7177,57 @@ namespace Slang void validateIRInstOperands(IRInst*); + + // Returns true if `instToCheck` is defined after `otherInst`. + static bool _isInstDefinedAfter(IRInst* instToCheck, IRInst* otherInst) + { + for (auto inst = otherInst->getNextInst(); inst; inst = inst->getNextInst()) + { + if (inst == instToCheck) + return true; + } + return false; + } + + static void _maybeHoistOperand(IRUse* use) + { + ShortList workList1, workList2; + workList1.add(use); + while (workList1.getCount()) + { + for (auto item : workList1) + { + auto user = item->getUser(); + auto operand = item->get(); + if (!operand) + continue; + + if (!getIROpInfo(operand->getOp()).isHoistable()) + continue; + + // We can't handle the case where operand and user are in different blocks. + if (operand->getParent() != user->getParent()) + continue; + + // We allow out-of-order uses in global scope. + if (operand->getParent() && operand->getParent()->getOp() == kIROp_Module) + continue; + + // If the operand is defined after user, move it to before user. + if (_isInstDefinedAfter(operand, user)) + { + operand->insertBefore(user); + for (UInt i = 0; i < operand->getOperandCount(); i++) + { + workList2.add(operand->getOperands() + i); + } + workList2.add(&operand->typeUse); + } + } + workList1 = _Move(workList2); + } + } + static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) { IRDeduplicationContext* dedupContext = nullptr; @@ -7259,6 +7310,10 @@ namespace Slang // Swap this use over to use the other value. uu->usedValue = other; + // If `other` is hoistable, then we need to make sure `other` is hoisted + // to a point before `user`, if it is not already so. + _maybeHoistOperand(uu); + if (userIsHoistable) { // Is the updated inst already exists in the global numbering map? -- cgit v1.2.3