diff options
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 22 | ||||
| -rw-r--r-- | tests/autodiff/custom-differential-type-error.slang | 26 | ||||
| -rw-r--r-- | tests/autodiff/custom-differential-type.slang | 46 |
4 files changed, 85 insertions, 11 deletions
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 89f7a70bb..5793167af 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1689,7 +1689,7 @@ FIDDLE() class DerivativeMemberAttribute : public Attribute { FIDDLE(...) - FIDDLE() DeclRefExpr* memberDeclRef; + FIDDLE() DeclRefExpr* memberDeclRef = nullptr; }; /// An attribute that marks an interface type as a COM interface declaration. diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 6f2b01aa1..2afd05df2 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7193,6 +7193,18 @@ bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( if (!diffMemberType) continue; + // Since the conformance checking happens before the decl body checking, the + // DerivativeMemberAttribute might not have been checked yet. So we need to make sure + // they are checked before we use them. `checkDerivativeMemberAttributeReferences` + // already handles the case that the attribute has already been checked. + checkDerivativeMemberAttributeReferences(varMember, derivativeAttr); + + // If there is anything wrong in the checking, `checkDerivativeMemberAttributeReferences` + // will diagnose an error, and `derivativeAttr->memberDeclRef` will be null. We will skip + // the remaining synthesis to avoid crash. + if (!derivativeAttr->memberDeclRef) + continue; + // Pull up the derivative member name from the attribute auto derivMemberName = derivativeAttr->memberDeclRef->declRef.getName(); @@ -8145,16 +8157,6 @@ void SemanticsVisitor::checkAggTypeConformance(AggTypeDecl* decl) auto inheritanceDecls = decl->getMembersOfType<InheritanceDecl>().toList(); for (auto inheritanceDecl : inheritanceDecls) { - // Special handling for when we check for conformance against `IDifferentiable` - // We will reference-checking for the [DerivativeMember(DiffType.member)] - // attributes here, since they have to be performed after types can be referenced - // and before conformance checking, where this information can be used to synthesize - // member methods (such as `dzero`, `dadd`, etc..) - // - if (inheritanceDecl->getSup().type->equals( - astBuilder->getDifferentiableInterfaceType())) - checkDifferentiableMembersInType(decl); - checkConformance(type, inheritanceDecl, decl); } diff --git a/tests/autodiff/custom-differential-type-error.slang b/tests/autodiff/custom-differential-type-error.slang new file mode 100644 index 000000000..2542eff4f --- /dev/null +++ b/tests/autodiff/custom-differential-type-error.slang @@ -0,0 +1,26 @@ +//TEST:SIMPLE(filecheck=CHECK):-target spirv + +interface IFoo : IDifferentiable +{ + float myFunc(); +} + +struct Bar: IDifferentiable +{ + float y; +} + +struct Foo : IFoo +{ + typealias Differential = Bar; + + // CHECK: ([[# @LINE+1]]): error 30027: 'x' is not a member of 'typeof(Foo.Differential)'. + [DerivativeMember(Differential.x)] + float x; + + [BackwardDifferentiable] + float myFunc() + { + return x * x; + }; +} diff --git a/tests/autodiff/custom-differential-type.slang b/tests/autodiff/custom-differential-type.slang new file mode 100644 index 000000000..4782b4297 --- /dev/null +++ b/tests/autodiff/custom-differential-type.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cuda -compute -shaderobj -output-using-type + +interface IFoo : IDifferentiable +{ + float myFunc(); +} + +struct Bar: IDifferentiable +{ + float y; +} + +struct Foo : IFoo +{ + typealias Differential = Bar; + + [DerivativeMember(Differential.y)] + float x; + + [BackwardDifferentiable] + float myFunc() + { + return x * x; + }; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=g_output +RWStructuredBuffer<float> g_output; + +[BackwardDifferentiable] +float wrapper(Foo f) +{ + return f.myFunc(); +} + +[shader("compute")] +void computeMain(uint3 thread_id : SV_DispatchThreadID) +{ + var di = diffPair(Foo(1.0f), Bar(0.0f)); + bwd_diff(wrapper)(di, 1.0f); + + // CHECK: 2.0 + g_output[0] = di.d.y; +} |
