summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-conversion.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-05-16 10:42:11 -0700
committerGitHub <noreply@github.com>2025-05-16 10:42:11 -0700
commit9dfd5244ad2953753535e82acd05e72e5ab2bc5f (patch)
tree0f7c8a0bf4ea3dd25348e1d1ac8e5bfcfd8c9724 /source/slang/slang-check-conversion.cpp
parent1fd7b2296d8360c245a0c732e7f842876533f92a (diff)
Allow lambda exprs without captures to coerce to `functype`. (#7129)
Diffstat (limited to 'source/slang/slang-check-conversion.cpp')
-rw-r--r--source/slang/slang-check-conversion.cpp147
1 files changed, 147 insertions, 0 deletions
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 68e5df7ce..8f2ea3ca7 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -1,4 +1,5 @@
// slang-check-conversion.cpp
+#include "slang-ast-synthesis.h"
#include "slang-check-impl.h"
// This file contains semantic-checking logic for dealing
@@ -1388,6 +1389,21 @@ bool SemanticsVisitor::_coerce(
}
}
+ if (auto toFuncType = as<FuncType>(toType))
+ {
+ if (auto fromLambdaType = isDeclRefTypeOf<StructDecl>(fromType))
+ {
+ if (tryCoerceLambdaToFuncType(fromLambdaType, toFuncType, fromExpr, outToExpr))
+ {
+ if (outCost)
+ {
+ *outCost = kConversionCost_LambdaToFunc;
+ }
+ return true;
+ }
+ }
+ }
+
// A type is always convertible to any of its supertypes.
//
if (auto witness = tryGetSubtypeWitness(fromType, toType))
@@ -1813,6 +1829,137 @@ bool SemanticsVisitor::_coerce(
return _failedCoercion(toType, outToExpr, fromExpr, sink);
}
+bool SemanticsVisitor::tryCoerceLambdaToFuncType(
+ DeclRef<StructDecl> lambdaStruct,
+ FuncType* toFuncType,
+ Expr* fromExpr,
+ Expr** outToExpr)
+{
+ FuncDecl* synStaticFunc = nullptr;
+ FuncDecl* invokeFunc = nullptr;
+
+ // First, check if `lambdaStruct` contains any fields.
+ // If it does, we can't convert it to a function type.
+ auto operatorName = getName("()");
+
+ for (auto member : lambdaStruct.getDecl()->members)
+ {
+ if (auto field = as<VarDecl>(member))
+ {
+ if (!isEffectivelyStatic(field))
+ return false;
+ }
+ else if (auto inheritanceDecl = as<InheritanceDecl>(member))
+ {
+ // If the struct inherits from anything that is not an interface,
+ // we will consider it to be non-empty and not convertible to a function type.
+ if (!isDeclRefTypeOf<InterfaceDecl>(inheritanceDecl->base.type))
+ return false;
+ }
+ else if (auto funcDecl = as<FuncDecl>(member))
+ {
+ // If the struct already contains a synthesized static invoke member, use it.
+ if (isEffectivelyStatic(funcDecl) &&
+ funcDecl->findModifier<SynthesizedStaticLambdaFuncModifier>() &&
+ funcDecl->returnType.type == toFuncType->getResultType())
+ synStaticFunc = funcDecl;
+ if (funcDecl->getName() == operatorName)
+ {
+ // If we found operator(), keep it for later.
+ invokeFunc = funcDecl;
+ }
+ }
+ }
+
+ if (!invokeFunc)
+ {
+ return false;
+ }
+
+ auto invokeFuncDeclRef = m_astBuilder->getMemberDeclRef(lambdaStruct, invokeFunc);
+
+ // Verify that the function parameter types are exactly the same as toFuncType.
+ if (invokeFunc->getParameters().getCount() != toFuncType->getParamCount())
+ {
+ return false;
+ }
+ Index paramId = 0;
+ for (auto param : invokeFunc->getParameters())
+ {
+ auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param);
+ auto toParamType = toFuncType->getParamType(paramId);
+ if (!paramType->equals(toParamType))
+ {
+ return false;
+ }
+ paramId++;
+ }
+
+ // Verify that the return type of the function is convertible to the function type.
+ if (!canCoerce(toFuncType->getResultType(), invokeFunc->returnType.type, nullptr))
+ {
+ return false;
+ }
+
+ if (!synStaticFunc)
+ {
+ // If the struct doesn't contain a static method for operator(), we try to synthesize one.
+ synStaticFunc = m_astBuilder->create<FuncDecl>();
+ synStaticFunc->nameAndLoc.name = getName("__syn_static_invoke");
+ addModifier(synStaticFunc, m_astBuilder->create<SynthesizedStaticLambdaFuncModifier>());
+ addModifier(synStaticFunc, m_astBuilder->create<HLSLStaticModifier>());
+ addModifier(synStaticFunc, m_astBuilder->create<SynthesizedModifier>());
+
+ synStaticFunc->ownedScope = m_astBuilder->create<Scope>();
+ synStaticFunc->ownedScope->containerDecl = synStaticFunc;
+ synStaticFunc->ownedScope->parent = getScope(lambdaStruct.getDecl());
+ synStaticFunc->parentDecl = lambdaStruct.getDecl();
+ synStaticFunc->returnType.type = toFuncType->getResultType();
+
+ List<Expr*> synArgs;
+ addRequiredParamsToSynthesizedDecl(invokeFuncDeclRef, synStaticFunc, synArgs);
+ ThisExpr* synThis = nullptr;
+ addModifiersToSynthesizedDecl(nullptr, invokeFuncDeclRef, synStaticFunc, synThis);
+
+ ASTSynthesizer synth(m_astBuilder, getNamePool());
+ synth.pushContainerScope(synStaticFunc);
+ auto blockStmt = m_astBuilder->create<BlockStmt>();
+ synStaticFunc->body = blockStmt;
+ auto seqStmt = synth.pushSeqStmtScope();
+ blockStmt->body = seqStmt;
+
+ synth.pushVarScope();
+
+ // emit `return LambdaStructType().operator()(args...)`.
+ auto tempThis = synth.emitInvokeExpr(
+ synth.emitStaticTypeExpr(DeclRefType::create(m_astBuilder, lambdaStruct)),
+ List<Expr*>());
+ tempThis = dispatchExpr(tempThis, *this);
+
+ Expr* operatorRefExpr = synth.emitMemberExpr(toFuncType, tempThis, invokeFuncDeclRef);
+
+ auto invokeExpr = synth.emitInvokeExpr(operatorRefExpr, _Move(synArgs));
+ invokeExpr = dispatchExpr(invokeExpr, *this);
+ auto resultValue =
+ coerce(CoercionSite::Return, toFuncType->getResultType(), invokeExpr, getSink());
+ synth.emitReturnStmt(resultValue);
+
+ lambdaStruct.getDecl()->addMember(synStaticFunc);
+ }
+
+ // If we have a static method for operator(), we can convert the lambda to a function type.
+ if (outToExpr)
+ {
+ VarExpr* expr = m_astBuilder->create<VarExpr>();
+ expr->loc = fromExpr->loc;
+ expr->declRef = m_astBuilder->getMemberDeclRef(lambdaStruct, synStaticFunc);
+ expr->type = QualType(toFuncType);
+ expr->checked = true;
+ *outToExpr = expr;
+ }
+ return true;
+}
+
bool SemanticsVisitor::canCoerce(
Type* toType,
QualType fromType,