diff options
| author | Yong He <yonghe@outlook.com> | 2023-04-10 14:36:39 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-10 14:36:39 -0700 |
| commit | d934bbcc5702ebd8964f65b1708c239c29320103 (patch) | |
| tree | 0c34aeddc873e65b76553fe28bfdd7b9cc830292 /source | |
| parent | d82992e30d5985001870e00afdf27091f59464f2 (diff) | |
Fix inlining. (#2786)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.cpp | 67 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.h | 5 |
4 files changed, 56 insertions, 22 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index df94bf69f..6c3d6a934 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -11,6 +11,7 @@ #include "slang-ir-addr-inst-elimination.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-validate.h" +#include "slang-ir-inline.h" namespace Slang { @@ -1566,6 +1567,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func); removeLinkageDecorations(func); + performForceInlining(func); + AutoDiffAddressConversionPolicy cvtPolicty; cvtPolicty.diffTypeContext = &differentiableTypeConformanceContext; auto result = eliminateAddressInsts(&cvtPolicty, func, sink); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 66c85647f..0bdc4a935 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -13,6 +13,7 @@ #include "slang-ir-redundancy-removal.h" #include "slang-ir-dominators.h" #include "slang-ir-loop-unroll.h" +#include "slang-ir-inline.h" namespace Slang { @@ -521,6 +522,8 @@ namespace Slang { removeLinkageDecorations(func); + performForceInlining(func); + DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext); diffTypeContext.setFunc(func); diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 5223e35cf..ed4fc7b06 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -40,35 +40,54 @@ struct InliningPassBase return considerAllCallSitesRec(m_module->getModuleInst()); } + bool considerCallSiteInFunc(IRFunc* func) + { + bool result = false; + + // Repeat until we run out of callees to inline. + for (;;) + { + bool changed = false; + + // Collect all the call sites in the function. + List<IRCall*> callsites; + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (auto call = as<IRCall>(inst)) + { + callsites.add(call); + } + } + } + + // Consider each call site. + for (auto call : callsites) + { + changed |= considerCallSite(call); + } + result |= changed; + if (!changed) + break; + } + return result; + } + /// Consider all call sites at or under `inst` for inlining bool considerAllCallSitesRec(IRInst* inst) { bool changed = false; - if( auto call = as<IRCall>(inst) ) + + if( auto func = as<IRFunc>(inst) ) { - changed = considerCallSite(call); + changed = considerCallSiteInFunc(func); } - // Note: we iterate until no more changes can be applied. - // This is defensive against changes made by inlining one callsite - // and make sure we get to process all callsites. - // - for (;;) + // Recursively consider the children of inst. + for (auto child : inst->getModifiableChildren()) { - bool changedInThisIteration = false; - // Note: getModifiableChildren will skip any insts that are no - // longer the chhild of `inst`. If we process one callsite, the - // remaining insts of the block will be moved into a different - // block and therefore we won't process them during this iteration. - // However, those callsites will eventually be processed - // by the outer loop. - for (auto child : inst->getModifiableChildren()) - { - changedInThisIteration = considerAllCallSitesRec(child); - changed |= changedInThisIteration; - } - if (!changedInThisIteration) - break; + changed |= considerAllCallSitesRec(child); } return changed; } @@ -770,6 +789,12 @@ void performForceInlining(IRModule* module) pass.considerAllCallSites(); } +bool performForceInlining(IRGlobalValueWithCode* func) +{ + ForceInliningPass pass(func->getModule()); + return pass.considerAllCallSitesRec(func); +} + // Defined in slang-ir-specialize-resource.cpp bool isResourceType(IRType* type); bool isIllegalGLSLParameterType(IRType* type); diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h index 2f3a7a965..61a411d32 100644 --- a/source/slang/slang-ir-inline.h +++ b/source/slang/slang-ir-inline.h @@ -7,7 +7,7 @@ namespace Slang { struct IRModule; struct IRCall; - + struct IRGlobalValueWithCode; class DiagnosticSink; /// Any call to a function that takes or returns a string parameter is inlined @@ -19,6 +19,9 @@ namespace Slang /// Inline any call sites to functions marked `[ForceInline]` void performForceInlining(IRModule* module); + /// Inline any call sites to functions marked `[ForceInline]` inside `func`. + bool performForceInlining(IRGlobalValueWithCode* func); + /// Inline calls to functions that returns a resource/sampler via either return value or output parameter. void performGLSLResourceReturnFunctionInlining(IRModule* module); |
