diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2019-02-15 09:08:19 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-02-15 09:08:19 -0800 |
| commit | a3fd4e2bc40cfc77db953b14744c30e7a18e7c1d (patch) | |
| tree | 5c226a6a4304086412c051f642a5f45fb043083c | |
| parent | 4cd317bcae0a13dc2bbb78448c8d60cd1dcc76bd (diff) | |
Split front- and back-ends (#846)
* Split front- and back-ends
This change is a major refactor of several of the types that provide the behind-the-scenes implementation of the public C API.
The goal of this refactor is primarily to allow for future API services that let the user operate both the front- and back-ends of the compiler in a more complex fashion.
For example, as user should be able to compile a bunch of source code into modules, look up types, functions, etc. in those modules, specialize generic types/functions to the types they've looked up, and then finally request target code to be gernerated for specialized entry points.
The back-end code generation they trigger should re-use the front-end compilation work (parsing, semantic checking, IR generation) that was already performed.
The most visible change is that `CompileRequest` has been split up into several smaller types that take responsibility for parts of what it did:
* The `Linkage` type owns the storage for `import`ed modules, and well as the `TargetRequest`s that represent code-generation targets. The intention is that an application could use a single `Linkage` for the duration of its runtime (so long as it was okay with the memory usage), so that each `import`ed module only gets loaded once. For now, this type needs to manage the search paths, file system, and source manager, because of its responsibility for loading files.
* A `FrontEndCompileRequest` owns the stuff related to parsing, semantic checking, and initial IR generation. This most notably includes the `TranslationUnitRequest`s and the `FrontEndEntryPointRequest`s (which used to be just `EntryPointRequest`s). It's main job is to produce AST and IR modules for each translation unit, and to find and validate the entry points. The front-end request does *not* interact with generic arguments for global or entry-point generic parameters.
* The main output of both `import` operations and front-end translation units is the `Module` type, which is just a simple container for both the AST module (to service the reflection/layout APIs, and also for semantic checking of code that `import`s the module) and the IR module (for linking and code generation). This type captures the commonalities between the old `LoadedModule` (which is now just an alias for `Module`) and `TranslationUnitRequest` (which now owns a `Module`).
* The secondary output of front-end compilation is a `Program`, which comprises a list of referenced `Module`s and validated `EntryPoint`s that will be used together. Layout and code generation both need a `Program` to tell them what modules and entry points will be used together (we don't want to just code-gen everythin that has ever been loaded into the linakge). The `Program`s created by the front-end do not include generic arguments, so they may provide incomplete layout information and/or be unsuitable for code generation.
* A `BackEndCompileRequest` owns stuff related to turning a `Program` into output kernels for the targets of a `Linkage`. Most of the data it owns beyond the `Program` to be compiled is minor, so this is a good candidate for demotion from a heap-allocated object to just a `struct` of options that gets passed around.
* The `CompileRequestBase` type is an attempt to wrap up the common functionality of both front-end and back-end compile requests. Most of it is just exposing the availability of a linkage and `DiagnosticSink`, so this type is a good candidate for subsequent removal. The main interesting thing it has is the flags related to dumping and validation of IR, so there is probably a good refactoring still to be made around deciding how options should be handled going forward.
* Behind the scenes, the `Program` type is set up to handle some level of on-line compilation and layout work. The `Program` knows the `Linkage` it belongs to, and allows for a `TargetProgram` to be looked up based on a specific `TargetRequest`. A `TargetProgram` then allows layout information and compiled kernel code to be asked for on-demand, in order to support eventual "live" compilation scenarios.
* The `EndToEndCompileRequest` type is a composition/coordination type that replaces the old `CompileRequest` in a way that uses the services of the various other types. It owns a few pieces of state that only make sense in the context of an end-to-end compile (e.g., there is really no way to "pass through" code when the front- and back-ends are run separately) or a command-line compile (everything to do with specifying output paths for files is really just for the benefit of `slangc`, and might even be moved there over time).
* One important detail is that the `EndToEndCompilRequest` owns all of the string-based generic arguments for both global and entry-point generic parameters. The logic in `check.cpp` for dealing with those arguments has been heavily refactored to separate out the parsings steps that are specific to end-to-end compilation with string-based type arguments, and the semantic checking steps that result in a specialized `Program` (which can be exposed through new APIs that aren't tied to end-to-end compilation).
It is perhaps not surprising that this change had a lot of consequences, so I'll briefly run over some of the main categories of changes required:
* I changed the way that global generic arguments are passed via API (use `spSetGlobalGenericArgs` instead of the generic arguments for `spAddEntryPointEx`, which are not just for entry-point generics), which has been a change that we've needed for a long time. This is technically a breaking API change, although we should have very few client applications that care about it.
* A bunch of places that used to take "big" objects like `CompileRequest` now just take the sub-pieces they care about (e.g., a function might have only needed a `Linkage` and a `DiagnosticSink`). This makes many subroutines or "context" struct types more generally useful, at the cost of taking more parameters.
* In a few cases the conceptually clean separation of the layers breaks down (often for edge-case or compatibility features), and so we may pass along additional objects that are allowed to be null, but are used when present. A big example of this is how the back-end code generation routines accept an `EndToEndCompileRequest` that is optional, and only used to check whether "pass through" compilation is needed. We should probably look into cleaning this kind of logic up over time so that we don't need to violate the apparent separation of phases of compilation.
* In cases where separation of layers was being broken for the sake of GLSL features, I went ahead and ripped them out, since all of that should be dead code anyway.
* In many cases I increased the encapsulation of data in the core types to help track down use sites and make sure they are following invariants better.
* In cases where code was doing, e.g., `context->shared->compileRequest->session->getThing()` I have tried to introduce convenience routines so that the usage site is just `context->getThing()` to improve encapsulation and allow changes to be made more easily going forward.
* The `noteInternalErrorLoc` functionality was moved off of the compile request and into `DiagnosticSink`, since that is the one type you can rely on having around when you want to note an internal error. We may consider going forward if (and how) it should reset the counter used for noting locations on internal errors.
* A few APIs now take `DiagnosticSink*` arguments where they didn't before, and as a result some public APIs need to create `DiagnosticSink`s to pass in, before going ahead and ignoring the messages. In the future there should be variations of these APIs that accept an `ISlangBlob**` parameter for the output.
* fixup: missing include for compilers with accurate template checking (non-VS)
* fixup: review feedback
47 files changed, 3294 insertions, 1905 deletions
diff --git a/examples/model-viewer/main.cpp b/examples/model-viewer/main.cpp index c1f7980cf..8bbc8c979 100644 --- a/examples/model-viewer/main.cpp +++ b/examples/model-viewer/main.cpp @@ -881,22 +881,27 @@ RefPtr<EffectVariant> createEffectVaraint( int translationUnitIndex = spAddTranslationUnit(slangRequest, SLANG_SOURCE_LANGUAGE_SLANG, nullptr); spAddTranslationUnitSourceFile(slangRequest, translationUnitIndex, program->shaderModule->inputPath.c_str()); - const int entryPointCont = int(program->entryPoints.size()); - for(int ii = 0; ii < entryPointCont; ++ii) + // Because our shader code uses global generic parameters for + // specialization, we need to specify the concrete argument + // types for the compiler to use when generating code. + // + spSetGlobalGenericArgs( + slangRequest, + int(genericArgs.size()), + genericArgs.data()); + + // Next we tell the Slang compiler about all of the entry points + // we plan to use. + // + const int entryPointCount = int(program->entryPoints.size()); + for(int ii = 0; ii < entryPointCount; ++ii) { auto entryPoint = program->entryPoints[ii]; - - // We are using the `spAddEntryPointEx` API so that we - // can specify the type names to use for the generic - // type parameters of the program. - // - spAddEntryPointEx( + spAddEntryPoint( slangRequest, translationUnitIndex, entryPoint->name.c_str(), - entryPoint->slangStage, - int(genericArgs.size()), - genericArgs.data()); + entryPoint->slangStage); } // We expect compilation to go through without a hitch, because the @@ -923,7 +928,7 @@ RefPtr<EffectVariant> createEffectVaraint( // std::vector<ISlangBlob*> kernelBlobs; std::vector<gfx::ShaderProgram::KernelDesc> kernelDescs; - for(int ii = 0; ii < entryPointCont; ++ii) + for(int ii = 0; ii < entryPointCount; ++ii) { auto entryPoint = program->entryPoints[ii]; diff --git a/external/glslang b/external/glslang -Subproject f6e7c4d2de0d59724ea07739df70c466d169a2c +Subproject 4207c97b938078818140edad101a032cf768191 @@ -923,6 +923,9 @@ extern "C" */ typedef struct SlangSession SlangSession; + typedef struct SlangLinkage SlangLinkage; + typedef struct SlangModule SlangModule; + /*! @brief A request for one or more compilation actions to be performed. */ @@ -989,6 +992,20 @@ extern "C" char const* sourcePath, char const* sourceString); + + + SLANG_API SlangLinkage* spCreateLinkage( + SlangSession* session); + + SLANG_API void spDestroyLinkage( + SlangLinkage* linkage); + + SLANG_API SlangModule* spLoadModule( + SlangLinkage* linkage, + char const* moduleName); + + + /*! @brief Create a compile request. */ @@ -1263,15 +1280,22 @@ extern "C" /** Add an entry point in a particular translation unit, with additional arguments that specify the concrete - type names for global generic type parameters. + type names for entry-point generic type parameters. */ SLANG_API int spAddEntryPointEx( SlangCompileRequest* request, int translationUnitIndex, char const* name, SlangStage stage, - int genericTypeNameCount, - char const** genericTypeNames); + int genericArgCount, + char const** genericArgs); + + /** Specify the arguments to use for global generic parameters. + */ + SLANG_API SlangResult spSetGlobalGenericArgs( + SlangCompileRequest* request, + int genericArgCount, + char const** genericArgs); /** Execute the compilation request. diff --git a/source/core/smart-pointer.h b/source/core/smart-pointer.h index e19ed6a4d..0b03deb8f 100644 --- a/source/core/smart-pointer.h +++ b/source/core/smart-pointer.h @@ -2,6 +2,7 @@ #define FUNDAMENTAL_LIB_SMART_POINTER_H #include "common.h" +#include "hash.h" #include "type-traits.h" #include <assert.h> @@ -157,10 +158,14 @@ namespace Slang releaseReference(old); } - int GetHashCode() - { - return (int)(long long)(void*)pointer; - } + int GetHashCode() + { + // Note: We need a `RefPtr<T>` to hash the same as a `T*`, + // so that a `T*` can be used as a key in a dictionary with + // `RefPtr<T>` keys, and vice versa. + // + return Slang::GetHashCode(pointer); + } bool operator==(const T * ptr) const { diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 3485afeea..483db60bb 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -400,22 +400,18 @@ namespace Slang else return DeclCheckState::CheckedHeader; } - DiagnosticSink* sink = nullptr; + + Linkage* m_linkage = nullptr; + DiagnosticSink* m_sink = nullptr; + DiagnosticSink* getSink() { - return sink; + return m_sink; } // ModuleDecl * program = nullptr; FuncDecl * function = nullptr; - CompileRequest* request = nullptr; - TranslationUnitRequest* translationUnit = nullptr; - - SourceLanguage getSourceLanguage() - { - return translationUnit->sourceLanguage; - } // lexical outer statements List<Stmt*> outerStmts; @@ -429,20 +425,15 @@ namespace Slang public: SemanticsVisitor( - DiagnosticSink* sink, - CompileRequest* request, - TranslationUnitRequest* translationUnit) - : sink(sink) - , request(request) - , translationUnit(translationUnit) - { - } + Linkage* linkage, + DiagnosticSink* sink) + : m_linkage(linkage) + , m_sink(sink) + {} - CompileRequest* getCompileRequest() { return request; } - TranslationUnitRequest* getTranslationUnit() { return translationUnit; } Session* getSession() { - return getCompileRequest()->mSession; + return m_linkage->getSession(); } public: @@ -985,7 +976,7 @@ namespace Slang catch(AbortCompilationException&) { throw; } catch(...) { - getCompileRequest()->noteInternalErrorLoc(decl->loc); + getSink()->noteInternalErrorLoc(decl->loc); throw; } } @@ -998,7 +989,7 @@ namespace Slang catch(AbortCompilationException&) { throw; } catch(...) { - getCompileRequest()->noteInternalErrorLoc(stmt->loc); + getSink()->noteInternalErrorLoc(stmt->loc); throw; } } @@ -1011,7 +1002,7 @@ namespace Slang catch(AbortCompilationException&) { throw; } catch(...) { - getCompileRequest()->noteInternalErrorLoc(expr->loc); + getSink()->noteInternalErrorLoc(expr->loc); throw; } } @@ -1030,7 +1021,7 @@ namespace Slang // being checked on the stack, so that we can report the full // chain that leads from this declaration back to itself. // - sink->diagnose(decl, Diagnostics::cyclicReference, decl); + getSink()->diagnose(decl, Diagnostics::cyclicReference, decl); return; } @@ -1050,7 +1041,7 @@ namespace Slang // TODO: This diagnostic should be emitted on the line that is referencing // the declaration. That requires `EnsureDecl` to take the requesting // location as a parameter. - sink->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl); + getSink()->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl); return; } } @@ -3019,7 +3010,7 @@ namespace Slang checkDecl(func); } - if (sink->GetErrorCount() != 0) + if (getSink()->GetErrorCount() != 0) return; // Force everything to be fully checked, just in case @@ -4921,9 +4912,12 @@ namespace Slang return new ConstantIntVal(expr->value); } + Linkage* getLinkage() { return m_linkage; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + Name* getName(String const& text) { - return getCompileRequest()->getNamePool()->getName(text); + return getNamePool()->getName(text); } RefPtr<IntVal> TryConstantFoldExpr( @@ -5079,58 +5073,18 @@ namespace Slang { auto varDecl = varRef.getDecl(); - switch(getSourceLanguage()) + // In HLSL, `static const` is used to mark compile-time constant expressions + if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) { - default: - case SourceLanguage::Slang: - case SourceLanguage::HLSL: - // HLSL: `static const` is used to mark compile-time constant expressions - if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) - { - if(auto constAttr = varDecl->FindModifier<ConstModifier>()) - { - // HLSL `static const` can be used as a constant expression - if(auto initExpr = getInitExpr(varRef)) - { - return TryConstantFoldExpr(initExpr.Ptr()); - } - } - } - break; - - case SourceLanguage::GLSL: - // GLSL: `const` indicates compile-time constant expression - // - // TODO(tfoley): The current logic here isn't robust against - // GLSL "specialization constants" - we will extract the - // initializer for a `const` variable and use it to extract - // a value, when we really should be using an opaque - // reference to the variable. if(auto constAttr = varDecl->FindModifier<ConstModifier>()) { - // We need to handle a "specialization constant" (with a `constant_id` layout modifier) - // differently from an ordinary compile-time constant. The latter can/should be reduced - // to a value, while the former should be kept as a symbolic reference - - if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>()) - { - // Retain the specialization constant as a symbolic reference - // - // TODO(tfoley): handle the case of non-`int` value parameters... - // - // TODO(tfoley): this is cloned from the case above that handles generic value parameters - return new GenericParamIntVal(varRef); - } - else if(auto initExpr = getInitExpr(varRef)) + // HLSL `static const` can be used as a constant expression + if(auto initExpr = getInitExpr(varRef)) { - // This is an ordinary constant, and not a specialization constant, so we - // can try to fold its value right now. return TryConstantFoldExpr(initExpr.Ptr()); } } - break; } - } else if(auto enumRef = declRef.as<EnumCaseDecl>()) { @@ -9060,17 +9014,32 @@ namespace Slang auto scope = decl->scope; // Try to load a module matching the name - auto importedModuleDecl = findOrImportModule(request, name, decl->moduleNameAndLoc.loc); + auto importedModule = findOrImportModule( + getLinkage(), + name, + decl->moduleNameAndLoc.loc, + getSink()); // If we didn't find a matching module, then bail out - if (!importedModuleDecl) + if (!importedModule) return; // Record the module that was imported, so that we can use // it later during code generation. + auto importedModuleDecl = importedModule->getModuleDecl(); decl->importedModuleDecl = importedModuleDecl; - importModuleIntoScope(scope.Ptr(), importedModuleDecl.Ptr()); + // Add the declarations from the imported module into the scope + // that the `import` declaration is set to extend. + // + importModuleIntoScope(scope.Ptr(), importedModuleDecl); + + // Record the `import`ed module (and everything it depends on) + // as a dependency of the module we are compiling. + if(auto module = getModule(decl)) + { + module->addModuleDependency(importedModule); + } decl->SetCheckState(getCheckedState()); } @@ -9142,29 +9111,25 @@ namespace Slang return (!decl->primaryDecl) || (decl == decl->primaryDecl); } - RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp) + RefPtr<Type> checkProperType( + Linkage* linkage, + TypeExp typeExp, + DiagnosticSink* sink) { - RefPtr<Type> type; - DiagnosticSink nSink; - nSink.sourceManager = tu->compileRequest->sourceManager; SemanticsVisitor visitor( - &nSink, - tu->compileRequest, - tu); + linkage, + sink); auto typeOut = visitor.CheckProperType(typeExp); - if (!nSink.errorCount) - { - type = typeOut.type; - } - return type; + return typeOut.type; } - FuncDecl* findFunctionDeclByName(EntryPointRequest* entryPoint, Name* name) + FuncDecl* findFunctionDeclByName( + Module* translationUnit, + Name* name, + DiagnosticSink* sink) { - auto translationUnit = entryPoint->getTranslationUnit(); - auto sink = &entryPoint->compileRequest->mSink; - auto translationUnitSyntax = translationUnit->SyntaxNode; + auto translationUnitSyntax = translationUnit->getModuleDecl(); // Make sure we've got a query-able member dictionary buildMemberDictionary(translationUnitSyntax); @@ -9270,7 +9235,9 @@ namespace Slang // Validate that an entry point function conforms to any additional // constraints based on the stage (and profile?) it specifies. void validateEntryPoint( - EntryPointRequest* entryPoint) + FuncDecl* entryPointFuncDecl, + Stage stage, + DiagnosticSink* sink) { // TODO: We currently do minimal checking here, but this is the // right place to perform the following validation checks: @@ -9297,28 +9264,32 @@ namespace Slang // that function is specific to the fragment profile/stage. // - auto sink = &entryPoint->compileRequest->mSink; + auto entryPointName = entryPointFuncDecl->getName(); + + auto module = getModule(entryPointFuncDecl); + auto linkage = module->getLinkage(); + // Every entry point needs to have a stage specified either via // command-line/API options, or via an explicit `[shader("...")]` attribute. // - if( entryPoint->getStage() == Stage::Unknown ) + if( stage == Stage::Unknown ) { - sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::entryPointHasNoStage, entryPoint->name); + sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointHasNoStage, entryPointName); } - if (entryPoint->getStage() == Stage::Hull) + if( stage == Stage::Hull ) { - auto translationUnit = entryPoint->getTranslationUnit(); - auto translationUnitSyntax = translationUnit->SyntaxNode; + // TODO: We could consider *always* checking any `[patchconsantfunc("...")]` + // attributes, so that they need to resolve to a function. - auto attr = entryPoint->getFuncDecl()->FindModifier<PatchConstantFuncAttribute>(); + auto attr = entryPointFuncDecl->FindModifier<PatchConstantFuncAttribute>(); if (attr) { if (attr->args.Count() != 1) { - sink->diagnose(translationUnitSyntax, Diagnostics::badlyDefinedPatchConstantFunc, entryPoint->name); + sink->diagnose(attr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); return; } @@ -9327,40 +9298,52 @@ namespace Slang if (!stringLit) { - sink->diagnose(translationUnitSyntax, Diagnostics::badlyDefinedPatchConstantFunc, entryPoint->name); + sink->diagnose(expr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName); return; } - Name* name = entryPoint->compileRequest->getNamePool()->getName(stringLit->value); - FuncDecl* funcDecl = findFunctionDeclByName(entryPoint, name); - if (!funcDecl) + // We look up the patch-constant function by its name in the module + // scope of the translation unit that declared the HS entry point. + // + // TODO: Eventually we probably want to do the lookup in the scope + // of the parent declarations of the entry point. E.g., if the entry + // point is a member function of a `struct`, then its patch-constant + // function should be allowed to be another member function of + // the same `struct`. + // + // In the extremely long run we may want to support an alternative to + // this attribute-based linkage between the two functions that + // make up the entry point. + // + Name* name = linkage->getNamePool()->getName(stringLit->value); + FuncDecl* patchConstantFuncDecl = findFunctionDeclByName( + module, + name, + sink); + if (!patchConstantFuncDecl) { - sink->diagnose(translationUnitSyntax, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc"); + sink->diagnose(expr, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc"); return; } - attr->patchConstantFuncDecl = funcDecl; + attr->patchConstantFuncDecl = patchConstantFuncDecl; } } - else if (entryPoint->getStage() == Stage::Compute) + else if(stage == Stage::Compute) { - auto funcDecl = entryPoint->getFuncDecl(); - - auto params = funcDecl->GetParameters(); - - for (const auto& param : params) + for(const auto& param : entryPointFuncDecl->GetParameters()) { - if (auto semantic = param->FindModifier<HLSLSimpleSemantic>()) + if(auto semantic = param->FindModifier<HLSLSimpleSemantic>()) { const auto& semanticToken = semantic->name; String lowerName = String(semanticToken.Content).ToLower(); - if (lowerName == "sv_dispatchthreadid") + if(lowerName == "sv_dispatchthreadid") { Type* paramType = param->getType(); - if (!isValidThreadDispatchIDType(paramType)) + if(!isValidThreadDispatchIDType(paramType)) { String typeString = paramType->ToString(); sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString); @@ -9372,26 +9355,30 @@ namespace Slang } } - // Given an `EntryPointRequest` specified via API or command line options, + // Given an entry point specified via API or command line options, // attempt to find a matching AST declaration that implements the specified // entry point. If such a function is found, then validate that it actually // meets the requirements for the selected stage/profile. // - void findAndValidateEntryPoint( - EntryPointRequest* entryPoint) + // Returns an `EntryPoint` object representing the (unspecialized) + // entry point if it is found and validated, and null otherwise. + // + RefPtr<EntryPoint> findAndValidateEntryPoint( + FrontEndEntryPointRequest* entryPointReq) { // The first step in validating the entry point is to find // the (unique) function declaration that matches its name. // - // TODO: We will eventually need to update this logic - // to work by parsing the provided `entryPoint->name` string - // as an expression, so that we can handle more complex - // names like `foo<int>` or `SomeType.vs`. - - auto translationUnit = entryPoint->getTranslationUnit(); - auto sink = &entryPoint->compileRequest->mSink; - auto translationUnitSyntax = translationUnit->SyntaxNode; + // TODO: We may eventually want/need to extend this to + // account for nested names like `SomeStruct.vsMain`, or + // indeed even to handle generics. + // + auto compileRequest = entryPointReq->getCompileRequest(); + auto translationUnit = entryPointReq->getTranslationUnit(); + auto sink = compileRequest->getSink(); + auto translationUnitSyntax = translationUnit->getModuleDecl(); + auto entryPointName = entryPointReq->getName(); // Make sure we've got a query-able member dictionary buildMemberDictionary(translationUnitSyntax); @@ -9399,12 +9386,12 @@ namespace Slang // We will look up any global-scope declarations in the translation // unit that match the name of our entry point. Decl* firstDeclWithName = nullptr; - if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPoint->name, firstDeclWithName) ) + if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPointName, firstDeclWithName) ) { // If there doesn't appear to be any such declaration, then // we need to diagnose it as an error, and then bail out. - sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPoint->name); - return; + sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPointName); + return nullptr; } // We found at least one global-scope declaration with the right name, @@ -9448,7 +9435,7 @@ namespace Slang // name before, so the whole thing is ambiguous. We need // to diagnose and bail out. - sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPoint->name); + sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPointName); // List all of the declarations that the user *might* mean for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName) @@ -9460,7 +9447,7 @@ namespace Slang } // Bail out. - return; + return nullptr; } } } @@ -9471,127 +9458,197 @@ namespace Slang // If not, then we need to diagnose the error. // For convenience, we will point to the first // declaration with the right name, that wasn't a function. - sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPoint->name); - return; + sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPointName); + return nullptr; } + // TODO: it is possible that the entry point was declared with + // profile or target overloading. Is there anything that we need + // to do at this point to filter out declarations that aren't + // relevant to the selected profile for the entry point? + + // We found something, and can start doing some basic checking. + // // If the entry point specifies a stage via a `[shader("...")]` attribute, // then we might be able to infer a stage for the entry point request if // it didn't have one, *or* issue a diagnostic if there is a mismatch. // + auto entryPointProfile = entryPointReq->getProfile(); if( auto entryPointAttribute = entryPointFuncDecl->FindModifier<EntryPointAttribute>() ) { - if( entryPoint->getStage() == Stage::Unknown ) + auto entryPointStage = entryPointProfile.GetStage(); + if( entryPointStage == Stage::Unknown ) { - entryPoint->profile.setStage(entryPointAttribute->stage); + entryPointProfile.setStage(entryPointAttribute->stage); } - else if( entryPointAttribute->stage != entryPoint->getStage() ) + else if( entryPointAttribute->stage != entryPointStage ) { - sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPoint->name, entryPoint->getStage(), entryPointAttribute->stage); + sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPointName, entryPointStage, entryPointAttribute->stage); } } + else + { + // TODO: Should we attach a `[shader(...)]` attribute to an + // entry point that didn't have one, so that we can have + // a more uniform representation in the AST? + } - // TODO: it is possible that the entry point was declared with - // profile or target overloading. Is there anything that we need - // to do at this point to filter out declarations that aren't - // relevant to the selected profile for the entry point? - // Phew, we have at least found a suitable decl. - // Let's record that in the entry-point request so - // that we don't have to re-do this effort again later. - // - // Note: we may replace the decl-ref we store at this point - // later in this function, when we (potentially) specialize - // a generic entry point to generic arguments provided - // via the API. + // Now that we've *found* the entry point, it is time to validate + // that it actually meets the constraints for the chosen stage/profile. // - entryPoint->funcDeclRef = makeDeclRef(entryPointFuncDecl); + validateEntryPoint( + entryPointFuncDecl, + entryPointProfile.GetStage(), + sink); - // If the user specified generic arguments for the entry point, - // then we will want to parse those arguments as expressions - // in a scope that includes the tanslation unit that holds - // the entry point, as well as any other modules that got - // transitively loaded via `import`. - // - // TODO: This would be better handled by giving the user - // more explicit ways to parse/build types at the API level, - // rather than keeping things string-based this far along. + RefPtr<EntryPoint> entryPoint = EntryPoint::create( + makeDeclRef(entryPointFuncDecl), + entryPointProfile); + + return entryPoint; + } + + /// Create a `Program` to represent the compiled code. + /// + /// The created program will comprise all of the translation + /// units that were compiled as part of the request, as + /// well as any entry points in those translation units. + /// + RefPtr<Program> createUnspecializedProgram( + FrontEndCompileRequest* compileRequest) + { + // We want our resulting program to depend on + // all the translation units the user specified, + // even if some of them don't contain entry points + // (this is important for parameter layout/binding). // - // TODO: Building a list of `scopesToTry` here shouldn't - // be required, since the `Scope` type itself has the ability - // for form chains for lookup purposes (e.g., the way that - // `import` is handled by modifying a scope). + // We also want to ensure that the modules for the + // translation units comes first in the enumerated + // order for dependencies, to match the pre-existing + // compiler behavior (at least for now). // - List<RefPtr<Scope>> scopesToTry; - scopesToTry.Add(entryPoint->getTranslationUnit()->SyntaxNode->scope); - for (auto & module : entryPoint->compileRequest->loadedModulesList) - scopesToTry.Add(module->moduleDecl->scope); + auto linkage = compileRequest->getLinkage(); + auto sink = compileRequest->getSink(); + auto program = new Program(linkage); + for(auto translationUnit : compileRequest->translationUnits ) + { + program->addReferencedLeafModule(translationUnit->getModule()); + } + for(auto translationUnit : compileRequest->translationUnits ) + { + program->addReferencedModule(translationUnit->getModule()); + } - // We are going to do some semantic checking, so we need to - // set up a `SemanticsVistitor` that we can use. - // - SemanticsVisitor semantics( - &entryPoint->compileRequest->mSink, - entryPoint->compileRequest, - entryPoint->getTranslationUnit()); - // We will be looping over the generic argument strings - // that the user provided via the API (or command line), - // and parsing+checking each into an `Expr`. + // The validation of entry points here will be modal, and controlled + // by whether the user specified any entry points directly via + // API or command-line options. // - // This loop will *not* handle coercing the arguments - // to be types. + // TODO: We may want to make this choice explicit rather than implicit. // - List<RefPtr<Expr>> genericArgs; - for (auto name : entryPoint->genericArgStrings) + // First, check if the user requested any entry points explicitly via + // the API or command line. + // + bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0; + + if( anyExplicitEntryPoints ) { - RefPtr<Expr> argExpr; - for (auto & s : scopesToTry) + // If there were any explicit requests for entry points to be + // checked, then we will *only* check those. + // + for(auto entryPointReq : compileRequest->getEntryPointReqs()) { - argExpr = entryPoint->compileRequest->parseTypeString( - entryPoint->getTranslationUnit(), - name, - s); - argExpr = semantics.CheckTerm(argExpr); - if( argExpr ) + auto entryPoint = findAndValidateEntryPoint( + entryPointReq); + if( entryPoint ) { - break; + program->addEntryPoint(entryPoint); + entryPointReq->getTranslationUnit()->entryPoints.Add(entryPoint); } } - // The following is a bit of a hack. - // - // Back-end code generation relies on us having computed layouts for all tagged - // unions that end up being used in the code, which means we need a way to find - // all such types that get used in a module (and the stuff it imports). + // TODO: We should consider always processing both categories, + // and just making sure to only check each entry point function + // declaration once... + } + else + { + // Otherwise, scan for any `[shader(...)]` attributes in + // the user's code, and construct `EntryPoint`s to + // represent them. // - // The Right Way to handle this would probably be to have each `ModuleDecl` track - // any tagged union types that get created in the context of that module, and - // then combine those lists later. + // This ensures that downstream code only has to consider + // the central list of entry point requests, and doesn't + // have to know where they came from. + + // TODO: A comprehensive approach here would need to search + // recursively for entry points, because they might appear + // as, e.g., member function of a `struct` type. // - // For now we are assuming a tagged union type only comes into existence - // as a (top-level) argument for a generic type parameter, so that we - // can check for them here and cache them on the entry point request. + // For now we'll start with an extremely basic approach that + // should work for typical HLSL code. // - if( auto typeType = as<TypeType>(argExpr->type) ) + UInt translationUnitCount = compileRequest->translationUnits.Count(); + for(UInt tt = 0; tt < translationUnitCount; ++tt) { - auto type = typeType->type; - if( auto taggedUnionType = as<TaggedUnionType>(type) ) + auto translationUnit = compileRequest->translationUnits[tt]; + for( auto globalDecl : translationUnit->getModuleDecl()->Members ) { - entryPoint->taggedUnionTypes.Add(taggedUnionType); + auto maybeFuncDecl = globalDecl; + if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) ) + { + maybeFuncDecl = genericDecl->inner; + } + + auto funcDecl = as<FuncDecl>(maybeFuncDecl); + if(!funcDecl) + continue; + + auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>(); + if(!entryPointAttr) + continue; + + // We've discovered a valid entry point. It is a function (possibly + // generic) that has a `[shader(...)]` attribute to mark it as an + // entry point. + // + // We will now register that entry point as an `EntryPoint` + // with an appropriately chosen profile. + // + // The profile will only include a stage, so that the profile "family" + // and "version" are left unspecified. Downstream code will need + // to be able to handle this case. + // + Profile profile; + profile.setStage(entryPointAttr->stage); + + validateEntryPoint(funcDecl, entryPointAttr->stage, sink); + + RefPtr<EntryPoint> entryPoint = EntryPoint::create( + makeDeclRef(funcDecl), + profile); + program->addEntryPoint(entryPoint); + translationUnit->entryPoints.Add(entryPoint); } } - - genericArgs.Add(argExpr); } - // There are two cases we care about here, and we are going to treat them - // as mutually exclusive for simplicity. - // - // The first case is when the entry point function is itself generic, - // in which case we will assume that `genericArgs` lines up one-to-one - // with the explicit generic parameters of the entry point. - // + return program; + } + + /// Create a specialization an existing entry point based on generic arguments. + DeclRef<FuncDecl> specializeEntryPoint( + Linkage* linkage, + FuncDecl* entryPointFuncDecl, + List<RefPtr<Expr>> const& genericArgs, + DiagnosticSink* sink) + { + SemanticsVisitor semantics( + linkage, + sink); + + DeclRef<FuncDecl> entryPointFuncDeclRef = makeDeclRef(entryPointFuncDecl); if( auto genericDecl = as<GenericDecl>(entryPointFuncDecl->ParentDecl) ) { // We will construct a suitable `GenericAppExpr` to represent @@ -9601,7 +9658,7 @@ namespace Slang // generic application like `F<A,B,C>` if it were // encountered in the source code. - auto session = entryPoint->compileRequest->mSession; + auto session = linkage->getSession(); auto genericDeclRef = makeDeclRef(genericDecl); // The first pieces is a `VarExpr` that refers to `genericDecl`. @@ -9639,12 +9696,13 @@ namespace Slang // The basic `VarExpr` and `StaticMemberExpr` cases // should be allow-able. - entryPoint->funcDeclRef = declRefExpr->declRef.as<FuncDecl>(); + entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>(); } else if( semantics.IsErrorExpr(checkedExpr) ) { // Any semantic error that occured should have been // reported already. + return DeclRef<FuncDecl>(); } else { @@ -9652,302 +9710,359 @@ namespace Slang // function should always be a `DeclRefExpr` // SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`"); + UNREACHABLE_RETURN(DeclRef<FuncDecl>()); } } - else - { - // The other case is when the entry point function is *not* itself - // generic, so we assume that any generic arguments must have been intended - // to match up with global generic parameters instead. - // - // We will only validate global generic type arguments when we are going - // to generate code, since in a no-codegen pass we will typically *not* - // have arguments to associate with the parameters. - // - if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) - { - // check that user-provioded type arguments conforms to the generic type - // parameter declaration of this translation unit - // collect global generic parameters from all imported modules - List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; - // add current translation unit first - { - auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - // add imported modules - for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) - { - auto moduleDecl = loadedModule->moduleDecl; - auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); - for (auto p : globalGenParams) - globalGenericParams.Add(p); - } - - if (globalGenericParams.Count() != genericArgs.Count()) - { - sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::mismatchEntryPointTypeArgument, - globalGenericParams.Count(), - genericArgs.Count()); - return; - } - - // We have an appropriate number of arguments for the global generic parameters, - // and now we need to check that the arguments conform to the declared constraints. - // - // Along the way, we will build up an appropriate set of substitutions to represent - // the generic arguments and their conformances. - // - RefPtr<Substitutions> globalGenericSubsts; - auto globalGenericSubstLink = &globalGenericSubsts; - // - // TODO: There is a serious flaw to this checking logic if we ever have cases where - // the constraints on one `type_param` can depend on another `type_param`, e.g.: - // - // type_param A; - // type_param B : ISidekick<A>; - // - // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to - // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being - // set to `Batman` to know whether the setting for `B` is valid. In this limit - // the constraints can be mutually recursive (so `A : IMentor<B>`). - // - // The only way to check things correctly is to validate each conformance under - // a set of assumptions (substitutions) that includes all the type substitutions, - // and possibly also all the other constraints *except* the one to be validated. - // - // We will punt on this for now, and just check each constraint in isolation. - // - UInt argCounter = 0; - for(auto& globalGenericParam : globalGenericParams) - { - // Get the argument that matches this parameter. - UInt argIndex = argCounter++; - SLANG_ASSERT(argIndex < genericArgs.Count()); - auto globalGenericArg = checkProperType(translationUnit, TypeExp(genericArgs[argIndex])); - if (!globalGenericArg) - { - sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, entryPoint->genericArgStrings[argIndex]); - return; - } + return entryPointFuncDeclRef; + } - // As a quick sanity check, see if the argument that is being supplied for a parameter - // is just the parameter itself, because this should always be an error: - // - if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() ) - { - auto argDeclRef = argDeclRefType->declRef; - if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) - { - if(argGenericParamDeclRef.getDecl() == globalGenericParam) - { - // We are trying to specialize a generic parameter using itself. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToItself, - globalGenericParam->getName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; - } - else - { - // We are trying to specialize a generic parameter using a *different* - // global generic type parameter. - sink->diagnose(globalGenericParam, - Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, - globalGenericParam->getName(), - argGenericParamDeclRef.GetName()); - sink->diagnose(entryPointFuncDecl, - Diagnostics::noteWhenCompilingEntryPoint, - entryPointFuncDecl->getName()); - continue; - } - } - } + /// Parse an array of strings as generic arguments. + /// + /// Names in the strings will be parsed in the context of + /// the code loaded into the given compile request. + /// + void parseGenericArgStrings( + EndToEndCompileRequest* endToEndReq, + List<String> const& genericArgStrings, + List<RefPtr<Expr>>& outGenericArgs) + { + auto unspecialiedProgram = endToEndReq->getUnspecializedProgram(); - // Create a substitution for this parameter/argument. - RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); - subst->paramDecl = globalGenericParam; - subst->actualType = globalGenericArg; + // TODO: Building a list of `scopesToTry` here shouldn't + // be required, since the `Scope` type itself has the ability + // for form chains for lookup purposes (e.g., the way that + // `import` is handled by modifying a scope). + // + List<RefPtr<Scope>> scopesToTry; + for( auto module : unspecialiedProgram->getModuleDependencies() ) + scopesToTry.Add(module->getModuleDecl()->scope); - // Walk through the declared constraints for the parameter, - // and check that the argument actually satisfies them. - for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) - { - // Get the type that the constraint is enforcing conformance to - auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + // We are going to do some semantic checking, so we need to + // set up a `SemanticsVistitor` that we can use. + // + auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); + SemanticsVisitor semantics( + linkage, + sink); - // Use our semantic-checking logic to search for a witness to the required conformance - SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); - auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); - if (!witness) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(globalGenericParam, - Diagnostics::typeArgumentDoesNotConformToInterface, - globalGenericParam->nameAndLoc.name, - globalGenericArg, - interfaceType); - } + // We will be looping over the generic argument strings + // that the user provided via the API (or command line), + // and parsing+checking each into an `Expr`. + // + // This loop will *not* handle coercing the arguments + // to be types. + // + for(auto name : genericArgStrings) + { + RefPtr<Expr> argExpr; + for (auto & s : scopesToTry) + { + argExpr = linkage->parseTypeString(name, s); + argExpr = semantics.CheckTerm(argExpr); + if( argExpr ) + { + break; + } + } - // Attach the concrete witness for this conformance to the - // substutiton - GlobalGenericParamSubstitution::ConstraintArg constraintArg; - constraintArg.decl = constraint; - constraintArg.val = witness; - subst->constraintArgs.Add(constraintArg); - } + outGenericArgs.Add(argExpr); + } + } - // Add the substitution for this parameter to the global substitution - // set that we are building. + /// Specialize a program to global generic arguments + RefPtr<Program> createSpecializedProgram( + Linkage* linkage, + Program* unspecializedProgram, + List<RefPtr<Expr>> const& globalGenericArgs, + DiagnosticSink* sink) + { + // The given `unspecializedProgram` should be one that + // was checked through the front-end, so that now we + // only need to check if the given arguments can satisfy + // the requirements of the global generic parameters. + // + // The new program needs to start off with the same + // module dependency list as the original. + // + RefPtr<Program> specializedProgram = new Program(linkage); + for(auto module : unspecializedProgram->getModuleDependencies()) + { + specializedProgram->addReferencedLeafModule(module); + } - *globalGenericSubstLink = subst; - globalGenericSubstLink = &subst->outer; - } - entryPoint->globalGenericSubst = globalGenericSubsts; - } + // We will collect all the global generic parameters + // defined in the modules being referenced, to find + // the global generic parameter signature of the + // program. + // + // TODO: Note that this doesn't handle the case where one + // or more of the type *arguments* that we are specifying + // ends up requiring additional modules to be referenced, + // which might in turn introduce new global generic parameters. + // + List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; + for(auto module : unspecializedProgram->getModuleDependencies()) + { + for(auto param : module->getModuleDecl()->getMembersOfType<GlobalGenericParamDecl>()) + globalGenericParams.Add(param); } - // If any errors occured while we were checking the generic arguments - // of the entry point, then we should bail out rather than try to - // perform the next step of validation. + // Next, we will check whether the supplied arguments can + // satisfy those parameters. + // + // An easy early-out case will be if the number of + // arguments isn't correct. // - if (sink->errorCount != 0) - return; + if (globalGenericParams.Count() != globalGenericArgs.Count()) + { + sink->diagnose(SourceLoc(), Diagnostics::mismatchGlobalGenericArguments, + globalGenericParams.Count(), + globalGenericArgs.Count()); + return nullptr; + } - // Now that we've *found* the entry point, it is time to validate - // that it actually meets the constraints for the chosen stage/profile. + // We have an appropriate number of arguments for the global generic parameters, + // and now we need to check that the arguments conform to the declared constraints. // - // TODO: This validation should (probably?) be performed "under" any global generic - // parameter substitution we might have created, so that we can validate - // based on knowledge of actual types. + // Along the way, we will build up an appropriate set of substitutions to represent + // the generic arguments and their conformances. // - validateEntryPoint(entryPoint); - } - - void validateEntryPoints( - CompileRequest* compileRequest) - { - // The validation of entry points here will be modal, and controlled - // by whether the user specified any entry points directly via - // API or command-line options. + RefPtr<Substitutions> globalGenericSubsts; + auto globalGenericSubstLink = &globalGenericSubsts; // - // TODO: We may want to make this choice explicit rather than implicit. + // TODO: There is a serious flaw to this checking logic if we ever have cases where + // the constraints on one `type_param` can depend on another `type_param`, e.g.: // - // First, check if the user request any entry points explicitly via - // the API or command line. + // type_param A; + // type_param B : ISidekick<A>; + // + // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to + // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being + // set to `Batman` to know whether the setting for `B` is valid. In this limit + // the constraints can be mutually recursive (so `A : IMentor<B>`). // - bool anyExplicitEntryPointRequests = false; - for (auto& translationUnit : compileRequest->translationUnits) + // The only way to check things correctly is to validate each conformance under + // a set of assumptions (substitutions) that includes all the type substitutions, + // and possibly also all the other constraints *except* the one to be validated. + // + // We will punt on this for now, and just check each constraint in isolation. + // + UInt argCounter = 0; + for(auto& globalGenericParam : globalGenericParams) { - if( translationUnit->entryPoints.Count() != 0) + // Get the argument that matches this parameter. + UInt argIndex = argCounter++; + SLANG_ASSERT(argIndex < globalGenericArgs.Count()); + auto globalGenericArg = checkProperType(linkage, TypeExp(globalGenericArgs[argIndex]), sink); + if (!globalGenericArg) { - anyExplicitEntryPointRequests = true; - break; + sink->diagnose(globalGenericParam, Diagnostics::globalGenericArgumentNotAType, globalGenericParam->getName()); + return nullptr; } - } - if( anyExplicitEntryPointRequests ) - { - // If there were any explicit requests for entry points to be - // checked, then we will *only* check those. - - for (auto& translationUnit : compileRequest->translationUnits) + // As a quick sanity check, see if the argument that is being supplied for a parameter + // is just the parameter itself, because this should always be an error: + // + if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() ) { - for (auto entryPoint : translationUnit->entryPoints) + auto argDeclRef = argDeclRefType->declRef; + if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>()) { - findAndValidateEntryPoint(entryPoint); + if(argGenericParamDeclRef.getDecl() == globalGenericParam) + { + // We are trying to specialize a generic parameter using itself. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToItself, + globalGenericParam->getName()); + continue; + } + else + { + // We are trying to specialize a generic parameter using a *different* + // global generic type parameter. + sink->diagnose(globalGenericParam, + Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam, + globalGenericParam->getName(), + argGenericParamDeclRef.GetName()); + continue; + } } } - } - else - { - // Otherwise, scan for any `[shader(...)]` attributes in - // the user's code, and construct `EntryPointRequest`s to - // represent them. - // - // This ensures that downstream code only has to consider - // the central list of entry point requests, and doesn't - // have to know where they came from. - // TODO: A comprehensive approach here would need to search - // recursively for entry points, because they might appear - // as, e.g., member function of a `struct` type. - // - // For now we'll start with an extremely basic approach that - // should work for typical HLSL code. - // - UInt translationUnitCount = compileRequest->translationUnits.Count(); - for(UInt tt = 0; tt < translationUnitCount; ++tt) + // Create a substitution for this parameter/argument. + RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution(); + subst->paramDecl = globalGenericParam; + subst->actualType = globalGenericArg; + + // Walk through the declared constraints for the parameter, + // and check that the argument actually satisfies them. + for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>()) { - auto translationUnit = compileRequest->translationUnits[tt]; - for( auto globalDecl : translationUnit->SyntaxNode->Members ) + // Get the type that the constraint is enforcing conformance to + auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + + // Use our semantic-checking logic to search for a witness to the required conformance + SemanticsVisitor visitor(linkage, sink); + auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType); + if (!witness) { - auto maybeFuncDecl = globalDecl; - if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) ) - { - maybeFuncDecl = genericDecl->inner; - } + // If no witness was found, then we will be unable to satisfy + // the conformances required. + sink->diagnose(globalGenericParam, + Diagnostics::typeArgumentDoesNotConformToInterface, + globalGenericParam->nameAndLoc.name, + globalGenericArg, + interfaceType); + } - auto funcDecl = as<FuncDecl>(maybeFuncDecl); - if(!funcDecl) - continue; + // Attach the concrete witness for this conformance to the + // substutiton + GlobalGenericParamSubstitution::ConstraintArg constraintArg; + constraintArg.decl = constraint; + constraintArg.val = witness; + subst->constraintArgs.Add(constraintArg); + } - auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>(); - if(!entryPointAttr) - continue; + // Add the substitution for this parameter to the global substitution + // set that we are building. - // We've discovered a valid entry point. It is a function (possibly - // generic) that has a `[shader(...)]` attribute to mark it as an - // entry point. - // - // We will now register that entry point as an `EntryPointRequest` - // with an appropriately chosen profile. - // - // The profile will only include a stage, so that the profile "family" - // and "version" are left unspecified. Downstream code will need - // to be able to handle this case. - // - Profile profile; - profile.setStage(entryPointAttr->stage); + *globalGenericSubstLink = subst; + globalGenericSubstLink = &subst->outer; + } + if(sink->GetErrorCount()) + return nullptr; - // We manually fill in the entry point request object. - RefPtr<EntryPointRequest> entryPointReq = new EntryPointRequest(); - entryPointReq->compileRequest = compileRequest; - entryPointReq->translationUnitIndex = int(tt); - entryPointReq->funcDeclRef = makeDeclRef(funcDecl); - entryPointReq->name = funcDecl->getName(); - entryPointReq->profile = profile; + specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts); - // Apply the common validation logic to this entry point. - validateEntryPoint(entryPointReq); + return specializedProgram; + } - // Add the entry point to the list in the translation unit, - // and also the global list in the compile request. - translationUnit->entryPoints.Add(entryPointReq); - compileRequest->entryPoints.Add(entryPointReq); - } - } + /// Specialize an entry point that was checked by the front-end, based on generic arguments. + /// + /// If the end-to-end compile request included generic argument strings + /// for this entry point, then they will be parsed, checked, and used + /// as arguments to the generic entry point. + /// + /// Returns a specialized entry point if everything worked as expected. + /// Returns null and diagnoses errors if anything goes wrong. + /// + RefPtr<EntryPoint> specializeEntryPoint( + EndToEndCompileRequest* endToEndReq, + EntryPoint* unspecializedEntryPoint, + EndToEndCompileRequest::EntryPointInfo const& entryPointInfo) + { + auto linkage = endToEndReq->getLinkage(); + auto sink = endToEndReq->getSink(); + auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl(); + + // If the user specified generic arguments for the entry point, + // then we will need to parse the arguments first. + // + List<RefPtr<Expr>> genericArgs; + parseGenericArgStrings( + endToEndReq, + entryPointInfo.genericArgStrings, + genericArgs); + + // Next we specialize the entry point function given the parsed + // generic argument expressions. + // + auto entryPointFuncDeclRef = specializeEntryPoint( + linkage, + entryPointFuncDecl, + genericArgs, + sink); + + RefPtr<EntryPoint> entryPoint = EntryPoint::create( + entryPointFuncDeclRef, + unspecializedEntryPoint->getProfile()); + + return entryPoint; + } + + /// Create a specialized program based on the given compile request. + /// + RefPtr<Program> createSpecializedProgram( + EndToEndCompileRequest* endToEndReq) + { + // The compile request must have already completed front-end processing, + // so that we have an unspecialized program available, and now only need + // to parse and check any generic arguments that are being supplied for + // global or entry-point generic parameters. + // + auto unspecializedProgram = endToEndReq->getUnspecializedProgram(); + + // First, let's parse the generic argument strings that were + // provided via the API, so taht we can match them + // against what was declared in the program. + // + List<RefPtr<Expr>> globalGenericArgs; + parseGenericArgStrings( + endToEndReq, + endToEndReq->globalGenericArgStrings, + globalGenericArgs); + + // Now we create the initial specialized program by + // applying the global generic arguments (if any) to the + // unspecialized program. + // + auto specializedProgram = createSpecializedProgram( + endToEndReq->getLinkage(), + unspecializedProgram, + globalGenericArgs, + endToEndReq->getSink()); + + // If anything went wrong with the global generic + // arguments, then bail out now. + // + if(!specializedProgram) + return nullptr; + + // Next we will deal with the entry points for the + // new specialized program. + // + // If the user specified explicit entry points as part of the + // end-to-end request, then we only want to process those (and + // ignore any other `[shader(...)]`-attributed entry points). + // + // However, if the user specified *no* entry points as part + // of the end-to-end request, then we would like to go + // ahead and consider all the entry points that were found + // by the front-end. + // + UInt entryPointCount = endToEndReq->entryPoints.Count(); + if( entryPointCount == 0 ) + { + entryPointCount = unspecializedProgram->getEntryPointCount(); + endToEndReq->entryPoints.SetSize(entryPointCount); } + + for( UInt ii = 0; ii < entryPointCount; ++ii ) + { + auto unspecializedEntryPoint = unspecializedProgram->getEntryPoint(ii); + auto& entryPointInfo = endToEndReq->entryPoints[ii]; + + auto specializedEntryPoint = specializeEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo); + specializedProgram->addEntryPoint(specializedEntryPoint); + } + + return specializedProgram; } void checkTranslationUnit( TranslationUnitRequest* translationUnit) { SemanticsVisitor visitor( - &translationUnit->compileRequest->mSink, - translationUnit->compileRequest, - translationUnit); + translationUnit->compileRequest->getLinkage(), + translationUnit->compileRequest->getSink()); // Apply the visitor to do the main semantic // checking that is required on all declarations // in the translation unit. - visitor.checkDecl(translationUnit->SyntaxNode); + visitor.checkDecl(translationUnit->getModuleDecl()); } diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp index 21f56c9ee..3bc34692d 100644 --- a/source/slang/compiler.cpp +++ b/source/slang/compiler.cpp @@ -120,23 +120,123 @@ namespace Slang return blob; } - // EntryPointRequest + // + // FrontEndEntryPointRequest + // + + FrontEndEntryPointRequest::FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile) + : m_compileRequest(compileRequest) + , m_translationUnitIndex(translationUnitIndex) + , m_name(name) + , m_profile(profile) + {} - TranslationUnitRequest* EntryPointRequest::getTranslationUnit() + + TranslationUnitRequest* FrontEndEntryPointRequest::getTranslationUnit() { - return compileRequest->translationUnits[translationUnitIndex].Ptr(); + return getCompileRequest()->translationUnits[m_translationUnitIndex]; } - DeclRef<FuncDecl> EntryPointRequest::getFuncDeclRef() + // + // EntryPoint + // + + RefPtr<EntryPoint> EntryPoint::create( + DeclRef<FuncDecl> funcDeclRef, + Profile profile) { - return funcDeclRef; + RefPtr<EntryPoint> entryPoint = new EntryPoint( + funcDeclRef.GetName(), + profile, + funcDeclRef); + return entryPoint; } - RefPtr<FuncDecl> EntryPointRequest::getFuncDecl() + RefPtr<EntryPoint> EntryPoint::createDummyForPassThrough( + Name* name, + Profile profile) { - return getFuncDeclRef().getDecl(); + RefPtr<EntryPoint> entryPoint = new EntryPoint( + name, + profile, + DeclRef<FuncDecl>()); + return entryPoint; } + EntryPoint::EntryPoint( + Name* name, + Profile profile, + DeclRef<FuncDecl> funcDeclRef) + : m_name(name) + , m_profile(profile) + , m_funcDeclRef(funcDeclRef) + { + // In order for later code generation to work, we need to track what + // modules each entry point depends on. We will build up the dependency + // list here when an `EntryPoint` gets created. + // + // We know an entry point depends on the module that declared the + // entry-point function itself. + // + // Note: we are carefully handling the case where `module` could + // be null, becase of "dummy" entry points created for pass-through + // compilation. + // + if(auto module = getModule()) + { + m_dependencyList.addDependency(module); + } + // + // TODO: We also need to include the modules needed by any generic + // arguments in the dependency list, since in the general case they + // might come from modules other than the one defining the entry point. + + // The following is a bit of a hack. + // + // Back-end code generation relies on us having computed layouts for all tagged + // unions that end up being used in the code, which means we need a way to find + // all such types that get used in a program (and the stuff it imports). + // + // For now we are assuming a tagged union type only comes into existence + // as a (top-level) argument for a generic type parameter, so that we + // can check for them here and cache them on the entry point. + // + // A longer-term strategy might need to consider any (tagged or untagged) + // union types that get used inside of a module, and also take + // those lists into account. + // + // An even longer-term strategy would be to allow type layout to + // be performed on IR types, so taht we don't need to have front-end + // code worrying about this stuff. + // + for( auto subst = funcDeclRef.substitutions.substitutions; subst; subst = subst->outer ) + { + if( auto genericSubst = as<GenericSubstitution>(subst) ) + { + for( auto arg : genericSubst->args ) + { + if( auto taggedUnionType = as<TaggedUnionType>(arg) ) + { + m_taggedUnionTypes.Add(taggedUnionType); + } + } + } + } + } + + Module* EntryPoint::getModule() + { + return Slang::getModule(getFuncDecl()); + } + + Linkage* EntryPoint::getLinkage() + { + return getModule()->getLinkage(); + } // @@ -279,13 +379,35 @@ namespace Slang // + /// If there is a pass-through compile going on, find the translation unit for the given entry point. + TranslationUnitRequest* findPassThroughTranslationUnit( + EndToEndCompileRequest* endToEndReq, + Int entryPointIndex) + { + // If there isn't an end-to-end compile going on, + // there can be no pass-through. + // + if(!endToEndReq) return nullptr; + + // And if pass-through isn't set, we don't need + // access to the translation unit. + // + if(endToEndReq->passThrough == PassThroughMode::None) return nullptr; + + auto frontEndReq = endToEndReq->getFrontEndReq(); + auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex); + auto translationUnit = entryPointReq->getTranslationUnit(); + return translationUnit; + } + String emitHLSLForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) { - auto compileRequest = entryPoint->compileRequest; - auto translationUnit = entryPoint->getTranslationUnit(); - if (compileRequest->passThrough != PassThroughMode::None) + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) { // Generate a string that includes the content of // the source file(s), along with a line directive @@ -294,7 +416,7 @@ namespace Slang // mode. StringBuilder codeBuilder; - for(auto sourceFile : translationUnit->sourceFiles) + for(auto sourceFile : translationUnit->getSourceFiles()) { codeBuilder << "#line 1 \""; @@ -323,21 +445,21 @@ namespace Slang else { return emitEntryPoint( + compileRequest, entryPoint, - targetReq->layout.Ptr(), CodeGenTarget::HLSL, targetReq); } } String emitGLSLForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) { - auto compileRequest = entryPoint->compileRequest; - auto translationUnit = entryPoint->getTranslationUnit(); - - if (compileRequest->passThrough != PassThroughMode::None) + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) { // Generate a string that includes the content of // the source file(s), along with a line directive @@ -347,7 +469,7 @@ namespace Slang StringBuilder codeBuilder; int translationUnitCounter = 0; - for(auto sourceFile : translationUnit->sourceFiles) + for(auto sourceFile : translationUnit->getSourceFiles()) { int translationUnitIndex = translationUnitCounter++; @@ -370,8 +492,8 @@ namespace Slang // TODO(tfoley): need to pass along the entry point // so that we properly emit it as the `main` function. return emitEntryPoint( + compileRequest, entryPoint, - targetReq->layout.Ptr(), CodeGenTarget::GLSL, targetReq); } @@ -484,9 +606,9 @@ namespace Slang sink->diagnoseRaw(SLANG_FAILED(res) ? Severity::Error : Severity::Warning, builder.getUnownedSlice()); } - static String _getDisplayPath(const DiagnosticSink& sink, SourceFile* sourceFile) + static String _getDisplayPath(DiagnosticSink* sink, SourceFile* sourceFile) { - if (sink.flags & DiagnosticSink::Flag::VerbosePath) + if (sink->flags & DiagnosticSink::Flag::VerbosePath) { return sourceFile->calcVerbosePath(); } @@ -496,17 +618,17 @@ namespace Slang } } - String calcTranslationUnitSourcePath(TranslationUnitRequest* translationUnitRequest) + String calcSourcePathForEntryPoint( + EndToEndCompileRequest* endToEndReq, + UInt entryPointIndex) { - CompileRequest* compileRequest = translationUnitRequest->compileRequest; - if (compileRequest->passThrough == PassThroughMode::None) - { + auto translationUnitRequest = findPassThroughTranslationUnit(endToEndReq, entryPointIndex); + if(!translationUnitRequest) return "slang-generated"; - } - auto& sink = translationUnitRequest->compileRequest->mSink; + auto sink = endToEndReq->getSink(); - const auto& sourceFiles = translationUnitRequest->sourceFiles; + const auto& sourceFiles = translationUnitRequest->getSourceFiles(); const int numSourceFiles = int(sourceFiles.Count()); @@ -542,22 +664,26 @@ namespace Slang } SlangResult emitDXBytecodeForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - List<uint8_t>& byteCodeOut) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List<uint8_t>& byteCodeOut) { byteCodeOut.Clear(); - auto session = entryPoint->compileRequest->mSession; + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); - auto compileFunc = (pD3DCompile)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DCompile, &entryPoint->compileRequest->mSink); + auto compileFunc = (pD3DCompile)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DCompile, sink); if (!compileFunc) { return SLANG_FAIL; } - auto hlslCode = emitHLSLForEntryPoint(entryPoint, targetReq); - maybeDumpIntermediate(entryPoint->compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL); + auto hlslCode = emitHLSLForEntryPoint(compileRequest, entryPoint, entryPointIndex, targetReq, endToEndReq); + maybeDumpIntermediate(compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL); auto profile = getEffectiveProfile(entryPoint, targetReq); @@ -569,16 +695,16 @@ namespace Slang // List<D3D_SHADER_MACRO> dxMacrosStorage; D3D_SHADER_MACRO const* dxMacros = nullptr; - if( entryPoint->compileRequest->passThrough != PassThroughMode::None ) + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) { - for( auto& define : entryPoint->compileRequest->preprocessorDefinitions ) + for( auto& define : translationUnit->compileRequest->preprocessorDefinitions ) { D3D_SHADER_MACRO dxMacro; dxMacro.Name = define.Key.Buffer(); dxMacro.Definition = define.Value.Buffer(); dxMacrosStorage.Add(dxMacro); } - for( auto& define : entryPoint->getTranslationUnit()->preprocessorDefinitions ) + for( auto& define : translationUnit->preprocessorDefinitions ) { D3D_SHADER_MACRO dxMacro; dxMacro.Name = define.Key.Buffer(); @@ -616,7 +742,7 @@ namespace Slang flags |= D3DCOMPILE_ENABLE_STRICTNESS; flags |= D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES; - const String sourcePath = calcTranslationUnitSourcePath(entryPoint->getTranslationUnit()); + const String sourcePath = "slang-geneated";// calcTranslationUnitSourcePath(entryPoint->getTranslationUnit()); ComPtr<ID3DBlob> codeBlob; ComPtr<ID3DBlob> diagnosticsBlob; @@ -626,7 +752,7 @@ namespace Slang sourcePath.Buffer(), dxMacros, nullptr, - getText(entryPoint->name).begin(), + getText(entryPoint->getName()).begin(), GetHLSLProfileName(profile).Buffer(), flags, 0, // unused: effect flags @@ -640,23 +766,24 @@ namespace Slang if (FAILED(hr)) { - reportExternalCompileError("fxc", hr, _getSlice(diagnosticsBlob), &entryPoint->compileRequest->mSink); + reportExternalCompileError("fxc", hr, _getSlice(diagnosticsBlob), sink); } return hr; } SlangResult dissassembleDXBC( - CompileRequest* compileRequest, - void const* data, - size_t size, - String& assemOut) + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + String& assemOut) { assemOut = String(); - auto session = compileRequest->mSession; + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); - auto disassembleFunc = (pD3DDisassemble)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DDisassemble, &compileRequest->mSink); + auto disassembleFunc = (pD3DDisassemble)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DDisassemble, sink); if (!disassembleFunc) { return SLANG_E_NOT_FOUND; @@ -677,25 +804,34 @@ namespace Slang if (FAILED(res)) { // TODO(tfoley): need to figure out what to diagnose here... - reportExternalCompileError("fxc", res, UnownedStringSlice(), &compileRequest->mSink); + reportExternalCompileError("fxc", res, UnownedStringSlice(), sink); } return res; } SlangResult emitDXBytecodeAssemblyForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - String& assemOut) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + String& assemOut) { List<uint8_t> dxbc; - SLANG_RETURN_ON_FAIL(emitDXBytecodeForEntryPoint(entryPoint, targetReq, dxbc)); + SLANG_RETURN_ON_FAIL(emitDXBytecodeForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + dxbc)); if (!dxbc.Count()) { return SLANG_FAIL; } - return dissassembleDXBC(entryPoint->compileRequest, dxbc.Buffer(), dxbc.Count(), assemOut); + return dissassembleDXBC(compileRequest, dxbc.Buffer(), dxbc.Count(), assemOut); } #endif @@ -704,26 +840,30 @@ namespace Slang // Implementations in `dxc-support.cpp` int emitDXILForEntryPointUsingDXC( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - List<uint8_t>& outCode); + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List<uint8_t>& outCode); SlangResult dissassembleDXILUsingDXC( - CompileRequest* compileRequest, - void const* data, - size_t size, - String& stringOut); + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + String& stringOut); #endif #if SLANG_ENABLE_GLSLANG_SUPPORT SlangResult invokeGLSLCompiler( - CompileRequest* slangCompileRequest, + BackEndCompileRequest* slangCompileRequest, glslang_CompileRequest& request) { - Session* session = slangCompileRequest->mSession; + Session* session = slangCompileRequest->getSession(); + auto sink = slangCompileRequest->getSink(); - auto glslang_compile = (glslang_CompileFunc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Glslang_Compile, &slangCompileRequest->mSink); + auto glslang_compile = (glslang_CompileFunc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Glslang_Compile, sink); if (!glslang_compile) { return SLANG_FAIL; @@ -743,7 +883,7 @@ SlangResult dissassembleDXILUsingDXC( if (err) { - reportExternalCompileError("glslang", SLANG_FAIL, diagnosticOutput.getUnownedSlice(), &slangCompileRequest->mSink); + reportExternalCompileError("glslang", SLANG_FAIL, diagnosticOutput.getUnownedSlice(), sink); return SLANG_FAIL; } @@ -751,10 +891,10 @@ SlangResult dissassembleDXILUsingDXC( } SlangResult dissassembleSPIRV( - CompileRequest* slangRequest, - void const* data, - size_t size, - String& stringOut) + BackEndCompileRequest* slangRequest, + void const* data, + size_t size, + String& stringOut) { stringOut = String(); @@ -782,21 +922,29 @@ SlangResult dissassembleDXILUsingDXC( } SlangResult emitSPIRVForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - List<uint8_t>& spirvOut) + BackEndCompileRequest* slangRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List<uint8_t>& spirvOut) { spirvOut.Clear(); - String rawGLSL = emitGLSLForEntryPoint(entryPoint, targetReq); - maybeDumpIntermediate(entryPoint->compileRequest, rawGLSL.Buffer(), CodeGenTarget::GLSL); + String rawGLSL = emitGLSLForEntryPoint( + slangRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); + maybeDumpIntermediate(slangRequest, rawGLSL.Buffer(), CodeGenTarget::GLSL); auto outputFunc = [](void const* data, size_t size, void* userData) { ((List<uint8_t>*)userData)->AddRange((uint8_t*)data, size); }; - const String sourcePath = calcTranslationUnitSourcePath(entryPoint->getTranslationUnit()); + const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); glslang_CompileRequest request; request.action = GLSLANG_ACTION_COMPILE_GLSL_TO_SPIRV; @@ -809,40 +957,56 @@ SlangResult dissassembleDXILUsingDXC( request.outputFunc = outputFunc; request.outputUserData = &spirvOut; - SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(entryPoint->compileRequest, request)); + SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(slangRequest, request)); return SLANG_OK; } SlangResult emitSPIRVAssemblyForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - String& assemblyOut) + BackEndCompileRequest* slangRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + String& assemblyOut) { List<uint8_t> spirv; - SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoint(entryPoint, targetReq, spirv)); + SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoint( + slangRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + spirv)); if (spirv.Count() == 0) return SLANG_FAIL; - return dissassembleSPIRV(entryPoint->compileRequest, spirv.begin(), spirv.Count(), assemblyOut); + return dissassembleSPIRV(slangRequest, spirv.begin(), spirv.Count(), assemblyOut); } #endif // Do emit logic for a single entry point CompileResult emitEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) { CompileResult result; - auto compileRequest = entryPoint->compileRequest; auto target = targetReq->target; switch (target) { case CodeGenTarget::HLSL: { - String code = emitHLSLForEntryPoint(entryPoint, targetReq); + String code = emitHLSLForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); maybeDumpIntermediate(compileRequest, code.Buffer(), target); result = CompileResult(code); } @@ -850,7 +1014,12 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::GLSL: { - String code = emitGLSLForEntryPoint(entryPoint, targetReq); + String code = emitGLSLForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); maybeDumpIntermediate(compileRequest, code.Buffer(), target); result = CompileResult(code); } @@ -860,7 +1029,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXBytecode: { List<uint8_t> code; - if (SLANG_SUCCEEDED(emitDXBytecodeForEntryPoint(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitDXBytecodeForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target); result = CompileResult(code); @@ -871,7 +1046,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXBytecodeAssembly: { String code; - if (SLANG_SUCCEEDED(emitDXBytecodeAssemblyForEntryPoint(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitDXBytecodeAssemblyForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), target); result = CompileResult(code); @@ -884,7 +1065,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXIL: { List<uint8_t> code; - if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target); result = CompileResult(code); @@ -895,7 +1082,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXILAssembly: { List<uint8_t> code; - if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { String assembly; dissassembleDXILUsingDXC( @@ -915,7 +1108,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::SPIRV: { List<uint8_t> code; - if (SLANG_SUCCEEDED(emitSPIRVForEntryPoint(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitSPIRVForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target); result = CompileResult(code); @@ -926,7 +1125,13 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::SPIRVAssembly: { String code; - if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoint(entryPoint, targetReq, code))) + if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq, + code))) { maybeDumpIntermediate(compileRequest, code.Buffer(), target); result = CompileResult(code); @@ -957,16 +1162,16 @@ SlangResult dissassembleDXILUsingDXC( }; static void writeOutputFile( - CompileRequest* compileRequest, - FILE* file, - String const& path, - void const* data, - size_t size) + BackEndCompileRequest* compileRequest, + FILE* file, + String const& path, + void const* data, + size_t size) { size_t count = fwrite(data, size, 1, file); if (count != 1) { - compileRequest->mSink.diagnose( + compileRequest->getSink()->diagnose( SourceLoc(), Diagnostics::cannotWriteOutputFile, path); @@ -974,16 +1179,16 @@ SlangResult dissassembleDXILUsingDXC( } static void writeOutputFile( - CompileRequest* compileRequest, - ISlangWriter* writer, - String const& path, - void const* data, - size_t size) + BackEndCompileRequest* compileRequest, + ISlangWriter* writer, + String const& path, + void const* data, + size_t size) { if (SLANG_FAILED(writer->write((const char*)data, size))) { - compileRequest->mSink.diagnose( + compileRequest->getSink()->diagnose( SourceLoc(), Diagnostics::cannotWriteOutputFile, path); @@ -991,18 +1196,18 @@ SlangResult dissassembleDXILUsingDXC( } static void writeOutputFile( - CompileRequest* compileRequest, - String const& path, - void const* data, - size_t size, - OutputFileKind kind) + BackEndCompileRequest* compileRequest, + String const& path, + void const* data, + size_t size, + OutputFileKind kind) { FILE* file = fopen( path.Buffer(), kind == OutputFileKind::Binary ? "wb" : "w"); if (!file) { - compileRequest->mSink.diagnose( + compileRequest->getSink()->diagnose( SourceLoc(), Diagnostics::cannotWriteOutputFile, path); @@ -1014,11 +1219,12 @@ SlangResult dissassembleDXILUsingDXC( } static void writeEntryPointResultToFile( - EntryPointRequest* entryPoint, + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, String const& outputPath, CompileResult const& result) { - auto compileRequest = entryPoint->compileRequest; + SLANG_UNUSED(entryPoint); switch (result.format) { @@ -1059,13 +1265,15 @@ SlangResult dissassembleDXILUsingDXC( } static void writeEntryPointResultToStandardOutput( - EntryPointRequest* entryPoint, + EndToEndCompileRequest* compileRequest, + EntryPoint* entryPoint, TargetRequest* targetReq, CompileResult const& result) { - auto compileRequest = entryPoint->compileRequest; + SLANG_UNUSED(entryPoint); ISlangWriter* writer = compileRequest->getWriter(WriterChannel::StdOutput); + auto backEndReq = compileRequest->getBackEndReq(); switch (result.format) { @@ -1087,7 +1295,7 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXBytecode: { String assembly; - dissassembleDXBC(compileRequest, + dissassembleDXBC(backEndReq, data.begin(), data.end() - data.begin(), assembly); writeOutputToConsole(writer, assembly); @@ -1099,7 +1307,7 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXIL: { String assembly; - dissassembleDXILUsingDXC(compileRequest, + dissassembleDXILUsingDXC(backEndReq, data.begin(), data.end() - data.begin(), assembly); @@ -1111,7 +1319,7 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::SPIRV: { String assembly; - dissassembleSPIRV(compileRequest, + dissassembleSPIRV(backEndReq, data.begin(), data.end() - data.begin(), assembly); writeOutputToConsole(writer, assembly); @@ -1129,7 +1337,7 @@ SlangResult dissassembleDXILUsingDXC( writer->setMode(SLANG_WRITER_MODE_BINARY); writeOutputFile( - compileRequest, + backEndReq, writer, "stdout", data.begin(), @@ -1146,89 +1354,108 @@ SlangResult dissassembleDXILUsingDXC( } static void writeEntryPointResult( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - UInt entryPointIndex) + EndToEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + TargetRequest* targetReq, + Int entryPointIndex) { - // It is possible that we are dynamically discovering entry - // points (using `[shader(...)]` attributes), so that the - // number of entry points on the compile request does not - // match the number of entries in the `entryPointOutputPaths` - // array. - // - String outputPath; - if( entryPointIndex < targetReq->entryPointOutputPaths.Count() ) - { - outputPath = targetReq->entryPointOutputPaths[entryPointIndex]; - } + auto program = compileRequest->getSpecializedProgram(); + auto targetProgram = program->getTargetProgram(targetReq); + auto backEndReq = compileRequest->getBackEndReq(); - auto& result = targetReq->entryPointResults[entryPointIndex]; + auto& result = targetProgram->getExistingEntryPointResult(entryPointIndex); // Skip the case with no output if (result.format == ResultFormat::None) return; - if (outputPath.Length()) - { - writeEntryPointResultToFile(entryPoint, outputPath, result); - } - else + // It is possible that we are dynamically discovering entry + // points (using `[shader(...)]` attributes), so that there + // might be entry points added to the program that did not + // get paths specified via command-line options. + // + RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; + if(compileRequest->targetInfos.TryGetValue(targetReq, targetInfo)) { - writeEntryPointResultToStandardOutput(entryPoint, targetReq, result); + String outputPath; + if(targetInfo->entryPointOutputPaths.TryGetValue(entryPointIndex, outputPath)) + { + writeEntryPointResultToFile(backEndReq, entryPoint, outputPath, result); + return; + } } + + writeEntryPointResultToStandardOutput(compileRequest, entryPoint, targetReq, result); } void generateOutputForTarget( - TargetRequest* targetReq) + BackEndCompileRequest* compileReq, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) { - CompileRequest* compileReq = targetReq->compileRequest; + auto program = compileReq->getProgram(); + auto targetProgram = program->getTargetProgram(targetReq); // Generate target code any entry points that // have been requested for compilation. - for (auto& entryPoint : compileReq->entryPoints) + auto entryPointCount = program->getEntryPointCount(); + for(UInt ii = 0; ii < entryPointCount; ++ii) { - CompileResult entryPointResult = emitEntryPoint(entryPoint, targetReq); - targetReq->entryPointResults.Add(entryPointResult); + auto entryPoint = program->getEntryPoint(ii); + CompileResult entryPointResult = emitEntryPoint( + compileReq, + entryPoint, + ii, + targetReq, + endToEndReq); + targetProgram->setEntryPointResult(ii, entryPointResult); } } - void generateOutput( - CompileRequest* compileRequest) + static void _generateOutput( + BackEndCompileRequest* compileRequest, + EndToEndCompileRequest* endToEndReq) { // Go through the code-generation targets that the user // has specified, and generate code for each of them. // - for (auto targetReq : compileRequest->targets) + auto linkage = compileRequest->getLinkage(); + for (auto targetReq : linkage->targets) { - generateOutputForTarget(targetReq); + generateOutputForTarget(compileRequest, targetReq, endToEndReq); } + } + + void generateOutput( + BackEndCompileRequest* compileRequest) + { + _generateOutput(compileRequest, nullptr); + } + + void generateOutput( + EndToEndCompileRequest* compileRequest) + { + _generateOutput(compileRequest->getBackEndReq(), compileRequest); // If we are in command-line mode, we might be expected to actually // write output to one or more files here. if (compileRequest->isCommandLineCompile) { - for (auto targetReq : compileRequest->targets) + auto linkage = compileRequest->getLinkage(); + auto program = compileRequest->getSpecializedProgram(); + for (auto targetReq : linkage->targets) { - UInt entryPointCount = compileRequest->entryPoints.Count(); + UInt entryPointCount = program->getEntryPointCount(); for (UInt ee = 0; ee < entryPointCount; ++ee) { writeEntryPointResult( - compileRequest->entryPoints[ee], + compileRequest, + program->getEntryPoint(ee), targetReq, ee); } } - - if (compileRequest->containerOutputPath.Length() != 0) - { - auto& data = compileRequest->generatedBytecode; - writeOutputFile(compileRequest, - compileRequest->containerOutputPath, - data.begin(), - data.end() - data.begin(), - OutputFileKind::Binary); - } } } @@ -1237,7 +1464,7 @@ SlangResult dissassembleDXILUsingDXC( // void dumpIntermediate( - CompileRequest*, + BackEndCompileRequest*, void const* data, size_t size, char const* ext, @@ -1271,7 +1498,7 @@ SlangResult dissassembleDXILUsingDXC( } void dumpIntermediateText( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, void const* data, size_t size, char const* ext) @@ -1280,7 +1507,7 @@ SlangResult dissassembleDXILUsingDXC( } void dumpIntermediateBinary( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, void const* data, size_t size, char const* ext) @@ -1289,7 +1516,7 @@ SlangResult dissassembleDXILUsingDXC( } void maybeDumpIntermediate( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, void const* data, size_t size, CodeGenTarget target) @@ -1362,7 +1589,7 @@ SlangResult dissassembleDXILUsingDXC( } void maybeDumpIntermediate( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, char const* text, CodeGenTarget target) { diff --git a/source/slang/compiler.h b/source/slang/compiler.h index 39199a62f..c975c1c2b 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -20,6 +20,8 @@ namespace Slang class CompileRequest; class ProgramLayout; class PtrType; + class TargetProgram; + class TargetRequest; class TypeLayout; enum class CompilerMode @@ -88,8 +90,12 @@ namespace Slang kMatrixLayoutMode_ColumnMajor = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR, }; - - class CompileRequest; + class Linkage; + class Module; + class Program; + class FrontEndCompileRequest; + class BackEndCompileRequest; + class EndToEndCompileRequest; class TranslationUnitRequest; // Result of compiling an entry point. @@ -112,18 +118,165 @@ namespace Slang ComPtr<ISlangBlob> blob; }; - // Describes an entry point that we've been requested to compile - class EntryPointRequest : public RefObject + /// A request for the front-end to find and validate an entry-point function + struct FrontEndEntryPointRequest : RefObject { public: + /// Create a request for an entry point. + FrontEndEntryPointRequest( + FrontEndCompileRequest* compileRequest, + int translationUnitIndex, + Name* name, + Profile profile); + + /// Get the parent front-end compile request. + FrontEndCompileRequest* getCompileRequest() { return m_compileRequest; } + + /// Get the translation unit that contains the entry point. + TranslationUnitRequest* getTranslationUnit(); + + /// Get the name of the entry point to find. + Name* getName() { return m_name; } + + /// Get the stage that the entry point is to be compiled for + Stage getStage() { return m_profile.GetStage(); } + + /// Get the profile that the entry point is to be compiled for + Profile getProfile() { return m_profile; } + + private: // The parent compile request - CompileRequest* compileRequest = nullptr; + FrontEndCompileRequest* m_compileRequest; + + // The index of the translation unit that will hold the entry point + int m_translationUnitIndex; + + // The name of the entry point function to look for + Name* m_name; + + // The profile to compile for (including stage) + Profile m_profile; + }; + + /// Tracks an ordered list of modules that something depends on. + struct ModuleDependencyList + { + public: + /// Get the list of modules that are depended on. + List<RefPtr<Module>> const& getModuleList() { return m_moduleList; } + + /// Add a module and everything it depends on to the list. + void addDependency(Module* module); + + /// Add a module to the list, but not the modules it depends on. + void addLeafDependency(Module* module); + + private: + void _addDependency(Module* module); + + List<RefPtr<Module>> m_moduleList; + HashSet<Module*> m_moduleSet; + }; + + /// Tracks an unordered list of filesystem paths that something depends on + struct FilePathDependencyList + { + public: + /// Get the list of paths that are depended on. + List<String> const& getFilePathList() { return m_filePathList; } + + /// Add a path to the list, if it is not already present + void addDependency(String const& path); + + /// Add all of the paths that `module` depends on to the list + void addDependency(Module* module); + + private: + + // TODO: We are using a `HashSet` here to deduplicate + // the paths so that we don't return the same path + // multiple times from `getFilePathList`, but because + // order isn't important, we could potentially do better + // in terms of memory (at some cost in performance) by + // just sorting the `m_filePathList` every once in + // a while and then deduplicating. + + List<String> m_filePathList; + HashSet<String> m_filePathSet; + }; + + /// Describes an entry point for the purposes of layout and code generation. + /// + /// This class also tracks any generic arguments to the entry point, + /// in the case that it is a specialization of a generic entry point. + /// + /// There is also a provision for creating a "dummy" entry point for + /// the purposes of pass-through compilation modes. Only the + /// `getName()` and `getProfile()` methods should be expected to + /// return useful data on pass-through entry points. + /// + class EntryPoint : public RefObject + { + public: + /// Create an entry point that refers to the given function. + static RefPtr<EntryPoint> create( + DeclRef<FuncDecl> funcDeclRef, + Profile profile); + + /// Get the function decl-ref, including any generic arguments. + DeclRef<FuncDecl> getFuncDeclRef() { return m_funcDeclRef; } + + /// Get the function declaration (without generic arguments). + RefPtr<FuncDecl> getFuncDecl() { return m_funcDeclRef.getDecl(); } + + /// Get the name of the entry point + Name* getName() { return m_name; } + + /// Get the profile associated with the entry point + /// + /// Note: only the stage part of the profile is expected + /// to contain useful data, but certain legacy code paths + /// allow for "shader model" information to come via this path. + /// + Profile getProfile() { return m_profile; } + + /// Get the stage that the entry point is for. + Stage getStage() { return m_profile.GetStage(); } + + /// Get the module that contains the entry point. + Module* getModule(); + + /// Get the linkage that contains the module for this entry pooint. + Linkage* getLinkage(); + + /// Get a list of modules that this entry point depends on. + /// + /// This will include the module that defines the entry point (see `getModule()`), + /// but may also include modules that are required by its generic type arguments. + /// + List<RefPtr<Module>> getModuleDependencies() { return m_dependencyList.getModuleList(); } + + /// Get a list of tagged-union types referenced by the entry point's generic parameters. + List<RefPtr<TaggedUnionType>> const& getTaggedUnionTypes() { return m_taggedUnionTypes; } + + /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. + static RefPtr<EntryPoint> createDummyForPassThrough( + Name* name, + Profile profile); + + private: + EntryPoint( + Name* name, + Profile profile, + DeclRef<FuncDecl> funcDeclRef); // The name of the entry point function (e.g., `main`) - Name* name; + // + Name* m_name = nullptr; - /// Source code for the generic arguments to use for the generic parameters of the entry point. - List<String> genericArgStrings; + // The declaration of the entry-point function itself. + // + DeclRef<FuncDecl> m_funcDeclRef; // The profile that the entry point will be compiled for // (this is a combination of the target stage, and also @@ -135,33 +288,13 @@ namespace Slang // from the target, while the stage part is all that is // intrinsic to the entry point. // - Profile profile; + Profile m_profile; - // Get the stage that the entry point is being compiled for. - Stage getStage() { return profile.GetStage(); } + // Any tagged union types that were referenced by the generic arguments of the entry point. + List<RefPtr<TaggedUnionType>> m_taggedUnionTypes; - // The index of the translation unit (within the parent - // compile request) that the entry point function is - // supposed to be defined in. - int translationUnitIndex; - - // The translation unit that this entry point came from - TranslationUnitRequest* getTranslationUnit(); - - // The declaration of the entry-point function itself. - // This will be filled in as part of semantic analysis; - // it should not be assumed to be available in cases - // where any errors were diagnosed. - // - DeclRef<FuncDecl> funcDeclRef; - - DeclRef<FuncDecl> getFuncDeclRef(); - RefPtr<FuncDecl> getFuncDecl(); - - RefPtr<Substitutions> globalGenericSubst; - - /// Any tagged union types that were referenced by the generic arguments of the entry point. - List<RefPtr<TaggedUnionType>> taggedUnionTypes; + // Modules the entry point depends on. + ModuleDependencyList m_dependencyList; }; enum class PassThroughMode : SlangPassThrough @@ -174,13 +307,78 @@ namespace Slang class SourceFile; - // A single translation unit requested to be compiled. - // + /// A module of code that has been compiled through the front-end + /// + /// A module comprises all the code from one translation unit (which + /// may span multiple Slang source files), and provides access + /// to both the AST and IR representations of that code. + /// + class Module : public RefObject + { + public: + /// Create a module (initially empty). + Module(Linkage* linkage); + + /// Get the parent linkage of this module. + Linkage* getLinkage() { return m_linkage; } + + /// Get the AST for the module (if it has been parsed) + ModuleDecl* getModuleDecl() { return m_moduleDecl; } + + /// The the IR for the module (if it has been generated) + IRModule* getIRModule() { return m_irModule; } + + /// Get the list of other modules this module depends on + List<RefPtr<Module>> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } + + /// Get the list of filesystem paths this module depends on + List<String> const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); } + + /// Register a module that this module depends on + void addModuleDependency(Module* module); + + /// Register a filesystem path that this module depends on + void addFilePathDependency(String const& path); + + /// Set the AST for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setModuleDecl(ModuleDecl* moduleDecl) { m_moduleDecl = moduleDecl; } + + /// Set the IR for this module. + /// + /// This should only be called once, during creation of the module. + /// + void setIRModule(IRModule* irModule) { m_irModule = irModule; } + + private: + // The parent linkage + Linkage* m_linkage = nullptr; + + // The AST for the module + RefPtr<ModuleDecl> m_moduleDecl; + + // The IR for the module + RefPtr<IRModule> m_irModule = nullptr; + + // List of modules this module depends on + ModuleDependencyList m_moduleDependencyList; + + // List of filesystem paths this module depends on + FilePathDependencyList m_filePathDependencyList; + }; + typedef Module LoadedModule; + + /// A request for the front-end to compile a translation unit. class TranslationUnitRequest : public RefObject { public: + TranslationUnitRequest( + FrontEndCompileRequest* compileRequest); + // The parent compile request - CompileRequest* compileRequest = nullptr; + FrontEndCompileRequest* compileRequest = nullptr; // The language in which the source file(s) // are assumed to be written @@ -189,26 +387,30 @@ namespace Slang // The source file(s) that will be compiled to form this translation unit // // Usually, for HLSL or GLSL there will be only one file. - List<SourceFile*> sourceFiles; + List<SourceFile*> m_sourceFiles; + + List<SourceFile*> const& getSourceFiles() { return m_sourceFiles; } + void addSourceFile(SourceFile* sourceFile); // The entry points associated with this translation unit - List<RefPtr<EntryPointRequest> > entryPoints; + List<RefPtr<EntryPoint>> entryPoints; // Preprocessor definitions to use for this translation unit only - // (whereas the ones on `CompileOptions` will be shared) + // (whereas the ones on `compileRequest` will be shared) Dictionary<String, String> preprocessorDefinitions; - // Compile flags for this translation unit - SlangCompileFlags compileFlags = 0; + /// The name that will be used for the module this translation unit produces. + Name* moduleName = nullptr; + + /// Result of compiling this translation unit (a module) + RefPtr<Module> module; - // The parsed syntax for the translation unit - RefPtr<ModuleDecl> SyntaxNode; + Module* getModule() { return module; } + RefPtr<ModuleDecl> getModuleDecl() { return module->getModuleDecl(); } - // The IR-level code for this translation unit. - // This will only be valid/non-null after semantic - // checking and IR generation are complete, so it - // is not safe to use this field without testing for NULL. - RefPtr<IRModule> irModule; + Session* getSession(); + NamePool* getNamePool(); + SourceManager* getSourceManager(); }; enum class FloatingPointMode : SlangFloatingPointMode @@ -232,33 +434,28 @@ namespace Slang Binary = SLANG_WRITER_MODE_BINARY, }; - // A request to generate output in some target format + /// A request to generate output in some target format. class TargetRequest : public RefObject { public: - CompileRequest* compileRequest; + Linkage* linkage; CodeGenTarget target; SlangTargetFlags targetFlags = 0; Slang::Profile targetProfile = Slang::Profile(); FloatingPointMode floatingPointMode = FloatingPointMode::Default; - // Requested output paths for each entry point. - // An empty string indices no output desired for - // the given entry point. - List<String> entryPointOutputPaths; + Linkage* getLinkage() { return linkage; } + CodeGenTarget getTarget() { return target; } + Profile getTargetProfile() { return targetProfile; } + FloatingPointMode getFloatingPointMode() { return floatingPointMode; } - // The resulting reflection layout information - RefPtr<ProgramLayout> layout; - - // Generated compile results for each entry point - // in the parent compile request (indexing matches - // the order they are given in the compile request) - List<CompileResult> entryPointResults; + Session* getSession(); + MatrixLayoutMode getDefaultMatrixLayoutMode(); // TypeLayouts created on the fly by reflection API Dictionary<Type*, RefPtr<TypeLayout>> typeLayouts; - MatrixLayoutMode getDefaultMatrixLayoutMode(); + Dictionary<Type*, RefPtr<TypeLayout>>& getTypeLayouts() { return typeLayouts; } }; /// Are we generating code for a D3D API? @@ -280,7 +477,8 @@ namespace Slang // - If the entry point and target disagree on the profile family, always use the // profile family and version from the target. // - Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target); + Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target); + // A directory to be searched when looking for files (e.g., `#include`) struct SearchDirectory @@ -294,110 +492,53 @@ namespace Slang String path; }; - // Represents a module that has been loaded through the front-end - // (up through IR generation). - // - class LoadedModule : public RefObject + /// A list of directories to search for files (e.g., `#include`) + struct SearchDirectoryList { - public: - // The AST for the module - RefPtr<ModuleDecl> moduleDecl; + // A parent list that should also be searched + SearchDirectoryList* parent = nullptr; - // The IR for the module - RefPtr<IRModule> irModule = nullptr; + // Directories to be searched + List<SearchDirectory> searchDirectories; }; - class Session; - - /// Create a blob that will retain (a copy of) raw data. /// ComPtr<ISlangBlob> createRawBlob(void const* data, size_t size); - class CompileRequest : public RefObject + /// A context for loading and re-using code modules. + class Linkage : public RefObject { public: - // Pointer to parent session - Session* mSession; + /// Create an initially-empty linkage + Linkage(Session* session); + + /// Get the parent session for this linkage + Session* getSession() { return m_session; } // Information on the targets we are being asked to // generate code for. List<RefPtr<TargetRequest>> targets; - // What container format are we being asked to generate? - ContainerFormat containerFormat = ContainerFormat::None; - - // Path to output container to - String containerOutputPath; - // Directories to search for `#include` files or `import`ed modules - List<SearchDirectory> searchDirectories; + SearchDirectoryList searchDirectories; + + SearchDirectoryList const& getSearchDirectories() { return searchDirectories; } // Definitions to provide during preprocessing Dictionary<String, String> preprocessorDefinitions; - // Translation units we are being asked to compile - List<RefPtr<TranslationUnitRequest> > translationUnits; - - // Entry points we've been asked to compile (each - // associated with a translation unit). - List<RefPtr<EntryPointRequest> > entryPoints; - - // Types constructed by reflection API - Dictionary<String, RefPtr<Type>> types; - - /// The layout to use for matrices by default (row/column major) - MatrixLayoutMode defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor; - MatrixLayoutMode getDefaultMatrixLayoutMode() { return defaultMatrixLayoutMode; } - - // Should we just pass the input to another compiler? - PassThroughMode passThrough = PassThroughMode::None; - - // Compile flags to be shared by all translation units - SlangCompileFlags compileFlags = 0; - - // Should we dump intermediate results along the way, for debugging? - bool shouldDumpIntermediates = false; - - bool shouldDumpIR = false; - bool shouldValidateIR = false; - bool shouldSkipCodegen = false; - - // If true then generateIR will serialize out IR, and serialize back in again. Making - // serialization a bottleneck or firewall between the front end and the backend - bool useSerialIRBottleneck = false; - - // If true will serialize and de-serialize with debug information - bool verifyDebugSerialization = false; - - // How should `#line` directives be emitted (if at all)? - LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default; - // Are we being driven by the command-line `slangc`, and should act accordingly? - bool isCommandLineCompile = false; // Source manager to help track files loaded - SourceManager sourceManagerStorage; - SourceManager* sourceManager; + SourceManager m_defaultSourceManager; + SourceManager* m_sourceManager = nullptr; // Name pool for looking up names NamePool namePool; NamePool* getNamePool() { return &namePool; } - // Output stuff - DiagnosticSink mSink; - String mDiagnosticOutput; - - /// A blob holding the diagnostic output - ComPtr<ISlangBlob> diagnosticOutputBlob; - - // Files that compilation depended on - List<String> mDependencyFilePaths; - - // Generated bytecode representation of all the code - List<uint8_t> generatedBytecode; - // Modules that have been dynamically loaded via `import` // // This is a list of unique modules loaded, in the order they were encountered. @@ -424,11 +565,7 @@ namespace Slang /// or a wrapped impl that makes fileSystem operate as fileSystemExt ComPtr<ISlangFileSystemExt> fileSystemExt; - // For output - ComPtr<ISlangWriter> m_writers[SLANG_WRITER_CHANNEL_COUNT_OF]; - - void setWriter(WriterChannel chan, ISlangWriter* writer); - ISlangWriter* getWriter(WriterChannel chan) const { return m_writers[int(chan)]; } + ISlangFileSystemExt* getFileSystemExt() { return fileSystemExt; } /// Load a file into memory using the configured file system. /// @@ -438,11 +575,177 @@ namespace Slang /// SlangResult loadFile(String const& path, ISlangBlob** outBlob); - CompileRequest(Session* session); - RefPtr<Expr> parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope); + RefPtr<Expr> parseTypeString(String typeStr, RefPtr<Scope> scope); + + /// Add a mew target amd return its index. + UInt addTarget( + CodeGenTarget target); + + RefPtr<Module> loadModule( + Name* name, + const PathInfo& filePathInfo, + ISlangBlob* fileContentsBlob, + SourceLoc const& loc, + DiagnosticSink* sink); + + void loadParsedModule( + RefPtr<TranslationUnitRequest> translationUnit, + Name* name, + PathInfo const& pathInfo); + + /// Load a module of the given name. + Module* loadModule(String const& name); + + RefPtr<Module> findOrImportModule( + Name* name, + SourceLoc const& loc, + DiagnosticSink* sink); - Type* getTypeFromString(String typeStr); + SourceManager* getSourceManager() + { + return m_sourceManager; + } + + /// Override the source manager for the linakge. + /// + /// This is only used to install a temporary override when + /// parsing stuff from strings (where we don't want to retain + /// full source files for the parsed result). + /// + /// TODO: We should remove the need for this hack. + /// + void setSourceManager(SourceManager* sourceManager) + { + m_sourceManager = sourceManager; + } + + void setFileSystem(ISlangFileSystem* fileSystem); + + /// The layout to use for matrices by default (row/column major) + MatrixLayoutMode defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor; + MatrixLayoutMode getDefaultMatrixLayoutMode() { return defaultMatrixLayoutMode; } + + private: + Session* m_session = nullptr; + + /// Tracks state of modules currently being loaded. + /// + /// This information is used to diagnose cases where + /// a user tries to recursively import the same module + /// (possibly along a transitive chain of `import`s). + /// + struct ModuleBeingImportedRAII + { + public: + ModuleBeingImportedRAII( + Linkage* linkage, + Module* module) + : linkage(linkage) + , module(module) + { + next = linkage->m_modulesBeingImported; + linkage->m_modulesBeingImported = this; + } + + ~ModuleBeingImportedRAII() + { + linkage->m_modulesBeingImported = next; + } + + Linkage* linkage; + Module* module; + ModuleBeingImportedRAII* next; + }; + + // Any modules currently being imported will be listed here + ModuleBeingImportedRAII* m_modulesBeingImported; + + /// Is the given module in the middle of being imported? + bool isBeingImported(Module* module); + }; + + /// Shared functionality between front- and back-end compile requests. + /// + /// This is the base class for both `FrontEndCompileRequest` and + /// `BackEndCompileRequest`, and allows a small number of parts of + /// the compiler to be easily invocable from either front-end or + /// back-end work. + /// + class CompileRequestBase : public RefObject + { + // TODO: We really shouldn't need this type in the long run. + // The few places that rely on it should be refactored to just + // depend on the unerlying information (a linkage and a diagnostic + // sink) directly. + // + // The flags to control dumping and validation of IR should be + // moved to some kind of shared settings/options `struct` that + // both front-end and back-end requests can store. + + public: + Session* getSession(); + Linkage* getLinkage() { return m_linkage; } + DiagnosticSink* getSink() { return m_sink; } + SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); } + SlangResult loadFile(String const& path, ISlangBlob** outBlob) { return getLinkage()->loadFile(path, outBlob); } + + bool shouldDumpIR = false; + bool shouldValidateIR = false; + + protected: + CompileRequestBase( + Linkage* linkage, + DiagnosticSink* sink); + + private: + Linkage* m_linkage = nullptr; + DiagnosticSink* m_sink = nullptr; + }; + + /// A request to compile source code to an AST + IR. + class FrontEndCompileRequest : public CompileRequestBase + { + public: + FrontEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink); + + int addEntryPoint( + int translationUnitIndex, + String const& name, + Profile entryPointProfile); + + // Translation units we are being asked to compile + List<RefPtr<TranslationUnitRequest> > translationUnits; + + RefPtr<TranslationUnitRequest> getTranslationUnit(UInt index) { return translationUnits[index]; } + + // Compile flags to be shared by all translation units + SlangCompileFlags compileFlags = 0; + + // If true then generateIR will serialize out IR, and serialize back in again. Making + // serialization a bottleneck or firewall between the front end and the backend + bool useSerialIRBottleneck = false; + + // If true will serialize and de-serialize with debug information + bool verifyDebugSerialization = false; + + List<RefPtr<FrontEndEntryPointRequest>> m_entryPointReqs; + + List<RefPtr<FrontEndEntryPointRequest>> const& getEntryPointReqs() { return m_entryPointReqs; } + UInt getEntryPointReqCount() { return m_entryPointReqs.Count(); } + FrontEndEntryPointRequest* getEntryPointReq(UInt index) { return m_entryPointReqs[index]; } + + // Directories to search for `#include` files or `import`ed modules + SearchDirectoryList searchDirectories; + + SearchDirectoryList const& getSearchDirectories() { return searchDirectories; } + + // Definitions to provide during preprocessing + Dictionary<String, String> preprocessorDefinitions; void parseTranslationUnit( TranslationUnitRequest* translationUnit); @@ -454,9 +757,24 @@ namespace Slang void generateIR(); SlangResult executeActionsInner(); - SlangResult executeActions(); - int addTranslationUnit(SourceLanguage language, String const& name); + /// Add a translation unit to be compiled. + /// + /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang` + /// @param moduleName The name that will be used for the module compile from the translation unit. + /// @return The zero-based index of the translation unit in this compile request. + int addTranslationUnit(SourceLanguage language, Name* moduleName); + + /// Add a translation unit to be compiled. + /// + /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang` + /// @return The zero-based index of the translation unit in this compile request. + /// + /// The module name for the translation unit will be automatically generated. + /// If all translation units in a compile request use automatically generated + /// module names, then they are guaranteed not to conflict with one another. + /// + int addTranslationUnit(SourceLanguage language); void addTranslationUnitSourceFile( int translationUnitIndex, @@ -476,63 +794,337 @@ namespace Slang int translationUnitIndex, String const& path); - int addEntryPoint( - int translationUnitIndex, - String const& name, - Profile profile, - List<String> const & genericTypeNames); + Program* getProgram() { return m_program; } - UInt addTarget( - CodeGenTarget target); + private: + RefPtr<Program> m_program; + }; - RefPtr<ModuleDecl> loadModule( - Name* name, - const PathInfo& filePathInfo, - ISlangBlob* fileContentsBlob, - SourceLoc const& loc); + /// A collection of code modules and entry points that are intended to be used together. + /// + /// A `Program` establishes that certain pieces of code are intended + /// to be used togehter so that, e.g., layout can make sure to allocate + /// space for the global shader parameters in all referenced modules. + /// + class Program : public RefObject + { + public: + /// Create a new program, initially empty. + /// + /// All code loaded into the program must come + /// from the given `linkage`. + Program( + Linkage* linkage); - void loadParsedModule( - RefPtr<TranslationUnitRequest> const& translationUnit, - Name* name, - PathInfo const& pathInfo); + /// Get the linkage that this program uses. + Linkage* getLinkage() { return m_linkage; } - RefPtr<ModuleDecl> findOrImportModule( - Name* name, - SourceLoc const& loc); + /// Get the number of entry points added to the program + UInt getEntryPointCount() { return m_entryPoints.Count(); } - Decl* lookupGlobalDecl(Name* name); + /// Get the entry point at the given `index`. + RefPtr<EntryPoint> getEntryPoint(UInt index) { return m_entryPoints[index]; } - SourceManager* getSourceManager() + /// Get the full ist of entry points on the program. + List<RefPtr<EntryPoint>> const& getEntryPoints() { return m_entryPoints; } + + /// Get the substitution (if any) that represents how global generics are specialized. + RefPtr<Substitutions> getGlobalGenericSubstitution() { return m_globalGenericSubst; } + + /// Get the full list of modules this program depends on + List<RefPtr<Module>> getModuleDependencies() { return m_moduleDependencyList.getModuleList(); } + + /// Get the full list of filesystem paths this program depends on + List<String> getFilePathDependencies() { return m_filePathDependencyList.getFilePathList(); } + + /// Get the target-specific version of this program for the given `target`. + /// + /// The `target` must be a target on the `Linkage` that was used to create this program. + TargetProgram* getTargetProgram(TargetRequest* target); + + /// Add a module (and everything it depends on) to the list of references + void addReferencedModule(Module* module); + + /// Add a module (but not the things it depends on) to the list of references + /// + /// This is a compatiblity hack for legacy compiler behavior. + void addReferencedLeafModule(Module* module); + + + /// Add an entry point to the program + /// + /// This also adds everything the entry point depends on to the list of references. + /// + void addEntryPoint(EntryPoint* entryPoint); + + /// Set the global generic argument substitution to use. + void setGlobalGenericSubsitution(RefPtr<Substitutions> subst) { - return sourceManager; + m_globalGenericSubst = subst; } - void setSourceManager(SourceManager* sm) + /// Parse a type from a string, in the context of this program. + /// + /// Any names in the string will be resolved using the modules + /// referenced by the program. + /// + /// On an error, returns null and reports diagnostic messages + /// to the provided `sink`. + /// + Type* getTypeFromString(String typeStr, DiagnosticSink* sink); + + /// Get the IR module that represents this program and its entry points. + /// + /// The IR module for a program tries to be minimal, and in the + /// common case will only include symbols with `[import]` declarations + /// for the entry point(s) of the program, and any types they + /// depend on. + /// + /// This IR module is intended to be linked against the IR modules + /// for all of the dependencies (see `getModuleDependencies()`) to + /// provide complete code. + /// + RefPtr<IRModule> getOrCreateIRModule(DiagnosticSink* sink); + + private: + // The linakge this program is associated with. + // + // Note that a `Program` keeps its associated linkage alive, + // and not vice versa. + // + RefPtr<Linkage> m_linkage; + + // Tracking data for the list of modules dependend on + ModuleDependencyList m_moduleDependencyList; + + // Tracking data for the list of filesystem paths dependend on + FilePathDependencyList m_filePathDependencyList; + + // Entry points that are part of the program. + List<RefPtr<EntryPoint> > m_entryPoints; + + // Specializations for global generic parameters (if any) + RefPtr<Substitutions> m_globalGenericSubst; + + // Generated IR for this program. + RefPtr<IRModule> m_irModule; + + // Cache of target-specific programs for each target. + Dictionary<TargetRequest*, RefPtr<TargetProgram>> m_targetPrograms; + + // Any types looked up dynamically using `getTypeFromString` + Dictionary<String, RefPtr<Type>> m_types; + }; + + /// A `Program` specialized for a particular `TargetRequest` + class TargetProgram : public RefObject + { + public: + TargetProgram( + Program* program, + TargetRequest* targetReq); + + /// Get the underlying program + Program* getProgram() { return m_program; } + + /// Get the underlying target + TargetRequest* getTargetReq() { return m_targetReq; } + + /// Get the layout for the program on the target. + /// + /// If this is the first time the layout has been + /// requested, report any errors that arise during + /// layout to the given `sink`. + /// + ProgramLayout* getOrCreateLayout(DiagnosticSink* sink); + + /// Get the layout for the program on the taarget. + /// + /// This routine assumes that `getOrCreateLayout` + /// has already been called previously. + /// + ProgramLayout* getExistingLayout() { - sourceManager = sm; - mSink.sourceManager = sm; + SLANG_ASSERT(m_layout); + return m_layout; } - void setFileSystem(ISlangFileSystem* fileSystem); + /// Get the compiled code for an entry point on the target. + /// + /// This routine assumes code generation has already been + /// performed and called `setEntryPointResult`. + /// + CompileResult& getExistingEntryPointResult(Int entryPointIndex) + { + return m_entryPointResults[entryPointIndex]; + } + + // TODO: Need a lazy `getOrCreateEntryPointResult` + + /// Set the compiled code for an entry point. + /// + /// Should only be called by code generation. + void setEntryPointResult(Int entryPointIndex, CompileResult const& result) + { + m_entryPointResults[entryPointIndex] = result; + } + + private: + // The program being compiled or laid out + Program* m_program; + + // The target that code/layout will be generated for + TargetRequest* m_targetReq; + + // The computed layout, if it has been generated yet + RefPtr<ProgramLayout> m_layout; + + // Generated compile results for each entry point + // in the parent `Program` (indexing matches + // the order they are given in the `Program`) + List<CompileResult> m_entryPointResults; + }; + + /// A request to generate code for a program + class BackEndCompileRequest : public CompileRequestBase + { + public: + BackEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink, + Program* program = nullptr); + + // Should we dump intermediate results along the way, for debugging? + bool shouldDumpIntermediates = false; + + // How should `#line` directives be emitted (if at all)? + LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default; + + LineDirectiveMode getLineDirectiveMode() { return lineDirectiveMode; } - /// During propagation of an exception for an internal - /// error, note that this source location was involved - void noteInternalErrorLoc(SourceLoc const& loc); + Program* getProgram() { return m_program; } + void setProgram(Program* program) { m_program = program; } - int internalErrorLocsNoted = 0; + private: + RefPtr<Program> m_program; + }; + + /// A compile request that spans the front and back ends of the compiler + /// + /// This is what the command-line `slangc` uses, as well as the legacy + /// C API. It ties together the functionality of `Linkage`, + /// `FrontEndCompileRequest`, and `BackEndCompileRequest`, plus a small + /// number of additional features that primarily make sense for + /// command-line usage. + /// + class EndToEndCompileRequest : public RefObject + { + public: + EndToEndCompileRequest( + Session* session); + + // What container format are we being asked to generate? + // + // Note: This field is unused except by the options-parsing + // logic; it exists to support wriiting out binary modules + // once that feature is ready. + // + ContainerFormat containerFormat = ContainerFormat::None; + + // Path to output container to + // + // Note: This field exists to support wriiting out binary modules + // once that feature is ready. + // + String containerOutputPath; + + // Should we just pass the input to another compiler? + PassThroughMode passThrough = PassThroughMode::None; + + /// Source code for the generic arguments to use for the global generic parameters of the program. + List<String> globalGenericArgStrings; + + + bool shouldSkipCodegen = false; + + // Are we being driven by the command-line `slangc`, and should act accordingly? + bool isCommandLineCompile = false; + + String mDiagnosticOutput; + + /// A blob holding the diagnostic output + ComPtr<ISlangBlob> diagnosticOutputBlob; + + /// Per-entry-point information not tracked by other compile requests + class EntryPointInfo : public RefObject + { + public: + /// Source code for the generic arguments to use for the generic parameters of the entry point. + List<String> genericArgStrings; + }; + List<EntryPointInfo> entryPoints; + + /// Per-target information only needed for command-line compiles + class TargetInfo : public RefObject + { + public: + // Requested output paths for each entry point. + // An empty string indices no output desired for + // the given entry point. + Dictionary<Int, String> entryPointOutputPaths; + }; + Dictionary<TargetRequest*, RefPtr<TargetInfo>> targetInfos; + + Linkage* getLinkage() { return m_linkage; } + + int addEntryPoint( + int translationUnitIndex, + String const& name, + Profile profile, + List<String> const & genericTypeNames); + + void setWriter(WriterChannel chan, ISlangWriter* writer); + ISlangWriter* getWriter(WriterChannel chan) const { return m_writers[int(chan)]; } + + SlangResult executeActionsInner(); + SlangResult executeActions(); + + Session* getSession() { return m_session; } + DiagnosticSink* getSink() { return &m_sink; } + NamePool* getNamePool() { return getLinkage()->getNamePool(); } + + FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; } + BackEndCompileRequest* getBackEndReq() { return m_backEndReq; } + Program* getUnspecializedProgram() { return getFrontEndReq()->getProgram(); } + Program* getSpecializedProgram() { return m_specializedProgram; } + + private: + Session* m_session = nullptr; + RefPtr<Linkage> m_linkage; + DiagnosticSink m_sink; + RefPtr<FrontEndCompileRequest> m_frontEndReq; + RefPtr<Program> m_unspecializedProgram; + RefPtr<Program> m_specializedProgram; + RefPtr<BackEndCompileRequest> m_backEndReq; + + // For output + ComPtr<ISlangWriter> m_writers[SLANG_WRITER_CHANNEL_COUNT_OF]; }; void generateOutput( - CompileRequest* compileRequest); + BackEndCompileRequest* compileRequest); + + void generateOutput( + EndToEndCompileRequest* compileRequest); // Helper to dump intermediate output when debugging void maybeDumpIntermediate( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, void const* data, size_t size, CodeGenTarget target); void maybeDumpIntermediate( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, char const* text, CodeGenTarget target); @@ -548,12 +1140,14 @@ namespace Slang @param sink The diagnostic sink to report to */ void reportExternalCompileError(const char* compilerName, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink); - /* Given a translationUnitRequest determines a filename that is most suitable to identify the input. - If the translation is a pass through will attempt to get the source file pathname. If the source is slang generated - there is no equivalent name so will return 'slang-generated' - @param translationUnitRequest The request to find an appropriate source path for + /* Determines a suitable filename to identify the input for a given entry point being compiled. + If the end-to-end compile is a pass-through case, will attempt to find the (unique) source file + pathname for the translation unit containing the entry point at `entryPointIndex. + If the compilation is not in a pass-through case, then always returns `"slang-generated"`. + @param endToEndReq The end-to-end compile request which might be using pass-through copmilation + @param entryPointIndex The index of the entry point to compute a filename for. @return the appropriate source filename */ - String calcTranslationUnitSourcePath(TranslationUnitRequest* translationUnitRequest); + String calcSourcePathForEntryPoint(EndToEndCompileRequest* endToEndReq, UInt entryPointIndex); struct TypeCheckingCache; // @@ -696,6 +1290,10 @@ namespace Slang String const& path, String const& source); ~Session(); + + private: + /// Linkage used for all built-in (stdlib) code. + RefPtr<Linkage> m_builtinLinkage; }; } diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index 0480eb934..10dcefe19 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -240,6 +240,14 @@ SIMPLE_SYNTAX_CLASS(FuncDecl, FunctionDeclBase) // that provides a scope for some number of declarations. SYNTAX_CLASS(ModuleDecl, ContainerDecl) FIELD(RefPtr<Scope>, scope) + + // The API-level module that this declaration belong to. + // + // This field allows lookup of the `Module` based on a + // declaration nested under a `ModuleDecl` by following + // its chain of parents. + // + RAW(Module* module = nullptr;) END_SYNTAX_CLASS() SYNTAX_CLASS(ImportDecl, Decl) diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 51076bde1..2d0dd7fdd 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -336,7 +336,7 @@ DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entr DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function") DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'") -DIAGNOSTIC(38005, Error, entryPointTypeSymbolNotAType, "entry-point type parameter '$0' must be declared as a type") +DIAGNOSTIC(38005, Error, globalGenericArgumentNotAType, "argument for global generic parameter '$0' must be a type") DIAGNOSTIC(38006, Warning, specifiedStageDoesntMatchAttribute, "entry point '$0' being compiled for the '$1' stage has a '[shader(...)]' attribute that specifies the '$2' stage") DIAGNOSTIC(38007, Error, entryPointHasNoStage, "no stage specified for entry point '$0'; use either a '[shader(\"name\")]' function attribute or the '-stage <name>' command-line option to specify a stage") @@ -356,6 +356,10 @@ DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_Dispatc DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'") +DIAGNOSTIC(38020, Error, mismatchGlobalGenericArguments, "expected $0 global generic arguments ($1 provided)") +DIAGNOSTIC(38021, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.") + + DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself") DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.") diff --git a/source/slang/diagnostics.h b/source/slang/diagnostics.h index e3aba32e6..9efc5efc6 100644 --- a/source/slang/diagnostics.h +++ b/source/slang/diagnostics.h @@ -2,6 +2,7 @@ #define RASTER_RENDERER_COMPILE_ERROR_H #include "../core/basic.h" +#include "../core/slang-writer.h" #include "source-loc.h" #include "token.h" @@ -153,7 +154,7 @@ namespace Slang StringBuilder outputBuffer; // List<Diagnostic> diagnostics; int errorCount = 0; - + int internalErrorLocsNoted = 0; ISlangWriter* writer = nullptr; Flags flags = 0; @@ -217,6 +218,32 @@ namespace Slang void diagnoseRaw( Severity severity, const UnownedStringSlice& message); + + /// During propagation of an exception for an internal + /// error, note that this source location was involved + void noteInternalErrorLoc(SourceLoc const& loc); + }; + + /// An `ISlangWriter` that writes directly to a diagnostic sink. + class DiagnosticSinkWriter : public AppendBufferWriter + { + public: + typedef AppendBufferWriter Super; + + DiagnosticSinkWriter(DiagnosticSink* sink) + : Super(WriterFlag::IsStatic) + , m_sink(sink) + {} + + // ISlangWriter + SLANG_NO_THROW virtual SlangResult SLANG_MCALL write(const char* chars, size_t numChars) SLANG_OVERRIDE + { + m_sink->diagnoseRaw(Severity::Note, UnownedStringSlice(chars, chars+numChars)); + return SLANG_OK; + } + + private: + DiagnosticSink* m_sink = nullptr; }; namespace Diagnostics diff --git a/source/slang/dxc-support.cpp b/source/slang/dxc-support.cpp index 7d2d6dc0c..603a59ea7 100644 --- a/source/slang/dxc-support.cpp +++ b/source/slang/dxc-support.cpp @@ -30,8 +30,11 @@ namespace Slang { String GetHLSLProfileName(Profile profile); String emitHLSLForEntryPoint( - EntryPointRequest* entryPoint, - TargetRequest* targetReq); + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq); static UnownedStringSlice _getSlice(IDxcBlob* blob) { @@ -46,19 +49,22 @@ namespace Slang } SlangResult emitDXILForEntryPointUsingDXC( - EntryPointRequest* entryPoint, - TargetRequest* targetReq, - List<uint8_t>& outCode) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + Int entryPointIndex, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List<uint8_t>& outCode) { - auto compileRequest = entryPoint->compileRequest; - auto session = compileRequest->mSession; + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); // First deal with all the rigamarole of loading // the `dxcompiler` library, and creating the // top-level COM objects that will be used to // compile things. - auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, &compileRequest->mSink); + auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink); if (!dxcCreateInstance) { return SLANG_FAIL; @@ -69,9 +75,7 @@ namespace Slang { // If can't load dxil - dxc will not be able to sign output // Output a suitable warning to the user - auto& sink = entryPoint->compileRequest->mSink; - - sink.diagnose(SourceLoc(), Diagnostics::dxilNotFound); + sink->diagnose(SourceLoc(), Diagnostics::dxilNotFound); } } @@ -89,8 +93,13 @@ namespace Slang // Now let's go ahead and generate HLSL for the entry // point, since we'll need that to feed into dxc. - auto hlslCode = emitHLSLForEntryPoint(entryPoint, targetReq); - maybeDumpIntermediate(entryPoint->compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL); + auto hlslCode = emitHLSLForEntryPoint( + compileRequest, + entryPoint, + entryPointIndex, + targetReq, + endToEndReq); + maybeDumpIntermediate(compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL); // Wrap the @@ -122,7 +131,7 @@ namespace Slang break; } - switch( targetReq->floatingPointMode ) + switch( targetReq->getFloatingPointMode() ) { default: break; @@ -149,7 +158,7 @@ namespace Slang // args[argCount++] = L"-no-warnings"; - String entryPointName = getText(entryPoint->name); + String entryPointName = getText(entryPoint->getName()); OSString wideEntryPointName = entryPointName.ToWString(); auto profile = getEffectiveProfile(entryPoint, targetReq); @@ -172,7 +181,7 @@ namespace Slang args[argCount++] = L"-enable-16bit-types"; } - const String sourcePath = calcTranslationUnitSourcePath(entryPoint->getTranslationUnit()); + const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); ComPtr<IDxcOperationResult> dxcResult; SLANG_RETURN_ON_FAIL(dxcCompiler->Compile(dxcSourceBlob, @@ -208,7 +217,7 @@ namespace Slang // into a string for safety. // - reportExternalCompileError("dxc", resultCode, _getSlice(dxcErrorBlob), &entryPoint->compileRequest->mSink); + reportExternalCompileError("dxc", resultCode, _getSlice(dxcErrorBlob), compileRequest->getSink()); return resultCode; } @@ -225,20 +234,21 @@ namespace Slang } SlangResult dissassembleDXILUsingDXC( - CompileRequest* compileRequest, - void const* data, - size_t size, - String& stringOut) + BackEndCompileRequest* compileRequest, + void const* data, + size_t size, + String& stringOut) { stringOut = String(); - auto session = compileRequest->mSession; + auto session = compileRequest->getSession(); + auto sink = compileRequest->getSink(); // First deal with all the rigamarole of loading // the `dxcompiler` library, and creating the // top-level COM objects that will be used to // compile things. - auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, &compileRequest->mSink); + auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink); if (!dxcCreateInstance) { return SLANG_FAIL; diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index c0e5e4296..2603df11e 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -79,8 +79,10 @@ void requireGLSLVersionImpl( // Shared state for an entire emit session struct SharedEmitContext { + BackEndCompileRequest* compileRequest = nullptr; + // The entry point we are being asked to compile - EntryPointRequest* entryPoint; + EntryPoint* entryPoint; // The layout for the entry point EntryPointLayout* entryPointLayout; @@ -153,7 +155,7 @@ struct SharedEmitContext // to use for it when emitting code. Dictionary<IRInst*, String> mapInstToName; - DiagnosticSink* getSink() { return &entryPoint->compileRequest->mSink; } + DiagnosticSink* getSink() { return compileRequest->getSink(); } Dictionary<IRInst*, UInt> mapIRValueToRayPayloadLocation; Dictionary<IRInst*, UInt> mapIRValueToCallablePayloadLocation; @@ -165,6 +167,10 @@ struct EmitContext SharedEmitContext* shared; DiagnosticSink* getSink() { return shared->getSink(); } + + LineDirectiveMode getLineDirectiveMode() { return shared->compileRequest->getLineDirectiveMode(); } + SourceManager* getSourceManager() { return shared->compileRequest->getSourceManager(); } + void noteInternalErrorLoc(SourceLoc loc) { return getSink()->noteInternalErrorLoc(loc); } }; // @@ -334,11 +340,6 @@ struct EmitVisitor : context(context) {} - Session* getSession() - { - return context->shared->entryPoint->compileRequest->mSession; - } - // Low-level emit logic void emitRawTextSpan(char const* textBegin, char const* textEnd) @@ -556,7 +557,7 @@ struct EmitVisitor bool shouldUseGLSLStyleLineDirective = false; - auto mode = context->shared->entryPoint->compileRequest->lineDirectiveMode; + auto mode = context->getLineDirectiveMode(); switch (mode) { case LineDirectiveMode::None: @@ -664,7 +665,7 @@ struct EmitVisitor { // Don't do any of this work if the user has requested that we // not emit line directives. - auto mode = context->shared->entryPoint->compileRequest->lineDirectiveMode; + auto mode = context->getLineDirectiveMode(); switch(mode) { case LineDirectiveMode::None: @@ -723,7 +724,7 @@ struct EmitVisitor SourceManager* getSourceManager() { - return context->shared->entryPoint->compileRequest->getSourceManager(); + return context->getSourceManager(); } void advanceToSourceLocation( @@ -747,7 +748,7 @@ struct EmitVisitor DiagnosticSink* getSink() { - return &context->shared->entryPoint->compileRequest->mSink; + return context->getSink(); } // @@ -1875,8 +1876,7 @@ struct EmitVisitor } } - void emitGLSLVersionDirective( - ModuleDecl* /*program*/) + void emitGLSLVersionDirective() { auto effectiveProfile = context->shared->effectiveProfile; if(effectiveProfile.getFamily() == ProfileFamily::GLSL) @@ -1931,8 +1931,7 @@ struct EmitVisitor Emit("#version 420\n"); } - void emitGLSLPreprocessorDirectives( - RefPtr<ModuleDecl> program) + void emitGLSLPreprocessorDirectives() { switch(context->shared->target) { @@ -1944,24 +1943,7 @@ struct EmitVisitor break; } - emitGLSLVersionDirective(program); - - - // TODO: when cross-compiling we may need to output additional `#extension` directives - // based on the features that we have used. - - for( auto extensionDirective : program->GetModifiersOfType<GLSLExtensionDirective>() ) - { - // TODO(tfoley): Emit an appropriate `#line` directive... - - Emit("#extension "); - emit(extensionDirective->extensionNameToken.Content); - Emit(" : "); - emit(extensionDirective->dispositionToken.Content); - Emit("\n"); - } - - // TODO: handle other cases... + emitGLSLVersionDirective(); } /// Emit directives to control overall layout computation for the emitted code. @@ -4115,7 +4097,7 @@ struct EmitVisitor catch(AbortCompilationException&) { throw; } catch(...) { - ctx->shared->entryPoint->compileRequest->noteInternalErrorLoc(inst->sourceLoc); + ctx->noteInternalErrorLoc(inst->sourceLoc); throw; } } @@ -6583,26 +6565,26 @@ struct EmitVisitor // EntryPointLayout* findEntryPointLayout( - ProgramLayout* programLayout, - EntryPointRequest* entryPointRequest) + ProgramLayout* programLayout, + EntryPoint* entryPoint) { for( auto entryPointLayout : programLayout->entryPoints ) { - if(entryPointLayout->entryPoint->getName() != entryPointRequest->name) + if(entryPointLayout->entryPoint->getName() != entryPoint->getName()) continue; // TODO: We need to be careful about this check, since it relies on // the profile information in the layout matching that in the request. // // What we really seem to want here is some dictionary mapping the - // `EntryPointRequest` directly to the `EntryPointLayout`, and maybe + // `EntryPoint` directly to the `EntryPointLayout`, and maybe // that is precisely what we should build... // - if(entryPointLayout->profile != entryPointRequest->profile) + if(entryPointLayout->profile != entryPoint->getProfile()) continue; // TODO: can't easily filter on translation unit here... - // Ideally the `EntryPointRequest` should get filled in with a pointer + // Ideally the `EntryPoint` should get filled in with a pointer // the specific function declaration that represents the entry point. return entryPointLayout.Ptr(); @@ -6662,13 +6644,14 @@ void legalizeTypes( IRModule* module); static void dumpIRIfEnabled( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, IRModule* irModule, char const* label = nullptr) { if(compileRequest->shouldDumpIR) { - WriterHelper writer(compileRequest->getWriter(WriterChannel::StdError)); + DiagnosticSinkWriter writerImpl(compileRequest->getSink()); + WriterHelper writer(&writerImpl); if(label) { @@ -6687,16 +6670,22 @@ static void dumpIRIfEnabled( } String emitEntryPoint( - EntryPointRequest* entryPoint, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetRequest) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + CodeGenTarget target, + TargetRequest* targetRequest) { - auto translationUnit = entryPoint->getTranslationUnit(); + auto sink = compileRequest->getSink(); + auto program = compileRequest->getProgram(); + auto targetProgram = program->getTargetProgram(targetRequest); + auto programLayout = targetProgram->getOrCreateLayout(sink); + +// auto translationUnit = entryPoint->getTranslationUnit(); SharedEmitContext sharedContext; + sharedContext.compileRequest = compileRequest; sharedContext.target = target; - sharedContext.finalTarget = targetRequest->target; + sharedContext.finalTarget = targetRequest->getTarget(); sharedContext.entryPoint = entryPoint; sharedContext.effectiveProfile = getEffectiveProfile(entryPoint, targetRequest); @@ -6715,16 +6704,13 @@ String emitEntryPoint( StructTypeLayout* globalStructLayout = getGlobalStructLayout(programLayout); sharedContext.globalStructLayout = globalStructLayout; - auto translationUnitSyntax = translationUnit->SyntaxNode.Ptr(); - EmitContext context; context.shared = &sharedContext; EmitVisitor visitor(&context); { - auto compileRequest = translationUnit->compileRequest; - auto session = compileRequest->mSession; + auto session = targetRequest->getSession(); // We start out by performing "linking" at the level of the IR. // This step will create a fresh IR module to be used for @@ -6735,6 +6721,7 @@ String emitEntryPoint( // any "profile-overloaded" symbols. // auto linkedIR = linkIR( + compileRequest, entryPoint, programLayout, target, @@ -6880,7 +6867,7 @@ String emitEntryPoint( session, irModule, irEntryPoint, - &compileRequest->mSink, + compileRequest->getSink(), &sharedContext.extensionUsageTracker); } break; @@ -6916,10 +6903,6 @@ String emitEntryPoint( // TODO: do we want to emit directly from IR, or translate the // IR back into AST for emission? visitor.emitIRModule(&context, irModule); - - // retain the specialized ir module, because the current - // GlobalGenericParamSubstitution implementation may reference ir objects - targetRequest->compileRequest->compiledModules.Add(irModule); } // Deal with cases where a particular stage requires certain GLSL versions @@ -6950,7 +6933,7 @@ String emitEntryPoint( // it is time to stich together the final output. // There may be global-scope modifiers that we should emit now - visitor.emitGLSLPreprocessorDirectives(translationUnitSyntax); + visitor.emitGLSLPreprocessorDirectives(); visitor.emitLayoutDirectives(targetRequest); diff --git a/source/slang/emit.h b/source/slang/emit.h index 98845f9c6..317afcf6b 100644 --- a/source/slang/emit.h +++ b/source/slang/emit.h @@ -8,7 +8,7 @@ namespace Slang { - class EntryPointRequest; + class EntryPoint; class ProgramLayout; class TranslationUnitRequest; @@ -20,13 +20,13 @@ namespace Slang // Emit code for a single entry point, based on // the input translation unit. String emitEntryPoint( - EntryPointRequest* entryPoint, - ProgramLayout* programLayout, + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, // The target language to generate code in (e.g., HLSL/GLSL) - CodeGenTarget target, + CodeGenTarget target, // The full target request - TargetRequest* targetRequest); + TargetRequest* targetRequest); } #endif diff --git a/source/slang/ir-dce.cpp b/source/slang/ir-dce.cpp index 0f037bfe5..ba6d7adb9 100644 --- a/source/slang/ir-dce.cpp +++ b/source/slang/ir-dce.cpp @@ -16,8 +16,8 @@ struct DeadCodeEliminationContext // the parameters that were passed to the top-level // `eliminateDeadCode` function. // - CompileRequest* compileRequest; - IRModule* module; + BackEndCompileRequest* compileRequest; + IRModule* module; // Our overall process is going to be to determine // which instructions in the module are "live" @@ -235,9 +235,9 @@ struct DeadCodeEliminationContext // we'll just go ahead and eliminate every single function/type // in a module. There needs to be a way to identify the // functions we want to keep around, and for right now - // that is handled with the `[entryPoint]` decoration. + // that is handled with the `[keepAlive]` decoration. // - if(inst->findDecorationImpl(kIROp_EntryPointDecoration)) + if(inst->findDecorationImpl(kIROp_KeepAliveDecoration)) return true; // // TODO: Eventually it would make sense to consider everything @@ -312,7 +312,7 @@ struct DeadCodeEliminationContext // and then defer to it for the real work. // void eliminateDeadCode( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, IRModule* module) { DeadCodeEliminationContext context; diff --git a/source/slang/ir-dce.h b/source/slang/ir-dce.h index fd56616d9..6089b404a 100644 --- a/source/slang/ir-dce.h +++ b/source/slang/ir-dce.h @@ -3,7 +3,7 @@ namespace Slang { - class CompileRequest; + class BackEndCompileRequest; struct IRModule; /// Eliminate "dead" code from the given IR module. @@ -14,6 +14,6 @@ namespace Slang /// etc. /// void eliminateDeadCode( - CompileRequest* compileRequest, - IRModule* module); + BackEndCompileRequest* compileRequest, + IRModule* module); } diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index b6f8ce547..eada52e4d 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -392,6 +392,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// even if it does not otherwise reference it. INST(DependsOnDecoration, dependsOn, 1, 0) + /// A `[keepAlive]` decoration marks an instruction that should not be eliminated. + INST(KeepAliveDecoration, keepAlive, 0, 0) + + /* LinkageDecoration */ INST(ImportDecoration, import, 1, 0) INST(ExportDecoration, export, 1, 0) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 6b12612ef..b7c1b2744 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -1201,6 +1201,11 @@ struct IRBuilder addDecoration(value, kIROp_EntryPointDecoration); } + void addKeepAliveDecoration(IRInst* value) + { + addDecoration(value, kIROp_KeepAliveDecoration); + } + /// Add a decoration that indicates that the given `inst` depends on the given `dependency`. /// /// This decoration can be used to ensure that a value that an instruction diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp index 35e0f46b8..2eef10614 100644 --- a/source/slang/ir-link.cpp +++ b/source/slang/ir-link.cpp @@ -14,7 +14,7 @@ namespace Slang // instead of the input/request layer. EntryPointLayout* findEntryPointLayout( ProgramLayout* programLayout, - EntryPointRequest* entryPointRequest); + EntryPoint* EntryPoint); struct IRSpecSymbol : RefObject { @@ -39,9 +39,6 @@ struct IRSharedSpecContext // The specialized module we are building RefPtr<IRModule> module; - // The original, unspecialized module we are copying - IRModule* originalModule; - // A map from mangled symbol names to zero or // more global IR values that have that name, // in the *original* module. @@ -67,8 +64,6 @@ struct IRSpecContextBase IRModule* getModule() { return getShared()->module; } - IRModule* getOriginalModule() { return getShared()->originalModule; } - IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; } // The current specialization environment to use. @@ -668,8 +663,8 @@ IRInst* specializeGeneric( IRSpecialize* specializeInst); IRFunc* specializeIRForEntryPoint( - IRSpecContext* context, - EntryPointRequest* entryPointRequest, + IRSpecContext* context, + EntryPoint* entryPoint, EntryPointLayout* entryPointLayout) { // We start by looking up the IR symbol that @@ -681,7 +676,7 @@ IRFunc* specializeIRForEntryPoint( // so that the mangled name of the decl-ref is // not the same as the mangled name of the decl. // - auto mangledName = getMangledName(entryPointRequest->getFuncDeclRef()); + auto mangledName = getMangledName(entryPoint->getFuncDeclRef()); RefPtr<IRSpecSymbol> sym; if (!context->getSymbols().TryGetValue(mangledName, sym)) { @@ -743,9 +738,9 @@ IRFunc* specializeIRForEntryPoint( return nullptr; } - if( !clonedFunc->findDecorationImpl(kIROp_EntryPointDecoration) ) + if( !clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration) ) { - context->builder->addEntryPointDecoration(clonedFunc); + context->builder->addKeepAliveDecoration(clonedFunc); } // We need to attach the layout information for @@ -1148,7 +1143,6 @@ void initializeSharedSpecContext( IRSharedSpecContext* sharedContext, Session* session, IRModule* module, - IRModule* originalModule, CodeGenTarget target) { @@ -1166,19 +1160,15 @@ void initializeSharedSpecContext( sharedBuilder->module = module; sharedContext->module = module; - sharedContext->originalModule = originalModule; sharedContext->target = target; - // We will populate a map with all of the IR values - // that use the same mangled name, to make lookup easier - // in other steps. - insertGlobalValueSymbols(sharedContext, originalModule); } // implementation provided in parameter-binding.cpp RefPtr<ProgramLayout> specializeProgramLayout( TargetRequest * targetReq, - ProgramLayout* programLayout, - SubstitutionSet typeSubst); + ProgramLayout* programLayout, + SubstitutionSet typeSubst, + DiagnosticSink* sink); struct IRSpecializationState { @@ -1211,11 +1201,14 @@ struct IRSpecializationState }; LinkedIR linkIR( - EntryPointRequest* entryPointRequest, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetReq) + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + TargetRequest* targetReq) { + auto sink = compileRequest->getSink(); + IRSpecializationState stateStorage; auto state = &stateStorage; @@ -1223,26 +1216,27 @@ LinkedIR linkIR( state->target = target; state->targetReq = targetReq; - - auto compileRequest = entryPointRequest->compileRequest; - auto translationUnit = entryPointRequest->getTranslationUnit(); - auto originalIRModule = translationUnit->irModule; + auto program = compileRequest->getProgram(); auto sharedContext = state->getSharedContext(); initializeSharedSpecContext( sharedContext, - compileRequest->mSession, + compileRequest->getSession(), nullptr, - originalIRModule, target); state->irModule = sharedContext->module; - // We also need to attach the IR definitions for symbols from - // any loaded modules: - for (auto loadedModule : compileRequest->loadedModulesList) + // We need to be able to look up IR definitions for any symbols in + // modules that the program depends on (transitively). To + // accelerate lookup, we will create a symbol table for looking + // up IR definitions by their mangled name. + // + auto originalProgramIRModule = program->getOrCreateIRModule(sink); + insertGlobalValueSymbols(sharedContext, originalProgramIRModule); + for (auto module : program->getModuleDependencies()) { - insertGlobalValueSymbols(sharedContext, loadedModule->irModule); + insertGlobalValueSymbols(sharedContext, module->getIRModule()); } auto context = state->getContext(); @@ -1257,7 +1251,8 @@ LinkedIR linkIR( RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout( targetReq, programLayout, - SubstitutionSet(entryPointRequest->globalGenericSubst)); + SubstitutionSet(program->getGlobalGenericSubstitution()), + compileRequest->getSink()); // TODO: we need to register the (IR-level) arguments of the global generic parameters as the // substitutions for the generic parameters in the original IR. @@ -1267,13 +1262,22 @@ LinkedIR linkIR( state->newProgramLayout = newProgramLayout; - // Next, we want to optimize lookup for layout infromation + // Next, we want to optimize lookup for layout information // associated with global declarations, so that we can // look things up based on the IR values (using mangled names) + // + // Note: We are scanning over all the key-value pairs for + // entries in the global scope, to account for the fact + // that the "same" shader parameter could be declared in + // multiple translation units, and thus end up with + // multiple mangled names (when the unique translation + // unit name gets involved). + // auto globalStructLayout = getScopeStructLayout(newProgramLayout); - for (auto globalVarLayout : globalStructLayout->fields) + for(auto entry : globalStructLayout->mapVarToLayout) { - auto mangledName = getMangledName(globalVarLayout->varDecl); + auto mangledName = getMangledName(entry.Key); + auto globalVarLayout = entry.Value; context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } @@ -1290,19 +1294,19 @@ LinkedIR linkIR( cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); } - auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); + auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPoint); // Next, we make sure to clone the global value for // the entry point function itself, and rely on // this step to recursively copy over anything else // it might reference. - auto irEntryPoint = specializeIRForEntryPoint(context, entryPointRequest, entryPointLayout); + auto irEntryPoint = specializeIRForEntryPoint(context, entryPoint, entryPointLayout); // HACK: right now the bindings for global generic parameters are coming in // as part of the original IR module, and we need to make sure these get // copied over, even if they aren't referenced. // - for(auto inst : originalIRModule->getGlobalInsts()) + for(auto inst : originalProgramIRModule->getGlobalInsts()) { auto bindInst = as<IRBindGlobalGenericParam>(inst); if(!bindInst) diff --git a/source/slang/ir-link.h b/source/slang/ir-link.h index 4fcdb4618..dba3ccc97 100644 --- a/source/slang/ir-link.h +++ b/source/slang/ir-link.h @@ -19,8 +19,9 @@ namespace Slang // used. // LinkedIR linkIR( - EntryPointRequest* entryPointRequest, - ProgramLayout* programLayout, - CodeGenTarget target, - TargetRequest* targetReq); + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + TargetRequest* targetReq); } diff --git a/source/slang/ir-specialize-resources.cpp b/source/slang/ir-specialize-resources.cpp index 0108a91f8..e974ffdb7 100644 --- a/source/slang/ir-specialize-resources.cpp +++ b/source/slang/ir-specialize-resources.cpp @@ -18,7 +18,7 @@ struct ResourceParameterSpecializationContext // the parameters that were passed to the top-level // `specializeResourceParameters` function. // - CompileRequest* compileRequest; + BackEndCompileRequest* compileRequest; TargetRequest* targetRequest; IRModule* module; @@ -372,7 +372,7 @@ struct ResourceParameterSpecializationContext // If we didn't find a pre-existing specialized // function, then we will go ahead and create one. // - // We start by gathering the infromation from the call + // We start by gathering the information from the call // site that is relevant to generating a specialized // callee function, which we avoided doing earlier // because it might have been throwaway work. @@ -850,7 +850,7 @@ struct ResourceParameterSpecializationContext // and then defer to it for the real work. // void specializeResourceParameters( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* module) { diff --git a/source/slang/ir-specialize-resources.h b/source/slang/ir-specialize-resources.h index 3d6ead130..0e636318c 100644 --- a/source/slang/ir-specialize-resources.h +++ b/source/slang/ir-specialize-resources.h @@ -3,7 +3,7 @@ namespace Slang { - class CompileRequest; + class BackEndCompileRequest; class TargetRequest; struct IRModule; @@ -18,7 +18,7 @@ namespace Slang /// global shader parameters directly). /// void specializeResourceParameters( - CompileRequest* compileRequest, + BackEndCompileRequest* compileRequest, TargetRequest* targetRequest, IRModule* module); } diff --git a/source/slang/ir-validate.cpp b/source/slang/ir-validate.cpp index 924ec71b3..9564873b1 100644 --- a/source/slang/ir-validate.cpp +++ b/source/slang/ir-validate.cpp @@ -195,13 +195,13 @@ namespace Slang } void validateIRModuleIfEnabled( - CompileRequest* compileRequest, - IRModule* module) + CompileRequestBase* compileRequest, + IRModule* module) { if (!compileRequest->shouldValidateIR) return; - auto sink = &compileRequest->mSink; + auto sink = compileRequest->getSink(); validateIRModule(module, sink); } } diff --git a/source/slang/ir-validate.h b/source/slang/ir-validate.h index 0ebc69019..1cb30961d 100644 --- a/source/slang/ir-validate.h +++ b/source/slang/ir-validate.h @@ -3,7 +3,7 @@ namespace Slang { - class CompileRequest; + class CompileRequestBase; class DiagnosticSink; struct IRModule; @@ -30,6 +30,6 @@ namespace Slang // A wrapper that calls `validateIRModule` only when IR validation is enabled // for the given compile request. void validateIRModuleIfEnabled( - CompileRequest* compileRequest, - IRModule* module); + CompileRequestBase* compileRequest, + IRModule* module); } diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 0d9427b08..b53bb8ebb 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -300,8 +300,18 @@ struct IRGenEnv struct SharedIRGenContext { - CompileRequest* compileRequest; - ModuleDecl* mainModuleDecl; + SharedIRGenContext( + Session* session, + DiagnosticSink* sink, + ModuleDecl* mainModuleDecl = nullptr) + : m_session(session) + , m_sink(sink) + , m_mainModuleDecl(mainModuleDecl) + {} + + Session* m_session = nullptr; + DiagnosticSink* m_sink = nullptr; + ModuleDecl* m_mainModuleDecl = nullptr; // The "global" environment for mapping declarations to their IR values. IRGenEnv globalEnv; @@ -356,17 +366,17 @@ struct IRGenContext Session* getSession() { - return shared->compileRequest->mSession; + return shared->m_session; } - CompileRequest* getCompileRequest() + DiagnosticSink* getSink() { - return shared->compileRequest; + return shared->m_sink; } - DiagnosticSink* getSink() + ModuleDecl* getMainModuleDecl() { - return &getCompileRequest()->mSink; + return shared->m_mainModuleDecl; } }; @@ -422,7 +432,7 @@ bool isImportedDecl(IRGenContext* context, Decl* decl) if (isFromStdLib(decl)) return false; - if (moduleDecl != context->shared->mainModuleDecl) + if (moduleDecl != context->getMainModuleDecl()) return true; return false; @@ -1735,16 +1745,31 @@ static String getNameForNameHint( if(auto genericParentDecl = as<GenericDecl>(parentDecl)) parentDecl = genericParentDecl->ParentDecl; + // A `ModuleDecl` can have a name too, but in the common case + // we don't want to generate name hints that include the module + // name, simply because they would lead to every global symbol + // getting a much longer name. + // + // TODO: We should probably include the module name for symbols + // being `import`ed, and not for symbols being compiled directly + // (those coming from a module that had no name given to it). + // + // For now we skip past a `ModuleDecl` parent. + // + if(auto moduleParentDecl = as<ModuleDecl>(parentDecl)) + parentDecl = moduleParentDecl->ParentDecl; + + if(!parentDecl) + { + return leafName->text; + } + auto parentName = getNameForNameHint(context, parentDecl); if(parentName.Length() == 0) { return leafName->text; } - // TODO: at some point we will start giving `ModuleDecl`s names, - // and in that case we need to think carefully about whether to - // include their names here or not. - // We will now construct a new `Name` to use as the hint, // combining the name of the parent and the leaf declaration. @@ -3603,7 +3628,7 @@ void lowerStmt( catch(AbortCompilationException&) { throw; } catch(...) { - context->getCompileRequest()->noteInternalErrorLoc(stmt->loc); + context->getSink()->noteInternalErrorLoc(stmt->loc); throw; } } @@ -5877,7 +5902,7 @@ LoweredValInfo lowerDecl( catch(AbortCompilationException&) { throw; } catch(...) { - context->getCompileRequest()->noteInternalErrorLoc(decl->loc); + context->getSink()->noteInternalErrorLoc(decl->loc); throw; } } @@ -6108,56 +6133,59 @@ LoweredValInfo emitDeclRef( type); } -static void lowerEntryPointToIR( - IRGenContext* context, - EntryPointRequest* entryPointRequest) +static void lowerFrontEndEntryPointToIR( + IRGenContext* context, + EntryPoint* entryPoint) { - // First, lower the entry point like an ordinary function + // TODO: We should emit an entry point as a dedicated IR function + // (distinct from the IR function used if it were called normally), + // with a mangled name based on the original function name plus + // the stage for which it is being compiled as an entry point (so + // that entry points for distinct stages always have distinct names). + // + // For now we just have an (implicit) constraint that a given + // function should only be used as an entry point for one stage, + // and any such function should *not* be used as an ordinary function. - auto session = context->getSession(); - auto entryPointFuncDeclRef = entryPointRequest->getFuncDeclRef(); - auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); + auto entryPointFuncDecl = entryPoint->getFuncDecl(); auto builder = context->irBuilder; builder->setInsertInto(builder->getModule()->getModuleInst()); auto loweredEntryPointFunc = getSimpleVal(context, - emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); + ensureDecl(context, entryPointFuncDecl)); // Attach a marker decoration so that we recognize // this as an entry point. // - builder->addEntryPointDecoration(loweredEntryPointFunc); - - // - if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>()) + IRInst* instToDecorate = loweredEntryPointFunc; + if(auto irGeneric = as<IRGeneric>(instToDecorate)) { - builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); + instToDecorate = findGenericReturnVal(irGeneric); } + builder->addEntryPointDecoration(instToDecorate); +} - // Now lower all the arguments supplied for global generic - // type parameters. - // - for (RefPtr<Substitutions> subst = entryPointRequest->globalGenericSubst; subst; subst = subst->outer) - { - auto gSubst = subst.as<GlobalGenericParamSubstitution>(); - if(!gSubst) - continue; +static void lowerProgramEntryPointToIR( + IRGenContext* context, + EntryPoint* entryPoint) +{ + // First, lower the entry point like an ordinary function - IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl)); - IRType* typeVal = lowerType(context, gSubst->actualType); + auto session = context->getSession(); + auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); + auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef)); - // bind `typeParam` to `typeVal` - builder->emitBindGlobalGenericParam(typeParam, typeVal); + auto builder = context->irBuilder; + builder->setInsertInto(builder->getModule()->getModuleInst()); - for (auto& constraintArg : gSubst->constraintArgs) - { - IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl)); - IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val); + auto loweredEntryPointFunc = getSimpleVal(context, + emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType)); - // bind `constraintParam` to `constraintVal` - builder->emitBindGlobalGenericParam(constraintParam, constraintVal); - } + // + if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>()) + { + builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice()); } } @@ -6191,19 +6219,19 @@ IRModule* generateIRForTranslationUnit( { auto compileRequest = translationUnit->compileRequest; - SharedIRGenContext sharedContextStorage; + SharedIRGenContext sharedContextStorage( + translationUnit->getSession(), + translationUnit->compileRequest->getSink(), + translationUnit->getModuleDecl()); SharedIRGenContext* sharedContext = &sharedContextStorage; - sharedContext->compileRequest = compileRequest; - sharedContext->mainModuleDecl = translationUnit->SyntaxNode; - IRGenContext contextStorage(sharedContext); IRGenContext* context = &contextStorage; SharedIRBuilder sharedBuilderStorage; SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->module = nullptr; - sharedBuilder->session = compileRequest->mSession; + sharedBuilder->session = compileRequest->getSession(); IRBuilder builderStorage; IRBuilder* builder = &builderStorage; @@ -6224,12 +6252,13 @@ IRModule* generateIRForTranslationUnit( // in case they require special handling. for (auto entryPoint : translationUnit->entryPoints) { - lowerEntryPointToIR(context, entryPoint); + lowerFrontEndEntryPointToIR(context, entryPoint); } + // // Next, ensure that all other global declarations have // been emitted. - for (auto decl : translationUnit->SyntaxNode->Members) + for (auto decl : translationUnit->getModuleDecl()->Members) { ensureAllDeclsRec(context, decl); } @@ -6271,12 +6300,12 @@ IRModule* generateIRForTranslationUnit( // Propagate `constexpr`-ness through the dataflow graph (and the // call graph) based on constraints imposed by different instructions. - propagateConstExpr(module, &compileRequest->mSink); + propagateConstExpr(module, compileRequest->getSink()); // TODO: give error messages if any `undefined` or // `unreachable` instructions remain. - checkForMissingReturns(module, &compileRequest->mSink); + checkForMissingReturns(module, compileRequest->getSink()); // TODO: consider doing some more aggressive optimizations // (in particular specialization of generics) here, so @@ -6293,28 +6322,82 @@ IRModule* generateIRForTranslationUnit( // then we can dump the initial IR for the module here. if(compileRequest->shouldDumpIR) { - ISlangWriter* writer = translationUnit->compileRequest->getWriter(WriterChannel::StdError); - - dumpIR(module, writer); + DiagnosticSinkWriter writer(compileRequest->getSink()); + dumpIR(module, &writer); } return module; } -#if 0 -String emitSlangIRAssemblyForEntryPoint( - EntryPointRequest* entryPoint) +RefPtr<IRModule> generateIRForProgram( + Session* session, + Program* program, + DiagnosticSink* sink) { - auto compileRequest = entryPoint->compileRequest; - auto irModule = lowerEntryPointToIR( - entryPoint, - compileRequest->layout.Ptr(), - // TODO: we need to pick the target more carefully here - CodeGenTarget::HLSL); - - return getSlangIRAssembly(irModule); -} -#endif +// auto compileRequest = translationUnit->compileRequest; + + SharedIRGenContext sharedContextStorage( + session, + sink); + SharedIRGenContext* sharedContext = &sharedContextStorage; + + IRGenContext contextStorage(sharedContext); + IRGenContext* context = &contextStorage; + + SharedIRBuilder sharedBuilderStorage; + SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; + sharedBuilder->module = nullptr; + sharedBuilder->session = session; + + IRBuilder builderStorage; + IRBuilder* builder = &builderStorage; + builder->sharedBuilder = sharedBuilder; + + RefPtr<IRModule> module = builder->createModule(); + sharedBuilder->module = module; + + context->irBuilder = builder; + + // We need to emit symbols for all of the entry + // points in the program; this is especially + // important in the case where a generic entry + // point is being specialized. + // + for(auto entryPoint : program->getEntryPoints()) + { + lowerProgramEntryPointToIR(context, entryPoint); + } + + // Now lower all the arguments supplied for global generic + // type parameters. + // + for (RefPtr<Substitutions> subst = program->getGlobalGenericSubstitution(); subst; subst = subst->outer) + { + auto gSubst = subst.as<GlobalGenericParamSubstitution>(); + if(!gSubst) + continue; + + IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl)); + IRType* typeVal = lowerType(context, gSubst->actualType); + + // bind `typeParam` to `typeVal` + builder->emitBindGlobalGenericParam(typeParam, typeVal); + + for (auto& constraintArg : gSubst->constraintArgs) + { + IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl)); + IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val); + + // bind `constraintParam` to `constraintVal` + builder->emitBindGlobalGenericParam(constraintParam, constraintVal); + } + } + + // TODO: Should we apply any of the validation or + // mandatory optimization passes here? + + return module; +} } // namespace Slang diff --git a/source/slang/lower-to-ir.h b/source/slang/lower-to-ir.h index bd878d6fa..f607e852c 100644 --- a/source/slang/lower-to-ir.h +++ b/source/slang/lower-to-ir.h @@ -14,7 +14,7 @@ namespace Slang { class CompileRequest; - class EntryPointRequest; + class EntryPoint; class ProgramLayout; class TranslationUnitRequest; @@ -22,5 +22,10 @@ namespace Slang IRModule* generateIRForTranslationUnit( TranslationUnitRequest* translationUnit); + + RefPtr<IRModule> generateIRForProgram( + Session* session, + Program* program, + DiagnosticSink* sink); } #endif diff --git a/source/slang/options.cpp b/source/slang/options.cpp index 8a4cf35b7..65fbfe068 100644 --- a/source/slang/options.cpp +++ b/source/slang/options.cpp @@ -41,7 +41,7 @@ struct OptionsParser SlangSession* session = nullptr; SlangCompileRequest* compileRequest = nullptr; - Slang::CompileRequest* requestImpl = nullptr; + Slang::EndToEndCompileRequest* requestImpl = nullptr; Slang::RefPtr<Slang::ConfigurableSharedLibraryLoader> sharedLibraryLoader; @@ -313,7 +313,7 @@ struct OptionsParser if (sourceLanguage == SLANG_SOURCE_LANGUAGE_UNKNOWN) { - requestImpl->mSink.diagnose(SourceLoc(), Diagnostics::cannotDeduceSourceLanguage, inPath); + requestImpl->getSink()->diagnose(SourceLoc(), Diagnostics::cannotDeduceSourceLanguage, inPath); return SLANG_FAIL; } @@ -425,9 +425,9 @@ struct OptionsParser { // Copy some state out of the current request, in case we've been called // after some other initialization has been performed. - flags = requestImpl->compileFlags; + flags = requestImpl->getFrontEndReq()->compileFlags; - DiagnosticSink* sink = &requestImpl->mSink; + DiagnosticSink* sink = requestImpl->getSink(); SlangMatrixLayoutMode defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_MODE_UNKNOWN; @@ -450,23 +450,24 @@ struct OptionsParser } else if(argStr == "-dump-ir" ) { - requestImpl->shouldDumpIR = true; + requestImpl->getFrontEndReq()->shouldDumpIR = true; + requestImpl->getBackEndReq()->shouldDumpIR = true; } else if (argStr == "-serial-ir") { - requestImpl->useSerialIRBottleneck = true; + requestImpl->getFrontEndReq()->useSerialIRBottleneck = true; } else if (argStr == "-verbose-paths") { - requestImpl->mSink.flags |= DiagnosticSink::Flag::VerbosePath; + requestImpl->getSink()->flags |= DiagnosticSink::Flag::VerbosePath; } else if (argStr == "-verify-debug-serial-ir") { - requestImpl->verifyDebugSerialization = true; + requestImpl->getFrontEndReq()->verifyDebugSerialization = true; } else if(argStr == "-validate-ir" ) { - requestImpl->shouldValidateIR = true; + requestImpl->getFrontEndReq()->shouldValidateIR = true; } else if(argStr == "-skip-codegen" ) { @@ -1222,18 +1223,7 @@ struct OptionsParser // Now that we've diagnosed the output paths, we can add them // to the compile request at the appropriate locations. // - // We start by allocating the arrays for per-entry-point output - // paths on each of the requested targets. - // - for(auto rawTarget : rawTargets) - { - auto targetID = rawTarget.targetID; - auto targetReq = requestImpl->targets[targetID]; - - targetReq->entryPointOutputPaths.SetSize(rawEntryPoints.Count()); - } - - // Consider the output files specified via `-o` and try to figure + // We will consider the output files specified via `-o` and try to figure // out how to deal with them. // for(auto& rawOutput : rawOutputs) @@ -1242,18 +1232,26 @@ struct OptionsParser if(rawOutput.entryPointIndex == -1) continue; auto targetID = rawTargets[rawOutput.targetIndex].targetID; - auto entryPointID = rawEntryPoints[rawOutput.entryPointIndex].entryPointID; + Int entryPointID = rawEntryPoints[rawOutput.entryPointIndex].entryPointID; + + auto target = requestImpl->getLinkage()->targets[targetID]; + auto entryPointReq = requestImpl->getFrontEndReq()->getEntryPointReqs()[entryPointID]; - auto targetReq = requestImpl->targets[targetID]; + RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; + if( !requestImpl->targetInfos.TryGetValue(target, targetInfo) ) + { + targetInfo = new EndToEndCompileRequest::TargetInfo(); + requestImpl->targetInfos[target] = targetInfo; + } - if(targetReq->entryPointOutputPaths[entryPointID].Length()) + String outputPath; + if( targetInfo->entryPointOutputPaths.ContainsKey(entryPointID) ) { - auto entryPointReq = requestImpl->entryPoints[entryPointID]; - sink->diagnose(SourceLoc(), Diagnostics::duplicateOutputPathsForEntryPointAndTarget, entryPointReq->name, targetReq->target); + sink->diagnose(SourceLoc(), Diagnostics::duplicateOutputPathsForEntryPointAndTarget, entryPointReq->getName(), target->getTarget()); } else { - targetReq->entryPointOutputPaths[entryPointID] = rawOutput.path; + targetInfo->entryPointOutputPaths[entryPointID] = rawOutput.path; } } @@ -1272,16 +1270,16 @@ SlangResult parseOptions( int argc, char const* const* argv) { - Slang::CompileRequest* compileRequest = (Slang::CompileRequest*) compileRequestIn; + Slang::EndToEndCompileRequest* compileRequest = (Slang::EndToEndCompileRequest*) compileRequestIn; OptionsParser parser; parser.compileRequest = compileRequestIn; parser.requestImpl = compileRequest; - parser.session = (SlangSession*)compileRequest->mSession; + parser.session = (SlangSession*)compileRequest->getSession(); Result res = parser.parse(argc, argv); - DiagnosticSink* sink = &compileRequest->mSink; + DiagnosticSink* sink = compileRequest->getSink(); if (sink->GetErrorCount() > 0) { // Put the errors in the diagnostic diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index f0abfe31c..56c5d7c1d 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -309,9 +309,6 @@ struct ParameterInfo : RefObject ParameterBindingInfo bindingInfo[kLayoutResourceKindCount]; - // The next parameter that has the same name... - ParameterInfo* nextOfSameName; - // The translation unit this parameter is specific to, if any TranslationUnitRequest* translationUnit = nullptr; @@ -335,8 +332,22 @@ struct EntryPointParameterBindingContext // across all translation units struct SharedParameterBindingContext { - // The base compile request - CompileRequest* compileRequest; + SharedParameterBindingContext( + LayoutRulesFamilyImpl* defaultLayoutRules, + ProgramLayout* programLayout, + TargetRequest* targetReq, + DiagnosticSink* sink) + : defaultLayoutRules(defaultLayoutRules) + , programLayout(programLayout) + , targetRequest(targetReq) + , m_sink(sink) + { + } + + DiagnosticSink* m_sink = nullptr; + + // The program that we are laying out +// Program* program = nullptr; // The target request that is triggering layout // @@ -365,20 +376,18 @@ struct SharedParameterBindingContext UInt defaultSpace = 0; TargetRequest* getTargetRequest() { return targetRequest; } + DiagnosticSink* getSink() { return m_sink; } }; static DiagnosticSink* getSink(SharedParameterBindingContext* shared) { - return &shared->compileRequest->mSink; + return shared->getSink(); } // State that might be specific to a single translation unit // or event to an entry point. struct ParameterBindingContext { - // The translation unit we are processing right now - TranslationUnitRequest* translationUnit; - // All the shared state needs to be available SharedParameterBindingContext* shared; @@ -386,7 +395,7 @@ struct ParameterBindingContext // the resource usage of shader parameters. TypeLayoutContext layoutContext; - // A dictionary to accellerate looking up parameters by name + // A dictionary to accelerate looking up parameters by name Dictionary<Name*, ParameterInfo*> mapNameToParameterInfo; // What stage (if any) are we compiling for? @@ -395,9 +404,6 @@ struct ParameterBindingContext // The entry point that is being processed right now. EntryPointLayout* entryPointLayout = nullptr; - // The source language we are trying to use - SourceLanguage sourceLanguage; - TargetRequest* getTargetRequest() { return shared->getTargetRequest(); } LayoutRulesFamilyImpl* getRulesFamily() { return layoutContext.getRulesFamily(); } }; @@ -1217,6 +1223,10 @@ static void collectGlobalScopeParameter( // If that is the case, we want to re-use the same `VarLayout` // across both parameters. // + // TODO: This logic currently detects *any* global-scope parameters + // with matching names, but it should eventually be narrowly + // scoped so that it only applies to parameters from unnamed modules. + // // First we look for an existing entry matching the name // of this parameter: auto parameterName = getReflectionName(varDecl); @@ -2477,7 +2487,7 @@ static ParameterBindingAndKindInfo maybeAllocateConstantBufferBinding( /// static void collectEntryPointParameters( ParameterBindingContext* context, - EntryPointRequest* entryPoint, + EntryPoint* entryPoint, SubstitutionSet typeSubst) { DeclRef<FuncDecl> entryPointFuncDeclRef = entryPoint->getFuncDeclRef(); @@ -2486,7 +2496,7 @@ static void collectEntryPointParameters( // the `EntryPointLayout` object here. // RefPtr<EntryPointLayout> entryPointLayout = new EntryPointLayout(); - entryPointLayout->profile = entryPoint->profile; + entryPointLayout->profile = entryPoint->getProfile(); entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl(); // The entry point layout must be added to the output @@ -2501,10 +2511,10 @@ static void collectEntryPointParameters( // Note: this isn't really the best place for this logic to sit, // but it is the simplest place where we have a direct correspondence - // between a single `EntryPointRequest` and its matching `EntryPointLayout`, + // between a single `EntryPoint` and its matching `EntryPointLayout`, // so we'll use it. // - for( auto taggedUnionType : entryPoint->taggedUnionTypes ) + for( auto taggedUnionType : entryPoint->getTaggedUnionTypes() ) { SLANG_ASSERT(taggedUnionType); auto substType = taggedUnionType->Substitute(typeSubst).as<Type>(); @@ -2645,59 +2655,9 @@ static void collectEntryPointParameters( } } -// When doing parameter binding for global-scope stuff in GLSL, -// we may need to know what stage we are compiling for, so that -// we can handle special cases appropriately (e.g., "arrayed" -// inputs and outputs). -static Stage -inferStageForTranslationUnit( - TranslationUnitRequest* translationUnit) -{ - // In the specific case where we are compiling GLSL input, - // and have only a single entry point, use the stage - // of the entry point. - // - // TODO: now that we've dropped official GLSL support, - // we probably should drop this as well. - // - if( translationUnit->sourceLanguage == SourceLanguage::GLSL ) - { - if( translationUnit->entryPoints.Count() == 1 ) - { - return translationUnit->entryPoints[0]->getStage(); - } - } - - return Stage::Unknown; -} - -static void collectModuleParameters( - ParameterBindingContext* inContext, - ModuleDecl* module) -{ - // Each loaded module provides a separate (logical) namespace for - // parameters, so that two parameters with the same name, in - // distinct modules, should yield different bindings. - // - ParameterBindingContext contextData = *inContext; - auto context = &contextData; - - context->translationUnit = nullptr; - - context->stage = Stage::Unknown; - - // All imported modules are implicitly Slang code - context->sourceLanguage = SourceLanguage::Slang; - - // A loaded module cannot define entry points that - // we'll expose (for now), so we just need to - // consider global-scope parameters. - collectGlobalScopeParameters(context, module); -} - static void collectParameters( ParameterBindingContext* inContext, - CompileRequest* request) + Program* program) { // All of the parameters in translation units directly // referenced in the compile request are part of one @@ -2707,29 +2667,21 @@ static void collectParameters( ParameterBindingContext contextData = *inContext; auto context = &contextData; - for( auto& translationUnit : request->translationUnits ) + for(RefPtr<Module> module : program->getModuleDependencies()) { - context->translationUnit = translationUnit; - context->stage = inferStageForTranslationUnit(translationUnit.Ptr()); - context->sourceLanguage = translationUnit->sourceLanguage; + context->stage = Stage::Unknown; // First look at global-scope parameters - collectGlobalScopeParameters(context, translationUnit->SyntaxNode.Ptr()); - - // Next consider parameters for entry points - for( auto& entryPoint : translationUnit->entryPoints ) - { - context->stage = entryPoint->getStage(); - collectEntryPointParameters(context, entryPoint.Ptr(), SubstitutionSet()); - } - context->entryPointLayout = nullptr; + collectGlobalScopeParameters(context, module->getModuleDecl()); } - // Now collect parameters from loaded modules - for (auto& loadedModule : request->loadedModulesList) + // Next consider parameters for entry points + for(auto entryPoint : program->getEntryPoints()) { - collectModuleParameters(context, loadedModule->moduleDecl.Ptr()); + context->stage = entryPoint->getStage(); + collectEntryPointParameters(context, entryPoint, SubstitutionSet()); } + context->entryPointLayout = nullptr; } /// Emit a diagnostic about a uniform parameter at global scope. @@ -2770,41 +2722,40 @@ static int _calcTotalNumUsedRegistersForLayoutResourceKind(ParameterBindingConte return numUsed; } -void generateParameterBindings( - TargetRequest* targetReq) +RefPtr<ProgramLayout> generateParameterBindings( + TargetProgram* targetProgram, + DiagnosticSink* sink) { - CompileRequest* compileReq = targetReq->compileRequest; + auto program = targetProgram->getProgram(); + auto targetReq = targetProgram->getTargetReq(); + + RefPtr<ProgramLayout> programLayout = new ProgramLayout(); + programLayout->targetProgram = targetProgram; // Try to find rules based on the selected code-generation target - auto layoutContext = getInitialLayoutContextForTarget(targetReq); + auto layoutContext = getInitialLayoutContextForTarget(targetReq, programLayout); // If there was no target, or there are no rules for the target, // then bail out here. if (!layoutContext.rules) - return; - - RefPtr<ProgramLayout> programLayout = new ProgramLayout(); - programLayout->targetRequest = targetReq; - - targetReq->layout = programLayout; + return nullptr; // Create a context to hold shared state during the process // of generating parameter bindings - SharedParameterBindingContext sharedContext; - sharedContext.compileRequest = compileReq; - sharedContext.defaultLayoutRules = layoutContext.getRulesFamily(); - sharedContext.programLayout = programLayout; - sharedContext.targetRequest = targetReq; + SharedParameterBindingContext sharedContext( + layoutContext.getRulesFamily(), + programLayout, + targetReq, + sink); // Create a sub-context to collect parameters that get // declared into the global scope ParameterBindingContext context; context.shared = &sharedContext; - context.translationUnit = nullptr; context.layoutContext = layoutContext; // Walk through AST to discover all the parameters - collectParameters(&context, compileReq); + collectParameters(&context, program); // Now walk through the parameters to generate initial binding information for( auto& parameter : sharedContext.parameters ) @@ -2978,17 +2929,35 @@ void generateParameterBindings( const int numShaderRecordRegs = _calcTotalNumUsedRegistersForLayoutResourceKind(&context, LayoutResourceKind::ShaderRecord); if (numShaderRecordRegs > 1) { - compileReq->mSink.diagnose(SourceLoc(), Diagnostics::tooManyShaderRecordConstantBuffers, numShaderRecordRegs); - return; + sink->diagnose(SourceLoc(), Diagnostics::tooManyShaderRecordConstantBuffers, numShaderRecordRegs); } } + return programLayout; +} + +ProgramLayout* TargetProgram::getOrCreateLayout(DiagnosticSink* sink) +{ + if( !m_layout ) + { + m_layout = generateParameterBindings(this, sink); + } + return m_layout; +} + +void generateParameterBindings( + Program* program, + TargetRequest* targetReq, + DiagnosticSink* sink) +{ + program->getTargetProgram(targetReq)->getOrCreateLayout(sink); } RefPtr<ProgramLayout> specializeProgramLayout( TargetRequest* targetReq, ProgramLayout* oldProgramLayout, - SubstitutionSet typeSubst) + SubstitutionSet typeSubst, + DiagnosticSink* sink) { // The goal of the layout specialization step is to take an existing `ProgramLayout`, // and add a layout to any parameter(s) that could not be laid out previously, because @@ -3006,7 +2975,7 @@ RefPtr<ProgramLayout> specializeProgramLayout( RefPtr<ProgramLayout> newProgramLayout; newProgramLayout = new ProgramLayout(); - newProgramLayout->targetRequest = targetReq; + newProgramLayout->targetProgram = oldProgramLayout->targetProgram; newProgramLayout->globalGenericParams = oldProgramLayout->globalGenericParams; // The basic idea will be to iterate over the parameters in the old layout, @@ -3020,18 +2989,17 @@ RefPtr<ProgramLayout> specializeProgramLayout( // We will use the same kind of context type as the original parameter binding // step did, so we initialize its state here: - auto layoutContext = getInitialLayoutContextForTarget(targetReq); + auto layoutContext = getInitialLayoutContextForTarget(targetReq, newProgramLayout); SLANG_ASSERT(layoutContext.rules); - SharedParameterBindingContext sharedContext; - sharedContext.compileRequest = targetReq->compileRequest; - sharedContext.defaultLayoutRules = layoutContext.getRulesFamily(); - sharedContext.programLayout = newProgramLayout; - sharedContext.targetRequest = targetReq; + SharedParameterBindingContext sharedContext( + layoutContext.getRulesFamily(), + newProgramLayout, + targetReq, + sink); ParameterBindingContext context; context.shared = &sharedContext; - context.translationUnit = nullptr; context.layoutContext = layoutContext; // We will also need state for laying out any global-scope parameters @@ -3119,12 +3087,9 @@ RefPtr<ProgramLayout> specializeProgramLayout( // parameter, the layout of its parameter list strictly follows // the declaration order. // - for (auto & translationUnit : targetReq->compileRequest->translationUnits) + for( auto entryPoint : oldProgramLayout->getProgram()->getEntryPoints() ) { - for (auto & entryPoint : translationUnit->entryPoints) - { - collectEntryPointParameters(&context, entryPoint, typeSubst); - } + collectEntryPointParameters(&context, entryPoint, typeSubst); context.entryPointLayout = nullptr; } diff --git a/source/slang/parameter-binding.h b/source/slang/parameter-binding.h index eb093821f..82b114021 100644 --- a/source/slang/parameter-binding.h +++ b/source/slang/parameter-binding.h @@ -8,6 +8,7 @@ namespace Slang { +class Program; class TargetRequest; // The parameter-binding interface is responsible for assigning @@ -24,7 +25,9 @@ class TargetRequest; // of the program. void generateParameterBindings( - TargetRequest* targetReq); + Program* program, + TargetRequest* targetReq, + DiagnosticSink* sink); } diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index e2085eb7d..3abc47ede 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -8,7 +8,7 @@ namespace Slang { - // Pre-declare + // pre-declare static Name* getName(Parser* parser, String const& text); // Helper class useful to build a list of modifiers. @@ -79,7 +79,11 @@ namespace Slang class Parser { public: - TranslationUnitRequest* translationUnit; + NamePool* namePool; + SourceLanguage sourceLanguage; + + NamePool* getNamePool() { return namePool; } + SourceLanguage getSourceLanguage() { return sourceLanguage; } int anonymousCounter = 0; @@ -124,27 +128,26 @@ namespace Slang currentScope = currentScope->parent; } Parser( + Session* session, TokenSpan const& _tokens, DiagnosticSink * sink, RefPtr<Scope> const& outerScope) : tokenReader(_tokens) , sink(sink) , outerScope(outerScope) + , m_session(session) {} Parser(const Parser & other) = default; - Session* getSession() - { - return translationUnit->compileRequest->mSession; - } - RefPtr<ModuleDecl> Parse(); + Session* m_session = nullptr; + Session* getSession() { return m_session; } + Token ReadToken(); Token ReadToken(TokenType type); Token ReadToken(const char * string); bool LookAheadToken(TokenType type, int offset = 0); bool LookAheadToken(const char * string, int offset = 0); void parseSourceFile(ModuleDecl* program); - RefPtr<ModuleDecl> ParseProgram(); RefPtr<Decl> ParseStruct(); RefPtr<ClassDecl> ParseClass(); RefPtr<Stmt> ParseStatement(); @@ -578,11 +581,6 @@ namespace Slang return false; } - RefPtr<ModuleDecl> Parser::Parse() - { - return ParseProgram(); - } - RefPtr<RefObject> ParseTypeDef(Parser* parser, void* /*userData*/) { RefPtr<TypeDefDecl> typeDefDecl = new TypeDefDecl(); @@ -694,7 +692,7 @@ namespace Slang Token token(TokenType::Identifier, scopedIdentifier, scopedIdSourceLoc); // Get the name pool - auto namePool = parser->translationUnit->compileRequest->getNamePool(); + auto namePool = parser->getNamePool(); // Since it's an Identifier have to set the name. token.ptrValue = namePool->getName(token.Content); @@ -910,7 +908,7 @@ namespace Slang static Name* getName(Parser* parser, String const& text) { - return parser->translationUnit->compileRequest->getNamePool()->getName(text); + return parser->getNamePool()->getName(text); } static NameLoc expectIdentifier(Parser* parser) @@ -1859,7 +1857,7 @@ namespace Slang } // GLSL allows `[]` directly in a type specifier - if (parser->translationUnit->sourceLanguage == SourceLanguage::GLSL) + if (parser->getSourceLanguage() == SourceLanguage::GLSL) { typeExpr = parsePostfixTypeSuffix(parser, typeExpr); } @@ -1929,7 +1927,7 @@ namespace Slang // Just as a safety net, only apply this logic for // a file that is being passed in as "true" Slang code. // - if(parser->translationUnit->sourceLanguage == SourceLanguage::Slang) + if(parser->getSourceLanguage() == SourceLanguage::Slang) { if(typeSpec.decl) { @@ -2313,171 +2311,6 @@ namespace Slang return ParseHLSLBufferDecl(parser, "TextureBuffer"); } - static void removeModifier( - Modifiers& modifiers, - RefPtr<Modifier> modifier) - { - RefPtr<Modifier>* link = &modifiers.first; - while (*link) - { - if (*link == modifier) - { - *link = (*link)->next; - return; - } - - link = &(*link)->next; - } - } - - static RefPtr<Decl> parseGLSLBlockDecl( - Parser* parser, - Modifiers& modifiers) - { - // An GLSL block like this: - // - // uniform Foo { int a; float b; } foo; - // - // is treated as syntax sugar for a type declaration - // and then a global variable declaration using that type: - // - // struct $anonymous { int a; float b; }; - // Block<$anonymous> foo; - // - // where `$anonymous` is a fresh name. - // - // If a "local name" like `foo` is not given, then - // we make the declaration "transparent" so that lookup - // will see through it to the members inside. - - - SourceLoc pos = parser->tokenReader.PeekLoc(); - - // The initial name before the `{` is only supposed - // to be made visible to reflection - auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); - - // Look at the qualifiers present on the block to decide what kind - // of block we are looking at. Also *remove* those qualifiers so - // that they don't interfere with downstream work. - String blockWrapperTypeName; - if( auto uniformMod = modifiers.findModifier<HLSLUniformModifier>() ) - { - removeModifier(modifiers, uniformMod); - blockWrapperTypeName = "ConstantBuffer"; - } - else if( auto inMod = modifiers.findModifier<InModifier>() ) - { - removeModifier(modifiers, inMod); - blockWrapperTypeName = "__GLSLInputParameterGroup"; - } - else if( auto outMod = modifiers.findModifier<OutModifier>() ) - { - removeModifier(modifiers, outMod); - blockWrapperTypeName = "__GLSLOutputParameterGroup"; - } - else if( auto bufferMod = modifiers.findModifier<GLSLBufferModifier>() ) - { - removeModifier(modifiers, bufferMod); - blockWrapperTypeName = "__GLSLShaderStorageBuffer"; - } - else - { - // Unknown case: just map to a constant buffer and hope for the best - blockWrapperTypeName = "ConstantBuffer"; - } - - // We are going to represent each buffer as a pair of declarations. - // The first is a type declaration that holds all the members, while - // the second is a variable declaration that uses the buffer type. - RefPtr<StructDecl> blockDataTypeDecl = new StructDecl(); - RefPtr<VarDecl> blockVarDecl = new VarDecl(); - - addModifier(blockDataTypeDecl, new ImplicitParameterGroupElementTypeModifier()); - addModifier(blockVarDecl, new ImplicitParameterGroupVariableModifier()); - - // Attach the reflection name to the block so we can use it - auto reflectionNameModifier = new ParameterGroupReflectionName(); - reflectionNameModifier->nameAndLoc = NameLoc(reflectionNameToken); - addModifier(blockVarDecl, reflectionNameModifier); - - // Both declarations will have a location that points to the name - parser->FillPosition(blockDataTypeDecl.Ptr()); - parser->FillPosition(blockVarDecl.Ptr()); - - // Generate a unique name for the data type - blockDataTypeDecl->nameAndLoc.name = generateName(parser, "ParameterGroup_" + String(reflectionNameToken.Content)); - - // TODO(tfoley): We end up constructing unchecked syntax here that - // is expected to type check into the right form, but it might be - // cleaner to have a more explicit desugaring pass where we parse - // these constructs directly into the AST and *then* desugar them. - - // Construct a type expression to reference the buffer data type - auto blockDataTypeExpr = new VarExpr(); - blockDataTypeExpr->loc = blockDataTypeDecl->loc; - blockDataTypeExpr->name = blockDataTypeDecl->getName(); - blockDataTypeExpr->scope = parser->currentScope.Ptr(); - - // Construct a type exrpession to reference the type constructor - auto blockWrapperTypeExpr = new VarExpr(); - blockWrapperTypeExpr->loc = pos; - blockWrapperTypeExpr->name = getName(parser, blockWrapperTypeName); - // Always need to look this up in the outer scope, - // so that it won't collide with, e.g., a local variable called `ConstantBuffer` - blockWrapperTypeExpr->scope = parser->outerScope; - - // Construct a type expression that represents the type for the variable, - // which is the wrapper type applied to the data type - auto blockVarTypeExpr = new GenericAppExpr(); - blockVarTypeExpr->loc = blockVarDecl->loc; - blockVarTypeExpr->FunctionExpr = blockWrapperTypeExpr; - blockVarTypeExpr->Arguments.Add(blockDataTypeExpr); - - blockVarDecl->type.exp = blockVarTypeExpr; - - // The declarations in the body belong to the data type. - parseAggTypeDeclBody(parser, blockDataTypeDecl.Ptr()); - - if( parser->LookAheadToken(TokenType::Identifier) ) - { - // The user gave an explicit name to the block, - // so we need to use that as our variable name - blockVarDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier)); - - // TODO: in this case we make actually have a more complex - // declarator, including `[]` brackets. - } - else - { - // synthesize a dummy name - blockVarDecl->nameAndLoc.name = generateName(parser, "parameterGroup_" + String(reflectionNameToken.Content)); - - // Otherwise we have a transparent declaration, similar - // to an HLSL `cbuffer` - auto transparentModifier = new TransparentModifier(); - transparentModifier->loc = pos; - addModifier(blockVarDecl, transparentModifier); - } - - // Expect a trailing `;` - parser->ReadToken(TokenType::Semicolon); - - // Because we are constructing two declarations, we have a thorny - // issue that were are only supposed to return one. - // For now we handle this by adding the type declaration to - // the current scope manually, and then returning the variable - // declaration. - // - // Note: this means that any modifiers that have already been parsed - // will get attached to the variable declaration, not the type. - // There might be cases where we need to shuffle things around. - - AddMember(parser->currentScope, blockDataTypeDecl); - - return blockVarDecl; - } - static void parseOptionalInheritanceClause(Parser* parser, AggTypeDeclBase* decl) { if (AdvanceIf(parser, TokenType::Colon)) @@ -3020,27 +2853,8 @@ namespace Slang // // - A keyword-based declaration (e.g., `cbuffer ...`) // - The beginning of a type in a declarator-based declaration (e.g., `int ...`) - // - A GLSL block declaration (e.g., `uniform Foo { ... }`) - - // Let's deal with the GLSL block case first. This is something like: - // - // uniform Foo { ... }; - // - // The `uniform` keyword has already been parsed as a modifier, - // so the identifier we are looking at is `Foo`. If the token - // after that is `{`, we assume this is a block. - // - // Of course, we only want to allow this syntax when parsing GLSL... - if (parser->translationUnit->sourceLanguage == SourceLanguage::GLSL) - { - if( parser->LookAheadToken(TokenType::LBrace, 1) ) - { - decl = parseGLSLBlockDecl(parser, modifiers); - break; - } - } - // Next we will check whether we can use the identifier token + // First we will check whether we can use the identifier token // as a declaration keyword and parse a declaration using // its associated callback: RefPtr<Decl> parsedDecl; @@ -3184,15 +2998,6 @@ namespace Slang currentScope = nullptr; } - RefPtr<ModuleDecl> Parser::ParseProgram() - { - RefPtr<ModuleDecl> program = new ModuleDecl(); - - parseSourceFile(program.Ptr()); - - return program; - } - RefPtr<Decl> Parser::ParseStruct() { RefPtr<StructDecl> rs = new StructDecl(); @@ -3591,7 +3396,7 @@ namespace Slang // parsing HLSL code. // - bool brokenScoping = translationUnit->sourceLanguage == SourceLanguage::HLSL; + bool brokenScoping = getSourceLanguage() == SourceLanguage::HLSL; // We will create a distinct syntax node class for the unscoped // case, just so that we can correctly handle it in downstream @@ -4439,14 +4244,18 @@ namespace Slang return parsePrefixExpr(this); } - RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit, + RefPtr<Expr> parseTypeFromSourceFile( + Session* session, TokenSpan const& tokens, DiagnosticSink* sink, - RefPtr<Scope> const& outerScope) + RefPtr<Scope> const& outerScope, + NamePool* namePool, + SourceLanguage sourceLanguage) { - Parser parser(tokens, sink, outerScope); - parser.translationUnit = translationUnit; + Parser parser(session, tokens, sink, outerScope); parser.currentScope = outerScope; + parser.namePool = namePool; + parser.sourceLanguage = sourceLanguage; return parser.ParseType(); } @@ -4457,12 +4266,11 @@ namespace Slang DiagnosticSink* sink, RefPtr<Scope> const& outerScope) { - Parser parser(tokens, sink, outerScope); - - parser.translationUnit = translationUnit; - + Parser parser(translationUnit->getSession(), tokens, sink, outerScope); + parser.namePool = translationUnit->getNamePool(); + parser.sourceLanguage = translationUnit->sourceLanguage; - return parser.parseSourceFile(translationUnit->SyntaxNode.Ptr()); + return parser.parseSourceFile(translationUnit->getModuleDecl()); } static void addBuiltinSyntaxImpl( diff --git a/source/slang/parser.h b/source/slang/parser.h index 785b6e345..abad902da 100644 --- a/source/slang/parser.h +++ b/source/slang/parser.h @@ -14,10 +14,13 @@ namespace Slang DiagnosticSink* sink, RefPtr<Scope> const& outerScope); - RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit, + RefPtr<Expr> parseTypeFromSourceFile( + Session* session, TokenSpan const& tokens, DiagnosticSink* sink, - RefPtr<Scope> const& outerScope); + RefPtr<Scope> const& outerScope, + NamePool* namePool, + SourceLanguage sourceLanguage); RefPtr<ModuleDecl> populateBaseLanguageModule( Session* session, diff --git a/source/slang/preprocessor.cpp b/source/slang/preprocessor.cpp index c6c438ef6..103db7dcb 100644 --- a/source/slang/preprocessor.cpp +++ b/source/slang/preprocessor.cpp @@ -194,27 +194,18 @@ struct Preprocessor // represent end-of-input situations. Token endOfFileToken; - // The translation unit that is being parsed - TranslationUnitRequest* translationUnit; + /// The linkage the provides the context for preprocessing + Linkage* linkage = nullptr; + + /// The module, if any, that the preprocessed result will belong to + Module* parentModule = nullptr; // The unique identities of any paths that have issued `#pragma once` directives to // stop them from being included again. HashSet<String> pragmaOnceUniqueIdentities; - TranslationUnitRequest* getTranslationUnit() - { - return translationUnit; - } - - ModuleDecl* getSyntax() - { - return getTranslationUnit()->SyntaxNode.Ptr(); - } - - CompileRequest* getCompileRequest() - { - return getTranslationUnit()->compileRequest; - } + NamePool* getNamePool() { return linkage->getNamePool(); } + SourceManager* getSourceManager() { return linkage->getSourceManager(); } }; // Convenience routine to access the diagnostic sink @@ -255,11 +246,6 @@ static void destroyInputStream(Preprocessor* /*preprocessor*/, PreprocessorInput delete inputStream; } -static NamePool* getNamePool(Preprocessor* preprocessor) -{ - return preprocessor->translationUnit->compileRequest->getNamePool(); -} - // Create an input stream to represent a pre-tokenized input file. // TODO(tfoley): pre-tokenizing files isn't going to work in the long run. static PreprocessorInputStream* CreateInputStreamForSource( @@ -272,7 +258,7 @@ static PreprocessorInputStream* CreateInputStreamForSource( initializePrimaryInputStream(preprocessor, inputStream); // initialize the embedded lexer so that it can generate a token stream - inputStream->lexer.initialize(sourceView, GetSink(preprocessor), getNamePool(preprocessor), memoryArena); + inputStream->lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), memoryArena); inputStream->token = inputStream->lexer.lexToken(); return inputStream; @@ -836,7 +822,7 @@ top: // Now re-lex the input - SourceManager* sourceManager = preprocessor->getCompileRequest()->getSourceManager(); + SourceManager* sourceManager = preprocessor->getSourceManager(); // We create a dummy file to represent the token-paste operation PathInfo pathInfo = PathInfo::makeTokenPaste(); @@ -845,7 +831,7 @@ top: SourceView* sourceView = sourceManager->createSourceView(sourceFile, nullptr); Lexer lexer; - lexer.initialize(sourceView, GetSink(preprocessor), getNamePool(preprocessor), sourceManager->getMemoryArena()); + lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena()); SimpleTokenInputStream* inputStream = new SimpleTokenInputStream(); initializeInputStream(preprocessor, inputStream); @@ -1564,7 +1550,7 @@ static void HandleEndIfDirective(PreprocessorDirectiveContext* context) // we expect it. // // Most directives do not need to call this directly, since we have -// a catch-all case in the main `HandleDirective()` funciton. +// a catch-all case in the main `HandleDirective()` function. // The `#include` case will call it directly to avoid complications // when it switches the input stream. static void expectEndOfDirective(PreprocessorDirectiveContext* context) @@ -1589,6 +1575,31 @@ static void expectEndOfDirective(PreprocessorDirectiveContext* context) AdvanceRawToken(context->preprocessor); } + /// Read a file in the context of handling a preprocessor directive +static SlangResult readFile( + PreprocessorDirectiveContext* context, + String const& path, + ISlangBlob** outBlob) +{ + // The actual file loading will be handled by the file system + // associated with the parent linkage. + // + auto linkage = context->preprocessor->linkage; + auto fileSystemExt = linkage->getFileSystemExt(); + SLANG_RETURN_ON_FAIL(fileSystemExt->loadFile(path.Buffer(), outBlob)); + + // If we are running the preprocessor as part of compiling a + // specific module, then we must keep track of the file we've + // read as yet another file that the module will depend on. + // + if(auto module = context->preprocessor->parentModule) + { + module->addFilePathDependency(path); + } + + return SLANG_OK; +} + // Handle a `#include` directive static void HandleIncludeDirective(PreprocessorDirectiveContext* context) { @@ -1603,7 +1614,7 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context) auto directiveLoc = GetDirectiveLoc(context); - PathInfo includedFromPathInfo = context->preprocessor->translationUnit->compileRequest->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); + PathInfo includedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); IncludeHandler* includeHandler = context->preprocessor->includeHandler; if (!includeHandler) @@ -1644,7 +1655,7 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context) // Push the new file onto our stack of input streams // TODO(tfoley): check if we have made our include stack too deep - auto sourceManager = context->preprocessor->getCompileRequest()->getSourceManager(); + auto sourceManager = context->preprocessor->getSourceManager(); // See if this an already loaded source file SourceFile* sourceFile = sourceManager->findSourceFileRecursively(filePathInfo.uniqueIdentity); @@ -1652,7 +1663,7 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context) if (!sourceFile) { ComPtr<ISlangBlob> foundSourceBlob; - if (SLANG_FAILED(includeHandler->readFile(filePathInfo.foundPath, foundSourceBlob.writeRef()))) + if (SLANG_FAILED(readFile(context, filePathInfo.foundPath, foundSourceBlob.writeRef()))) { GetSink(context)->diagnose(pathToken.loc, Diagnostics::includeFailed, path); return; @@ -1843,7 +1854,7 @@ static void HandleLineDirective(PreprocessorDirectiveContext* context) return; } - auto sourceManager = context->preprocessor->translationUnit->compileRequest->getSourceManager(); + auto sourceManager = context->preprocessor->getSourceManager(); String file; if (PeekTokenType(context) == TokenType::EndOfDirective) @@ -1891,7 +1902,7 @@ SLANG_PRAGMA_DIRECTIVE_CALLBACK(handlePragmaOnceDirective) // We are using the 'uniqueIdentity' as determined by the ISlangFileSystemEx interface to determine file identities. auto directiveLoc = GetDirectiveLoc(context); - auto issuedFromPathInfo = context->preprocessor->translationUnit->compileRequest->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); + auto issuedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual); // Must have uniqueIdentity for a #pragma once to work if (!issuedFromPathInfo.hasUniqueIdentity()) @@ -1962,82 +1973,6 @@ static void HandlePragmaDirective(PreprocessorDirectiveContext* context) (subDirective->callback)(context, subDirectiveToken); } -// Handle a `#version` directive -static void handleGLSLVersionDirective(PreprocessorDirectiveContext* context) -{ - Token versionNumberToken; - if(!ExpectRaw( - context, - TokenType::IntegerLiteral, - Diagnostics::expectedTokenInPreprocessorDirective, - &versionNumberToken)) - { - return; - } - - Token glslProfileToken; - if(PeekTokenType(context) == TokenType::Identifier) - { - glslProfileToken = AdvanceToken(context); - } - - // Need to construct a representation taht we can hook into our compilation result - - auto modifier = new GLSLVersionDirective(); - modifier->versionNumberToken = versionNumberToken; - modifier->glslProfileToken = glslProfileToken; - - // Attach the modifier to the program we are parsing! - - addModifier( - context->preprocessor->getSyntax(), - modifier); -} - -// Handle a `#extension` directive, e.g., -// -// #extension some_extension_name : enable -// -static void handleGLSLExtensionDirective(PreprocessorDirectiveContext* context) -{ - Token extensionNameToken; - if(!ExpectRaw( - context, - TokenType::Identifier, - Diagnostics::expectedTokenInPreprocessorDirective, - &extensionNameToken)) - { - return; - } - - if( !ExpectRaw(context, TokenType::Colon, Diagnostics::expectedTokenInPreprocessorDirective) ) - { - return; - } - - Token dispositionToken; - if(!ExpectRaw( - context, - TokenType::Identifier, - Diagnostics::expectedTokenInPreprocessorDirective, - &dispositionToken)) - { - return; - } - - // Need to construct a representation taht we can hook into our compilation result - - auto modifier = new GLSLExtensionDirective(); - modifier->extensionNameToken = extensionNameToken; - modifier->dispositionToken = dispositionToken; - - // Attach the modifier to the program we are parsing! - - addModifier( - context->preprocessor->getSyntax(), - modifier); -} - // Handle an invalid directive static void HandleInvalidDirective(PreprocessorDirectiveContext* context) { @@ -2092,11 +2027,6 @@ static const PreprocessorDirective kDirectives[] = { "line", &HandleLineDirective, 0 }, { "pragma", &HandlePragmaDirective, 0 }, - // TODO(tfoley): These are specific to GLSL, and probably - // shouldn't be enabled for HLSL or Slang - { "version", &handleGLSLVersionDirective, 0 }, - { "extension", &handleGLSLExtensionDirective, 0 }, - { nullptr, nullptr, 0 }, }; @@ -2270,7 +2200,7 @@ static void DefineMacro( PreprocessorMacro* macro = CreateMacro(preprocessor); - auto sourceManager = preprocessor->translationUnit->compileRequest->getSourceManager(); + auto sourceManager = preprocessor->getSourceManager(); SourceFile* keyFile = sourceManager->createSourceFileWithString(pathInfo, key); SourceFile* valueFile = sourceManager->createSourceFileWithString(pathInfo, value); @@ -2280,10 +2210,10 @@ static void DefineMacro( // Use existing `Lexer` to generate a token stream. Lexer lexer; - lexer.initialize(valueView, GetSink(preprocessor), getNamePool(preprocessor), sourceManager->getMemoryArena()); + lexer.initialize(valueView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena()); macro->tokens = lexer.lexAllTokens(); - Name* keyName = preprocessor->translationUnit->compileRequest->getNamePool()->getName(key); + Name* keyName = preprocessor->getNamePool()->getName(key); macro->nameAndLoc.name = keyName; macro->nameAndLoc.loc = keyView->getRange().begin; @@ -2321,11 +2251,13 @@ TokenList preprocessSource( DiagnosticSink* sink, IncludeHandler* includeHandler, Dictionary<String, String> defines, - TranslationUnitRequest* translationUnit) + Linkage* linkage, + Module* parentModule) { Preprocessor preprocessor; InitializePreprocessor(&preprocessor, sink); - preprocessor.translationUnit = translationUnit; + preprocessor.linkage = linkage; + preprocessor.parentModule = parentModule; preprocessor.includeHandler = includeHandler; for (auto p : defines) @@ -2333,7 +2265,7 @@ TokenList preprocessSource( DefineMacro(&preprocessor, p.Key, p.Value); } - SourceManager* sourceManager = translationUnit->compileRequest->getSourceManager(); + SourceManager* sourceManager = linkage->getSourceManager(); SourceView* sourceView = sourceManager->createSourceView(file, nullptr); diff --git a/source/slang/preprocessor.h b/source/slang/preprocessor.h index 4d02cb50b..6e8ac1c69 100644 --- a/source/slang/preprocessor.h +++ b/source/slang/preprocessor.h @@ -8,8 +8,9 @@ namespace Slang { class DiagnosticSink; +class Linkage; +class Module; class ModuleDecl; -class TranslationUnitRequest; // Callback interface for the preprocessor to use when looking // for files in `#include` directives. @@ -20,9 +21,6 @@ struct IncludeHandler const String& pathIncludedFrom, PathInfo& pathInfoOut) = 0; - virtual SlangResult readFile(const String& path, - ISlangBlob** blobOut) = 0; - virtual String simplifyPath(const String& path) = 0; }; @@ -32,7 +30,8 @@ TokenList preprocessSource( DiagnosticSink* sink, IncludeHandler* includeHandler, Dictionary<String, String> defines, - TranslationUnitRequest* translationUnit); + Linkage* linkage, + Module* parentModule); } // namespace Slang diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index 2b0be98c9..b40900faf 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -585,10 +585,15 @@ SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType) SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * reflection, char const * name) { - auto context = convert(reflection); - auto compileRequest = context->targetRequest->compileRequest; + auto programLayout = convert(reflection); + auto program = programLayout->getProgram(); + + // TODO: We should extend this API to support getting error messages + // when type lookup fails. + // + Slang::DiagnosticSink sink; - RefPtr<Type> result = compileRequest->getTypeFromString(name); + RefPtr<Type> result = program->getTypeFromString(name, &sink); return (SlangReflectionType*)result.Ptr(); } @@ -599,12 +604,13 @@ SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout( { auto context = convert(reflection); auto type = convert(inType); - auto layoutContext = getInitialLayoutContextForTarget(context->targetRequest); + auto targetReq = context->getTargetReq(); + auto layoutContext = getInitialLayoutContextForTarget(targetReq, context); RefPtr<TypeLayout> result; - if (context->targetRequest->typeLayouts.TryGetValue(type, result)) + if (targetReq->getTypeLayouts().TryGetValue(type, result)) return (SlangReflectionTypeLayout*)result.Ptr(); result = CreateTypeLayout(layoutContext, type); - context->targetRequest->typeLayouts[type] = result; + targetReq->getTypeLayouts()[type] = result; return (SlangReflectionTypeLayout*)result.Ptr(); } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 16dfe8618..7a5b58d07 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -60,6 +60,8 @@ Session::Session() // Make sure our source manager is initialized builtinSourceManager.initialize(nullptr, nullptr); + m_builtinLinkage = new Linkage(this); + // Initialize representations of some very basic types: initializeTypes(); @@ -90,11 +92,12 @@ Session::Session() struct IncludeHandlerImpl : IncludeHandler { - CompileRequest* request; + Linkage* linkage; + SearchDirectoryList* searchDirectories; ISlangFileSystemExt* _getFileSystemExt() { - return request->fileSystemExt; + return linkage->getFileSystemExt(); } SlangResult _findFile(SlangPathType fromPathType, const String& fromPath, const String& path, PathInfo& pathInfoOut) @@ -153,18 +156,22 @@ struct IncludeHandlerImpl : IncludeHandler } // Search all the searchDirectories - for (auto & dir : request->searchDirectories) + for(auto sd = searchDirectories; sd; sd = sd->parent) { - SlangResult res = _findFile(SLANG_PATH_TYPE_DIRECTORY, dir.path, pathToInclude, pathInfoOut); - if (SLANG_SUCCEEDED(res) || res != SLANG_E_NOT_FOUND) + for(auto& dir : sd->searchDirectories) { - return res; + SlangResult res = _findFile(SLANG_PATH_TYPE_DIRECTORY, dir.path, pathToInclude, pathInfoOut); + if (SLANG_SUCCEEDED(res) || res != SLANG_E_NOT_FOUND) + { + return res; + } } } return SLANG_E_NOT_FOUND; } +#if 0 virtual SlangResult readFile(const String& path, ISlangBlob** blobOut) override { @@ -175,6 +182,7 @@ struct IncludeHandlerImpl : IncludeHandler return SLANG_OK; } +#endif virtual String simplifyPath(const String& path) override { @@ -192,9 +200,9 @@ struct IncludeHandlerImpl : IncludeHandler // -Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target) +Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) { - auto entryPointProfile = entryPoint->profile; + auto entryPointProfile = entryPoint->getProfile(); auto targetProfile = target->targetProfile; // Depending on the target *format* we might have to restrict the @@ -310,20 +318,13 @@ Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target // -CompileRequest::CompileRequest(Session* session) - : mSession(session) +Linkage::Linkage(Session* session) + : m_session(session) + , m_sourceManager(&m_defaultSourceManager) { getNamePool()->setRootNamePool(session->getRootNamePool()); - setSourceManager(&sourceManagerStorage); - - sourceManager->initialize(session->getBuiltinSourceManager(), nullptr); - - // Set all the default writers - for (int i = 0; i < int(WriterChannel::CountOf); ++i) - { - setWriter(WriterChannel(i), nullptr); - } + m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr); setFileSystem(nullptr); } @@ -379,10 +380,61 @@ ComPtr<ISlangBlob> createRawBlob(void const* inData, size_t size) } // +// TargetRequest +// + +Session* TargetRequest::getSession() +{ + return linkage->getSession(); +} MatrixLayoutMode TargetRequest::getDefaultMatrixLayoutMode() { - return compileRequest->getDefaultMatrixLayoutMode(); + return linkage->getDefaultMatrixLayoutMode(); +} + +// +// TranslationUnitRequest +// + +TranslationUnitRequest::TranslationUnitRequest( + FrontEndCompileRequest* compileRequest) + : compileRequest(compileRequest) +{ + module = new Module(compileRequest->getLinkage()); +} + + +Session* TranslationUnitRequest::getSession() +{ + return compileRequest->getSession(); +} + +NamePool* TranslationUnitRequest::getNamePool() +{ + return compileRequest->getNamePool(); +} + +SourceManager* TranslationUnitRequest::getSourceManager() +{ + return compileRequest->getSourceManager(); +} + +void TranslationUnitRequest::addSourceFile(SourceFile* sourceFile) +{ + m_sourceFiles.Add(sourceFile); + + // We want to record that the compiled module has a dependency + // on the path of the source file, but we also need to account + // for cases where the user added a source string/blob without + // an associated path (so that the API passes along an empty + // string). + // + auto path = sourceFile->getPathInfo().foundPath; + if(path.Length()) + { + getModule()->addFilePathDependency(path); + } } @@ -407,7 +459,7 @@ static ISlangWriter* _getDefaultWriter(WriterChannel chan) } } -void CompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) +void EndToEndCompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) { // If the user passed in null, we will use the default writer on that channel m_writers[int(chan)] = writer ? writer : _getDefaultWriter(chan); @@ -415,20 +467,20 @@ void CompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) // For diagnostic output, if the user passes in nullptr, we set on mSink.writer as that enables buffering on DiagnosticSink if (chan == WriterChannel::Diagnostic) { - mSink.writer = writer; + m_sink.writer = writer; } } -SlangResult CompileRequest::loadFile(String const& path, ISlangBlob** outBlob) +SlangResult Linkage::loadFile(String const& path, ISlangBlob** outBlob) { return fileSystemExt->loadFile(path.Buffer(), outBlob); } -RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope) +RefPtr<Expr> Linkage::parseTypeString(String typeStr, RefPtr<Scope> scope) { // Create a SourceManager on the stack, so any allocations for 'SourceFile'/'SourceView' etc will be cleaned up SourceManager localSourceManager; - localSourceManager.initialize(sourceManager, nullptr); + localSourceManager.initialize(getSourceManager(), nullptr); Slang::SourceFile* srcFile = localSourceManager.createSourceFileWithString(PathInfo::makeTypeParse(), typeStr); @@ -440,20 +492,20 @@ RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translatio // Use RAII - to make sure everything is reset even if an exception is thrown. struct ScopeReplaceSourceManager { - ScopeReplaceSourceManager(CompileRequest* request, SourceManager* replaceManager): - m_request(request), - m_originalSourceManager(request->getSourceManager()) + ScopeReplaceSourceManager(Linkage* linkage, SourceManager* replaceManager): + m_linkage(linkage), + m_originalSourceManager(linkage->getSourceManager()) { - request->setSourceManager(replaceManager); + linkage->setSourceManager(replaceManager); } ~ScopeReplaceSourceManager() { - m_request->setSourceManager(m_originalSourceManager); + m_linkage->setSourceManager(m_originalSourceManager); } private: - CompileRequest* m_request; + Linkage* m_linkage; SourceManager* m_originalSourceManager; }; @@ -465,87 +517,131 @@ RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translatio &sink, nullptr, Dictionary<String,String>(), - translationUnit); + this, + nullptr); - return parseTypeFromSourceFile(translationUnit, tokens, &sink, scope); + return parseTypeFromSourceFile( + getSession(), + tokens, &sink, scope, getNamePool(), SourceLanguage::Slang); } -RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp); -Type* CompileRequest::getTypeFromString(String typeStr) +RefPtr<Type> checkProperType( + Linkage* linkage, + TypeExp typeExp, + DiagnosticSink* sink); + +Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) { + // If we've looked up this type name before, + // then we can re-use it. + // RefPtr<Type> type; - if (types.TryGetValue(typeStr, type)) + if(m_types.TryGetValue(typeStr, type)) return type; - auto translationUnit = translationUnits.First(); + + // Otherwise, we need to start looking in + // the modules that were directly or + // indirectly referenced. + // + // TODO: This `scopesToTry` idiom appears + // all over the code, and isn't really + // how we should be handling this kind of + // lookup at all. + // List<RefPtr<Scope>> scopesToTry; - for (auto tu : translationUnits) - scopesToTry.Add(tu->SyntaxNode->scope); - for (auto & module : loadedModulesList) - scopesToTry.Add(module->moduleDecl->scope); - // parse type name - for (auto & s : scopesToTry) - { - RefPtr<Expr> typeExpr = parseTypeString(translationUnit, + for(auto module : getModuleDependencies()) + scopesToTry.Add(module->getModuleDecl()->scope); + + auto linkage = getLinkage(); + for(auto& s : scopesToTry) + { + RefPtr<Expr> typeExpr = linkage->parseTypeString( typeStr, s); - type = checkProperType(translationUnit, TypeExp(typeExpr)); - if (type) + type = checkProperType(linkage, TypeExp(typeExpr), sink); + if(type) break; } - if (type) + if( type ) { - types[typeStr] = type; + m_types[typeStr] = type; } - return type.Ptr(); + return type; } -void CompileRequest::parseTranslationUnit( +CompileRequestBase::CompileRequestBase( + Linkage* linkage, + DiagnosticSink* sink) + : m_linkage(linkage) + , m_sink(sink) +{} + + +FrontEndCompileRequest::FrontEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink) + : CompileRequestBase(linkage, sink) +{ +} + +void FrontEndCompileRequest::parseTranslationUnit( TranslationUnitRequest* translationUnit) { IncludeHandlerImpl includeHandler; - includeHandler.request = this; + includeHandler.linkage = getLinkage(); + includeHandler.searchDirectories = &searchDirectories; RefPtr<Scope> languageScope; switch (translationUnit->sourceLanguage) { case SourceLanguage::HLSL: - languageScope = mSession->hlslLanguageScope; + languageScope = getSession()->hlslLanguageScope; break; case SourceLanguage::Slang: default: - languageScope = mSession->slangLanguageScope; + languageScope = getSession()->slangLanguageScope; break; } Dictionary<String, String> combinedPreprocessorDefinitions; + for(auto& def : getLinkage()->preprocessorDefinitions) + combinedPreprocessorDefinitions.Add(def.Key, def.Value); for(auto& def : preprocessorDefinitions) combinedPreprocessorDefinitions.Add(def.Key, def.Value); for(auto& def : translationUnit->preprocessorDefinitions) combinedPreprocessorDefinitions.Add(def.Key, def.Value); + auto module = translationUnit->getModule(); RefPtr<ModuleDecl> translationUnitSyntax = new ModuleDecl(); - translationUnit->SyntaxNode = translationUnitSyntax; + translationUnitSyntax->nameAndLoc.name = translationUnit->moduleName; + translationUnitSyntax->module = module; + module->setModuleDecl(translationUnitSyntax); - for (auto sourceFile : translationUnit->sourceFiles) + for (auto sourceFile : translationUnit->getSourceFiles()) { auto tokens = preprocessSource( sourceFile, - &mSink, + getSink(), &includeHandler, combinedPreprocessorDefinitions, - translationUnit); + getLinkage(), + module); parseSourceFile( translationUnit, tokens, - &mSink, + getSink(), languageScope); } } -void validateEntryPoints(CompileRequest*); +RefPtr<Program> createUnspecializedProgram( + FrontEndCompileRequest* compileRequest); -void CompileRequest::checkAllTranslationUnits() +RefPtr<Program> createSpecializedProgram( + EndToEndCompileRequest* endToEndReq); + +void FrontEndCompileRequest::checkAllTranslationUnits() { // Iterate over all translation units and // apply the semantic checking logic. @@ -553,12 +649,9 @@ void CompileRequest::checkAllTranslationUnits() { checkTranslationUnit(translationUnit.Ptr()); } - - // Next, do follow-up validation on any entry points. - validateEntryPoints(this); } -void CompileRequest::generateIR() +void FrontEndCompileRequest::generateIR() { // Our task in this function is to generate IR code // for all of the declarations in the translation @@ -581,9 +674,9 @@ void CompileRequest::generateIR() if (verifyDebugSerialization) { // Verify debug information - if (SLANG_FAILED(IRSerialUtil::verifySerialize(irModule, mSession, sourceManager, IRSerialBinary::CompressionType::None, IRSerialWriter::OptionFlag::DebugInfo))) + if (SLANG_FAILED(IRSerialUtil::verifySerialize(irModule, getSession(), getSourceManager(), IRSerialBinary::CompressionType::None, IRSerialWriter::OptionFlag::DebugInfo))) { - mSink.diagnose(irModule->moduleInst->sourceLoc, Diagnostics::serialDebugVerificationFailed); + getSink()->diagnose(irModule->moduleInst->sourceLoc, Diagnostics::serialDebugVerificationFailed); } } @@ -593,7 +686,7 @@ void CompileRequest::generateIR() { // Write IR out to serialData - copying over SourceLoc information directly IRSerialWriter writer; - writer.write(irModule, sourceManager, IRSerialWriter::OptionFlag::RawSourceLocation, &serialData); + writer.write(irModule, getSourceManager(), IRSerialWriter::OptionFlag::RawSourceLocation, &serialData); // Destroy irModule such that memory can be used for newly constructed read irReadModule irModule = nullptr; @@ -602,7 +695,7 @@ void CompileRequest::generateIR() { // Read IR back from serialData IRSerialReader reader; - reader.read(serialData, mSession, nullptr, irReadModule); + reader.read(serialData, getSession(), nullptr, irReadModule); } // Set irModule to the read module @@ -610,12 +703,12 @@ void CompileRequest::generateIR() } // Set the module on the translation unit - translationUnit->irModule = irModule; + translationUnit->getModule()->setIRModule(irModule); } } // Try to infer a single common source language for a request -static SourceLanguage inferSourceLanguage(CompileRequest* request) +static SourceLanguage inferSourceLanguage(FrontEndCompileRequest* request) { SourceLanguage language = SourceLanguage::Unknown; for (auto& translationUnit : request->translationUnits) @@ -639,29 +732,115 @@ static SourceLanguage inferSourceLanguage(CompileRequest* request) return language; } -SlangResult CompileRequest::executeActionsInner() +SlangResult FrontEndCompileRequest::executeActionsInner() { - // Do some cleanup on settings specified by user. - // In particular, we want to propagate flags from the overall request down to - // each translation unit. + // We currently allow GlSL files on the command line so that we can + // drive our "pass-through" mode, but we really want to issue an error + // message if the user is seriously asking us to compile them. for (auto& translationUnit : translationUnits) { - translationUnit->compileFlags |= compileFlags; + switch(translationUnit->sourceLanguage) + { + default: + break; + + case SourceLanguage::GLSL: + getSink()->diagnose(SourceLoc(), Diagnostics::glslIsNotSupported); + return SLANG_FAIL; + } + } + + + // Parse everything from the input files requested + for (auto& translationUnit : translationUnits) + { + parseTranslationUnit(translationUnit.Ptr()); } + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + // Perform semantic checking on the whole collection + checkAllTranslationUnits(); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + // Look up all the entry points that are expected, + // and use them to populate the `program` member. + // + m_program = createUnspecializedProgram(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + if ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) + { + // Generate initial IR for all the translation + // units, if we are in a mode where IR is called for. + generateIR(); + } + + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + // Do parameter binding generation, for each compilation target. + // + for(auto targetReq : getLinkage()->targets) + { + auto targetProgram = m_program->getTargetProgram(targetReq); + targetProgram->getOrCreateLayout(getSink()); + } + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + return SLANG_OK; +} + +BackEndCompileRequest::BackEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink, + Program* program) + : CompileRequestBase(linkage, sink) + , m_program(program) +{} + +EndToEndCompileRequest::EndToEndCompileRequest( + Session* session) + : m_session(session) +{ + m_linkage = new Linkage(session); + + m_sink.sourceManager = m_linkage->getSourceManager(); + + // Set all the default writers + for (int i = 0; i < int(WriterChannel::CountOf); ++i) + { + setWriter(WriterChannel(i), nullptr); + } + + m_frontEndReq = new FrontEndCompileRequest(getLinkage(), getSink()); + + m_backEndReq = new BackEndCompileRequest(getLinkage(), getSink()); +} + +SlangResult EndToEndCompileRequest::executeActionsInner() +{ // If no code-generation target was specified, then try to infer one from the source language, // just to make sure we can do something reasonable when invoked from the command line. - if (targets.Count() == 0) + // + // TODO: This logic should be moved into `options.cpp` or somewhere else + // specific to the command-line tool. + // + if (getLinkage()->targets.Count() == 0) { - auto language = inferSourceLanguage(this); + auto language = inferSourceLanguage(getFrontEndReq()); switch (language) { case SourceLanguage::HLSL: - addTarget(CodeGenTarget::DXBytecode); + getLinkage()->addTarget(CodeGenTarget::DXBytecode); break; case SourceLanguage::GLSL: - addTarget(CodeGenTarget::SPIRV); + getLinkage()->addTarget(CodeGenTarget::SPIRV); break; default: @@ -672,105 +851,117 @@ SlangResult CompileRequest::executeActionsInner() // We only do parsing and semantic checking if we *aren't* doing // a pass-through compilation. // - // Note that we *do* perform output generation as normal in pass-through mode. if (passThrough == PassThroughMode::None) { - // We currently allow GlSL files on the command line so that we can - // drive our "pass-through" mode, but we really want to issue an error - // message if the user is seriously asking us to compile them. - for (auto& translationUnit : translationUnits) - { - switch(translationUnit->sourceLanguage) - { - default: - break; - - case SourceLanguage::GLSL: - mSink.diagnose(SourceLoc(), Diagnostics::glslIsNotSupported); - return SLANG_FAIL; - } - } + SLANG_RETURN_ON_FAIL(getFrontEndReq()->executeActionsInner()); + } + // If command line specifies to skip codegen, we exit here. + // Note: this is a debugging option. + // + if (shouldSkipCodegen || + ((getFrontEndReq()->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) != 0)) + { + // We will use the program (and matching layout information) + // that was computed in the front-end for all subsequent + // reflection queries, etc. + // + m_specializedProgram = getUnspecializedProgram(); - // Parse everything from the input files requested - for (auto& translationUnit : translationUnits) - { - parseTranslationUnit(translationUnit.Ptr()); - } - if (mSink.GetErrorCount() != 0) - return SLANG_FAIL; + return SLANG_OK; + } - // Perform semantic checking on the whole collection - checkAllTranslationUnits(); - if (mSink.GetErrorCount() != 0) + // If codegen is enabled, we need to move along to + // apply any generic specialization that the user asked for. + // + if (passThrough == PassThroughMode::None) + { + m_specializedProgram = createSpecializedProgram(this); + if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; - if ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) + // For each code generation target, we will generate specialized + // parameter binding information (taking global generic + // arguments into account at this time). + // + for (auto targetReq : getLinkage()->targets) { - // Generate initial IR for all the translation - // units, if we are in a mode where IR is called for. - generateIR(); + auto targetProgram = m_specializedProgram->getTargetProgram(targetReq); + targetProgram->getOrCreateLayout(getSink()); } - - if (mSink.GetErrorCount() != 0) + if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; - - // For each code generation target generate - // parameter binding information. - // This step is done globally, because all translation - // units and entry points need to agree on where - // parameters are allocated. - for (auto targetReq : targets) + } + else + { + // We need to create dummy `EntryPoint` objects + // to make sure that the logic in `generateOutput` + // sees something worth processing. + // + auto specializedProgram = new Program(getLinkage()); + m_specializedProgram = specializedProgram; + for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs()) { - generateParameterBindings(targetReq); - if (mSink.GetErrorCount() != 0) - return SLANG_FAIL; + RefPtr<EntryPoint> entryPoint = EntryPoint::createDummyForPassThrough( + entryPointReq->getName(), + entryPointReq->getProfile()); + + specializedProgram->addEntryPoint(entryPoint); } } - // If command line specifies to skip codegen, we exit here. - // Note: this is a debugging option. - if (shouldSkipCodegen || - ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) != 0)) - return SLANG_OK; - // Generate output code, in whatever format was requested + getBackEndReq()->setProgram(getSpecializedProgram()); generateOutput(this); - if (mSink.GetErrorCount() != 0) + if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; return SLANG_OK; } // Act as expected of the API-based compiler -SlangResult CompileRequest::executeActions() +SlangResult EndToEndCompileRequest::executeActions() { SlangResult res = executeActionsInner(); - mDiagnosticOutput = mSink.outputBuffer.ProduceString(); + mDiagnosticOutput = getSink()->outputBuffer.ProduceString(); return res; } -int CompileRequest::addTranslationUnit(SourceLanguage language, String const&) +int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* moduleName) { UInt result = translationUnits.Count(); - RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(); + RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(this); translationUnit->compileRequest = this; translationUnit->sourceLanguage = SourceLanguage(language); + translationUnit->moduleName = moduleName; + translationUnits.Add(translationUnit); return (int) result; } -void CompileRequest::addTranslationUnitSourceFile( +int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language) +{ + // We want to ensure that symbols defined in different translation + // units get unique mangled names, so that we can, e.g., tell apart + // a `main()` function in `vertex.slang` and a `main()` in `fragment.slang`, + // even when they are being compiled together. + // + String generatedName = "tu"; + generatedName.append(translationUnits.Count()); + return addTranslationUnit(language, getNamePool()->getName(generatedName)); +} + +void FrontEndCompileRequest::addTranslationUnitSourceFile( int translationUnitIndex, SourceFile* sourceFile) { - translationUnits[translationUnitIndex]->sourceFiles.Add(sourceFile); + translationUnits[translationUnitIndex]->addSourceFile(sourceFile); } -void CompileRequest::addTranslationUnitSourceBlob( +void FrontEndCompileRequest::addTranslationUnitSourceBlob( int translationUnitIndex, String const& path, ISlangBlob* sourceBlob) @@ -781,7 +972,7 @@ void CompileRequest::addTranslationUnitSourceBlob( addTranslationUnitSourceFile(translationUnitIndex, sourceFile); } -void CompileRequest::addTranslationUnitSourceString( +void FrontEndCompileRequest::addTranslationUnitSourceString( int translationUnitIndex, String const& path, String const& source) @@ -792,7 +983,7 @@ void CompileRequest::addTranslationUnitSourceString( addTranslationUnitSourceFile(translationUnitIndex, sourceFile); } -void CompileRequest::addTranslationUnitSourceFile( +void FrontEndCompileRequest::addTranslationUnitSourceFile( int translationUnitIndex, String const& path) { @@ -809,7 +1000,7 @@ void CompileRequest::addTranslationUnitSourceFile( if(SLANG_FAILED(result)) { // Emit a diagnostic! - mSink.diagnose( + getSink()->diagnose( SourceLoc(), Diagnostics::cannotOpenFile, path); @@ -820,36 +1011,51 @@ void CompileRequest::addTranslationUnitSourceFile( translationUnitIndex, path, sourceBlob); +} + +int FrontEndCompileRequest::addEntryPoint( + int translationUnitIndex, + String const& name, + Profile entryPointProfile) +{ + auto translationUnitReq = translationUnits[translationUnitIndex]; + + UInt result = m_entryPointReqs.Count(); + + RefPtr<FrontEndEntryPointRequest> entryPointReq = new FrontEndEntryPointRequest( + this, + translationUnitIndex, + getNamePool()->getName(name), + entryPointProfile); + + m_entryPointReqs.Add(entryPointReq); +// translationUnitReq->entryPoints.Add(entryPointReq); - mDependencyFilePaths.Add(path); + return int(result); } -int CompileRequest::addEntryPoint( +int EndToEndCompileRequest::addEntryPoint( int translationUnitIndex, String const& name, Profile entryPointProfile, List<String> const & genericTypeNames) { - RefPtr<EntryPointRequest> entryPoint = new EntryPointRequest(); - entryPoint->compileRequest = this; - entryPoint->name = getNamePool()->getName(name); - entryPoint->profile = entryPointProfile; - entryPoint->translationUnitIndex = translationUnitIndex; + getFrontEndReq()->addEntryPoint(translationUnitIndex, name, entryPointProfile); + + EntryPointInfo entryPointInfo; for (auto typeName : genericTypeNames) - entryPoint->genericArgStrings.Add(typeName); - auto translationUnit = translationUnits[translationUnitIndex].Ptr(); - translationUnit->entryPoints.Add(entryPoint); + entryPointInfo.genericArgStrings.Add(typeName); UInt result = entryPoints.Count(); - entryPoints.Add(entryPoint); + entryPoints.Add(_Move(entryPointInfo)); return (int) result; } -UInt CompileRequest::addTarget( +UInt Linkage::addTarget( CodeGenTarget target) { RefPtr<TargetRequest> targetReq = new TargetRequest(); - targetReq->compileRequest = this; + targetReq->linkage = this; targetReq->target = target; UInt result = targets.Count(); @@ -857,15 +1063,16 @@ UInt CompileRequest::addTarget( return (int) result; } -void CompileRequest::loadParsedModule( - RefPtr<TranslationUnitRequest> const& translationUnit, - Name* name, - const PathInfo& pathInfo) +void Linkage::loadParsedModule( + RefPtr<TranslationUnitRequest> translationUnit, + Name* name, + const PathInfo& pathInfo) { // Note: we add the loaded module to our name->module listing // before doing semantic checking, so that if it tries to // recursively `import` itself, we can detect it. - RefPtr<LoadedModule> loadedModule = new LoadedModule(); + // + RefPtr<Module> loadedModule = translationUnit->getModule(); // Get a path String mostUniqueIdentity = pathInfo.getMostUniqueIdentity(); @@ -874,12 +1081,11 @@ void CompileRequest::loadParsedModule( mapPathToLoadedModule.Add(mostUniqueIdentity, loadedModule); mapNameToLoadedModules.Add(name, loadedModule); - int errorCountBefore = mSink.GetErrorCount(); - checkTranslationUnit(translationUnit.Ptr()); - int errorCountAfter = mSink.GetErrorCount(); + auto sink = translationUnit->compileRequest->getSink(); - RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode; - loadedModule->moduleDecl = moduleDecl; + int errorCountBefore = sink->GetErrorCount(); + checkTranslationUnit(translationUnit.Ptr()); + int errorCountAfter = sink->GetErrorCount(); if (errorCountAfter != errorCountBefore) { @@ -890,39 +1096,56 @@ void CompileRequest::loadParsedModule( // If we didn't run into any errors, then try to generate // IR code for the imported module. SLANG_ASSERT(errorCountAfter == 0); - loadedModule->irModule = generateIRForTranslationUnit(translationUnit); + loadedModule->setIRModule(generateIRForTranslationUnit(translationUnit)); } loadedModulesList.Add(loadedModule); } -RefPtr<ModuleDecl> CompileRequest::loadModule( +Module* Linkage::loadModule(String const& name) +{ + // TODO: We either need to have a diagnostics sink + // get passed into this operation, or associate + // one with the linkage. + // + DiagnosticSink* sink = nullptr; + return findOrImportModule( + getNamePool()->getName(name), + SourceLoc(), + sink); +} + + +RefPtr<Module> Linkage::loadModule( Name* name, const PathInfo& filePathInfo, ISlangBlob* sourceBlob, - SourceLoc const& srcLoc) + SourceLoc const& srcLoc, + DiagnosticSink* sink) { - RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(); - translationUnit->compileRequest = this; + RefPtr<FrontEndCompileRequest> frontEndReq = new FrontEndCompileRequest(this, sink); - // We don't want to use the same options that the user specified - // for loading modules on-demand. In particular, we always want - // semantic checking to be enabled. - // - // TODO: decide which options, if any, should be inherited. - translationUnit->compileFlags = 0; + RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(frontEndReq); + translationUnit->compileRequest = frontEndReq; + translationUnit->moduleName = name; + + auto module = translationUnit->getModule(); + + ModuleBeingImportedRAII moduleBeingImported( + this, + module); // Create with the 'friendly' name SourceFile* sourceFile = getSourceManager()->createSourceFileWithBlob(filePathInfo, sourceBlob); - translationUnit->sourceFiles.Add(sourceFile); + translationUnit->addSourceFile(sourceFile); - int errorCountBefore = mSink.GetErrorCount(); - parseTranslationUnit(translationUnit.Ptr()); - int errorCountAfter = mSink.GetErrorCount(); + int errorCountBefore = sink->GetErrorCount(); + frontEndReq->parseTranslationUnit(translationUnit); + int errorCountAfter = sink->GetErrorCount(); if( errorCountAfter != errorCountBefore ) { - mSink.diagnose(srcLoc, Diagnostics::errorInImportedModule); + sink->diagnose(srcLoc, Diagnostics::errorInImportedModule); } if (errorCountAfter) { @@ -935,38 +1158,57 @@ RefPtr<ModuleDecl> CompileRequest::loadModule( name, filePathInfo); - errorCountAfter = mSink.GetErrorCount(); + errorCountAfter = sink->GetErrorCount(); if (errorCountAfter != errorCountBefore) { - mSink.diagnose(srcLoc, Diagnostics::errorInImportedModule); + sink->diagnose(srcLoc, Diagnostics::errorInImportedModule); // Something went wrong during the parsing, so we should bail out. return nullptr; } - return translationUnit->SyntaxNode; + return module; +} + +bool Linkage::isBeingImported(Module* module) +{ + for(auto ii = m_modulesBeingImported; ii; ii = ii->next) + { + if(module == ii->module) + return true; + } + return false; } -RefPtr<ModuleDecl> CompileRequest::findOrImportModule( +RefPtr<Module> Linkage::findOrImportModule( Name* name, - SourceLoc const& loc) + SourceLoc const& loc, + DiagnosticSink* sink) { // Have we already loaded a module matching this name? - // If so, return it. + // RefPtr<LoadedModule> loadedModule; if (mapNameToLoadedModules.TryGetValue(name, loadedModule)) { + // If the map shows a null module having been loaded, + // then that means there was a prior load attempt, + // but it failed, so we won't bother trying again. + // if (!loadedModule) return nullptr; - if (!loadedModule->moduleDecl) + // If state shows us that the module is already being + // imported deeper on the call stack, then we've + // hit a recursive case, and that is an error. + // + if(isBeingImported(loadedModule)) { // We seem to be in the middle of loading this module - mSink.diagnose(loc, Diagnostics::recursiveModuleImport, name); + sink->diagnose(loc, Diagnostics::recursiveModuleImport, name); return nullptr; } - return loadedModule->moduleDecl; + return loadedModule; } // Derive a file name for the module, by taking the given @@ -991,7 +1233,8 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( // using our ordinary include-handling logic. IncludeHandlerImpl includeHandler; - includeHandler.request = this; + includeHandler.linkage = this; + includeHandler.searchDirectories = &searchDirectories; // Get the original path info PathInfo pathIncludedFromInfo = getSourceManager()->getPathInfo(loc, SourceLocType::Actual); @@ -1000,20 +1243,20 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( // We have to load via the found path - as that is how file was originally loaded if (SLANG_FAILED(includeHandler.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo))) { - this->mSink.diagnose(loc, Diagnostics::cannotFindFile, fileName); + sink->diagnose(loc, Diagnostics::cannotFindFile, fileName); mapNameToLoadedModules[name] = nullptr; return nullptr; } // Maybe this was loaded previously at a different relative name? if (mapPathToLoadedModule.TryGetValue(filePathInfo.getMostUniqueIdentity(), loadedModule)) - return loadedModule->moduleDecl; + return loadedModule; // Try to load it ComPtr<ISlangBlob> fileContents; - if (SLANG_FAILED(includeHandler.readFile(filePathInfo.foundPath, fileContents.writeRef()))) + if(SLANG_FAILED(getFileSystemExt()->loadFile(filePathInfo.foundPath.Buffer(), fileContents.writeRef()))) { - this->mSink.diagnose(loc, Diagnostics::cannotOpenFile, fileName); + sink->diagnose(loc, Diagnostics::cannotOpenFile, fileName); mapNameToLoadedModules[name] = nullptr; return nullptr; } @@ -1024,26 +1267,159 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( name, filePathInfo, fileContents, - loc); + loc, + sink); } -Decl * CompileRequest::lookupGlobalDecl(Name * name) +// +// ModuleDependencyList +// + +void ModuleDependencyList::addDependency(Module* module) { - Decl* resultDecl = nullptr; - for (auto module : loadedModulesList) + // If we depend on a module, then we depend on everything it depends on. + // + // Note: We are processing these sub-depenencies before adding the + // `module` itself, so that in the common case a module will always + // appear *after* everything it depends on. + // + // However, this rule is being violated in the compiler right now because + // the modules for hte top-level translation units of a compile request + // will be added to the list first (using `addLeafDependency`) to + // maintain compatibility with old behavior. This may be fixed later. + // + for(auto subDependency : module->getModuleDependencyList()) { - if (module->moduleDecl->memberDictionary.TryGetValue(name, resultDecl)) - break; + _addDependency(subDependency); + } + _addDependency(module); +} + +void ModuleDependencyList::addLeafDependency(Module* module) +{ + _addDependency(module); +} + +void ModuleDependencyList::_addDependency(Module* module) +{ + if(m_moduleSet.Contains(module)) + return; + + m_moduleList.Add(module); + m_moduleSet.Add(module); +} + +// +// FilePathDependencyList +// + +void FilePathDependencyList::addDependency(String const& path) +{ + if(m_filePathSet.Contains(path)) + return; + + m_filePathList.Add(path); + m_filePathSet.Add(path); +} + +void FilePathDependencyList::addDependency(Module* module) +{ + for(auto& path : module->getFilePathDependencyList()) + { + addDependency(path); } - for (auto transUnit : translationUnits) +} + + + +// +// Module +// + +Module::Module(Linkage* linkage) + : m_linkage(linkage) +{} + + +void Module::addModuleDependency(Module* module) +{ + m_moduleDependencyList.addDependency(module); + m_filePathDependencyList.addDependency(module); +} + +void Module::addFilePathDependency(String const& path) +{ + m_filePathDependencyList.addDependency(path); +} + +// Program + +Program::Program(Linkage* linkage) + : m_linkage(linkage) +{} + +void Program::addReferencedModule(Module* module) +{ + m_moduleDependencyList.addDependency(module); + m_filePathDependencyList.addDependency(module); +} + +void Program::addReferencedLeafModule(Module* module) +{ + m_moduleDependencyList.addLeafDependency(module); + m_filePathDependencyList.addDependency(module); +} + +void Program::addEntryPoint(EntryPoint* entryPoint) +{ + m_entryPoints.Add(entryPoint); + + for(auto module : entryPoint->getModuleDependencies()) { - if (transUnit->SyntaxNode->memberDictionary.TryGetValue(name, resultDecl)) - break; + addReferencedModule(module); } - return resultDecl; } -void CompileRequest::noteInternalErrorLoc(SourceLoc const& loc) +RefPtr<IRModule> Program::getOrCreateIRModule(DiagnosticSink* sink) +{ + if(!m_irModule) + { + m_irModule = generateIRForProgram( + m_linkage->getSession(), + this, + sink); + } + return m_irModule; +} + + +TargetProgram* Program::getTargetProgram(TargetRequest* target) +{ + RefPtr<TargetProgram> targetProgram; + if(!m_targetPrograms.TryGetValue(target, targetProgram)) + { + targetProgram = new TargetProgram(this, target); + m_targetPrograms[target] = targetProgram; + } + return targetProgram; +} + +// +// TargetProgram +// + +TargetProgram::TargetProgram( + Program* program, + TargetRequest* targetReq) + : m_program(program) + , m_targetReq(targetReq) +{ + m_entryPointResults.SetSize(program->getEntryPoints().Count()); +} + +// + +void DiagnosticSink::noteInternalErrorLoc(SourceLoc const& loc) { // Don't consider invalid source locations. if(!loc.isValid()) @@ -1054,14 +1430,19 @@ void CompileRequest::noteInternalErrorLoc(SourceLoc const& loc) // code might have confused the compiler. if(internalErrorLocsNoted == 0) { - mSink.diagnose(loc, Diagnostics::noteLocationOfInternalError); + diagnose(loc, Diagnostics::noteLocationOfInternalError); } internalErrorLocsNoted++; } +Session* CompileRequestBase::getSession() +{ + return getLinkage()->getSession(); +} + static const Slang::Guid IID_ISlangFileSystemExt = SLANG_UUID_ISlangFileSystemExt; -void CompileRequest::setFileSystem(ISlangFileSystem* inFileSystem) +void Linkage::setFileSystem(ISlangFileSystem* inFileSystem) { // Set the fileSystem fileSystem = inFileSystem; @@ -1085,15 +1466,16 @@ void CompileRequest::setFileSystem(ISlangFileSystem* inFileSystem) } // Set the file system used on the source manager - sourceManager->setFileSystemExt(fileSystemExt); + getSourceManager()->setFileSystemExt(fileSystemExt); } -RefPtr<ModuleDecl> findOrImportModule( - CompileRequest* request, +RefPtr<Module> findOrImportModule( + Linkage* linkage, Name* name, - SourceLoc const& loc) + SourceLoc const& loc, + DiagnosticSink* sink) { - return request->findOrImportModule(name, loc); + return linkage->findOrImportModule(name, loc, sink); } void Session::addBuiltinSource( @@ -1101,30 +1483,34 @@ void Session::addBuiltinSource( String const& path, String const& source) { - RefPtr<CompileRequest> compileRequest = new CompileRequest(this); - compileRequest->setSourceManager(getBuiltinSourceManager()); + DiagnosticSink sink; + RefPtr<FrontEndCompileRequest> compileRequest = new FrontEndCompileRequest( + m_builtinLinkage, + &sink); - auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, path); + Name* moduleName = getNamePool()->getName(path); + auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, moduleName); compileRequest->addTranslationUnitSourceString( translationUnitIndex, path, source); - SlangResult res = compileRequest->executeActions(); + SlangResult res = compileRequest->executeActionsInner(); if (SLANG_FAILED(res)) { - fprintf(stderr, "%s", compileRequest->mDiagnosticOutput.Buffer()); + char const* diagnostics = sink.outputBuffer.Buffer(); + fprintf(stderr, "%s", diagnostics); #ifdef _WIN32 - OutputDebugStringA(compileRequest->mDiagnosticOutput.Buffer()); + OutputDebugStringA(diagnostics); #endif SLANG_UNEXPECTED("error in Slang standard library"); } // Extract the AST for the code we just parsed - auto syntax = compileRequest->translationUnits[translationUnitIndex]->SyntaxNode; + auto syntax = compileRequest->translationUnits[translationUnitIndex]->getModuleDecl(); // HACK(tfoley): mark all declarations in the "stdlib" so // that we can detect them later (e.g., so we don't emit them) @@ -1176,19 +1562,37 @@ Session::~Session() // implementation of C interface -#define SESSION(x) reinterpret_cast<Slang::Session *>(x) -#define REQ(x) reinterpret_cast<Slang::CompileRequest*>(x) +static SlangSession* convert(Slang::Session* session) +{ return reinterpret_cast<SlangSession*>(session); } + +static Slang::Session* convert(SlangSession* session) +{ return reinterpret_cast<Slang::Session*>(session); } + +static SlangCompileRequest* convert(Slang::EndToEndCompileRequest* request) +{ return reinterpret_cast<SlangCompileRequest*>(request); } + +static Slang::EndToEndCompileRequest* convert(SlangCompileRequest* request) +{ return reinterpret_cast<Slang::EndToEndCompileRequest*>(request); } + +static SlangLinkage* convert(Slang::Linkage* linkage) +{ return reinterpret_cast<SlangLinkage*>(linkage); } + +static Slang::Linkage* convert(SlangLinkage* linkage) +{ return reinterpret_cast<Slang::Linkage*>(linkage); } + +static SlangModule* convert(Slang::Module* module) +{ return reinterpret_cast<SlangModule*>(module); } SLANG_API SlangSession* spCreateSession(const char*) { - return reinterpret_cast<SlangSession *>(new Slang::Session()); + return convert(new Slang::Session()); } SLANG_API void spDestroySession( SlangSession* session) { if(!session) return; - delete SESSION(session); + delete convert(session); } SLANG_API void spAddBuiltins( @@ -1196,7 +1600,7 @@ SLANG_API void spAddBuiltins( char const* sourcePath, char const* sourceString) { - auto s = SESSION(session); + auto s = convert(session); s->addBuiltinSource( // TODO(tfoley): Add ability to directly new builtins to the approriate scope @@ -1210,7 +1614,7 @@ SLANG_API void spSessionSetSharedLibraryLoader( SlangSession* session, ISlangSharedLibraryLoader* loader) { - auto s = SESSION(session); + auto s = convert(session); if (!loader) { @@ -1237,7 +1641,7 @@ SLANG_API void spSessionSetSharedLibraryLoader( SLANG_API ISlangSharedLibraryLoader* spSessionGetSharedLibraryLoader( SlangSession* session) { - auto s = SESSION(session); + auto s = convert(session); return (s->sharedLibraryLoader == Slang::DefaultSharedLibraryLoader::getSingleton()) ? nullptr : s->sharedLibraryLoader.get(); } @@ -1245,7 +1649,7 @@ SLANG_API SlangResult spSessionCheckCompileTargetSupport( SlangSession* session, SlangCompileTarget target) { - auto s = SESSION(session); + auto s = convert(session); return Slang::checkCompileTargetSupport(s, Slang::CodeGenTarget(target)); } @@ -1253,16 +1657,45 @@ SLANG_API SlangResult spSessionCheckPassThroughSupport( SlangSession* session, SlangPassThrough passThrough) { - auto s = SESSION(session); + auto s = convert(session); return Slang::checkExternalCompilerSupport(s, Slang::PassThroughMode(passThrough)); } + +SLANG_API SlangLinkage* spCreateLinkage( + SlangSession* session) +{ + auto s = convert(session); + auto linkage = new Slang::Linkage(s); + return convert(linkage); +} + +SLANG_API void spDestroyLinkage( + SlangLinkage* linkage) +{ + if(!linkage) return; + auto lnk = convert(linkage); + delete lnk; +} + +SLANG_API SlangModule* spLoadModule( + SlangLinkage* linkage, + char const* moduleName) +{ + if(!linkage) return nullptr; + auto lnk = convert(linkage); + + auto mod = lnk->loadModule(moduleName); + return convert(mod); +} + + SLANG_API SlangCompileRequest* spCreateCompileRequest( SlangSession* session) { - auto s = SESSION(session); - auto req = new Slang::CompileRequest(s); - return reinterpret_cast<SlangCompileRequest*>(req); + auto s = convert(session); + auto req = new Slang::EndToEndCompileRequest(s); + return convert(req); } /*! @@ -1272,7 +1705,7 @@ SLANG_API void spDestroyCompileRequest( SlangCompileRequest* request) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); delete req; } @@ -1281,21 +1714,21 @@ SLANG_API void spSetFileSystem( ISlangFileSystem* fileSystem) { if(!request) return; - REQ(request)->setFileSystem(fileSystem); + convert(request)->getLinkage()->setFileSystem(fileSystem); } SLANG_API void spSetCompileFlags( SlangCompileRequest* request, SlangCompileFlags flags) { - REQ(request)->compileFlags = flags; + convert(request)->getFrontEndReq()->compileFlags = flags; } SLANG_API void spSetDumpIntermediates( SlangCompileRequest* request, int enable) { - REQ(request)->shouldDumpIntermediates = enable != 0; + convert(request)->getBackEndReq()->shouldDumpIntermediates = enable != 0; } SLANG_API void spSetLineDirectiveMode( @@ -1304,13 +1737,13 @@ SLANG_API void spSetLineDirectiveMode( { // TODO: validation - REQ(request)->lineDirectiveMode = Slang::LineDirectiveMode(mode); + convert(request)->getBackEndReq()->lineDirectiveMode = Slang::LineDirectiveMode(mode); } SLANG_API void spSetCommandLineCompilerMode( SlangCompileRequest* request) { - REQ(request)->isCommandLineCompile = true; + convert(request)->isCommandLineCompile = true; } @@ -1318,17 +1751,19 @@ SLANG_API void spSetCodeGenTarget( SlangCompileRequest* request, SlangCompileTarget target) { - auto req = REQ(request); - req->targets.Clear(); - req->addTarget(Slang::CodeGenTarget(target)); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets.Clear(); + linkage->addTarget(Slang::CodeGenTarget(target)); } SLANG_API int spAddCodeGenTarget( SlangCompileRequest* request, SlangCompileTarget target) { - auto req = REQ(request); - return (int) req->addTarget(Slang::CodeGenTarget(target)); + auto req = convert(request); + auto linkage = req->getLinkage(); + return (int) linkage->addTarget(Slang::CodeGenTarget(target)); } SLANG_API void spSetTargetProfile( @@ -1336,8 +1771,9 @@ SLANG_API void spSetTargetProfile( int targetIndex, SlangProfileID profile) { - auto req = REQ(request); - req->targets[targetIndex]->targetProfile = Slang::Profile(profile); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets[targetIndex]->targetProfile = Slang::Profile(profile); } SLANG_API void spSetTargetFlags( @@ -1345,8 +1781,9 @@ SLANG_API void spSetTargetFlags( int targetIndex, SlangTargetFlags flags) { - auto req = REQ(request); - req->targets[targetIndex]->targetFlags = flags; + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets[targetIndex]->targetFlags = flags; } SLANG_API void spSetTargetFloatingPointMode( @@ -1354,16 +1791,18 @@ SLANG_API void spSetTargetFloatingPointMode( int targetIndex, SlangFloatingPointMode mode) { - auto req = REQ(request); - req->targets[targetIndex]->floatingPointMode = Slang::FloatingPointMode(mode); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets[targetIndex]->floatingPointMode = Slang::FloatingPointMode(mode); } SLANG_API void spSetMatrixLayoutMode( SlangCompileRequest* request, SlangMatrixLayoutMode mode) { - auto req = REQ(request); - req->defaultMatrixLayoutMode = Slang::MatrixLayoutMode(mode); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->defaultMatrixLayoutMode = Slang::MatrixLayoutMode(mode); } SLANG_API void spSetTargetMatrixLayoutMode( @@ -1380,7 +1819,7 @@ SLANG_API void spSetOutputContainerFormat( SlangCompileRequest* request, SlangContainerFormat format) { - auto req = REQ(request); + auto req = convert(request); req->containerFormat = Slang::ContainerFormat(format); } @@ -1389,7 +1828,7 @@ SLANG_API void spSetPassThrough( SlangCompileRequest* request, SlangPassThrough passThrough) { - REQ(request)->passThrough = Slang::PassThroughMode(passThrough); + convert(request)->passThrough = Slang::PassThroughMode(passThrough); } SLANG_API void spSetDiagnosticCallback( @@ -1400,7 +1839,7 @@ SLANG_API void spSetDiagnosticCallback( using namespace Slang; if(!request) return; - auto req = REQ(request); + auto req = convert(request); ComPtr<ISlangWriter> writer(new CallbackWriter(callback, userData, WriterFlag::IsConsole)); req->setWriter(WriterChannel::Diagnostic, writer); @@ -1412,7 +1851,7 @@ SLANG_API void spSetWriter( ISlangWriter* writer) { if (!request) return; - auto req = REQ(request); + auto req = convert(request); req->setWriter(Slang::WriterChannel(chan), writer); } @@ -1422,15 +1861,17 @@ SLANG_API ISlangWriter* spGetWriter( SlangWriterChannel chan) { if (!request) return nullptr; - auto req = REQ(request); + auto req = convert(request); return req->getWriter(Slang::WriterChannel(chan)); } SLANG_API void spAddSearchPath( - SlangCompileRequest* request, - const char* path) + SlangCompileRequest* request, + const char* path) { - REQ(request)->searchDirectories.Add(Slang::SearchDirectory(path)); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->searchDirectories.searchDirectories.Add(Slang::SearchDirectory(path)); } SLANG_API void spAddPreprocessorDefine( @@ -1438,25 +1879,27 @@ SLANG_API void spAddPreprocessorDefine( const char* key, const char* value) { - REQ(request)->preprocessorDefinitions[key] = value; + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->preprocessorDefinitions[key] = value; } SLANG_API char const* spGetDiagnosticOutput( SlangCompileRequest* request) { if(!request) return 0; - auto req = REQ(request); + auto req = convert(request); return req->mDiagnosticOutput.begin(); } SLANG_API SlangResult spGetDiagnosticOutputBlob( - SlangCompileRequest* request, - ISlangBlob** outBlob) + SlangCompileRequest* request, + ISlangBlob** outBlob) { if(!request) return SLANG_ERROR_INVALID_PARAMETER; if(!outBlob) return SLANG_ERROR_INVALID_PARAMETER; - auto req = REQ(request); + auto req = convert(request); if(!req->diagnosticOutputBlob) { @@ -1475,11 +1918,13 @@ SLANG_API int spAddTranslationUnit( SlangSourceLanguage language, char const* name) { - auto req = REQ(request); + SLANG_UNUSED(name); + + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); - return req->addTranslationUnit( - Slang::SourceLanguage(language), - name ? name : ""); + return frontEndReq->addTranslationUnit( + Slang::SourceLanguage(language)); } SLANG_API void spTranslationUnit_addPreprocessorDefine( @@ -1488,10 +1933,10 @@ SLANG_API void spTranslationUnit_addPreprocessorDefine( const char* key, const char* value) { - auto req = REQ(request); - - req->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value; + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + frontEndReq->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value; } SLANG_API void spAddTranslationUnitSourceFile( @@ -1500,12 +1945,13 @@ SLANG_API void spAddTranslationUnitSourceFile( char const* path) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if(!path) return; if(translationUnitIndex < 0) return; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return; + if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return; - req->addTranslationUnitSourceFile( + frontEndReq->addTranslationUnitSourceFile( translationUnitIndex, path); } @@ -1533,14 +1979,15 @@ SLANG_API void spAddTranslationUnitSourceStringSpan( char const* sourceEnd) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if(!sourceBegin) return; if(translationUnitIndex < 0) return; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return; + if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return; if(!path) path = ""; - req->addTranslationUnitSourceString( + frontEndReq->addTranslationUnitSourceString( translationUnitIndex, path, Slang::UnownedStringSlice(sourceBegin, sourceEnd)); @@ -1553,14 +2000,15 @@ SLANG_API void spAddTranslationUnitSourceBlob( ISlangBlob* sourceBlob) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if(!sourceBlob) return; if(translationUnitIndex < 0) return; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return; + if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return; if(!path) path = ""; - req->addTranslationUnitSourceBlob( + frontEndReq->addTranslationUnitSourceBlob( translationUnitIndex, path, sourceBlob); @@ -1584,17 +2032,13 @@ SLANG_API int spAddEntryPoint( char const* name, SlangStage stage) { - if(!request) return -1; - auto req = REQ(request); - if(!name) return -1; - if(translationUnitIndex < 0) return -1; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1; - - return req->addEntryPoint( + return spAddEntryPointEx( + request, translationUnitIndex, name, - Slang::Profile(Slang::Stage(stage)), - Slang::List<Slang::String>()); + stage, + 0, + nullptr); } SLANG_API int spAddEntryPointEx( @@ -1606,10 +2050,11 @@ SLANG_API int spAddEntryPointEx( char const ** genericParamTypeNames) { if (!request) return -1; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if (!name) return -1; if (translationUnitIndex < 0) return -1; - if (Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1; + if (Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return -1; Slang::List<Slang::String> typeNames; for (int i = 0; i < genericParamTypeNameCount; i++) typeNames.Add(genericParamTypeNames[i]); @@ -1620,12 +2065,28 @@ SLANG_API int spAddEntryPointEx( typeNames); } +SLANG_API SlangResult spSetGlobalGenericArgs( + SlangCompileRequest* request, + int genericArgCount, + char const** genericArgs) +{ + if (!request) return SLANG_FAIL; + auto req = convert(request); + + auto& genericArgStrings = req->globalGenericArgStrings; + genericArgStrings.Clear(); + for (int i = 0; i < genericArgCount; i++) + genericArgStrings.Add(genericArgs[i]); + + return SLANG_OK; +} + // Compile in a context that already has its translation units specified SLANG_API SlangResult spCompile( SlangCompileRequest* request) { - auto req = REQ(request); + auto req = convert(request); #if !defined(SLANG_DEBUG_INTERNAL_ERROR) // By default we'd like to catch as many internal errors as possible, @@ -1654,7 +2115,7 @@ SLANG_API SlangResult spCompile( // We will print out information on the exception to help out the user // in either filing a bug, or locating what in their code created // a problem. - req->mSink.diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAbortedDueToException, typeid(e).name(), e.Message); + req->getSink()->diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAbortedDueToException, typeid(e).name(), e.Message); } catch (...) { @@ -1662,9 +2123,9 @@ SLANG_API SlangResult spCompile( // `Exception`, so something really fishy is going on. We want to // let the user know that we messed up, so they know to blame Slang // and not some other component in their system. - req->mSink.diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAborted); + req->getSink()->diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAborted); } - req->mDiagnosticOutput = req->mSink.outputBuffer.ProduceString(); + req->mDiagnosticOutput = req->getSink()->outputBuffer.ProduceString(); return res; #else // When debugging, we probably don't want to filter out any errors, since @@ -1680,8 +2141,10 @@ spGetDependencyFileCount( SlangCompileRequest* request) { if(!request) return 0; - auto req = REQ(request); - return (int) req->mDependencyFilePaths.Count(); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + auto program = frontEndReq->getProgram(); + return (int) program->getFilePathDependencies().Count(); } /** Get the path to a file this compilation dependend on. @@ -1692,16 +2155,19 @@ spGetDependencyFilePath( int index) { if(!request) return 0; - auto req = REQ(request); - return req->mDependencyFilePaths[index].begin(); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + auto program = frontEndReq->getProgram(); + return program->getFilePathDependencies()[index].begin(); } SLANG_API int spGetTranslationUnitCount( SlangCompileRequest* request) { - auto req = REQ(request); - return (int) req->translationUnits.Count(); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + return (int) frontEndReq->translationUnits.Count(); } // Get the output code associated with a specific translation unit @@ -1718,15 +2184,26 @@ SLANG_API void const* spGetEntryPointCode( int entryPointIndex, size_t* outSize) { - auto req = REQ(request); + auto req = convert(request); + auto linkage = req->getLinkage(); + auto program = req->getSpecializedProgram(); // TODO: We should really accept a target index in this API - auto targetCount = req->targets.Count(); - if (targetCount == 0) + Slang::UInt targetIndex = 0; + auto targetCount = linkage->targets.Count(); + if (targetIndex >= targetCount) return nullptr; - auto targetReq = req->targets[0]; + auto targetReq = linkage->targets[targetIndex]; - Slang::CompileResult& result = targetReq->entryPointResults[entryPointIndex]; + + if(entryPointIndex < 0) return nullptr; + if(Slang::UInt(entryPointIndex) >= req->entryPoints.Count()) return nullptr; + auto entryPoint = program->getEntryPoint(entryPointIndex); + + auto targetProgram = program->getTargetProgram(targetReq); + if(!targetProgram) + return nullptr; + Slang::CompileResult& result = targetProgram->getExistingEntryPointResult(entryPointIndex); void const* data = nullptr; size_t size = 0; @@ -1761,21 +2238,29 @@ SLANG_API SlangResult spGetEntryPointCodeBlob( if(!request) return SLANG_ERROR_INVALID_PARAMETER; if(!outBlob) return SLANG_ERROR_INVALID_PARAMETER; - auto req = REQ(request); + auto req = convert(request); + auto linkage = req->getLinkage(); + auto program = req->getSpecializedProgram(); - int targetCount = (int) req->targets.Count(); + int targetCount = (int) linkage->targets.Count(); if((targetIndex < 0) || (targetIndex >= targetCount)) { return SLANG_ERROR_INVALID_PARAMETER; } - auto targetReq = req->targets[targetIndex]; + auto targetReq = linkage->targets[targetIndex]; int entryPointCount = (int) req->entryPoints.Count(); if((entryPointIndex < 0) || (entryPointIndex >= entryPointCount)) { return SLANG_ERROR_INVALID_PARAMETER; } - Slang::CompileResult& result = targetReq->entryPointResults[entryPointIndex]; + auto entryPointReq = program->getEntryPoint(entryPointIndex); + + + auto targetProgram = program->getTargetProgram(targetReq); + if(!targetProgram) + return SLANG_FAIL; + Slang::CompileResult& result = targetProgram->getExistingEntryPointResult(entryPointIndex); auto blob = result.getBlob(); *outBlob = blob.detach(); @@ -1793,13 +2278,9 @@ SLANG_API void const* spGetCompileRequestCode( SlangCompileRequest* request, size_t* outSize) { - auto req = REQ(request); - - void const* data = req->generatedBytecode.Buffer(); - size_t size = req->generatedBytecode.Count(); - - if(outSize) *outSize = size; - return data; + SLANG_UNUSED(request); + SLANG_UNUSED(outSize); + return nullptr; } // Reflection API @@ -1808,7 +2289,9 @@ SLANG_API SlangReflection* spGetReflection( SlangCompileRequest* request) { if( !request ) return 0; - auto req = REQ(request); + auto req = convert(request); + auto linkage = req->getLinkage(); + auto program = req->getSpecializedProgram(); // Note(tfoley): The API signature doesn't let the client // specify which target they want to access reflection @@ -1818,12 +2301,16 @@ SLANG_API SlangReflection* spGetReflection( // so that we can do this better, and make it clear that // `spGetReflection()` is shorthand for `targetIndex == 0`. // - auto targetCount = req->targets.Count(); - if (targetCount == 0) - return 0; - auto targetReq = req->targets[0]; + Slang::UInt targetIndex = 0; + auto targetCount = linkage->targets.Count(); + if (targetIndex >= targetCount) + return nullptr; + + auto targetReq = linkage->targets[targetIndex]; + auto targetProgram = program->getTargetProgram(targetReq); + auto programLayout = targetProgram->getExistingLayout(); - return (SlangReflection*) targetReq->layout.Ptr(); + return (SlangReflection*) programLayout; } // ... rest of reflection API implementation is in `Reflection.cpp` diff --git a/source/slang/syntax-visitors.h b/source/slang/syntax-visitors.h index 9644deae1..3fca323e8 100644 --- a/source/slang/syntax-visitors.h +++ b/source/slang/syntax-visitors.h @@ -6,8 +6,10 @@ namespace Slang { - class CompileRequest; - class EntryPointRequest; + class DiagnosticSink; + class EntryPoint; + class Linkage; + class Module; class ShaderCompiler; class ShaderLinkInfo; class ShaderSymbol; @@ -24,10 +26,11 @@ namespace Slang // Needed by import declaration checking. // // TODO: need a better location to declare this. - RefPtr<ModuleDecl> findOrImportModule( - CompileRequest* request, + RefPtr<Module> findOrImportModule( + Linkage* linkage, Name* name, - SourceLoc const& loc); + SourceLoc const& loc, + DiagnosticSink* sink); } #endif
\ No newline at end of file diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 709206278..b1b9f6d80 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -2713,4 +2713,15 @@ RefPtr<Val> TaggedUnionSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int return substWitness; } +Module* getModule(Decl* decl) +{ + for( auto dd = decl; dd; dd = dd->ParentDecl ) + { + if(auto moduleDecl = as<ModuleDecl>(dd)) + return moduleDecl->module; + } + return nullptr; +} + + } // namespace Slang diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 6a404214e..5198a44b2 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -12,6 +12,7 @@ namespace Slang { + class Module; class Name; class Session; class Substitutions; @@ -1360,6 +1361,10 @@ namespace Slang Function = 4, All = 7 }; + + /// Get the module that a declaration is associated with, if any. + Module* getModule(Decl* decl); + } // namespace Slang #endif diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp index 0bc676cf9..95a92ee2c 100644 --- a/source/slang/type-layout.cpp +++ b/source/slang/type-layout.cpp @@ -802,7 +802,7 @@ LayoutRulesImpl* GetLayoutRulesImpl(LayoutRule rule) LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targetReq) { - switch (targetReq->target) + switch (targetReq->getTarget()) { case CodeGenTarget::HLSL: case CodeGenTarget::DXBytecode: @@ -821,12 +821,13 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe } } -TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq) +TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq, ProgramLayout* programLayout) { LayoutRulesFamilyImpl* rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq); TypeLayoutContext context; context.targetReq = targetReq; + context.programLayout = programLayout; context.rules = nullptr; context.matrixLayoutMode = targetReq->getDefaultMatrixLayoutMode(); @@ -962,7 +963,7 @@ static bool isOpenGLTarget(TargetRequest*) bool isD3DTarget(TargetRequest* targetReq) { - switch( targetReq->target ) + switch( targetReq->getTarget() ) { case CodeGenTarget::HLSL: case CodeGenTarget::DXBytecode: @@ -978,7 +979,7 @@ bool isD3DTarget(TargetRequest* targetReq) bool isKhronosTarget(TargetRequest* targetReq) { - switch( targetReq->target ) + switch( targetReq->getTarget() ) { default: return false; @@ -1008,7 +1009,7 @@ static bool isSM5OrEarlier(TargetRequest* targetReq) if(!isD3DTarget(targetReq)) return false; - auto profile = targetReq->targetProfile; + auto profile = targetReq->getTargetProfile(); if(profile.getFamily() == ProfileFamily::DX) { @@ -1024,7 +1025,7 @@ static bool isSM5_1OrLater(TargetRequest* targetReq) if(!isD3DTarget(targetReq)) return false; - auto profile = targetReq->targetProfile; + auto profile = targetReq->getTargetProfile(); if(profile.getFamily() == ProfileFamily::DX) { @@ -2102,7 +2103,7 @@ SimpleLayoutInfo GetLayoutImpl( // // The `maybeAdjustLayoutForArrayElementType` computes an "adjusted" // type layout for the element type which takes the array stride into - // acount. If it returns the same type layout that was passed in, + // account. If it returns the same type layout that was passed in, // then that means no adjustement took place. // // The `additionalSpacesNeededForAdjustedElementType` variable counts @@ -2327,13 +2328,35 @@ SimpleLayoutInfo GetLayoutImpl( // we should have already populated ProgramLayout::genericEntryPointParams list at this point, // so we can find the index of this generic param decl in the list genParamTypeLayout->type = type; - genParamTypeLayout->paramIndex = findGenericParam(context.targetReq->layout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl()); + genParamTypeLayout->paramIndex = findGenericParam(context.programLayout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl()); genParamTypeLayout->rules = rules; genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1; *outTypeLayout = genParamTypeLayout; } return info; } + else if( auto simpleGenericParam = declRef.as<GenericTypeParamDecl>() ) + { + // A bare generic type parameter can come up during layout + // of a generic entry point (or an entry point nested in + // a generic type). For now we will just pretend like + // the fields of generic parameter type take no space, + // since there is no reasonable way to account for them + // in the resulting layout. + // + // TODO: It might be better to completely ignore generic + // entry points during initial layout, but doing so would + // mean that users couldn't get layout information on + // any parameters, even those that don't depend on + // generics. + // + SimpleLayoutInfo info; + return GetSimpleLayoutImpl( + info, + type, + rules, + outTypeLayout); + } } else if (auto errorType = as<ErrorType>(type)) { diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index e20db7f56..1d939b18f 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -649,6 +649,14 @@ public: RefPtr<VarLayout> globalScopeLayout; */ + /// The target and program for which layout was computed + TargetProgram* targetProgram; + + TargetProgram* getTargetProgram() { return targetProgram; } + TargetRequest* getTargetReq() { return targetProgram->getTargetReq(); } + Program* getProgram() { return targetProgram->getProgram(); } + + // We catalog the requested entry points here, // and any entry-point-specific parameter data // will (eventually) belong there... @@ -656,8 +664,6 @@ public: List<RefPtr<GenericParamLayout>> globalGenericParams; Dictionary<String, GenericParamLayout*> globalGenericParamsMap; - - TargetRequest* targetRequest = nullptr; }; StructTypeLayout* getGlobalStructLayout( @@ -804,6 +810,8 @@ struct LayoutRulesFamilyImpl virtual LayoutRulesImpl* getShaderRecordConstantBufferRules() = 0; }; +typedef List<RefPtr<GenericParamLayout>> GenericParamLayouts; + struct TypeLayoutContext { // The layout rules to use (e.g., we compute @@ -812,7 +820,12 @@ struct TypeLayoutContext LayoutRulesImpl* rules; // The target request that is triggering layout - TargetRequest* targetReq; + TargetRequest* targetReq; + + // A parent program layout that will establish the ordering + // of all global generic type parameters. + // + ProgramLayout* programLayout; // Whether to lay out matrices column-major // or row-major. @@ -840,8 +853,13 @@ struct TypeLayoutContext // Get an appropriate set of layout rules (packaged up // as a `TypeLayoutContext`) to perform type layout // for the given target. +// +// The provided `programLayout` is used to establish +// the ordering of all global generic type paramters. +// TypeLayoutContext getInitialLayoutContextForTarget( - TargetRequest* targetReq); + TargetRequest* targetReq, + ProgramLayout* programLayout); // Get the "simple" layout for a type according to a given set of layout // rules. Note that a "simple" layout can only consume one `LayoutResourceKind`, diff --git a/tests/bindings/glsl-parameter-blocks.slang b/tests/bindings/glsl-parameter-blocks.slang index d356df775..ee385e158 100644 --- a/tests/bindings/glsl-parameter-blocks.slang +++ b/tests/bindings/glsl-parameter-blocks.slang @@ -1,4 +1,3 @@ -#version 450 core //TEST:CROSS_COMPILE: -profile ps_5_0 -entry main -target spirv-assembly struct Test diff --git a/tests/compute/global-type-param-in-entrypoint.slang b/tests/compute/global-type-param-in-entrypoint.slang index 9a1e9b054..0386a7d10 100644 --- a/tests/compute/global-type-param-in-entrypoint.slang +++ b/tests/compute/global-type-param-in-entrypoint.slang @@ -1,7 +1,7 @@ //TEST(compute):COMPARE_RENDER_COMPUTE: //TEST_INPUT: cbuffer(data=[1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0], stride=16):dxbinding(0),glbinding(0) //TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):dxbinding(1),glbinding(0),out -//TEST_INPUT: type VertImpl +//TEST_INPUT: global_type VertImpl interface IVertInterpolant { diff --git a/tests/ir/string-literal.slang.expected b/tests/ir/string-literal.slang.expected index b86eab2c8..5cbd56aea 100644 --- a/tests/ir/string-literal.slang.expected +++ b/tests/ir/string-literal.slang.expected @@ -1,7 +1,7 @@ result code = 0 standard error = { [entryPoint] -[export("_S04mainp1puV")] +[export("_S3tu04mainp1puV")] [nameHint("main")] func %main : Func(Void, UInt) { diff --git a/tools/gfx/render.h b/tools/gfx/render.h index bfbe0f82a..bc880d0be 100644 --- a/tools/gfx/render.h +++ b/tools/gfx/render.h @@ -141,6 +141,7 @@ struct ShaderCompileRequest EntryPoint vertexShader; EntryPoint fragmentShader; EntryPoint computeShader; + Slang::List<Slang::String> globalTypeArguments; Slang::List<Slang::String> entryPointTypeArguments; }; diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index d3a7acf37..4d2791563 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -387,7 +387,8 @@ Result RenderTestApp::initializeShaders(ShaderCompiler* shaderCompiler) compileRequest.computeShader.source = sourceInfo; compileRequest.computeShader.name = computeEntryPointName; } - compileRequest.entryPointTypeArguments = m_shaderInputLayout.globalTypeArguments; + compileRequest.globalTypeArguments = m_shaderInputLayout.globalTypeArguments; + compileRequest.entryPointTypeArguments = m_shaderInputLayout.entryPointTypeArguments; m_shaderProgram = shaderCompiler->compileProgram(compileRequest); return m_shaderProgram ? SLANG_OK : SLANG_FAIL; diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 0257c9b53..8ab14bf76 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -10,6 +10,7 @@ namespace renderer_test { entries.Clear(); globalTypeArguments.Clear(); + entryPointTypeArguments.Clear(); auto lines = Split(source, '\n'); for (auto & line : lines) { @@ -25,6 +26,14 @@ namespace renderer_test StringBuilder typeExp; while (!parser.IsEnd()) typeExp << parser.ReadToken().Content; + entryPointTypeArguments.Add(typeExp); + } + else if (parser.LookAhead("global_type")) + { + parser.ReadToken(); + StringBuilder typeExp; + while (!parser.IsEnd()) + typeExp << parser.ReadToken().Content; globalTypeArguments.Add(typeExp); } else diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h index 92dd516a7..be86d971f 100644 --- a/tools/render-test/shader-input-layout.h +++ b/tools/render-test/shader-input-layout.h @@ -73,6 +73,7 @@ class ShaderInputLayout public: Slang::List<ShaderInputLayoutEntry> entries; Slang::List<Slang::String> globalTypeArguments; + Slang::List<Slang::String> entryPointTypeArguments; int numRenderTargets = 1; void Parse(const char * source); }; diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index 1dc7323a5..0bf086d43 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -92,16 +92,25 @@ RefPtr<ShaderProgram> ShaderCompiler::compileProgram( RefPtr<ShaderProgram> shaderProgram; - Slang::List<const char*> rawTypeNames; + + Slang::List<const char*> rawGlobalTypeNames; + for (auto typeName : request.globalTypeArguments) + rawGlobalTypeNames.Add(typeName.Buffer()); + spSetGlobalGenericArgs( + slangRequest, + (int)rawGlobalTypeNames.Count(), + rawGlobalTypeNames.Buffer()); + + Slang::List<const char*> rawEntryPointTypeNames; for (auto typeName : request.entryPointTypeArguments) - rawTypeNames.Add(typeName.Buffer()); + rawEntryPointTypeNames.Add(typeName.Buffer()); if (request.computeShader.name) { int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit, computeEntryPointName, SLANG_STAGE_COMPUTE, - (int)rawTypeNames.Count(), - rawTypeNames.Buffer()); + (int)rawEntryPointTypeNames.Count(), + rawEntryPointTypeNames.Buffer()); spSetLineDirectiveMode(slangRequest, SLANG_LINE_DIRECTIVE_MODE_NONE); const SlangResult res = spCompile(slangRequest); @@ -129,8 +138,8 @@ RefPtr<ShaderProgram> ShaderCompiler::compileProgram( } else { - int vertexEntryPoint = spAddEntryPointEx(slangRequest, vertexTranslationUnit, vertexEntryPointName, SLANG_STAGE_VERTEX, (int)rawTypeNames.Count(), rawTypeNames.Buffer()); - int fragmentEntryPoint = spAddEntryPointEx(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, SLANG_STAGE_FRAGMENT, (int)rawTypeNames.Count(), rawTypeNames.Buffer()); + int vertexEntryPoint = spAddEntryPointEx(slangRequest, vertexTranslationUnit, vertexEntryPointName, SLANG_STAGE_VERTEX, (int)rawEntryPointTypeNames.Count(), rawEntryPointTypeNames.Buffer()); + int fragmentEntryPoint = spAddEntryPointEx(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, SLANG_STAGE_FRAGMENT, (int)rawEntryPointTypeNames.Count(), rawEntryPointTypeNames.Buffer()); const SlangResult res = spCompile(slangRequest); if (auto diagnostics = spGetDiagnosticOutput(slangRequest)) |
