From 551bbb5fbd61b53253de8f6ba3303bb4d29f8c86 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 11 Feb 2025 03:08:27 -0800 Subject: Add checking for differentiability of the primal substitute function. (#6277) Co-authored-by: Yong He Co-authored-by: Ellie Hermaszewska --- source/slang/slang-check-decl.cpp | 21 ++++++++++ source/slang/slang-diagnostic-defs.h | 6 +++ source/slang/slang-ir-autodiff.cpp | 1 + tests/autodiff/primal-substitute-4.slang | 46 ++++++++++++++++++++++ tests/diagnostics/autodiff-primal-substitute.slang | 44 +++++++++++++++++++++ 5 files changed, 118 insertions(+) create mode 100644 tests/autodiff/primal-substitute-4.slang create mode 100644 tests/diagnostics/autodiff-primal-substitute.slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 2ce4f81f1..76074f551 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -12114,6 +12114,27 @@ static void checkDerivativeAttribute( imaginaryArguments.directions, imaginaryArguments.thisArg, imaginaryArguments.thisArgDirection); + + // For primal-substitute we'd also want to make sure that the differentiability + // level of the target is as high as the funcDecl itself + // + if (auto declRefExpr = as(attr->funcExpr)) + { + if (auto declRef = declRefExpr->declRef) + { + auto targetDiffLevel = visitor->getShared()->getFuncDifferentiableLevel( + declRef.as().getDecl()); + auto currDiffLevel = visitor->getShared()->getFuncDifferentiableLevel(funcDecl); + if (targetDiffLevel < currDiffLevel) + { + visitor->getSink()->diagnose( + attr->loc, + Diagnostics::primalSubstituteTargetMustHaveHigherDifferentiabilityLevel, + declRefExpr->declRef.getDecl(), + funcDecl); + } + } + } } static void checkCudaKernelAttribute( diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 59a1bbdb6..acb9beb94 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -1208,6 +1208,12 @@ DIAGNOSTIC( Error, overloadedFuncUsedWithDerivativeOfAttributes, "cannot resolve overloaded functions for derivative-of attributes.") +DIAGNOSTIC( + 31158, + Error, + primalSubstituteTargetMustHaveHigherDifferentiabilityLevel, + "primal substitute function for differentiable method must also be differentiable. Use " + "[Differentiable] or [TreatAsDifferentiable] (for empty derivatives)") DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.") diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index d65a22e77..9075002e0 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -2471,6 +2471,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalReturnDecoration: + case kIROp_PrimalSubstituteDecoration: case kIROp_AutoDiffOriginalValueDecoration: case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_IntermediateContextFieldDifferentialTypeDecoration: diff --git a/tests/autodiff/primal-substitute-4.slang b/tests/autodiff/primal-substitute-4.slang new file mode 100644 index 000000000..8f7720639 --- /dev/null +++ b/tests/autodiff/primal-substitute-4.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type -g0 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal +RWStructuredBuffer outputBufferPrimal; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer +RWStructuredBuffer gradBuffer; + +struct BufferWithGrad +{ + RWStructuredBuffer primal; + RWStructuredBuffer grad; + + [Differentiable] + void add(float value) { primal[0] = primal[0] + detach(value); } + + [PrimalSubstituteOf(add), Differentiable] + void add_subst(float value) + { + } + + [BackwardDerivativeOf(add)] + void add_bwd(inout DifferentialPair d) + { + d = diffPair(d.p, grad[0]); + } +} + +[Differentiable] +void diffCall(BufferWithGrad result) +{ + result.add(1.0f); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; + diffCall(bg); + bwd_diff(diffCall)(bg); + + // CHECK: type: float + // CHECK-NEXT: 1.0 + // CHECK-NEXT: 0.0 +} diff --git a/tests/diagnostics/autodiff-primal-substitute.slang b/tests/diagnostics/autodiff-primal-substitute.slang new file mode 100644 index 000000000..178698719 --- /dev/null +++ b/tests/diagnostics/autodiff-primal-substitute.slang @@ -0,0 +1,44 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal +RWStructuredBuffer outputBufferPrimal; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer +RWStructuredBuffer gradBuffer; + +struct BufferWithGrad +{ + RWStructuredBuffer primal; + RWStructuredBuffer grad; + + [Differentiable] + void add(float value) { primal[0] = primal[0] + detach(value); } + + // check for diagnostic: + // CHECK-DAG: ([[# @LINE+1]]): error 31158 + [PrimalSubstituteOf(add)] + void add_subst(float value) + { + } + + [BackwardDerivativeOf(add)] + void add_bwd(inout DifferentialPair d) + { + d = diffPair(d.p, grad[0]); + } +} + +[Differentiable] +void diffCall(BufferWithGrad result) +{ + result.add(1.0f); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; + diffCall(bg); + bwd_diff(diffCall)(bg); +} -- cgit v1.2.3