diff options
| author | Yong He <yonghe@outlook.com> | 2020-06-24 13:16:11 -0700 |
|---|---|---|
| committer | Yong He <yonghe@outlook.com> | 2020-06-24 18:10:15 -0700 |
| commit | 0ca75fe002f346f6ab9b77f40c0576d2905560f1 (patch) | |
| tree | ed8a3af372900923e59f0d6da629c2d0969ee7fd /source | |
| parent | 3fe4f5398d524333e955ecb91be5646e86f3b2da (diff) | |
Dynamic dispatch for generic interface requirements.
-Lower interfaces into actual `IRInterfaceType` insts.
-Lower `DeclRef<AssocTypeDecl>` into `IRAssociatedType`
-Generate proper IRType for generic functions.
-Add a test case exercising dynamic dispatching a generic static function through an associated type.
-Bug fixes for the test case.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-emit-cpp.cpp | 64 | ||||
| -rw-r--r-- | source/slang/slang-emit-cpp.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 25 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-generics.cpp | 181 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 32 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 336 |
10 files changed, 529 insertions, 170 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 3438fd3f4..2605723c7 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -236,9 +236,9 @@ List<IRWitnessTableEntry*> CLikeSourceEmitter::getSortedWitnessTableEntries(IRWi // Get a sorted list of entries using RequirementKeys defined in `interfaceType`. for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { - auto reqKey = cast<IRStructKey>(interfaceType->getOperand(i)); + auto reqEntry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); IRWitnessTableEntry* entry = nullptr; - if (witnessTableEntryDictionary.TryGetValue(reqKey, entry)) + if (witnessTableEntryDictionary.TryGetValue(reqEntry->getRequirementKey(), entry)) { sortedWitnessTableEntries.add(entry); } @@ -1962,6 +1962,10 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO are hashed with 'getStringHash' */ break; + case kIROp_undefined: + m_writer->emit(getName(inst)); + break; + case kIROp_IntLit: case kIROp_FloatLit: case kIROp_BoolLit: @@ -3554,6 +3558,11 @@ void CLikeSourceEmitter::emitGlobalInst(IRInst* inst) are hashed with 'getStringHash' */ break; + case kIROp_InterfaceRequirementEntry: + // Don't emit anything for interface requirement at global level. + // They are handled in `emitInterface`. + break; + case kIROp_Func: emitFunc((IRFunc*) inst); break; @@ -3610,6 +3619,10 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I ensureInstOperand(ctx, inst->getFullType()); UInt operandCount = inst->operandCount; + auto requiredLevel = EmitAction::Definition; + if (inst->op == kIROp_InterfaceType) + requiredLevel = EmitAction::ForwardDeclaration; + for(UInt ii = 0; ii < operandCount; ++ii) { // TODO: there are some special cases we can add here, @@ -3620,8 +3633,8 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I // only need the type they point to to be forward-declared. // Similarly, a `call` instruction only needs the callee // to be forward-declared, etc. - - ensureInstOperand(ctx, inst->getOperand(ii)); + + ensureInstOperand(ctx, inst->getOperand(ii), requiredLevel); } for(auto child : inst->getDecorationsAndChildren()) diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 4a59f4cf9..eeace4aa7 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -390,12 +390,27 @@ static UnownedStringSlice _getResourceTypePrefix(IROp op) } } +static bool isVoidPtrType(IRType* type) +{ + auto ptrType = as<IRPtrType>(type); + if (!ptrType) return false; + return ptrType->getValueType()->op == kIROp_VoidType; +} + SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) { switch (type->op) { case kIROp_PtrType: { + if (isVoidPtrType(type)) + { + // A `void*` type will always emit as `void*`. + // `void*` types are generated as a result of generics lowering + // for dynamic dispatch. + out << "void*"; + return SLANG_OK; + } auto ptrType = static_cast<IRPtrType*>(type); SLANG_RETURN_ON_FAIL(calcTypeName(ptrType->getValueType(), target, out)); // TODO(JS): It seems although it says it is a pointer, it can actually be output as a reference @@ -494,7 +509,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S // struct of function pointers corresponding to the interface type. auto witnessTableType = static_cast<IRWitnessTableType*>(type); auto baseType = cast<IRType>(witnessTableType->getOperand(0)); - emitType(baseType); + SLANG_RETURN_ON_FAIL(calcTypeName(baseType, target, out)); out << "*"; return SLANG_OK; } @@ -1591,8 +1606,7 @@ void CPPSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable) { auto interfaceType = cast<IRInterfaceType>(witnessTable->getOperand(0)); auto witnessTableItems = witnessTable->getChildren(); - List<IRWitnessTableEntry*> sortedWitnessTableEntries = getSortedWitnessTableEntries(witnessTable); - _maybeEmitWitnessTableTypeDefinition(interfaceType, sortedWitnessTableEntries); + _maybeEmitWitnessTableTypeDefinition(interfaceType); // Define a global variable for the witness table. m_writer->emit("extern "); @@ -1747,17 +1761,16 @@ void CPPSourceEmitter::emitInterface(IRInterfaceType* interfaceType) /// acoording to the order defined by `interfaceType`. /// void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( - IRInterfaceType* interfaceType, - const List<IRWitnessTableEntry*>& sortedWitnessTableEntries) + IRInterfaceType* interfaceType) { m_writer->emit("struct "); emitSimpleType(interfaceType); m_writer->emit("\n{\n"); m_writer->indent(); - for (Index i = 0; i < sortedWitnessTableEntries.getCount(); i++) + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) { - auto entry = sortedWitnessTableEntries[i]; - if (auto funcVal = as<IRFunc>(entry->satisfyingVal.get())) + auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); + if (auto funcVal = as<IRFuncType>(entry->getRequirementVal())) { emitType(funcVal->getResultType()); m_writer->emit(" (KernelContext::*"); @@ -1765,33 +1778,35 @@ void CPPSourceEmitter::_maybeEmitWitnessTableTypeDefinition( m_writer->emit(")"); m_writer->emit("("); bool isFirstParam = true; - for (auto param : funcVal->getParams()) + for (UInt p = 0; p < funcVal->getParamCount(); p++) { + auto paramType = funcVal->getParamType(p); + // Ingore TypeType-typed parameters for now. + if (as<IRTypeType>(paramType)) + continue; + if (!isFirstParam) m_writer->emit(", "); else isFirstParam = false; - if (param->findDecoration<IRThisPointerDecoration>()) + auto thisDecor = funcVal->findDecoration<IRThisPointerDecoration>(); + if (thisDecor && cast<IRIntLit>(thisDecor->getOperand(0))->value.intVal == (IRIntegerValue)p) { - m_writer->emit("void* "); - m_writer->emit(getName(param)); + m_writer->emit("void* param"); + m_writer->emit(p); continue; } - emitSimpleFuncParamImpl(param); + emitParamType(paramType, String("param") + String(p)); } m_writer->emit(");\n"); } - else if (auto witnessTableVal = as<IRWitnessTable>(entry->getSatisfyingVal())) + else if (auto constraintInterfaceType = as<IRInterfaceType>(entry->getRequirementVal())) { - emitType(as<IRType>(witnessTableVal->getOperand(0))); + emitType(constraintInterfaceType); m_writer->emit("* "); m_writer->emit(getName(entry->requirementKey.get())); m_writer->emit(";\n"); } - else - { - // TODO: handle other witness table entry types. - } } m_writer->dedent(); m_writer->emit("};\n"); @@ -1990,13 +2005,6 @@ void CPPSourceEmitter::emitSimpleValueImpl(IRInst* inst) } } -static bool isVoidPtrType(IRType* type) -{ - auto ptrType = as<IRPtrType>(type); - if (!ptrType) return false; - return ptrType->getValueType()->op == kIROp_VoidType; -} - void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) { // Polymorphic types are already translated to void* type in @@ -2004,9 +2012,7 @@ void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) // emit "void&" instead of "void*" for pointer types. // In the future, we will handle pointer types more properly, // and this override logic will not be necessary. - // For now we special-case this scenario. - if (param->findDecoration<IRPolymorphicDecoration>() && - isVoidPtrType(param->getDataType())) + if (isVoidPtrType(param->getDataType())) { m_writer->emit("void* "); m_writer->emit(getName(param)); diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h index 47ba03d70..6f91444a3 100644 --- a/source/slang/slang-emit-cpp.h +++ b/source/slang/slang-emit-cpp.h @@ -89,7 +89,7 @@ protected: virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder); // Emits a struct of function pointers defined in `interfaceType`. - void _maybeEmitWitnessTableTypeDefinition(IRInterfaceType* interfaceType, const List<IRWitnessTableEntry*>& sortedWitnessTableEntries); + void _maybeEmitWitnessTableTypeDefinition(IRInterfaceType* interfaceType); void _maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp); void _emitForwardDeclarations(const List<EmitAction>& actions); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 58ff1a79f..e9bc23993 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -164,7 +164,8 @@ INST(Nop, nop, 0, 0) // `field` instructions. // INST(StructType, struct, 0, PARENT) -INST(InterfaceType, interface, 0, PARENT) +INST(InterfaceType, interface, 0, 0) +INST(AssociatedType, associated_type, 0, 0) // A TypeType-typed IRValue represents a IRType. // It is used to represent a type parameter/argument in a generics. @@ -223,6 +224,7 @@ INST(Call, call, 1, 0) INST(WitnessTableEntry, witness_table_entry, 2, 0) +INST(InterfaceRequirementEntry, interface_req_entry, 2, 0) INST(Param, param, 0, 0) INST(StructField, field, 2, 0) @@ -507,14 +509,12 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(BindExistentialSlotsDecoration, bindExistentialSlots, 0, 0) - /// A `[polymorphic]` decoration marks a function parameter that should translate to an abstract type - /// e.g. (void*) that are casted to actual type before use. For example, a parameter of generic type - /// is marked `[polymorphic]`, so that the code gen logic can emit it as a `void*` parameter, - /// allowing the function to be used at sites that are agnostic of the actual object type. - INST(PolymorphicDecoration, polymorphic, 0, 0) /// A `[this_ptr]` decoration marks a function parameter that serves as `this` pointer. - INST(ThisPointerDecoration, this_ptr, 0, 0) + /// `[this_ptr]` decoration is also used to mark an `IRFunc` as a non-static function. + /// The argument is an integer value that represents the index of the `this` parameter, + /// which is always 0. + INST(ThisPointerDecoration, this_ptr, 1, 0) /// A `[format(f)]` decoration specifies that the format of an image should be `f` diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index b13d52981..fb0cc57c7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -165,8 +165,6 @@ IR_SIMPLE_DECORATION(VulkanCallablePayloadDecoration) /// vulkan hit attributes, and should have a location assigned /// to it. IR_SIMPLE_DECORATION(VulkanHitAttributesDecoration) - -IR_SIMPLE_DECORATION(PolymorphicDecoration) IR_SIMPLE_DECORATION(ThisPointerDecoration) @@ -410,9 +408,13 @@ struct IRLookupWitnessMethod : IRInst { IRUse witnessTable; IRUse requirementKey; + IRUse interfaceType; IRInst* getWitnessTable() { return witnessTable.get(); } IRInst* getRequirementKey() { return requirementKey.get(); } + IRInst* getInterfaceType() { return interfaceType.get(); } + + IR_LEAF_ISA(lookup_interface_method) }; struct IRLookupWitnessTable : IRInst @@ -1675,7 +1677,8 @@ struct IRBuilder IRInst* emitLookupInterfaceMethodInst( IRType* type, IRInst* witnessTableVal, - IRInst* interfaceMethodVal); + IRInst* interfaceMethodVal, + IRType* interfaceType); IRInst* emitCallInst( IRType* type, @@ -1809,9 +1812,16 @@ struct IRBuilder IRInst* requirementKey, IRInst* satisfyingVal); + IRInterfaceRequirementEntry* createInterfaceRequirementEntry( + IRInst* requirementKey, + IRInst* requirementVal); + // Create an initially empty `struct` type. IRStructType* createStructType(); + // Create an IRType representing an `associatedtype` decl. + IRAssociatedType* createAssociatedType(); + // Create an empty `interface` type. IRInterfaceType* createInterfaceType(UInt operandCount, IRInst* const* operands); @@ -2160,14 +2170,9 @@ struct IRBuilder addDecoration(value, kIROp_LoopControlDecoration, getIntValue(getIntType(), IRIntegerValue(mode))); } - void addPolymorphicDecoration(IRInst* value) - { - addDecoration(value, kIROp_PolymorphicDecoration); - } - - void addThisPointerDecoration(IRInst* value) + void addThisPointerDecoration(IRInst* value, int paramIndex) { - addDecoration(value, kIROp_ThisPointerDecoration); + addDecoration(value, kIROp_ThisPointerDecoration, getIntValue(getIntType(), paramIndex)); } void addSemanticDecoration(IRInst* value, UnownedStringSlice const& text, int index = 0) diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 3f51aa876..4e6ad74a4 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -228,6 +228,7 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) case kIROp_StructKey: case kIROp_GlobalGenericParam: case kIROp_WitnessTable: + case kIROp_InterfaceType: case kIROp_TaggedUnionType: return cloneGlobalValue(this, originalValue); @@ -607,8 +608,7 @@ IRInterfaceType* cloneInterfaceTypeImpl( auto clonedInterface = builder->createInterfaceType(originalInterface->getOperandCount(), nullptr); for (UInt i = 0; i < originalInterface->getOperandCount(); i++) { - auto clonedKey = findClonedValue(context, originalInterface->getOperand(i)); - SLANG_ASSERT(clonedKey); + auto clonedKey = cloneValue(context, originalInterface->getOperand(i)); clonedInterface->setOperand(i, clonedKey); } cloneSimpleGlobalValueImpl(context, originalInterface, originalValues, clonedInterface); @@ -628,6 +628,7 @@ void cloneGlobalValueWithCodeCommon( cloneDecorations(context, clonedValue, originalValue); cloneExtraDecorations(context, clonedValue, originalValues); + clonedValue->setFullType((IRType*)cloneValue(context, originalValue->getFullType())); // We will walk through the blocks of the function, and clone each of them. // diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index f6340a633..fe0fa3364 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -16,6 +16,7 @@ namespace Slang IRModule* module; Dictionary<IRInst*, IRInst*> loweredGenericFunctions; + HashSet<IRInterfaceType*> loweredInterfaceTypes; SharedIRBuilder sharedBuilderStorage; @@ -45,6 +46,20 @@ namespace Slang workListSet.Add(inst); } + bool isPolymorphicType(IRInst* typeInst) + { + if (as<IRParam>(typeInst) && as<IRTypeType>(typeInst->getFullType())) + return true; + switch (typeInst->op) + { + case kIROp_AssociatedType: + case kIROp_InterfaceType: + return true; + default: + return false; + } + } + IRInst* lowerGenericFunction(IRInst* genericValue) { IRInst* result = nullptr; @@ -64,6 +79,7 @@ namespace Slang builder.sharedBuilder = &sharedBuilderStorage; builder.setInsertBefore(genericParent); auto loweredFunc = cloneInstAndOperands(&cloneEnv, &builder, func); + loweredFunc->setFullType(lowerGenericFuncType(&builder, cast<IRGeneric>(genericParent->typeUse.get()))); List<IRInst*> clonedParams; for (auto genericParam : genericParent->getParams()) { @@ -82,7 +98,7 @@ namespace Slang // Turn generic parameters into void pointers. for (auto param : cast<IRFunc>(loweredFunc)->getParams()) { - if (param->findDecoration<IRPolymorphicDecoration>()) + if (isPolymorphicType(param->getFullType())) { param->setFullType(builder.getPtrType(builder.getVoidType())); } @@ -91,6 +107,106 @@ namespace Slang return loweredFunc; } + IRType* lowerGenericFuncType(IRBuilder* builder, IRGeneric* genericVal) + { + List<IRInst*> genericParamTypes; + for (auto genericParam : genericVal->getParams()) + { + if (isPolymorphicType(genericParam->getFullType())) + { + genericParamTypes.add(builder->getPtrType(builder->getVoidType())); + } + else + { + genericParamTypes.add(genericParam->getFullType()); + } + } + + auto innerType = (IRFuncType*)lowerFuncType( + builder, + cast<IRFuncType>(findGenericReturnVal(genericVal)), + genericParamTypes.getCount()); + + for (int i = 0; i < genericParamTypes.getCount(); i++) + { + innerType->setOperand( + innerType->getOperandCount() - genericParamTypes.getCount() + i, + genericParamTypes[i]); + } + + return innerType; + } + + IRType* lowerFuncType(IRBuilder* builder, IRFuncType* funcType, UInt additionalParamCount = 0) + { + List<IRInst*> newOperands; + bool translated = false; + for (UInt i = 0; i < funcType->getOperandCount(); i++) + { + auto paramType = funcType->getOperand(i); + if (isPolymorphicType(paramType)) + { + newOperands.add(builder->getPtrType(builder->getVoidType())); + translated = true; + } + else if (paramType->op == kIROp_Specialize) + { + // TODO: handle static specialized type here. + // For now treat all specialized types as dynamic. + // In the future, we need to turn things like Array<IDynamic> into Array<void*>. + newOperands.add(builder->getPtrType(builder->getVoidType())); + translated = true; + } + else + { + newOperands.add(paramType); + } + } + if (!translated) + return funcType; + for (UInt i = 0; i < additionalParamCount; i++) + { + newOperands.add(nullptr); + } + auto newFuncType = builder->getFuncType( + newOperands.getCount() - 1, + (IRType**)(newOperands.begin() + 1), + (IRType*)newOperands[0]); + + IRCloneEnv cloneEnv; + cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, funcType, newFuncType); + return newFuncType; + } + + IRInterfaceType* maybeLowerInterfaceType(IRInterfaceType* interfaceType) + { + if (loweredInterfaceTypes.Contains(interfaceType)) + return interfaceType; + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(interfaceType); + + // Translate IRFuncType in interface requirements. + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + if (auto entry = as<IRInterfaceRequirementEntry>(interfaceType->getOperand(i))) + { + if (auto funcType = as<IRFuncType>(entry->getRequirementVal())) + { + entry->requirementVal.set(lowerFuncType(&builder, funcType)); + } + else if (auto genericFuncType = as<IRGeneric>(entry->getRequirementVal())) + { + entry->requirementVal.set(lowerGenericFuncType(&builder, genericFuncType)); + } + } + } + + loweredInterfaceTypes.Add(interfaceType); + return interfaceType; + } + void processInst(IRInst* inst) { if (auto callInst = as<IRCall>(inst)) @@ -98,25 +214,53 @@ namespace Slang // If we see a call(specialize(gFunc, Targs), args), // translate it into call(gFunc, args, Targs). auto funcOperand = callInst->getOperand(0); + IRInst* loweredFunc = nullptr; if (auto specializeInst = as<IRSpecialize>(funcOperand)) { - auto loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); - if (loweredFunc == specializeInst->getOperand(0)) + auto funcToSpecialize = specializeInst->getOperand(0); + List<IRType*> paramTypes; + if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize)) { - // This is an intrinsic function, don't transform. - return; + // The callee is a result of witness table lookup, we will only + // translate the call. + IRInst* callee = nullptr; + auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(interfaceLookup->getInterfaceType())); + for (UInt i = 0; i < interfaceType->getOperandCount(); i++) + { + auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i)); + if (entry->getRequirementKey() == interfaceLookup->getOperand(1)) + { + callee = entry->getRequirementVal(); + break; + } + } + auto funcType = cast<IRFuncType>(callee); + for (UInt i = 0; i < funcType->getParamCount(); i++) + paramTypes.add(funcType->getParamType(i)); + loweredFunc = funcToSpecialize; + } + else + { + loweredFunc = lowerGenericFunction(specializeInst->getOperand(0)); + if (loweredFunc == specializeInst->getOperand(0)) + { + // This is an intrinsic function, don't transform. + return; + } + for (auto param : as<IRFunc>(loweredFunc)->getParams()) + paramTypes.add(param->getDataType()); } + IRBuilder builderStorage; auto builder = &builderStorage; builder->sharedBuilder = &sharedBuilderStorage; builder->setInsertBefore(inst); List<IRInst*> args; - auto pp = as<IRFunc>(loweredFunc)->getParams().begin(); auto voidPtrType = builder->getPtrType(builder->getVoidType()); for (UInt i = 0; i < callInst->getArgCount(); i++) { auto arg = callInst->getArg(i); - if ((*pp)->getDataType() == voidPtrType && + if (paramTypes[i] == voidPtrType && arg->getDataType() != voidPtrType) { // We are calling a generic function that with an argument of @@ -132,7 +276,6 @@ namespace Slang arg); } args.add(arg); - ++pp; } for (UInt i = 0; i < specializeInst->getArgCount(); i++) args.add(specializeInst->getArg(i)); @@ -141,6 +284,28 @@ namespace Slang callInst->removeAndDeallocate(); } } + else if (auto witnessTable = as<IRWitnessTable>(inst)) + { + // Lower generic functions in witness table. + for (auto child : witnessTable->getChildren()) + { + auto entry = as<IRWitnessTableEntry>(child); + if (!entry) + continue; + if (auto genericVal = as<IRGeneric>(entry->getSatisfyingVal())) + { + if (findGenericReturnVal(genericVal)->op == kIROp_Func) + { + auto loweredFunc = lowerGenericFunction(genericVal); + entry->satisfyingVal.set(loweredFunc); + } + } + } + } + else if (auto interfaceType = as<IRInterfaceType>(inst)) + { + maybeLowerInterfaceType(interfaceType); + } } void processModule() diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 77011b569..891f4b3e0 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2508,14 +2508,16 @@ namespace Slang IRInst* IRBuilder::emitLookupInterfaceMethodInst( IRType* type, IRInst* witnessTableVal, - IRInst* interfaceMethodVal) + IRInst* interfaceMethodVal, + IRType* interfaceType) { + IRInst* args[3] = { witnessTableVal , interfaceMethodVal, interfaceType }; auto inst = createInst<IRLookupWitnessMethod>( this, kIROp_lookup_interface_method, type, - witnessTableVal, - interfaceMethodVal); + 3, + args); addInst(inst); return inst; @@ -2811,6 +2813,20 @@ namespace Slang return entry; } + IRInterfaceRequirementEntry* IRBuilder::createInterfaceRequirementEntry( + IRInst* requirementKey, + IRInst* requirementVal) + { + IRInterfaceRequirementEntry* entry = createInst<IRInterfaceRequirementEntry>( + this, + kIROp_InterfaceRequirementEntry, + nullptr, + requirementKey, + requirementVal); + addGlobalValue(this, entry); + return entry; + } + IRStructType* IRBuilder::createStructType() { IRStructType* structType = createInst<IRStructType>( @@ -2821,6 +2837,16 @@ namespace Slang return structType; } + IRAssociatedType* IRBuilder::createAssociatedType() + { + IRAssociatedType* associatedType = createInst<IRAssociatedType>( + this, + kIROp_AssociatedType, + nullptr); + addGlobalValue(this, associatedType); + return associatedType; + } + IRInterfaceType* IRBuilder::createInterfaceType(UInt operandCount, IRInst* const* operands) { IRInterfaceType* interfaceType = createInst<IRInterfaceType>( diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 3c9a15650..b41c94e7f 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1190,6 +1190,25 @@ struct IRStructType : IRType IR_LEAF_ISA(StructType) }; +struct IRAssociatedType : IRType +{ + IR_LEAF_ISA(AssociatedType) +}; + +struct IRInterfaceRequirementEntry : IRInst +{ + // The AST-level requirement + IRUse requirementKey; + + // The IR-level value that represents the declaration of the requirement + IRUse requirementVal; + + IRInst* getRequirementKey() { return getOperand(0); } + IRInst* getRequirementVal() { return getOperand(1); } + + IR_LEAF_ISA(InterfaceRequirementEntry); +}; + struct IRInterfaceType : IRType { IR_LEAF_ISA(InterfaceType) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ea04ea85c..ff356fd48 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -401,6 +401,18 @@ struct IRGenContext { return shared->m_mainModuleDecl; } + + LoweredValInfo* findLoweredDecl(Decl* decl) + { + IRGenEnv* envToFindIn = env; + while (envToFindIn) + { + if (auto rs = envToFindIn->mapDeclToValue.TryGetValue(decl)) + return rs; + envToFindIn = envToFindIn->outer; + } + return nullptr; + } }; void setGlobalValue(SharedIRGenContext* sharedContext, Decl* decl, LoweredValInfo value) @@ -986,6 +998,8 @@ IRStructKey* getInterfaceRequirementKey( IRGenContext* context, Decl* requirementDecl) { + if (auto genericDecl = as<GenericDecl>(requirementDecl)) + return getInterfaceRequirementKey(context, genericDecl->inner); IRStructKey* requirementKey = nullptr; if(context->shared->interfaceRequirementKeys.TryGetValue(requirementDecl, requirementKey)) { @@ -1059,7 +1073,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( nullptr, baseWitnessTable, - requirementKey)); + requirementKey, + lowerType(context, val->subToMid->sup))); } LoweredValInfo visitTaggedUnionSubtypeWitness( @@ -1240,7 +1255,8 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower auto caseFunc = subBuilder->emitLookupInterfaceMethodInst( caseFuncType, caseWitnessTable, - irReqKey); + irReqKey, + irWitnessTableBaseType); // We are going to emit a `call` to the satisfying value // for the case type, so we will collect the arguments for that call. @@ -4520,7 +4536,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable) { auto subBuilder = subContext->irBuilder; - + for(auto entry : astWitnessTable->requirementDictionary) { auto requiredMemberDecl = entry.Key; @@ -5275,11 +5291,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // a witness table for the interface type's conformance // to its own interface. // - List<IRStructKey*> requirementKeys; + NestedContext nestedContext(this); + auto subBuilder = nestedContext.getBuilder(); + auto subContext = nestedContext.getContext(); + List<IRInterfaceRequirementEntry*> requirementEntries; + for (auto requirementDecl : decl->members) { - requirementKeys.add(getInterfaceRequirementKey(requirementDecl)); - + auto key = getInterfaceRequirementKey(requirementDecl); + auto entry = subBuilder->createInterfaceRequirementEntry(key, nullptr); + requirementEntries.add(entry); // As a special case, any type constraints placed // on an associated type will *also* need to be turned // into requirement keys for this interface. @@ -5287,22 +5308,20 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { for (auto constraintDecl : associatedTypeDecl->getMembersOfType<TypeConstraintDecl>()) { - requirementKeys.add(getInterfaceRequirementKey(constraintDecl)); + auto constraintKey = getInterfaceRequirementKey(constraintDecl); + requirementEntries.add( + subBuilder->createInterfaceRequirementEntry(constraintKey, + lowerType(context, constraintDecl->getSup().type))); } } } - - NestedContext nestedContext(this); - auto subBuilder = nestedContext.getBuilder(); - auto subContext = nestedContext.getContext(); - // Emit any generics that should wrap the actual type. emitOuterGenerics(subContext, decl, decl); IRInterfaceType* irInterface = subBuilder->createInterfaceType( - requirementKeys.getCount(), - reinterpret_cast<IRInst**>(requirementKeys.getBuffer())); + requirementEntries.getCount(), + reinterpret_cast<IRInst**>(requirementEntries.getBuffer())); addNameHint(context, irInterface, decl); addLinkageDecoration(context, irInterface, decl); subBuilder->setInsertInto(irInterface); @@ -5389,63 +5408,76 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Emit any generics that should wrap the actual type. emitOuterGenerics(subContext, decl, decl); - IRStructType* irStruct = subBuilder->createStructType(); - addNameHint(context, irStruct, decl); - addLinkageDecoration(context, irStruct, decl); + IRInst* resultType = nullptr; + if (as<AssocTypeDecl>(decl)) + { + resultType = subBuilder->createAssociatedType(); + } + else + { + resultType = subBuilder->createStructType(); + } - subBuilder->setInsertInto(irStruct); + addNameHint(context, resultType, decl); + addLinkageDecoration(context, resultType, decl); - // A `struct` that inherits from another `struct` must start - // with a member for the direct base type. - // - for( auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>() ) + if (resultType->op == kIROp_StructType) { - auto superType = inheritanceDecl->base; - if(auto superDeclRefType = as<DeclRefType>(superType)) + IRStructType* irStruct = (IRStructType*)resultType; + subBuilder->setInsertInto(irStruct); + + // A `struct` that inherits from another `struct` must start + // with a member for the direct base type. + // + for( auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>() ) { - if(auto superStructDeclRef = superDeclRefType->declRef.as<StructDecl>()) + auto superType = inheritanceDecl->base; + if(auto superDeclRefType = as<DeclRefType>(superType)) { - auto superKey = (IRStructKey*) getSimpleVal(context, ensureDecl(context, inheritanceDecl)); - auto irSuperType = lowerType(context, superType.type); - subBuilder->createStructField( - irStruct, - superKey, - irSuperType); + if(auto superStructDeclRef = superDeclRefType->declRef.as<StructDecl>()) + { + auto superKey = (IRStructKey*) getSimpleVal(context, ensureDecl(context, inheritanceDecl)); + auto irSuperType = lowerType(context, superType.type); + subBuilder->createStructField( + irStruct, + superKey, + irSuperType); + } } } - } - for (auto fieldDecl : decl->getMembersOfType<VarDeclBase>()) - { - if (fieldDecl->hasModifier<HLSLStaticModifier>()) + for (auto fieldDecl : decl->getMembersOfType<VarDeclBase>()) { - // A `static` field is actually a global variable, - // and we should emit it as such. - ensureDecl(context, fieldDecl); - continue; - } - - // Each ordinary field will need to turn into a struct "key" - // that is used for fetching the field. - IRInst* fieldKeyInst = getSimpleVal(context, - ensureDecl(context, fieldDecl)); - auto fieldKey = as<IRStructKey>(fieldKeyInst); - SLANG_ASSERT(fieldKey); - - // Note: we lower the type of the field in the "sub" - // context, so that any generic parameters that were - // set up for the type can be referenced by the field type. - IRType* fieldType = lowerType( - subContext, - fieldDecl->getType()); + if (fieldDecl->hasModifier<HLSLStaticModifier>()) + { + // A `static` field is actually a global variable, + // and we should emit it as such. + ensureDecl(context, fieldDecl); + continue; + } - // Then, the parent `struct` instruction itself will have - // a "field" instruction. - subBuilder->createStructField( - irStruct, - fieldKey, - fieldType); + // Each ordinary field will need to turn into a struct "key" + // that is used for fetching the field. + IRInst* fieldKeyInst = getSimpleVal(context, + ensureDecl(context, fieldDecl)); + auto fieldKey = as<IRStructKey>(fieldKeyInst); + SLANG_ASSERT(fieldKey); + + // Note: we lower the type of the field in the "sub" + // context, so that any generic parameters that were + // set up for the type can be referenced by the field type. + IRType* fieldType = lowerType( + subContext, + fieldDecl->getType()); + + // Then, the parent `struct` instruction itself will have + // a "field" instruction. + subBuilder->createStructField( + irStruct, + fieldKey, + fieldType); + } } // There may be members not handled by the above logic (e.g., @@ -5455,10 +5487,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Instead we will force emission of all children of aggregate // type declarations later, from the top-level emit logic. - irStruct->moveToEnd(); - addTargetIntrinsicDecorations(irStruct, decl); + resultType->moveToEnd(); + addTargetIntrinsicDecorations(resultType, decl); - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irStruct)); + return LoweredValInfo::simple(finishOuterGenerics(subBuilder, resultType)); } LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl) @@ -5995,29 +6027,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return as<IRStringLit>(builder->getStringValue(stringLitExpr->value.getUnownedSlice())); } - LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) + void _lowerFuncResultAndParameterTypes( + ParameterLists& parameterLists, + List<IRType*>& paramTypes, + IRType*& irResultType, + IRBuilder* subBuilder, + IRGenContext* subContext, + FunctionDeclBase* decl) { - // We are going to use a nested builder, because we will - // change the parent node that things get nested into. - // - NestedContext nestedContext(this); - auto subBuilder = nestedContext.getBuilder(); - auto subContext = nestedContext.getContext(); - - // The actual `IRFunction` that we emit needs to be nested - // inside of one `IRGeneric` for every outer `GenericDecl` - // in the declaration hierarchy. - - emitOuterGenerics(subContext, decl, decl); - // Collect the parameter lists we will use for our new function. - ParameterLists parameterLists; collectParameterLists(decl, ¶meterLists, kParameterListCollectMode_Default); - // TODO: if there are any generic parameters in the collected list, then - // we need to output an IR function with generic parameters (or a generic - // with a nested function... the exact representation is still TBD). - // In most cases the return type for a declaration can be read off the declaration // itself, but things get a bit more complicated when we have to deal with // accessors for subscript declarations (and eventually for properties). @@ -6036,14 +6056,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } - // need to create an IR function here - - IRFunc* irFunc = subBuilder->createFunc(); - addNameHint(context, irFunc, decl); - addLinkageDecoration(context, irFunc, decl); - - List<IRType*> paramTypes; - for( auto paramInfo : parameterLists.params ) { IRType* irParamType = lowerType(subContext, paramInfo.type); @@ -6054,10 +6066,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Simple case of a by-value input parameter. break; - // If the parameter is declared `out` or `inout`, - // then we will represent it with a pointer type in - // the IR, but we will use a specialized pointer - // type that encodes the parameter direction information. + // If the parameter is declared `out` or `inout`, + // then we will represent it with a pointer type in + // the IR, but we will use a specialized pointer + // type that encodes the parameter direction information. case kParameterDirection_Out: irParamType = subBuilder->getOutType(irParamType); break; @@ -6084,7 +6096,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> paramTypes.add(irParamType); } - auto irResultType = lowerType(subContext, declForReturnType->returnType); + irResultType = lowerType(subContext, declForReturnType->returnType); if (auto setterDecl = as<SetterDecl>(decl)) { @@ -6107,11 +6119,83 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // being accessed, rather than a simple value. irResultType = subBuilder->getPtrType(irResultType); } + } - auto irFuncType = subBuilder->getFuncType( + IRFuncType* _lowerFuncTypeImpl( + ParameterLists& parameterLists, + List<IRType*>& paramTypes, + IRType*& irResultType, + IRBuilder* builder, + IRGenContext* irGenContext, + FunctionDeclBase* decl) + { + _lowerFuncResultAndParameterTypes( + parameterLists, + paramTypes, + irResultType, + builder, + irGenContext, + decl); + + auto irFuncType = builder->getFuncType( paramTypes.getCount(), paramTypes.getBuffer(), irResultType); + + if (parameterLists.params.getCount() && parameterLists.params[0].isThisParam) + builder->addThisPointerDecoration(irFuncType, 0); + return irFuncType; + } + + IRInst* lowerFuncType(FunctionDeclBase* decl) + { + NestedContext nestedContextFuncType(this); + auto funcTypeBuilder = nestedContextFuncType.getBuilder(); + auto funcTypeContext = nestedContextFuncType.getContext(); + + emitOuterGenerics(funcTypeContext, decl, decl); + + ParameterLists parameterLists; + List<IRType*> paramTypes; + IRType* irResultType = nullptr; + auto irFuncType = _lowerFuncTypeImpl( + parameterLists, + paramTypes, + irResultType, + funcTypeBuilder, + funcTypeContext, + decl); + + return finishOuterGenerics(funcTypeBuilder, irFuncType); + } + + LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) + { + // We are going to use a nested builder, because we will + // change the parent node that things get nested into. + // + NestedContext nestedContextFunc(this); + auto subBuilder = nestedContextFunc.getBuilder(); + auto subContext = nestedContextFunc.getContext(); + + emitOuterGenerics(subContext, decl, decl); + + // need to create an IR function here + + IRFunc* irFunc = subBuilder->createFunc(); + addNameHint(context, irFunc, decl); + addLinkageDecoration(context, irFunc, decl); + + ParameterLists parameterLists; + List<IRType*> paramTypes; + IRType* irResultType = nullptr; + auto irFuncType = _lowerFuncTypeImpl( + parameterLists, + paramTypes, + irResultType, + subBuilder, + subContext, + decl); irFunc->setFullType(irFuncType); subBuilder->setInsertInto(irFunc); @@ -6251,14 +6335,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (paramInfo.isThisParam) { subContext->thisVal = paramVal; - subBuilder->addThisPointerDecoration(irParam); - } - - // Add a [polymorphic] decoration for generic-typed parameters. - if (as<IRParam>(irParamType) && - as<IRTypeType>(irParamType->getFullType())) - { - subBuilder->addPolymorphicDecoration(irParam); + subBuilder->addThisPointerDecoration(irParam, (int)(paramTypeIndex - 1)); } } @@ -6470,7 +6547,53 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // body appear before the function itself in the list // of global values. irFunc->moveToEnd(); - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irFunc)); + + // If this function is defined inside an interface, add a reference to the IRFunc from + // the interface's type definition. + auto finalVal = finishOuterGenerics(subBuilder, irFunc); + + if (auto genericVal = as<IRGeneric>(finalVal)) + { + auto funcType = lowerFuncType(decl); + genericVal->typeUse.set(funcType); + } + + maybeAssociateToInterfaceType(decl, finalVal); + + return LoweredValInfo::simple(finalVal); + } + + void maybeAssociateToInterfaceType(Decl* decl, IRInst* irFuncVal) + { + auto parent = decl->parentDecl; + InterfaceDecl* interfaceDecl = nullptr; + while (parent) + { + interfaceDecl = as<InterfaceDecl>(parent); + if (interfaceDecl) break; + parent = parent->parentDecl; + } + if (!interfaceDecl) + return; + auto loweredVal = context->findLoweredDecl(interfaceDecl); + if (!loweredVal) + { + return; + } + IRInst* irFuncType = irFuncVal->typeUse.get(); + auto irInterfaceType = cast<IRInterfaceType>(loweredVal->val); + auto key = getInterfaceRequirementKey(decl); + for (UInt i = 0; i < irInterfaceType->getOperandCount(); i++) + { + auto operand = cast<IRInterfaceRequirementEntry>(irInterfaceType->getOperand(i)); + if (operand->getOperand(0) == key) + { + operand->setOperand(1, irFuncType); + return; + } + } + SLANG_UNREACHABLE("associating interface function declaration:" + "requirement not found in the interface type."); } LoweredValInfo visitGenericDecl(GenericDecl * genDecl) @@ -6759,7 +6882,8 @@ LoweredValInfo emitDeclRef( auto irSatisfyingVal = context->irBuilder->emitLookupInterfaceMethodInst( type, irWitnessTable, - irRequirementKey); + irRequirementKey, + lowerType(context, thisTypeSubst->witness->sup)); return LoweredValInfo::simple(irSatisfyingVal); } else |
