summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/user-guide/03-convenience-features.md16
-rw-r--r--source/slang/core.meta.slang8
-rw-r--r--source/slang/slang-check-decl.cpp2
-rw-r--r--source/slang/slang-check-expr.cpp49
-rw-r--r--source/slang/slang-check-impl.h4
-rw-r--r--source/slang/slang-diagnostic-defs.h1
-rw-r--r--source/slang/slang-parser.cpp10
-rw-r--r--tests/language-feature/ifunc/diff-functor.slang8
-rw-r--r--tests/language-feature/ifunc/functor.slang40
-rw-r--r--tests/language-feature/ifunc/ifunc.slang6
10 files changed, 125 insertions, 19 deletions
diff --git a/docs/user-guide/03-convenience-features.md b/docs/user-guide/03-convenience-features.md
index 559e4c9b4..1561d6605 100644
--- a/docs/user-guide/03-convenience-features.md
+++ b/docs/user-guide/03-convenience-features.md
@@ -331,6 +331,22 @@ int test()
```
Slang currently supports overloading the following operators: `+`, `-`, `*`, `/`, `%`, `&`, `|`, `<`, `>`, `<=`, `>=`, `==`, `!=`, unary `-`, `~` and `!`. Please note that the `&&` and `||` operators are not supported.
+In addition, you can overload operator `()` as a member method:
+```csharp
+struct MyFunctor
+{
+ int operator()(float v)
+ {
+ // ...
+ }
+}
+void test()
+{
+ MyFunctor f;
+ int x = f(1.0f); // calls MyFunctor::operator().
+ int y = f.operator()(1.0f); // explicitly calling operator().
+}
+```
## Subscript Operator
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index db0acb3ed..629737d6c 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -957,25 +957,25 @@ extension Tuple<T> : IComparable
interface IMutatingFunc<TR, each TP>
{
[mutating]
- TR __call(expand each TP p);
+ TR operator()(expand each TP p);
}
interface IFunc<TR, each TP> : IMutatingFunc<TR, expand each TP>
{
- TR __call(expand each TP p);
+ TR operator()(expand each TP p);
}
interface IDifferentiableMutatingFunc<TR : IDifferentiable, each TP : IDifferentiable> : IMutatingFunc<TR, expand each TP>
{
[Differentiable]
[mutating]
- TR __call(expand each TP p);
+ TR operator()(expand each TP p);
}
interface IDifferentiableFunc<TR : IDifferentiable, each TP : IDifferentiable> : IFunc<TR, expand each TP>, IDifferentiableMutatingFunc<TR, expand each TP>
{
[Differentiable]
- TR __call(expand each TP p);
+ TR operator()(expand each TP p);
}
__generic<T>
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 66707fc56..3bd6bd327 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -7881,7 +7881,7 @@ namespace Slang
ctorToInvoke->declRef = declInfo.defaultCtor->getDefaultDeclRef();
ctorToInvoke->name = declInfo.defaultCtor->getName();
ctorToInvoke->loc = declInfo.defaultCtor->loc;
- ctorToInvoke->type = structDeclInfo.defaultCtor->returnType.type;
+ ctorToInvoke->type = m_astBuilder->getFuncType(ArrayView<Type*>(), structDeclInfo.defaultCtor->returnType.type);
auto invoke = m_astBuilder->create<InvokeExpr>();
invoke->functionExpr = ctorToInvoke;
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 3072c3257..500407e26 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2693,6 +2693,42 @@ namespace Slang
return newExpr;
expr->functionExpr = CheckTerm(expr->functionExpr);
+
+ if (auto baseType = as<DeclRefType>(expr->functionExpr->type))
+ {
+ // If callee is a value of DeclRefType, then it is a functor.
+ // We need to look for `operator()` member within the type and
+ // call that instead.
+ auto operatorName = getName("()");
+
+ bool needDeref = false;
+ expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref);
+
+ LookupResult lookupResult = lookUpMember(
+ m_astBuilder,
+ this,
+ operatorName,
+ expr->functionExpr->type,
+ m_outerScope,
+ LookupMask::Default,
+ LookupOptions::NoDeref);
+ bool diagnosed = false;
+ lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed);
+ if (!lookupResult.isValid())
+ {
+ if (!diagnosed)
+ getSink()->diagnose(expr, Diagnostics::callOperatorNotFound, baseType);
+ return CreateErrorExpr(expr);
+ }
+ auto callFuncExpr = createLookupResultExpr(
+ operatorName,
+ lookupResult,
+ expr->functionExpr,
+ expr->loc,
+ expr->functionExpr);
+ expr->functionExpr = callFuncExpr;
+ }
+
m_treatAsDifferentiableExpr = treatAsDifferentiableExpr;
// If we are in a differentiable function, register differential witness tables involved in
@@ -4400,12 +4436,8 @@ namespace Slang
return expr;
}
- Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref)
+ Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref)
{
- auto baseExpr = inBaseExpr;
-
- baseExpr = CheckTerm(baseExpr);
-
auto derefExpr = MaybeDereference(baseExpr);
if (derefExpr != baseExpr)
@@ -4459,6 +4491,13 @@ namespace Slang
return baseExpr;
}
+ Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref)
+ {
+ auto baseExpr = inBaseExpr;
+ baseExpr = CheckTerm(baseExpr);
+ return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref);
+ }
+
Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType)
{
LookupResult lookupResult = lookUpMember(
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index f997abc57..96b467b0c 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2669,6 +2669,10 @@ namespace Slang
/// Perform checking operations required for the "base" expression of a member-reference like `base.someField`
Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref);
+ /// Prepare baseExpr for use as the base of a member expr.
+ /// This include inserting implicit open-existential operations as needed.
+ Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref);
+
Expr* lookupMemberResultFailure(
DeclRefExpr* expr,
QualType const& baseType,
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index e058fcd91..d23ae8a3e 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -283,6 +283,7 @@ DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for oper
DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).")
DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'")
DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.")
+DIAGNOSTIC(30016, Error, callOperatorNotFound, "no call operation found for type '$0'")
DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.")
DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'")
DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)")
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 5881e5796..b3a18f8a8 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -1367,6 +1367,9 @@ namespace Slang
case TokenType::Comma:
case TokenType::OpAssign:
break;
+ case TokenType::LParent:
+ parser->ReadToken(TokenType::RParent);
+ break;
// Note(tfoley): Even more of a hack!
case TokenType::QuestionMark:
@@ -1382,6 +1385,9 @@ namespace Slang
break;
}
+ if (nameToken.type == TokenType::LParent)
+ return NameLoc(getName(parser, "()"), nameToken.loc);
+
return NameLoc(
getName(parser, nameToken.getContent()),
nameToken.loc);
@@ -7051,7 +7057,7 @@ namespace Slang
varExpr->scope = parser->currentScope;
parser->FillPosition(varExpr);
- auto nameAndLoc = NameLoc(parser->ReadToken());
+ auto nameAndLoc = ParseDeclName(parser);
varExpr->name = nameAndLoc.name;
if(peekTokenType(parser) == TokenType::OpLess)
@@ -7177,7 +7183,7 @@ namespace Slang
memberExpr->baseExpression = expr;
parser->ReadToken(nextTokenType);
parser->FillPosition(memberExpr);
- memberExpr->name = expectIdentifier(parser).name;
+ memberExpr->name = ParseDeclName(parser).name;
if (peekTokenType(parser) == TokenType::OpLess)
expr = maybeParseGenericApp(parser, memberExpr);
diff --git a/tests/language-feature/ifunc/diff-functor.slang b/tests/language-feature/ifunc/diff-functor.slang
index 04b0be44f..117cce76b 100644
--- a/tests/language-feature/ifunc/diff-functor.slang
+++ b/tests/language-feature/ifunc/diff-functor.slang
@@ -6,7 +6,7 @@
struct DiffFunctor : IDifferentiableFunc<float, float>
{
[Differentiable]
- float __call(float p)
+ float operator()(float p)
{
return p + 1;
}
@@ -14,19 +14,19 @@ struct DiffFunctor : IDifferentiableFunc<float, float>
float apply(IMutatingFunc<float, float> f, float p)
{
- return f.__call(p);
+ return f(p);
}
[Differentiable]
float applyDiff(IDifferentiableFunc<float, float> f, float p)
{
- return f.__call(p);
+ return f(p);
}
[Differentiable]
TR applyDiffGen<TR : IDifferentiable, each TP : IDifferentiable>(IDifferentiableFunc<TR, TP> f, expand each TP p)
{
- return f.__call(expand each p);
+ return f(expand each p);
}
//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer
diff --git a/tests/language-feature/ifunc/functor.slang b/tests/language-feature/ifunc/functor.slang
new file mode 100644
index 000000000..73987cbbf
--- /dev/null
+++ b/tests/language-feature/ifunc/functor.slang
@@ -0,0 +1,40 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type
+
+struct Functor : IMutatingFunc<float, float>
+{
+ int context;
+
+ [mutating]
+ float operator()(float p)
+ {
+ context += (int)p;
+ return context;
+ }
+}
+
+float apply<T:IMutatingFunc<float,float>>(inout T f, float p)
+{
+ return f(p);
+}
+
+//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint tid: SV_DispatchThreadID)
+{
+ Functor f;
+ f.context = 0;
+
+ f(1.0f);
+ f.operator()(1.0f); // explicit operator () call should also work.
+
+ apply(f, 2.0f);
+ apply(f, 3.0f);
+
+ // CHECK: 7
+ outputBuffer[0] = (uint)f.context;
+}
diff --git a/tests/language-feature/ifunc/ifunc.slang b/tests/language-feature/ifunc/ifunc.slang
index f270299b3..6c946f4df 100644
--- a/tests/language-feature/ifunc/ifunc.slang
+++ b/tests/language-feature/ifunc/ifunc.slang
@@ -5,7 +5,7 @@
struct Functor : IFunc<int, int, bool>
{
- int __call(int p, bool t)
+ int operator()(int p, bool t)
{
return p + 1;
}
@@ -15,7 +15,7 @@ struct MutatingFunctor : IMutatingFunc<int, int, bool>
{
int data = 0;
[mutating]
- int __call(int p, bool t)
+ int operator()(int p, bool t)
{
data++;
return p + 1;
@@ -24,7 +24,7 @@ struct MutatingFunctor : IMutatingFunc<int, int, bool>
int apply(IMutatingFunc<int, int, bool> f, int p)
{
- return f.__call(p, true);
+ return f(p, true);
}
//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer