diff options
| -rw-r--r-- | docs/user-guide/03-convenience-features.md | 16 | ||||
| -rw-r--r-- | source/slang/core.meta.slang | 8 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 49 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 10 | ||||
| -rw-r--r-- | tests/language-feature/ifunc/diff-functor.slang | 8 | ||||
| -rw-r--r-- | tests/language-feature/ifunc/functor.slang | 40 | ||||
| -rw-r--r-- | tests/language-feature/ifunc/ifunc.slang | 6 |
10 files changed, 125 insertions, 19 deletions
diff --git a/docs/user-guide/03-convenience-features.md b/docs/user-guide/03-convenience-features.md index 559e4c9b4..1561d6605 100644 --- a/docs/user-guide/03-convenience-features.md +++ b/docs/user-guide/03-convenience-features.md @@ -331,6 +331,22 @@ int test() ``` Slang currently supports overloading the following operators: `+`, `-`, `*`, `/`, `%`, `&`, `|`, `<`, `>`, `<=`, `>=`, `==`, `!=`, unary `-`, `~` and `!`. Please note that the `&&` and `||` operators are not supported. +In addition, you can overload operator `()` as a member method: +```csharp +struct MyFunctor +{ + int operator()(float v) + { + // ... + } +} +void test() +{ + MyFunctor f; + int x = f(1.0f); // calls MyFunctor::operator(). + int y = f.operator()(1.0f); // explicitly calling operator(). +} +``` ## Subscript Operator diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index db0acb3ed..629737d6c 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -957,25 +957,25 @@ extension Tuple<T> : IComparable interface IMutatingFunc<TR, each TP> { [mutating] - TR __call(expand each TP p); + TR operator()(expand each TP p); } interface IFunc<TR, each TP> : IMutatingFunc<TR, expand each TP> { - TR __call(expand each TP p); + TR operator()(expand each TP p); } interface IDifferentiableMutatingFunc<TR : IDifferentiable, each TP : IDifferentiable> : IMutatingFunc<TR, expand each TP> { [Differentiable] [mutating] - TR __call(expand each TP p); + TR operator()(expand each TP p); } interface IDifferentiableFunc<TR : IDifferentiable, each TP : IDifferentiable> : IFunc<TR, expand each TP>, IDifferentiableMutatingFunc<TR, expand each TP> { [Differentiable] - TR __call(expand each TP p); + TR operator()(expand each TP p); } __generic<T> diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 66707fc56..3bd6bd327 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7881,7 +7881,7 @@ namespace Slang ctorToInvoke->declRef = declInfo.defaultCtor->getDefaultDeclRef(); ctorToInvoke->name = declInfo.defaultCtor->getName(); ctorToInvoke->loc = declInfo.defaultCtor->loc; - ctorToInvoke->type = structDeclInfo.defaultCtor->returnType.type; + ctorToInvoke->type = m_astBuilder->getFuncType(ArrayView<Type*>(), structDeclInfo.defaultCtor->returnType.type); auto invoke = m_astBuilder->create<InvokeExpr>(); invoke->functionExpr = ctorToInvoke; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 3072c3257..500407e26 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2693,6 +2693,42 @@ namespace Slang return newExpr; expr->functionExpr = CheckTerm(expr->functionExpr); + + if (auto baseType = as<DeclRefType>(expr->functionExpr->type)) + { + // If callee is a value of DeclRefType, then it is a functor. + // We need to look for `operator()` member within the type and + // call that instead. + auto operatorName = getName("()"); + + bool needDeref = false; + expr->functionExpr = maybeInsertImplicitOpForMemberBase(expr->functionExpr, needDeref); + + LookupResult lookupResult = lookUpMember( + m_astBuilder, + this, + operatorName, + expr->functionExpr->type, + m_outerScope, + LookupMask::Default, + LookupOptions::NoDeref); + bool diagnosed = false; + lookupResult = filterLookupResultByVisibilityAndDiagnose(lookupResult, expr->loc, diagnosed); + if (!lookupResult.isValid()) + { + if (!diagnosed) + getSink()->diagnose(expr, Diagnostics::callOperatorNotFound, baseType); + return CreateErrorExpr(expr); + } + auto callFuncExpr = createLookupResultExpr( + operatorName, + lookupResult, + expr->functionExpr, + expr->loc, + expr->functionExpr); + expr->functionExpr = callFuncExpr; + } + m_treatAsDifferentiableExpr = treatAsDifferentiableExpr; // If we are in a differentiable function, register differential witness tables involved in @@ -4400,12 +4436,8 @@ namespace Slang return expr; } - Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref) + Expr* SemanticsVisitor::maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref) { - auto baseExpr = inBaseExpr; - - baseExpr = CheckTerm(baseExpr); - auto derefExpr = MaybeDereference(baseExpr); if (derefExpr != baseExpr) @@ -4459,6 +4491,13 @@ namespace Slang return baseExpr; } + Expr* SemanticsVisitor::checkBaseForMemberExpr(Expr* inBaseExpr, bool& outNeedDeref) + { + auto baseExpr = inBaseExpr; + baseExpr = CheckTerm(baseExpr); + return maybeInsertImplicitOpForMemberBase(baseExpr, outNeedDeref); + } + Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType) { LookupResult lookupResult = lookUpMember( diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index f997abc57..96b467b0c 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2669,6 +2669,10 @@ namespace Slang /// Perform checking operations required for the "base" expression of a member-reference like `base.someField` Expr* checkBaseForMemberExpr(Expr* baseExpr, bool& outNeedDeref); + /// Prepare baseExpr for use as the base of a member expr. + /// This include inserting implicit open-existential operations as needed. + Expr* maybeInsertImplicitOpForMemberBase(Expr* baseExpr, bool& outNeedDeref); + Expr* lookupMemberResultFailure( DeclRefExpr* expr, QualType const& baseType, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index e058fcd91..d23ae8a3e 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -283,6 +283,7 @@ DIAGNOSTIC(30012, Error, noApplicationUnaryOperator, "no overload found for oper DIAGNOSTIC(30012, Error, noOverloadFoundForBinOperatorOnTypes, "no overload found for operator $0 ($1, $2).") DIAGNOSTIC(30013, Error, subscriptNonArray, "no subscript operation found for type '$0'") DIAGNOSTIC(30014, Error, subscriptIndexNonInteger, "index expression must evaluate to int.") +DIAGNOSTIC(30016, Error, callOperatorNotFound, "no call operation found for type '$0'") DIAGNOSTIC(30015, Error, undefinedIdentifier2, "undefined identifier '$0'.") DIAGNOSTIC(30019, Error, typeMismatch, "expected an expression of type '$0', got '$1'") DIAGNOSTIC(30021, Error, noApplicationFunction, "$0: no overload takes arguments ($1)") diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 5881e5796..b3a18f8a8 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1367,6 +1367,9 @@ namespace Slang case TokenType::Comma: case TokenType::OpAssign: break; + case TokenType::LParent: + parser->ReadToken(TokenType::RParent); + break; // Note(tfoley): Even more of a hack! case TokenType::QuestionMark: @@ -1382,6 +1385,9 @@ namespace Slang break; } + if (nameToken.type == TokenType::LParent) + return NameLoc(getName(parser, "()"), nameToken.loc); + return NameLoc( getName(parser, nameToken.getContent()), nameToken.loc); @@ -7051,7 +7057,7 @@ namespace Slang varExpr->scope = parser->currentScope; parser->FillPosition(varExpr); - auto nameAndLoc = NameLoc(parser->ReadToken()); + auto nameAndLoc = ParseDeclName(parser); varExpr->name = nameAndLoc.name; if(peekTokenType(parser) == TokenType::OpLess) @@ -7177,7 +7183,7 @@ namespace Slang memberExpr->baseExpression = expr; parser->ReadToken(nextTokenType); parser->FillPosition(memberExpr); - memberExpr->name = expectIdentifier(parser).name; + memberExpr->name = ParseDeclName(parser).name; if (peekTokenType(parser) == TokenType::OpLess) expr = maybeParseGenericApp(parser, memberExpr); diff --git a/tests/language-feature/ifunc/diff-functor.slang b/tests/language-feature/ifunc/diff-functor.slang index 04b0be44f..117cce76b 100644 --- a/tests/language-feature/ifunc/diff-functor.slang +++ b/tests/language-feature/ifunc/diff-functor.slang @@ -6,7 +6,7 @@ struct DiffFunctor : IDifferentiableFunc<float, float> { [Differentiable] - float __call(float p) + float operator()(float p) { return p + 1; } @@ -14,19 +14,19 @@ struct DiffFunctor : IDifferentiableFunc<float, float> float apply(IMutatingFunc<float, float> f, float p) { - return f.__call(p); + return f(p); } [Differentiable] float applyDiff(IDifferentiableFunc<float, float> f, float p) { - return f.__call(p); + return f(p); } [Differentiable] TR applyDiffGen<TR : IDifferentiable, each TP : IDifferentiable>(IDifferentiableFunc<TR, TP> f, expand each TP p) { - return f.__call(expand each p); + return f(expand each p); } //TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer diff --git a/tests/language-feature/ifunc/functor.slang b/tests/language-feature/ifunc/functor.slang new file mode 100644 index 000000000..73987cbbf --- /dev/null +++ b/tests/language-feature/ifunc/functor.slang @@ -0,0 +1,40 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-dx12 -use-dxil -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type + +struct Functor : IMutatingFunc<float, float> +{ + int context; + + [mutating] + float operator()(float p) + { + context += (int)p; + return context; + } +} + +float apply<T:IMutatingFunc<float,float>>(inout T f, float p) +{ + return f(p); +} + +//TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer +RWStructuredBuffer<uint> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint tid: SV_DispatchThreadID) +{ + Functor f; + f.context = 0; + + f(1.0f); + f.operator()(1.0f); // explicit operator () call should also work. + + apply(f, 2.0f); + apply(f, 3.0f); + + // CHECK: 7 + outputBuffer[0] = (uint)f.context; +} diff --git a/tests/language-feature/ifunc/ifunc.slang b/tests/language-feature/ifunc/ifunc.slang index f270299b3..6c946f4df 100644 --- a/tests/language-feature/ifunc/ifunc.slang +++ b/tests/language-feature/ifunc/ifunc.slang @@ -5,7 +5,7 @@ struct Functor : IFunc<int, int, bool> { - int __call(int p, bool t) + int operator()(int p, bool t) { return p + 1; } @@ -15,7 +15,7 @@ struct MutatingFunctor : IMutatingFunc<int, int, bool> { int data = 0; [mutating] - int __call(int p, bool t) + int operator()(int p, bool t) { data++; return p + 1; @@ -24,7 +24,7 @@ struct MutatingFunctor : IMutatingFunc<int, int, bool> int apply(IMutatingFunc<int, int, bool> f, int p) { - return f.__call(p, true); + return f(p, true); } //TEST_INPUT:ubuffer(data=[0 3 2 2], stride=4):out,name=outputBuffer |
