summaryrefslogtreecommitdiffstats
path: root/tests/autodiff/reverse-inout-param-custom-derivative.slang
blob: c4549e37b85bda7ca82b5fdfc49e04380efe6b8e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj

//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

float rng(int state, float x)
{
    return state + x;
}

[BackwardDerivativeOf(rng)]
void rng_bwd(int inState, inout DifferentialPair<float> x, float dOut)
{
    x = diffPair(x.p, (float)inState + dOut - 1.0);
}

[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
    var x = diffPair(2.0, 1.0);

    __bwd_diff(rng)(4, x, 3.0);

    outputBuffer[0] = x.d; // should be 6

}