summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-04-30 14:17:45 -0700
committerGitHub <noreply@github.com>2025-04-30 14:17:45 -0700
commit7f1df9d0b31413e59846cc955d2a955d3f361e2a (patch)
tree8cfcb7b6dde96f90e9581f9a904a25158a7358cb /source/slang/slang-check-expr.cpp
parent678de6547bc8cac15e31de30b400e9a3b45c216f (diff)
Initial support for immutable lambda expressions. (#6914)
* Initial support for immutable lambda expressions. * More diagnostics, and langauge server fix. * Language server fix. * Fix bug identified in review. * Add expected result. * Update expected result.
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp207
1 files changed, 205 insertions, 2 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 2c595dd4a..87f29d367 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -12,8 +12,10 @@
// * `slang-check-conversion.cpp` is responsible for the logic of handling type conversion/coercion
#include "core/slang-char-util.h"
+#include "slang-ast-decl.h"
#include "slang-ast-natural-layout.h"
#include "slang-ast-print.h"
+#include "slang-ast-synthesis.h"
#include "slang-lookup-spirv.h"
#include "slang-lookup.h"
@@ -3125,15 +3127,116 @@ Expr* SemanticsExprVisitor::visitVarExpr(VarExpr* expr)
return expr;
}
+ Expr* resultExpr = expr;
+
if (lookupResult.isValid())
{
- return createLookupResultExpr(expr->name, lookupResult, nullptr, expr->loc, expr);
+ auto lookupResultExpr =
+ createLookupResultExpr(expr->name, lookupResult, nullptr, expr->loc, expr);
+ if (m_parentLambdaExpr)
+ return maybeRegisterLambdaCapture(lookupResultExpr);
+ return lookupResultExpr;
}
if (!diagnosed)
getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->name);
- return expr;
+ return resultExpr;
+}
+
+Expr* SemanticsExprVisitor::maybeRegisterLambdaCapture(Expr* exprIn)
+{
+ if (auto memberExpr = as<MemberExpr>(exprIn))
+ {
+ memberExpr->baseExpression = maybeRegisterLambdaCapture(memberExpr->baseExpression);
+ return memberExpr;
+ }
+ else if (auto subscriptExpr = as<IndexExpr>(exprIn))
+ {
+ subscriptExpr->baseExpression = maybeRegisterLambdaCapture(subscriptExpr->baseExpression);
+ return subscriptExpr;
+ }
+ auto thisExpr = as<ThisExpr>(exprIn);
+ auto varExpr = as<VarExpr>(exprIn);
+ if (!thisExpr && !varExpr)
+ return exprIn;
+
+ Decl* srcDecl = nullptr;
+ if (varExpr)
+ srcDecl = as<VarDeclBase>(varExpr->declRef.getDecl());
+ else
+ {
+ // If we see a `this` expression inside a lambda, it is referencing the
+ // `this` value of the parent type of the outer function, not the lambda struct
+ // itself. Since we don't have a VarDecl representing `this`, we will just use
+ // the AggTypeDecl as the key to register in the lambda capture map.
+ auto thisTypeDecl = isDeclRefTypeOf<Decl>(thisExpr->type.type);
+ if (!thisTypeDecl)
+ return exprIn;
+ srcDecl = thisTypeDecl.getDecl();
+ }
+
+ if (!srcDecl)
+ return exprIn;
+
+ if (as<VarDeclBase>(srcDecl) && isGlobalDecl(srcDecl))
+ return exprIn;
+
+ auto lambdaScope = m_parentLambdaExpr->paramScopeDecl;
+ bool isDefinedInLambdaScope = false;
+ for (auto parentDecl = srcDecl->parentDecl; parentDecl; parentDecl = parentDecl->parentDecl)
+ {
+ if (parentDecl == lambdaScope)
+ {
+ isDefinedInLambdaScope = true;
+ break;
+ }
+ }
+ if (isDefinedInLambdaScope)
+ return exprIn;
+
+ // We are referencing something that doesn't belong to the lambda scope, we need to
+ // capture it in the current lambda function.
+
+ // If we have already captured the variable, just return the captured variable.
+ VarDeclBase* capturedVarDecl = nullptr;
+ if (!m_mapSrcDeclToCapturedLambdaDecl->tryGetValue(srcDecl, capturedVarDecl))
+ {
+ // If not already captured, create a captured variable in the lambda struct decl.
+ capturedVarDecl = m_astBuilder->create<VarDecl>();
+ capturedVarDecl->nameAndLoc = srcDecl->nameAndLoc;
+ SLANG_ASSERT(exprIn->type.type);
+ capturedVarDecl->type.type = exprIn->type.type;
+ m_mapSrcDeclToCapturedLambdaDecl->add(srcDecl, capturedVarDecl);
+ m_parentLambdaDecl->addMember(capturedVarDecl);
+
+ // Is captured value NonCopyable? If so, it needs to be an error.
+ if (isNonCopyableType(capturedVarDecl->type.type))
+ {
+ getSink()->diagnose(
+ exprIn,
+ Diagnostics::nonCopyableTypeCapturedInLambda,
+ capturedVarDecl->type.type);
+ }
+ }
+
+ // Return a VarExpr referencing the capturedVarDecl.
+ auto thisLambdaExpr = m_astBuilder->create<ThisExpr>();
+ thisLambdaExpr->scope = m_parentLambdaDecl->ownedScope;
+ thisLambdaExpr->type = QualType(DeclRefType::create(m_astBuilder, m_parentLambdaDecl));
+ thisLambdaExpr->checked = true;
+
+ auto resultMemberExpr = m_astBuilder->create<MemberExpr>();
+ resultMemberExpr->declRef = capturedVarDecl;
+ resultMemberExpr->baseExpression = thisLambdaExpr;
+ resultMemberExpr->type = exprIn->type;
+ resultMemberExpr->loc = exprIn->loc;
+
+ // For captured variables, we need to set the type to be a non-lvalue to prevent
+ // lambda expression body from mutating their values.
+ resultMemberExpr->type.isLeftValue = false;
+ resultMemberExpr->checked = true;
+ return resultMemberExpr;
}
Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType)
@@ -4075,6 +4178,102 @@ error:;
return expr;
}
+Expr* SemanticsExprVisitor::visitLambdaExpr(LambdaExpr* lambdaExpr)
+{
+ ASTSynthesizer synthesizer = ASTSynthesizer(m_astBuilder, getNamePool());
+ synthesizer.pushContainerScope(m_outerScope->containerDecl);
+
+ Dictionary<Decl*, VarDeclBase*> mapSrcDeclToCapturedDecl;
+ ensureAllDeclsRec(lambdaExpr->paramScopeDecl, DeclCheckState::DefinitionChecked);
+ LambdaDecl* lambdaStructDecl = m_astBuilder->create<LambdaDecl>();
+ auto subContext = withParentLambdaExpr(lambdaExpr, lambdaStructDecl, &mapSrcDeclToCapturedDecl);
+ addModifier(lambdaStructDecl, m_astBuilder->create<SynthesizedModifier>());
+ m_parentFunc->addMember(lambdaStructDecl);
+ synthesizer.pushScopeForContainer(lambdaStructDecl);
+ lambdaStructDecl->loc = lambdaExpr->loc;
+ StringBuilder nameBuilder;
+ nameBuilder << "_slang_Lambda_";
+ if (m_parentFunc)
+ {
+ nameBuilder << getText(m_parentFunc->getName());
+ }
+ nameBuilder << "_";
+ nameBuilder << m_parentFunc->members.getCount();
+ auto name = getName(nameBuilder.getBuffer());
+ lambdaStructDecl->nameAndLoc.name = name;
+ lambdaStructDecl->nameAndLoc.loc = lambdaExpr->loc;
+
+ auto funcDecl = m_astBuilder->create<FuncDecl>();
+ synthesizer.pushScopeForContainer(funcDecl);
+ funcDecl->loc = lambdaExpr->loc;
+ funcDecl->nameAndLoc.name = getName("()");
+ lambdaStructDecl->addMember(funcDecl);
+ lambdaStructDecl->funcDecl = funcDecl;
+ addModifier(funcDecl, m_astBuilder->create<SynthesizedModifier>());
+
+ // As we check the body, we will fill in the result type when we visit `ReturnStmt`.
+ dispatchStmt(lambdaExpr->bodyStmt, subContext);
+
+ // If the lambda has no return type, we will set it to `void`.
+ if (!funcDecl->returnType.type)
+ funcDecl->returnType.type = m_astBuilder->getVoidType();
+
+ synthesizer.popScope();
+ synthesizer.popScope();
+
+ funcDecl->body = lambdaExpr->bodyStmt;
+ for (auto param : lambdaExpr->paramScopeDecl->members)
+ {
+ funcDecl->addMember(param);
+ }
+
+ // LambdaDecl should inherit from `IFunc<>`.
+ if (funcDecl->returnType.type)
+ {
+ auto genApp = m_astBuilder->create<GenericAppExpr>();
+ genApp->functionExpr = synthesizer.emitVarExpr(getName("IFunc"));
+ auto returnTypeExp = synthesizer.emitStaticTypeExpr(funcDecl->returnType.type);
+ genApp->arguments.add(returnTypeExp);
+ for (auto param : getMembersOfType<ParamDecl>(m_astBuilder, lambdaExpr->paramScopeDecl))
+ {
+ auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param);
+ auto paramTypeExp = synthesizer.emitStaticTypeExpr(paramType);
+ genApp->arguments.add(paramTypeExp);
+ }
+ auto inheritanceDecl = m_astBuilder->create<InheritanceDecl>();
+ inheritanceDecl->base.exp = genApp;
+ lambdaStructDecl->addMember(inheritanceDecl);
+ }
+
+ // Synthesizer the ctor signature, and `IFunc` witness.
+ ensureDecl(lambdaStructDecl, DeclCheckState::AttributesChecked);
+
+ // Return an expr that represents `SynthesizedLambdaStruct.__init(captured_args...)`.
+ List<Expr*> args;
+ Dictionary<VarDeclBase*, Decl*> mapCapturedDeclToSrcDecl;
+ for (auto kv : mapSrcDeclToCapturedDecl)
+ {
+ mapCapturedDeclToSrcDecl[kv.second] = kv.first;
+ }
+ for (auto capturedField : getMembersOfType<VarDecl>(m_astBuilder, lambdaStructDecl))
+ {
+ auto src = mapCapturedDeclToSrcDecl[capturedField.getDecl()];
+ if (auto srcVarDecl = as<VarDeclBase>(src))
+ {
+ args.add(synthesizer.emitVarExpr(srcVarDecl));
+ }
+ else
+ {
+ args.add(synthesizer.emitThisExpr());
+ }
+ }
+ auto resultLambdaObj = synthesizer.emitCtorInvokeExpr(
+ synthesizer.emitStaticTypeExpr(DeclRefType::create(m_astBuilder, lambdaStructDecl)),
+ _Move(args));
+ auto checkedResultExpr = dispatchExpr(resultLambdaObj, *this);
+ return checkedResultExpr;
+}
+
void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr)
{
auto checkedInvokeExpr = as<InvokeExpr>(invokeExpr);
@@ -5039,6 +5238,10 @@ Expr* SemanticsExprVisitor::visitThisExpr(ThisExpr* expr)
else if (auto typeOrExtensionDecl = as<AggTypeDeclBase>(containerDecl))
{
expr->type.type = calcThisType(makeDeclRef(typeOrExtensionDecl));
+ if (m_parentLambdaExpr)
+ {
+ return maybeRegisterLambdaCapture(expr);
+ }
return expr;
}
#if 0