From 9dfd5244ad2953753535e82acd05e72e5ab2bc5f Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 16 May 2025 10:42:11 -0700 Subject: Allow lambda exprs without captures to coerce to `functype`. (#7129) --- source/slang/slang-ast-modifier.h | 8 ++ source/slang/slang-ast-support-types.h | 1 + source/slang/slang-ast-synthesis.cpp | 10 ++ source/slang/slang-ast-synthesis.h | 1 + source/slang/slang-check-conversion.cpp | 147 +++++++++++++++++++++ source/slang/slang-check-decl.cpp | 2 +- source/slang/slang-check-expr.cpp | 6 +- source/slang/slang-check-impl.h | 9 ++ tests/language-feature/lambda/coerce-failure.slang | 19 +++ .../lambda/coerce-to-functype.slang | 13 ++ 10 files changed, 213 insertions(+), 3 deletions(-) create mode 100644 tests/language-feature/lambda/coerce-failure.slang create mode 100644 tests/language-feature/lambda/coerce-to-functype.slang diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index e566eca9e..ac1960685 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -158,6 +158,14 @@ class SynthesizedModifier : public Modifier FIDDLE(...) }; +// Marks that the definition of a func decl is synthesized static invoke func for +// a lambda that doesn't capture anything. +FIDDLE() +class SynthesizedStaticLambdaFuncModifier : public Modifier +{ + FIDDLE(...) +}; + // Marks a synthesized variable as local temporary variable. FIDDLE() class LocalTempVarModifier : public Modifier diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 87715d9e0..b5ebe1884 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -93,6 +93,7 @@ FIDDLE() namespace Slang kConversionCost_None = 0, kConversionCost_GenericParamUpcast = 1, + kConversionCost_LambdaToFunc = 1, kConversionCost_UnconstraintGenericParam = 20, kConversionCost_SizedArrayToUnsizedArray = 30, diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp index c7291f526..5f7b527e2 100644 --- a/source/slang/slang-ast-synthesis.cpp +++ b/source/slang/slang-ast-synthesis.cpp @@ -168,6 +168,16 @@ Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name) return rs; } +Expr* ASTSynthesizer::emitMemberExpr(QualType exprType, Expr* base, DeclRef declRef) +{ + auto rs = m_builder->create(); + rs->baseExpression = base; + rs->declRef = declRef; + rs->type = exprType; + rs->checked = base->checked; + return rs; +} + Expr* ASTSynthesizer::emitIndexExpr(Expr* base, Expr* index) { auto rs = m_builder->create(); diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h index c1072c705..bdfd90aeb 100644 --- a/source/slang/slang-ast-synthesis.h +++ b/source/slang/slang-ast-synthesis.h @@ -126,6 +126,7 @@ public: Expr* emitMemberExpr(Expr* base, Name* name); Expr* emitMemberExpr(Type* base, Name* name); + Expr* emitMemberExpr(QualType exprType, Expr* base, DeclRef declRef); Expr* emitIndexExpr(Expr* base, Expr* index); diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 68e5df7ce..8f2ea3ca7 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1,4 +1,5 @@ // slang-check-conversion.cpp +#include "slang-ast-synthesis.h" #include "slang-check-impl.h" // This file contains semantic-checking logic for dealing @@ -1388,6 +1389,21 @@ bool SemanticsVisitor::_coerce( } } + if (auto toFuncType = as(toType)) + { + if (auto fromLambdaType = isDeclRefTypeOf(fromType)) + { + if (tryCoerceLambdaToFuncType(fromLambdaType, toFuncType, fromExpr, outToExpr)) + { + if (outCost) + { + *outCost = kConversionCost_LambdaToFunc; + } + return true; + } + } + } + // A type is always convertible to any of its supertypes. // if (auto witness = tryGetSubtypeWitness(fromType, toType)) @@ -1813,6 +1829,137 @@ bool SemanticsVisitor::_coerce( return _failedCoercion(toType, outToExpr, fromExpr, sink); } +bool SemanticsVisitor::tryCoerceLambdaToFuncType( + DeclRef lambdaStruct, + FuncType* toFuncType, + Expr* fromExpr, + Expr** outToExpr) +{ + FuncDecl* synStaticFunc = nullptr; + FuncDecl* invokeFunc = nullptr; + + // First, check if `lambdaStruct` contains any fields. + // If it does, we can't convert it to a function type. + auto operatorName = getName("()"); + + for (auto member : lambdaStruct.getDecl()->members) + { + if (auto field = as(member)) + { + if (!isEffectivelyStatic(field)) + return false; + } + else if (auto inheritanceDecl = as(member)) + { + // If the struct inherits from anything that is not an interface, + // we will consider it to be non-empty and not convertible to a function type. + if (!isDeclRefTypeOf(inheritanceDecl->base.type)) + return false; + } + else if (auto funcDecl = as(member)) + { + // If the struct already contains a synthesized static invoke member, use it. + if (isEffectivelyStatic(funcDecl) && + funcDecl->findModifier() && + funcDecl->returnType.type == toFuncType->getResultType()) + synStaticFunc = funcDecl; + if (funcDecl->getName() == operatorName) + { + // If we found operator(), keep it for later. + invokeFunc = funcDecl; + } + } + } + + if (!invokeFunc) + { + return false; + } + + auto invokeFuncDeclRef = m_astBuilder->getMemberDeclRef(lambdaStruct, invokeFunc); + + // Verify that the function parameter types are exactly the same as toFuncType. + if (invokeFunc->getParameters().getCount() != toFuncType->getParamCount()) + { + return false; + } + Index paramId = 0; + for (auto param : invokeFunc->getParameters()) + { + auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param); + auto toParamType = toFuncType->getParamType(paramId); + if (!paramType->equals(toParamType)) + { + return false; + } + paramId++; + } + + // Verify that the return type of the function is convertible to the function type. + if (!canCoerce(toFuncType->getResultType(), invokeFunc->returnType.type, nullptr)) + { + return false; + } + + if (!synStaticFunc) + { + // If the struct doesn't contain a static method for operator(), we try to synthesize one. + synStaticFunc = m_astBuilder->create(); + synStaticFunc->nameAndLoc.name = getName("__syn_static_invoke"); + addModifier(synStaticFunc, m_astBuilder->create()); + addModifier(synStaticFunc, m_astBuilder->create()); + addModifier(synStaticFunc, m_astBuilder->create()); + + synStaticFunc->ownedScope = m_astBuilder->create(); + synStaticFunc->ownedScope->containerDecl = synStaticFunc; + synStaticFunc->ownedScope->parent = getScope(lambdaStruct.getDecl()); + synStaticFunc->parentDecl = lambdaStruct.getDecl(); + synStaticFunc->returnType.type = toFuncType->getResultType(); + + List synArgs; + addRequiredParamsToSynthesizedDecl(invokeFuncDeclRef, synStaticFunc, synArgs); + ThisExpr* synThis = nullptr; + addModifiersToSynthesizedDecl(nullptr, invokeFuncDeclRef, synStaticFunc, synThis); + + ASTSynthesizer synth(m_astBuilder, getNamePool()); + synth.pushContainerScope(synStaticFunc); + auto blockStmt = m_astBuilder->create(); + synStaticFunc->body = blockStmt; + auto seqStmt = synth.pushSeqStmtScope(); + blockStmt->body = seqStmt; + + synth.pushVarScope(); + + // emit `return LambdaStructType().operator()(args...)`. + auto tempThis = synth.emitInvokeExpr( + synth.emitStaticTypeExpr(DeclRefType::create(m_astBuilder, lambdaStruct)), + List()); + tempThis = dispatchExpr(tempThis, *this); + + Expr* operatorRefExpr = synth.emitMemberExpr(toFuncType, tempThis, invokeFuncDeclRef); + + auto invokeExpr = synth.emitInvokeExpr(operatorRefExpr, _Move(synArgs)); + invokeExpr = dispatchExpr(invokeExpr, *this); + auto resultValue = + coerce(CoercionSite::Return, toFuncType->getResultType(), invokeExpr, getSink()); + synth.emitReturnStmt(resultValue); + + lambdaStruct.getDecl()->addMember(synStaticFunc); + } + + // If we have a static method for operator(), we can convert the lambda to a function type. + if (outToExpr) + { + VarExpr* expr = m_astBuilder->create(); + expr->loc = fromExpr->loc; + expr->declRef = m_astBuilder->getMemberDeclRef(lambdaStruct, synStaticFunc); + expr->type = QualType(toFuncType); + expr->checked = true; + *outToExpr = expr; + } + return true; +} + bool SemanticsVisitor::canCoerce( Type* toType, QualType fromType, diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index f340bd6fd..79482ea6e 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -4449,7 +4449,7 @@ void SemanticsVisitor::addModifiersToSynthesizedDecl( auto synStaticModifier = m_astBuilder->create(); synthesized->modifiers.first = synStaticModifier; } - else + else if (context) { // For a non-`static` requirement, we need a `this` parameter. // diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 41f945763..ad36a7e4a 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2887,6 +2887,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr) } } } + rs->checked = true; return rs; } @@ -4209,9 +4210,9 @@ Expr* SemanticsExprVisitor::visitLambdaExpr(LambdaExpr* lambdaExpr) if (m_parentFunc) { nameBuilder << getText(m_parentFunc->getName()); + nameBuilder << "_"; + nameBuilder << m_parentFunc->members.getCount(); } - nameBuilder << "_"; - nameBuilder << m_parentFunc->members.getCount(); auto name = getName(nameBuilder.getBuffer()); lambdaStructDecl->nameAndLoc.name = name; lambdaStructDecl->nameAndLoc.loc = lambdaExpr->loc; @@ -4283,6 +4284,7 @@ Expr* SemanticsExprVisitor::visitLambdaExpr(LambdaExpr* lambdaExpr) auto resultLambdaObj = synthesizer.emitCtorInvokeExpr( synthesizer.emitStaticTypeExpr(DeclRefType::create(m_astBuilder, lambdaStructDecl)), _Move(args)); + resultLambdaObj->loc = lambdaExpr->loc; auto checkedResultExpr = dispatchExpr(resultLambdaObj, *this); return checkedResultExpr; } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index f0884503f..87f2df53d 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1451,6 +1451,15 @@ public: /// void ensureDeclBase(DeclBase* decl, DeclCheckState state, SemanticsContext* baseContext); + // Check if `lambdaStruct` can be coerced to `funcType`, if so returns the coerced + // expression in `outExpr`. The coercion is only valid if the lambda struct + // does not contain any captures. + bool tryCoerceLambdaToFuncType( + DeclRef lambdaStruct, + FuncType* funcType, + Expr* fromExpr, + Expr** outExpr); + // A "proper" type is one that can be used as the type of an expression. // Put simply, it can be a concrete type like `int`, or a generic // type that is applied to arguments, like `Texture2D`. diff --git a/tests/language-feature/lambda/coerce-failure.slang b/tests/language-feature/lambda/coerce-failure.slang new file mode 100644 index 000000000..a5286611c --- /dev/null +++ b/tests/language-feature/lambda/coerce-failure.slang @@ -0,0 +1,19 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv + + +func test(f: functype(int, int)->float) -> float +{ + return f(2,3) + 10.0f; +} + + +[numthreads(1,1,1)] +void computeMain() +{ + int c = 2; + // CHECK: ([[# @LINE+1]]): error 30019 + let result = test((int x, int y)=> x + y + c); + + // CHECK: ([[# @LINE+1]]): error 30019 + let result1 = test((int x, float y) => x + y); +} \ No newline at end of file diff --git a/tests/language-feature/lambda/coerce-to-functype.slang b/tests/language-feature/lambda/coerce-to-functype.slang new file mode 100644 index 000000000..9ec430def --- /dev/null +++ b/tests/language-feature/lambda/coerce-to-functype.slang @@ -0,0 +1,13 @@ +//TEST:INTERPRET(filecheck=CHECK): + +func test(f: functype(int, int)->float) -> float +{ + return f(2,3) + 10.0f; +} + +func main() +{ + let result = test((int x, int y)=>x+y); + // CHECK: 15.0 + printf("%f\n", result); +} \ No newline at end of file -- cgit v1.2.3