diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-16 12:17:49 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-16 12:17:49 -0800 |
| commit | 801aa3b44254341018a1acbe754f2ce3b0900e2a (patch) | |
| tree | b3066778522edb99bf64c0ac80c91b0b4cb788f8 /tests/autodiff | |
| parent | 09d8e048d2264d89886cda8e87e8a452d4f913c1 (diff) | |
Clean up type checking of higher order expressions. (#2519)
* Clean up type checking of higher order expressions.
* Replace `goto` with `break` to pacify clang.
* Fix.
* Fixes.
* Fix more tests.
* Fix lowerWitnessTable parameter error.
* Exclude attributes from ast printing.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tests/autodiff')
| -rw-r--r-- | tests/autodiff/high-order-diff-operator.slang | 35 | ||||
| -rw-r--r-- | tests/autodiff/high-order-diff-operator.slang.expected.txt | 5 |
2 files changed, 40 insertions, 0 deletions
diff --git a/tests/autodiff/high-order-diff-operator.slang b/tests/autodiff/high-order-diff-operator.slang new file mode 100644 index 000000000..dca67e9f3 --- /dev/null +++ b/tests/autodiff/high-order-diff-operator.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[ForwardDifferentiable] +float mySqr(float x) +{ + return x * x; +} + +[ForwardDifferentiable] +float f(float x) +{ + return mySqr(x * x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Given f(x) = x^4, + // f''(x) = 12 * x^2 + // Expect f''(4) = 192 + float.Differential t = 2; + outputBuffer[0] = __fwd_diff(__fwd_diff(f))( + DifferentialPair<DifferentialPair<float>>( + DifferentialPair<float>(4.0, 1.0), DifferentialPair<float>(1.0, 0.0))).d.d; + + // sin''(x) = cos'(x) = -sin(x). + // Expect sin''(Pi/2) = -1. + outputBuffer[1] = __fwd_diff(__fwd_diff(sin))( + DifferentialPair<DifferentialPair<float>>( + DifferentialPair<float>(float.getPi()/2, 1.0), DifferentialPair<float>(1.0, 0.0))).d.d; +} diff --git a/tests/autodiff/high-order-diff-operator.slang.expected.txt b/tests/autodiff/high-order-diff-operator.slang.expected.txt new file mode 100644 index 000000000..305a8e111 --- /dev/null +++ b/tests/autodiff/high-order-diff-operator.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +192.000000 +-1.000000 +0.000000 +0.000000 |
