From bffac95febd7a29cfac0becfcb019cd057b53765 Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Thu, 2 Oct 2025 15:23:53 -0700 Subject: Fix the missing derivative member check (#8569) Close #8568. The root cause of this issue is that when the struct is indirectly inherited from IDifferentiable type, we will not check the reference of the DerivativeMember attribute. This PR fixes this issue by checking the DerivativeMember attribute right before synthesize the requirement methods of IDifferentiable interface. --- source/slang/slang-ast-modifier.h | 2 +- source/slang/slang-check-decl.cpp | 22 ++++++----- .../autodiff/custom-differential-type-error.slang | 26 ++++++++++++ tests/autodiff/custom-differential-type.slang | 46 ++++++++++++++++++++++ 4 files changed, 85 insertions(+), 11 deletions(-) create mode 100644 tests/autodiff/custom-differential-type-error.slang create mode 100644 tests/autodiff/custom-differential-type.slang diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 89f7a70bb..5793167af 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1689,7 +1689,7 @@ FIDDLE() class DerivativeMemberAttribute : public Attribute { FIDDLE(...) - FIDDLE() DeclRefExpr* memberDeclRef; + FIDDLE() DeclRefExpr* memberDeclRef = nullptr; }; /// An attribute that marks an interface type as a COM interface declaration. diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 6f2b01aa1..2afd05df2 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7193,6 +7193,18 @@ bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( if (!diffMemberType) continue; + // Since the conformance checking happens before the decl body checking, the + // DerivativeMemberAttribute might not have been checked yet. So we need to make sure + // they are checked before we use them. `checkDerivativeMemberAttributeReferences` + // already handles the case that the attribute has already been checked. + checkDerivativeMemberAttributeReferences(varMember, derivativeAttr); + + // If there is anything wrong in the checking, `checkDerivativeMemberAttributeReferences` + // will diagnose an error, and `derivativeAttr->memberDeclRef` will be null. We will skip + // the remaining synthesis to avoid crash. + if (!derivativeAttr->memberDeclRef) + continue; + // Pull up the derivative member name from the attribute auto derivMemberName = derivativeAttr->memberDeclRef->declRef.getName(); @@ -8145,16 +8157,6 @@ void SemanticsVisitor::checkAggTypeConformance(AggTypeDecl* decl) auto inheritanceDecls = decl->getMembersOfType().toList(); for (auto inheritanceDecl : inheritanceDecls) { - // Special handling for when we check for conformance against `IDifferentiable` - // We will reference-checking for the [DerivativeMember(DiffType.member)] - // attributes here, since they have to be performed after types can be referenced - // and before conformance checking, where this information can be used to synthesize - // member methods (such as `dzero`, `dadd`, etc..) - // - if (inheritanceDecl->getSup().type->equals( - astBuilder->getDifferentiableInterfaceType())) - checkDifferentiableMembersInType(decl); - checkConformance(type, inheritanceDecl, decl); } diff --git a/tests/autodiff/custom-differential-type-error.slang b/tests/autodiff/custom-differential-type-error.slang new file mode 100644 index 000000000..2542eff4f --- /dev/null +++ b/tests/autodiff/custom-differential-type-error.slang @@ -0,0 +1,26 @@ +//TEST:SIMPLE(filecheck=CHECK):-target spirv + +interface IFoo : IDifferentiable +{ + float myFunc(); +} + +struct Bar: IDifferentiable +{ + float y; +} + +struct Foo : IFoo +{ + typealias Differential = Bar; + + // CHECK: ([[# @LINE+1]]): error 30027: 'x' is not a member of 'typeof(Foo.Differential)'. + [DerivativeMember(Differential.x)] + float x; + + [BackwardDifferentiable] + float myFunc() + { + return x * x; + }; +} diff --git a/tests/autodiff/custom-differential-type.slang b/tests/autodiff/custom-differential-type.slang new file mode 100644 index 000000000..4782b4297 --- /dev/null +++ b/tests/autodiff/custom-differential-type.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cuda -compute -shaderobj -output-using-type + +interface IFoo : IDifferentiable +{ + float myFunc(); +} + +struct Bar: IDifferentiable +{ + float y; +} + +struct Foo : IFoo +{ + typealias Differential = Bar; + + [DerivativeMember(Differential.y)] + float x; + + [BackwardDifferentiable] + float myFunc() + { + return x * x; + }; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=g_output +RWStructuredBuffer g_output; + +[BackwardDifferentiable] +float wrapper(Foo f) +{ + return f.myFunc(); +} + +[shader("compute")] +void computeMain(uint3 thread_id : SV_DispatchThreadID) +{ + var di = diffPair(Foo(1.0f), Bar(0.0f)); + bwd_diff(wrapper)(di, 1.0f); + + // CHECK: 2.0 + g_output[0] = di.d.y; +} -- cgit v1.2.3