summaryrefslogtreecommitdiffstats
path: root/tests/diagnostics/custom-derivative-generic.slang
blob: fb65dd2ccaf0e995baa9a3ca0016fd681e5e96aa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):

//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer;

typedef DifferentialPair<float> dpfloat;

interface IFoo
{
    static float bar1(float x);

    // CHECK-DAG: {{.*}}(13): error 31152
    [PrimalSubstitute(bar1)]
    static float bar(float x);

    static DifferentialPair<float> dd(DifferentialPair<float> x);
}

__generic<let N:int>
float f(float x)
{
    return N*x*x;
}

// CHECK-DAG: {{.*}}(26): error 31153
[ForwardDerivative(IFoo.dd)]
float bbb(float x);

// CHECK-DAG: {{.*}}(30): error 31152
[ForwardDerivativeOf(IFoo.bar)]
DifferentialPair<float> dd1(DifferentialPair<float> x)
{
    return x;
}

// CHECK-DAG: {{.*}}(37): error 31151
[BackwardDerivativeOf(f)]
DifferentialPair<float> df<let N:int>(inout DifferentialPair<float> x, float dOut)
{
    var primal = x.p * x.p;
    var diff = 2 * x.p * x.d * N;
    return DifferentialPair<float>(primal, diff);
}
[numthreads(1, 1, 1)]
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
    {
        dpfloat dpa = dpfloat(3.0, 1.0);
        outputBuffer[1] = __fwd_diff(f<3>)(dpa).d; // Expect: 6.0
    }
}