diff options
| author | Yong He <yonghe@outlook.com> | 2024-12-09 04:47:53 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-09 20:47:53 +0800 |
| commit | 051ae8acec0a641bcaf86e7eeff35eff29e8922d (patch) | |
| tree | 4e385415742ad98c8842454fda14a9abb8112cb2 /tests | |
| parent | 71e90a7ba78d0566e3b7da54df48f9af598e4cbb (diff) | |
Fix crash during emitCast of attributed type, allow MaxIters to take linktime const. (#5791)
* Fix crash during emitCast of attributed type.
* Allow [MaxIters] to take link time constants.
---------
Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bugs/gh-5781.slang | 57 | ||||
| -rw-r--r-- | tests/language-feature/constants/max-iters-link-time-const.slang | 15 |
2 files changed, 72 insertions, 0 deletions
diff --git a/tests/bugs/gh-5781.slang b/tests/bugs/gh-5781.slang new file mode 100644 index 000000000..33456f500 --- /dev/null +++ b/tests/bugs/gh-5781.slang @@ -0,0 +1,57 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv +// CHECK: OpEntryPoint + +module test; + +public enum class MaterialID : uint { invalid = 0xffffffff }; + +public struct Material : IDifferentiable +{ + float x; +} + +public struct Hit +{ + MaterialID material; +} + +public struct Scene +{ + StructuredBuffer<Material> materials; + RWStructuredBuffer<Material> grads; + + [Differentiable] + Material load(MaterialID id) { return materials[uint(id)]; } + + void accumulate(MaterialID id, Material d) { grads[uint(id)].x += d.x; } + + [Differentiable, BackwardDerivative(_get_material_bwd)] + public Material get_material(MaterialID id) { return load(id); } + + public void _get_material_bwd(MaterialID id, Material d) { accumulate(id, d); } + + [Differentiable] + public Material get_material(Hit hit) { return get_material(hit.material); } +} + +[Differentiable] +float trace(const Scene scene, Hit hit) +{ + Material m = scene.get_material(hit); + return m.x; +} + + +[shader("compute")] +void main( + uniform Scene scene, + uniform StructuredBuffer<uint> input, + uniform RWStructuredBuffer<float> output, + uniform RWStructuredBuffer<float> grads +) +{ + Hit hit; + hit.material = MaterialID(input[0]); + output[0] = trace(scene, hit); + bwd_diff(trace)(scene, hit, grads[0]); +}
\ No newline at end of file diff --git a/tests/language-feature/constants/max-iters-link-time-const.slang b/tests/language-feature/constants/max-iters-link-time-const.slang new file mode 100644 index 000000000..cf1ccbbd1 --- /dev/null +++ b/tests/language-feature/constants/max-iters-link-time-const.slang @@ -0,0 +1,15 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv +// CHECK: OpEntryPoint + +extern static const int num = 10; +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1,1,1)] +void computeMain() +{ + [MaxIters(num)] + for (int i = 0; i < num; i++) + { + outputBuffer[i] = i; + } +} |
