From 6d1fe29cdcbca18d559e302d6427a504d1762173 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 22 Oct 2020 23:44:11 -0700 Subject: Generate `if` based dispatch logic on GPU targets. (#1585) --- source/slang/slang-ir-generics-lowering-context.h | 3 + source/slang/slang-ir-lower-generic-call.cpp | 11 +- source/slang/slang-ir-lower-generics.cpp | 10 ++ source/slang/slang-ir-specialize-dispatch.cpp | 127 ++++++++++++++++++++++ source/slang/slang-ir-specialize-dispatch.h | 13 +++ source/slang/slang.vcxproj | 4 +- source/slang/slang.vcxproj.filters | 6 + 7 files changed, 166 insertions(+), 8 deletions(-) create mode 100644 source/slang/slang-ir-specialize-dispatch.cpp create mode 100644 source/slang/slang-ir-specialize-dispatch.h (limited to 'source') diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h index be56e9c84..3bd86e068 100644 --- a/source/slang/slang-ir-generics-lowering-context.h +++ b/source/slang/slang-ir-generics-lowering-context.h @@ -31,6 +31,9 @@ namespace Slang // Dictionaries for interface type requirement key-value lookups. // Used by `findInterfaceRequirementVal`. Dictionary> mapInterfaceRequirementKeyValue; + + // Map from interface requirement keys to its corresponding dispatch method. + OrderedDictionary mapInterfaceRequirementKeyToDispatchMethods; SharedIRBuilder sharedBuilderStorage; diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp index 577b4e86d..bd01a78fb 100644 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -8,9 +8,6 @@ namespace Slang { SharedGenericsLoweringContext* sharedContext; - // Map from interface requirement keys to its corresponding dispatch method. - OrderedDictionary mapInterfaceRequirementKeyToDispatchMethods; - // Represents a work item for unpacking `inout` or `out` arguments after a generic call. struct ArgumentUnpackWorkItem { @@ -91,8 +88,8 @@ namespace Slang // Create a dispatch function for a interface method. // On CPU, the dispatch function is implemented as a witness table lookup followed by // a function-pointer call. - // TODO: On GPU targets, we should implement the dispatch function with a `switch` statement - // based on the type ID. + // On GPU targets, we can modify the body of the dispatch function in a follow-up + // pass to implement it with a `switch` statement based on the type ID. IRFunc* _createInterfaceDispatchMethod( IRBuilder* builder, IRInterfaceType* interfaceType, @@ -140,11 +137,11 @@ namespace Slang IRInst* requirementKey, IRInst* requirementVal) { - if (auto func = mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(requirementKey)) + if (auto func = sharedContext->mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(requirementKey)) return *func; auto dispatchFunc = _createInterfaceDispatchMethod(builder, interfaceType, requirementKey, requirementVal); - mapInterfaceRequirementKeyToDispatchMethods.AddIfNotExists( + sharedContext->mapInterfaceRequirementKeyToDispatchMethods.AddIfNotExists( requirementKey, dispatchFunc); return dispatchFunc; } diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index a9540a87a..4b86cff51 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -8,6 +8,7 @@ #include "slang-ir-lower-generic-function.h" #include "slang-ir-lower-generic-call.h" #include "slang-ir-lower-generic-type.h" +#include "slang-ir-specialize-dispatch.h" #include "slang-ir-witness-table-wrapper.h" #include "slang-ir-ssa.h" #include "slang-ir-dce.h" @@ -57,6 +58,15 @@ namespace Slang generateAnyValueMarshallingFunctions(&sharedContext); if (sink->getErrorCount() != 0) return; + + // On non-CPU targets, generate `if` based dispatch functions. + if (sharedContext.targetReq->getTarget() != CodeGenTarget::CPPSource) + { + specializeDispatchFunctions(&sharedContext); + if (sink->getErrorCount() != 0) + return; + } + // We might have generated new temporary variables during lowering. // An SSA pass can clean up unnecessary load/stores. constructSSA(module); diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp new file mode 100644 index 000000000..0c519427d --- /dev/null +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -0,0 +1,127 @@ +#include "slang-ir-specialize-dispatch.h" + +#include "slang-ir-generics-lowering-context.h" +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ +IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key) +{ + for (auto entry : table->getEntries()) + { + if (entry->getRequirementKey() == key) + return entry->getSatisfyingVal(); + } + return nullptr; +} + +void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc) +{ + auto witnessTableType = cast(dispatchFunc->getDataType())->getParamType(0); + + // Collect all witness tables of `witnessTableType` in current module. + List witnessTables; + for (auto globalInst : sharedContext->module->getGlobalInsts()) + { + if (globalInst->op == kIROp_WitnessTable && globalInst->getDataType() == witnessTableType) + { + witnessTables.add(cast(globalInst)); + } + } + + SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock()); + auto block = dispatchFunc->getFirstBlock(); + + // The dispatch function before modification must be in the form of + // call(lookup_interface_method(witnessTableParam, interfaceReqKey), args) + // We now find the relavent instructions. + IRCall* callInst = nullptr; + IRLookupWitnessMethod* lookupInst = nullptr; + IRReturn* returnInst = nullptr; + for (auto inst : block->getOrdinaryInsts()) + { + switch (inst->op) + { + case kIROp_Call: + callInst = cast(inst); + break; + case kIROp_lookup_interface_method: + lookupInst = cast(inst); + break; + case kIROp_ReturnVal: + case kIROp_ReturnVoid: + returnInst = cast(inst); + break; + default: + break; + } + } + SLANG_ASSERT(callInst && lookupInst && returnInst); + + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(callInst); + + auto witnessTableParam = block->getFirstParam(); + auto requirementKey = lookupInst->getRequirementKey(); + List params; + for (auto param = block->getFirstParam()->getNextParam(); param; param = param->getNextParam()) + { + params.add(param); + } + + // Emit cascaded if statements to call the correct concrete function based on + // the witness table pointer passed in. + auto ifBlock = block; + for (Index i = 0; i < witnessTables.getCount(); i++) + { + auto witnessTable = witnessTables[i]; + bool isLast = (i == witnessTables.getCount() - 1); + IRInst* cmpArgs[] = + { + builder->emitBitCast(builder->getUInt64Type(), witnessTableParam), + builder->emitBitCast(builder->getUInt64Type(),(IRInst*)witnessTable) + }; + IRInst* condition = nullptr; + IRBlock* trueBlock = nullptr; + if (!isLast) + { + condition = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_Eql, 2, cmpArgs); + trueBlock = builder->emitBlock(); + } + auto callee = findWitnessTableEntry(witnessTable, requirementKey); + SLANG_ASSERT(callee); + auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params); + if (callInst->getDataType()->op == kIROp_VoidType) + builder->emitReturn(); + else + builder->emitReturn(specializedCallInst); + if (!isLast) + { + auto falseBlock = builder->emitBlock(); + builder->setInsertInto(ifBlock); + builder->emitIf(condition, trueBlock, falseBlock); + builder->setInsertInto(falseBlock); + ifBlock = falseBlock; + } + } + + // Remove old implementation. + lookupInst->removeAndDeallocate(); + callInst->removeAndDeallocate(); + returnInst->removeAndDeallocate(); +} + +void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext) +{ + sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); + + for (auto kv : sharedContext->mapInterfaceRequirementKeyToDispatchMethods) + { + auto dispatchFunc = kv.Value; + specializeDispatchFunction(sharedContext, dispatchFunc); + } +} +} // namespace Slang diff --git a/source/slang/slang-ir-specialize-dispatch.h b/source/slang/slang-ir-specialize-dispatch.h new file mode 100644 index 000000000..fe87eb0bf --- /dev/null +++ b/source/slang/slang-ir-specialize-dispatch.h @@ -0,0 +1,13 @@ +// slang-ir-specialize-dispatch.h +#pragma once + +namespace Slang +{ +struct SharedGenericsLoweringContext; + +/// Modifies the body of interface dispatch functions to use branching instead +/// of function pointer calls to implement the dynamic dispatch logic. +/// This is only used on GPU targets where function pointers are not supported +/// or are not efficient. +void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext); +} diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index bb4293b8f..a09282a4a 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -258,6 +258,7 @@ + @@ -390,6 +391,7 @@ + @@ -453,4 +455,4 @@ - \ No newline at end of file + \ No newline at end of file diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index aad88a15f..ab0e52ddb 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -225,6 +225,9 @@ Header Files + + Header Files + Header Files @@ -617,6 +620,9 @@ Source Files + + Source Files + Source Files -- cgit v1.2.3