diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2025-04-04 15:25:20 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-04 15:25:20 +0800 |
| commit | 83a42cb76feb1f702ff730040f359cabc01c571a (patch) | |
| tree | 442ff945665ace3bbdafaf410f2664825419fd7d | |
| parent | 4233d69cf88f1623cb573c8edb61456b24dc5339 (diff) | |
Do no fail on missing no_diff annotation on non-differentiable (inputs and output) function outputs (#6737)
Closes https://github.com/shader-slang/slang/issues/6632
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 21 | ||||
| -rw-r--r-- | tests/bugs/gh-6632.slang | 29 |
4 files changed, 58 insertions, 7 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index afd698e8b..f70e30b72 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -147,6 +147,20 @@ bool isNoDiffType(IRType* paramType) return false; } +// Return true if the result type and all the parameter types are no_diff +bool isNeverDiffFuncType(IRFuncType* const funcType) +{ + const auto resultType = funcType->getResultType(); + if (!isNoDiffType(resultType)) + return false; + for (const auto p : funcType->getParamTypes()) + { + if (!isNoDiffType(p)) + return false; + } + return true; +} + IRInst* lookupForwardDerivativeReference(IRInst* primalFunction) { if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 2cd08eb28..befd1f98a 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -563,6 +563,7 @@ void stripAutoDiffDecorations(IRModule* module); void stripTempDecorations(IRInst* inst); bool isNoDiffType(IRType* paramType); +bool isNeverDiffFuncType(IRFuncType* funcType); IRInst* lookupForwardDerivativeReference(IRInst* primalFunction); diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 9001295e0..7f3c7bf01 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -487,18 +487,25 @@ public: { if (auto call = as<IRCall>(inst)) { + const auto callee = call->getCallee(); // If inst's type is differentiable, and it is in expectDiffInstWorkList, // then some user is expecting the result of the call to produce a derivative. // In this case we need to issue a diagnostic. if (isDifferentiableType(diffTypeContext, inst->getFullType()) && - !isDifferentiableFunc(call->getCallee(), requiredDiffLevel)) + !isDifferentiableFunc(callee, requiredDiffLevel)) { - sink->diagnose( - inst, - Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, - getResolvedInstForDecorations(call->getCallee()), - requiredDiffLevel == DifferentiableLevel::Forward ? "forward" - : "backward"); + // No need to fail here if the function is no_diff in + // both inputs and all outputs, this is equivalent of + // inserting no_diff on this inst. + if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType()))) + { + sink->diagnose( + inst, + Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, + getResolvedInstForDecorations(call->getCallee()), + requiredDiffLevel == DifferentiableLevel::Forward ? "forward" + : "backward"); + } } } } diff --git a/tests/bugs/gh-6632.slang b/tests/bugs/gh-6632.slang new file mode 100644 index 000000000..f9240e63b --- /dev/null +++ b/tests/bugs/gh-6632.slang @@ -0,0 +1,29 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu + +// CHECK: 40000000 + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer + +// Non-differentiable function with no_diff parameters and return type +no_diff float targetFunc(no_diff float x) +{ + return x * 2.0f; +} + +[Differentiable] +float errorForward(no_diff float x) +{ + float result = targetFunc(x); + return result; +} + +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float input = 1.0f; + + outputBuffer[0] = errorForward(input); +} + |
