summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-07-10 16:49:41 -0400
committerGitHub <noreply@github.com>2024-07-10 13:49:41 -0700
commit45ef0ce906c93c16495755fec2e597573e8631c4 (patch)
treeddb0ed618683488893d1c076f95b8e6e9e14d8ff
parent16a47816747ca9a9de67b842a73f0e981dbc8b91 (diff)
Fix lowering of associated types and synthesis of dispatch functions. (#4568)
* Treat global variables and parameters as non-differentiable when checking derivative data-flow Global parameters are by-default not differentiable (even if they are of a differentiable type), because our auto-diff passes do not touch anything outside of function bodies. The solution is to use wrapper objects with differentiable getter/setter methods (and we should provide a few such objects in the stdlib). Fixes: #3289 This is a potentially breaking change: User code that was previously working with global variables of a differentiable type will now throw an error (previously the gradient would be dropped without warning). The solution is to use `detach()` to keep same behavior as before or rewrite the access using differentiable getter/setter methods. * Fix issues with lookup witness lowering * Update slang-ir-lower-witness-lookup.cpp * Add tests * Update slang-ir-lower-witness-lookup.cpp * Cleanup * Update nested-assoc-types.slang --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-ir-insts.h6
-rw-r--r--source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp2
-rw-r--r--tests/compute/nested-assoc-types.slang118
-rw-r--r--tests/compute/nested-assoc-types.slang.expected.txt6
5 files changed, 126 insertions, 8 deletions
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 83b38b3b6..f0fd38061 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1292,12 +1292,6 @@ struct IRGetSequentialID : IRInst
IRInst* getRTTIOperand() { return getOperand(0); }
};
-struct IRLookupWitnessTable : IRInst
-{
- IRUse sourceType;
- IRUse interfaceType;
-};
-
/// Allocates space from local stack.
///
struct IRAlloca : IRInst
diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
index 5a7fd9412..12941469d 100644
--- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
+++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
@@ -158,7 +158,7 @@ struct AssociatedTypeLookupSpecializationContext
builder.setInsertBefore(inst);
auto witnessTableArg = inst->getWitnessTable();
auto callInst = builder.emitCallInst(
- builder.getWitnessTableIDType(interfaceType), func, witnessTableArg);
+ func->getResultType(), func, witnessTableArg);
inst->replaceUsesWith(callInst);
inst->removeAndDeallocate();
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 6fa2ce67f..d8d573d63 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -10288,7 +10288,7 @@ static void _addFlattenedTupleArgs(
bool isAbstractWitnessTable(IRInst* inst)
{
- if (as<IRThisTypeWitness>(inst))
+ if (as<IRThisTypeWitness>(inst) || as<IRInterfaceRequirementEntry>(inst))
return true;
if (auto lookup = as<IRLookupWitnessMethod>(inst))
return isAbstractWitnessTable(lookup->getWitnessTable());
diff --git a/tests/compute/nested-assoc-types.slang b/tests/compute/nested-assoc-types.slang
new file mode 100644
index 000000000..374e31d6b
--- /dev/null
+++ b/tests/compute/nested-assoc-types.slang
@@ -0,0 +1,118 @@
+// Test calling differentiable function through dynamic dispatch.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+[anyValueSize(16)]
+interface IFoo
+{
+ float foo();
+}
+
+[anyValueSize(16)]
+interface INestedInterface
+{
+ associatedtype NestedAssocType : IFoo;
+}
+
+[anyValueSize(16)]
+interface IInterface
+{
+ associatedtype MyAssocType : INestedInterface;
+ MyAssocType.NestedAssocType calc(float x);
+}
+
+// ================================
+
+struct A_Assoc_Assoc : IFoo
+{
+ float a;
+
+ float foo()
+ {
+ return a;
+ }
+}
+
+struct A_Assoc : INestedInterface
+{
+ typedef A_Assoc_Assoc NestedAssocType;
+}
+
+struct A : IInterface
+{
+ typedef A_Assoc MyAssocType
+
+ int data1;
+
+ __init(int data1) { this.data1 = data1; }
+
+ A_Assoc_Assoc calc(float x) { return { x * x * x * data1 }; }
+};
+
+// ================================
+
+struct B_Assoc_Assoc : IFoo
+{
+ float b;
+
+ float foo()
+ {
+ return b;
+ }
+}
+
+struct B_Assoc : INestedInterface
+{
+ typedef B_Assoc_Assoc NestedAssocType;
+}
+
+struct B : IInterface
+{
+ typedef B_Assoc MyAssocType
+
+ int data1;
+ int data2;
+
+ __init(int data1, int data2) { this.data1 = data1; this.data2 = data2; }
+
+ B_Assoc_Assoc calc(float x) { return { x * x * data1 * data2 }; }
+};
+
+// ================================
+
+float doThing(IInterface obj, float x)
+{
+ let o = obj.calc(x);
+ return o.foo();
+}
+
+float f(uint id, float x)
+{
+ IInterface obj;
+
+ switch (id)
+ {
+ case 0:
+ obj = A(2);
+ break;
+
+ default:
+ obj = B(2, 3);
+ }
+
+ return doThing(obj, x);
+}
+
+//TEST_INPUT: type_conformance A:IInterface = 0
+//TEST_INPUT: type_conformance B:IInterface = 1
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ outputBuffer[0] = f(dispatchThreadID.x, 1.0); // A.calc, expect 2
+ outputBuffer[1] = f(dispatchThreadID.x + 1, 1.5); // B.calc, expect 13.5
+} \ No newline at end of file
diff --git a/tests/compute/nested-assoc-types.slang.expected.txt b/tests/compute/nested-assoc-types.slang.expected.txt
new file mode 100644
index 000000000..91a52a345
--- /dev/null
+++ b/tests/compute/nested-assoc-types.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+2.000000
+13.500000
+0.000000
+0.000000
+0.000000