summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-09-08 12:26:07 -0700
committerGitHub <noreply@github.com>2023-09-08 12:26:07 -0700
commite8a1dd11eab4c07366b29aca775eb927a465e133 (patch)
tree6f1e0788860ace253c05c3fbb37be8cd1f07fecf
parentcb5dd19992fb77ca2be866d9c6f2f4436c8b1c1e (diff)
Don't inline callees with custom derivative before autodiff. (#3196)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp3
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp13
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.h2
-rw-r--r--source/slang/slang-ir-inline.cpp41
-rw-r--r--source/slang/slang-ir-inline.h3
-rw-r--r--tests/autodiff/inline.slang45
7 files changed, 109 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 2662498ed..b82818b99 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1521,6 +1521,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFu
inBuilder->addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);
+ copyOriginalDecorations(origFunc, diffFunc);
+
FuncBodyTranscriptionTask task;
task.type = FuncBodyTranscriptionTaskType::Forward;
task.originalFunc = origFunc;
@@ -1689,7 +1691,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func);
removeLinkageDecorations(func);
- performForceInlining(func);
+ performPreAutoDiffForceInlining(func);
initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func);
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 8d7582373..681c69cd3 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -351,6 +351,7 @@ namespace Slang
builder.setInsertBefore(diffFunc->getFirstDecorationOrChild());
cloneInst(&cloneEnv, &builder, dictDecor);
}
+ copyOriginalDecorations(origFunc, diffFunc);
builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast);
return InstPair(primalFunc, diffFunc);
}
@@ -532,7 +533,7 @@ namespace Slang
{
removeLinkageDecorations(func);
- performForceInlining(func);
+ performPreAutoDiffForceInlining(func);
DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext);
diffTypeContext.setFunc(func);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 53bb30f54..cf2310fc8 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -508,6 +508,19 @@ bool AutoDiffTranscriberBase::isExistentialType(IRType *type)
}
}
+void AutoDiffTranscriberBase::copyOriginalDecorations(IRInst* origFunc, IRInst* diffFunc)
+{
+ for (auto decor : origFunc->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_ForceInlineDecoration:
+ cloneDecoration(decor, diffFunc);
+ break;
+ }
+ }
+}
+
InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst)
{
IRInst* origBase = origInst->getOperand(0);
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h
index e9acbcd99..7b4c293e9 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.h
+++ b/source/slang/slang-ir-autodiff-transcriber-base.h
@@ -140,6 +140,8 @@ struct AutoDiffTranscriberBase
void _markInstAsDifferential(IRBuilder* builder, IRInst* diffInst, IRInst* primalInst = nullptr);
+ void copyOriginalDecorations(IRInst* origFunc, IRInst* diffFunc);
+
virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) = 0;
// Create an empty func to represent the transcribed func of `origFunc`.
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 5479a98ff..06b63db52 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -831,6 +831,47 @@ bool performForceInlining(IRGlobalValueWithCode* func)
return pass.considerAllCallSitesRec(func);
}
+struct PreAutoDiffForceInliningPass : InliningPassBase
+{
+ typedef InliningPassBase Super;
+
+ PreAutoDiffForceInliningPass(IRModule* module)
+ : Super(module)
+ {}
+
+ bool shouldInline(CallSiteInfo const& info)
+ {
+ if (info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>() ||
+ info.callee->findDecoration<IRIntrinsicOpDecoration>())
+ return true;
+ bool hasForceInline = false;
+ bool hasUserDefinedDerivative = false;
+ for (auto decor : info.callee->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_UnsafeForceInlineEarlyDecoration:
+ case kIROp_IntrinsicOpDecoration:
+ return true;
+ case kIROp_ForceInlineDecoration:
+ hasForceInline = true;
+ break;
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ case kIROp_ForwardDerivativeDecoration:
+ hasUserDefinedDerivative = true;
+ break;
+ }
+ }
+ return (hasForceInline && !hasUserDefinedDerivative);
+ }
+};
+
+bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func)
+{
+ PreAutoDiffForceInliningPass pass(func->getModule());
+ return pass.considerAllCallSitesRec(func);
+}
+
// Defined in slang-ir-specialize-resource.cpp
bool isResourceType(IRType* type);
bool isIllegalGLSLParameterType(IRType* type);
diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h
index 6eb3a1bb1..fe050b7b9 100644
--- a/source/slang/slang-ir-inline.h
+++ b/source/slang/slang-ir-inline.h
@@ -21,6 +21,9 @@ namespace Slang
/// Inline any call sites to functions marked `[ForceInline]` inside `func`.
bool performForceInlining(IRGlobalValueWithCode* func);
+
+ /// Perform force inlining of functions that does not have custom derivatives.
+ bool performPreAutoDiffForceInlining(IRGlobalValueWithCode* func);
/// Inline calls to functions that returns a resource/sampler via either return value or output parameter.
void performGLSLResourceReturnFunctionInlining(IRModule* module);
diff --git a/tests/autodiff/inline.slang b/tests/autodiff/inline.slang
new file mode 100644
index 000000000..9b41b5f1f
--- /dev/null
+++ b/tests/autodiff/inline.slang
@@ -0,0 +1,45 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -output-using-type
+//TEST:SIMPLE(filecheck=CHECK):-stage compute -entry computeMain -target hlsl
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+// CHECK-NOT: void mySqr{{.*}}(
+
+// Test that calls to a ForceInline function stil get correct custom derivative.
+[BackwardDerivative(bwd_mySqr)]
+[ForceInline]
+void mySqr(float x, out float y)
+{
+ y = x * x;
+}
+
+void bwd_mySqr(inout DifferentialPair<float> dpx, in float.Differential dy)
+{
+ dpx = DifferentialPair<float>(dpx.p, 1001.0);
+}
+
+[Differentiable]
+void myF(float x, out float y)
+{
+ mySqr(x, y);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(2.0, 1.0);
+ __bwd_diff(myF)(dpa, 1.0);
+ // BUFFER: 1001.0
+ outputBuffer[0] = dpa.d;
+
+ float o;
+ myF(1.0, o);
+ outputBuffer[1] = o;
+ }
+}