diff options
Diffstat (limited to 'source/slang/check.cpp')
| -rw-r--r-- | source/slang/check.cpp | 1041 |
1 files changed, 578 insertions, 463 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 3485afeea..483db60bb 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -400,22 +400,18 @@ namespace Slang else return DeclCheckState::CheckedHeader; } - DiagnosticSink* sink = nullptr; + + Linkage* m_linkage = nullptr; + DiagnosticSink* m_sink = nullptr; + DiagnosticSink* getSink() { - return sink; + return m_sink; } // ModuleDecl * program = nullptr; FuncDecl * function = nullptr; - CompileRequest* request = nullptr; - TranslationUnitRequest* translationUnit = nullptr; - - SourceLanguage getSourceLanguage() - { - return translationUnit->sourceLanguage; - } // lexical outer statements List<Stmt*> outerStmts; @@ -429,20 +425,15 @@ namespace Slang public: SemanticsVisitor( - DiagnosticSink* sink, - CompileRequest* request, - TranslationUnitRequest* translationUnit) - : sink(sink) - , request(request) - , translationUnit(translationUnit) - { - } + Linkage* linkage, + DiagnosticSink* sink) + : m_linkage(linkage) + , m_sink(sink) + {} - CompileRequest* getCompileRequest() { return request; } - TranslationUnitRequest* getTranslationUnit() { return translationUnit; } Session* getSession() { - return getCompileRequest()->mSession; + return m_linkage->getSession(); } public: @@ -985,7 +976,7 @@ namespace Slang catch(AbortCompilationException&) { throw; } catch(...) { - getCompileRequest()->noteInternalErrorLoc(decl->loc); + getSink()->noteInternalErrorLoc(decl->loc); throw; } } @@ -998,7 +989,7 @@ namespace Slang catch(AbortCompilationException&) { throw; } catch(...) { - getCompileRequest()->noteInternalErrorLoc(stmt->loc); + getSink()->noteInternalErrorLoc(stmt->loc); throw; } } @@ -1011,7 +1002,7 @@ namespace Slang catch(AbortCompilationException&) { throw; } catch(...) { - getCompileRequest()->noteInternalErrorLoc(expr->loc); + getSink()->noteInternalErrorLoc(expr->loc); throw; } } @@ -1030,7 +1021,7 @@ namespace Slang // being checked on the stack, so that we can report the full // chain that leads from this declaration back to itself. // - sink->diagnose(decl, Diagnostics::cyclicReference, decl); + getSink()->diagnose(decl, Diagnostics::cyclicReference, decl); return; } @@ -1050,7 +1041,7 @@ namespace Slang // TODO: This diagnostic should be emitted on the line that is referencing // the declaration. That requires `EnsureDecl` to take the requesting // location as a parameter. - sink->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl); + getSink()->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl); return; } } @@ -3019,7 +3010,7 @@ namespace Slang checkDecl(func); } - if (sink->GetErrorCount() != 0) + if (getSink()->GetErrorCount() != 0) return; // Force everything to be fully checked, just in case @@ -4921,9 +4912,12 @@ namespace Slang return new ConstantIntVal(expr->value); } + Linkage* getLinkage() { return m_linkage; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + Name* getName(String const& text) { - return getCompileRequest()->getNamePool()->getName(text); + return getNamePool()->getName(text); } RefPtr<IntVal> TryConstantFoldExpr( @@ -5079,58 +5073,18 @@ namespace Slang { auto varDecl = varRef.getDecl(); - switch(getSourceLanguage()) + // In HLSL, `static const` is used to mark compile-time constant expressions + if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) { - default: - case SourceLanguage::Slang: - case SourceLanguage::HLSL: - // HLSL: `static const` is used to mark compile-time constant expressions - if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) - { - if(auto constAttr = varDecl->FindModifier<ConstModifier>()) - { - // HLSL `static const` can be used as a constant expression - if(auto initExpr = getInitExpr(varRef)) - { - return TryConstantFoldExpr(initExpr.Ptr()); - } - } - } - break; - - case SourceLanguage::GLSL: - // GLSL: `const` indicates compile-time constant expression - // - // TODO(tfoley): The current logic here isn't robust against - // GLSL "specialization constants" - we will extract the - // initializer for a `const` variable and use it to extract - // a value, when we really should be using an opaque - // reference to the variable. if(auto constAttr = varDecl->FindModifier<ConstModifier>()) { - // We need to handle a "specialization constant" (with a `constant_id` layout modifier) - // differently from an ordinary compile-time constant. The latter can/should be reduced - // to a value, while the former should be kept as a symbolic reference - - if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>()) - { - // Retain the specialization constant as a symbolic reference - // - // TODO(tfoley): handle the case of non-`int` value parameters... - // - // TODO(tfoley): this is cloned from the case above that handles generic value parameters - return new GenericParamIntVal(varRef); - } - else if(auto initExpr = getInitExpr(varRef)) + // HLSL `static const` can be used as a constant expression + if(auto initExpr = getInitExpr(varRef)) { - // This is an ordinary constant, and not a specialization constant, so we - // can try to fold its value right now. return TryConstantFoldExpr(initExpr.Ptr()); } } - break; } - } else if(auto enumRef = declRef.as<EnumCaseDecl>()) { @@ -9060,17 +9014,32 @@ namespace Slang auto scope = decl->scope; // Try to load a module matching the name - auto importedModuleDecl = findOrImportModule(request, name, decl->moduleNameAndLoc.loc); + auto importedModule = findOrImportModule( + getLinkage(), + name, + decl->moduleNameAndLoc.loc, + getSink()); // If we didn't find a matching module, then bail out - if (!importedModuleDecl) + if (!importedModule) return; // Record the module that was imported, so that we can use // it later during code generation. + auto importedModuleDecl = importedModule->getModuleDecl(); decl->importedModuleDecl = importedModuleDecl; - importModuleIntoScope(scope.Ptr(), importedModuleDecl.Ptr()); + // Add the declarations from the imported module into the scope + // that the `import` declaration is set to extend. + // + importModuleIntoScope(scope.Ptr(), importedModuleDecl); + + // Record the `import`ed module (and everything it depends on) + // as a dependency of the module we are compiling. + if(auto module = getModule(decl)) + { + module->addModuleDependency(importedModule); + } decl->SetCheckState(getCheckedState()); } @@ -9142,29 +9111,25 @@ namespace Slang return (!decl->primaryDecl) || (decl == decl->primaryDecl); } - RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp) + RefPtr<Type> checkProperType( + Linkage* linkage, + TypeExp typeExp, + DiagnosticSink* sink) { - RefPtr<Type> type; - DiagnosticSink nSink; - nSink.sourceManager = tu->compileRequest->sourceManager; SemanticsVisitor visitor( - &nSink, - tu->compileRequest, - tu); + linkage, + sink); auto typeOut = visitor.CheckProperType(typeExp); - if (!nSink.errorCount) - { - type = typeOut.type; - } - return type; + return typeOut.type; } - FuncDecl* findFunctionDeclByName(EntryPointRequest* entryPoint, Name* name) + FuncDecl* findFunctionDeclByName( + Module* translationUnit, + Name* name, + DiagnosticSink* sink) { - auto translationUnit = entryPoint->getTranslationUnit(); - auto sink = &entryPoint->compileRequest->mSink; - auto translationUnitSyntax = translationUnit->SyntaxNode; + auto translationUnitSyntax = translationUnit->getModuleDecl(); // Make sure we've got a query-able member dictionary buildMemberDictionary(translationUnitSyntax); @@ -9270,7 +9235,9 @@ namespace Slang // Validate that an entry point function conforms to any additional // constraints based on the stage (and profile?) it specifies. void validateEntryPoint( - EntryPointRequest* entryPoint) + FuncDecl* entryPointFuncDecl, + Stage stage, + DiagnosticSink* sink) { // TODO: We currently do minimal checking here, but this is the // right place to perform the following validation checks: @@ -9297,28 +9264,32 @@ namespace Slang // that function is specific to the fragment profile/stage. // - auto sink = &entryPoint->compileRequest->mSink; + auto entryPointName = entryPointFuncDecl->getName(); + + auto module = getModule(entryPointFuncDecl); + auto linkage = module->getLinkage(); + // Every entry point needs to have a stage specified either via // command-line/API options, or via an explicit `[shader("...")]` attribute. // - if( entryPoint->getStage() == Stage::Unknown ) + if( stage == Stage::Unknown ) { - sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::entryPointHasNoStage, entryPoint->name); + sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointHasNoStage, entryPointName); } - if (entryPoint->getStage() == Stage::Hull) + if( stage == Stage::Hull ) { - auto translationUnit = entryPoint->getTranslationUnit(); - auto translationUnitSyntax = translationUnit->SyntaxNode; + // TODO: We could consider *always* checking any `[patchconsantfunc("...")]` + // attributes, so that they need to resolve to a function. - auto attr = entryPoint->getFuncDecl()->FindModifier<PatchConstantFuncAttribute>(); + auto attr = entryPointFuncDecl->FindModifier<PatchConstantFuncAttribute>(); if (attr) { if (attr->args.Count() != 1) { - sink->diagnose(translationUnitSyntax, Diagnostics::badlyDefinedPatchConstantFunc, entryPoint->name); + sink->diagnose(attr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); return; } @@ -9327,40 +9298,52 @@ namespace Slang if (!stringLit) { - sink->diagnose(translationUnitSyntax, Diagnostics::badlyDefinedPatchConstantFunc, entryPoint->name); + sink->diagnose(expr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); return; } - Name* name = entryPoint->compileRequest->getNamePool()->getName(stringLit->value); - FuncDecl* funcDecl = findFunctionDeclByName(entryPoint, name); - if (!funcDecl) + // We look up the patch-constant function by its name in the module + // scope of the translation unit that declared the HS entry point. + // + // TODO: Eventually we probably want to do the lookup in the scope + // of the parent declarations of the entry point. E.g., if the entry + // point is a member function of a `struct`, then its patch-constant + // function should be allowed to be another member function of + // the same `struct`. + // + // In the extremely long run we may want to support an alternative to + // this attribute-based linkage between the two functions that + // make up the entry point. + // + Name* name = linkage->getNamePool()->getName(stringLit->value); + FuncDecl* patchConstantFuncDecl = findFunctionDeclByName( + module, + name, + sink); + if (!patchConstantFuncDecl) { - sink->diagnose(translationUnitSyntax, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc"); + sink->diagnose(expr, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc"); return; } - attr->patchConstantFuncDecl = funcDecl; + attr->patchConstantFuncDecl = patchConstantFuncDecl; } } - else if (entryPoint->getStage() == Stage::Compute) + else if(stage == Stage::Compute) { - auto funcDecl = entryPoint->getFuncDecl(); - - auto params = funcDecl->GetParameters(); - - for (const auto& param : params) + for(const auto& param : entryPointFuncDecl->GetParameters()) { - if (auto semantic = param->FindModifier<HLSLSimpleSemantic>()) + if(auto semantic = param->FindModifier<HLSLSimpleSemantic>()) { const auto& semanticToken = semantic->name; String lowerName = String(semanticToken.Content).ToLower(); - if (lowerName == "sv_dispatchthreadid") + if(lowerName == "sv_dispatchthreadid") { Type* paramType = param->getType(); - if (!isValidThreadDispatchIDType(paramType)) + if(!isValidThreadDispatchIDType(paramType)) { String typeString = paramType->ToString(); sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString); @@ -9372,26 +9355,30 @@ namespace Slang } } - // Given an `EntryPointRequest` specified via API or command line options, + // Given an entry point specified via API or command line options, // attempt to find a matching AST declaration that implements the specified // entry point. If such a function is found, then validate that it actually // meets the requirements for the selected stage/profile. // - void findAndValidateEntryPoint( - EntryPointRequest* entryPoint) + // Returns an `EntryPoint` object representing the (unspecialized) + // entry point if it is found and validated, and null otherwise. + // + RefPtr<EntryPoint> findAndValidateEntryPoint( + FrontEndEntryPointRequest* entryPointReq) { // The first step in validating the entry point is to find // the (unique) function declaration that matches its name. // - // TODO: We will eventually need to update this logic - // to work by parsing the provided `entryPoint->name` string - // as an expression, so that we can handle more complex - // names like `foo<int>` or `SomeType.vs`. - - auto translationUnit = entryPoint->getTranslationUnit(); - auto sink = &entryPoint->compileRequest->mSink; - auto translationUnitSyntax = translationUnit->SyntaxNode; + // TODO: We may eventually want/need to extend this to + // account for nested names like `SomeStruct.vsMain`, or + // indeed even to handle generics. + // + auto compileRequest = entryPointReq->getCompileRequest(); + auto translationUnit = entryPointReq->getTranslationUnit(); + auto sink = compileRequest->getSink(); + auto translationUnitSyntax = translationUnit->getModuleDecl(); + auto entryPointName = entryPointReq->getName(); // Make sure we've got a query-able member dictionary buildMemberDictionary(translationUnitSyntax); @@ -9399,12 +9386,12 @@ namespace Slang // We will look up any global-scope declarations in the translation // unit that match the name of our entry point. Decl* firstDeclWithName = nullptr; - if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPoint->name, firstDeclWithName) ) + if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPointName, firstDeclWithName) ) { // If there doesn't appear to be any such declaration, then // we need to diagnose it as an error, and then bail out. - sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPoint->name); - return; + sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPointName); + return nullptr; } // We found at least one global-scope declaration with the right name, @@ -9448,7 +9435,7 @@ namespace Slang // name before, so the whole thing is ambiguous. We need // to diagnose and bail out. - sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPoint->name); + sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPointName); // List all of the declarations that the user *might* mean for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName) @@ -9460,7 +9447,7 @@ namespace Slang } // Bail out. - return; + return nullptr; } } } @@ -9471,127 +9458,197 @@ namespace Slang // If not, then we need to diagnose the error. // For convenience, we will point to the first // declaration with the right name, that wasn't a function. - sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPoint->name); - return; + sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPointName); + return nullptr; } + // TODO: it is possible that the entry point was declared with + // profile or target overloading. Is there anything that we need + // to do at this point to filter out declarations that aren't + // relevant to the selected profile for the entry point? + + // We found something, and can start doing some basic checking. + // // If the entry point specifies a stage via a `[shader("...")]` attribute, // then we might be able to infer a stage for the entry point request if // it didn't have one, *or* issue a diagnostic if there is a mismatch. // + auto entryPointProfile = entryPointReq->getProfile(); if( auto entryPointAttribute = entryPointFuncDecl->FindModifier<EntryPointAttribute>() ) { - if( entryPoint->getStage() == Stage::Unknown ) + auto entryPointStage = entryPointProfile.GetStage(); + if( entryPointStage == Stage::Unknown ) { - entryPoint->profile.setStage(entryPointAttribute->stage); + entryPointProfile.setStage(entryPointAttribute->stage); } - else if( entryPointAttribute->stage != entryPoint->getStage() ) + else if( entryPointAttribute->stage != entryPointStage ) { - sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPoint->name, entryPoint->getStage(), entryPointAttribute->stage); + sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPointName, entryPointStage, entryPointAttribute->stage); } } + else + { + // TODO: Should we attach a `[shader(...)]` attribute to an + // entry point that didn't have one, so that we can have + // a more uniform representation in the AST? + } - // TODO: it is possible that the entry point was declared with - // profile or target overloading. Is there anything that we need - // to do at this point to filter out declarations that aren't - // relevant to the selected profile for the entry point? - // Phew, we have at least found a suitable decl. - // Let's record that in the entry-point request so - // that we don't have to re-do this effort again later. - // - // Note: we may replace the decl-ref we store at this point - // later in this function, when we (potentially) specialize - // a generic entry point to generic arguments provided - // via the API. + // Now that we've *found* the entry point, it is time to validate + // that it actually meets the constraints for the chosen stage/profile. // - entryPoint->funcDeclRef = makeDeclRef(entryPointFuncDecl); + validateEntryPoint( + entryPointFuncDecl, + entryPointProfile.GetStage(), + sink); - // If the user specified generic arguments for the entry point, - // then we will want to parse those arguments as expressions - // in a scope that includes the tanslation unit that holds - // the entry point, as well as any other modules that got - // transitively loaded via `import`. - // - // TODO: This would be better handled by giving the user - // more explicit ways to parse/build types at the API level, - // rather than keeping things string-based this far along. + RefPtr<EntryPoint> entryPoint = EntryPoint::create( + makeDeclRef(entryPointFuncDecl), + entryPointProfile); + + return entryPoint; + } + + /// Create a `Program` to represent the compiled code. + /// + /// The created program will comprise all of the translation + /// units that were compiled as part of the request, as + /// well as any entry points in those translation units. + /// + RefPtr<Program> createUnspecializedProgram( + FrontEndCompileRequest* compileRequest) + { + // We want our resulting program to depend on + // all the translation units the user specified, + // even if some of them don't contain entry points + // (this is important for parameter layout/binding). // - // TODO: Building a list of `scopesToTry` here shouldn't - // be required, since the `Scope` type itself has the ability - // for form chains for lookup purposes (e.g., the way that - // `import` is handled by modifying a scope). + // We also want to ensure that the modules for the + // translation units comes first in the enumerated + // order for dependencies, to match the pre-existing + // compiler behavior (at least for now). // - List<RefPtr<Scope>> scopesToTry; - scopesToTry.Add(entryPoint->getTranslationUnit()->SyntaxNode->scope); - for (auto & module : entryPoint->compileRequest->loadedModulesList) - scopesToTry.Add(module->moduleDecl->scope); + auto linkage = compileRequest->getLinkage(); + auto sink = compileRequest->getSink(); + auto program = new Program(linkage); + for(auto translationUnit : compileRequest->translationUnits ) + { + program->addReferencedLeafModule(translationUnit->getModule()); + } + for(auto translationUnit : compileRequest->translationUnits ) + { + program->addReferencedModule(translationUnit->getModule()); + } - // We are going to do some semantic checking, so we need to - // set up a `SemanticsVistitor` that we can use. - // - SemanticsVisitor semantics( - &entryPoint->compileRequest->mSink, - entryPoint->compileRequest, - entryPoint->getTranslationUnit()); - // We will be looping over the generic argument strings - // that the user provided via the API (or command line), - // and parsing+checking each into an `Expr`. + // The validation of entry points here will be modal, and controlled + // by whether the user specified any entry points directly via + // API or command-line options. // - // This loop will *not* handle coercing the arguments - // to be types. + // TODO: We may want to make this choice explicit rather than implicit. // - List<RefPtr<Expr>> genericArgs; - for (auto name : entryPoint->genericArgStrings) + // First, check if the user requested any entry points explicitly via + // the API or command line. + // + bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0; + + if( anyExplicitEntryPoints ) { - RefPtr<Expr> argExpr; - for (auto & s : scopesToTry) + // If there were any explicit requests for entry points to be + // checked, then we will *only* check those. + // + for(auto entryPointReq : compileRequest->getEntryPointReqs()) { - argExpr = entryPoint->compileRequest->parseTypeString( - entryPoint->getTranslationUnit(), - name, - s); - argExpr = semantics.CheckTerm(argExpr); - if( argExpr ) + auto entryPoint = findAndValidateEntryPoint( + entryPointReq); + if( entryPoint ) { - break; + program->addEntryPoint(entryPoint); + entryPointReq->getTranslationUnit()->entryPoints.Add(entryPoint); } } - // The following is a bit of a hack. - // - // Back-end code generation relies on us having computed layouts for all tagged - // unions that end up being used in the code, which means we need a way to find - // all such types that get used in a module (and the stuff it imports). + // TODO: We should consider always processing both categories, + // and just making sure to only check each entry point function + // declaration once... + } + else + { + // Otherwise, scan for any `[shader(...)]` attributes in + // the user's code, and construct `EntryPoint`s to + // represent them. // - // The Right Way to handle this would probably be to have each `ModuleDecl` track - // any tagged union types that get created in the context of that module, and - // then combine those lists later. + // This ensures that downstream code only has to consider + // the central list of entry point requests, and doesn't + // have to know where they came from. + + // TODO: A comprehensive approach here would need to search + // recursively for entry points, because they might appear + // as, e.g., member function of a `struct` type. // - // For now we are assuming a tagged union type only comes into existence - // as a (top-level) argument for a generic type parameter, so that we - // can check for them here and cache them on the entry point request. + // For now we'll start with an extremely basic approach that + // should work for typical HLSL code. // - if( auto typeType = as<TypeType>(argExpr->type) ) + UInt translationUnitCount = compileRequest->translationUnits.Count(); + for(UInt tt = 0; tt < translationUnitCount; ++tt) { - auto type = typeType->type; - if( auto taggedUnionType = as<TaggedUnionType>(type) ) + auto translationUnit = compileRequest->translationUnits[tt]; + for( auto globalDecl : translationUnit->getModuleDecl()->Members ) { - entryPoint->taggedUnionTypes.Add(taggedUnionType); + auto maybeFuncDecl = globalDecl; + if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) ) + { + maybeFuncDecl = genericDecl->inner; + } + + auto funcDecl = as<FuncDecl>(maybeFuncDecl); + if(!funcDecl) + continue; + + auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>(); + if(!entryPointAttr) + continue; + + // We've discovered a valid entry point. It is a function (possibly + // generic) that has a `[shader(...)]` attribute to mark it as an + // entry point. + // + // We will now register that entry point as an `EntryPoint` + // with an appropriately chosen profile. + // + // The profile will only include a stage, so that the profile "family" + // and "version" are left unspecified. Downstream code will need + // to be able to handle this case. + // + Profile profile; + profile.setStage(entryPointAttr->stage); + + validateEntryPoint(funcDecl, entryPointAttr->stage, sink); + + RefPtr<EntryPoint> entryPoint = EntryPoint::create( + makeDeclRef(funcDecl), + profile); + program->addEntryPoint(entryPoint); + translationUnit->entryPoints.Add(entryPoint); } } - - genericArgs.Add(argExpr); } - // There are two cases we care about here, and we are going to treat them - // as mutually exclusive for simplicity. - // - // The first case is when the entry point function is itself generic, - // in which case we will assume that `genericArgs` lines up one-to-one - // with the explicit generic parameters of the entry point. - // + return program; + } + + /// Create a specialization an existing entry point based on generic arguments. + DeclRef<FuncDecl> specializeEntryPoint( + Linkage* linkage, + FuncDecl* entryPointFuncDecl, + List<RefPtr<Expr>> const& genericArgs, + DiagnosticSink* sink) + { + SemanticsVisitor semantics( + linkage, + sink); + + DeclRef<FuncDecl> entryPointFuncDeclRef = makeDeclRef(entryPointFuncDecl); if( auto genericDecl = as<GenericDecl>(entryPointFuncDecl->ParentDecl) ) { // We will construct a suitable `GenericAppExpr` to represent @@ -9601,7 +9658,7 @@ namespace Slang // generic application like `F<A,B,C>` if it were // encountered in the source code. - auto session = entryPoint->compileRequest->mSession; + auto session = linkage->getSession(); auto genericDeclRef = makeDeclRef(genericDecl); // The first pieces is a `VarExpr` that refers to `genericDecl`. @@ -9639,12 +9696,13 @@ namespace Slang // The basic `VarExpr` and `StaticMemberExpr` cases // should be allow-able. - entryPoint->funcDeclRef = declRefExpr->declRef.as<FuncDecl>(); + entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>(); } else if( semantics.IsErrorExpr(checkedExpr) ) { // Any semantic error that occured should have been // reported already. + return DeclRef<FuncDecl>(); } else { @@ -9652,302 +9710,359 @@ namespace Slang // function should always be a `DeclRefExpr` // SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`"); + UNREACHABLE_RETURN(DeclRef<FuncDecl>()); } } - else - { - // The other case is when the entry point function is *not* itself - // generic, so we assume that any generic arguments must have been intended - // to match up with global generic parameters instead. - // - // We will only validate global generic type arguments when we are going - // to generate code, since in a no-codegen pass we will typically *not* - // have arguments to associate with the parameters. - // - if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) - { - // check that user-provioded type arguments conforms to the generic type - // parameter declaration of this translation unit - // collect global generic parameters from all imported modules - List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; - // add current translation unit first - { - auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - // add imported modules - for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) - { - auto moduleDecl = loadedModule->moduleDecl; - auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - - if (globalGenericParams.Count() != genericArgs.Count()) - { - sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::mismatchEntryPointTypeArgument, - globalGenericParams.Count(), - genericArgs.Count()); - return; - } - - // We have an appropriate number of arguments for the global generic parameters, - // and now we need to check that the arguments conform to the declared constraints. - // - // Along the way, we will build up an appropriate set of substitutions to represent - // the generic arguments and their conformances. - // - RefPtr<Substitutions> globalGenericSubsts; - auto globalGenericSubstLink = &globalGenericSubsts; - // - // TODO: There is a serious flaw to this checking logic if we ever have cases where - // the constraints on one `type_param` can depend on another `type_param`, e.g.: - // - // type_param A; - // type_param B : ISidekick<A>; - // - // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to - // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being - // set to `Batman` to know whether the setting for `B` is valid. In this limit - // the constraints can be mutually recursive (so `A : IMentor<B>`). - // - // The only way to check things correctly is to validate each conformance under - // a set of assumptions (substitutions) that includes all the type substitutions, - // and possibly also all the other constraints *except* the one to be validated. - // - // We will punt on this for now, and just check each constraint in isolation. - // - UInt argCounter = 0; - for(auto& globalGenericParam : globalGenericParams) - { - // Get the argument that matches this parameter. - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < genericArgs.Count()); - auto globalGenericArg = checkProperType(translationUnit, TypeExp(genericArgs[argIndex])); - if (!globalGenericArg) - { - sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, entryPoint->genericArgStrings[argIndex]); - return; - } + return entryPointFuncDeclRef; + } - // As a quick sanity check, see if the argument that is being supplied for a parameter - // is just the parameter itself, because this should always be an error: - // - if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() ) - { - auto argDeclRef = argDeclRefType->declRef; - if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) - { - if(argGenericParamDeclRef.getDecl() == globalGenericParam) - { - // We are trying to specialize a generic parameter using itself. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToItself, - globalGenericParam->getName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; - } - else - { - // We are trying to specialize a generic parameter using a *different* - // global generic type parameter. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, - globalGenericParam->getName(), - argGenericParamDeclRef.GetName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; - } - } - } + /// Parse an array of strings as generic arguments. + /// + /// Names in the strings will be parsed in the context of + /// the code loaded into the given compile request. + /// + void parseGenericArgStrings( + EndToEndCompileRequest* endToEndReq, + List<String> const& genericArgStrings, + List<RefPtr<Expr>>& outGenericArgs) + { + auto unspecialiedProgram = endToEndReq->getUnspecializedProgram(); - // Create a substitution for this parameter/argument. - RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); - subst->paramDecl = globalGenericParam; - subst->actualType = globalGenericArg; + // TODO: Building a list of `scopesToTry` here shouldn't + // be required, since the `Scope` type itself has the ability + // for form chains for lookup purposes (e.g., the way that + // `import` is handled by modifying a scope). + // + List<RefPtr<Scope>> scopesToTry; + for( auto module : unspecialiedProgram->getModuleDependencies() ) + scopesToTry.Add(module->getModuleDecl()->scope); - // Walk through the declared constraints for the parameter, - // and check that the argument actually satisfies them. - for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) - { - // Get the type that the constraint is enforcing conformance to - auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + // We are going to do some semantic checking, so we need to + // set up a `SemanticsVistitor` that we can use. + // + auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); + SemanticsVisitor semantics( + linkage, + sink); - // Use our semantic-checking logic to search for a witness to the required conformance - SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); - auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); - if (!witness) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(globalGenericParam, - Diagnostics::typeArgumentDoesNotConformToInterface, - globalGenericParam->nameAndLoc.name, - globalGenericArg, - interfaceType); - } + // We will be looping over the generic argument strings + // that the user provided via the API (or command line), + // and parsing+checking each into an `Expr`. + // + // This loop will *not* handle coercing the arguments + // to be types. + // + for(auto name : genericArgStrings) + { + RefPtr<Expr> argExpr; + for (auto & s : scopesToTry) + { + argExpr = linkage->parseTypeString(name, s); + argExpr = semantics.CheckTerm(argExpr); + if( argExpr ) + { + break; + } + } - // Attach the concrete witness for this conformance to the - // substutiton - GlobalGenericParamSubstitution::ConstraintArg constraintArg; - constraintArg.decl = constraint; - constraintArg.val = witness; - subst->constraintArgs.Add(constraintArg); - } + outGenericArgs.Add(argExpr); + } + } - // Add the substitution for this parameter to the global substitution - // set that we are building. + /// Specialize a program to global generic arguments + RefPtr<Program> createSpecializedProgram( + Linkage* linkage, + Program* unspecializedProgram, + List<RefPtr<Expr>> const& globalGenericArgs, + DiagnosticSink* sink) + { + // The given `unspecializedProgram` should be one that + // was checked through the front-end, so that now we + // only need to check if the given arguments can satisfy + // the requirements of the global generic parameters. + // + // The new program needs to start off with the same + // module dependency list as the original. + // + RefPtr<Program> specializedProgram = new Program(linkage); + for(auto module : unspecializedProgram->getModuleDependencies()) + { + specializedProgram->addReferencedLeafModule(module); + } - *globalGenericSubstLink = subst; - globalGenericSubstLink = &subst->outer; - } - entryPoint->globalGenericSubst = globalGenericSubsts; - } + // We will collect all the global generic parameters + // defined in the modules being referenced, to find + // the global generic parameter signature of the + // program. + // + // TODO: Note that this doesn't handle the case where one + // or more of the type *arguments* that we are specifying + // ends up requiring additional modules to be referenced, + // which might in turn introduce new global generic parameters. + // + List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; + for(auto module : unspecializedProgram->getModuleDependencies()) + { + for(auto param : module->getModuleDecl()->getMembersOfType<GlobalGenericParamDecl>()) + globalGenericParams.Add(param); } - // If any errors occured while we were checking the generic arguments - // of the entry point, then we should bail out rather than try to - // perform the next step of validation. + // Next, we will check whether the supplied arguments can + // satisfy those parameters. + // + // An easy early-out case will be if the number of + // arguments isn't correct. // - if (sink->errorCount != 0) - return; + if (globalGenericParams.Count() != globalGenericArgs.Count()) + { + sink->diagnose(SourceLoc(), Diagnostics::mismatchGlobalGenericArguments, + globalGenericParams.Count(), + globalGenericArgs.Count()); + return nullptr; + } - // Now that we've *found* the entry point, it is time to validate - // that it actually meets the constraints for the chosen stage/profile. + // We have an appropriate number of arguments for the global generic parameters, + // and now we need to check that the arguments conform to the declared constraints. // - // TODO: This validation should (probably?) be performed "under" any global generic - // parameter substitution we might have created, so that we can validate - // based on knowledge of actual types. + // Along the way, we will build up an appropriate set of substitutions to represent + // the generic arguments and their conformances. // - validateEntryPoint(entryPoint); - } - - void validateEntryPoints( - CompileRequest* compileRequest) - { - // The validation of entry points here will be modal, and controlled - // by whether the user specified any entry points directly via - // API or command-line options. + RefPtr<Substitutions> globalGenericSubsts; + auto globalGenericSubstLink = &globalGenericSubsts; // - // TODO: We may want to make this choice explicit rather than implicit. + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: // - // First, check if the user request any entry points explicitly via - // the API or command line. + // type_param A; + // type_param B : ISidekick<A>; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor<B>`). // - bool anyExplicitEntryPointRequests = false; - for (auto& translationUnit : compileRequest->translationUnits) + // The only way to check things correctly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. + // + UInt argCounter = 0; + for(auto& globalGenericParam : globalGenericParams) { - if( translationUnit->entryPoints.Count() != 0) + // Get the argument that matches this parameter. + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < globalGenericArgs.Count()); + auto globalGenericArg = checkProperType(linkage, TypeExp(globalGenericArgs[argIndex]), sink); + if (!globalGenericArg) { - anyExplicitEntryPointRequests = true; - break; + sink->diagnose(globalGenericParam, Diagnostics::globalGenericArgumentNotAType, globalGenericParam->getName()); + return nullptr; } - } - if( anyExplicitEntryPointRequests ) - { - // If there were any explicit requests for entry points to be - // checked, then we will *only* check those. - - for (auto& translationUnit : compileRequest->translationUnits) + // As a quick sanity check, see if the argument that is being supplied for a parameter + // is just the parameter itself, because this should always be an error: + // + if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() ) { - for (auto entryPoint : translationUnit->entryPoints) + auto argDeclRef = argDeclRefType->declRef; + if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) { - findAndValidateEntryPoint(entryPoint); + if(argGenericParamDeclRef.getDecl() == globalGenericParam) + { + // We are trying to specialize a generic parameter using itself. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToItself, + globalGenericParam->getName()); + continue; + } + else + { + // We are trying to specialize a generic parameter using a *different* + // global generic type parameter. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, + globalGenericParam->getName(), + argGenericParamDeclRef.GetName()); + continue; + } } } - } - else - { - // Otherwise, scan for any `[shader(...)]` attributes in - // the user's code, and construct `EntryPointRequest`s to - // represent them. - // - // This ensures that downstream code only has to consider - // the central list of entry point requests, and doesn't - // have to know where they came from. - // TODO: A comprehensive approach here would need to search - // recursively for entry points, because they might appear - // as, e.g., member function of a `struct` type. - // - // For now we'll start with an extremely basic approach that - // should work for typical HLSL code. - // - UInt translationUnitCount = compileRequest->translationUnits.Count(); - for(UInt tt = 0; tt < translationUnitCount; ++tt) + // Create a substitution for this parameter/argument. + RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); + subst->paramDecl = globalGenericParam; + subst->actualType = globalGenericArg; + + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) { - auto translationUnit = compileRequest->translationUnits[tt]; - for( auto globalDecl : translationUnit->SyntaxNode->Members ) + // Get the type that the constraint is enforcing conformance to + auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + + // Use our semantic-checking logic to search for a witness to the required conformance + SemanticsVisitor visitor(linkage, sink); + auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); + if (!witness) { - auto maybeFuncDecl = globalDecl; - if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) ) - { - maybeFuncDecl = genericDecl->inner; - } + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(globalGenericParam, + Diagnostics::typeArgumentDoesNotConformToInterface, + globalGenericParam->nameAndLoc.name, + globalGenericArg, + interfaceType); + } - auto funcDecl = as<FuncDecl>(maybeFuncDecl); - if(!funcDecl) - continue; + // Attach the concrete witness for this conformance to the + // substutiton + GlobalGenericParamSubstitution::ConstraintArg constraintArg; + constraintArg.decl = constraint; + constraintArg.val = witness; + subst->constraintArgs.Add(constraintArg); + } - auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>(); - if(!entryPointAttr) - continue; + // Add the substitution for this parameter to the global substitution + // set that we are building. - // We've discovered a valid entry point. It is a function (possibly - // generic) that has a `[shader(...)]` attribute to mark it as an - // entry point. - // - // We will now register that entry point as an `EntryPointRequest` - // with an appropriately chosen profile. - // - // The profile will only include a stage, so that the profile "family" - // and "version" are left unspecified. Downstream code will need - // to be able to handle this case. - // - Profile profile; - profile.setStage(entryPointAttr->stage); + *globalGenericSubstLink = subst; + globalGenericSubstLink = &subst->outer; + } + if(sink->GetErrorCount()) + return nullptr; - // We manually fill in the entry point request object. - RefPtr<EntryPointRequest> entryPointReq = new EntryPointRequest(); - entryPointReq->compileRequest = compileRequest; - entryPointReq->translationUnitIndex = int(tt); - entryPointReq->funcDeclRef = makeDeclRef(funcDecl); - entryPointReq->name = funcDecl->getName(); - entryPointReq->profile = profile; + specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts); - // Apply the common validation logic to this entry point. - validateEntryPoint(entryPointReq); + return specializedProgram; + } - // Add the entry point to the list in the translation unit, - // and also the global list in the compile request. - translationUnit->entryPoints.Add(entryPointReq); - compileRequest->entryPoints.Add(entryPointReq); - } - } + /// Specialize an entry point that was checked by the front-end, based on generic arguments. + /// + /// If the end-to-end compile request included generic argument strings + /// for this entry point, then they will be parsed, checked, and used + /// as arguments to the generic entry point. + /// + /// Returns a specialized entry point if everything worked as expected. + /// Returns null and diagnoses errors if anything goes wrong. + /// + RefPtr<EntryPoint> specializeEntryPoint( + EndToEndCompileRequest* endToEndReq, + EntryPoint* unspecializedEntryPoint, + EndToEndCompileRequest::EntryPointInfo const& entryPointInfo) + { + auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); + auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl(); + + // If the user specified generic arguments for the entry point, + // then we will need to parse the arguments first. + // + List<RefPtr<Expr>> genericArgs; + parseGenericArgStrings( + endToEndReq, + entryPointInfo.genericArgStrings, + genericArgs); + + // Next we specialize the entry point function given the parsed + // generic argument expressions. + // + auto entryPointFuncDeclRef = specializeEntryPoint( + linkage, + entryPointFuncDecl, + genericArgs, + sink); + + RefPtr<EntryPoint> entryPoint = EntryPoint::create( + entryPointFuncDeclRef, + unspecializedEntryPoint->getProfile()); + + return entryPoint; + } + + /// Create a specialized program based on the given compile request. + /// + RefPtr<Program> createSpecializedProgram( + EndToEndCompileRequest* endToEndReq) + { + // The compile request must have already completed front-end processing, + // so that we have an unspecialized program available, and now only need + // to parse and check any generic arguments that are being supplied for + // global or entry-point generic parameters. + // + auto unspecializedProgram = endToEndReq->getUnspecializedProgram(); + + // First, let's parse the generic argument strings that were + // provided via the API, so taht we can match them + // against what was declared in the program. + // + List<RefPtr<Expr>> globalGenericArgs; + parseGenericArgStrings( + endToEndReq, + endToEndReq->globalGenericArgStrings, + globalGenericArgs); + + // Now we create the initial specialized program by + // applying the global generic arguments (if any) to the + // unspecialized program. + // + auto specializedProgram = createSpecializedProgram( + endToEndReq->getLinkage(), + unspecializedProgram, + globalGenericArgs, + endToEndReq->getSink()); + + // If anything went wrong with the global generic + // arguments, then bail out now. + // + if(!specializedProgram) + return nullptr; + + // Next we will deal with the entry points for the + // new specialized program. + // + // If the user specified explicit entry points as part of the + // end-to-end request, then we only want to process those (and + // ignore any other `[shader(...)]`-attributed entry points). + // + // However, if the user specified *no* entry points as part + // of the end-to-end request, then we would like to go + // ahead and consider all the entry points that were found + // by the front-end. + // + UInt entryPointCount = endToEndReq->entryPoints.Count(); + if( entryPointCount == 0 ) + { + entryPointCount = unspecializedProgram->getEntryPointCount(); + endToEndReq->entryPoints.SetSize(entryPointCount); } + + for( UInt ii = 0; ii < entryPointCount; ++ii ) + { + auto unspecializedEntryPoint = unspecializedProgram->getEntryPoint(ii); + auto& entryPointInfo = endToEndReq->entryPoints[ii]; + + auto specializedEntryPoint = specializeEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); + specializedProgram->addEntryPoint(specializedEntryPoint); + } + + return specializedProgram; } void checkTranslationUnit( TranslationUnitRequest* translationUnit) { SemanticsVisitor visitor( - &translationUnit->compileRequest->mSink, - translationUnit->compileRequest, - translationUnit); + translationUnit->compileRequest->getLinkage(), + translationUnit->compileRequest->getSink()); // Apply the visitor to do the main semantic // checking that is required on all declarations // in the translation unit. - visitor.checkDecl(translationUnit->SyntaxNode); + visitor.checkDecl(translationUnit->getModuleDecl()); } |
