summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-inline.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-09-08 12:26:07 -0700
committerGitHub <noreply@github.com>2023-09-08 12:26:07 -0700
commite8a1dd11eab4c07366b29aca775eb927a465e133 (patch)
tree6f1e0788860ace253c05c3fbb37be8cd1f07fecf /source/slang/slang-ir-inline.cpp
parentcb5dd19992fb77ca2be866d9c6f2f4436c8b1c1e (diff)
Don't inline callees with custom derivative before autodiff. (#3196)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-inline.cpp')
-rw-r--r--source/slang/slang-ir-inline.cpp41
1 files changed, 41 insertions, 0 deletions
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 5479a98ff..06b63db52 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -831,6 +831,47 @@ bool performForceInlining(IRGlobalValueWithCode* func)
return pass.considerAllCallSitesRec(func);
}
+struct PreAutoDiffForceInliningPass : InliningPassBase
+{
+ typedef InliningPassBase Super;
+
+ PreAutoDiffForceInliningPass(IRModule* module)
+ : Super(module)
+ {}
+
+ bool shouldInline(CallSiteInfo const& info)
+ {
+ if (info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>() ||
+ info.callee->findDecoration<IRIntrinsicOpDecoration>())
+ return true;
+ bool hasForceInline = false;
+ bool hasUserDefinedDerivative = false;
+ for (auto decor : info.callee->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_UnsafeForceInlineEarlyDecoration:
+ case kIROp_IntrinsicOpDecoration:
+ return true;
+ case kIROp_ForceInlineDecoration:
+ hasForceInline = true;
+ break;
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_ForwardDerivativeDecoration:
+ hasUserDefinedDerivative = true;
+ break;
+ }
+ }
+ return (hasForceInline && !hasUserDefinedDerivative);
+ }
+};
+
+bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func)
+{
+ PreAutoDiffForceInliningPass pass(func->getModule());
+ return pass.considerAllCallSitesRec(func);
+}
+
// Defined in slang-ir-specialize-resource.cpp
bool isResourceType(IRType* type);
bool isIllegalGLSLParameterType(IRType* type);