diff options
| author | Yong He <yonghe@outlook.com> | 2020-10-29 10:21:07 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-10-29 10:21:07 -0700 |
| commit | 060071604bc715951ddf940a51ced1da48b3dd10 (patch) | |
| tree | 19daa4c23bdc5098e8bf5c1e28d5dbe1a389eca3 | |
| parent | 494e09af2cebafa34db49dc1f60afd43aebed619 (diff) | |
Generate `switch` based dynamic dispatch logic. (#1591)
Co-authored-by: Tim Foley <tim.foley.is@gmail.com>
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 27 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generics.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-dispatch.cpp | 138 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 4 |
7 files changed, 145 insertions, 53 deletions
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 6a82815a9..d1250357a 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1651,35 +1651,27 @@ void CPPSourceEmitter::_emitWitnessTableDefinitions() m_writer->emit(getName(witnessTable)); m_writer->emit(" = {\n"); m_writer->indent(); - bool isFirstEntry = true; + auto seqIdDecoration = witnessTable->findDecoration<IRSequentialIDDecoration>(); + SLANG_ASSERT(seqIdDecoration); + m_writer->emit((UInt)seqIdDecoration->getSequentialID()); for (Index i = 0; i < sortedWitnessTableEntries.getCount(); i++) { auto entry = sortedWitnessTableEntries[i]; if (auto funcVal = as<IRFunc>(entry->satisfyingVal.get())) { - if (!isFirstEntry) - m_writer->emit(",\n"); - else - isFirstEntry = false; - + m_writer->emit(",\n"); m_writer->emit(getName(funcVal)); } else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal())) { - if (!isFirstEntry) - m_writer->emit(",\n"); - else - isFirstEntry = false; + m_writer->emit(",\n"); m_writer->emit("&"); m_writer->emit(getName(witnessTableVal)); } else if (entry->getSatisfyingVal() && isPointerOfType(entry->getSatisfyingVal()->getDataType(), kIROp_RTTIType)) { - if (!isFirstEntry) - m_writer->emit(",\n"); - else - isFirstEntry = false; + m_writer->emit(",\n"); emitInstExpr(entry->getSatisfyingVal(), getInfo(EmitOp::General)); } else @@ -1704,6 +1696,7 @@ void CPPSourceEmitter::emitInterface(IRInterfaceType* interfaceType) emitSimpleType(interfaceType); m_writer->emit("\n{\n"); m_writer->indent(); + m_writer->emit("uint32_t sequentialID;\n"); for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); @@ -2196,6 +2189,12 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut m_writer->emit(getName(inst->getOperand(1))); return true; } + case kIROp_GetSequentialID: + { + emitInstExpr(inst->getOperand(0), inOuterPrec); + m_writer->emit("->sequentialID"); + return true; + } case kIROp_WitnessTable: { m_writer->emit("(&"); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index b225cbf0b..3681077cb 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -229,6 +229,7 @@ INST(DefaultConstruct, defaultConstruct, 0, 0) INST(Specialize, specialize, 2, 0) INST(lookup_interface_method, lookup_interface_method, 2, 0) +INST(GetSequentialID, GetSequentialID, 1, 0) INST(lookup_witness_table, lookup_witness_table, 2, 0) INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index aedcea232..d3f94abed 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -497,6 +497,14 @@ struct IRLookupWitnessMethod : IRInst IR_LEAF_ISA(lookup_interface_method) }; +// Returns the sequential ID of an RTTI object. +struct IRGetSequentialID : IRInst +{ + IR_LEAF_ISA(GetSequentialID) + + IRInst* getRTTIOperand() { return getOperand(0); } +}; + struct IRLookupWitnessTable : IRInst { IRUse sourceType; @@ -1916,6 +1924,8 @@ struct IRBuilder IRInst* witnessTableVal, IRInst* interfaceMethodVal); + IRInst* emitGetSequentialIDInst(IRInst* rttiObj); + IRInst* emitAlloca(IRInst* type, IRInst* rttiObjPtr); IRInst* emitCopy(IRInst* dst, IRInst* src, IRInst* rttiObjPtr); diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 4b86cff51..11bb400b0 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -59,13 +59,9 @@ namespace Slang if (sink->getErrorCount() != 0) return; - // On non-CPU targets, generate `if` based dispatch functions. - if (sharedContext.targetReq->getTarget() != CodeGenTarget::CPPSource) - { - specializeDispatchFunctions(&sharedContext); - if (sink->getErrorCount() != 0) - return; - } + specializeDispatchFunctions(&sharedContext); + if (sink->getErrorCount() != 0) + return; // We might have generated new temporary variables during lowering. // An SSA pass can clean up unnecessary load/stores. diff --git a/source/slang/slang-ir-specialize-dispatch.cpp b/source/slang/slang-ir-specialize-dispatch.cpp index 05ed867bd..ddbb743a8 100644 --- a/source/slang/slang-ir-specialize-dispatch.cpp +++ b/source/slang/slang-ir-specialize-dispatch.cpp @@ -16,7 +16,7 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key) return nullptr; } -void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc) +IRFunc* specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IRFunc* dispatchFunc) { auto witnessTableType = cast<IRFuncType>(dispatchFunc->getDataType())->getParamType(0); @@ -62,35 +62,63 @@ void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IR IRBuilder builderStorage; auto builder = &builderStorage; builder->sharedBuilder = &sharedContext->sharedBuilderStorage; - builder->setInsertBefore(callInst); + builder->setInsertBefore(dispatchFunc); + + // Create a new dispatch func to replace the existing one. + auto newDispatchFunc = builder->createFunc(); + + List<IRType*> paramTypes; + for (auto paramInst : dispatchFunc->getParams()) + { + paramTypes.add(paramInst->getFullType()); + } + + // Modify the first paramter from IRWitnessTable to UInt representing the sequential ID. + paramTypes[0] = builder->getUIntType(); + + auto newDipsatchFuncType = builder->getFuncType(paramTypes, dispatchFunc->getResultType()); + newDispatchFunc->setFullType(newDipsatchFuncType); + dispatchFunc->transferDecorationsTo(newDispatchFunc); + + builder->setInsertInto(newDispatchFunc); + auto newBlock = builder->emitBlock(); + + IRBlock* defaultBlock = nullptr; - auto witnessTableParam = block->getFirstParam(); auto requirementKey = lookupInst->getRequirementKey(); List<IRInst*> params; - for (auto param = block->getFirstParam()->getNextParam(); param; param = param->getNextParam()) + for (Index i = 0; i < paramTypes.getCount(); i++) { - params.add(param); + auto param = builder->emitParam(paramTypes[i]); + if (i > 0) + params.add(param); } + auto witnessTableParam = newBlock->getFirstParam(); - // Emit cascaded if statements to call the correct concrete function based on - // the witness table pointer passed in. - auto ifBlock = block; + // Generate case blocks for each possible witness table. + List<IRInst*> caseBlocks; for (Index i = 0; i < witnessTables.getCount(); i++) { auto witnessTable = witnessTables[i]; - bool isLast = (i == witnessTables.getCount() - 1); - IRInst* cmpArgs[] = + auto seqIdDecoration = witnessTable->findDecoration<IRSequentialIDDecoration>(); + SLANG_ASSERT(seqIdDecoration); + + if (i != witnessTables.getCount() - 1) { - builder->emitBitCast(builder->getUInt64Type(), witnessTableParam), - builder->emitBitCast(builder->getUInt64Type(),(IRInst*)witnessTable) - }; - IRInst* condition = nullptr; - IRBlock* trueBlock = nullptr; - if (!isLast) + // Create a case block if we are not the last case. + caseBlocks.add(seqIdDecoration->getSequentialIDOperand()); + builder->setInsertInto(newDispatchFunc); + auto caseBlock = builder->emitBlock(); + caseBlocks.add(caseBlock); + } + else { - condition = builder->emitIntrinsicInst(builder->getBoolType(), kIROp_Eql, 2, cmpArgs); - trueBlock = builder->emitBlock(); + // Generate code for the last possible value in the `default` block. + builder->setInsertInto(newDispatchFunc); + defaultBlock = builder->emitBlock(); + builder->setInsertInto(defaultBlock); } + auto callee = findWitnessTableEntry(witnessTable, requirementKey); SLANG_ASSERT(callee); auto specializedCallInst = builder->emitCallInst(callInst->getFullType(), callee, params); @@ -98,20 +126,28 @@ void specializeDispatchFunction(SharedGenericsLoweringContext* sharedContext, IR builder->emitReturn(); else builder->emitReturn(specializedCallInst); - if (!isLast) - { - auto falseBlock = builder->emitBlock(); - builder->setInsertInto(ifBlock); - builder->emitIf(condition, trueBlock, falseBlock); - builder->setInsertInto(falseBlock); - ifBlock = falseBlock; - } } + // Emit a switch statement to call the correct concrete function based on + // the witness table sequential ID passed in. + builder->setInsertInto(newDispatchFunc); + auto breakBlock = builder->emitBlock(); + builder->setInsertInto(breakBlock); + builder->emitUnreachable(); + + builder->setInsertInto(newBlock); + builder->emitSwitch( + witnessTableParam, + breakBlock, + defaultBlock, + caseBlocks.getCount(), + caseBlocks.getBuffer()); + // Remove old implementation. - lookupInst->removeAndDeallocate(); - callInst->removeAndDeallocate(); - returnInst->removeAndDeallocate(); + dispatchFunc->replaceUsesWith(newDispatchFunc); + dispatchFunc->removeAndDeallocate(); + + return newDispatchFunc; } // Ensures every witness table object has been assigned a sequential ID. @@ -179,6 +215,40 @@ void ensureWitnessTableSequentialIDs(SharedGenericsLoweringContext* sharedContex } } +// Fixes up call sites of a dispatch function, so that the witness table argument is replaced with +// its sequential ID. +void fixupDispatchFuncCall(SharedGenericsLoweringContext* sharedContext, IRFunc* newDispatchFunc) +{ + List<IRInst*> users; + for (auto use = newDispatchFunc->firstUse; use; use = use->nextUse) + { + users.add(use->getUser()); + } + for (auto user : users) + { + if (auto call = as<IRCall>(user)) + { + if (call->getCallee() != newDispatchFunc) + continue; + IRBuilder builder; + builder.sharedBuilder = &sharedContext->sharedBuilderStorage; + builder.setInsertBefore(call); + List<IRInst*> args; + for (UInt i = 0; i < call->getArgCount(); i++) + { + args.add(call->getArg(i)); + } + if (as<IRWitnessTable>(args[0]->getDataType())) + continue; + auto seqIdArg = builder.emitGetSequentialIDInst(args[0]); + args[0] = seqIdArg; + auto newCall = builder.emitCallInst(call->getFullType(), newDispatchFunc, args); + call->replaceUsesWith(newCall); + call->removeAndDeallocate(); + } + } +} + void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext) { sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); @@ -186,10 +256,18 @@ void specializeDispatchFunctions(SharedGenericsLoweringContext* sharedContext) // First we ensure that all witness table objects has a sequential ID assigned. ensureWitnessTableSequentialIDs(sharedContext); + // Generate specialized dispatch functions and fixup call sites. for (auto kv : sharedContext->mapInterfaceRequirementKeyToDispatchMethods) { auto dispatchFunc = kv.Value; - specializeDispatchFunction(sharedContext, dispatchFunc); + + // Generate a specialized `switch` statement based dispatch func, + // from the witness tables present in the module. + auto newDispatchFunc = specializeDispatchFunction(sharedContext, dispatchFunc); + + // Fix up the call sites of newDispatchFunc to pass in sequential IDs instead of + // witness table objects. + fixupDispatchFuncCall(sharedContext, newDispatchFunc); } } } // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 778f066dd..dc106fff1 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2642,6 +2642,14 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitGetSequentialIDInst(IRInst* rttiObj) + { + auto inst = createInst<IRAlloca>(this, kIROp_GetSequentialID, getUIntType(), rttiObj); + + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitAlloca(IRInst* type, IRInst* rttiObjPtr) { auto inst = createInst<IRAlloca>( diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index d5bd55442..3b79d9d14 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -73,8 +73,8 @@ enum IROp : int32_t /* IROpMeta describe values for layout of IROp, as well as values for accessing aspects of IROp bits. */ enum IROpMeta { - kIROpMeta_OtherShift = 8, ///< Number of bits for op (shift right by this to get the other bits) - kIROpMeta_OpMask = 0xff, ///< Mask for just opcode + kIROpMeta_OtherShift = 10, ///< Number of bits for op (shift right by this to get the other bits) + kIROpMeta_OpMask = 0x3ff, ///< Mask for just opcode }; IROp findIROp(const UnownedStringSlice& name); |
