From a9f2f8a592c4514cd116c947486055788092ea56 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 25 Feb 2025 10:42:19 -0800 Subject: Fix `UseGraph::replace` (#6395) * Fix `UseGraph::isTrivial()` test. * Fix. * Fix. * Refactor `UseGraph` and `UseChain` * Update slang-ir-autodiff-primal-hoist.cpp * Update all auto-diff locations that handle pointers to treat user pointers as regular values * Update test to use direct-SPIRV only --------- Co-authored-by: Yong He --- tests/autodiff/dynamic-dispatch-ptr.slang | 43 +++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/autodiff/dynamic-dispatch-ptr.slang (limited to 'tests') diff --git a/tests/autodiff/dynamic-dispatch-ptr.slang b/tests/autodiff/dynamic-dispatch-ptr.slang new file mode 100644 index 000000000..3f2269f78 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-ptr.slang @@ -0,0 +1,43 @@ +//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type -emit-spirv-directly + +//CHECK: 1.0 + +//TEST_INPUT: type_conformance Sensor:ISensor = 1; + +[anyValueSize(16)] +interface ISensor +{ + [Differentiable] + float4 splat(float4 point); +} + +struct Sensor : ISensor +{ + [Differentiable] + float4 splat(float4 point) + { + return point; + } +} + +[Differentiable] +float4 splat(ISensor* obj, float4 point) +{ + return obj->splat(point); +} + +//TEST_INPUT: set s = ubuffer(data=[0 0 1 0 0 0 0 0]) +uniform ISensor *s; + +//TEST_INPUT: set outBuffer = out ubuffer(data=[0 0 0 0], stride=4) +RWStructuredBuffer outBuffer; + +[shader("compute"), numthreads(1, 1, 1)] +void computeMain( + uint3 id : SV_DispatchThreadID +) +{ + DifferentialPair dp; + bwd_diff(splat)(s, dp, float4(1.0f)); + outBuffer[id.x] = dp.d; +} \ No newline at end of file -- cgit v1.2.3