// slang-ir-lower-generic-function.cpp #include "slang-ir-lower-generic-function.h" #include "slang-ir-clone.h" #include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" #include "slang-ir.h" namespace Slang { // This is a subpass of generics lowering IR transformation. // This pass lowers all generic function types and function definitions, including // the function types used in interface types, to ordinary functions that takes // raw pointers in place of generic types. struct GenericFunctionLoweringContext { SharedGenericsLoweringContext* sharedContext; IRInst* lowerGenericFunction(IRInst* genericValue) { IRInst* result = nullptr; if (sharedContext->loweredGenericFunctions.tryGetValue(genericValue, result)) return result; // Do not lower intrinsic functions. if (genericValue->findDecoration()) return genericValue; auto genericParent = as(genericValue); SLANG_ASSERT(genericParent); SLANG_ASSERT(genericParent->getDataType()); auto genericRetVal = findGenericReturnVal(genericParent); auto func = as(genericRetVal); if (!func) { // Nested generic functions are supposed to be flattened before entering // this pass. The reason we are still seeing them must be that they are // intrinsic functions. In this case we ignore the function. if (as(genericRetVal)) { SLANG_ASSERT( findInnerMostGenericReturnVal(genericParent) ->findDecoration() != nullptr); } return genericValue; } SLANG_ASSERT(func); // Do not lower intrinsic functions. UnownedStringSlice intrinsicDef; IRInst* intrinsicInst; if (!func->isDefinition() || findTargetIntrinsicDefinition( func, sharedContext->targetProgram->getTargetReq()->getTargetCaps(), intrinsicDef, intrinsicInst)) { sharedContext->loweredGenericFunctions[genericValue] = genericValue; return genericValue; } IRCloneEnv cloneEnv; IRBuilder builder(sharedContext->module); builder.setInsertBefore(genericParent); // Do not clone func type (which would break IR def-use rules if we do it here) // This is OK since we will lower the type immediately after the clone. cloneEnv.mapOldValToNew[func->getFullType()] = builder.getTypeKind(); auto loweredFunc = cast(cloneInstAndOperands(&cloneEnv, &builder, func)); auto loweredGenericType = lowerGenericFuncType(&builder, genericParent, cast(func->getFullType())); SLANG_ASSERT(loweredGenericType); loweredFunc->setFullType(loweredGenericType); OrderedHashSet childrenToDemote; List clonedParams; auto moduleInst = genericParent->getModule()->getModuleInst(); for (auto genericChild : genericParent->getFirstBlock()->getChildren()) { switch (genericChild->getOp()) { case kIROp_Func: continue; case kIROp_Return: continue; } // Process all generic parameters and local type definitions. auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild); switch (clonedChild->getOp()) { case kIROp_Param: { auto paramType = clonedChild->getFullType(); auto loweredParamType = sharedContext->lowerType(&builder, paramType); if (loweredParamType != paramType) { clonedChild->setFullType((IRType*)loweredParamType); } clonedParams.add(clonedChild); } break; case kIROp_Specialize: case kIROp_LookupWitnessMethod: childrenToDemote.add(clonedChild); break; default: { bool shouldDemote = false; if (childrenToDemote.contains(clonedChild->getFullType())) shouldDemote = true; for (UInt i = 0; i < clonedChild->getOperandCount(); i++) { if (childrenToDemote.contains(clonedChild->getOperand(i))) { shouldDemote = true; break; } } if (shouldDemote && clonedChild->getParent() == moduleInst) { childrenToDemote.add(clonedChild); } continue; } } } cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, func, loweredFunc); auto block = as(loweredFunc->getFirstChild()); for (auto param : clonedParams) { param->removeFromParent(); block->addParam(as(param)); } // Demote specialize and lookupWitness insts and their dependents down to function body. auto insertPoint = block->getFirstOrdinaryInst(); List childrenToDemoteList; for (auto child : childrenToDemote) childrenToDemoteList.add(child); for (Index i = childrenToDemoteList.getCount() - 1; i >= 0; i--) { auto child = childrenToDemoteList[i]; child->insertBefore(insertPoint); } // Lower generic typed parameters into AnyValueType. auto firstInst = loweredFunc->getFirstOrdinaryInst(); builder.setInsertBefore(firstInst); sharedContext->loweredGenericFunctions[genericValue] = loweredFunc; sharedContext->addToWorkList(loweredFunc); return loweredFunc; } IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal, IRFuncType* funcType) { ShortList genericParamTypes; Dictionary typeMapping; for (auto genericParam : genericVal->getParams()) { genericParamTypes.add(sharedContext->lowerType(builder, genericParam->getFullType())); if (auto anyValueSizeDecor = genericParam->findDecoration()) { auto anyValueSize = sharedContext->getInterfaceAnyValueSize( anyValueSizeDecor->getConstraintType(), genericParam->sourceLoc); auto anyValueType = builder->getAnyValueType(anyValueSize); typeMapping[genericParam] = anyValueType; } } auto innerType = (IRFuncType*)lowerFuncType( builder, funcType, typeMapping, genericParamTypes.getArrayView().arrayView); return innerType; } IRType* lowerFuncType( IRBuilder* builder, IRFuncType* funcType, const Dictionary& typeMapping, ArrayView additionalParams) { List newOperands; bool translated = false; for (UInt i = 0; i < funcType->getOperandCount(); i++) { auto paramType = funcType->getOperand(i); auto loweredParamType = sharedContext->lowerType(builder, paramType, typeMapping, nullptr); SLANG_ASSERT(loweredParamType); translated = translated || (loweredParamType != paramType); newOperands.add(loweredParamType); } if (!translated && additionalParams.getCount() == 0) return funcType; for (Index i = 0; i < additionalParams.getCount(); i++) { newOperands.add(additionalParams[i]); } auto newFuncType = builder->getFuncType( newOperands.getCount() - 1, (IRType**)(newOperands.begin() + 1), (IRType*)newOperands[0]); IRCloneEnv cloneEnv; cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, funcType, newFuncType); return newFuncType; } IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) { IRInterfaceType* loweredType = nullptr; if (sharedContext->loweredInterfaceTypes.tryGetValue(interfaceType, loweredType)) return loweredType; if (sharedContext->mapLoweredInterfaceToOriginal.containsKey(interfaceType)) return interfaceType; // Do not lower intrinsic interfaces. if (isBuiltin(interfaceType)) return interfaceType; // Do not lower COM interfaces. if (isComInterfaceType(interfaceType)) return interfaceType; List newEntries; IRBuilder builder(sharedContext->module); builder.setInsertBefore(interfaceType); // Translate IRFuncType in interface requirements. for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { if (auto entry = as(interfaceType->getOperand(i))) { // Note: The logic that creates the `IRInterfaceRequirementEntry`s does // not currently guarantee that the *value* part of each key-value pair // gets filled in. We thus need to defend against a null `requirementVal` // here, at least until the underlying issue gets resolved. // IRInst* requirementVal = entry->getRequirementVal(); IRInst* loweredVal = nullptr; if (!requirementVal) { } else if (auto funcType = as(requirementVal)) { loweredVal = lowerFuncType( &builder, funcType, Dictionary(), ArrayView()); } else if (auto genericFuncType = as(requirementVal)) { loweredVal = lowerGenericFuncType( &builder, genericFuncType, cast(findGenericReturnVal(genericFuncType))); } else if (requirementVal->getOp() == kIROp_AssociatedType) { loweredVal = builder.getRTTIHandleType(); } else { loweredVal = requirementVal; } auto newEntry = builder.createInterfaceRequirementEntry(entry->getRequirementKey(), loweredVal); newEntries.add(newEntry); } } loweredType = builder.createInterfaceType(newEntries.getCount(), (IRInst**)newEntries.getBuffer()); loweredType->sourceLoc = interfaceType->sourceLoc; IRCloneEnv cloneEnv; cloneInstDecorationsAndChildren( &cloneEnv, sharedContext->module, interfaceType, loweredType); sharedContext->loweredInterfaceTypes.add(interfaceType, loweredType); sharedContext->mapLoweredInterfaceToOriginal[loweredType] = interfaceType; return loweredType; } bool isTypeKindVal(IRInst* inst) { auto type = inst->getDataType(); if (!type) return false; return type->getOp() == kIROp_TypeKind; } // Lower items in a witness table. This triggers lowering of generic functions, // and emission of wrapper functions. void lowerWitnessTable(IRWitnessTable* witnessTable) { IRInterfaceType* conformanceType = as(witnessTable->getConformanceType()); if (!conformanceType) return; auto interfaceType = maybeLowerInterfaceType(conformanceType); IRBuilder builderStorage(sharedContext->module); auto builder = &builderStorage; builder->setInsertBefore(witnessTable); if (interfaceType != witnessTable->getConformanceType()) { auto newWitnessTableType = builder->getWitnessTableType(interfaceType); witnessTable->setFullType(newWitnessTableType); } if (isBuiltin(interfaceType)) return; for (auto child : witnessTable->getChildren()) { auto entry = as(child); if (!entry) continue; if (auto genericVal = as(entry->getSatisfyingVal())) { // Lower generic functions in witness table. if (findGenericReturnVal(genericVal)->getOp() == kIROp_Func) { auto loweredFunc = lowerGenericFunction(genericVal); entry->satisfyingVal.set(loweredFunc); } } else if (isTypeKindVal(entry->getSatisfyingVal())) { // Translate a Type value to an RTTI object pointer. auto rttiObject = sharedContext->maybeEmitRTTIObject(entry->getSatisfyingVal()); auto rttiObjectPtr = builder->emitGetAddress(builder->getRTTIHandleType(), rttiObject); entry->satisfyingVal.set(rttiObjectPtr); } else if (as(entry->getSatisfyingVal())) { // No processing needed here. // The witness table will be processed from the work list. } } } void lowerLookupInterfaceMethodInst(IRLookupWitnessMethod* lookupInst) { // Update the type of lookupInst to the lowered type of the corresponding interface // requirement val. // If the requirement is a function, interfaceRequirementVal will be the lowered function // type. If the requirement is an associatedtype, interfaceRequirementVal will be // Ptr. IRInst* interfaceRequirementVal = nullptr; auto witnessTableType = as(lookupInst->getWitnessTable()->getDataType()); if (!witnessTableType) return; if (witnessTableType->getConformanceType()->findDecoration()) return; IRInterfaceType* conformanceType = as(witnessTableType->getConformanceType()); // NoneWitness generates conformance types which aren't interfaces. In // that case, the method can just be skipped entirely, since there's no // real witness for it and it should be in unreachable code at this // point. if (!conformanceType) return; auto interfaceType = maybeLowerInterfaceType(conformanceType); interfaceRequirementVal = sharedContext->findInterfaceRequirementVal( interfaceType, lookupInst->getRequirementKey()); IRBuilder builder(lookupInst); builder.replaceOperand(&lookupInst->typeUse, interfaceRequirementVal); } void lowerSpecialize(IRSpecialize* specializeInst) { // If we see a call(specialize(gFunc, Targs), args), // translate it into call(gFunc, args, Targs). IRInst* loweredFunc = nullptr; auto funcToSpecialize = specializeInst->getBase(); if (funcToSpecialize->getOp() == kIROp_Generic) { loweredFunc = lowerGenericFunction(funcToSpecialize); if (loweredFunc != funcToSpecialize) { IRBuilder builder; builder.replaceOperand(specializeInst->getOperands(), loweredFunc); } } } void processInst(IRInst* inst) { if (auto specializeInst = as(inst)) { lowerSpecialize(specializeInst); } else if (auto lookupInterfaceMethod = as(inst)) { lowerLookupInterfaceMethodInst(lookupInterfaceMethod); } else if (auto witnessTable = as(inst)) { lowerWitnessTable(witnessTable); } else if (auto interfaceType = as(inst)) { maybeLowerInterfaceType(interfaceType); } } void replaceLoweredInterfaceTypes() { for (const auto& [loweredKey, loweredValue] : sharedContext->loweredInterfaceTypes) loweredKey->replaceUsesWith(loweredValue); sharedContext->mapInterfaceRequirementKeyValue.clear(); } void processModule() { sharedContext->addToWorkList(sharedContext->module->getModuleInst()); while (sharedContext->workList.getCount() != 0) { IRInst* inst = sharedContext->workList.getLast(); sharedContext->workList.removeLast(); sharedContext->workListSet.remove(inst); processInst(inst); for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) { sharedContext->addToWorkList(child); } } replaceLoweredInterfaceTypes(); } }; void lowerGenericFunctions(SharedGenericsLoweringContext* sharedContext) { GenericFunctionLoweringContext context; context.sharedContext = sharedContext; context.processModule(); } } // namespace Slang