diff options
| author | Yong He <yonghe@outlook.com> | 2025-04-30 14:17:45 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-30 14:17:45 -0700 |
| commit | 7f1df9d0b31413e59846cc955d2a955d3f361e2a (patch) | |
| tree | 8cfcb7b6dde96f90e9581f9a904a25158a7358cb /source/slang/slang-check-expr.cpp | |
| parent | 678de6547bc8cac15e31de30b400e9a3b45c216f (diff) | |
Initial support for immutable lambda expressions. (#6914)
* Initial support for immutable lambda expressions.
* More diagnostics, and langauge server fix.
* Language server fix.
* Fix bug identified in review.
* Add expected result.
* Update expected result.
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 207 |
1 files changed, 205 insertions, 2 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2c595dd4a..87f29d367 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -12,8 +12,10 @@ // * `slang-check-conversion.cpp` is responsible for the logic of handling type conversion/coercion #include "core/slang-char-util.h" +#include "slang-ast-decl.h" #include "slang-ast-natural-layout.h" #include "slang-ast-print.h" +#include "slang-ast-synthesis.h" #include "slang-lookup-spirv.h" #include "slang-lookup.h" @@ -3125,15 +3127,116 @@ Expr* SemanticsExprVisitor::visitVarExpr(VarExpr* expr) return expr; } + Expr* resultExpr = expr; + if (lookupResult.isValid()) { - return createLookupResultExpr(expr->name, lookupResult, nullptr, expr->loc, expr); + auto lookupResultExpr = + createLookupResultExpr(expr->name, lookupResult, nullptr, expr->loc, expr); + if (m_parentLambdaExpr) + return maybeRegisterLambdaCapture(lookupResultExpr); + return lookupResultExpr; } if (!diagnosed) getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->name); - return expr; + return resultExpr; +} + +Expr* SemanticsExprVisitor::maybeRegisterLambdaCapture(Expr* exprIn) +{ + if (auto memberExpr = as<MemberExpr>(exprIn)) + { + memberExpr->baseExpression = maybeRegisterLambdaCapture(memberExpr->baseExpression); + return memberExpr; + } + else if (auto subscriptExpr = as<IndexExpr>(exprIn)) + { + subscriptExpr->baseExpression = maybeRegisterLambdaCapture(subscriptExpr->baseExpression); + return subscriptExpr; + } + auto thisExpr = as<ThisExpr>(exprIn); + auto varExpr = as<VarExpr>(exprIn); + if (!thisExpr && !varExpr) + return exprIn; + + Decl* srcDecl = nullptr; + if (varExpr) + srcDecl = as<VarDeclBase>(varExpr->declRef.getDecl()); + else + { + // If we see a `this` expression inside a lambda, it is referencing the + // `this` value of the parent type of the outer function, not the lambda struct + // itself. Since we don't have a VarDecl representing `this`, we will just use + // the AggTypeDecl as the key to register in the lambda capture map. + auto thisTypeDecl = isDeclRefTypeOf<Decl>(thisExpr->type.type); + if (!thisTypeDecl) + return exprIn; + srcDecl = thisTypeDecl.getDecl(); + } + + if (!srcDecl) + return exprIn; + + if (as<VarDeclBase>(srcDecl) && isGlobalDecl(srcDecl)) + return exprIn; + + auto lambdaScope = m_parentLambdaExpr->paramScopeDecl; + bool isDefinedInLambdaScope = false; + for (auto parentDecl = srcDecl->parentDecl; parentDecl; parentDecl = parentDecl->parentDecl) + { + if (parentDecl == lambdaScope) + { + isDefinedInLambdaScope = true; + break; + } + } + if (isDefinedInLambdaScope) + return exprIn; + + // We are referencing something that doesn't belong to the lambda scope, we need to + // capture it in the current lambda function. + + // If we have already captured the variable, just return the captured variable. + VarDeclBase* capturedVarDecl = nullptr; + if (!m_mapSrcDeclToCapturedLambdaDecl->tryGetValue(srcDecl, capturedVarDecl)) + { + // If not already captured, create a captured variable in the lambda struct decl. + capturedVarDecl = m_astBuilder->create<VarDecl>(); + capturedVarDecl->nameAndLoc = srcDecl->nameAndLoc; + SLANG_ASSERT(exprIn->type.type); + capturedVarDecl->type.type = exprIn->type.type; + m_mapSrcDeclToCapturedLambdaDecl->add(srcDecl, capturedVarDecl); + m_parentLambdaDecl->addMember(capturedVarDecl); + + // Is captured value NonCopyable? If so, it needs to be an error. + if (isNonCopyableType(capturedVarDecl->type.type)) + { + getSink()->diagnose( + exprIn, + Diagnostics::nonCopyableTypeCapturedInLambda, + capturedVarDecl->type.type); + } + } + + // Return a VarExpr referencing the capturedVarDecl. + auto thisLambdaExpr = m_astBuilder->create<ThisExpr>(); + thisLambdaExpr->scope = m_parentLambdaDecl->ownedScope; + thisLambdaExpr->type = QualType(DeclRefType::create(m_astBuilder, m_parentLambdaDecl)); + thisLambdaExpr->checked = true; + + auto resultMemberExpr = m_astBuilder->create<MemberExpr>(); + resultMemberExpr->declRef = capturedVarDecl; + resultMemberExpr->baseExpression = thisLambdaExpr; + resultMemberExpr->type = exprIn->type; + resultMemberExpr->loc = exprIn->loc; + + // For captured variables, we need to set the type to be a non-lvalue to prevent + // lambda expression body from mutating their values. + resultMemberExpr->type.isLeftValue = false; + resultMemberExpr->checked = true; + return resultMemberExpr; } Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) @@ -4075,6 +4178,102 @@ error:; return expr; } +Expr* SemanticsExprVisitor::visitLambdaExpr(LambdaExpr* lambdaExpr) +{ + ASTSynthesizer synthesizer = ASTSynthesizer(m_astBuilder, getNamePool()); + synthesizer.pushContainerScope(m_outerScope->containerDecl); + + Dictionary<Decl*, VarDeclBase*> mapSrcDeclToCapturedDecl; + ensureAllDeclsRec(lambdaExpr->paramScopeDecl, DeclCheckState::DefinitionChecked); + LambdaDecl* lambdaStructDecl = m_astBuilder->create<LambdaDecl>(); + auto subContext = withParentLambdaExpr(lambdaExpr, lambdaStructDecl, &mapSrcDeclToCapturedDecl); + addModifier(lambdaStructDecl, m_astBuilder->create<SynthesizedModifier>()); + m_parentFunc->addMember(lambdaStructDecl); + synthesizer.pushScopeForContainer(lambdaStructDecl); + lambdaStructDecl->loc = lambdaExpr->loc; + StringBuilder nameBuilder; + nameBuilder << "_slang_Lambda_"; + if (m_parentFunc) + { + nameBuilder << getText(m_parentFunc->getName()); + } + nameBuilder << "_"; + nameBuilder << m_parentFunc->members.getCount(); + auto name = getName(nameBuilder.getBuffer()); + lambdaStructDecl->nameAndLoc.name = name; + lambdaStructDecl->nameAndLoc.loc = lambdaExpr->loc; + + auto funcDecl = m_astBuilder->create<FuncDecl>(); + synthesizer.pushScopeForContainer(funcDecl); + funcDecl->loc = lambdaExpr->loc; + funcDecl->nameAndLoc.name = getName("()"); + lambdaStructDecl->addMember(funcDecl); + lambdaStructDecl->funcDecl = funcDecl; + addModifier(funcDecl, m_astBuilder->create<SynthesizedModifier>()); + + // As we check the body, we will fill in the result type when we visit `ReturnStmt`. + dispatchStmt(lambdaExpr->bodyStmt, subContext); + + // If the lambda has no return type, we will set it to `void`. + if (!funcDecl->returnType.type) + funcDecl->returnType.type = m_astBuilder->getVoidType(); + + synthesizer.popScope(); + synthesizer.popScope(); + + funcDecl->body = lambdaExpr->bodyStmt; + for (auto param : lambdaExpr->paramScopeDecl->members) + { + funcDecl->addMember(param); + } + + // LambdaDecl should inherit from `IFunc<>`. + if (funcDecl->returnType.type) + { + auto genApp = m_astBuilder->create<GenericAppExpr>(); + genApp->functionExpr = synthesizer.emitVarExpr(getName("IFunc")); + auto returnTypeExp = synthesizer.emitStaticTypeExpr(funcDecl->returnType.type); + genApp->arguments.add(returnTypeExp); + for (auto param : getMembersOfType<ParamDecl>(m_astBuilder, lambdaExpr->paramScopeDecl)) + { + auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param); + auto paramTypeExp = synthesizer.emitStaticTypeExpr(paramType); + genApp->arguments.add(paramTypeExp); + } + auto inheritanceDecl = m_astBuilder->create<InheritanceDecl>(); + inheritanceDecl->base.exp = genApp; + lambdaStructDecl->addMember(inheritanceDecl); + } + + // Synthesizer the ctor signature, and `IFunc` witness. + ensureDecl(lambdaStructDecl, DeclCheckState::AttributesChecked); + + // Return an expr that represents `SynthesizedLambdaStruct.__init(captured_args...)`. + List<Expr*> args; + Dictionary<VarDeclBase*, Decl*> mapCapturedDeclToSrcDecl; + for (auto kv : mapSrcDeclToCapturedDecl) + { + mapCapturedDeclToSrcDecl[kv.second] = kv.first; + } + for (auto capturedField : getMembersOfType<VarDecl>(m_astBuilder, lambdaStructDecl)) + { + auto src = mapCapturedDeclToSrcDecl[capturedField.getDecl()]; + if (auto srcVarDecl = as<VarDeclBase>(src)) + { + args.add(synthesizer.emitVarExpr(srcVarDecl)); + } + else + { + args.add(synthesizer.emitThisExpr()); + } + } + auto resultLambdaObj = synthesizer.emitCtorInvokeExpr( + synthesizer.emitStaticTypeExpr(DeclRefType::create(m_astBuilder, lambdaStructDecl)), + _Move(args)); + auto checkedResultExpr = dispatchExpr(resultLambdaObj, *this); + return checkedResultExpr; +} + void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr) { auto checkedInvokeExpr = as<InvokeExpr>(invokeExpr); @@ -5039,6 +5238,10 @@ Expr* SemanticsExprVisitor::visitThisExpr(ThisExpr* expr) else if (auto typeOrExtensionDecl = as<AggTypeDeclBase>(containerDecl)) { expr->type.type = calcThisType(makeDeclRef(typeOrExtensionDecl)); + if (m_parentLambdaExpr) + { + return maybeRegisterLambdaCapture(expr); + } return expr; } #if 0 |
