diff options
Diffstat (limited to 'source/slang/slang-ir-specialize-dispatch.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize-dispatch.cpp | 48 |
1 files changed, 33 insertions, 15 deletions
diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index ebf3f1909..fc8f384ec 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -9,10 +9,9 @@ namespace Slang IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc) { auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0); - + auto conformanceType = cast<IRWitnessTableTypeBase>(witnessTableType)->getConformanceType(); // Collect all witness tables of `witnessTableType` in current module. - List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType( - cast<IRWitnessTableTypeBase>(witnessTableType)->getConformanceType()); + List<IRWitnessTable*> witnessTables = sharedContext->getWitnessTablesFromInterfaceType(conformanceType); SLANG_ASSERT(dispatchFunc->getFirstBlock() == dispatchFunc->getLastBlock()); auto block = dispatchFunc->getFirstBlock(); @@ -57,8 +56,8 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, paramTypes.add(paramInst->getFullType()); } - // Modify the first paramter from IRWitnessTable to UInt representing the sequential ID. - paramTypes[0] = builder->getUIntType(); + // Modify the first paramter from IRWitnessTable to IRWitnessTableID representing the sequential ID. + paramTypes[0] = builder->getWitnessTableIDType((IRType*)conformanceType); auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType()); newDispatchFunc->setFullType(newDipsatchFuncType); @@ -79,6 +78,15 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, } auto witnessTableParam = newBlock->getFirstParam(); + // `witnessTableParam` is expected to have `IRWitnessTableID` type, which + // will later lower into a `uint2`. We only use the first element of the uint2 + // to store the sequential ID and reserve the second 32-bit value for future + // pointer-compatibility. We insert a member extract inst right now + // to obtain the first element and use it in our switch statement. + UInt elemIdx = 0; + auto witnessTableSequentialID = + builder->emitSwizzle(builder->getUIntType(), witnessTableParam, 1, &elemIdx); + // Generate case blocks for each possible witness table. List<IRInst*> caseBlocks; for (Index i = 0; i < witnessTables.getCount(); i++) @@ -115,18 +123,28 @@ IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, // Emit a switch statement to call the correct concrete function based on // the witness table sequential ID passed in. builder->setInsertInto(newDispatchFunc); - auto breakBlock = builder->emitBlock(); - builder->setInsertInto(breakBlock); - builder->emitUnreachable(); - builder->setInsertInto(newBlock); - builder->emitSwitch( - witnessTableParam, - breakBlock, - defaultBlock, - caseBlocks.getCount(), - caseBlocks.getBuffer()); + + if (witnessTables.getCount() == 1) + { + // If there is only 1 case, no switch statement is necessary. + builder->setInsertInto(newBlock); + builder->emitBranch(defaultBlock); + } + else + { + auto breakBlock = builder->emitBlock(); + builder->setInsertInto(breakBlock); + builder->emitUnreachable(); + builder->setInsertInto(newBlock); + builder->emitSwitch( + witnessTableSequentialID, + breakBlock, + defaultBlock, + caseBlocks.getCount(), + caseBlocks.getBuffer()); + } // Remove old implementation. dispatchFunc->replaceUsesWith(newDispatchFunc); dispatchFunc->removeAndDeallocate(); |
