diff options
| author | Yong He <yonghe@outlook.com> | 2024-07-31 10:03:39 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-07-31 10:03:39 -0700 |
| commit | 134f8ccc930a8da28808c2e288344c21c67a577e (patch) | |
| tree | 483c09957f94aa626c2e866ebc7634591d725657 /source | |
| parent | 6e4b82741893be55f6216c31e19650029c667078 (diff) | |
Fix IR lowering for generic interface types. (#4761)
* Fix IR lowering for generic interface types.
* Fix.
* Fix.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 112 |
1 files changed, 78 insertions, 34 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 74aa0a0ee..583dcaacc 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8230,6 +8230,40 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> context->irBuilder->addDecoration(originalKey, op, associatedKey); } + // Given `value` defined as an independent generic of `outerGeneric`, emit IR that specializes it using + // the generic params defined in `outerGeneric`. + // For example: + // ``` + // interface IFoo<T> { void f(); } + // ``` + // We will lower `IFoo<T>::f` into `%f = IRGeneric(T) { return IRFunc(...) }` + // When we lower the interface type `IFoo`, it will become: + // ``` + // %IFoo = IRGeneric(T1) { return IRInterfaceType(???); ) + // ``` + // We want the `???` to be `specialize(%f, T1)`. + // To do so, we will call `specializeWithOuterGeneric` with `value` = `%f`, and `outerGeneric` = %IFoo. + // + IRInst* specializeWithOuterGeneric(IRBuilder* irBuilder, IRInst* value, IRGeneric* outerGeneric) + { + if (!as<IRGeneric>(value)) + return value; + if (!outerGeneric) + return value; + + // If `outerGeneric` has a generic parent, we want to recursively specialize value + // using the parent generic first. + auto parentGeneric = getOuterGeneric(outerGeneric); + if (parentGeneric) + value = specializeWithOuterGeneric(irBuilder, value, parentGeneric); + + // Now we can specialize `value` using the params defined in `outerGeneric`. + List<IRInst*> args; + for (auto param : outerGeneric->getParams()) + args.add(param); + return irBuilder->emitSpecializeInst(irBuilder->getGenericKind(), value, args); + } + LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl) { // The members of an interface will turn into the keys that will @@ -8306,54 +8340,55 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } UInt entryIndex = 0; - auto addEntry = [&](IRStructKey* requirementKey, Decl* requirementDecl) + auto addEntry = [&](IRStructKey* requirementKey, DeclRef<Decl> requirementDeclRef) { auto entry = subBuilder->createInterfaceRequirementEntry( requirementKey, nullptr); - if (auto inheritance = as<InheritanceDecl>(requirementDecl)) + if (auto inheritance = requirementDeclRef.as<InheritanceDecl>()) { - auto irBaseType = lowerType(context, inheritance->base.type); + auto irBaseType = lowerType(subContext, getSup(subContext->astBuilder, inheritance)); auto irWitnessTableType = subBuilder->getWitnessTableType(irBaseType); entry->setRequirementVal(irWitnessTableType); } else { - IRInst* requirementVal = ensureDecl(subContext, requirementDecl).val; - if (requirementVal) + auto requirementVal = ensureDecl(subContext, requirementDeclRef.getDecl()).val; + + switch (requirementVal->getOp()) { - switch (requirementVal->getOp()) - { - case kIROp_Func: - case kIROp_Generic: - { - // Remove lowered `IRFunc`s since we only care about - // function types. - auto reqType = requirementVal->getFullType(); - entry->setRequirementVal(reqType); - break; - } - default: - entry->setRequirementVal(requirementVal); - break; - } - if (requirementDecl->findModifier<HLSLStaticModifier>()) - { - getBuilder()->addStaticRequirementDecoration(requirementKey); - } + default: + // For the majority of requirements, we only care about its type in an + // interface definition, so we store only the type from the lowered IR + // in the interface entry. + // We need to make sure the type is specialized with the outer generic + // parameters in case the interface itself is inside a generic. + // + requirementVal = specializeWithOuterGeneric(context->irBuilder, requirementVal->getFullType(), outerGeneric); + entry->setRequirementVal(requirementVal); + break; + + case kIROp_AssociatedType: + // For associated types, we will store it directly inside the interface type. + entry->setRequirementVal(requirementVal); + break; + } + if (requirementDeclRef.getDecl()->findModifier<HLSLStaticModifier>()) + { + getBuilder()->addStaticRequirementDecoration(requirementKey); } } irInterface->setOperand(entryIndex, entry); entryIndex++; // Add addtional requirements for type constraints placed // on an associated types. - if (auto associatedTypeDecl = as<AssocTypeDecl>(requirementDecl)) + if (auto associatedTypeDeclRef = requirementDeclRef.as<AssocTypeDecl>()) { - for (auto constraintDecl : associatedTypeDecl->getMembersOfType<TypeConstraintDecl>()) + for (auto constraintDeclRef : getMembersOfType<TypeConstraintDecl>(subContext->astBuilder, associatedTypeDeclRef)) { - auto constraintKey = getInterfaceRequirementKey(constraintDecl); + auto constraintKey = getInterfaceRequirementKey(constraintDeclRef.getDecl()); auto constraintInterfaceType = - lowerType(context, constraintDecl->getSup().type); + lowerType(context, getSup(subContext->astBuilder, constraintDeclRef)); auto witnessTableType = getBuilder()->getWitnessTableType(constraintInterfaceType); @@ -8362,16 +8397,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> irInterface->setOperand(entryIndex, constraintEntry); entryIndex++; - context->setValue(constraintDecl, LoweredValInfo::simple(constraintEntry)); + context->setValue(constraintDeclRef.getDecl(), LoweredValInfo::simple(constraintEntry)); } } else { CallableDecl* callableDecl = nullptr; - if (auto genDecl = as<GenericDecl>(requirementDecl)) + if (auto genDecl = as<GenericDecl>(requirementDeclRef.getDecl())) callableDecl = as<CallableDecl>(genDecl->inner); else - callableDecl = as<CallableDecl>(requirementDecl); + callableDecl = as<CallableDecl>(requirementDeclRef.getDecl()); if (callableDecl) { // Differentiable functions has additional requirements for the derivatives. @@ -8384,7 +8419,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Add lowered requirement entry to current decl mapping to prevent // the function requirements from being lowered again when we get to // `ensureAllDeclsRec`. - context->setValue(requirementDecl, LoweredValInfo::simple(entry)); + context->setValue(requirementDeclRef.getDecl(), LoweredValInfo::simple(entry)); } }; for (auto requirementDecl : decl->members) @@ -8400,7 +8435,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { auto accessorKey = getInterfaceRequirementKey(accessorDecl); if (accessorKey) - addEntry(accessorKey, accessorDecl); + { + auto accessorDeclRef = createDefaultSpecializedDeclRef(subContext, nullptr, accessorDecl); + addEntry(accessorKey, accessorDeclRef); + } } } } @@ -8408,7 +8446,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } else { - addEntry(requirementKey, requirementDecl); + if (auto genericDecl = as<GenericDecl>(requirementDecl)) + { + // We need to form a declref into the inner decls in case of a generic requirement. + requirementDecl = getInner(genericDecl); + } + auto requirementDeclRef = createDefaultSpecializedDeclRef(subContext, nullptr, requirementDecl); + addEntry(requirementKey, requirementDeclRef); } } |
