diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2024-10-29 14:49:26 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-29 14:49:26 +0800 |
| commit | f65d756bff8d4c5cbc15bd0322a2ae8e6b896a21 (patch) | |
| tree | ea1d61342cd29368e19135000ec2948813096205 /source/slang/slang-ir-lower-generic-function.cpp | |
| parent | a729c15e9dce9f5116a38afc66329ab2ca4cea54 (diff) | |
format
* format
* Minor test fixes
* enable checking cpp format in ci
Diffstat (limited to 'source/slang/slang-ir-lower-generic-function.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-generic-function.cpp | 678 |
1 files changed, 351 insertions, 327 deletions
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index cb6555a7e..5c39b2124 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -1,409 +1,433 @@ // slang-ir-lower-generic-function.cpp #include "slang-ir-lower-generic-function.h" -#include "slang-ir-generics-lowering-context.h" -#include "slang-ir.h" #include "slang-ir-clone.h" +#include "slang-ir-generics-lowering-context.h" #include "slang-ir-insts.h" #include "slang-ir-util.h" +#include "slang-ir.h" namespace Slang { - // This is a subpass of generics lowering IR transformation. - // This pass lowers all generic function types and function definitions, including - // the function types used in interface types, to ordinary functions that takes - // raw pointers in place of generic types. - struct GenericFunctionLoweringContext - { - SharedGenericsLoweringContext* sharedContext; +// This is a subpass of generics lowering IR transformation. +// This pass lowers all generic function types and function definitions, including +// the function types used in interface types, to ordinary functions that takes +// raw pointers in place of generic types. +struct GenericFunctionLoweringContext +{ + SharedGenericsLoweringContext* sharedContext; - IRInst* lowerGenericFunction(IRInst* genericValue) + IRInst* lowerGenericFunction(IRInst* genericValue) + { + IRInst* result = nullptr; + if (sharedContext->loweredGenericFunctions.tryGetValue(genericValue, result)) + return result; + // Do not lower intrinsic functions. + if (genericValue->findDecoration<IRTargetIntrinsicDecoration>()) + return genericValue; + auto genericParent = as<IRGeneric>(genericValue); + SLANG_ASSERT(genericParent); + SLANG_ASSERT(genericParent->getDataType()); + auto genericRetVal = findGenericReturnVal(genericParent); + auto func = as<IRFunc>(genericRetVal); + if (!func) { - IRInst* result = nullptr; - if (sharedContext->loweredGenericFunctions.tryGetValue(genericValue, result)) - return result; - // Do not lower intrinsic functions. - if (genericValue->findDecoration<IRTargetIntrinsicDecoration>()) - return genericValue; - auto genericParent = as<IRGeneric>(genericValue); - SLANG_ASSERT(genericParent); - SLANG_ASSERT(genericParent->getDataType()); - auto genericRetVal = findGenericReturnVal(genericParent); - auto func = as<IRFunc>(genericRetVal); - if (!func) + // Nested generic functions are supposed to be flattened before entering + // this pass. The reason we are still seeing them must be that they are + // intrinsic functions. In this case we ignore the function. + if (as<IRGeneric>(genericRetVal)) { - // Nested generic functions are supposed to be flattened before entering - // this pass. The reason we are still seeing them must be that they are - // intrinsic functions. In this case we ignore the function. - if (as<IRGeneric>(genericRetVal)) - { - SLANG_ASSERT(findInnerMostGenericReturnVal(genericParent) - ->findDecoration<IRTargetIntrinsicDecoration>() != nullptr); - } - return genericValue; + SLANG_ASSERT( + findInnerMostGenericReturnVal(genericParent) + ->findDecoration<IRTargetIntrinsicDecoration>() != nullptr); } - SLANG_ASSERT(func); - // Do not lower intrinsic functions. - UnownedStringSlice intrinsicDef; - IRInst* intrinsicInst; - if (!func->isDefinition() || findTargetIntrinsicDefinition(func, sharedContext->targetProgram->getTargetReq()->getTargetCaps(), intrinsicDef, intrinsicInst)) + return genericValue; + } + SLANG_ASSERT(func); + // Do not lower intrinsic functions. + UnownedStringSlice intrinsicDef; + IRInst* intrinsicInst; + if (!func->isDefinition() || + findTargetIntrinsicDefinition( + func, + sharedContext->targetProgram->getTargetReq()->getTargetCaps(), + intrinsicDef, + intrinsicInst)) + { + sharedContext->loweredGenericFunctions[genericValue] = genericValue; + return genericValue; + } + IRCloneEnv cloneEnv; + IRBuilder builder(sharedContext->module); + builder.setInsertBefore(genericParent); + // Do not clone func type (which would break IR def-use rules if we do it here) + // This is OK since we will lower the type immediately after the clone. + cloneEnv.mapOldValToNew[func->getFullType()] = builder.getTypeKind(); + auto loweredFunc = cast<IRFunc>(cloneInstAndOperands(&cloneEnv, &builder, func)); + auto loweredGenericType = + lowerGenericFuncType(&builder, genericParent, cast<IRFuncType>(func->getFullType())); + SLANG_ASSERT(loweredGenericType); + loweredFunc->setFullType(loweredGenericType); + + OrderedHashSet<IRInst*> childrenToDemote; + List<IRInst*> clonedParams; + auto moduleInst = genericParent->getModule()->getModuleInst(); + for (auto genericChild : genericParent->getFirstBlock()->getChildren()) + { + switch (genericChild->getOp()) { - sharedContext->loweredGenericFunctions[genericValue] = genericValue; - return genericValue; + case kIROp_Func: continue; + case kIROp_Return: continue; } - IRCloneEnv cloneEnv; - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(genericParent); - // Do not clone func type (which would break IR def-use rules if we do it here) - // This is OK since we will lower the type immediately after the clone. - cloneEnv.mapOldValToNew[func->getFullType()] = builder.getTypeKind(); - auto loweredFunc = cast<IRFunc>(cloneInstAndOperands(&cloneEnv, &builder, func)); - auto loweredGenericType = - lowerGenericFuncType(&builder, genericParent, cast<IRFuncType>(func->getFullType())); - SLANG_ASSERT(loweredGenericType); - loweredFunc->setFullType(loweredGenericType); - - OrderedHashSet<IRInst*> childrenToDemote; - List<IRInst*> clonedParams; - auto moduleInst = genericParent->getModule()->getModuleInst(); - for (auto genericChild : genericParent->getFirstBlock()->getChildren()) + // Process all generic parameters and local type definitions. + auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild); + switch (clonedChild->getOp()) { - switch (genericChild->getOp()) + case kIROp_Param: { - case kIROp_Func: - continue; - case kIROp_Return: - continue; + auto paramType = clonedChild->getFullType(); + auto loweredParamType = sharedContext->lowerType(&builder, paramType); + if (loweredParamType != paramType) + { + clonedChild->setFullType((IRType*)loweredParamType); + } + clonedParams.add(clonedChild); } - // Process all generic parameters and local type definitions. - auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild); - switch (clonedChild->getOp()) + break; + case kIROp_Specialize: + case kIROp_LookupWitness: childrenToDemote.add(clonedChild); break; + default: { - case kIROp_Param: + bool shouldDemote = false; + if (childrenToDemote.contains(clonedChild->getFullType())) + shouldDemote = true; + for (UInt i = 0; i < clonedChild->getOperandCount(); i++) { - auto paramType = clonedChild->getFullType(); - auto loweredParamType = sharedContext->lowerType(&builder, paramType); - if (loweredParamType != paramType) + if (childrenToDemote.contains(clonedChild->getOperand(i))) { - clonedChild->setFullType((IRType*)loweredParamType); + shouldDemote = true; + break; } - clonedParams.add(clonedChild); } - break; - case kIROp_Specialize: - case kIROp_LookupWitness: - childrenToDemote.add(clonedChild); - break; - default: + if (shouldDemote && clonedChild->getParent() == moduleInst) { - bool shouldDemote = false; - if (childrenToDemote.contains(clonedChild->getFullType())) - shouldDemote = true; - for (UInt i = 0; i < clonedChild->getOperandCount(); i++) - { - if (childrenToDemote.contains(clonedChild->getOperand(i))) - { - shouldDemote = true; - break; - } - } - if (shouldDemote && clonedChild->getParent() == moduleInst) - { - childrenToDemote.add(clonedChild); - } - continue; + childrenToDemote.add(clonedChild); } + continue; } } - cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, func, loweredFunc); - - auto block = as<IRBlock>(loweredFunc->getFirstChild()); - for (auto param : clonedParams) - { - param->removeFromParent(); - block->addParam(as<IRParam>(param)); - } + } + cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, func, loweredFunc); - // Demote specialize and lookupWitness insts and their dependents down to function body. - auto insertPoint = block->getFirstOrdinaryInst(); - List<IRInst*> childrenToDemoteList; - for (auto child : childrenToDemote) - childrenToDemoteList.add(child); - for (Index i = childrenToDemoteList.getCount() - 1; i >= 0; i--) - { - auto child = childrenToDemoteList[i]; - child->insertBefore(insertPoint); - } + auto block = as<IRBlock>(loweredFunc->getFirstChild()); + for (auto param : clonedParams) + { + param->removeFromParent(); + block->addParam(as<IRParam>(param)); + } - // Lower generic typed parameters into AnyValueType. - auto firstInst = loweredFunc->getFirstOrdinaryInst(); - builder.setInsertBefore(firstInst); - sharedContext->loweredGenericFunctions[genericValue] = loweredFunc; - sharedContext->addToWorkList(loweredFunc); - return loweredFunc; + // Demote specialize and lookupWitness insts and their dependents down to function body. + auto insertPoint = block->getFirstOrdinaryInst(); + List<IRInst*> childrenToDemoteList; + for (auto child : childrenToDemote) + childrenToDemoteList.add(child); + for (Index i = childrenToDemoteList.getCount() - 1; i >= 0; i--) + { + auto child = childrenToDemoteList[i]; + child->insertBefore(insertPoint); } - IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal, IRFuncType* funcType) + // Lower generic typed parameters into AnyValueType. + auto firstInst = loweredFunc->getFirstOrdinaryInst(); + builder.setInsertBefore(firstInst); + sharedContext->loweredGenericFunctions[genericValue] = loweredFunc; + sharedContext->addToWorkList(loweredFunc); + return loweredFunc; + } + + IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal, IRFuncType* funcType) + { + ShortList<IRInst*> genericParamTypes; + Dictionary<IRInst*, IRInst*> typeMapping; + for (auto genericParam : genericVal->getParams()) { - ShortList<IRInst*> genericParamTypes; - Dictionary<IRInst*, IRInst*> typeMapping; - for (auto genericParam : genericVal->getParams()) + genericParamTypes.add(sharedContext->lowerType(builder, genericParam->getFullType())); + if (auto anyValueSizeDecor = genericParam->findDecoration<IRTypeConstraintDecoration>()) { - genericParamTypes.add(sharedContext->lowerType(builder, genericParam->getFullType())); - if (auto anyValueSizeDecor = genericParam->findDecoration<IRTypeConstraintDecoration>()) - { - auto anyValueSize = sharedContext->getInterfaceAnyValueSize(anyValueSizeDecor->getConstraintType(), genericParam->sourceLoc); - auto anyValueType = builder->getAnyValueType(anyValueSize); - typeMapping[genericParam] = anyValueType; - } + auto anyValueSize = sharedContext->getInterfaceAnyValueSize( + anyValueSizeDecor->getConstraintType(), + genericParam->sourceLoc); + auto anyValueType = builder->getAnyValueType(anyValueSize); + typeMapping[genericParam] = anyValueType; } + } - auto innerType = (IRFuncType*)lowerFuncType( - builder, - funcType, - typeMapping, - genericParamTypes.getArrayView().arrayView); + auto innerType = (IRFuncType*)lowerFuncType( + builder, + funcType, + typeMapping, + genericParamTypes.getArrayView().arrayView); - return innerType; - } + return innerType; + } - IRType* lowerFuncType(IRBuilder* builder, IRFuncType* funcType, - const Dictionary<IRInst*, IRInst*>& typeMapping, - ArrayView<IRInst*> additionalParams) + IRType* lowerFuncType( + IRBuilder* builder, + IRFuncType* funcType, + const Dictionary<IRInst*, IRInst*>& typeMapping, + ArrayView<IRInst*> additionalParams) + { + List<IRInst*> newOperands; + bool translated = false; + for (UInt i = 0; i < funcType->getOperandCount(); i++) { - List<IRInst*> newOperands; - bool translated = false; - for (UInt i = 0; i < funcType->getOperandCount(); i++) - { - auto paramType = funcType->getOperand(i); - auto loweredParamType = sharedContext->lowerType(builder, paramType, typeMapping, nullptr); - SLANG_ASSERT(loweredParamType); - translated = translated || (loweredParamType != paramType); - newOperands.add(loweredParamType); - } - if (!translated && additionalParams.getCount() == 0) - return funcType; - for (Index i = 0; i < additionalParams.getCount(); i++) - { - newOperands.add(additionalParams[i]); - } - auto newFuncType = builder->getFuncType( - newOperands.getCount() - 1, - (IRType**)(newOperands.begin() + 1), - (IRType*)newOperands[0]); - - IRCloneEnv cloneEnv; - cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, funcType, newFuncType); - return newFuncType; + auto paramType = funcType->getOperand(i); + auto loweredParamType = + sharedContext->lowerType(builder, paramType, typeMapping, nullptr); + SLANG_ASSERT(loweredParamType); + translated = translated || (loweredParamType != paramType); + newOperands.add(loweredParamType); } - - IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) + if (!translated && additionalParams.getCount() == 0) + return funcType; + for (Index i = 0; i < additionalParams.getCount(); i++) { - IRInterfaceType* loweredType = nullptr; - if (sharedContext->loweredInterfaceTypes.tryGetValue(interfaceType, loweredType)) - return loweredType; - if (sharedContext->mapLoweredInterfaceToOriginal.containsKey(interfaceType)) - return interfaceType; - // Do not lower intrinsic interfaces. - if (isBuiltin(interfaceType)) - return interfaceType; - // Do not lower COM interfaces. - if (isComInterfaceType(interfaceType)) - return interfaceType; - - List<IRInterfaceRequirementEntry*> newEntries; + newOperands.add(additionalParams[i]); + } + auto newFuncType = builder->getFuncType( + newOperands.getCount() - 1, + (IRType**)(newOperands.begin() + 1), + (IRType*)newOperands[0]); - IRBuilder builder(sharedContext->module); - builder.setInsertBefore(interfaceType); + IRCloneEnv cloneEnv; + cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, funcType, newFuncType); + return newFuncType; + } - // Translate IRFuncType in interface requirements. - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - if (auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i))) - { - // Note: The logic that creates the `IRInterfaceRequirementEntry`s does - // not currently guarantee that the *value* part of each key-value pair - // gets filled in. We thus need to defend against a null `requirementVal` - // here, at least until the underlying issue gets resolved. - // - IRInst* requirementVal = entry->getRequirementVal(); - IRInst* loweredVal = nullptr; - if(!requirementVal) - {} - else if (auto funcType = as<IRFuncType>(requirementVal)) - { - loweredVal = lowerFuncType(&builder, funcType, Dictionary<IRInst*, IRInst*>(), ArrayView<IRInst*>()); - } - else if (auto genericFuncType = as<IRGeneric>(requirementVal)) - { - loweredVal = lowerGenericFuncType( - &builder, - genericFuncType, - cast<IRFuncType>(findGenericReturnVal(genericFuncType))); - } - else if (requirementVal->getOp() == kIROp_AssociatedType) - { - loweredVal = builder.getRTTIHandleType(); - } - else - { - loweredVal = requirementVal; - } - auto newEntry = builder.createInterfaceRequirementEntry(entry->getRequirementKey(), loweredVal); - newEntries.add(newEntry); - } - } - loweredType = builder.createInterfaceType(newEntries.getCount(), (IRInst**)newEntries.getBuffer()); - loweredType->sourceLoc = interfaceType->sourceLoc; - IRCloneEnv cloneEnv; - cloneInstDecorationsAndChildren(&cloneEnv, sharedContext->module, - interfaceType, loweredType); - sharedContext->loweredInterfaceTypes.add(interfaceType, loweredType); - sharedContext->mapLoweredInterfaceToOriginal[loweredType] = interfaceType; + IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) + { + IRInterfaceType* loweredType = nullptr; + if (sharedContext->loweredInterfaceTypes.tryGetValue(interfaceType, loweredType)) return loweredType; - } + if (sharedContext->mapLoweredInterfaceToOriginal.containsKey(interfaceType)) + return interfaceType; + // Do not lower intrinsic interfaces. + if (isBuiltin(interfaceType)) + return interfaceType; + // Do not lower COM interfaces. + if (isComInterfaceType(interfaceType)) + return interfaceType; - bool isTypeKindVal(IRInst* inst) - { - auto type = inst->getDataType(); - if (!type) return false; - return type->getOp() == kIROp_TypeKind; - } + List<IRInterfaceRequirementEntry*> newEntries; + + IRBuilder builder(sharedContext->module); + builder.setInsertBefore(interfaceType); - // Lower items in a witness table. This triggers lowering of generic functions, - // and emission of wrapper functions. - void lowerWitnessTable(IRWitnessTable* witnessTable) + // Translate IRFuncType in interface requirements. + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { - auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTable->getConformanceType())); - IRBuilder builderStorage(sharedContext->module); - auto builder = &builderStorage; - builder->setInsertBefore(witnessTable); - if (interfaceType != witnessTable->getConformanceType()) + if (auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i))) { - auto newWitnessTableType = builder->getWitnessTableType(interfaceType); - witnessTable->setFullType(newWitnessTableType); - } - if (isBuiltin(interfaceType)) - return; - for (auto child : witnessTable->getChildren()) - { - auto entry = as<IRWitnessTableEntry>(child); - if (!entry) - continue; - if (auto genericVal = as<IRGeneric>(entry->getSatisfyingVal())) + // Note: The logic that creates the `IRInterfaceRequirementEntry`s does + // not currently guarantee that the *value* part of each key-value pair + // gets filled in. We thus need to defend against a null `requirementVal` + // here, at least until the underlying issue gets resolved. + // + IRInst* requirementVal = entry->getRequirementVal(); + IRInst* loweredVal = nullptr; + if (!requirementVal) { - // Lower generic functions in witness table. - if (findGenericReturnVal(genericVal)->getOp() == kIROp_Func) - { - auto loweredFunc = lowerGenericFunction(genericVal); - entry->satisfyingVal.set(loweredFunc); - } } - else if (isTypeKindVal(entry->getSatisfyingVal())) + else if (auto funcType = as<IRFuncType>(requirementVal)) + { + loweredVal = lowerFuncType( + &builder, + funcType, + Dictionary<IRInst*, IRInst*>(), + ArrayView<IRInst*>()); + } + else if (auto genericFuncType = as<IRGeneric>(requirementVal)) { - // Translate a Type value to an RTTI object pointer. - auto rttiObject = sharedContext->maybeEmitRTTIObject(entry->getSatisfyingVal()); - auto rttiObjectPtr = builder->emitGetAddress( - builder->getRTTIHandleType(), - rttiObject); - entry->satisfyingVal.set(rttiObjectPtr); + loweredVal = lowerGenericFuncType( + &builder, + genericFuncType, + cast<IRFuncType>(findGenericReturnVal(genericFuncType))); } - else if (as<IRWitnessTable>(entry->getSatisfyingVal())) + else if (requirementVal->getOp() == kIROp_AssociatedType) { - // No processing needed here. - // The witness table will be processed from the work list. + loweredVal = builder.getRTTIHandleType(); } + else + { + loweredVal = requirementVal; + } + auto newEntry = + builder.createInterfaceRequirementEntry(entry->getRequirementKey(), loweredVal); + newEntries.add(newEntry); } } + loweredType = + builder.createInterfaceType(newEntries.getCount(), (IRInst**)newEntries.getBuffer()); + loweredType->sourceLoc = interfaceType->sourceLoc; + IRCloneEnv cloneEnv; + cloneInstDecorationsAndChildren( + &cloneEnv, + sharedContext->module, + interfaceType, + loweredType); + sharedContext->loweredInterfaceTypes.add(interfaceType, loweredType); + sharedContext->mapLoweredInterfaceToOriginal[loweredType] = interfaceType; + return loweredType; + } - void lowerLookupInterfaceMethodInst(IRLookupWitnessMethod* lookupInst) - { - // Update the type of lookupInst to the lowered type of the corresponding interface requirement val. + bool isTypeKindVal(IRInst* inst) + { + auto type = inst->getDataType(); + if (!type) + return false; + return type->getOp() == kIROp_TypeKind; + } - // If the requirement is a function, interfaceRequirementVal will be the lowered function type. - // If the requirement is an associatedtype, interfaceRequirementVal will be Ptr<RTTIObject>. - IRInst* interfaceRequirementVal = nullptr; - auto witnessTableType = as<IRWitnessTableType>(lookupInst->getWitnessTable()->getDataType()); - if (!witnessTableType) return; - if (witnessTableType->getConformanceType()->findDecoration<IRComInterfaceDecoration>()) - return; - auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType())); - interfaceRequirementVal = sharedContext->findInterfaceRequirementVal(interfaceType, lookupInst->getRequirementKey()); - IRBuilder builder(lookupInst); - builder.replaceOperand(&lookupInst->typeUse, interfaceRequirementVal); + // Lower items in a witness table. This triggers lowering of generic functions, + // and emission of wrapper functions. + void lowerWitnessTable(IRWitnessTable* witnessTable) + { + auto interfaceType = + maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTable->getConformanceType())); + IRBuilder builderStorage(sharedContext->module); + auto builder = &builderStorage; + builder->setInsertBefore(witnessTable); + if (interfaceType != witnessTable->getConformanceType()) + { + auto newWitnessTableType = builder->getWitnessTableType(interfaceType); + witnessTable->setFullType(newWitnessTableType); } - - void lowerSpecialize(IRSpecialize* specializeInst) + if (isBuiltin(interfaceType)) + return; + for (auto child : witnessTable->getChildren()) { - // If we see a call(specialize(gFunc, Targs), args), - // translate it into call(gFunc, args, Targs). - IRInst* loweredFunc = nullptr; - auto funcToSpecialize = specializeInst->getBase(); - if (funcToSpecialize->getOp() == kIROp_Generic) + auto entry = as<IRWitnessTableEntry>(child); + if (!entry) + continue; + if (auto genericVal = as<IRGeneric>(entry->getSatisfyingVal())) { - loweredFunc = lowerGenericFunction(funcToSpecialize); - if (loweredFunc != funcToSpecialize) + // Lower generic functions in witness table. + if (findGenericReturnVal(genericVal)->getOp() == kIROp_Func) { - IRBuilder builder; - builder.replaceOperand(specializeInst->getOperands(), loweredFunc); + auto loweredFunc = lowerGenericFunction(genericVal); + entry->satisfyingVal.set(loweredFunc); } } - } - - void processInst(IRInst* inst) - { - if (auto specializeInst = as<IRSpecialize>(inst)) + else if (isTypeKindVal(entry->getSatisfyingVal())) { - lowerSpecialize(specializeInst); + // Translate a Type value to an RTTI object pointer. + auto rttiObject = sharedContext->maybeEmitRTTIObject(entry->getSatisfyingVal()); + auto rttiObjectPtr = + builder->emitGetAddress(builder->getRTTIHandleType(), rttiObject); + entry->satisfyingVal.set(rttiObjectPtr); } - else if (auto lookupInterfaceMethod = as<IRLookupWitnessMethod>(inst)) + else if (as<IRWitnessTable>(entry->getSatisfyingVal())) { - lowerLookupInterfaceMethodInst(lookupInterfaceMethod); + // No processing needed here. + // The witness table will be processed from the work list. } - else if (auto witnessTable = as<IRWitnessTable>(inst)) - { - lowerWitnessTable(witnessTable); - } - else if (auto interfaceType = as<IRInterfaceType>(inst)) + } + } + + void lowerLookupInterfaceMethodInst(IRLookupWitnessMethod* lookupInst) + { + // Update the type of lookupInst to the lowered type of the corresponding interface + // requirement val. + + // If the requirement is a function, interfaceRequirementVal will be the lowered function + // type. If the requirement is an associatedtype, interfaceRequirementVal will be + // Ptr<RTTIObject>. + IRInst* interfaceRequirementVal = nullptr; + auto witnessTableType = + as<IRWitnessTableType>(lookupInst->getWitnessTable()->getDataType()); + if (!witnessTableType) + return; + if (witnessTableType->getConformanceType()->findDecoration<IRComInterfaceDecoration>()) + return; + auto interfaceType = + maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType())); + interfaceRequirementVal = sharedContext->findInterfaceRequirementVal( + interfaceType, + lookupInst->getRequirementKey()); + IRBuilder builder(lookupInst); + builder.replaceOperand(&lookupInst->typeUse, interfaceRequirementVal); + } + + void lowerSpecialize(IRSpecialize* specializeInst) + { + // If we see a call(specialize(gFunc, Targs), args), + // translate it into call(gFunc, args, Targs). + IRInst* loweredFunc = nullptr; + auto funcToSpecialize = specializeInst->getBase(); + if (funcToSpecialize->getOp() == kIROp_Generic) + { + loweredFunc = lowerGenericFunction(funcToSpecialize); + if (loweredFunc != funcToSpecialize) { - maybeLowerInterfaceType(interfaceType); + IRBuilder builder; + builder.replaceOperand(specializeInst->getOperands(), loweredFunc); } } + } - void replaceLoweredInterfaceTypes() + void processInst(IRInst* inst) + { + if (auto specializeInst = as<IRSpecialize>(inst)) { - for (const auto& [loweredKey, loweredValue] : sharedContext->loweredInterfaceTypes) - loweredKey->replaceUsesWith(loweredValue); - sharedContext->mapInterfaceRequirementKeyValue.clear(); + lowerSpecialize(specializeInst); } - - void processModule() + else if (auto lookupInterfaceMethod = as<IRLookupWitnessMethod>(inst)) + { + lowerLookupInterfaceMethodInst(lookupInterfaceMethod); + } + else if (auto witnessTable = as<IRWitnessTable>(inst)) { - sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + lowerWitnessTable(witnessTable); + } + else if (auto interfaceType = as<IRInterfaceType>(inst)) + { + maybeLowerInterfaceType(interfaceType); + } + } - while (sharedContext->workList.getCount() != 0) - { - IRInst* inst = sharedContext->workList.getLast(); + void replaceLoweredInterfaceTypes() + { + for (const auto& [loweredKey, loweredValue] : sharedContext->loweredInterfaceTypes) + loweredKey->replaceUsesWith(loweredValue); + sharedContext->mapInterfaceRequirementKeyValue.clear(); + } - sharedContext->workList.removeLast(); - sharedContext->workListSet.remove(inst); + void processModule() + { + sharedContext->addToWorkList(sharedContext->module->getModuleInst()); - processInst(inst); + while (sharedContext->workList.getCount() != 0) + { + IRInst* inst = sharedContext->workList.getLast(); - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - sharedContext->addToWorkList(child); - } - } + sharedContext->workList.removeLast(); + sharedContext->workListSet.remove(inst); - replaceLoweredInterfaceTypes(); + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + sharedContext->addToWorkList(child); + } } - }; - void lowerGenericFunctions(SharedGenericsLoweringContext* sharedContext) - { - GenericFunctionLoweringContext context; - context.sharedContext = sharedContext; - context.processModule(); + + replaceLoweredInterfaceTypes(); } +}; +void lowerGenericFunctions(SharedGenericsLoweringContext* sharedContext) +{ + GenericFunctionLoweringContext context; + context.sharedContext = sharedContext; + context.processModule(); } - +} // namespace Slang |
