summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2020-10-22 23:44:11 -0700
committerGitHub <noreply@github.com>2020-10-22 23:44:11 -0700
commit6d1fe29cdcbca18d559e302d6427a504d1762173 (patch)
treec4f2539c4ad926f3a71ee4af5e13e28e3f7b9606 /source
parent10e1bae34733f1cdb5abc001666b1aafa1c1f406 (diff)
Generate `if` based dispatch logic on GPU targets. (#1585)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-generics-lowering-context.h3
-rw-r--r--source/slang/slang-ir-lower-generic-call.cpp11
-rw-r--r--source/slang/slang-ir-lower-generics.cpp10
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp127
-rw-r--r--source/slang/slang-ir-specialize-dispatch.h13
-rw-r--r--source/slang/slang.vcxproj4
-rw-r--r--source/slang/slang.vcxproj.filters6
7 files changed, 166 insertions, 8 deletions
diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h
index be56e9c84..3bd86e068 100644
--- a/source/slang/slang-ir-generics-lowering-context.h
+++ b/source/slang/slang-ir-generics-lowering-context.h
@@ -31,6 +31,9 @@ namespace Slang
// Dictionaries for interface type requirement key-value lookups.
// Used by `findInterfaceRequirementVal`.
Dictionary<IRInterfaceType*, Dictionary<IRInst*, IRInst*>> mapInterfaceRequirementKeyValue;
+
+ // Map from interface requirement keys to its corresponding dispatch method.
+ OrderedDictionary<IRInst*, IRFunc*> mapInterfaceRequirementKeyToDispatchMethods;
SharedIRBuilder sharedBuilderStorage;
diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp
index 577b4e86d..bd01a78fb 100644
--- a/source/slang/slang-ir-lower-generic-call.cpp
+++ b/source/slang/slang-ir-lower-generic-call.cpp
@@ -8,9 +8,6 @@ namespace Slang
{
SharedGenericsLoweringContext* sharedContext;
- // Map from interface requirement keys to its corresponding dispatch method.
- OrderedDictionary<IRInst*, IRFunc*> mapInterfaceRequirementKeyToDispatchMethods;
-
// Represents a work item for unpacking `inout` or `out` arguments after a generic call.
struct ArgumentUnpackWorkItem
{
@@ -91,8 +88,8 @@ namespace Slang
// Create a dispatch function for a interface method.
// On CPU, the dispatch function is implemented as a witness table lookup followed by
// a function-pointer call.
- // TODO: On GPU targets, we should implement the dispatch function with a `switch` statement
- // based on the type ID.
+ // On GPU targets, we can modify the body of the dispatch function in a follow-up
+ // pass to implement it with a `switch` statement based on the type ID.
IRFunc* _createInterfaceDispatchMethod(
IRBuilder* builder,
IRInterfaceType* interfaceType,
@@ -140,11 +137,11 @@ namespace Slang
IRInst* requirementKey,
IRInst* requirementVal)
{
- if (auto func = mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(requirementKey))
+ if (auto func = sharedContext->mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(requirementKey))
return *func;
auto dispatchFunc =
_createInterfaceDispatchMethod(builder, interfaceType, requirementKey, requirementVal);
- mapInterfaceRequirementKeyToDispatchMethods.AddIfNotExists(
+ sharedContext->mapInterfaceRequirementKeyToDispatchMethods.AddIfNotExists(
requirementKey, dispatchFunc);
return dispatchFunc;
}
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index a9540a87a..4b86cff51 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -8,6 +8,7 @@
#include "slang-ir-lower-generic-function.h"
#include "slang-ir-lower-generic-call.h"
#include "slang-ir-lower-generic-type.h"
+#include "slang-ir-specialize-dispatch.h"
#include "slang-ir-witness-table-wrapper.h"
#include "slang-ir-ssa.h"
#include "slang-ir-dce.h"
@@ -57,6 +58,15 @@ namespace Slang
generateAnyValueMarshallingFunctions(&sharedContext);
if (sink->getErrorCount() != 0)
return;
+
+ // On non-CPU targets, generate `if` based dispatch functions.
+ if (sharedContext.targetReq->getTarget() != CodeGenTarget::CPPSource)
+ {
+ specializeDispatchFunctions(&sharedContext);
+ if (sink->getErrorCount() != 0)
+ return;
+ }
+
// We might have generated new temporary variables during lowering.
// An SSA pass can clean up unnecessary load/stores.
constructSSA(module);
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
new file mode 100644
index 000000000..0c519427d
--- /dev/null
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -0,0 +1,127 @@
+#include "slang-ir-specialize-dispatch.h"
+
+#include "slang-ir-generics-lowering-context.h"
+#include "slang-ir-insts.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
+{
+ for (auto entry : table->getEntries())
+ {
+ if (entry->getRequirementKey() == key)
+ return entry->getSatisfyingVal();
+ }
+ return nullptr;
+}
+
+void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
+{
+ auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);
+
+ // Collect all witness tables of `witnessTableType` in current module.
+ List<IRWitnessTable*> witnessTables;
+ for (auto globalInst : sharedContext->module->getGlobalInsts())
+ {
+ if (globalInst->op == kIROp_WitnessTable && globalInst->getDataType() == witnessTableType)
+ {
+ witnessTables.add(cast<IRWitnessTable>(globalInst));
+ }
+ }
+
+ SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock());
+ auto block = dispatchFunc->getFirstBlock();
+
+ // The dispatch function before modification must be in the form of
+ // call(lookup_interface_method(witnessTableParam, interfaceReqKey), args)
+ // We now find the relavent instructions.
+ IRCall* callInst = nullptr;
+ IRLookupWitnessMethod* lookupInst = nullptr;
+ IRReturn* returnInst = nullptr;
+ for (auto inst : block->getOrdinaryInsts())
+ {
+ switch (inst->op)
+ {
+ case kIROp_Call:
+ callInst = cast<IRCall>(inst);
+ break;
+ case kIROp_lookup_interface_method:
+ lookupInst = cast<IRLookupWitnessMethod>(inst);
+ break;
+ case kIROp_ReturnVal:
+ case kIROp_ReturnVoid:
+ returnInst = cast<IRReturn>(inst);
+ break;
+ default:
+ break;
+ }
+ }
+ SLANG_ASSERT(callInst && lookupInst && returnInst);
+
+ IRBuilder builderStorage;
+ auto builder = &builderStorage;
+ builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder->setInsertBefore(callInst);
+
+ auto witnessTableParam = block->getFirstParam();
+ auto requirementKey = lookupInst->getRequirementKey();
+ List<IRInst*> params;
+ for (auto param = block->getFirstParam()->getNextParam(); param; param = param->getNextParam())
+ {
+ params.add(param);
+ }
+
+ // Emit cascaded if statements to call the correct concrete function based on
+ // the witness table pointer passed in.
+ auto ifBlock = block;
+ for (Index i = 0; i < witnessTables.getCount(); i++)
+ {
+ auto witnessTable = witnessTables[i];
+ bool isLast = (i == witnessTables.getCount() - 1);
+ IRInst* cmpArgs[] =
+ {
+ builder->emitBitCast(builder->getUInt64Type(), witnessTableParam),
+ builder->emitBitCast(builder->getUInt64Type(),(IRInst*)witnessTable)
+ };
+ IRInst* condition = nullptr;
+ IRBlock* trueBlock = nullptr;
+ if (!isLast)
+ {
+ condition = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_Eql, 2, cmpArgs);
+ trueBlock = builder->emitBlock();
+ }
+ auto callee = findWitnessTableEntry(witnessTable, requirementKey);
+ SLANG_ASSERT(callee);
+ auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params);
+ if (callInst->getDataType()->op == kIROp_VoidType)
+ builder->emitReturn();
+ else
+ builder->emitReturn(specializedCallInst);
+ if (!isLast)
+ {
+ auto falseBlock = builder->emitBlock();
+ builder->setInsertInto(ifBlock);
+ builder->emitIf(condition, trueBlock, falseBlock);
+ builder->setInsertInto(falseBlock);
+ ifBlock = falseBlock;
+ }
+ }
+
+ // Remove old implementation.
+ lookupInst->removeAndDeallocate();
+ callInst->removeAndDeallocate();
+ returnInst->removeAndDeallocate();
+}
+
+void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext)
+{
+ sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
+
+ for (auto kv : sharedContext->mapInterfaceRequirementKeyToDispatchMethods)
+ {
+ auto dispatchFunc = kv.Value;
+ specializeDispatchFunction(sharedContext, dispatchFunc);
+ }
+}
+} // namespace Slang
diff --git a/source/slang/slang-ir-specialize-dispatch.h b/source/slang/slang-ir-specialize-dispatch.h
new file mode 100644
index 000000000..fe87eb0bf
--- /dev/null
+++ b/source/slang/slang-ir-specialize-dispatch.h
@@ -0,0 +1,13 @@
+// slang-ir-specialize-dispatch.h
+#pragma once
+
+namespace Slang
+{
+struct SharedGenericsLoweringContext;
+
+/// Modifies the body of interface dispatch functions to use branching instead
+/// of function pointer calls to implement the dynamic dispatch logic.
+/// This is only used on GPU targets where function pointers are not supported
+/// or are not efficient.
+void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext);
+}
diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj
index bb4293b8f..a09282a4a 100644
--- a/source/slang/slang.vcxproj
+++ b/source/slang/slang.vcxproj
@@ -258,6 +258,7 @@
<ClInclude Include="slang-ir-restructure.h" />
<ClInclude Include="slang-ir-sccp.h" />
<ClInclude Include="slang-ir-specialize-arrays.h" />
+ <ClInclude Include="slang-ir-specialize-dispatch.h" />
<ClInclude Include="slang-ir-specialize-function-call.h" />
<ClInclude Include="slang-ir-specialize-resources.h" />
<ClInclude Include="slang-ir-specialize.h" />
@@ -390,6 +391,7 @@
<ClCompile Include="slang-ir-restructure.cpp" />
<ClCompile Include="slang-ir-sccp.cpp" />
<ClCompile Include="slang-ir-specialize-arrays.cpp" />
+ <ClCompile Include="slang-ir-specialize-dispatch.cpp" />
<ClCompile Include="slang-ir-specialize-function-call.cpp" />
<ClCompile Include="slang-ir-specialize-resources.cpp" />
<ClCompile Include="slang-ir-specialize.cpp" />
@@ -453,4 +455,4 @@
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
</ImportGroup>
-</Project> \ No newline at end of file
+</Project> \ No newline at end of file
diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters
index aad88a15f..ab0e52ddb 100644
--- a/source/slang/slang.vcxproj.filters
+++ b/source/slang/slang.vcxproj.filters
@@ -225,6 +225,9 @@
<ClInclude Include="slang-ir-specialize-arrays.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="slang-ir-specialize-dispatch.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="slang-ir-specialize-function-call.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -617,6 +620,9 @@
<ClCompile Include="slang-ir-specialize-arrays.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="slang-ir-specialize-dispatch.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="slang-ir-specialize-function-call.cpp">
<Filter>Source Files</Filter>
</ClCompile>