From b2ca2d5a4efeae807d3c3f48f60235e47413b559 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 23 Aug 2024 21:45:59 -0700 Subject: Make variadic generics work with interfaces and forward autodiff. (#4905) --- tests/language-feature/generics/variadic-0.slang | 4 +- .../language-feature/generics/variadic-void.slang | 2 + tests/language-feature/ifunc/diff-functor.slang | 44 ++++++++++++++++++++++ tests/language-feature/ifunc/ifunc.slang | 40 ++++++++++++++++++++ 4 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 tests/language-feature/ifunc/diff-functor.slang create mode 100644 tests/language-feature/ifunc/ifunc.slang (limited to 'tests') diff --git a/tests/language-feature/generics/variadic-0.slang b/tests/language-feature/generics/variadic-0.slang index 8ee41647f..ac9ca2c1c 100644 --- a/tests/language-feature/generics/variadic-0.slang +++ b/tests/language-feature/generics/variadic-0.slang @@ -1,4 +1,6 @@ //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; @@ -19,7 +21,7 @@ S makeS(T x) } bool cmp(T a, int b) { - return a > __int_cast(b); + return a > T(b); } void accept(expand each T value) {} diff --git a/tests/language-feature/generics/variadic-void.slang b/tests/language-feature/generics/variadic-void.slang index d44acbfd4..976c104f8 100644 --- a/tests/language-feature/generics/variadic-void.slang +++ b/tests/language-feature/generics/variadic-void.slang @@ -1,4 +1,6 @@ //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu -shaderobj -output-using-type //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer outputBuffer; diff --git a/tests/language-feature/ifunc/diff-functor.slang b/tests/language-feature/ifunc/diff-functor.slang new file mode 100644 index 000000000..04b0be44f --- /dev/null +++ b/tests/language-feature/ifunc/diff-functor.slang @@ -0,0 +1,44 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type + +struct DiffFunctor : IDifferentiableFunc +{ + [Differentiable] + float __call(float p) + { + return p + 1; + } +} + +float apply(IMutatingFunc f, float p) +{ + return f.__call(p); +} + +[Differentiable] +float applyDiff(IDifferentiableFunc f, float p) +{ + return f.__call(p); +} + +[Differentiable] +TR applyDiffGen(IDifferentiableFunc f, expand each TP p) +{ + return f.__call(expand each p); +} + +//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint tid: SV_DispatchThreadID) +{ + // CHECK: 4 + outputBuffer[0] = (uint)apply(DiffFunctor(), 3.0); + // CHECK: 1 + outputBuffer[1] = (uint)fwd_diff(applyDiff)(DiffFunctor(), diffPair(2.0, 1.0)).d; + // CHECK: 1 + outputBuffer[2] = (uint)fwd_diff(applyDiffGen)(DiffFunctor(), diffPair(2.0, 1.0)).d; +} diff --git a/tests/language-feature/ifunc/ifunc.slang b/tests/language-feature/ifunc/ifunc.slang new file mode 100644 index 000000000..f270299b3 --- /dev/null +++ b/tests/language-feature/ifunc/ifunc.slang @@ -0,0 +1,40 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type + +struct Functor : IFunc +{ + int __call(int p, bool t) + { + return p + 1; + } +} + +struct MutatingFunctor : IMutatingFunc +{ + int data = 0; + [mutating] + int __call(int p, bool t) + { + data++; + return p + 1; + } +} + +int apply(IMutatingFunc f, int p) +{ + return f.__call(p, true); +} + +//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint tid: SV_DispatchThreadID) +{ + // CHECK: 2 + outputBuffer[0] = apply(MutatingFunctor(), 1); + // CHECK: 3 + outputBuffer[1] = apply(Functor(), 2); +} -- cgit v1.2.3