diff options
| author | Yong He <yonghe@outlook.com> | 2020-10-29 10:21:07 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-10-29 10:21:07 -0700 |
| commit | 060071604bc715951ddf940a51ced1da48b3dd10 (patch) | |
| tree | 19daa4c23bdc5098e8bf5c1e28d5dbe1a389eca3 /source/slang/slang-ir-specialize-dispatch.cpp | |
| parent | 494e09af2cebafa34db49dc1f60afd43aebed619 (diff) | |
Generate `switch` based dynamic dispatch logic. (#1591)
Co-authored-by: Tim Foley <tim.foley.is@gmail.com>
Diffstat (limited to 'source/slang/slang-ir-specialize-dispatch.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize-dispatch.cpp | 138 |
1 files changed, 108 insertions, 30 deletions
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index 05ed867bd..ddbb743a8 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -16,7 +16,7 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key) return nullptr; } -void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc) +IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc) { auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0); @@ -62,35 +62,63 @@ void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IR IRBuilder builderStorage; auto builder = &builderStorage; builder->sharedBuilder = &sharedContext->sharedBuilderStorage; - builder->setInsertBefore(callInst); + builder->setInsertBefore(dispatchFunc); + + // Create a new dispatch func to replace the existing one. + auto newDispatchFunc = builder->createFunc(); + + List<IRType*> paramTypes; + for (auto paramInst : dispatchFunc->getParams()) + { + paramTypes.add(paramInst->getFullType()); + } + + // Modify the first paramter from IRWitnessTable to UInt representing the sequential ID. + paramTypes[0] = builder->getUIntType(); + + auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType()); + newDispatchFunc->setFullType(newDipsatchFuncType); + dispatchFunc->transferDecorationsTo(newDispatchFunc); + + builder->setInsertInto(newDispatchFunc); + auto newBlock = builder->emitBlock(); + + IRBlock* defaultBlock = nullptr; - auto witnessTableParam = block->getFirstParam(); auto requirementKey = lookupInst->getRequirementKey(); List<IRInst*> params; - for (auto param = block->getFirstParam()->getNextParam(); param; param = param->getNextParam()) + for (Index i = 0; i < paramTypes.getCount(); i++) { - params.add(param); + auto param = builder->emitParam(paramTypes[i]); + if (i > 0) + params.add(param); } + auto witnessTableParam = newBlock->getFirstParam(); - // Emit cascaded if statements to call the correct concrete function based on - // the witness table pointer passed in. - auto ifBlock = block; + // Generate case blocks for each possible witness table. + List<IRInst*> caseBlocks; for (Index i = 0; i < witnessTables.getCount(); i++) { auto witnessTable = witnessTables[i]; - bool isLast = (i == witnessTables.getCount() - 1); - IRInst* cmpArgs[] = + auto seqIdDecoration = witnessTable->findDecoration<IRSequentialIDDecoration>(); + SLANG_ASSERT(seqIdDecoration); + + if (i != witnessTables.getCount() - 1) { - builder->emitBitCast(builder->getUInt64Type(), witnessTableParam), - builder->emitBitCast(builder->getUInt64Type(),(IRInst*)witnessTable) - }; - IRInst* condition = nullptr; - IRBlock* trueBlock = nullptr; - if (!isLast) + // Create a case block if we are not the last case. + caseBlocks.add(seqIdDecoration->getSequentialIDOperand()); + builder->setInsertInto(newDispatchFunc); + auto caseBlock = builder->emitBlock(); + caseBlocks.add(caseBlock); + } + else { - condition = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_Eql, 2, cmpArgs); - trueBlock = builder->emitBlock(); + // Generate code for the last possible value in the `default` block. + builder->setInsertInto(newDispatchFunc); + defaultBlock = builder->emitBlock(); + builder->setInsertInto(defaultBlock); } + auto callee = findWitnessTableEntry(witnessTable, requirementKey); SLANG_ASSERT(callee); auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params); @@ -98,20 +126,28 @@ void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IR builder->emitReturn(); else builder->emitReturn(specializedCallInst); - if (!isLast) - { - auto falseBlock = builder->emitBlock(); - builder->setInsertInto(ifBlock); - builder->emitIf(condition, trueBlock, falseBlock); - builder->setInsertInto(falseBlock); - ifBlock = falseBlock; - } } + // Emit a switch statement to call the correct concrete function based on + // the witness table sequential ID passed in. + builder->setInsertInto(newDispatchFunc); + auto breakBlock = builder->emitBlock(); + builder->setInsertInto(breakBlock); + builder->emitUnreachable(); + + builder->setInsertInto(newBlock); + builder->emitSwitch( + witnessTableParam, + breakBlock, + defaultBlock, + caseBlocks.getCount(), + caseBlocks.getBuffer()); + // Remove old implementation. - lookupInst->removeAndDeallocate(); - callInst->removeAndDeallocate(); - returnInst->removeAndDeallocate(); + dispatchFunc->replaceUsesWith(newDispatchFunc); + dispatchFunc->removeAndDeallocate(); + + return newDispatchFunc; } // Ensures every witness table object has been assigned a sequential ID. @@ -179,6 +215,40 @@ void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContex } } +// Fixes up call sites of a dispatch function, so that the witness table argument is replaced with +// its sequential ID. +void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc* newDispatchFunc) +{ + List<IRInst*> users; + for (auto use = newDispatchFunc->firstUse; use; use = use->nextUse) + { + users.add(use->getUser()); + } + for (auto user : users) + { + if (auto call = as<IRCall>(user)) + { + if (call->getCallee() != newDispatchFunc) + continue; + IRBuilder builder; + builder.sharedBuilder = &sharedContext->sharedBuilderStorage; + builder.setInsertBefore(call); + List<IRInst*> args; + for (UInt i = 0; i < call->getArgCount(); i++) + { + args.add(call->getArg(i)); + } + if (as<IRWitnessTable>(args[0]->getDataType())) + continue; + auto seqIdArg = builder.emitGetSequentialIDInst(args[0]); + args[0] = seqIdArg; + auto newCall = builder.emitCallInst(call->getFullType(), newDispatchFunc, args); + call->replaceUsesWith(newCall); + call->removeAndDeallocate(); + } + } +} + void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext) { sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); @@ -186,10 +256,18 @@ void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext) // First we ensure that all witness table objects has a sequential ID assigned. ensureWitnessTableSequentialIDs(sharedContext); + // Generate specialized dispatch functions and fixup call sites. for (auto kv : sharedContext->mapInterfaceRequirementKeyToDispatchMethods) { auto dispatchFunc = kv.Value; - specializeDispatchFunction(sharedContext, dispatchFunc); + + // Generate a specialized `switch` statement based dispatch func, + // from the witness tables present in the module. + auto newDispatchFunc = specializeDispatchFunction(sharedContext, dispatchFunc); + + // Fix up the call sites of newDispatchFunc to pass in sequential IDs instead of + // witness table objects. + fixupDispatchFuncCall(sharedContext, newDispatchFunc); } } } // namespace Slang |
