summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArielG-NV <159081215+ArielG-NV@users.noreply.github.com>2024-08-26 17:17:28 -0400
committerGitHub <noreply@github.com>2024-08-26 17:17:28 -0400
commit6c3261b618b88c2b996e56dea58ba4f5435b0908 (patch)
treecbc961e42f98d99810032ca956c8767fdbaf7849
parente1c6fecd90142761aaecbf4e281beb87893fc531 (diff)
Correct the `generic-jvp.slang` test (#4900)
Fixes: #4899 Fixes invalid test results since `{...}` was differentiating the constructor of `myvector` when it should not (see #4877). This change modifies the test so it is correct so other PRs may be merged if indirectly/directly fixing the old use-case for this test.
-rw-r--r--tests/autodiff/generic-jvp.slang10
-rw-r--r--tests/autodiff/generic-jvp.slang.expected.txt4
2 files changed, 10 insertions, 4 deletions
diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang
index 7e5625477..6cdd63bdb 100644
--- a/tests/autodiff/generic-jvp.slang
+++ b/tests/autodiff/generic-jvp.slang
@@ -10,6 +10,12 @@ __generic<let N : int>
struct myvector
{
vector<Real, N> val;
+
+ [TreatAsDifferentiable]
+ __init(vector<Real,N> data)
+ {
+ val = data;
+ }
}
extension myvector<3> : MyLinearArithmeticType
@@ -115,7 +121,7 @@ extension myfloat3 : IDifferentiable
[ForwardDifferentiable]
static Differential dmul<T : __BuiltinRealType>(T a, Differential b)
{
- return { __realCast<Real, T>(a) * b.val };
+ return myfloat3(__realCast<Real, T>(a) * b.val);
}
};
@@ -141,7 +147,7 @@ extension myfloat4 : IDifferentiable
[ForwardDifferentiable]
static Differential dmul<T: __BuiltinRealType>(T a, Differential b)
{
- return { __realCast<Real, T>(a) * b.val };
+ return myfloat4(__realCast<Real, T>(a) * b.val);
}
};
diff --git a/tests/autodiff/generic-jvp.slang.expected.txt b/tests/autodiff/generic-jvp.slang.expected.txt
index ceeaf120e..5306c75b9 100644
--- a/tests/autodiff/generic-jvp.slang.expected.txt
+++ b/tests/autodiff/generic-jvp.slang.expected.txt
@@ -1,6 +1,6 @@
type: float
22.000000
9.500000
-27.500000
-40.500000
+0.000000
+0.000000
0.000000