diff options
Diffstat (limited to 'tests/autodiff')
| -rw-r--r-- | tests/autodiff/custom-differential-type-error.slang | 26 | ||||
| -rw-r--r-- | tests/autodiff/custom-differential-type.slang | 46 |
2 files changed, 72 insertions, 0 deletions
diff --git a/tests/autodiff/custom-differential-type-error.slang b/tests/autodiff/custom-differential-type-error.slang new file mode 100644 index 000000000..2542eff4f --- /dev/null +++ b/tests/autodiff/custom-differential-type-error.slang @@ -0,0 +1,26 @@ +//TEST:SIMPLE(filecheck=CHECK):-target spirv + +interface IFoo : IDifferentiable +{ + float myFunc(); +} + +struct Bar: IDifferentiable +{ + float y; +} + +struct Foo : IFoo +{ + typealias Differential = Bar; + + // CHECK: ([[# @LINE+1]]): error 30027: 'x' is not a member of 'typeof(Foo.Differential)'. + [DerivativeMember(Differential.x)] + float x; + + [BackwardDifferentiable] + float myFunc() + { + return x * x; + }; +} diff --git a/tests/autodiff/custom-differential-type.slang b/tests/autodiff/custom-differential-type.slang new file mode 100644 index 000000000..4782b4297 --- /dev/null +++ b/tests/autodiff/custom-differential-type.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-cuda -compute -shaderobj -output-using-type + +interface IFoo : IDifferentiable +{ + float myFunc(); +} + +struct Bar: IDifferentiable +{ + float y; +} + +struct Foo : IFoo +{ + typealias Differential = Bar; + + [DerivativeMember(Differential.y)] + float x; + + [BackwardDifferentiable] + float myFunc() + { + return x * x; + }; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=g_output +RWStructuredBuffer<float> g_output; + +[BackwardDifferentiable] +float wrapper(Foo f) +{ + return f.myFunc(); +} + +[shader("compute")] +void computeMain(uint3 thread_id : SV_DispatchThreadID) +{ + var di = diffPair(Foo(1.0f), Bar(0.0f)); + bwd_diff(wrapper)(di, 1.0f); + + // CHECK: 2.0 + g_output[0] = di.d.y; +} |
