diff options
| author | Yong He <yonghe@outlook.com> | 2022-12-01 18:55:43 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-12-01 18:55:43 -0800 |
| commit | e7df8538eb8f0ed06f0838d946bec8e9e0fe0985 (patch) | |
| tree | 3c08e646600ab82ffda260f2b6deb96dd2085776 /tests | |
| parent | f51f69d045d9e0b83d9ab1f4623d4319ce1867be (diff) | |
Allow `no_diff` on `this` parameter. (#2543)
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/no-diff-this-interface.slang | 69 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-this-interface.slang.expected.txt | 7 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-this.slang | 49 | ||||
| -rw-r--r-- | tests/autodiff/no-diff-this.slang.expected.txt | 5 |
4 files changed, 130 insertions, 0 deletions
diff --git a/tests/autodiff/no-diff-this-interface.slang b/tests/autodiff/no-diff-this-interface.slang new file mode 100644 index 000000000..4f4d45089 --- /dev/null +++ b/tests/autodiff/no-diff-this-interface.slang @@ -0,0 +1,69 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; + +interface IFoo +{ + // Since IFoo is not inheriting from IDifferentiable, + // The `this` parameter should be considered as `no_diff` when `getVal` + // is called through this interface. + [ForwardDifferentiable] + float getVal(float y); +} + +struct A : IDifferentiable, IFoo +{ + float x; + + // This `getVal` implementation will have `this` parameter treated as + // differentiable. In order for this method to satisfy the `IFoo.getVal` + // requirement, we need to synthesize a method with `[NoDiffThis]` attribute + // that calls this. + [ForwardDifferentiable] + float getVal(float y){ return x * x + y * y; } +} + +[ForwardDifferentiable] +static float f<T:IFoo>(T obj, float y) +{ + return obj.getVal(y); +} + +[ForwardDifferentiable] +static float f2(IFoo obj, float y) +{ + return obj.getVal(y); +} + +[ForwardDifferentiable] +float f3(A obj, float y) +{ + return obj.getVal(y); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a; + a.x = 2.0; + A.Differential ad; + ad.x = 1.0; + + let rs = __fwd_diff(f)(a, dpfloat(3.0, 1.0)); + outputBuffer[0] = rs.p; // Expect: 13.0 + outputBuffer[1] = rs.d; // Expect: 6.0 + + let rs2 = __fwd_diff(f2)(a, dpfloat(3.0, 1.0)); + outputBuffer[2] = rs2.p; // Expect: 13.0 + outputBuffer[3] = rs2.d; // Expect: 6.0 + + // By calling A.getVal directly, we will invoke the implementation + // that differentiates the `this` argument. + let rs3 = __fwd_diff(f3)(DifferentialPair<A>(a, ad), dpfloat(3.0, 1.0)); + outputBuffer[4] = rs3.p; // Expect: 13.0 + outputBuffer[5] = rs3.d; // Expect: 10.0 +} diff --git a/tests/autodiff/no-diff-this-interface.slang.expected.txt b/tests/autodiff/no-diff-this-interface.slang.expected.txt new file mode 100644 index 000000000..2116bf00f --- /dev/null +++ b/tests/autodiff/no-diff-this-interface.slang.expected.txt @@ -0,0 +1,7 @@ +type: float +13.000000 +6.000000 +13.000000 +6.000000 +13.000000 +10.000000 diff --git a/tests/autodiff/no-diff-this.slang b/tests/autodiff/no-diff-this.slang new file mode 100644 index 000000000..9daf07d05 --- /dev/null +++ b/tests/autodiff/no-diff-this.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; + +struct A : IDifferentiable +{ + float x; + + [ForwardDifferentiable] + float getVal(float y){ return x * x + y * y; } + + [ForwardDifferentiable] + [NoDiffThis] + float getVal2(float y) { return x * x + y * y; } + + [ForwardDifferentiable] + static float f(A obj, float y) + { + return obj.getVal(y); + } + + [ForwardDifferentiable] + static float f2(A obj, float y) + { + return obj.getVal2(y); + } +} + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + A a; + a.x = 2.0; + A.Differential ad; + ad.x = 1.0; + + let rs = __fwd_diff(A.f)(DifferentialPair<A>(a, ad), dpfloat(3.0, 1.0)); + outputBuffer[0] = rs.p; // Expect: 13.0 + outputBuffer[1] = rs.d; // Expect: 10.0 + + let rs2 = __fwd_diff(A.f2)(DifferentialPair<A>(a, ad), dpfloat(3.0, 1.0)); + outputBuffer[2] = rs2.p; // Expect: 13.0 + outputBuffer[3] = rs2.d; // Expect: 6.0 +} diff --git a/tests/autodiff/no-diff-this.slang.expected.txt b/tests/autodiff/no-diff-this.slang.expected.txt new file mode 100644 index 000000000..e55d2a51b --- /dev/null +++ b/tests/autodiff/no-diff-this.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +13.000000 +10.000000 +13.000000 +6.000000 |
