diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-07-10 16:49:41 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-10 13:49:41 -0700 |
| commit | 45ef0ce906c93c16495755fec2e597573e8631c4 (patch) | |
| tree | ddb0ed618683488893d1c076f95b8e6e9e14d8ff | |
| parent | 16a47816747ca9a9de67b842a73f0e981dbc8b91 (diff) | |
Fix lowering of associated types and synthesis of dispatch functions. (#4568)
* Treat global variables and parameters as non-differentiable when checking derivative data-flow
Global parameters are by-default not differentiable (even if they are of a differentiable type), because our auto-diff passes do not touch anything outside of function bodies.
The solution is to use wrapper objects with differentiable getter/setter methods (and we should provide a few such objects in the stdlib).
Fixes: #3289
This is a potentially breaking change: User code that was previously working with global variables of a differentiable type will now throw an error (previously the gradient would be dropped without warning). The solution is to use `detach()` to keep same behavior as before or rewrite the access using differentiable getter/setter methods.
* Fix issues with lookup witness lowering
* Update slang-ir-lower-witness-lookup.cpp
* Add tests
* Update slang-ir-lower-witness-lookup.cpp
* Cleanup
* Update nested-assoc-types.slang
---------
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/slang-ir-insts.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 2 | ||||
| -rw-r--r-- | tests/compute/nested-assoc-types.slang | 118 | ||||
| -rw-r--r-- | tests/compute/nested-assoc-types.slang.expected.txt | 6 |
5 files changed, 126 insertions, 8 deletions
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 83b38b3b6..f0fd38061 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1292,12 +1292,6 @@ struct IRGetSequentialID : IRInst IRInst* getRTTIOperand() { return getOperand(0); } }; -struct IRLookupWitnessTable : IRInst -{ - IRUse sourceType; - IRUse interfaceType; -}; - /// Allocates space from local stack. /// struct IRAlloca : IRInst diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index 5a7fd9412..12941469d 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -158,7 +158,7 @@ struct AssociatedTypeLookupSpecializationContext builder.setInsertBefore(inst); auto witnessTableArg = inst->getWitnessTable(); auto callInst = builder.emitCallInst( - builder.getWitnessTableIDType(interfaceType), func, witnessTableArg); + func->getResultType(), func, witnessTableArg); inst->replaceUsesWith(callInst); inst->removeAndDeallocate(); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 6fa2ce67f..d8d573d63 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -10288,7 +10288,7 @@ static void _addFlattenedTupleArgs( bool isAbstractWitnessTable(IRInst* inst) { - if (as<IRThisTypeWitness>(inst)) + if (as<IRThisTypeWitness>(inst) || as<IRInterfaceRequirementEntry>(inst)) return true; if (auto lookup = as<IRLookupWitnessMethod>(inst)) return isAbstractWitnessTable(lookup->getWitnessTable()); diff --git a/tests/compute/nested-assoc-types.slang b/tests/compute/nested-assoc-types.slang new file mode 100644 index 000000000..374e31d6b --- /dev/null +++ b/tests/compute/nested-assoc-types.slang @@ -0,0 +1,118 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IFoo +{ + float foo(); +} + +[anyValueSize(16)] +interface INestedInterface +{ + associatedtype NestedAssocType : IFoo; +} + +[anyValueSize(16)] +interface IInterface +{ + associatedtype MyAssocType : INestedInterface; + MyAssocType.NestedAssocType calc(float x); +} + +// ================================ + +struct A_Assoc_Assoc : IFoo +{ + float a; + + float foo() + { + return a; + } +} + +struct A_Assoc : INestedInterface +{ + typedef A_Assoc_Assoc NestedAssocType; +} + +struct A : IInterface +{ + typedef A_Assoc MyAssocType + + int data1; + + __init(int data1) { this.data1 = data1; } + + A_Assoc_Assoc calc(float x) { return { x * x * x * data1 }; } +}; + +// ================================ + +struct B_Assoc_Assoc : IFoo +{ + float b; + + float foo() + { + return b; + } +} + +struct B_Assoc : INestedInterface +{ + typedef B_Assoc_Assoc NestedAssocType; +} + +struct B : IInterface +{ + typedef B_Assoc MyAssocType + + int data1; + int data2; + + __init(int data1, int data2) { this.data1 = data1; this.data2 = data2; } + + B_Assoc_Assoc calc(float x) { return { x * x * data1 * data2 }; } +}; + +// ================================ + +float doThing(IInterface obj, float x) +{ + let o = obj.calc(x); + return o.foo(); +} + +float f(uint id, float x) +{ + IInterface obj; + + switch (id) + { + case 0: + obj = A(2); + break; + + default: + obj = B(2, 3); + } + + return doThing(obj, x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[0] = f(dispatchThreadID.x, 1.0); // A.calc, expect 2 + outputBuffer[1] = f(dispatchThreadID.x + 1, 1.5); // B.calc, expect 13.5 +}
\ No newline at end of file diff --git a/tests/compute/nested-assoc-types.slang.expected.txt b/tests/compute/nested-assoc-types.slang.expected.txt new file mode 100644 index 000000000..91a52a345 --- /dev/null +++ b/tests/compute/nested-assoc-types.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +2.000000 +13.500000 +0.000000 +0.000000 +0.000000 |
