diff options
| author | Yong He <yonghe@outlook.com> | 2020-06-26 11:59:33 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-06-26 11:59:33 -0700 |
| commit | 3e8bdb60afb5b0c0a53ce06d1dbbc429988f5885 (patch) | |
| tree | 03f379d064f5e4df3423824140fad897b8a688e7 /source/slang/slang-ir-lower-generics.cpp | |
| parent | d084f632a136354dd12952183994240b459240ee (diff) | |
| parent | 4e443984065552cc2f648ae2fae9e49a4ef21107 (diff) | |
Merge pull request #1408 from csyonghe/dyndispatch2
Dynamic dispatch for generic interface requirements and `associatedtype`
Diffstat (limited to 'source/slang/slang-ir-lower-generics.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-generics.cpp | 191 |
1 files changed, 179 insertions, 12 deletions
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index f6340a633..774836e29 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -16,6 +16,7 @@ namespace Slang IRModule* module; Dictionary<IRInst*, IRInst*> loweredGenericFunctions; + HashSet<IRInterfaceType*> loweredInterfaceTypes; SharedIRBuilder sharedBuilderStorage; @@ -45,6 +46,21 @@ namespace Slang workListSet.Add(inst); } + bool isPolymorphicType(IRInst* typeInst) + { + if (as<IRParam>(typeInst) && as<IRTypeType>(typeInst->getFullType())) + return true; + switch (typeInst->op) + { + case kIROp_ThisType: + case kIROp_AssociatedType: + case kIROp_InterfaceType: + return true; + default: + return false; + } + } + IRInst* lowerGenericFunction(IRInst* genericValue) { IRInst* result = nullptr; @@ -64,6 +80,7 @@ namespace Slang builder.sharedBuilder = &sharedBuilderStorage; builder.setInsertBefore(genericParent); auto loweredFunc = cloneInstAndOperands(&cloneEnv, &builder, func); + loweredFunc->setFullType(lowerGenericFuncType(&builder, cast<IRGeneric>(genericParent->getFullType()))); List<IRInst*> clonedParams; for (auto genericParam : genericParent->getParams()) { @@ -82,15 +99,115 @@ namespace Slang // Turn generic parameters into void pointers. for (auto param : cast<IRFunc>(loweredFunc)->getParams()) { - if (param->findDecoration<IRPolymorphicDecoration>()) + if (isPolymorphicType(param->getFullType())) { - param->setFullType(builder.getPtrType(builder.getVoidType())); + param->setFullType(builder.getRawPointerType()); } } addToWorkList(loweredFunc); return loweredFunc; } + IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal) + { + List<IRInst*> genericParamTypes; + for (auto genericParam : genericVal->getParams()) + { + if (isPolymorphicType(genericParam->getFullType())) + { + genericParamTypes.add(builder->getRawPointerType()); + } + else + { + genericParamTypes.add(genericParam->getFullType()); + } + } + + auto innerType = (IRFuncType*)lowerFuncType( + builder, + cast<IRFuncType>(findGenericReturnVal(genericVal)), + genericParamTypes.getCount()); + + for (int i = 0; i < genericParamTypes.getCount(); i++) + { + innerType->setOperand( + innerType->getOperandCount() - genericParamTypes.getCount() + i, + genericParamTypes[i]); + } + + return innerType; + } + + IRType* lowerFuncType(IRBuilder* builder, IRFuncType* funcType, UInt additionalParamCount = 0) + { + List<IRInst*> newOperands; + bool translated = false; + for (UInt i = 0; i < funcType->getOperandCount(); i++) + { + auto paramType = funcType->getOperand(i); + if (isPolymorphicType(paramType)) + { + newOperands.add(builder->getRawPointerType()); + translated = true; + } + else if (paramType->op == kIROp_Specialize) + { + // TODO: handle static specialized type here. + // For now treat all specialized types as dynamic. + // In the future, we need to turn things like Array<IDynamic> into Array<void*>. + newOperands.add(builder->getRawPointerType()); + translated = true; + } + else + { + newOperands.add(paramType); + } + } + if (!translated && additionalParamCount == 0) + return funcType; + for (UInt i = 0; i < additionalParamCount; i++) + { + newOperands.add(nullptr); + } + auto newFuncType = builder->getFuncType( + newOperands.getCount() - 1, + (IRType**)(newOperands.begin() + 1), + (IRType*)newOperands[0]); + + IRCloneEnv cloneEnv; + cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, funcType, newFuncType); + return newFuncType; + } + + IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) + { + if (loweredInterfaceTypes.Contains(interfaceType)) + return interfaceType; + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(interfaceType); + + // Translate IRFuncType in interface requirements. + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + if (auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i))) + { + if (auto funcType = as<IRFuncType>(entry->getRequirementVal())) + { + entry->setRequirementVal(lowerFuncType(&builder, funcType)); + } + else if (auto genericFuncType = as<IRGeneric>(entry->getRequirementVal())) + { + entry->setRequirementVal(lowerGenericFuncType(&builder, genericFuncType)); + } + } + } + + loweredInterfaceTypes.Add(interfaceType); + return interfaceType; + } + void processInst(IRInst* inst) { if (auto callInst = as<IRCall>(inst)) @@ -98,26 +215,55 @@ namespace Slang // If we see a call(specialize(gFunc, Targs), args), // translate it into call(gFunc, args, Targs). auto funcOperand = callInst->getOperand(0); + IRInst* loweredFunc = nullptr; if (auto specializeInst = as<IRSpecialize>(funcOperand)) { - auto loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); - if (loweredFunc == specializeInst->getOperand(0)) + auto funcToSpecialize = specializeInst->getOperand(0); + List<IRType*> paramTypes; + if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize)) { - // This is an intrinsic function, don't transform. - return; + // The callee is a result of witness table lookup, we will only + // translate the call. + IRInst* callee = nullptr; + auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable()->getFullType()); + auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType())); + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); + if (entry->getRequirementKey() == interfaceLookup->getOperand(1)) + { + callee = entry->getRequirementVal(); + break; + } + } + auto funcType = cast<IRFuncType>(callee); + for (UInt i = 0; i < funcType->getParamCount(); i++) + paramTypes.add(funcType->getParamType(i)); + loweredFunc = funcToSpecialize; + } + else + { + loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); + if (loweredFunc == specializeInst->getOperand(0)) + { + // This is an intrinsic function, don't transform. + return; + } + for (auto param : as<IRFunc>(loweredFunc)->getParams()) + paramTypes.add(param->getDataType()); } + IRBuilder builderStorage; auto builder = &builderStorage; builder->sharedBuilder = &sharedBuilderStorage; builder->setInsertBefore(inst); List<IRInst*> args; - auto pp = as<IRFunc>(loweredFunc)->getParams().begin(); - auto voidPtrType = builder->getPtrType(builder->getVoidType()); + auto rawPtrType = builder->getRawPointerType(); for (UInt i = 0; i < callInst->getArgCount(); i++) { auto arg = callInst->getArg(i); - if ((*pp)->getDataType() == voidPtrType && - arg->getDataType() != voidPtrType) + if (paramTypes[i] == rawPtrType && + arg->getDataType() != rawPtrType) { // We are calling a generic function that with an argument of // concrete type. We need to convert this argument o void*. @@ -128,11 +274,10 @@ namespace Slang // what we needed. For now we use another instruction here // to keep changes minimal. arg = builder->emitGetAddress( - voidPtrType, + rawPtrType, arg); } args.add(arg); - ++pp; } for (UInt i = 0; i < specializeInst->getArgCount(); i++) args.add(specializeInst->getArg(i)); @@ -141,6 +286,28 @@ namespace Slang callInst->removeAndDeallocate(); } } + else if (auto witnessTable = as<IRWitnessTable>(inst)) + { + // Lower generic functions in witness table. + for (auto child : witnessTable->getChildren()) + { + auto entry = as<IRWitnessTableEntry>(child); + if (!entry) + continue; + if (auto genericVal = as<IRGeneric>(entry->getSatisfyingVal())) + { + if (findGenericReturnVal(genericVal)->op == kIROp_Func) + { + auto loweredFunc = lowerGenericFunction(genericVal); + entry->satisfyingVal.set(loweredFunc); + } + } + } + } + else if (auto interfaceType = as<IRInterfaceType>(inst)) + { + maybeLowerInterfaceType(interfaceType); + } } void processModule() |
