From 977e4b21d69406b0b68c5963f50489d7433db830 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 11 Jul 2024 13:46:21 -0400 Subject: Fix issue with synthesizing `Differential` type for self-differential generic types (#4602) * Fix issue with synthesizing `Differential` type for self-differential generic types The problem was that we were using the type that was performing the lookup for `.Differential` which can have substitutions based on the local context where the decl is being referenced. We need to synthesize the type local to the decl itself * Update auto-differential-type-generic.slang --- .../autodiff/auto-differential-type-generic.slang | 64 ++++++++++++++++++++++ ...to-differential-type-generic.slang.expected.txt | 6 ++ 2 files changed, 70 insertions(+) create mode 100644 tests/autodiff/auto-differential-type-generic.slang create mode 100644 tests/autodiff/auto-differential-type-generic.slang.expected.txt (limited to 'tests') diff --git a/tests/autodiff/auto-differential-type-generic.slang b/tests/autodiff/auto-differential-type-generic.slang new file mode 100644 index 000000000..060495e56 --- /dev/null +++ b/tests/autodiff/auto-differential-type-generic.slang @@ -0,0 +1,64 @@ +// Tests automatic synthesis of Differential type requirement for generic types. +// +// This specifically tests a synthesis path that occurs when the lookup of the Differential type happens before the conformance-check. +// If this path doesn't construct the generic differential type correctly, it will throw an error when constructing the array +// in this line: Feature<3>.Differential b = {0.2, 0.3, 0.4}; +// + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +__generic +struct Feature: IDifferentiable +{ + float vals[C]; +} + + +struct Linear +{ + typedef Feature Input; + typedef Feature Output; + + [BackwardDerivative(eval_bwd)] + Output eval(Input in_feature) + { + Output out_feature; + for (int i = 0; i < C; i++) + { + out_feature.vals[i] = in_feature.vals[i] * 2.0; + } + return out_feature; + } + + void eval_bwd(inout DifferentialPair in_feature_pair, Feature.Differential d_output) + { + /* empty.. doesn't really matter */ + } +} + +[Differentiable] +Feature<3> f(Feature<3> a, Linear<3> layer) +{ + return layer.eval(a); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + Feature<3> a = {1.0, 2.0, 3.0}; + Feature<3>.Differential b = {0.2, 0.3, 0.4}; + + Linear<3> layer; + + var dpA = diffPair(a, b); + + var result = fwd_diff(f)(dpA, layer).d; + + outputBuffer[0] = result.vals[0]; + outputBuffer[1] = result.vals[1]; + outputBuffer[2] = result.vals[2]; +} \ No newline at end of file diff --git a/tests/autodiff/auto-differential-type-generic.slang.expected.txt b/tests/autodiff/auto-differential-type-generic.slang.expected.txt new file mode 100644 index 000000000..c0d274b2b --- /dev/null +++ b/tests/autodiff/auto-differential-type-generic.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.400000 +0.600000 +0.800000 +0.000000 +0.000000 -- cgit v1.2.3