diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-lower-generic-call.cpp | 107 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 3 |
3 files changed, 116 insertions, 5 deletions
diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp index e3080c612..577b4e86d 100644 --- a/source/slang/slang-ir-lower-generic-call.cpp +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -8,6 +8,9 @@ namespace Slang { SharedGenericsLoweringContext* sharedContext; + // Map from interface requirement keys to its corresponding dispatch method. + OrderedDictionary<IRInst*, IRFunc*> mapInterfaceRequirementKeyToDispatchMethods; + // Represents a work item for unpacking `inout` or `out` arguments after a generic call. struct ArgumentUnpackWorkItem { @@ -85,6 +88,67 @@ namespace Slang return value; } + // Create a dispatch function for a interface method. + // On CPU, the dispatch function is implemented as a witness table lookup followed by + // a function-pointer call. + // TODO: On GPU targets, we should implement the dispatch function with a `switch` statement + // based on the type ID. + IRFunc* _createInterfaceDispatchMethod( + IRBuilder* builder, + IRInterfaceType* interfaceType, + IRInst* requirementKey, + IRInst* requirementVal) + { + auto func = builder->createFunc(); + if (auto linkage = requirementKey->findDecoration<IRLinkageDecoration>()) + { + builder->addNameHintDecoration(func, linkage->getMangledName()); + } + + auto reqFuncType = cast<IRFuncType>(requirementVal); + List<IRType*> paramTypes; + paramTypes.add(builder->getWitnessTableType(interfaceType)); + for (UInt i = 0; i < reqFuncType->getParamCount(); i++) + { + paramTypes.add(reqFuncType->getParamType(i)); + } + auto dispatchFuncType = builder->getFuncType(paramTypes, reqFuncType->getResultType()); + func->setFullType(dispatchFuncType); + builder->setInsertInto(func); + builder->emitBlock(); + List<IRInst*> params; + IRParam* witnessTableParam = builder->emitParam(paramTypes[0]); + for (Index i = 1; i < paramTypes.getCount(); i++) + { + params.add(builder->emitParam(paramTypes[i])); + } + auto callee = builder->emitLookupInterfaceMethodInst( + reqFuncType, witnessTableParam, requirementKey); + auto call = (IRCall*)builder->emitCallInst(reqFuncType->getResultType(), callee, params); + if (call->getDataType()->op == kIROp_VoidType) + builder->emitReturn(); + else + builder->emitReturn(call); + return func; + } + + // If an interface dispatch method is already created, return it. + // Otherwise, create the method. + IRFunc* getOrCreateInterfaceDispatchMethod( + IRBuilder* builder, + IRInterfaceType* interfaceType, + IRInst* requirementKey, + IRInst* requirementVal) + { + if (auto func = mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(requirementKey)) + return *func; + auto dispatchFunc = + _createInterfaceDispatchMethod(builder, interfaceType, requirementKey, requirementVal); + mapInterfaceRequirementKeyToDispatchMethods.AddIfNotExists( + requirementKey, dispatchFunc); + return dispatchFunc; + } + // Translate `callInst` into a call of `newCallee`, and respect the new `funcType`. // If `newCallee` is a lowered generic function, `specializeInst` contains the type // arguments used to specialize the callee. @@ -213,12 +277,45 @@ namespace Slang { // If we see a call(lookup_interface_method(...), ...), we need to translate // all occurences of associatedtypes. - auto funcType = cast<IRFuncType>(lookupInst->getDataType()); - auto loweredFunc = lookupInst; - if (isBuiltin(cast<IRWitnessTableType>( - lookupInst->getWitnessTable()->getDataType())->getConformanceType())) + auto interfaceType = cast<IRInterfaceType>( + cast<IRWitnessTableType>(lookupInst->getWitnessTable()->getDataType()) + ->getConformanceType()); + if (isBuiltin(interfaceType)) return; - translateCallInst(callInst, funcType, loweredFunc, nullptr); + + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(callInst); + + // Create interface dispatch method that bottlenecks the dispatch logic. + auto requirementKey = lookupInst->getRequirementKey(); + auto requirementVal = + sharedContext->findInterfaceRequirementVal(interfaceType, requirementKey); + auto dispatchFunc = getOrCreateInterfaceDispatchMethod( + builder, interfaceType, requirementKey, requirementVal); + + auto parentFunc = getParentFunc(callInst); + // Don't process the call inst that is the one in the dispatch function itself. + if (parentFunc == dispatchFunc) + return; + + // Replace `callInst` with a new call inst that calls `dispatchFunc` instead, and + // with the witness table as first argument, + builder->setInsertBefore(callInst); + List<IRInst*> newArgs; + newArgs.add(lookupInst->getWitnessTable()); + for (UInt i = 0; i < callInst->getArgCount(); i++) + newArgs.add(callInst->getArg(i)); + auto newCall = + (IRCall*)builder->emitCallInst(callInst->getFullType(), dispatchFunc, newArgs); + callInst->replaceUsesWith(newCall); + callInst->removeAndDeallocate(); + + // Translate the new call inst as normal, taking care of packing/unpacking inputs + // and outputs. + translateCallInst( + newCall, cast<IRFuncType>(dispatchFunc->getFullType()), dispatchFunc, nullptr); } void lowerCall(IRCall* callInst) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6f0cc43e2..778f066dd 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5559,5 +5559,16 @@ namespace Slang { return inst->findDecoration<IRBuiltinDecoration>() != nullptr; } + IRFunc* getParentFunc(IRInst* inst) + { + auto parent = inst->getParent(); + while (parent) + { + if (auto func = as<IRFunc>(parent)) + return func; + parent = parent->getParent(); + } + return nullptr; + } } // namespace Slang diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 03d29280a..d5bd55442 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1488,6 +1488,9 @@ bool isPointerOfType(IRInst* ptrType, IROp opCode); // True if the IR inst represents a builtin object (e.g. __BuiltinFloatingPointType). bool isBuiltin(IRInst* inst); + // Get the enclosuing function of an instruction. +IRFunc* getParentFunc(IRInst* inst); + } #endif |
