diff options
| author | Yong He <yonghe@outlook.com> | 2020-06-19 11:19:51 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-06-19 11:19:51 -0700 |
| commit | 110d15b61ac5d76da001d412eaa4be07f3cd8f4d (patch) | |
| tree | 6d2cef5dab495b484844b4d54c312751af62091e | |
| parent | 5fbb9ff7e1516bd787695d2c9d80b696f0a9ca9a (diff) | |
Dynamic dispatch for static member functions of associatedtypes. (#1404)
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 21 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-2.slang | 53 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-2.slang.expected.txt | 4 |
4 files changed, 78 insertions, 13 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 6c004c84c..3438fd3f4 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -237,18 +237,15 @@ List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWi for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { auto reqKey = cast<IRStructKey>(interfaceType->getOperand(i)); - bool matchingEntryFound = false; IRWitnessTableEntry* entry = nullptr; if (witnessTableEntryDictionary.TryGetValue(reqKey, entry)) { - if (entry->requirementKey.get() == reqKey) - { - matchingEntryFound = true; - sortedWitnessTableEntries.add(entry); - break; - } + sortedWitnessTableEntries.add(entry); + } + else + { + SLANG_UNREACHABLE("interface requirement key not found in witness table."); } - SLANG_ASSERT(matchingEntryFound); } return sortedWitnessTableEntries; } diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index accb290fa..4a59f4cf9 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1711,6 +1711,15 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions() m_writer->emit("&KernelContext::"); m_writer->emit(_getWitnessTableWrapperFuncName(funcVal)); } + else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal())) + { + if (!isFirstEntry) + m_writer->emit(",\n"); + else + isFirstEntry = false; + m_writer->emit("&"); + m_writer->emit(getName(witnessTableVal)); + } else { // TODO: handle other witness table entry types. @@ -1745,16 +1754,11 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( emitSimpleType(interfaceType); m_writer->emit("\n{\n"); m_writer->indent(); - bool isFirstEntry = true; for (Index i = 0; i < sortedWitnessTableEntries.getCount(); i++) { auto entry = sortedWitnessTableEntries[i]; if (auto funcVal = as<IRFunc>(entry->satisfyingVal.get())) { - if (!isFirstEntry) - m_writer->emit(",\n"); - else - isFirstEntry = false; emitType(funcVal->getResultType()); m_writer->emit(" (KernelContext::*"); m_writer->emit(getName(entry->requirementKey.get())); @@ -1777,6 +1781,13 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( } m_writer->emit(");\n"); } + else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal())) + { + emitType(as<IRType>(witnessTableVal->getOperand(0))); + m_writer->emit("* "); + m_writer->emit(getName(entry->requirementKey.get())); + m_writer->emit(";\n"); + } else { // TODO: handle other witness table entry types. diff --git a/tests/compute/dynamic-dispatch-2.slang b/tests/compute/dynamic-dispatch-2.slang new file mode 100644 index 000000000..ade8aeb84 --- /dev/null +++ b/tests/compute/dynamic-dispatch-2.slang @@ -0,0 +1,53 @@ +//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code + +// Test dynamic dispatch code gen for static member functions +// of associated type. + +interface IAssoc +{ + int get(); + static int getBase(); +} +interface IInterface +{ + associatedtype Assoc : IAssoc; + int Compute(int inVal); + Assoc getAssoc(); +}; + +int GenericCompute<T:IInterface>(T obj, int inVal) +{ + return obj.Compute(inVal) + T.Assoc.getBase(); +} + +struct Impl : IInterface +{ + struct Assoc : IAssoc + { + int val; + int get() { return val; } + static int getBase() { return -1; } + }; + int base; + int Compute(int inVal) { return base + inVal * inVal; } + Assoc getAssoc() { Assoc rs; rs.val = 1; return rs; } +}; + +int test(int inVal) +{ + Impl obj; + obj.base = 1; + return GenericCompute<Impl>(obj, inVal); +} + +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer : register(u0); + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = outputBuffer[tid]; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +} diff --git a/tests/compute/dynamic-dispatch-2.slang.expected.txt b/tests/compute/dynamic-dispatch-2.slang.expected.txt new file mode 100644 index 000000000..c9fa0697e --- /dev/null +++ b/tests/compute/dynamic-dispatch-2.slang.expected.txt @@ -0,0 +1,4 @@ +0 +1 +4 +9 |
