summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-31 10:03:39 -0700
committerGitHub <noreply@github.com>2024-07-31 10:03:39 -0700
commit134f8ccc930a8da28808c2e288344c21c67a577e (patch)
tree483c09957f94aa626c2e866ebc7634591d725657 /source
parent6e4b82741893be55f6216c31e19650029c667078 (diff)
Fix IR lowering for generic interface types. (#4761)
* Fix IR lowering for generic interface types. * Fix. * Fix.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-lower-to-ir.cpp112
1 files changed, 78 insertions, 34 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 74aa0a0ee..583dcaacc 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8230,6 +8230,40 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
context->irBuilder->addDecoration(originalKey, op, associatedKey);
}
+ // Given `value` defined as an independent generic of `outerGeneric`, emit IR that specializes it using
+ // the generic params defined in `outerGeneric`.
+ // For example:
+ // ```
+ // interface IFoo<T> { void f(); }
+ // ```
+ // We will lower `IFoo<T>::f` into `%f = IRGeneric(T) { return IRFunc(...) }`
+ // When we lower the interface type `IFoo`, it will become:
+ // ```
+ // %IFoo = IRGeneric(T1) { return IRInterfaceType(???); )
+ // ```
+ // We want the `???` to be `specialize(%f, T1)`.
+ // To do so, we will call `specializeWithOuterGeneric` with `value` = `%f`, and `outerGeneric` = %IFoo.
+ //
+ IRInst* specializeWithOuterGeneric(IRBuilder* irBuilder, IRInst* value, IRGeneric* outerGeneric)
+ {
+ if (!as<IRGeneric>(value))
+ return value;
+ if (!outerGeneric)
+ return value;
+
+ // If `outerGeneric` has a generic parent, we want to recursively specialize value
+ // using the parent generic first.
+ auto parentGeneric = getOuterGeneric(outerGeneric);
+ if (parentGeneric)
+ value = specializeWithOuterGeneric(irBuilder, value, parentGeneric);
+
+ // Now we can specialize `value` using the params defined in `outerGeneric`.
+ List<IRInst*> args;
+ for (auto param : outerGeneric->getParams())
+ args.add(param);
+ return irBuilder->emitSpecializeInst(irBuilder->getGenericKind(), value, args);
+ }
+
LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl)
{
// The members of an interface will turn into the keys that will
@@ -8306,54 +8340,55 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
UInt entryIndex = 0;
- auto addEntry = [&](IRStructKey* requirementKey, Decl* requirementDecl)
+ auto addEntry = [&](IRStructKey* requirementKey, DeclRef<Decl> requirementDeclRef)
{
auto entry = subBuilder->createInterfaceRequirementEntry(
requirementKey,
nullptr);
- if (auto inheritance = as<InheritanceDecl>(requirementDecl))
+ if (auto inheritance = requirementDeclRef.as<InheritanceDecl>())
{
- auto irBaseType = lowerType(context, inheritance->base.type);
+ auto irBaseType = lowerType(subContext, getSup(subContext->astBuilder, inheritance));
auto irWitnessTableType = subBuilder->getWitnessTableType(irBaseType);
entry->setRequirementVal(irWitnessTableType);
}
else
{
- IRInst* requirementVal = ensureDecl(subContext, requirementDecl).val;
- if (requirementVal)
+ auto requirementVal = ensureDecl(subContext, requirementDeclRef.getDecl()).val;
+
+ switch (requirementVal->getOp())
{
- switch (requirementVal->getOp())
- {
- case kIROp_Func:
- case kIROp_Generic:
- {
- // Remove lowered `IRFunc`s since we only care about
- // function types.
- auto reqType = requirementVal->getFullType();
- entry->setRequirementVal(reqType);
- break;
- }
- default:
- entry->setRequirementVal(requirementVal);
- break;
- }
- if (requirementDecl->findModifier<HLSLStaticModifier>())
- {
- getBuilder()->addStaticRequirementDecoration(requirementKey);
- }
+ default:
+ // For the majority of requirements, we only care about its type in an
+ // interface definition, so we store only the type from the lowered IR
+ // in the interface entry.
+ // We need to make sure the type is specialized with the outer generic
+ // parameters in case the interface itself is inside a generic.
+ //
+ requirementVal = specializeWithOuterGeneric(context->irBuilder, requirementVal->getFullType(), outerGeneric);
+ entry->setRequirementVal(requirementVal);
+ break;
+
+ case kIROp_AssociatedType:
+ // For associated types, we will store it directly inside the interface type.
+ entry->setRequirementVal(requirementVal);
+ break;
+ }
+ if (requirementDeclRef.getDecl()->findModifier<HLSLStaticModifier>())
+ {
+ getBuilder()->addStaticRequirementDecoration(requirementKey);
}
}
irInterface->setOperand(entryIndex, entry);
entryIndex++;
// Add addtional requirements for type constraints placed
// on an associated types.
- if (auto associatedTypeDecl = as<AssocTypeDecl>(requirementDecl))
+ if (auto associatedTypeDeclRef = requirementDeclRef.as<AssocTypeDecl>())
{
- for (auto constraintDecl : associatedTypeDecl->getMembersOfType<TypeConstraintDecl>())
+ for (auto constraintDeclRef : getMembersOfType<TypeConstraintDecl>(subContext->astBuilder, associatedTypeDeclRef))
{
- auto constraintKey = getInterfaceRequirementKey(constraintDecl);
+ auto constraintKey = getInterfaceRequirementKey(constraintDeclRef.getDecl());
auto constraintInterfaceType =
- lowerType(context, constraintDecl->getSup().type);
+ lowerType(context, getSup(subContext->astBuilder, constraintDeclRef));
auto witnessTableType =
getBuilder()->getWitnessTableType(constraintInterfaceType);
@@ -8362,16 +8397,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
irInterface->setOperand(entryIndex, constraintEntry);
entryIndex++;
- context->setValue(constraintDecl, LoweredValInfo::simple(constraintEntry));
+ context->setValue(constraintDeclRef.getDecl(), LoweredValInfo::simple(constraintEntry));
}
}
else
{
CallableDecl* callableDecl = nullptr;
- if (auto genDecl = as<GenericDecl>(requirementDecl))
+ if (auto genDecl = as<GenericDecl>(requirementDeclRef.getDecl()))
callableDecl = as<CallableDecl>(genDecl->inner);
else
- callableDecl = as<CallableDecl>(requirementDecl);
+ callableDecl = as<CallableDecl>(requirementDeclRef.getDecl());
if (callableDecl)
{
// Differentiable functions has additional requirements for the derivatives.
@@ -8384,7 +8419,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Add lowered requirement entry to current decl mapping to prevent
// the function requirements from being lowered again when we get to
// `ensureAllDeclsRec`.
- context->setValue(requirementDecl, LoweredValInfo::simple(entry));
+ context->setValue(requirementDeclRef.getDecl(), LoweredValInfo::simple(entry));
}
};
for (auto requirementDecl : decl->members)
@@ -8400,7 +8435,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
auto accessorKey = getInterfaceRequirementKey(accessorDecl);
if (accessorKey)
- addEntry(accessorKey, accessorDecl);
+ {
+ auto accessorDeclRef = createDefaultSpecializedDeclRef(subContext, nullptr, accessorDecl);
+ addEntry(accessorKey, accessorDeclRef);
+ }
}
}
}
@@ -8408,7 +8446,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
else
{
- addEntry(requirementKey, requirementDecl);
+ if (auto genericDecl = as<GenericDecl>(requirementDecl))
+ {
+ // We need to form a declref into the inner decls in case of a generic requirement.
+ requirementDecl = getInner(genericDecl);
+ }
+ auto requirementDeclRef = createDefaultSpecializedDeclRef(subContext, nullptr, requirementDecl);
+ addEntry(requirementKey, requirementDeclRef);
}
}