diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-07-10 16:49:41 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-10 13:49:41 -0700 |
| commit | 45ef0ce906c93c16495755fec2e597573e8631c4 (patch) | |
| tree | ddb0ed618683488893d1c076f95b8e6e9e14d8ff /tests | |
| parent | 16a47816747ca9a9de67b842a73f0e981dbc8b91 (diff) | |
Fix lowering of associated types and synthesis of dispatch functions. (#4568)
* Treat global variables and parameters as non-differentiable when checking derivative data-flow
Global parameters are by-default not differentiable (even if they are of a differentiable type), because our auto-diff passes do not touch anything outside of function bodies.
The solution is to use wrapper objects with differentiable getter/setter methods (and we should provide a few such objects in the stdlib).
Fixes: #3289
This is a potentially breaking change: User code that was previously working with global variables of a differentiable type will now throw an error (previously the gradient would be dropped without warning). The solution is to use `detach()` to keep same behavior as before or rewrite the access using differentiable getter/setter methods.
* Fix issues with lookup witness lowering
* Update slang-ir-lower-witness-lookup.cpp
* Add tests
* Update slang-ir-lower-witness-lookup.cpp
* Cleanup
* Update nested-assoc-types.slang
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/compute/nested-assoc-types.slang | 118 | ||||
| -rw-r--r-- | tests/compute/nested-assoc-types.slang.expected.txt | 6 |
2 files changed, 124 insertions, 0 deletions
diff --git a/tests/compute/nested-assoc-types.slang b/tests/compute/nested-assoc-types.slang new file mode 100644 index 000000000..374e31d6b --- /dev/null +++ b/tests/compute/nested-assoc-types.slang @@ -0,0 +1,118 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IFoo +{ + float foo(); +} + +[anyValueSize(16)] +interface INestedInterface +{ + associatedtype NestedAssocType : IFoo; +} + +[anyValueSize(16)] +interface IInterface +{ + associatedtype MyAssocType : INestedInterface; + MyAssocType.NestedAssocType calc(float x); +} + +// ================================ + +struct A_Assoc_Assoc : IFoo +{ + float a; + + float foo() + { + return a; + } +} + +struct A_Assoc : INestedInterface +{ + typedef A_Assoc_Assoc NestedAssocType; +} + +struct A : IInterface +{ + typedef A_Assoc MyAssocType + + int data1; + + __init(int data1) { this.data1 = data1; } + + A_Assoc_Assoc calc(float x) { return { x * x * x * data1 }; } +}; + +// ================================ + +struct B_Assoc_Assoc : IFoo +{ + float b; + + float foo() + { + return b; + } +} + +struct B_Assoc : INestedInterface +{ + typedef B_Assoc_Assoc NestedAssocType; +} + +struct B : IInterface +{ + typedef B_Assoc MyAssocType + + int data1; + int data2; + + __init(int data1, int data2) { this.data1 = data1; this.data2 = data2; } + + B_Assoc_Assoc calc(float x) { return { x * x * data1 * data2 }; } +}; + +// ================================ + +float doThing(IInterface obj, float x) +{ + let o = obj.calc(x); + return o.foo(); +} + +float f(uint id, float x) +{ + IInterface obj; + + switch (id) + { + case 0: + obj = A(2); + break; + + default: + obj = B(2, 3); + } + + return doThing(obj, x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(dispatchThreadID.x, 1.0); // A.calc, expect 2 + outputBuffer[1] = f(dispatchThreadID.x + 1, 1.5); // B.calc, expect 13.5 +}
\ No newline at end of file diff --git a/tests/compute/nested-assoc-types.slang.expected.txt b/tests/compute/nested-assoc-types.slang.expected.txt new file mode 100644 index 000000000..91a52a345 --- /dev/null +++ b/tests/compute/nested-assoc-types.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +2.000000 +13.500000 +0.000000 +0.000000 +0.000000 |
