diff options
| author | Yong He <yonghe@outlook.com> | 2020-06-24 13:16:11 -0700 |
|---|---|---|
| committer | Yong He <yonghe@outlook.com> | 2020-06-24 18:10:15 -0700 |
| commit | 0ca75fe002f346f6ab9b77f40c0576d2905560f1 (patch) | |
| tree | ed8a3af372900923e59f0d6da629c2d0969ee7fd /source/slang/slang-ir-lower-generics.cpp | |
| parent | 3fe4f5398d524333e955ecb91be5646e86f3b2da (diff) | |
Dynamic dispatch for generic interface requirements.
-Lower interfaces into actual `IRInterfaceType` insts.
-Lower `DeclRef<AssocTypeDecl>` into `IRAssociatedType`
-Generate proper IRType for generic functions.
-Add a test case exercising dynamic dispatching a generic static function through an associated type.
-Bug fixes for the test case.
Diffstat (limited to 'source/slang/slang-ir-lower-generics.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-generics.cpp | 181 |
1 files changed, 173 insertions, 8 deletions
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index f6340a633..fe0fa3364 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -16,6 +16,7 @@ namespace Slang IRModule* module; Dictionary<IRInst*, IRInst*> loweredGenericFunctions; + HashSet<IRInterfaceType*> loweredInterfaceTypes; SharedIRBuilder sharedBuilderStorage; @@ -45,6 +46,20 @@ namespace Slang workListSet.Add(inst); } + bool isPolymorphicType(IRInst* typeInst) + { + if (as<IRParam>(typeInst) && as<IRTypeType>(typeInst->getFullType())) + return true; + switch (typeInst->op) + { + case kIROp_AssociatedType: + case kIROp_InterfaceType: + return true; + default: + return false; + } + } + IRInst* lowerGenericFunction(IRInst* genericValue) { IRInst* result = nullptr; @@ -64,6 +79,7 @@ namespace Slang builder.sharedBuilder = &sharedBuilderStorage; builder.setInsertBefore(genericParent); auto loweredFunc = cloneInstAndOperands(&cloneEnv, &builder, func); + loweredFunc->setFullType(lowerGenericFuncType(&builder, cast<IRGeneric>(genericParent->typeUse.get()))); List<IRInst*> clonedParams; for (auto genericParam : genericParent->getParams()) { @@ -82,7 +98,7 @@ namespace Slang // Turn generic parameters into void pointers. for (auto param : cast<IRFunc>(loweredFunc)->getParams()) { - if (param->findDecoration<IRPolymorphicDecoration>()) + if (isPolymorphicType(param->getFullType())) { param->setFullType(builder.getPtrType(builder.getVoidType())); } @@ -91,6 +107,106 @@ namespace Slang return loweredFunc; } + IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal) + { + List<IRInst*> genericParamTypes; + for (auto genericParam : genericVal->getParams()) + { + if (isPolymorphicType(genericParam->getFullType())) + { + genericParamTypes.add(builder->getPtrType(builder->getVoidType())); + } + else + { + genericParamTypes.add(genericParam->getFullType()); + } + } + + auto innerType = (IRFuncType*)lowerFuncType( + builder, + cast<IRFuncType>(findGenericReturnVal(genericVal)), + genericParamTypes.getCount()); + + for (int i = 0; i < genericParamTypes.getCount(); i++) + { + innerType->setOperand( + innerType->getOperandCount() - genericParamTypes.getCount() + i, + genericParamTypes[i]); + } + + return innerType; + } + + IRType* lowerFuncType(IRBuilder* builder, IRFuncType* funcType, UInt additionalParamCount = 0) + { + List<IRInst*> newOperands; + bool translated = false; + for (UInt i = 0; i < funcType->getOperandCount(); i++) + { + auto paramType = funcType->getOperand(i); + if (isPolymorphicType(paramType)) + { + newOperands.add(builder->getPtrType(builder->getVoidType())); + translated = true; + } + else if (paramType->op == kIROp_Specialize) + { + // TODO: handle static specialized type here. + // For now treat all specialized types as dynamic. + // In the future, we need to turn things like Array<IDynamic> into Array<void*>. + newOperands.add(builder->getPtrType(builder->getVoidType())); + translated = true; + } + else + { + newOperands.add(paramType); + } + } + if (!translated) + return funcType; + for (UInt i = 0; i < additionalParamCount; i++) + { + newOperands.add(nullptr); + } + auto newFuncType = builder->getFuncType( + newOperands.getCount() - 1, + (IRType**)(newOperands.begin() + 1), + (IRType*)newOperands[0]); + + IRCloneEnv cloneEnv; + cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, funcType, newFuncType); + return newFuncType; + } + + IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) + { + if (loweredInterfaceTypes.Contains(interfaceType)) + return interfaceType; + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(interfaceType); + + // Translate IRFuncType in interface requirements. + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + if (auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i))) + { + if (auto funcType = as<IRFuncType>(entry->getRequirementVal())) + { + entry->requirementVal.set(lowerFuncType(&builder, funcType)); + } + else if (auto genericFuncType = as<IRGeneric>(entry->getRequirementVal())) + { + entry->requirementVal.set(lowerGenericFuncType(&builder, genericFuncType)); + } + } + } + + loweredInterfaceTypes.Add(interfaceType); + return interfaceType; + } + void processInst(IRInst* inst) { if (auto callInst = as<IRCall>(inst)) @@ -98,25 +214,53 @@ namespace Slang // If we see a call(specialize(gFunc, Targs), args), // translate it into call(gFunc, args, Targs). auto funcOperand = callInst->getOperand(0); + IRInst* loweredFunc = nullptr; if (auto specializeInst = as<IRSpecialize>(funcOperand)) { - auto loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); - if (loweredFunc == specializeInst->getOperand(0)) + auto funcToSpecialize = specializeInst->getOperand(0); + List<IRType*> paramTypes; + if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize)) { - // This is an intrinsic function, don't transform. - return; + // The callee is a result of witness table lookup, we will only + // translate the call. + IRInst* callee = nullptr; + auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(interfaceLookup->getInterfaceType())); + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); + if (entry->getRequirementKey() == interfaceLookup->getOperand(1)) + { + callee = entry->getRequirementVal(); + break; + } + } + auto funcType = cast<IRFuncType>(callee); + for (UInt i = 0; i < funcType->getParamCount(); i++) + paramTypes.add(funcType->getParamType(i)); + loweredFunc = funcToSpecialize; + } + else + { + loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); + if (loweredFunc == specializeInst->getOperand(0)) + { + // This is an intrinsic function, don't transform. + return; + } + for (auto param : as<IRFunc>(loweredFunc)->getParams()) + paramTypes.add(param->getDataType()); } + IRBuilder builderStorage; auto builder = &builderStorage; builder->sharedBuilder = &sharedBuilderStorage; builder->setInsertBefore(inst); List<IRInst*> args; - auto pp = as<IRFunc>(loweredFunc)->getParams().begin(); auto voidPtrType = builder->getPtrType(builder->getVoidType()); for (UInt i = 0; i < callInst->getArgCount(); i++) { auto arg = callInst->getArg(i); - if ((*pp)->getDataType() == voidPtrType && + if (paramTypes[i] == voidPtrType && arg->getDataType() != voidPtrType) { // We are calling a generic function that with an argument of @@ -132,7 +276,6 @@ namespace Slang arg); } args.add(arg); - ++pp; } for (UInt i = 0; i < specializeInst->getArgCount(); i++) args.add(specializeInst->getArg(i)); @@ -141,6 +284,28 @@ namespace Slang callInst->removeAndDeallocate(); } } + else if (auto witnessTable = as<IRWitnessTable>(inst)) + { + // Lower generic functions in witness table. + for (auto child : witnessTable->getChildren()) + { + auto entry = as<IRWitnessTableEntry>(child); + if (!entry) + continue; + if (auto genericVal = as<IRGeneric>(entry->getSatisfyingVal())) + { + if (findGenericReturnVal(genericVal)->op == kIROp_Func) + { + auto loweredFunc = lowerGenericFunction(genericVal); + entry->satisfyingVal.set(loweredFunc); + } + } + } + } + else if (auto interfaceType = as<IRInterfaceType>(inst)) + { + maybeLowerInterfaceType(interfaceType); + } } void processModule() |
