From bffac95febd7a29cfac0becfcb019cd057b53765 Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Thu, 2 Oct 2025 15:23:53 -0700 Subject: Fix the missing derivative member check (#8569) Close #8568. The root cause of this issue is that when the struct is indirectly inherited from IDifferentiable type, we will not check the reference of the DerivativeMember attribute. This PR fixes this issue by checking the DerivativeMember attribute right before synthesize the requirement methods of IDifferentiable interface. --- .../autodiff/custom-differential-type-error.slang | 26 ++++++++++++ tests/autodiff/custom-differential-type.slang | 46 ++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 tests/autodiff/custom-differential-type-error.slang create mode 100644 tests/autodiff/custom-differential-type.slang (limited to 'tests/autodiff') diff --git a/tests/autodiff/custom-differential-type-error.slang b/tests/autodiff/custom-differential-type-error.slang new file mode 100644 index 000000000..2542eff4f --- /dev/null +++ b/tests/autodiff/custom-differential-type-error.slang @@ -0,0 +1,26 @@ +//TEST:SIMPLE(filecheck=CHECK):-target spirv + +interface IFoo : IDifferentiable +{ + float myFunc(); +} + +struct Bar: IDifferentiable +{ + float y; +} + +struct Foo : IFoo +{ + typealias Differential = Bar; + + // CHECK: ([[# @LINE+1]]): error 30027: 'x' is not a member of 'typeof(Foo.Differential)'. + [DerivativeMember(Differential.x)] + float x; + + [BackwardDifferentiable] + float myFunc() + { + return x * x; + }; +} diff --git a/tests/autodiff/custom-differential-type.slang b/tests/autodiff/custom-differential-type.slang new file mode 100644 index 000000000..4782b4297 --- /dev/null +++ b/tests/autodiff/custom-differential-type.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cuda -compute -shaderobj -output-using-type + +interface IFoo : IDifferentiable +{ + float myFunc(); +} + +struct Bar: IDifferentiable +{ + float y; +} + +struct Foo : IFoo +{ + typealias Differential = Bar; + + [DerivativeMember(Differential.y)] + float x; + + [BackwardDifferentiable] + float myFunc() + { + return x * x; + }; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=g_output +RWStructuredBuffer g_output; + +[BackwardDifferentiable] +float wrapper(Foo f) +{ + return f.myFunc(); +} + +[shader("compute")] +void computeMain(uint3 thread_id : SV_DispatchThreadID) +{ + var di = diffPair(Foo(1.0f), Bar(0.0f)); + bwd_diff(wrapper)(di, 1.0f); + + // CHECK: 2.0 + g_output[0] = di.d.y; +} -- cgit v1.2.3