summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2018-01-21 16:26:52 -0800
committerTim Foley <tfoleyNV@users.noreply.github.com>2018-01-21 16:26:52 -0800
commit8196dc4a684a75344e507697273e2123af97b979 (patch)
treedbee3f92a10f15a9f7202010e1683b369c38ba15
parent4044a1d3a0605198465a7eb6e0e3c1f8b1a3c298 (diff)
specialize witness tables when needed when specializing `lookup_witness_table` instruction. (#376)
-rw-r--r--source/slang/ir.cpp14
-rw-r--r--source/slang/syntax.cpp7
-rw-r--r--source/slang/type-defs.h1
-rw-r--r--tests/compute/int-generic.slang42
-rw-r--r--tests/compute/int-generic.slang.expected.txt1
5 files changed, 64 insertions, 1 deletions
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 7318bff4c..1d3c91979 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -4862,6 +4862,20 @@ namespace Slang
auto interfaceDeclRef = ((IRDeclRef*)lookupInst->interfaceType.usedValue)->declRef;
auto mangledName = getMangledNameForConformanceWitness(srcDeclRef, interfaceDeclRef);
witnessTables.TryGetValue(mangledName, witnessTable);
+
+ if (!witnessTable)
+ {
+ // try specialize the witness table
+ auto genDeclRef = srcDeclRef;
+ genDeclRef.substitutions = createDefaultSubstitutions(module->session, genDeclRef.decl);
+ auto genName = getMangledNameForConformanceWitness(genDeclRef, interfaceDeclRef);
+ IRWitnessTable* genTable = nullptr;
+ if (witnessTables.TryGetValue(genName, genTable))
+ {
+ witnessTable = specializeWitnessTable(sharedContext, genTable, srcDeclRef, nullptr);
+ witnessTables.AddIfNotExists(witnessTable->mangledName, witnessTable);
+ }
+ }
if (witnessTable)
{
lookupInst->replaceUsesWith(witnessTable);
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index ab4a5f94c..3bccf51ce 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -856,7 +856,12 @@ void Type::accept(IValVisitor* visitor, void* extra)
Type* ErrorType::CreateCanonicalType()
{
- return this;
+ return this;
+ }
+
+ RefPtr<Val> ErrorType::SubstituteImpl(SubstitutionSet /*subst*/, int* /*ioDiff*/)
+ {
+ return this;
}
int ErrorType::GetHashCode()
diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h
index c4ec09f1d..db9630c0e 100644
--- a/source/slang/type-defs.h
+++ b/source/slang/type-defs.h
@@ -36,6 +36,7 @@ public:
protected:
virtual bool EqualsImpl(Type * type) override;
+ virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override;
virtual Type* CreateCanonicalType() override;
virtual int GetHashCode() override;
)
diff --git a/tests/compute/int-generic.slang b/tests/compute/int-generic.slang
new file mode 100644
index 000000000..7531ee74e
--- /dev/null
+++ b/tests/compute/int-generic.slang
@@ -0,0 +1,42 @@
+//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
+//TEST_INPUT:type Material<1,2>
+RWStructuredBuffer<int> outputBuffer;
+
+interface IBRDF
+{
+ int compute();
+};
+
+interface IMaterial
+{
+ associatedtype TBRDF : IBRDF;
+ TBRDF getBRDF();
+}
+
+struct BRDF<let A:int, let B:int> : IBRDF
+{
+ int c;
+ int compute()
+ {
+ return A+B;
+ }
+};
+
+struct Material<let A:int, let B: int> : IMaterial
+{
+ typedef BRDF<A,B> TBRDF;
+ TBRDF getBRDF() { TBRDF a; a.c = 0; return a; }
+};
+
+type_param TMaterial : IMaterial;
+
+TMaterial material;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ TMaterial.TBRDF brdf = material.getBRDF();
+ int outVal = brdf.compute();
+ outputBuffer[dispatchThreadID.x] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/int-generic.slang.expected.txt b/tests/compute/int-generic.slang.expected.txt
new file mode 100644
index 000000000..e440e5c84
--- /dev/null
+++ b/tests/compute/int-generic.slang.expected.txt
@@ -0,0 +1 @@
+3 \ No newline at end of file