summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-09-08 12:26:07 -0700
committerGitHub <noreply@github.com>2023-09-08 12:26:07 -0700
commite8a1dd11eab4c07366b29aca775eb927a465e133 (patch)
tree6f1e0788860ace253c05c3fbb37be8cd1f07fecf /source/slang
parentcb5dd19992fb77ca2be866d9c6f2f4436c8b1c1e (diff)
Don't inline callees with custom derivative before autodiff. (#3196)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp13
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
-rw-r--r--source/slang/slang-ir-inline.cpp41
-rw-r--r--source/slang/slang-ir-inline.h3
6 files changed, 64 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 2662498ed..b82818b99 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1521,6 +1521,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
inBuilder->addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);
+ copyOriginalDecorations(origFunc, diffFunc);
+
FuncBodyTranscriptionTask task;
task.type = FuncBodyTranscriptionTaskType::Forward;
task.originalFunc = origFunc;
@@ -1689,7 +1691,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
removeLinkageDecorations(func);
- performForceInlining(func);
+ performPreAutoDiffForceInlining(func);
initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 8d7582373..681c69cd3 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -351,6 +351,7 @@ namespace Slang
builder.setInsertBefore(diffFunc->getFirstDecorationOrChild());
cloneInst(&cloneEnv, &builder, dictDecor);
}
+ copyOriginalDecorations(origFunc, diffFunc);
builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);
return InstPair(primalFunc, diffFunc);
}
@@ -532,7 +533,7 @@ namespace Slang
{
removeLinkageDecorations(func);
- performForceInlining(func);
+ performPreAutoDiffForceInlining(func);
DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
diffTypeContext.setFunc(func);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 53bb30f54..cf2310fc8 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -508,6 +508,19 @@ bool AutoDiffTranscriberBase::isExistentialType(IRType *type)
}
}
+void AutoDiffTranscriberBase::copyOriginalDecorations(IRInst* origFunc, IRInst* diffFunc)
+{
+ for (auto decor : origFunc->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_ForceInlineDecoration:
+ cloneDecoration(decor, diffFunc);
+ break;
+ }
+ }
+}
+
InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst)
{
IRInst* origBase = origInst->getOperand(0);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index e9acbcd99..7b4c293e9 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -140,6 +140,8 @@ struct AutoDiffTranscriberBase
void _markInstAsDifferential(IRBuilder* builder, IRInst* diffInst, IRInst* primalInst = nullptr);
+ void copyOriginalDecorations(IRInst* origFunc, IRInst* diffFunc);
+
virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) = 0;
// Create an empty func to represent the transcribed func of `origFunc`.
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 5479a98ff..06b63db52 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -831,6 +831,47 @@ bool performForceInlining(IRGlobalValueWithCode* func)
return pass.considerAllCallSitesRec(func);
}
+struct PreAutoDiffForceInliningPass : InliningPassBase
+{
+ typedef InliningPassBase Super;
+
+ PreAutoDiffForceInliningPass(IRModule* module)
+ : Super(module)
+ {}
+
+ bool shouldInline(CallSiteInfo const& info)
+ {
+ if (info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>() ||
+ info.callee->findDecoration<IRIntrinsicOpDecoration>())
+ return true;
+ bool hasForceInline = false;
+ bool hasUserDefinedDerivative = false;
+ for (auto decor : info.callee->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_UnsafeForceInlineEarlyDecoration:
+ case kIROp_IntrinsicOpDecoration:
+ return true;
+ case kIROp_ForceInlineDecoration:
+ hasForceInline = true;
+ break;
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_ForwardDerivativeDecoration:
+ hasUserDefinedDerivative = true;
+ break;
+ }
+ }
+ return (hasForceInline && !hasUserDefinedDerivative);
+ }
+};
+
+bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func)
+{
+ PreAutoDiffForceInliningPass pass(func->getModule());
+ return pass.considerAllCallSitesRec(func);
+}
+
// Defined in slang-ir-specialize-resource.cpp
bool isResourceType(IRType* type);
bool isIllegalGLSLParameterType(IRType* type);
diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h
index 6eb3a1bb1..fe050b7b9 100644
--- a/source/slang/slang-ir-inline.h
+++ b/source/slang/slang-ir-inline.h
@@ -21,6 +21,9 @@ namespace Slang
/// Inline any call sites to functions marked `[ForceInline]` inside `func`.
bool performForceInlining(IRGlobalValueWithCode* func);
+
+ /// Perform force inlining of functions that does not have custom derivatives.
+ bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func);
/// Inline calls to functions that returns a resource/sampler via either return value or output parameter.
void performGLSLResourceReturnFunctionInlining(IRModule* module);