summaryrefslogtreecommitdiffstats
path: root/tests/bugs/gh-6589.slang
blob: 9433f510e5b923f1ad19c2bb3aa05bbbe5553186 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
//TEST:SIMPLE(filecheck=CHECK): -target spirv

// This is a test that checks that we can apply partial specialization to a function
// we won't specialize the function parameters too aggressively. Instead, we will specialize
// the parameters at the same time of specializing the arguments. Otherwise, we could lose
// the chance to specialize the argument.
//
// In this test, `matrix_vector_interfaces` will be fully specialized, otherwise the compile
// will fail because we don't allow opaque type in the existential type. So as long as the target
// spirv code can be generated, we are good.

// CHECK: %main
public interface ITensor<T : IDifferentiable, let D : int>
{
    public T get(int idx);

}

public interface IRWTensor<T : IDifferentiable, let D : int> : ITensor<T, D>
{
}


public struct RWTensor<T : IDifferentiable, let D : int> : IRWTensor<T, D>
{
    public RWStructuredBuffer<T> buffer;
    public T get(int idx) { return buffer[idx]; }
}

public struct GradInOutTensor<T : IDifferentiable, let D : int> : IRWTensor<T, D>
{
    public RWTensor<T, D> primal;
    public T get(int idx) { return primal.get(idx); }
}

struct CallData
{
    GradInOutTensor<float, 3> weights;
    GradInOutTensor<float, 2> biases;
    RWStructuredBuffer<float> _result;
}
ParameterBlock<CallData> call_data;

float matrix_vector_interfaces(ITensor<float, 2> weights, ITensor<float, 1> biases)
{
    return weights.get(0);
}

[shader("compute")]
[numthreads(1, 1, 1)]
void main(uint3 dispatchThreadID: SV_DispatchThreadID)
{
    float _result;
    GradInOutTensor<float, 2> weights;
    GradInOutTensor<float, 1> biases;

    weights.primal.buffer = call_data.weights.primal.buffer;
    biases.primal.buffer = call_data.biases.primal.buffer;

    _result = matrix_vector_interfaces(weights, biases);

    call_data._result[0] = _result;
}