From 4f6f827e26ffcb9b850ef8a8b7f7b4beb5addb7a Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 27 Aug 2024 18:48:41 -0700 Subject: Add functor syntax support. (#4926) --- source/slang/core.meta.slang | 8 +++--- source/slang/slang-check-decl.cpp | 2 +- source/slang/slang-check-expr.cpp | 49 ++++++++++++++++++++++++++++++++---- source/slang/slang-check-impl.h | 4 +++ source/slang/slang-diagnostic-defs.h | 1 + source/slang/slang-parser.cpp | 10 ++++++-- 6 files changed, 62 insertions(+), 12 deletions(-) (limited to 'source/slang') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index db0acb3ed..629737d6c 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -957,25 +957,25 @@ extension Tuple : IComparable interface IMutatingFunc { [mutating] - TR __call(expand each TP p); + TR operator()(expand each TP p); } interface IFunc : IMutatingFunc { - TR __call(expand each TP p); + TR operator()(expand each TP p); } interface IDifferentiableMutatingFunc : IMutatingFunc { [Differentiable] [mutating] - TR __call(expand each TP p); + TR operator()(expand each TP p); } interface IDifferentiableFunc : IFunc, IDifferentiableMutatingFunc { [Differentiable] - TR __call(expand each TP p); + TR operator()(expand each TP p); } __generic diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 66707fc56..3bd6bd327 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7881,7 +7881,7 @@ namespace Slang ctorToInvoke->declRef = declInfo.defaultCtor->getDefaultDeclRef(); ctorToInvoke->name = declInfo.defaultCtor->getName(); ctorToInvoke->loc = declInfo.defaultCtor->loc; - ctorToInvoke->type = structDeclInfo.defaultCtor->returnType.type; + ctorToInvoke->type = m_astBuilder->getFuncType(ArrayView(), structDeclInfo.defaultCtor->returnType.type); auto invoke = m_astBuilder->create(); invoke->functionExpr = ctorToInvoke; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 3072c3257..500407e26 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2693,6 +2693,42 @@ namespace Slang return newExpr; expr->functionExpr = CheckTerm(expr->functionExpr); + + if (auto baseType = as(expr->functionExpr->type)) + { + // If callee is a value of DeclRefType, then it is a functor. + // We need to look for `operator()` member within the type and + // call that instead. + auto operatorName = getName("()"); + + bool needDeref = false; + expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref); + + LookupResult lookupResult = lookUpMember( + m_astBuilder, + this, + operatorName, + expr->functionExpr->type, + m_outerScope, + LookupMask::Default, + LookupOptions::NoDeref); + bool diagnosed = false; + lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); + if (!lookupResult.isValid()) + { + if (!diagnosed) + getSink()->diagnose(expr, Diagnostics::callOperatorNotFound, baseType); + return CreateErrorExpr(expr); + } + auto callFuncExpr = createLookupResultExpr( + operatorName, + lookupResult, + expr->functionExpr, + expr->loc, + expr->functionExpr); + expr->functionExpr = callFuncExpr; + } + m_treatAsDifferentiableExpr = treatAsDifferentiableExpr; // If we are in a differentiable function, register differential witness tables involved in @@ -4400,12 +4436,8 @@ namespace Slang return expr; } - Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref) + Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref) { - auto baseExpr = inBaseExpr; - - baseExpr = CheckTerm(baseExpr); - auto derefExpr = MaybeDereference(baseExpr); if (derefExpr != baseExpr) @@ -4459,6 +4491,13 @@ namespace Slang return baseExpr; } + Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref) + { + auto baseExpr = inBaseExpr; + baseExpr = CheckTerm(baseExpr); + return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref); + } + Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType) { LookupResult lookupResult = lookUpMember( diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index f997abc57..96b467b0c 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2669,6 +2669,10 @@ namespace Slang /// Perform checking operations required for the "base" expression of a member-reference like `base.someField` Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref); + /// Prepare baseExpr for use as the base of a member expr. + /// This include inserting implicit open-existential operations as needed. + Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref); + Expr* lookupMemberResultFailure( DeclRefExpr* expr, QualType const& baseType, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index e058fcd91..d23ae8a3e 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -283,6 +283,7 @@ DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for oper DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).") DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'") DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.") +DIAGNOSTIC(30016, Error, callOperatorNotFound, "no call operation found for type '$0'") DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.") DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'") DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)") diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 5881e5796..b3a18f8a8 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1367,6 +1367,9 @@ namespace Slang case TokenType::Comma: case TokenType::OpAssign: break; + case TokenType::LParent: + parser->ReadToken(TokenType::RParent); + break; // Note(tfoley): Even more of a hack! case TokenType::QuestionMark: @@ -1382,6 +1385,9 @@ namespace Slang break; } + if (nameToken.type == TokenType::LParent) + return NameLoc(getName(parser, "()"), nameToken.loc); + return NameLoc( getName(parser, nameToken.getContent()), nameToken.loc); @@ -7051,7 +7057,7 @@ namespace Slang varExpr->scope = parser->currentScope; parser->FillPosition(varExpr); - auto nameAndLoc = NameLoc(parser->ReadToken()); + auto nameAndLoc = ParseDeclName(parser); varExpr->name = nameAndLoc.name; if(peekTokenType(parser) == TokenType::OpLess) @@ -7177,7 +7183,7 @@ namespace Slang memberExpr->baseExpression = expr; parser->ReadToken(nextTokenType); parser->FillPosition(memberExpr); - memberExpr->name = expectIdentifier(parser).name; + memberExpr->name = ParseDeclName(parser).name; if (peekTokenType(parser) == TokenType::OpLess) expr = maybeParseGenericApp(parser, memberExpr); -- cgit v1.2.3