summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-03-07 11:22:32 -0800
committerGitHub <noreply@github.com>2023-03-07 11:22:32 -0800
commit257733f328f38a763c8b0c8830ff4c0d34ec9491 (patch)
tree87e444746f353d69a365380904f3f8caf15fbfec /tests
parent6f31eae79d5b4297d0099c5779a9806a786cf9f8 (diff)
Reuse higher-order `ResolveInvoke` logic to resolve func refs in `[*DerivativeOf]` attribs. (#2688)
* Reuse higher-order `ResolveInvoke` logic to resolve func refs in [*DerivativeOf] attribs. * Add diff implementation matrix versions of binary and ternary intrinsics. * Add diff impl for legacy intrinsics. * Fix diagnostics of using non-differentiable function in a diff operator. * Add diff implementation for `determinant`. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang12
-rw-r--r--tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt5
-rw-r--r--tests/autodiff-dstdlib/dstdlib-sqrt.slang11
-rw-r--r--tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt6
4 files changed, 30 insertions, 4 deletions
diff --git a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang
index 379e2c3ef..53972ac2c 100644
--- a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang
+++ b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang
@@ -2,7 +2,7 @@
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
@@ -43,4 +43,14 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
__bwd_diff(diffSin)(dpx, 1.0);
outputBuffer[4] = dpx.d; // Expect: -1.000000
}
+
+ {
+ dpfloat dpx = dpfloat(float.getPi() / 3.0, 1.0);
+ __bwd_diff(sincos)(dpx, 1.0, 0.0);
+ outputBuffer[5] = dpx.d; // Expect: 0.5
+ __bwd_diff(sincos)(dpx, 0.0, 1.0);
+ outputBuffer[6] = dpx.d; // Expect: -0.8660254
+ __bwd_diff(sincos)(dpx, 1.0, 1.0);
+ outputBuffer[7] = dpx.d; // Expect: -0.3660254
+ }
}
diff --git a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt
index a4b804cb8..17627df68 100644
--- a/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-elementary-bwd.slang.expected.txt
@@ -3,4 +3,7 @@ type: float
7.389056
0.000000
1.000000
--1.00000 \ No newline at end of file
+-1.000000
+0.500000
+-0.866025
+-0.366025
diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang b/tests/autodiff-dstdlib/dstdlib-sqrt.slang
index 15573c4ef..d68a2697c 100644
--- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang
+++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang
@@ -1,7 +1,7 @@
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;
typedef DifferentialPair<float> dpfloat;
@@ -50,4 +50,13 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
outputBuffer[6] = dpx.d[0]; // Expect: 0.158114
outputBuffer[7] = dpx.d[1]; // Expect: 0.577350
}
+
+ {
+ var dpx = diffPair(float2x2(4.0, 9.0, 16.0, 25.0), float2x2(0.0, 0.0, 0.0, 0.0));
+ __bwd_diff(sqrt)(dpx, float2x2(1.0, 2.0, 3.0, 4.0));
+ outputBuffer[8] = dpx.d[0][0]; // Expect: 0.25
+ outputBuffer[9] = dpx.d[0][1]; // Expect: 0.3333
+ outputBuffer[10] = dpx.d[1][0]; // Expect: 0.375
+ outputBuffer[11] = dpx.d[1][1]; // Expect: 0.4
+ }
}
diff --git a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt
index fe6487fef..7e0fdf02f 100644
--- a/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt
+++ b/tests/autodiff-dstdlib/dstdlib-sqrt.slang.expected.txt
@@ -6,4 +6,8 @@ type: float
0.000000
0.000000
0.158114
-0.577350 \ No newline at end of file
+0.577350
+0.250000
+0.333333
+0.375000
+0.400000 \ No newline at end of file