diff options
| author | Yong He <yonghe@outlook.com> | 2023-01-11 15:33:28 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-11 15:33:28 -0800 |
| commit | a3ac6e71cbc922b7c941c45f23ee18a9fc274d1f (patch) | |
| tree | acf8c18601f124e9290494f8b379d2420369fc35 /tests | |
| parent | 20262684bcbb707d16669b2670039df870b65ca8 (diff) | |
Make backward differentiation work with generics. (#2586)
* Make backward differentiation work with generics.
* Fix.
* Another fix.
* More fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-reverse-1.slang | 47 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-reverse-1.slang.expected.txt | 6 |
2 files changed, 53 insertions, 0 deletions
diff --git a/tests/autodiff/dynamic-dispatch-reverse-1.slang b/tests/autodiff/dynamic-dispatch-reverse-1.slang new file mode 100644 index 000000000..846004f95 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-reverse-1.slang @@ -0,0 +1,47 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + static float calc(float x); +} + +struct A : IInterface +{ + static float calc(float x) { return 1.0; } +}; + +struct B : IInterface +{ + static float calc(float x) { return 2.0; } +}; + +[BackwardDifferentiable] +float sqr<T:IInterface>(T obj, float x) +{ + return no_diff(obj.calc(x)) + x * x; +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + var p = DifferentialPair<float>(2.0, 1.0); + __bwd_diff(sqr)(obj, p, 1.0); // A.calc, expect 4 + outputBuffer[0] = p.d; + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + p = DifferentialPair<float>(1.5, 1.0); + __bwd_diff(sqr)(obj, p, 1.0); // A.calc, expect 4 + outputBuffer[1] = p.d; // B.calc, expect 3 +} diff --git a/tests/autodiff/dynamic-dispatch-reverse-1.slang.expected.txt b/tests/autodiff/dynamic-dispatch-reverse-1.slang.expected.txt new file mode 100644 index 000000000..780ba6ed4 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-reverse-1.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +4.000000 +3.000000 +0.000000 +0.000000 +0.000000 |
