diff options
Diffstat (limited to 'source/slang/check.cpp')
| -rw-r--r-- | source/slang/check.cpp | 281 |
1 files changed, 257 insertions, 24 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 1af91e9fd..070d4c606 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -1407,6 +1407,30 @@ namespace Slang return constIntVal; } + // Check an expression, coerce it to the `String` type, and then + // ensure that it has a literal (not just compile-time constant) value. + bool checkLiteralStringVal( + RefPtr<Expr> expr, + String* outVal) + { + // TODO: This should actually perform semantic checking, etc., + // but for now we are just going to look for a direct string + // literal AST node. + + if(auto stringLitExpr = expr.As<StringLiteralExpr>()) + { + if(outVal) + { + *outVal = stringLitExpr->value; + } + return true; + } + + getSink()->diagnose(expr, Diagnostics::expectedAStringLiteral); + + return false; + } + void visitSyntaxDecl(SyntaxDecl*) { // These are only used in the stdlib, so no checking is needed @@ -1456,6 +1480,36 @@ namespace Slang } } + Stage findStageByName(String const& name) + { + static const struct + { + char const* name; + Stage stage; + } kStages[] = + { + { "vertex", Stage::Vertex }, + { "hull", Stage::Hull }, + { "domain", Stage::Domain }, + { "geometry", Stage::Geometry }, + { "fragment", Stage::Fragment }, + { "compute", Stage::Compute }, + + // Allow `pixel` as an alias of `fragment` + { "pixel", Stage::Fragment }, + }; + + for(auto entry : kStages) + { + if(name == entry.name) + { + return entry.stage; + } + } + + return Stage::Unknown; + } + bool validateAttribute(RefPtr<Attribute> attr) { if(auto numThreadsAttr = attr.As<NumThreadsAttribute>()) @@ -1487,6 +1541,24 @@ namespace Slang instanceAttr->value = (int32_t)val->value; } + else if(auto entryPointAttr = attr.As<EntryPointAttribute>()) + { + SLANG_ASSERT(attr->args.Count() == 1); + + String stageName; + if(!checkLiteralStringVal(attr->args[0], &stageName)) + { + return false; + } + + auto stage = findStageByName(stageName); + if(stage == Stage::Unknown) + { + getSink()->diagnose(attr->args[0], Diagnostics::unknownStageName, stageName); + } + + entryPointAttr->stage = stage; + } else { if(attr->args.Count() == 0) @@ -3172,35 +3244,51 @@ namespace Slang stmt->Expression = CheckExpr(stmt->Expression); } - RefPtr<Expr> visitConstantExpr(ConstantExpr *expr) + RefPtr<Expr> visitBoolLiteralExpr(BoolLiteralExpr* expr) { - // The expression might already have a type, determined by its suffix - if(expr->type.type) - return expr; + expr->type = getSession()->getBoolType(); + return expr; + } - switch (expr->ConstType) + RefPtr<Expr> visitIntegerLiteralExpr(IntegerLiteralExpr* expr) + { + // The expression might already have a type, determined by its suffix. + // It it doesn't, we will give it a default type. + // + // TODO: We should be careful to pick a "big enough" type + // based on the size of the value (e.g., don't try to stuff + // a constant in an `int` if it requires 64 or more bits). + // + // The long-term solution here is to give a type to a literal + // based on the context where it is used, but that requires + // a more sophisticated type system than we have today. + // + if(!expr->type.type) { - case ConstantExpr::ConstantType::Int: expr->type = getSession()->getIntType(); - break; - case ConstantExpr::ConstantType::Bool: - expr->type = getSession()->getBoolType(); - break; - case ConstantExpr::ConstantType::Float: + } + return expr; + } + + RefPtr<Expr> visitFloatingPointLiteralExpr(FloatingPointLiteralExpr* expr) + { + if(!expr->type.type) + { expr->type = getSession()->getFloatType(); - break; - default: - expr->type = QualType(getSession()->getErrorType()); - throw "Invalid constant type."; - break; } return expr; } - IntVal* GetIntVal(ConstantExpr* expr) + RefPtr<Expr> visitStringLiteralExpr(StringLiteralExpr* expr) + { + expr->type = getSession()->getStringType(); + return expr; + } + + IntVal* GetIntVal(IntegerLiteralExpr* expr) { // TODO(tfoley): don't keep allocating here! - return new ConstantIntVal(expr->integerValue); + return new ConstantIntVal(expr->value); } Name* getName(String const& text) @@ -3339,9 +3427,9 @@ namespace Slang } // TODO(tfoley): more serious constant folding here - if (auto constExp = dynamic_cast<ConstantExpr*>(expr)) + if (auto intLitExpr = dynamic_cast<IntegerLiteralExpr*>(expr)) { - return GetIntVal(constExp); + return GetIntVal(intLitExpr); } // it is possible that we are referring to a generic value param @@ -6898,7 +6986,46 @@ namespace Slang return type; } + // Validate that an entry point function conforms to any additional + // constraints based on the stage (and profile?) it specifies. void validateEntryPoint( + EntryPointRequest* /*entryPoint*/) + { + // TODO: We currently don't do any checking here, but this is the + // right place to perform the following validation checks: + // + // * Does the entry point specify all of the attributes required + // by the chosen stage (e.g., a `[domain(...)]` attribute for] + // a hull shader. + // + // * Are the function input/output parameters and result type + // all valid for the chosen stage? (e.g., there shouldn't be + // an `OutputStream<X>` type in a vertex shader signature) + // + // * For any varying input/output, are there semantics specified + // (Note: this potentially overlaps with layout logic...), and + // are the system-value semantics valid for the given stage? + // + // There's actually a lot of detail to semantic checking, in + // that the AST-level code should probably be validating the + // use of system-value semantics by linking them to explicit + // declarations in the standard library. We should also be + // using profile information on those declarations to infer + // appropriate profile restrictions on the entry point. + // + // * Is the entry point actually usable on the given stage/profile? + // E.g., if we have a vertex shader that (transitively) calls + // `Texture2D.Sample`, then that should produce an error because + // that function is specific to the fragment profile/stage. + // + } + + // Given an `EntryPointRequest` specified via API or command line options, + // attempt to find a matching AST declaration that implements the specified + // entry point. If such a function is found, then validate that it actually + // meets the requirements for the selected stage/profile. + // + void findAndValidateEntryPoint( EntryPointRequest* entryPoint) { // The first step in validating the entry point is to find @@ -7070,10 +7197,116 @@ namespace Slang } if (sink->errorCount != 0) return; - // TODO: after all that work, we are now in a position to start - // validating the declaration itself. E.g., we should check if - // the declared input/output parameters have suitable semantics, - // if they are of types that are appropriate to the stage, etc. + + // Now that we've *found* the entry point, it is time to validate + // that it actually meets the constraints for the chosen stage/profile. + validateEntryPoint(entryPoint); + } + + void validateEntryPoints( + CompileRequest* compileRequest) + { + // The validation of entry points here will be modal, and controlled + // by whether the user specified any entry points directly via + // API or command-line options. + // + // TODO: We may want to make this choice explicit rather than implicit. + // + // First, check if the user request any entry points explicitly via + // the API or command line. + // + bool anyExplicitEntryPointRequests = false; + for (auto& translationUnit : compileRequest->translationUnits) + { + if( translationUnit->entryPoints.Count() != 0) + { + anyExplicitEntryPointRequests = true; + break; + } + } + + if( anyExplicitEntryPointRequests ) + { + // If there were any explicit requests for entry points to be + // checked, then we will *only* check those. + + for (auto& translationUnit : compileRequest->translationUnits) + { + for (auto entryPoint : translationUnit->entryPoints) + { + findAndValidateEntryPoint(entryPoint); + } + } + } + else + { + // Otherwise, scan for any `[shader(...)]` attributes in + // the user's code, and construct `EntryPointRequest`s to + // represent them. + // + // This ensures that downstream code only has to consider + // the central list of entry point requests, and doesn't + // have to know where they came from. + + // TODO: A comprehensive approach here would need to search + // recursively for entry points, because they might appear + // as, e.g., member function of a `struct` type. + // + // For now we'll start with an extremely basic approach that + // should work for typical HLSL code. + // + UInt translationUnitCount = compileRequest->translationUnits.Count(); + for(UInt tt = 0; tt < translationUnitCount; ++tt) + { + auto translationUnit = compileRequest->translationUnits[tt]; + for( auto globalDecl : translationUnit->SyntaxNode->Members ) + { + auto maybeFuncDecl = globalDecl; + if( auto genericDecl = maybeFuncDecl->As<GenericDecl>() ) + { + maybeFuncDecl = genericDecl->inner; + } + + auto funcDecl = maybeFuncDecl->As<FuncDecl>(); + if(!funcDecl) + continue; + + auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>(); + if(!entryPointAttr) + continue; + + // We've discovered a valid entry point. It is a function (possibly + // generic) that has a `[shader(...)]` attribute to mark it as an + // entry point. + // + // We will now register that entry point as an `EntryPointRequest` + // with an appropriately chosen profile. + // + // The profile will only include a stage, so that the profile "family" + // and "version" are left unspecified. Downstream code will need + // to be able to handle this case. + // + Profile profile; + profile.setStage(entryPointAttr->stage); + + // We manually fill in the entry point request object. + RefPtr<EntryPointRequest> entryPointReq = new EntryPointRequest(); + entryPointReq->compileRequest = compileRequest; + entryPointReq->translationUnitIndex = int(tt); + entryPointReq->decl = funcDecl; + entryPointReq->name = funcDecl->getName(); + entryPointReq->profile = profile; + + // Apply the common validation logic to this entry point. + validateEntryPoint(entryPointReq); + + // Add the entry point to the list in the translation unit, + // and also the global list in the compile request. + translationUnit->entryPoints.Add(entryPointReq); + compileRequest->entryPoints.Add(entryPointReq); + } + } + } } void checkTranslationUnit( |
