From 45ef0ce906c93c16495755fec2e597573e8631c4 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Wed, 10 Jul 2024 16:49:41 -0400 Subject: Fix lowering of associated types and synthesis of dispatch functions. (#4568) * Treat global variables and parameters as non-differentiable when checking derivative data-flow Global parameters are by-default not differentiable (even if they are of a differentiable type), because our auto-diff passes do not touch anything outside of function bodies. The solution is to use wrapper objects with differentiable getter/setter methods (and we should provide a few such objects in the stdlib). Fixes: #3289 This is a potentially breaking change: User code that was previously working with global variables of a differentiable type will now throw an error (previously the gradient would be dropped without warning). The solution is to use `detach()` to keep same behavior as before or rewrite the access using differentiable getter/setter methods. * Fix issues with lookup witness lowering * Update slang-ir-lower-witness-lookup.cpp * Add tests * Update slang-ir-lower-witness-lookup.cpp * Cleanup * Update nested-assoc-types.slang --------- Co-authored-by: Yong He --- tests/compute/nested-assoc-types.slang | 118 +++++++++++++++++++++ .../compute/nested-assoc-types.slang.expected.txt | 6 ++ 2 files changed, 124 insertions(+) create mode 100644 tests/compute/nested-assoc-types.slang create mode 100644 tests/compute/nested-assoc-types.slang.expected.txt (limited to 'tests/compute') diff --git a/tests/compute/nested-assoc-types.slang b/tests/compute/nested-assoc-types.slang new file mode 100644 index 000000000..374e31d6b --- /dev/null +++ b/tests/compute/nested-assoc-types.slang @@ -0,0 +1,118 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[anyValueSize(16)] +interface IFoo +{ + float foo(); +} + +[anyValueSize(16)] +interface INestedInterface +{ + associatedtype NestedAssocType : IFoo; +} + +[anyValueSize(16)] +interface IInterface +{ + associatedtype MyAssocType : INestedInterface; + MyAssocType.NestedAssocType calc(float x); +} + +// ================================ + +struct A_Assoc_Assoc : IFoo +{ + float a; + + float foo() + { + return a; + } +} + +struct A_Assoc : INestedInterface +{ + typedef A_Assoc_Assoc NestedAssocType; +} + +struct A : IInterface +{ + typedef A_Assoc MyAssocType + + int data1; + + __init(int data1) { this.data1 = data1; } + + A_Assoc_Assoc calc(float x) { return { x * x * x * data1 }; } +}; + +// ================================ + +struct B_Assoc_Assoc : IFoo +{ + float b; + + float foo() + { + return b; + } +} + +struct B_Assoc : INestedInterface +{ + typedef B_Assoc_Assoc NestedAssocType; +} + +struct B : IInterface +{ + typedef B_Assoc MyAssocType + + int data1; + int data2; + + __init(int data1, int data2) { this.data1 = data1; this.data2 = data2; } + + B_Assoc_Assoc calc(float x) { return { x * x * data1 * data2 }; } +}; + +// ================================ + +float doThing(IInterface obj, float x) +{ + let o = obj.calc(x); + return o.foo(); +} + +float f(uint id, float x) +{ + IInterface obj; + + switch (id) + { + case 0: + obj = A(2); + break; + + default: + obj = B(2, 3); + } + + return doThing(obj, x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(dispatchThreadID.x, 1.0); // A.calc, expect 2 + outputBuffer[1] = f(dispatchThreadID.x + 1, 1.5); // B.calc, expect 13.5 +} \ No newline at end of file diff --git a/tests/compute/nested-assoc-types.slang.expected.txt b/tests/compute/nested-assoc-types.slang.expected.txt new file mode 100644 index 000000000..91a52a345 --- /dev/null +++ b/tests/compute/nested-assoc-types.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +2.000000 +13.500000 +0.000000 +0.000000 +0.000000 -- cgit v1.2.3