summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-lower-generics.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2020-06-26 11:59:33 -0700
committerGitHub <noreply@github.com>2020-06-26 11:59:33 -0700
commit3e8bdb60afb5b0c0a53ce06d1dbbc429988f5885 (patch)
tree03f379d064f5e4df3423824140fad897b8a688e7 /source/slang/slang-ir-lower-generics.cpp
parentd084f632a136354dd12952183994240b459240ee (diff)
parent4e443984065552cc2f648ae2fae9e49a4ef21107 (diff)
Merge pull request #1408 from csyonghe/dyndispatch2
Dynamic dispatch for generic interface requirements and `associatedtype`
Diffstat (limited to 'source/slang/slang-ir-lower-generics.cpp')
-rw-r--r--source/slang/slang-ir-lower-generics.cpp191
1 files changed, 179 insertions, 12 deletions
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index f6340a633..774836e29 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -16,6 +16,7 @@ namespace Slang
IRModule* module;
Dictionary<IRInst*, IRInst*> loweredGenericFunctions;
+ HashSet<IRInterfaceType*> loweredInterfaceTypes;
SharedIRBuilder sharedBuilderStorage;
@@ -45,6 +46,21 @@ namespace Slang
workListSet.Add(inst);
}
+ bool isPolymorphicType(IRInst* typeInst)
+ {
+ if (as<IRParam>(typeInst) && as<IRTypeType>(typeInst->getFullType()))
+ return true;
+ switch (typeInst->op)
+ {
+ case kIROp_ThisType:
+ case kIROp_AssociatedType:
+ case kIROp_InterfaceType:
+ return true;
+ default:
+ return false;
+ }
+ }
+
IRInst* lowerGenericFunction(IRInst* genericValue)
{
IRInst* result = nullptr;
@@ -64,6 +80,7 @@ namespace Slang
builder.sharedBuilder = &sharedBuilderStorage;
builder.setInsertBefore(genericParent);
auto loweredFunc = cloneInstAndOperands(&cloneEnv, &builder, func);
+ loweredFunc->setFullType(lowerGenericFuncType(&builder, cast<IRGeneric>(genericParent->getFullType())));
List<IRInst*> clonedParams;
for (auto genericParam : genericParent->getParams())
{
@@ -82,15 +99,115 @@ namespace Slang
// Turn generic parameters into void pointers.
for (auto param : cast<IRFunc>(loweredFunc)->getParams())
{
- if (param->findDecoration<IRPolymorphicDecoration>())
+ if (isPolymorphicType(param->getFullType()))
{
- param->setFullType(builder.getPtrType(builder.getVoidType()));
+ param->setFullType(builder.getRawPointerType());
}
}
addToWorkList(loweredFunc);
return loweredFunc;
}
+ IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal)
+ {
+ List<IRInst*> genericParamTypes;
+ for (auto genericParam : genericVal->getParams())
+ {
+ if (isPolymorphicType(genericParam->getFullType()))
+ {
+ genericParamTypes.add(builder->getRawPointerType());
+ }
+ else
+ {
+ genericParamTypes.add(genericParam->getFullType());
+ }
+ }
+
+ auto innerType = (IRFuncType*)lowerFuncType(
+ builder,
+ cast<IRFuncType>(findGenericReturnVal(genericVal)),
+ genericParamTypes.getCount());
+
+ for (int i = 0; i < genericParamTypes.getCount(); i++)
+ {
+ innerType->setOperand(
+ innerType->getOperandCount() - genericParamTypes.getCount() + i,
+ genericParamTypes[i]);
+ }
+
+ return innerType;
+ }
+
+ IRType* lowerFuncType(IRBuilder* builder, IRFuncType* funcType, UInt additionalParamCount = 0)
+ {
+ List<IRInst*> newOperands;
+ bool translated = false;
+ for (UInt i = 0; i < funcType->getOperandCount(); i++)
+ {
+ auto paramType = funcType->getOperand(i);
+ if (isPolymorphicType(paramType))
+ {
+ newOperands.add(builder->getRawPointerType());
+ translated = true;
+ }
+ else if (paramType->op == kIROp_Specialize)
+ {
+ // TODO: handle static specialized type here.
+ // For now treat all specialized types as dynamic.
+ // In the future, we need to turn things like Array<IDynamic> into Array<void*>.
+ newOperands.add(builder->getRawPointerType());
+ translated = true;
+ }
+ else
+ {
+ newOperands.add(paramType);
+ }
+ }
+ if (!translated && additionalParamCount == 0)
+ return funcType;
+ for (UInt i = 0; i < additionalParamCount; i++)
+ {
+ newOperands.add(nullptr);
+ }
+ auto newFuncType = builder->getFuncType(
+ newOperands.getCount() - 1,
+ (IRType**)(newOperands.begin() + 1),
+ (IRType*)newOperands[0]);
+
+ IRCloneEnv cloneEnv;
+ cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, funcType, newFuncType);
+ return newFuncType;
+ }
+
+ IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType)
+ {
+ if (loweredInterfaceTypes.Contains(interfaceType))
+ return interfaceType;
+
+ IRBuilder builder;
+ builder.sharedBuilder = &sharedBuilderStorage;
+ builder.setInsertBefore(interfaceType);
+
+ // Translate IRFuncType in interface requirements.
+ for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
+ {
+ if (auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)))
+ {
+ if (auto funcType = as<IRFuncType>(entry->getRequirementVal()))
+ {
+ entry->setRequirementVal(lowerFuncType(&builder, funcType));
+ }
+ else if (auto genericFuncType = as<IRGeneric>(entry->getRequirementVal()))
+ {
+ entry->setRequirementVal(lowerGenericFuncType(&builder, genericFuncType));
+ }
+ }
+ }
+
+ loweredInterfaceTypes.Add(interfaceType);
+ return interfaceType;
+ }
+
void processInst(IRInst* inst)
{
if (auto callInst = as<IRCall>(inst))
@@ -98,26 +215,55 @@ namespace Slang
// If we see a call(specialize(gFunc, Targs), args),
// translate it into call(gFunc, args, Targs).
auto funcOperand = callInst->getOperand(0);
+ IRInst* loweredFunc = nullptr;
if (auto specializeInst = as<IRSpecialize>(funcOperand))
{
- auto loweredFunc = lowerGenericFunction(specializeInst->getOperand(0));
- if (loweredFunc == specializeInst->getOperand(0))
+ auto funcToSpecialize = specializeInst->getOperand(0);
+ List<IRType*> paramTypes;
+ if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize))
{
- // This is an intrinsic function, don't transform.
- return;
+ // The callee is a result of witness table lookup, we will only
+ // translate the call.
+ IRInst* callee = nullptr;
+ auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable()->getFullType());
+ auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType()));
+ for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
+ {
+ auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
+ if (entry->getRequirementKey() == interfaceLookup->getOperand(1))
+ {
+ callee = entry->getRequirementVal();
+ break;
+ }
+ }
+ auto funcType = cast<IRFuncType>(callee);
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ paramTypes.add(funcType->getParamType(i));
+ loweredFunc = funcToSpecialize;
+ }
+ else
+ {
+ loweredFunc = lowerGenericFunction(specializeInst->getOperand(0));
+ if (loweredFunc == specializeInst->getOperand(0))
+ {
+ // This is an intrinsic function, don't transform.
+ return;
+ }
+ for (auto param : as<IRFunc>(loweredFunc)->getParams())
+ paramTypes.add(param->getDataType());
}
+
IRBuilder builderStorage;
auto builder = &builderStorage;
builder->sharedBuilder = &sharedBuilderStorage;
builder->setInsertBefore(inst);
List<IRInst*> args;
- auto pp = as<IRFunc>(loweredFunc)->getParams().begin();
- auto voidPtrType = builder->getPtrType(builder->getVoidType());
+ auto rawPtrType = builder->getRawPointerType();
for (UInt i = 0; i < callInst->getArgCount(); i++)
{
auto arg = callInst->getArg(i);
- if ((*pp)->getDataType() == voidPtrType &&
- arg->getDataType() != voidPtrType)
+ if (paramTypes[i] == rawPtrType &&
+ arg->getDataType() != rawPtrType)
{
// We are calling a generic function that with an argument of
// concrete type. We need to convert this argument o void*.
@@ -128,11 +274,10 @@ namespace Slang
// what we needed. For now we use another instruction here
// to keep changes minimal.
arg = builder->emitGetAddress(
- voidPtrType,
+ rawPtrType,
arg);
}
args.add(arg);
- ++pp;
}
for (UInt i = 0; i < specializeInst->getArgCount(); i++)
args.add(specializeInst->getArg(i));
@@ -141,6 +286,28 @@ namespace Slang
callInst->removeAndDeallocate();
}
}
+ else if (auto witnessTable = as<IRWitnessTable>(inst))
+ {
+ // Lower generic functions in witness table.
+ for (auto child : witnessTable->getChildren())
+ {
+ auto entry = as<IRWitnessTableEntry>(child);
+ if (!entry)
+ continue;
+ if (auto genericVal = as<IRGeneric>(entry->getSatisfyingVal()))
+ {
+ if (findGenericReturnVal(genericVal)->op == kIROp_Func)
+ {
+ auto loweredFunc = lowerGenericFunction(genericVal);
+ entry->satisfyingVal.set(loweredFunc);
+ }
+ }
+ }
+ }
+ else if (auto interfaceType = as<IRInterfaceType>(inst))
+ {
+ maybeLowerInterfaceType(interfaceType);
+ }
}
void processModule()