From 48f26ef082fa3b0c2a02dc57585f7e43210bbb63 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 13 Jul 2020 15:16:09 -0700 Subject: Dynamic code gen for functions returning generic types. (#1439) * Dynamic code gen for functions returning generic types. * Add expected test result. --- source/slang/slang-ir-insts.h | 2 + source/slang/slang-ir-lower-generics.cpp | 287 ++++++++++++++++++++----------- source/slang/slang-ir.cpp | 30 ++++ source/slang/slang-ir.h | 1 + 4 files changed, 222 insertions(+), 98 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index d07a6d76e..a7dd1355a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1956,6 +1956,8 @@ struct IRBuilder IRType* type); IRParam* emitParam( IRType* type); + IRParam* emitParamAtHead( + IRType* type); IRVar* emitVar( IRType* type); diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 8b003f854..260f87bde 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -60,6 +60,15 @@ namespace Slang case kIROp_AssociatedType: case kIROp_InterfaceType: return true; + case kIROp_Specialize: + { + for (UInt i = 0; i < typeInst->getOperandCount(); i++) + { + if (isPolymorphicType(typeInst->getOperand(i))) + return true; + } + return false; + } default: break; } @@ -124,6 +133,41 @@ namespace Slang } } cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, func, loweredFunc); + + // If the function returns a generic typed value, we need to turn it + // into an `out` parameter, since only the caller can allocate space + // for it. + auto oldFuncType = cast(func->getDataType()); + if (isPolymorphicType(oldFuncType->getResultType())) + { + builder.setInsertInto(loweredFunc->getFirstBlock()); + // We defer creation of the returnVal parameter until we see the first + // `return` instruction, because we can only obtain the cloned return type + // of this function by checking the type of the cloned return inst. + IRParam* retValParam = nullptr; + // Translate all return insts to `store`s. + // Those `store`s will be processed and translated into `copy`s when we + // get to process them via workList. + for (auto bb : loweredFunc->getBlocks()) + { + auto retInst = as(bb->getTerminator()); + if (!retInst) + continue; + if (!retValParam) + { + // Now we have the return type, emit the returnVal parameter. + // The type of this parameter is still not translated to RawPointer yet, + // and will be processed together with all the other existing parameters. + retValParam = builder.emitParamAtHead( + builder.getOutType(retInst->getVal()->getDataType())); + } + builder.setInsertBefore(retInst); + builder.emitStore(retValParam, retInst->getVal()); + builder.emitReturn(); + retInst->removeAndDeallocate(); + } + } + auto block = as(loweredFunc->getFirstChild()); for (auto param : clonedParams) { @@ -139,7 +183,10 @@ namespace Slang param = param->getNextInst()) { // Generic typed parameters have a type that is a param itself. - if (auto rttiParam = as(param->getDataType())) + auto paramType = param->getDataType(); + if (auto ptrType = as(paramType)) + paramType = ptrType->getValueType(); + if (auto rttiParam = as(paramType)) { SLANG_ASSERT(isPointerOfType(rttiParam->getDataType(), kIROp_RTTIType)); // Lower into a function parameter of raw pointer type. @@ -189,6 +236,14 @@ namespace Slang { auto loweredParamType = lowerParameterType(builder, paramType); translated = translated || (loweredParamType != paramType); + if (translated && i == 0) + { + // We are translating the return value, this means that + // the return value must be passed explicitly via an `out` parameter. + // In this case, the new return value will be `void`, and the + // translated return value type will be the first parameter type; + newOperands.add(builder->getVoidType()); + } newOperands.add(loweredParamType); } } @@ -382,110 +437,146 @@ namespace Slang return result; } - void processInst(IRInst* inst) + void lowerCall(IRCall* callInst) { - if (auto callInst = as(inst)) + // If we see a call(specialize(gFunc, Targs), args), + // translate it into call(gFunc, args, Targs). + auto funcOperand = callInst->getOperand(0); + IRInst* loweredFunc = nullptr; + auto specializeInst = as(funcOperand); + if (!specializeInst) + return; + + auto funcToSpecialize = specializeInst->getOperand(0); + List paramTypes; + IRFuncType* funcType = nullptr; + if (auto interfaceLookup = as(funcToSpecialize)) { - // If we see a call(specialize(gFunc, Targs), args), - // translate it into call(gFunc, args, Targs). - auto funcOperand = callInst->getOperand(0); - IRInst* loweredFunc = nullptr; - if (auto specializeInst = as(funcOperand)) + // The callee is a result of witness table lookup, we will only + // translate the call. + IRInst* callee = nullptr; + auto witnessTableType = cast(interfaceLookup->getWitnessTable()->getDataType()); + auto interfaceType = maybeLowerInterfaceType(cast(witnessTableType->getConformanceType())); + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { - auto funcToSpecialize = specializeInst->getOperand(0); - List paramTypes; - if (auto interfaceLookup = as(funcToSpecialize)) - { - // The callee is a result of witness table lookup, we will only - // translate the call. - IRInst* callee = nullptr; - auto witnessTableType = cast(interfaceLookup->getWitnessTable()->getDataType()); - auto interfaceType = maybeLowerInterfaceType(cast(witnessTableType->getConformanceType())); - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - auto entry = cast(interfaceType->getOperand(i)); - if (entry->getRequirementKey() == interfaceLookup->getOperand(1)) - { - callee = entry->getRequirementVal(); - break; - } - } - auto funcType = cast(callee); - for (UInt i = 0; i < funcType->getParamCount(); i++) - paramTypes.add(funcType->getParamType(i)); - loweredFunc = funcToSpecialize; - } - else + auto entry = cast(interfaceType->getOperand(i)); + if (entry->getRequirementKey() == interfaceLookup->getOperand(1)) { - loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); - if (loweredFunc == specializeInst->getOperand(0)) - { - // This is an intrinsic function, don't transform. - return; - } - for (auto param : as(loweredFunc)->getParams()) - paramTypes.add(param->getDataType()); + callee = entry->getRequirementVal(); + break; } + } + funcType = cast(callee); + loweredFunc = funcToSpecialize; + } + else + { + loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); + if (loweredFunc == specializeInst->getOperand(0)) + { + // This is an intrinsic function, don't transform. + return; + } + funcType = cast(loweredFunc->getDataType()); + } - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(inst); - List args; - auto rawPtrType = builder->getRawPointerType(); - for (UInt i = 0; i < callInst->getArgCount(); i++) - { - auto arg = callInst->getArg(i); - if (as(paramTypes[i]) && - !as(arg->getDataType())) - { - // We are calling a generic function that with an argument of - // concrete type. We need to convert this argument to void*. - - // Ideally this should just be a GetElementAddress inst. - // However the current code emitting logic for this instruction - // doesn't truly respect the pointerness and does not produce - // what we needed. For now we use another instruction here - // to keep changes minimal. - arg = builder->emitGetAddress( - rawPtrType, - arg); - } - args.add(arg); - } - for (UInt i = 0; i < specializeInst->getArgCount(); i++) - { - auto arg = specializeInst->getArg(i); - // Translate Type arguments into RTTI object. - if (as(arg)) - { - // We are using a simple type to specialize a callee. - // Generate RTTI for this type. - auto rttiObject = maybeEmitRTTIObject(arg); - arg = builder->emitGetAddress( - builder->getPtrType(builder->getRTTIType()), - rttiObject); - } - else if (arg->op == kIROp_Specialize) - { - // The type argument used to specialize a callee is itself a - // specialization of some generic type. - // TODO: generate RTTI object for specializations of generic types. - SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types"); - } - else if (arg->op == kIROp_RTTIObject) - { - // We are inside a generic function and using a generic parameter - // to specialize another callee. The generic parameter of the caller - // has already been translated into an RTTI object, so we just need - // to pass this object down. - } - args.add(arg); - } - auto newCall = builder->emitCallInst(callInst->getFullType(), loweredFunc, args); - callInst->replaceUsesWith(newCall); - callInst->removeAndDeallocate(); + for (UInt i = 0; i < funcType->getParamCount(); i++) + paramTypes.add(funcType->getParamType(i)); + + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedBuilderStorage; + builder->setInsertBefore(callInst); + + List args; + + // Indicates whether the caller should allocate space for return value. + // If the lowered callee returns void and this call inst has a type that is not void, + // then we are calling a transformed function that expects caller allocated return value + // as the first argument. + bool shouldCallerAllocateReturnValue = (funcType->getResultType()->op == kIROp_VoidType && + callInst->getDataType() != funcType->getResultType()); + + IRVar* retVarInst = nullptr; + int startParamIndex = 0; + if (shouldCallerAllocateReturnValue) + { + // Declare a var for the return value. + retVarInst = builder->emitVar(callInst->getFullType()); + args.add(retVarInst); + startParamIndex = 1; + } + + for (UInt i = 0; i < callInst->getArgCount(); i++) + { + auto arg = callInst->getArg(i); + if (as(paramTypes[i] + startParamIndex) && + !as(arg->getDataType())) + { + // We are calling a generic function that with an argument of + // concrete type. We need to convert this argument to void*. + + // Ideally this should just be a GetElementAddress inst. + // However the current code emitting logic for this instruction + // doesn't truly respect the pointerness and does not produce + // what we needed. For now we use another instruction here + // to keep changes minimal. + arg = builder->emitGetAddress( + builder->getRawPointerType(), + arg); + } + args.add(arg); + } + for (UInt i = 0; i < specializeInst->getArgCount(); i++) + { + auto arg = specializeInst->getArg(i); + // Translate Type arguments into RTTI object. + if (as(arg)) + { + // We are using a simple type to specialize a callee. + // Generate RTTI for this type. + auto rttiObject = maybeEmitRTTIObject(arg); + arg = builder->emitGetAddress( + builder->getPtrType(builder->getRTTIType()), + rttiObject); + } + else if (arg->op == kIROp_Specialize) + { + // The type argument used to specialize a callee is itself a + // specialization of some generic type. + // TODO: generate RTTI object for specializations of generic types. + SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types"); } + else if (arg->op == kIROp_RTTIObject) + { + // We are inside a generic function and using a generic parameter + // to specialize another callee. The generic parameter of the caller + // has already been translated into an RTTI object, so we just need + // to pass this object down. + } + args.add(arg); + } + auto callInstType = retVarInst ? builder->getVoidType() : callInst->getFullType(); + auto newCall = builder->emitCallInst(callInstType, loweredFunc, args); + if (retVarInst) + { + auto loadInst = builder->emitLoad(retVarInst); + callInst->replaceUsesWith(loadInst); + addToWorkList(loadInst); + addToWorkList(retVarInst); + } + else + { + callInst->replaceUsesWith(newCall); + } + callInst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + if (auto callInst = as(inst)) + { + lowerCall(callInst); } else if (auto witnessTable = as(inst)) { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b2ddc8ed3..efb8b3b27 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -343,6 +343,25 @@ namespace Slang } } + // Similar to addParam, but instead of appending `param` to the end + // of the parameter list, this function inserts `param` before the + // head of the list. + void IRBlock::insertParamAtHead(IRParam* param) + { + if (auto firstParam = getFirstParam()) + { + param->insertBefore(firstParam); + } + else if (auto firstOrdinary = getFirstOrdinaryInst()) + { + param->insertBefore(firstOrdinary); + } + else + { + param->insertAtEnd(this); + } + } + IRInst* IRBlock::getFirstOrdinaryInst() { // Find the last parameter (if any) of the block @@ -3030,6 +3049,17 @@ namespace Slang return param; } + IRParam* IRBuilder::emitParamAtHead( + IRType* type) + { + auto param = createParam(type); + if (auto bb = getBlock()) + { + bb->insertParamAtHead(param); + } + return param; + } + IRVar* IRBuilder::emitVar( IRType* type) { diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 9792f0625..6ca5d5058 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -822,6 +822,7 @@ struct IRBlock : IRInst } void addParam(IRParam* param); + void insertParamAtHead(IRParam* param); // The "ordinary" instructions come after the parameters IRInst* getFirstOrdinaryInst(); -- cgit v1.2.3