diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ast-decl.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ast-iterator.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-synthesis.cpp | 27 | ||||
| -rw-r--r-- | source/slang/slang-ast-synthesis.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 207 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 22 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 41 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-language-server-ast-lookup.cpp | 29 | ||||
| -rw-r--r-- | source/slang/slang-language-server-semantic-tokens.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 40 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 5 |
15 files changed, 423 insertions, 29 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 3201a3e1d..6fb281247 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -431,6 +431,14 @@ class ConstructorDecl : public FunctionDeclBase bool containsFlavor(ConstructorFlavor flavor) { return m_flavor & (int)flavor; } }; +FIDDLE() +class LambdaDecl : public StructDecl +{ + FIDDLE(...) + + FIDDLE() FunctionDeclBase* funcDecl; +}; + // A subscript operation used to index instances of a type FIDDLE() class SubscriptDecl : public CallableDecl diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index cd5f9b6e8..5d8caefa9 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -673,6 +673,14 @@ class DispatchKernelExpr : public HigherOrderInvokeExpr FIDDLE() Expr* dispatchSize; }; +FIDDLE() +class LambdaExpr : public Expr +{ + FIDDLE(...) + FIDDLE() ScopeDecl* paramScopeDecl; + FIDDLE() Stmt* bodyStmt; +}; + /// An express to mark its inner expression as an intended non-differential call. FIDDLE() class TreatAsDifferentiableExpr : public Expr diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 2112d452e..0a4b17cc0 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -568,6 +568,8 @@ void iterateASTWithLanguageServerFilter( { auto filter = [&](DeclBase* decl) { + if (as<ConstructorDecl>(decl) && decl->findModifier<SynthesizedModifier>()) + return false; return as<NamespaceDeclBase>(decl) || sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual) .pathInfo.foundPath.getUnownedSlice() diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp index 58c68c369..c7291f526 100644 --- a/source/slang/slang-ast-synthesis.cpp +++ b/source/slang/slang-ast-synthesis.cpp @@ -50,6 +50,13 @@ ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl*& outInd return stmt; } +Expr* ASTSynthesizer::emitThisExpr() +{ + auto varExpr = m_builder->create<ThisExpr>(); + varExpr->scope = getCurrentScope().m_scope; + return varExpr; +} + Expr* ASTSynthesizer::emitVarExpr(Name* name) { auto scope = getCurrentScope(); @@ -60,7 +67,7 @@ Expr* ASTSynthesizer::emitVarExpr(Name* name) return varExpr; } -Expr* ASTSynthesizer::emitVarExpr(VarDecl* varDecl) +Expr* ASTSynthesizer::emitVarExpr(VarDeclBase* varDecl) { auto varExpr = m_builder->create<VarExpr>(); varExpr->declRef = makeDeclRef<Decl>(varDecl); @@ -68,7 +75,7 @@ Expr* ASTSynthesizer::emitVarExpr(VarDecl* varDecl) return varExpr; } -Expr* ASTSynthesizer::emitVarExpr(VarDecl* var, Type* type) +Expr* ASTSynthesizer::emitVarExpr(VarDeclBase* var, Type* type) { auto expr = m_builder->create<VarExpr>(); expr->declRef = makeDeclRef<Decl>(var); @@ -86,6 +93,14 @@ Expr* ASTSynthesizer::emitVarExpr(DeclStmt* varStmt, Type* type) return expr; } +Expr* ASTSynthesizer::emitStaticTypeExpr(Type* type) +{ + auto expr = m_builder->create<SharedTypeExpr>(); + expr->type.type = m_builder->getTypeType(type); + expr->checked = true; + return expr; +} + Expr* ASTSynthesizer::emitIntConst(int value) { auto expr = m_builder->create<IntegerLiteralExpr>(); @@ -126,6 +141,14 @@ Expr* ASTSynthesizer::emitInvokeExpr(Expr* callee, List<Expr*>&& args) return rs; } +Expr* ASTSynthesizer::emitCtorInvokeExpr(Expr* callee, List<Expr*>&& args) +{ + auto rs = m_builder->create<ExplicitCtorInvokeExpr>(); + rs->functionExpr = callee; + rs->arguments = _Move(args); + return rs; +} + Expr* ASTSynthesizer::emitGenericAppExpr(Expr* genericExpr, List<Expr*>&& args) { auto rs = m_builder->create<GenericAppExpr>(); diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h index 591b7edde..c1072c705 100644 --- a/source/slang/slang-ast-synthesis.h +++ b/source/slang/slang-ast-synthesis.h @@ -113,10 +113,12 @@ public: Expr* emitPostfixExpr(UnownedStringSlice operatorToken, Expr* base); + Expr* emitThisExpr(); Expr* emitVarExpr(Name* name); - Expr* emitVarExpr(VarDecl* var); - Expr* emitVarExpr(VarDecl* var, Type* type); + Expr* emitVarExpr(VarDeclBase* var); + Expr* emitVarExpr(VarDeclBase* var, Type* type); Expr* emitVarExpr(DeclStmt* varStmt, Type* type); + Expr* emitStaticTypeExpr(Type* type); Expr* emitIntConst(int value); @@ -134,6 +136,7 @@ public: } Expr* emitInvokeExpr(Expr* callee, List<Expr*>&& args); + Expr* emitCtorInvokeExpr(Expr* callee, List<Expr*>&& args); Expr* emitGenericAppExpr(Expr* genericExpr, List<Expr*>&& args); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2c595dd4a..87f29d367 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -12,8 +12,10 @@ // * `slang-check-conversion.cpp` is responsible for the logic of handling type conversion/coercion #include "core/slang-char-util.h" +#include "slang-ast-decl.h" #include "slang-ast-natural-layout.h" #include "slang-ast-print.h" +#include "slang-ast-synthesis.h" #include "slang-lookup-spirv.h" #include "slang-lookup.h" @@ -3125,15 +3127,116 @@ Expr* SemanticsExprVisitor::visitVarExpr(VarExpr* expr) return expr; } + Expr* resultExpr = expr; + if (lookupResult.isValid()) { - return createLookupResultExpr(expr->name, lookupResult, nullptr, expr->loc, expr); + auto lookupResultExpr = + createLookupResultExpr(expr->name, lookupResult, nullptr, expr->loc, expr); + if (m_parentLambdaExpr) + return maybeRegisterLambdaCapture(lookupResultExpr); + return lookupResultExpr; } if (!diagnosed) getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->name); - return expr; + return resultExpr; +} + +Expr* SemanticsExprVisitor::maybeRegisterLambdaCapture(Expr* exprIn) +{ + if (auto memberExpr = as<MemberExpr>(exprIn)) + { + memberExpr->baseExpression = maybeRegisterLambdaCapture(memberExpr->baseExpression); + return memberExpr; + } + else if (auto subscriptExpr = as<IndexExpr>(exprIn)) + { + subscriptExpr->baseExpression = maybeRegisterLambdaCapture(subscriptExpr->baseExpression); + return subscriptExpr; + } + auto thisExpr = as<ThisExpr>(exprIn); + auto varExpr = as<VarExpr>(exprIn); + if (!thisExpr && !varExpr) + return exprIn; + + Decl* srcDecl = nullptr; + if (varExpr) + srcDecl = as<VarDeclBase>(varExpr->declRef.getDecl()); + else + { + // If we see a `this` expression inside a lambda, it is referencing the + // `this` value of the parent type of the outer function, not the lambda struct + // itself. Since we don't have a VarDecl representing `this`, we will just use + // the AggTypeDecl as the key to register in the lambda capture map. + auto thisTypeDecl = isDeclRefTypeOf<Decl>(thisExpr->type.type); + if (!thisTypeDecl) + return exprIn; + srcDecl = thisTypeDecl.getDecl(); + } + + if (!srcDecl) + return exprIn; + + if (as<VarDeclBase>(srcDecl) && isGlobalDecl(srcDecl)) + return exprIn; + + auto lambdaScope = m_parentLambdaExpr->paramScopeDecl; + bool isDefinedInLambdaScope = false; + for (auto parentDecl = srcDecl->parentDecl; parentDecl; parentDecl = parentDecl->parentDecl) + { + if (parentDecl == lambdaScope) + { + isDefinedInLambdaScope = true; + break; + } + } + if (isDefinedInLambdaScope) + return exprIn; + + // We are referencing something that doesn't belong to the lambda scope, we need to + // capture it in the current lambda function. + + // If we have already captured the variable, just return the captured variable. + VarDeclBase* capturedVarDecl = nullptr; + if (!m_mapSrcDeclToCapturedLambdaDecl->tryGetValue(srcDecl, capturedVarDecl)) + { + // If not already captured, create a captured variable in the lambda struct decl. + capturedVarDecl = m_astBuilder->create<VarDecl>(); + capturedVarDecl->nameAndLoc = srcDecl->nameAndLoc; + SLANG_ASSERT(exprIn->type.type); + capturedVarDecl->type.type = exprIn->type.type; + m_mapSrcDeclToCapturedLambdaDecl->add(srcDecl, capturedVarDecl); + m_parentLambdaDecl->addMember(capturedVarDecl); + + // Is captured value NonCopyable? If so, it needs to be an error. + if (isNonCopyableType(capturedVarDecl->type.type)) + { + getSink()->diagnose( + exprIn, + Diagnostics::nonCopyableTypeCapturedInLambda, + capturedVarDecl->type.type); + } + } + + // Return a VarExpr referencing the capturedVarDecl. + auto thisLambdaExpr = m_astBuilder->create<ThisExpr>(); + thisLambdaExpr->scope = m_parentLambdaDecl->ownedScope; + thisLambdaExpr->type = QualType(DeclRefType::create(m_astBuilder, m_parentLambdaDecl)); + thisLambdaExpr->checked = true; + + auto resultMemberExpr = m_astBuilder->create<MemberExpr>(); + resultMemberExpr->declRef = capturedVarDecl; + resultMemberExpr->baseExpression = thisLambdaExpr; + resultMemberExpr->type = exprIn->type; + resultMemberExpr->loc = exprIn->loc; + + // For captured variables, we need to set the type to be a non-lvalue to prevent + // lambda expression body from mutating their values. + resultMemberExpr->type.isLeftValue = false; + resultMemberExpr->checked = true; + return resultMemberExpr; } Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) @@ -4075,6 +4178,102 @@ error:; return expr; } +Expr* SemanticsExprVisitor::visitLambdaExpr(LambdaExpr* lambdaExpr) +{ + ASTSynthesizer synthesizer = ASTSynthesizer(m_astBuilder, getNamePool()); + synthesizer.pushContainerScope(m_outerScope->containerDecl); + + Dictionary<Decl*, VarDeclBase*> mapSrcDeclToCapturedDecl; + ensureAllDeclsRec(lambdaExpr->paramScopeDecl, DeclCheckState::DefinitionChecked); + LambdaDecl* lambdaStructDecl = m_astBuilder->create<LambdaDecl>(); + auto subContext = withParentLambdaExpr(lambdaExpr, lambdaStructDecl, &mapSrcDeclToCapturedDecl); + addModifier(lambdaStructDecl, m_astBuilder->create<SynthesizedModifier>()); + m_parentFunc->addMember(lambdaStructDecl); + synthesizer.pushScopeForContainer(lambdaStructDecl); + lambdaStructDecl->loc = lambdaExpr->loc; + StringBuilder nameBuilder; + nameBuilder << "_slang_Lambda_"; + if (m_parentFunc) + { + nameBuilder << getText(m_parentFunc->getName()); + } + nameBuilder << "_"; + nameBuilder << m_parentFunc->members.getCount(); + auto name = getName(nameBuilder.getBuffer()); + lambdaStructDecl->nameAndLoc.name = name; + lambdaStructDecl->nameAndLoc.loc = lambdaExpr->loc; + + auto funcDecl = m_astBuilder->create<FuncDecl>(); + synthesizer.pushScopeForContainer(funcDecl); + funcDecl->loc = lambdaExpr->loc; + funcDecl->nameAndLoc.name = getName("()"); + lambdaStructDecl->addMember(funcDecl); + lambdaStructDecl->funcDecl = funcDecl; + addModifier(funcDecl, m_astBuilder->create<SynthesizedModifier>()); + + // As we check the body, we will fill in the result type when we visit `ReturnStmt`. + dispatchStmt(lambdaExpr->bodyStmt, subContext); + + // If the lambda has no return type, we will set it to `void`. + if (!funcDecl->returnType.type) + funcDecl->returnType.type = m_astBuilder->getVoidType(); + + synthesizer.popScope(); + synthesizer.popScope(); + + funcDecl->body = lambdaExpr->bodyStmt; + for (auto param : lambdaExpr->paramScopeDecl->members) + { + funcDecl->addMember(param); + } + + // LambdaDecl should inherit from `IFunc<>`. + if (funcDecl->returnType.type) + { + auto genApp = m_astBuilder->create<GenericAppExpr>(); + genApp->functionExpr = synthesizer.emitVarExpr(getName("IFunc")); + auto returnTypeExp = synthesizer.emitStaticTypeExpr(funcDecl->returnType.type); + genApp->arguments.add(returnTypeExp); + for (auto param : getMembersOfType<ParamDecl>(m_astBuilder, lambdaExpr->paramScopeDecl)) + { + auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param); + auto paramTypeExp = synthesizer.emitStaticTypeExpr(paramType); + genApp->arguments.add(paramTypeExp); + } + auto inheritanceDecl = m_astBuilder->create<InheritanceDecl>(); + inheritanceDecl->base.exp = genApp; + lambdaStructDecl->addMember(inheritanceDecl); + } + + // Synthesizer the ctor signature, and `IFunc` witness. + ensureDecl(lambdaStructDecl, DeclCheckState::AttributesChecked); + + // Return an expr that represents `SynthesizedLambdaStruct.__init(captured_args...)`. + List<Expr*> args; + Dictionary<VarDeclBase*, Decl*> mapCapturedDeclToSrcDecl; + for (auto kv : mapSrcDeclToCapturedDecl) + { + mapCapturedDeclToSrcDecl[kv.second] = kv.first; + } + for (auto capturedField : getMembersOfType<VarDecl>(m_astBuilder, lambdaStructDecl)) + { + auto src = mapCapturedDeclToSrcDecl[capturedField.getDecl()]; + if (auto srcVarDecl = as<VarDeclBase>(src)) + { + args.add(synthesizer.emitVarExpr(srcVarDecl)); + } + else + { + args.add(synthesizer.emitThisExpr()); + } + } + auto resultLambdaObj = synthesizer.emitCtorInvokeExpr( + synthesizer.emitStaticTypeExpr(DeclRefType::create(m_astBuilder, lambdaStructDecl)), + _Move(args)); + auto checkedResultExpr = dispatchExpr(resultLambdaObj, *this); + return checkedResultExpr; +} + void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr) { auto checkedInvokeExpr = as<InvokeExpr>(invokeExpr); @@ -5039,6 +5238,10 @@ Expr* SemanticsExprVisitor::visitThisExpr(ThisExpr* expr) else if (auto typeOrExtensionDecl = as<AggTypeDeclBase>(containerDecl)) { expr->type.type = calcThisType(makeDeclRef(typeOrExtensionDecl)); + if (m_parentLambdaExpr) + { + return maybeRegisterLambdaCapture(expr); + } return expr; } #if 0 diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index c9406cd1f..1e93d8381 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -989,6 +989,18 @@ public: return result; } + SemanticsContext withParentLambdaExpr( + LambdaExpr* expr, + LambdaDecl* decl, + Dictionary<Decl*, VarDeclBase*>* mapSrcDeclToCapturedLambdaDecl) + { + SemanticsContext result(*this); + result.m_parentLambdaExpr = expr; + result.m_mapSrcDeclToCapturedLambdaDecl = mapSrcDeclToCapturedLambdaDecl; + result.m_parentLambdaDecl = decl; + return result; + } + /// Information for tracking one or more outer statements. /// /// During checking of statements, we need to track what @@ -1161,6 +1173,13 @@ protected: ExpandExpr* m_parentExpandExpr = nullptr; OrderedHashSet<Type*>* m_capturedTypePacks = nullptr; + + // If we are checking inside a lambda expression, we need + // to track the referenced variables that should be captured + // by the lambda. + LambdaExpr* m_parentLambdaExpr = nullptr; + LambdaDecl* m_parentLambdaDecl = nullptr; + Dictionary<Decl*, VarDeclBase*>* m_mapSrcDeclToCapturedLambdaDecl = nullptr; }; struct OuterScopeContextRAII @@ -2900,8 +2919,11 @@ public: Expr* visitEachExpr(EachExpr* expr); + Expr* visitLambdaExpr(LambdaExpr* expr); + void maybeCheckKnownBuiltinInvocation(Expr* invokeExpr); + Expr* maybeRegisterLambdaCapture(Expr* exprIn); // // Some syntax nodes should not occur in the concrete input syntax, // and will only appear *after* checking is complete. We need to diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 4547281e1..6a713f412 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -727,6 +727,27 @@ Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> paramDeclRef) return paramType; } +Type* getParamTypeWithDirectionWrapper(ASTBuilder* astBuilder, DeclRef<VarDeclBase> paramDeclRef) +{ + auto result = getParamType(astBuilder, paramDeclRef); + auto direction = getParameterDirection(paramDeclRef.getDecl()); + switch (direction) + { + case kParameterDirection_In: + return result; + case kParameterDirection_ConstRef: + return astBuilder->getConstRefType(result); + case kParameterDirection_Out: + return astBuilder->getOutType(result); + case kParameterDirection_InOut: + return astBuilder->getInOutType(result); + case kParameterDirection_Ref: + return astBuilder->getRefType(result, AddressSpace::Generic); + default: + return result; + } +} + void Module::_collectShaderParams() { // We are going to walk the global declarations in the body of the diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 9525f71c9..0e5ed92aa 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -547,9 +547,19 @@ void SemanticsStmtVisitor::visitDiscardStmt(DiscardStmt*) void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) { auto function = getParentFunc(); + Type* returnType = nullptr; + Type* expectedReturnType = nullptr; + if (m_parentLambdaDecl) + { + expectedReturnType = m_parentLambdaDecl->funcDecl->returnType.type; + } + else if (function) + { + expectedReturnType = function->returnType.type; + } if (!stmt->expression) { - if (function && !function->returnType.equals(m_astBuilder->getVoidType()) && + if (expectedReturnType && !expectedReturnType->equals(m_astBuilder->getVoidType()) && !as<ConstructorDecl>(function)) { getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); @@ -558,24 +568,31 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) else { stmt->expression = CheckTerm(stmt->expression); + returnType = stmt->expression->type.type; if (!stmt->expression->type->equals(m_astBuilder->getErrorType())) { - if (function) + if (!m_parentLambdaExpr && expectedReturnType) { stmt->expression = - coerce(CoercionSite::Return, function->returnType.Ptr(), stmt->expression); - } - else - { - // TODO(tfoley): this case currently gets triggered for member functions, - // which aren't being checked consistently (because of the whole symbol - // table idea getting in the way). - - // getSink()->diagnose(stmt, - // Diagnostics::unimplemented, "case for return stmt"); + coerce(CoercionSite::Return, expectedReturnType, stmt->expression); } } } + if (m_parentLambdaDecl) + { + if (!returnType) + returnType = m_astBuilder->getVoidType(); + if (!m_parentLambdaDecl->funcDecl->returnType.type) + m_parentLambdaDecl->funcDecl->returnType.type = returnType; + if (!m_parentLambdaDecl->funcDecl->returnType.type->equals(returnType)) + { + getSink()->diagnose( + stmt, + Diagnostics::returnTypeMismatchInsideLambda, + returnType, + m_parentLambdaDecl->funcDecl->returnType.type); + } + } if (FindOuterStmt<DeferStmt>()) { diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 5dcc41bd8..ac4008b7f 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -932,7 +932,18 @@ DIAGNOSTIC( continueInsideDefer, "'continue' must not appear inside a defer statement.") DIAGNOSTIC(30110, Error, returnInsideDefer, "'return' must not appear inside a defer statement.") +DIAGNOSTIC( + 30111, + Error, + returnTypeMismatchInsideLambda, + "returned values must have the same type among all 'return' statements inside a lambda " + "expression: returned '$0' here, but '$1' previously.") +DIAGNOSTIC( + 30112, + Error, + nonCopyableTypeCapturedInLambda, + "cannot capture non-copyable type '$0' in a lambda expression.") // Include DIAGNOSTIC( diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 7375756f5..b08388aee 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -62,9 +62,10 @@ static Index _getDeclNameLength(Name* name, Decl* optionalDecl = nullptr) return 0; } // HACK: our __subscript functions currently have a name "operator[]". + // and our operator() functions have a name "()". // Since this isn't the name that actually appears in user's code, // we need to shorten its reported length to 1 for now. - if (name->text.startsWith("operator")) + if (name->text.startsWith("operator") || name->text.startsWith("()")) { return 1; } @@ -75,7 +76,7 @@ bool _isLocInRange(ASTLookupContext* context, SourceLoc loc, Int length) { auto humaneLoc = context->sourceManager->getHumaneLoc(loc, SourceLocType::Actual); return humaneLoc.line == context->line && context->col >= humaneLoc.column && - context->col <= humaneLoc.column + length && + context->col < humaneLoc.column + length && humaneLoc.pathInfo.foundPath.getUnownedSlice().endsWithCaseInsensitive( context->sourceFileName); } @@ -87,7 +88,7 @@ bool _isLocInRange(ASTLookupContext* context, SourceLoc start, SourceLoc end) Loc s{startLoc.line, startLoc.column}; Loc e{endLoc.line, endLoc.column}; Loc c{context->line, context->col}; - return s <= c && c <= e && + return s <= c && c < e && startLoc.pathInfo.foundPath.getUnownedSlice().endsWithCaseInsensitive( context->sourceFileName); } @@ -667,17 +668,27 @@ bool _findAstNodeImpl(ASTLookupContext& context, SyntaxNode* node) { if (_isLocInRange(&context, decl->nameAndLoc.loc, _getDeclNameLength(decl->getName()))) { + bool isRealDeclName = true; for (auto modifier : decl->modifiers) { if (as<SynthesizedModifier>(modifier)) - return false; + { + isRealDeclName = false; + break; + } if (as<ImplicitParameterGroupElementTypeModifier>(modifier)) - return false; + { + isRealDeclName = false; + break; + } + } + if (isRealDeclName) + { + ASTLookupResult result; + result.path = context.nodePath; + context.results.add(_Move(result)); + return true; } - ASTLookupResult result; - result.path = context.nodePath; - context.results.add(_Move(result)); - return true; } } if (auto funcDecl = as<FunctionDeclBase>(node)) diff --git a/source/slang/slang-language-server-semantic-tokens.cpp b/source/slang/slang-language-server-semantic-tokens.cpp index 5d0e41ecb..54528d7fc 100644 --- a/source/slang/slang-language-server-semantic-tokens.cpp +++ b/source/slang/slang-language-server-semantic-tokens.cpp @@ -39,6 +39,20 @@ SemanticToken _createSemanticToken(SourceManager* manager, SourceLoc loc, Name* return token; } +// We don't want to semantic highlight a synthetic name, like $init, () etc, +// because they aren't actually appearing in the source code. +bool isHighlightableName(Name* name) +{ + if (!name) + return false; + for (auto ch : name->text) + { + if (!CharUtil::isAlphaOrDigit(ch) && ch != '_') + return false; + } + return true; +} + List<SemanticToken> getSemanticTokens( Linkage* linkage, Module* module, @@ -67,7 +81,7 @@ List<SemanticToken> getSemanticTokens( return; if (!name) name = declRef.getDecl()->getName(); - if (!name) + if (!isHighlightableName(name)) return; // Don't look at the expr if it is defined in a different file. if (!manager->getHumaneLoc(loc, SourceLocType::Actual) @@ -221,7 +235,7 @@ List<SemanticToken> getSemanticTokens( } else if (auto funcDecl = as<FuncDecl>(node)) { - if (funcDecl->getName()) + if (isHighlightableName(funcDecl->getName())) { SemanticToken token = _createSemanticToken(manager, funcDecl->getNameLoc(), funcDecl->getName()); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 2b94a1fa7..990090537 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -4268,6 +4268,12 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitLambdaExpr(LambdaExpr*) + { + SLANG_UNEXPECTED("a valid ast should not contain an LambdaExpr."); + UNREACHABLE_RETURN(LoweredValInfo()); + } + LoweredValInfo visitSPIRVAsmExpr(SPIRVAsmExpr* expr) { // Although the surface syntax can have an empty ASM block, the IR asm diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index a8573c909..c17a086a7 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -7149,6 +7149,36 @@ static bool tryParseExpression(Parser* parser, Expr*& outExpr, TokenType tokenTy return false; } +static Expr* parseLambdaExpr(Parser* parser) +{ + auto lambdaExpr = parser->astBuilder->create<LambdaExpr>(); + parser->ReadToken(TokenType::LParent); + lambdaExpr->paramScopeDecl = parser->astBuilder->create<ScopeDecl>(); + parser->pushScopeAndSetParent(lambdaExpr->paramScopeDecl); + while (!AdvanceIfMatch(parser, MatchedTokenType::Parentheses)) + { + AddMember(lambdaExpr->paramScopeDecl, parser->ParseParameter()); + if (AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); + } + parser->FillPosition(lambdaExpr); + parser->ReadToken(TokenType::DoubleRightArrow); + if (parser->LookAheadToken(TokenType::LBrace)) + { + lambdaExpr->bodyStmt = parser->parseBlockStatement(); + } + else + { + auto returnStmt = parser->astBuilder->create<ReturnStmt>(); + parser->FillPosition(returnStmt); + returnStmt->expression = parser->ParseArgExpr(); + lambdaExpr->bodyStmt = returnStmt; + } + parser->PopScope(); + return lambdaExpr; +} + static Expr* parseAtomicExpr(Parser* parser) { switch (peekTokenType(parser)) @@ -7161,12 +7191,22 @@ static Expr* parseAtomicExpr(Parser* parser) // Either: // - parenthesized expression `(exp)` // - cast `(type) exp` + // - lambda expressions (paramList)=>x // // Proper disambiguation requires mixing up parsing // and semantic checking (which we should do eventually) // but for now we will follow some heuristics. case TokenType::LParent: { + // Disambiguate between a lambda expression and other cases. + auto tokenReader = parser->tokenReader; + SkipBalancedToken(&tokenReader); + auto nextTokenAfterParent = tokenReader.peekTokenType(); + if (nextTokenAfterParent == TokenType::DoubleRightArrow) + { + return parseLambdaExpr(parser); + } + Token openParen = parser->ReadToken(TokenType::LParent); // Only handles cases of `(type)`, where type is a single identifier, diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 8d78872a6..97f829b59 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -146,6 +146,11 @@ inline Type* getType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> declRef) /// modifier list and return a ModifiedType if such modifiers exist. Type* getParamType(ASTBuilder* astBuilder, DeclRef<VarDeclBase> paramDeclRef); +/// Get the parameter type, wrapped with `Out<>`, `InOut<>` or `Ref<>` if the parameter has +/// an non-trivial direction. +Type* getParamTypeWithDirectionWrapper(ASTBuilder* astBuilder, DeclRef<VarDeclBase> paramDeclRef); + + inline SubstExpr<Expr> getInitExpr(ASTBuilder* astBuilder, DeclRef<VarDeclBase> declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->initExpr); |
