diff options
| author | Yong He <yonghe@outlook.com> | 2023-10-25 07:50:14 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-10-25 22:50:14 +0800 |
| commit | 5dc3c2f57963de93ad03724a01ea48b8585dc15a (patch) | |
| tree | 072748b952eb03da7950110ed3a8f87da9b5e72f /source | |
| parent | f8bf75cf1ae0aeee155996a917c2925bc500f3e2 (diff) | |
Add `IArray`. (#3281)
* Initial support for generic interfaces.
* Cleanup.
* Add generic syntax for interfaces.
* Add `IArray`.
* Fix.
* Fix.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 33 | ||||
| -rw-r--r-- | source/slang/slang-check-constraint.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 92 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 63 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-pass-base.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 5 |
10 files changed, 209 insertions, 55 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 3e51d7028..0f600d7c6 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -317,10 +317,27 @@ interface __FlagsEnumType : __EnumType { }; +interface IArray<T> +{ + int getCount(); + __subscript(int index) -> T + { + get; + } +} + __generic<T, let N:int> __magic_type(ArrayExpressionType) -struct Array +struct Array : IArray<T> { + [ForceInline] + int getCount() { return N; } + + __subscript(int index) -> T + { + __intrinsic_op($(kIROp_GetElement)) + get; + } } // The "comma operator" is effectively just a generic function that returns its second @@ -885,7 +902,7 @@ extension int16_t /// An `N` component vector with elements of type `T`. __generic<T = float, let N : int = 4> __magic_type(VectorExpressionType) -struct vector +struct vector : IArray<T> { /// The element type of the vector typedef T Element; @@ -900,6 +917,11 @@ struct vector // TODO: we should revise semantic checking so this kind of "identity" conversion is not required __intrinsic_op(0) __init(vector<T,N> value); + + [ForceInline] + int getCount() { return N; } + + __subscript(int index) -> T { __intrinsic_op($(kIROp_GetElement)) get; } } const int kRowMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_ROW_MAJOR); @@ -908,10 +930,15 @@ const int kColumnMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_COLUMN_MAJOR); /// A matrix with `R` rows and `C` columns, with elements of type `T`. __generic<T = float, let R : int = 4, let C : int = 4, let L : int = $(SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)> __magic_type(MatrixExpressionType) -struct matrix +struct matrix : IArray<vector<T,C>> { __intrinsic_op($(kIROp_MakeMatrixFromScalar)) __init(T val); + + [ForceInline] + int getCount() { return R; } + + __subscript(int index) -> vector<T,C> { __intrinsic_op($(kIROp_GetElement)) get; } } ${{{{ diff --git a/source/slang/slang-check-constraint.cpp b/source/slang/slang-check-constraint.cpp index 996b88f48..8fd4061db 100644 --- a/source/slang/slang-check-constraint.cpp +++ b/source/slang/slang-check-constraint.cpp @@ -716,8 +716,38 @@ namespace Slang if (typeParamDecl->parentDecl == constraints.genericDecl) return TryUnifyTypeParam(constraints, typeParamDecl, fst); - // can't be unified if they refer to different declarations. - if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) return false; + // If they refer to different declarations, we need to check if one type's super type + // matches the other type, if so we can unify them. + if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) + { + { + auto fstTypeInheritanceInfo = getShared()->getInheritanceInfo(fstDeclRefType); + for (auto supType : fstTypeInheritanceInfo.facets) + { + if (supType->origin.declRef.getDecl() == sndDeclRef.getDecl()) + { + fstDeclRef = supType->origin.declRef; + goto endMatch; + } + } + } + // try the other direction + { + auto sndTypeInheritanceInfo = getShared()->getInheritanceInfo(sndDeclRefType); + for (auto supType : sndTypeInheritanceInfo.facets) + { + if (supType->origin.declRef.getDecl() == fstDeclRef.getDecl()) + { + sndDeclRef = supType->origin.declRef; + goto endMatch; + } + } + } + endMatch:; + // If they still refer to different decls, then we can't unify them. + if (fstDeclRef.getDecl() != sndDeclRef.getDecl()) + return false; + } // next we need to unify the substitutions applied // to each declaration reference. diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 8df5ae618..f75f84e21 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2058,6 +2058,90 @@ namespace Slang return true; } + bool SemanticsVisitor::doesSubscriptMatchRequirement( + DeclRef<SubscriptDecl> satisfyingMemberDeclRef, + DeclRef<SubscriptDecl> requiredMemberDeclRef, + RefPtr<WitnessTable> witnessTable) + { + // The result type and parameters of the satisfying member must match the type of the required member. + // + auto requiredParams = getParameters(m_astBuilder, requiredMemberDeclRef).toArray(); + auto satisfyingParams = getParameters(m_astBuilder, satisfyingMemberDeclRef).toArray(); + auto paramCount = requiredParams.getCount(); + if (satisfyingParams.getCount() != paramCount) + return false; + + for (Index paramIndex = 0; paramIndex < paramCount; ++paramIndex) + { + auto requiredParam = requiredParams[paramIndex]; + auto satisfyingParam = satisfyingParams[paramIndex]; + + auto requiredParamType = getType(m_astBuilder, requiredParam); + auto satisfyingParamType = getType(m_astBuilder, satisfyingParam); + + if (!requiredParamType->equals(satisfyingParamType)) + return false; + } + + auto requiredResultType = getResultType(m_astBuilder, requiredMemberDeclRef); + auto satisfyingResultType = getResultType(m_astBuilder, satisfyingMemberDeclRef); + if (!requiredResultType->equals(satisfyingResultType)) + return false; + + // Each accessor in the requirement must be accounted for by an accessor + // in the satisfying member. + // + // Note: it is fine for the satisfying member to provide *more* accessors + // than the original declaration. + // + Dictionary<DeclRef<AccessorDecl>, DeclRef<AccessorDecl>> mapRequiredToSatisfyingAccessorDeclRef; + for (auto requiredAccessorDeclRef : getMembersOfType<AccessorDecl>(m_astBuilder, requiredMemberDeclRef)) + { + // We need to search for an accessor that can satisfy the requirement. + // + // For now we will do the simplest (and slowest) thing of a linear search, + // which is mostly fine because the number of accessors is bounded. + // + bool found = false; + for (auto satisfyingAccessorDeclRef : getMembersOfType<AccessorDecl>(m_astBuilder, satisfyingMemberDeclRef)) + { + if (doesAccessorMatchRequirement(satisfyingAccessorDeclRef, requiredAccessorDeclRef)) + { + // When we find a match on an accessor, we record it so that + // we can set up the witness values later, but we do *not* + // record it into the actual witness table yet, in case + // a later accessor comes along that doesn't find a match. + // + mapRequiredToSatisfyingAccessorDeclRef.add(requiredAccessorDeclRef, satisfyingAccessorDeclRef); + found = true; + break; + } + } + if (!found) + return false; + } + + // Once things are done, we will install the satisfying values + // into the witness table for the requirements. + // + for (const auto& [key, value] : mapRequiredToSatisfyingAccessorDeclRef) + { + witnessTable->add( + key.getDecl(), + RequirementWitness(value)); + } + // + // Note: the subscript declaration itself isn't something that + // has a useful value/representation in downstream passes, so + // we are mostly just installing it into the witness table + // as a way to mark this requirement as being satisfied. + // + witnessTable->add( + requiredMemberDeclRef.getDecl(), + RequirementWitness(satisfyingMemberDeclRef)); + return true; + } + bool SemanticsVisitor::doesVarMatchRequirement( DeclRef<VarDeclBase> satisfyingMemberDeclRef, DeclRef<VarDeclBase> requiredMemberDeclRef, @@ -2514,6 +2598,14 @@ namespace Slang return doesVarMatchRequirement(varDeclRef, requiredVarDeclRef, witnessTable); } } + else if (auto subscriptDeclRef = memberDeclRef.as<SubscriptDecl>()) + { + if (auto requiredSubscriptDeclRef = requiredMemberDeclRef.as<SubscriptDecl>()) + { + ensureDecl(subscriptDeclRef, DeclCheckState::CanUseFuncSignature); + return doesSubscriptMatchRequirement(subscriptDeclRef, requiredSubscriptDeclRef, witnessTable); + } + } // Default: just assume that thing aren't being satisfied. return false; } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 22bc2cae8..b76fe0003 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1752,8 +1752,7 @@ namespace Slang Expr* SemanticsExprVisitor::visitIndexExpr(IndexExpr* subscriptExpr) { - auto baseExpr = subscriptExpr->baseExpression; - baseExpr = CheckExpr(baseExpr); + auto baseExpr = checkBaseForMemberExpr(subscriptExpr->baseExpression); for (auto& arg : subscriptExpr->indexExprs) { @@ -1822,47 +1821,34 @@ namespace Slang // Default behavior is to look at all available `__subscript` // declarations on the type and try to call one of them. - { - Name* name = getName("operator[]"); - LookupResult lookupResult = lookUpMember( - m_astBuilder, - this, - name, - baseType, - LookupMask::Default, - LookupOptions::NoDeref); - if (!lookupResult.isValid()) - { - goto fail; - } + auto operatorName = getName("operator[]"); - // Now that we know there is at least one subscript member, - // we will construct a reference to it and try to call it. - // - // Note: the expression may be an `OverloadedExpr`, in which - // case the attempt to call it will trigger overload - // resolution. - Expr* subscriptFuncExpr = createLookupResultExpr( - name, - lookupResult, - subscriptExpr->baseExpression, - subscriptExpr->loc, - subscriptExpr); - - InvokeExpr* subscriptCallExpr = m_astBuilder->create<InvokeExpr>(); - subscriptCallExpr->loc = subscriptExpr->loc; - subscriptCallExpr->functionExpr = subscriptFuncExpr; - subscriptCallExpr->arguments.addRange(subscriptExpr->indexExprs); - subscriptCallExpr->argumentDelimeterLocs.addRange(subscriptExpr->argumentDelimeterLocs); - - return CheckInvokeExprWithCheckedOperands(subscriptCallExpr); - } - - fail: + LookupResult lookupResult = lookUpMember( + m_astBuilder, + this, + operatorName, + baseType, + LookupMask::Default, + LookupOptions::NoDeref); + if (!lookupResult.isValid()) { getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); return CreateErrorExpr(subscriptExpr); } + auto subscriptFuncExpr = createLookupResultExpr( + operatorName, + lookupResult, + subscriptExpr->baseExpression, + subscriptExpr->loc, + subscriptExpr); + + InvokeExpr* subscriptCallExpr = m_astBuilder->create<InvokeExpr>(); + subscriptCallExpr->loc = subscriptExpr->loc; + subscriptCallExpr->functionExpr = subscriptFuncExpr; + subscriptCallExpr->arguments.addRange(subscriptExpr->indexExprs); + subscriptCallExpr->argumentDelimeterLocs.addRange(subscriptExpr->argumentDelimeterLocs); + + return CheckInvokeExprWithCheckedOperands(subscriptCallExpr); } Expr* SemanticsExprVisitor::visitParenExpr(ParenExpr* expr) @@ -2306,7 +2292,6 @@ namespace Slang expr->type = GetTypeForDeclRef(expr->declRef, expr->loc); return expr; } - expr->type = QualType(m_astBuilder->getErrorType()); auto lookupResult = lookUp( m_astBuilder, diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 544bbe170..072a4a5fa 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1445,6 +1445,12 @@ namespace Slang DeclRef<PropertyDecl> satisfyingMemberDeclRef, DeclRef<PropertyDecl> requiredMemberDeclRef, RefPtr<WitnessTable> witnessTable); + + bool doesSubscriptMatchRequirement( + DeclRef<SubscriptDecl> satisfyingMemberDeclRef, + DeclRef<SubscriptDecl> requiredMemberDeclRef, + RefPtr<WitnessTable> witnessTable); + bool doesVarMatchRequirement( DeclRef<VarDeclBase> satisfyingMemberDeclRef, DeclRef<VarDeclBase> requiredMemberDeclRef, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 5c169a4ba..656e28701 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -274,7 +274,7 @@ DIAGNOSTIC(30010, Error, whilePredicateTypeError2, "'while': expression must eva DIAGNOSTIC(30011, Error, assignNonLValue, "left of '=' is not an l-value.") DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for operator $0 ($1).") DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).") -DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'") +DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'") DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.") DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.") DIAGNOSTIC(30018, Error, typeNotInTheSameHierarchy, "invalid use of 'as' operator: expression evaluates to '$0', which is not in the same type hierarchy as target type '$1'.") diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index aa21127e9..3274c3223 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1397,7 +1397,7 @@ struct SPIRVEmitContext case kIROp_RateQualifiedType: { - auto result = emitGlobalInst(as<IRRateQualifiedType>(inst)->getValueType()); + auto result = ensureInst(as<IRRateQualifiedType>(inst)->getValueType()); registerInst(inst, result); return result; } diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h index 1ebc4e350..9051e74df 100644 --- a/source/slang/slang-ir-inst-pass-base.h +++ b/source/slang/slang-ir-inst-pass-base.h @@ -17,6 +17,7 @@ namespace Slang InstHashSet workListSet; void addToWorkList(IRInst* inst) { + SLANG_ASSERT(inst); if (workListSet.contains(inst)) return; @@ -139,6 +140,7 @@ namespace Slang default: break; } + SLANG_ASSERT(child); if (shouldInstBeLiveIfParentIsLive(child, IRDeadCodeEliminationOptions())) addToWorkList(child); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d75b66a9b..b4dac190a 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1432,9 +1432,10 @@ bool shouldDeclBeTreatedAsInterfaceRequirement(Decl* requirementDecl) { if (const auto funcDecl = as<CallableDecl>(requirementDecl)) { - } - else if (const auto propertyDecl = as<PropertyDecl>(requirementDecl)) - { + // Subscript decl itself won't have a witness table entry. + // But its accessors will. + if (const auto subscriptDecl = as<SubscriptDecl>(requirementDecl)) + return false; } else if (const auto assocTypeDecl = as<AssocTypeDecl>(requirementDecl)) { @@ -1451,6 +1452,9 @@ bool shouldDeclBeTreatedAsInterfaceRequirement(Decl* requirementDecl) } else { + // We will return false for PropertyDecl because the property decl itself + // won't have a witness table entry. Instead there will be witness entries + // for its accessors. return false; } return true; @@ -3020,6 +3024,7 @@ void _lowerFuncDeclBaseTypeInfo( auto& parameterLists = outInfo.parameterLists; collectParameterLists( context, + declRef, ¶meterLists, kParameterListCollectMode_Default, kParameterDirection_In); @@ -3501,6 +3506,14 @@ struct ExprLoweringContext // appropriately. auto funcDeclRef = resolvedInfo.funcDeclRef; auto baseExpr = resolvedInfo.baseExpr; + if (baseExpr) + { + // The base expression might be an "upcast" to a base interface, in + // which case we don't want to emit the result of the cast, but instead + // the source. + // + baseExpr = this->maybeIgnoreCastToInterface(baseExpr); + } // If the thing being invoked is a subscript operation, // then we need to handle multiple extra details @@ -3550,12 +3563,6 @@ struct ExprLoweringContext // a member function: if (baseExpr) { - // The base expression might be an "upcast" to a base interface, in - // which case we don't want to emit the result of the cast, but instead - // the source. - // - baseExpr = this->maybeIgnoreCastToInterface(baseExpr); - auto thisType = getThisParamTypeForCallable(context, funcDeclRef); auto irThisType = lowerType(context, thisType); addCallArgsForParam( diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 4addb1d53..b2b5deb22 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -133,6 +133,11 @@ namespace Slang return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); } + inline Type* getType(ASTBuilder* astBuilder, DeclRef<SubscriptDecl> declRef) + { + return declRef.substitute(astBuilder, declRef.getDecl()->returnType.Ptr()); + } + inline Type* getType(ASTBuilder* astBuilder, DeclRef<EnumCaseDecl> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->type.Ptr()); |
