From e9d5ecbf19147af6e1473020b64ced4286b79079 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 15 Jul 2020 11:39:11 -0700 Subject: Refactor lower-generics pass into separate subpasses. (#1442) --- .../slang/slang-ir-generics-lowering-context.cpp | 101 ++++ source/slang/slang-ir-generics-lowering-context.h | 65 ++ source/slang/slang-ir-inst-defs.h | 4 + source/slang/slang-ir-insts.h | 5 + source/slang/slang-ir-lower-generic-call.cpp | 185 ++++++ source/slang/slang-ir-lower-generic-call.h | 13 + source/slang/slang-ir-lower-generic-function.cpp | 359 +++++++++++ source/slang/slang-ir-lower-generic-function.h | 13 + source/slang/slang-ir-lower-generic-var.cpp | 193 ++++++ source/slang/slang-ir-lower-generic-var.h | 12 + source/slang/slang-ir-lower-generics.cpp | 656 +-------------------- source/slang/slang-ir-lower-generics.h | 2 + source/slang/slang-ir.cpp | 4 +- source/slang/slang-ir.h | 10 +- source/slang/slang-lower-to-ir.cpp | 23 +- source/slang/slang.vcxproj | 8 + source/slang/slang.vcxproj.filters | 24 + 17 files changed, 1015 insertions(+), 662 deletions(-) create mode 100644 source/slang/slang-ir-generics-lowering-context.cpp create mode 100644 source/slang/slang-ir-generics-lowering-context.h create mode 100644 source/slang/slang-ir-lower-generic-call.cpp create mode 100644 source/slang/slang-ir-lower-generic-call.h create mode 100644 source/slang/slang-ir-lower-generic-function.cpp create mode 100644 source/slang/slang-ir-lower-generic-function.h create mode 100644 source/slang/slang-ir-lower-generic-var.cpp create mode 100644 source/slang/slang-ir-lower-generic-var.h (limited to 'source') diff --git a/source/slang/slang-ir-generics-lowering-context.cpp b/source/slang/slang-ir-generics-lowering-context.cpp new file mode 100644 index 000000000..6ee0a17f0 --- /dev/null +++ b/source/slang/slang-ir-generics-lowering-context.cpp @@ -0,0 +1,101 @@ +//slang-ir-generics-lowering-context.cpp + +#include "slang-ir-generics-lowering-context.h" + +#include "slang-ir-layout.h" + +namespace Slang +{ + bool isPolymorphicType(IRInst* typeInst) + { + if (as(typeInst) && as(typeInst->getDataType())) + return true; + switch (typeInst->op) + { + case kIROp_ThisType: + case kIROp_AssociatedType: + case kIROp_InterfaceType: + return true; + case kIROp_Specialize: + { + for (UInt i = 0; i < typeInst->getOperandCount(); i++) + { + if (isPolymorphicType(typeInst->getOperand(i))) + return true; + } + return false; + } + default: + break; + } + if (auto ptrType = as(typeInst)) + { + return isPolymorphicType(ptrType->getValueType()); + } + return false; + } + + bool isTypeValue(IRInst* typeInst) + { + if (typeInst) + { + switch (typeInst->op) + { + case kIROp_TypeType: + return true; + case kIROp_lookup_interface_method: + return typeInst->getDataType()->op == kIROp_TypeKind; + default: + return false; + } + } + return false; + } + + IRInst* SharedGenericsLoweringContext::maybeEmitRTTIObject(IRInst* typeInst) + { + IRInst* result = nullptr; + if (mapTypeToRTTIObject.TryGetValue(typeInst, result)) + return result; + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedBuilderStorage; + builder->setInsertBefore(typeInst->next); + + result = builder->emitMakeRTTIObject(typeInst); + + // For now the only type info we encapsualte is type size. + IRSizeAndAlignment sizeAndAlignment; + getNaturalSizeAndAlignment((IRType*)typeInst, &sizeAndAlignment); + builder->addRTTITypeSizeDecoration(result, sizeAndAlignment.size); + + // Give a name to the rtti object. + if (auto exportDecoration = typeInst->findDecoration()) + { + String rttiObjName = String(exportDecoration->getMangledName()) + "_rtti"; + builder->addExportDecoration(result, rttiObjName.getUnownedSlice()); + } + mapTypeToRTTIObject[typeInst] = result; + return result; + } + + IRInst* SharedGenericsLoweringContext::findInterfaceRequirementVal(IRInterfaceType* interfaceType, IRInst* requirementKey) + { + if (auto dict = mapInterfaceRequirementKeyValue.TryGetValue(interfaceType)) + return (*dict)[requirementKey].GetValue(); + _builldInterfaceRequirementMap(interfaceType); + return findInterfaceRequirementVal(interfaceType, requirementKey); + } + + void SharedGenericsLoweringContext::_builldInterfaceRequirementMap(IRInterfaceType* interfaceType) + { + mapInterfaceRequirementKeyValue.Add(interfaceType, + Dictionary()); + auto dict = mapInterfaceRequirementKeyValue.TryGetValue(interfaceType); + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + auto entry = cast(interfaceType->getOperand(i)); + (*dict)[entry->getRequirementKey()] = entry->getRequirementVal(); + } + } +} diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h new file mode 100644 index 000000000..11aef34aa --- /dev/null +++ b/source/slang/slang-ir-generics-lowering-context.h @@ -0,0 +1,65 @@ +// slang-ir-generics-lowering-context.h +#pragma once + +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + struct IRModule; + + struct SharedGenericsLoweringContext + { + // For convenience, we will keep a pointer to the module + // we are processing. + IRModule* module; + + // RTTI objects for each type used to call a generic function. + Dictionary mapTypeToRTTIObject; + + Dictionary loweredGenericFunctions; + HashSet loweredInterfaceTypes; + + // Dictionaries for interface type requirement key-value lookups. + // Used by `findInterfaceRequirementVal`. + Dictionary> mapInterfaceRequirementKeyValue; + + SharedIRBuilder sharedBuilderStorage; + + // We will use a single work list of instructions that need + // to be considered for lowering. + // + List workList; + HashSet workListSet; + + void addToWorkList( + IRInst* inst) + { + for (auto ii = inst->getParent(); ii; ii = ii->getParent()) + { + if (as(ii)) + return; + } + + if (workListSet.Contains(inst)) + return; + + workList.add(inst); + workListSet.Add(inst); + } + + + void _builldInterfaceRequirementMap(IRInterfaceType* interfaceType); + + IRInst* findInterfaceRequirementVal(IRInterfaceType* interfaceType, IRInst* requirementKey); + + // Emits an IRRTTIObject containing type information for a given type. + IRInst* maybeEmitRTTIObject(IRInst* typeInst); + }; + + bool isPolymorphicType(IRInst* typeInst); + + // Returns true if typeInst represents a type and should be lowered into + // Ptr(RTTIType). + bool isTypeValue(IRInst* typeInst); +} diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 5cbf7f03b..e01121d36 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -25,8 +25,12 @@ INST(Nop, nop, 0, 0) INST_RANGE(BasicType, VoidType, AfterBaseType) INST(StringType, String, 0, 0) + INST(RawPointerType, RawPointerType, 0, 0) INST(RTTIPointerType, RTTIPointerType, 1, 0) + INST(AfterRawPointerTypeBase, AfterRawPointerTypeBase, 0, 0) + INST_RANGE(RawPointerTypeBase, RawPointerType, AfterRawPointerTypeBase) + /* ArrayTypeBase */ INST(ArrayType, Array, 2, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 745cc6b02..ed9993512 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1483,6 +1483,11 @@ struct IRWitnessTable : IRInst return getOperand(0); } + void setConformanceType(IRInst* type) + { + setOperand(0, type); + } + IR_LEAF_ISA(WitnessTable) }; diff --git a/source/slang/slang-ir-lower-generic-call.cpp b/source/slang/slang-ir-lower-generic-call.cpp new file mode 100644 index 000000000..f339d5309 --- /dev/null +++ b/source/slang/slang-ir-lower-generic-call.cpp @@ -0,0 +1,185 @@ +// slang-ir-lower-generic-function.cpp +#include "slang-ir-lower-generic-function.h" +#include "slang-ir-generics-lowering-context.h" + +namespace Slang +{ + struct GenericCallLoweringContext + { + SharedGenericsLoweringContext* sharedContext; + + // Translate `callInst` into a call of `newCallee`, and respect the new `funcType`. + // If `funcType` involve lowered generic parameters or return values, this function + // also translates the argument list to match with that. + // If `newCallee` is a lowered generic function, `specializeInst` contains the type + // arguments used to specialize the callee. + void translateCallInst( + IRCall* callInst, + IRFuncType* funcType, + IRInst* newCallee, + IRSpecialize* specializeInst) + { + List paramTypes; + for (UInt i = 0; i < funcType->getParamCount(); i++) + paramTypes.add(funcType->getParamType(i)); + + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(callInst); + + List args; + + // Indicates whether the caller should allocate space for return value. + // If the lowered callee returns void and this call inst has a type that is not void, + // then we are calling a transformed function that expects caller allocated return value + // as the first argument. + bool shouldCallerAllocateReturnValue = (funcType->getResultType()->op == kIROp_VoidType && + callInst->getDataType() != funcType->getResultType()); + + IRVar* retVarInst = nullptr; + int startParamIndex = 0; + if (shouldCallerAllocateReturnValue) + { + // Declare a var for the return value. + retVarInst = builder->emitVar(callInst->getFullType()); + args.add(retVarInst); + startParamIndex = 1; + } + + for (UInt i = 0; i < callInst->getArgCount(); i++) + { + auto arg = callInst->getArg(i); + if (as(paramTypes[i] + startParamIndex) && + !as(arg->getDataType()) && + !as(arg->getDataType())) + { + // We are calling a generic function that with an argument of + // some concrete value type. We need to convert this argument to void*. + // We do so by defining a local variable, store the SSA value + // in the variable, and use the pointer of this variable as argument. + auto localVar = builder->emitVar(arg->getDataType()); + builder->emitStore(localVar, arg); + arg = localVar; + } + args.add(arg); + } + if (specializeInst) + { + for (UInt i = 0; i < specializeInst->getArgCount(); i++) + { + auto arg = specializeInst->getArg(i); + // Translate Type arguments into RTTI object. + if (as(arg)) + { + // We are using a simple type to specialize a callee. + // Generate RTTI for this type. + auto rttiObject = sharedContext->maybeEmitRTTIObject(arg); + arg = builder->emitGetAddress( + builder->getPtrType(builder->getRTTIType()), + rttiObject); + } + else if (arg->op == kIROp_Specialize) + { + // The type argument used to specialize a callee is itself a + // specialization of some generic type. + // TODO: generate RTTI object for specializations of generic types. + SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types"); + } + else if (arg->op == kIROp_RTTIObject) + { + // We are inside a generic function and using a generic parameter + // to specialize another callee. The generic parameter of the caller + // has already been translated into an RTTI object, so we just need + // to pass this object down. + } + args.add(arg); + } + } + auto callInstType = retVarInst ? builder->getVoidType() : callInst->getFullType(); + auto newCall = builder->emitCallInst(callInstType, newCallee, args); + if (retVarInst) + { + auto loadInst = builder->emitLoad(retVarInst); + callInst->replaceUsesWith(loadInst); + } + else + { + callInst->replaceUsesWith(newCall); + } + callInst->removeAndDeallocate(); + } + + void lowerCallToSpecializedFunc(IRCall* callInst, IRSpecialize* specializeInst) + { + // If we see a call(specialize(gFunc, Targs), args), + // translate it into call(gFunc, args, Targs). + auto loweredFunc = specializeInst->getBase(); + // All callees should have already been lowered in lower-generic-functions pass. + // For intrinsic generic functions, they are left as is, and we also need to ignore + // them here. + if (loweredFunc->op == kIROp_Generic) + { + // This is an intrinsic function, don't transform. + return; + } + IRFuncType* funcType = cast(loweredFunc->getDataType()); + translateCallInst(callInst, funcType, loweredFunc, specializeInst); + } + + void lowerCall(IRCall* callInst) + { + if (auto specializeInst = as(callInst->getCallee())) + lowerCallToSpecializedFunc(callInst, specializeInst); + } + + void processInst(IRInst* inst) + { + if (auto callInst = as(inst)) + { + lowerCall(callInst); + } + } + + void processModule() + { + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; + sharedBuilder->module = sharedContext->module; + sharedBuilder->session = sharedContext->module->session; + + sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + + while (sharedContext->workList.getCount() != 0) + { + // We will then iterate until our work list goes dry. + // + while (sharedContext->workList.getCount() != 0) + { + IRInst* inst = sharedContext->workList.getLast(); + + sharedContext->workList.removeLast(); + sharedContext->workListSet.Remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + sharedContext->addToWorkList(child); + } + } + } + } + }; + + void lowerGenericCalls(SharedGenericsLoweringContext* sharedContext) + { + GenericCallLoweringContext context; + context.sharedContext = sharedContext; + context.processModule(); + } + +} diff --git a/source/slang/slang-ir-lower-generic-call.h b/source/slang/slang-ir-lower-generic-call.h new file mode 100644 index 000000000..6b8d24515 --- /dev/null +++ b/source/slang/slang-ir-lower-generic-call.h @@ -0,0 +1,13 @@ +// slang-ir-lower-generic-call.h +#pragma once + +namespace Slang +{ + struct SharedGenericsLoweringContext; + + /// Lower generic and interface-based code to ordinary types and functions using + /// dynamic dispatch mechanisms. + void lowerGenericCalls( + SharedGenericsLoweringContext* sharedContext); + +} diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp new file mode 100644 index 000000000..1e725cfae --- /dev/null +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -0,0 +1,359 @@ +// slang-ir-lower-generic-function.cpp +#include "slang-ir-lower-generic-function.h" + +#include "slang-ir-generics-lowering-context.h" +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + // This is a subpass of generics lowering IR transformation. + // This pass lowers all generic function types and function definitions, including + // the function types used in interface types, to ordinary functions that takes + // raw pointers in place of generic types. + struct GenericFunctionLoweringContext + { + SharedGenericsLoweringContext* sharedContext; + IRInst* lowerParameterType(IRBuilder* builder, IRInst* paramType) + { + if (isTypeValue(paramType)) + { + return builder->getPtrType(builder->getRTTIType()); + } + if (isPolymorphicType(paramType)) + { + return builder->getRawPointerType(); + } + return paramType; + } + + IRInst* lowerGenericFunction(IRInst* genericValue) + { + IRInst* result = nullptr; + if (sharedContext->loweredGenericFunctions.TryGetValue(genericValue, result)) + return result; + auto genericParent = as(genericValue); + SLANG_ASSERT(genericParent); + auto func = as(findGenericReturnVal(genericParent)); + SLANG_ASSERT(func); + if (!func->isDefinition()) + { + sharedContext->loweredGenericFunctions[genericValue] = genericValue; + return genericValue; + } + IRCloneEnv cloneEnv; + IRBuilder builder; + builder.sharedBuilder = &sharedContext->sharedBuilderStorage; + builder.setInsertBefore(genericParent); + auto loweredFunc = cast(cloneInstAndOperands(&cloneEnv, &builder, func)); + loweredFunc->setFullType(lowerGenericFuncType(&builder, cast(genericParent->getFullType()))); + List clonedParams; + for (auto genericChild : genericParent->getFirstBlock()->getChildren()) + { + if (genericChild == func) + continue; + if (genericChild->op == kIROp_ReturnVal) + continue; + // Process all generic parameters and local type definitions. + auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild); + if (clonedChild->op == kIROp_Param) + { + auto paramType = clonedChild->getFullType(); + auto loweredParamType = lowerParameterType(&builder, paramType); + if (loweredParamType != paramType) + { + clonedChild->setFullType((IRType*)loweredParamType); + } + clonedParams.add(clonedChild); + } + } + cloneInstDecorationsAndChildren(&cloneEnv, &sharedContext->sharedBuilderStorage, func, loweredFunc); + + // If the function returns a generic typed value, we need to turn it + // into an `out` parameter, since only the caller can allocate space + // for it. + auto oldFuncType = cast(func->getDataType()); + if (isPolymorphicType(oldFuncType->getResultType())) + { + builder.setInsertBefore(loweredFunc->getFirstBlock()->getFirstOrdinaryInst()); + // We defer creation of the returnVal parameter until we see the first + // `return` instruction, because we can only obtain the cloned return type + // of this function by checking the type of the cloned return inst. + IRParam* retValParam = nullptr; + // Translate all return insts to `store`s. + // Those `store`s will be processed and translated into `copy`s when we + // get to process them via workList. + for (auto bb : loweredFunc->getBlocks()) + { + auto retInst = as(bb->getTerminator()); + if (!retInst) + continue; + if (!retValParam) + { + // Now we have the return type, emit the returnVal parameter. + // The type of this parameter is still not translated to RawPointer yet, + // and will be processed together with all the other existing parameters. + retValParam = builder.emitParamAtHead( + builder.getOutType(retInst->getVal()->getDataType())); + } + builder.setInsertBefore(retInst); + builder.emitStore(retValParam, retInst->getVal()); + builder.emitReturn(); + retInst->removeAndDeallocate(); + } + } + + auto block = as(loweredFunc->getFirstChild()); + for (auto param : clonedParams) + { + param->removeFromParent(); + block->addParam(as(param)); + } + // Lower generic typed parameters into RTTIPointers. + auto firstInst = loweredFunc->getFirstOrdinaryInst(); + builder.setInsertBefore(firstInst); + + for (IRInst* param = loweredFunc->getFirstParam(); + param && param->op == kIROp_Param; + param = param->getNextInst()) + { + // Generic typed parameters have a type that is a param itself. + auto paramType = param->getDataType(); + if (auto ptrType = as(paramType)) + paramType = ptrType->getValueType(); + if (isPointerOfType(paramType->getDataType(), kIROp_RTTIType)) + { + // Lower into a function parameter of raw pointer type. + param->setFullType(builder.getRawPointerType()); + auto newType = builder.getRTTIPointerType(paramType); + // Cast the raw pointer parameter into a RTTIPointer with RTTI info from the type parameter. + auto typedPtr = builder.emitBitCast(newType, param); + // Replace all uses of param with typePtr. + param->replaceUsesWith(typedPtr); + typedPtr->setOperand(0, param); + } + } + sharedContext->loweredGenericFunctions[genericValue] = loweredFunc; + sharedContext->addToWorkList(loweredFunc); + return loweredFunc; + } + + IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal) + { + ShortList genericParamTypes; + for (auto genericParam : genericVal->getParams()) + { + genericParamTypes.add(lowerParameterType(builder, genericParam->getFullType())); + } + + auto innerType = (IRFuncType*)lowerFuncType( + builder, + cast(findGenericReturnVal(genericVal)), + genericParamTypes.getArrayView().arrayView); + + return innerType; + } + + IRType* lowerFuncType(IRBuilder* builder, IRFuncType* funcType, ArrayView additionalParams) + { + List newOperands; + bool translated = false; + for (UInt i = 0; i < funcType->getOperandCount(); i++) + { + auto paramType = funcType->getOperand(i); + auto loweredParamType = lowerParameterType(builder, paramType); + translated = translated || (loweredParamType != paramType); + if (translated && i == 0) + { + // We are translating the return value, this means that + // the return value must be passed explicitly via an `out` parameter. + // In this case, the new return value will be `void`, and the + // translated return value type will be the first parameter type; + newOperands.add(builder->getVoidType()); + } + newOperands.add(loweredParamType); + } + if (!translated && additionalParams.getCount() == 0) + return funcType; + for (Index i = 0; i < additionalParams.getCount(); i++) + { + newOperands.add(additionalParams[i]); + } + auto newFuncType = builder->getFuncType( + newOperands.getCount() - 1, + (IRType**)(newOperands.begin() + 1), + (IRType*)newOperands[0]); + + IRCloneEnv cloneEnv; + cloneInstDecorationsAndChildren(&cloneEnv, &sharedContext->sharedBuilderStorage, funcType, newFuncType); + return newFuncType; + } + + IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) + { + if (sharedContext->loweredInterfaceTypes.Contains(interfaceType)) + return interfaceType; + + IRBuilder builder; + builder.sharedBuilder = &sharedContext->sharedBuilderStorage; + builder.setInsertBefore(interfaceType); + + // Translate IRFuncType in interface requirements. + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + if (auto entry = as(interfaceType->getOperand(i))) + { + if (auto funcType = as(entry->getRequirementVal())) + { + entry->setRequirementVal(lowerFuncType(&builder, funcType, ArrayView())); + } + else if (auto genericFuncType = as(entry->getRequirementVal())) + { + entry->setRequirementVal(lowerGenericFuncType(&builder, genericFuncType)); + } + } + } + + sharedContext->loweredInterfaceTypes.Add(interfaceType); + return interfaceType; + } + + bool isTypeKindVal(IRInst* inst) + { + auto type = inst->getDataType(); + if (!type) return false; + return type->op == kIROp_TypeKind; + } + + // Lower items in a witness table. This triggers lowering of generic functions, + // and emission of wrapper functions. + void lowerWitnessTable(IRWitnessTable* witnessTable) + { + auto interfaceType = maybeLowerInterfaceType(cast(witnessTable->getConformanceType())); + if (interfaceType != witnessTable->getConformanceType()) + witnessTable->setConformanceType(interfaceType); + for (auto child : witnessTable->getChildren()) + { + auto entry = as(child); + if (!entry) + continue; + if (auto genericVal = as(entry->getSatisfyingVal())) + { + // Lower generic functions in witness table. + if (findGenericReturnVal(genericVal)->op == kIROp_Func) + { + auto loweredFunc = lowerGenericFunction(genericVal); + entry->satisfyingVal.set(loweredFunc); + } + } + else if (isTypeKindVal(entry->getSatisfyingVal())) + { + // Translate a Type value to an RTTI object pointer. + auto rttiObject = sharedContext->maybeEmitRTTIObject(entry->getSatisfyingVal()); + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(witnessTable); + auto rttiObjectPtr = builder->emitGetAddress( + builder->getPtrType(builder->getRTTIType()), + rttiObject); + entry->satisfyingVal.set(rttiObjectPtr); + } + else if (as(entry->getSatisfyingVal())) + { + // No processing needed here. + // The witness table will be processed from the work list. + } + } + } + + void lowerLookupInterfaceMethodInst(IRLookupWitnessMethod* lookupInst) + { + // Update the type of lookupInst to the lowered type of the corresponding interface requirement val. + + // If the requirement is a function, interfaceRequirementVal will be the lowered function type. + IRInst* interfaceRequirementVal = nullptr; + auto witnessTableType = cast(lookupInst->getWitnessTable()->getDataType()); + auto interfaceType = maybeLowerInterfaceType(cast(witnessTableType->getConformanceType())); + interfaceRequirementVal = sharedContext->findInterfaceRequirementVal(interfaceType, lookupInst->getRequirementKey()); + lookupInst->setFullType((IRType*)interfaceRequirementVal); + } + + void lowerSpecialize(IRSpecialize* specializeInst) + { + // If we see a call(specialize(gFunc, Targs), args), + // translate it into call(gFunc, args, Targs). + IRInst* loweredFunc = nullptr; + auto funcToSpecialize = specializeInst->getBase(); + if (funcToSpecialize->op == kIROp_Generic) + { + loweredFunc = lowerGenericFunction(funcToSpecialize); + if (loweredFunc != funcToSpecialize) + { + specializeInst->setOperand(0, loweredFunc); + } + } + } + + void processInst(IRInst* inst) + { + if (auto specializeInst = as(inst)) + { + lowerSpecialize(specializeInst); + } + else if (auto lookupInterfaceMethod = as(inst)) + { + lowerLookupInterfaceMethodInst(lookupInterfaceMethod); + } + else if (auto witnessTable = as(inst)) + { + lowerWitnessTable(witnessTable); + } + else if (auto interfaceType = as(inst)) + { + maybeLowerInterfaceType(interfaceType); + } + } + + void processModule() + { + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; + sharedBuilder->module = sharedContext->module; + sharedBuilder->session = sharedContext->module->session; + + sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + + while (sharedContext->workList.getCount() != 0) + { + // We will then iterate until our work list goes dry. + // + while (sharedContext->workList.getCount() != 0) + { + IRInst* inst = sharedContext->workList.getLast(); + + sharedContext->workList.removeLast(); + sharedContext->workListSet.Remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + sharedContext->addToWorkList(child); + } + } + } + } + }; + void lowerGenericFunctions(SharedGenericsLoweringContext* sharedContext) + { + GenericFunctionLoweringContext context; + context.sharedContext = sharedContext; + context.processModule(); + } +} + diff --git a/source/slang/slang-ir-lower-generic-function.h b/source/slang/slang-ir-lower-generic-function.h new file mode 100644 index 000000000..c364cfdd0 --- /dev/null +++ b/source/slang/slang-ir-lower-generic-function.h @@ -0,0 +1,13 @@ +// slang-ir-lower-generic-function.h +#pragma once + +namespace Slang +{ + struct SharedGenericsLoweringContext; + + /// Lower generic and interface-based code to ordinary types and functions using + /// dynamic dispatch mechanisms. + void lowerGenericFunctions( + SharedGenericsLoweringContext* sharedContext); + +} diff --git a/source/slang/slang-ir-lower-generic-var.cpp b/source/slang/slang-ir-lower-generic-var.cpp new file mode 100644 index 000000000..0c45e8e38 --- /dev/null +++ b/source/slang/slang-ir-lower-generic-var.cpp @@ -0,0 +1,193 @@ +// slang-ir-lower-generic-function.cpp +#include "slang-ir-lower-generic-function.h" + +#include "slang-ir-generics-lowering-context.h" +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + // This is a subpass of generics lowering IR transformation. + // This pass lowers all generic function types and function definitions, including + // the function types used in interface types, to ordinary functions that takes + // raw pointers in place of generic types. + struct GenericVarLoweringContext + { + SharedGenericsLoweringContext* sharedContext; + + void processVarInst(IRInst* varInst) + { + // We process only var declarations that have type + // `Ptr`. + // Due to the processing of `lowerGenericFunction`, + // A local variable of generic type now appears as + // `var X:Ptr>` + // We match this pattern and turn this inst into + // `X:RawPtr = alloca(rtti_extract_size(irParam))` + auto varTypeInst = varInst->getDataType(); + if (!varTypeInst) + return; + auto ptrType = as(varTypeInst); + if (!ptrType) + return; + + // `varTypeParam` represents a pointer to the RTTI object. + auto varTypeParam = ptrType->getValueType(); + if (varTypeParam->op != kIROp_Param) + return; + if (!varTypeParam->getDataType()) + return; + if (varTypeParam->getDataType()->op != kIROp_PtrType) + return; + if (as(varTypeParam->getDataType())->getValueType()->op != kIROp_RTTIType) + return; + + + // A local variable of generic type has a type that is an IRParam. + // This parameter represents the RTTI that tells us the size of the type. + // We need to transform the variable into an `alloca` call to allocate its + // space based on the provided RTTI object. + + // Initialize IRBuilder for emitting instructions. + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(varInst); + + // The result of `alloca` is an RTTIPointer(rttiObject). + auto type = builder->getRTTIPointerType(varTypeParam); + auto newVarInst = builder->emitAlloca(type, varTypeParam); + varInst->replaceUsesWith(newVarInst); + varInst->removeAndDeallocate(); + } + + void processStoreInst(IRStore* storeInst) + { + auto rttiType = as(storeInst->ptr.get()->getDataType()); + if (!rttiType) + return; + // All stores of generic typed variables needs to be translated + // to `IRCopy`s. + auto valPtr = storeInst->val.get(); + if (valPtr->getDataType()->op == kIROp_RTTIPointerType) + { + // If `value` of the store is from another generic variable, it should + // have already been replaced with the pointer to that variable by now. + // So we don't need to do anything here. + } + else + { + // If value does not come from another generic variable, then it must be + // a param. In this case, the parameter is a bitCast of the parameter to an + // RTTIPointer type, so we just use the original parameter pointer and get + // rid of the bitcast. + SLANG_ASSERT(valPtr->op == kIROp_BitCast); + valPtr = valPtr->getOperand(0); + SLANG_ASSERT(valPtr->op == kIROp_Param); + } + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(storeInst); + auto copy = builder->emitCopy( + storeInst->ptr.get(), + valPtr, + rttiType->getRTTIOperand()); + storeInst->replaceUsesWith(copy); + storeInst->removeAndDeallocate(); + } + + void processLoadInst(IRLoad* loadInst) + { + auto rttiType = as(loadInst->ptr.get()->getDataType()); + if (!rttiType) + return; + // There are only two possible uses of a load(genericVar): + // 1. store(x, load(genVar)), which will be handled by processStoreInst. + // 2. call(f, load(genVar)) when calling a generic function or a member function + // via an interface witness lookup. In this case, we need to replace with + // just `genVar`, since that function has already been lowered to take + // raw pointers. + // In both cases, we can simply replace the use side with a pointer instead + // and never need to represent a "value" typed object explicitly. + // However, to preserve the ordering, we must make a copy of every load so + // we don't change the meaning of the code if there are `store`s between the + // `load` and the use site. + + IRBuilder builderStorage; + auto builder = &builderStorage; + builder->sharedBuilder = &sharedContext->sharedBuilderStorage; + builder->setInsertBefore(loadInst); + + // Allocate a copy of the value. + auto allocaInst = builder->emitAlloca(rttiType, rttiType->getRTTIOperand()); + builder->emitCopy(allocaInst, loadInst->ptr.get(), rttiType->getRTTIOperand()); + + // Here we replace all uses of load to just the pointer to the copy. + // After this, all arguments in `call`s will be in its correct form. + // All `store`s will become `store(x, genVar)`, and still need + // to be translated into another `copy`, we leave that step when we get to + // process the `store` inst. + loadInst->replaceUsesWith(allocaInst); + loadInst->removeAndDeallocate(); + } + + void processInst(IRInst* inst) + { + if (inst->op == kIROp_Var || inst->op == kIROp_undefined) + { + processVarInst(inst); + } + else if (inst->op == kIROp_Load) + { + processLoadInst(cast(inst)); + } + else if (inst->op == kIROp_Store) + { + processStoreInst(cast(inst)); + } + } + + void processModule() + { + // We start by initializing our shared IR building state, + // since we will re-use that state for any code we + // generate along the way. + // + SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; + sharedBuilder->module = sharedContext->module; + sharedBuilder->session = sharedContext->module->session; + + sharedContext->addToWorkList(sharedContext->module->getModuleInst()); + + while (sharedContext->workList.getCount() != 0) + { + // We will then iterate until our work list goes dry. + // + while (sharedContext->workList.getCount() != 0) + { + IRInst* inst = sharedContext->workList.getLast(); + + sharedContext->workList.removeLast(); + sharedContext->workListSet.Remove(inst); + + processInst(inst); + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + sharedContext->addToWorkList(child); + } + } + } + } + }; + + void lowerGenericVar(SharedGenericsLoweringContext* sharedContext) + { + GenericVarLoweringContext context; + context.sharedContext = sharedContext; + context.processModule(); + } +} + diff --git a/source/slang/slang-ir-lower-generic-var.h b/source/slang/slang-ir-lower-generic-var.h new file mode 100644 index 000000000..dfc59b24e --- /dev/null +++ b/source/slang/slang-ir-lower-generic-var.h @@ -0,0 +1,12 @@ +// slang-ir-lower-generic-var.h +#pragma once + +namespace Slang +{ + struct SharedGenericsLoweringContext; + + /// Lower load and stores of generic local variables into raw pointer operations. + void lowerGenericVar( + SharedGenericsLoweringContext* sharedContext); + +} diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 260f87bde..7876cc7d8 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -1,658 +1,20 @@ // slang-ir-lower-generics.cpp #include "slang-ir-lower-generics.h" -#include "slang-ir.h" -#include "slang-ir-layout.h" -#include "slang-ir-clone.h" -#include "slang-ir-insts.h" +#include "slang-ir-generics-lowering-context.h" +#include "slang-ir-lower-generic-function.h" +#include "slang-ir-lower-generic-call.h" +#include "slang-ir-lower-generic-var.h" namespace Slang { - struct GenericsLoweringContext; - - struct GenericsLoweringContext - { - // For convenience, we will keep a pointer to the module - // we are processing. - IRModule* module; - - // RTTI objects for each type used to call a generic function. - Dictionary mapTypeToRTTIObject; - - Dictionary loweredGenericFunctions; - HashSet loweredInterfaceTypes; - - SharedIRBuilder sharedBuilderStorage; - - // We will use a single work list of instructions that need - // to be considered for lowering. - // - List workList; - HashSet workListSet; - - void addToWorkList( - IRInst* inst) - { - // We will ignore any code that is nested under a generic, - // because they will be recursively processed through specialized - // call sites. - // - for (auto ii = inst->getParent(); ii; ii = ii->getParent()) - { - if (as(ii)) - return; - } - - if (workListSet.Contains(inst)) - return; - - workList.add(inst); - workListSet.Add(inst); - } - - bool isPolymorphicType(IRInst* typeInst) - { - if (as(typeInst) && as(typeInst->getDataType())) - return true; - switch (typeInst->op) - { - case kIROp_ThisType: - case kIROp_AssociatedType: - case kIROp_InterfaceType: - return true; - case kIROp_Specialize: - { - for (UInt i = 0; i < typeInst->getOperandCount(); i++) - { - if (isPolymorphicType(typeInst->getOperand(i))) - return true; - } - return false; - } - default: - break; - } - if (auto ptrType = as(typeInst)) - { - return isPolymorphicType(ptrType->getValueType()); - } - return false; - } - - IRInst* lowerParameterType(IRBuilder* builder, IRInst* paramType) - { - if (paramType && paramType->op == kIROp_TypeType) - { - return builder->getPtrType(builder->getRTTIType()); - } - if (isPolymorphicType(paramType)) - { - return builder->getRawPointerType(); - } - return paramType; - } - - IRInst* lowerGenericFunction(IRInst* genericValue) - { - IRInst* result = nullptr; - if (loweredGenericFunctions.TryGetValue(genericValue, result)) - return result; - auto genericParent = as(genericValue); - SLANG_ASSERT(genericParent); - auto func = as(findGenericReturnVal(genericParent)); - SLANG_ASSERT(func); - if (!func->isDefinition()) - { - loweredGenericFunctions[genericValue] = genericValue; - return genericValue; - } - IRCloneEnv cloneEnv; - IRBuilder builder; - builder.sharedBuilder = &sharedBuilderStorage; - builder.setInsertBefore(genericParent); - auto loweredFunc = cast(cloneInstAndOperands(&cloneEnv, &builder, func)); - loweredFunc->setFullType(lowerGenericFuncType(&builder, cast(genericParent->getFullType()))); - List clonedParams; - for (auto genericChild : genericParent->getFirstBlock()->getChildren()) - { - if (genericChild == func) - continue; - if (genericChild->op == kIROp_ReturnVal) - continue; - // Process all generic parameters and local type definitions. - auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild); - if (clonedChild->op == kIROp_Param) - { - auto paramType = clonedChild->getFullType(); - auto loweredParamType = lowerParameterType(&builder, paramType); - if (loweredParamType != paramType) - { - clonedChild->setFullType((IRType*)loweredParamType); - } - clonedParams.add(clonedChild); - } - } - cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, func, loweredFunc); - - // If the function returns a generic typed value, we need to turn it - // into an `out` parameter, since only the caller can allocate space - // for it. - auto oldFuncType = cast(func->getDataType()); - if (isPolymorphicType(oldFuncType->getResultType())) - { - builder.setInsertInto(loweredFunc->getFirstBlock()); - // We defer creation of the returnVal parameter until we see the first - // `return` instruction, because we can only obtain the cloned return type - // of this function by checking the type of the cloned return inst. - IRParam* retValParam = nullptr; - // Translate all return insts to `store`s. - // Those `store`s will be processed and translated into `copy`s when we - // get to process them via workList. - for (auto bb : loweredFunc->getBlocks()) - { - auto retInst = as(bb->getTerminator()); - if (!retInst) - continue; - if (!retValParam) - { - // Now we have the return type, emit the returnVal parameter. - // The type of this parameter is still not translated to RawPointer yet, - // and will be processed together with all the other existing parameters. - retValParam = builder.emitParamAtHead( - builder.getOutType(retInst->getVal()->getDataType())); - } - builder.setInsertBefore(retInst); - builder.emitStore(retValParam, retInst->getVal()); - builder.emitReturn(); - retInst->removeAndDeallocate(); - } - } - - auto block = as(loweredFunc->getFirstChild()); - for (auto param : clonedParams) - { - param->removeFromParent(); - block->addParam(as(param)); - } - // Lower generic typed parameters into RTTIPointers. - auto firstInst = loweredFunc->getFirstOrdinaryInst(); - builder.setInsertBefore(firstInst); - - for (IRInst* param = loweredFunc->getFirstParam(); - param && param->op == kIROp_Param; - param = param->getNextInst()) - { - // Generic typed parameters have a type that is a param itself. - auto paramType = param->getDataType(); - if (auto ptrType = as(paramType)) - paramType = ptrType->getValueType(); - if (auto rttiParam = as(paramType)) - { - SLANG_ASSERT(isPointerOfType(rttiParam->getDataType(), kIROp_RTTIType)); - // Lower into a function parameter of raw pointer type. - param->setFullType(builder.getRawPointerType()); - auto newType = builder.getRTTIPointerType(rttiParam); - // Cast the raw pointer parameter into a RTTIPointer with RTTI info from the type parameter. - auto typedPtr = builder.emitBitCast(newType, param); - // Replace all uses of param with typePtr. - param->replaceUsesWith(typedPtr); - typedPtr->setOperand(0, param); - } - } - loweredGenericFunctions[genericValue] = loweredFunc; - addToWorkList(loweredFunc); - return loweredFunc; - } - - IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal) - { - ShortList genericParamTypes; - for (auto genericParam : genericVal->getParams()) - { - genericParamTypes.add(lowerParameterType(builder, genericParam->getFullType())); - } - - auto innerType = (IRFuncType*)lowerFuncType( - builder, - cast(findGenericReturnVal(genericVal)), - genericParamTypes.getArrayView().arrayView); - - return innerType; - } - - IRType* lowerFuncType(IRBuilder* builder, IRFuncType* funcType, ArrayView additionalParams) - { - List newOperands; - bool translated = false; - for (UInt i = 0; i < funcType->getOperandCount(); i++) - { - auto paramType = funcType->getOperand(i); - if (paramType->op == kIROp_Specialize) - { - newOperands.add(builder->getRawPointerType()); - translated = true; - } - else - { - auto loweredParamType = lowerParameterType(builder, paramType); - translated = translated || (loweredParamType != paramType); - if (translated && i == 0) - { - // We are translating the return value, this means that - // the return value must be passed explicitly via an `out` parameter. - // In this case, the new return value will be `void`, and the - // translated return value type will be the first parameter type; - newOperands.add(builder->getVoidType()); - } - newOperands.add(loweredParamType); - } - } - if (!translated && additionalParams.getCount() == 0) - return funcType; - for (Index i = 0; i < additionalParams.getCount(); i++) - { - newOperands.add(additionalParams[i]); - } - auto newFuncType = builder->getFuncType( - newOperands.getCount() - 1, - (IRType**)(newOperands.begin() + 1), - (IRType*)newOperands[0]); - - IRCloneEnv cloneEnv; - cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, funcType, newFuncType); - return newFuncType; - } - - IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) - { - if (loweredInterfaceTypes.Contains(interfaceType)) - return interfaceType; - - IRBuilder builder; - builder.sharedBuilder = &sharedBuilderStorage; - builder.setInsertBefore(interfaceType); - - // Translate IRFuncType in interface requirements. - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - if (auto entry = as(interfaceType->getOperand(i))) - { - if (auto funcType = as(entry->getRequirementVal())) - { - entry->setRequirementVal(lowerFuncType(&builder, funcType, ArrayView())); - } - else if (auto genericFuncType = as(entry->getRequirementVal())) - { - entry->setRequirementVal(lowerGenericFuncType(&builder, genericFuncType)); - } - } - } - - loweredInterfaceTypes.Add(interfaceType); - return interfaceType; - } - - void processVarInst(IRInst* varInst) - { - // We process only var declarations that have type - // `Ptr`. - // Due to the processing of `lowerGenericFunction`, - // A local variable of generic type now appears as - // `var X:Ptr>` - // We match this pattern and turn this inst into - // `X:RawPtr = alloca(rtti_extract_size(irParam))` - auto varTypeInst = varInst->getDataType(); - if (!varTypeInst) - return; - auto ptrType = as(varTypeInst); - if (!ptrType) - return; - - // `varTypeParam` represents a pointer to the RTTI object. - auto varTypeParam = ptrType->getValueType(); - if (varTypeParam->op != kIROp_Param) - return; - if (!varTypeParam->getDataType()) - return; - if (varTypeParam->getDataType()->op != kIROp_PtrType) - return; - if (as(varTypeParam->getDataType())->getValueType()->op != kIROp_RTTIType) - return; - - - // A local variable of generic type has a type that is an IRParam. - // This parameter represents the RTTI that tells us the size of the type. - // We need to transform the variable into an `alloca` call to allocate its - // space based on the provided RTTI object. - - // Initialize IRBuilder for emitting instructions. - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(varInst); - - // The result of `alloca` is an RTTIPointer(rttiObject). - auto type = builder->getRTTIPointerType(varTypeParam); - auto newVarInst = builder->emitAlloca(type, varTypeParam); - varInst->replaceUsesWith(newVarInst); - varInst->removeAndDeallocate(); - } - - void processStoreInst(IRStore* storeInst) - { - auto rttiType = as(storeInst->ptr.get()->getDataType()); - if (!rttiType) - return; - // All stores of generic typed variables needs to be translated - // to `IRCopy`s. - auto valPtr = storeInst->val.get(); - if (valPtr->getDataType()->op == kIROp_RTTIPointerType) - { - // If `value` of the store is from another generic variable, it should - // have already been replaced with the pointer to that variable by now. - // So we don't need to do anything here. - } - else - { - // If value does not come from another generic variable, then it must be - // a param. In this case, the parameter is a bitCast of the parameter to an - // RTTIPointer type, so we just use the original parameter pointer and get - // rid of the bitcast. - SLANG_ASSERT(valPtr->op == kIROp_BitCast); - valPtr = valPtr->getOperand(0); - SLANG_ASSERT(valPtr->op == kIROp_Param); - } - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(storeInst); - auto copy = builder->emitCopy( - storeInst->ptr.get(), - valPtr, - rttiType->getRTTIOperand()); - storeInst->replaceUsesWith(copy); - storeInst->removeAndDeallocate(); - } - - void processLoadInst(IRLoad* loadInst) - { - auto rttiType = as(loadInst->ptr.get()->getDataType()); - if (!rttiType) - return; - // There are only two possible uses of a load(genericVar): - // 1. store(x, load(genVar)), which will be handled by processStoreInst. - // 2. call(f, load(genVar)) when calling a generic function or a member function - // via an interface witness lookup. In this case, we need to replace with - // just `genVar`, since that function has already been lowered to take - // raw pointers. - // In both cases, we can simply replace the use side with a pointer instead - // and never need to represent a "value" typed object explicitly. - // However, to preserve the ordering, we must make a copy of every load so - // we don't change the meaning of the code if there are `store`s between the - // `load` and the use site. - - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(loadInst); - - // Allocate a copy of the value. - auto allocaInst = builder->emitAlloca(rttiType, rttiType->getRTTIOperand()); - builder->emitCopy(allocaInst, loadInst->ptr.get(), rttiType->getRTTIOperand()); - - // Here we replace all uses of load to just the pointer to the copy. - // After this, all arguments in `call`s will be in its correct form. - // All `store`s will become `store(x, genVar)`, and still need - // to be translated into another `copy`, we leave that step when we get to - // process the `store` inst. - loadInst->replaceUsesWith(allocaInst); - loadInst->removeAndDeallocate(); - } - - // Emits an IRRTTIObject containing type information for a given type. - IRInst* maybeEmitRTTIObject(IRInst* typeInst) - { - IRInst* result = nullptr; - if (mapTypeToRTTIObject.TryGetValue(typeInst, result)) - return result; - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(typeInst->next); - - result = builder->emitMakeRTTIObject(typeInst); - - // For now the only type info we encapsualte is type size. - IRSizeAndAlignment sizeAndAlignment; - getNaturalSizeAndAlignment((IRType*)typeInst, &sizeAndAlignment); - builder->addRTTITypeSizeDecoration(result, sizeAndAlignment.size); - - // Give a name to the rtti object. - if (auto exportDecoration = typeInst->findDecoration()) - { - String rttiObjName = String(exportDecoration->getMangledName()) + "_rtti"; - builder->addExportDecoration(result, rttiObjName.getUnownedSlice()); - } - mapTypeToRTTIObject[typeInst] = result; - return result; - } - - void lowerCall(IRCall* callInst) - { - // If we see a call(specialize(gFunc, Targs), args), - // translate it into call(gFunc, args, Targs). - auto funcOperand = callInst->getOperand(0); - IRInst* loweredFunc = nullptr; - auto specializeInst = as(funcOperand); - if (!specializeInst) - return; - - auto funcToSpecialize = specializeInst->getOperand(0); - List paramTypes; - IRFuncType* funcType = nullptr; - if (auto interfaceLookup = as(funcToSpecialize)) - { - // The callee is a result of witness table lookup, we will only - // translate the call. - IRInst* callee = nullptr; - auto witnessTableType = cast(interfaceLookup->getWitnessTable()->getDataType()); - auto interfaceType = maybeLowerInterfaceType(cast(witnessTableType->getConformanceType())); - for (UInt i = 0; i < interfaceType->getOperandCount(); i++) - { - auto entry = cast(interfaceType->getOperand(i)); - if (entry->getRequirementKey() == interfaceLookup->getOperand(1)) - { - callee = entry->getRequirementVal(); - break; - } - } - funcType = cast(callee); - loweredFunc = funcToSpecialize; - } - else - { - loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); - if (loweredFunc == specializeInst->getOperand(0)) - { - // This is an intrinsic function, don't transform. - return; - } - funcType = cast(loweredFunc->getDataType()); - } - - for (UInt i = 0; i < funcType->getParamCount(); i++) - paramTypes.add(funcType->getParamType(i)); - - IRBuilder builderStorage; - auto builder = &builderStorage; - builder->sharedBuilder = &sharedBuilderStorage; - builder->setInsertBefore(callInst); - - List args; - - // Indicates whether the caller should allocate space for return value. - // If the lowered callee returns void and this call inst has a type that is not void, - // then we are calling a transformed function that expects caller allocated return value - // as the first argument. - bool shouldCallerAllocateReturnValue = (funcType->getResultType()->op == kIROp_VoidType && - callInst->getDataType() != funcType->getResultType()); - - IRVar* retVarInst = nullptr; - int startParamIndex = 0; - if (shouldCallerAllocateReturnValue) - { - // Declare a var for the return value. - retVarInst = builder->emitVar(callInst->getFullType()); - args.add(retVarInst); - startParamIndex = 1; - } - - for (UInt i = 0; i < callInst->getArgCount(); i++) - { - auto arg = callInst->getArg(i); - if (as(paramTypes[i] + startParamIndex) && - !as(arg->getDataType())) - { - // We are calling a generic function that with an argument of - // concrete type. We need to convert this argument to void*. - - // Ideally this should just be a GetElementAddress inst. - // However the current code emitting logic for this instruction - // doesn't truly respect the pointerness and does not produce - // what we needed. For now we use another instruction here - // to keep changes minimal. - arg = builder->emitGetAddress( - builder->getRawPointerType(), - arg); - } - args.add(arg); - } - for (UInt i = 0; i < specializeInst->getArgCount(); i++) - { - auto arg = specializeInst->getArg(i); - // Translate Type arguments into RTTI object. - if (as(arg)) - { - // We are using a simple type to specialize a callee. - // Generate RTTI for this type. - auto rttiObject = maybeEmitRTTIObject(arg); - arg = builder->emitGetAddress( - builder->getPtrType(builder->getRTTIType()), - rttiObject); - } - else if (arg->op == kIROp_Specialize) - { - // The type argument used to specialize a callee is itself a - // specialization of some generic type. - // TODO: generate RTTI object for specializations of generic types. - SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types"); - } - else if (arg->op == kIROp_RTTIObject) - { - // We are inside a generic function and using a generic parameter - // to specialize another callee. The generic parameter of the caller - // has already been translated into an RTTI object, so we just need - // to pass this object down. - } - args.add(arg); - } - auto callInstType = retVarInst ? builder->getVoidType() : callInst->getFullType(); - auto newCall = builder->emitCallInst(callInstType, loweredFunc, args); - if (retVarInst) - { - auto loadInst = builder->emitLoad(retVarInst); - callInst->replaceUsesWith(loadInst); - addToWorkList(loadInst); - addToWorkList(retVarInst); - } - else - { - callInst->replaceUsesWith(newCall); - } - callInst->removeAndDeallocate(); - } - - void processInst(IRInst* inst) - { - if (auto callInst = as(inst)) - { - lowerCall(callInst); - } - else if (auto witnessTable = as(inst)) - { - // Lower generic functions in witness table. - for (auto child : witnessTable->getChildren()) - { - auto entry = as(child); - if (!entry) - continue; - if (auto genericVal = as(entry->getSatisfyingVal())) - { - if (findGenericReturnVal(genericVal)->op == kIROp_Func) - { - auto loweredFunc = lowerGenericFunction(genericVal); - entry->satisfyingVal.set(loweredFunc); - } - } - } - } - else if (auto interfaceType = as(inst)) - { - maybeLowerInterfaceType(interfaceType); - } - else if (inst->op == kIROp_Var || inst->op == kIROp_undefined) - { - processVarInst(inst); - } - else if (inst->op == kIROp_Load) - { - processLoadInst(cast(inst)); - } - else if (inst->op == kIROp_Store) - { - processStoreInst(cast(inst)); - } - } - - void processModule() - { - // We start by initializing our shared IR building state, - // since we will re-use that state for any code we - // generate along the way. - // - SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; - sharedBuilder->module = module; - sharedBuilder->session = module->session; - - addToWorkList(module->getModuleInst()); - - while (workList.getCount() != 0) - { - // We will then iterate until our work list goes dry. - // - while (workList.getCount() != 0) - { - IRInst* inst = workList.getLast(); - - workList.removeLast(); - workListSet.Remove(inst); - - processInst(inst); - - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - addToWorkList(child); - } - } - } - } - }; - void lowerGenerics( IRModule* module) { - GenericsLoweringContext context; - context.module = module; - context.processModule(); + SharedGenericsLoweringContext sharedContext; + sharedContext.module = module; + lowerGenericFunctions(&sharedContext); + lowerGenericCalls(&sharedContext); + lowerGenericVar(&sharedContext); } } // namespace Slang diff --git a/source/slang/slang-ir-lower-generics.h b/source/slang/slang-ir-lower-generics.h index ed9e58c8f..664dd11ff 100644 --- a/source/slang/slang-ir-lower-generics.h +++ b/source/slang/slang-ir-lower-generics.h @@ -1,6 +1,8 @@ // slang-ir-lower-generics.h #pragma once +#include "slang-ir.h" + namespace Slang { struct IRModule; diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index efb8b3b27..fc4b71138 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2919,7 +2919,7 @@ namespace Slang IRStructType* structType = createInst( this, kIROp_StructType, - nullptr); + getTypeKind()); addGlobalValue(this, structType); return structType; } @@ -2929,7 +2929,7 @@ namespace Slang IRInterfaceType* interfaceType = createInst( this, kIROp_InterfaceType, - nullptr, + getTypeKind(), operandCount, operands); addGlobalValue(this, interfaceType); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 6ca5d5058..481e4f601 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1110,8 +1110,14 @@ SIMPLE_IR_TYPE(OutType, OutTypeBase) SIMPLE_IR_TYPE(InOutType, OutTypeBase) SIMPLE_IR_TYPE(ExistentialBoxType, PtrTypeBase) + /// The base class of RawPointerType and RTTIPointerType. +struct IRRawPointerTypeBase : IRType +{ + IR_PARENT_ISA(RawPointerTypeBase); +}; + /// Represents a pointer to an object of unknown type. -struct IRRawPointerType : IRType +struct IRRawPointerType : IRRawPointerTypeBase { IR_LEAF_ISA(RawPointerType) }; @@ -1119,7 +1125,7 @@ struct IRRawPointerType : IRType /// Represents a pointer to an object whose type is determined at runtime, /// with type information available through `rttiOperand`. /// -struct IRRTTIPointerType : IRRawPointerType +struct IRRTTIPointerType : IRRawPointerTypeBase { IRInst* getRTTIOperand() { return getOperand(0); } IR_LEAF_ISA(RTTIPointerType) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index cfed09dd0..175f9264f 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -5863,21 +5863,22 @@ struct DeclLoweringVisitor : DeclVisitor IRInst* requirementVal = lowerDecl(subContext, requirementDecl).val; if (requirementVal) { - auto reqType = requirementVal->getFullType(); - entry->setRequirementVal(reqType); - if (!requirementVal->hasUses()) + switch (requirementVal->op) + { + case kIROp_Func: + case kIROp_Generic: { // Remove lowered `IRFunc`s since we only care about // function types. - switch (requirementVal->op) - { - case kIROp_Func: - case kIROp_Generic: + auto reqType = requirementVal->getFullType(); + entry->setRequirementVal(reqType); + if (!requirementVal->hasUses()) requirementVal->removeAndDeallocate(); - break; - default: - break; - } + break; + } + default: + entry->setRequirementVal(requirementVal); + break; } } irInterface->setOperand(entryIndex, entry); diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index d23dccabf..53d4681b3 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -231,6 +231,7 @@ + @@ -238,6 +239,9 @@ + + + @@ -326,12 +330,16 @@ + + + + diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index f32e911b5..5ae8d77ff 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -144,6 +144,9 @@ Header Files + + Header Files + Header Files @@ -165,6 +168,15 @@ Header Files + + Header Files + + + Header Files + + + Header Files + Header Files @@ -425,6 +437,9 @@ Source Files + + Source Files + Source Files @@ -443,6 +458,15 @@ Source Files + + Source Files + + + Source Files + + + Source Files + Source Files -- cgit v1.2.3