summaryrefslogtreecommitdiffstats
path: root/tests/bugs
diff options
context:
space:
mode:
Diffstat (limited to 'tests/bugs')
-rw-r--r--tests/bugs/gh-6589.slang63
1 files changed, 63 insertions, 0 deletions
diff --git a/tests/bugs/gh-6589.slang b/tests/bugs/gh-6589.slang
new file mode 100644
index 000000000..9433f510e
--- /dev/null
+++ b/tests/bugs/gh-6589.slang
@@ -0,0 +1,63 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+
+// This is a test that checks that we can apply partial specialization to a function
+// we won't specialize the function parameters too aggressively. Instead, we will specialize
+// the parameters at the same time of specializing the arguments. Otherwise, we could lose
+// the chance to specialize the argument.
+//
+// In this test, `matrix_vector_interfaces` will be fully specialized, otherwise the compile
+// will fail because we don't allow opaque type in the existential type. So as long as the target
+// spirv code can be generated, we are good.
+
+// CHECK: %main
+public interface ITensor<T : IDifferentiable, let D : int>
+{
+ public T get(int idx);
+
+}
+
+public interface IRWTensor<T : IDifferentiable, let D : int> : ITensor<T, D>
+{
+}
+
+
+public struct RWTensor<T : IDifferentiable, let D : int> : IRWTensor<T, D>
+{
+ public RWStructuredBuffer<T> buffer;
+ public T get(int idx) { return buffer[idx]; }
+}
+
+public struct GradInOutTensor<T : IDifferentiable, let D : int> : IRWTensor<T, D>
+{
+ public RWTensor<T, D> primal;
+ public T get(int idx) { return primal.get(idx); }
+}
+
+struct CallData
+{
+ GradInOutTensor<float, 3> weights;
+ GradInOutTensor<float, 2> biases;
+ RWStructuredBuffer<float> _result;
+}
+ParameterBlock<CallData> call_data;
+
+float matrix_vector_interfaces(ITensor<float, 2> weights, ITensor<float, 1> biases)
+{
+ return weights.get(0);
+}
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void main(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ float _result;
+ GradInOutTensor<float, 2> weights;
+ GradInOutTensor<float, 1> biases;
+
+ weights.primal.buffer = call_data.weights.primal.buffer;
+ biases.primal.buffer = call_data.biases.primal.buffer;
+
+ _result = matrix_vector_interfaces(weights, biases);
+
+ call_data._result[0] = _result;
+}