summaryrefslogtreecommitdiffstats
path: root/tests/diagnostics
diff options
context:
space:
mode:
Diffstat (limited to 'tests/diagnostics')
-rw-r--r--tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang40
-rw-r--r--tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang48
-rw-r--r--tests/diagnostics/force-no-diff-this.slang42
3 files changed, 130 insertions, 0 deletions
diff --git a/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang
new file mode 100644
index 000000000..961ac75d5
--- /dev/null
+++ b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement.slang
@@ -0,0 +1,40 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+
+float someNoDiffFunc(float x, no_diff float y)
+{
+ return x * x + y * y;
+}
+
+// Previously, when we call a no-diff function side a differntiable function, we will have to use no_diff to tell compiler that this is intended.
+// However, if the parameter is just a constant, there is no need to use no_diff, because constant won't carry any derivative information.
+// Therefore, this test is to check we won't report any error when the parameter is a constant in this case.
+[Differentiable]
+float eval(float x)
+{
+ // CHECK-NOT: ([[# @LINE+1]]): error 41020
+ return exp(x) - someNoDiffFunc(1.0f, x);
+}
+
+[Differentiable]
+float eval1(float x)
+{
+ // CHECK: ([[# @LINE+1]]): error 41020
+ return exp(x) - someNoDiffFunc(x, 1.0);
+}
+
+RWStructuredBuffer<float> output;
+
+[shader("compute")]
+[numthreads(1,1,1)]
+void computeMain(uint id : SV_DispatchThreadID)
+{
+ var x = diffPair(2.0f);
+ bwd_diff(eval)(x, 1.0f);
+
+ output[0] = x.d;
+
+ var x1 = diffPair(2.0f);
+ bwd_diff(eval1)(x1, 1.0f);
+ output[1] = x1.d;
+}
+
diff --git a/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang
new file mode 100644
index 000000000..f27c6ec6b
--- /dev/null
+++ b/tests/diagnostics/const-to-nodiff-function-diagnostic-improvement1.slang
@@ -0,0 +1,48 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+
+
+// Similar to const-to-nodiff-function-diagnostic-improvement.slang, but with a CoopVec type
+// to reproduce a more realistic scenario.
+extension<T : __BuiltinFloatingPointType, let K : int> CoopVec<T, K> : IDifferentiable
+{
+ typealias Differential = CoopVec<T, K>;
+};
+
+[BackwardDerivativeOf(exp)]
+void exp_BackwardAutoDiff<T : __BuiltinFloatingPointType, let K : int>(inout DifferentialPair<CoopVec<T, K>> p0, CoopVec<T, K>.Differential dResult)
+{
+ p0 = diffPair(p0.p, dResult * exp(p0.p));
+}
+
+[Differentiable]
+CoopVec<T, K> eval<T : __BuiltinFloatingPointType, let K : int>(CoopVec<T, K> x)
+{
+ // CHECK-NOT: ([[# @LINE+1]]): error 41020
+ return exp(x) - CoopVec<T, K>(1.);
+}
+
+[Differentiable]
+CoopVec<T, K> eval1<T : __BuiltinFloatingPointType, let K : int>(CoopVec<T, K> x)
+{
+ // test.slang(25): error 41020: derivative cannot be propagated through call to non-backward-differentiable function `CoopVec.$init`, use 'no_diff' to clarify intention.
+ // CHECK: ([[# @LINE+1]]): error 41020
+ return exp(x) - CoopVec<T, K>(x[0]);
+}
+
+
+RWStructuredBuffer<float> output;
+
+[shader("compute")]
+[numthreads(1,1,1)]
+void computeMain(uint id : SV_DispatchThreadID)
+{
+ var x = diffPair(CoopVec<float, 2>(2.0f), CoopVec<float, 2>(1.0f));
+ bwd_diff(eval)(x, CoopVec<float, 2>(1.0f));
+
+ output[0] = x.d[0];
+
+ var x1 = diffPair(CoopVec<float, 2>(2.0f), CoopVec<float, 2>(1.0f));
+ bwd_diff(eval1)(x1, CoopVec<float, 2>(1.0f));
+ output[1] = x1.d[1];
+}
+
diff --git a/tests/diagnostics/force-no-diff-this.slang b/tests/diagnostics/force-no-diff-this.slang
new file mode 100644
index 000000000..ae1464ffb
--- /dev/null
+++ b/tests/diagnostics/force-no-diff-this.slang
@@ -0,0 +1,42 @@
+//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):
+
+struct MyStruct<T> where T: __BuiltinFloatingPointType
+{
+ float a;
+ __init(float a) { this.a = a;}
+
+ [Differentiable]
+ T eval(T x)
+ {
+ //CHECK: ([[# @LINE+1]]): warning 31159
+ return exp(x * T(a) * T(a));
+ }
+
+ [Differentiable]
+ [NoDiffThis]
+ T eval1(T x)
+ {
+ //CHECK-NOT: ([[# @LINE+1]]): warning 31159
+ return exp(x * T(a) * T(a));
+ }
+};
+
+[Differentiable]
+float evalFunc(float x)
+{
+ MyStruct<float> s = {x};
+ return s.eval(x) + s.eval1(x);
+}
+
+RWStructuredBuffer<float> output;
+
+[shader("compute")]
+[numthreads(1,1,1)]
+void computeMain(uint id : SV_DispatchThreadID)
+{
+ var x = diffPair(2.0f);
+ bwd_diff(evalFunc)(x, 1.0f);
+
+ output[0] = x.d;
+}
+