summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-27 18:48:41 -0700
committerGitHub <noreply@github.com>2024-08-27 18:48:41 -0700
commit4f6f827e26ffcb9b850ef8a8b7f7b4beb5addb7a (patch)
treee8f20e798866df7e10067ce5b7ae22f9dc57ff84 /source/slang
parentfbaa444d890f58fabc5933b0c28048d2c5d862c0 (diff)
Add functor syntax support. (#4926)
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/core.meta.slang8
-rw-r--r--source/slang/slang-check-decl.cpp2
-rw-r--r--source/slang/slang-check-expr.cpp49
-rw-r--r--source/slang/slang-check-impl.h4
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-parser.cpp10
6 files changed, 62 insertions, 12 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index db0acb3ed..629737d6c 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -957,25 +957,25 @@ extension Tuple<T> : IComparable
interface IMutatingFunc<TR, each TP>
{
[mutating]
- TR __call(expand each TP p);
+ TR operator()(expand each TP p);
}
interface IFunc<TR, each TP> : IMutatingFunc<TR, expand each TP>
{
- TR __call(expand each TP p);
+ TR operator()(expand each TP p);
}
interface IDifferentiableMutatingFunc<TR : IDifferentiable, each TP : IDifferentiable> : IMutatingFunc<TR, expand each TP>
{
[Differentiable]
[mutating]
- TR __call(expand each TP p);
+ TR operator()(expand each TP p);
}
interface IDifferentiableFunc<TR : IDifferentiable, each TP : IDifferentiable> : IFunc<TR, expand each TP>, IDifferentiableMutatingFunc<TR, expand each TP>
{
[Differentiable]
- TR __call(expand each TP p);
+ TR operator()(expand each TP p);
}
__generic<T>
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 66707fc56..3bd6bd327 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -7881,7 +7881,7 @@ namespace Slang
ctorToInvoke->declRef = declInfo.defaultCtor->getDefaultDeclRef();
ctorToInvoke->name = declInfo.defaultCtor->getName();
ctorToInvoke->loc = declInfo.defaultCtor->loc;
- ctorToInvoke->type = structDeclInfo.defaultCtor->returnType.type;
+ ctorToInvoke->type = m_astBuilder->getFuncType(ArrayView<Type*>(), structDeclInfo.defaultCtor->returnType.type);
auto invoke = m_astBuilder->create<InvokeExpr>();
invoke->functionExpr = ctorToInvoke;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 3072c3257..500407e26 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2693,6 +2693,42 @@ namespace Slang
return newExpr;
expr->functionExpr = CheckTerm(expr->functionExpr);
+
+ if (auto baseType = as<DeclRefType>(expr->functionExpr->type))
+ {
+ // If callee is a value of DeclRefType, then it is a functor.
+ // We need to look for `operator()` member within the type and
+ // call that instead.
+ auto operatorName = getName("()");
+
+ bool needDeref = false;
+ expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref);
+
+ LookupResult lookupResult = lookUpMember(
+ m_astBuilder,
+ this,
+ operatorName,
+ expr->functionExpr->type,
+ m_outerScope,
+ LookupMask::Default,
+ LookupOptions::NoDeref);
+ bool diagnosed = false;
+ lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed);
+ if (!lookupResult.isValid())
+ {
+ if (!diagnosed)
+ getSink()->diagnose(expr, Diagnostics::callOperatorNotFound, baseType);
+ return CreateErrorExpr(expr);
+ }
+ auto callFuncExpr = createLookupResultExpr(
+ operatorName,
+ lookupResult,
+ expr->functionExpr,
+ expr->loc,
+ expr->functionExpr);
+ expr->functionExpr = callFuncExpr;
+ }
+
m_treatAsDifferentiableExpr = treatAsDifferentiableExpr;
// If we are in a differentiable function, register differential witness tables involved in
@@ -4400,12 +4436,8 @@ namespace Slang
return expr;
}
- Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref)
+ Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref)
{
- auto baseExpr = inBaseExpr;
-
- baseExpr = CheckTerm(baseExpr);
-
auto derefExpr = MaybeDereference(baseExpr);
if (derefExpr != baseExpr)
@@ -4459,6 +4491,13 @@ namespace Slang
return baseExpr;
}
+ Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref)
+ {
+ auto baseExpr = inBaseExpr;
+ baseExpr = CheckTerm(baseExpr);
+ return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref);
+ }
+
Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType)
{
LookupResult lookupResult = lookUpMember(
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index f997abc57..96b467b0c 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2669,6 +2669,10 @@ namespace Slang
/// Perform checking operations required for the "base" expression of a member-reference like `base.someField`
Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref);
+ /// Prepare baseExpr for use as the base of a member expr.
+ /// This include inserting implicit open-existential operations as needed.
+ Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref);
+
Expr* lookupMemberResultFailure(
DeclRefExpr* expr,
QualType const& baseType,
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index e058fcd91..d23ae8a3e 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -283,6 +283,7 @@ DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for oper
DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).")
DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'")
DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.")
+DIAGNOSTIC(30016, Error, callOperatorNotFound, "no call operation found for type '$0'")
DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.")
DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'")
DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)")
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 5881e5796..b3a18f8a8 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -1367,6 +1367,9 @@ namespace Slang
case TokenType::Comma:
case TokenType::OpAssign:
break;
+ case TokenType::LParent:
+ parser->ReadToken(TokenType::RParent);
+ break;
// Note(tfoley): Even more of a hack!
case TokenType::QuestionMark:
@@ -1382,6 +1385,9 @@ namespace Slang
break;
}
+ if (nameToken.type == TokenType::LParent)
+ return NameLoc(getName(parser, "()"), nameToken.loc);
+
return NameLoc(
getName(parser, nameToken.getContent()),
nameToken.loc);
@@ -7051,7 +7057,7 @@ namespace Slang
varExpr->scope = parser->currentScope;
parser->FillPosition(varExpr);
- auto nameAndLoc = NameLoc(parser->ReadToken());
+ auto nameAndLoc = ParseDeclName(parser);
varExpr->name = nameAndLoc.name;
if(peekTokenType(parser) == TokenType::OpLess)
@@ -7177,7 +7183,7 @@ namespace Slang
memberExpr->baseExpression = expr;
parser->ReadToken(nextTokenType);
parser->FillPosition(memberExpr);
- memberExpr->name = expectIdentifier(parser).name;
+ memberExpr->name = ParseDeclName(parser).name;
if (peekTokenType(parser) == TokenType::OpLess)
expr = maybeParseGenericApp(parser, memberExpr);