summaryrefslogtreecommitdiffstats
path: root/tests/language-feature/generics/generic-witness-derived.slang
blob: e9659102c318b19505a179e8ab7cdb3626884d7e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type

// Test that we can compile a generic function with a generic type constraint that is dependent on an
// outer generic type parameter.

namespace ns{

    public interface IBinaryElementWiseFunction<T>
    {
        public static T call(const in T lhs, const in T rhs);
    }
    public struct AddOp<T : IArithmetic> : IBinaryElementWiseFunction<T>
    {
        public static  T call(const in T lhs, const in T rhs)
        {
            return lhs + rhs; 
        }
    }
    public struct BinaryElementWiseInputData<T : IArithmetic>
    {
        T lhs;
        T rhs;

        // Note: `U` is constrainted by `IBinaryElementWiseFunction<T>`, which is dependent on `T`,
        // that is another generic type parameter defined on the outer type.
        // This eventually leads to a IRGeneric where one param has a type that is dependent on
        // another param.
        // In this case, the IR for `test` after generic flattening will be:
        // ```
        // %g_test = IRGeneric
        // {
        //     IRBlock
        //     {
        //          %T = IRParam : Type;
        //          %T_w = IRParam : IRWitnessTableType<IArithmetic>;
        //          %U = IRParam : Type;
        //          %U_w = IRRaram : IRWitnessTableType<%s>; // note that the type here is a forward reference to %s
        //          %s = specialize(%IBinaryElementWiseFunction, %T) // %s is dependent on %T.
        //          ...
        //     }
        // }
        // 
        public T test<U : IBinaryElementWiseFunction<T>>(U x)
        {
            return x.call(lhs ,rhs);
        }
    }
}


//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer;

[shader("compute")]
[numthreads(1,1,1)]
void computeMain(uint3 threadId: SV_DispatchThreadID)
{
    ns::BinaryElementWiseInputData<int> cb;
    cb.lhs = threadId.x + 1;
    cb.rhs = 2;
    // CHECK: 3
    outputBuffer[0] = cb.test(ns::AddOp<int>());
}