diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 16:02:56 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 16:02:56 -0800 |
| commit | 4ad0470025da4e808c46023f9a2525febcf973a2 (patch) | |
| tree | 8fcb1c84121ddf40c50ca58b5de867da0da435ee /tests/autodiff/dynamic-dispatch-generic-2.slang | |
| parent | 97cb4851eed7a43f10196971b08d3d311386ce9f (diff) | |
Fix issues around dynamic generic function and autodiff. (#2528)
* Fix issues around dynamic generic function and autodiff.
* Fix return type issue.
* Fix type unification for generic `inout` parameter.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests/autodiff/dynamic-dispatch-generic-2.slang')
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-generic-2.slang | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/tests/autodiff/dynamic-dispatch-generic-2.slang b/tests/autodiff/dynamic-dispatch-generic-2.slang new file mode 100644 index 000000000..bbf7c7da1 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-generic-2.slang @@ -0,0 +1,49 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [ForwardDifferentiable] + float calc(float x); +} + +struct A : IInterface +{ + float z; + [ForwardDifferentiable] + float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + float z; + + [ForwardDifferentiable] + float calc(float x) { return x * x + z; } +}; + +[ForwardDifferentiable] +float sqr<T:IInterface>(T obj, float x) +{ + return obj.calc(x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d; // A.calc, expect 12 + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + outputBuffer[1] = __fwd_diff(sqr)(obj, DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3 +} |
