diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 09:39:08 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 09:39:08 -0800 |
| commit | 97cb4851eed7a43f10196971b08d3d311386ce9f (patch) | |
| tree | 99ba81368068b3345fa23b749108265aa753ed2b /tests | |
| parent | 6178cb601368e977c4aa82e0ae25b8eb1e875d84 (diff) | |
Autodiff through simple dynamic dispatch. (#2527)
* Autodiff through simple dynamic dispatch.
* Revert changes.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-autodiff-simple.slang | 48 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt | 6 |
2 files changed, 54 insertions, 0 deletions
diff --git a/tests/autodiff/dynamic-dispatch-autodiff-simple.slang b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang new file mode 100644 index 000000000..1247253f9 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang @@ -0,0 +1,48 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [ForwardDifferentiable] + static float calc(float x); +} + +struct A : IInterface +{ + [ForwardDifferentiable] + static float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + [ForwardDifferentiable] + static float calc(float x) { return x * x; } +}; + +[ForwardDifferentiable] +float sqr(IInterface obj, float x) +{ + return obj.calc(x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d; // A.calc, expect 12 + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + outputBuffer[1] = __fwd_diff(sqr)(obj, DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3 + + outputBuffer[2] = __fwd_diff(obj.calc)(DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3 +} diff --git a/tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt new file mode 100644 index 000000000..1b1844a5d --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +3.000000 +3.000000 +0.000000 +0.000000 |
