summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-specialize-dispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-specialize-dispatch.cpp')
-rw-r--r--source/slang/slang-ir-specialize-dispatch.cpp48
1 files changed, 33 insertions, 15 deletions
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp
index ebf3f1909..fc8f384ec 100644
--- a/source/slang/slang-ir-specialize-dispatch.cpp
+++ b/source/slang/slang-ir-specialize-dispatch.cpp
@@ -9,10 +9,9 @@ namespace Slang
IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc)
{
auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0);
-
+ auto conformanceType = cast<IRWitnessTableTypeBase>(witnessTableType)->getConformanceType();
// Collect all witness tables of `witnessTableType` in current module.
- List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(
- cast<IRWitnessTableTypeBase>(witnessTableType)->getConformanceType());
+ List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(conformanceType);
SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock());
auto block = dispatchFunc->getFirstBlock();
@@ -57,8 +56,8 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
paramTypes.add(paramInst->getFullType());
}
- // Modify the first paramter from IRWitnessTable to UInt representing the sequential ID.
- paramTypes[0] = builder->getUIntType();
+ // Modify the first paramter from IRWitnessTable to IRWitnessTableID representing the sequential ID.
+ paramTypes[0] = builder->getWitnessTableIDType((IRType*)conformanceType);
auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType());
newDispatchFunc->setFullType(newDipsatchFuncType);
@@ -79,6 +78,15 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
}
auto witnessTableParam = newBlock->getFirstParam();
+ // `witnessTableParam` is expected to have `IRWitnessTableID` type, which
+ // will later lower into a `uint2`. We only use the first element of the uint2
+ // to store the sequential ID and reserve the second 32-bit value for future
+ // pointer-compatibility. We insert a member extract inst right now
+ // to obtain the first element and use it in our switch statement.
+ UInt elemIdx = 0;
+ auto witnessTableSequentialID =
+ builder->emitSwizzle(builder->getUIntType(), witnessTableParam, 1, &elemIdx);
+
// Generate case blocks for each possible witness table.
List<IRInst*> caseBlocks;
for (Index i = 0; i < witnessTables.getCount(); i++)
@@ -115,18 +123,28 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext,
// Emit a switch statement to call the correct concrete function based on
// the witness table sequential ID passed in.
builder->setInsertInto(newDispatchFunc);
- auto breakBlock = builder->emitBlock();
- builder->setInsertInto(breakBlock);
- builder->emitUnreachable();
- builder->setInsertInto(newBlock);
- builder->emitSwitch(
- witnessTableParam,
- breakBlock,
- defaultBlock,
- caseBlocks.getCount(),
- caseBlocks.getBuffer());
+
+ if (witnessTables.getCount() == 1)
+ {
+ // If there is only 1 case, no switch statement is necessary.
+ builder->setInsertInto(newBlock);
+ builder->emitBranch(defaultBlock);
+ }
+ else
+ {
+ auto breakBlock = builder->emitBlock();
+ builder->setInsertInto(breakBlock);
+ builder->emitUnreachable();
+ builder->setInsertInto(newBlock);
+ builder->emitSwitch(
+ witnessTableSequentialID,
+ breakBlock,
+ defaultBlock,
+ caseBlocks.getCount(),
+ caseBlocks.getBuffer());
+ }
// Remove old implementation.
dispatchFunc->replaceUsesWith(newDispatchFunc);
dispatchFunc->removeAndDeallocate();