summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/generic-jvp.slang10
-rw-r--r--tests/autodiff/generic-jvp.slang.expected.txt4
2 files changed, 10 insertions, 4 deletions
diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang
index 7e5625477..6cdd63bdb 100644
--- a/tests/autodiff/generic-jvp.slang
+++ b/tests/autodiff/generic-jvp.slang
@@ -10,6 +10,12 @@ __generic<let N : int>
struct myvector
{
vector<Real, N> val;
+
+ [TreatAsDifferentiable]
+ __init(vector<Real,N> data)
+ {
+ val = data;
+ }
}
extension myvector<3> : MyLinearArithmeticType
@@ -115,7 +121,7 @@ extension myfloat3 : IDifferentiable
[ForwardDifferentiable]
static Differential dmul<T : __BuiltinRealType>(T a, Differential b)
{
- return { __realCast<Real, T>(a) * b.val };
+ return myfloat3(__realCast<Real, T>(a) * b.val);
}
};
@@ -141,7 +147,7 @@ extension myfloat4 : IDifferentiable
[ForwardDifferentiable]
static Differential dmul<T: __BuiltinRealType>(T a, Differential b)
{
- return { __realCast<Real, T>(a) * b.val };
+ return myfloat4(__realCast<Real, T>(a) * b.val);
}
};
diff --git a/tests/autodiff/generic-jvp.slang.expected.txt b/tests/autodiff/generic-jvp.slang.expected.txt
index ceeaf120e..5306c75b9 100644
--- a/tests/autodiff/generic-jvp.slang.expected.txt
+++ b/tests/autodiff/generic-jvp.slang.expected.txt
@@ -1,6 +1,6 @@
type: float
22.000000
9.500000
-27.500000
-40.500000
+0.000000
+0.000000
0.000000