summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-07-11 13:46:21 -0400
committerGitHub <noreply@github.com>2024-07-11 13:46:21 -0400
commit977e4b21d69406b0b68c5963f50489d7433db830 (patch)
treeaf5e68f910c1887ef3c638d8620449ab66d9faf3 /tests
parentc3061afffdc078914cc522538749c6e0c5e36f65 (diff)
Fix issue with synthesizing `Differential` type for self-differential generic types (#4602)
* Fix issue with synthesizing `Differential` type for self-differential generic types The problem was that we were using the type that was performing the lookup for `.Differential` which can have substitutions based on the local context where the decl is being referenced. We need to synthesize the type local to the decl itself * Update auto-differential-type-generic.slang
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/auto-differential-type-generic.slang64
-rw-r--r--tests/autodiff/auto-differential-type-generic.slang.expected.txt6
2 files changed, 70 insertions, 0 deletions
diff --git a/tests/autodiff/auto-differential-type-generic.slang b/tests/autodiff/auto-differential-type-generic.slang
new file mode 100644
index 000000000..060495e56
--- /dev/null
+++ b/tests/autodiff/auto-differential-type-generic.slang
@@ -0,0 +1,64 @@
+// Tests automatic synthesis of Differential type requirement for generic types.
+//
+// This specifically tests a synthesis path that occurs when the lookup of the Differential type happens before the conformance-check.
+// If this path doesn't construct the generic differential type correctly, it will throw an error when constructing the array
+// in this line: Feature<3>.Differential b = {0.2, 0.3, 0.4};
+//
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+__generic<let C : int>
+struct Feature: IDifferentiable
+{
+ float vals[C];
+}
+
+
+struct Linear<let C : int>
+{
+ typedef Feature<C> Input;
+ typedef Feature<C> Output;
+
+ [BackwardDerivative(eval_bwd)]
+ Output eval(Input in_feature)
+ {
+ Output out_feature;
+ for (int i = 0; i < C; i++)
+ {
+ out_feature.vals[i] = in_feature.vals[i] * 2.0;
+ }
+ return out_feature;
+ }
+
+ void eval_bwd(inout DifferentialPair<Input> in_feature_pair, Feature<C>.Differential d_output)
+ {
+ /* empty.. doesn't really matter */
+ }
+}
+
+[Differentiable]
+Feature<3> f(Feature<3> a, Linear<3> layer)
+{
+ return layer.eval(a);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ Feature<3> a = {1.0, 2.0, 3.0};
+ Feature<3>.Differential b = {0.2, 0.3, 0.4};
+
+ Linear<3> layer;
+
+ var dpA = diffPair(a, b);
+
+ var result = fwd_diff(f)(dpA, layer).d;
+
+ outputBuffer[0] = result.vals[0];
+ outputBuffer[1] = result.vals[1];
+ outputBuffer[2] = result.vals[2];
+} \ No newline at end of file
diff --git a/tests/autodiff/auto-differential-type-generic.slang.expected.txt b/tests/autodiff/auto-differential-type-generic.slang.expected.txt
new file mode 100644
index 000000000..c0d274b2b
--- /dev/null
+++ b/tests/autodiff/auto-differential-type-generic.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+0.400000
+0.600000
+0.800000
+0.000000
+0.000000