summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-specialize-dispatch.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2020-10-29 10:21:07 -0700
committerGitHub <noreply@github.com>2020-10-29 10:21:07 -0700
commit060071604bc715951ddf940a51ced1da48b3dd10 (patch)
tree19daa4c23bdc5098e8bf5c1e28d5dbe1a389eca3 /source/slang/slang-ir-specialize-dispatch.cpp
parent494e09af2cebafa34db49dc1f60afd43aebed619 (diff)
Generate `switch` based dynamic dispatch logic. (#1591)
Co-authored-by: Tim Foley <tim.foley.is@gmail.com>
Diffstat (limited to 'source/slang/slang-ir-specialize-dispatch.cpp')
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp138
1 files changed, 108 insertions, 30 deletions
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
index 05ed867bd..ddbb743a8 100644
--- a/source/slang/slang-ir-specialize-dispatch.cpp
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -16,7 +16,7 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
return nullptr;
}
-void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
+IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
{
auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);
@@ -62,35 +62,63 @@ void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IR
IRBuilder builderStorage;
auto builder = &builderStorage;
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
- builder->setInsertBefore(callInst);
+ builder->setInsertBefore(dispatchFunc);
+
+ // Create a new dispatch func to replace the existing one.
+ auto newDispatchFunc = builder->createFunc();
+
+ List<IRType*> paramTypes;
+ for (auto paramInst : dispatchFunc->getParams())
+ {
+ paramTypes.add(paramInst->getFullType());
+ }
+
+ // Modify the first paramter from IRWitnessTable to UInt representing the sequential ID.
+ paramTypes[0] = builder->getUIntType();
+
+ auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType());
+ newDispatchFunc->setFullType(newDipsatchFuncType);
+ dispatchFunc->transferDecorationsTo(newDispatchFunc);
+
+ builder->setInsertInto(newDispatchFunc);
+ auto newBlock = builder->emitBlock();
+
+ IRBlock* defaultBlock = nullptr;
- auto witnessTableParam = block->getFirstParam();
auto requirementKey = lookupInst->getRequirementKey();
List<IRInst*> params;
- for (auto param = block->getFirstParam()->getNextParam(); param; param = param->getNextParam())
+ for (Index i = 0; i < paramTypes.getCount(); i++)
{
- params.add(param);
+ auto param = builder->emitParam(paramTypes[i]);
+ if (i > 0)
+ params.add(param);
}
+ auto witnessTableParam = newBlock->getFirstParam();
- // Emit cascaded if statements to call the correct concrete function based on
- // the witness table pointer passed in.
- auto ifBlock = block;
+ // Generate case blocks for each possible witness table.
+ List<IRInst*> caseBlocks;
for (Index i = 0; i < witnessTables.getCount(); i++)
{
auto witnessTable = witnessTables[i];
- bool isLast = (i == witnessTables.getCount() - 1);
- IRInst* cmpArgs[] =
+ auto seqIdDecoration = witnessTable->findDecoration<IRSequentialIDDecoration>();
+ SLANG_ASSERT(seqIdDecoration);
+
+ if (i != witnessTables.getCount() - 1)
{
- builder->emitBitCast(builder->getUInt64Type(), witnessTableParam),
- builder->emitBitCast(builder->getUInt64Type(),(IRInst*)witnessTable)
- };
- IRInst* condition = nullptr;
- IRBlock* trueBlock = nullptr;
- if (!isLast)
+ // Create a case block if we are not the last case.
+ caseBlocks.add(seqIdDecoration->getSequentialIDOperand());
+ builder->setInsertInto(newDispatchFunc);
+ auto caseBlock = builder->emitBlock();
+ caseBlocks.add(caseBlock);
+ }
+ else
{
- condition = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_Eql, 2, cmpArgs);
- trueBlock = builder->emitBlock();
+ // Generate code for the last possible value in the `default` block.
+ builder->setInsertInto(newDispatchFunc);
+ defaultBlock = builder->emitBlock();
+ builder->setInsertInto(defaultBlock);
}
+
auto callee = findWitnessTableEntry(witnessTable, requirementKey);
SLANG_ASSERT(callee);
auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params);
@@ -98,20 +126,28 @@ void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IR
builder->emitReturn();
else
builder->emitReturn(specializedCallInst);
- if (!isLast)
- {
- auto falseBlock = builder->emitBlock();
- builder->setInsertInto(ifBlock);
- builder->emitIf(condition, trueBlock, falseBlock);
- builder->setInsertInto(falseBlock);
- ifBlock = falseBlock;
- }
}
+ // Emit a switch statement to call the correct concrete function based on
+ // the witness table sequential ID passed in.
+ builder->setInsertInto(newDispatchFunc);
+ auto breakBlock = builder->emitBlock();
+ builder->setInsertInto(breakBlock);
+ builder->emitUnreachable();
+
+ builder->setInsertInto(newBlock);
+ builder->emitSwitch(
+ witnessTableParam,
+ breakBlock,
+ defaultBlock,
+ caseBlocks.getCount(),
+ caseBlocks.getBuffer());
+
// Remove old implementation.
- lookupInst->removeAndDeallocate();
- callInst->removeAndDeallocate();
- returnInst->removeAndDeallocate();
+ dispatchFunc->replaceUsesWith(newDispatchFunc);
+ dispatchFunc->removeAndDeallocate();
+
+ return newDispatchFunc;
}
// Ensures every witness table object has been assigned a sequential ID.
@@ -179,6 +215,40 @@ void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContex
}
}
+// Fixes up call sites of a dispatch function, so that the witness table argument is replaced with
+// its sequential ID.
+void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc* newDispatchFunc)
+{
+ List<IRInst*> users;
+ for (auto use = newDispatchFunc->firstUse; use; use = use->nextUse)
+ {
+ users.add(use->getUser());
+ }
+ for (auto user : users)
+ {
+ if (auto call = as<IRCall>(user))
+ {
+ if (call->getCallee() != newDispatchFunc)
+ continue;
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder.setInsertBefore(call);
+ List<IRInst*> args;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ args.add(call->getArg(i));
+ }
+ if (as<IRWitnessTable>(args[0]->getDataType()))
+ continue;
+ auto seqIdArg = builder.emitGetSequentialIDInst(args[0]);
+ args[0] = seqIdArg;
+ auto newCall = builder.emitCallInst(call->getFullType(), newDispatchFunc, args);
+ call->replaceUsesWith(newCall);
+ call->removeAndDeallocate();
+ }
+ }
+}
+
void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext)
{
sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
@@ -186,10 +256,18 @@ void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext)
// First we ensure that all witness table objects has a sequential ID assigned.
ensureWitnessTableSequentialIDs(sharedContext);
+ // Generate specialized dispatch functions and fixup call sites.
for (auto kv : sharedContext->mapInterfaceRequirementKeyToDispatchMethods)
{
auto dispatchFunc = kv.Value;
- specializeDispatchFunction(sharedContext, dispatchFunc);
+
+ // Generate a specialized `switch` statement based dispatch func,
+ // from the witness tables present in the module.
+ auto newDispatchFunc = specializeDispatchFunction(sharedContext, dispatchFunc);
+
+ // Fix up the call sites of newDispatchFunc to pass in sequential IDs instead of
+ // witness table objects.
+ fixupDispatchFuncCall(sharedContext, newDispatchFunc);
}
}
} // namespace Slang