diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-04-23 17:12:14 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-23 14:12:14 -0700 |
| commit | 94d696801e8b313267e518cb16949d0ec122d46f (patch) | |
| tree | ad9f9628882792dc7f5f0fd4f987a5810ad0040e /tests | |
| parent | e8673a535e91af8fd8d31d6845af1c792f554f05 (diff) | |
Add support for `kIROp_MakeExistential` (#2832)
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/dynamic-object-bwd-diff.slang | 76 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-object-bwd-diff.slang.expected.txt | 6 |
2 files changed, 82 insertions, 0 deletions
diff --git a/tests/autodiff/dynamic-object-bwd-diff.slang b/tests/autodiff/dynamic-object-bwd-diff.slang new file mode 100644 index 000000000..a10c48f9b --- /dev/null +++ b/tests/autodiff/dynamic-object-bwd-diff.slang @@ -0,0 +1,76 @@ +// Test calling backward differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [BackwardDifferentiable] + float calc(IInterface2 i2, float x); +} + +interface IInterface2 +{ + float innerCalc(float x); +} + +struct C : IInterface2 +{ + float innerCalc(float x) { return 2 * x; } +} + +struct A : IInterface +{ + float a; + [BackwardDifferentiable] + float calc(IInterface2 i2, float x) + { + float b = no_diff(i2.innerCalc(x)); + return a*b*x; + } +}; + +struct B : IInterface +{ + float a; + [BackwardDifferentiable] + float calc(IInterface2 i2, float x) + { + float b = no_diff(i2.innerCalc(x)); + return a*b*x*x; + } +}; + +[BackwardDifferentiable] +float run(int id, float x, no_diff float y) +{ + IInterface obj = createDynamicObject<IInterface>(id, y); + C c = {}; + return obj.calc(c, x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 +//TEST_INPUT: type_conformance C:IInterface2 = 0 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + var p = diffPair(3.0); + + __bwd_diff(run)(0, p, 0.5, 1.0f); + outputBuffer[0] = p.d; // A.calc, expect 3 + } + + { + var p = diffPair(3.0); + + __bwd_diff(run)(1, p, 1.5, 1.0f); + outputBuffer[1] = p.d; // B.calc, expect 40.5 + } +} diff --git a/tests/autodiff/dynamic-object-bwd-diff.slang.expected.txt b/tests/autodiff/dynamic-object-bwd-diff.slang.expected.txt new file mode 100644 index 000000000..7c6952bfa --- /dev/null +++ b/tests/autodiff/dynamic-object-bwd-diff.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +3.000000 +54.000000 +0.000000 +0.000000 +0.000000 |
