diff options
| author | Yong He <yonghe@outlook.com> | 2023-09-08 12:26:07 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-08 12:26:07 -0700 |
| commit | e8a1dd11eab4c07366b29aca775eb927a465e133 (patch) | |
| tree | 6f1e0788860ace253c05c3fbb37be8cd1f07fecf /source/slang | |
| parent | cb5dd19992fb77ca2be866d9c6f2f4436c8b1c1e (diff) | |
Don't inline callees with custom derivative before autodiff. (#3196)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.cpp | 41 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.h | 3 |
6 files changed, 64 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 2662498ed..b82818b99 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1521,6 +1521,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu inBuilder->addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast); + copyOriginalDecorations(origFunc, diffFunc); + FuncBodyTranscriptionTask task; task.type = FuncBodyTranscriptionTaskType::Forward; task.originalFunc = origFunc; @@ -1689,7 +1691,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func); removeLinkageDecorations(func); - performForceInlining(func); + performPreAutoDiffForceInlining(func); initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func); diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 8d7582373..681c69cd3 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -351,6 +351,7 @@ namespace Slang builder.setInsertBefore(diffFunc->getFirstDecorationOrChild()); cloneInst(&cloneEnv, &builder, dictDecor); } + copyOriginalDecorations(origFunc, diffFunc); builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast); return InstPair(primalFunc, diffFunc); } @@ -532,7 +533,7 @@ namespace Slang { removeLinkageDecorations(func); - performForceInlining(func); + performPreAutoDiffForceInlining(func); DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext); diffTypeContext.setFunc(func); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 53bb30f54..cf2310fc8 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -508,6 +508,19 @@ bool AutoDiffTranscriberBase::isExistentialType(IRType *type) } } +void AutoDiffTranscriberBase::copyOriginalDecorations(IRInst* origFunc, IRInst* diffFunc) +{ + for (auto decor : origFunc->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_ForceInlineDecoration: + cloneDecoration(decor, diffFunc); + break; + } + } +} + InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst) { IRInst* origBase = origInst->getOperand(0); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index e9acbcd99..7b4c293e9 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -140,6 +140,8 @@ struct AutoDiffTranscriberBase void _markInstAsDifferential(IRBuilder* builder, IRInst* diffInst, IRInst* primalInst = nullptr); + void copyOriginalDecorations(IRInst* origFunc, IRInst* diffFunc); + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) = 0; // Create an empty func to represent the transcribed func of `origFunc`. diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 5479a98ff..06b63db52 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -831,6 +831,47 @@ bool performForceInlining(IRGlobalValueWithCode* func) return pass.considerAllCallSitesRec(func); } +struct PreAutoDiffForceInliningPass : InliningPassBase +{ + typedef InliningPassBase Super; + + PreAutoDiffForceInliningPass(IRModule* module) + : Super(module) + {} + + bool shouldInline(CallSiteInfo const& info) + { + if (info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>() || + info.callee->findDecoration<IRIntrinsicOpDecoration>()) + return true; + bool hasForceInline = false; + bool hasUserDefinedDerivative = false; + for (auto decor : info.callee->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_UnsafeForceInlineEarlyDecoration: + case kIROp_IntrinsicOpDecoration: + return true; + case kIROp_ForceInlineDecoration: + hasForceInline = true; + break; + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_ForwardDerivativeDecoration: + hasUserDefinedDerivative = true; + break; + } + } + return (hasForceInline && !hasUserDefinedDerivative); + } +}; + +bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func) +{ + PreAutoDiffForceInliningPass pass(func->getModule()); + return pass.considerAllCallSitesRec(func); +} + // Defined in slang-ir-specialize-resource.cpp bool isResourceType(IRType* type); bool isIllegalGLSLParameterType(IRType* type); diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h index 6eb3a1bb1..fe050b7b9 100644 --- a/source/slang/slang-ir-inline.h +++ b/source/slang/slang-ir-inline.h @@ -21,6 +21,9 @@ namespace Slang /// Inline any call sites to functions marked `[ForceInline]` inside `func`. bool performForceInlining(IRGlobalValueWithCode* func); + + /// Perform force inlining of functions that does not have custom derivatives. + bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func); /// Inline calls to functions that returns a resource/sampler via either return value or output parameter. void performGLSLResourceReturnFunctionInlining(IRModule* module); |
