From 1b40fe56725eeefe9c601461278376b697d4d35a Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 23 Nov 2022 17:50:02 -0800 Subject: Make differentiable data-flow pass recognize interface methods. (#2530) * Make differentiable data-flow pass recognize interface methods. * Make existing test to work with `[TreatAsDifferentiable]`. Co-authored-by: Yong He --- tests/autodiff/generic-autodiff-1.slang | 2 +- tests/autodiff/generic-impl-jvp.slang | 1 + tests/autodiff/generic-jvp.slang | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/autodiff/generic-autodiff-1.slang b/tests/autodiff/generic-autodiff-1.slang index 43a6d3b10..9ab0d5fef 100644 --- a/tests/autodiff/generic-autodiff-1.slang +++ b/tests/autodiff/generic-autodiff-1.slang @@ -23,7 +23,7 @@ struct A : IInterface [ForwardDifferentiable] float sqr(inout T obj, float x) { - return obj.sample() + x*x; + return (no_diff obj.sample()) + x*x; } [numthreads(1, 1, 1)] diff --git a/tests/autodiff/generic-impl-jvp.slang b/tests/autodiff/generic-impl-jvp.slang index a1bc18252..332833fff 100644 --- a/tests/autodiff/generic-impl-jvp.slang +++ b/tests/autodiff/generic-impl-jvp.slang @@ -225,6 +225,7 @@ typedef lineardvector<4> mydfloat4; typedef DifferentialPair dpfloat; +[TreatAsDifferentiable] interface MyLinearArithmeticType { static This ladd(This a, This b); diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang index 61ec077f4..2be0045d4 100644 --- a/tests/autodiff/generic-jvp.slang +++ b/tests/autodiff/generic-jvp.slang @@ -85,6 +85,7 @@ typedef myvector<4> myfloat4; typedef DifferentialPair dpfloat; +[TreatAsDifferentiable] interface MyLinearArithmeticType { static This ladd(This a, This b); -- cgit v1.2.3