diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-bwd-diff.slang | 52 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt | 6 |
2 files changed, 58 insertions, 0 deletions
diff --git a/tests/autodiff/dynamic-dispatch-bwd-diff.slang b/tests/autodiff/dynamic-dispatch-bwd-diff.slang new file mode 100644 index 000000000..5945c22cd --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-bwd-diff.slang @@ -0,0 +1,52 @@ +// Test calling backward differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [BackwardDifferentiable] + float calc(float x); +} + +struct A : IInterface +{ + float a; + [BackwardDifferentiable] + float calc(float x) { return a*x*x; } +}; + +struct B : IInterface +{ + float a; + [BackwardDifferentiable] + float calc(float x) { return a*x*x*x; } +}; + +[BackwardDifferentiable] +float run(IInterface obj, float x) +{ + return obj.calc(x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0.5f); // A + var p = diffPair(3.0); + + __bwd_diff(run)(obj, p, 1.0f); + outputBuffer[0] = p.d; // A.calc, expect 3 + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 1.5f); // B + p = diffPair(3.0); + __bwd_diff(run)(obj, p, 1.0f); + outputBuffer[1] = p.d; // B.calc, expect 40.5 +} diff --git a/tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt b/tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt new file mode 100644 index 000000000..57bb1ee65 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-bwd-diff.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +3.000000 +40.500000 +0.000000 +0.000000 +0.000000 |
