diff options
| author | Yong He <yonghe@outlook.com> | 2023-09-08 12:26:07 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-08 12:26:07 -0700 |
| commit | e8a1dd11eab4c07366b29aca775eb927a465e133 (patch) | |
| tree | 6f1e0788860ace253c05c3fbb37be8cd1f07fecf /source/slang/slang-ir-inline.cpp | |
| parent | cb5dd19992fb77ca2be866d9c6f2f4436c8b1c1e (diff) | |
Don't inline callees with custom derivative before autodiff. (#3196)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-inline.cpp')
| -rw-r--r-- | source/slang/slang-ir-inline.cpp | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 5479a98ff..06b63db52 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -831,6 +831,47 @@ bool performForceInlining(IRGlobalValueWithCode* func) return pass.considerAllCallSitesRec(func); } +struct PreAutoDiffForceInliningPass : InliningPassBase +{ + typedef InliningPassBase Super; + + PreAutoDiffForceInliningPass(IRModule* module) + : Super(module) + {} + + bool shouldInline(CallSiteInfo const& info) + { + if (info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>() || + info.callee->findDecoration<IRIntrinsicOpDecoration>()) + return true; + bool hasForceInline = false; + bool hasUserDefinedDerivative = false; + for (auto decor : info.callee->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_UnsafeForceInlineEarlyDecoration: + case kIROp_IntrinsicOpDecoration: + return true; + case kIROp_ForceInlineDecoration: + hasForceInline = true; + break; + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_ForwardDerivativeDecoration: + hasUserDefinedDerivative = true; + break; + } + } + return (hasForceInline && !hasUserDefinedDerivative); + } +}; + +bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func) +{ + PreAutoDiffForceInliningPass pass(func->getModule()); + return pass.considerAllCallSitesRec(func); +} + // Defined in slang-ir-specialize-resource.cpp bool isResourceType(IRType* type); bool isIllegalGLSLParameterType(IRType* type); |
