summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-10-25 07:50:14 -0700
committerGitHub <noreply@github.com>2023-10-25 22:50:14 +0800
commit5dc3c2f57963de93ad03724a01ea48b8585dc15a (patch)
tree072748b952eb03da7950110ed3a8f87da9b5e72f /source
parentf8bf75cf1ae0aeee155996a917c2925bc500f3e2 (diff)
Add `IArray`. (#3281)
* Initial support for generic interfaces. * Cleanup. * Add generic syntax for interfaces. * Add `IArray`. * Fix. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang33
-rw-r--r--source/slang/slang-check-constraint.cpp34
-rw-r--r--source/slang/slang-check-decl.cpp92
-rw-r--r--source/slang/slang-check-expr.cpp63
-rw-r--r--source/slang/slang-check-impl.h6
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit-spirv.cpp2
-rw-r--r--source/slang/slang-ir-inst-pass-base.h2
-rw-r--r--source/slang/slang-lower-to-ir.cpp25
-rw-r--r--source/slang/slang-syntax.h5
10 files changed, 209 insertions, 55 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 3e51d7028..0f600d7c6 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -317,10 +317,27 @@ interface __FlagsEnumType : __EnumType
{
};
+interface IArray<T>
+{
+ int getCount();
+ __subscript(int index) -> T
+ {
+ get;
+ }
+}
+
__generic<T, let N:int>
__magic_type(ArrayExpressionType)
-struct Array
+struct Array : IArray<T>
{
+ [ForceInline]
+ int getCount() { return N; }
+
+ __subscript(int index) -> T
+ {
+ __intrinsic_op($(kIROp_GetElement))
+ get;
+ }
}
// The "comma operator" is effectively just a generic function that returns its second
@@ -885,7 +902,7 @@ extension int16_t
/// An `N` component vector with elements of type `T`.
__generic<T = float, let N : int = 4>
__magic_type(VectorExpressionType)
-struct vector
+struct vector : IArray<T>
{
/// The element type of the vector
typedef T Element;
@@ -900,6 +917,11 @@ struct vector
// TODO: we should revise semantic checking so this kind of "identity" conversion is not required
__intrinsic_op(0)
__init(vector<T,N> value);
+
+ [ForceInline]
+ int getCount() { return N; }
+
+ __subscript(int index) -> T { __intrinsic_op($(kIROp_GetElement)) get; }
}
const int kRowMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_ROW_MAJOR);
@@ -908,10 +930,15 @@ const int kColumnMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_COLUMN_MAJOR);
/// A matrix with `R` rows and `C` columns, with elements of type `T`.
__generic<T = float, let R : int = 4, let C : int = 4, let L : int = $(SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)>
__magic_type(MatrixExpressionType)
-struct matrix
+struct matrix : IArray<vector<T,C>>
{
__intrinsic_op($(kIROp_MakeMatrixFromScalar))
__init(T val);
+
+ [ForceInline]
+ int getCount() { return R; }
+
+ __subscript(int index) -> vector<T,C> { __intrinsic_op($(kIROp_GetElement)) get; }
}
${{{{
diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp
index 996b88f48..8fd4061db 100644
--- a/source/slang/slang-check-constraint.cpp
+++ b/source/slang/slang-check-constraint.cpp
@@ -716,8 +716,38 @@ namespace Slang
if (typeParamDecl->parentDecl == constraints.genericDecl)
return TryUnifyTypeParam(constraints, typeParamDecl, fst);
- // can't be unified if they refer to different declarations.
- if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) return false;
+ // If they refer to different declarations, we need to check if one type's super type
+ // matches the other type, if so we can unify them.
+ if (fstDeclRef.getDecl() != sndDeclRef.getDecl())
+ {
+ {
+ auto fstTypeInheritanceInfo = getShared()->getInheritanceInfo(fstDeclRefType);
+ for (auto supType : fstTypeInheritanceInfo.facets)
+ {
+ if (supType->origin.declRef.getDecl() == sndDeclRef.getDecl())
+ {
+ fstDeclRef = supType->origin.declRef;
+ goto endMatch;
+ }
+ }
+ }
+ // try the other direction
+ {
+ auto sndTypeInheritanceInfo = getShared()->getInheritanceInfo(sndDeclRefType);
+ for (auto supType : sndTypeInheritanceInfo.facets)
+ {
+ if (supType->origin.declRef.getDecl() == fstDeclRef.getDecl())
+ {
+ sndDeclRef = supType->origin.declRef;
+ goto endMatch;
+ }
+ }
+ }
+ endMatch:;
+ // If they still refer to different decls, then we can't unify them.
+ if (fstDeclRef.getDecl() != sndDeclRef.getDecl())
+ return false;
+ }
// next we need to unify the substitutions applied
// to each declaration reference.
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 8df5ae618..f75f84e21 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2058,6 +2058,90 @@ namespace Slang
return true;
}
+ bool SemanticsVisitor::doesSubscriptMatchRequirement(
+ DeclRef<SubscriptDecl> satisfyingMemberDeclRef,
+ DeclRef<SubscriptDecl> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ // The result type and parameters of the satisfying member must match the type of the required member.
+ //
+ auto requiredParams = getParameters(m_astBuilder, requiredMemberDeclRef).toArray();
+ auto satisfyingParams = getParameters(m_astBuilder, satisfyingMemberDeclRef).toArray();
+ auto paramCount = requiredParams.getCount();
+ if (satisfyingParams.getCount() != paramCount)
+ return false;
+
+ for (Index paramIndex = 0; paramIndex < paramCount; ++paramIndex)
+ {
+ auto requiredParam = requiredParams[paramIndex];
+ auto satisfyingParam = satisfyingParams[paramIndex];
+
+ auto requiredParamType = getType(m_astBuilder, requiredParam);
+ auto satisfyingParamType = getType(m_astBuilder, satisfyingParam);
+
+ if (!requiredParamType->equals(satisfyingParamType))
+ return false;
+ }
+
+ auto requiredResultType = getResultType(m_astBuilder, requiredMemberDeclRef);
+ auto satisfyingResultType = getResultType(m_astBuilder, satisfyingMemberDeclRef);
+ if (!requiredResultType->equals(satisfyingResultType))
+ return false;
+
+ // Each accessor in the requirement must be accounted for by an accessor
+ // in the satisfying member.
+ //
+ // Note: it is fine for the satisfying member to provide *more* accessors
+ // than the original declaration.
+ //
+ Dictionary<DeclRef<AccessorDecl>, DeclRef<AccessorDecl>> mapRequiredToSatisfyingAccessorDeclRef;
+ for (auto requiredAccessorDeclRef : getMembersOfType<AccessorDecl>(m_astBuilder, requiredMemberDeclRef))
+ {
+ // We need to search for an accessor that can satisfy the requirement.
+ //
+ // For now we will do the simplest (and slowest) thing of a linear search,
+ // which is mostly fine because the number of accessors is bounded.
+ //
+ bool found = false;
+ for (auto satisfyingAccessorDeclRef : getMembersOfType<AccessorDecl>(m_astBuilder, satisfyingMemberDeclRef))
+ {
+ if (doesAccessorMatchRequirement(satisfyingAccessorDeclRef, requiredAccessorDeclRef))
+ {
+ // When we find a match on an accessor, we record it so that
+ // we can set up the witness values later, but we do *not*
+ // record it into the actual witness table yet, in case
+ // a later accessor comes along that doesn't find a match.
+ //
+ mapRequiredToSatisfyingAccessorDeclRef.add(requiredAccessorDeclRef, satisfyingAccessorDeclRef);
+ found = true;
+ break;
+ }
+ }
+ if (!found)
+ return false;
+ }
+
+ // Once things are done, we will install the satisfying values
+ // into the witness table for the requirements.
+ //
+ for (const auto& [key, value] : mapRequiredToSatisfyingAccessorDeclRef)
+ {
+ witnessTable->add(
+ key.getDecl(),
+ RequirementWitness(value));
+ }
+ //
+ // Note: the subscript declaration itself isn't something that
+ // has a useful value/representation in downstream passes, so
+ // we are mostly just installing it into the witness table
+ // as a way to mark this requirement as being satisfied.
+ //
+ witnessTable->add(
+ requiredMemberDeclRef.getDecl(),
+ RequirementWitness(satisfyingMemberDeclRef));
+ return true;
+ }
+
bool SemanticsVisitor::doesVarMatchRequirement(
DeclRef<VarDeclBase> satisfyingMemberDeclRef,
DeclRef<VarDeclBase> requiredMemberDeclRef,
@@ -2514,6 +2598,14 @@ namespace Slang
return doesVarMatchRequirement(varDeclRef, requiredVarDeclRef, witnessTable);
}
}
+ else if (auto subscriptDeclRef = memberDeclRef.as<SubscriptDecl>())
+ {
+ if (auto requiredSubscriptDeclRef = requiredMemberDeclRef.as<SubscriptDecl>())
+ {
+ ensureDecl(subscriptDeclRef, DeclCheckState::CanUseFuncSignature);
+ return doesSubscriptMatchRequirement(subscriptDeclRef, requiredSubscriptDeclRef, witnessTable);
+ }
+ }
// Default: just assume that thing aren't being satisfied.
return false;
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 22bc2cae8..b76fe0003 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1752,8 +1752,7 @@ namespace Slang
Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr)
{
- auto baseExpr = subscriptExpr->baseExpression;
- baseExpr = CheckExpr(baseExpr);
+ auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression);
for (auto& arg : subscriptExpr->indexExprs)
{
@@ -1822,47 +1821,34 @@ namespace Slang
// Default behavior is to look at all available `__subscript`
// declarations on the type and try to call one of them.
- {
- Name* name = getName("operator[]");
- LookupResult lookupResult = lookUpMember(
- m_astBuilder,
- this,
- name,
- baseType,
- LookupMask::Default,
- LookupOptions::NoDeref);
- if (!lookupResult.isValid())
- {
- goto fail;
- }
+ auto operatorName = getName("operator[]");
- // Now that we know there is at least one subscript member,
- // we will construct a reference to it and try to call it.
- //
- // Note: the expression may be an `OverloadedExpr`, in which
- // case the attempt to call it will trigger overload
- // resolution.
- Expr* subscriptFuncExpr = createLookupResultExpr(
- name,
- lookupResult,
- subscriptExpr->baseExpression,
- subscriptExpr->loc,
- subscriptExpr);
-
- InvokeExpr* subscriptCallExpr = m_astBuilder->create<InvokeExpr>();
- subscriptCallExpr->loc = subscriptExpr->loc;
- subscriptCallExpr->functionExpr = subscriptFuncExpr;
- subscriptCallExpr->arguments.addRange(subscriptExpr->indexExprs);
- subscriptCallExpr->argumentDelimeterLocs.addRange(subscriptExpr->argumentDelimeterLocs);
-
- return CheckInvokeExprWithCheckedOperands(subscriptCallExpr);
- }
-
- fail:
+ LookupResult lookupResult = lookUpMember(
+ m_astBuilder,
+ this,
+ operatorName,
+ baseType,
+ LookupMask::Default,
+ LookupOptions::NoDeref);
+ if (!lookupResult.isValid())
{
getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType);
return CreateErrorExpr(subscriptExpr);
}
+ auto subscriptFuncExpr = createLookupResultExpr(
+ operatorName,
+ lookupResult,
+ subscriptExpr->baseExpression,
+ subscriptExpr->loc,
+ subscriptExpr);
+
+ InvokeExpr* subscriptCallExpr = m_astBuilder->create<InvokeExpr>();
+ subscriptCallExpr->loc = subscriptExpr->loc;
+ subscriptCallExpr->functionExpr = subscriptFuncExpr;
+ subscriptCallExpr->arguments.addRange(subscriptExpr->indexExprs);
+ subscriptCallExpr->argumentDelimeterLocs.addRange(subscriptExpr->argumentDelimeterLocs);
+
+ return CheckInvokeExprWithCheckedOperands(subscriptCallExpr);
}
Expr* SemanticsExprVisitor::visitParenExpr(ParenExpr* expr)
@@ -2306,7 +2292,6 @@ namespace Slang
expr->type = GetTypeForDeclRef(expr->declRef, expr->loc);
return expr;
}
-
expr->type = QualType(m_astBuilder->getErrorType());
auto lookupResult = lookUp(
m_astBuilder,
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 544bbe170..072a4a5fa 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1445,6 +1445,12 @@ namespace Slang
DeclRef<PropertyDecl> satisfyingMemberDeclRef,
DeclRef<PropertyDecl> requiredMemberDeclRef,
RefPtr<WitnessTable> witnessTable);
+
+ bool doesSubscriptMatchRequirement(
+ DeclRef<SubscriptDecl> satisfyingMemberDeclRef,
+ DeclRef<SubscriptDecl> requiredMemberDeclRef,
+ RefPtr<WitnessTable> witnessTable);
+
bool doesVarMatchRequirement(
DeclRef<VarDeclBase> satisfyingMemberDeclRef,
DeclRef<VarDeclBase> requiredMemberDeclRef,
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 5c169a4ba..656e28701 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -274,7 +274,7 @@ DIAGNOSTIC(30010, Error, whilePredicateTypeError2, "'while': expression must eva
DIAGNOSTIC(30011, Error, assignNonLValue, "left of '=' is not an l-value.")
DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for operator $0 ($1).")
DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).")
-DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'")
+DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'")
DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.")
DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.")
DIAGNOSTIC(30018, Error, typeNotInTheSameHierarchy, "invalid use of 'as' operator: expression evaluates to '$0', which is not in the same type hierarchy as target type '$1'.")
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index aa21127e9..3274c3223 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1397,7 +1397,7 @@ struct SPIRVEmitContext
case kIROp_RateQualifiedType:
{
- auto result = emitGlobalInst(as<IRRateQualifiedType>(inst)->getValueType());
+ auto result = ensureInst(as<IRRateQualifiedType>(inst)->getValueType());
registerInst(inst, result);
return result;
}
diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h
index 1ebc4e350..9051e74df 100644
--- a/source/slang/slang-ir-inst-pass-base.h
+++ b/source/slang/slang-ir-inst-pass-base.h
@@ -17,6 +17,7 @@ namespace Slang
InstHashSet workListSet;
void addToWorkList(IRInst* inst)
{
+ SLANG_ASSERT(inst);
if (workListSet.contains(inst))
return;
@@ -139,6 +140,7 @@ namespace Slang
default:
break;
}
+ SLANG_ASSERT(child);
if (shouldInstBeLiveIfParentIsLive(child, IRDeadCodeEliminationOptions()))
addToWorkList(child);
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index d75b66a9b..b4dac190a 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1432,9 +1432,10 @@ bool shouldDeclBeTreatedAsInterfaceRequirement(Decl* requirementDecl)
{
if (const auto funcDecl = as<CallableDecl>(requirementDecl))
{
- }
- else if (const auto propertyDecl = as<PropertyDecl>(requirementDecl))
- {
+ // Subscript decl itself won't have a witness table entry.
+ // But its accessors will.
+ if (const auto subscriptDecl = as<SubscriptDecl>(requirementDecl))
+ return false;
}
else if (const auto assocTypeDecl = as<AssocTypeDecl>(requirementDecl))
{
@@ -1451,6 +1452,9 @@ bool shouldDeclBeTreatedAsInterfaceRequirement(Decl* requirementDecl)
}
else
{
+ // We will return false for PropertyDecl because the property decl itself
+ // won't have a witness table entry. Instead there will be witness entries
+ // for its accessors.
return false;
}
return true;
@@ -3020,6 +3024,7 @@ void _lowerFuncDeclBaseTypeInfo(
auto& parameterLists = outInfo.parameterLists;
collectParameterLists(
context,
+
declRef,
&parameterLists, kParameterListCollectMode_Default, kParameterDirection_In);
@@ -3501,6 +3506,14 @@ struct ExprLoweringContext
// appropriately.
auto funcDeclRef = resolvedInfo.funcDeclRef;
auto baseExpr = resolvedInfo.baseExpr;
+ if (baseExpr)
+ {
+ // The base expression might be an "upcast" to a base interface, in
+ // which case we don't want to emit the result of the cast, but instead
+ // the source.
+ //
+ baseExpr = this->maybeIgnoreCastToInterface(baseExpr);
+ }
// If the thing being invoked is a subscript operation,
// then we need to handle multiple extra details
@@ -3550,12 +3563,6 @@ struct ExprLoweringContext
// a member function:
if (baseExpr)
{
- // The base expression might be an "upcast" to a base interface, in
- // which case we don't want to emit the result of the cast, but instead
- // the source.
- //
- baseExpr = this->maybeIgnoreCastToInterface(baseExpr);
-
auto thisType = getThisParamTypeForCallable(context, funcDeclRef);
auto irThisType = lowerType(context, thisType);
addCallArgsForParam(
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index 4addb1d53..b2b5deb22 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -133,6 +133,11 @@ namespace Slang
return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr());
}
+ inline Type* getType(ASTBuilder* astBuilder, DeclRef<SubscriptDecl> declRef)
+ {
+ return declRef.substitute(astBuilder, declRef.getDecl()->returnType.Ptr());
+ }
+
inline Type* getType(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> declRef)
{
return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr());