summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2020-10-29 10:21:07 -0700
committerGitHub <noreply@github.com>2020-10-29 10:21:07 -0700
commit060071604bc715951ddf940a51ced1da48b3dd10 (patch)
tree19daa4c23bdc5098e8bf5c1e28d5dbe1a389eca3
parent494e09af2cebafa34db49dc1f60afd43aebed619 (diff)
Generate `switch` based dynamic dispatch logic. (#1591)
Co-authored-by: Tim Foley <tim.foley.is@gmail.com>
-rw-r--r--source/slang/slang-emit-cpp.cpp27
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h10
-rw-r--r--source/slang/slang-ir-lower-generics.cpp10
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp138
-rw-r--r--source/slang/slang-ir.cpp8
-rw-r--r--source/slang/slang-ir.h4
7 files changed, 145 insertions, 53 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 6a82815a9..d1250357a 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -1651,35 +1651,27 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions()
m_writer->emit(getName(witnessTable));
m_writer->emit(" = {\n");
m_writer->indent();
- bool isFirstEntry = true;
+ auto seqIdDecoration = witnessTable->findDecoration<IRSequentialIDDecoration>();
+ SLANG_ASSERT(seqIdDecoration);
+ m_writer->emit((UInt)seqIdDecoration->getSequentialID());
for (Index i = 0; i < sortedWitnessTableEntries.getCount(); i++)
{
auto entry = sortedWitnessTableEntries[i];
if (auto funcVal = as<IRFunc>(entry->satisfyingVal.get()))
{
- if (!isFirstEntry)
- m_writer->emit(",\n");
- else
- isFirstEntry = false;
-
+ m_writer->emit(",\n");
m_writer->emit(getName(funcVal));
}
else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal()))
{
- if (!isFirstEntry)
- m_writer->emit(",\n");
- else
- isFirstEntry = false;
+ m_writer->emit(",\n");
m_writer->emit("&");
m_writer->emit(getName(witnessTableVal));
}
else if (entry->getSatisfyingVal() &&
isPointerOfType(entry->getSatisfyingVal()->getDataType(), kIROp_RTTIType))
{
- if (!isFirstEntry)
- m_writer->emit(",\n");
- else
- isFirstEntry = false;
+ m_writer->emit(",\n");
emitInstExpr(entry->getSatisfyingVal(), getInfo(EmitOp::General));
}
else
@@ -1704,6 +1696,7 @@ void CPPSourceEmitter::emitInterface(IRInterfaceType* interfaceType)
emitSimpleType(interfaceType);
m_writer->emit("\n{\n");
m_writer->indent();
+ m_writer->emit("uint32_t sequentialID;\n");
for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
{
auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
@@ -2196,6 +2189,12 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
m_writer->emit(getName(inst->getOperand(1)));
return true;
}
+ case kIROp_GetSequentialID:
+ {
+ emitInstExpr(inst->getOperand(0), inOuterPrec);
+ m_writer->emit("->sequentialID");
+ return true;
+ }
case kIROp_WitnessTable:
{
m_writer->emit("(&");
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index b225cbf0b..3681077cb 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -229,6 +229,7 @@ INST(DefaultConstruct, defaultConstruct, 0, 0)
INST(Specialize, specialize, 2, 0)
INST(lookup_interface_method, lookup_interface_method, 2, 0)
+INST(GetSequentialID, GetSequentialID, 1, 0)
INST(lookup_witness_table, lookup_witness_table, 2, 0)
INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index aedcea232..d3f94abed 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -497,6 +497,14 @@ struct IRLookupWitnessMethod : IRInst
IR_LEAF_ISA(lookup_interface_method)
};
+// Returns the sequential ID of an RTTI object.
+struct IRGetSequentialID : IRInst
+{
+ IR_LEAF_ISA(GetSequentialID)
+
+ IRInst* getRTTIOperand() { return getOperand(0); }
+};
+
struct IRLookupWitnessTable : IRInst
{
IRUse sourceType;
@@ -1916,6 +1924,8 @@ struct IRBuilder
IRInst* witnessTableVal,
IRInst* interfaceMethodVal);
+ IRInst* emitGetSequentialIDInst(IRInst* rttiObj);
+
IRInst* emitAlloca(IRInst* type, IRInst* rttiObjPtr);
IRInst* emitCopy(IRInst* dst, IRInst* src, IRInst* rttiObjPtr);
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index 4b86cff51..11bb400b0 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -59,13 +59,9 @@ namespace Slang
if (sink->getErrorCount() != 0)
return;
- // On non-CPU targets, generate `if` based dispatch functions.
- if (sharedContext.targetReq->getTarget() != CodeGenTarget::CPPSource)
- {
- specializeDispatchFunctions(&sharedContext);
- if (sink->getErrorCount() != 0)
- return;
- }
+ specializeDispatchFunctions(&sharedContext);
+ if (sink->getErrorCount() != 0)
+ return;
// We might have generated new temporary variables during lowering.
// An SSA pass can clean up unnecessary load/stores.
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
index 05ed867bd..ddbb743a8 100644
--- a/source/slang/slang-ir-specialize-dispatch.cpp
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -16,7 +16,7 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
return nullptr;
}
-void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
+IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
{
auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);
@@ -62,35 +62,63 @@ void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IR
IRBuilder builderStorage;
auto builder = &builderStorage;
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
- builder->setInsertBefore(callInst);
+ builder->setInsertBefore(dispatchFunc);
+
+ // Create a new dispatch func to replace the existing one.
+ auto newDispatchFunc = builder->createFunc();
+
+ List<IRType*> paramTypes;
+ for (auto paramInst : dispatchFunc->getParams())
+ {
+ paramTypes.add(paramInst->getFullType());
+ }
+
+ // Modify the first paramter from IRWitnessTable to UInt representing the sequential ID.
+ paramTypes[0] = builder->getUIntType();
+
+ auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType());
+ newDispatchFunc->setFullType(newDipsatchFuncType);
+ dispatchFunc->transferDecorationsTo(newDispatchFunc);
+
+ builder->setInsertInto(newDispatchFunc);
+ auto newBlock = builder->emitBlock();
+
+ IRBlock* defaultBlock = nullptr;
- auto witnessTableParam = block->getFirstParam();
auto requirementKey = lookupInst->getRequirementKey();
List<IRInst*> params;
- for (auto param = block->getFirstParam()->getNextParam(); param; param = param->getNextParam())
+ for (Index i = 0; i < paramTypes.getCount(); i++)
{
- params.add(param);
+ auto param = builder->emitParam(paramTypes[i]);
+ if (i > 0)
+ params.add(param);
}
+ auto witnessTableParam = newBlock->getFirstParam();
- // Emit cascaded if statements to call the correct concrete function based on
- // the witness table pointer passed in.
- auto ifBlock = block;
+ // Generate case blocks for each possible witness table.
+ List<IRInst*> caseBlocks;
for (Index i = 0; i < witnessTables.getCount(); i++)
{
auto witnessTable = witnessTables[i];
- bool isLast = (i == witnessTables.getCount() - 1);
- IRInst* cmpArgs[] =
+ auto seqIdDecoration = witnessTable->findDecoration<IRSequentialIDDecoration>();
+ SLANG_ASSERT(seqIdDecoration);
+
+ if (i != witnessTables.getCount() - 1)
{
- builder->emitBitCast(builder->getUInt64Type(), witnessTableParam),
- builder->emitBitCast(builder->getUInt64Type(),(IRInst*)witnessTable)
- };
- IRInst* condition = nullptr;
- IRBlock* trueBlock = nullptr;
- if (!isLast)
+ // Create a case block if we are not the last case.
+ caseBlocks.add(seqIdDecoration->getSequentialIDOperand());
+ builder->setInsertInto(newDispatchFunc);
+ auto caseBlock = builder->emitBlock();
+ caseBlocks.add(caseBlock);
+ }
+ else
{
- condition = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_Eql, 2, cmpArgs);
- trueBlock = builder->emitBlock();
+ // Generate code for the last possible value in the `default` block.
+ builder->setInsertInto(newDispatchFunc);
+ defaultBlock = builder->emitBlock();
+ builder->setInsertInto(defaultBlock);
}
+
auto callee = findWitnessTableEntry(witnessTable, requirementKey);
SLANG_ASSERT(callee);
auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params);
@@ -98,20 +126,28 @@ void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IR
builder->emitReturn();
else
builder->emitReturn(specializedCallInst);
- if (!isLast)
- {
- auto falseBlock = builder->emitBlock();
- builder->setInsertInto(ifBlock);
- builder->emitIf(condition, trueBlock, falseBlock);
- builder->setInsertInto(falseBlock);
- ifBlock = falseBlock;
- }
}
+ // Emit a switch statement to call the correct concrete function based on
+ // the witness table sequential ID passed in.
+ builder->setInsertInto(newDispatchFunc);
+ auto breakBlock = builder->emitBlock();
+ builder->setInsertInto(breakBlock);
+ builder->emitUnreachable();
+
+ builder->setInsertInto(newBlock);
+ builder->emitSwitch(
+ witnessTableParam,
+ breakBlock,
+ defaultBlock,
+ caseBlocks.getCount(),
+ caseBlocks.getBuffer());
+
// Remove old implementation.
- lookupInst->removeAndDeallocate();
- callInst->removeAndDeallocate();
- returnInst->removeAndDeallocate();
+ dispatchFunc->replaceUsesWith(newDispatchFunc);
+ dispatchFunc->removeAndDeallocate();
+
+ return newDispatchFunc;
}
// Ensures every witness table object has been assigned a sequential ID.
@@ -179,6 +215,40 @@ void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContex
}
}
+// Fixes up call sites of a dispatch function, so that the witness table argument is replaced with
+// its sequential ID.
+void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc* newDispatchFunc)
+{
+ List<IRInst*> users;
+ for (auto use = newDispatchFunc->firstUse; use; use = use->nextUse)
+ {
+ users.add(use->getUser());
+ }
+ for (auto user : users)
+ {
+ if (auto call = as<IRCall>(user))
+ {
+ if (call->getCallee() != newDispatchFunc)
+ continue;
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder.setInsertBefore(call);
+ List<IRInst*> args;
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ args.add(call->getArg(i));
+ }
+ if (as<IRWitnessTable>(args[0]->getDataType()))
+ continue;
+ auto seqIdArg = builder.emitGetSequentialIDInst(args[0]);
+ args[0] = seqIdArg;
+ auto newCall = builder.emitCallInst(call->getFullType(), newDispatchFunc, args);
+ call->replaceUsesWith(newCall);
+ call->removeAndDeallocate();
+ }
+ }
+}
+
void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext)
{
sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
@@ -186,10 +256,18 @@ void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext)
// First we ensure that all witness table objects has a sequential ID assigned.
ensureWitnessTableSequentialIDs(sharedContext);
+ // Generate specialized dispatch functions and fixup call sites.
for (auto kv : sharedContext->mapInterfaceRequirementKeyToDispatchMethods)
{
auto dispatchFunc = kv.Value;
- specializeDispatchFunction(sharedContext, dispatchFunc);
+
+ // Generate a specialized `switch` statement based dispatch func,
+ // from the witness tables present in the module.
+ auto newDispatchFunc = specializeDispatchFunction(sharedContext, dispatchFunc);
+
+ // Fix up the call sites of newDispatchFunc to pass in sequential IDs instead of
+ // witness table objects.
+ fixupDispatchFuncCall(sharedContext, newDispatchFunc);
}
}
} // namespace Slang
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 778f066dd..dc106fff1 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2642,6 +2642,14 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitGetSequentialIDInst(IRInst* rttiObj)
+ {
+ auto inst = createInst<IRAlloca>(this, kIROp_GetSequentialID, getUIntType(), rttiObj);
+
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitAlloca(IRInst* type, IRInst* rttiObjPtr)
{
auto inst = createInst<IRAlloca>(
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index d5bd55442..3b79d9d14 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -73,8 +73,8 @@ enum IROp : int32_t
/* IROpMeta describe values for layout of IROp, as well as values for accessing aspects of IROp bits. */
enum IROpMeta
{
- kIROpMeta_OtherShift = 8, ///< Number of bits for op (shift right by this to get the other bits)
- kIROpMeta_OpMask = 0xff, ///< Mask for just opcode
+ kIROpMeta_OtherShift = 10, ///< Number of bits for op (shift right by this to get the other bits)
+ kIROpMeta_OpMask = 0x3ff, ///< Mask for just opcode
};
IROp findIROp(const UnownedStringSlice& name);