diff options
| author | Yong He <yonghe@outlook.com> | 2025-05-16 10:42:11 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-16 10:42:11 -0700 |
| commit | 9dfd5244ad2953753535e82acd05e72e5ab2bc5f (patch) | |
| tree | 0f7c8a0bf4ea3dd25348e1d1ac8e5bfcfd8c9724 /source/slang/slang-check-conversion.cpp | |
| parent | 1fd7b2296d8360c245a0c732e7f842876533f92a (diff) | |
Allow lambda exprs without captures to coerce to `functype`. (#7129)
Diffstat (limited to 'source/slang/slang-check-conversion.cpp')
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 147 |
1 files changed, 147 insertions, 0 deletions
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 68e5df7ce..8f2ea3ca7 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1,4 +1,5 @@ // slang-check-conversion.cpp +#include "slang-ast-synthesis.h" #include "slang-check-impl.h" // This file contains semantic-checking logic for dealing @@ -1388,6 +1389,21 @@ bool SemanticsVisitor::_coerce( } } + if (auto toFuncType = as<FuncType>(toType)) + { + if (auto fromLambdaType = isDeclRefTypeOf<StructDecl>(fromType)) + { + if (tryCoerceLambdaToFuncType(fromLambdaType, toFuncType, fromExpr, outToExpr)) + { + if (outCost) + { + *outCost = kConversionCost_LambdaToFunc; + } + return true; + } + } + } + // A type is always convertible to any of its supertypes. // if (auto witness = tryGetSubtypeWitness(fromType, toType)) @@ -1813,6 +1829,137 @@ bool SemanticsVisitor::_coerce( return _failedCoercion(toType, outToExpr, fromExpr, sink); } +bool SemanticsVisitor::tryCoerceLambdaToFuncType( + DeclRef<StructDecl> lambdaStruct, + FuncType* toFuncType, + Expr* fromExpr, + Expr** outToExpr) +{ + FuncDecl* synStaticFunc = nullptr; + FuncDecl* invokeFunc = nullptr; + + // First, check if `lambdaStruct` contains any fields. + // If it does, we can't convert it to a function type. + auto operatorName = getName("()"); + + for (auto member : lambdaStruct.getDecl()->members) + { + if (auto field = as<VarDecl>(member)) + { + if (!isEffectivelyStatic(field)) + return false; + } + else if (auto inheritanceDecl = as<InheritanceDecl>(member)) + { + // If the struct inherits from anything that is not an interface, + // we will consider it to be non-empty and not convertible to a function type. + if (!isDeclRefTypeOf<InterfaceDecl>(inheritanceDecl->base.type)) + return false; + } + else if (auto funcDecl = as<FuncDecl>(member)) + { + // If the struct already contains a synthesized static invoke member, use it. + if (isEffectivelyStatic(funcDecl) && + funcDecl->findModifier<SynthesizedStaticLambdaFuncModifier>() && + funcDecl->returnType.type == toFuncType->getResultType()) + synStaticFunc = funcDecl; + if (funcDecl->getName() == operatorName) + { + // If we found operator(), keep it for later. + invokeFunc = funcDecl; + } + } + } + + if (!invokeFunc) + { + return false; + } + + auto invokeFuncDeclRef = m_astBuilder->getMemberDeclRef(lambdaStruct, invokeFunc); + + // Verify that the function parameter types are exactly the same as toFuncType. + if (invokeFunc->getParameters().getCount() != toFuncType->getParamCount()) + { + return false; + } + Index paramId = 0; + for (auto param : invokeFunc->getParameters()) + { + auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param); + auto toParamType = toFuncType->getParamType(paramId); + if (!paramType->equals(toParamType)) + { + return false; + } + paramId++; + } + + // Verify that the return type of the function is convertible to the function type. + if (!canCoerce(toFuncType->getResultType(), invokeFunc->returnType.type, nullptr)) + { + return false; + } + + if (!synStaticFunc) + { + // If the struct doesn't contain a static method for operator(), we try to synthesize one. + synStaticFunc = m_astBuilder->create<FuncDecl>(); + synStaticFunc->nameAndLoc.name = getName("__syn_static_invoke"); + addModifier(synStaticFunc, m_astBuilder->create<SynthesizedStaticLambdaFuncModifier>()); + addModifier(synStaticFunc, m_astBuilder->create<HLSLStaticModifier>()); + addModifier(synStaticFunc, m_astBuilder->create<SynthesizedModifier>()); + + synStaticFunc->ownedScope = m_astBuilder->create<Scope>(); + synStaticFunc->ownedScope->containerDecl = synStaticFunc; + synStaticFunc->ownedScope->parent = getScope(lambdaStruct.getDecl()); + synStaticFunc->parentDecl = lambdaStruct.getDecl(); + synStaticFunc->returnType.type = toFuncType->getResultType(); + + List<Expr*> synArgs; + addRequiredParamsToSynthesizedDecl(invokeFuncDeclRef, synStaticFunc, synArgs); + ThisExpr* synThis = nullptr; + addModifiersToSynthesizedDecl(nullptr, invokeFuncDeclRef, synStaticFunc, synThis); + + ASTSynthesizer synth(m_astBuilder, getNamePool()); + synth.pushContainerScope(synStaticFunc); + auto blockStmt = m_astBuilder->create<BlockStmt>(); + synStaticFunc->body = blockStmt; + auto seqStmt = synth.pushSeqStmtScope(); + blockStmt->body = seqStmt; + + synth.pushVarScope(); + + // emit `return LambdaStructType().operator()(args...)`. + auto tempThis = synth.emitInvokeExpr( + synth.emitStaticTypeExpr(DeclRefType::create(m_astBuilder, lambdaStruct)), + List<Expr*>()); + tempThis = dispatchExpr(tempThis, *this); + + Expr* operatorRefExpr = synth.emitMemberExpr(toFuncType, tempThis, invokeFuncDeclRef); + + auto invokeExpr = synth.emitInvokeExpr(operatorRefExpr, _Move(synArgs)); + invokeExpr = dispatchExpr(invokeExpr, *this); + auto resultValue = + coerce(CoercionSite::Return, toFuncType->getResultType(), invokeExpr, getSink()); + synth.emitReturnStmt(resultValue); + + lambdaStruct.getDecl()->addMember(synStaticFunc); + } + + // If we have a static method for operator(), we can convert the lambda to a function type. + if (outToExpr) + { + VarExpr* expr = m_astBuilder->create<VarExpr>(); + expr->loc = fromExpr->loc; + expr->declRef = m_astBuilder->getMemberDeclRef(lambdaStruct, synStaticFunc); + expr->type = QualType(toFuncType); + expr->checked = true; + *outToExpr = expr; + } + return true; +} + bool SemanticsVisitor::canCoerce( Type* toType, QualType fromType, |
