From 6c3261b618b88c2b996e56dea58ba4f5435b0908 Mon Sep 17 00:00:00 2001 From: ArielG-NV <159081215+ArielG-NV@users.noreply.github.com> Date: Mon, 26 Aug 2024 17:17:28 -0400 Subject: Correct the `generic-jvp.slang` test (#4900) Fixes: #4899 Fixes invalid test results since `{...}` was differentiating the constructor of `myvector` when it should not (see #4877). This change modifies the test so it is correct so other PRs may be merged if indirectly/directly fixing the old use-case for this test. --- tests/autodiff/generic-jvp.slang | 10 ++++++++-- tests/autodiff/generic-jvp.slang.expected.txt | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) (limited to 'tests') diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 7e5625477..6cdd63bdb 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -10,6 +10,12 @@ __generic struct myvector { vector val; + + [TreatAsDifferentiable] + __init(vector data) + { + val = data; + } } extension myvector<3> : MyLinearArithmeticType @@ -115,7 +121,7 @@ extension myfloat3 : IDifferentiable [ForwardDifferentiable] static Differential dmul(T a, Differential b) { - return { __realCast(a) * b.val }; + return myfloat3(__realCast(a) * b.val); } }; @@ -141,7 +147,7 @@ extension myfloat4 : IDifferentiable [ForwardDifferentiable] static Differential dmul(T a, Differential b) { - return { __realCast(a) * b.val }; + return myfloat4(__realCast(a) * b.val); } }; diff --git a/tests/autodiff/generic-jvp.slang.expected.txt b/tests/autodiff/generic-jvp.slang.expected.txt index ceeaf120e..5306c75b9 100644 --- a/tests/autodiff/generic-jvp.slang.expected.txt +++ b/tests/autodiff/generic-jvp.slang.expected.txt @@ -1,6 +1,6 @@ type: float 22.000000 9.500000 -27.500000 -40.500000 +0.000000 +0.000000 0.000000 -- cgit v1.2.3