From 110d15b61ac5d76da001d412eaa4be07f3cd8f4d Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 19 Jun 2020 11:19:51 -0700 Subject: Dynamic dispatch for static member functions of associatedtypes. (#1404) --- source/slang/slang-emit-c-like.cpp | 13 ++---- source/slang/slang-emit-cpp.cpp | 21 +++++++-- tests/compute/dynamic-dispatch-2.slang | 53 ++++++++++++++++++++++ .../compute/dynamic-dispatch-2.slang.expected.txt | 4 ++ 4 files changed, 78 insertions(+), 13 deletions(-) create mode 100644 tests/compute/dynamic-dispatch-2.slang create mode 100644 tests/compute/dynamic-dispatch-2.slang.expected.txt diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 6c004c84c..3438fd3f4 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -237,18 +237,15 @@ List CLikeSourceEmitter::getSortedWitnessTableEntries(IRWi for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { auto reqKey = cast(interfaceType->getOperand(i)); - bool matchingEntryFound = false; IRWitnessTableEntry* entry = nullptr; if (witnessTableEntryDictionary.TryGetValue(reqKey, entry)) { - if (entry->requirementKey.get() == reqKey) - { - matchingEntryFound = true; - sortedWitnessTableEntries.add(entry); - break; - } + sortedWitnessTableEntries.add(entry); + } + else + { + SLANG_UNREACHABLE("interface requirement key not found in witness table."); } - SLANG_ASSERT(matchingEntryFound); } return sortedWitnessTableEntries; } diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index accb290fa..4a59f4cf9 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1711,6 +1711,15 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions() m_writer->emit("&KernelContext::"); m_writer->emit(_getWitnessTableWrapperFuncName(funcVal)); } + else if (auto witnessTableVal = as(entry->getSatisfyingVal())) + { + if (!isFirstEntry) + m_writer->emit(",\n"); + else + isFirstEntry = false; + m_writer->emit("&"); + m_writer->emit(getName(witnessTableVal)); + } else { // TODO: handle other witness table entry types. @@ -1745,16 +1754,11 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( emitSimpleType(interfaceType); m_writer->emit("\n{\n"); m_writer->indent(); - bool isFirstEntry = true; for (Index i = 0; i < sortedWitnessTableEntries.getCount(); i++) { auto entry = sortedWitnessTableEntries[i]; if (auto funcVal = as(entry->satisfyingVal.get())) { - if (!isFirstEntry) - m_writer->emit(",\n"); - else - isFirstEntry = false; emitType(funcVal->getResultType()); m_writer->emit(" (KernelContext::*"); m_writer->emit(getName(entry->requirementKey.get())); @@ -1777,6 +1781,13 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( } m_writer->emit(");\n"); } + else if (auto witnessTableVal = as(entry->getSatisfyingVal())) + { + emitType(as(witnessTableVal->getOperand(0))); + m_writer->emit("* "); + m_writer->emit(getName(entry->requirementKey.get())); + m_writer->emit(";\n"); + } else { // TODO: handle other witness table entry types. diff --git a/tests/compute/dynamic-dispatch-2.slang b/tests/compute/dynamic-dispatch-2.slang new file mode 100644 index 000000000..ade8aeb84 --- /dev/null +++ b/tests/compute/dynamic-dispatch-2.slang @@ -0,0 +1,53 @@ +//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code + +// Test dynamic dispatch code gen for static member functions +// of associated type. + +interface IAssoc +{ + int get(); + static int getBase(); +} +interface IInterface +{ + associatedtype Assoc : IAssoc; + int Compute(int inVal); + Assoc getAssoc(); +}; + +int GenericCompute(T obj, int inVal) +{ + return obj.Compute(inVal) + T.Assoc.getBase(); +} + +struct Impl : IInterface +{ + struct Assoc : IAssoc + { + int val; + int get() { return val; } + static int getBase() { return -1; } + }; + int base; + int Compute(int inVal) { return base + inVal * inVal; } + Assoc getAssoc() { Assoc rs; rs.val = 1; return rs; } +}; + +int test(int inVal) +{ + Impl obj; + obj.base = 1; + return GenericCompute(obj, inVal); +} + +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer : register(u0); + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = outputBuffer[tid]; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +} diff --git a/tests/compute/dynamic-dispatch-2.slang.expected.txt b/tests/compute/dynamic-dispatch-2.slang.expected.txt new file mode 100644 index 000000000..c9fa0697e --- /dev/null +++ b/tests/compute/dynamic-dispatch-2.slang.expected.txt @@ -0,0 +1,4 @@ +0 +1 +4 +9 -- cgit v1.2.3