summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/bugs/gh-5781.slang57
-rw-r--r--tests/language-feature/constants/max-iters-link-time-const.slang15
2 files changed, 72 insertions, 0 deletions
diff --git a/tests/bugs/gh-5781.slang b/tests/bugs/gh-5781.slang
new file mode 100644
index 000000000..33456f500
--- /dev/null
+++ b/tests/bugs/gh-5781.slang
@@ -0,0 +1,57 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+// CHECK: OpEntryPoint
+
+module test;
+
+public enum class MaterialID : uint { invalid = 0xffffffff };
+
+public struct Material : IDifferentiable
+{
+ float x;
+}
+
+public struct Hit
+{
+ MaterialID material;
+}
+
+public struct Scene
+{
+ StructuredBuffer<Material> materials;
+ RWStructuredBuffer<Material> grads;
+
+ [Differentiable]
+ Material load(MaterialID id) { return materials[uint(id)]; }
+
+ void accumulate(MaterialID id, Material d) { grads[uint(id)].x += d.x; }
+
+ [Differentiable, BackwardDerivative(_get_material_bwd)]
+ public Material get_material(MaterialID id) { return load(id); }
+
+ public void _get_material_bwd(MaterialID id, Material d) { accumulate(id, d); }
+
+ [Differentiable]
+ public Material get_material(Hit hit) { return get_material(hit.material); }
+}
+
+[Differentiable]
+float trace(const Scene scene, Hit hit)
+{
+ Material m = scene.get_material(hit);
+ return m.x;
+}
+
+
+[shader("compute")]
+void main(
+ uniform Scene scene,
+ uniform StructuredBuffer<uint> input,
+ uniform RWStructuredBuffer<float> output,
+ uniform RWStructuredBuffer<float> grads
+)
+{
+ Hit hit;
+ hit.material = MaterialID(input[0]);
+ output[0] = trace(scene, hit);
+ bwd_diff(trace)(scene, hit, grads[0]);
+} \ No newline at end of file
diff --git a/tests/language-feature/constants/max-iters-link-time-const.slang b/tests/language-feature/constants/max-iters-link-time-const.slang
new file mode 100644
index 000000000..cf1ccbbd1
--- /dev/null
+++ b/tests/language-feature/constants/max-iters-link-time-const.slang
@@ -0,0 +1,15 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+// CHECK: OpEntryPoint
+
+extern static const int num = 10;
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ [MaxIters(num)]
+ for (int i = 0; i < num; i++)
+ {
+ outputBuffer[i] = i;
+ }
+}