diff options
| author | Yong He <yonghe@outlook.com> | 2020-07-13 15:16:09 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-07-13 15:16:09 -0700 |
| commit | 48f26ef082fa3b0c2a02dc57585f7e43210bbb63 (patch) | |
| tree | e3e13e8034c0f2efe1454a51b4df0290056dae9f /source | |
| parent | 249f48dbb5e240c713661be969a6939ec57561e5 (diff) | |
Dynamic code gen for functions returning generic types. (#1439)
* Dynamic code gen for functions returning generic types.
* Add expected test result.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generics.cpp | 287 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 1 |
4 files changed, 222 insertions, 98 deletions
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index d07a6d76e..a7dd1355a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1956,6 +1956,8 @@ struct IRBuilder IRType* type); IRParam* emitParam( IRType* type); + IRParam* emitParamAtHead( + IRType* type); IRVar* emitVar( IRType* type); diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 8b003f854..260f87bde 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -60,6 +60,15 @@ namespace Slang case kIROp_AssociatedType: case kIROp_InterfaceType: return true; + case kIROp_Specialize: + { + for (UInt i = 0; i < typeInst->getOperandCount(); i++) + { + if (isPolymorphicType(typeInst->getOperand(i))) + return true; + } + return false; + } default: break; } @@ -124,6 +133,41 @@ namespace Slang } } cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, func, loweredFunc); + + // If the function returns a generic typed value, we need to turn it + // into an `out` parameter, since only the caller can allocate space + // for it. + auto oldFuncType = cast<IRFuncType>(func->getDataType()); + if (isPolymorphicType(oldFuncType->getResultType())) + { + builder.setInsertInto(loweredFunc->getFirstBlock()); + // We defer creation of the returnVal parameter until we see the first + // `return` instruction, because we can only obtain the cloned return type + // of this function by checking the type of the cloned return inst. + IRParam* retValParam = nullptr; + // Translate all return insts to `store`s. + // Those `store`s will be processed and translated into `copy`s when we + // get to process them via workList. + for (auto bb : loweredFunc->getBlocks()) + { + auto retInst = as<IRReturnVal>(bb->getTerminator()); + if (!retInst) + continue; + if (!retValParam) + { + // Now we have the return type, emit the returnVal parameter. + // The type of this parameter is still not translated to RawPointer yet, + // and will be processed together with all the other existing parameters. + retValParam = builder.emitParamAtHead( + builder.getOutType(retInst->getVal()->getDataType())); + } + builder.setInsertBefore(retInst); + builder.emitStore(retValParam, retInst->getVal()); + builder.emitReturn(); + retInst->removeAndDeallocate(); + } + } + auto block = as<IRBlock>(loweredFunc->getFirstChild()); for (auto param : clonedParams) { @@ -139,7 +183,10 @@ namespace Slang param = param->getNextInst()) { // Generic typed parameters have a type that is a param itself. - if (auto rttiParam = as<IRParam>(param->getDataType())) + auto paramType = param->getDataType(); + if (auto ptrType = as<IRPtrTypeBase>(paramType)) + paramType = ptrType->getValueType(); + if (auto rttiParam = as<IRParam>(paramType)) { SLANG_ASSERT(isPointerOfType(rttiParam->getDataType(), kIROp_RTTIType)); // Lower into a function parameter of raw pointer type. @@ -189,6 +236,14 @@ namespace Slang { auto loweredParamType = lowerParameterType(builder, paramType); translated = translated || (loweredParamType != paramType); + if (translated && i == 0) + { + // We are translating the return value, this means that + // the return value must be passed explicitly via an `out` parameter. + // In this case, the new return value will be `void`, and the + // translated return value type will be the first parameter type; + newOperands.add(builder->getVoidType()); + } newOperands.add(loweredParamType); } } @@ -382,110 +437,146 @@ namespace Slang return result; } - void processInst(IRInst* inst) + void lowerCall(IRCall* callInst) { - if (auto callInst = as<IRCall>(inst)) + // If we see a call(specialize(gFunc, Targs), args), + // translate it into call(gFunc, args, Targs). + auto funcOperand = callInst->getOperand(0); + IRInst* loweredFunc = nullptr; + auto specializeInst = as<IRSpecialize>(funcOperand); + if (!specializeInst) + return; + + auto funcToSpecialize = specializeInst->getOperand(0); + List<IRType*> paramTypes; + IRFuncType* funcType = nullptr; + if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize)) { - // If we see a call(specialize(gFunc, Targs), args), - // translate it into call(gFunc, args, Targs). - auto funcOperand = callInst->getOperand(0); - IRInst* loweredFunc = nullptr; - if (auto specializeInst = as<IRSpecialize>(funcOperand)) + // The callee is a result of witness table lookup, we will only + // translate the call. + IRInst* callee = nullptr; + auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable()->getDataType()); + auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType())); + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { - auto funcToSpecialize = specializeInst->getOperand(0); - List<IRType*> paramTypes; - if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize)) - { - // The callee is a result of witness table lookup, we will only - // translate the call. - IRInst* callee = nullptr; - auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable()->getDataType()); - auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType())); - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); - if (entry->getRequirementKey() == interfaceLookup->getOperand(1)) - { - callee = entry->getRequirementVal(); - break; - } - } - auto funcType = cast<IRFuncType>(callee); - for (UInt i = 0; i < funcType->getParamCount(); i++) - paramTypes.add(funcType->getParamType(i)); - loweredFunc = funcToSpecialize; - } - else + auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); + if (entry->getRequirementKey() == interfaceLookup->getOperand(1)) { - loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); - if (loweredFunc == specializeInst->getOperand(0)) - { - // This is an intrinsic function, don't transform. - return; - } - for (auto param : as<IRFunc>(loweredFunc)->getParams()) - paramTypes.add(param->getDataType()); + callee = entry->getRequirementVal(); + break; } + } + funcType = cast<IRFuncType>(callee); + loweredFunc = funcToSpecialize; + } + else + { + loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); + if (loweredFunc == specializeInst->getOperand(0)) + { + // This is an intrinsic function, don't transform. + return; + } + funcType = cast<IRFuncType>(loweredFunc->getDataType()); + } - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(inst); - List<IRInst*> args; - auto rawPtrType = builder->getRawPointerType(); - for (UInt i = 0; i < callInst->getArgCount(); i++) - { - auto arg = callInst->getArg(i); - if (as<IRRawPointerType>(paramTypes[i]) && - !as<IRRawPointerType>(arg->getDataType())) - { - // We are calling a generic function that with an argument of - // concrete type. We need to convert this argument to void*. - - // Ideally this should just be a GetElementAddress inst. - // However the current code emitting logic for this instruction - // doesn't truly respect the pointerness and does not produce - // what we needed. For now we use another instruction here - // to keep changes minimal. - arg = builder->emitGetAddress( - rawPtrType, - arg); - } - args.add(arg); - } - for (UInt i = 0; i < specializeInst->getArgCount(); i++) - { - auto arg = specializeInst->getArg(i); - // Translate Type arguments into RTTI object. - if (as<IRType>(arg)) - { - // We are using a simple type to specialize a callee. - // Generate RTTI for this type. - auto rttiObject = maybeEmitRTTIObject(arg); - arg = builder->emitGetAddress( - builder->getPtrType(builder->getRTTIType()), - rttiObject); - } - else if (arg->op == kIROp_Specialize) - { - // The type argument used to specialize a callee is itself a - // specialization of some generic type. - // TODO: generate RTTI object for specializations of generic types. - SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types"); - } - else if (arg->op == kIROp_RTTIObject) - { - // We are inside a generic function and using a generic parameter - // to specialize another callee. The generic parameter of the caller - // has already been translated into an RTTI object, so we just need - // to pass this object down. - } - args.add(arg); - } - auto newCall = builder->emitCallInst(callInst->getFullType(), loweredFunc, args); - callInst->replaceUsesWith(newCall); - callInst->removeAndDeallocate(); + for (UInt i = 0; i < funcType->getParamCount(); i++) + paramTypes.add(funcType->getParamType(i)); + + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedBuilderStorage; + builder->setInsertBefore(callInst); + + List<IRInst*> args; + + // Indicates whether the caller should allocate space for return value. + // If the lowered callee returns void and this call inst has a type that is not void, + // then we are calling a transformed function that expects caller allocated return value + // as the first argument. + bool shouldCallerAllocateReturnValue = (funcType->getResultType()->op == kIROp_VoidType && + callInst->getDataType() != funcType->getResultType()); + + IRVar* retVarInst = nullptr; + int startParamIndex = 0; + if (shouldCallerAllocateReturnValue) + { + // Declare a var for the return value. + retVarInst = builder->emitVar(callInst->getFullType()); + args.add(retVarInst); + startParamIndex = 1; + } + + for (UInt i = 0; i < callInst->getArgCount(); i++) + { + auto arg = callInst->getArg(i); + if (as<IRRawPointerType>(paramTypes[i] + startParamIndex) && + !as<IRRawPointerType>(arg->getDataType())) + { + // We are calling a generic function that with an argument of + // concrete type. We need to convert this argument to void*. + + // Ideally this should just be a GetElementAddress inst. + // However the current code emitting logic for this instruction + // doesn't truly respect the pointerness and does not produce + // what we needed. For now we use another instruction here + // to keep changes minimal. + arg = builder->emitGetAddress( + builder->getRawPointerType(), + arg); + } + args.add(arg); + } + for (UInt i = 0; i < specializeInst->getArgCount(); i++) + { + auto arg = specializeInst->getArg(i); + // Translate Type arguments into RTTI object. + if (as<IRType>(arg)) + { + // We are using a simple type to specialize a callee. + // Generate RTTI for this type. + auto rttiObject = maybeEmitRTTIObject(arg); + arg = builder->emitGetAddress( + builder->getPtrType(builder->getRTTIType()), + rttiObject); + } + else if (arg->op == kIROp_Specialize) + { + // The type argument used to specialize a callee is itself a + // specialization of some generic type. + // TODO: generate RTTI object for specializations of generic types. + SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types"); } + else if (arg->op == kIROp_RTTIObject) + { + // We are inside a generic function and using a generic parameter + // to specialize another callee. The generic parameter of the caller + // has already been translated into an RTTI object, so we just need + // to pass this object down. + } + args.add(arg); + } + auto callInstType = retVarInst ? builder->getVoidType() : callInst->getFullType(); + auto newCall = builder->emitCallInst(callInstType, loweredFunc, args); + if (retVarInst) + { + auto loadInst = builder->emitLoad(retVarInst); + callInst->replaceUsesWith(loadInst); + addToWorkList(loadInst); + addToWorkList(retVarInst); + } + else + { + callInst->replaceUsesWith(newCall); + } + callInst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + if (auto callInst = as<IRCall>(inst)) + { + lowerCall(callInst); } else if (auto witnessTable = as<IRWitnessTable>(inst)) { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b2ddc8ed3..efb8b3b27 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -343,6 +343,25 @@ namespace Slang } } + // Similar to addParam, but instead of appending `param` to the end + // of the parameter list, this function inserts `param` before the + // head of the list. + void IRBlock::insertParamAtHead(IRParam* param) + { + if (auto firstParam = getFirstParam()) + { + param->insertBefore(firstParam); + } + else if (auto firstOrdinary = getFirstOrdinaryInst()) + { + param->insertBefore(firstOrdinary); + } + else + { + param->insertAtEnd(this); + } + } + IRInst* IRBlock::getFirstOrdinaryInst() { // Find the last parameter (if any) of the block @@ -3030,6 +3049,17 @@ namespace Slang return param; } + IRParam* IRBuilder::emitParamAtHead( + IRType* type) + { + auto param = createParam(type); + if (auto bb = getBlock()) + { + bb->insertParamAtHead(param); + } + return param; + } + IRVar* IRBuilder::emitVar( IRType* type) { diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 9792f0625..6ca5d5058 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -822,6 +822,7 @@ struct IRBlock : IRInst } void addParam(IRParam* param); + void insertParamAtHead(IRParam* param); // The "ordinary" instructions come after the parameters IRInst* getFirstOrdinaryInst(); |
