From 83a42cb76feb1f702ff730040f359cabc01c571a Mon Sep 17 00:00:00 2001 From: Ellie Hermaszewska Date: Fri, 4 Apr 2025 15:25:20 +0800 Subject: Do no fail on missing no_diff annotation on non-differentiable (inputs and output) function outputs (#6737) Closes https://github.com/shader-slang/slang/issues/6632 --- source/slang/slang-ir-autodiff.cpp | 14 +++++++++++ source/slang/slang-ir-autodiff.h | 1 + source/slang/slang-ir-check-differentiability.cpp | 21 ++++++++++------ tests/bugs/gh-6632.slang | 29 +++++++++++++++++++++++ 4 files changed, 58 insertions(+), 7 deletions(-) create mode 100644 tests/bugs/gh-6632.slang diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index afd698e8b..f70e30b72 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -147,6 +147,20 @@ bool isNoDiffType(IRType* paramType) return false; } +// Return true if the result type and all the parameter types are no_diff +bool isNeverDiffFuncType(IRFuncType* const funcType) +{ + const auto resultType = funcType->getResultType(); + if (!isNoDiffType(resultType)) + return false; + for (const auto p : funcType->getParamTypes()) + { + if (!isNoDiffType(p)) + return false; + } + return true; +} + IRInst* lookupForwardDerivativeReference(IRInst* primalFunction) { if (auto jvpDefinition = primalFunction->findDecoration()) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 2cd08eb28..befd1f98a 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -563,6 +563,7 @@ void stripAutoDiffDecorations(IRModule* module); void stripTempDecorations(IRInst* inst); bool isNoDiffType(IRType* paramType); +bool isNeverDiffFuncType(IRFuncType* funcType); IRInst* lookupForwardDerivativeReference(IRInst* primalFunction); diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 9001295e0..7f3c7bf01 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -487,18 +487,25 @@ public: { if (auto call = as(inst)) { + const auto callee = call->getCallee(); // If inst's type is differentiable, and it is in expectDiffInstWorkList, // then some user is expecting the result of the call to produce a derivative. // In this case we need to issue a diagnostic. if (isDifferentiableType(diffTypeContext, inst->getFullType()) && - !isDifferentiableFunc(call->getCallee(), requiredDiffLevel)) + !isDifferentiableFunc(callee, requiredDiffLevel)) { - sink->diagnose( - inst, - Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, - getResolvedInstForDecorations(call->getCallee()), - requiredDiffLevel == DifferentiableLevel::Forward ? "forward" - : "backward"); + // No need to fail here if the function is no_diff in + // both inputs and all outputs, this is equivalent of + // inserting no_diff on this inst. + if (!isNeverDiffFuncType(cast(callee->getDataType()))) + { + sink->diagnose( + inst, + Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, + getResolvedInstForDecorations(call->getCallee()), + requiredDiffLevel == DifferentiableLevel::Forward ? "forward" + : "backward"); + } } } } diff --git a/tests/bugs/gh-6632.slang b/tests/bugs/gh-6632.slang new file mode 100644 index 000000000..f9240e63b --- /dev/null +++ b/tests/bugs/gh-6632.slang @@ -0,0 +1,29 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu + +// CHECK: 40000000 + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer + +// Non-differentiable function with no_diff parameters and return type +no_diff float targetFunc(no_diff float x) +{ + return x * 2.0f; +} + +[Differentiable] +float errorForward(no_diff float x) +{ + float result = targetFunc(x); + return result; +} + +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float input = 1.0f; + + outputBuffer[0] = errorForward(input); +} + -- cgit v1.2.3