summaryrefslogtreecommitdiff
path: root/source/slang/check.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/check.cpp')
-rw-r--r--source/slang/check.cpp281
1 files changed, 257 insertions, 24 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 1af91e9fd..070d4c606 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -1407,6 +1407,30 @@ namespace Slang
return constIntVal;
}
+ // Check an expression, coerce it to the `String` type, and then
+ // ensure that it has a literal (not just compile-time constant) value.
+ bool checkLiteralStringVal(
+ RefPtr<Expr> expr,
+ String* outVal)
+ {
+ // TODO: This should actually perform semantic checking, etc.,
+ // but for now we are just going to look for a direct string
+ // literal AST node.
+
+ if(auto stringLitExpr = expr.As<StringLiteralExpr>())
+ {
+ if(outVal)
+ {
+ *outVal = stringLitExpr->value;
+ }
+ return true;
+ }
+
+ getSink()->diagnose(expr, Diagnostics::expectedAStringLiteral);
+
+ return false;
+ }
+
void visitSyntaxDecl(SyntaxDecl*)
{
// These are only used in the stdlib, so no checking is needed
@@ -1456,6 +1480,36 @@ namespace Slang
}
}
+ Stage findStageByName(String const& name)
+ {
+ static const struct
+ {
+ char const* name;
+ Stage stage;
+ } kStages[] =
+ {
+ { "vertex", Stage::Vertex },
+ { "hull", Stage::Hull },
+ { "domain", Stage::Domain },
+ { "geometry", Stage::Geometry },
+ { "fragment", Stage::Fragment },
+ { "compute", Stage::Compute },
+
+ // Allow `pixel` as an alias of `fragment`
+ { "pixel", Stage::Fragment },
+ };
+
+ for(auto entry : kStages)
+ {
+ if(name == entry.name)
+ {
+ return entry.stage;
+ }
+ }
+
+ return Stage::Unknown;
+ }
+
bool validateAttribute(RefPtr<Attribute> attr)
{
if(auto numThreadsAttr = attr.As<NumThreadsAttribute>())
@@ -1487,6 +1541,24 @@ namespace Slang
instanceAttr->value = (int32_t)val->value;
}
+ else if(auto entryPointAttr = attr.As<EntryPointAttribute>())
+ {
+ SLANG_ASSERT(attr->args.Count() == 1);
+
+ String stageName;
+ if(!checkLiteralStringVal(attr->args[0], &stageName))
+ {
+ return false;
+ }
+
+ auto stage = findStageByName(stageName);
+ if(stage == Stage::Unknown)
+ {
+ getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName);
+ }
+
+ entryPointAttr->stage = stage;
+ }
else
{
if(attr->args.Count() == 0)
@@ -3172,35 +3244,51 @@ namespace Slang
stmt->Expression = CheckExpr(stmt->Expression);
}
- RefPtr<Expr> visitConstantExpr(ConstantExpr *expr)
+ RefPtr<Expr> visitBoolLiteralExpr(BoolLiteralExpr* expr)
{
- // The expression might already have a type, determined by its suffix
- if(expr->type.type)
- return expr;
+ expr->type = getSession()->getBoolType();
+ return expr;
+ }
- switch (expr->ConstType)
+ RefPtr<Expr> visitIntegerLiteralExpr(IntegerLiteralExpr* expr)
+ {
+ // The expression might already have a type, determined by its suffix.
+ // It it doesn't, we will give it a default type.
+ //
+ // TODO: We should be careful to pick a "big enough" type
+ // based on the size of the value (e.g., don't try to stuff
+ // a constant in an `int` if it requires 64 or more bits).
+ //
+ // The long-term solution here is to give a type to a literal
+ // based on the context where it is used, but that requires
+ // a more sophisticated type system than we have today.
+ //
+ if(!expr->type.type)
{
- case ConstantExpr::ConstantType::Int:
expr->type = getSession()->getIntType();
- break;
- case ConstantExpr::ConstantType::Bool:
- expr->type = getSession()->getBoolType();
- break;
- case ConstantExpr::ConstantType::Float:
+ }
+ return expr;
+ }
+
+ RefPtr<Expr> visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr)
+ {
+ if(!expr->type.type)
+ {
expr->type = getSession()->getFloatType();
- break;
- default:
- expr->type = QualType(getSession()->getErrorType());
- throw "Invalid constant type.";
- break;
}
return expr;
}
- IntVal* GetIntVal(ConstantExpr* expr)
+ RefPtr<Expr> visitStringLiteralExpr(StringLiteralExpr* expr)
+ {
+ expr->type = getSession()->getStringType();
+ return expr;
+ }
+
+ IntVal* GetIntVal(IntegerLiteralExpr* expr)
{
// TODO(tfoley): don't keep allocating here!
- return new ConstantIntVal(expr->integerValue);
+ return new ConstantIntVal(expr->value);
}
Name* getName(String const& text)
@@ -3339,9 +3427,9 @@ namespace Slang
}
// TODO(tfoley): more serious constant folding here
- if (auto constExp = dynamic_cast<ConstantExpr*>(expr))
+ if (auto intLitExpr = dynamic_cast<IntegerLiteralExpr*>(expr))
{
- return GetIntVal(constExp);
+ return GetIntVal(intLitExpr);
}
// it is possible that we are referring to a generic value param
@@ -6898,7 +6986,46 @@ namespace Slang
return type;
}
+ // Validate that an entry point function conforms to any additional
+ // constraints based on the stage (and profile?) it specifies.
void validateEntryPoint(
+ EntryPointRequest* /*entryPoint*/)
+ {
+ // TODO: We currently don't do any checking here, but this is the
+ // right place to perform the following validation checks:
+ //
+ // * Does the entry point specify all of the attributes required
+ // by the chosen stage (e.g., a `[domain(...)]` attribute for]
+ // a hull shader.
+ //
+ // * Are the function input/output parameters and result type
+ // all valid for the chosen stage? (e.g., there shouldn't be
+ // an `OutputStream<X>` type in a vertex shader signature)
+ //
+ // * For any varying input/output, are there semantics specified
+ // (Note: this potentially overlaps with layout logic...), and
+ // are the system-value semantics valid for the given stage?
+ //
+ // There's actually a lot of detail to semantic checking, in
+ // that the AST-level code should probably be validating the
+ // use of system-value semantics by linking them to explicit
+ // declarations in the standard library. We should also be
+ // using profile information on those declarations to infer
+ // appropriate profile restrictions on the entry point.
+ //
+ // * Is the entry point actually usable on the given stage/profile?
+ // E.g., if we have a vertex shader that (transitively) calls
+ // `Texture2D.Sample`, then that should produce an error because
+ // that function is specific to the fragment profile/stage.
+ //
+ }
+
+ // Given an `EntryPointRequest` specified via API or command line options,
+ // attempt to find a matching AST declaration that implements the specified
+ // entry point. If such a function is found, then validate that it actually
+ // meets the requirements for the selected stage/profile.
+ //
+ void findAndValidateEntryPoint(
EntryPointRequest* entryPoint)
{
// The first step in validating the entry point is to find
@@ -7070,10 +7197,116 @@ namespace Slang
}
if (sink->errorCount != 0)
return;
- // TODO: after all that work, we are now in a position to start
- // validating the declaration itself. E.g., we should check if
- // the declared input/output parameters have suitable semantics,
- // if they are of types that are appropriate to the stage, etc.
+
+ // Now that we've *found* the entry point, it is time to validate
+ // that it actually meets the constraints for the chosen stage/profile.
+ validateEntryPoint(entryPoint);
+ }
+
+ void validateEntryPoints(
+ CompileRequest* compileRequest)
+ {
+ // The validation of entry points here will be modal, and controlled
+ // by whether the user specified any entry points directly via
+ // API or command-line options.
+ //
+ // TODO: We may want to make this choice explicit rather than implicit.
+ //
+ // First, check if the user request any entry points explicitly via
+ // the API or command line.
+ //
+ bool anyExplicitEntryPointRequests = false;
+ for (auto& translationUnit : compileRequest->translationUnits)
+ {
+ if( translationUnit->entryPoints.Count() != 0)
+ {
+ anyExplicitEntryPointRequests = true;
+ break;
+ }
+ }
+
+ if( anyExplicitEntryPointRequests )
+ {
+ // If there were any explicit requests for entry points to be
+ // checked, then we will *only* check those.
+
+ for (auto& translationUnit : compileRequest->translationUnits)
+ {
+ for (auto entryPoint : translationUnit->entryPoints)
+ {
+ findAndValidateEntryPoint(entryPoint);
+ }
+ }
+ }
+ else
+ {
+ // Otherwise, scan for any `[shader(...)]` attributes in
+ // the user's code, and construct `EntryPointRequest`s to
+ // represent them.
+ //
+ // This ensures that downstream code only has to consider
+ // the central list of entry point requests, and doesn't
+ // have to know where they came from.
+
+ // TODO: A comprehensive approach here would need to search
+ // recursively for entry points, because they might appear
+ // as, e.g., member function of a `struct` type.
+ //
+ // For now we'll start with an extremely basic approach that
+ // should work for typical HLSL code.
+ //
+ UInt translationUnitCount = compileRequest->translationUnits.Count();
+ for(UInt tt = 0; tt < translationUnitCount; ++tt)
+ {
+ auto translationUnit = compileRequest->translationUnits[tt];
+ for( auto globalDecl : translationUnit->SyntaxNode->Members )
+ {
+ auto maybeFuncDecl = globalDecl;
+ if( auto genericDecl = maybeFuncDecl->As<GenericDecl>() )
+ {
+ maybeFuncDecl = genericDecl->inner;
+ }
+
+ auto funcDecl = maybeFuncDecl->As<FuncDecl>();
+ if(!funcDecl)
+ continue;
+
+ auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>();
+ if(!entryPointAttr)
+ continue;
+
+ // We've discovered a valid entry point. It is a function (possibly
+ // generic) that has a `[shader(...)]` attribute to mark it as an
+ // entry point.
+ //
+ // We will now register that entry point as an `EntryPointRequest`
+ // with an appropriately chosen profile.
+ //
+ // The profile will only include a stage, so that the profile "family"
+ // and "version" are left unspecified. Downstream code will need
+ // to be able to handle this case.
+ //
+ Profile profile;
+ profile.setStage(entryPointAttr->stage);
+
+ // We manually fill in the entry point request object.
+ RefPtr<EntryPointRequest> entryPointReq = new EntryPointRequest();
+ entryPointReq->compileRequest = compileRequest;
+ entryPointReq->translationUnitIndex = int(tt);
+ entryPointReq->decl = funcDecl;
+ entryPointReq->name = funcDecl->getName();
+ entryPointReq->profile = profile;
+
+ // Apply the common validation logic to this entry point.
+ validateEntryPoint(entryPointReq);
+
+ // Add the entry point to the list in the translation unit,
+ // and also the global list in the compile request.
+ translationUnit->entryPoints.Add(entryPointReq);
+ compileRequest->entryPoints.Add(entryPointReq);
+ }
+ }
+ }
}
void checkTranslationUnit(