diff options
Diffstat (limited to 'source/slang/slang-check-stmt.cpp')
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 41 |
1 files changed, 29 insertions, 12 deletions
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 9525f71c9..0e5ed92aa 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -547,9 +547,19 @@ void SemanticsStmtVisitor::visitDiscardStmt(DiscardStmt*) void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) { auto function = getParentFunc(); + Type* returnType = nullptr; + Type* expectedReturnType = nullptr; + if (m_parentLambdaDecl) + { + expectedReturnType = m_parentLambdaDecl->funcDecl->returnType.type; + } + else if (function) + { + expectedReturnType = function->returnType.type; + } if (!stmt->expression) { - if (function && !function->returnType.equals(m_astBuilder->getVoidType()) && + if (expectedReturnType && !expectedReturnType->equals(m_astBuilder->getVoidType()) && !as<ConstructorDecl>(function)) { getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); @@ -558,24 +568,31 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) else { stmt->expression = CheckTerm(stmt->expression); + returnType = stmt->expression->type.type; if (!stmt->expression->type->equals(m_astBuilder->getErrorType())) { - if (function) + if (!m_parentLambdaExpr && expectedReturnType) { stmt->expression = - coerce(CoercionSite::Return, function->returnType.Ptr(), stmt->expression); - } - else - { - // TODO(tfoley): this case currently gets triggered for member functions, - // which aren't being checked consistently (because of the whole symbol - // table idea getting in the way). - - // getSink()->diagnose(stmt, - // Diagnostics::unimplemented, "case for return stmt"); + coerce(CoercionSite::Return, expectedReturnType, stmt->expression); } } } + if (m_parentLambdaDecl) + { + if (!returnType) + returnType = m_astBuilder->getVoidType(); + if (!m_parentLambdaDecl->funcDecl->returnType.type) + m_parentLambdaDecl->funcDecl->returnType.type = returnType; + if (!m_parentLambdaDecl->funcDecl->returnType.type->equals(returnType)) + { + getSink()->diagnose( + stmt, + Diagnostics::returnTypeMismatchInsideLambda, + returnType, + m_parentLambdaDecl->funcDecl->returnType.type); + } + } if (FindOuterStmt<DeferStmt>()) { |
