summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-lower-generic-call.cpp107
-rw-r--r--source/slang/slang-ir.cpp11
-rw-r--r--source/slang/slang-ir.h3
3 files changed, 116 insertions, 5 deletions
diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp
index e3080c612..577b4e86d 100644
--- a/source/slang/slang-ir-lower-generic-call.cpp
+++ b/source/slang/slang-ir-lower-generic-call.cpp
@@ -8,6 +8,9 @@ namespace Slang
{
SharedGenericsLoweringContext* sharedContext;
+ // Map from interface requirement keys to its corresponding dispatch method.
+ OrderedDictionary<IRInst*, IRFunc*> mapInterfaceRequirementKeyToDispatchMethods;
+
// Represents a work item for unpacking `inout` or `out` arguments after a generic call.
struct ArgumentUnpackWorkItem
{
@@ -85,6 +88,67 @@ namespace Slang
return value;
}
+ // Create a dispatch function for a interface method.
+ // On CPU, the dispatch function is implemented as a witness table lookup followed by
+ // a function-pointer call.
+ // TODO: On GPU targets, we should implement the dispatch function with a `switch` statement
+ // based on the type ID.
+ IRFunc* _createInterfaceDispatchMethod(
+ IRBuilder* builder,
+ IRInterfaceType* interfaceType,
+ IRInst* requirementKey,
+ IRInst* requirementVal)
+ {
+ auto func = builder->createFunc();
+ if (auto linkage = requirementKey->findDecoration<IRLinkageDecoration>())
+ {
+ builder->addNameHintDecoration(func, linkage->getMangledName());
+ }
+
+ auto reqFuncType = cast<IRFuncType>(requirementVal);
+ List<IRType*> paramTypes;
+ paramTypes.add(builder->getWitnessTableType(interfaceType));
+ for (UInt i = 0; i < reqFuncType->getParamCount(); i++)
+ {
+ paramTypes.add(reqFuncType->getParamType(i));
+ }
+ auto dispatchFuncType = builder->getFuncType(paramTypes, reqFuncType->getResultType());
+ func->setFullType(dispatchFuncType);
+ builder->setInsertInto(func);
+ builder->emitBlock();
+ List<IRInst*> params;
+ IRParam* witnessTableParam = builder->emitParam(paramTypes[0]);
+ for (Index i = 1; i < paramTypes.getCount(); i++)
+ {
+ params.add(builder->emitParam(paramTypes[i]));
+ }
+ auto callee = builder->emitLookupInterfaceMethodInst(
+ reqFuncType, witnessTableParam, requirementKey);
+ auto call = (IRCall*)builder->emitCallInst(reqFuncType->getResultType(), callee, params);
+ if (call->getDataType()->op == kIROp_VoidType)
+ builder->emitReturn();
+ else
+ builder->emitReturn(call);
+ return func;
+ }
+
+ // If an interface dispatch method is already created, return it.
+ // Otherwise, create the method.
+ IRFunc* getOrCreateInterfaceDispatchMethod(
+ IRBuilder* builder,
+ IRInterfaceType* interfaceType,
+ IRInst* requirementKey,
+ IRInst* requirementVal)
+ {
+ if (auto func = mapInterfaceRequirementKeyToDispatchMethods.TryGetValue(requirementKey))
+ return *func;
+ auto dispatchFunc =
+ _createInterfaceDispatchMethod(builder, interfaceType, requirementKey, requirementVal);
+ mapInterfaceRequirementKeyToDispatchMethods.AddIfNotExists(
+ requirementKey, dispatchFunc);
+ return dispatchFunc;
+ }
+
// Translate `callInst` into a call of `newCallee`, and respect the new `funcType`.
// If `newCallee` is a lowered generic function, `specializeInst` contains the type
// arguments used to specialize the callee.
@@ -213,12 +277,45 @@ namespace Slang
{
// If we see a call(lookup_interface_method(...), ...), we need to translate
// all occurences of associatedtypes.
- auto funcType = cast<IRFuncType>(lookupInst->getDataType());
- auto loweredFunc = lookupInst;
- if (isBuiltin(cast<IRWitnessTableType>(
- lookupInst->getWitnessTable()->getDataType())->getConformanceType()))
+ auto interfaceType = cast<IRInterfaceType>(
+ cast<IRWitnessTableType>(lookupInst->getWitnessTable()->getDataType())
+ ->getConformanceType());
+ if (isBuiltin(interfaceType))
return;
- translateCallInst(callInst, funcType, loweredFunc, nullptr);
+
+ IRBuilder builderStorage;
+ auto builder = &builderStorage;
+ builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
+ builder->setInsertBefore(callInst);
+
+ // Create interface dispatch method that bottlenecks the dispatch logic.
+ auto requirementKey = lookupInst->getRequirementKey();
+ auto requirementVal =
+ sharedContext->findInterfaceRequirementVal(interfaceType, requirementKey);
+ auto dispatchFunc = getOrCreateInterfaceDispatchMethod(
+ builder, interfaceType, requirementKey, requirementVal);
+
+ auto parentFunc = getParentFunc(callInst);
+ // Don't process the call inst that is the one in the dispatch function itself.
+ if (parentFunc == dispatchFunc)
+ return;
+
+ // Replace `callInst` with a new call inst that calls `dispatchFunc` instead, and
+ // with the witness table as first argument,
+ builder->setInsertBefore(callInst);
+ List<IRInst*> newArgs;
+ newArgs.add(lookupInst->getWitnessTable());
+ for (UInt i = 0; i < callInst->getArgCount(); i++)
+ newArgs.add(callInst->getArg(i));
+ auto newCall =
+ (IRCall*)builder->emitCallInst(callInst->getFullType(), dispatchFunc, newArgs);
+ callInst->replaceUsesWith(newCall);
+ callInst->removeAndDeallocate();
+
+ // Translate the new call inst as normal, taking care of packing/unpacking inputs
+ // and outputs.
+ translateCallInst(
+ newCall, cast<IRFuncType>(dispatchFunc->getFullType()), dispatchFunc, nullptr);
}
void lowerCall(IRCall* callInst)
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 6f0cc43e2..778f066dd 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5559,5 +5559,16 @@ namespace Slang
{
return inst->findDecoration<IRBuiltinDecoration>() != nullptr;
}
+ IRFunc* getParentFunc(IRInst* inst)
+ {
+ auto parent = inst->getParent();
+ while (parent)
+ {
+ if (auto func = as<IRFunc>(parent))
+ return func;
+ parent = parent->getParent();
+ }
+ return nullptr;
+ }
} // namespace Slang
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 03d29280a..d5bd55442 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1488,6 +1488,9 @@ bool isPointerOfType(IRInst* ptrType, IROp opCode);
// True if the IR inst represents a builtin object (e.g. __BuiltinFloatingPointType).
bool isBuiltin(IRInst* inst);
+ // Get the enclosuing function of an instruction.
+IRFunc* getParentFunc(IRInst* inst);
+
}
#endif