diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-02-11 03:08:27 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-11 19:08:27 +0800 |
| commit | 551bbb5fbd61b53253de8f6ba3303bb4d29f8c86 (patch) | |
| tree | 44b21ceefd66ac2b92c1b165fe280dcfe276cf65 | |
| parent | 0b4e463aee4107b383067424007c6a995f1f9f87 (diff) | |
Add checking for differentiability of the primal substitute function. (#6277)
Co-authored-by: Yong He <yonghe@outlook.com>
Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 | ||||
| -rw-r--r-- | tests/autodiff/primal-substitute-4.slang | 46 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-primal-substitute.slang | 44 |
5 files changed, 118 insertions, 0 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 2ce4f81f1..76074f551 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -12114,6 +12114,27 @@ static void checkDerivativeAttribute( imaginaryArguments.directions, imaginaryArguments.thisArg, imaginaryArguments.thisArgDirection); + + // For primal-substitute we'd also want to make sure that the differentiability + // level of the target is as high as the funcDecl itself + // + if (auto declRefExpr = as<DeclRefExpr>(attr->funcExpr)) + { + if (auto declRef = declRefExpr->declRef) + { + auto targetDiffLevel = visitor->getShared()->getFuncDifferentiableLevel( + declRef.as<FunctionDeclBase>().getDecl()); + auto currDiffLevel = visitor->getShared()->getFuncDifferentiableLevel(funcDecl); + if (targetDiffLevel < currDiffLevel) + { + visitor->getSink()->diagnose( + attr->loc, + Diagnostics::primalSubstituteTargetMustHaveHigherDifferentiabilityLevel, + declRefExpr->declRef.getDecl(), + funcDecl); + } + } + } } static void checkCudaKernelAttribute( diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 59a1bbdb6..acb9beb94 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -1208,6 +1208,12 @@ DIAGNOSTIC( Error, overloadedFuncUsedWithDerivativeOfAttributes, "cannot resolve overloaded functions for derivative-of attributes.") +DIAGNOSTIC( + 31158, + Error, + primalSubstituteTargetMustHaveHigherDifferentiabilityLevel, + "primal substitute function for differentiable method must also be differentiable. Use " + "[Differentiable] or [TreatAsDifferentiable] (for empty derivatives)") DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.") diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index d65a22e77..9075002e0 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -2471,6 +2471,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalReturnDecoration: + case kIROp_PrimalSubstituteDecoration: case kIROp_AutoDiffOriginalValueDecoration: case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_IntermediateContextFieldDifferentialTypeDecoration: diff --git a/tests/autodiff/primal-substitute-4.slang b/tests/autodiff/primal-substitute-4.slang new file mode 100644 index 000000000..8f7720639 --- /dev/null +++ b/tests/autodiff/primal-substitute-4.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type -g0 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal +RWStructuredBuffer<float> outputBufferPrimal; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer +RWStructuredBuffer<float> gradBuffer; + +struct BufferWithGrad +{ + RWStructuredBuffer<float> primal; + RWStructuredBuffer<float> grad; + + [Differentiable] + void add(float value) { primal[0] = primal[0] + detach(value); } + + [PrimalSubstituteOf(add), Differentiable] + void add_subst(float value) + { + } + + [BackwardDerivativeOf(add)] + void add_bwd(inout DifferentialPair<float> d) + { + d = diffPair(d.p, grad[0]); + } +} + +[Differentiable] +void diffCall(BufferWithGrad result) +{ + result.add(1.0f); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; + diffCall(bg); + bwd_diff(diffCall)(bg); + + // CHECK: type: float + // CHECK-NEXT: 1.0 + // CHECK-NEXT: 0.0 +} diff --git a/tests/diagnostics/autodiff-primal-substitute.slang b/tests/diagnostics/autodiff-primal-substitute.slang new file mode 100644 index 000000000..178698719 --- /dev/null +++ b/tests/diagnostics/autodiff-primal-substitute.slang @@ -0,0 +1,44 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal +RWStructuredBuffer<float> outputBufferPrimal; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer +RWStructuredBuffer<float> gradBuffer; + +struct BufferWithGrad +{ + RWStructuredBuffer<float> primal; + RWStructuredBuffer<float> grad; + + [Differentiable] + void add(float value) { primal[0] = primal[0] + detach(value); } + + // check for diagnostic: + // CHECK-DAG: ([[# @LINE+1]]): error 31158 + [PrimalSubstituteOf(add)] + void add_subst(float value) + { + } + + [BackwardDerivativeOf(add)] + void add_bwd(inout DifferentialPair<float> d) + { + d = diffPair(d.p, grad[0]); + } +} + +[Differentiable] +void diffCall(BufferWithGrad result) +{ + result.add(1.0f); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; + diffCall(bg); + bwd_diff(diffCall)(bg); +} |
