diff options
47 files changed, 3294 insertions, 1905 deletions
diff --git a/examples/model-viewer/main.cpp b/examples/model-viewer/main.cpp index c1f7980cf..8bbc8c979 100644 --- a/examples/model-viewer/main.cpp +++ b/examples/model-viewer/main.cpp @@ -881,22 +881,27 @@ RefPtr<EffectVariant> createEffectVaraint( int translationUnitIndex = spAddTranslationUnit(slangRequest, SLANG_SOURCE_LANGUAGE_SLANG, nullptr); spAddTranslationUnitSourceFile(slangRequest, translationUnitIndex, program->shaderModule->inputPath.c_str()); - const int entryPointCont = int(program->entryPoints.size()); - for(int ii = 0; ii < entryPointCont; ++ii) + // Because our shader code uses global generic parameters for + // specialization, we need to specify the concrete argument + // types for the compiler to use when generating code. + // + spSetGlobalGenericArgs( + slangRequest, + int(genericArgs.size()), + genericArgs.data()); + + // Next we tell the Slang compiler about all of the entry points + // we plan to use. + // + const int entryPointCount = int(program->entryPoints.size()); + for(int ii = 0; ii < entryPointCount; ++ii) { auto entryPoint = program->entryPoints[ii]; - - // We are using the `spAddEntryPointEx` API so that we - // can specify the type names to use for the generic - // type parameters of the program. - // - spAddEntryPointEx( + spAddEntryPoint( slangRequest, translationUnitIndex, entryPoint->name.c_str(), - entryPoint->slangStage, - int(genericArgs.size()), - genericArgs.data()); + entryPoint->slangStage); } // We expect compilation to go through without a hitch, because the @@ -923,7 +928,7 @@ RefPtr<EffectVariant> createEffectVaraint( // std::vector<ISlangBlob*> kernelBlobs; std::vector<gfx::ShaderProgram::KernelDesc> kernelDescs; - for(int ii = 0; ii < entryPointCont; ++ii) + for(int ii = 0; ii < entryPointCount; ++ii) { auto entryPoint = program->entryPoints[ii]; diff --git a/external/glslang b/external/glslang -Subproject f6e7c4d2de0d59724ea07739df70c466d169a2c +Subproject 4207c97b938078818140edad101a032cf768191 @@ -923,6 +923,9 @@ extern "C" */ typedef struct SlangSession SlangSession; + typedef struct SlangLinkage SlangLinkage; + typedef struct SlangModule SlangModule; + /*! @brief A request for one or more compilation actions to be performed. */ @@ -989,6 +992,20 @@ extern "C" char const* sourcePath, char const* sourceString); + + + SLANG_API SlangLinkage* spCreateLinkage( + SlangSession* session); + + SLANG_API void spDestroyLinkage( + SlangLinkage* linkage); + + SLANG_API SlangModule* spLoadModule( + SlangLinkage* linkage, + char const* moduleName); + + + /*! @brief Create a compile request. */ @@ -1263,15 +1280,22 @@ extern "C" /** Add an entry point in a particular translation unit, with additional arguments that specify the concrete - type names for global generic type parameters. + type names for entry-point generic type parameters. */ SLANG_API int spAddEntryPointEx( SlangCompileRequest* request, int translationUnitIndex, char const* name, SlangStage stage, - int genericTypeNameCount, - char const** genericTypeNames); + int genericArgCount, + char const** genericArgs); + + /** Specify the arguments to use for global generic parameters. + */ + SLANG_API SlangResult spSetGlobalGenericArgs( + SlangCompileRequest* request, + int genericArgCount, + char const** genericArgs); /** Execute the compilation request. diff --git a/source/core/smart-pointer.h b/source/core/smart-pointer.h index e19ed6a4d..0b03deb8f 100644 --- a/source/core/smart-pointer.h +++ b/source/core/smart-pointer.h @@ -2,6 +2,7 @@ #define FUNDAMENTAL_LIB_SMART_POINTER_H #include "common.h" +#include "hash.h" #include "type-traits.h" #include <assert.h> @@ -157,10 +158,14 @@ namespace Slang releaseReference(old); } - int GetHashCode() - { - return (int)(long long)(void*)pointer; - } + int GetHashCode() + { + // Note: We need a `RefPtr<T>` to hash the same as a `T*`, + // so that a `T*` can be used as a key in a dictionary with + // `RefPtr<T>` keys, and vice versa. + // + return Slang::GetHashCode(pointer); + } bool operator==(const T * ptr) const { diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 3485afeea..483db60bb 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -400,22 +400,18 @@ namespace Slang else return DeclCheckState::CheckedHeader; } - DiagnosticSink* sink = nullptr; + + Linkage* m_linkage = nullptr; + DiagnosticSink* m_sink = nullptr; + DiagnosticSink* getSink() { - return sink; + return m_sink; } // ModuleDecl * program = nullptr; FuncDecl * function = nullptr; - CompileRequest* request = nullptr; - TranslationUnitRequest* translationUnit = nullptr; - - SourceLanguage getSourceLanguage() - { - return translationUnit->sourceLanguage; - } // lexical outer statements List<Stmt*> outerStmts; @@ -429,20 +425,15 @@ namespace Slang public: SemanticsVisitor( - DiagnosticSink* sink, - CompileRequest* request, - TranslationUnitRequest* translationUnit) - : sink(sink) - , request(request) - , translationUnit(translationUnit) - { - } + Linkage* linkage, + DiagnosticSink* sink) + : m_linkage(linkage) + , m_sink(sink) + {} - CompileRequest* getCompileRequest() { return request; } - TranslationUnitRequest* getTranslationUnit() { return translationUnit; } Session* getSession() { - return getCompileRequest()->mSession; + return m_linkage->getSession(); } public: @@ -985,7 +976,7 @@ namespace Slang catch(AbortCompilationException&) { throw; } catch(...) { - getCompileRequest()->noteInternalErrorLoc(decl->loc); + getSink()->noteInternalErrorLoc(decl->loc); throw; } } @@ -998,7 +989,7 @@ namespace Slang catch(AbortCompilationException&) { throw; } catch(...) { - getCompileRequest()->noteInternalErrorLoc(stmt->loc); + getSink()->noteInternalErrorLoc(stmt->loc); throw; } } @@ -1011,7 +1002,7 @@ namespace Slang catch(AbortCompilationException&) { throw; } catch(...) { - getCompileRequest()->noteInternalErrorLoc(expr->loc); + getSink()->noteInternalErrorLoc(expr->loc); throw; } } @@ -1030,7 +1021,7 @@ namespace Slang // being checked on the stack, so that we can report the full // chain that leads from this declaration back to itself. // - sink->diagnose(decl, Diagnostics::cyclicReference, decl); + getSink()->diagnose(decl, Diagnostics::cyclicReference, decl); return; } @@ -1050,7 +1041,7 @@ namespace Slang // TODO: This diagnostic should be emitted on the line that is referencing // the declaration. That requires `EnsureDecl` to take the requesting // location as a parameter. - sink->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl); + getSink()->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl); return; } } @@ -3019,7 +3010,7 @@ namespace Slang checkDecl(func); } - if (sink->GetErrorCount() != 0) + if (getSink()->GetErrorCount() != 0) return; // Force everything to be fully checked, just in case @@ -4921,9 +4912,12 @@ namespace Slang return new ConstantIntVal(expr->value); } + Linkage* getLinkage() { return m_linkage; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + Name* getName(String const& text) { - return getCompileRequest()->getNamePool()->getName(text); + return getNamePool()->getName(text); } RefPtr<IntVal> TryConstantFoldExpr( @@ -5079,58 +5073,18 @@ namespace Slang { auto varDecl = varRef.getDecl(); - switch(getSourceLanguage()) + // In HLSL, `static const` is used to mark compile-time constant expressions + if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) { - default: - case SourceLanguage::Slang: - case SourceLanguage::HLSL: - // HLSL: `static const` is used to mark compile-time constant expressions - if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) - { - if(auto constAttr = varDecl->FindModifier<ConstModifier>()) - { - // HLSL `static const` can be used as a constant expression - if(auto initExpr = getInitExpr(varRef)) - { - return TryConstantFoldExpr(initExpr.Ptr()); - } - } - } - break; - - case SourceLanguage::GLSL: - // GLSL: `const` indicates compile-time constant expression - // - // TODO(tfoley): The current logic here isn't robust against - // GLSL "specialization constants" - we will extract the - // initializer for a `const` variable and use it to extract - // a value, when we really should be using an opaque - // reference to the variable. if(auto constAttr = varDecl->FindModifier<ConstModifier>()) { - // We need to handle a "specialization constant" (with a `constant_id` layout modifier) - // differently from an ordinary compile-time constant. The latter can/should be reduced - // to a value, while the former should be kept as a symbolic reference - - if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>()) - { - // Retain the specialization constant as a symbolic reference - // - // TODO(tfoley): handle the case of non-`int` value parameters... - // - // TODO(tfoley): this is cloned from the case above that handles generic value parameters - return new GenericParamIntVal(varRef); - } - else if(auto initExpr = getInitExpr(varRef)) + // HLSL `static const` can be used as a constant expression + if(auto initExpr = getInitExpr(varRef)) { - // This is an ordinary constant, and not a specialization constant, so we - // can try to fold its value right now. return TryConstantFoldExpr(initExpr.Ptr()); } } - break; } - } else if(auto enumRef = declRef.as<EnumCaseDecl>()) { @@ -9060,17 +9014,32 @@ namespace Slang auto scope = decl->scope; // Try to load a module matching the name - auto importedModuleDecl = findOrImportModule(request, name, decl->moduleNameAndLoc.loc); + auto importedModule = findOrImportModule( + getLinkage(), + name, + decl->moduleNameAndLoc.loc, + getSink()); // If we didn't find a matching module, then bail out - if (!importedModuleDecl) + if (!importedModule) return; // Record the module that was imported, so that we can use // it later during code generation. + auto importedModuleDecl = importedModule->getModuleDecl(); decl->importedModuleDecl = importedModuleDecl; - importModuleIntoScope(scope.Ptr(), importedModuleDecl.Ptr()); + // Add the declarations from the imported module into the scope + // that the `import` declaration is set to extend. + // + importModuleIntoScope(scope.Ptr(), importedModuleDecl); + + // Record the `import`ed module (and everything it depends on) + // as a dependency of the module we are compiling. + if(auto module = getModule(decl)) + { + module->addModuleDependency(importedModule); + } decl->SetCheckState(getCheckedState()); } @@ -9142,29 +9111,25 @@ namespace Slang return (!decl->primaryDecl) || (decl == decl->primaryDecl); } - RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp) + RefPtr<Type> checkProperType( + Linkage* linkage, + TypeExp typeExp, + DiagnosticSink* sink) { - RefPtr<Type> type; - DiagnosticSink nSink; - nSink.sourceManager = tu->compileRequest->sourceManager; SemanticsVisitor visitor( - &nSink, - tu->compileRequest, - tu); + linkage, + sink); auto typeOut = visitor.CheckProperType(typeExp); - if (!nSink.errorCount) - { - type = typeOut.type; - } - return type; + return typeOut.type; } - FuncDecl* findFunctionDeclByName(EntryPointRequest* entryPoint, Name* name) + FuncDecl* findFunctionDeclByName( + Module* translationUnit, + Name* name, + DiagnosticSink* sink) { - auto translationUnit = entryPoint->getTranslationUnit(); - auto sink = &entryPoint->compileRequest->mSink; - auto translationUnitSyntax = translationUnit->SyntaxNode; + auto translationUnitSyntax = translationUnit->getModuleDecl(); // Make sure we've got a query-able member dictionary buildMemberDictionary(translationUnitSyntax); @@ -9270,7 +9235,9 @@ namespace Slang // Validate that an entry point function conforms to any additional // constraints based on the stage (and profile?) it specifies. void validateEntryPoint( - EntryPointRequest* entryPoint) + FuncDecl* entryPointFuncDecl, + Stage stage, + DiagnosticSink* sink) { // TODO: We currently do minimal checking here, but this is the // right place to perform the following validation checks: @@ -9297,28 +9264,32 @@ namespace Slang // that function is specific to the fragment profile/stage. // - auto sink = &entryPoint->compileRequest->mSink; + auto entryPointName = entryPointFuncDecl->getName(); + + auto module = getModule(entryPointFuncDecl); + auto linkage = module->getLinkage(); + // Every entry point needs to have a stage specified either via // command-line/API options, or via an explicit `[shader("...")]` attribute. // - if( entryPoint->getStage() == Stage::Unknown ) + if( stage == Stage::Unknown ) { - sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::entryPointHasNoStage, entryPoint->name); + sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointHasNoStage, entryPointName); } - if (entryPoint->getStage() == Stage::Hull) + if( stage == Stage::Hull ) { - auto translationUnit = entryPoint->getTranslationUnit(); - auto translationUnitSyntax = translationUnit->SyntaxNode; + // TODO: We could consider *always* checking any `[patchconsantfunc("...")]` + // attributes, so that they need to resolve to a function. - auto attr = entryPoint->getFuncDecl()->FindModifier<PatchConstantFuncAttribute>(); + auto attr = entryPointFuncDecl->FindModifier<PatchConstantFuncAttribute>(); if (attr) { if (attr->args.Count() != 1) { - sink->diagnose(translationUnitSyntax, Diagnostics::badlyDefinedPatchConstantFunc, entryPoint->name); + sink->diagnose(attr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); return; } @@ -9327,40 +9298,52 @@ namespace Slang if (!stringLit) { - sink->diagnose(translationUnitSyntax, Diagnostics::badlyDefinedPatchConstantFunc, entryPoint->name); + sink->diagnose(expr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); return; } - Name* name = entryPoint->compileRequest->getNamePool()->getName(stringLit->value); - FuncDecl* funcDecl = findFunctionDeclByName(entryPoint, name); - if (!funcDecl) + // We look up the patch-constant function by its name in the module + // scope of the translation unit that declared the HS entry point. + // + // TODO: Eventually we probably want to do the lookup in the scope + // of the parent declarations of the entry point. E.g., if the entry + // point is a member function of a `struct`, then its patch-constant + // function should be allowed to be another member function of + // the same `struct`. + // + // In the extremely long run we may want to support an alternative to + // this attribute-based linkage between the two functions that + // make up the entry point. + // + Name* name = linkage->getNamePool()->getName(stringLit->value); + FuncDecl* patchConstantFuncDecl = findFunctionDeclByName( + module, + name, + sink); + if (!patchConstantFuncDecl) { - sink->diagnose(translationUnitSyntax, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc"); + sink->diagnose(expr, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc"); return; } - attr->patchConstantFuncDecl = funcDecl; + attr->patchConstantFuncDecl = patchConstantFuncDecl; } } - else if (entryPoint->getStage() == Stage::Compute) + else if(stage == Stage::Compute) { - auto funcDecl = entryPoint->getFuncDecl(); - - auto params = funcDecl->GetParameters(); - - for (const auto& param : params) + for(const auto& param : entryPointFuncDecl->GetParameters()) { - if (auto semantic = param->FindModifier<HLSLSimpleSemantic>()) + if(auto semantic = param->FindModifier<HLSLSimpleSemantic>()) { const auto& semanticToken = semantic->name; String lowerName = String(semanticToken.Content).ToLower(); - if (lowerName == "sv_dispatchthreadid") + if(lowerName == "sv_dispatchthreadid") { Type* paramType = param->getType(); - if (!isValidThreadDispatchIDType(paramType)) + if(!isValidThreadDispatchIDType(paramType)) { String typeString = paramType->ToString(); sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString); @@ -9372,26 +9355,30 @@ namespace Slang } } - // Given an `EntryPointRequest` specified via API or command line options, + // Given an entry point specified via API or command line options, // attempt to find a matching AST declaration that implements the specified // entry point. If such a function is found, then validate that it actually // meets the requirements for the selected stage/profile. // - void findAndValidateEntryPoint( - EntryPointRequest* entryPoint) + // Returns an `EntryPoint` object representing the (unspecialized) + // entry point if it is found and validated, and null otherwise. + // + RefPtr<EntryPoint> findAndValidateEntryPoint( + FrontEndEntryPointRequest* entryPointReq) { // The first step in validating the entry point is to find // the (unique) function declaration that matches its name. // - // TODO: We will eventually need to update this logic - // to work by parsing the provided `entryPoint->name` string - // as an expression, so that we can handle more complex - // names like `foo<int>` or `SomeType.vs`. - - auto translationUnit = entryPoint->getTranslationUnit(); - auto sink = &entryPoint->compileRequest->mSink; - auto translationUnitSyntax = translationUnit->SyntaxNode; + // TODO: We may eventually want/need to extend this to + // account for nested names like `SomeStruct.vsMain`, or + // indeed even to handle generics. + // + auto compileRequest = entryPointReq->getCompileRequest(); + auto translationUnit = entryPointReq->getTranslationUnit(); + auto sink = compileRequest->getSink(); + auto translationUnitSyntax = translationUnit->getModuleDecl(); + auto entryPointName = entryPointReq->getName(); // Make sure we've got a query-able member dictionary buildMemberDictionary(translationUnitSyntax); @@ -9399,12 +9386,12 @@ namespace Slang // We will look up any global-scope declarations in the translation // unit that match the name of our entry point. Decl* firstDeclWithName = nullptr; - if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPoint->name, firstDeclWithName) ) + if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPointName, firstDeclWithName) ) { // If there doesn't appear to be any such declaration, then // we need to diagnose it as an error, and then bail out. - sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPoint->name); - return; + sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPointName); + return nullptr; } // We found at least one global-scope declaration with the right name, @@ -9448,7 +9435,7 @@ namespace Slang // name before, so the whole thing is ambiguous. We need // to diagnose and bail out. - sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPoint->name); + sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPointName); // List all of the declarations that the user *might* mean for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName) @@ -9460,7 +9447,7 @@ namespace Slang } // Bail out. - return; + return nullptr; } } } @@ -9471,127 +9458,197 @@ namespace Slang // If not, then we need to diagnose the error. // For convenience, we will point to the first // declaration with the right name, that wasn't a function. - sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPoint->name); - return; + sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPointName); + return nullptr; } + // TODO: it is possible that the entry point was declared with + // profile or target overloading. Is there anything that we need + // to do at this point to filter out declarations that aren't + // relevant to the selected profile for the entry point? + + // We found something, and can start doing some basic checking. + // // If the entry point specifies a stage via a `[shader("...")]` attribute, // then we might be able to infer a stage for the entry point request if // it didn't have one, *or* issue a diagnostic if there is a mismatch. // + auto entryPointProfile = entryPointReq->getProfile(); if( auto entryPointAttribute = entryPointFuncDecl->FindModifier<EntryPointAttribute>() ) { - if( entryPoint->getStage() == Stage::Unknown ) + auto entryPointStage = entryPointProfile.GetStage(); + if( entryPointStage == Stage::Unknown ) { - entryPoint->profile.setStage(entryPointAttribute->stage); + entryPointProfile.setStage(entryPointAttribute->stage); } - else if( entryPointAttribute->stage != entryPoint->getStage() ) + else if( entryPointAttribute->stage != entryPointStage ) { - sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPoint->name, entryPoint->getStage(), entryPointAttribute->stage); + sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPointName, entryPointStage, entryPointAttribute->stage); } } + else + { + // TODO: Should we attach a `[shader(...)]` attribute to an + // entry point that didn't have one, so that we can have + // a more uniform representation in the AST? + } - // TODO: it is possible that the entry point was declared with - // profile or target overloading. Is there anything that we need - // to do at this point to filter out declarations that aren't - // relevant to the selected profile for the entry point? - // Phew, we have at least found a suitable decl. - // Let's record that in the entry-point request so - // that we don't have to re-do this effort again later. - // - // Note: we may replace the decl-ref we store at this point - // later in this function, when we (potentially) specialize - // a generic entry point to generic arguments provided - // via the API. + // Now that we've *found* the entry point, it is time to validate + // that it actually meets the constraints for the chosen stage/profile. // - entryPoint->funcDeclRef = makeDeclRef(entryPointFuncDecl); + validateEntryPoint( + entryPointFuncDecl, + entryPointProfile.GetStage(), + sink); - // If the user specified generic arguments for the entry point, - // then we will want to parse those arguments as expressions - // in a scope that includes the tanslation unit that holds - // the entry point, as well as any other modules that got - // transitively loaded via `import`. - // - // TODO: This would be better handled by giving the user - // more explicit ways to parse/build types at the API level, - // rather than keeping things string-based this far along. + RefPtr<EntryPoint> entryPoint = EntryPoint::create( + makeDeclRef(entryPointFuncDecl), + entryPointProfile); + + return entryPoint; + } + + /// Create a `Program` to represent the compiled code. + /// + /// The created program will comprise all of the translation + /// units that were compiled as part of the request, as + /// well as any entry points in those translation units. + /// + RefPtr<Program> createUnspecializedProgram( + FrontEndCompileRequest* compileRequest) + { + // We want our resulting program to depend on + // all the translation units the user specified, + // even if some of them don't contain entry points + // (this is important for parameter layout/binding). // - // TODO: Building a list of `scopesToTry` here shouldn't - // be required, since the `Scope` type itself has the ability - // for form chains for lookup purposes (e.g., the way that - // `import` is handled by modifying a scope). + // We also want to ensure that the modules for the + // translation units comes first in the enumerated + // order for dependencies, to match the pre-existing + // compiler behavior (at least for now). // - List<RefPtr<Scope>> scopesToTry; - scopesToTry.Add(entryPoint->getTranslationUnit()->SyntaxNode->scope); - for (auto & module : entryPoint->compileRequest->loadedModulesList) - scopesToTry.Add(module->moduleDecl->scope); + auto linkage = compileRequest->getLinkage(); + auto sink = compileRequest->getSink(); + auto program = new Program(linkage); + for(auto translationUnit : compileRequest->translationUnits ) + { + program->addReferencedLeafModule(translationUnit->getModule()); + } + for(auto translationUnit : compileRequest->translationUnits ) + { + program->addReferencedModule(translationUnit->getModule()); + } - // We are going to do some semantic checking, so we need to - // set up a `SemanticsVistitor` that we can use. - // - SemanticsVisitor semantics( - &entryPoint->compileRequest->mSink, - entryPoint->compileRequest, - entryPoint->getTranslationUnit()); - // We will be looping over the generic argument strings - // that the user provided via the API (or command line), - // and parsing+checking each into an `Expr`. + // The validation of entry points here will be modal, and controlled + // by whether the user specified any entry points directly via + // API or command-line options. // - // This loop will *not* handle coercing the arguments - // to be types. + // TODO: We may want to make this choice explicit rather than implicit. // - List<RefPtr<Expr>> genericArgs; - for (auto name : entryPoint->genericArgStrings) + // First, check if the user requested any entry points explicitly via + // the API or command line. + // + bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0; + + if( anyExplicitEntryPoints ) { - RefPtr<Expr> argExpr; - for (auto & s : scopesToTry) + // If there were any explicit requests for entry points to be + // checked, then we will *only* check those. + // + for(auto entryPointReq : compileRequest->getEntryPointReqs()) { - argExpr = entryPoint->compileRequest->parseTypeString( - entryPoint->getTranslationUnit(), - name, - s); - argExpr = semantics.CheckTerm(argExpr); - if( argExpr ) + auto entryPoint = findAndValidateEntryPoint( + entryPointReq); + if( entryPoint ) { - break; + program->addEntryPoint(entryPoint); + entryPointReq->getTranslationUnit()->entryPoints.Add(entryPoint); } } - // The following is a bit of a hack. - // - // Back-end code generation relies on us having computed layouts for all tagged - // unions that end up being used in the code, which means we need a way to find - // all such types that get used in a module (and the stuff it imports). + // TODO: We should consider always processing both categories, + // and just making sure to only check each entry point function + // declaration once... + } + else + { + // Otherwise, scan for any `[shader(...)]` attributes in + // the user's code, and construct `EntryPoint`s to + // represent them. // - // The Right Way to handle this would probably be to have each `ModuleDecl` track - // any tagged union types that get created in the context of that module, and - // then combine those lists later. + // This ensures that downstream code only has to consider + // the central list of entry point requests, and doesn't + // have to know where they came from. + + // TODO: A comprehensive approach here would need to search + // recursively for entry points, because they might appear + // as, e.g., member function of a `struct` type. // - // For now we are assuming a tagged union type only comes into existence - // as a (top-level) argument for a generic type parameter, so that we - // can check for them here and cache them on the entry point request. + // For now we'll start with an extremely basic approach that + // should work for typical HLSL code. // - if( auto typeType = as<TypeType>(argExpr->type) ) + UInt translationUnitCount = compileRequest->translationUnits.Count(); + for(UInt tt = 0; tt < translationUnitCount; ++tt) { - auto type = typeType->type; - if( auto taggedUnionType = as<TaggedUnionType>(type) ) + auto translationUnit = compileRequest->translationUnits[tt]; + for( auto globalDecl : translationUnit->getModuleDecl()->Members ) { - entryPoint->taggedUnionTypes.Add(taggedUnionType); + auto maybeFuncDecl = globalDecl; + if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) ) + { + maybeFuncDecl = genericDecl->inner; + } + + auto funcDecl = as<FuncDecl>(maybeFuncDecl); + if(!funcDecl) + continue; + + auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>(); + if(!entryPointAttr) + continue; + + // We've discovered a valid entry point. It is a function (possibly + // generic) that has a `[shader(...)]` attribute to mark it as an + // entry point. + // + // We will now register that entry point as an `EntryPoint` + // with an appropriately chosen profile. + // + // The profile will only include a stage, so that the profile "family" + // and "version" are left unspecified. Downstream code will need + // to be able to handle this case. + // + Profile profile; + profile.setStage(entryPointAttr->stage); + + validateEntryPoint(funcDecl, entryPointAttr->stage, sink); + + RefPtr<EntryPoint> entryPoint = EntryPoint::create( + makeDeclRef(funcDecl), + profile); + program->addEntryPoint(entryPoint); + translationUnit->entryPoints.Add(entryPoint); } } - - genericArgs.Add(argExpr); } - // There are two cases we care about here, and we are going to treat them - // as mutually exclusive for simplicity. - // - // The first case is when the entry point function is itself generic, - // in which case we will assume that `genericArgs` lines up one-to-one - // with the explicit generic parameters of the entry point. - // + return program; + } + + /// Create a specialization an existing entry point based on generic arguments. + DeclRef<FuncDecl> specializeEntryPoint( + Linkage* linkage, + FuncDecl* entryPointFuncDecl, + List<RefPtr<Expr>> const& genericArgs, + DiagnosticSink* sink) + { + SemanticsVisitor semantics( + linkage, + sink); + + DeclRef<FuncDecl> entryPointFuncDeclRef = makeDeclRef(entryPointFuncDecl); if( auto genericDecl = as<GenericDecl>(entryPointFuncDecl->ParentDecl) ) { // We will construct a suitable `GenericAppExpr` to represent @@ -9601,7 +9658,7 @@ namespace Slang // generic application like `F<A,B,C>` if it were // encountered in the source code. - auto session = entryPoint->compileRequest->mSession; + auto session = linkage->getSession(); auto genericDeclRef = makeDeclRef(genericDecl); // The first pieces is a `VarExpr` that refers to `genericDecl`. @@ -9639,12 +9696,13 @@ namespace Slang // The basic `VarExpr` and `StaticMemberExpr` cases // should be allow-able. - entryPoint->funcDeclRef = declRefExpr->declRef.as<FuncDecl>(); + entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>(); } else if( semantics.IsErrorExpr(checkedExpr) ) { // Any semantic error that occured should have been // reported already. + return DeclRef<FuncDecl>(); } else { @@ -9652,302 +9710,359 @@ namespace Slang // function should always be a `DeclRefExpr` // SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`"); + UNREACHABLE_RETURN(DeclRef<FuncDecl>()); } } - else - { - // The other case is when the entry point function is *not* itself - // generic, so we assume that any generic arguments must have been intended - // to match up with global generic parameters instead. - // - // We will only validate global generic type arguments when we are going - // to generate code, since in a no-codegen pass we will typically *not* - // have arguments to associate with the parameters. - // - if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) - { - // check that user-provioded type arguments conforms to the generic type - // parameter declaration of this translation unit - // collect global generic parameters from all imported modules - List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; - // add current translation unit first - { - auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - // add imported modules - for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) - { - auto moduleDecl = loadedModule->moduleDecl; - auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - - if (globalGenericParams.Count() != genericArgs.Count()) - { - sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::mismatchEntryPointTypeArgument, - globalGenericParams.Count(), - genericArgs.Count()); - return; - } - - // We have an appropriate number of arguments for the global generic parameters, - // and now we need to check that the arguments conform to the declared constraints. - // - // Along the way, we will build up an appropriate set of substitutions to represent - // the generic arguments and their conformances. - // - RefPtr<Substitutions> globalGenericSubsts; - auto globalGenericSubstLink = &globalGenericSubsts; - // - // TODO: There is a serious flaw to this checking logic if we ever have cases where - // the constraints on one `type_param` can depend on another `type_param`, e.g.: - // - // type_param A; - // type_param B : ISidekick<A>; - // - // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to - // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being - // set to `Batman` to know whether the setting for `B` is valid. In this limit - // the constraints can be mutually recursive (so `A : IMentor<B>`). - // - // The only way to check things correctly is to validate each conformance under - // a set of assumptions (substitutions) that includes all the type substitutions, - // and possibly also all the other constraints *except* the one to be validated. - // - // We will punt on this for now, and just check each constraint in isolation. - // - UInt argCounter = 0; - for(auto& globalGenericParam : globalGenericParams) - { - // Get the argument that matches this parameter. - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < genericArgs.Count()); - auto globalGenericArg = checkProperType(translationUnit, TypeExp(genericArgs[argIndex])); - if (!globalGenericArg) - { - sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, entryPoint->genericArgStrings[argIndex]); - return; - } + return entryPointFuncDeclRef; + } - // As a quick sanity check, see if the argument that is being supplied for a parameter - // is just the parameter itself, because this should always be an error: - // - if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() ) - { - auto argDeclRef = argDeclRefType->declRef; - if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) - { - if(argGenericParamDeclRef.getDecl() == globalGenericParam) - { - // We are trying to specialize a generic parameter using itself. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToItself, - globalGenericParam->getName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; - } - else - { - // We are trying to specialize a generic parameter using a *different* - // global generic type parameter. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, - globalGenericParam->getName(), - argGenericParamDeclRef.GetName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; - } - } - } + /// Parse an array of strings as generic arguments. + /// + /// Names in the strings will be parsed in the context of + /// the code loaded into the given compile request. + /// + void parseGenericArgStrings( + EndToEndCompileRequest* endToEndReq, + List<String> const& genericArgStrings, + List<RefPtr<Expr>>& outGenericArgs) + { + auto unspecialiedProgram = endToEndReq->getUnspecializedProgram(); - // Create a substitution for this parameter/argument. - RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); - subst->paramDecl = globalGenericParam; - subst->actualType = globalGenericArg; + // TODO: Building a list of `scopesToTry` here shouldn't + // be required, since the `Scope` type itself has the ability + // for form chains for lookup purposes (e.g., the way that + // `import` is handled by modifying a scope). + // + List<RefPtr<Scope>> scopesToTry; + for( auto module : unspecialiedProgram->getModuleDependencies() ) + scopesToTry.Add(module->getModuleDecl()->scope); - // Walk through the declared constraints for the parameter, - // and check that the argument actually satisfies them. - for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) - { - // Get the type that the constraint is enforcing conformance to - auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + // We are going to do some semantic checking, so we need to + // set up a `SemanticsVistitor` that we can use. + // + auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); + SemanticsVisitor semantics( + linkage, + sink); - // Use our semantic-checking logic to search for a witness to the required conformance - SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); - auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); - if (!witness) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(globalGenericParam, - Diagnostics::typeArgumentDoesNotConformToInterface, - globalGenericParam->nameAndLoc.name, - globalGenericArg, - interfaceType); - } + // We will be looping over the generic argument strings + // that the user provided via the API (or command line), + // and parsing+checking each into an `Expr`. + // + // This loop will *not* handle coercing the arguments + // to be types. + // + for(auto name : genericArgStrings) + { + RefPtr<Expr> argExpr; + for (auto & s : scopesToTry) + { + argExpr = linkage->parseTypeString(name, s); + argExpr = semantics.CheckTerm(argExpr); + if( argExpr ) + { + break; + } + } - // Attach the concrete witness for this conformance to the - // substutiton - GlobalGenericParamSubstitution::ConstraintArg constraintArg; - constraintArg.decl = constraint; - constraintArg.val = witness; - subst->constraintArgs.Add(constraintArg); - } + outGenericArgs.Add(argExpr); + } + } - // Add the substitution for this parameter to the global substitution - // set that we are building. + /// Specialize a program to global generic arguments + RefPtr<Program> createSpecializedProgram( + Linkage* linkage, + Program* unspecializedProgram, + List<RefPtr<Expr>> const& globalGenericArgs, + DiagnosticSink* sink) + { + // The given `unspecializedProgram` should be one that + // was checked through the front-end, so that now we + // only need to check if the given arguments can satisfy + // the requirements of the global generic parameters. + // + // The new program needs to start off with the same + // module dependency list as the original. + // + RefPtr<Program> specializedProgram = new Program(linkage); + for(auto module : unspecializedProgram->getModuleDependencies()) + { + specializedProgram->addReferencedLeafModule(module); + } - *globalGenericSubstLink = subst; - globalGenericSubstLink = &subst->outer; - } - entryPoint->globalGenericSubst = globalGenericSubsts; - } + // We will collect all the global generic parameters + // defined in the modules being referenced, to find + // the global generic parameter signature of the + // program. + // + // TODO: Note that this doesn't handle the case where one + // or more of the type *arguments* that we are specifying + // ends up requiring additional modules to be referenced, + // which might in turn introduce new global generic parameters. + // + List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; + for(auto module : unspecializedProgram->getModuleDependencies()) + { + for(auto param : module->getModuleDecl()->getMembersOfType<GlobalGenericParamDecl>()) + globalGenericParams.Add(param); } - // If any errors occured while we were checking the generic arguments - // of the entry point, then we should bail out rather than try to - // perform the next step of validation. + // Next, we will check whether the supplied arguments can + // satisfy those parameters. + // + // An easy early-out case will be if the number of + // arguments isn't correct. // - if (sink->errorCount != 0) - return; + if (globalGenericParams.Count() != globalGenericArgs.Count()) + { + sink->diagnose(SourceLoc(), Diagnostics::mismatchGlobalGenericArguments, + globalGenericParams.Count(), + globalGenericArgs.Count()); + return nullptr; + } - // Now that we've *found* the entry point, it is time to validate - // that it actually meets the constraints for the chosen stage/profile. + // We have an appropriate number of arguments for the global generic parameters, + // and now we need to check that the arguments conform to the declared constraints. // - // TODO: This validation should (probably?) be performed "under" any global generic - // parameter substitution we might have created, so that we can validate - // based on knowledge of actual types. + // Along the way, we will build up an appropriate set of substitutions to represent + // the generic arguments and their conformances. // - validateEntryPoint(entryPoint); - } - - void validateEntryPoints( - CompileRequest* compileRequest) - { - // The validation of entry points here will be modal, and controlled - // by whether the user specified any entry points directly via - // API or command-line options. + RefPtr<Substitutions> globalGenericSubsts; + auto globalGenericSubstLink = &globalGenericSubsts; // - // TODO: We may want to make this choice explicit rather than implicit. + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: // - // First, check if the user request any entry points explicitly via - // the API or command line. + // type_param A; + // type_param B : ISidekick<A>; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor<B>`). // - bool anyExplicitEntryPointRequests = false; - for (auto& translationUnit : compileRequest->translationUnits) + // The only way to check things correctly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. + // + UInt argCounter = 0; + for(auto& globalGenericParam : globalGenericParams) { - if( translationUnit->entryPoints.Count() != 0) + // Get the argument that matches this parameter. + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < globalGenericArgs.Count()); + auto globalGenericArg = checkProperType(linkage, TypeExp(globalGenericArgs[argIndex]), sink); + if (!globalGenericArg) { - anyExplicitEntryPointRequests = true; - break; + sink->diagnose(globalGenericParam, Diagnostics::globalGenericArgumentNotAType, globalGenericParam->getName()); + return nullptr; } - } - if( anyExplicitEntryPointRequests ) - { - // If there were any explicit requests for entry points to be - // checked, then we will *only* check those. - - for (auto& translationUnit : compileRequest->translationUnits) + // As a quick sanity check, see if the argument that is being supplied for a parameter + // is just the parameter itself, because this should always be an error: + // + if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() ) { - for (auto entryPoint : translationUnit->entryPoints) + auto argDeclRef = argDeclRefType->declRef; + if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) { - findAndValidateEntryPoint(entryPoint); + if(argGenericParamDeclRef.getDecl() == globalGenericParam) + { + // We are trying to specialize a generic parameter using itself. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToItself, + globalGenericParam->getName()); + continue; + } + else + { + // We are trying to specialize a generic parameter using a *different* + // global generic type parameter. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, + globalGenericParam->getName(), + argGenericParamDeclRef.GetName()); + continue; + } } } - } - else - { - // Otherwise, scan for any `[shader(...)]` attributes in - // the user's code, and construct `EntryPointRequest`s to - // represent them. - // - // This ensures that downstream code only has to consider - // the central list of entry point requests, and doesn't - // have to know where they came from. - // TODO: A comprehensive approach here would need to search - // recursively for entry points, because they might appear - // as, e.g., member function of a `struct` type. - // - // For now we'll start with an extremely basic approach that - // should work for typical HLSL code. - // - UInt translationUnitCount = compileRequest->translationUnits.Count(); - for(UInt tt = 0; tt < translationUnitCount; ++tt) + // Create a substitution for this parameter/argument. + RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); + subst->paramDecl = globalGenericParam; + subst->actualType = globalGenericArg; + + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) { - auto translationUnit = compileRequest->translationUnits[tt]; - for( auto globalDecl : translationUnit->SyntaxNode->Members ) + // Get the type that the constraint is enforcing conformance to + auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + + // Use our semantic-checking logic to search for a witness to the required conformance + SemanticsVisitor visitor(linkage, sink); + auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); + if (!witness) { - auto maybeFuncDecl = globalDecl; - if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) ) - { - maybeFuncDecl = genericDecl->inner; - } + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(globalGenericParam, + Diagnostics::typeArgumentDoesNotConformToInterface, + globalGenericParam->nameAndLoc.name, + globalGenericArg, + interfaceType); + } - auto funcDecl = as<FuncDecl>(maybeFuncDecl); - if(!funcDecl) - continue; + // Attach the concrete witness for this conformance to the + // substutiton + GlobalGenericParamSubstitution::ConstraintArg constraintArg; + constraintArg.decl = constraint; + constraintArg.val = witness; + subst->constraintArgs.Add(constraintArg); + } - auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>(); - if(!entryPointAttr) - continue; + // Add the substitution for this parameter to the global substitution + // set that we are building. - // We've discovered a valid entry point. It is a function (possibly - // generic) that has a `[shader(...)]` attribute to mark it as an - // entry point. - // - // We will now register that entry point as an `EntryPointRequest` - // with an appropriately chosen profile. - // - // The profile will only include a stage, so that the profile "family" - // and "version" are left unspecified. Downstream code will need - // to be able to handle this case. - // - Profile profile; - profile.setStage(entryPointAttr->stage); + *globalGenericSubstLink = subst; + globalGenericSubstLink = &subst->outer; + } + if(sink->GetErrorCount()) + return nullptr; - // We manually fill in the entry point request object. - RefPtr<EntryPointRequest> entryPointReq = new EntryPointRequest(); - entryPointReq->compileRequest = compileRequest; - entryPointReq->translationUnitIndex = int(tt); - entryPointReq->funcDeclRef = makeDeclRef(funcDecl); - entryPointReq->name = funcDecl->getName(); - entryPointReq->profile = profile; + specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts); - // Apply the common validation logic to this entry point. - validateEntryPoint(entryPointReq); + return specializedProgram; + } - // Add the entry point to the list in the translation unit, - // and also the global list in the compile request. - translationUnit->entryPoints.Add(entryPointReq); - compileRequest->entryPoints.Add(entryPointReq); - } - } + /// Specialize an entry point that was checked by the front-end, based on generic arguments. + /// + /// If the end-to-end compile request included generic argument strings + /// for this entry point, then they will be parsed, checked, and used + /// as arguments to the generic entry point. + /// + /// Returns a specialized entry point if everything worked as expected. + /// Returns null and diagnoses errors if anything goes wrong. + /// + RefPtr<EntryPoint> specializeEntryPoint( + EndToEndCompileRequest* endToEndReq, + EntryPoint* unspecializedEntryPoint, + EndToEndCompileRequest::EntryPointInfo const& entryPointInfo) + { + auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); + auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl(); + + // If the user specified generic arguments for the entry point, + // then we will need to parse the arguments first. + // + List<RefPtr<Expr>> genericArgs; + parseGenericArgStrings( + endToEndReq, + entryPointInfo.genericArgStrings, + genericArgs); + + // Next we specialize the entry point function given the parsed + // generic argument expressions. + // + auto entryPointFuncDeclRef = specializeEntryPoint( + linkage, + entryPointFuncDecl, + genericArgs, + sink); + + RefPtr<EntryPoint> entryPoint = EntryPoint::create( + entryPointFuncDeclRef, + unspecializedEntryPoint->getProfile()); + + return entryPoint; + } + + /// Create a specialized program based on the given compile request. + /// + RefPtr<Program> createSpecializedProgram( + EndToEndCompileRequest* endToEndReq) + { + // The compile request must have already completed front-end processing, + // so that we have an unspecialized program available, and now only need + // to parse and check any generic arguments that are being supplied for + // global or entry-point generic parameters. + // + auto unspecializedProgram = endToEndReq->getUnspecializedProgram(); + + // First, let's parse the generic argument strings that were + // provided via the API, so taht we can match them + // against what was declared in the program. + // + List<RefPtr<Expr>> globalGenericArgs; + parseGenericArgStrings( + endToEndReq, + endToEndReq->globalGenericArgStrings, + globalGenericArgs); + + // Now we create the initial specialized program by + // applying the global generic arguments (if any) to the + // unspecialized program. + // + auto specializedProgram = createSpecializedProgram( + endToEndReq->getLinkage(), + unspecializedProgram, + globalGenericArgs, + endToEndReq->getSink()); + + // If anything went wrong with the global generic + // arguments, then bail out now. + // + if(!specializedProgram) + return nullptr; + + // Next we will deal with the entry points for the + // new specialized program. + // + // If the user specified explicit entry points as part of the + // end-to-end request, then we only want to process those (and + // ignore any other `[shader(...)]`-attributed entry points). + // + // However, if the user specified *no* entry points as part + // of the end-to-end request, then we would like to go + // ahead and consider all the entry points that were found + // by the front-end. + // + UInt entryPointCount = endToEndReq->entryPoints.Count(); + if( entryPointCount == 0 ) + { + entryPointCount = unspecializedProgram->getEntryPointCount(); + endToEndReq->entryPoints.SetSize(entryPointCount); } + + for( UInt ii = 0; ii < entryPointCount; ++ii ) + { + auto unspecializedEntryPoint = unspecializedProgram->getEntryPoint(ii); + auto& entryPointInfo = endToEndReq->entryPoints[ii]; + + auto specializedEntryPoint = specializeEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); + specializedProgram->addEntryPoint(specializedEntryPoint); + } + + return specializedProgram; } void checkTranslationUnit( TranslationUnitRequest* translationUnit) { SemanticsVisitor visitor( - &translationUnit->compileRequest->mSink, - translationUnit->compileRequest, - translationUnit); + translationUnit->compileRequest->getLinkage(), + translationUnit->compileRequest->getSink()); // Apply the visitor to do the main semantic // checking that is required on all declarations // in the translation unit. - visitor.checkDecl(translationUnit->SyntaxNode); + visitor.checkDecl(translationUnit->getModuleDecl()); } diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp index 21f56c9ee..3bc34692d 100644 --- a/source/slang/compiler.cpp +++ b/source/slang/compiler.cpp @@ -120,23 +120,123 @@ namespace Slang return blob; } - // EntryPointRequest + // + // FrontEndEntryPointRequest + // + + FrontEndEntryPointRequest::FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile) + : m_compileRequest(compileRequest) + , m_translationUnitIndex(translationUnitIndex) + , m_name(name) + , m_profile(profile) + {} - TranslationUnitRequest* EntryPointRequest::getTranslationUnit() + + TranslationUnitRequest* FrontEndEntryPointRequest::getTranslationUnit() { - return compileRequest->translationUnits[translationUnitIndex].Ptr(); + return getCompileRequest()->translationUnits[m_translationUnitIndex]; } - DeclRef<FuncDecl> EntryPointRequest::getFuncDeclRef() + // + // EntryPoint + // + + RefPtr<EntryPoint> EntryPoint::create( + DeclRef<FuncDecl> funcDeclRef, + Profile profile) { - return funcDeclRef; + RefPtr<EntryPoint> entryPoint = new EntryPoint( + funcDeclRef.GetName(), + profile, + funcDeclRef); + return entryPoint; } - RefPtr<FuncDecl> EntryPointRequest::getFuncDecl() + RefPtr<EntryPoint> EntryPoint::createDummyForPassThrough( + Name* name, + Profile profile) { - return getFuncDeclRef().getDecl(); + RefPtr<EntryPoint> entryPoint = new EntryPoint( + name, + profile, + DeclRef<FuncDecl>()); + return entryPoint; } + EntryPoint::EntryPoint( + Name* name, + Profile profile, + DeclRef<FuncDecl> funcDeclRef) + : m_name(name) + , m_profile(profile) + , m_funcDeclRef(funcDeclRef) + { + // In order for later code generation to work, we need to track what + // modules each entry point depends on. We will build up the dependency + // list here when an `EntryPoint` gets created. + // + // We know an entry point depends on the module that declared the + // entry-point function itself. + // + // Note: we are carefully handling the case where `module` could + // be null, becase of "dummy" entry points created for pass-through + // compilation. + // + if(auto module = getModule()) + { + m_dependencyList.addDependency(module); + } + // + // TODO: We also need to include the modules needed by any generic + // arguments in the dependency list, since in the general case they + // might come from modules other than the one defining the entry point. + + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a program (and the stuff it imports). + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point. + // + // A longer-term strategy might need to consider any (tagged or untagged) + // union types that get used inside of a module, and also take + // those lists into account. + // + // An even longer-term strategy would be to allow type layout to + // be performed on IR types, so taht we don't need to have front-end + // code worrying about this stuff. + // + for( auto subst = funcDeclRef.substitutions.substitutions; subst; subst = subst->outer ) + { + if( auto genericSubst = as<GenericSubstitution>(subst) ) + { + for( auto arg : genericSubst->args ) + { + if( auto taggedUnionType = as<TaggedUnionType>(arg) ) + { + m_taggedUnionTypes.Add(taggedUnionType); + } + } + } + } + } + + Module* EntryPoint::getModule() + { + return Slang::getModule(getFuncDecl()); + } + + Linkage* EntryPoint::getLinkage() + { + return getModule()->getLinkage(); + } // @@ -279,13 +379,35 @@ namespace Slang // + /// If there is a pass-through compile going on, find the translation unit for the given entry point. + TranslationUnitRequest* findPassThroughTranslationUnit( + EndToEndCompileRequest* endToEndReq, + Int entryPointIndex) + { + // If there isn't an end-to-end compile going on, + // there can be no pass-through. + // + if(!endToEndReq) return nullptr; + + // And if pass-through isn't set, we don't need + // access to the translation unit. + // + if(endToEndReq->passThrough == PassThroughMode::None) return nullptr; + + auto frontEndReq = endToEndReq->getFrontEndReq(); + auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex); + auto translationUnit = entryPointReq->getTranslationUnit(); + return translationUnit; + } + String emitHLSLForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) { - auto compileRequest = entryPoint->compileRequest; - auto translationUnit = entryPoint->getTranslationUnit(); - if (compileRequest->passThrough != PassThroughMode::None) + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) { // Generate a string that includes the content of // the source file(s), along with a line directive @@ -294,7 +416,7 @@ namespace Slang // mode. StringBuilder codeBuilder; - for(auto sourceFile : translationUnit->sourceFiles) + for(auto sourceFile : translationUnit->getSourceFiles()) { codeBuilder << "#line 1 \""; @@ -323,21 +445,21 @@ namespace Slang else { return emitEntryPoint( + compileRequest, entryPoint, - targetReq->layout.Ptr(), CodeGenTarget::HLSL, targetReq); } } String emitGLSLForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) { - auto compileRequest = entryPoint->compileRequest; - auto translationUnit = entryPoint->getTranslationUnit(); - - if (compileRequest->passThrough != PassThroughMode::None) + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) { // Generate a string that includes the content of // the source file(s), along with a line directive @@ -347,7 +469,7 @@ namespace Slang StringBuilder codeBuilder; int translationUnitCounter = 0; - for(auto sourceFile : translationUnit->sourceFiles) + for(auto sourceFile : translationUnit->getSourceFiles()) { int translationUnitIndex = translationUnitCounter++; @@ -370,8 +492,8 @@ namespace Slang // TODO(tfoley): need to pass along the entry point // so that we properly emit it as the `main` function. return emitEntryPoint( + compileRequest, entryPoint, - targetReq->layout.Ptr(), CodeGenTarget::GLSL, targetReq); } @@ -484,9 +606,9 @@ namespace Slang sink->diagnoseRaw(SLANG_FAILED(res) ? Severity::Error : Severity::Warning, builder.getUnownedSlice()); } - static String _getDisplayPath(const DiagnosticSink& sink, SourceFile* sourceFile) + static String _getDisplayPath(DiagnosticSink* sink, SourceFile* sourceFile) { - if (sink.flags & DiagnosticSink::Flag::VerbosePath) + if (sink->flags & DiagnosticSink::Flag::VerbosePath) { return sourceFile->calcVerbosePath(); } @@ -496,17 +618,17 @@ namespace Slang } } - String calcTranslationUnitSourcePath(TranslationUnitRequest* translationUnitRequest) + String calcSourcePathForEntryPoint( + EndToEndCompileRequest* endToEndReq, + UInt entryPointIndex) { - CompileRequest* compileRequest = translationUnitRequest->compileRequest; - if (compileRequest->passThrough == PassThroughMode::None) - { + auto translationUnitRequest = findPassThroughTranslationUnit(endToEndReq, entryPointIndex); + if(!translationUnitRequest) return "slang-generated"; - } - auto& sink = translationUnitRequest->compileRequest->mSink; + auto sink = endToEndReq->getSink(); - const auto& sourceFiles = translationUnitRequest->sourceFiles; + const auto& sourceFiles = translationUnitRequest->getSourceFiles(); const int numSourceFiles = int(sourceFiles.Count()); @@ -542,22 +664,26 @@ namespace Slang } SlangResult emitDXBytecodeForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - List<uint8_t>& byteCodeOut) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List<uint8_t>& byteCodeOut) { byteCodeOut.Clear(); - auto session = entryPoint->compileRequest->mSession; + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); - auto compileFunc = (pD3DCompile)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DCompile, &entryPoint->compileRequest->mSink); + auto compileFunc = (pD3DCompile)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DCompile, sink); if (!compileFunc) { return SLANG_FAIL; } - auto hlslCode = emitHLSLForEntryPoint(entryPoint, targetReq); - maybeDumpIntermediate(entryPoint->compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL); + auto hlslCode = emitHLSLForEntryPoint(compileRequest, entryPoint, entryPointIndex, targetReq, endToEndReq); + maybeDumpIntermediate(compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL); auto profile = getEffectiveProfile(entryPoint, targetReq); @@ -569,16 +695,16 @@ namespace Slang // List<D3D_SHADER_MACRO> dxMacrosStorage; D3D_SHADER_MACRO const* dxMacros = nullptr; - if( entryPoint->compileRequest->passThrough != PassThroughMode::None ) + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) { - for( auto& define : entryPoint->compileRequest->preprocessorDefinitions ) + for( auto& define : translationUnit->compileRequest->preprocessorDefinitions ) { D3D_SHADER_MACRO dxMacro; dxMacro.Name = define.Key.Buffer(); dxMacro.Definition = define.Value.Buffer(); dxMacrosStorage.Add(dxMacro); } - for( auto& define : entryPoint->getTranslationUnit()->preprocessorDefinitions ) + for( auto& define : translationUnit->preprocessorDefinitions ) { D3D_SHADER_MACRO dxMacro; dxMacro.Name = define.Key.Buffer(); @@ -616,7 +742,7 @@ namespace Slang flags |= D3DCOMPILE_ENABLE_STRICTNESS; flags |= D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES; - const String sourcePath = calcTranslationUnitSourcePath(entryPoint->getTranslationUnit()); + const String sourcePath = "slang-geneated";// calcTranslationUnitSourcePath(entryPoint->getTranslationUnit()); ComPtr<ID3DBlob> codeBlob; ComPtr<ID3DBlob> diagnosticsBlob; @@ -626,7 +752,7 @@ namespace Slang sourcePath.Buffer(), dxMacros, nullptr, - getText(entryPoint->name).begin(), + getText(entryPoint->getName()).begin(), GetHLSLProfileName(profile).Buffer(), flags, 0, // unused: effect flags @@ -640,23 +766,24 @@ namespace Slang if (FAILED(hr)) { - reportExternalCompileError("fxc", hr, _getSlice(diagnosticsBlob), &entryPoint->compileRequest->mSink); + reportExternalCompileError("fxc", hr, _getSlice(diagnosticsBlob), sink); } return hr; } SlangResult dissassembleDXBC( - CompileRequest* compileRequest, - void const* data, - size_t size, - String& assemOut) + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + String& assemOut) { assemOut = String(); - auto session = compileRequest->mSession; + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); - auto disassembleFunc = (pD3DDisassemble)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DDisassemble, &compileRequest->mSink); + auto disassembleFunc = (pD3DDisassemble)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DDisassemble, sink); if (!disassembleFunc) { return SLANG_E_NOT_FOUND; @@ -677,25 +804,34 @@ namespace Slang if (FAILED(res)) { // TODO(tfoley): need to figure out what to diagnose here... - reportExternalCompileError("fxc", res, UnownedStringSlice(), &compileRequest->mSink); + reportExternalCompileError("fxc", res, UnownedStringSlice(), sink); } return res; } SlangResult emitDXBytecodeAssemblyForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - String& assemOut) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + String& assemOut) { List<uint8_t> dxbc; - SLANG_RETURN_ON_FAIL(emitDXBytecodeForEntryPoint(entryPoint, targetReq, dxbc)); + SLANG_RETURN_ON_FAIL(emitDXBytecodeForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + dxbc)); if (!dxbc.Count()) { return SLANG_FAIL; } - return dissassembleDXBC(entryPoint->compileRequest, dxbc.Buffer(), dxbc.Count(), assemOut); + return dissassembleDXBC(compileRequest, dxbc.Buffer(), dxbc.Count(), assemOut); } #endif @@ -704,26 +840,30 @@ namespace Slang // Implementations in `dxc-support.cpp` int emitDXILForEntryPointUsingDXC( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - List<uint8_t>& outCode); + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List<uint8_t>& outCode); SlangResult dissassembleDXILUsingDXC( - CompileRequest* compileRequest, - void const* data, - size_t size, - String& stringOut); + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + String& stringOut); #endif #if SLANG_ENABLE_GLSLANG_SUPPORT SlangResult invokeGLSLCompiler( - CompileRequest* slangCompileRequest, + BackEndCompileRequest* slangCompileRequest, glslang_CompileRequest& request) { - Session* session = slangCompileRequest->mSession; + Session* session = slangCompileRequest->getSession(); + auto sink = slangCompileRequest->getSink(); - auto glslang_compile = (glslang_CompileFunc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Glslang_Compile, &slangCompileRequest->mSink); + auto glslang_compile = (glslang_CompileFunc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Glslang_Compile, sink); if (!glslang_compile) { return SLANG_FAIL; @@ -743,7 +883,7 @@ SlangResult dissassembleDXILUsingDXC( if (err) { - reportExternalCompileError("glslang", SLANG_FAIL, diagnosticOutput.getUnownedSlice(), &slangCompileRequest->mSink); + reportExternalCompileError("glslang", SLANG_FAIL, diagnosticOutput.getUnownedSlice(), sink); return SLANG_FAIL; } @@ -751,10 +891,10 @@ SlangResult dissassembleDXILUsingDXC( } SlangResult dissassembleSPIRV( - CompileRequest* slangRequest, - void const* data, - size_t size, - String& stringOut) + BackEndCompileRequest* slangRequest, + void const* data, + size_t size, + String& stringOut) { stringOut = String(); @@ -782,21 +922,29 @@ SlangResult dissassembleDXILUsingDXC( } SlangResult emitSPIRVForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - List<uint8_t>& spirvOut) + BackEndCompileRequest* slangRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List<uint8_t>& spirvOut) { spirvOut.Clear(); - String rawGLSL = emitGLSLForEntryPoint(entryPoint, targetReq); - maybeDumpIntermediate(entryPoint->compileRequest, rawGLSL.Buffer(), CodeGenTarget::GLSL); + String rawGLSL = emitGLSLForEntryPoint( + slangRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); + maybeDumpIntermediate(slangRequest, rawGLSL.Buffer(), CodeGenTarget::GLSL); auto outputFunc = [](void const* data, size_t size, void* userData) { ((List<uint8_t>*)userData)->AddRange((uint8_t*)data, size); }; - const String sourcePath = calcTranslationUnitSourcePath(entryPoint->getTranslationUnit()); + const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); glslang_CompileRequest request; request.action = GLSLANG_ACTION_COMPILE_GLSL_TO_SPIRV; @@ -809,40 +957,56 @@ SlangResult dissassembleDXILUsingDXC( request.outputFunc = outputFunc; request.outputUserData = &spirvOut; - SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(entryPoint->compileRequest, request)); + SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(slangRequest, request)); return SLANG_OK; } SlangResult emitSPIRVAssemblyForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - String& assemblyOut) + BackEndCompileRequest* slangRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + String& assemblyOut) { List<uint8_t> spirv; - SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoint(entryPoint, targetReq, spirv)); + SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoint( + slangRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + spirv)); if (spirv.Count() == 0) return SLANG_FAIL; - return dissassembleSPIRV(entryPoint->compileRequest, spirv.begin(), spirv.Count(), assemblyOut); + return dissassembleSPIRV(slangRequest, spirv.begin(), spirv.Count(), assemblyOut); } #endif // Do emit logic for a single entry point CompileResult emitEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) { CompileResult result; - auto compileRequest = entryPoint->compileRequest; auto target = targetReq->target; switch (target) { case CodeGenTarget::HLSL: { - String code = emitHLSLForEntryPoint(entryPoint, targetReq); + String code = emitHLSLForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); maybeDumpIntermediate(compileRequest, code.Buffer(), target); result = CompileResult(code); } @@ -850,7 +1014,12 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::GLSL: { - String code = emitGLSLForEntryPoint(entryPoint, targetReq); + String code = emitGLSLForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); maybeDumpIntermediate(compileRequest, code.Buffer(), target); result = CompileResult(code); } @@ -860,7 +1029,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXBytecode: { List<uint8_t> code; - if (SLANG_SUCCEEDED(emitDXBytecodeForEntryPoint(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitDXBytecodeForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target); result = CompileResult(code); @@ -871,7 +1046,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXBytecodeAssembly: { String code; - if (SLANG_SUCCEEDED(emitDXBytecodeAssemblyForEntryPoint(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitDXBytecodeAssemblyForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), target); result = CompileResult(code); @@ -884,7 +1065,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXIL: { List<uint8_t> code; - if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target); result = CompileResult(code); @@ -895,7 +1082,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXILAssembly: { List<uint8_t> code; - if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { String assembly; dissassembleDXILUsingDXC( @@ -915,7 +1108,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::SPIRV: { List<uint8_t> code; - if (SLANG_SUCCEEDED(emitSPIRVForEntryPoint(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitSPIRVForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target); result = CompileResult(code); @@ -926,7 +1125,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::SPIRVAssembly: { String code; - if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoint(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), target); result = CompileResult(code); @@ -957,16 +1162,16 @@ SlangResult dissassembleDXILUsingDXC( }; static void writeOutputFile( - CompileRequest* compileRequest, - FILE* file, - String const& path, - void const* data, - size_t size) + BackEndCompileRequest* compileRequest, + FILE* file, + String const& path, + void const* data, + size_t size) { size_t count = fwrite(data, size, 1, file); if (count != 1) { - compileRequest->mSink.diagnose( + compileRequest->getSink()->diagnose( SourceLoc(), Diagnostics::cannotWriteOutputFile, path); @@ -974,16 +1179,16 @@ SlangResult dissassembleDXILUsingDXC( } static void writeOutputFile( - CompileRequest* compileRequest, - ISlangWriter* writer, - String const& path, - void const* data, - size_t size) + BackEndCompileRequest* compileRequest, + ISlangWriter* writer, + String const& path, + void const* data, + size_t size) { if (SLANG_FAILED(writer->write((const char*)data, size))) { - compileRequest->mSink.diagnose( + compileRequest->getSink()->diagnose( SourceLoc(), Diagnostics::cannotWriteOutputFile, path); @@ -991,18 +1196,18 @@ SlangResult dissassembleDXILUsingDXC( } static void writeOutputFile( - CompileRequest* compileRequest, - String const& path, - void const* data, - size_t size, - OutputFileKind kind) + BackEndCompileRequest* compileRequest, + String const& path, + void const* data, + size_t size, + OutputFileKind kind) { FILE* file = fopen( path.Buffer(), kind == OutputFileKind::Binary ? "wb" : "w"); if (!file) { - compileRequest->mSink.diagnose( + compileRequest->getSink()->diagnose( SourceLoc(), Diagnostics::cannotWriteOutputFile, path); @@ -1014,11 +1219,12 @@ SlangResult dissassembleDXILUsingDXC( } static void writeEntryPointResultToFile( - EntryPointRequest* entryPoint, + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, String const& outputPath, CompileResult const& result) { - auto compileRequest = entryPoint->compileRequest; + SLANG_UNUSED(entryPoint); switch (result.format) { @@ -1059,13 +1265,15 @@ SlangResult dissassembleDXILUsingDXC( } static void writeEntryPointResultToStandardOutput( - EntryPointRequest* entryPoint, + EndToEndCompileRequest* compileRequest, + EntryPoint* entryPoint, TargetRequest* targetReq, CompileResult const& result) { - auto compileRequest = entryPoint->compileRequest; + SLANG_UNUSED(entryPoint); ISlangWriter* writer = compileRequest->getWriter(WriterChannel::StdOutput); + auto backEndReq = compileRequest->getBackEndReq(); switch (result.format) { @@ -1087,7 +1295,7 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXBytecode: { String assembly; - dissassembleDXBC(compileRequest, + dissassembleDXBC(backEndReq, data.begin(), data.end() - data.begin(), assembly); writeOutputToConsole(writer, assembly); @@ -1099,7 +1307,7 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXIL: { String assembly; - dissassembleDXILUsingDXC(compileRequest, + dissassembleDXILUsingDXC(backEndReq, data.begin(), data.end() - data.begin(), assembly); @@ -1111,7 +1319,7 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::SPIRV: { String assembly; - dissassembleSPIRV(compileRequest, + dissassembleSPIRV(backEndReq, data.begin(), data.end() - data.begin(), assembly); writeOutputToConsole(writer, assembly); @@ -1129,7 +1337,7 @@ SlangResult dissassembleDXILUsingDXC( writer->setMode(SLANG_WRITER_MODE_BINARY); writeOutputFile( - compileRequest, + backEndReq, writer, "stdout", data.begin(), @@ -1146,89 +1354,108 @@ SlangResult dissassembleDXILUsingDXC( } static void writeEntryPointResult( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - UInt entryPointIndex) + EndToEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + TargetRequest* targetReq, + Int entryPointIndex) { - // It is possible that we are dynamically discovering entry - // points (using `[shader(...)]` attributes), so that the - // number of entry points on the compile request does not - // match the number of entries in the `entryPointOutputPaths` - // array. - // - String outputPath; - if( entryPointIndex < targetReq->entryPointOutputPaths.Count() ) - { - outputPath = targetReq->entryPointOutputPaths[entryPointIndex]; - } + auto program = compileRequest->getSpecializedProgram(); + auto targetProgram = program->getTargetProgram(targetReq); + auto backEndReq = compileRequest->getBackEndReq(); - auto& result = targetReq->entryPointResults[entryPointIndex]; + auto& result = targetProgram->getExistingEntryPointResult(entryPointIndex); // Skip the case with no output if (result.format == ResultFormat::None) return; - if (outputPath.Length()) - { - writeEntryPointResultToFile(entryPoint, outputPath, result); - } - else + // It is possible that we are dynamically discovering entry + // points (using `[shader(...)]` attributes), so that there + // might be entry points added to the program that did not + // get paths specified via command-line options. + // + RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; + if(compileRequest->targetInfos.TryGetValue(targetReq, targetInfo)) { - writeEntryPointResultToStandardOutput(entryPoint, targetReq, result); + String outputPath; + if(targetInfo->entryPointOutputPaths.TryGetValue(entryPointIndex, outputPath)) + { + writeEntryPointResultToFile(backEndReq, entryPoint, outputPath, result); + return; + } } + + writeEntryPointResultToStandardOutput(compileRequest, entryPoint, targetReq, result); } void generateOutputForTarget( - TargetRequest* targetReq) + BackEndCompileRequest* compileReq, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) { - CompileRequest* compileReq = targetReq->compileRequest; + auto program = compileReq->getProgram(); + auto targetProgram = program->getTargetProgram(targetReq); // Generate target code any entry points that // have been requested for compilation. - for (auto& entryPoint : compileReq->entryPoints) + auto entryPointCount = program->getEntryPointCount(); + for(UInt ii = 0; ii < entryPointCount; ++ii) { - CompileResult entryPointResult = emitEntryPoint(entryPoint, targetReq); - targetReq->entryPointResults.Add(entryPointResult); + auto entryPoint = program->getEntryPoint(ii); + CompileResult entryPointResult = emitEntryPoint( + compileReq, + entryPoint, + ii, + targetReq, + endToEndReq); + targetProgram->setEntryPointResult(ii, entryPointResult); } } - void generateOutput( - CompileRequest* compileRequest) + static void _generateOutput( + BackEndCompileRequest* compileRequest, + EndToEndCompileRequest* endToEndReq) { // Go through the code-generation targets that the user // has specified, and generate code for each of them. // - for (auto targetReq : compileRequest->targets) + auto linkage = compileRequest->getLinkage(); + for (auto targetReq : linkage->targets) { - generateOutputForTarget(targetReq); + generateOutputForTarget(compileRequest, targetReq, endToEndReq); } + } + + void generateOutput( + BackEndCompileRequest* compileRequest) + { + _generateOutput(compileRequest, nullptr); + } + + void generateOutput( + EndToEndCompileRequest* compileRequest) + { + _generateOutput(compileRequest->getBackEndReq(), compileRequest); // If we are in command-line mode, we might be expected to actually // write output to one or more files here. if (compileRequest->isCommandLineCompile) { - for (auto targetReq : compileRequest->targets) + auto linkage = compileRequest->getLinkage(); + auto program = compileRequest->getSpecializedProgram(); + for (auto targetReq : linkage->targets) { - UInt entryPointCount = compileRequest->entryPoints.Count(); + UInt entryPointCount = program->getEntryPointCount(); for (UInt ee = 0; ee < entryPointCount; ++ee) { writeEntryPointResult( - compileRequest->entryPoints[ee], + compileRequest, + program->getEntryPoint(ee), targetReq, ee); } } - - if (compileRequest->containerOutputPath.Length() != 0) - { - auto& data = compileRequest->generatedBytecode; - writeOutputFile(compileRequest, - compileRequest->containerOutputPath, - data.begin(), - data.end() - data.begin(), - OutputFileKind::Binary); - } } } @@ -1237,7 +1464,7 @@ SlangResult dissassembleDXILUsingDXC( // void dumpIntermediate( - CompileRequest*, + BackEndCompileRequest*, void const* data, size_t size, char const* ext, @@ -1271,7 +1498,7 @@ SlangResult dissassembleDXILUsingDXC( } void dumpIntermediateText( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, void const* data, size_t size, char const* ext) @@ -1280,7 +1507,7 @@ SlangResult dissassembleDXILUsingDXC( } void dumpIntermediateBinary( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, void const* data, size_t size, char const* ext) @@ -1289,7 +1516,7 @@ SlangResult dissassembleDXILUsingDXC( } void maybeDumpIntermediate( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, void const* data, size_t size, CodeGenTarget target) @@ -1362,7 +1589,7 @@ SlangResult dissassembleDXILUsingDXC( } void maybeDumpIntermediate( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, char const* text, CodeGenTarget target) { diff --git a/source/slang/compiler.h b/source/slang/compiler.h index 39199a62f..c975c1c2b 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -20,6 +20,8 @@ namespace Slang class CompileRequest; class ProgramLayout; class PtrType; + class TargetProgram; + class TargetRequest; class TypeLayout; enum class CompilerMode @@ -88,8 +90,12 @@ namespace Slang kMatrixLayoutMode_ColumnMajor = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR, }; - - class CompileRequest; + class Linkage; + class Module; + class Program; + class FrontEndCompileRequest; + class BackEndCompileRequest; + class EndToEndCompileRequest; class TranslationUnitRequest; // Result of compiling an entry point. @@ -112,18 +118,165 @@ namespace Slang ComPtr<ISlangBlob> blob; }; - // Describes an entry point that we've been requested to compile - class EntryPointRequest : public RefObject + /// A request for the front-end to find and validate an entry-point function + struct FrontEndEntryPointRequest : RefObject { public: + /// Create a request for an entry point. + FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile); + + /// Get the parent front-end compile request. + FrontEndCompileRequest* getCompileRequest() { return m_compileRequest; } + + /// Get the translation unit that contains the entry point. + TranslationUnitRequest* getTranslationUnit(); + + /// Get the name of the entry point to find. + Name* getName() { return m_name; } + + /// Get the stage that the entry point is to be compiled for + Stage getStage() { return m_profile.GetStage(); } + + /// Get the profile that the entry point is to be compiled for + Profile getProfile() { return m_profile; } + + private: // The parent compile request - CompileRequest* compileRequest = nullptr; + FrontEndCompileRequest* m_compileRequest; + + // The index of the translation unit that will hold the entry point + int m_translationUnitIndex; + + // The name of the entry point function to look for + Name* m_name; + + // The profile to compile for (including stage) + Profile m_profile; + }; + + /// Tracks an ordered list of modules that something depends on. + struct ModuleDependencyList + { + public: + /// Get the list of modules that are depended on. + List<RefPtr<Module>> const& getModuleList() { return m_moduleList; } + + /// Add a module and everything it depends on to the list. + void addDependency(Module* module); + + /// Add a module to the list, but not the modules it depends on. + void addLeafDependency(Module* module); + + private: + void _addDependency(Module* module); + + List<RefPtr<Module>> m_moduleList; + HashSet<Module*> m_moduleSet; + }; + + /// Tracks an unordered list of filesystem paths that something depends on + struct FilePathDependencyList + { + public: + /// Get the list of paths that are depended on. + List<String> const& getFilePathList() { return m_filePathList; } + + /// Add a path to the list, if it is not already present + void addDependency(String const& path); + + /// Add all of the paths that `module` depends on to the list + void addDependency(Module* module); + + private: + + // TODO: We are using a `HashSet` here to deduplicate + // the paths so that we don't return the same path + // multiple times from `getFilePathList`, but because + // order isn't important, we could potentially do better + // in terms of memory (at some cost in performance) by + // just sorting the `m_filePathList` every once in + // a while and then deduplicating. + + List<String> m_filePathList; + HashSet<String> m_filePathSet; + }; + + /// Describes an entry point for the purposes of layout and code generation. + /// + /// This class also tracks any generic arguments to the entry point, + /// in the case that it is a specialization of a generic entry point. + /// + /// There is also a provision for creating a "dummy" entry point for + /// the purposes of pass-through compilation modes. Only the + /// `getName()` and `getProfile()` methods should be expected to + /// return useful data on pass-through entry points. + /// + class EntryPoint : public RefObject + { + public: + /// Create an entry point that refers to the given function. + static RefPtr<EntryPoint> create( + DeclRef<FuncDecl> funcDeclRef, + Profile profile); + + /// Get the function decl-ref, including any generic arguments. + DeclRef<FuncDecl> getFuncDeclRef() { return m_funcDeclRef; } + + /// Get the function declaration (without generic arguments). + RefPtr<FuncDecl> getFuncDecl() { return m_funcDeclRef.getDecl(); } + + /// Get the name of the entry point + Name* getName() { return m_name; } + + /// Get the profile associated with the entry point + /// + /// Note: only the stage part of the profile is expected + /// to contain useful data, but certain legacy code paths + /// allow for "shader model" information to come via this path. + /// + Profile getProfile() { return m_profile; } + + /// Get the stage that the entry point is for. + Stage getStage() { return m_profile.GetStage(); } + + /// Get the module that contains the entry point. + Module* getModule(); + + /// Get the linkage that contains the module for this entry pooint. + Linkage* getLinkage(); + + /// Get a list of modules that this entry point depends on. + /// + /// This will include the module that defines the entry point (see `getModule()`), + /// but may also include modules that are required by its generic type arguments. + /// + List<RefPtr<Module>> getModuleDependencies() { return m_dependencyList.getModuleList(); } + + /// Get a list of tagged-union types referenced by the entry point's generic parameters. + List<RefPtr<TaggedUnionType>> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } + + /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. + static RefPtr<EntryPoint> createDummyForPassThrough( + Name* name, + Profile profile); + + private: + EntryPoint( + Name* name, + Profile profile, + DeclRef<FuncDecl> funcDeclRef); // The name of the entry point function (e.g., `main`) - Name* name; + // + Name* m_name = nullptr; - /// Source code for the generic arguments to use for the generic parameters of the entry point. - List<String> genericArgStrings; + // The declaration of the entry-point function itself. + // + DeclRef<FuncDecl> m_funcDeclRef; // The profile that the entry point will be compiled for // (this is a combination of the target stage, and also @@ -135,33 +288,13 @@ namespace Slang // from the target, while the stage part is all that is // intrinsic to the entry point. // - Profile profile; + Profile m_profile; - // Get the stage that the entry point is being compiled for. - Stage getStage() { return profile.GetStage(); } + // Any tagged union types that were referenced by the generic arguments of the entry point. + List<RefPtr<TaggedUnionType>> m_taggedUnionTypes; - // The index of the translation unit (within the parent - // compile request) that the entry point function is - // supposed to be defined in. - int translationUnitIndex; - - // The translation unit that this entry point came from - TranslationUnitRequest* getTranslationUnit(); - - // The declaration of the entry-point function itself. - // This will be filled in as part of semantic analysis; - // it should not be assumed to be available in cases - // where any errors were diagnosed. - // - DeclRef<FuncDecl> funcDeclRef; - - DeclRef<FuncDecl> getFuncDeclRef(); - RefPtr<FuncDecl> getFuncDecl(); - - RefPtr<Substitutions> globalGenericSubst; - - /// Any tagged union types that were referenced by the generic arguments of the entry point. - List<RefPtr<TaggedUnionType>> taggedUnionTypes; + // Modules the entry point depends on. + ModuleDependencyList m_dependencyList; }; enum class PassThroughMode : SlangPassThrough @@ -174,13 +307,78 @@ namespace Slang class SourceFile; - // A single translation unit requested to be compiled. - // + /// A module of code that has been compiled through the front-end + /// + /// A module comprises all the code from one translation unit (which + /// may span multiple Slang source files), and provides access + /// to both the AST and IR representations of that code. + /// + class Module : public RefObject + { + public: + /// Create a module (initially empty). + Module(Linkage* linkage); + + /// Get the parent linkage of this module. + Linkage* getLinkage() { return m_linkage; } + + /// Get the AST for the module (if it has been parsed) + ModuleDecl* getModuleDecl() { return m_moduleDecl; } + + /// The the IR for the module (if it has been generated) + IRModule* getIRModule() { return m_irModule; } + + /// Get the list of other modules this module depends on + List<RefPtr<Module>> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } + + /// Get the list of filesystem paths this module depends on + List<String> const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); } + + /// Register a module that this module depends on + void addModuleDependency(Module* module); + + /// Register a filesystem path that this module depends on + void addFilePathDependency(String const& path); + + /// Set the AST for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setModuleDecl(ModuleDecl* moduleDecl) { m_moduleDecl = moduleDecl; } + + /// Set the IR for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setIRModule(IRModule* irModule) { m_irModule = irModule; } + + private: + // The parent linkage + Linkage* m_linkage = nullptr; + + // The AST for the module + RefPtr<ModuleDecl> m_moduleDecl; + + // The IR for the module + RefPtr<IRModule> m_irModule = nullptr; + + // List of modules this module depends on + ModuleDependencyList m_moduleDependencyList; + + // List of filesystem paths this module depends on + FilePathDependencyList m_filePathDependencyList; + }; + typedef Module LoadedModule; + + /// A request for the front-end to compile a translation unit. class TranslationUnitRequest : public RefObject { public: + TranslationUnitRequest( + FrontEndCompileRequest* compileRequest); + // The parent compile request - CompileRequest* compileRequest = nullptr; + FrontEndCompileRequest* compileRequest = nullptr; // The language in which the source file(s) // are assumed to be written @@ -189,26 +387,30 @@ namespace Slang // The source file(s) that will be compiled to form this translation unit // // Usually, for HLSL or GLSL there will be only one file. - List<SourceFile*> sourceFiles; + List<SourceFile*> m_sourceFiles; + + List<SourceFile*> const& getSourceFiles() { return m_sourceFiles; } + void addSourceFile(SourceFile* sourceFile); // The entry points associated with this translation unit - List<RefPtr<EntryPointRequest> > entryPoints; + List<RefPtr<EntryPoint>> entryPoints; // Preprocessor definitions to use for this translation unit only - // (whereas the ones on `CompileOptions` will be shared) + // (whereas the ones on `compileRequest` will be shared) Dictionary<String, String> preprocessorDefinitions; - // Compile flags for this translation unit - SlangCompileFlags compileFlags = 0; + /// The name that will be used for the module this translation unit produces. + Name* moduleName = nullptr; + + /// Result of compiling this translation unit (a module) + RefPtr<Module> module; - // The parsed syntax for the translation unit - RefPtr<ModuleDecl> SyntaxNode; + Module* getModule() { return module; } + RefPtr<ModuleDecl> getModuleDecl() { return module->getModuleDecl(); } - // The IR-level code for this translation unit. - // This will only be valid/non-null after semantic - // checking and IR generation are complete, so it - // is not safe to use this field without testing for NULL. - RefPtr<IRModule> irModule; + Session* getSession(); + NamePool* getNamePool(); + SourceManager* getSourceManager(); }; enum class FloatingPointMode : SlangFloatingPointMode @@ -232,33 +434,28 @@ namespace Slang Binary = SLANG_WRITER_MODE_BINARY, }; - // A request to generate output in some target format + /// A request to generate output in some target format. class TargetRequest : public RefObject { public: - CompileRequest* compileRequest; + Linkage* linkage; CodeGenTarget target; SlangTargetFlags targetFlags = 0; Slang::Profile targetProfile = Slang::Profile(); FloatingPointMode floatingPointMode = FloatingPointMode::Default; - // Requested output paths for each entry point. - // An empty string indices no output desired for - // the given entry point. - List<String> entryPointOutputPaths; + Linkage* getLinkage() { return linkage; } + CodeGenTarget getTarget() { return target; } + Profile getTargetProfile() { return targetProfile; } + FloatingPointMode getFloatingPointMode() { return floatingPointMode; } - // The resulting reflection layout information - RefPtr<ProgramLayout> layout; - - // Generated compile results for each entry point - // in the parent compile request (indexing matches - // the order they are given in the compile request) - List<CompileResult> entryPointResults; + Session* getSession(); + MatrixLayoutMode getDefaultMatrixLayoutMode(); // TypeLayouts created on the fly by reflection API Dictionary<Type*, RefPtr<TypeLayout>> typeLayouts; - MatrixLayoutMode getDefaultMatrixLayoutMode(); + Dictionary<Type*, RefPtr<TypeLayout>>& getTypeLayouts() { return typeLayouts; } }; /// Are we generating code for a D3D API? @@ -280,7 +477,8 @@ namespace Slang // - If the entry point and target disagree on the profile family, always use the // profile family and version from the target. // - Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target); + Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target); + // A directory to be searched when looking for files (e.g., `#include`) struct SearchDirectory @@ -294,110 +492,53 @@ namespace Slang String path; }; - // Represents a module that has been loaded through the front-end - // (up through IR generation). - // - class LoadedModule : public RefObject + /// A list of directories to search for files (e.g., `#include`) + struct SearchDirectoryList { - public: - // The AST for the module - RefPtr<ModuleDecl> moduleDecl; + // A parent list that should also be searched + SearchDirectoryList* parent = nullptr; - // The IR for the module - RefPtr<IRModule> irModule = nullptr; + // Directories to be searched + List<SearchDirectory> searchDirectories; }; - class Session; - - /// Create a blob that will retain (a copy of) raw data. /// ComPtr<ISlangBlob> createRawBlob(void const* data, size_t size); - class CompileRequest : public RefObject + /// A context for loading and re-using code modules. + class Linkage : public RefObject { public: - // Pointer to parent session - Session* mSession; + /// Create an initially-empty linkage + Linkage(Session* session); + + /// Get the parent session for this linkage + Session* getSession() { return m_session; } // Information on the targets we are being asked to // generate code for. List<RefPtr<TargetRequest>> targets; - // What container format are we being asked to generate? - ContainerFormat containerFormat = ContainerFormat::None; - - // Path to output container to - String containerOutputPath; - // Directories to search for `#include` files or `import`ed modules - List<SearchDirectory> searchDirectories; + SearchDirectoryList searchDirectories; + + SearchDirectoryList const& getSearchDirectories() { return searchDirectories; } // Definitions to provide during preprocessing Dictionary<String, String> preprocessorDefinitions; - // Translation units we are being asked to compile - List<RefPtr<TranslationUnitRequest> > translationUnits; - - // Entry points we've been asked to compile (each - // associated with a translation unit). - List<RefPtr<EntryPointRequest> > entryPoints; - - // Types constructed by reflection API - Dictionary<String, RefPtr<Type>> types; - - /// The layout to use for matrices by default (row/column major) - MatrixLayoutMode defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor; - MatrixLayoutMode getDefaultMatrixLayoutMode() { return defaultMatrixLayoutMode; } - - // Should we just pass the input to another compiler? - PassThroughMode passThrough = PassThroughMode::None; - - // Compile flags to be shared by all translation units - SlangCompileFlags compileFlags = 0; - - // Should we dump intermediate results along the way, for debugging? - bool shouldDumpIntermediates = false; - - bool shouldDumpIR = false; - bool shouldValidateIR = false; - bool shouldSkipCodegen = false; - - // If true then generateIR will serialize out IR, and serialize back in again. Making - // serialization a bottleneck or firewall between the front end and the backend - bool useSerialIRBottleneck = false; - - // If true will serialize and de-serialize with debug information - bool verifyDebugSerialization = false; - - // How should `#line` directives be emitted (if at all)? - LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default; - // Are we being driven by the command-line `slangc`, and should act accordingly? - bool isCommandLineCompile = false; // Source manager to help track files loaded - SourceManager sourceManagerStorage; - SourceManager* sourceManager; + SourceManager m_defaultSourceManager; + SourceManager* m_sourceManager = nullptr; // Name pool for looking up names NamePool namePool; NamePool* getNamePool() { return &namePool; } - // Output stuff - DiagnosticSink mSink; - String mDiagnosticOutput; - - /// A blob holding the diagnostic output - ComPtr<ISlangBlob> diagnosticOutputBlob; - - // Files that compilation depended on - List<String> mDependencyFilePaths; - - // Generated bytecode representation of all the code - List<uint8_t> generatedBytecode; - // Modules that have been dynamically loaded via `import` // // This is a list of unique modules loaded, in the order they were encountered. @@ -424,11 +565,7 @@ namespace Slang /// or a wrapped impl that makes fileSystem operate as fileSystemExt ComPtr<ISlangFileSystemExt> fileSystemExt; - // For output - ComPtr<ISlangWriter> m_writers[SLANG_WRITER_CHANNEL_COUNT_OF]; - - void setWriter(WriterChannel chan, ISlangWriter* writer); - ISlangWriter* getWriter(WriterChannel chan) const { return m_writers[int(chan)]; } + ISlangFileSystemExt* getFileSystemExt() { return fileSystemExt; } /// Load a file into memory using the configured file system. /// @@ -438,11 +575,177 @@ namespace Slang /// SlangResult loadFile(String const& path, ISlangBlob** outBlob); - CompileRequest(Session* session); - RefPtr<Expr> parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope); + RefPtr<Expr> parseTypeString(String typeStr, RefPtr<Scope> scope); + + /// Add a mew target amd return its index. + UInt addTarget( + CodeGenTarget target); + + RefPtr<Module> loadModule( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink); + + void loadParsedModule( + RefPtr<TranslationUnitRequest> translationUnit, + Name* name, + PathInfo const& pathInfo); + + /// Load a module of the given name. + Module* loadModule(String const& name); + + RefPtr<Module> findOrImportModule( + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink); - Type* getTypeFromString(String typeStr); + SourceManager* getSourceManager() + { + return m_sourceManager; + } + + /// Override the source manager for the linakge. + /// + /// This is only used to install a temporary override when + /// parsing stuff from strings (where we don't want to retain + /// full source files for the parsed result). + /// + /// TODO: We should remove the need for this hack. + /// + void setSourceManager(SourceManager* sourceManager) + { + m_sourceManager = sourceManager; + } + + void setFileSystem(ISlangFileSystem* fileSystem); + + /// The layout to use for matrices by default (row/column major) + MatrixLayoutMode defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor; + MatrixLayoutMode getDefaultMatrixLayoutMode() { return defaultMatrixLayoutMode; } + + private: + Session* m_session = nullptr; + + /// Tracks state of modules currently being loaded. + /// + /// This information is used to diagnose cases where + /// a user tries to recursively import the same module + /// (possibly along a transitive chain of `import`s). + /// + struct ModuleBeingImportedRAII + { + public: + ModuleBeingImportedRAII( + Linkage* linkage, + Module* module) + : linkage(linkage) + , module(module) + { + next = linkage->m_modulesBeingImported; + linkage->m_modulesBeingImported = this; + } + + ~ModuleBeingImportedRAII() + { + linkage->m_modulesBeingImported = next; + } + + Linkage* linkage; + Module* module; + ModuleBeingImportedRAII* next; + }; + + // Any modules currently being imported will be listed here + ModuleBeingImportedRAII* m_modulesBeingImported; + + /// Is the given module in the middle of being imported? + bool isBeingImported(Module* module); + }; + + /// Shared functionality between front- and back-end compile requests. + /// + /// This is the base class for both `FrontEndCompileRequest` and + /// `BackEndCompileRequest`, and allows a small number of parts of + /// the compiler to be easily invocable from either front-end or + /// back-end work. + /// + class CompileRequestBase : public RefObject + { + // TODO: We really shouldn't need this type in the long run. + // The few places that rely on it should be refactored to just + // depend on the unerlying information (a linkage and a diagnostic + // sink) directly. + // + // The flags to control dumping and validation of IR should be + // moved to some kind of shared settings/options `struct` that + // both front-end and back-end requests can store. + + public: + Session* getSession(); + Linkage* getLinkage() { return m_linkage; } + DiagnosticSink* getSink() { return m_sink; } + SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } + SlangResult loadFile(String const& path, ISlangBlob** outBlob) { return getLinkage()->loadFile(path, outBlob); } + + bool shouldDumpIR = false; + bool shouldValidateIR = false; + + protected: + CompileRequestBase( + Linkage* linkage, + DiagnosticSink* sink); + + private: + Linkage* m_linkage = nullptr; + DiagnosticSink* m_sink = nullptr; + }; + + /// A request to compile source code to an AST + IR. + class FrontEndCompileRequest : public CompileRequestBase + { + public: + FrontEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink); + + int addEntryPoint( + int translationUnitIndex, + String const& name, + Profile entryPointProfile); + + // Translation units we are being asked to compile + List<RefPtr<TranslationUnitRequest> > translationUnits; + + RefPtr<TranslationUnitRequest> getTranslationUnit(UInt index) { return translationUnits[index]; } + + // Compile flags to be shared by all translation units + SlangCompileFlags compileFlags = 0; + + // If true then generateIR will serialize out IR, and serialize back in again. Making + // serialization a bottleneck or firewall between the front end and the backend + bool useSerialIRBottleneck = false; + + // If true will serialize and de-serialize with debug information + bool verifyDebugSerialization = false; + + List<RefPtr<FrontEndEntryPointRequest>> m_entryPointReqs; + + List<RefPtr<FrontEndEntryPointRequest>> const& getEntryPointReqs() { return m_entryPointReqs; } + UInt getEntryPointReqCount() { return m_entryPointReqs.Count(); } + FrontEndEntryPointRequest* getEntryPointReq(UInt index) { return m_entryPointReqs[index]; } + + // Directories to search for `#include` files or `import`ed modules + SearchDirectoryList searchDirectories; + + SearchDirectoryList const& getSearchDirectories() { return searchDirectories; } + + // Definitions to provide during preprocessing + Dictionary<String, String> preprocessorDefinitions; void parseTranslationUnit( TranslationUnitRequest* translationUnit); @@ -454,9 +757,24 @@ namespace Slang void generateIR(); SlangResult executeActionsInner(); - SlangResult executeActions(); - int addTranslationUnit(SourceLanguage language, String const& name); + /// Add a translation unit to be compiled. + /// + /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang` + /// @param moduleName The name that will be used for the module compile from the translation unit. + /// @return The zero-based index of the translation unit in this compile request. + int addTranslationUnit(SourceLanguage language, Name* moduleName); + + /// Add a translation unit to be compiled. + /// + /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang` + /// @return The zero-based index of the translation unit in this compile request. + /// + /// The module name for the translation unit will be automatically generated. + /// If all translation units in a compile request use automatically generated + /// module names, then they are guaranteed not to conflict with one another. + /// + int addTranslationUnit(SourceLanguage language); void addTranslationUnitSourceFile( int translationUnitIndex, @@ -476,63 +794,337 @@ namespace Slang int translationUnitIndex, String const& path); - int addEntryPoint( - int translationUnitIndex, - String const& name, - Profile profile, - List<String> const & genericTypeNames); + Program* getProgram() { return m_program; } - UInt addTarget( - CodeGenTarget target); + private: + RefPtr<Program> m_program; + }; - RefPtr<ModuleDecl> loadModule( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc); + /// A collection of code modules and entry points that are intended to be used together. + /// + /// A `Program` establishes that certain pieces of code are intended + /// to be used togehter so that, e.g., layout can make sure to allocate + /// space for the global shader parameters in all referenced modules. + /// + class Program : public RefObject + { + public: + /// Create a new program, initially empty. + /// + /// All code loaded into the program must come + /// from the given `linkage`. + Program( + Linkage* linkage); - void loadParsedModule( - RefPtr<TranslationUnitRequest> const& translationUnit, - Name* name, - PathInfo const& pathInfo); + /// Get the linkage that this program uses. + Linkage* getLinkage() { return m_linkage; } - RefPtr<ModuleDecl> findOrImportModule( - Name* name, - SourceLoc const& loc); + /// Get the number of entry points added to the program + UInt getEntryPointCount() { return m_entryPoints.Count(); } - Decl* lookupGlobalDecl(Name* name); + /// Get the entry point at the given `index`. + RefPtr<EntryPoint> getEntryPoint(UInt index) { return m_entryPoints[index]; } - SourceManager* getSourceManager() + /// Get the full ist of entry points on the program. + List<RefPtr<EntryPoint>> const& getEntryPoints() { return m_entryPoints; } + + /// Get the substitution (if any) that represents how global generics are specialized. + RefPtr<Substitutions> getGlobalGenericSubstitution() { return m_globalGenericSubst; } + + /// Get the full list of modules this program depends on + List<RefPtr<Module>> getModuleDependencies() { return m_moduleDependencyList.getModuleList(); } + + /// Get the full list of filesystem paths this program depends on + List<String> getFilePathDependencies() { return m_filePathDependencyList.getFilePathList(); } + + /// Get the target-specific version of this program for the given `target`. + /// + /// The `target` must be a target on the `Linkage` that was used to create this program. + TargetProgram* getTargetProgram(TargetRequest* target); + + /// Add a module (and everything it depends on) to the list of references + void addReferencedModule(Module* module); + + /// Add a module (but not the things it depends on) to the list of references + /// + /// This is a compatiblity hack for legacy compiler behavior. + void addReferencedLeafModule(Module* module); + + + /// Add an entry point to the program + /// + /// This also adds everything the entry point depends on to the list of references. + /// + void addEntryPoint(EntryPoint* entryPoint); + + /// Set the global generic argument substitution to use. + void setGlobalGenericSubsitution(RefPtr<Substitutions> subst) { - return sourceManager; + m_globalGenericSubst = subst; } - void setSourceManager(SourceManager* sm) + /// Parse a type from a string, in the context of this program. + /// + /// Any names in the string will be resolved using the modules + /// referenced by the program. + /// + /// On an error, returns null and reports diagnostic messages + /// to the provided `sink`. + /// + Type* getTypeFromString(String typeStr, DiagnosticSink* sink); + + /// Get the IR module that represents this program and its entry points. + /// + /// The IR module for a program tries to be minimal, and in the + /// common case will only include symbols with `[import]` declarations + /// for the entry point(s) of the program, and any types they + /// depend on. + /// + /// This IR module is intended to be linked against the IR modules + /// for all of the dependencies (see `getModuleDependencies()`) to + /// provide complete code. + /// + RefPtr<IRModule> getOrCreateIRModule(DiagnosticSink* sink); + + private: + // The linakge this program is associated with. + // + // Note that a `Program` keeps its associated linkage alive, + // and not vice versa. + // + RefPtr<Linkage> m_linkage; + + // Tracking data for the list of modules dependend on + ModuleDependencyList m_moduleDependencyList; + + // Tracking data for the list of filesystem paths dependend on + FilePathDependencyList m_filePathDependencyList; + + // Entry points that are part of the program. + List<RefPtr<EntryPoint> > m_entryPoints; + + // Specializations for global generic parameters (if any) + RefPtr<Substitutions> m_globalGenericSubst; + + // Generated IR for this program. + RefPtr<IRModule> m_irModule; + + // Cache of target-specific programs for each target. + Dictionary<TargetRequest*, RefPtr<TargetProgram>> m_targetPrograms; + + // Any types looked up dynamically using `getTypeFromString` + Dictionary<String, RefPtr<Type>> m_types; + }; + + /// A `Program` specialized for a particular `TargetRequest` + class TargetProgram : public RefObject + { + public: + TargetProgram( + Program* program, + TargetRequest* targetReq); + + /// Get the underlying program + Program* getProgram() { return m_program; } + + /// Get the underlying target + TargetRequest* getTargetReq() { return m_targetReq; } + + /// Get the layout for the program on the target. + /// + /// If this is the first time the layout has been + /// requested, report any errors that arise during + /// layout to the given `sink`. + /// + ProgramLayout* getOrCreateLayout(DiagnosticSink* sink); + + /// Get the layout for the program on the taarget. + /// + /// This routine assumes that `getOrCreateLayout` + /// has already been called previously. + /// + ProgramLayout* getExistingLayout() { - sourceManager = sm; - mSink.sourceManager = sm; + SLANG_ASSERT(m_layout); + return m_layout; } - void setFileSystem(ISlangFileSystem* fileSystem); + /// Get the compiled code for an entry point on the target. + /// + /// This routine assumes code generation has already been + /// performed and called `setEntryPointResult`. + /// + CompileResult& getExistingEntryPointResult(Int entryPointIndex) + { + return m_entryPointResults[entryPointIndex]; + } + + // TODO: Need a lazy `getOrCreateEntryPointResult` + + /// Set the compiled code for an entry point. + /// + /// Should only be called by code generation. + void setEntryPointResult(Int entryPointIndex, CompileResult const& result) + { + m_entryPointResults[entryPointIndex] = result; + } + + private: + // The program being compiled or laid out + Program* m_program; + + // The target that code/layout will be generated for + TargetRequest* m_targetReq; + + // The computed layout, if it has been generated yet + RefPtr<ProgramLayout> m_layout; + + // Generated compile results for each entry point + // in the parent `Program` (indexing matches + // the order they are given in the `Program`) + List<CompileResult> m_entryPointResults; + }; + + /// A request to generate code for a program + class BackEndCompileRequest : public CompileRequestBase + { + public: + BackEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink, + Program* program = nullptr); + + // Should we dump intermediate results along the way, for debugging? + bool shouldDumpIntermediates = false; + + // How should `#line` directives be emitted (if at all)? + LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default; + + LineDirectiveMode getLineDirectiveMode() { return lineDirectiveMode; } - /// During propagation of an exception for an internal - /// error, note that this source location was involved - void noteInternalErrorLoc(SourceLoc const& loc); + Program* getProgram() { return m_program; } + void setProgram(Program* program) { m_program = program; } - int internalErrorLocsNoted = 0; + private: + RefPtr<Program> m_program; + }; + + /// A compile request that spans the front and back ends of the compiler + /// + /// This is what the command-line `slangc` uses, as well as the legacy + /// C API. It ties together the functionality of `Linkage`, + /// `FrontEndCompileRequest`, and `BackEndCompileRequest`, plus a small + /// number of additional features that primarily make sense for + /// command-line usage. + /// + class EndToEndCompileRequest : public RefObject + { + public: + EndToEndCompileRequest( + Session* session); + + // What container format are we being asked to generate? + // + // Note: This field is unused except by the options-parsing + // logic; it exists to support wriiting out binary modules + // once that feature is ready. + // + ContainerFormat containerFormat = ContainerFormat::None; + + // Path to output container to + // + // Note: This field exists to support wriiting out binary modules + // once that feature is ready. + // + String containerOutputPath; + + // Should we just pass the input to another compiler? + PassThroughMode passThrough = PassThroughMode::None; + + /// Source code for the generic arguments to use for the global generic parameters of the program. + List<String> globalGenericArgStrings; + + + bool shouldSkipCodegen = false; + + // Are we being driven by the command-line `slangc`, and should act accordingly? + bool isCommandLineCompile = false; + + String mDiagnosticOutput; + + /// A blob holding the diagnostic output + ComPtr<ISlangBlob> diagnosticOutputBlob; + + /// Per-entry-point information not tracked by other compile requests + class EntryPointInfo : public RefObject + { + public: + /// Source code for the generic arguments to use for the generic parameters of the entry point. + List<String> genericArgStrings; + }; + List<EntryPointInfo> entryPoints; + + /// Per-target information only needed for command-line compiles + class TargetInfo : public RefObject + { + public: + // Requested output paths for each entry point. + // An empty string indices no output desired for + // the given entry point. + Dictionary<Int, String> entryPointOutputPaths; + }; + Dictionary<TargetRequest*, RefPtr<TargetInfo>> targetInfos; + + Linkage* getLinkage() { return m_linkage; } + + int addEntryPoint( + int translationUnitIndex, + String const& name, + Profile profile, + List<String> const & genericTypeNames); + + void setWriter(WriterChannel chan, ISlangWriter* writer); + ISlangWriter* getWriter(WriterChannel chan) const { return m_writers[int(chan)]; } + + SlangResult executeActionsInner(); + SlangResult executeActions(); + + Session* getSession() { return m_session; } + DiagnosticSink* getSink() { return &m_sink; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + + FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } + BackEndCompileRequest* getBackEndReq() { return m_backEndReq; } + Program* getUnspecializedProgram() { return getFrontEndReq()->getProgram(); } + Program* getSpecializedProgram() { return m_specializedProgram; } + + private: + Session* m_session = nullptr; + RefPtr<Linkage> m_linkage; + DiagnosticSink m_sink; + RefPtr<FrontEndCompileRequest> m_frontEndReq; + RefPtr<Program> m_unspecializedProgram; + RefPtr<Program> m_specializedProgram; + RefPtr<BackEndCompileRequest> m_backEndReq; + + // For output + ComPtr<ISlangWriter> m_writers[SLANG_WRITER_CHANNEL_COUNT_OF]; }; void generateOutput( - CompileRequest* compileRequest); + BackEndCompileRequest* compileRequest); + + void generateOutput( + EndToEndCompileRequest* compileRequest); // Helper to dump intermediate output when debugging void maybeDumpIntermediate( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, void const* data, size_t size, CodeGenTarget target); void maybeDumpIntermediate( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, char const* text, CodeGenTarget target); @@ -548,12 +1140,14 @@ namespace Slang @param sink The diagnostic sink to report to */ void reportExternalCompileError(const char* compilerName, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink); - /* Given a translationUnitRequest determines a filename that is most suitable to identify the input. - If the translation is a pass through will attempt to get the source file pathname. If the source is slang generated - there is no equivalent name so will return 'slang-generated' - @param translationUnitRequest The request to find an appropriate source path for + /* Determines a suitable filename to identify the input for a given entry point being compiled. + If the end-to-end compile is a pass-through case, will attempt to find the (unique) source file + pathname for the translation unit containing the entry point at `entryPointIndex. + If the compilation is not in a pass-through case, then always returns `"slang-generated"`. + @param endToEndReq The end-to-end compile request which might be using pass-through copmilation + @param entryPointIndex The index of the entry point to compute a filename for. @return the appropriate source filename */ - String calcTranslationUnitSourcePath(TranslationUnitRequest* translationUnitRequest); + String calcSourcePathForEntryPoint(EndToEndCompileRequest* endToEndReq, UInt entryPointIndex); struct TypeCheckingCache; // @@ -696,6 +1290,10 @@ namespace Slang String const& path, String const& source); ~Session(); + + private: + /// Linkage used for all built-in (stdlib) code. + RefPtr<Linkage> m_builtinLinkage; }; } diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index 0480eb934..10dcefe19 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -240,6 +240,14 @@ SIMPLE_SYNTAX_CLASS(FuncDecl, FunctionDeclBase) // that provides a scope for some number of declarations. SYNTAX_CLASS(ModuleDecl, ContainerDecl) FIELD(RefPtr<Scope>, scope) + + // The API-level module that this declaration belong to. + // + // This field allows lookup of the `Module` based on a + // declaration nested under a `ModuleDecl` by following + // its chain of parents. + // + RAW(Module* module = nullptr;) END_SYNTAX_CLASS() SYNTAX_CLASS(ImportDecl, Decl) diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 51076bde1..2d0dd7fdd 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -336,7 +336,7 @@ DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entr DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function") DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'") -DIAGNOSTIC(38005, Error, entryPointTypeSymbolNotAType, "entry-point type parameter '$0' must be declared as a type") +DIAGNOSTIC(38005, Error, globalGenericArgumentNotAType, "argument for global generic parameter '$0' must be a type") DIAGNOSTIC(38006, Warning, specifiedStageDoesntMatchAttribute, "entry point '$0' being compiled for the '$1' stage has a '[shader(...)]' attribute that specifies the '$2' stage") DIAGNOSTIC(38007, Error, entryPointHasNoStage, "no stage specified for entry point '$0'; use either a '[shader(\"name\")]' function attribute or the '-stage <name>' command-line option to specify a stage") @@ -356,6 +356,10 @@ DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_Dispatc DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'") +DIAGNOSTIC(38020, Error, mismatchGlobalGenericArguments, "expected $0 global generic arguments ($1 provided)") +DIAGNOSTIC(38021, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") + + DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.") diff --git a/source/slang/diagnostics.h b/source/slang/diagnostics.h index e3aba32e6..9efc5efc6 100644 --- a/source/slang/diagnostics.h +++ b/source/slang/diagnostics.h @@ -2,6 +2,7 @@ #define RASTER_RENDERER_COMPILE_ERROR_H #include "../core/basic.h" +#include "../core/slang-writer.h" #include "source-loc.h" #include "token.h" @@ -153,7 +154,7 @@ namespace Slang StringBuilder outputBuffer; // List<Diagnostic> diagnostics; int errorCount = 0; - + int internalErrorLocsNoted = 0; ISlangWriter* writer = nullptr; Flags flags = 0; @@ -217,6 +218,32 @@ namespace Slang void diagnoseRaw( Severity severity, const UnownedStringSlice& message); + + /// During propagation of an exception for an internal + /// error, note that this source location was involved + void noteInternalErrorLoc(SourceLoc const& loc); + }; + + /// An `ISlangWriter` that writes directly to a diagnostic sink. + class DiagnosticSinkWriter : public AppendBufferWriter + { + public: + typedef AppendBufferWriter Super; + + DiagnosticSinkWriter(DiagnosticSink* sink) + : Super(WriterFlag::IsStatic) + , m_sink(sink) + {} + + // ISlangWriter + SLANG_NO_THROW virtual SlangResult SLANG_MCALL write(const char* chars, size_t numChars) SLANG_OVERRIDE + { + m_sink->diagnoseRaw(Severity::Note, UnownedStringSlice(chars, chars+numChars)); + return SLANG_OK; + } + + private: + DiagnosticSink* m_sink = nullptr; }; namespace Diagnostics diff --git a/source/slang/dxc-support.cpp b/source/slang/dxc-support.cpp index 7d2d6dc0c..603a59ea7 100644 --- a/source/slang/dxc-support.cpp +++ b/source/slang/dxc-support.cpp @@ -30,8 +30,11 @@ namespace Slang { String GetHLSLProfileName(Profile profile); String emitHLSLForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq); + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq); static UnownedStringSlice _getSlice(IDxcBlob* blob) { @@ -46,19 +49,22 @@ namespace Slang } SlangResult emitDXILForEntryPointUsingDXC( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - List<uint8_t>& outCode) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List<uint8_t>& outCode) { - auto compileRequest = entryPoint->compileRequest; - auto session = compileRequest->mSession; + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); // First deal with all the rigamarole of loading // the `dxcompiler` library, and creating the // top-level COM objects that will be used to // compile things. - auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, &compileRequest->mSink); + auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink); if (!dxcCreateInstance) { return SLANG_FAIL; @@ -69,9 +75,7 @@ namespace Slang { // If can't load dxil - dxc will not be able to sign output // Output a suitable warning to the user - auto& sink = entryPoint->compileRequest->mSink; - - sink.diagnose(SourceLoc(), Diagnostics::dxilNotFound); + sink->diagnose(SourceLoc(), Diagnostics::dxilNotFound); } } @@ -89,8 +93,13 @@ namespace Slang // Now let's go ahead and generate HLSL for the entry // point, since we'll need that to feed into dxc. - auto hlslCode = emitHLSLForEntryPoint(entryPoint, targetReq); - maybeDumpIntermediate(entryPoint->compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL); + auto hlslCode = emitHLSLForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); + maybeDumpIntermediate(compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL); // Wrap the @@ -122,7 +131,7 @@ namespace Slang break; } - switch( targetReq->floatingPointMode ) + switch( targetReq->getFloatingPointMode() ) { default: break; @@ -149,7 +158,7 @@ namespace Slang // args[argCount++] = L"-no-warnings"; - String entryPointName = getText(entryPoint->name); + String entryPointName = getText(entryPoint->getName()); OSString wideEntryPointName = entryPointName.ToWString(); auto profile = getEffectiveProfile(entryPoint, targetReq); @@ -172,7 +181,7 @@ namespace Slang args[argCount++] = L"-enable-16bit-types"; } - const String sourcePath = calcTranslationUnitSourcePath(entryPoint->getTranslationUnit()); + const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); ComPtr<IDxcOperationResult> dxcResult; SLANG_RETURN_ON_FAIL(dxcCompiler->Compile(dxcSourceBlob, @@ -208,7 +217,7 @@ namespace Slang // into a string for safety. // - reportExternalCompileError("dxc", resultCode, _getSlice(dxcErrorBlob), &entryPoint->compileRequest->mSink); + reportExternalCompileError("dxc", resultCode, _getSlice(dxcErrorBlob), compileRequest->getSink()); return resultCode; } @@ -225,20 +234,21 @@ namespace Slang } SlangResult dissassembleDXILUsingDXC( - CompileRequest* compileRequest, - void const* data, - size_t size, - String& stringOut) + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + String& stringOut) { stringOut = String(); - auto session = compileRequest->mSession; + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); // First deal with all the rigamarole of loading // the `dxcompiler` library, and creating the // top-level COM objects that will be used to // compile things. - auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, &compileRequest->mSink); + auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink); if (!dxcCreateInstance) { return SLANG_FAIL; diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index c0e5e4296..2603df11e 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -79,8 +79,10 @@ void requireGLSLVersionImpl( // Shared state for an entire emit session struct SharedEmitContext { + BackEndCompileRequest* compileRequest = nullptr; + // The entry point we are being asked to compile - EntryPointRequest* entryPoint; + EntryPoint* entryPoint; // The layout for the entry point EntryPointLayout* entryPointLayout; @@ -153,7 +155,7 @@ struct SharedEmitContext // to use for it when emitting code. Dictionary<IRInst*, String> mapInstToName; - DiagnosticSink* getSink() { return &entryPoint->compileRequest->mSink; } + DiagnosticSink* getSink() { return compileRequest->getSink(); } Dictionary<IRInst*, UInt> mapIRValueToRayPayloadLocation; Dictionary<IRInst*, UInt> mapIRValueToCallablePayloadLocation; @@ -165,6 +167,10 @@ struct EmitContext SharedEmitContext* shared; DiagnosticSink* getSink() { return shared->getSink(); } + + LineDirectiveMode getLineDirectiveMode() { return shared->compileRequest->getLineDirectiveMode(); } + SourceManager* getSourceManager() { return shared->compileRequest->getSourceManager(); } + void noteInternalErrorLoc(SourceLoc loc) { return getSink()->noteInternalErrorLoc(loc); } }; // @@ -334,11 +340,6 @@ struct EmitVisitor : context(context) {} - Session* getSession() - { - return context->shared->entryPoint->compileRequest->mSession; - } - // Low-level emit logic void emitRawTextSpan(char const* textBegin, char const* textEnd) @@ -556,7 +557,7 @@ struct EmitVisitor bool shouldUseGLSLStyleLineDirective = false; - auto mode = context->shared->entryPoint->compileRequest->lineDirectiveMode; + auto mode = context->getLineDirectiveMode(); switch (mode) { case LineDirectiveMode::None: @@ -664,7 +665,7 @@ struct EmitVisitor { // Don't do any of this work if the user has requested that we // not emit line directives. - auto mode = context->shared->entryPoint->compileRequest->lineDirectiveMode; + auto mode = context->getLineDirectiveMode(); switch(mode) { case LineDirectiveMode::None: @@ -723,7 +724,7 @@ struct EmitVisitor SourceManager* getSourceManager() { - return context->shared->entryPoint->compileRequest->getSourceManager(); + return context->getSourceManager(); } void advanceToSourceLocation( @@ -747,7 +748,7 @@ struct EmitVisitor DiagnosticSink* getSink() { - return &context->shared->entryPoint->compileRequest->mSink; + return context->getSink(); } // @@ -1875,8 +1876,7 @@ struct EmitVisitor } } - void emitGLSLVersionDirective( - ModuleDecl* /*program*/) + void emitGLSLVersionDirective() { auto effectiveProfile = context->shared->effectiveProfile; if(effectiveProfile.getFamily() == ProfileFamily::GLSL) @@ -1931,8 +1931,7 @@ struct EmitVisitor Emit("#version 420\n"); } - void emitGLSLPreprocessorDirectives( - RefPtr<ModuleDecl> program) + void emitGLSLPreprocessorDirectives() { switch(context->shared->target) { @@ -1944,24 +1943,7 @@ struct EmitVisitor break; } - emitGLSLVersionDirective(program); - - - // TODO: when cross-compiling we may need to output additional `#extension` directives - // based on the features that we have used. - - for( auto extensionDirective : program->GetModifiersOfType<GLSLExtensionDirective>() ) - { - // TODO(tfoley): Emit an appropriate `#line` directive... - - Emit("#extension "); - emit(extensionDirective->extensionNameToken.Content); - Emit(" : "); - emit(extensionDirective->dispositionToken.Content); - Emit("\n"); - } - - // TODO: handle other cases... + emitGLSLVersionDirective(); } /// Emit directives to control overall layout computation for the emitted code. @@ -4115,7 +4097,7 @@ struct EmitVisitor catch(AbortCompilationException&) { throw; } catch(...) { - ctx->shared->entryPoint->compileRequest->noteInternalErrorLoc(inst->sourceLoc); + ctx->noteInternalErrorLoc(inst->sourceLoc); throw; } } @@ -6583,26 +6565,26 @@ struct EmitVisitor // EntryPointLayout* findEntryPointLayout( - ProgramLayout* programLayout, - EntryPointRequest* entryPointRequest) + ProgramLayout* programLayout, + EntryPoint* entryPoint) { for( auto entryPointLayout : programLayout->entryPoints ) { - if(entryPointLayout->entryPoint->getName() != entryPointRequest->name) + if(entryPointLayout->entryPoint->getName() != entryPoint->getName()) continue; // TODO: We need to be careful about this check, since it relies on // the profile information in the layout matching that in the request. // // What we really seem to want here is some dictionary mapping the - // `EntryPointRequest` directly to the `EntryPointLayout`, and maybe + // `EntryPoint` directly to the `EntryPointLayout`, and maybe // that is precisely what we should build... // - if(entryPointLayout->profile != entryPointRequest->profile) + if(entryPointLayout->profile != entryPoint->getProfile()) continue; // TODO: can't easily filter on translation unit here... - // Ideally the `EntryPointRequest` should get filled in with a pointer + // Ideally the `EntryPoint` should get filled in with a pointer // the specific function declaration that represents the entry point. return entryPointLayout.Ptr(); @@ -6662,13 +6644,14 @@ void legalizeTypes( IRModule* module); static void dumpIRIfEnabled( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, IRModule* irModule, char const* label = nullptr) { if(compileRequest->shouldDumpIR) { - WriterHelper writer(compileRequest->getWriter(WriterChannel::StdError)); + DiagnosticSinkWriter writerImpl(compileRequest->getSink()); + WriterHelper writer(&writerImpl); if(label) { @@ -6687,16 +6670,22 @@ static void dumpIRIfEnabled( } String emitEntryPoint( - EntryPointRequest* entryPoint, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetRequest) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + CodeGenTarget target, + TargetRequest* targetRequest) { - auto translationUnit = entryPoint->getTranslationUnit(); + auto sink = compileRequest->getSink(); + auto program = compileRequest->getProgram(); + auto targetProgram = program->getTargetProgram(targetRequest); + auto programLayout = targetProgram->getOrCreateLayout(sink); + +// auto translationUnit = entryPoint->getTranslationUnit(); SharedEmitContext sharedContext; + sharedContext.compileRequest = compileRequest; sharedContext.target = target; - sharedContext.finalTarget = targetRequest->target; + sharedContext.finalTarget = targetRequest->getTarget(); sharedContext.entryPoint = entryPoint; sharedContext.effectiveProfile = getEffectiveProfile(entryPoint, targetRequest); @@ -6715,16 +6704,13 @@ String emitEntryPoint( StructTypeLayout* globalStructLayout = getGlobalStructLayout(programLayout); sharedContext.globalStructLayout = globalStructLayout; - auto translationUnitSyntax = translationUnit->SyntaxNode.Ptr(); - EmitContext context; context.shared = &sharedContext; EmitVisitor visitor(&context); { - auto compileRequest = translationUnit->compileRequest; - auto session = compileRequest->mSession; + auto session = targetRequest->getSession(); // We start out by performing "linking" at the level of the IR. // This step will create a fresh IR module to be used for @@ -6735,6 +6721,7 @@ String emitEntryPoint( // any "profile-overloaded" symbols. // auto linkedIR = linkIR( + compileRequest, entryPoint, programLayout, target, @@ -6880,7 +6867,7 @@ String emitEntryPoint( session, irModule, irEntryPoint, - &compileRequest->mSink, + compileRequest->getSink(), &sharedContext.extensionUsageTracker); } break; @@ -6916,10 +6903,6 @@ String emitEntryPoint( // TODO: do we want to emit directly from IR, or translate the // IR back into AST for emission? visitor.emitIRModule(&context, irModule); - - // retain the specialized ir module, because the current - // GlobalGenericParamSubstitution implementation may reference ir objects - targetRequest->compileRequest->compiledModules.Add(irModule); } // Deal with cases where a particular stage requires certain GLSL versions @@ -6950,7 +6933,7 @@ String emitEntryPoint( // it is time to stich together the final output. // There may be global-scope modifiers that we should emit now - visitor.emitGLSLPreprocessorDirectives(translationUnitSyntax); + visitor.emitGLSLPreprocessorDirectives(); visitor.emitLayoutDirectives(targetRequest); diff --git a/source/slang/emit.h b/source/slang/emit.h index 98845f9c6..317afcf6b 100644 --- a/source/slang/emit.h +++ b/source/slang/emit.h @@ -8,7 +8,7 @@ namespace Slang { - class EntryPointRequest; + class EntryPoint; class ProgramLayout; class TranslationUnitRequest; @@ -20,13 +20,13 @@ namespace Slang // Emit code for a single entry point, based on // the input translation unit. String emitEntryPoint( - EntryPointRequest* entryPoint, - ProgramLayout* programLayout, + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, // The target language to generate code in (e.g., HLSL/GLSL) - CodeGenTarget target, + CodeGenTarget target, // The full target request - TargetRequest* targetRequest); + TargetRequest* targetRequest); } #endif diff --git a/source/slang/ir-dce.cpp b/source/slang/ir-dce.cpp index 0f037bfe5..ba6d7adb9 100644 --- a/source/slang/ir-dce.cpp +++ b/source/slang/ir-dce.cpp @@ -16,8 +16,8 @@ struct DeadCodeEliminationContext // the parameters that were passed to the top-level // `eliminateDeadCode` function. // - CompileRequest* compileRequest; - IRModule* module; + BackEndCompileRequest* compileRequest; + IRModule* module; // Our overall process is going to be to determine // which instructions in the module are "live" @@ -235,9 +235,9 @@ struct DeadCodeEliminationContext // we'll just go ahead and eliminate every single function/type // in a module. There needs to be a way to identify the // functions we want to keep around, and for right now - // that is handled with the `[entryPoint]` decoration. + // that is handled with the `[keepAlive]` decoration. // - if(inst->findDecorationImpl(kIROp_EntryPointDecoration)) + if(inst->findDecorationImpl(kIROp_KeepAliveDecoration)) return true; // // TODO: Eventually it would make sense to consider everything @@ -312,7 +312,7 @@ struct DeadCodeEliminationContext // and then defer to it for the real work. // void eliminateDeadCode( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, IRModule* module) { DeadCodeEliminationContext context; diff --git a/source/slang/ir-dce.h b/source/slang/ir-dce.h index fd56616d9..6089b404a 100644 --- a/source/slang/ir-dce.h +++ b/source/slang/ir-dce.h @@ -3,7 +3,7 @@ namespace Slang { - class CompileRequest; + class BackEndCompileRequest; struct IRModule; /// Eliminate "dead" code from the given IR module. @@ -14,6 +14,6 @@ namespace Slang /// etc. /// void eliminateDeadCode( - CompileRequest* compileRequest, - IRModule* module); + BackEndCompileRequest* compileRequest, + IRModule* module); } diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index b6f8ce547..eada52e4d 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -392,6 +392,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// even if it does not otherwise reference it. INST(DependsOnDecoration, dependsOn, 1, 0) + /// A `[keepAlive]` decoration marks an instruction that should not be eliminated. + INST(KeepAliveDecoration, keepAlive, 0, 0) + + /* LinkageDecoration */ INST(ImportDecoration, import, 1, 0) INST(ExportDecoration, export, 1, 0) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 6b12612ef..b7c1b2744 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -1201,6 +1201,11 @@ struct IRBuilder addDecoration(value, kIROp_EntryPointDecoration); } + void addKeepAliveDecoration(IRInst* value) + { + addDecoration(value, kIROp_KeepAliveDecoration); + } + /// Add a decoration that indicates that the given `inst` depends on the given `dependency`. /// /// This decoration can be used to ensure that a value that an instruction diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp index 35e0f46b8..2eef10614 100644 --- a/source/slang/ir-link.cpp +++ b/source/slang/ir-link.cpp @@ -14,7 +14,7 @@ namespace Slang // instead of the input/request layer. EntryPointLayout* findEntryPointLayout( ProgramLayout* programLayout, - EntryPointRequest* entryPointRequest); + EntryPoint* EntryPoint); struct IRSpecSymbol : RefObject { @@ -39,9 +39,6 @@ struct IRSharedSpecContext // The specialized module we are building RefPtr<IRModule> module; - // The original, unspecialized module we are copying - IRModule* originalModule; - // A map from mangled symbol names to zero or // more global IR values that have that name, // in the *original* module. @@ -67,8 +64,6 @@ struct IRSpecContextBase IRModule* getModule() { return getShared()->module; } - IRModule* getOriginalModule() { return getShared()->originalModule; } - IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } // The current specialization environment to use. @@ -668,8 +663,8 @@ IRInst* specializeGeneric( IRSpecialize* specializeInst); IRFunc* specializeIRForEntryPoint( - IRSpecContext* context, - EntryPointRequest* entryPointRequest, + IRSpecContext* context, + EntryPoint* entryPoint, EntryPointLayout* entryPointLayout) { // We start by looking up the IR symbol that @@ -681,7 +676,7 @@ IRFunc* specializeIRForEntryPoint( // so that the mangled name of the decl-ref is // not the same as the mangled name of the decl. // - auto mangledName = getMangledName(entryPointRequest->getFuncDeclRef()); + auto mangledName = getMangledName(entryPoint->getFuncDeclRef()); RefPtr<IRSpecSymbol> sym; if (!context->getSymbols().TryGetValue(mangledName, sym)) { @@ -743,9 +738,9 @@ IRFunc* specializeIRForEntryPoint( return nullptr; } - if( !clonedFunc->findDecorationImpl(kIROp_EntryPointDecoration) ) + if( !clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration) ) { - context->builder->addEntryPointDecoration(clonedFunc); + context->builder->addKeepAliveDecoration(clonedFunc); } // We need to attach the layout information for @@ -1148,7 +1143,6 @@ void initializeSharedSpecContext( IRSharedSpecContext* sharedContext, Session* session, IRModule* module, - IRModule* originalModule, CodeGenTarget target) { @@ -1166,19 +1160,15 @@ void initializeSharedSpecContext( sharedBuilder->module = module; sharedContext->module = module; - sharedContext->originalModule = originalModule; sharedContext->target = target; - // We will populate a map with all of the IR values - // that use the same mangled name, to make lookup easier - // in other steps. - insertGlobalValueSymbols(sharedContext, originalModule); } // implementation provided in parameter-binding.cpp RefPtr<ProgramLayout> specializeProgramLayout( TargetRequest * targetReq, - ProgramLayout* programLayout, - SubstitutionSet typeSubst); + ProgramLayout* programLayout, + SubstitutionSet typeSubst, + DiagnosticSink* sink); struct IRSpecializationState { @@ -1211,11 +1201,14 @@ struct IRSpecializationState }; LinkedIR linkIR( - EntryPointRequest* entryPointRequest, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetReq) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + TargetRequest* targetReq) { + auto sink = compileRequest->getSink(); + IRSpecializationState stateStorage; auto state = &stateStorage; @@ -1223,26 +1216,27 @@ LinkedIR linkIR( state->target = target; state->targetReq = targetReq; - - auto compileRequest = entryPointRequest->compileRequest; - auto translationUnit = entryPointRequest->getTranslationUnit(); - auto originalIRModule = translationUnit->irModule; + auto program = compileRequest->getProgram(); auto sharedContext = state->getSharedContext(); initializeSharedSpecContext( sharedContext, - compileRequest->mSession, + compileRequest->getSession(), nullptr, - originalIRModule, target); state->irModule = sharedContext->module; - // We also need to attach the IR definitions for symbols from - // any loaded modules: - for (auto loadedModule : compileRequest->loadedModulesList) + // We need to be able to look up IR definitions for any symbols in + // modules that the program depends on (transitively). To + // accelerate lookup, we will create a symbol table for looking + // up IR definitions by their mangled name. + // + auto originalProgramIRModule = program->getOrCreateIRModule(sink); + insertGlobalValueSymbols(sharedContext, originalProgramIRModule); + for (auto module : program->getModuleDependencies()) { - insertGlobalValueSymbols(sharedContext, loadedModule->irModule); + insertGlobalValueSymbols(sharedContext, module->getIRModule()); } auto context = state->getContext(); @@ -1257,7 +1251,8 @@ LinkedIR linkIR( RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout( targetReq, programLayout, - SubstitutionSet(entryPointRequest->globalGenericSubst)); + SubstitutionSet(program->getGlobalGenericSubstitution()), + compileRequest->getSink()); // TODO: we need to register the (IR-level) arguments of the global generic parameters as the // substitutions for the generic parameters in the original IR. @@ -1267,13 +1262,22 @@ LinkedIR linkIR( state->newProgramLayout = newProgramLayout; - // Next, we want to optimize lookup for layout infromation + // Next, we want to optimize lookup for layout information // associated with global declarations, so that we can // look things up based on the IR values (using mangled names) + // + // Note: We are scanning over all the key-value pairs for + // entries in the global scope, to account for the fact + // that the "same" shader parameter could be declared in + // multiple translation units, and thus end up with + // multiple mangled names (when the unique translation + // unit name gets involved). + // auto globalStructLayout = getScopeStructLayout(newProgramLayout); - for (auto globalVarLayout : globalStructLayout->fields) + for(auto entry : globalStructLayout->mapVarToLayout) { - auto mangledName = getMangledName(globalVarLayout->varDecl); + auto mangledName = getMangledName(entry.Key); + auto globalVarLayout = entry.Value; context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } @@ -1290,19 +1294,19 @@ LinkedIR linkIR( cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); } - auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); + auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPoint); // Next, we make sure to clone the global value for // the entry point function itself, and rely on // this step to recursively copy over anything else // it might reference. - auto irEntryPoint = specializeIRForEntryPoint(context, entryPointRequest, entryPointLayout); + auto irEntryPoint = specializeIRForEntryPoint(context, entryPoint, entryPointLayout); // HACK: right now the bindings for global generic parameters are coming in // as part of the original IR module, and we need to make sure these get // copied over, even if they aren't referenced. // - for(auto inst : originalIRModule->getGlobalInsts()) + for(auto inst : originalProgramIRModule->getGlobalInsts()) { auto bindInst = as<IRBindGlobalGenericParam>(inst); if(!bindInst) diff --git a/source/slang/ir-link.h b/source/slang/ir-link.h index 4fcdb4618..dba3ccc97 100644 --- a/source/slang/ir-link.h +++ b/source/slang/ir-link.h @@ -19,8 +19,9 @@ namespace Slang // used. // LinkedIR linkIR( - EntryPointRequest* entryPointRequest, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetReq); + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + TargetRequest* targetReq); } diff --git a/source/slang/ir-specialize-resources.cpp b/source/slang/ir-specialize-resources.cpp index 0108a91f8..e974ffdb7 100644 --- a/source/slang/ir-specialize-resources.cpp +++ b/source/slang/ir-specialize-resources.cpp @@ -18,7 +18,7 @@ struct ResourceParameterSpecializationContext // the parameters that were passed to the top-level // `specializeResourceParameters` function. // - CompileRequest* compileRequest; + BackEndCompileRequest* compileRequest; TargetRequest* targetRequest; IRModule* module; @@ -372,7 +372,7 @@ struct ResourceParameterSpecializationContext // If we didn't find a pre-existing specialized // function, then we will go ahead and create one. // - // We start by gathering the infromation from the call + // We start by gathering the information from the call // site that is relevant to generating a specialized // callee function, which we avoided doing earlier // because it might have been throwaway work. @@ -850,7 +850,7 @@ struct ResourceParameterSpecializationContext // and then defer to it for the real work. // void specializeResourceParameters( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* module) { diff --git a/source/slang/ir-specialize-resources.h b/source/slang/ir-specialize-resources.h index 3d6ead130..0e636318c 100644 --- a/source/slang/ir-specialize-resources.h +++ b/source/slang/ir-specialize-resources.h @@ -3,7 +3,7 @@ namespace Slang { - class CompileRequest; + class BackEndCompileRequest; class TargetRequest; struct IRModule; @@ -18,7 +18,7 @@ namespace Slang /// global shader parameters directly). /// void specializeResourceParameters( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* module); } diff --git a/source/slang/ir-validate.cpp b/source/slang/ir-validate.cpp index 924ec71b3..9564873b1 100644 --- a/source/slang/ir-validate.cpp +++ b/source/slang/ir-validate.cpp @@ -195,13 +195,13 @@ namespace Slang } void validateIRModuleIfEnabled( - CompileRequest* compileRequest, - IRModule* module) + CompileRequestBase* compileRequest, + IRModule* module) { if (!compileRequest->shouldValidateIR) return; - auto sink = &compileRequest->mSink; + auto sink = compileRequest->getSink(); validateIRModule(module, sink); } } diff --git a/source/slang/ir-validate.h b/source/slang/ir-validate.h index 0ebc69019..1cb30961d 100644 --- a/source/slang/ir-validate.h +++ b/source/slang/ir-validate.h @@ -3,7 +3,7 @@ namespace Slang { - class CompileRequest; + class CompileRequestBase; class DiagnosticSink; struct IRModule; @@ -30,6 +30,6 @@ namespace Slang // A wrapper that calls `validateIRModule` only when IR validation is enabled // for the given compile request. void validateIRModuleIfEnabled( - CompileRequest* compileRequest, - IRModule* module); + CompileRequestBase* compileRequest, + IRModule* module); } diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 0d9427b08..b53bb8ebb 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -300,8 +300,18 @@ struct IRGenEnv struct SharedIRGenContext { - CompileRequest* compileRequest; - ModuleDecl* mainModuleDecl; + SharedIRGenContext( + Session* session, + DiagnosticSink* sink, + ModuleDecl* mainModuleDecl = nullptr) + : m_session(session) + , m_sink(sink) + , m_mainModuleDecl(mainModuleDecl) + {} + + Session* m_session = nullptr; + DiagnosticSink* m_sink = nullptr; + ModuleDecl* m_mainModuleDecl = nullptr; // The "global" environment for mapping declarations to their IR values. IRGenEnv globalEnv; @@ -356,17 +366,17 @@ struct IRGenContext Session* getSession() { - return shared->compileRequest->mSession; + return shared->m_session; } - CompileRequest* getCompileRequest() + DiagnosticSink* getSink() { - return shared->compileRequest; + return shared->m_sink; } - DiagnosticSink* getSink() + ModuleDecl* getMainModuleDecl() { - return &getCompileRequest()->mSink; + return shared->m_mainModuleDecl; } }; @@ -422,7 +432,7 @@ bool isImportedDecl(IRGenContext* context, Decl* decl) if (isFromStdLib(decl)) return false; - if (moduleDecl != context->shared->mainModuleDecl) + if (moduleDecl != context->getMainModuleDecl()) return true; return false; @@ -1735,16 +1745,31 @@ static String getNameForNameHint( if(auto genericParentDecl = as<GenericDecl>(parentDecl)) parentDecl = genericParentDecl->ParentDecl; + // A `ModuleDecl` can have a name too, but in the common case + // we don't want to generate name hints that include the module + // name, simply because they would lead to every global symbol + // getting a much longer name. + // + // TODO: We should probably include the module name for symbols + // being `import`ed, and not for symbols being compiled directly + // (those coming from a module that had no name given to it). + // + // For now we skip past a `ModuleDecl` parent. + // + if(auto moduleParentDecl = as<ModuleDecl>(parentDecl)) + parentDecl = moduleParentDecl->ParentDecl; + + if(!parentDecl) + { + return leafName->text; + } + auto parentName = getNameForNameHint(context, parentDecl); if(parentName.Length() == 0) { return leafName->text; } - // TODO: at some point we will start giving `ModuleDecl`s names, - // and in that case we need to think carefully about whether to - // include their names here or not. - // We will now construct a new `Name` to use as the hint, // combining the name of the parent and the leaf declaration. @@ -3603,7 +3628,7 @@ void lowerStmt( catch(AbortCompilationException&) { throw; } catch(...) { - context->getCompileRequest()->noteInternalErrorLoc(stmt->loc); + context->getSink()->noteInternalErrorLoc(stmt->loc); throw; } } @@ -5877,7 +5902,7 @@ LoweredValInfo lowerDecl( catch(AbortCompilationException&) { throw; } catch(...) { - context->getCompileRequest()->noteInternalErrorLoc(decl->loc); + context->getSink()->noteInternalErrorLoc(decl->loc); throw; } } @@ -6108,56 +6133,59 @@ LoweredValInfo emitDeclRef( type); } -static void lowerEntryPointToIR( - IRGenContext* context, - EntryPointRequest* entryPointRequest) +static void lowerFrontEndEntryPointToIR( + IRGenContext* context, + EntryPoint* entryPoint) { - // First, lower the entry point like an ordinary function + // TODO: We should emit an entry point as a dedicated IR function + // (distinct from the IR function used if it were called normally), + // with a mangled name based on the original function name plus + // the stage for which it is being compiled as an entry point (so + // that entry points for distinct stages always have distinct names). + // + // For now we just have an (implicit) constraint that a given + // function should only be used as an entry point for one stage, + // and any such function should *not* be used as an ordinary function. - auto session = context->getSession(); - auto entryPointFuncDeclRef = entryPointRequest->getFuncDeclRef(); - auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); + auto entryPointFuncDecl = entryPoint->getFuncDecl(); auto builder = context->irBuilder; builder->setInsertInto(builder->getModule()->getModuleInst()); auto loweredEntryPointFunc = getSimpleVal(context, - emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); + ensureDecl(context, entryPointFuncDecl)); // Attach a marker decoration so that we recognize // this as an entry point. // - builder->addEntryPointDecoration(loweredEntryPointFunc); - - // - if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>()) + IRInst* instToDecorate = loweredEntryPointFunc; + if(auto irGeneric = as<IRGeneric>(instToDecorate)) { - builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); + instToDecorate = findGenericReturnVal(irGeneric); } + builder->addEntryPointDecoration(instToDecorate); +} - // Now lower all the arguments supplied for global generic - // type parameters. - // - for (RefPtr<Substitutions> subst = entryPointRequest->globalGenericSubst; subst; subst = subst->outer) - { - auto gSubst = subst.as<GlobalGenericParamSubstitution>(); - if(!gSubst) - continue; +static void lowerProgramEntryPointToIR( + IRGenContext* context, + EntryPoint* entryPoint) +{ + // First, lower the entry point like an ordinary function - IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl)); - IRType* typeVal = lowerType(context, gSubst->actualType); + auto session = context->getSession(); + auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); + auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); - // bind `typeParam` to `typeVal` - builder->emitBindGlobalGenericParam(typeParam, typeVal); + auto builder = context->irBuilder; + builder->setInsertInto(builder->getModule()->getModuleInst()); - for (auto& constraintArg : gSubst->constraintArgs) - { - IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl)); - IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val); + auto loweredEntryPointFunc = getSimpleVal(context, + emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); - // bind `constraintParam` to `constraintVal` - builder->emitBindGlobalGenericParam(constraintParam, constraintVal); - } + // + if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>()) + { + builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); } } @@ -6191,19 +6219,19 @@ IRModule* generateIRForTranslationUnit( { auto compileRequest = translationUnit->compileRequest; - SharedIRGenContext sharedContextStorage; + SharedIRGenContext sharedContextStorage( + translationUnit->getSession(), + translationUnit->compileRequest->getSink(), + translationUnit->getModuleDecl()); SharedIRGenContext* sharedContext = &sharedContextStorage; - sharedContext->compileRequest = compileRequest; - sharedContext->mainModuleDecl = translationUnit->SyntaxNode; - IRGenContext contextStorage(sharedContext); IRGenContext* context = &contextStorage; SharedIRBuilder sharedBuilderStorage; SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->module = nullptr; - sharedBuilder->session = compileRequest->mSession; + sharedBuilder->session = compileRequest->getSession(); IRBuilder builderStorage; IRBuilder* builder = &builderStorage; @@ -6224,12 +6252,13 @@ IRModule* generateIRForTranslationUnit( // in case they require special handling. for (auto entryPoint : translationUnit->entryPoints) { - lowerEntryPointToIR(context, entryPoint); + lowerFrontEndEntryPointToIR(context, entryPoint); } + // // Next, ensure that all other global declarations have // been emitted. - for (auto decl : translationUnit->SyntaxNode->Members) + for (auto decl : translationUnit->getModuleDecl()->Members) { ensureAllDeclsRec(context, decl); } @@ -6271,12 +6300,12 @@ IRModule* generateIRForTranslationUnit( // Propagate `constexpr`-ness through the dataflow graph (and the // call graph) based on constraints imposed by different instructions. - propagateConstExpr(module, &compileRequest->mSink); + propagateConstExpr(module, compileRequest->getSink()); // TODO: give error messages if any `undefined` or // `unreachable` instructions remain. - checkForMissingReturns(module, &compileRequest->mSink); + checkForMissingReturns(module, compileRequest->getSink()); // TODO: consider doing some more aggressive optimizations // (in particular specialization of generics) here, so @@ -6293,28 +6322,82 @@ IRModule* generateIRForTranslationUnit( // then we can dump the initial IR for the module here. if(compileRequest->shouldDumpIR) { - ISlangWriter* writer = translationUnit->compileRequest->getWriter(WriterChannel::StdError); - - dumpIR(module, writer); + DiagnosticSinkWriter writer(compileRequest->getSink()); + dumpIR(module, &writer); } return module; } -#if 0 -String emitSlangIRAssemblyForEntryPoint( - EntryPointRequest* entryPoint) +RefPtr<IRModule> generateIRForProgram( + Session* session, + Program* program, + DiagnosticSink* sink) { - auto compileRequest = entryPoint->compileRequest; - auto irModule = lowerEntryPointToIR( - entryPoint, - compileRequest->layout.Ptr(), - // TODO: we need to pick the target more carefully here - CodeGenTarget::HLSL); - - return getSlangIRAssembly(irModule); -} -#endif +// auto compileRequest = translationUnit->compileRequest; + + SharedIRGenContext sharedContextStorage( + session, + sink); + SharedIRGenContext* sharedContext = &sharedContextStorage; + + IRGenContext contextStorage(sharedContext); + IRGenContext* context = &contextStorage; + + SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = session; + + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; + + RefPtr<IRModule> module = builder->createModule(); + sharedBuilder->module = module; + + context->irBuilder = builder; + + // We need to emit symbols for all of the entry + // points in the program; this is especially + // important in the case where a generic entry + // point is being specialized. + // + for(auto entryPoint : program->getEntryPoints()) + { + lowerProgramEntryPointToIR(context, entryPoint); + } + + // Now lower all the arguments supplied for global generic + // type parameters. + // + for (RefPtr<Substitutions> subst = program->getGlobalGenericSubstitution(); subst; subst = subst->outer) + { + auto gSubst = subst.as<GlobalGenericParamSubstitution>(); + if(!gSubst) + continue; + + IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl)); + IRType* typeVal = lowerType(context, gSubst->actualType); + + // bind `typeParam` to `typeVal` + builder->emitBindGlobalGenericParam(typeParam, typeVal); + + for (auto& constraintArg : gSubst->constraintArgs) + { + IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl)); + IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val); + + // bind `constraintParam` to `constraintVal` + builder->emitBindGlobalGenericParam(constraintParam, constraintVal); + } + } + + // TODO: Should we apply any of the validation or + // mandatory optimization passes here? + + return module; +} } // namespace Slang diff --git a/source/slang/lower-to-ir.h b/source/slang/lower-to-ir.h index bd878d6fa..f607e852c 100644 --- a/source/slang/lower-to-ir.h +++ b/source/slang/lower-to-ir.h @@ -14,7 +14,7 @@ namespace Slang { class CompileRequest; - class EntryPointRequest; + class EntryPoint; class ProgramLayout; class TranslationUnitRequest; @@ -22,5 +22,10 @@ namespace Slang IRModule* generateIRForTranslationUnit( TranslationUnitRequest* translationUnit); + + RefPtr<IRModule> generateIRForProgram( + Session* session, + Program* program, + DiagnosticSink* sink); } #endif diff --git a/source/slang/options.cpp b/source/slang/options.cpp index 8a4cf35b7..65fbfe068 100644 --- a/source/slang/options.cpp +++ b/source/slang/options.cpp @@ -41,7 +41,7 @@ struct OptionsParser SlangSession* session = nullptr; SlangCompileRequest* compileRequest = nullptr; - Slang::CompileRequest* requestImpl = nullptr; + Slang::EndToEndCompileRequest* requestImpl = nullptr; Slang::RefPtr<Slang::ConfigurableSharedLibraryLoader> sharedLibraryLoader; @@ -313,7 +313,7 @@ struct OptionsParser if (sourceLanguage == SLANG_SOURCE_LANGUAGE_UNKNOWN) { - requestImpl->mSink.diagnose(SourceLoc(), Diagnostics::cannotDeduceSourceLanguage, inPath); + requestImpl->getSink()->diagnose(SourceLoc(), Diagnostics::cannotDeduceSourceLanguage, inPath); return SLANG_FAIL; } @@ -425,9 +425,9 @@ struct OptionsParser { // Copy some state out of the current request, in case we've been called // after some other initialization has been performed. - flags = requestImpl->compileFlags; + flags = requestImpl->getFrontEndReq()->compileFlags; - DiagnosticSink* sink = &requestImpl->mSink; + DiagnosticSink* sink = requestImpl->getSink(); SlangMatrixLayoutMode defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_MODE_UNKNOWN; @@ -450,23 +450,24 @@ struct OptionsParser } else if(argStr == "-dump-ir" ) { - requestImpl->shouldDumpIR = true; + requestImpl->getFrontEndReq()->shouldDumpIR = true; + requestImpl->getBackEndReq()->shouldDumpIR = true; } else if (argStr == "-serial-ir") { - requestImpl->useSerialIRBottleneck = true; + requestImpl->getFrontEndReq()->useSerialIRBottleneck = true; } else if (argStr == "-verbose-paths") { - requestImpl->mSink.flags |= DiagnosticSink::Flag::VerbosePath; + requestImpl->getSink()->flags |= DiagnosticSink::Flag::VerbosePath; } else if (argStr == "-verify-debug-serial-ir") { - requestImpl->verifyDebugSerialization = true; + requestImpl->getFrontEndReq()->verifyDebugSerialization = true; } else if(argStr == "-validate-ir" ) { - requestImpl->shouldValidateIR = true; + requestImpl->getFrontEndReq()->shouldValidateIR = true; } else if(argStr == "-skip-codegen" ) { @@ -1222,18 +1223,7 @@ struct OptionsParser // Now that we've diagnosed the output paths, we can add them // to the compile request at the appropriate locations. // - // We start by allocating the arrays for per-entry-point output - // paths on each of the requested targets. - // - for(auto rawTarget : rawTargets) - { - auto targetID = rawTarget.targetID; - auto targetReq = requestImpl->targets[targetID]; - - targetReq->entryPointOutputPaths.SetSize(rawEntryPoints.Count()); - } - - // Consider the output files specified via `-o` and try to figure + // We will consider the output files specified via `-o` and try to figure // out how to deal with them. // for(auto& rawOutput : rawOutputs) @@ -1242,18 +1232,26 @@ struct OptionsParser if(rawOutput.entryPointIndex == -1) continue; auto targetID = rawTargets[rawOutput.targetIndex].targetID; - auto entryPointID = rawEntryPoints[rawOutput.entryPointIndex].entryPointID; + Int entryPointID = rawEntryPoints[rawOutput.entryPointIndex].entryPointID; + + auto target = requestImpl->getLinkage()->targets[targetID]; + auto entryPointReq = requestImpl->getFrontEndReq()->getEntryPointReqs()[entryPointID]; - auto targetReq = requestImpl->targets[targetID]; + RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; + if( !requestImpl->targetInfos.TryGetValue(target, targetInfo) ) + { + targetInfo = new EndToEndCompileRequest::TargetInfo(); + requestImpl->targetInfos[target] = targetInfo; + } - if(targetReq->entryPointOutputPaths[entryPointID].Length()) + String outputPath; + if( targetInfo->entryPointOutputPaths.ContainsKey(entryPointID) ) { - auto entryPointReq = requestImpl->entryPoints[entryPointID]; - sink->diagnose(SourceLoc(), Diagnostics::duplicateOutputPathsForEntryPointAndTarget, entryPointReq->name, targetReq->target); + sink->diagnose(SourceLoc(), Diagnostics::duplicateOutputPathsForEntryPointAndTarget, entryPointReq->getName(), target->getTarget()); } else { - targetReq->entryPointOutputPaths[entryPointID] = rawOutput.path; + targetInfo->entryPointOutputPaths[entryPointID] = rawOutput.path; } } @@ -1272,16 +1270,16 @@ SlangResult parseOptions( int argc, char const* const* argv) { - Slang::CompileRequest* compileRequest = (Slang::CompileRequest*) compileRequestIn; + Slang::EndToEndCompileRequest* compileRequest = (Slang::EndToEndCompileRequest*) compileRequestIn; OptionsParser parser; parser.compileRequest = compileRequestIn; parser.requestImpl = compileRequest; - parser.session = (SlangSession*)compileRequest->mSession; + parser.session = (SlangSession*)compileRequest->getSession(); Result res = parser.parse(argc, argv); - DiagnosticSink* sink = &compileRequest->mSink; + DiagnosticSink* sink = compileRequest->getSink(); if (sink->GetErrorCount() > 0) { // Put the errors in the diagnostic diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index f0abfe31c..56c5d7c1d 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -309,9 +309,6 @@ struct ParameterInfo : RefObject ParameterBindingInfo bindingInfo[kLayoutResourceKindCount]; - // The next parameter that has the same name... - ParameterInfo* nextOfSameName; - // The translation unit this parameter is specific to, if any TranslationUnitRequest* translationUnit = nullptr; @@ -335,8 +332,22 @@ struct EntryPointParameterBindingContext // across all translation units struct SharedParameterBindingContext { - // The base compile request - CompileRequest* compileRequest; + SharedParameterBindingContext( + LayoutRulesFamilyImpl* defaultLayoutRules, + ProgramLayout* programLayout, + TargetRequest* targetReq, + DiagnosticSink* sink) + : defaultLayoutRules(defaultLayoutRules) + , programLayout(programLayout) + , targetRequest(targetReq) + , m_sink(sink) + { + } + + DiagnosticSink* m_sink = nullptr; + + // The program that we are laying out +// Program* program = nullptr; // The target request that is triggering layout // @@ -365,20 +376,18 @@ struct SharedParameterBindingContext UInt defaultSpace = 0; TargetRequest* getTargetRequest() { return targetRequest; } + DiagnosticSink* getSink() { return m_sink; } }; static DiagnosticSink* getSink(SharedParameterBindingContext* shared) { - return &shared->compileRequest->mSink; + return shared->getSink(); } // State that might be specific to a single translation unit // or event to an entry point. struct ParameterBindingContext { - // The translation unit we are processing right now - TranslationUnitRequest* translationUnit; - // All the shared state needs to be available SharedParameterBindingContext* shared; @@ -386,7 +395,7 @@ struct ParameterBindingContext // the resource usage of shader parameters. TypeLayoutContext layoutContext; - // A dictionary to accellerate looking up parameters by name + // A dictionary to accelerate looking up parameters by name Dictionary<Name*, ParameterInfo*> mapNameToParameterInfo; // What stage (if any) are we compiling for? @@ -395,9 +404,6 @@ struct ParameterBindingContext // The entry point that is being processed right now. EntryPointLayout* entryPointLayout = nullptr; - // The source language we are trying to use - SourceLanguage sourceLanguage; - TargetRequest* getTargetRequest() { return shared->getTargetRequest(); } LayoutRulesFamilyImpl* getRulesFamily() { return layoutContext.getRulesFamily(); } }; @@ -1217,6 +1223,10 @@ static void collectGlobalScopeParameter( // If that is the case, we want to re-use the same `VarLayout` // across both parameters. // + // TODO: This logic currently detects *any* global-scope parameters + // with matching names, but it should eventually be narrowly + // scoped so that it only applies to parameters from unnamed modules. + // // First we look for an existing entry matching the name // of this parameter: auto parameterName = getReflectionName(varDecl); @@ -2477,7 +2487,7 @@ static ParameterBindingAndKindInfo maybeAllocateConstantBufferBinding( /// static void collectEntryPointParameters( ParameterBindingContext* context, - EntryPointRequest* entryPoint, + EntryPoint* entryPoint, SubstitutionSet typeSubst) { DeclRef<FuncDecl> entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); @@ -2486,7 +2496,7 @@ static void collectEntryPointParameters( // the `EntryPointLayout` object here. // RefPtr<EntryPointLayout> entryPointLayout = new EntryPointLayout(); - entryPointLayout->profile = entryPoint->profile; + entryPointLayout->profile = entryPoint->getProfile(); entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl(); // The entry point layout must be added to the output @@ -2501,10 +2511,10 @@ static void collectEntryPointParameters( // Note: this isn't really the best place for this logic to sit, // but it is the simplest place where we have a direct correspondence - // between a single `EntryPointRequest` and its matching `EntryPointLayout`, + // between a single `EntryPoint` and its matching `EntryPointLayout`, // so we'll use it. // - for( auto taggedUnionType : entryPoint->taggedUnionTypes ) + for( auto taggedUnionType : entryPoint->getTaggedUnionTypes() ) { SLANG_ASSERT(taggedUnionType); auto substType = taggedUnionType->Substitute(typeSubst).as<Type>(); @@ -2645,59 +2655,9 @@ static void collectEntryPointParameters( } } -// When doing parameter binding for global-scope stuff in GLSL, -// we may need to know what stage we are compiling for, so that -// we can handle special cases appropriately (e.g., "arrayed" -// inputs and outputs). -static Stage -inferStageForTranslationUnit( - TranslationUnitRequest* translationUnit) -{ - // In the specific case where we are compiling GLSL input, - // and have only a single entry point, use the stage - // of the entry point. - // - // TODO: now that we've dropped official GLSL support, - // we probably should drop this as well. - // - if( translationUnit->sourceLanguage == SourceLanguage::GLSL ) - { - if( translationUnit->entryPoints.Count() == 1 ) - { - return translationUnit->entryPoints[0]->getStage(); - } - } - - return Stage::Unknown; -} - -static void collectModuleParameters( - ParameterBindingContext* inContext, - ModuleDecl* module) -{ - // Each loaded module provides a separate (logical) namespace for - // parameters, so that two parameters with the same name, in - // distinct modules, should yield different bindings. - // - ParameterBindingContext contextData = *inContext; - auto context = &contextData; - - context->translationUnit = nullptr; - - context->stage = Stage::Unknown; - - // All imported modules are implicitly Slang code - context->sourceLanguage = SourceLanguage::Slang; - - // A loaded module cannot define entry points that - // we'll expose (for now), so we just need to - // consider global-scope parameters. - collectGlobalScopeParameters(context, module); -} - static void collectParameters( ParameterBindingContext* inContext, - CompileRequest* request) + Program* program) { // All of the parameters in translation units directly // referenced in the compile request are part of one @@ -2707,29 +2667,21 @@ static void collectParameters( ParameterBindingContext contextData = *inContext; auto context = &contextData; - for( auto& translationUnit : request->translationUnits ) + for(RefPtr<Module> module : program->getModuleDependencies()) { - context->translationUnit = translationUnit; - context->stage = inferStageForTranslationUnit(translationUnit.Ptr()); - context->sourceLanguage = translationUnit->sourceLanguage; + context->stage = Stage::Unknown; // First look at global-scope parameters - collectGlobalScopeParameters(context, translationUnit->SyntaxNode.Ptr()); - - // Next consider parameters for entry points - for( auto& entryPoint : translationUnit->entryPoints ) - { - context->stage = entryPoint->getStage(); - collectEntryPointParameters(context, entryPoint.Ptr(), SubstitutionSet()); - } - context->entryPointLayout = nullptr; + collectGlobalScopeParameters(context, module->getModuleDecl()); } - // Now collect parameters from loaded modules - for (auto& loadedModule : request->loadedModulesList) + // Next consider parameters for entry points + for(auto entryPoint : program->getEntryPoints()) { - collectModuleParameters(context, loadedModule->moduleDecl.Ptr()); + context->stage = entryPoint->getStage(); + collectEntryPointParameters(context, entryPoint, SubstitutionSet()); } + context->entryPointLayout = nullptr; } /// Emit a diagnostic about a uniform parameter at global scope. @@ -2770,41 +2722,40 @@ static int _calcTotalNumUsedRegistersForLayoutResourceKind(ParameterBindingConte return numUsed; } -void generateParameterBindings( - TargetRequest* targetReq) +RefPtr<ProgramLayout> generateParameterBindings( + TargetProgram* targetProgram, + DiagnosticSink* sink) { - CompileRequest* compileReq = targetReq->compileRequest; + auto program = targetProgram->getProgram(); + auto targetReq = targetProgram->getTargetReq(); + + RefPtr<ProgramLayout> programLayout = new ProgramLayout(); + programLayout->targetProgram = targetProgram; // Try to find rules based on the selected code-generation target - auto layoutContext = getInitialLayoutContextForTarget(targetReq); + auto layoutContext = getInitialLayoutContextForTarget(targetReq, programLayout); // If there was no target, or there are no rules for the target, // then bail out here. if (!layoutContext.rules) - return; - - RefPtr<ProgramLayout> programLayout = new ProgramLayout(); - programLayout->targetRequest = targetReq; - - targetReq->layout = programLayout; + return nullptr; // Create a context to hold shared state during the process // of generating parameter bindings - SharedParameterBindingContext sharedContext; - sharedContext.compileRequest = compileReq; - sharedContext.defaultLayoutRules = layoutContext.getRulesFamily(); - sharedContext.programLayout = programLayout; - sharedContext.targetRequest = targetReq; + SharedParameterBindingContext sharedContext( + layoutContext.getRulesFamily(), + programLayout, + targetReq, + sink); // Create a sub-context to collect parameters that get // declared into the global scope ParameterBindingContext context; context.shared = &sharedContext; - context.translationUnit = nullptr; context.layoutContext = layoutContext; // Walk through AST to discover all the parameters - collectParameters(&context, compileReq); + collectParameters(&context, program); // Now walk through the parameters to generate initial binding information for( auto& parameter : sharedContext.parameters ) @@ -2978,17 +2929,35 @@ void generateParameterBindings( const int numShaderRecordRegs = _calcTotalNumUsedRegistersForLayoutResourceKind(&context, LayoutResourceKind::ShaderRecord); if (numShaderRecordRegs > 1) { - compileReq->mSink.diagnose(SourceLoc(), Diagnostics::tooManyShaderRecordConstantBuffers, numShaderRecordRegs); - return; + sink->diagnose(SourceLoc(), Diagnostics::tooManyShaderRecordConstantBuffers, numShaderRecordRegs); } } + return programLayout; +} + +ProgramLayout* TargetProgram::getOrCreateLayout(DiagnosticSink* sink) +{ + if( !m_layout ) + { + m_layout = generateParameterBindings(this, sink); + } + return m_layout; +} + +void generateParameterBindings( + Program* program, + TargetRequest* targetReq, + DiagnosticSink* sink) +{ + program->getTargetProgram(targetReq)->getOrCreateLayout(sink); } RefPtr<ProgramLayout> specializeProgramLayout( TargetRequest* targetReq, ProgramLayout* oldProgramLayout, - SubstitutionSet typeSubst) + SubstitutionSet typeSubst, + DiagnosticSink* sink) { // The goal of the layout specialization step is to take an existing `ProgramLayout`, // and add a layout to any parameter(s) that could not be laid out previously, because @@ -3006,7 +2975,7 @@ RefPtr<ProgramLayout> specializeProgramLayout( RefPtr<ProgramLayout> newProgramLayout; newProgramLayout = new ProgramLayout(); - newProgramLayout->targetRequest = targetReq; + newProgramLayout->targetProgram = oldProgramLayout->targetProgram; newProgramLayout->globalGenericParams = oldProgramLayout->globalGenericParams; // The basic idea will be to iterate over the parameters in the old layout, @@ -3020,18 +2989,17 @@ RefPtr<ProgramLayout> specializeProgramLayout( // We will use the same kind of context type as the original parameter binding // step did, so we initialize its state here: - auto layoutContext = getInitialLayoutContextForTarget(targetReq); + auto layoutContext = getInitialLayoutContextForTarget(targetReq, newProgramLayout); SLANG_ASSERT(layoutContext.rules); - SharedParameterBindingContext sharedContext; - sharedContext.compileRequest = targetReq->compileRequest; - sharedContext.defaultLayoutRules = layoutContext.getRulesFamily(); - sharedContext.programLayout = newProgramLayout; - sharedContext.targetRequest = targetReq; + SharedParameterBindingContext sharedContext( + layoutContext.getRulesFamily(), + newProgramLayout, + targetReq, + sink); ParameterBindingContext context; context.shared = &sharedContext; - context.translationUnit = nullptr; context.layoutContext = layoutContext; // We will also need state for laying out any global-scope parameters @@ -3119,12 +3087,9 @@ RefPtr<ProgramLayout> specializeProgramLayout( // parameter, the layout of its parameter list strictly follows // the declaration order. // - for (auto & translationUnit : targetReq->compileRequest->translationUnits) + for( auto entryPoint : oldProgramLayout->getProgram()->getEntryPoints() ) { - for (auto & entryPoint : translationUnit->entryPoints) - { - collectEntryPointParameters(&context, entryPoint, typeSubst); - } + collectEntryPointParameters(&context, entryPoint, typeSubst); context.entryPointLayout = nullptr; } diff --git a/source/slang/parameter-binding.h b/source/slang/parameter-binding.h index eb093821f..82b114021 100644 --- a/source/slang/parameter-binding.h +++ b/source/slang/parameter-binding.h @@ -8,6 +8,7 @@ namespace Slang { +class Program; class TargetRequest; // The parameter-binding interface is responsible for assigning @@ -24,7 +25,9 @@ class TargetRequest; // of the program. void generateParameterBindings( - TargetRequest* targetReq); + Program* program, + TargetRequest* targetReq, + DiagnosticSink* sink); } diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index e2085eb7d..3abc47ede 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -8,7 +8,7 @@ namespace Slang { - // Pre-declare + // pre-declare static Name* getName(Parser* parser, String const& text); // Helper class useful to build a list of modifiers. @@ -79,7 +79,11 @@ namespace Slang class Parser { public: - TranslationUnitRequest* translationUnit; + NamePool* namePool; + SourceLanguage sourceLanguage; + + NamePool* getNamePool() { return namePool; } + SourceLanguage getSourceLanguage() { return sourceLanguage; } int anonymousCounter = 0; @@ -124,27 +128,26 @@ namespace Slang currentScope = currentScope->parent; } Parser( + Session* session, TokenSpan const& _tokens, DiagnosticSink * sink, RefPtr<Scope> const& outerScope) : tokenReader(_tokens) , sink(sink) , outerScope(outerScope) + , m_session(session) {} Parser(const Parser & other) = default; - Session* getSession() - { - return translationUnit->compileRequest->mSession; - } - RefPtr<ModuleDecl> Parse(); + Session* m_session = nullptr; + Session* getSession() { return m_session; } + Token ReadToken(); Token ReadToken(TokenType type); Token ReadToken(const char * string); bool LookAheadToken(TokenType type, int offset = 0); bool LookAheadToken(const char * string, int offset = 0); void parseSourceFile(ModuleDecl* program); - RefPtr<ModuleDecl> ParseProgram(); RefPtr<Decl> ParseStruct(); RefPtr<ClassDecl> ParseClass(); RefPtr<Stmt> ParseStatement(); @@ -578,11 +581,6 @@ namespace Slang return false; } - RefPtr<ModuleDecl> Parser::Parse() - { - return ParseProgram(); - } - RefPtr<RefObject> ParseTypeDef(Parser* parser, void* /*userData*/) { RefPtr<TypeDefDecl> typeDefDecl = new TypeDefDecl(); @@ -694,7 +692,7 @@ namespace Slang Token token(TokenType::Identifier, scopedIdentifier, scopedIdSourceLoc); // Get the name pool - auto namePool = parser->translationUnit->compileRequest->getNamePool(); + auto namePool = parser->getNamePool(); // Since it's an Identifier have to set the name. token.ptrValue = namePool->getName(token.Content); @@ -910,7 +908,7 @@ namespace Slang static Name* getName(Parser* parser, String const& text) { - return parser->translationUnit->compileRequest->getNamePool()->getName(text); + return parser->getNamePool()->getName(text); } static NameLoc expectIdentifier(Parser* parser) @@ -1859,7 +1857,7 @@ namespace Slang } // GLSL allows `[]` directly in a type specifier - if (parser->translationUnit->sourceLanguage == SourceLanguage::GLSL) + if (parser->getSourceLanguage() == SourceLanguage::GLSL) { typeExpr = parsePostfixTypeSuffix(parser, typeExpr); } @@ -1929,7 +1927,7 @@ namespace Slang // Just as a safety net, only apply this logic for // a file that is being passed in as "true" Slang code. // - if(parser->translationUnit->sourceLanguage == SourceLanguage::Slang) + if(parser->getSourceLanguage() == SourceLanguage::Slang) { if(typeSpec.decl) { @@ -2313,171 +2311,6 @@ namespace Slang return ParseHLSLBufferDecl(parser, "TextureBuffer"); } - static void removeModifier( - Modifiers& modifiers, - RefPtr<Modifier> modifier) - { - RefPtr<Modifier>* link = &modifiers.first; - while (*link) - { - if (*link == modifier) - { - *link = (*link)->next; - return; - } - - link = &(*link)->next; - } - } - - static RefPtr<Decl> parseGLSLBlockDecl( - Parser* parser, - Modifiers& modifiers) - { - // An GLSL block like this: - // - // uniform Foo { int a; float b; } foo; - // - // is treated as syntax sugar for a type declaration - // and then a global variable declaration using that type: - // - // struct $anonymous { int a; float b; }; - // Block<$anonymous> foo; - // - // where `$anonymous` is a fresh name. - // - // If a "local name" like `foo` is not given, then - // we make the declaration "transparent" so that lookup - // will see through it to the members inside. - - - SourceLoc pos = parser->tokenReader.PeekLoc(); - - // The initial name before the `{` is only supposed - // to be made visible to reflection - auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); - - // Look at the qualifiers present on the block to decide what kind - // of block we are looking at. Also *remove* those qualifiers so - // that they don't interfere with downstream work. - String blockWrapperTypeName; - if( auto uniformMod = modifiers.findModifier<HLSLUniformModifier>() ) - { - removeModifier(modifiers, uniformMod); - blockWrapperTypeName = "ConstantBuffer"; - } - else if( auto inMod = modifiers.findModifier<InModifier>() ) - { - removeModifier(modifiers, inMod); - blockWrapperTypeName = "__GLSLInputParameterGroup"; - } - else if( auto outMod = modifiers.findModifier<OutModifier>() ) - { - removeModifier(modifiers, outMod); - blockWrapperTypeName = "__GLSLOutputParameterGroup"; - } - else if( auto bufferMod = modifiers.findModifier<GLSLBufferModifier>() ) - { - removeModifier(modifiers, bufferMod); - blockWrapperTypeName = "__GLSLShaderStorageBuffer"; - } - else - { - // Unknown case: just map to a constant buffer and hope for the best - blockWrapperTypeName = "ConstantBuffer"; - } - - // We are going to represent each buffer as a pair of declarations. - // The first is a type declaration that holds all the members, while - // the second is a variable declaration that uses the buffer type. - RefPtr<StructDecl> blockDataTypeDecl = new StructDecl(); - RefPtr<VarDecl> blockVarDecl = new VarDecl(); - - addModifier(blockDataTypeDecl, new ImplicitParameterGroupElementTypeModifier()); - addModifier(blockVarDecl, new ImplicitParameterGroupVariableModifier()); - - // Attach the reflection name to the block so we can use it - auto reflectionNameModifier = new ParameterGroupReflectionName(); - reflectionNameModifier->nameAndLoc = NameLoc(reflectionNameToken); - addModifier(blockVarDecl, reflectionNameModifier); - - // Both declarations will have a location that points to the name - parser->FillPosition(blockDataTypeDecl.Ptr()); - parser->FillPosition(blockVarDecl.Ptr()); - - // Generate a unique name for the data type - blockDataTypeDecl->nameAndLoc.name = generateName(parser, "ParameterGroup_" + String(reflectionNameToken.Content)); - - // TODO(tfoley): We end up constructing unchecked syntax here that - // is expected to type check into the right form, but it might be - // cleaner to have a more explicit desugaring pass where we parse - // these constructs directly into the AST and *then* desugar them. - - // Construct a type expression to reference the buffer data type - auto blockDataTypeExpr = new VarExpr(); - blockDataTypeExpr->loc = blockDataTypeDecl->loc; - blockDataTypeExpr->name = blockDataTypeDecl->getName(); - blockDataTypeExpr->scope = parser->currentScope.Ptr(); - - // Construct a type exrpession to reference the type constructor - auto blockWrapperTypeExpr = new VarExpr(); - blockWrapperTypeExpr->loc = pos; - blockWrapperTypeExpr->name = getName(parser, blockWrapperTypeName); - // Always need to look this up in the outer scope, - // so that it won't collide with, e.g., a local variable called `ConstantBuffer` - blockWrapperTypeExpr->scope = parser->outerScope; - - // Construct a type expression that represents the type for the variable, - // which is the wrapper type applied to the data type - auto blockVarTypeExpr = new GenericAppExpr(); - blockVarTypeExpr->loc = blockVarDecl->loc; - blockVarTypeExpr->FunctionExpr = blockWrapperTypeExpr; - blockVarTypeExpr->Arguments.Add(blockDataTypeExpr); - - blockVarDecl->type.exp = blockVarTypeExpr; - - // The declarations in the body belong to the data type. - parseAggTypeDeclBody(parser, blockDataTypeDecl.Ptr()); - - if( parser->LookAheadToken(TokenType::Identifier) ) - { - // The user gave an explicit name to the block, - // so we need to use that as our variable name - blockVarDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - - // TODO: in this case we make actually have a more complex - // declarator, including `[]` brackets. - } - else - { - // synthesize a dummy name - blockVarDecl->nameAndLoc.name = generateName(parser, "parameterGroup_" + String(reflectionNameToken.Content)); - - // Otherwise we have a transparent declaration, similar - // to an HLSL `cbuffer` - auto transparentModifier = new TransparentModifier(); - transparentModifier->loc = pos; - addModifier(blockVarDecl, transparentModifier); - } - - // Expect a trailing `;` - parser->ReadToken(TokenType::Semicolon); - - // Because we are constructing two declarations, we have a thorny - // issue that were are only supposed to return one. - // For now we handle this by adding the type declaration to - // the current scope manually, and then returning the variable - // declaration. - // - // Note: this means that any modifiers that have already been parsed - // will get attached to the variable declaration, not the type. - // There might be cases where we need to shuffle things around. - - AddMember(parser->currentScope, blockDataTypeDecl); - - return blockVarDecl; - } - static void parseOptionalInheritanceClause(Parser* parser, AggTypeDeclBase* decl) { if (AdvanceIf(parser, TokenType::Colon)) @@ -3020,27 +2853,8 @@ namespace Slang // // - A keyword-based declaration (e.g., `cbuffer ...`) // - The beginning of a type in a declarator-based declaration (e.g., `int ...`) - // - A GLSL block declaration (e.g., `uniform Foo { ... }`) - - // Let's deal with the GLSL block case first. This is something like: - // - // uniform Foo { ... }; - // - // The `uniform` keyword has already been parsed as a modifier, - // so the identifier we are looking at is `Foo`. If the token - // after that is `{`, we assume this is a block. - // - // Of course, we only want to allow this syntax when parsing GLSL... - if (parser->translationUnit->sourceLanguage == SourceLanguage::GLSL) - { - if( parser->LookAheadToken(TokenType::LBrace, 1) ) - { - decl = parseGLSLBlockDecl(parser, modifiers); - break; - } - } - // Next we will check whether we can use the identifier token + // First we will check whether we can use the identifier token // as a declaration keyword and parse a declaration using // its associated callback: RefPtr<Decl> parsedDecl; @@ -3184,15 +2998,6 @@ namespace Slang currentScope = nullptr; } - RefPtr<ModuleDecl> Parser::ParseProgram() - { - RefPtr<ModuleDecl> program = new ModuleDecl(); - - parseSourceFile(program.Ptr()); - - return program; - } - RefPtr<Decl> Parser::ParseStruct() { RefPtr<StructDecl> rs = new StructDecl(); @@ -3591,7 +3396,7 @@ namespace Slang // parsing HLSL code. // - bool brokenScoping = translationUnit->sourceLanguage == SourceLanguage::HLSL; + bool brokenScoping = getSourceLanguage() == SourceLanguage::HLSL; // We will create a distinct syntax node class for the unscoped // case, just so that we can correctly handle it in downstream @@ -4439,14 +4244,18 @@ namespace Slang return parsePrefixExpr(this); } - RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit, + RefPtr<Expr> parseTypeFromSourceFile( + Session* session, TokenSpan const& tokens, DiagnosticSink* sink, - RefPtr<Scope> const& outerScope) + RefPtr<Scope> const& outerScope, + NamePool* namePool, + SourceLanguage sourceLanguage) { - Parser parser(tokens, sink, outerScope); - parser.translationUnit = translationUnit; + Parser parser(session, tokens, sink, outerScope); parser.currentScope = outerScope; + parser.namePool = namePool; + parser.sourceLanguage = sourceLanguage; return parser.ParseType(); } @@ -4457,12 +4266,11 @@ namespace Slang DiagnosticSink* sink, RefPtr<Scope> const& outerScope) { - Parser parser(tokens, sink, outerScope); - - parser.translationUnit = translationUnit; - + Parser parser(translationUnit->getSession(), tokens, sink, outerScope); + parser.namePool = translationUnit->getNamePool(); + parser.sourceLanguage = translationUnit->sourceLanguage; - return parser.parseSourceFile(translationUnit->SyntaxNode.Ptr()); + return parser.parseSourceFile(translationUnit->getModuleDecl()); } static void addBuiltinSyntaxImpl( diff --git a/source/slang/parser.h b/source/slang/parser.h index 785b6e345..abad902da 100644 --- a/source/slang/parser.h +++ b/source/slang/parser.h @@ -14,10 +14,13 @@ namespace Slang DiagnosticSink* sink, RefPtr<Scope> const& outerScope); - RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit, + RefPtr<Expr> parseTypeFromSourceFile( + Session* session, TokenSpan const& tokens, DiagnosticSink* sink, - RefPtr<Scope> const& outerScope); + RefPtr<Scope> const& outerScope, + NamePool* namePool, + SourceLanguage sourceLanguage); RefPtr<ModuleDecl> populateBaseLanguageModule( Session* session, diff --git a/source/slang/preprocessor.cpp b/source/slang/preprocessor.cpp index c6c438ef6..103db7dcb 100644 --- a/source/slang/preprocessor.cpp +++ b/source/slang/preprocessor.cpp @@ -194,27 +194,18 @@ struct Preprocessor // represent end-of-input situations. Token endOfFileToken; - // The translation unit that is being parsed - TranslationUnitRequest* translationUnit; + /// The linkage the provides the context for preprocessing + Linkage* linkage = nullptr; + + /// The module, if any, that the preprocessed result will belong to + Module* parentModule = nullptr; // The unique identities of any paths that have issued `#pragma once` directives to // stop them from being included again. HashSet<String> pragmaOnceUniqueIdentities; - TranslationUnitRequest* getTranslationUnit() - { - return translationUnit; - } - - ModuleDecl* getSyntax() - { - return getTranslationUnit()->SyntaxNode.Ptr(); - } - - CompileRequest* getCompileRequest() - { - return getTranslationUnit()->compileRequest; - } + NamePool* getNamePool() { return linkage->getNamePool(); } + SourceManager* getSourceManager() { return linkage->getSourceManager(); } }; // Convenience routine to access the diagnostic sink @@ -255,11 +246,6 @@ static void destroyInputStream(Preprocessor* /*preprocessor*/, PreprocessorInput delete inputStream; } -static NamePool* getNamePool(Preprocessor* preprocessor) -{ - return preprocessor->translationUnit->compileRequest->getNamePool(); -} - // Create an input stream to represent a pre-tokenized input file. // TODO(tfoley): pre-tokenizing files isn't going to work in the long run. static PreprocessorInputStream* CreateInputStreamForSource( @@ -272,7 +258,7 @@ static PreprocessorInputStream* CreateInputStreamForSource( initializePrimaryInputStream(preprocessor, inputStream); // initialize the embedded lexer so that it can generate a token stream - inputStream->lexer.initialize(sourceView, GetSink(preprocessor), getNamePool(preprocessor), memoryArena); + inputStream->lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), memoryArena); inputStream->token = inputStream->lexer.lexToken(); return inputStream; @@ -836,7 +822,7 @@ top: // Now re-lex the input - SourceManager* sourceManager = preprocessor->getCompileRequest()->getSourceManager(); + SourceManager* sourceManager = preprocessor->getSourceManager(); // We create a dummy file to represent the token-paste operation PathInfo pathInfo = PathInfo::makeTokenPaste(); @@ -845,7 +831,7 @@ top: SourceView* sourceView = sourceManager->createSourceView(sourceFile, nullptr); Lexer lexer; - lexer.initialize(sourceView, GetSink(preprocessor), getNamePool(preprocessor), sourceManager->getMemoryArena()); + lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena()); SimpleTokenInputStream* inputStream = new SimpleTokenInputStream(); initializeInputStream(preprocessor, inputStream); @@ -1564,7 +1550,7 @@ static void HandleEndIfDirective(PreprocessorDirectiveContext* context) // we expect it. // // Most directives do not need to call this directly, since we have -// a catch-all case in the main `HandleDirective()` funciton. +// a catch-all case in the main `HandleDirective()` function. // The `#include` case will call it directly to avoid complications // when it switches the input stream. static void expectEndOfDirective(PreprocessorDirectiveContext* context) @@ -1589,6 +1575,31 @@ static void expectEndOfDirective(PreprocessorDirectiveContext* context) AdvanceRawToken(context->preprocessor); } + /// Read a file in the context of handling a preprocessor directive +static SlangResult readFile( + PreprocessorDirectiveContext* context, + String const& path, + ISlangBlob** outBlob) +{ + // The actual file loading will be handled by the file system + // associated with the parent linkage. + // + auto linkage = context->preprocessor->linkage; + auto fileSystemExt = linkage->getFileSystemExt(); + SLANG_RETURN_ON_FAIL(fileSystemExt->loadFile(path.Buffer(), outBlob)); + + // If we are running the preprocessor as part of compiling a + // specific module, then we must keep track of the file we've + // read as yet another file that the module will depend on. + // + if(auto module = context->preprocessor->parentModule) + { + module->addFilePathDependency(path); + } + + return SLANG_OK; +} + // Handle a `#include` directive static void HandleIncludeDirective(PreprocessorDirectiveContext* context) { @@ -1603,7 +1614,7 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context) auto directiveLoc = GetDirectiveLoc(context); - PathInfo includedFromPathInfo = context->preprocessor->translationUnit->compileRequest->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); + PathInfo includedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); IncludeHandler* includeHandler = context->preprocessor->includeHandler; if (!includeHandler) @@ -1644,7 +1655,7 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context) // Push the new file onto our stack of input streams // TODO(tfoley): check if we have made our include stack too deep - auto sourceManager = context->preprocessor->getCompileRequest()->getSourceManager(); + auto sourceManager = context->preprocessor->getSourceManager(); // See if this an already loaded source file SourceFile* sourceFile = sourceManager->findSourceFileRecursively(filePathInfo.uniqueIdentity); @@ -1652,7 +1663,7 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context) if (!sourceFile) { ComPtr<ISlangBlob> foundSourceBlob; - if (SLANG_FAILED(includeHandler->readFile(filePathInfo.foundPath, foundSourceBlob.writeRef()))) + if (SLANG_FAILED(readFile(context, filePathInfo.foundPath, foundSourceBlob.writeRef()))) { GetSink(context)->diagnose(pathToken.loc, Diagnostics::includeFailed, path); return; @@ -1843,7 +1854,7 @@ static void HandleLineDirective(PreprocessorDirectiveContext* context) return; } - auto sourceManager = context->preprocessor->translationUnit->compileRequest->getSourceManager(); + auto sourceManager = context->preprocessor->getSourceManager(); String file; if (PeekTokenType(context) == TokenType::EndOfDirective) @@ -1891,7 +1902,7 @@ SLANG_PRAGMA_DIRECTIVE_CALLBACK(handlePragmaOnceDirective) // We are using the 'uniqueIdentity' as determined by the ISlangFileSystemEx interface to determine file identities. auto directiveLoc = GetDirectiveLoc(context); - auto issuedFromPathInfo = context->preprocessor->translationUnit->compileRequest->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); + auto issuedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); // Must have uniqueIdentity for a #pragma once to work if (!issuedFromPathInfo.hasUniqueIdentity()) @@ -1962,82 +1973,6 @@ static void HandlePragmaDirective(PreprocessorDirectiveContext* context) (subDirective->callback)(context, subDirectiveToken); } -// Handle a `#version` directive -static void handleGLSLVersionDirective(PreprocessorDirectiveContext* context) -{ - Token versionNumberToken; - if(!ExpectRaw( - context, - TokenType::IntegerLiteral, - Diagnostics::expectedTokenInPreprocessorDirective, - &versionNumberToken)) - { - return; - } - - Token glslProfileToken; - if(PeekTokenType(context) == TokenType::Identifier) - { - glslProfileToken = AdvanceToken(context); - } - - // Need to construct a representation taht we can hook into our compilation result - - auto modifier = new GLSLVersionDirective(); - modifier->versionNumberToken = versionNumberToken; - modifier->glslProfileToken = glslProfileToken; - - // Attach the modifier to the program we are parsing! - - addModifier( - context->preprocessor->getSyntax(), - modifier); -} - -// Handle a `#extension` directive, e.g., -// -// #extension some_extension_name : enable -// -static void handleGLSLExtensionDirective(PreprocessorDirectiveContext* context) -{ - Token extensionNameToken; - if(!ExpectRaw( - context, - TokenType::Identifier, - Diagnostics::expectedTokenInPreprocessorDirective, - &extensionNameToken)) - { - return; - } - - if( !ExpectRaw(context, TokenType::Colon, Diagnostics::expectedTokenInPreprocessorDirective) ) - { - return; - } - - Token dispositionToken; - if(!ExpectRaw( - context, - TokenType::Identifier, - Diagnostics::expectedTokenInPreprocessorDirective, - &dispositionToken)) - { - return; - } - - // Need to construct a representation taht we can hook into our compilation result - - auto modifier = new GLSLExtensionDirective(); - modifier->extensionNameToken = extensionNameToken; - modifier->dispositionToken = dispositionToken; - - // Attach the modifier to the program we are parsing! - - addModifier( - context->preprocessor->getSyntax(), - modifier); -} - // Handle an invalid directive static void HandleInvalidDirective(PreprocessorDirectiveContext* context) { @@ -2092,11 +2027,6 @@ static const PreprocessorDirective kDirectives[] = { "line", &HandleLineDirective, 0 }, { "pragma", &HandlePragmaDirective, 0 }, - // TODO(tfoley): These are specific to GLSL, and probably - // shouldn't be enabled for HLSL or Slang - { "version", &handleGLSLVersionDirective, 0 }, - { "extension", &handleGLSLExtensionDirective, 0 }, - { nullptr, nullptr, 0 }, }; @@ -2270,7 +2200,7 @@ static void DefineMacro( PreprocessorMacro* macro = CreateMacro(preprocessor); - auto sourceManager = preprocessor->translationUnit->compileRequest->getSourceManager(); + auto sourceManager = preprocessor->getSourceManager(); SourceFile* keyFile = sourceManager->createSourceFileWithString(pathInfo, key); SourceFile* valueFile = sourceManager->createSourceFileWithString(pathInfo, value); @@ -2280,10 +2210,10 @@ static void DefineMacro( // Use existing `Lexer` to generate a token stream. Lexer lexer; - lexer.initialize(valueView, GetSink(preprocessor), getNamePool(preprocessor), sourceManager->getMemoryArena()); + lexer.initialize(valueView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena()); macro->tokens = lexer.lexAllTokens(); - Name* keyName = preprocessor->translationUnit->compileRequest->getNamePool()->getName(key); + Name* keyName = preprocessor->getNamePool()->getName(key); macro->nameAndLoc.name = keyName; macro->nameAndLoc.loc = keyView->getRange().begin; @@ -2321,11 +2251,13 @@ TokenList preprocessSource( DiagnosticSink* sink, IncludeHandler* includeHandler, Dictionary<String, String> defines, - TranslationUnitRequest* translationUnit) + Linkage* linkage, + Module* parentModule) { Preprocessor preprocessor; InitializePreprocessor(&preprocessor, sink); - preprocessor.translationUnit = translationUnit; + preprocessor.linkage = linkage; + preprocessor.parentModule = parentModule; preprocessor.includeHandler = includeHandler; for (auto p : defines) @@ -2333,7 +2265,7 @@ TokenList preprocessSource( DefineMacro(&preprocessor, p.Key, p.Value); } - SourceManager* sourceManager = translationUnit->compileRequest->getSourceManager(); + SourceManager* sourceManager = linkage->getSourceManager(); SourceView* sourceView = sourceManager->createSourceView(file, nullptr); diff --git a/source/slang/preprocessor.h b/source/slang/preprocessor.h index 4d02cb50b..6e8ac1c69 100644 --- a/source/slang/preprocessor.h +++ b/source/slang/preprocessor.h @@ -8,8 +8,9 @@ namespace Slang { class DiagnosticSink; +class Linkage; +class Module; class ModuleDecl; -class TranslationUnitRequest; // Callback interface for the preprocessor to use when looking // for files in `#include` directives. @@ -20,9 +21,6 @@ struct IncludeHandler const String& pathIncludedFrom, PathInfo& pathInfoOut) = 0; - virtual SlangResult readFile(const String& path, - ISlangBlob** blobOut) = 0; - virtual String simplifyPath(const String& path) = 0; }; @@ -32,7 +30,8 @@ TokenList preprocessSource( DiagnosticSink* sink, IncludeHandler* includeHandler, Dictionary<String, String> defines, - TranslationUnitRequest* translationUnit); + Linkage* linkage, + Module* parentModule); } // namespace Slang diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index 2b0be98c9..b40900faf 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -585,10 +585,15 @@ SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType) SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * reflection, char const * name) { - auto context = convert(reflection); - auto compileRequest = context->targetRequest->compileRequest; + auto programLayout = convert(reflection); + auto program = programLayout->getProgram(); + + // TODO: We should extend this API to support getting error messages + // when type lookup fails. + // + Slang::DiagnosticSink sink; - RefPtr<Type> result = compileRequest->getTypeFromString(name); + RefPtr<Type> result = program->getTypeFromString(name, &sink); return (SlangReflectionType*)result.Ptr(); } @@ -599,12 +604,13 @@ SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout( { auto context = convert(reflection); auto type = convert(inType); - auto layoutContext = getInitialLayoutContextForTarget(context->targetRequest); + auto targetReq = context->getTargetReq(); + auto layoutContext = getInitialLayoutContextForTarget(targetReq, context); RefPtr<TypeLayout> result; - if (context->targetRequest->typeLayouts.TryGetValue(type, result)) + if (targetReq->getTypeLayouts().TryGetValue(type, result)) return (SlangReflectionTypeLayout*)result.Ptr(); result = CreateTypeLayout(layoutContext, type); - context->targetRequest->typeLayouts[type] = result; + targetReq->getTypeLayouts()[type] = result; return (SlangReflectionTypeLayout*)result.Ptr(); } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 16dfe8618..7a5b58d07 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -60,6 +60,8 @@ Session::Session() // Make sure our source manager is initialized builtinSourceManager.initialize(nullptr, nullptr); + m_builtinLinkage = new Linkage(this); + // Initialize representations of some very basic types: initializeTypes(); @@ -90,11 +92,12 @@ Session::Session() struct IncludeHandlerImpl : IncludeHandler { - CompileRequest* request; + Linkage* linkage; + SearchDirectoryList* searchDirectories; ISlangFileSystemExt* _getFileSystemExt() { - return request->fileSystemExt; + return linkage->getFileSystemExt(); } SlangResult _findFile(SlangPathType fromPathType, const String& fromPath, const String& path, PathInfo& pathInfoOut) @@ -153,18 +156,22 @@ struct IncludeHandlerImpl : IncludeHandler } // Search all the searchDirectories - for (auto & dir : request->searchDirectories) + for(auto sd = searchDirectories; sd; sd = sd->parent) { - SlangResult res = _findFile(SLANG_PATH_TYPE_DIRECTORY, dir.path, pathToInclude, pathInfoOut); - if (SLANG_SUCCEEDED(res) || res != SLANG_E_NOT_FOUND) + for(auto& dir : sd->searchDirectories) { - return res; + SlangResult res = _findFile(SLANG_PATH_TYPE_DIRECTORY, dir.path, pathToInclude, pathInfoOut); + if (SLANG_SUCCEEDED(res) || res != SLANG_E_NOT_FOUND) + { + return res; + } } } return SLANG_E_NOT_FOUND; } +#if 0 virtual SlangResult readFile(const String& path, ISlangBlob** blobOut) override { @@ -175,6 +182,7 @@ struct IncludeHandlerImpl : IncludeHandler return SLANG_OK; } +#endif virtual String simplifyPath(const String& path) override { @@ -192,9 +200,9 @@ struct IncludeHandlerImpl : IncludeHandler // -Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target) +Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) { - auto entryPointProfile = entryPoint->profile; + auto entryPointProfile = entryPoint->getProfile(); auto targetProfile = target->targetProfile; // Depending on the target *format* we might have to restrict the @@ -310,20 +318,13 @@ Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target // -CompileRequest::CompileRequest(Session* session) - : mSession(session) +Linkage::Linkage(Session* session) + : m_session(session) + , m_sourceManager(&m_defaultSourceManager) { getNamePool()->setRootNamePool(session->getRootNamePool()); - setSourceManager(&sourceManagerStorage); - - sourceManager->initialize(session->getBuiltinSourceManager(), nullptr); - - // Set all the default writers - for (int i = 0; i < int(WriterChannel::CountOf); ++i) - { - setWriter(WriterChannel(i), nullptr); - } + m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr); setFileSystem(nullptr); } @@ -379,10 +380,61 @@ ComPtr<ISlangBlob> createRawBlob(void const* inData, size_t size) } // +// TargetRequest +// + +Session* TargetRequest::getSession() +{ + return linkage->getSession(); +} MatrixLayoutMode TargetRequest::getDefaultMatrixLayoutMode() { - return compileRequest->getDefaultMatrixLayoutMode(); + return linkage->getDefaultMatrixLayoutMode(); +} + +// +// TranslationUnitRequest +// + +TranslationUnitRequest::TranslationUnitRequest( + FrontEndCompileRequest* compileRequest) + : compileRequest(compileRequest) +{ + module = new Module(compileRequest->getLinkage()); +} + + +Session* TranslationUnitRequest::getSession() +{ + return compileRequest->getSession(); +} + +NamePool* TranslationUnitRequest::getNamePool() +{ + return compileRequest->getNamePool(); +} + +SourceManager* TranslationUnitRequest::getSourceManager() +{ + return compileRequest->getSourceManager(); +} + +void TranslationUnitRequest::addSourceFile(SourceFile* sourceFile) +{ + m_sourceFiles.Add(sourceFile); + + // We want to record that the compiled module has a dependency + // on the path of the source file, but we also need to account + // for cases where the user added a source string/blob without + // an associated path (so that the API passes along an empty + // string). + // + auto path = sourceFile->getPathInfo().foundPath; + if(path.Length()) + { + getModule()->addFilePathDependency(path); + } } @@ -407,7 +459,7 @@ static ISlangWriter* _getDefaultWriter(WriterChannel chan) } } -void CompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) +void EndToEndCompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) { // If the user passed in null, we will use the default writer on that channel m_writers[int(chan)] = writer ? writer : _getDefaultWriter(chan); @@ -415,20 +467,20 @@ void CompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) // For diagnostic output, if the user passes in nullptr, we set on mSink.writer as that enables buffering on DiagnosticSink if (chan == WriterChannel::Diagnostic) { - mSink.writer = writer; + m_sink.writer = writer; } } -SlangResult CompileRequest::loadFile(String const& path, ISlangBlob** outBlob) +SlangResult Linkage::loadFile(String const& path, ISlangBlob** outBlob) { return fileSystemExt->loadFile(path.Buffer(), outBlob); } -RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope) +RefPtr<Expr> Linkage::parseTypeString(String typeStr, RefPtr<Scope> scope) { // Create a SourceManager on the stack, so any allocations for 'SourceFile'/'SourceView' etc will be cleaned up SourceManager localSourceManager; - localSourceManager.initialize(sourceManager, nullptr); + localSourceManager.initialize(getSourceManager(), nullptr); Slang::SourceFile* srcFile = localSourceManager.createSourceFileWithString(PathInfo::makeTypeParse(), typeStr); @@ -440,20 +492,20 @@ RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translatio // Use RAII - to make sure everything is reset even if an exception is thrown. struct ScopeReplaceSourceManager { - ScopeReplaceSourceManager(CompileRequest* request, SourceManager* replaceManager): - m_request(request), - m_originalSourceManager(request->getSourceManager()) + ScopeReplaceSourceManager(Linkage* linkage, SourceManager* replaceManager): + m_linkage(linkage), + m_originalSourceManager(linkage->getSourceManager()) { - request->setSourceManager(replaceManager); + linkage->setSourceManager(replaceManager); } ~ScopeReplaceSourceManager() { - m_request->setSourceManager(m_originalSourceManager); + m_linkage->setSourceManager(m_originalSourceManager); } private: - CompileRequest* m_request; + Linkage* m_linkage; SourceManager* m_originalSourceManager; }; @@ -465,87 +517,131 @@ RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translatio &sink, nullptr, Dictionary<String,String>(), - translationUnit); + this, + nullptr); - return parseTypeFromSourceFile(translationUnit, tokens, &sink, scope); + return parseTypeFromSourceFile( + getSession(), + tokens, &sink, scope, getNamePool(), SourceLanguage::Slang); } -RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp); -Type* CompileRequest::getTypeFromString(String typeStr) +RefPtr<Type> checkProperType( + Linkage* linkage, + TypeExp typeExp, + DiagnosticSink* sink); + +Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) { + // If we've looked up this type name before, + // then we can re-use it. + // RefPtr<Type> type; - if (types.TryGetValue(typeStr, type)) + if(m_types.TryGetValue(typeStr, type)) return type; - auto translationUnit = translationUnits.First(); + + // Otherwise, we need to start looking in + // the modules that were directly or + // indirectly referenced. + // + // TODO: This `scopesToTry` idiom appears + // all over the code, and isn't really + // how we should be handling this kind of + // lookup at all. + // List<RefPtr<Scope>> scopesToTry; - for (auto tu : translationUnits) - scopesToTry.Add(tu->SyntaxNode->scope); - for (auto & module : loadedModulesList) - scopesToTry.Add(module->moduleDecl->scope); - // parse type name - for (auto & s : scopesToTry) - { - RefPtr<Expr> typeExpr = parseTypeString(translationUnit, + for(auto module : getModuleDependencies()) + scopesToTry.Add(module->getModuleDecl()->scope); + + auto linkage = getLinkage(); + for(auto& s : scopesToTry) + { + RefPtr<Expr> typeExpr = linkage->parseTypeString( typeStr, s); - type = checkProperType(translationUnit, TypeExp(typeExpr)); - if (type) + type = checkProperType(linkage, TypeExp(typeExpr), sink); + if(type) break; } - if (type) + if( type ) { - types[typeStr] = type; + m_types[typeStr] = type; } - return type.Ptr(); + return type; } -void CompileRequest::parseTranslationUnit( +CompileRequestBase::CompileRequestBase( + Linkage* linkage, + DiagnosticSink* sink) + : m_linkage(linkage) + , m_sink(sink) +{} + + +FrontEndCompileRequest::FrontEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink) + : CompileRequestBase(linkage, sink) +{ +} + +void FrontEndCompileRequest::parseTranslationUnit( TranslationUnitRequest* translationUnit) { IncludeHandlerImpl includeHandler; - includeHandler.request = this; + includeHandler.linkage = getLinkage(); + includeHandler.searchDirectories = &searchDirectories; RefPtr<Scope> languageScope; switch (translationUnit->sourceLanguage) { case SourceLanguage::HLSL: - languageScope = mSession->hlslLanguageScope; + languageScope = getSession()->hlslLanguageScope; break; case SourceLanguage::Slang: default: - languageScope = mSession->slangLanguageScope; + languageScope = getSession()->slangLanguageScope; break; } Dictionary<String, String> combinedPreprocessorDefinitions; + for(auto& def : getLinkage()->preprocessorDefinitions) + combinedPreprocessorDefinitions.Add(def.Key, def.Value); for(auto& def : preprocessorDefinitions) combinedPreprocessorDefinitions.Add(def.Key, def.Value); for(auto& def : translationUnit->preprocessorDefinitions) combinedPreprocessorDefinitions.Add(def.Key, def.Value); + auto module = translationUnit->getModule(); RefPtr<ModuleDecl> translationUnitSyntax = new ModuleDecl(); - translationUnit->SyntaxNode = translationUnitSyntax; + translationUnitSyntax->nameAndLoc.name = translationUnit->moduleName; + translationUnitSyntax->module = module; + module->setModuleDecl(translationUnitSyntax); - for (auto sourceFile : translationUnit->sourceFiles) + for (auto sourceFile : translationUnit->getSourceFiles()) { auto tokens = preprocessSource( sourceFile, - &mSink, + getSink(), &includeHandler, combinedPreprocessorDefinitions, - translationUnit); + getLinkage(), + module); parseSourceFile( translationUnit, tokens, - &mSink, + getSink(), languageScope); } } -void validateEntryPoints(CompileRequest*); +RefPtr<Program> createUnspecializedProgram( + FrontEndCompileRequest* compileRequest); -void CompileRequest::checkAllTranslationUnits() +RefPtr<Program> createSpecializedProgram( + EndToEndCompileRequest* endToEndReq); + +void FrontEndCompileRequest::checkAllTranslationUnits() { // Iterate over all translation units and // apply the semantic checking logic. @@ -553,12 +649,9 @@ void CompileRequest::checkAllTranslationUnits() { checkTranslationUnit(translationUnit.Ptr()); } - - // Next, do follow-up validation on any entry points. - validateEntryPoints(this); } -void CompileRequest::generateIR() +void FrontEndCompileRequest::generateIR() { // Our task in this function is to generate IR code // for all of the declarations in the translation @@ -581,9 +674,9 @@ void CompileRequest::generateIR() if (verifyDebugSerialization) { // Verify debug information - if (SLANG_FAILED(IRSerialUtil::verifySerialize(irModule, mSession, sourceManager, IRSerialBinary::CompressionType::None, IRSerialWriter::OptionFlag::DebugInfo))) + if (SLANG_FAILED(IRSerialUtil::verifySerialize(irModule, getSession(), getSourceManager(), IRSerialBinary::CompressionType::None, IRSerialWriter::OptionFlag::DebugInfo))) { - mSink.diagnose(irModule->moduleInst->sourceLoc, Diagnostics::serialDebugVerificationFailed); + getSink()->diagnose(irModule->moduleInst->sourceLoc, Diagnostics::serialDebugVerificationFailed); } } @@ -593,7 +686,7 @@ void CompileRequest::generateIR() { // Write IR out to serialData - copying over SourceLoc information directly IRSerialWriter writer; - writer.write(irModule, sourceManager, IRSerialWriter::OptionFlag::RawSourceLocation, &serialData); + writer.write(irModule, getSourceManager(), IRSerialWriter::OptionFlag::RawSourceLocation, &serialData); // Destroy irModule such that memory can be used for newly constructed read irReadModule irModule = nullptr; @@ -602,7 +695,7 @@ void CompileRequest::generateIR() { // Read IR back from serialData IRSerialReader reader; - reader.read(serialData, mSession, nullptr, irReadModule); + reader.read(serialData, getSession(), nullptr, irReadModule); } // Set irModule to the read module @@ -610,12 +703,12 @@ void CompileRequest::generateIR() } // Set the module on the translation unit - translationUnit->irModule = irModule; + translationUnit->getModule()->setIRModule(irModule); } } // Try to infer a single common source language for a request -static SourceLanguage inferSourceLanguage(CompileRequest* request) +static SourceLanguage inferSourceLanguage(FrontEndCompileRequest* request) { SourceLanguage language = SourceLanguage::Unknown; for (auto& translationUnit : request->translationUnits) @@ -639,29 +732,115 @@ static SourceLanguage inferSourceLanguage(CompileRequest* request) return language; } -SlangResult CompileRequest::executeActionsInner() +SlangResult FrontEndCompileRequest::executeActionsInner() { - // Do some cleanup on settings specified by user. - // In particular, we want to propagate flags from the overall request down to - // each translation unit. + // We currently allow GlSL files on the command line so that we can + // drive our "pass-through" mode, but we really want to issue an error + // message if the user is seriously asking us to compile them. for (auto& translationUnit : translationUnits) { - translationUnit->compileFlags |= compileFlags; + switch(translationUnit->sourceLanguage) + { + default: + break; + + case SourceLanguage::GLSL: + getSink()->diagnose(SourceLoc(), Diagnostics::glslIsNotSupported); + return SLANG_FAIL; + } + } + + + // Parse everything from the input files requested + for (auto& translationUnit : translationUnits) + { + parseTranslationUnit(translationUnit.Ptr()); } + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + // Perform semantic checking on the whole collection + checkAllTranslationUnits(); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + // Look up all the entry points that are expected, + // and use them to populate the `program` member. + // + m_program = createUnspecializedProgram(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + if ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) + { + // Generate initial IR for all the translation + // units, if we are in a mode where IR is called for. + generateIR(); + } + + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + // Do parameter binding generation, for each compilation target. + // + for(auto targetReq : getLinkage()->targets) + { + auto targetProgram = m_program->getTargetProgram(targetReq); + targetProgram->getOrCreateLayout(getSink()); + } + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + return SLANG_OK; +} + +BackEndCompileRequest::BackEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink, + Program* program) + : CompileRequestBase(linkage, sink) + , m_program(program) +{} + +EndToEndCompileRequest::EndToEndCompileRequest( + Session* session) + : m_session(session) +{ + m_linkage = new Linkage(session); + + m_sink.sourceManager = m_linkage->getSourceManager(); + + // Set all the default writers + for (int i = 0; i < int(WriterChannel::CountOf); ++i) + { + setWriter(WriterChannel(i), nullptr); + } + + m_frontEndReq = new FrontEndCompileRequest(getLinkage(), getSink()); + + m_backEndReq = new BackEndCompileRequest(getLinkage(), getSink()); +} + +SlangResult EndToEndCompileRequest::executeActionsInner() +{ // If no code-generation target was specified, then try to infer one from the source language, // just to make sure we can do something reasonable when invoked from the command line. - if (targets.Count() == 0) + // + // TODO: This logic should be moved into `options.cpp` or somewhere else + // specific to the command-line tool. + // + if (getLinkage()->targets.Count() == 0) { - auto language = inferSourceLanguage(this); + auto language = inferSourceLanguage(getFrontEndReq()); switch (language) { case SourceLanguage::HLSL: - addTarget(CodeGenTarget::DXBytecode); + getLinkage()->addTarget(CodeGenTarget::DXBytecode); break; case SourceLanguage::GLSL: - addTarget(CodeGenTarget::SPIRV); + getLinkage()->addTarget(CodeGenTarget::SPIRV); break; default: @@ -672,105 +851,117 @@ SlangResult CompileRequest::executeActionsInner() // We only do parsing and semantic checking if we *aren't* doing // a pass-through compilation. // - // Note that we *do* perform output generation as normal in pass-through mode. if (passThrough == PassThroughMode::None) { - // We currently allow GlSL files on the command line so that we can - // drive our "pass-through" mode, but we really want to issue an error - // message if the user is seriously asking us to compile them. - for (auto& translationUnit : translationUnits) - { - switch(translationUnit->sourceLanguage) - { - default: - break; - - case SourceLanguage::GLSL: - mSink.diagnose(SourceLoc(), Diagnostics::glslIsNotSupported); - return SLANG_FAIL; - } - } + SLANG_RETURN_ON_FAIL(getFrontEndReq()->executeActionsInner()); + } + // If command line specifies to skip codegen, we exit here. + // Note: this is a debugging option. + // + if (shouldSkipCodegen || + ((getFrontEndReq()->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) != 0)) + { + // We will use the program (and matching layout information) + // that was computed in the front-end for all subsequent + // reflection queries, etc. + // + m_specializedProgram = getUnspecializedProgram(); - // Parse everything from the input files requested - for (auto& translationUnit : translationUnits) - { - parseTranslationUnit(translationUnit.Ptr()); - } - if (mSink.GetErrorCount() != 0) - return SLANG_FAIL; + return SLANG_OK; + } - // Perform semantic checking on the whole collection - checkAllTranslationUnits(); - if (mSink.GetErrorCount() != 0) + // If codegen is enabled, we need to move along to + // apply any generic specialization that the user asked for. + // + if (passThrough == PassThroughMode::None) + { + m_specializedProgram = createSpecializedProgram(this); + if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; - if ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) + // For each code generation target, we will generate specialized + // parameter binding information (taking global generic + // arguments into account at this time). + // + for (auto targetReq : getLinkage()->targets) { - // Generate initial IR for all the translation - // units, if we are in a mode where IR is called for. - generateIR(); + auto targetProgram = m_specializedProgram->getTargetProgram(targetReq); + targetProgram->getOrCreateLayout(getSink()); } - - if (mSink.GetErrorCount() != 0) + if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; - - // For each code generation target generate - // parameter binding information. - // This step is done globally, because all translation - // units and entry points need to agree on where - // parameters are allocated. - for (auto targetReq : targets) + } + else + { + // We need to create dummy `EntryPoint` objects + // to make sure that the logic in `generateOutput` + // sees something worth processing. + // + auto specializedProgram = new Program(getLinkage()); + m_specializedProgram = specializedProgram; + for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs()) { - generateParameterBindings(targetReq); - if (mSink.GetErrorCount() != 0) - return SLANG_FAIL; + RefPtr<EntryPoint> entryPoint = EntryPoint::createDummyForPassThrough( + entryPointReq->getName(), + entryPointReq->getProfile()); + + specializedProgram->addEntryPoint(entryPoint); } } - // If command line specifies to skip codegen, we exit here. - // Note: this is a debugging option. - if (shouldSkipCodegen || - ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) != 0)) - return SLANG_OK; - // Generate output code, in whatever format was requested + getBackEndReq()->setProgram(getSpecializedProgram()); generateOutput(this); - if (mSink.GetErrorCount() != 0) + if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; return SLANG_OK; } // Act as expected of the API-based compiler -SlangResult CompileRequest::executeActions() +SlangResult EndToEndCompileRequest::executeActions() { SlangResult res = executeActionsInner(); - mDiagnosticOutput = mSink.outputBuffer.ProduceString(); + mDiagnosticOutput = getSink()->outputBuffer.ProduceString(); return res; } -int CompileRequest::addTranslationUnit(SourceLanguage language, String const&) +int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* moduleName) { UInt result = translationUnits.Count(); - RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(); + RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(this); translationUnit->compileRequest = this; translationUnit->sourceLanguage = SourceLanguage(language); + translationUnit->moduleName = moduleName; + translationUnits.Add(translationUnit); return (int) result; } -void CompileRequest::addTranslationUnitSourceFile( +int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language) +{ + // We want to ensure that symbols defined in different translation + // units get unique mangled names, so that we can, e.g., tell apart + // a `main()` function in `vertex.slang` and a `main()` in `fragment.slang`, + // even when they are being compiled together. + // + String generatedName = "tu"; + generatedName.append(translationUnits.Count()); + return addTranslationUnit(language, getNamePool()->getName(generatedName)); +} + +void FrontEndCompileRequest::addTranslationUnitSourceFile( int translationUnitIndex, SourceFile* sourceFile) { - translationUnits[translationUnitIndex]->sourceFiles.Add(sourceFile); + translationUnits[translationUnitIndex]->addSourceFile(sourceFile); } -void CompileRequest::addTranslationUnitSourceBlob( +void FrontEndCompileRequest::addTranslationUnitSourceBlob( int translationUnitIndex, String const& path, ISlangBlob* sourceBlob) @@ -781,7 +972,7 @@ void CompileRequest::addTranslationUnitSourceBlob( addTranslationUnitSourceFile(translationUnitIndex, sourceFile); } -void CompileRequest::addTranslationUnitSourceString( +void FrontEndCompileRequest::addTranslationUnitSourceString( int translationUnitIndex, String const& path, String const& source) @@ -792,7 +983,7 @@ void CompileRequest::addTranslationUnitSourceString( addTranslationUnitSourceFile(translationUnitIndex, sourceFile); } -void CompileRequest::addTranslationUnitSourceFile( +void FrontEndCompileRequest::addTranslationUnitSourceFile( int translationUnitIndex, String const& path) { @@ -809,7 +1000,7 @@ void CompileRequest::addTranslationUnitSourceFile( if(SLANG_FAILED(result)) { // Emit a diagnostic! - mSink.diagnose( + getSink()->diagnose( SourceLoc(), Diagnostics::cannotOpenFile, path); @@ -820,36 +1011,51 @@ void CompileRequest::addTranslationUnitSourceFile( translationUnitIndex, path, sourceBlob); +} + +int FrontEndCompileRequest::addEntryPoint( + int translationUnitIndex, + String const& name, + Profile entryPointProfile) +{ + auto translationUnitReq = translationUnits[translationUnitIndex]; + + UInt result = m_entryPointReqs.Count(); + + RefPtr<FrontEndEntryPointRequest> entryPointReq = new FrontEndEntryPointRequest( + this, + translationUnitIndex, + getNamePool()->getName(name), + entryPointProfile); + + m_entryPointReqs.Add(entryPointReq); +// translationUnitReq->entryPoints.Add(entryPointReq); - mDependencyFilePaths.Add(path); + return int(result); } -int CompileRequest::addEntryPoint( +int EndToEndCompileRequest::addEntryPoint( int translationUnitIndex, String const& name, Profile entryPointProfile, List<String> const & genericTypeNames) { - RefPtr<EntryPointRequest> entryPoint = new EntryPointRequest(); - entryPoint->compileRequest = this; - entryPoint->name = getNamePool()->getName(name); - entryPoint->profile = entryPointProfile; - entryPoint->translationUnitIndex = translationUnitIndex; + getFrontEndReq()->addEntryPoint(translationUnitIndex, name, entryPointProfile); + + EntryPointInfo entryPointInfo; for (auto typeName : genericTypeNames) - entryPoint->genericArgStrings.Add(typeName); - auto translationUnit = translationUnits[translationUnitIndex].Ptr(); - translationUnit->entryPoints.Add(entryPoint); + entryPointInfo.genericArgStrings.Add(typeName); UInt result = entryPoints.Count(); - entryPoints.Add(entryPoint); + entryPoints.Add(_Move(entryPointInfo)); return (int) result; } -UInt CompileRequest::addTarget( +UInt Linkage::addTarget( CodeGenTarget target) { RefPtr<TargetRequest> targetReq = new TargetRequest(); - targetReq->compileRequest = this; + targetReq->linkage = this; targetReq->target = target; UInt result = targets.Count(); @@ -857,15 +1063,16 @@ UInt CompileRequest::addTarget( return (int) result; } -void CompileRequest::loadParsedModule( - RefPtr<TranslationUnitRequest> const& translationUnit, - Name* name, - const PathInfo& pathInfo) +void Linkage::loadParsedModule( + RefPtr<TranslationUnitRequest> translationUnit, + Name* name, + const PathInfo& pathInfo) { // Note: we add the loaded module to our name->module listing // before doing semantic checking, so that if it tries to // recursively `import` itself, we can detect it. - RefPtr<LoadedModule> loadedModule = new LoadedModule(); + // + RefPtr<Module> loadedModule = translationUnit->getModule(); // Get a path String mostUniqueIdentity = pathInfo.getMostUniqueIdentity(); @@ -874,12 +1081,11 @@ void CompileRequest::loadParsedModule( mapPathToLoadedModule.Add(mostUniqueIdentity, loadedModule); mapNameToLoadedModules.Add(name, loadedModule); - int errorCountBefore = mSink.GetErrorCount(); - checkTranslationUnit(translationUnit.Ptr()); - int errorCountAfter = mSink.GetErrorCount(); + auto sink = translationUnit->compileRequest->getSink(); - RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode; - loadedModule->moduleDecl = moduleDecl; + int errorCountBefore = sink->GetErrorCount(); + checkTranslationUnit(translationUnit.Ptr()); + int errorCountAfter = sink->GetErrorCount(); if (errorCountAfter != errorCountBefore) { @@ -890,39 +1096,56 @@ void CompileRequest::loadParsedModule( // If we didn't run into any errors, then try to generate // IR code for the imported module. SLANG_ASSERT(errorCountAfter == 0); - loadedModule->irModule = generateIRForTranslationUnit(translationUnit); + loadedModule->setIRModule(generateIRForTranslationUnit(translationUnit)); } loadedModulesList.Add(loadedModule); } -RefPtr<ModuleDecl> CompileRequest::loadModule( +Module* Linkage::loadModule(String const& name) +{ + // TODO: We either need to have a diagnostics sink + // get passed into this operation, or associate + // one with the linkage. + // + DiagnosticSink* sink = nullptr; + return findOrImportModule( + getNamePool()->getName(name), + SourceLoc(), + sink); +} + + +RefPtr<Module> Linkage::loadModule( Name* name, const PathInfo& filePathInfo, ISlangBlob* sourceBlob, - SourceLoc const& srcLoc) + SourceLoc const& srcLoc, + DiagnosticSink* sink) { - RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(); - translationUnit->compileRequest = this; + RefPtr<FrontEndCompileRequest> frontEndReq = new FrontEndCompileRequest(this, sink); - // We don't want to use the same options that the user specified - // for loading modules on-demand. In particular, we always want - // semantic checking to be enabled. - // - // TODO: decide which options, if any, should be inherited. - translationUnit->compileFlags = 0; + RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(frontEndReq); + translationUnit->compileRequest = frontEndReq; + translationUnit->moduleName = name; + + auto module = translationUnit->getModule(); + + ModuleBeingImportedRAII moduleBeingImported( + this, + module); // Create with the 'friendly' name SourceFile* sourceFile = getSourceManager()->createSourceFileWithBlob(filePathInfo, sourceBlob); - translationUnit->sourceFiles.Add(sourceFile); + translationUnit->addSourceFile(sourceFile); - int errorCountBefore = mSink.GetErrorCount(); - parseTranslationUnit(translationUnit.Ptr()); - int errorCountAfter = mSink.GetErrorCount(); + int errorCountBefore = sink->GetErrorCount(); + frontEndReq->parseTranslationUnit(translationUnit); + int errorCountAfter = sink->GetErrorCount(); if( errorCountAfter != errorCountBefore ) { - mSink.diagnose(srcLoc, Diagnostics::errorInImportedModule); + sink->diagnose(srcLoc, Diagnostics::errorInImportedModule); } if (errorCountAfter) { @@ -935,38 +1158,57 @@ RefPtr<ModuleDecl> CompileRequest::loadModule( name, filePathInfo); - errorCountAfter = mSink.GetErrorCount(); + errorCountAfter = sink->GetErrorCount(); if (errorCountAfter != errorCountBefore) { - mSink.diagnose(srcLoc, Diagnostics::errorInImportedModule); + sink->diagnose(srcLoc, Diagnostics::errorInImportedModule); // Something went wrong during the parsing, so we should bail out. return nullptr; } - return translationUnit->SyntaxNode; + return module; +} + +bool Linkage::isBeingImported(Module* module) +{ + for(auto ii = m_modulesBeingImported; ii; ii = ii->next) + { + if(module == ii->module) + return true; + } + return false; } -RefPtr<ModuleDecl> CompileRequest::findOrImportModule( +RefPtr<Module> Linkage::findOrImportModule( Name* name, - SourceLoc const& loc) + SourceLoc const& loc, + DiagnosticSink* sink) { // Have we already loaded a module matching this name? - // If so, return it. + // RefPtr<LoadedModule> loadedModule; if (mapNameToLoadedModules.TryGetValue(name, loadedModule)) { + // If the map shows a null module having been loaded, + // then that means there was a prior load attempt, + // but it failed, so we won't bother trying again. + // if (!loadedModule) return nullptr; - if (!loadedModule->moduleDecl) + // If state shows us that the module is already being + // imported deeper on the call stack, then we've + // hit a recursive case, and that is an error. + // + if(isBeingImported(loadedModule)) { // We seem to be in the middle of loading this module - mSink.diagnose(loc, Diagnostics::recursiveModuleImport, name); + sink->diagnose(loc, Diagnostics::recursiveModuleImport, name); return nullptr; } - return loadedModule->moduleDecl; + return loadedModule; } // Derive a file name for the module, by taking the given @@ -991,7 +1233,8 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( // using our ordinary include-handling logic. IncludeHandlerImpl includeHandler; - includeHandler.request = this; + includeHandler.linkage = this; + includeHandler.searchDirectories = &searchDirectories; // Get the original path info PathInfo pathIncludedFromInfo = getSourceManager()->getPathInfo(loc, SourceLocType::Actual); @@ -1000,20 +1243,20 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( // We have to load via the found path - as that is how file was originally loaded if (SLANG_FAILED(includeHandler.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo))) { - this->mSink.diagnose(loc, Diagnostics::cannotFindFile, fileName); + sink->diagnose(loc, Diagnostics::cannotFindFile, fileName); mapNameToLoadedModules[name] = nullptr; return nullptr; } // Maybe this was loaded previously at a different relative name? if (mapPathToLoadedModule.TryGetValue(filePathInfo.getMostUniqueIdentity(), loadedModule)) - return loadedModule->moduleDecl; + return loadedModule; // Try to load it ComPtr<ISlangBlob> fileContents; - if (SLANG_FAILED(includeHandler.readFile(filePathInfo.foundPath, fileContents.writeRef()))) + if(SLANG_FAILED(getFileSystemExt()->loadFile(filePathInfo.foundPath.Buffer(), fileContents.writeRef()))) { - this->mSink.diagnose(loc, Diagnostics::cannotOpenFile, fileName); + sink->diagnose(loc, Diagnostics::cannotOpenFile, fileName); mapNameToLoadedModules[name] = nullptr; return nullptr; } @@ -1024,26 +1267,159 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( name, filePathInfo, fileContents, - loc); + loc, + sink); } -Decl * CompileRequest::lookupGlobalDecl(Name * name) +// +// ModuleDependencyList +// + +void ModuleDependencyList::addDependency(Module* module) { - Decl* resultDecl = nullptr; - for (auto module : loadedModulesList) + // If we depend on a module, then we depend on everything it depends on. + // + // Note: We are processing these sub-depenencies before adding the + // `module` itself, so that in the common case a module will always + // appear *after* everything it depends on. + // + // However, this rule is being violated in the compiler right now because + // the modules for hte top-level translation units of a compile request + // will be added to the list first (using `addLeafDependency`) to + // maintain compatibility with old behavior. This may be fixed later. + // + for(auto subDependency : module->getModuleDependencyList()) { - if (module->moduleDecl->memberDictionary.TryGetValue(name, resultDecl)) - break; + _addDependency(subDependency); + } + _addDependency(module); +} + +void ModuleDependencyList::addLeafDependency(Module* module) +{ + _addDependency(module); +} + +void ModuleDependencyList::_addDependency(Module* module) +{ + if(m_moduleSet.Contains(module)) + return; + + m_moduleList.Add(module); + m_moduleSet.Add(module); +} + +// +// FilePathDependencyList +// + +void FilePathDependencyList::addDependency(String const& path) +{ + if(m_filePathSet.Contains(path)) + return; + + m_filePathList.Add(path); + m_filePathSet.Add(path); +} + +void FilePathDependencyList::addDependency(Module* module) +{ + for(auto& path : module->getFilePathDependencyList()) + { + addDependency(path); } - for (auto transUnit : translationUnits) +} + + + +// +// Module +// + +Module::Module(Linkage* linkage) + : m_linkage(linkage) +{} + + +void Module::addModuleDependency(Module* module) +{ + m_moduleDependencyList.addDependency(module); + m_filePathDependencyList.addDependency(module); +} + +void Module::addFilePathDependency(String const& path) +{ + m_filePathDependencyList.addDependency(path); +} + +// Program + +Program::Program(Linkage* linkage) + : m_linkage(linkage) +{} + +void Program::addReferencedModule(Module* module) +{ + m_moduleDependencyList.addDependency(module); + m_filePathDependencyList.addDependency(module); +} + +void Program::addReferencedLeafModule(Module* module) +{ + m_moduleDependencyList.addLeafDependency(module); + m_filePathDependencyList.addDependency(module); +} + +void Program::addEntryPoint(EntryPoint* entryPoint) +{ + m_entryPoints.Add(entryPoint); + + for(auto module : entryPoint->getModuleDependencies()) { - if (transUnit->SyntaxNode->memberDictionary.TryGetValue(name, resultDecl)) - break; + addReferencedModule(module); } - return resultDecl; } -void CompileRequest::noteInternalErrorLoc(SourceLoc const& loc) +RefPtr<IRModule> Program::getOrCreateIRModule(DiagnosticSink* sink) +{ + if(!m_irModule) + { + m_irModule = generateIRForProgram( + m_linkage->getSession(), + this, + sink); + } + return m_irModule; +} + + +TargetProgram* Program::getTargetProgram(TargetRequest* target) +{ + RefPtr<TargetProgram> targetProgram; + if(!m_targetPrograms.TryGetValue(target, targetProgram)) + { + targetProgram = new TargetProgram(this, target); + m_targetPrograms[target] = targetProgram; + } + return targetProgram; +} + +// +// TargetProgram +// + +TargetProgram::TargetProgram( + Program* program, + TargetRequest* targetReq) + : m_program(program) + , m_targetReq(targetReq) +{ + m_entryPointResults.SetSize(program->getEntryPoints().Count()); +} + +// + +void DiagnosticSink::noteInternalErrorLoc(SourceLoc const& loc) { // Don't consider invalid source locations. if(!loc.isValid()) @@ -1054,14 +1430,19 @@ void CompileRequest::noteInternalErrorLoc(SourceLoc const& loc) // code might have confused the compiler. if(internalErrorLocsNoted == 0) { - mSink.diagnose(loc, Diagnostics::noteLocationOfInternalError); + diagnose(loc, Diagnostics::noteLocationOfInternalError); } internalErrorLocsNoted++; } +Session* CompileRequestBase::getSession() +{ + return getLinkage()->getSession(); +} + static const Slang::Guid IID_ISlangFileSystemExt = SLANG_UUID_ISlangFileSystemExt; -void CompileRequest::setFileSystem(ISlangFileSystem* inFileSystem) +void Linkage::setFileSystem(ISlangFileSystem* inFileSystem) { // Set the fileSystem fileSystem = inFileSystem; @@ -1085,15 +1466,16 @@ void CompileRequest::setFileSystem(ISlangFileSystem* inFileSystem) } // Set the file system used on the source manager - sourceManager->setFileSystemExt(fileSystemExt); + getSourceManager()->setFileSystemExt(fileSystemExt); } -RefPtr<ModuleDecl> findOrImportModule( - CompileRequest* request, +RefPtr<Module> findOrImportModule( + Linkage* linkage, Name* name, - SourceLoc const& loc) + SourceLoc const& loc, + DiagnosticSink* sink) { - return request->findOrImportModule(name, loc); + return linkage->findOrImportModule(name, loc, sink); } void Session::addBuiltinSource( @@ -1101,30 +1483,34 @@ void Session::addBuiltinSource( String const& path, String const& source) { - RefPtr<CompileRequest> compileRequest = new CompileRequest(this); - compileRequest->setSourceManager(getBuiltinSourceManager()); + DiagnosticSink sink; + RefPtr<FrontEndCompileRequest> compileRequest = new FrontEndCompileRequest( + m_builtinLinkage, + &sink); - auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, path); + Name* moduleName = getNamePool()->getName(path); + auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, moduleName); compileRequest->addTranslationUnitSourceString( translationUnitIndex, path, source); - SlangResult res = compileRequest->executeActions(); + SlangResult res = compileRequest->executeActionsInner(); if (SLANG_FAILED(res)) { - fprintf(stderr, "%s", compileRequest->mDiagnosticOutput.Buffer()); + char const* diagnostics = sink.outputBuffer.Buffer(); + fprintf(stderr, "%s", diagnostics); #ifdef _WIN32 - OutputDebugStringA(compileRequest->mDiagnosticOutput.Buffer()); + OutputDebugStringA(diagnostics); #endif SLANG_UNEXPECTED("error in Slang standard library"); } // Extract the AST for the code we just parsed - auto syntax = compileRequest->translationUnits[translationUnitIndex]->SyntaxNode; + auto syntax = compileRequest->translationUnits[translationUnitIndex]->getModuleDecl(); // HACK(tfoley): mark all declarations in the "stdlib" so // that we can detect them later (e.g., so we don't emit them) @@ -1176,19 +1562,37 @@ Session::~Session() // implementation of C interface -#define SESSION(x) reinterpret_cast<Slang::Session *>(x) -#define REQ(x) reinterpret_cast<Slang::CompileRequest*>(x) +static SlangSession* convert(Slang::Session* session) +{ return reinterpret_cast<SlangSession*>(session); } + +static Slang::Session* convert(SlangSession* session) +{ return reinterpret_cast<Slang::Session*>(session); } + +static SlangCompileRequest* convert(Slang::EndToEndCompileRequest* request) +{ return reinterpret_cast<SlangCompileRequest*>(request); } + +static Slang::EndToEndCompileRequest* convert(SlangCompileRequest* request) +{ return reinterpret_cast<Slang::EndToEndCompileRequest*>(request); } + +static SlangLinkage* convert(Slang::Linkage* linkage) +{ return reinterpret_cast<SlangLinkage*>(linkage); } + +static Slang::Linkage* convert(SlangLinkage* linkage) +{ return reinterpret_cast<Slang::Linkage*>(linkage); } + +static SlangModule* convert(Slang::Module* module) +{ return reinterpret_cast<SlangModule*>(module); } SLANG_API SlangSession* spCreateSession(const char*) { - return reinterpret_cast<SlangSession *>(new Slang::Session()); + return convert(new Slang::Session()); } SLANG_API void spDestroySession( SlangSession* session) { if(!session) return; - delete SESSION(session); + delete convert(session); } SLANG_API void spAddBuiltins( @@ -1196,7 +1600,7 @@ SLANG_API void spAddBuiltins( char const* sourcePath, char const* sourceString) { - auto s = SESSION(session); + auto s = convert(session); s->addBuiltinSource( // TODO(tfoley): Add ability to directly new builtins to the approriate scope @@ -1210,7 +1614,7 @@ SLANG_API void spSessionSetSharedLibraryLoader( SlangSession* session, ISlangSharedLibraryLoader* loader) { - auto s = SESSION(session); + auto s = convert(session); if (!loader) { @@ -1237,7 +1641,7 @@ SLANG_API void spSessionSetSharedLibraryLoader( SLANG_API ISlangSharedLibraryLoader* spSessionGetSharedLibraryLoader( SlangSession* session) { - auto s = SESSION(session); + auto s = convert(session); return (s->sharedLibraryLoader == Slang::DefaultSharedLibraryLoader::getSingleton()) ? nullptr : s->sharedLibraryLoader.get(); } @@ -1245,7 +1649,7 @@ SLANG_API SlangResult spSessionCheckCompileTargetSupport( SlangSession* session, SlangCompileTarget target) { - auto s = SESSION(session); + auto s = convert(session); return Slang::checkCompileTargetSupport(s, Slang::CodeGenTarget(target)); } @@ -1253,16 +1657,45 @@ SLANG_API SlangResult spSessionCheckPassThroughSupport( SlangSession* session, SlangPassThrough passThrough) { - auto s = SESSION(session); + auto s = convert(session); return Slang::checkExternalCompilerSupport(s, Slang::PassThroughMode(passThrough)); } + +SLANG_API SlangLinkage* spCreateLinkage( + SlangSession* session) +{ + auto s = convert(session); + auto linkage = new Slang::Linkage(s); + return convert(linkage); +} + +SLANG_API void spDestroyLinkage( + SlangLinkage* linkage) +{ + if(!linkage) return; + auto lnk = convert(linkage); + delete lnk; +} + +SLANG_API SlangModule* spLoadModule( + SlangLinkage* linkage, + char const* moduleName) +{ + if(!linkage) return nullptr; + auto lnk = convert(linkage); + + auto mod = lnk->loadModule(moduleName); + return convert(mod); +} + + SLANG_API SlangCompileRequest* spCreateCompileRequest( SlangSession* session) { - auto s = SESSION(session); - auto req = new Slang::CompileRequest(s); - return reinterpret_cast<SlangCompileRequest*>(req); + auto s = convert(session); + auto req = new Slang::EndToEndCompileRequest(s); + return convert(req); } /*! @@ -1272,7 +1705,7 @@ SLANG_API void spDestroyCompileRequest( SlangCompileRequest* request) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); delete req; } @@ -1281,21 +1714,21 @@ SLANG_API void spSetFileSystem( ISlangFileSystem* fileSystem) { if(!request) return; - REQ(request)->setFileSystem(fileSystem); + convert(request)->getLinkage()->setFileSystem(fileSystem); } SLANG_API void spSetCompileFlags( SlangCompileRequest* request, SlangCompileFlags flags) { - REQ(request)->compileFlags = flags; + convert(request)->getFrontEndReq()->compileFlags = flags; } SLANG_API void spSetDumpIntermediates( SlangCompileRequest* request, int enable) { - REQ(request)->shouldDumpIntermediates = enable != 0; + convert(request)->getBackEndReq()->shouldDumpIntermediates = enable != 0; } SLANG_API void spSetLineDirectiveMode( @@ -1304,13 +1737,13 @@ SLANG_API void spSetLineDirectiveMode( { // TODO: validation - REQ(request)->lineDirectiveMode = Slang::LineDirectiveMode(mode); + convert(request)->getBackEndReq()->lineDirectiveMode = Slang::LineDirectiveMode(mode); } SLANG_API void spSetCommandLineCompilerMode( SlangCompileRequest* request) { - REQ(request)->isCommandLineCompile = true; + convert(request)->isCommandLineCompile = true; } @@ -1318,17 +1751,19 @@ SLANG_API void spSetCodeGenTarget( SlangCompileRequest* request, SlangCompileTarget target) { - auto req = REQ(request); - req->targets.Clear(); - req->addTarget(Slang::CodeGenTarget(target)); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets.Clear(); + linkage->addTarget(Slang::CodeGenTarget(target)); } SLANG_API int spAddCodeGenTarget( SlangCompileRequest* request, SlangCompileTarget target) { - auto req = REQ(request); - return (int) req->addTarget(Slang::CodeGenTarget(target)); + auto req = convert(request); + auto linkage = req->getLinkage(); + return (int) linkage->addTarget(Slang::CodeGenTarget(target)); } SLANG_API void spSetTargetProfile( @@ -1336,8 +1771,9 @@ SLANG_API void spSetTargetProfile( int targetIndex, SlangProfileID profile) { - auto req = REQ(request); - req->targets[targetIndex]->targetProfile = Slang::Profile(profile); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets[targetIndex]->targetProfile = Slang::Profile(profile); } SLANG_API void spSetTargetFlags( @@ -1345,8 +1781,9 @@ SLANG_API void spSetTargetFlags( int targetIndex, SlangTargetFlags flags) { - auto req = REQ(request); - req->targets[targetIndex]->targetFlags = flags; + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets[targetIndex]->targetFlags = flags; } SLANG_API void spSetTargetFloatingPointMode( @@ -1354,16 +1791,18 @@ SLANG_API void spSetTargetFloatingPointMode( int targetIndex, SlangFloatingPointMode mode) { - auto req = REQ(request); - req->targets[targetIndex]->floatingPointMode = Slang::FloatingPointMode(mode); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets[targetIndex]->floatingPointMode = Slang::FloatingPointMode(mode); } SLANG_API void spSetMatrixLayoutMode( SlangCompileRequest* request, SlangMatrixLayoutMode mode) { - auto req = REQ(request); - req->defaultMatrixLayoutMode = Slang::MatrixLayoutMode(mode); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->defaultMatrixLayoutMode = Slang::MatrixLayoutMode(mode); } SLANG_API void spSetTargetMatrixLayoutMode( @@ -1380,7 +1819,7 @@ SLANG_API void spSetOutputContainerFormat( SlangCompileRequest* request, SlangContainerFormat format) { - auto req = REQ(request); + auto req = convert(request); req->containerFormat = Slang::ContainerFormat(format); } @@ -1389,7 +1828,7 @@ SLANG_API void spSetPassThrough( SlangCompileRequest* request, SlangPassThrough passThrough) { - REQ(request)->passThrough = Slang::PassThroughMode(passThrough); + convert(request)->passThrough = Slang::PassThroughMode(passThrough); } SLANG_API void spSetDiagnosticCallback( @@ -1400,7 +1839,7 @@ SLANG_API void spSetDiagnosticCallback( using namespace Slang; if(!request) return; - auto req = REQ(request); + auto req = convert(request); ComPtr<ISlangWriter> writer(new CallbackWriter(callback, userData, WriterFlag::IsConsole)); req->setWriter(WriterChannel::Diagnostic, writer); @@ -1412,7 +1851,7 @@ SLANG_API void spSetWriter( ISlangWriter* writer) { if (!request) return; - auto req = REQ(request); + auto req = convert(request); req->setWriter(Slang::WriterChannel(chan), writer); } @@ -1422,15 +1861,17 @@ SLANG_API ISlangWriter* spGetWriter( SlangWriterChannel chan) { if (!request) return nullptr; - auto req = REQ(request); + auto req = convert(request); return req->getWriter(Slang::WriterChannel(chan)); } SLANG_API void spAddSearchPath( - SlangCompileRequest* request, - const char* path) + SlangCompileRequest* request, + const char* path) { - REQ(request)->searchDirectories.Add(Slang::SearchDirectory(path)); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->searchDirectories.searchDirectories.Add(Slang::SearchDirectory(path)); } SLANG_API void spAddPreprocessorDefine( @@ -1438,25 +1879,27 @@ SLANG_API void spAddPreprocessorDefine( const char* key, const char* value) { - REQ(request)->preprocessorDefinitions[key] = value; + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->preprocessorDefinitions[key] = value; } SLANG_API char const* spGetDiagnosticOutput( SlangCompileRequest* request) { if(!request) return 0; - auto req = REQ(request); + auto req = convert(request); return req->mDiagnosticOutput.begin(); } SLANG_API SlangResult spGetDiagnosticOutputBlob( - SlangCompileRequest* request, - ISlangBlob** outBlob) + SlangCompileRequest* request, + ISlangBlob** outBlob) { if(!request) return SLANG_ERROR_INVALID_PARAMETER; if(!outBlob) return SLANG_ERROR_INVALID_PARAMETER; - auto req = REQ(request); + auto req = convert(request); if(!req->diagnosticOutputBlob) { @@ -1475,11 +1918,13 @@ SLANG_API int spAddTranslationUnit( SlangSourceLanguage language, char const* name) { - auto req = REQ(request); + SLANG_UNUSED(name); + + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); - return req->addTranslationUnit( - Slang::SourceLanguage(language), - name ? name : ""); + return frontEndReq->addTranslationUnit( + Slang::SourceLanguage(language)); } SLANG_API void spTranslationUnit_addPreprocessorDefine( @@ -1488,10 +1933,10 @@ SLANG_API void spTranslationUnit_addPreprocessorDefine( const char* key, const char* value) { - auto req = REQ(request); - - req->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value; + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + frontEndReq->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value; } SLANG_API void spAddTranslationUnitSourceFile( @@ -1500,12 +1945,13 @@ SLANG_API void spAddTranslationUnitSourceFile( char const* path) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if(!path) return; if(translationUnitIndex < 0) return; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return; + if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return; - req->addTranslationUnitSourceFile( + frontEndReq->addTranslationUnitSourceFile( translationUnitIndex, path); } @@ -1533,14 +1979,15 @@ SLANG_API void spAddTranslationUnitSourceStringSpan( char const* sourceEnd) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if(!sourceBegin) return; if(translationUnitIndex < 0) return; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return; + if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return; if(!path) path = ""; - req->addTranslationUnitSourceString( + frontEndReq->addTranslationUnitSourceString( translationUnitIndex, path, Slang::UnownedStringSlice(sourceBegin, sourceEnd)); @@ -1553,14 +2000,15 @@ SLANG_API void spAddTranslationUnitSourceBlob( ISlangBlob* sourceBlob) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if(!sourceBlob) return; if(translationUnitIndex < 0) return; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return; + if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return; if(!path) path = ""; - req->addTranslationUnitSourceBlob( + frontEndReq->addTranslationUnitSourceBlob( translationUnitIndex, path, sourceBlob); @@ -1584,17 +2032,13 @@ SLANG_API int spAddEntryPoint( char const* name, SlangStage stage) { - if(!request) return -1; - auto req = REQ(request); - if(!name) return -1; - if(translationUnitIndex < 0) return -1; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1; - - return req->addEntryPoint( + return spAddEntryPointEx( + request, translationUnitIndex, name, - Slang::Profile(Slang::Stage(stage)), - Slang::List<Slang::String>()); + stage, + 0, + nullptr); } SLANG_API int spAddEntryPointEx( @@ -1606,10 +2050,11 @@ SLANG_API int spAddEntryPointEx( char const ** genericParamTypeNames) { if (!request) return -1; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if (!name) return -1; if (translationUnitIndex < 0) return -1; - if (Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1; + if (Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return -1; Slang::List<Slang::String> typeNames; for (int i = 0; i < genericParamTypeNameCount; i++) typeNames.Add(genericParamTypeNames[i]); @@ -1620,12 +2065,28 @@ SLANG_API int spAddEntryPointEx( typeNames); } +SLANG_API SlangResult spSetGlobalGenericArgs( + SlangCompileRequest* request, + int genericArgCount, + char const** genericArgs) +{ + if (!request) return SLANG_FAIL; + auto req = convert(request); + + auto& genericArgStrings = req->globalGenericArgStrings; + genericArgStrings.Clear(); + for (int i = 0; i < genericArgCount; i++) + genericArgStrings.Add(genericArgs[i]); + + return SLANG_OK; +} + // Compile in a context that already has its translation units specified SLANG_API SlangResult spCompile( SlangCompileRequest* request) { - auto req = REQ(request); + auto req = convert(request); #if !defined(SLANG_DEBUG_INTERNAL_ERROR) // By default we'd like to catch as many internal errors as possible, @@ -1654,7 +2115,7 @@ SLANG_API SlangResult spCompile( // We will print out information on the exception to help out the user // in either filing a bug, or locating what in their code created // a problem. - req->mSink.diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAbortedDueToException, typeid(e).name(), e.Message); + req->getSink()->diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAbortedDueToException, typeid(e).name(), e.Message); } catch (...) { @@ -1662,9 +2123,9 @@ SLANG_API SlangResult spCompile( // `Exception`, so something really fishy is going on. We want to // let the user know that we messed up, so they know to blame Slang // and not some other component in their system. - req->mSink.diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAborted); + req->getSink()->diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAborted); } - req->mDiagnosticOutput = req->mSink.outputBuffer.ProduceString(); + req->mDiagnosticOutput = req->getSink()->outputBuffer.ProduceString(); return res; #else // When debugging, we probably don't want to filter out any errors, since @@ -1680,8 +2141,10 @@ spGetDependencyFileCount( SlangCompileRequest* request) { if(!request) return 0; - auto req = REQ(request); - return (int) req->mDependencyFilePaths.Count(); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + auto program = frontEndReq->getProgram(); + return (int) program->getFilePathDependencies().Count(); } /** Get the path to a file this compilation dependend on. @@ -1692,16 +2155,19 @@ spGetDependencyFilePath( int index) { if(!request) return 0; - auto req = REQ(request); - return req->mDependencyFilePaths[index].begin(); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + auto program = frontEndReq->getProgram(); + return program->getFilePathDependencies()[index].begin(); } SLANG_API int spGetTranslationUnitCount( SlangCompileRequest* request) { - auto req = REQ(request); - return (int) req->translationUnits.Count(); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + return (int) frontEndReq->translationUnits.Count(); } // Get the output code associated with a specific translation unit @@ -1718,15 +2184,26 @@ SLANG_API void const* spGetEntryPointCode( int entryPointIndex, size_t* outSize) { - auto req = REQ(request); + auto req = convert(request); + auto linkage = req->getLinkage(); + auto program = req->getSpecializedProgram(); // TODO: We should really accept a target index in this API - auto targetCount = req->targets.Count(); - if (targetCount == 0) + Slang::UInt targetIndex = 0; + auto targetCount = linkage->targets.Count(); + if (targetIndex >= targetCount) return nullptr; - auto targetReq = req->targets[0]; + auto targetReq = linkage->targets[targetIndex]; - Slang::CompileResult& result = targetReq->entryPointResults[entryPointIndex]; + + if(entryPointIndex < 0) return nullptr; + if(Slang::UInt(entryPointIndex) >= req->entryPoints.Count()) return nullptr; + auto entryPoint = program->getEntryPoint(entryPointIndex); + + auto targetProgram = program->getTargetProgram(targetReq); + if(!targetProgram) + return nullptr; + Slang::CompileResult& result = targetProgram->getExistingEntryPointResult(entryPointIndex); void const* data = nullptr; size_t size = 0; @@ -1761,21 +2238,29 @@ SLANG_API SlangResult spGetEntryPointCodeBlob( if(!request) return SLANG_ERROR_INVALID_PARAMETER; if(!outBlob) return SLANG_ERROR_INVALID_PARAMETER; - auto req = REQ(request); + auto req = convert(request); + auto linkage = req->getLinkage(); + auto program = req->getSpecializedProgram(); - int targetCount = (int) req->targets.Count(); + int targetCount = (int) linkage->targets.Count(); if((targetIndex < 0) || (targetIndex >= targetCount)) { return SLANG_ERROR_INVALID_PARAMETER; } - auto targetReq = req->targets[targetIndex]; + auto targetReq = linkage->targets[targetIndex]; int entryPointCount = (int) req->entryPoints.Count(); if((entryPointIndex < 0) || (entryPointIndex >= entryPointCount)) { return SLANG_ERROR_INVALID_PARAMETER; } - Slang::CompileResult& result = targetReq->entryPointResults[entryPointIndex]; + auto entryPointReq = program->getEntryPoint(entryPointIndex); + + + auto targetProgram = program->getTargetProgram(targetReq); + if(!targetProgram) + return SLANG_FAIL; + Slang::CompileResult& result = targetProgram->getExistingEntryPointResult(entryPointIndex); auto blob = result.getBlob(); *outBlob = blob.detach(); @@ -1793,13 +2278,9 @@ SLANG_API void const* spGetCompileRequestCode( SlangCompileRequest* request, size_t* outSize) { - auto req = REQ(request); - - void const* data = req->generatedBytecode.Buffer(); - size_t size = req->generatedBytecode.Count(); - - if(outSize) *outSize = size; - return data; + SLANG_UNUSED(request); + SLANG_UNUSED(outSize); + return nullptr; } // Reflection API @@ -1808,7 +2289,9 @@ SLANG_API SlangReflection* spGetReflection( SlangCompileRequest* request) { if( !request ) return 0; - auto req = REQ(request); + auto req = convert(request); + auto linkage = req->getLinkage(); + auto program = req->getSpecializedProgram(); // Note(tfoley): The API signature doesn't let the client // specify which target they want to access reflection @@ -1818,12 +2301,16 @@ SLANG_API SlangReflection* spGetReflection( // so that we can do this better, and make it clear that // `spGetReflection()` is shorthand for `targetIndex == 0`. // - auto targetCount = req->targets.Count(); - if (targetCount == 0) - return 0; - auto targetReq = req->targets[0]; + Slang::UInt targetIndex = 0; + auto targetCount = linkage->targets.Count(); + if (targetIndex >= targetCount) + return nullptr; + + auto targetReq = linkage->targets[targetIndex]; + auto targetProgram = program->getTargetProgram(targetReq); + auto programLayout = targetProgram->getExistingLayout(); - return (SlangReflection*) targetReq->layout.Ptr(); + return (SlangReflection*) programLayout; } // ... rest of reflection API implementation is in `Reflection.cpp` diff --git a/source/slang/syntax-visitors.h b/source/slang/syntax-visitors.h index 9644deae1..3fca323e8 100644 --- a/source/slang/syntax-visitors.h +++ b/source/slang/syntax-visitors.h @@ -6,8 +6,10 @@ namespace Slang { - class CompileRequest; - class EntryPointRequest; + class DiagnosticSink; + class EntryPoint; + class Linkage; + class Module; class ShaderCompiler; class ShaderLinkInfo; class ShaderSymbol; @@ -24,10 +26,11 @@ namespace Slang // Needed by import declaration checking. // // TODO: need a better location to declare this. - RefPtr<ModuleDecl> findOrImportModule( - CompileRequest* request, + RefPtr<Module> findOrImportModule( + Linkage* linkage, Name* name, - SourceLoc const& loc); + SourceLoc const& loc, + DiagnosticSink* sink); } #endif
\ No newline at end of file diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 709206278..b1b9f6d80 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -2713,4 +2713,15 @@ RefPtr<Val> TaggedUnionSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int return substWitness; } +Module* getModule(Decl* decl) +{ + for( auto dd = decl; dd; dd = dd->ParentDecl ) + { + if(auto moduleDecl = as<ModuleDecl>(dd)) + return moduleDecl->module; + } + return nullptr; +} + + } // namespace Slang diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 6a404214e..5198a44b2 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -12,6 +12,7 @@ namespace Slang { + class Module; class Name; class Session; class Substitutions; @@ -1360,6 +1361,10 @@ namespace Slang Function = 4, All = 7 }; + + /// Get the module that a declaration is associated with, if any. + Module* getModule(Decl* decl); + } // namespace Slang #endif diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp index 0bc676cf9..95a92ee2c 100644 --- a/source/slang/type-layout.cpp +++ b/source/slang/type-layout.cpp @@ -802,7 +802,7 @@ LayoutRulesImpl* GetLayoutRulesImpl(LayoutRule rule) LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targetReq) { - switch (targetReq->target) + switch (targetReq->getTarget()) { case CodeGenTarget::HLSL: case CodeGenTarget::DXBytecode: @@ -821,12 +821,13 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe } } -TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq) +TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq, ProgramLayout* programLayout) { LayoutRulesFamilyImpl* rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq); TypeLayoutContext context; context.targetReq = targetReq; + context.programLayout = programLayout; context.rules = nullptr; context.matrixLayoutMode = targetReq->getDefaultMatrixLayoutMode(); @@ -962,7 +963,7 @@ static bool isOpenGLTarget(TargetRequest*) bool isD3DTarget(TargetRequest* targetReq) { - switch( targetReq->target ) + switch( targetReq->getTarget() ) { case CodeGenTarget::HLSL: case CodeGenTarget::DXBytecode: @@ -978,7 +979,7 @@ bool isD3DTarget(TargetRequest* targetReq) bool isKhronosTarget(TargetRequest* targetReq) { - switch( targetReq->target ) + switch( targetReq->getTarget() ) { default: return false; @@ -1008,7 +1009,7 @@ static bool isSM5OrEarlier(TargetRequest* targetReq) if(!isD3DTarget(targetReq)) return false; - auto profile = targetReq->targetProfile; + auto profile = targetReq->getTargetProfile(); if(profile.getFamily() == ProfileFamily::DX) { @@ -1024,7 +1025,7 @@ static bool isSM5_1OrLater(TargetRequest* targetReq) if(!isD3DTarget(targetReq)) return false; - auto profile = targetReq->targetProfile; + auto profile = targetReq->getTargetProfile(); if(profile.getFamily() == ProfileFamily::DX) { @@ -2102,7 +2103,7 @@ SimpleLayoutInfo GetLayoutImpl( // // The `maybeAdjustLayoutForArrayElementType` computes an "adjusted" // type layout for the element type which takes the array stride into - // acount. If it returns the same type layout that was passed in, + // account. If it returns the same type layout that was passed in, // then that means no adjustement took place. // // The `additionalSpacesNeededForAdjustedElementType` variable counts @@ -2327,13 +2328,35 @@ SimpleLayoutInfo GetLayoutImpl( // we should have already populated ProgramLayout::genericEntryPointParams list at this point, // so we can find the index of this generic param decl in the list genParamTypeLayout->type = type; - genParamTypeLayout->paramIndex = findGenericParam(context.targetReq->layout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl()); + genParamTypeLayout->paramIndex = findGenericParam(context.programLayout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl()); genParamTypeLayout->rules = rules; genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; *outTypeLayout = genParamTypeLayout; } return info; } + else if( auto simpleGenericParam = declRef.as<GenericTypeParamDecl>() ) + { + // A bare generic type parameter can come up during layout + // of a generic entry point (or an entry point nested in + // a generic type). For now we will just pretend like + // the fields of generic parameter type take no space, + // since there is no reasonable way to account for them + // in the resulting layout. + // + // TODO: It might be better to completely ignore generic + // entry points during initial layout, but doing so would + // mean that users couldn't get layout information on + // any parameters, even those that don't depend on + // generics. + // + SimpleLayoutInfo info; + return GetSimpleLayoutImpl( + info, + type, + rules, + outTypeLayout); + } } else if (auto errorType = as<ErrorType>(type)) { diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index e20db7f56..1d939b18f 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -649,6 +649,14 @@ public: RefPtr<VarLayout> globalScopeLayout; */ + /// The target and program for which layout was computed + TargetProgram* targetProgram; + + TargetProgram* getTargetProgram() { return targetProgram; } + TargetRequest* getTargetReq() { return targetProgram->getTargetReq(); } + Program* getProgram() { return targetProgram->getProgram(); } + + // We catalog the requested entry points here, // and any entry-point-specific parameter data // will (eventually) belong there... @@ -656,8 +664,6 @@ public: List<RefPtr<GenericParamLayout>> globalGenericParams; Dictionary<String, GenericParamLayout*> globalGenericParamsMap; - - TargetRequest* targetRequest = nullptr; }; StructTypeLayout* getGlobalStructLayout( @@ -804,6 +810,8 @@ struct LayoutRulesFamilyImpl virtual LayoutRulesImpl* getShaderRecordConstantBufferRules() = 0; }; +typedef List<RefPtr<GenericParamLayout>> GenericParamLayouts; + struct TypeLayoutContext { // The layout rules to use (e.g., we compute @@ -812,7 +820,12 @@ struct TypeLayoutContext LayoutRulesImpl* rules; // The target request that is triggering layout - TargetRequest* targetReq; + TargetRequest* targetReq; + + // A parent program layout that will establish the ordering + // of all global generic type parameters. + // + ProgramLayout* programLayout; // Whether to lay out matrices column-major // or row-major. @@ -840,8 +853,13 @@ struct TypeLayoutContext // Get an appropriate set of layout rules (packaged up // as a `TypeLayoutContext`) to perform type layout // for the given target. +// +// The provided `programLayout` is used to establish +// the ordering of all global generic type paramters. +// TypeLayoutContext getInitialLayoutContextForTarget( - TargetRequest* targetReq); + TargetRequest* targetReq, + ProgramLayout* programLayout); // Get the "simple" layout for a type according to a given set of layout // rules. Note that a "simple" layout can only consume one `LayoutResourceKind`, diff --git a/tests/bindings/glsl-parameter-blocks.slang b/tests/bindings/glsl-parameter-blocks.slang index d356df775..ee385e158 100644 --- a/tests/bindings/glsl-parameter-blocks.slang +++ b/tests/bindings/glsl-parameter-blocks.slang @@ -1,4 +1,3 @@ -#version 450 core //TEST:CROSS_COMPILE: -profile ps_5_0 -entry main -target spirv-assembly struct Test diff --git a/tests/compute/global-type-param-in-entrypoint.slang b/tests/compute/global-type-param-in-entrypoint.slang index 9a1e9b054..0386a7d10 100644 --- a/tests/compute/global-type-param-in-entrypoint.slang +++ b/tests/compute/global-type-param-in-entrypoint.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_RENDER_COMPUTE: //TEST_INPUT: cbuffer(data=[1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0], stride=16):dxbinding(0),glbinding(0) //TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):dxbinding(1),glbinding(0),out -//TEST_INPUT: type VertImpl +//TEST_INPUT: global_type VertImpl interface IVertInterpolant { diff --git a/tests/ir/string-literal.slang.expected b/tests/ir/string-literal.slang.expected index b86eab2c8..5cbd56aea 100644 --- a/tests/ir/string-literal.slang.expected +++ b/tests/ir/string-literal.slang.expected @@ -1,7 +1,7 @@ result code = 0 standard error = { [entryPoint] -[export("_S04mainp1puV")] +[export("_S3tu04mainp1puV")] [nameHint("main")] func %main : Func(Void, UInt) { diff --git a/tools/gfx/render.h b/tools/gfx/render.h index bfbe0f82a..bc880d0be 100644 --- a/tools/gfx/render.h +++ b/tools/gfx/render.h @@ -141,6 +141,7 @@ struct ShaderCompileRequest EntryPoint vertexShader; EntryPoint fragmentShader; EntryPoint computeShader; + Slang::List<Slang::String> globalTypeArguments; Slang::List<Slang::String> entryPointTypeArguments; }; diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index d3a7acf37..4d2791563 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -387,7 +387,8 @@ Result RenderTestApp::initializeShaders(ShaderCompiler* shaderCompiler) compileRequest.computeShader.source = sourceInfo; compileRequest.computeShader.name = computeEntryPointName; } - compileRequest.entryPointTypeArguments = m_shaderInputLayout.globalTypeArguments; + compileRequest.globalTypeArguments = m_shaderInputLayout.globalTypeArguments; + compileRequest.entryPointTypeArguments = m_shaderInputLayout.entryPointTypeArguments; m_shaderProgram = shaderCompiler->compileProgram(compileRequest); return m_shaderProgram ? SLANG_OK : SLANG_FAIL; diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 0257c9b53..8ab14bf76 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -10,6 +10,7 @@ namespace renderer_test { entries.Clear(); globalTypeArguments.Clear(); + entryPointTypeArguments.Clear(); auto lines = Split(source, '\n'); for (auto & line : lines) { @@ -25,6 +26,14 @@ namespace renderer_test StringBuilder typeExp; while (!parser.IsEnd()) typeExp << parser.ReadToken().Content; + entryPointTypeArguments.Add(typeExp); + } + else if (parser.LookAhead("global_type")) + { + parser.ReadToken(); + StringBuilder typeExp; + while (!parser.IsEnd()) + typeExp << parser.ReadToken().Content; globalTypeArguments.Add(typeExp); } else diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h index 92dd516a7..be86d971f 100644 --- a/tools/render-test/shader-input-layout.h +++ b/tools/render-test/shader-input-layout.h @@ -73,6 +73,7 @@ class ShaderInputLayout public: Slang::List<ShaderInputLayoutEntry> entries; Slang::List<Slang::String> globalTypeArguments; + Slang::List<Slang::String> entryPointTypeArguments; int numRenderTargets = 1; void Parse(const char * source); }; diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index 1dc7323a5..0bf086d43 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -92,16 +92,25 @@ RefPtr<ShaderProgram> ShaderCompiler::compileProgram( RefPtr<ShaderProgram> shaderProgram; - Slang::List<const char*> rawTypeNames; + + Slang::List<const char*> rawGlobalTypeNames; + for (auto typeName : request.globalTypeArguments) + rawGlobalTypeNames.Add(typeName.Buffer()); + spSetGlobalGenericArgs( + slangRequest, + (int)rawGlobalTypeNames.Count(), + rawGlobalTypeNames.Buffer()); + + Slang::List<const char*> rawEntryPointTypeNames; for (auto typeName : request.entryPointTypeArguments) - rawTypeNames.Add(typeName.Buffer()); + rawEntryPointTypeNames.Add(typeName.Buffer()); if (request.computeShader.name) { int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit, computeEntryPointName, SLANG_STAGE_COMPUTE, - (int)rawTypeNames.Count(), - rawTypeNames.Buffer()); + (int)rawEntryPointTypeNames.Count(), + rawEntryPointTypeNames.Buffer()); spSetLineDirectiveMode(slangRequest, SLANG_LINE_DIRECTIVE_MODE_NONE); const SlangResult res = spCompile(slangRequest); @@ -129,8 +138,8 @@ RefPtr<ShaderProgram> ShaderCompiler::compileProgram( } else { - int vertexEntryPoint = spAddEntryPointEx(slangRequest, vertexTranslationUnit, vertexEntryPointName, SLANG_STAGE_VERTEX, (int)rawTypeNames.Count(), rawTypeNames.Buffer()); - int fragmentEntryPoint = spAddEntryPointEx(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, SLANG_STAGE_FRAGMENT, (int)rawTypeNames.Count(), rawTypeNames.Buffer()); + int vertexEntryPoint = spAddEntryPointEx(slangRequest, vertexTranslationUnit, vertexEntryPointName, SLANG_STAGE_VERTEX, (int)rawEntryPointTypeNames.Count(), rawEntryPointTypeNames.Buffer()); + int fragmentEntryPoint = spAddEntryPointEx(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, SLANG_STAGE_FRAGMENT, (int)rawEntryPointTypeNames.Count(), rawEntryPointTypeNames.Buffer()); const SlangResult res = spCompile(slangRequest); if (auto diagnostics = spGetDiagnosticOutput(slangRequest)) |
