summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-modifier.h8
-rw-r--r--source/slang/slang-ast-support-types.h1
-rw-r--r--source/slang/slang-ast-synthesis.cpp10
-rw-r--r--source/slang/slang-ast-synthesis.h1
-rw-r--r--source/slang/slang-check-conversion.cpp147
-rw-r--r--source/slang/slang-check-decl.cpp2
-rw-r--r--source/slang/slang-check-expr.cpp6
-rw-r--r--source/slang/slang-check-impl.h9
-rw-r--r--tests/language-feature/lambda/coerce-failure.slang19
-rw-r--r--tests/language-feature/lambda/coerce-to-functype.slang13
10 files changed, 213 insertions, 3 deletions
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index e566eca9e..ac1960685 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -158,6 +158,14 @@ class SynthesizedModifier : public Modifier
FIDDLE(...)
};
+// Marks that the definition of a func decl is synthesized static invoke func for
+// a lambda that doesn't capture anything.
+FIDDLE()
+class SynthesizedStaticLambdaFuncModifier : public Modifier
+{
+ FIDDLE(...)
+};
+
// Marks a synthesized variable as local temporary variable.
FIDDLE()
class LocalTempVarModifier : public Modifier
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 87715d9e0..b5ebe1884 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -93,6 +93,7 @@ FIDDLE() namespace Slang
kConversionCost_None = 0,
kConversionCost_GenericParamUpcast = 1,
+ kConversionCost_LambdaToFunc = 1,
kConversionCost_UnconstraintGenericParam = 20,
kConversionCost_SizedArrayToUnsizedArray = 30,
diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp
index c7291f526..5f7b527e2 100644
--- a/source/slang/slang-ast-synthesis.cpp
+++ b/source/slang/slang-ast-synthesis.cpp
@@ -168,6 +168,16 @@ Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name)
return rs;
}
+Expr* ASTSynthesizer::emitMemberExpr(QualType exprType, Expr* base, DeclRef<Decl> declRef)
+{
+ auto rs = m_builder->create<MemberExpr>();
+ rs->baseExpression = base;
+ rs->declRef = declRef;
+ rs->type = exprType;
+ rs->checked = base->checked;
+ return rs;
+}
+
Expr* ASTSynthesizer::emitIndexExpr(Expr* base, Expr* index)
{
auto rs = m_builder->create<IndexExpr>();
diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h
index c1072c705..bdfd90aeb 100644
--- a/source/slang/slang-ast-synthesis.h
+++ b/source/slang/slang-ast-synthesis.h
@@ -126,6 +126,7 @@ public:
Expr* emitMemberExpr(Expr* base, Name* name);
Expr* emitMemberExpr(Type* base, Name* name);
+ Expr* emitMemberExpr(QualType exprType, Expr* base, DeclRef<Decl> declRef);
Expr* emitIndexExpr(Expr* base, Expr* index);
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 68e5df7ce..8f2ea3ca7 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -1,4 +1,5 @@
// slang-check-conversion.cpp
+#include "slang-ast-synthesis.h"
#include "slang-check-impl.h"
// This file contains semantic-checking logic for dealing
@@ -1388,6 +1389,21 @@ bool SemanticsVisitor::_coerce(
}
}
+ if (auto toFuncType = as<FuncType>(toType))
+ {
+ if (auto fromLambdaType = isDeclRefTypeOf<StructDecl>(fromType))
+ {
+ if (tryCoerceLambdaToFuncType(fromLambdaType, toFuncType, fromExpr, outToExpr))
+ {
+ if (outCost)
+ {
+ *outCost = kConversionCost_LambdaToFunc;
+ }
+ return true;
+ }
+ }
+ }
+
// A type is always convertible to any of its supertypes.
//
if (auto witness = tryGetSubtypeWitness(fromType, toType))
@@ -1813,6 +1829,137 @@ bool SemanticsVisitor::_coerce(
return _failedCoercion(toType, outToExpr, fromExpr, sink);
}
+bool SemanticsVisitor::tryCoerceLambdaToFuncType(
+ DeclRef<StructDecl> lambdaStruct,
+ FuncType* toFuncType,
+ Expr* fromExpr,
+ Expr** outToExpr)
+{
+ FuncDecl* synStaticFunc = nullptr;
+ FuncDecl* invokeFunc = nullptr;
+
+ // First, check if `lambdaStruct` contains any fields.
+ // If it does, we can't convert it to a function type.
+ auto operatorName = getName("()");
+
+ for (auto member : lambdaStruct.getDecl()->members)
+ {
+ if (auto field = as<VarDecl>(member))
+ {
+ if (!isEffectivelyStatic(field))
+ return false;
+ }
+ else if (auto inheritanceDecl = as<InheritanceDecl>(member))
+ {
+ // If the struct inherits from anything that is not an interface,
+ // we will consider it to be non-empty and not convertible to a function type.
+ if (!isDeclRefTypeOf<InterfaceDecl>(inheritanceDecl->base.type))
+ return false;
+ }
+ else if (auto funcDecl = as<FuncDecl>(member))
+ {
+ // If the struct already contains a synthesized static invoke member, use it.
+ if (isEffectivelyStatic(funcDecl) &&
+ funcDecl->findModifier<SynthesizedStaticLambdaFuncModifier>() &&
+ funcDecl->returnType.type == toFuncType->getResultType())
+ synStaticFunc = funcDecl;
+ if (funcDecl->getName() == operatorName)
+ {
+ // If we found operator(), keep it for later.
+ invokeFunc = funcDecl;
+ }
+ }
+ }
+
+ if (!invokeFunc)
+ {
+ return false;
+ }
+
+ auto invokeFuncDeclRef = m_astBuilder->getMemberDeclRef(lambdaStruct, invokeFunc);
+
+ // Verify that the function parameter types are exactly the same as toFuncType.
+ if (invokeFunc->getParameters().getCount() != toFuncType->getParamCount())
+ {
+ return false;
+ }
+ Index paramId = 0;
+ for (auto param : invokeFunc->getParameters())
+ {
+ auto paramType = getParamTypeWithDirectionWrapper(m_astBuilder, param);
+ auto toParamType = toFuncType->getParamType(paramId);
+ if (!paramType->equals(toParamType))
+ {
+ return false;
+ }
+ paramId++;
+ }
+
+ // Verify that the return type of the function is convertible to the function type.
+ if (!canCoerce(toFuncType->getResultType(), invokeFunc->returnType.type, nullptr))
+ {
+ return false;
+ }
+
+ if (!synStaticFunc)
+ {
+ // If the struct doesn't contain a static method for operator(), we try to synthesize one.
+ synStaticFunc = m_astBuilder->create<FuncDecl>();
+ synStaticFunc->nameAndLoc.name = getName("__syn_static_invoke");
+ addModifier(synStaticFunc, m_astBuilder->create<SynthesizedStaticLambdaFuncModifier>());
+ addModifier(synStaticFunc, m_astBuilder->create<HLSLStaticModifier>());
+ addModifier(synStaticFunc, m_astBuilder->create<SynthesizedModifier>());
+
+ synStaticFunc->ownedScope = m_astBuilder->create<Scope>();
+ synStaticFunc->ownedScope->containerDecl = synStaticFunc;
+ synStaticFunc->ownedScope->parent = getScope(lambdaStruct.getDecl());
+ synStaticFunc->parentDecl = lambdaStruct.getDecl();
+ synStaticFunc->returnType.type = toFuncType->getResultType();
+
+ List<Expr*> synArgs;
+ addRequiredParamsToSynthesizedDecl(invokeFuncDeclRef, synStaticFunc, synArgs);
+ ThisExpr* synThis = nullptr;
+ addModifiersToSynthesizedDecl(nullptr, invokeFuncDeclRef, synStaticFunc, synThis);
+
+ ASTSynthesizer synth(m_astBuilder, getNamePool());
+ synth.pushContainerScope(synStaticFunc);
+ auto blockStmt = m_astBuilder->create<BlockStmt>();
+ synStaticFunc->body = blockStmt;
+ auto seqStmt = synth.pushSeqStmtScope();
+ blockStmt->body = seqStmt;
+
+ synth.pushVarScope();
+
+ // emit `return LambdaStructType().operator()(args...)`.
+ auto tempThis = synth.emitInvokeExpr(
+ synth.emitStaticTypeExpr(DeclRefType::create(m_astBuilder, lambdaStruct)),
+ List<Expr*>());
+ tempThis = dispatchExpr(tempThis, *this);
+
+ Expr* operatorRefExpr = synth.emitMemberExpr(toFuncType, tempThis, invokeFuncDeclRef);
+
+ auto invokeExpr = synth.emitInvokeExpr(operatorRefExpr, _Move(synArgs));
+ invokeExpr = dispatchExpr(invokeExpr, *this);
+ auto resultValue =
+ coerce(CoercionSite::Return, toFuncType->getResultType(), invokeExpr, getSink());
+ synth.emitReturnStmt(resultValue);
+
+ lambdaStruct.getDecl()->addMember(synStaticFunc);
+ }
+
+ // If we have a static method for operator(), we can convert the lambda to a function type.
+ if (outToExpr)
+ {
+ VarExpr* expr = m_astBuilder->create<VarExpr>();
+ expr->loc = fromExpr->loc;
+ expr->declRef = m_astBuilder->getMemberDeclRef(lambdaStruct, synStaticFunc);
+ expr->type = QualType(toFuncType);
+ expr->checked = true;
+ *outToExpr = expr;
+ }
+ return true;
+}
+
bool SemanticsVisitor::canCoerce(
Type* toType,
QualType fromType,
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index f340bd6fd..79482ea6e 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -4449,7 +4449,7 @@ void SemanticsVisitor::addModifiersToSynthesizedDecl(
auto synStaticModifier = m_astBuilder->create<HLSLStaticModifier>();
synthesized->modifiers.first = synStaticModifier;
}
- else
+ else if (context)
{
// For a non-`static` requirement, we need a `this` parameter.
//
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 41f945763..ad36a7e4a 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2887,6 +2887,7 @@ Expr* SemanticsVisitor::CheckInvokeExprWithCheckedOperands(InvokeExpr* expr)
}
}
}
+ rs->checked = true;
return rs;
}
@@ -4209,9 +4210,9 @@ Expr* SemanticsExprVisitor::visitLambdaExpr(LambdaExpr* lambdaExpr)
if (m_parentFunc)
{
nameBuilder << getText(m_parentFunc->getName());
+ nameBuilder << "_";
+ nameBuilder << m_parentFunc->members.getCount();
}
- nameBuilder << "_";
- nameBuilder << m_parentFunc->members.getCount();
auto name = getName(nameBuilder.getBuffer());
lambdaStructDecl->nameAndLoc.name = name;
lambdaStructDecl->nameAndLoc.loc = lambdaExpr->loc;
@@ -4283,6 +4284,7 @@ Expr* SemanticsExprVisitor::visitLambdaExpr(LambdaExpr* lambdaExpr)
auto resultLambdaObj = synthesizer.emitCtorInvokeExpr(
synthesizer.emitStaticTypeExpr(DeclRefType::create(m_astBuilder, lambdaStructDecl)),
_Move(args));
+ resultLambdaObj->loc = lambdaExpr->loc;
auto checkedResultExpr = dispatchExpr(resultLambdaObj, *this);
return checkedResultExpr;
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index f0884503f..87f2df53d 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1451,6 +1451,15 @@ public:
///
void ensureDeclBase(DeclBase* decl, DeclCheckState state, SemanticsContext* baseContext);
+ // Check if `lambdaStruct` can be coerced to `funcType`, if so returns the coerced
+ // expression in `outExpr`. The coercion is only valid if the lambda struct
+ // does not contain any captures.
+ bool tryCoerceLambdaToFuncType(
+ DeclRef<StructDecl> lambdaStruct,
+ FuncType* funcType,
+ Expr* fromExpr,
+ Expr** outExpr);
+
// A "proper" type is one that can be used as the type of an expression.
// Put simply, it can be a concrete type like `int`, or a generic
// type that is applied to arguments, like `Texture2D<float4>`.
diff --git a/tests/language-feature/lambda/coerce-failure.slang b/tests/language-feature/lambda/coerce-failure.slang
new file mode 100644
index 000000000..a5286611c
--- /dev/null
+++ b/tests/language-feature/lambda/coerce-failure.slang
@@ -0,0 +1,19 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+
+
+func test(f: functype(int, int)->float) -> float
+{
+ return f(2,3) + 10.0f;
+}
+
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ int c = 2;
+ // CHECK: ([[# @LINE+1]]): error 30019
+ let result = test((int x, int y)=> x + y + c);
+
+ // CHECK: ([[# @LINE+1]]): error 30019
+ let result1 = test((int x, float y) => x + y);
+} \ No newline at end of file
diff --git a/tests/language-feature/lambda/coerce-to-functype.slang b/tests/language-feature/lambda/coerce-to-functype.slang
new file mode 100644
index 000000000..9ec430def
--- /dev/null
+++ b/tests/language-feature/lambda/coerce-to-functype.slang
@@ -0,0 +1,13 @@
+//TEST:INTERPRET(filecheck=CHECK):
+
+func test(f: functype(int, int)->float) -> float
+{
+ return f(2,3) + 10.0f;
+}
+
+func main()
+{
+ let result = test((int x, int y)=>x+y);
+ // CHECK: 15.0
+ printf("%f\n", result);
+} \ No newline at end of file