From 3d4eaf3c9b13e32c4e4d7737f17805503cddcb0b Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 15 Jan 2018 18:15:49 -0500 Subject: Support transitive interfaces This commit is a bunch of quick hacks to get transitive interfaces to work. The idea is for each concrete type we create one giant witness table that contains entries for all the transitively reachable interface requirements, and then create one copy of that witness table for each interface it implements. `DoLocalLookupImpl` now also looks up in inherited interface decles when looking up for a symbol in an interface decl. When visiting `InheritanceDecl` in `lower-to-ir`, create copies of the giant witness table for each transitively inherited interface, so that these witness tables can be found later when the IR is specialized. Re-enable the `copy all witness tables` hack in `specializeIRForEntryPoint` to ensure those transitive witness tables are copied over. --- source/slang/ir.cpp | 6 +- source/slang/lookup.cpp | 17 ++++++ source/slang/lower-to-ir.cpp | 27 +++++++++ tests/compute/transitive-interface.slang | 66 ++++++++++++++++++++++ .../transitive-interface.slang.expected.txt | 4 ++ 5 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 tests/compute/transitive-interface.slang create mode 100644 tests/compute/transitive-interface.slang.expected.txt diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 46eff33e9..17ba14c6e 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -3690,12 +3690,12 @@ namespace Slang cloneFunctionCommon(context, clonedFunc, originalFunc); // for now, clone all unreferenced witness tables - /*for (auto gv = context->getOriginalModule()->getFirstGlobalValue(); + for (auto gv = context->getOriginalModule()->getFirstGlobalValue(); gv; gv = gv->getNextValue()) { if (gv->op == kIROp_witness_table) cloneGlobalValue(context, (IRWitnessTable*)gv); - }*/ + } // We need to attach the layout information for // the entry point to this declaration, so that @@ -4746,7 +4746,9 @@ namespace Slang // // We will first find or construct a specialized version // of the callee funciton/ + auto oldFunc = dumpIRFunc(genericFunc); auto specFunc = getSpecializedFunc(sharedContext, genericFunc, specDeclRef); + auto newFunc = dumpIRFunc(specFunc); // // Then we will replace the use sites for the `specialize` // instruction with uses of the specialized function. diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp index 19503d63f..0791c508b 100644 --- a/source/slang/lookup.cpp +++ b/source/slang/lookup.cpp @@ -297,6 +297,23 @@ void DoLocalLookupImpl( session, name, extDeclRef, request, result, inBreadcrumbs); } + + } + // for interface decls, also lookup in the base interfaces + if (request.semantics) + { + if (auto interfaceDeclRef = containerDeclRef.As()) + { + auto baseInterfaces = getMembersOfType(interfaceDeclRef); + for (auto inheritanceDeclRef : baseInterfaces) + { + auto baseType = inheritanceDeclRef.getDecl()->base.type.As(); + SLANG_ASSERT(baseType); + int diff = 0; + auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(interfaceDeclRef.substitutions, &diff); + DoLocalLookupImpl(session, name, baseInterfaceDeclRef.As(), request, result, inBreadcrumbs); + } + } } } diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 498783f4b..c8e010f7d 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -2781,6 +2781,30 @@ struct DeclLoweringVisitor : DeclVisitor return LoweredValInfo(); } + void walkInheritanceHierarchyAndCreateWitnessTableCopies(IRWitnessTable* witnessTable, Type* subType, InheritanceDecl* inheritanceDecl) + { + auto baseDeclRef = inheritanceDecl->base.type.As(); + if (auto baseInterfaceDeclRef = baseDeclRef->declRef.As()) + { + for (auto subInheritanceDeclRef : getMembersOfType(baseInterfaceDeclRef)) + { + auto cpyMangledName = getMangledNameForConformanceWitness(subType, subInheritanceDeclRef.getDecl()->getSup().type); + if (!witnessTablesDictionary.ContainsKey(cpyMangledName)) + { + auto cpyTable = context->irBuilder->createWitnessTable(); + cpyTable->mangledName = cpyMangledName; + context->irBuilder->createWitnessTableEntry(witnessTable, + context->irBuilder->getDeclRefVal(subInheritanceDeclRef), cpyTable); + cpyTable->entries = witnessTable->entries; + witnessTablesDictionary.Add(cpyMangledName, cpyTable); + walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, subType, subInheritanceDeclRef.getDecl()); + } + } + } + } + + Dictionary witnessTablesDictionary; + LoweredValInfo visitInheritanceDecl(InheritanceDecl* inheritanceDecl) { // Construct a type for the parent declaration. @@ -2817,6 +2841,8 @@ struct DeclLoweringVisitor : DeclVisitor // conformance of the type to its super-type. auto witnessTable = context->irBuilder->createWitnessTable(); witnessTable->mangledName = mangledName; + + witnessTablesDictionary.Add(mangledName, witnessTable); if (parentDecl->ParentDecl) witnessTable->genericDecl = dynamic_cast(parentDecl->ParentDecl); @@ -2850,6 +2876,7 @@ struct DeclLoweringVisitor : DeclVisitor } witnessTable->moveToEnd(); + walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, type, inheritanceDecl); // A direct reference to this inheritance relationship (e.g., // as a subtype witness) will take the form of a reference to diff --git a/tests/compute/transitive-interface.slang b/tests/compute/transitive-interface.slang new file mode 100644 index 000000000..04ececf93 --- /dev/null +++ b/tests/compute/transitive-interface.slang @@ -0,0 +1,66 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer outputBuffer; + +interface IAdd +{ + float addf(float u, float v); +} + +interface ISub +{ + float subf(float u, float v); +} + +interface IAddAndSub : IAdd, ISub +{ +} + +struct Simple : IAddAndSub +{ + float addf(float u, float v) + { + return u+v; + } + float subf(float u, float v) + { + return u-v; + } +}; + +float testAdd(T t) +{ + return t.addf(1.0, 1.0); +} + +interface IAssoc +{ + associatedtype AT : IAdd; +} + +struct AssocImpl : IAssoc +{ + typedef Simple AT; +}; + +float testAdd2(T assoc) +{ + T.AT obj; + return obj.addf(1.0, 1.0); +} + +float testSub(T t, float base) +{ + return t.subf(base, 1.0); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + AssocImpl s; + float outVal = testAdd2(s); + Simple s1; + outVal += testSub(s1, outVal); + outputBuffer[dispatchThreadID.x] = outVal; +} \ No newline at end of file diff --git a/tests/compute/transitive-interface.slang.expected.txt b/tests/compute/transitive-interface.slang.expected.txt new file mode 100644 index 000000000..e143b7f20 --- /dev/null +++ b/tests/compute/transitive-interface.slang.expected.txt @@ -0,0 +1,4 @@ +3F800000 +3F800000 +3F800000 +3F800000 \ No newline at end of file -- cgit v1.2.3