summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-10 14:36:39 -0700
committerGitHub <noreply@github.com>2023-04-10 14:36:39 -0700
commitd934bbcc5702ebd8964f65b1708c239c29320103 (patch)
tree0c34aeddc873e65b76553fe28bfdd7b9cc830292 /source
parentd82992e30d5985001870e00afdf27091f59464f2 (diff)
Fix inlining. (#2786)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp3
-rw-r--r--source/slang/slang-ir-inline.cpp67
-rw-r--r--source/slang/slang-ir-inline.h5
4 files changed, 56 insertions, 22 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index df94bf69f..6c3d6a934 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -11,6 +11,7 @@
#include "slang-ir-addr-inst-elimination.h"
#include "slang-ir-ssa-simplification.h"
#include "slang-ir-validate.h"
+#include "slang-ir-inline.h"
namespace Slang
{
@@ -1566,6 +1567,8 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
removeLinkageDecorations(func);
+ performForceInlining(func);
+
AutoDiffAddressConversionPolicy cvtPolicty;
cvtPolicty.diffTypeContext = &differentiableTypeConformanceContext;
auto result = eliminateAddressInsts(&cvtPolicty, func, sink);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 66c85647f..0bdc4a935 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -13,6 +13,7 @@
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-dominators.h"
#include "slang-ir-loop-unroll.h"
+#include "slang-ir-inline.h"
namespace Slang
{
@@ -521,6 +522,8 @@ namespace Slang
{
removeLinkageDecorations(func);
+ performForceInlining(func);
+
DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
diffTypeContext.setFunc(func);
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 5223e35cf..ed4fc7b06 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -40,35 +40,54 @@ struct InliningPassBase
return considerAllCallSitesRec(m_module->getModuleInst());
}
+ bool considerCallSiteInFunc(IRFunc* func)
+ {
+ bool result = false;
+
+ // Repeat until we run out of callees to inline.
+ for (;;)
+ {
+ bool changed = false;
+
+ // Collect all the call sites in the function.
+ List<IRCall*> callsites;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getChildren())
+ {
+ if (auto call = as<IRCall>(inst))
+ {
+ callsites.add(call);
+ }
+ }
+ }
+
+ // Consider each call site.
+ for (auto call : callsites)
+ {
+ changed |= considerCallSite(call);
+ }
+ result |= changed;
+ if (!changed)
+ break;
+ }
+ return result;
+ }
+
/// Consider all call sites at or under `inst` for inlining
bool considerAllCallSitesRec(IRInst* inst)
{
bool changed = false;
- if( auto call = as<IRCall>(inst) )
+
+ if( auto func = as<IRFunc>(inst) )
{
- changed = considerCallSite(call);
+ changed = considerCallSiteInFunc(func);
}
- // Note: we iterate until no more changes can be applied.
- // This is defensive against changes made by inlining one callsite
- // and make sure we get to process all callsites.
- //
- for (;;)
+ // Recursively consider the children of inst.
+ for (auto child : inst->getModifiableChildren())
{
- bool changedInThisIteration = false;
- // Note: getModifiableChildren will skip any insts that are no
- // longer the chhild of `inst`. If we process one callsite, the
- // remaining insts of the block will be moved into a different
- // block and therefore we won't process them during this iteration.
- // However, those callsites will eventually be processed
- // by the outer loop.
- for (auto child : inst->getModifiableChildren())
- {
- changedInThisIteration = considerAllCallSitesRec(child);
- changed |= changedInThisIteration;
- }
- if (!changedInThisIteration)
- break;
+ changed |= considerAllCallSitesRec(child);
}
return changed;
}
@@ -770,6 +789,12 @@ void performForceInlining(IRModule* module)
pass.considerAllCallSites();
}
+bool performForceInlining(IRGlobalValueWithCode* func)
+{
+ ForceInliningPass pass(func->getModule());
+ return pass.considerAllCallSitesRec(func);
+}
+
// Defined in slang-ir-specialize-resource.cpp
bool isResourceType(IRType* type);
bool isIllegalGLSLParameterType(IRType* type);
diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h
index 2f3a7a965..61a411d32 100644
--- a/source/slang/slang-ir-inline.h
+++ b/source/slang/slang-ir-inline.h
@@ -7,7 +7,7 @@ namespace Slang
{
struct IRModule;
struct IRCall;
-
+ struct IRGlobalValueWithCode;
class DiagnosticSink;
/// Any call to a function that takes or returns a string parameter is inlined
@@ -19,6 +19,9 @@ namespace Slang
/// Inline any call sites to functions marked `[ForceInline]`
void performForceInlining(IRModule* module);
+ /// Inline any call sites to functions marked `[ForceInline]` inside `func`.
+ bool performForceInlining(IRGlobalValueWithCode* func);
+
/// Inline calls to functions that returns a resource/sampler via either return value or output parameter.
void performGLSLResourceReturnFunctionInlining(IRModule* module);