summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/ir.cpp6
-rw-r--r--source/slang/lookup.cpp17
-rw-r--r--source/slang/lower-to-ir.cpp27
-rw-r--r--tests/compute/transitive-interface.slang66
-rw-r--r--tests/compute/transitive-interface.slang.expected.txt4
5 files changed, 118 insertions, 2 deletions
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 46eff33e9..17ba14c6e 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -3690,12 +3690,12 @@ namespace Slang
cloneFunctionCommon(context, clonedFunc, originalFunc);
// for now, clone all unreferenced witness tables
- /*for (auto gv = context->getOriginalModule()->getFirstGlobalValue();
+ for (auto gv = context->getOriginalModule()->getFirstGlobalValue();
gv; gv = gv->getNextValue())
{
if (gv->op == kIROp_witness_table)
cloneGlobalValue(context, (IRWitnessTable*)gv);
- }*/
+ }
// We need to attach the layout information for
// the entry point to this declaration, so that
@@ -4746,7 +4746,9 @@ namespace Slang
//
// We will first find or construct a specialized version
// of the callee funciton/
+ auto oldFunc = dumpIRFunc(genericFunc);
auto specFunc = getSpecializedFunc(sharedContext, genericFunc, specDeclRef);
+ auto newFunc = dumpIRFunc(specFunc);
//
// Then we will replace the use sites for the `specialize`
// instruction with uses of the specialized function.
diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp
index 19503d63f..0791c508b 100644
--- a/source/slang/lookup.cpp
+++ b/source/slang/lookup.cpp
@@ -297,6 +297,23 @@ void DoLocalLookupImpl(
session,
name, extDeclRef, request, result, inBreadcrumbs);
}
+
+ }
+ // for interface decls, also lookup in the base interfaces
+ if (request.semantics)
+ {
+ if (auto interfaceDeclRef = containerDeclRef.As<InterfaceDecl>())
+ {
+ auto baseInterfaces = getMembersOfType<InheritanceDecl>(interfaceDeclRef);
+ for (auto inheritanceDeclRef : baseInterfaces)
+ {
+ auto baseType = inheritanceDeclRef.getDecl()->base.type.As<DeclRefType>();
+ SLANG_ASSERT(baseType);
+ int diff = 0;
+ auto baseInterfaceDeclRef = baseType->declRef.SubstituteImpl(interfaceDeclRef.substitutions, &diff);
+ DoLocalLookupImpl(session, name, baseInterfaceDeclRef.As<ContainerDecl>(), request, result, inBreadcrumbs);
+ }
+ }
}
}
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 498783f4b..c8e010f7d 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -2781,6 +2781,30 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo();
}
+ void walkInheritanceHierarchyAndCreateWitnessTableCopies(IRWitnessTable* witnessTable, Type* subType, InheritanceDecl* inheritanceDecl)
+ {
+ auto baseDeclRef = inheritanceDecl->base.type.As<DeclRefType>();
+ if (auto baseInterfaceDeclRef = baseDeclRef->declRef.As<InterfaceDecl>())
+ {
+ for (auto subInheritanceDeclRef : getMembersOfType<InheritanceDecl>(baseInterfaceDeclRef))
+ {
+ auto cpyMangledName = getMangledNameForConformanceWitness(subType, subInheritanceDeclRef.getDecl()->getSup().type);
+ if (!witnessTablesDictionary.ContainsKey(cpyMangledName))
+ {
+ auto cpyTable = context->irBuilder->createWitnessTable();
+ cpyTable->mangledName = cpyMangledName;
+ context->irBuilder->createWitnessTableEntry(witnessTable,
+ context->irBuilder->getDeclRefVal(subInheritanceDeclRef), cpyTable);
+ cpyTable->entries = witnessTable->entries;
+ witnessTablesDictionary.Add(cpyMangledName, cpyTable);
+ walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, subType, subInheritanceDeclRef.getDecl());
+ }
+ }
+ }
+ }
+
+ Dictionary<String, IRWitnessTable*> witnessTablesDictionary;
+
LoweredValInfo visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
{
// Construct a type for the parent declaration.
@@ -2817,6 +2841,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// conformance of the type to its super-type.
auto witnessTable = context->irBuilder->createWitnessTable();
witnessTable->mangledName = mangledName;
+
+ witnessTablesDictionary.Add(mangledName, witnessTable);
if (parentDecl->ParentDecl)
witnessTable->genericDecl = dynamic_cast<GenericDecl*>(parentDecl->ParentDecl);
@@ -2850,6 +2876,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
witnessTable->moveToEnd();
+ walkInheritanceHierarchyAndCreateWitnessTableCopies(witnessTable, type, inheritanceDecl);
// A direct reference to this inheritance relationship (e.g.,
// as a subtype witness) will take the form of a reference to
diff --git a/tests/compute/transitive-interface.slang b/tests/compute/transitive-interface.slang
new file mode 100644
index 000000000..04ececf93
--- /dev/null
+++ b/tests/compute/transitive-interface.slang
@@ -0,0 +1,66 @@
+//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IAdd
+{
+ float addf(float u, float v);
+}
+
+interface ISub
+{
+ float subf(float u, float v);
+}
+
+interface IAddAndSub : IAdd, ISub
+{
+}
+
+struct Simple : IAddAndSub
+{
+ float addf(float u, float v)
+ {
+ return u+v;
+ }
+ float subf(float u, float v)
+ {
+ return u-v;
+ }
+};
+
+float testAdd<T:IAdd>(T t)
+{
+ return t.addf(1.0, 1.0);
+}
+
+interface IAssoc
+{
+ associatedtype AT : IAdd;
+}
+
+struct AssocImpl : IAssoc
+{
+ typedef Simple AT;
+};
+
+float testAdd2<T:IAssoc>(T assoc)
+{
+ T.AT obj;
+ return obj.addf(1.0, 1.0);
+}
+
+float testSub<T:ISub>(T t, float base)
+{
+ return t.subf(base, 1.0);
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ AssocImpl s;
+ float outVal = testAdd2(s);
+ Simple s1;
+ outVal += testSub(s1, outVal);
+ outputBuffer[dispatchThreadID.x] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/transitive-interface.slang.expected.txt b/tests/compute/transitive-interface.slang.expected.txt
new file mode 100644
index 000000000..e143b7f20
--- /dev/null
+++ b/tests/compute/transitive-interface.slang.expected.txt
@@ -0,0 +1,4 @@
+3F800000
+3F800000
+3F800000
+3F800000 \ No newline at end of file