summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/core/smart-pointer.h13
-rw-r--r--source/slang/check.cpp1041
-rw-r--r--source/slang/compiler.cpp563
-rw-r--r--source/slang/compiler.h970
-rw-r--r--source/slang/decl-defs.h8
-rw-r--r--source/slang/diagnostic-defs.h6
-rw-r--r--source/slang/diagnostics.h29
-rw-r--r--source/slang/dxc-support.cpp56
-rw-r--r--source/slang/emit.cpp99
-rw-r--r--source/slang/emit.h10
-rw-r--r--source/slang/ir-dce.cpp10
-rw-r--r--source/slang/ir-dce.h6
-rw-r--r--source/slang/ir-inst-defs.h4
-rw-r--r--source/slang/ir-insts.h5
-rw-r--r--source/slang/ir-link.cpp84
-rw-r--r--source/slang/ir-link.h9
-rw-r--r--source/slang/ir-specialize-resources.cpp6
-rw-r--r--source/slang/ir-specialize-resources.h4
-rw-r--r--source/slang/ir-validate.cpp6
-rw-r--r--source/slang/ir-validate.h6
-rw-r--r--source/slang/lower-to-ir.cpp225
-rw-r--r--source/slang/lower-to-ir.h7
-rw-r--r--source/slang/options.cpp58
-rw-r--r--source/slang/parameter-binding.cpp199
-rw-r--r--source/slang/parameter-binding.h5
-rw-r--r--source/slang/parser.cpp250
-rw-r--r--source/slang/parser.h7
-rw-r--r--source/slang/preprocessor.cpp166
-rw-r--r--source/slang/preprocessor.h9
-rw-r--r--source/slang/reflection.cpp18
-rw-r--r--source/slang/slang.cpp1127
-rw-r--r--source/slang/syntax-visitors.h13
-rw-r--r--source/slang/syntax.cpp11
-rw-r--r--source/slang/syntax.h5
-rw-r--r--source/slang/type-layout.cpp39
-rw-r--r--source/slang/type-layout.h26
36 files changed, 3220 insertions, 1880 deletions
diff --git a/source/core/smart-pointer.h b/source/core/smart-pointer.h
index e19ed6a4d..0b03deb8f 100644
--- a/source/core/smart-pointer.h
+++ b/source/core/smart-pointer.h
@@ -2,6 +2,7 @@
#define FUNDAMENTAL_LIB_SMART_POINTER_H
#include "common.h"
+#include "hash.h"
#include "type-traits.h"
#include <assert.h>
@@ -157,10 +158,14 @@ namespace Slang
releaseReference(old);
}
- int GetHashCode()
- {
- return (int)(long long)(void*)pointer;
- }
+ int GetHashCode()
+ {
+ // Note: We need a `RefPtr<T>` to hash the same as a `T*`,
+ // so that a `T*` can be used as a key in a dictionary with
+ // `RefPtr<T>` keys, and vice versa.
+ //
+ return Slang::GetHashCode(pointer);
+ }
bool operator==(const T * ptr) const
{
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 3485afeea..483db60bb 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -400,22 +400,18 @@ namespace Slang
else
return DeclCheckState::CheckedHeader;
}
- DiagnosticSink* sink = nullptr;
+
+ Linkage* m_linkage = nullptr;
+ DiagnosticSink* m_sink = nullptr;
+
DiagnosticSink* getSink()
{
- return sink;
+ return m_sink;
}
// ModuleDecl * program = nullptr;
FuncDecl * function = nullptr;
- CompileRequest* request = nullptr;
- TranslationUnitRequest* translationUnit = nullptr;
-
- SourceLanguage getSourceLanguage()
- {
- return translationUnit->sourceLanguage;
- }
// lexical outer statements
List<Stmt*> outerStmts;
@@ -429,20 +425,15 @@ namespace Slang
public:
SemanticsVisitor(
- DiagnosticSink* sink,
- CompileRequest* request,
- TranslationUnitRequest* translationUnit)
- : sink(sink)
- , request(request)
- , translationUnit(translationUnit)
- {
- }
+ Linkage* linkage,
+ DiagnosticSink* sink)
+ : m_linkage(linkage)
+ , m_sink(sink)
+ {}
- CompileRequest* getCompileRequest() { return request; }
- TranslationUnitRequest* getTranslationUnit() { return translationUnit; }
Session* getSession()
{
- return getCompileRequest()->mSession;
+ return m_linkage->getSession();
}
public:
@@ -985,7 +976,7 @@ namespace Slang
catch(AbortCompilationException&) { throw; }
catch(...)
{
- getCompileRequest()->noteInternalErrorLoc(decl->loc);
+ getSink()->noteInternalErrorLoc(decl->loc);
throw;
}
}
@@ -998,7 +989,7 @@ namespace Slang
catch(AbortCompilationException&) { throw; }
catch(...)
{
- getCompileRequest()->noteInternalErrorLoc(stmt->loc);
+ getSink()->noteInternalErrorLoc(stmt->loc);
throw;
}
}
@@ -1011,7 +1002,7 @@ namespace Slang
catch(AbortCompilationException&) { throw; }
catch(...)
{
- getCompileRequest()->noteInternalErrorLoc(expr->loc);
+ getSink()->noteInternalErrorLoc(expr->loc);
throw;
}
}
@@ -1030,7 +1021,7 @@ namespace Slang
// being checked on the stack, so that we can report the full
// chain that leads from this declaration back to itself.
//
- sink->diagnose(decl, Diagnostics::cyclicReference, decl);
+ getSink()->diagnose(decl, Diagnostics::cyclicReference, decl);
return;
}
@@ -1050,7 +1041,7 @@ namespace Slang
// TODO: This diagnostic should be emitted on the line that is referencing
// the declaration. That requires `EnsureDecl` to take the requesting
// location as a parameter.
- sink->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl);
+ getSink()->diagnose(decl, Diagnostics::localVariableUsedBeforeDeclared, decl);
return;
}
}
@@ -3019,7 +3010,7 @@ namespace Slang
checkDecl(func);
}
- if (sink->GetErrorCount() != 0)
+ if (getSink()->GetErrorCount() != 0)
return;
// Force everything to be fully checked, just in case
@@ -4921,9 +4912,12 @@ namespace Slang
return new ConstantIntVal(expr->value);
}
+ Linkage* getLinkage() { return m_linkage; }
+ NamePool* getNamePool() { return getLinkage()->getNamePool(); }
+
Name* getName(String const& text)
{
- return getCompileRequest()->getNamePool()->getName(text);
+ return getNamePool()->getName(text);
}
RefPtr<IntVal> TryConstantFoldExpr(
@@ -5079,58 +5073,18 @@ namespace Slang
{
auto varDecl = varRef.getDecl();
- switch(getSourceLanguage())
+ // In HLSL, `static const` is used to mark compile-time constant expressions
+ if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>())
{
- default:
- case SourceLanguage::Slang:
- case SourceLanguage::HLSL:
- // HLSL: `static const` is used to mark compile-time constant expressions
- if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>())
- {
- if(auto constAttr = varDecl->FindModifier<ConstModifier>())
- {
- // HLSL `static const` can be used as a constant expression
- if(auto initExpr = getInitExpr(varRef))
- {
- return TryConstantFoldExpr(initExpr.Ptr());
- }
- }
- }
- break;
-
- case SourceLanguage::GLSL:
- // GLSL: `const` indicates compile-time constant expression
- //
- // TODO(tfoley): The current logic here isn't robust against
- // GLSL "specialization constants" - we will extract the
- // initializer for a `const` variable and use it to extract
- // a value, when we really should be using an opaque
- // reference to the variable.
if(auto constAttr = varDecl->FindModifier<ConstModifier>())
{
- // We need to handle a "specialization constant" (with a `constant_id` layout modifier)
- // differently from an ordinary compile-time constant. The latter can/should be reduced
- // to a value, while the former should be kept as a symbolic reference
-
- if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>())
- {
- // Retain the specialization constant as a symbolic reference
- //
- // TODO(tfoley): handle the case of non-`int` value parameters...
- //
- // TODO(tfoley): this is cloned from the case above that handles generic value parameters
- return new GenericParamIntVal(varRef);
- }
- else if(auto initExpr = getInitExpr(varRef))
+ // HLSL `static const` can be used as a constant expression
+ if(auto initExpr = getInitExpr(varRef))
{
- // This is an ordinary constant, and not a specialization constant, so we
- // can try to fold its value right now.
return TryConstantFoldExpr(initExpr.Ptr());
}
}
- break;
}
-
}
else if(auto enumRef = declRef.as<EnumCaseDecl>())
{
@@ -9060,17 +9014,32 @@ namespace Slang
auto scope = decl->scope;
// Try to load a module matching the name
- auto importedModuleDecl = findOrImportModule(request, name, decl->moduleNameAndLoc.loc);
+ auto importedModule = findOrImportModule(
+ getLinkage(),
+ name,
+ decl->moduleNameAndLoc.loc,
+ getSink());
// If we didn't find a matching module, then bail out
- if (!importedModuleDecl)
+ if (!importedModule)
return;
// Record the module that was imported, so that we can use
// it later during code generation.
+ auto importedModuleDecl = importedModule->getModuleDecl();
decl->importedModuleDecl = importedModuleDecl;
- importModuleIntoScope(scope.Ptr(), importedModuleDecl.Ptr());
+ // Add the declarations from the imported module into the scope
+ // that the `import` declaration is set to extend.
+ //
+ importModuleIntoScope(scope.Ptr(), importedModuleDecl);
+
+ // Record the `import`ed module (and everything it depends on)
+ // as a dependency of the module we are compiling.
+ if(auto module = getModule(decl))
+ {
+ module->addModuleDependency(importedModule);
+ }
decl->SetCheckState(getCheckedState());
}
@@ -9142,29 +9111,25 @@ namespace Slang
return (!decl->primaryDecl) || (decl == decl->primaryDecl);
}
- RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp)
+ RefPtr<Type> checkProperType(
+ Linkage* linkage,
+ TypeExp typeExp,
+ DiagnosticSink* sink)
{
- RefPtr<Type> type;
- DiagnosticSink nSink;
- nSink.sourceManager = tu->compileRequest->sourceManager;
SemanticsVisitor visitor(
- &nSink,
- tu->compileRequest,
- tu);
+ linkage,
+ sink);
auto typeOut = visitor.CheckProperType(typeExp);
- if (!nSink.errorCount)
- {
- type = typeOut.type;
- }
- return type;
+ return typeOut.type;
}
- FuncDecl* findFunctionDeclByName(EntryPointRequest* entryPoint, Name* name)
+ FuncDecl* findFunctionDeclByName(
+ Module* translationUnit,
+ Name* name,
+ DiagnosticSink* sink)
{
- auto translationUnit = entryPoint->getTranslationUnit();
- auto sink = &entryPoint->compileRequest->mSink;
- auto translationUnitSyntax = translationUnit->SyntaxNode;
+ auto translationUnitSyntax = translationUnit->getModuleDecl();
// Make sure we've got a query-able member dictionary
buildMemberDictionary(translationUnitSyntax);
@@ -9270,7 +9235,9 @@ namespace Slang
// Validate that an entry point function conforms to any additional
// constraints based on the stage (and profile?) it specifies.
void validateEntryPoint(
- EntryPointRequest* entryPoint)
+ FuncDecl* entryPointFuncDecl,
+ Stage stage,
+ DiagnosticSink* sink)
{
// TODO: We currently do minimal checking here, but this is the
// right place to perform the following validation checks:
@@ -9297,28 +9264,32 @@ namespace Slang
// that function is specific to the fragment profile/stage.
//
- auto sink = &entryPoint->compileRequest->mSink;
+ auto entryPointName = entryPointFuncDecl->getName();
+
+ auto module = getModule(entryPointFuncDecl);
+ auto linkage = module->getLinkage();
+
// Every entry point needs to have a stage specified either via
// command-line/API options, or via an explicit `[shader("...")]` attribute.
//
- if( entryPoint->getStage() == Stage::Unknown )
+ if( stage == Stage::Unknown )
{
- sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::entryPointHasNoStage, entryPoint->name);
+ sink->diagnose(entryPointFuncDecl, Diagnostics::entryPointHasNoStage, entryPointName);
}
- if (entryPoint->getStage() == Stage::Hull)
+ if( stage == Stage::Hull )
{
- auto translationUnit = entryPoint->getTranslationUnit();
- auto translationUnitSyntax = translationUnit->SyntaxNode;
+ // TODO: We could consider *always* checking any `[patchconsantfunc("...")]`
+ // attributes, so that they need to resolve to a function.
- auto attr = entryPoint->getFuncDecl()->FindModifier<PatchConstantFuncAttribute>();
+ auto attr = entryPointFuncDecl->FindModifier<PatchConstantFuncAttribute>();
if (attr)
{
if (attr->args.Count() != 1)
{
- sink->diagnose(translationUnitSyntax, Diagnostics::badlyDefinedPatchConstantFunc, entryPoint->name);
+ sink->diagnose(attr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName);
return;
}
@@ -9327,40 +9298,52 @@ namespace Slang
if (!stringLit)
{
- sink->diagnose(translationUnitSyntax, Diagnostics::badlyDefinedPatchConstantFunc, entryPoint->name);
+ sink->diagnose(expr, Diagnostics::badlyDefinedPatchConstantFunc, entryPointName);
return;
}
- Name* name = entryPoint->compileRequest->getNamePool()->getName(stringLit->value);
- FuncDecl* funcDecl = findFunctionDeclByName(entryPoint, name);
- if (!funcDecl)
+ // We look up the patch-constant function by its name in the module
+ // scope of the translation unit that declared the HS entry point.
+ //
+ // TODO: Eventually we probably want to do the lookup in the scope
+ // of the parent declarations of the entry point. E.g., if the entry
+ // point is a member function of a `struct`, then its patch-constant
+ // function should be allowed to be another member function of
+ // the same `struct`.
+ //
+ // In the extremely long run we may want to support an alternative to
+ // this attribute-based linkage between the two functions that
+ // make up the entry point.
+ //
+ Name* name = linkage->getNamePool()->getName(stringLit->value);
+ FuncDecl* patchConstantFuncDecl = findFunctionDeclByName(
+ module,
+ name,
+ sink);
+ if (!patchConstantFuncDecl)
{
- sink->diagnose(translationUnitSyntax, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc");
+ sink->diagnose(expr, Diagnostics::attributeFunctionNotFound, name, "patchconstantfunc");
return;
}
- attr->patchConstantFuncDecl = funcDecl;
+ attr->patchConstantFuncDecl = patchConstantFuncDecl;
}
}
- else if (entryPoint->getStage() == Stage::Compute)
+ else if(stage == Stage::Compute)
{
- auto funcDecl = entryPoint->getFuncDecl();
-
- auto params = funcDecl->GetParameters();
-
- for (const auto& param : params)
+ for(const auto& param : entryPointFuncDecl->GetParameters())
{
- if (auto semantic = param->FindModifier<HLSLSimpleSemantic>())
+ if(auto semantic = param->FindModifier<HLSLSimpleSemantic>())
{
const auto& semanticToken = semantic->name;
String lowerName = String(semanticToken.Content).ToLower();
- if (lowerName == "sv_dispatchthreadid")
+ if(lowerName == "sv_dispatchthreadid")
{
Type* paramType = param->getType();
- if (!isValidThreadDispatchIDType(paramType))
+ if(!isValidThreadDispatchIDType(paramType))
{
String typeString = paramType->ToString();
sink->diagnose(param->loc, Diagnostics::invalidDispatchThreadIDType, typeString);
@@ -9372,26 +9355,30 @@ namespace Slang
}
}
- // Given an `EntryPointRequest` specified via API or command line options,
+ // Given an entry point specified via API or command line options,
// attempt to find a matching AST declaration that implements the specified
// entry point. If such a function is found, then validate that it actually
// meets the requirements for the selected stage/profile.
//
- void findAndValidateEntryPoint(
- EntryPointRequest* entryPoint)
+ // Returns an `EntryPoint` object representing the (unspecialized)
+ // entry point if it is found and validated, and null otherwise.
+ //
+ RefPtr<EntryPoint> findAndValidateEntryPoint(
+ FrontEndEntryPointRequest* entryPointReq)
{
// The first step in validating the entry point is to find
// the (unique) function declaration that matches its name.
//
- // TODO: We will eventually need to update this logic
- // to work by parsing the provided `entryPoint->name` string
- // as an expression, so that we can handle more complex
- // names like `foo<int>` or `SomeType.vs`.
-
- auto translationUnit = entryPoint->getTranslationUnit();
- auto sink = &entryPoint->compileRequest->mSink;
- auto translationUnitSyntax = translationUnit->SyntaxNode;
+ // TODO: We may eventually want/need to extend this to
+ // account for nested names like `SomeStruct.vsMain`, or
+ // indeed even to handle generics.
+ //
+ auto compileRequest = entryPointReq->getCompileRequest();
+ auto translationUnit = entryPointReq->getTranslationUnit();
+ auto sink = compileRequest->getSink();
+ auto translationUnitSyntax = translationUnit->getModuleDecl();
+ auto entryPointName = entryPointReq->getName();
// Make sure we've got a query-able member dictionary
buildMemberDictionary(translationUnitSyntax);
@@ -9399,12 +9386,12 @@ namespace Slang
// We will look up any global-scope declarations in the translation
// unit that match the name of our entry point.
Decl* firstDeclWithName = nullptr;
- if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPoint->name, firstDeclWithName) )
+ if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPointName, firstDeclWithName) )
{
// If there doesn't appear to be any such declaration, then
// we need to diagnose it as an error, and then bail out.
- sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPoint->name);
- return;
+ sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPointName);
+ return nullptr;
}
// We found at least one global-scope declaration with the right name,
@@ -9448,7 +9435,7 @@ namespace Slang
// name before, so the whole thing is ambiguous. We need
// to diagnose and bail out.
- sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPoint->name);
+ sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPointName);
// List all of the declarations that the user *might* mean
for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName)
@@ -9460,7 +9447,7 @@ namespace Slang
}
// Bail out.
- return;
+ return nullptr;
}
}
}
@@ -9471,127 +9458,197 @@ namespace Slang
// If not, then we need to diagnose the error.
// For convenience, we will point to the first
// declaration with the right name, that wasn't a function.
- sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPoint->name);
- return;
+ sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPointName);
+ return nullptr;
}
+ // TODO: it is possible that the entry point was declared with
+ // profile or target overloading. Is there anything that we need
+ // to do at this point to filter out declarations that aren't
+ // relevant to the selected profile for the entry point?
+
+ // We found something, and can start doing some basic checking.
+ //
// If the entry point specifies a stage via a `[shader("...")]` attribute,
// then we might be able to infer a stage for the entry point request if
// it didn't have one, *or* issue a diagnostic if there is a mismatch.
//
+ auto entryPointProfile = entryPointReq->getProfile();
if( auto entryPointAttribute = entryPointFuncDecl->FindModifier<EntryPointAttribute>() )
{
- if( entryPoint->getStage() == Stage::Unknown )
+ auto entryPointStage = entryPointProfile.GetStage();
+ if( entryPointStage == Stage::Unknown )
{
- entryPoint->profile.setStage(entryPointAttribute->stage);
+ entryPointProfile.setStage(entryPointAttribute->stage);
}
- else if( entryPointAttribute->stage != entryPoint->getStage() )
+ else if( entryPointAttribute->stage != entryPointStage )
{
- sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPoint->name, entryPoint->getStage(), entryPointAttribute->stage);
+ sink->diagnose(entryPointFuncDecl, Diagnostics::specifiedStageDoesntMatchAttribute, entryPointName, entryPointStage, entryPointAttribute->stage);
}
}
+ else
+ {
+ // TODO: Should we attach a `[shader(...)]` attribute to an
+ // entry point that didn't have one, so that we can have
+ // a more uniform representation in the AST?
+ }
- // TODO: it is possible that the entry point was declared with
- // profile or target overloading. Is there anything that we need
- // to do at this point to filter out declarations that aren't
- // relevant to the selected profile for the entry point?
- // Phew, we have at least found a suitable decl.
- // Let's record that in the entry-point request so
- // that we don't have to re-do this effort again later.
- //
- // Note: we may replace the decl-ref we store at this point
- // later in this function, when we (potentially) specialize
- // a generic entry point to generic arguments provided
- // via the API.
+ // Now that we've *found* the entry point, it is time to validate
+ // that it actually meets the constraints for the chosen stage/profile.
//
- entryPoint->funcDeclRef = makeDeclRef(entryPointFuncDecl);
+ validateEntryPoint(
+ entryPointFuncDecl,
+ entryPointProfile.GetStage(),
+ sink);
- // If the user specified generic arguments for the entry point,
- // then we will want to parse those arguments as expressions
- // in a scope that includes the tanslation unit that holds
- // the entry point, as well as any other modules that got
- // transitively loaded via `import`.
- //
- // TODO: This would be better handled by giving the user
- // more explicit ways to parse/build types at the API level,
- // rather than keeping things string-based this far along.
+ RefPtr<EntryPoint> entryPoint = EntryPoint::create(
+ makeDeclRef(entryPointFuncDecl),
+ entryPointProfile);
+
+ return entryPoint;
+ }
+
+ /// Create a `Program` to represent the compiled code.
+ ///
+ /// The created program will comprise all of the translation
+ /// units that were compiled as part of the request, as
+ /// well as any entry points in those translation units.
+ ///
+ RefPtr<Program> createUnspecializedProgram(
+ FrontEndCompileRequest* compileRequest)
+ {
+ // We want our resulting program to depend on
+ // all the translation units the user specified,
+ // even if some of them don't contain entry points
+ // (this is important for parameter layout/binding).
//
- // TODO: Building a list of `scopesToTry` here shouldn't
- // be required, since the `Scope` type itself has the ability
- // for form chains for lookup purposes (e.g., the way that
- // `import` is handled by modifying a scope).
+ // We also want to ensure that the modules for the
+ // translation units comes first in the enumerated
+ // order for dependencies, to match the pre-existing
+ // compiler behavior (at least for now).
//
- List<RefPtr<Scope>> scopesToTry;
- scopesToTry.Add(entryPoint->getTranslationUnit()->SyntaxNode->scope);
- for (auto & module : entryPoint->compileRequest->loadedModulesList)
- scopesToTry.Add(module->moduleDecl->scope);
+ auto linkage = compileRequest->getLinkage();
+ auto sink = compileRequest->getSink();
+ auto program = new Program(linkage);
+ for(auto translationUnit : compileRequest->translationUnits )
+ {
+ program->addReferencedLeafModule(translationUnit->getModule());
+ }
+ for(auto translationUnit : compileRequest->translationUnits )
+ {
+ program->addReferencedModule(translationUnit->getModule());
+ }
- // We are going to do some semantic checking, so we need to
- // set up a `SemanticsVistitor` that we can use.
- //
- SemanticsVisitor semantics(
- &entryPoint->compileRequest->mSink,
- entryPoint->compileRequest,
- entryPoint->getTranslationUnit());
- // We will be looping over the generic argument strings
- // that the user provided via the API (or command line),
- // and parsing+checking each into an `Expr`.
+ // The validation of entry points here will be modal, and controlled
+ // by whether the user specified any entry points directly via
+ // API or command-line options.
//
- // This loop will *not* handle coercing the arguments
- // to be types.
+ // TODO: We may want to make this choice explicit rather than implicit.
//
- List<RefPtr<Expr>> genericArgs;
- for (auto name : entryPoint->genericArgStrings)
+ // First, check if the user requested any entry points explicitly via
+ // the API or command line.
+ //
+ bool anyExplicitEntryPoints = compileRequest->getEntryPointReqCount() != 0;
+
+ if( anyExplicitEntryPoints )
{
- RefPtr<Expr> argExpr;
- for (auto & s : scopesToTry)
+ // If there were any explicit requests for entry points to be
+ // checked, then we will *only* check those.
+ //
+ for(auto entryPointReq : compileRequest->getEntryPointReqs())
{
- argExpr = entryPoint->compileRequest->parseTypeString(
- entryPoint->getTranslationUnit(),
- name,
- s);
- argExpr = semantics.CheckTerm(argExpr);
- if( argExpr )
+ auto entryPoint = findAndValidateEntryPoint(
+ entryPointReq);
+ if( entryPoint )
{
- break;
+ program->addEntryPoint(entryPoint);
+ entryPointReq->getTranslationUnit()->entryPoints.Add(entryPoint);
}
}
- // The following is a bit of a hack.
- //
- // Back-end code generation relies on us having computed layouts for all tagged
- // unions that end up being used in the code, which means we need a way to find
- // all such types that get used in a module (and the stuff it imports).
+ // TODO: We should consider always processing both categories,
+ // and just making sure to only check each entry point function
+ // declaration once...
+ }
+ else
+ {
+ // Otherwise, scan for any `[shader(...)]` attributes in
+ // the user's code, and construct `EntryPoint`s to
+ // represent them.
//
- // The Right Way to handle this would probably be to have each `ModuleDecl` track
- // any tagged union types that get created in the context of that module, and
- // then combine those lists later.
+ // This ensures that downstream code only has to consider
+ // the central list of entry point requests, and doesn't
+ // have to know where they came from.
+
+ // TODO: A comprehensive approach here would need to search
+ // recursively for entry points, because they might appear
+ // as, e.g., member function of a `struct` type.
//
- // For now we are assuming a tagged union type only comes into existence
- // as a (top-level) argument for a generic type parameter, so that we
- // can check for them here and cache them on the entry point request.
+ // For now we'll start with an extremely basic approach that
+ // should work for typical HLSL code.
//
- if( auto typeType = as<TypeType>(argExpr->type) )
+ UInt translationUnitCount = compileRequest->translationUnits.Count();
+ for(UInt tt = 0; tt < translationUnitCount; ++tt)
{
- auto type = typeType->type;
- if( auto taggedUnionType = as<TaggedUnionType>(type) )
+ auto translationUnit = compileRequest->translationUnits[tt];
+ for( auto globalDecl : translationUnit->getModuleDecl()->Members )
{
- entryPoint->taggedUnionTypes.Add(taggedUnionType);
+ auto maybeFuncDecl = globalDecl;
+ if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) )
+ {
+ maybeFuncDecl = genericDecl->inner;
+ }
+
+ auto funcDecl = as<FuncDecl>(maybeFuncDecl);
+ if(!funcDecl)
+ continue;
+
+ auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>();
+ if(!entryPointAttr)
+ continue;
+
+ // We've discovered a valid entry point. It is a function (possibly
+ // generic) that has a `[shader(...)]` attribute to mark it as an
+ // entry point.
+ //
+ // We will now register that entry point as an `EntryPoint`
+ // with an appropriately chosen profile.
+ //
+ // The profile will only include a stage, so that the profile "family"
+ // and "version" are left unspecified. Downstream code will need
+ // to be able to handle this case.
+ //
+ Profile profile;
+ profile.setStage(entryPointAttr->stage);
+
+ validateEntryPoint(funcDecl, entryPointAttr->stage, sink);
+
+ RefPtr<EntryPoint> entryPoint = EntryPoint::create(
+ makeDeclRef(funcDecl),
+ profile);
+ program->addEntryPoint(entryPoint);
+ translationUnit->entryPoints.Add(entryPoint);
}
}
-
- genericArgs.Add(argExpr);
}
- // There are two cases we care about here, and we are going to treat them
- // as mutually exclusive for simplicity.
- //
- // The first case is when the entry point function is itself generic,
- // in which case we will assume that `genericArgs` lines up one-to-one
- // with the explicit generic parameters of the entry point.
- //
+ return program;
+ }
+
+ /// Create a specialization an existing entry point based on generic arguments.
+ DeclRef<FuncDecl> specializeEntryPoint(
+ Linkage* linkage,
+ FuncDecl* entryPointFuncDecl,
+ List<RefPtr<Expr>> const& genericArgs,
+ DiagnosticSink* sink)
+ {
+ SemanticsVisitor semantics(
+ linkage,
+ sink);
+
+ DeclRef<FuncDecl> entryPointFuncDeclRef = makeDeclRef(entryPointFuncDecl);
if( auto genericDecl = as<GenericDecl>(entryPointFuncDecl->ParentDecl) )
{
// We will construct a suitable `GenericAppExpr` to represent
@@ -9601,7 +9658,7 @@ namespace Slang
// generic application like `F<A,B,C>` if it were
// encountered in the source code.
- auto session = entryPoint->compileRequest->mSession;
+ auto session = linkage->getSession();
auto genericDeclRef = makeDeclRef(genericDecl);
// The first pieces is a `VarExpr` that refers to `genericDecl`.
@@ -9639,12 +9696,13 @@ namespace Slang
// The basic `VarExpr` and `StaticMemberExpr` cases
// should be allow-able.
- entryPoint->funcDeclRef = declRefExpr->declRef.as<FuncDecl>();
+ entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>();
}
else if( semantics.IsErrorExpr(checkedExpr) )
{
// Any semantic error that occured should have been
// reported already.
+ return DeclRef<FuncDecl>();
}
else
{
@@ -9652,302 +9710,359 @@ namespace Slang
// function should always be a `DeclRefExpr`
//
SLANG_UNEXPECTED("reference to generic decl wasn't a `DeclRefExpr`");
+ UNREACHABLE_RETURN(DeclRef<FuncDecl>());
}
}
- else
- {
- // The other case is when the entry point function is *not* itself
- // generic, so we assume that any generic arguments must have been intended
- // to match up with global generic parameters instead.
- //
- // We will only validate global generic type arguments when we are going
- // to generate code, since in a no-codegen pass we will typically *not*
- // have arguments to associate with the parameters.
- //
- if ((entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0)
- {
- // check that user-provioded type arguments conforms to the generic type
- // parameter declaration of this translation unit
- // collect global generic parameters from all imported modules
- List<RefPtr<GlobalGenericParamDecl>> globalGenericParams;
- // add current translation unit first
- {
- auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>();
- for (auto p : globalGenParams)
- globalGenericParams.Add(p);
- }
- // add imported modules
- for (auto loadedModule : entryPoint->compileRequest->loadedModulesList)
- {
- auto moduleDecl = loadedModule->moduleDecl;
- auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>();
- for (auto p : globalGenParams)
- globalGenericParams.Add(p);
- }
-
- if (globalGenericParams.Count() != genericArgs.Count())
- {
- sink->diagnose(entryPoint->getFuncDecl(), Diagnostics::mismatchEntryPointTypeArgument,
- globalGenericParams.Count(),
- genericArgs.Count());
- return;
- }
-
- // We have an appropriate number of arguments for the global generic parameters,
- // and now we need to check that the arguments conform to the declared constraints.
- //
- // Along the way, we will build up an appropriate set of substitutions to represent
- // the generic arguments and their conformances.
- //
- RefPtr<Substitutions> globalGenericSubsts;
- auto globalGenericSubstLink = &globalGenericSubsts;
- //
- // TODO: There is a serious flaw to this checking logic if we ever have cases where
- // the constraints on one `type_param` can depend on another `type_param`, e.g.:
- //
- // type_param A;
- // type_param B : ISidekick<A>;
- //
- // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to
- // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being
- // set to `Batman` to know whether the setting for `B` is valid. In this limit
- // the constraints can be mutually recursive (so `A : IMentor<B>`).
- //
- // The only way to check things correctly is to validate each conformance under
- // a set of assumptions (substitutions) that includes all the type substitutions,
- // and possibly also all the other constraints *except* the one to be validated.
- //
- // We will punt on this for now, and just check each constraint in isolation.
- //
- UInt argCounter = 0;
- for(auto& globalGenericParam : globalGenericParams)
- {
- // Get the argument that matches this parameter.
- UInt argIndex = argCounter++;
- SLANG_ASSERT(argIndex < genericArgs.Count());
- auto globalGenericArg = checkProperType(translationUnit, TypeExp(genericArgs[argIndex]));
- if (!globalGenericArg)
- {
- sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, entryPoint->genericArgStrings[argIndex]);
- return;
- }
+ return entryPointFuncDeclRef;
+ }
- // As a quick sanity check, see if the argument that is being supplied for a parameter
- // is just the parameter itself, because this should always be an error:
- //
- if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() )
- {
- auto argDeclRef = argDeclRefType->declRef;
- if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>())
- {
- if(argGenericParamDeclRef.getDecl() == globalGenericParam)
- {
- // We are trying to specialize a generic parameter using itself.
- sink->diagnose(globalGenericParam,
- Diagnostics::cannotSpecializeGlobalGenericToItself,
- globalGenericParam->getName());
- sink->diagnose(entryPointFuncDecl,
- Diagnostics::noteWhenCompilingEntryPoint,
- entryPointFuncDecl->getName());
- continue;
- }
- else
- {
- // We are trying to specialize a generic parameter using a *different*
- // global generic type parameter.
- sink->diagnose(globalGenericParam,
- Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam,
- globalGenericParam->getName(),
- argGenericParamDeclRef.GetName());
- sink->diagnose(entryPointFuncDecl,
- Diagnostics::noteWhenCompilingEntryPoint,
- entryPointFuncDecl->getName());
- continue;
- }
- }
- }
+ /// Parse an array of strings as generic arguments.
+ ///
+ /// Names in the strings will be parsed in the context of
+ /// the code loaded into the given compile request.
+ ///
+ void parseGenericArgStrings(
+ EndToEndCompileRequest* endToEndReq,
+ List<String> const& genericArgStrings,
+ List<RefPtr<Expr>>& outGenericArgs)
+ {
+ auto unspecialiedProgram = endToEndReq->getUnspecializedProgram();
- // Create a substitution for this parameter/argument.
- RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution();
- subst->paramDecl = globalGenericParam;
- subst->actualType = globalGenericArg;
+ // TODO: Building a list of `scopesToTry` here shouldn't
+ // be required, since the `Scope` type itself has the ability
+ // for form chains for lookup purposes (e.g., the way that
+ // `import` is handled by modifying a scope).
+ //
+ List<RefPtr<Scope>> scopesToTry;
+ for( auto module : unspecialiedProgram->getModuleDependencies() )
+ scopesToTry.Add(module->getModuleDecl()->scope);
- // Walk through the declared constraints for the parameter,
- // and check that the argument actually satisfies them.
- for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>())
- {
- // Get the type that the constraint is enforcing conformance to
- auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr));
+ // We are going to do some semantic checking, so we need to
+ // set up a `SemanticsVistitor` that we can use.
+ //
+ auto linkage = endToEndReq->getLinkage();
+ auto sink = endToEndReq->getSink();
+ SemanticsVisitor semantics(
+ linkage,
+ sink);
- // Use our semantic-checking logic to search for a witness to the required conformance
- SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit);
- auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType);
- if (!witness)
- {
- // If no witness was found, then we will be unable to satisfy
- // the conformances required.
- sink->diagnose(globalGenericParam,
- Diagnostics::typeArgumentDoesNotConformToInterface,
- globalGenericParam->nameAndLoc.name,
- globalGenericArg,
- interfaceType);
- }
+ // We will be looping over the generic argument strings
+ // that the user provided via the API (or command line),
+ // and parsing+checking each into an `Expr`.
+ //
+ // This loop will *not* handle coercing the arguments
+ // to be types.
+ //
+ for(auto name : genericArgStrings)
+ {
+ RefPtr<Expr> argExpr;
+ for (auto & s : scopesToTry)
+ {
+ argExpr = linkage->parseTypeString(name, s);
+ argExpr = semantics.CheckTerm(argExpr);
+ if( argExpr )
+ {
+ break;
+ }
+ }
- // Attach the concrete witness for this conformance to the
- // substutiton
- GlobalGenericParamSubstitution::ConstraintArg constraintArg;
- constraintArg.decl = constraint;
- constraintArg.val = witness;
- subst->constraintArgs.Add(constraintArg);
- }
+ outGenericArgs.Add(argExpr);
+ }
+ }
- // Add the substitution for this parameter to the global substitution
- // set that we are building.
+ /// Specialize a program to global generic arguments
+ RefPtr<Program> createSpecializedProgram(
+ Linkage* linkage,
+ Program* unspecializedProgram,
+ List<RefPtr<Expr>> const& globalGenericArgs,
+ DiagnosticSink* sink)
+ {
+ // The given `unspecializedProgram` should be one that
+ // was checked through the front-end, so that now we
+ // only need to check if the given arguments can satisfy
+ // the requirements of the global generic parameters.
+ //
+ // The new program needs to start off with the same
+ // module dependency list as the original.
+ //
+ RefPtr<Program> specializedProgram = new Program(linkage);
+ for(auto module : unspecializedProgram->getModuleDependencies())
+ {
+ specializedProgram->addReferencedLeafModule(module);
+ }
- *globalGenericSubstLink = subst;
- globalGenericSubstLink = &subst->outer;
- }
- entryPoint->globalGenericSubst = globalGenericSubsts;
- }
+ // We will collect all the global generic parameters
+ // defined in the modules being referenced, to find
+ // the global generic parameter signature of the
+ // program.
+ //
+ // TODO: Note that this doesn't handle the case where one
+ // or more of the type *arguments* that we are specifying
+ // ends up requiring additional modules to be referenced,
+ // which might in turn introduce new global generic parameters.
+ //
+ List<RefPtr<GlobalGenericParamDecl>> globalGenericParams;
+ for(auto module : unspecializedProgram->getModuleDependencies())
+ {
+ for(auto param : module->getModuleDecl()->getMembersOfType<GlobalGenericParamDecl>())
+ globalGenericParams.Add(param);
}
- // If any errors occured while we were checking the generic arguments
- // of the entry point, then we should bail out rather than try to
- // perform the next step of validation.
+ // Next, we will check whether the supplied arguments can
+ // satisfy those parameters.
+ //
+ // An easy early-out case will be if the number of
+ // arguments isn't correct.
//
- if (sink->errorCount != 0)
- return;
+ if (globalGenericParams.Count() != globalGenericArgs.Count())
+ {
+ sink->diagnose(SourceLoc(), Diagnostics::mismatchGlobalGenericArguments,
+ globalGenericParams.Count(),
+ globalGenericArgs.Count());
+ return nullptr;
+ }
- // Now that we've *found* the entry point, it is time to validate
- // that it actually meets the constraints for the chosen stage/profile.
+ // We have an appropriate number of arguments for the global generic parameters,
+ // and now we need to check that the arguments conform to the declared constraints.
//
- // TODO: This validation should (probably?) be performed "under" any global generic
- // parameter substitution we might have created, so that we can validate
- // based on knowledge of actual types.
+ // Along the way, we will build up an appropriate set of substitutions to represent
+ // the generic arguments and their conformances.
//
- validateEntryPoint(entryPoint);
- }
-
- void validateEntryPoints(
- CompileRequest* compileRequest)
- {
- // The validation of entry points here will be modal, and controlled
- // by whether the user specified any entry points directly via
- // API or command-line options.
+ RefPtr<Substitutions> globalGenericSubsts;
+ auto globalGenericSubstLink = &globalGenericSubsts;
//
- // TODO: We may want to make this choice explicit rather than implicit.
+ // TODO: There is a serious flaw to this checking logic if we ever have cases where
+ // the constraints on one `type_param` can depend on another `type_param`, e.g.:
//
- // First, check if the user request any entry points explicitly via
- // the API or command line.
+ // type_param A;
+ // type_param B : ISidekick<A>;
+ //
+ // In that case, if a user tries to set `B` to `Robin` and `Robin` conforms to
+ // `ISidekick<Batman>`, then the compiler needs to know whether `A` is being
+ // set to `Batman` to know whether the setting for `B` is valid. In this limit
+ // the constraints can be mutually recursive (so `A : IMentor<B>`).
//
- bool anyExplicitEntryPointRequests = false;
- for (auto& translationUnit : compileRequest->translationUnits)
+ // The only way to check things correctly is to validate each conformance under
+ // a set of assumptions (substitutions) that includes all the type substitutions,
+ // and possibly also all the other constraints *except* the one to be validated.
+ //
+ // We will punt on this for now, and just check each constraint in isolation.
+ //
+ UInt argCounter = 0;
+ for(auto& globalGenericParam : globalGenericParams)
{
- if( translationUnit->entryPoints.Count() != 0)
+ // Get the argument that matches this parameter.
+ UInt argIndex = argCounter++;
+ SLANG_ASSERT(argIndex < globalGenericArgs.Count());
+ auto globalGenericArg = checkProperType(linkage, TypeExp(globalGenericArgs[argIndex]), sink);
+ if (!globalGenericArg)
{
- anyExplicitEntryPointRequests = true;
- break;
+ sink->diagnose(globalGenericParam, Diagnostics::globalGenericArgumentNotAType, globalGenericParam->getName());
+ return nullptr;
}
- }
- if( anyExplicitEntryPointRequests )
- {
- // If there were any explicit requests for entry points to be
- // checked, then we will *only* check those.
-
- for (auto& translationUnit : compileRequest->translationUnits)
+ // As a quick sanity check, see if the argument that is being supplied for a parameter
+ // is just the parameter itself, because this should always be an error:
+ //
+ if( auto argDeclRefType = globalGenericArg.as<DeclRefType>() )
{
- for (auto entryPoint : translationUnit->entryPoints)
+ auto argDeclRef = argDeclRefType->declRef;
+ if(auto argGenericParamDeclRef = argDeclRef.as<GlobalGenericParamDecl>())
{
- findAndValidateEntryPoint(entryPoint);
+ if(argGenericParamDeclRef.getDecl() == globalGenericParam)
+ {
+ // We are trying to specialize a generic parameter using itself.
+ sink->diagnose(globalGenericParam,
+ Diagnostics::cannotSpecializeGlobalGenericToItself,
+ globalGenericParam->getName());
+ continue;
+ }
+ else
+ {
+ // We are trying to specialize a generic parameter using a *different*
+ // global generic type parameter.
+ sink->diagnose(globalGenericParam,
+ Diagnostics::cannotSpecializeGlobalGenericToAnotherGenericParam,
+ globalGenericParam->getName(),
+ argGenericParamDeclRef.GetName());
+ continue;
+ }
}
}
- }
- else
- {
- // Otherwise, scan for any `[shader(...)]` attributes in
- // the user's code, and construct `EntryPointRequest`s to
- // represent them.
- //
- // This ensures that downstream code only has to consider
- // the central list of entry point requests, and doesn't
- // have to know where they came from.
- // TODO: A comprehensive approach here would need to search
- // recursively for entry points, because they might appear
- // as, e.g., member function of a `struct` type.
- //
- // For now we'll start with an extremely basic approach that
- // should work for typical HLSL code.
- //
- UInt translationUnitCount = compileRequest->translationUnits.Count();
- for(UInt tt = 0; tt < translationUnitCount; ++tt)
+ // Create a substitution for this parameter/argument.
+ RefPtr<GlobalGenericParamSubstitution> subst = new GlobalGenericParamSubstitution();
+ subst->paramDecl = globalGenericParam;
+ subst->actualType = globalGenericArg;
+
+ // Walk through the declared constraints for the parameter,
+ // and check that the argument actually satisfies them.
+ for(auto constraint : globalGenericParam->getMembersOfType<GenericTypeConstraintDecl>())
{
- auto translationUnit = compileRequest->translationUnits[tt];
- for( auto globalDecl : translationUnit->SyntaxNode->Members )
+ // Get the type that the constraint is enforcing conformance to
+ auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr));
+
+ // Use our semantic-checking logic to search for a witness to the required conformance
+ SemanticsVisitor visitor(linkage, sink);
+ auto witness = visitor.tryGetSubtypeWitness(globalGenericArg, interfaceType);
+ if (!witness)
{
- auto maybeFuncDecl = globalDecl;
- if( auto genericDecl = as<GenericDecl>(maybeFuncDecl) )
- {
- maybeFuncDecl = genericDecl->inner;
- }
+ // If no witness was found, then we will be unable to satisfy
+ // the conformances required.
+ sink->diagnose(globalGenericParam,
+ Diagnostics::typeArgumentDoesNotConformToInterface,
+ globalGenericParam->nameAndLoc.name,
+ globalGenericArg,
+ interfaceType);
+ }
- auto funcDecl = as<FuncDecl>(maybeFuncDecl);
- if(!funcDecl)
- continue;
+ // Attach the concrete witness for this conformance to the
+ // substutiton
+ GlobalGenericParamSubstitution::ConstraintArg constraintArg;
+ constraintArg.decl = constraint;
+ constraintArg.val = witness;
+ subst->constraintArgs.Add(constraintArg);
+ }
- auto entryPointAttr = funcDecl->FindModifier<EntryPointAttribute>();
- if(!entryPointAttr)
- continue;
+ // Add the substitution for this parameter to the global substitution
+ // set that we are building.
- // We've discovered a valid entry point. It is a function (possibly
- // generic) that has a `[shader(...)]` attribute to mark it as an
- // entry point.
- //
- // We will now register that entry point as an `EntryPointRequest`
- // with an appropriately chosen profile.
- //
- // The profile will only include a stage, so that the profile "family"
- // and "version" are left unspecified. Downstream code will need
- // to be able to handle this case.
- //
- Profile profile;
- profile.setStage(entryPointAttr->stage);
+ *globalGenericSubstLink = subst;
+ globalGenericSubstLink = &subst->outer;
+ }
+ if(sink->GetErrorCount())
+ return nullptr;
- // We manually fill in the entry point request object.
- RefPtr<EntryPointRequest> entryPointReq = new EntryPointRequest();
- entryPointReq->compileRequest = compileRequest;
- entryPointReq->translationUnitIndex = int(tt);
- entryPointReq->funcDeclRef = makeDeclRef(funcDecl);
- entryPointReq->name = funcDecl->getName();
- entryPointReq->profile = profile;
+ specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts);
- // Apply the common validation logic to this entry point.
- validateEntryPoint(entryPointReq);
+ return specializedProgram;
+ }
- // Add the entry point to the list in the translation unit,
- // and also the global list in the compile request.
- translationUnit->entryPoints.Add(entryPointReq);
- compileRequest->entryPoints.Add(entryPointReq);
- }
- }
+ /// Specialize an entry point that was checked by the front-end, based on generic arguments.
+ ///
+ /// If the end-to-end compile request included generic argument strings
+ /// for this entry point, then they will be parsed, checked, and used
+ /// as arguments to the generic entry point.
+ ///
+ /// Returns a specialized entry point if everything worked as expected.
+ /// Returns null and diagnoses errors if anything goes wrong.
+ ///
+ RefPtr<EntryPoint> specializeEntryPoint(
+ EndToEndCompileRequest* endToEndReq,
+ EntryPoint* unspecializedEntryPoint,
+ EndToEndCompileRequest::EntryPointInfo const& entryPointInfo)
+ {
+ auto linkage = endToEndReq->getLinkage();
+ auto sink = endToEndReq->getSink();
+ auto entryPointFuncDecl = unspecializedEntryPoint->getFuncDecl();
+
+ // If the user specified generic arguments for the entry point,
+ // then we will need to parse the arguments first.
+ //
+ List<RefPtr<Expr>> genericArgs;
+ parseGenericArgStrings(
+ endToEndReq,
+ entryPointInfo.genericArgStrings,
+ genericArgs);
+
+ // Next we specialize the entry point function given the parsed
+ // generic argument expressions.
+ //
+ auto entryPointFuncDeclRef = specializeEntryPoint(
+ linkage,
+ entryPointFuncDecl,
+ genericArgs,
+ sink);
+
+ RefPtr<EntryPoint> entryPoint = EntryPoint::create(
+ entryPointFuncDeclRef,
+ unspecializedEntryPoint->getProfile());
+
+ return entryPoint;
+ }
+
+ /// Create a specialized program based on the given compile request.
+ ///
+ RefPtr<Program> createSpecializedProgram(
+ EndToEndCompileRequest* endToEndReq)
+ {
+ // The compile request must have already completed front-end processing,
+ // so that we have an unspecialized program available, and now only need
+ // to parse and check any generic arguments that are being supplied for
+ // global or entry-point generic parameters.
+ //
+ auto unspecializedProgram = endToEndReq->getUnspecializedProgram();
+
+ // First, let's parse the generic argument strings that were
+ // provided via the API, so taht we can match them
+ // against what was declared in the program.
+ //
+ List<RefPtr<Expr>> globalGenericArgs;
+ parseGenericArgStrings(
+ endToEndReq,
+ endToEndReq->globalGenericArgStrings,
+ globalGenericArgs);
+
+ // Now we create the initial specialized program by
+ // applying the global generic arguments (if any) to the
+ // unspecialized program.
+ //
+ auto specializedProgram = createSpecializedProgram(
+ endToEndReq->getLinkage(),
+ unspecializedProgram,
+ globalGenericArgs,
+ endToEndReq->getSink());
+
+ // If anything went wrong with the global generic
+ // arguments, then bail out now.
+ //
+ if(!specializedProgram)
+ return nullptr;
+
+ // Next we will deal with the entry points for the
+ // new specialized program.
+ //
+ // If the user specified explicit entry points as part of the
+ // end-to-end request, then we only want to process those (and
+ // ignore any other `[shader(...)]`-attributed entry points).
+ //
+ // However, if the user specified *no* entry points as part
+ // of the end-to-end request, then we would like to go
+ // ahead and consider all the entry points that were found
+ // by the front-end.
+ //
+ UInt entryPointCount = endToEndReq->entryPoints.Count();
+ if( entryPointCount == 0 )
+ {
+ entryPointCount = unspecializedProgram->getEntryPointCount();
+ endToEndReq->entryPoints.SetSize(entryPointCount);
}
+
+ for( UInt ii = 0; ii < entryPointCount; ++ii )
+ {
+ auto unspecializedEntryPoint = unspecializedProgram->getEntryPoint(ii);
+ auto& entryPointInfo = endToEndReq->entryPoints[ii];
+
+ auto specializedEntryPoint = specializeEntryPoint(endToEndReq, unspecializedEntryPoint, entryPointInfo);
+ specializedProgram->addEntryPoint(specializedEntryPoint);
+ }
+
+ return specializedProgram;
}
void checkTranslationUnit(
TranslationUnitRequest* translationUnit)
{
SemanticsVisitor visitor(
- &translationUnit->compileRequest->mSink,
- translationUnit->compileRequest,
- translationUnit);
+ translationUnit->compileRequest->getLinkage(),
+ translationUnit->compileRequest->getSink());
// Apply the visitor to do the main semantic
// checking that is required on all declarations
// in the translation unit.
- visitor.checkDecl(translationUnit->SyntaxNode);
+ visitor.checkDecl(translationUnit->getModuleDecl());
}
diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp
index 21f56c9ee..3bc34692d 100644
--- a/source/slang/compiler.cpp
+++ b/source/slang/compiler.cpp
@@ -120,23 +120,123 @@ namespace Slang
return blob;
}
- // EntryPointRequest
+ //
+ // FrontEndEntryPointRequest
+ //
+
+ FrontEndEntryPointRequest::FrontEndEntryPointRequest(
+ FrontEndCompileRequest* compileRequest,
+ int translationUnitIndex,
+ Name* name,
+ Profile profile)
+ : m_compileRequest(compileRequest)
+ , m_translationUnitIndex(translationUnitIndex)
+ , m_name(name)
+ , m_profile(profile)
+ {}
- TranslationUnitRequest* EntryPointRequest::getTranslationUnit()
+
+ TranslationUnitRequest* FrontEndEntryPointRequest::getTranslationUnit()
{
- return compileRequest->translationUnits[translationUnitIndex].Ptr();
+ return getCompileRequest()->translationUnits[m_translationUnitIndex];
}
- DeclRef<FuncDecl> EntryPointRequest::getFuncDeclRef()
+ //
+ // EntryPoint
+ //
+
+ RefPtr<EntryPoint> EntryPoint::create(
+ DeclRef<FuncDecl> funcDeclRef,
+ Profile profile)
{
- return funcDeclRef;
+ RefPtr<EntryPoint> entryPoint = new EntryPoint(
+ funcDeclRef.GetName(),
+ profile,
+ funcDeclRef);
+ return entryPoint;
}
- RefPtr<FuncDecl> EntryPointRequest::getFuncDecl()
+ RefPtr<EntryPoint> EntryPoint::createDummyForPassThrough(
+ Name* name,
+ Profile profile)
{
- return getFuncDeclRef().getDecl();
+ RefPtr<EntryPoint> entryPoint = new EntryPoint(
+ name,
+ profile,
+ DeclRef<FuncDecl>());
+ return entryPoint;
}
+ EntryPoint::EntryPoint(
+ Name* name,
+ Profile profile,
+ DeclRef<FuncDecl> funcDeclRef)
+ : m_name(name)
+ , m_profile(profile)
+ , m_funcDeclRef(funcDeclRef)
+ {
+ // In order for later code generation to work, we need to track what
+ // modules each entry point depends on. We will build up the dependency
+ // list here when an `EntryPoint` gets created.
+ //
+ // We know an entry point depends on the module that declared the
+ // entry-point function itself.
+ //
+ // Note: we are carefully handling the case where `module` could
+ // be null, becase of "dummy" entry points created for pass-through
+ // compilation.
+ //
+ if(auto module = getModule())
+ {
+ m_dependencyList.addDependency(module);
+ }
+ //
+ // TODO: We also need to include the modules needed by any generic
+ // arguments in the dependency list, since in the general case they
+ // might come from modules other than the one defining the entry point.
+
+ // The following is a bit of a hack.
+ //
+ // Back-end code generation relies on us having computed layouts for all tagged
+ // unions that end up being used in the code, which means we need a way to find
+ // all such types that get used in a program (and the stuff it imports).
+ //
+ // For now we are assuming a tagged union type only comes into existence
+ // as a (top-level) argument for a generic type parameter, so that we
+ // can check for them here and cache them on the entry point.
+ //
+ // A longer-term strategy might need to consider any (tagged or untagged)
+ // union types that get used inside of a module, and also take
+ // those lists into account.
+ //
+ // An even longer-term strategy would be to allow type layout to
+ // be performed on IR types, so taht we don't need to have front-end
+ // code worrying about this stuff.
+ //
+ for( auto subst = funcDeclRef.substitutions.substitutions; subst; subst = subst->outer )
+ {
+ if( auto genericSubst = as<GenericSubstitution>(subst) )
+ {
+ for( auto arg : genericSubst->args )
+ {
+ if( auto taggedUnionType = as<TaggedUnionType>(arg) )
+ {
+ m_taggedUnionTypes.Add(taggedUnionType);
+ }
+ }
+ }
+ }
+ }
+
+ Module* EntryPoint::getModule()
+ {
+ return Slang::getModule(getFuncDecl());
+ }
+
+ Linkage* EntryPoint::getLinkage()
+ {
+ return getModule()->getLinkage();
+ }
//
@@ -279,13 +379,35 @@ namespace Slang
//
+ /// If there is a pass-through compile going on, find the translation unit for the given entry point.
+ TranslationUnitRequest* findPassThroughTranslationUnit(
+ EndToEndCompileRequest* endToEndReq,
+ Int entryPointIndex)
+ {
+ // If there isn't an end-to-end compile going on,
+ // there can be no pass-through.
+ //
+ if(!endToEndReq) return nullptr;
+
+ // And if pass-through isn't set, we don't need
+ // access to the translation unit.
+ //
+ if(endToEndReq->passThrough == PassThroughMode::None) return nullptr;
+
+ auto frontEndReq = endToEndReq->getFrontEndReq();
+ auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex);
+ auto translationUnit = entryPointReq->getTranslationUnit();
+ return translationUnit;
+ }
+
String emitHLSLForEntryPoint(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq)
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq)
{
- auto compileRequest = entryPoint->compileRequest;
- auto translationUnit = entryPoint->getTranslationUnit();
- if (compileRequest->passThrough != PassThroughMode::None)
+ if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex))
{
// Generate a string that includes the content of
// the source file(s), along with a line directive
@@ -294,7 +416,7 @@ namespace Slang
// mode.
StringBuilder codeBuilder;
- for(auto sourceFile : translationUnit->sourceFiles)
+ for(auto sourceFile : translationUnit->getSourceFiles())
{
codeBuilder << "#line 1 \"";
@@ -323,21 +445,21 @@ namespace Slang
else
{
return emitEntryPoint(
+ compileRequest,
entryPoint,
- targetReq->layout.Ptr(),
CodeGenTarget::HLSL,
targetReq);
}
}
String emitGLSLForEntryPoint(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq)
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq)
{
- auto compileRequest = entryPoint->compileRequest;
- auto translationUnit = entryPoint->getTranslationUnit();
-
- if (compileRequest->passThrough != PassThroughMode::None)
+ if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex))
{
// Generate a string that includes the content of
// the source file(s), along with a line directive
@@ -347,7 +469,7 @@ namespace Slang
StringBuilder codeBuilder;
int translationUnitCounter = 0;
- for(auto sourceFile : translationUnit->sourceFiles)
+ for(auto sourceFile : translationUnit->getSourceFiles())
{
int translationUnitIndex = translationUnitCounter++;
@@ -370,8 +492,8 @@ namespace Slang
// TODO(tfoley): need to pass along the entry point
// so that we properly emit it as the `main` function.
return emitEntryPoint(
+ compileRequest,
entryPoint,
- targetReq->layout.Ptr(),
CodeGenTarget::GLSL,
targetReq);
}
@@ -484,9 +606,9 @@ namespace Slang
sink->diagnoseRaw(SLANG_FAILED(res) ? Severity::Error : Severity::Warning, builder.getUnownedSlice());
}
- static String _getDisplayPath(const DiagnosticSink& sink, SourceFile* sourceFile)
+ static String _getDisplayPath(DiagnosticSink* sink, SourceFile* sourceFile)
{
- if (sink.flags & DiagnosticSink::Flag::VerbosePath)
+ if (sink->flags & DiagnosticSink::Flag::VerbosePath)
{
return sourceFile->calcVerbosePath();
}
@@ -496,17 +618,17 @@ namespace Slang
}
}
- String calcTranslationUnitSourcePath(TranslationUnitRequest* translationUnitRequest)
+ String calcSourcePathForEntryPoint(
+ EndToEndCompileRequest* endToEndReq,
+ UInt entryPointIndex)
{
- CompileRequest* compileRequest = translationUnitRequest->compileRequest;
- if (compileRequest->passThrough == PassThroughMode::None)
- {
+ auto translationUnitRequest = findPassThroughTranslationUnit(endToEndReq, entryPointIndex);
+ if(!translationUnitRequest)
return "slang-generated";
- }
- auto& sink = translationUnitRequest->compileRequest->mSink;
+ auto sink = endToEndReq->getSink();
- const auto& sourceFiles = translationUnitRequest->sourceFiles;
+ const auto& sourceFiles = translationUnitRequest->getSourceFiles();
const int numSourceFiles = int(sourceFiles.Count());
@@ -542,22 +664,26 @@ namespace Slang
}
SlangResult emitDXBytecodeForEntryPoint(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq,
- List<uint8_t>& byteCodeOut)
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq,
+ List<uint8_t>& byteCodeOut)
{
byteCodeOut.Clear();
- auto session = entryPoint->compileRequest->mSession;
+ auto session = compileRequest->getSession();
+ auto sink = compileRequest->getSink();
- auto compileFunc = (pD3DCompile)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DCompile, &entryPoint->compileRequest->mSink);
+ auto compileFunc = (pD3DCompile)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DCompile, sink);
if (!compileFunc)
{
return SLANG_FAIL;
}
- auto hlslCode = emitHLSLForEntryPoint(entryPoint, targetReq);
- maybeDumpIntermediate(entryPoint->compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL);
+ auto hlslCode = emitHLSLForEntryPoint(compileRequest, entryPoint, entryPointIndex, targetReq, endToEndReq);
+ maybeDumpIntermediate(compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL);
auto profile = getEffectiveProfile(entryPoint, targetReq);
@@ -569,16 +695,16 @@ namespace Slang
//
List<D3D_SHADER_MACRO> dxMacrosStorage;
D3D_SHADER_MACRO const* dxMacros = nullptr;
- if( entryPoint->compileRequest->passThrough != PassThroughMode::None )
+ if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex))
{
- for( auto& define : entryPoint->compileRequest->preprocessorDefinitions )
+ for( auto& define : translationUnit->compileRequest->preprocessorDefinitions )
{
D3D_SHADER_MACRO dxMacro;
dxMacro.Name = define.Key.Buffer();
dxMacro.Definition = define.Value.Buffer();
dxMacrosStorage.Add(dxMacro);
}
- for( auto& define : entryPoint->getTranslationUnit()->preprocessorDefinitions )
+ for( auto& define : translationUnit->preprocessorDefinitions )
{
D3D_SHADER_MACRO dxMacro;
dxMacro.Name = define.Key.Buffer();
@@ -616,7 +742,7 @@ namespace Slang
flags |= D3DCOMPILE_ENABLE_STRICTNESS;
flags |= D3DCOMPILE_ENABLE_UNBOUNDED_DESCRIPTOR_TABLES;
- const String sourcePath = calcTranslationUnitSourcePath(entryPoint->getTranslationUnit());
+ const String sourcePath = "slang-geneated";// calcTranslationUnitSourcePath(entryPoint->getTranslationUnit());
ComPtr<ID3DBlob> codeBlob;
ComPtr<ID3DBlob> diagnosticsBlob;
@@ -626,7 +752,7 @@ namespace Slang
sourcePath.Buffer(),
dxMacros,
nullptr,
- getText(entryPoint->name).begin(),
+ getText(entryPoint->getName()).begin(),
GetHLSLProfileName(profile).Buffer(),
flags,
0, // unused: effect flags
@@ -640,23 +766,24 @@ namespace Slang
if (FAILED(hr))
{
- reportExternalCompileError("fxc", hr, _getSlice(diagnosticsBlob), &entryPoint->compileRequest->mSink);
+ reportExternalCompileError("fxc", hr, _getSlice(diagnosticsBlob), sink);
}
return hr;
}
SlangResult dissassembleDXBC(
- CompileRequest* compileRequest,
- void const* data,
- size_t size,
- String& assemOut)
+ BackEndCompileRequest* compileRequest,
+ void const* data,
+ size_t size,
+ String& assemOut)
{
assemOut = String();
- auto session = compileRequest->mSession;
+ auto session = compileRequest->getSession();
+ auto sink = compileRequest->getSink();
- auto disassembleFunc = (pD3DDisassemble)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DDisassemble, &compileRequest->mSink);
+ auto disassembleFunc = (pD3DDisassemble)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Fxc_D3DDisassemble, sink);
if (!disassembleFunc)
{
return SLANG_E_NOT_FOUND;
@@ -677,25 +804,34 @@ namespace Slang
if (FAILED(res))
{
// TODO(tfoley): need to figure out what to diagnose here...
- reportExternalCompileError("fxc", res, UnownedStringSlice(), &compileRequest->mSink);
+ reportExternalCompileError("fxc", res, UnownedStringSlice(), sink);
}
return res;
}
SlangResult emitDXBytecodeAssemblyForEntryPoint(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq,
- String& assemOut)
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq,
+ String& assemOut)
{
List<uint8_t> dxbc;
- SLANG_RETURN_ON_FAIL(emitDXBytecodeForEntryPoint(entryPoint, targetReq, dxbc));
+ SLANG_RETURN_ON_FAIL(emitDXBytecodeForEntryPoint(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq,
+ dxbc));
if (!dxbc.Count())
{
return SLANG_FAIL;
}
- return dissassembleDXBC(entryPoint->compileRequest, dxbc.Buffer(), dxbc.Count(), assemOut);
+ return dissassembleDXBC(compileRequest, dxbc.Buffer(), dxbc.Count(), assemOut);
}
#endif
@@ -704,26 +840,30 @@ namespace Slang
// Implementations in `dxc-support.cpp`
int emitDXILForEntryPointUsingDXC(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq,
- List<uint8_t>& outCode);
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq,
+ List<uint8_t>& outCode);
SlangResult dissassembleDXILUsingDXC(
- CompileRequest* compileRequest,
- void const* data,
- size_t size,
- String& stringOut);
+ BackEndCompileRequest* compileRequest,
+ void const* data,
+ size_t size,
+ String& stringOut);
#endif
#if SLANG_ENABLE_GLSLANG_SUPPORT
SlangResult invokeGLSLCompiler(
- CompileRequest* slangCompileRequest,
+ BackEndCompileRequest* slangCompileRequest,
glslang_CompileRequest& request)
{
- Session* session = slangCompileRequest->mSession;
+ Session* session = slangCompileRequest->getSession();
+ auto sink = slangCompileRequest->getSink();
- auto glslang_compile = (glslang_CompileFunc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Glslang_Compile, &slangCompileRequest->mSink);
+ auto glslang_compile = (glslang_CompileFunc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Glslang_Compile, sink);
if (!glslang_compile)
{
return SLANG_FAIL;
@@ -743,7 +883,7 @@ SlangResult dissassembleDXILUsingDXC(
if (err)
{
- reportExternalCompileError("glslang", SLANG_FAIL, diagnosticOutput.getUnownedSlice(), &slangCompileRequest->mSink);
+ reportExternalCompileError("glslang", SLANG_FAIL, diagnosticOutput.getUnownedSlice(), sink);
return SLANG_FAIL;
}
@@ -751,10 +891,10 @@ SlangResult dissassembleDXILUsingDXC(
}
SlangResult dissassembleSPIRV(
- CompileRequest* slangRequest,
- void const* data,
- size_t size,
- String& stringOut)
+ BackEndCompileRequest* slangRequest,
+ void const* data,
+ size_t size,
+ String& stringOut)
{
stringOut = String();
@@ -782,21 +922,29 @@ SlangResult dissassembleDXILUsingDXC(
}
SlangResult emitSPIRVForEntryPoint(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq,
- List<uint8_t>& spirvOut)
+ BackEndCompileRequest* slangRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq,
+ List<uint8_t>& spirvOut)
{
spirvOut.Clear();
- String rawGLSL = emitGLSLForEntryPoint(entryPoint, targetReq);
- maybeDumpIntermediate(entryPoint->compileRequest, rawGLSL.Buffer(), CodeGenTarget::GLSL);
+ String rawGLSL = emitGLSLForEntryPoint(
+ slangRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq);
+ maybeDumpIntermediate(slangRequest, rawGLSL.Buffer(), CodeGenTarget::GLSL);
auto outputFunc = [](void const* data, size_t size, void* userData)
{
((List<uint8_t>*)userData)->AddRange((uint8_t*)data, size);
};
- const String sourcePath = calcTranslationUnitSourcePath(entryPoint->getTranslationUnit());
+ const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex);
glslang_CompileRequest request;
request.action = GLSLANG_ACTION_COMPILE_GLSL_TO_SPIRV;
@@ -809,40 +957,56 @@ SlangResult dissassembleDXILUsingDXC(
request.outputFunc = outputFunc;
request.outputUserData = &spirvOut;
- SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(entryPoint->compileRequest, request));
+ SLANG_RETURN_ON_FAIL(invokeGLSLCompiler(slangRequest, request));
return SLANG_OK;
}
SlangResult emitSPIRVAssemblyForEntryPoint(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq,
- String& assemblyOut)
+ BackEndCompileRequest* slangRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq,
+ String& assemblyOut)
{
List<uint8_t> spirv;
- SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoint(entryPoint, targetReq, spirv));
+ SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoint(
+ slangRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq,
+ spirv));
if (spirv.Count() == 0)
return SLANG_FAIL;
- return dissassembleSPIRV(entryPoint->compileRequest, spirv.begin(), spirv.Count(), assemblyOut);
+ return dissassembleSPIRV(slangRequest, spirv.begin(), spirv.Count(), assemblyOut);
}
#endif
// Do emit logic for a single entry point
CompileResult emitEntryPoint(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq)
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq)
{
CompileResult result;
- auto compileRequest = entryPoint->compileRequest;
auto target = targetReq->target;
switch (target)
{
case CodeGenTarget::HLSL:
{
- String code = emitHLSLForEntryPoint(entryPoint, targetReq);
+ String code = emitHLSLForEntryPoint(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq);
maybeDumpIntermediate(compileRequest, code.Buffer(), target);
result = CompileResult(code);
}
@@ -850,7 +1014,12 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::GLSL:
{
- String code = emitGLSLForEntryPoint(entryPoint, targetReq);
+ String code = emitGLSLForEntryPoint(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq);
maybeDumpIntermediate(compileRequest, code.Buffer(), target);
result = CompileResult(code);
}
@@ -860,7 +1029,13 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::DXBytecode:
{
List<uint8_t> code;
- if (SLANG_SUCCEEDED(emitDXBytecodeForEntryPoint(entryPoint, targetReq, code)))
+ if (SLANG_SUCCEEDED(emitDXBytecodeForEntryPoint(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq,
+ code)))
{
maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target);
result = CompileResult(code);
@@ -871,7 +1046,13 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::DXBytecodeAssembly:
{
String code;
- if (SLANG_SUCCEEDED(emitDXBytecodeAssemblyForEntryPoint(entryPoint, targetReq, code)))
+ if (SLANG_SUCCEEDED(emitDXBytecodeAssemblyForEntryPoint(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq,
+ code)))
{
maybeDumpIntermediate(compileRequest, code.Buffer(), target);
result = CompileResult(code);
@@ -884,7 +1065,13 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::DXIL:
{
List<uint8_t> code;
- if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC(entryPoint, targetReq, code)))
+ if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq,
+ code)))
{
maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target);
result = CompileResult(code);
@@ -895,7 +1082,13 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::DXILAssembly:
{
List<uint8_t> code;
- if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC(entryPoint, targetReq, code)))
+ if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq,
+ code)))
{
String assembly;
dissassembleDXILUsingDXC(
@@ -915,7 +1108,13 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::SPIRV:
{
List<uint8_t> code;
- if (SLANG_SUCCEEDED(emitSPIRVForEntryPoint(entryPoint, targetReq, code)))
+ if (SLANG_SUCCEEDED(emitSPIRVForEntryPoint(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq,
+ code)))
{
maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target);
result = CompileResult(code);
@@ -926,7 +1125,13 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::SPIRVAssembly:
{
String code;
- if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoint(entryPoint, targetReq, code)))
+ if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoint(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq,
+ code)))
{
maybeDumpIntermediate(compileRequest, code.Buffer(), target);
result = CompileResult(code);
@@ -957,16 +1162,16 @@ SlangResult dissassembleDXILUsingDXC(
};
static void writeOutputFile(
- CompileRequest* compileRequest,
- FILE* file,
- String const& path,
- void const* data,
- size_t size)
+ BackEndCompileRequest* compileRequest,
+ FILE* file,
+ String const& path,
+ void const* data,
+ size_t size)
{
size_t count = fwrite(data, size, 1, file);
if (count != 1)
{
- compileRequest->mSink.diagnose(
+ compileRequest->getSink()->diagnose(
SourceLoc(),
Diagnostics::cannotWriteOutputFile,
path);
@@ -974,16 +1179,16 @@ SlangResult dissassembleDXILUsingDXC(
}
static void writeOutputFile(
- CompileRequest* compileRequest,
- ISlangWriter* writer,
- String const& path,
- void const* data,
- size_t size)
+ BackEndCompileRequest* compileRequest,
+ ISlangWriter* writer,
+ String const& path,
+ void const* data,
+ size_t size)
{
if (SLANG_FAILED(writer->write((const char*)data, size)))
{
- compileRequest->mSink.diagnose(
+ compileRequest->getSink()->diagnose(
SourceLoc(),
Diagnostics::cannotWriteOutputFile,
path);
@@ -991,18 +1196,18 @@ SlangResult dissassembleDXILUsingDXC(
}
static void writeOutputFile(
- CompileRequest* compileRequest,
- String const& path,
- void const* data,
- size_t size,
- OutputFileKind kind)
+ BackEndCompileRequest* compileRequest,
+ String const& path,
+ void const* data,
+ size_t size,
+ OutputFileKind kind)
{
FILE* file = fopen(
path.Buffer(),
kind == OutputFileKind::Binary ? "wb" : "w");
if (!file)
{
- compileRequest->mSink.diagnose(
+ compileRequest->getSink()->diagnose(
SourceLoc(),
Diagnostics::cannotWriteOutputFile,
path);
@@ -1014,11 +1219,12 @@ SlangResult dissassembleDXILUsingDXC(
}
static void writeEntryPointResultToFile(
- EntryPointRequest* entryPoint,
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
String const& outputPath,
CompileResult const& result)
{
- auto compileRequest = entryPoint->compileRequest;
+ SLANG_UNUSED(entryPoint);
switch (result.format)
{
@@ -1059,13 +1265,15 @@ SlangResult dissassembleDXILUsingDXC(
}
static void writeEntryPointResultToStandardOutput(
- EntryPointRequest* entryPoint,
+ EndToEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
TargetRequest* targetReq,
CompileResult const& result)
{
- auto compileRequest = entryPoint->compileRequest;
+ SLANG_UNUSED(entryPoint);
ISlangWriter* writer = compileRequest->getWriter(WriterChannel::StdOutput);
+ auto backEndReq = compileRequest->getBackEndReq();
switch (result.format)
{
@@ -1087,7 +1295,7 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::DXBytecode:
{
String assembly;
- dissassembleDXBC(compileRequest,
+ dissassembleDXBC(backEndReq,
data.begin(),
data.end() - data.begin(), assembly);
writeOutputToConsole(writer, assembly);
@@ -1099,7 +1307,7 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::DXIL:
{
String assembly;
- dissassembleDXILUsingDXC(compileRequest,
+ dissassembleDXILUsingDXC(backEndReq,
data.begin(),
data.end() - data.begin(),
assembly);
@@ -1111,7 +1319,7 @@ SlangResult dissassembleDXILUsingDXC(
case CodeGenTarget::SPIRV:
{
String assembly;
- dissassembleSPIRV(compileRequest,
+ dissassembleSPIRV(backEndReq,
data.begin(),
data.end() - data.begin(), assembly);
writeOutputToConsole(writer, assembly);
@@ -1129,7 +1337,7 @@ SlangResult dissassembleDXILUsingDXC(
writer->setMode(SLANG_WRITER_MODE_BINARY);
writeOutputFile(
- compileRequest,
+ backEndReq,
writer,
"stdout",
data.begin(),
@@ -1146,89 +1354,108 @@ SlangResult dissassembleDXILUsingDXC(
}
static void writeEntryPointResult(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq,
- UInt entryPointIndex)
+ EndToEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ TargetRequest* targetReq,
+ Int entryPointIndex)
{
- // It is possible that we are dynamically discovering entry
- // points (using `[shader(...)]` attributes), so that the
- // number of entry points on the compile request does not
- // match the number of entries in the `entryPointOutputPaths`
- // array.
- //
- String outputPath;
- if( entryPointIndex < targetReq->entryPointOutputPaths.Count() )
- {
- outputPath = targetReq->entryPointOutputPaths[entryPointIndex];
- }
+ auto program = compileRequest->getSpecializedProgram();
+ auto targetProgram = program->getTargetProgram(targetReq);
+ auto backEndReq = compileRequest->getBackEndReq();
- auto& result = targetReq->entryPointResults[entryPointIndex];
+ auto& result = targetProgram->getExistingEntryPointResult(entryPointIndex);
// Skip the case with no output
if (result.format == ResultFormat::None)
return;
- if (outputPath.Length())
- {
- writeEntryPointResultToFile(entryPoint, outputPath, result);
- }
- else
+ // It is possible that we are dynamically discovering entry
+ // points (using `[shader(...)]` attributes), so that there
+ // might be entry points added to the program that did not
+ // get paths specified via command-line options.
+ //
+ RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo;
+ if(compileRequest->targetInfos.TryGetValue(targetReq, targetInfo))
{
- writeEntryPointResultToStandardOutput(entryPoint, targetReq, result);
+ String outputPath;
+ if(targetInfo->entryPointOutputPaths.TryGetValue(entryPointIndex, outputPath))
+ {
+ writeEntryPointResultToFile(backEndReq, entryPoint, outputPath, result);
+ return;
+ }
}
+
+ writeEntryPointResultToStandardOutput(compileRequest, entryPoint, targetReq, result);
}
void generateOutputForTarget(
- TargetRequest* targetReq)
+ BackEndCompileRequest* compileReq,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq)
{
- CompileRequest* compileReq = targetReq->compileRequest;
+ auto program = compileReq->getProgram();
+ auto targetProgram = program->getTargetProgram(targetReq);
// Generate target code any entry points that
// have been requested for compilation.
- for (auto& entryPoint : compileReq->entryPoints)
+ auto entryPointCount = program->getEntryPointCount();
+ for(UInt ii = 0; ii < entryPointCount; ++ii)
{
- CompileResult entryPointResult = emitEntryPoint(entryPoint, targetReq);
- targetReq->entryPointResults.Add(entryPointResult);
+ auto entryPoint = program->getEntryPoint(ii);
+ CompileResult entryPointResult = emitEntryPoint(
+ compileReq,
+ entryPoint,
+ ii,
+ targetReq,
+ endToEndReq);
+ targetProgram->setEntryPointResult(ii, entryPointResult);
}
}
- void generateOutput(
- CompileRequest* compileRequest)
+ static void _generateOutput(
+ BackEndCompileRequest* compileRequest,
+ EndToEndCompileRequest* endToEndReq)
{
// Go through the code-generation targets that the user
// has specified, and generate code for each of them.
//
- for (auto targetReq : compileRequest->targets)
+ auto linkage = compileRequest->getLinkage();
+ for (auto targetReq : linkage->targets)
{
- generateOutputForTarget(targetReq);
+ generateOutputForTarget(compileRequest, targetReq, endToEndReq);
}
+ }
+
+ void generateOutput(
+ BackEndCompileRequest* compileRequest)
+ {
+ _generateOutput(compileRequest, nullptr);
+ }
+
+ void generateOutput(
+ EndToEndCompileRequest* compileRequest)
+ {
+ _generateOutput(compileRequest->getBackEndReq(), compileRequest);
// If we are in command-line mode, we might be expected to actually
// write output to one or more files here.
if (compileRequest->isCommandLineCompile)
{
- for (auto targetReq : compileRequest->targets)
+ auto linkage = compileRequest->getLinkage();
+ auto program = compileRequest->getSpecializedProgram();
+ for (auto targetReq : linkage->targets)
{
- UInt entryPointCount = compileRequest->entryPoints.Count();
+ UInt entryPointCount = program->getEntryPointCount();
for (UInt ee = 0; ee < entryPointCount; ++ee)
{
writeEntryPointResult(
- compileRequest->entryPoints[ee],
+ compileRequest,
+ program->getEntryPoint(ee),
targetReq,
ee);
}
}
-
- if (compileRequest->containerOutputPath.Length() != 0)
- {
- auto& data = compileRequest->generatedBytecode;
- writeOutputFile(compileRequest,
- compileRequest->containerOutputPath,
- data.begin(),
- data.end() - data.begin(),
- OutputFileKind::Binary);
- }
}
}
@@ -1237,7 +1464,7 @@ SlangResult dissassembleDXILUsingDXC(
//
void dumpIntermediate(
- CompileRequest*,
+ BackEndCompileRequest*,
void const* data,
size_t size,
char const* ext,
@@ -1271,7 +1498,7 @@ SlangResult dissassembleDXILUsingDXC(
}
void dumpIntermediateText(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
void const* data,
size_t size,
char const* ext)
@@ -1280,7 +1507,7 @@ SlangResult dissassembleDXILUsingDXC(
}
void dumpIntermediateBinary(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
void const* data,
size_t size,
char const* ext)
@@ -1289,7 +1516,7 @@ SlangResult dissassembleDXILUsingDXC(
}
void maybeDumpIntermediate(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
void const* data,
size_t size,
CodeGenTarget target)
@@ -1362,7 +1589,7 @@ SlangResult dissassembleDXILUsingDXC(
}
void maybeDumpIntermediate(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
char const* text,
CodeGenTarget target)
{
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index 39199a62f..c975c1c2b 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -20,6 +20,8 @@ namespace Slang
class CompileRequest;
class ProgramLayout;
class PtrType;
+ class TargetProgram;
+ class TargetRequest;
class TypeLayout;
enum class CompilerMode
@@ -88,8 +90,12 @@ namespace Slang
kMatrixLayoutMode_ColumnMajor = SLANG_MATRIX_LAYOUT_COLUMN_MAJOR,
};
-
- class CompileRequest;
+ class Linkage;
+ class Module;
+ class Program;
+ class FrontEndCompileRequest;
+ class BackEndCompileRequest;
+ class EndToEndCompileRequest;
class TranslationUnitRequest;
// Result of compiling an entry point.
@@ -112,18 +118,165 @@ namespace Slang
ComPtr<ISlangBlob> blob;
};
- // Describes an entry point that we've been requested to compile
- class EntryPointRequest : public RefObject
+ /// A request for the front-end to find and validate an entry-point function
+ struct FrontEndEntryPointRequest : RefObject
{
public:
+ /// Create a request for an entry point.
+ FrontEndEntryPointRequest(
+ FrontEndCompileRequest* compileRequest,
+ int translationUnitIndex,
+ Name* name,
+ Profile profile);
+
+ /// Get the parent front-end compile request.
+ FrontEndCompileRequest* getCompileRequest() { return m_compileRequest; }
+
+ /// Get the translation unit that contains the entry point.
+ TranslationUnitRequest* getTranslationUnit();
+
+ /// Get the name of the entry point to find.
+ Name* getName() { return m_name; }
+
+ /// Get the stage that the entry point is to be compiled for
+ Stage getStage() { return m_profile.GetStage(); }
+
+ /// Get the profile that the entry point is to be compiled for
+ Profile getProfile() { return m_profile; }
+
+ private:
// The parent compile request
- CompileRequest* compileRequest = nullptr;
+ FrontEndCompileRequest* m_compileRequest;
+
+ // The index of the translation unit that will hold the entry point
+ int m_translationUnitIndex;
+
+ // The name of the entry point function to look for
+ Name* m_name;
+
+ // The profile to compile for (including stage)
+ Profile m_profile;
+ };
+
+ /// Tracks an ordered list of modules that something depends on.
+ struct ModuleDependencyList
+ {
+ public:
+ /// Get the list of modules that are depended on.
+ List<RefPtr<Module>> const& getModuleList() { return m_moduleList; }
+
+ /// Add a module and everything it depends on to the list.
+ void addDependency(Module* module);
+
+ /// Add a module to the list, but not the modules it depends on.
+ void addLeafDependency(Module* module);
+
+ private:
+ void _addDependency(Module* module);
+
+ List<RefPtr<Module>> m_moduleList;
+ HashSet<Module*> m_moduleSet;
+ };
+
+ /// Tracks an unordered list of filesystem paths that something depends on
+ struct FilePathDependencyList
+ {
+ public:
+ /// Get the list of paths that are depended on.
+ List<String> const& getFilePathList() { return m_filePathList; }
+
+ /// Add a path to the list, if it is not already present
+ void addDependency(String const& path);
+
+ /// Add all of the paths that `module` depends on to the list
+ void addDependency(Module* module);
+
+ private:
+
+ // TODO: We are using a `HashSet` here to deduplicate
+ // the paths so that we don't return the same path
+ // multiple times from `getFilePathList`, but because
+ // order isn't important, we could potentially do better
+ // in terms of memory (at some cost in performance) by
+ // just sorting the `m_filePathList` every once in
+ // a while and then deduplicating.
+
+ List<String> m_filePathList;
+ HashSet<String> m_filePathSet;
+ };
+
+ /// Describes an entry point for the purposes of layout and code generation.
+ ///
+ /// This class also tracks any generic arguments to the entry point,
+ /// in the case that it is a specialization of a generic entry point.
+ ///
+ /// There is also a provision for creating a "dummy" entry point for
+ /// the purposes of pass-through compilation modes. Only the
+ /// `getName()` and `getProfile()` methods should be expected to
+ /// return useful data on pass-through entry points.
+ ///
+ class EntryPoint : public RefObject
+ {
+ public:
+ /// Create an entry point that refers to the given function.
+ static RefPtr<EntryPoint> create(
+ DeclRef<FuncDecl> funcDeclRef,
+ Profile profile);
+
+ /// Get the function decl-ref, including any generic arguments.
+ DeclRef<FuncDecl> getFuncDeclRef() { return m_funcDeclRef; }
+
+ /// Get the function declaration (without generic arguments).
+ RefPtr<FuncDecl> getFuncDecl() { return m_funcDeclRef.getDecl(); }
+
+ /// Get the name of the entry point
+ Name* getName() { return m_name; }
+
+ /// Get the profile associated with the entry point
+ ///
+ /// Note: only the stage part of the profile is expected
+ /// to contain useful data, but certain legacy code paths
+ /// allow for "shader model" information to come via this path.
+ ///
+ Profile getProfile() { return m_profile; }
+
+ /// Get the stage that the entry point is for.
+ Stage getStage() { return m_profile.GetStage(); }
+
+ /// Get the module that contains the entry point.
+ Module* getModule();
+
+ /// Get the linkage that contains the module for this entry pooint.
+ Linkage* getLinkage();
+
+ /// Get a list of modules that this entry point depends on.
+ ///
+ /// This will include the module that defines the entry point (see `getModule()`),
+ /// but may also include modules that are required by its generic type arguments.
+ ///
+ List<RefPtr<Module>> getModuleDependencies() { return m_dependencyList.getModuleList(); }
+
+ /// Get a list of tagged-union types referenced by the entry point's generic parameters.
+ List<RefPtr<TaggedUnionType>> const& getTaggedUnionTypes() { return m_taggedUnionTypes; }
+
+ /// Create a dummy `EntryPoint` that is only usable for pass-through compilation.
+ static RefPtr<EntryPoint> createDummyForPassThrough(
+ Name* name,
+ Profile profile);
+
+ private:
+ EntryPoint(
+ Name* name,
+ Profile profile,
+ DeclRef<FuncDecl> funcDeclRef);
// The name of the entry point function (e.g., `main`)
- Name* name;
+ //
+ Name* m_name = nullptr;
- /// Source code for the generic arguments to use for the generic parameters of the entry point.
- List<String> genericArgStrings;
+ // The declaration of the entry-point function itself.
+ //
+ DeclRef<FuncDecl> m_funcDeclRef;
// The profile that the entry point will be compiled for
// (this is a combination of the target stage, and also
@@ -135,33 +288,13 @@ namespace Slang
// from the target, while the stage part is all that is
// intrinsic to the entry point.
//
- Profile profile;
+ Profile m_profile;
- // Get the stage that the entry point is being compiled for.
- Stage getStage() { return profile.GetStage(); }
+ // Any tagged union types that were referenced by the generic arguments of the entry point.
+ List<RefPtr<TaggedUnionType>> m_taggedUnionTypes;
- // The index of the translation unit (within the parent
- // compile request) that the entry point function is
- // supposed to be defined in.
- int translationUnitIndex;
-
- // The translation unit that this entry point came from
- TranslationUnitRequest* getTranslationUnit();
-
- // The declaration of the entry-point function itself.
- // This will be filled in as part of semantic analysis;
- // it should not be assumed to be available in cases
- // where any errors were diagnosed.
- //
- DeclRef<FuncDecl> funcDeclRef;
-
- DeclRef<FuncDecl> getFuncDeclRef();
- RefPtr<FuncDecl> getFuncDecl();
-
- RefPtr<Substitutions> globalGenericSubst;
-
- /// Any tagged union types that were referenced by the generic arguments of the entry point.
- List<RefPtr<TaggedUnionType>> taggedUnionTypes;
+ // Modules the entry point depends on.
+ ModuleDependencyList m_dependencyList;
};
enum class PassThroughMode : SlangPassThrough
@@ -174,13 +307,78 @@ namespace Slang
class SourceFile;
- // A single translation unit requested to be compiled.
- //
+ /// A module of code that has been compiled through the front-end
+ ///
+ /// A module comprises all the code from one translation unit (which
+ /// may span multiple Slang source files), and provides access
+ /// to both the AST and IR representations of that code.
+ ///
+ class Module : public RefObject
+ {
+ public:
+ /// Create a module (initially empty).
+ Module(Linkage* linkage);
+
+ /// Get the parent linkage of this module.
+ Linkage* getLinkage() { return m_linkage; }
+
+ /// Get the AST for the module (if it has been parsed)
+ ModuleDecl* getModuleDecl() { return m_moduleDecl; }
+
+ /// The the IR for the module (if it has been generated)
+ IRModule* getIRModule() { return m_irModule; }
+
+ /// Get the list of other modules this module depends on
+ List<RefPtr<Module>> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); }
+
+ /// Get the list of filesystem paths this module depends on
+ List<String> const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); }
+
+ /// Register a module that this module depends on
+ void addModuleDependency(Module* module);
+
+ /// Register a filesystem path that this module depends on
+ void addFilePathDependency(String const& path);
+
+ /// Set the AST for this module.
+ ///
+ /// This should only be called once, during creation of the module.
+ ///
+ void setModuleDecl(ModuleDecl* moduleDecl) { m_moduleDecl = moduleDecl; }
+
+ /// Set the IR for this module.
+ ///
+ /// This should only be called once, during creation of the module.
+ ///
+ void setIRModule(IRModule* irModule) { m_irModule = irModule; }
+
+ private:
+ // The parent linkage
+ Linkage* m_linkage = nullptr;
+
+ // The AST for the module
+ RefPtr<ModuleDecl> m_moduleDecl;
+
+ // The IR for the module
+ RefPtr<IRModule> m_irModule = nullptr;
+
+ // List of modules this module depends on
+ ModuleDependencyList m_moduleDependencyList;
+
+ // List of filesystem paths this module depends on
+ FilePathDependencyList m_filePathDependencyList;
+ };
+ typedef Module LoadedModule;
+
+ /// A request for the front-end to compile a translation unit.
class TranslationUnitRequest : public RefObject
{
public:
+ TranslationUnitRequest(
+ FrontEndCompileRequest* compileRequest);
+
// The parent compile request
- CompileRequest* compileRequest = nullptr;
+ FrontEndCompileRequest* compileRequest = nullptr;
// The language in which the source file(s)
// are assumed to be written
@@ -189,26 +387,30 @@ namespace Slang
// The source file(s) that will be compiled to form this translation unit
//
// Usually, for HLSL or GLSL there will be only one file.
- List<SourceFile*> sourceFiles;
+ List<SourceFile*> m_sourceFiles;
+
+ List<SourceFile*> const& getSourceFiles() { return m_sourceFiles; }
+ void addSourceFile(SourceFile* sourceFile);
// The entry points associated with this translation unit
- List<RefPtr<EntryPointRequest> > entryPoints;
+ List<RefPtr<EntryPoint>> entryPoints;
// Preprocessor definitions to use for this translation unit only
- // (whereas the ones on `CompileOptions` will be shared)
+ // (whereas the ones on `compileRequest` will be shared)
Dictionary<String, String> preprocessorDefinitions;
- // Compile flags for this translation unit
- SlangCompileFlags compileFlags = 0;
+ /// The name that will be used for the module this translation unit produces.
+ Name* moduleName = nullptr;
+
+ /// Result of compiling this translation unit (a module)
+ RefPtr<Module> module;
- // The parsed syntax for the translation unit
- RefPtr<ModuleDecl> SyntaxNode;
+ Module* getModule() { return module; }
+ RefPtr<ModuleDecl> getModuleDecl() { return module->getModuleDecl(); }
- // The IR-level code for this translation unit.
- // This will only be valid/non-null after semantic
- // checking and IR generation are complete, so it
- // is not safe to use this field without testing for NULL.
- RefPtr<IRModule> irModule;
+ Session* getSession();
+ NamePool* getNamePool();
+ SourceManager* getSourceManager();
};
enum class FloatingPointMode : SlangFloatingPointMode
@@ -232,33 +434,28 @@ namespace Slang
Binary = SLANG_WRITER_MODE_BINARY,
};
- // A request to generate output in some target format
+ /// A request to generate output in some target format.
class TargetRequest : public RefObject
{
public:
- CompileRequest* compileRequest;
+ Linkage* linkage;
CodeGenTarget target;
SlangTargetFlags targetFlags = 0;
Slang::Profile targetProfile = Slang::Profile();
FloatingPointMode floatingPointMode = FloatingPointMode::Default;
- // Requested output paths for each entry point.
- // An empty string indices no output desired for
- // the given entry point.
- List<String> entryPointOutputPaths;
+ Linkage* getLinkage() { return linkage; }
+ CodeGenTarget getTarget() { return target; }
+ Profile getTargetProfile() { return targetProfile; }
+ FloatingPointMode getFloatingPointMode() { return floatingPointMode; }
- // The resulting reflection layout information
- RefPtr<ProgramLayout> layout;
-
- // Generated compile results for each entry point
- // in the parent compile request (indexing matches
- // the order they are given in the compile request)
- List<CompileResult> entryPointResults;
+ Session* getSession();
+ MatrixLayoutMode getDefaultMatrixLayoutMode();
// TypeLayouts created on the fly by reflection API
Dictionary<Type*, RefPtr<TypeLayout>> typeLayouts;
- MatrixLayoutMode getDefaultMatrixLayoutMode();
+ Dictionary<Type*, RefPtr<TypeLayout>>& getTypeLayouts() { return typeLayouts; }
};
/// Are we generating code for a D3D API?
@@ -280,7 +477,8 @@ namespace Slang
// - If the entry point and target disagree on the profile family, always use the
// profile family and version from the target.
//
- Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target);
+ Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target);
+
// A directory to be searched when looking for files (e.g., `#include`)
struct SearchDirectory
@@ -294,110 +492,53 @@ namespace Slang
String path;
};
- // Represents a module that has been loaded through the front-end
- // (up through IR generation).
- //
- class LoadedModule : public RefObject
+ /// A list of directories to search for files (e.g., `#include`)
+ struct SearchDirectoryList
{
- public:
- // The AST for the module
- RefPtr<ModuleDecl> moduleDecl;
+ // A parent list that should also be searched
+ SearchDirectoryList* parent = nullptr;
- // The IR for the module
- RefPtr<IRModule> irModule = nullptr;
+ // Directories to be searched
+ List<SearchDirectory> searchDirectories;
};
- class Session;
-
-
/// Create a blob that will retain (a copy of) raw data.
///
ComPtr<ISlangBlob> createRawBlob(void const* data, size_t size);
- class CompileRequest : public RefObject
+ /// A context for loading and re-using code modules.
+ class Linkage : public RefObject
{
public:
- // Pointer to parent session
- Session* mSession;
+ /// Create an initially-empty linkage
+ Linkage(Session* session);
+
+ /// Get the parent session for this linkage
+ Session* getSession() { return m_session; }
// Information on the targets we are being asked to
// generate code for.
List<RefPtr<TargetRequest>> targets;
- // What container format are we being asked to generate?
- ContainerFormat containerFormat = ContainerFormat::None;
-
- // Path to output container to
- String containerOutputPath;
-
// Directories to search for `#include` files or `import`ed modules
- List<SearchDirectory> searchDirectories;
+ SearchDirectoryList searchDirectories;
+
+ SearchDirectoryList const& getSearchDirectories() { return searchDirectories; }
// Definitions to provide during preprocessing
Dictionary<String, String> preprocessorDefinitions;
- // Translation units we are being asked to compile
- List<RefPtr<TranslationUnitRequest> > translationUnits;
-
- // Entry points we've been asked to compile (each
- // associated with a translation unit).
- List<RefPtr<EntryPointRequest> > entryPoints;
-
- // Types constructed by reflection API
- Dictionary<String, RefPtr<Type>> types;
-
- /// The layout to use for matrices by default (row/column major)
- MatrixLayoutMode defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor;
- MatrixLayoutMode getDefaultMatrixLayoutMode() { return defaultMatrixLayoutMode; }
-
- // Should we just pass the input to another compiler?
- PassThroughMode passThrough = PassThroughMode::None;
-
- // Compile flags to be shared by all translation units
- SlangCompileFlags compileFlags = 0;
-
- // Should we dump intermediate results along the way, for debugging?
- bool shouldDumpIntermediates = false;
-
- bool shouldDumpIR = false;
- bool shouldValidateIR = false;
- bool shouldSkipCodegen = false;
-
- // If true then generateIR will serialize out IR, and serialize back in again. Making
- // serialization a bottleneck or firewall between the front end and the backend
- bool useSerialIRBottleneck = false;
-
- // If true will serialize and de-serialize with debug information
- bool verifyDebugSerialization = false;
-
- // How should `#line` directives be emitted (if at all)?
- LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default;
- // Are we being driven by the command-line `slangc`, and should act accordingly?
- bool isCommandLineCompile = false;
// Source manager to help track files loaded
- SourceManager sourceManagerStorage;
- SourceManager* sourceManager;
+ SourceManager m_defaultSourceManager;
+ SourceManager* m_sourceManager = nullptr;
// Name pool for looking up names
NamePool namePool;
NamePool* getNamePool() { return &namePool; }
- // Output stuff
- DiagnosticSink mSink;
- String mDiagnosticOutput;
-
- /// A blob holding the diagnostic output
- ComPtr<ISlangBlob> diagnosticOutputBlob;
-
- // Files that compilation depended on
- List<String> mDependencyFilePaths;
-
- // Generated bytecode representation of all the code
- List<uint8_t> generatedBytecode;
-
// Modules that have been dynamically loaded via `import`
//
// This is a list of unique modules loaded, in the order they were encountered.
@@ -424,11 +565,7 @@ namespace Slang
/// or a wrapped impl that makes fileSystem operate as fileSystemExt
ComPtr<ISlangFileSystemExt> fileSystemExt;
- // For output
- ComPtr<ISlangWriter> m_writers[SLANG_WRITER_CHANNEL_COUNT_OF];
-
- void setWriter(WriterChannel chan, ISlangWriter* writer);
- ISlangWriter* getWriter(WriterChannel chan) const { return m_writers[int(chan)]; }
+ ISlangFileSystemExt* getFileSystemExt() { return fileSystemExt; }
/// Load a file into memory using the configured file system.
///
@@ -438,11 +575,177 @@ namespace Slang
///
SlangResult loadFile(String const& path, ISlangBlob** outBlob);
- CompileRequest(Session* session);
- RefPtr<Expr> parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope);
+ RefPtr<Expr> parseTypeString(String typeStr, RefPtr<Scope> scope);
+
+ /// Add a mew target amd return its index.
+ UInt addTarget(
+ CodeGenTarget target);
+
+ RefPtr<Module> loadModule(
+ Name* name,
+ const PathInfo& filePathInfo,
+ ISlangBlob* fileContentsBlob,
+ SourceLoc const& loc,
+ DiagnosticSink* sink);
+
+ void loadParsedModule(
+ RefPtr<TranslationUnitRequest> translationUnit,
+ Name* name,
+ PathInfo const& pathInfo);
+
+ /// Load a module of the given name.
+ Module* loadModule(String const& name);
+
+ RefPtr<Module> findOrImportModule(
+ Name* name,
+ SourceLoc const& loc,
+ DiagnosticSink* sink);
- Type* getTypeFromString(String typeStr);
+ SourceManager* getSourceManager()
+ {
+ return m_sourceManager;
+ }
+
+ /// Override the source manager for the linakge.
+ ///
+ /// This is only used to install a temporary override when
+ /// parsing stuff from strings (where we don't want to retain
+ /// full source files for the parsed result).
+ ///
+ /// TODO: We should remove the need for this hack.
+ ///
+ void setSourceManager(SourceManager* sourceManager)
+ {
+ m_sourceManager = sourceManager;
+ }
+
+ void setFileSystem(ISlangFileSystem* fileSystem);
+
+ /// The layout to use for matrices by default (row/column major)
+ MatrixLayoutMode defaultMatrixLayoutMode = kMatrixLayoutMode_ColumnMajor;
+ MatrixLayoutMode getDefaultMatrixLayoutMode() { return defaultMatrixLayoutMode; }
+
+ private:
+ Session* m_session = nullptr;
+
+ /// Tracks state of modules currently being loaded.
+ ///
+ /// This information is used to diagnose cases where
+ /// a user tries to recursively import the same module
+ /// (possibly along a transitive chain of `import`s).
+ ///
+ struct ModuleBeingImportedRAII
+ {
+ public:
+ ModuleBeingImportedRAII(
+ Linkage* linkage,
+ Module* module)
+ : linkage(linkage)
+ , module(module)
+ {
+ next = linkage->m_modulesBeingImported;
+ linkage->m_modulesBeingImported = this;
+ }
+
+ ~ModuleBeingImportedRAII()
+ {
+ linkage->m_modulesBeingImported = next;
+ }
+
+ Linkage* linkage;
+ Module* module;
+ ModuleBeingImportedRAII* next;
+ };
+
+ // Any modules currently being imported will be listed here
+ ModuleBeingImportedRAII* m_modulesBeingImported;
+
+ /// Is the given module in the middle of being imported?
+ bool isBeingImported(Module* module);
+ };
+
+ /// Shared functionality between front- and back-end compile requests.
+ ///
+ /// This is the base class for both `FrontEndCompileRequest` and
+ /// `BackEndCompileRequest`, and allows a small number of parts of
+ /// the compiler to be easily invocable from either front-end or
+ /// back-end work.
+ ///
+ class CompileRequestBase : public RefObject
+ {
+ // TODO: We really shouldn't need this type in the long run.
+ // The few places that rely on it should be refactored to just
+ // depend on the unerlying information (a linkage and a diagnostic
+ // sink) directly.
+ //
+ // The flags to control dumping and validation of IR should be
+ // moved to some kind of shared settings/options `struct` that
+ // both front-end and back-end requests can store.
+
+ public:
+ Session* getSession();
+ Linkage* getLinkage() { return m_linkage; }
+ DiagnosticSink* getSink() { return m_sink; }
+ SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); }
+ NamePool* getNamePool() { return getLinkage()->getNamePool(); }
+ ISlangFileSystemExt* getFileSystemExt() { return getLinkage()->getFileSystemExt(); }
+ SlangResult loadFile(String const& path, ISlangBlob** outBlob) { return getLinkage()->loadFile(path, outBlob); }
+
+ bool shouldDumpIR = false;
+ bool shouldValidateIR = false;
+
+ protected:
+ CompileRequestBase(
+ Linkage* linkage,
+ DiagnosticSink* sink);
+
+ private:
+ Linkage* m_linkage = nullptr;
+ DiagnosticSink* m_sink = nullptr;
+ };
+
+ /// A request to compile source code to an AST + IR.
+ class FrontEndCompileRequest : public CompileRequestBase
+ {
+ public:
+ FrontEndCompileRequest(
+ Linkage* linkage,
+ DiagnosticSink* sink);
+
+ int addEntryPoint(
+ int translationUnitIndex,
+ String const& name,
+ Profile entryPointProfile);
+
+ // Translation units we are being asked to compile
+ List<RefPtr<TranslationUnitRequest> > translationUnits;
+
+ RefPtr<TranslationUnitRequest> getTranslationUnit(UInt index) { return translationUnits[index]; }
+
+ // Compile flags to be shared by all translation units
+ SlangCompileFlags compileFlags = 0;
+
+ // If true then generateIR will serialize out IR, and serialize back in again. Making
+ // serialization a bottleneck or firewall between the front end and the backend
+ bool useSerialIRBottleneck = false;
+
+ // If true will serialize and de-serialize with debug information
+ bool verifyDebugSerialization = false;
+
+ List<RefPtr<FrontEndEntryPointRequest>> m_entryPointReqs;
+
+ List<RefPtr<FrontEndEntryPointRequest>> const& getEntryPointReqs() { return m_entryPointReqs; }
+ UInt getEntryPointReqCount() { return m_entryPointReqs.Count(); }
+ FrontEndEntryPointRequest* getEntryPointReq(UInt index) { return m_entryPointReqs[index]; }
+
+ // Directories to search for `#include` files or `import`ed modules
+ SearchDirectoryList searchDirectories;
+
+ SearchDirectoryList const& getSearchDirectories() { return searchDirectories; }
+
+ // Definitions to provide during preprocessing
+ Dictionary<String, String> preprocessorDefinitions;
void parseTranslationUnit(
TranslationUnitRequest* translationUnit);
@@ -454,9 +757,24 @@ namespace Slang
void generateIR();
SlangResult executeActionsInner();
- SlangResult executeActions();
- int addTranslationUnit(SourceLanguage language, String const& name);
+ /// Add a translation unit to be compiled.
+ ///
+ /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang`
+ /// @param moduleName The name that will be used for the module compile from the translation unit.
+ /// @return The zero-based index of the translation unit in this compile request.
+ int addTranslationUnit(SourceLanguage language, Name* moduleName);
+
+ /// Add a translation unit to be compiled.
+ ///
+ /// @param language The source language that the translation unit will use (e.g., `SourceLanguage::Slang`
+ /// @return The zero-based index of the translation unit in this compile request.
+ ///
+ /// The module name for the translation unit will be automatically generated.
+ /// If all translation units in a compile request use automatically generated
+ /// module names, then they are guaranteed not to conflict with one another.
+ ///
+ int addTranslationUnit(SourceLanguage language);
void addTranslationUnitSourceFile(
int translationUnitIndex,
@@ -476,63 +794,337 @@ namespace Slang
int translationUnitIndex,
String const& path);
- int addEntryPoint(
- int translationUnitIndex,
- String const& name,
- Profile profile,
- List<String> const & genericTypeNames);
+ Program* getProgram() { return m_program; }
- UInt addTarget(
- CodeGenTarget target);
+ private:
+ RefPtr<Program> m_program;
+ };
- RefPtr<ModuleDecl> loadModule(
- Name* name,
- const PathInfo& filePathInfo,
- ISlangBlob* fileContentsBlob,
- SourceLoc const& loc);
+ /// A collection of code modules and entry points that are intended to be used together.
+ ///
+ /// A `Program` establishes that certain pieces of code are intended
+ /// to be used togehter so that, e.g., layout can make sure to allocate
+ /// space for the global shader parameters in all referenced modules.
+ ///
+ class Program : public RefObject
+ {
+ public:
+ /// Create a new program, initially empty.
+ ///
+ /// All code loaded into the program must come
+ /// from the given `linkage`.
+ Program(
+ Linkage* linkage);
- void loadParsedModule(
- RefPtr<TranslationUnitRequest> const& translationUnit,
- Name* name,
- PathInfo const& pathInfo);
+ /// Get the linkage that this program uses.
+ Linkage* getLinkage() { return m_linkage; }
- RefPtr<ModuleDecl> findOrImportModule(
- Name* name,
- SourceLoc const& loc);
+ /// Get the number of entry points added to the program
+ UInt getEntryPointCount() { return m_entryPoints.Count(); }
- Decl* lookupGlobalDecl(Name* name);
+ /// Get the entry point at the given `index`.
+ RefPtr<EntryPoint> getEntryPoint(UInt index) { return m_entryPoints[index]; }
- SourceManager* getSourceManager()
+ /// Get the full ist of entry points on the program.
+ List<RefPtr<EntryPoint>> const& getEntryPoints() { return m_entryPoints; }
+
+ /// Get the substitution (if any) that represents how global generics are specialized.
+ RefPtr<Substitutions> getGlobalGenericSubstitution() { return m_globalGenericSubst; }
+
+ /// Get the full list of modules this program depends on
+ List<RefPtr<Module>> getModuleDependencies() { return m_moduleDependencyList.getModuleList(); }
+
+ /// Get the full list of filesystem paths this program depends on
+ List<String> getFilePathDependencies() { return m_filePathDependencyList.getFilePathList(); }
+
+ /// Get the target-specific version of this program for the given `target`.
+ ///
+ /// The `target` must be a target on the `Linkage` that was used to create this program.
+ TargetProgram* getTargetProgram(TargetRequest* target);
+
+ /// Add a module (and everything it depends on) to the list of references
+ void addReferencedModule(Module* module);
+
+ /// Add a module (but not the things it depends on) to the list of references
+ ///
+ /// This is a compatiblity hack for legacy compiler behavior.
+ void addReferencedLeafModule(Module* module);
+
+
+ /// Add an entry point to the program
+ ///
+ /// This also adds everything the entry point depends on to the list of references.
+ ///
+ void addEntryPoint(EntryPoint* entryPoint);
+
+ /// Set the global generic argument substitution to use.
+ void setGlobalGenericSubsitution(RefPtr<Substitutions> subst)
{
- return sourceManager;
+ m_globalGenericSubst = subst;
}
- void setSourceManager(SourceManager* sm)
+ /// Parse a type from a string, in the context of this program.
+ ///
+ /// Any names in the string will be resolved using the modules
+ /// referenced by the program.
+ ///
+ /// On an error, returns null and reports diagnostic messages
+ /// to the provided `sink`.
+ ///
+ Type* getTypeFromString(String typeStr, DiagnosticSink* sink);
+
+ /// Get the IR module that represents this program and its entry points.
+ ///
+ /// The IR module for a program tries to be minimal, and in the
+ /// common case will only include symbols with `[import]` declarations
+ /// for the entry point(s) of the program, and any types they
+ /// depend on.
+ ///
+ /// This IR module is intended to be linked against the IR modules
+ /// for all of the dependencies (see `getModuleDependencies()`) to
+ /// provide complete code.
+ ///
+ RefPtr<IRModule> getOrCreateIRModule(DiagnosticSink* sink);
+
+ private:
+ // The linakge this program is associated with.
+ //
+ // Note that a `Program` keeps its associated linkage alive,
+ // and not vice versa.
+ //
+ RefPtr<Linkage> m_linkage;
+
+ // Tracking data for the list of modules dependend on
+ ModuleDependencyList m_moduleDependencyList;
+
+ // Tracking data for the list of filesystem paths dependend on
+ FilePathDependencyList m_filePathDependencyList;
+
+ // Entry points that are part of the program.
+ List<RefPtr<EntryPoint> > m_entryPoints;
+
+ // Specializations for global generic parameters (if any)
+ RefPtr<Substitutions> m_globalGenericSubst;
+
+ // Generated IR for this program.
+ RefPtr<IRModule> m_irModule;
+
+ // Cache of target-specific programs for each target.
+ Dictionary<TargetRequest*, RefPtr<TargetProgram>> m_targetPrograms;
+
+ // Any types looked up dynamically using `getTypeFromString`
+ Dictionary<String, RefPtr<Type>> m_types;
+ };
+
+ /// A `Program` specialized for a particular `TargetRequest`
+ class TargetProgram : public RefObject
+ {
+ public:
+ TargetProgram(
+ Program* program,
+ TargetRequest* targetReq);
+
+ /// Get the underlying program
+ Program* getProgram() { return m_program; }
+
+ /// Get the underlying target
+ TargetRequest* getTargetReq() { return m_targetReq; }
+
+ /// Get the layout for the program on the target.
+ ///
+ /// If this is the first time the layout has been
+ /// requested, report any errors that arise during
+ /// layout to the given `sink`.
+ ///
+ ProgramLayout* getOrCreateLayout(DiagnosticSink* sink);
+
+ /// Get the layout for the program on the taarget.
+ ///
+ /// This routine assumes that `getOrCreateLayout`
+ /// has already been called previously.
+ ///
+ ProgramLayout* getExistingLayout()
{
- sourceManager = sm;
- mSink.sourceManager = sm;
+ SLANG_ASSERT(m_layout);
+ return m_layout;
}
- void setFileSystem(ISlangFileSystem* fileSystem);
+ /// Get the compiled code for an entry point on the target.
+ ///
+ /// This routine assumes code generation has already been
+ /// performed and called `setEntryPointResult`.
+ ///
+ CompileResult& getExistingEntryPointResult(Int entryPointIndex)
+ {
+ return m_entryPointResults[entryPointIndex];
+ }
+
+ // TODO: Need a lazy `getOrCreateEntryPointResult`
+
+ /// Set the compiled code for an entry point.
+ ///
+ /// Should only be called by code generation.
+ void setEntryPointResult(Int entryPointIndex, CompileResult const& result)
+ {
+ m_entryPointResults[entryPointIndex] = result;
+ }
+
+ private:
+ // The program being compiled or laid out
+ Program* m_program;
+
+ // The target that code/layout will be generated for
+ TargetRequest* m_targetReq;
+
+ // The computed layout, if it has been generated yet
+ RefPtr<ProgramLayout> m_layout;
+
+ // Generated compile results for each entry point
+ // in the parent `Program` (indexing matches
+ // the order they are given in the `Program`)
+ List<CompileResult> m_entryPointResults;
+ };
+
+ /// A request to generate code for a program
+ class BackEndCompileRequest : public CompileRequestBase
+ {
+ public:
+ BackEndCompileRequest(
+ Linkage* linkage,
+ DiagnosticSink* sink,
+ Program* program = nullptr);
+
+ // Should we dump intermediate results along the way, for debugging?
+ bool shouldDumpIntermediates = false;
+
+ // How should `#line` directives be emitted (if at all)?
+ LineDirectiveMode lineDirectiveMode = LineDirectiveMode::Default;
+
+ LineDirectiveMode getLineDirectiveMode() { return lineDirectiveMode; }
- /// During propagation of an exception for an internal
- /// error, note that this source location was involved
- void noteInternalErrorLoc(SourceLoc const& loc);
+ Program* getProgram() { return m_program; }
+ void setProgram(Program* program) { m_program = program; }
- int internalErrorLocsNoted = 0;
+ private:
+ RefPtr<Program> m_program;
+ };
+
+ /// A compile request that spans the front and back ends of the compiler
+ ///
+ /// This is what the command-line `slangc` uses, as well as the legacy
+ /// C API. It ties together the functionality of `Linkage`,
+ /// `FrontEndCompileRequest`, and `BackEndCompileRequest`, plus a small
+ /// number of additional features that primarily make sense for
+ /// command-line usage.
+ ///
+ class EndToEndCompileRequest : public RefObject
+ {
+ public:
+ EndToEndCompileRequest(
+ Session* session);
+
+ // What container format are we being asked to generate?
+ //
+ // Note: This field is unused except by the options-parsing
+ // logic; it exists to support wriiting out binary modules
+ // once that feature is ready.
+ //
+ ContainerFormat containerFormat = ContainerFormat::None;
+
+ // Path to output container to
+ //
+ // Note: This field exists to support wriiting out binary modules
+ // once that feature is ready.
+ //
+ String containerOutputPath;
+
+ // Should we just pass the input to another compiler?
+ PassThroughMode passThrough = PassThroughMode::None;
+
+ /// Source code for the generic arguments to use for the global generic parameters of the program.
+ List<String> globalGenericArgStrings;
+
+
+ bool shouldSkipCodegen = false;
+
+ // Are we being driven by the command-line `slangc`, and should act accordingly?
+ bool isCommandLineCompile = false;
+
+ String mDiagnosticOutput;
+
+ /// A blob holding the diagnostic output
+ ComPtr<ISlangBlob> diagnosticOutputBlob;
+
+ /// Per-entry-point information not tracked by other compile requests
+ class EntryPointInfo : public RefObject
+ {
+ public:
+ /// Source code for the generic arguments to use for the generic parameters of the entry point.
+ List<String> genericArgStrings;
+ };
+ List<EntryPointInfo> entryPoints;
+
+ /// Per-target information only needed for command-line compiles
+ class TargetInfo : public RefObject
+ {
+ public:
+ // Requested output paths for each entry point.
+ // An empty string indices no output desired for
+ // the given entry point.
+ Dictionary<Int, String> entryPointOutputPaths;
+ };
+ Dictionary<TargetRequest*, RefPtr<TargetInfo>> targetInfos;
+
+ Linkage* getLinkage() { return m_linkage; }
+
+ int addEntryPoint(
+ int translationUnitIndex,
+ String const& name,
+ Profile profile,
+ List<String> const & genericTypeNames);
+
+ void setWriter(WriterChannel chan, ISlangWriter* writer);
+ ISlangWriter* getWriter(WriterChannel chan) const { return m_writers[int(chan)]; }
+
+ SlangResult executeActionsInner();
+ SlangResult executeActions();
+
+ Session* getSession() { return m_session; }
+ DiagnosticSink* getSink() { return &m_sink; }
+ NamePool* getNamePool() { return getLinkage()->getNamePool(); }
+
+ FrontEndCompileRequest* getFrontEndReq() { return m_frontEndReq; }
+ BackEndCompileRequest* getBackEndReq() { return m_backEndReq; }
+ Program* getUnspecializedProgram() { return getFrontEndReq()->getProgram(); }
+ Program* getSpecializedProgram() { return m_specializedProgram; }
+
+ private:
+ Session* m_session = nullptr;
+ RefPtr<Linkage> m_linkage;
+ DiagnosticSink m_sink;
+ RefPtr<FrontEndCompileRequest> m_frontEndReq;
+ RefPtr<Program> m_unspecializedProgram;
+ RefPtr<Program> m_specializedProgram;
+ RefPtr<BackEndCompileRequest> m_backEndReq;
+
+ // For output
+ ComPtr<ISlangWriter> m_writers[SLANG_WRITER_CHANNEL_COUNT_OF];
};
void generateOutput(
- CompileRequest* compileRequest);
+ BackEndCompileRequest* compileRequest);
+
+ void generateOutput(
+ EndToEndCompileRequest* compileRequest);
// Helper to dump intermediate output when debugging
void maybeDumpIntermediate(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
void const* data,
size_t size,
CodeGenTarget target);
void maybeDumpIntermediate(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
char const* text,
CodeGenTarget target);
@@ -548,12 +1140,14 @@ namespace Slang
@param sink The diagnostic sink to report to */
void reportExternalCompileError(const char* compilerName, SlangResult res, const UnownedStringSlice& diagnostic, DiagnosticSink* sink);
- /* Given a translationUnitRequest determines a filename that is most suitable to identify the input.
- If the translation is a pass through will attempt to get the source file pathname. If the source is slang generated
- there is no equivalent name so will return 'slang-generated'
- @param translationUnitRequest The request to find an appropriate source path for
+ /* Determines a suitable filename to identify the input for a given entry point being compiled.
+ If the end-to-end compile is a pass-through case, will attempt to find the (unique) source file
+ pathname for the translation unit containing the entry point at `entryPointIndex.
+ If the compilation is not in a pass-through case, then always returns `"slang-generated"`.
+ @param endToEndReq The end-to-end compile request which might be using pass-through copmilation
+ @param entryPointIndex The index of the entry point to compute a filename for.
@return the appropriate source filename */
- String calcTranslationUnitSourcePath(TranslationUnitRequest* translationUnitRequest);
+ String calcSourcePathForEntryPoint(EndToEndCompileRequest* endToEndReq, UInt entryPointIndex);
struct TypeCheckingCache;
//
@@ -696,6 +1290,10 @@ namespace Slang
String const& path,
String const& source);
~Session();
+
+ private:
+ /// Linkage used for all built-in (stdlib) code.
+ RefPtr<Linkage> m_builtinLinkage;
};
}
diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h
index 0480eb934..10dcefe19 100644
--- a/source/slang/decl-defs.h
+++ b/source/slang/decl-defs.h
@@ -240,6 +240,14 @@ SIMPLE_SYNTAX_CLASS(FuncDecl, FunctionDeclBase)
// that provides a scope for some number of declarations.
SYNTAX_CLASS(ModuleDecl, ContainerDecl)
FIELD(RefPtr<Scope>, scope)
+
+ // The API-level module that this declaration belong to.
+ //
+ // This field allows lookup of the `Module` based on a
+ // declaration nested under a `ModuleDecl` by following
+ // its chain of parents.
+ //
+ RAW(Module* module = nullptr;)
END_SYNTAX_CLASS()
SYNTAX_CLASS(ImportDecl, Decl)
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index 51076bde1..2d0dd7fdd 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -336,7 +336,7 @@ DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entr
DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function")
DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'")
-DIAGNOSTIC(38005, Error, entryPointTypeSymbolNotAType, "entry-point type parameter '$0' must be declared as a type")
+DIAGNOSTIC(38005, Error, globalGenericArgumentNotAType, "argument for global generic parameter '$0' must be a type")
DIAGNOSTIC(38006, Warning, specifiedStageDoesntMatchAttribute, "entry point '$0' being compiled for the '$1' stage has a '[shader(...)]' attribute that specifies the '$2' stage")
DIAGNOSTIC(38007, Error, entryPointHasNoStage, "no stage specified for entry point '$0'; use either a '[shader(\"name\")]' function attribute or the '-stage <name>' command-line option to specify a stage")
@@ -356,6 +356,10 @@ DIAGNOSTIC(38024, Error, invalidDispatchThreadIDType, "parameter with SV_Dispatc
DIAGNOSTIC(-1, Note, noteWhenCompilingEntryPoint, "when compiling entry point '$0'")
+DIAGNOSTIC(38020, Error, mismatchGlobalGenericArguments, "expected $0 global generic arguments ($1 provided)")
+DIAGNOSTIC(38021, Error, globalTypeArgumentDoesNotConformToInterface, "type argument `$1` for global generic parameter `$0` does not conform to interface `$2`.")
+
+
DIAGNOSTIC(38200, Error, recursiveModuleImport, "module `$0` recursively imports itself")
DIAGNOSTIC(39999, Fatal, errorInImportedModule, "error in imported module, compilation ceased.")
diff --git a/source/slang/diagnostics.h b/source/slang/diagnostics.h
index e3aba32e6..9efc5efc6 100644
--- a/source/slang/diagnostics.h
+++ b/source/slang/diagnostics.h
@@ -2,6 +2,7 @@
#define RASTER_RENDERER_COMPILE_ERROR_H
#include "../core/basic.h"
+#include "../core/slang-writer.h"
#include "source-loc.h"
#include "token.h"
@@ -153,7 +154,7 @@ namespace Slang
StringBuilder outputBuffer;
// List<Diagnostic> diagnostics;
int errorCount = 0;
-
+ int internalErrorLocsNoted = 0;
ISlangWriter* writer = nullptr;
Flags flags = 0;
@@ -217,6 +218,32 @@ namespace Slang
void diagnoseRaw(
Severity severity,
const UnownedStringSlice& message);
+
+ /// During propagation of an exception for an internal
+ /// error, note that this source location was involved
+ void noteInternalErrorLoc(SourceLoc const& loc);
+ };
+
+ /// An `ISlangWriter` that writes directly to a diagnostic sink.
+ class DiagnosticSinkWriter : public AppendBufferWriter
+ {
+ public:
+ typedef AppendBufferWriter Super;
+
+ DiagnosticSinkWriter(DiagnosticSink* sink)
+ : Super(WriterFlag::IsStatic)
+ , m_sink(sink)
+ {}
+
+ // ISlangWriter
+ SLANG_NO_THROW virtual SlangResult SLANG_MCALL write(const char* chars, size_t numChars) SLANG_OVERRIDE
+ {
+ m_sink->diagnoseRaw(Severity::Note, UnownedStringSlice(chars, chars+numChars));
+ return SLANG_OK;
+ }
+
+ private:
+ DiagnosticSink* m_sink = nullptr;
};
namespace Diagnostics
diff --git a/source/slang/dxc-support.cpp b/source/slang/dxc-support.cpp
index 7d2d6dc0c..603a59ea7 100644
--- a/source/slang/dxc-support.cpp
+++ b/source/slang/dxc-support.cpp
@@ -30,8 +30,11 @@ namespace Slang
{
String GetHLSLProfileName(Profile profile);
String emitHLSLForEntryPoint(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq);
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq);
static UnownedStringSlice _getSlice(IDxcBlob* blob)
{
@@ -46,19 +49,22 @@ namespace Slang
}
SlangResult emitDXILForEntryPointUsingDXC(
- EntryPointRequest* entryPoint,
- TargetRequest* targetReq,
- List<uint8_t>& outCode)
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ Int entryPointIndex,
+ TargetRequest* targetReq,
+ EndToEndCompileRequest* endToEndReq,
+ List<uint8_t>& outCode)
{
- auto compileRequest = entryPoint->compileRequest;
- auto session = compileRequest->mSession;
+ auto session = compileRequest->getSession();
+ auto sink = compileRequest->getSink();
// First deal with all the rigamarole of loading
// the `dxcompiler` library, and creating the
// top-level COM objects that will be used to
// compile things.
- auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, &compileRequest->mSink);
+ auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink);
if (!dxcCreateInstance)
{
return SLANG_FAIL;
@@ -69,9 +75,7 @@ namespace Slang
{
// If can't load dxil - dxc will not be able to sign output
// Output a suitable warning to the user
- auto& sink = entryPoint->compileRequest->mSink;
-
- sink.diagnose(SourceLoc(), Diagnostics::dxilNotFound);
+ sink->diagnose(SourceLoc(), Diagnostics::dxilNotFound);
}
}
@@ -89,8 +93,13 @@ namespace Slang
// Now let's go ahead and generate HLSL for the entry
// point, since we'll need that to feed into dxc.
- auto hlslCode = emitHLSLForEntryPoint(entryPoint, targetReq);
- maybeDumpIntermediate(entryPoint->compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL);
+ auto hlslCode = emitHLSLForEntryPoint(
+ compileRequest,
+ entryPoint,
+ entryPointIndex,
+ targetReq,
+ endToEndReq);
+ maybeDumpIntermediate(compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL);
// Wrap the
@@ -122,7 +131,7 @@ namespace Slang
break;
}
- switch( targetReq->floatingPointMode )
+ switch( targetReq->getFloatingPointMode() )
{
default:
break;
@@ -149,7 +158,7 @@ namespace Slang
//
args[argCount++] = L"-no-warnings";
- String entryPointName = getText(entryPoint->name);
+ String entryPointName = getText(entryPoint->getName());
OSString wideEntryPointName = entryPointName.ToWString();
auto profile = getEffectiveProfile(entryPoint, targetReq);
@@ -172,7 +181,7 @@ namespace Slang
args[argCount++] = L"-enable-16bit-types";
}
- const String sourcePath = calcTranslationUnitSourcePath(entryPoint->getTranslationUnit());
+ const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex);
ComPtr<IDxcOperationResult> dxcResult;
SLANG_RETURN_ON_FAIL(dxcCompiler->Compile(dxcSourceBlob,
@@ -208,7 +217,7 @@ namespace Slang
// into a string for safety.
//
- reportExternalCompileError("dxc", resultCode, _getSlice(dxcErrorBlob), &entryPoint->compileRequest->mSink);
+ reportExternalCompileError("dxc", resultCode, _getSlice(dxcErrorBlob), compileRequest->getSink());
return resultCode;
}
@@ -225,20 +234,21 @@ namespace Slang
}
SlangResult dissassembleDXILUsingDXC(
- CompileRequest* compileRequest,
- void const* data,
- size_t size,
- String& stringOut)
+ BackEndCompileRequest* compileRequest,
+ void const* data,
+ size_t size,
+ String& stringOut)
{
stringOut = String();
- auto session = compileRequest->mSession;
+ auto session = compileRequest->getSession();
+ auto sink = compileRequest->getSink();
// First deal with all the rigamarole of loading
// the `dxcompiler` library, and creating the
// top-level COM objects that will be used to
// compile things.
- auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, &compileRequest->mSink);
+ auto dxcCreateInstance = (DxcCreateInstanceProc)session->getSharedLibraryFunc(Session::SharedLibraryFuncType::Dxc_DxcCreateInstance, sink);
if (!dxcCreateInstance)
{
return SLANG_FAIL;
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index c0e5e4296..2603df11e 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -79,8 +79,10 @@ void requireGLSLVersionImpl(
// Shared state for an entire emit session
struct SharedEmitContext
{
+ BackEndCompileRequest* compileRequest = nullptr;
+
// The entry point we are being asked to compile
- EntryPointRequest* entryPoint;
+ EntryPoint* entryPoint;
// The layout for the entry point
EntryPointLayout* entryPointLayout;
@@ -153,7 +155,7 @@ struct SharedEmitContext
// to use for it when emitting code.
Dictionary<IRInst*, String> mapInstToName;
- DiagnosticSink* getSink() { return &entryPoint->compileRequest->mSink; }
+ DiagnosticSink* getSink() { return compileRequest->getSink(); }
Dictionary<IRInst*, UInt> mapIRValueToRayPayloadLocation;
Dictionary<IRInst*, UInt> mapIRValueToCallablePayloadLocation;
@@ -165,6 +167,10 @@ struct EmitContext
SharedEmitContext* shared;
DiagnosticSink* getSink() { return shared->getSink(); }
+
+ LineDirectiveMode getLineDirectiveMode() { return shared->compileRequest->getLineDirectiveMode(); }
+ SourceManager* getSourceManager() { return shared->compileRequest->getSourceManager(); }
+ void noteInternalErrorLoc(SourceLoc loc) { return getSink()->noteInternalErrorLoc(loc); }
};
//
@@ -334,11 +340,6 @@ struct EmitVisitor
: context(context)
{}
- Session* getSession()
- {
- return context->shared->entryPoint->compileRequest->mSession;
- }
-
// Low-level emit logic
void emitRawTextSpan(char const* textBegin, char const* textEnd)
@@ -556,7 +557,7 @@ struct EmitVisitor
bool shouldUseGLSLStyleLineDirective = false;
- auto mode = context->shared->entryPoint->compileRequest->lineDirectiveMode;
+ auto mode = context->getLineDirectiveMode();
switch (mode)
{
case LineDirectiveMode::None:
@@ -664,7 +665,7 @@ struct EmitVisitor
{
// Don't do any of this work if the user has requested that we
// not emit line directives.
- auto mode = context->shared->entryPoint->compileRequest->lineDirectiveMode;
+ auto mode = context->getLineDirectiveMode();
switch(mode)
{
case LineDirectiveMode::None:
@@ -723,7 +724,7 @@ struct EmitVisitor
SourceManager* getSourceManager()
{
- return context->shared->entryPoint->compileRequest->getSourceManager();
+ return context->getSourceManager();
}
void advanceToSourceLocation(
@@ -747,7 +748,7 @@ struct EmitVisitor
DiagnosticSink* getSink()
{
- return &context->shared->entryPoint->compileRequest->mSink;
+ return context->getSink();
}
//
@@ -1875,8 +1876,7 @@ struct EmitVisitor
}
}
- void emitGLSLVersionDirective(
- ModuleDecl* /*program*/)
+ void emitGLSLVersionDirective()
{
auto effectiveProfile = context->shared->effectiveProfile;
if(effectiveProfile.getFamily() == ProfileFamily::GLSL)
@@ -1931,8 +1931,7 @@ struct EmitVisitor
Emit("#version 420\n");
}
- void emitGLSLPreprocessorDirectives(
- RefPtr<ModuleDecl> program)
+ void emitGLSLPreprocessorDirectives()
{
switch(context->shared->target)
{
@@ -1944,24 +1943,7 @@ struct EmitVisitor
break;
}
- emitGLSLVersionDirective(program);
-
-
- // TODO: when cross-compiling we may need to output additional `#extension` directives
- // based on the features that we have used.
-
- for( auto extensionDirective : program->GetModifiersOfType<GLSLExtensionDirective>() )
- {
- // TODO(tfoley): Emit an appropriate `#line` directive...
-
- Emit("#extension ");
- emit(extensionDirective->extensionNameToken.Content);
- Emit(" : ");
- emit(extensionDirective->dispositionToken.Content);
- Emit("\n");
- }
-
- // TODO: handle other cases...
+ emitGLSLVersionDirective();
}
/// Emit directives to control overall layout computation for the emitted code.
@@ -4115,7 +4097,7 @@ struct EmitVisitor
catch(AbortCompilationException&) { throw; }
catch(...)
{
- ctx->shared->entryPoint->compileRequest->noteInternalErrorLoc(inst->sourceLoc);
+ ctx->noteInternalErrorLoc(inst->sourceLoc);
throw;
}
}
@@ -6583,26 +6565,26 @@ struct EmitVisitor
//
EntryPointLayout* findEntryPointLayout(
- ProgramLayout* programLayout,
- EntryPointRequest* entryPointRequest)
+ ProgramLayout* programLayout,
+ EntryPoint* entryPoint)
{
for( auto entryPointLayout : programLayout->entryPoints )
{
- if(entryPointLayout->entryPoint->getName() != entryPointRequest->name)
+ if(entryPointLayout->entryPoint->getName() != entryPoint->getName())
continue;
// TODO: We need to be careful about this check, since it relies on
// the profile information in the layout matching that in the request.
//
// What we really seem to want here is some dictionary mapping the
- // `EntryPointRequest` directly to the `EntryPointLayout`, and maybe
+ // `EntryPoint` directly to the `EntryPointLayout`, and maybe
// that is precisely what we should build...
//
- if(entryPointLayout->profile != entryPointRequest->profile)
+ if(entryPointLayout->profile != entryPoint->getProfile())
continue;
// TODO: can't easily filter on translation unit here...
- // Ideally the `EntryPointRequest` should get filled in with a pointer
+ // Ideally the `EntryPoint` should get filled in with a pointer
// the specific function declaration that represents the entry point.
return entryPointLayout.Ptr();
@@ -6662,13 +6644,14 @@ void legalizeTypes(
IRModule* module);
static void dumpIRIfEnabled(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
IRModule* irModule,
char const* label = nullptr)
{
if(compileRequest->shouldDumpIR)
{
- WriterHelper writer(compileRequest->getWriter(WriterChannel::StdError));
+ DiagnosticSinkWriter writerImpl(compileRequest->getSink());
+ WriterHelper writer(&writerImpl);
if(label)
{
@@ -6687,16 +6670,22 @@ static void dumpIRIfEnabled(
}
String emitEntryPoint(
- EntryPointRequest* entryPoint,
- ProgramLayout* programLayout,
- CodeGenTarget target,
- TargetRequest* targetRequest)
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ CodeGenTarget target,
+ TargetRequest* targetRequest)
{
- auto translationUnit = entryPoint->getTranslationUnit();
+ auto sink = compileRequest->getSink();
+ auto program = compileRequest->getProgram();
+ auto targetProgram = program->getTargetProgram(targetRequest);
+ auto programLayout = targetProgram->getOrCreateLayout(sink);
+
+// auto translationUnit = entryPoint->getTranslationUnit();
SharedEmitContext sharedContext;
+ sharedContext.compileRequest = compileRequest;
sharedContext.target = target;
- sharedContext.finalTarget = targetRequest->target;
+ sharedContext.finalTarget = targetRequest->getTarget();
sharedContext.entryPoint = entryPoint;
sharedContext.effectiveProfile = getEffectiveProfile(entryPoint, targetRequest);
@@ -6715,16 +6704,13 @@ String emitEntryPoint(
StructTypeLayout* globalStructLayout = getGlobalStructLayout(programLayout);
sharedContext.globalStructLayout = globalStructLayout;
- auto translationUnitSyntax = translationUnit->SyntaxNode.Ptr();
-
EmitContext context;
context.shared = &sharedContext;
EmitVisitor visitor(&context);
{
- auto compileRequest = translationUnit->compileRequest;
- auto session = compileRequest->mSession;
+ auto session = targetRequest->getSession();
// We start out by performing "linking" at the level of the IR.
// This step will create a fresh IR module to be used for
@@ -6735,6 +6721,7 @@ String emitEntryPoint(
// any "profile-overloaded" symbols.
//
auto linkedIR = linkIR(
+ compileRequest,
entryPoint,
programLayout,
target,
@@ -6880,7 +6867,7 @@ String emitEntryPoint(
session,
irModule,
irEntryPoint,
- &compileRequest->mSink,
+ compileRequest->getSink(),
&sharedContext.extensionUsageTracker);
}
break;
@@ -6916,10 +6903,6 @@ String emitEntryPoint(
// TODO: do we want to emit directly from IR, or translate the
// IR back into AST for emission?
visitor.emitIRModule(&context, irModule);
-
- // retain the specialized ir module, because the current
- // GlobalGenericParamSubstitution implementation may reference ir objects
- targetRequest->compileRequest->compiledModules.Add(irModule);
}
// Deal with cases where a particular stage requires certain GLSL versions
@@ -6950,7 +6933,7 @@ String emitEntryPoint(
// it is time to stich together the final output.
// There may be global-scope modifiers that we should emit now
- visitor.emitGLSLPreprocessorDirectives(translationUnitSyntax);
+ visitor.emitGLSLPreprocessorDirectives();
visitor.emitLayoutDirectives(targetRequest);
diff --git a/source/slang/emit.h b/source/slang/emit.h
index 98845f9c6..317afcf6b 100644
--- a/source/slang/emit.h
+++ b/source/slang/emit.h
@@ -8,7 +8,7 @@
namespace Slang
{
- class EntryPointRequest;
+ class EntryPoint;
class ProgramLayout;
class TranslationUnitRequest;
@@ -20,13 +20,13 @@ namespace Slang
// Emit code for a single entry point, based on
// the input translation unit.
String emitEntryPoint(
- EntryPointRequest* entryPoint,
- ProgramLayout* programLayout,
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
// The target language to generate code in (e.g., HLSL/GLSL)
- CodeGenTarget target,
+ CodeGenTarget target,
// The full target request
- TargetRequest* targetRequest);
+ TargetRequest* targetRequest);
}
#endif
diff --git a/source/slang/ir-dce.cpp b/source/slang/ir-dce.cpp
index 0f037bfe5..ba6d7adb9 100644
--- a/source/slang/ir-dce.cpp
+++ b/source/slang/ir-dce.cpp
@@ -16,8 +16,8 @@ struct DeadCodeEliminationContext
// the parameters that were passed to the top-level
// `eliminateDeadCode` function.
//
- CompileRequest* compileRequest;
- IRModule* module;
+ BackEndCompileRequest* compileRequest;
+ IRModule* module;
// Our overall process is going to be to determine
// which instructions in the module are "live"
@@ -235,9 +235,9 @@ struct DeadCodeEliminationContext
// we'll just go ahead and eliminate every single function/type
// in a module. There needs to be a way to identify the
// functions we want to keep around, and for right now
- // that is handled with the `[entryPoint]` decoration.
+ // that is handled with the `[keepAlive]` decoration.
//
- if(inst->findDecorationImpl(kIROp_EntryPointDecoration))
+ if(inst->findDecorationImpl(kIROp_KeepAliveDecoration))
return true;
//
// TODO: Eventually it would make sense to consider everything
@@ -312,7 +312,7 @@ struct DeadCodeEliminationContext
// and then defer to it for the real work.
//
void eliminateDeadCode(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
IRModule* module)
{
DeadCodeEliminationContext context;
diff --git a/source/slang/ir-dce.h b/source/slang/ir-dce.h
index fd56616d9..6089b404a 100644
--- a/source/slang/ir-dce.h
+++ b/source/slang/ir-dce.h
@@ -3,7 +3,7 @@
namespace Slang
{
- class CompileRequest;
+ class BackEndCompileRequest;
struct IRModule;
/// Eliminate "dead" code from the given IR module.
@@ -14,6 +14,6 @@ namespace Slang
/// etc.
///
void eliminateDeadCode(
- CompileRequest* compileRequest,
- IRModule* module);
+ BackEndCompileRequest* compileRequest,
+ IRModule* module);
}
diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h
index b6f8ce547..eada52e4d 100644
--- a/source/slang/ir-inst-defs.h
+++ b/source/slang/ir-inst-defs.h
@@ -392,6 +392,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// even if it does not otherwise reference it.
INST(DependsOnDecoration, dependsOn, 1, 0)
+ /// A `[keepAlive]` decoration marks an instruction that should not be eliminated.
+ INST(KeepAliveDecoration, keepAlive, 0, 0)
+
+
/* LinkageDecoration */
INST(ImportDecoration, import, 1, 0)
INST(ExportDecoration, export, 1, 0)
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
index 6b12612ef..b7c1b2744 100644
--- a/source/slang/ir-insts.h
+++ b/source/slang/ir-insts.h
@@ -1201,6 +1201,11 @@ struct IRBuilder
addDecoration(value, kIROp_EntryPointDecoration);
}
+ void addKeepAliveDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_KeepAliveDecoration);
+ }
+
/// Add a decoration that indicates that the given `inst` depends on the given `dependency`.
///
/// This decoration can be used to ensure that a value that an instruction
diff --git a/source/slang/ir-link.cpp b/source/slang/ir-link.cpp
index 35e0f46b8..2eef10614 100644
--- a/source/slang/ir-link.cpp
+++ b/source/slang/ir-link.cpp
@@ -14,7 +14,7 @@ namespace Slang
// instead of the input/request layer.
EntryPointLayout* findEntryPointLayout(
ProgramLayout* programLayout,
- EntryPointRequest* entryPointRequest);
+ EntryPoint* EntryPoint);
struct IRSpecSymbol : RefObject
{
@@ -39,9 +39,6 @@ struct IRSharedSpecContext
// The specialized module we are building
RefPtr<IRModule> module;
- // The original, unspecialized module we are copying
- IRModule* originalModule;
-
// A map from mangled symbol names to zero or
// more global IR values that have that name,
// in the *original* module.
@@ -67,8 +64,6 @@ struct IRSpecContextBase
IRModule* getModule() { return getShared()->module; }
- IRModule* getOriginalModule() { return getShared()->originalModule; }
-
IRSharedSpecContext::SymbolDictionary& getSymbols() { return getShared()->symbols; }
// The current specialization environment to use.
@@ -668,8 +663,8 @@ IRInst* specializeGeneric(
IRSpecialize* specializeInst);
IRFunc* specializeIRForEntryPoint(
- IRSpecContext* context,
- EntryPointRequest* entryPointRequest,
+ IRSpecContext* context,
+ EntryPoint* entryPoint,
EntryPointLayout* entryPointLayout)
{
// We start by looking up the IR symbol that
@@ -681,7 +676,7 @@ IRFunc* specializeIRForEntryPoint(
// so that the mangled name of the decl-ref is
// not the same as the mangled name of the decl.
//
- auto mangledName = getMangledName(entryPointRequest->getFuncDeclRef());
+ auto mangledName = getMangledName(entryPoint->getFuncDeclRef());
RefPtr<IRSpecSymbol> sym;
if (!context->getSymbols().TryGetValue(mangledName, sym))
{
@@ -743,9 +738,9 @@ IRFunc* specializeIRForEntryPoint(
return nullptr;
}
- if( !clonedFunc->findDecorationImpl(kIROp_EntryPointDecoration) )
+ if( !clonedFunc->findDecorationImpl(kIROp_KeepAliveDecoration) )
{
- context->builder->addEntryPointDecoration(clonedFunc);
+ context->builder->addKeepAliveDecoration(clonedFunc);
}
// We need to attach the layout information for
@@ -1148,7 +1143,6 @@ void initializeSharedSpecContext(
IRSharedSpecContext* sharedContext,
Session* session,
IRModule* module,
- IRModule* originalModule,
CodeGenTarget target)
{
@@ -1166,19 +1160,15 @@ void initializeSharedSpecContext(
sharedBuilder->module = module;
sharedContext->module = module;
- sharedContext->originalModule = originalModule;
sharedContext->target = target;
- // We will populate a map with all of the IR values
- // that use the same mangled name, to make lookup easier
- // in other steps.
- insertGlobalValueSymbols(sharedContext, originalModule);
}
// implementation provided in parameter-binding.cpp
RefPtr<ProgramLayout> specializeProgramLayout(
TargetRequest * targetReq,
- ProgramLayout* programLayout,
- SubstitutionSet typeSubst);
+ ProgramLayout* programLayout,
+ SubstitutionSet typeSubst,
+ DiagnosticSink* sink);
struct IRSpecializationState
{
@@ -1211,11 +1201,14 @@ struct IRSpecializationState
};
LinkedIR linkIR(
- EntryPointRequest* entryPointRequest,
- ProgramLayout* programLayout,
- CodeGenTarget target,
- TargetRequest* targetReq)
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ ProgramLayout* programLayout,
+ CodeGenTarget target,
+ TargetRequest* targetReq)
{
+ auto sink = compileRequest->getSink();
+
IRSpecializationState stateStorage;
auto state = &stateStorage;
@@ -1223,26 +1216,27 @@ LinkedIR linkIR(
state->target = target;
state->targetReq = targetReq;
-
- auto compileRequest = entryPointRequest->compileRequest;
- auto translationUnit = entryPointRequest->getTranslationUnit();
- auto originalIRModule = translationUnit->irModule;
+ auto program = compileRequest->getProgram();
auto sharedContext = state->getSharedContext();
initializeSharedSpecContext(
sharedContext,
- compileRequest->mSession,
+ compileRequest->getSession(),
nullptr,
- originalIRModule,
target);
state->irModule = sharedContext->module;
- // We also need to attach the IR definitions for symbols from
- // any loaded modules:
- for (auto loadedModule : compileRequest->loadedModulesList)
+ // We need to be able to look up IR definitions for any symbols in
+ // modules that the program depends on (transitively). To
+ // accelerate lookup, we will create a symbol table for looking
+ // up IR definitions by their mangled name.
+ //
+ auto originalProgramIRModule = program->getOrCreateIRModule(sink);
+ insertGlobalValueSymbols(sharedContext, originalProgramIRModule);
+ for (auto module : program->getModuleDependencies())
{
- insertGlobalValueSymbols(sharedContext, loadedModule->irModule);
+ insertGlobalValueSymbols(sharedContext, module->getIRModule());
}
auto context = state->getContext();
@@ -1257,7 +1251,8 @@ LinkedIR linkIR(
RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout(
targetReq,
programLayout,
- SubstitutionSet(entryPointRequest->globalGenericSubst));
+ SubstitutionSet(program->getGlobalGenericSubstitution()),
+ compileRequest->getSink());
// TODO: we need to register the (IR-level) arguments of the global generic parameters as the
// substitutions for the generic parameters in the original IR.
@@ -1267,13 +1262,22 @@ LinkedIR linkIR(
state->newProgramLayout = newProgramLayout;
- // Next, we want to optimize lookup for layout infromation
+ // Next, we want to optimize lookup for layout information
// associated with global declarations, so that we can
// look things up based on the IR values (using mangled names)
+ //
+ // Note: We are scanning over all the key-value pairs for
+ // entries in the global scope, to account for the fact
+ // that the "same" shader parameter could be declared in
+ // multiple translation units, and thus end up with
+ // multiple mangled names (when the unique translation
+ // unit name gets involved).
+ //
auto globalStructLayout = getScopeStructLayout(newProgramLayout);
- for (auto globalVarLayout : globalStructLayout->fields)
+ for(auto entry : globalStructLayout->mapVarToLayout)
{
- auto mangledName = getMangledName(globalVarLayout->varDecl);
+ auto mangledName = getMangledName(entry.Key);
+ auto globalVarLayout = entry.Value;
context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout);
}
@@ -1290,19 +1294,19 @@ LinkedIR linkIR(
cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue);
}
- auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest);
+ auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPoint);
// Next, we make sure to clone the global value for
// the entry point function itself, and rely on
// this step to recursively copy over anything else
// it might reference.
- auto irEntryPoint = specializeIRForEntryPoint(context, entryPointRequest, entryPointLayout);
+ auto irEntryPoint = specializeIRForEntryPoint(context, entryPoint, entryPointLayout);
// HACK: right now the bindings for global generic parameters are coming in
// as part of the original IR module, and we need to make sure these get
// copied over, even if they aren't referenced.
//
- for(auto inst : originalIRModule->getGlobalInsts())
+ for(auto inst : originalProgramIRModule->getGlobalInsts())
{
auto bindInst = as<IRBindGlobalGenericParam>(inst);
if(!bindInst)
diff --git a/source/slang/ir-link.h b/source/slang/ir-link.h
index 4fcdb4618..dba3ccc97 100644
--- a/source/slang/ir-link.h
+++ b/source/slang/ir-link.h
@@ -19,8 +19,9 @@ namespace Slang
// used.
//
LinkedIR linkIR(
- EntryPointRequest* entryPointRequest,
- ProgramLayout* programLayout,
- CodeGenTarget target,
- TargetRequest* targetReq);
+ BackEndCompileRequest* compileRequest,
+ EntryPoint* entryPoint,
+ ProgramLayout* programLayout,
+ CodeGenTarget target,
+ TargetRequest* targetReq);
}
diff --git a/source/slang/ir-specialize-resources.cpp b/source/slang/ir-specialize-resources.cpp
index 0108a91f8..e974ffdb7 100644
--- a/source/slang/ir-specialize-resources.cpp
+++ b/source/slang/ir-specialize-resources.cpp
@@ -18,7 +18,7 @@ struct ResourceParameterSpecializationContext
// the parameters that were passed to the top-level
// `specializeResourceParameters` function.
//
- CompileRequest* compileRequest;
+ BackEndCompileRequest* compileRequest;
TargetRequest* targetRequest;
IRModule* module;
@@ -372,7 +372,7 @@ struct ResourceParameterSpecializationContext
// If we didn't find a pre-existing specialized
// function, then we will go ahead and create one.
//
- // We start by gathering the infromation from the call
+ // We start by gathering the information from the call
// site that is relevant to generating a specialized
// callee function, which we avoided doing earlier
// because it might have been throwaway work.
@@ -850,7 +850,7 @@ struct ResourceParameterSpecializationContext
// and then defer to it for the real work.
//
void specializeResourceParameters(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
TargetRequest* targetRequest,
IRModule* module)
{
diff --git a/source/slang/ir-specialize-resources.h b/source/slang/ir-specialize-resources.h
index 3d6ead130..0e636318c 100644
--- a/source/slang/ir-specialize-resources.h
+++ b/source/slang/ir-specialize-resources.h
@@ -3,7 +3,7 @@
namespace Slang
{
- class CompileRequest;
+ class BackEndCompileRequest;
class TargetRequest;
struct IRModule;
@@ -18,7 +18,7 @@ namespace Slang
/// global shader parameters directly).
///
void specializeResourceParameters(
- CompileRequest* compileRequest,
+ BackEndCompileRequest* compileRequest,
TargetRequest* targetRequest,
IRModule* module);
}
diff --git a/source/slang/ir-validate.cpp b/source/slang/ir-validate.cpp
index 924ec71b3..9564873b1 100644
--- a/source/slang/ir-validate.cpp
+++ b/source/slang/ir-validate.cpp
@@ -195,13 +195,13 @@ namespace Slang
}
void validateIRModuleIfEnabled(
- CompileRequest* compileRequest,
- IRModule* module)
+ CompileRequestBase* compileRequest,
+ IRModule* module)
{
if (!compileRequest->shouldValidateIR)
return;
- auto sink = &compileRequest->mSink;
+ auto sink = compileRequest->getSink();
validateIRModule(module, sink);
}
}
diff --git a/source/slang/ir-validate.h b/source/slang/ir-validate.h
index 0ebc69019..1cb30961d 100644
--- a/source/slang/ir-validate.h
+++ b/source/slang/ir-validate.h
@@ -3,7 +3,7 @@
namespace Slang
{
- class CompileRequest;
+ class CompileRequestBase;
class DiagnosticSink;
struct IRModule;
@@ -30,6 +30,6 @@ namespace Slang
// A wrapper that calls `validateIRModule` only when IR validation is enabled
// for the given compile request.
void validateIRModuleIfEnabled(
- CompileRequest* compileRequest,
- IRModule* module);
+ CompileRequestBase* compileRequest,
+ IRModule* module);
}
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 0d9427b08..b53bb8ebb 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -300,8 +300,18 @@ struct IRGenEnv
struct SharedIRGenContext
{
- CompileRequest* compileRequest;
- ModuleDecl* mainModuleDecl;
+ SharedIRGenContext(
+ Session* session,
+ DiagnosticSink* sink,
+ ModuleDecl* mainModuleDecl = nullptr)
+ : m_session(session)
+ , m_sink(sink)
+ , m_mainModuleDecl(mainModuleDecl)
+ {}
+
+ Session* m_session = nullptr;
+ DiagnosticSink* m_sink = nullptr;
+ ModuleDecl* m_mainModuleDecl = nullptr;
// The "global" environment for mapping declarations to their IR values.
IRGenEnv globalEnv;
@@ -356,17 +366,17 @@ struct IRGenContext
Session* getSession()
{
- return shared->compileRequest->mSession;
+ return shared->m_session;
}
- CompileRequest* getCompileRequest()
+ DiagnosticSink* getSink()
{
- return shared->compileRequest;
+ return shared->m_sink;
}
- DiagnosticSink* getSink()
+ ModuleDecl* getMainModuleDecl()
{
- return &getCompileRequest()->mSink;
+ return shared->m_mainModuleDecl;
}
};
@@ -422,7 +432,7 @@ bool isImportedDecl(IRGenContext* context, Decl* decl)
if (isFromStdLib(decl))
return false;
- if (moduleDecl != context->shared->mainModuleDecl)
+ if (moduleDecl != context->getMainModuleDecl())
return true;
return false;
@@ -1735,16 +1745,31 @@ static String getNameForNameHint(
if(auto genericParentDecl = as<GenericDecl>(parentDecl))
parentDecl = genericParentDecl->ParentDecl;
+ // A `ModuleDecl` can have a name too, but in the common case
+ // we don't want to generate name hints that include the module
+ // name, simply because they would lead to every global symbol
+ // getting a much longer name.
+ //
+ // TODO: We should probably include the module name for symbols
+ // being `import`ed, and not for symbols being compiled directly
+ // (those coming from a module that had no name given to it).
+ //
+ // For now we skip past a `ModuleDecl` parent.
+ //
+ if(auto moduleParentDecl = as<ModuleDecl>(parentDecl))
+ parentDecl = moduleParentDecl->ParentDecl;
+
+ if(!parentDecl)
+ {
+ return leafName->text;
+ }
+
auto parentName = getNameForNameHint(context, parentDecl);
if(parentName.Length() == 0)
{
return leafName->text;
}
- // TODO: at some point we will start giving `ModuleDecl`s names,
- // and in that case we need to think carefully about whether to
- // include their names here or not.
-
// We will now construct a new `Name` to use as the hint,
// combining the name of the parent and the leaf declaration.
@@ -3603,7 +3628,7 @@ void lowerStmt(
catch(AbortCompilationException&) { throw; }
catch(...)
{
- context->getCompileRequest()->noteInternalErrorLoc(stmt->loc);
+ context->getSink()->noteInternalErrorLoc(stmt->loc);
throw;
}
}
@@ -5877,7 +5902,7 @@ LoweredValInfo lowerDecl(
catch(AbortCompilationException&) { throw; }
catch(...)
{
- context->getCompileRequest()->noteInternalErrorLoc(decl->loc);
+ context->getSink()->noteInternalErrorLoc(decl->loc);
throw;
}
}
@@ -6108,56 +6133,59 @@ LoweredValInfo emitDeclRef(
type);
}
-static void lowerEntryPointToIR(
- IRGenContext* context,
- EntryPointRequest* entryPointRequest)
+static void lowerFrontEndEntryPointToIR(
+ IRGenContext* context,
+ EntryPoint* entryPoint)
{
- // First, lower the entry point like an ordinary function
+ // TODO: We should emit an entry point as a dedicated IR function
+ // (distinct from the IR function used if it were called normally),
+ // with a mangled name based on the original function name plus
+ // the stage for which it is being compiled as an entry point (so
+ // that entry points for distinct stages always have distinct names).
+ //
+ // For now we just have an (implicit) constraint that a given
+ // function should only be used as an entry point for one stage,
+ // and any such function should *not* be used as an ordinary function.
- auto session = context->getSession();
- auto entryPointFuncDeclRef = entryPointRequest->getFuncDeclRef();
- auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef));
+ auto entryPointFuncDecl = entryPoint->getFuncDecl();
auto builder = context->irBuilder;
builder->setInsertInto(builder->getModule()->getModuleInst());
auto loweredEntryPointFunc = getSimpleVal(context,
- emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType));
+ ensureDecl(context, entryPointFuncDecl));
// Attach a marker decoration so that we recognize
// this as an entry point.
//
- builder->addEntryPointDecoration(loweredEntryPointFunc);
-
- //
- if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>())
+ IRInst* instToDecorate = loweredEntryPointFunc;
+ if(auto irGeneric = as<IRGeneric>(instToDecorate))
{
- builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice());
+ instToDecorate = findGenericReturnVal(irGeneric);
}
+ builder->addEntryPointDecoration(instToDecorate);
+}
- // Now lower all the arguments supplied for global generic
- // type parameters.
- //
- for (RefPtr<Substitutions> subst = entryPointRequest->globalGenericSubst; subst; subst = subst->outer)
- {
- auto gSubst = subst.as<GlobalGenericParamSubstitution>();
- if(!gSubst)
- continue;
+static void lowerProgramEntryPointToIR(
+ IRGenContext* context,
+ EntryPoint* entryPoint)
+{
+ // First, lower the entry point like an ordinary function
- IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl));
- IRType* typeVal = lowerType(context, gSubst->actualType);
+ auto session = context->getSession();
+ auto entryPointFuncDeclRef = entryPoint->getFuncDeclRef();
+ auto entryPointFuncType = lowerType(context, getFuncType(session, entryPointFuncDeclRef));
- // bind `typeParam` to `typeVal`
- builder->emitBindGlobalGenericParam(typeParam, typeVal);
+ auto builder = context->irBuilder;
+ builder->setInsertInto(builder->getModule()->getModuleInst());
- for (auto& constraintArg : gSubst->constraintArgs)
- {
- IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl));
- IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val);
+ auto loweredEntryPointFunc = getSimpleVal(context,
+ emitDeclRef(context, entryPointFuncDeclRef, entryPointFuncType));
- // bind `constraintParam` to `constraintVal`
- builder->emitBindGlobalGenericParam(constraintParam, constraintVal);
- }
+ //
+ if(!loweredEntryPointFunc->findDecoration<IRLinkageDecoration>())
+ {
+ builder->addExportDecoration(loweredEntryPointFunc, getMangledName(entryPointFuncDeclRef).getUnownedSlice());
}
}
@@ -6191,19 +6219,19 @@ IRModule* generateIRForTranslationUnit(
{
auto compileRequest = translationUnit->compileRequest;
- SharedIRGenContext sharedContextStorage;
+ SharedIRGenContext sharedContextStorage(
+ translationUnit->getSession(),
+ translationUnit->compileRequest->getSink(),
+ translationUnit->getModuleDecl());
SharedIRGenContext* sharedContext = &sharedContextStorage;
- sharedContext->compileRequest = compileRequest;
- sharedContext->mainModuleDecl = translationUnit->SyntaxNode;
-
IRGenContext contextStorage(sharedContext);
IRGenContext* context = &contextStorage;
SharedIRBuilder sharedBuilderStorage;
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->module = nullptr;
- sharedBuilder->session = compileRequest->mSession;
+ sharedBuilder->session = compileRequest->getSession();
IRBuilder builderStorage;
IRBuilder* builder = &builderStorage;
@@ -6224,12 +6252,13 @@ IRModule* generateIRForTranslationUnit(
// in case they require special handling.
for (auto entryPoint : translationUnit->entryPoints)
{
- lowerEntryPointToIR(context, entryPoint);
+ lowerFrontEndEntryPointToIR(context, entryPoint);
}
+
//
// Next, ensure that all other global declarations have
// been emitted.
- for (auto decl : translationUnit->SyntaxNode->Members)
+ for (auto decl : translationUnit->getModuleDecl()->Members)
{
ensureAllDeclsRec(context, decl);
}
@@ -6271,12 +6300,12 @@ IRModule* generateIRForTranslationUnit(
// Propagate `constexpr`-ness through the dataflow graph (and the
// call graph) based on constraints imposed by different instructions.
- propagateConstExpr(module, &compileRequest->mSink);
+ propagateConstExpr(module, compileRequest->getSink());
// TODO: give error messages if any `undefined` or
// `unreachable` instructions remain.
- checkForMissingReturns(module, &compileRequest->mSink);
+ checkForMissingReturns(module, compileRequest->getSink());
// TODO: consider doing some more aggressive optimizations
// (in particular specialization of generics) here, so
@@ -6293,28 +6322,82 @@ IRModule* generateIRForTranslationUnit(
// then we can dump the initial IR for the module here.
if(compileRequest->shouldDumpIR)
{
- ISlangWriter* writer = translationUnit->compileRequest->getWriter(WriterChannel::StdError);
-
- dumpIR(module, writer);
+ DiagnosticSinkWriter writer(compileRequest->getSink());
+ dumpIR(module, &writer);
}
return module;
}
-#if 0
-String emitSlangIRAssemblyForEntryPoint(
- EntryPointRequest* entryPoint)
+RefPtr<IRModule> generateIRForProgram(
+ Session* session,
+ Program* program,
+ DiagnosticSink* sink)
{
- auto compileRequest = entryPoint->compileRequest;
- auto irModule = lowerEntryPointToIR(
- entryPoint,
- compileRequest->layout.Ptr(),
- // TODO: we need to pick the target more carefully here
- CodeGenTarget::HLSL);
-
- return getSlangIRAssembly(irModule);
-}
-#endif
+// auto compileRequest = translationUnit->compileRequest;
+
+ SharedIRGenContext sharedContextStorage(
+ session,
+ sink);
+ SharedIRGenContext* sharedContext = &sharedContextStorage;
+
+ IRGenContext contextStorage(sharedContext);
+ IRGenContext* context = &contextStorage;
+
+ SharedIRBuilder sharedBuilderStorage;
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->module = nullptr;
+ sharedBuilder->session = session;
+
+ IRBuilder builderStorage;
+ IRBuilder* builder = &builderStorage;
+ builder->sharedBuilder = sharedBuilder;
+
+ RefPtr<IRModule> module = builder->createModule();
+ sharedBuilder->module = module;
+
+ context->irBuilder = builder;
+
+ // We need to emit symbols for all of the entry
+ // points in the program; this is especially
+ // important in the case where a generic entry
+ // point is being specialized.
+ //
+ for(auto entryPoint : program->getEntryPoints())
+ {
+ lowerProgramEntryPointToIR(context, entryPoint);
+ }
+
+ // Now lower all the arguments supplied for global generic
+ // type parameters.
+ //
+ for (RefPtr<Substitutions> subst = program->getGlobalGenericSubstitution(); subst; subst = subst->outer)
+ {
+ auto gSubst = subst.as<GlobalGenericParamSubstitution>();
+ if(!gSubst)
+ continue;
+
+ IRInst* typeParam = getSimpleVal(context, ensureDecl(context, gSubst->paramDecl));
+ IRType* typeVal = lowerType(context, gSubst->actualType);
+
+ // bind `typeParam` to `typeVal`
+ builder->emitBindGlobalGenericParam(typeParam, typeVal);
+
+ for (auto& constraintArg : gSubst->constraintArgs)
+ {
+ IRInst* constraintParam = getSimpleVal(context, ensureDecl(context, constraintArg.decl));
+ IRInst* constraintVal = lowerSimpleVal(context, constraintArg.val);
+
+ // bind `constraintParam` to `constraintVal`
+ builder->emitBindGlobalGenericParam(constraintParam, constraintVal);
+ }
+ }
+
+ // TODO: Should we apply any of the validation or
+ // mandatory optimization passes here?
+
+ return module;
+}
} // namespace Slang
diff --git a/source/slang/lower-to-ir.h b/source/slang/lower-to-ir.h
index bd878d6fa..f607e852c 100644
--- a/source/slang/lower-to-ir.h
+++ b/source/slang/lower-to-ir.h
@@ -14,7 +14,7 @@
namespace Slang
{
class CompileRequest;
- class EntryPointRequest;
+ class EntryPoint;
class ProgramLayout;
class TranslationUnitRequest;
@@ -22,5 +22,10 @@ namespace Slang
IRModule* generateIRForTranslationUnit(
TranslationUnitRequest* translationUnit);
+
+ RefPtr<IRModule> generateIRForProgram(
+ Session* session,
+ Program* program,
+ DiagnosticSink* sink);
}
#endif
diff --git a/source/slang/options.cpp b/source/slang/options.cpp
index 8a4cf35b7..65fbfe068 100644
--- a/source/slang/options.cpp
+++ b/source/slang/options.cpp
@@ -41,7 +41,7 @@ struct OptionsParser
SlangSession* session = nullptr;
SlangCompileRequest* compileRequest = nullptr;
- Slang::CompileRequest* requestImpl = nullptr;
+ Slang::EndToEndCompileRequest* requestImpl = nullptr;
Slang::RefPtr<Slang::ConfigurableSharedLibraryLoader> sharedLibraryLoader;
@@ -313,7 +313,7 @@ struct OptionsParser
if (sourceLanguage == SLANG_SOURCE_LANGUAGE_UNKNOWN)
{
- requestImpl->mSink.diagnose(SourceLoc(), Diagnostics::cannotDeduceSourceLanguage, inPath);
+ requestImpl->getSink()->diagnose(SourceLoc(), Diagnostics::cannotDeduceSourceLanguage, inPath);
return SLANG_FAIL;
}
@@ -425,9 +425,9 @@ struct OptionsParser
{
// Copy some state out of the current request, in case we've been called
// after some other initialization has been performed.
- flags = requestImpl->compileFlags;
+ flags = requestImpl->getFrontEndReq()->compileFlags;
- DiagnosticSink* sink = &requestImpl->mSink;
+ DiagnosticSink* sink = requestImpl->getSink();
SlangMatrixLayoutMode defaultMatrixLayoutMode = SLANG_MATRIX_LAYOUT_MODE_UNKNOWN;
@@ -450,23 +450,24 @@ struct OptionsParser
}
else if(argStr == "-dump-ir" )
{
- requestImpl->shouldDumpIR = true;
+ requestImpl->getFrontEndReq()->shouldDumpIR = true;
+ requestImpl->getBackEndReq()->shouldDumpIR = true;
}
else if (argStr == "-serial-ir")
{
- requestImpl->useSerialIRBottleneck = true;
+ requestImpl->getFrontEndReq()->useSerialIRBottleneck = true;
}
else if (argStr == "-verbose-paths")
{
- requestImpl->mSink.flags |= DiagnosticSink::Flag::VerbosePath;
+ requestImpl->getSink()->flags |= DiagnosticSink::Flag::VerbosePath;
}
else if (argStr == "-verify-debug-serial-ir")
{
- requestImpl->verifyDebugSerialization = true;
+ requestImpl->getFrontEndReq()->verifyDebugSerialization = true;
}
else if(argStr == "-validate-ir" )
{
- requestImpl->shouldValidateIR = true;
+ requestImpl->getFrontEndReq()->shouldValidateIR = true;
}
else if(argStr == "-skip-codegen" )
{
@@ -1222,18 +1223,7 @@ struct OptionsParser
// Now that we've diagnosed the output paths, we can add them
// to the compile request at the appropriate locations.
//
- // We start by allocating the arrays for per-entry-point output
- // paths on each of the requested targets.
- //
- for(auto rawTarget : rawTargets)
- {
- auto targetID = rawTarget.targetID;
- auto targetReq = requestImpl->targets[targetID];
-
- targetReq->entryPointOutputPaths.SetSize(rawEntryPoints.Count());
- }
-
- // Consider the output files specified via `-o` and try to figure
+ // We will consider the output files specified via `-o` and try to figure
// out how to deal with them.
//
for(auto& rawOutput : rawOutputs)
@@ -1242,18 +1232,26 @@ struct OptionsParser
if(rawOutput.entryPointIndex == -1) continue;
auto targetID = rawTargets[rawOutput.targetIndex].targetID;
- auto entryPointID = rawEntryPoints[rawOutput.entryPointIndex].entryPointID;
+ Int entryPointID = rawEntryPoints[rawOutput.entryPointIndex].entryPointID;
+
+ auto target = requestImpl->getLinkage()->targets[targetID];
+ auto entryPointReq = requestImpl->getFrontEndReq()->getEntryPointReqs()[entryPointID];
- auto targetReq = requestImpl->targets[targetID];
+ RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo;
+ if( !requestImpl->targetInfos.TryGetValue(target, targetInfo) )
+ {
+ targetInfo = new EndToEndCompileRequest::TargetInfo();
+ requestImpl->targetInfos[target] = targetInfo;
+ }
- if(targetReq->entryPointOutputPaths[entryPointID].Length())
+ String outputPath;
+ if( targetInfo->entryPointOutputPaths.ContainsKey(entryPointID) )
{
- auto entryPointReq = requestImpl->entryPoints[entryPointID];
- sink->diagnose(SourceLoc(), Diagnostics::duplicateOutputPathsForEntryPointAndTarget, entryPointReq->name, targetReq->target);
+ sink->diagnose(SourceLoc(), Diagnostics::duplicateOutputPathsForEntryPointAndTarget, entryPointReq->getName(), target->getTarget());
}
else
{
- targetReq->entryPointOutputPaths[entryPointID] = rawOutput.path;
+ targetInfo->entryPointOutputPaths[entryPointID] = rawOutput.path;
}
}
@@ -1272,16 +1270,16 @@ SlangResult parseOptions(
int argc,
char const* const* argv)
{
- Slang::CompileRequest* compileRequest = (Slang::CompileRequest*) compileRequestIn;
+ Slang::EndToEndCompileRequest* compileRequest = (Slang::EndToEndCompileRequest*) compileRequestIn;
OptionsParser parser;
parser.compileRequest = compileRequestIn;
parser.requestImpl = compileRequest;
- parser.session = (SlangSession*)compileRequest->mSession;
+ parser.session = (SlangSession*)compileRequest->getSession();
Result res = parser.parse(argc, argv);
- DiagnosticSink* sink = &compileRequest->mSink;
+ DiagnosticSink* sink = compileRequest->getSink();
if (sink->GetErrorCount() > 0)
{
// Put the errors in the diagnostic
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index f0abfe31c..56c5d7c1d 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -309,9 +309,6 @@ struct ParameterInfo : RefObject
ParameterBindingInfo bindingInfo[kLayoutResourceKindCount];
- // The next parameter that has the same name...
- ParameterInfo* nextOfSameName;
-
// The translation unit this parameter is specific to, if any
TranslationUnitRequest* translationUnit = nullptr;
@@ -335,8 +332,22 @@ struct EntryPointParameterBindingContext
// across all translation units
struct SharedParameterBindingContext
{
- // The base compile request
- CompileRequest* compileRequest;
+ SharedParameterBindingContext(
+ LayoutRulesFamilyImpl* defaultLayoutRules,
+ ProgramLayout* programLayout,
+ TargetRequest* targetReq,
+ DiagnosticSink* sink)
+ : defaultLayoutRules(defaultLayoutRules)
+ , programLayout(programLayout)
+ , targetRequest(targetReq)
+ , m_sink(sink)
+ {
+ }
+
+ DiagnosticSink* m_sink = nullptr;
+
+ // The program that we are laying out
+// Program* program = nullptr;
// The target request that is triggering layout
//
@@ -365,20 +376,18 @@ struct SharedParameterBindingContext
UInt defaultSpace = 0;
TargetRequest* getTargetRequest() { return targetRequest; }
+ DiagnosticSink* getSink() { return m_sink; }
};
static DiagnosticSink* getSink(SharedParameterBindingContext* shared)
{
- return &shared->compileRequest->mSink;
+ return shared->getSink();
}
// State that might be specific to a single translation unit
// or event to an entry point.
struct ParameterBindingContext
{
- // The translation unit we are processing right now
- TranslationUnitRequest* translationUnit;
-
// All the shared state needs to be available
SharedParameterBindingContext* shared;
@@ -386,7 +395,7 @@ struct ParameterBindingContext
// the resource usage of shader parameters.
TypeLayoutContext layoutContext;
- // A dictionary to accellerate looking up parameters by name
+ // A dictionary to accelerate looking up parameters by name
Dictionary<Name*, ParameterInfo*> mapNameToParameterInfo;
// What stage (if any) are we compiling for?
@@ -395,9 +404,6 @@ struct ParameterBindingContext
// The entry point that is being processed right now.
EntryPointLayout* entryPointLayout = nullptr;
- // The source language we are trying to use
- SourceLanguage sourceLanguage;
-
TargetRequest* getTargetRequest() { return shared->getTargetRequest(); }
LayoutRulesFamilyImpl* getRulesFamily() { return layoutContext.getRulesFamily(); }
};
@@ -1217,6 +1223,10 @@ static void collectGlobalScopeParameter(
// If that is the case, we want to re-use the same `VarLayout`
// across both parameters.
//
+ // TODO: This logic currently detects *any* global-scope parameters
+ // with matching names, but it should eventually be narrowly
+ // scoped so that it only applies to parameters from unnamed modules.
+ //
// First we look for an existing entry matching the name
// of this parameter:
auto parameterName = getReflectionName(varDecl);
@@ -2477,7 +2487,7 @@ static ParameterBindingAndKindInfo maybeAllocateConstantBufferBinding(
///
static void collectEntryPointParameters(
ParameterBindingContext* context,
- EntryPointRequest* entryPoint,
+ EntryPoint* entryPoint,
SubstitutionSet typeSubst)
{
DeclRef<FuncDecl> entryPointFuncDeclRef = entryPoint->getFuncDeclRef();
@@ -2486,7 +2496,7 @@ static void collectEntryPointParameters(
// the `EntryPointLayout` object here.
//
RefPtr<EntryPointLayout> entryPointLayout = new EntryPointLayout();
- entryPointLayout->profile = entryPoint->profile;
+ entryPointLayout->profile = entryPoint->getProfile();
entryPointLayout->entryPoint = entryPointFuncDeclRef.getDecl();
// The entry point layout must be added to the output
@@ -2501,10 +2511,10 @@ static void collectEntryPointParameters(
// Note: this isn't really the best place for this logic to sit,
// but it is the simplest place where we have a direct correspondence
- // between a single `EntryPointRequest` and its matching `EntryPointLayout`,
+ // between a single `EntryPoint` and its matching `EntryPointLayout`,
// so we'll use it.
//
- for( auto taggedUnionType : entryPoint->taggedUnionTypes )
+ for( auto taggedUnionType : entryPoint->getTaggedUnionTypes() )
{
SLANG_ASSERT(taggedUnionType);
auto substType = taggedUnionType->Substitute(typeSubst).as<Type>();
@@ -2645,59 +2655,9 @@ static void collectEntryPointParameters(
}
}
-// When doing parameter binding for global-scope stuff in GLSL,
-// we may need to know what stage we are compiling for, so that
-// we can handle special cases appropriately (e.g., "arrayed"
-// inputs and outputs).
-static Stage
-inferStageForTranslationUnit(
- TranslationUnitRequest* translationUnit)
-{
- // In the specific case where we are compiling GLSL input,
- // and have only a single entry point, use the stage
- // of the entry point.
- //
- // TODO: now that we've dropped official GLSL support,
- // we probably should drop this as well.
- //
- if( translationUnit->sourceLanguage == SourceLanguage::GLSL )
- {
- if( translationUnit->entryPoints.Count() == 1 )
- {
- return translationUnit->entryPoints[0]->getStage();
- }
- }
-
- return Stage::Unknown;
-}
-
-static void collectModuleParameters(
- ParameterBindingContext* inContext,
- ModuleDecl* module)
-{
- // Each loaded module provides a separate (logical) namespace for
- // parameters, so that two parameters with the same name, in
- // distinct modules, should yield different bindings.
- //
- ParameterBindingContext contextData = *inContext;
- auto context = &contextData;
-
- context->translationUnit = nullptr;
-
- context->stage = Stage::Unknown;
-
- // All imported modules are implicitly Slang code
- context->sourceLanguage = SourceLanguage::Slang;
-
- // A loaded module cannot define entry points that
- // we'll expose (for now), so we just need to
- // consider global-scope parameters.
- collectGlobalScopeParameters(context, module);
-}
-
static void collectParameters(
ParameterBindingContext* inContext,
- CompileRequest* request)
+ Program* program)
{
// All of the parameters in translation units directly
// referenced in the compile request are part of one
@@ -2707,29 +2667,21 @@ static void collectParameters(
ParameterBindingContext contextData = *inContext;
auto context = &contextData;
- for( auto& translationUnit : request->translationUnits )
+ for(RefPtr<Module> module : program->getModuleDependencies())
{
- context->translationUnit = translationUnit;
- context->stage = inferStageForTranslationUnit(translationUnit.Ptr());
- context->sourceLanguage = translationUnit->sourceLanguage;
+ context->stage = Stage::Unknown;
// First look at global-scope parameters
- collectGlobalScopeParameters(context, translationUnit->SyntaxNode.Ptr());
-
- // Next consider parameters for entry points
- for( auto& entryPoint : translationUnit->entryPoints )
- {
- context->stage = entryPoint->getStage();
- collectEntryPointParameters(context, entryPoint.Ptr(), SubstitutionSet());
- }
- context->entryPointLayout = nullptr;
+ collectGlobalScopeParameters(context, module->getModuleDecl());
}
- // Now collect parameters from loaded modules
- for (auto& loadedModule : request->loadedModulesList)
+ // Next consider parameters for entry points
+ for(auto entryPoint : program->getEntryPoints())
{
- collectModuleParameters(context, loadedModule->moduleDecl.Ptr());
+ context->stage = entryPoint->getStage();
+ collectEntryPointParameters(context, entryPoint, SubstitutionSet());
}
+ context->entryPointLayout = nullptr;
}
/// Emit a diagnostic about a uniform parameter at global scope.
@@ -2770,41 +2722,40 @@ static int _calcTotalNumUsedRegistersForLayoutResourceKind(ParameterBindingConte
return numUsed;
}
-void generateParameterBindings(
- TargetRequest* targetReq)
+RefPtr<ProgramLayout> generateParameterBindings(
+ TargetProgram* targetProgram,
+ DiagnosticSink* sink)
{
- CompileRequest* compileReq = targetReq->compileRequest;
+ auto program = targetProgram->getProgram();
+ auto targetReq = targetProgram->getTargetReq();
+
+ RefPtr<ProgramLayout> programLayout = new ProgramLayout();
+ programLayout->targetProgram = targetProgram;
// Try to find rules based on the selected code-generation target
- auto layoutContext = getInitialLayoutContextForTarget(targetReq);
+ auto layoutContext = getInitialLayoutContextForTarget(targetReq, programLayout);
// If there was no target, or there are no rules for the target,
// then bail out here.
if (!layoutContext.rules)
- return;
-
- RefPtr<ProgramLayout> programLayout = new ProgramLayout();
- programLayout->targetRequest = targetReq;
-
- targetReq->layout = programLayout;
+ return nullptr;
// Create a context to hold shared state during the process
// of generating parameter bindings
- SharedParameterBindingContext sharedContext;
- sharedContext.compileRequest = compileReq;
- sharedContext.defaultLayoutRules = layoutContext.getRulesFamily();
- sharedContext.programLayout = programLayout;
- sharedContext.targetRequest = targetReq;
+ SharedParameterBindingContext sharedContext(
+ layoutContext.getRulesFamily(),
+ programLayout,
+ targetReq,
+ sink);
// Create a sub-context to collect parameters that get
// declared into the global scope
ParameterBindingContext context;
context.shared = &sharedContext;
- context.translationUnit = nullptr;
context.layoutContext = layoutContext;
// Walk through AST to discover all the parameters
- collectParameters(&context, compileReq);
+ collectParameters(&context, program);
// Now walk through the parameters to generate initial binding information
for( auto& parameter : sharedContext.parameters )
@@ -2978,17 +2929,35 @@ void generateParameterBindings(
const int numShaderRecordRegs = _calcTotalNumUsedRegistersForLayoutResourceKind(&context, LayoutResourceKind::ShaderRecord);
if (numShaderRecordRegs > 1)
{
- compileReq->mSink.diagnose(SourceLoc(), Diagnostics::tooManyShaderRecordConstantBuffers, numShaderRecordRegs);
- return;
+ sink->diagnose(SourceLoc(), Diagnostics::tooManyShaderRecordConstantBuffers, numShaderRecordRegs);
}
}
+ return programLayout;
+}
+
+ProgramLayout* TargetProgram::getOrCreateLayout(DiagnosticSink* sink)
+{
+ if( !m_layout )
+ {
+ m_layout = generateParameterBindings(this, sink);
+ }
+ return m_layout;
+}
+
+void generateParameterBindings(
+ Program* program,
+ TargetRequest* targetReq,
+ DiagnosticSink* sink)
+{
+ program->getTargetProgram(targetReq)->getOrCreateLayout(sink);
}
RefPtr<ProgramLayout> specializeProgramLayout(
TargetRequest* targetReq,
ProgramLayout* oldProgramLayout,
- SubstitutionSet typeSubst)
+ SubstitutionSet typeSubst,
+ DiagnosticSink* sink)
{
// The goal of the layout specialization step is to take an existing `ProgramLayout`,
// and add a layout to any parameter(s) that could not be laid out previously, because
@@ -3006,7 +2975,7 @@ RefPtr<ProgramLayout> specializeProgramLayout(
RefPtr<ProgramLayout> newProgramLayout;
newProgramLayout = new ProgramLayout();
- newProgramLayout->targetRequest = targetReq;
+ newProgramLayout->targetProgram = oldProgramLayout->targetProgram;
newProgramLayout->globalGenericParams = oldProgramLayout->globalGenericParams;
// The basic idea will be to iterate over the parameters in the old layout,
@@ -3020,18 +2989,17 @@ RefPtr<ProgramLayout> specializeProgramLayout(
// We will use the same kind of context type as the original parameter binding
// step did, so we initialize its state here:
- auto layoutContext = getInitialLayoutContextForTarget(targetReq);
+ auto layoutContext = getInitialLayoutContextForTarget(targetReq, newProgramLayout);
SLANG_ASSERT(layoutContext.rules);
- SharedParameterBindingContext sharedContext;
- sharedContext.compileRequest = targetReq->compileRequest;
- sharedContext.defaultLayoutRules = layoutContext.getRulesFamily();
- sharedContext.programLayout = newProgramLayout;
- sharedContext.targetRequest = targetReq;
+ SharedParameterBindingContext sharedContext(
+ layoutContext.getRulesFamily(),
+ newProgramLayout,
+ targetReq,
+ sink);
ParameterBindingContext context;
context.shared = &sharedContext;
- context.translationUnit = nullptr;
context.layoutContext = layoutContext;
// We will also need state for laying out any global-scope parameters
@@ -3119,12 +3087,9 @@ RefPtr<ProgramLayout> specializeProgramLayout(
// parameter, the layout of its parameter list strictly follows
// the declaration order.
//
- for (auto & translationUnit : targetReq->compileRequest->translationUnits)
+ for( auto entryPoint : oldProgramLayout->getProgram()->getEntryPoints() )
{
- for (auto & entryPoint : translationUnit->entryPoints)
- {
- collectEntryPointParameters(&context, entryPoint, typeSubst);
- }
+ collectEntryPointParameters(&context, entryPoint, typeSubst);
context.entryPointLayout = nullptr;
}
diff --git a/source/slang/parameter-binding.h b/source/slang/parameter-binding.h
index eb093821f..82b114021 100644
--- a/source/slang/parameter-binding.h
+++ b/source/slang/parameter-binding.h
@@ -8,6 +8,7 @@
namespace Slang {
+class Program;
class TargetRequest;
// The parameter-binding interface is responsible for assigning
@@ -24,7 +25,9 @@ class TargetRequest;
// of the program.
void generateParameterBindings(
- TargetRequest* targetReq);
+ Program* program,
+ TargetRequest* targetReq,
+ DiagnosticSink* sink);
}
diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp
index e2085eb7d..3abc47ede 100644
--- a/source/slang/parser.cpp
+++ b/source/slang/parser.cpp
@@ -8,7 +8,7 @@
namespace Slang
{
- // Pre-declare
+ // pre-declare
static Name* getName(Parser* parser, String const& text);
// Helper class useful to build a list of modifiers.
@@ -79,7 +79,11 @@ namespace Slang
class Parser
{
public:
- TranslationUnitRequest* translationUnit;
+ NamePool* namePool;
+ SourceLanguage sourceLanguage;
+
+ NamePool* getNamePool() { return namePool; }
+ SourceLanguage getSourceLanguage() { return sourceLanguage; }
int anonymousCounter = 0;
@@ -124,27 +128,26 @@ namespace Slang
currentScope = currentScope->parent;
}
Parser(
+ Session* session,
TokenSpan const& _tokens,
DiagnosticSink * sink,
RefPtr<Scope> const& outerScope)
: tokenReader(_tokens)
, sink(sink)
, outerScope(outerScope)
+ , m_session(session)
{}
Parser(const Parser & other) = default;
- Session* getSession()
- {
- return translationUnit->compileRequest->mSession;
- }
- RefPtr<ModuleDecl> Parse();
+ Session* m_session = nullptr;
+ Session* getSession() { return m_session; }
+
Token ReadToken();
Token ReadToken(TokenType type);
Token ReadToken(const char * string);
bool LookAheadToken(TokenType type, int offset = 0);
bool LookAheadToken(const char * string, int offset = 0);
void parseSourceFile(ModuleDecl* program);
- RefPtr<ModuleDecl> ParseProgram();
RefPtr<Decl> ParseStruct();
RefPtr<ClassDecl> ParseClass();
RefPtr<Stmt> ParseStatement();
@@ -578,11 +581,6 @@ namespace Slang
return false;
}
- RefPtr<ModuleDecl> Parser::Parse()
- {
- return ParseProgram();
- }
-
RefPtr<RefObject> ParseTypeDef(Parser* parser, void* /*userData*/)
{
RefPtr<TypeDefDecl> typeDefDecl = new TypeDefDecl();
@@ -694,7 +692,7 @@ namespace Slang
Token token(TokenType::Identifier, scopedIdentifier, scopedIdSourceLoc);
// Get the name pool
- auto namePool = parser->translationUnit->compileRequest->getNamePool();
+ auto namePool = parser->getNamePool();
// Since it's an Identifier have to set the name.
token.ptrValue = namePool->getName(token.Content);
@@ -910,7 +908,7 @@ namespace Slang
static Name* getName(Parser* parser, String const& text)
{
- return parser->translationUnit->compileRequest->getNamePool()->getName(text);
+ return parser->getNamePool()->getName(text);
}
static NameLoc expectIdentifier(Parser* parser)
@@ -1859,7 +1857,7 @@ namespace Slang
}
// GLSL allows `[]` directly in a type specifier
- if (parser->translationUnit->sourceLanguage == SourceLanguage::GLSL)
+ if (parser->getSourceLanguage() == SourceLanguage::GLSL)
{
typeExpr = parsePostfixTypeSuffix(parser, typeExpr);
}
@@ -1929,7 +1927,7 @@ namespace Slang
// Just as a safety net, only apply this logic for
// a file that is being passed in as "true" Slang code.
//
- if(parser->translationUnit->sourceLanguage == SourceLanguage::Slang)
+ if(parser->getSourceLanguage() == SourceLanguage::Slang)
{
if(typeSpec.decl)
{
@@ -2313,171 +2311,6 @@ namespace Slang
return ParseHLSLBufferDecl(parser, "TextureBuffer");
}
- static void removeModifier(
- Modifiers& modifiers,
- RefPtr<Modifier> modifier)
- {
- RefPtr<Modifier>* link = &modifiers.first;
- while (*link)
- {
- if (*link == modifier)
- {
- *link = (*link)->next;
- return;
- }
-
- link = &(*link)->next;
- }
- }
-
- static RefPtr<Decl> parseGLSLBlockDecl(
- Parser* parser,
- Modifiers& modifiers)
- {
- // An GLSL block like this:
- //
- // uniform Foo { int a; float b; } foo;
- //
- // is treated as syntax sugar for a type declaration
- // and then a global variable declaration using that type:
- //
- // struct $anonymous { int a; float b; };
- // Block<$anonymous> foo;
- //
- // where `$anonymous` is a fresh name.
- //
- // If a "local name" like `foo` is not given, then
- // we make the declaration "transparent" so that lookup
- // will see through it to the members inside.
-
-
- SourceLoc pos = parser->tokenReader.PeekLoc();
-
- // The initial name before the `{` is only supposed
- // to be made visible to reflection
- auto reflectionNameToken = parser->ReadToken(TokenType::Identifier);
-
- // Look at the qualifiers present on the block to decide what kind
- // of block we are looking at. Also *remove* those qualifiers so
- // that they don't interfere with downstream work.
- String blockWrapperTypeName;
- if( auto uniformMod = modifiers.findModifier<HLSLUniformModifier>() )
- {
- removeModifier(modifiers, uniformMod);
- blockWrapperTypeName = "ConstantBuffer";
- }
- else if( auto inMod = modifiers.findModifier<InModifier>() )
- {
- removeModifier(modifiers, inMod);
- blockWrapperTypeName = "__GLSLInputParameterGroup";
- }
- else if( auto outMod = modifiers.findModifier<OutModifier>() )
- {
- removeModifier(modifiers, outMod);
- blockWrapperTypeName = "__GLSLOutputParameterGroup";
- }
- else if( auto bufferMod = modifiers.findModifier<GLSLBufferModifier>() )
- {
- removeModifier(modifiers, bufferMod);
- blockWrapperTypeName = "__GLSLShaderStorageBuffer";
- }
- else
- {
- // Unknown case: just map to a constant buffer and hope for the best
- blockWrapperTypeName = "ConstantBuffer";
- }
-
- // We are going to represent each buffer as a pair of declarations.
- // The first is a type declaration that holds all the members, while
- // the second is a variable declaration that uses the buffer type.
- RefPtr<StructDecl> blockDataTypeDecl = new StructDecl();
- RefPtr<VarDecl> blockVarDecl = new VarDecl();
-
- addModifier(blockDataTypeDecl, new ImplicitParameterGroupElementTypeModifier());
- addModifier(blockVarDecl, new ImplicitParameterGroupVariableModifier());
-
- // Attach the reflection name to the block so we can use it
- auto reflectionNameModifier = new ParameterGroupReflectionName();
- reflectionNameModifier->nameAndLoc = NameLoc(reflectionNameToken);
- addModifier(blockVarDecl, reflectionNameModifier);
-
- // Both declarations will have a location that points to the name
- parser->FillPosition(blockDataTypeDecl.Ptr());
- parser->FillPosition(blockVarDecl.Ptr());
-
- // Generate a unique name for the data type
- blockDataTypeDecl->nameAndLoc.name = generateName(parser, "ParameterGroup_" + String(reflectionNameToken.Content));
-
- // TODO(tfoley): We end up constructing unchecked syntax here that
- // is expected to type check into the right form, but it might be
- // cleaner to have a more explicit desugaring pass where we parse
- // these constructs directly into the AST and *then* desugar them.
-
- // Construct a type expression to reference the buffer data type
- auto blockDataTypeExpr = new VarExpr();
- blockDataTypeExpr->loc = blockDataTypeDecl->loc;
- blockDataTypeExpr->name = blockDataTypeDecl->getName();
- blockDataTypeExpr->scope = parser->currentScope.Ptr();
-
- // Construct a type exrpession to reference the type constructor
- auto blockWrapperTypeExpr = new VarExpr();
- blockWrapperTypeExpr->loc = pos;
- blockWrapperTypeExpr->name = getName(parser, blockWrapperTypeName);
- // Always need to look this up in the outer scope,
- // so that it won't collide with, e.g., a local variable called `ConstantBuffer`
- blockWrapperTypeExpr->scope = parser->outerScope;
-
- // Construct a type expression that represents the type for the variable,
- // which is the wrapper type applied to the data type
- auto blockVarTypeExpr = new GenericAppExpr();
- blockVarTypeExpr->loc = blockVarDecl->loc;
- blockVarTypeExpr->FunctionExpr = blockWrapperTypeExpr;
- blockVarTypeExpr->Arguments.Add(blockDataTypeExpr);
-
- blockVarDecl->type.exp = blockVarTypeExpr;
-
- // The declarations in the body belong to the data type.
- parseAggTypeDeclBody(parser, blockDataTypeDecl.Ptr());
-
- if( parser->LookAheadToken(TokenType::Identifier) )
- {
- // The user gave an explicit name to the block,
- // so we need to use that as our variable name
- blockVarDecl->nameAndLoc = NameLoc(parser->ReadToken(TokenType::Identifier));
-
- // TODO: in this case we make actually have a more complex
- // declarator, including `[]` brackets.
- }
- else
- {
- // synthesize a dummy name
- blockVarDecl->nameAndLoc.name = generateName(parser, "parameterGroup_" + String(reflectionNameToken.Content));
-
- // Otherwise we have a transparent declaration, similar
- // to an HLSL `cbuffer`
- auto transparentModifier = new TransparentModifier();
- transparentModifier->loc = pos;
- addModifier(blockVarDecl, transparentModifier);
- }
-
- // Expect a trailing `;`
- parser->ReadToken(TokenType::Semicolon);
-
- // Because we are constructing two declarations, we have a thorny
- // issue that were are only supposed to return one.
- // For now we handle this by adding the type declaration to
- // the current scope manually, and then returning the variable
- // declaration.
- //
- // Note: this means that any modifiers that have already been parsed
- // will get attached to the variable declaration, not the type.
- // There might be cases where we need to shuffle things around.
-
- AddMember(parser->currentScope, blockDataTypeDecl);
-
- return blockVarDecl;
- }
-
static void parseOptionalInheritanceClause(Parser* parser, AggTypeDeclBase* decl)
{
if (AdvanceIf(parser, TokenType::Colon))
@@ -3020,27 +2853,8 @@ namespace Slang
//
// - A keyword-based declaration (e.g., `cbuffer ...`)
// - The beginning of a type in a declarator-based declaration (e.g., `int ...`)
- // - A GLSL block declaration (e.g., `uniform Foo { ... }`)
-
- // Let's deal with the GLSL block case first. This is something like:
- //
- // uniform Foo { ... };
- //
- // The `uniform` keyword has already been parsed as a modifier,
- // so the identifier we are looking at is `Foo`. If the token
- // after that is `{`, we assume this is a block.
- //
- // Of course, we only want to allow this syntax when parsing GLSL...
- if (parser->translationUnit->sourceLanguage == SourceLanguage::GLSL)
- {
- if( parser->LookAheadToken(TokenType::LBrace, 1) )
- {
- decl = parseGLSLBlockDecl(parser, modifiers);
- break;
- }
- }
- // Next we will check whether we can use the identifier token
+ // First we will check whether we can use the identifier token
// as a declaration keyword and parse a declaration using
// its associated callback:
RefPtr<Decl> parsedDecl;
@@ -3184,15 +2998,6 @@ namespace Slang
currentScope = nullptr;
}
- RefPtr<ModuleDecl> Parser::ParseProgram()
- {
- RefPtr<ModuleDecl> program = new ModuleDecl();
-
- parseSourceFile(program.Ptr());
-
- return program;
- }
-
RefPtr<Decl> Parser::ParseStruct()
{
RefPtr<StructDecl> rs = new StructDecl();
@@ -3591,7 +3396,7 @@ namespace Slang
// parsing HLSL code.
//
- bool brokenScoping = translationUnit->sourceLanguage == SourceLanguage::HLSL;
+ bool brokenScoping = getSourceLanguage() == SourceLanguage::HLSL;
// We will create a distinct syntax node class for the unscoped
// case, just so that we can correctly handle it in downstream
@@ -4439,14 +4244,18 @@ namespace Slang
return parsePrefixExpr(this);
}
- RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit,
+ RefPtr<Expr> parseTypeFromSourceFile(
+ Session* session,
TokenSpan const& tokens,
DiagnosticSink* sink,
- RefPtr<Scope> const& outerScope)
+ RefPtr<Scope> const& outerScope,
+ NamePool* namePool,
+ SourceLanguage sourceLanguage)
{
- Parser parser(tokens, sink, outerScope);
- parser.translationUnit = translationUnit;
+ Parser parser(session, tokens, sink, outerScope);
parser.currentScope = outerScope;
+ parser.namePool = namePool;
+ parser.sourceLanguage = sourceLanguage;
return parser.ParseType();
}
@@ -4457,12 +4266,11 @@ namespace Slang
DiagnosticSink* sink,
RefPtr<Scope> const& outerScope)
{
- Parser parser(tokens, sink, outerScope);
-
- parser.translationUnit = translationUnit;
-
+ Parser parser(translationUnit->getSession(), tokens, sink, outerScope);
+ parser.namePool = translationUnit->getNamePool();
+ parser.sourceLanguage = translationUnit->sourceLanguage;
- return parser.parseSourceFile(translationUnit->SyntaxNode.Ptr());
+ return parser.parseSourceFile(translationUnit->getModuleDecl());
}
static void addBuiltinSyntaxImpl(
diff --git a/source/slang/parser.h b/source/slang/parser.h
index 785b6e345..abad902da 100644
--- a/source/slang/parser.h
+++ b/source/slang/parser.h
@@ -14,10 +14,13 @@ namespace Slang
DiagnosticSink* sink,
RefPtr<Scope> const& outerScope);
- RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit,
+ RefPtr<Expr> parseTypeFromSourceFile(
+ Session* session,
TokenSpan const& tokens,
DiagnosticSink* sink,
- RefPtr<Scope> const& outerScope);
+ RefPtr<Scope> const& outerScope,
+ NamePool* namePool,
+ SourceLanguage sourceLanguage);
RefPtr<ModuleDecl> populateBaseLanguageModule(
Session* session,
diff --git a/source/slang/preprocessor.cpp b/source/slang/preprocessor.cpp
index c6c438ef6..103db7dcb 100644
--- a/source/slang/preprocessor.cpp
+++ b/source/slang/preprocessor.cpp
@@ -194,27 +194,18 @@ struct Preprocessor
// represent end-of-input situations.
Token endOfFileToken;
- // The translation unit that is being parsed
- TranslationUnitRequest* translationUnit;
+ /// The linkage the provides the context for preprocessing
+ Linkage* linkage = nullptr;
+
+ /// The module, if any, that the preprocessed result will belong to
+ Module* parentModule = nullptr;
// The unique identities of any paths that have issued `#pragma once` directives to
// stop them from being included again.
HashSet<String> pragmaOnceUniqueIdentities;
- TranslationUnitRequest* getTranslationUnit()
- {
- return translationUnit;
- }
-
- ModuleDecl* getSyntax()
- {
- return getTranslationUnit()->SyntaxNode.Ptr();
- }
-
- CompileRequest* getCompileRequest()
- {
- return getTranslationUnit()->compileRequest;
- }
+ NamePool* getNamePool() { return linkage->getNamePool(); }
+ SourceManager* getSourceManager() { return linkage->getSourceManager(); }
};
// Convenience routine to access the diagnostic sink
@@ -255,11 +246,6 @@ static void destroyInputStream(Preprocessor* /*preprocessor*/, PreprocessorInput
delete inputStream;
}
-static NamePool* getNamePool(Preprocessor* preprocessor)
-{
- return preprocessor->translationUnit->compileRequest->getNamePool();
-}
-
// Create an input stream to represent a pre-tokenized input file.
// TODO(tfoley): pre-tokenizing files isn't going to work in the long run.
static PreprocessorInputStream* CreateInputStreamForSource(
@@ -272,7 +258,7 @@ static PreprocessorInputStream* CreateInputStreamForSource(
initializePrimaryInputStream(preprocessor, inputStream);
// initialize the embedded lexer so that it can generate a token stream
- inputStream->lexer.initialize(sourceView, GetSink(preprocessor), getNamePool(preprocessor), memoryArena);
+ inputStream->lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), memoryArena);
inputStream->token = inputStream->lexer.lexToken();
return inputStream;
@@ -836,7 +822,7 @@ top:
// Now re-lex the input
- SourceManager* sourceManager = preprocessor->getCompileRequest()->getSourceManager();
+ SourceManager* sourceManager = preprocessor->getSourceManager();
// We create a dummy file to represent the token-paste operation
PathInfo pathInfo = PathInfo::makeTokenPaste();
@@ -845,7 +831,7 @@ top:
SourceView* sourceView = sourceManager->createSourceView(sourceFile, nullptr);
Lexer lexer;
- lexer.initialize(sourceView, GetSink(preprocessor), getNamePool(preprocessor), sourceManager->getMemoryArena());
+ lexer.initialize(sourceView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena());
SimpleTokenInputStream* inputStream = new SimpleTokenInputStream();
initializeInputStream(preprocessor, inputStream);
@@ -1564,7 +1550,7 @@ static void HandleEndIfDirective(PreprocessorDirectiveContext* context)
// we expect it.
//
// Most directives do not need to call this directly, since we have
-// a catch-all case in the main `HandleDirective()` funciton.
+// a catch-all case in the main `HandleDirective()` function.
// The `#include` case will call it directly to avoid complications
// when it switches the input stream.
static void expectEndOfDirective(PreprocessorDirectiveContext* context)
@@ -1589,6 +1575,31 @@ static void expectEndOfDirective(PreprocessorDirectiveContext* context)
AdvanceRawToken(context->preprocessor);
}
+ /// Read a file in the context of handling a preprocessor directive
+static SlangResult readFile(
+ PreprocessorDirectiveContext* context,
+ String const& path,
+ ISlangBlob** outBlob)
+{
+ // The actual file loading will be handled by the file system
+ // associated with the parent linkage.
+ //
+ auto linkage = context->preprocessor->linkage;
+ auto fileSystemExt = linkage->getFileSystemExt();
+ SLANG_RETURN_ON_FAIL(fileSystemExt->loadFile(path.Buffer(), outBlob));
+
+ // If we are running the preprocessor as part of compiling a
+ // specific module, then we must keep track of the file we've
+ // read as yet another file that the module will depend on.
+ //
+ if(auto module = context->preprocessor->parentModule)
+ {
+ module->addFilePathDependency(path);
+ }
+
+ return SLANG_OK;
+}
+
// Handle a `#include` directive
static void HandleIncludeDirective(PreprocessorDirectiveContext* context)
{
@@ -1603,7 +1614,7 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context)
auto directiveLoc = GetDirectiveLoc(context);
- PathInfo includedFromPathInfo = context->preprocessor->translationUnit->compileRequest->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual);
+ PathInfo includedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual);
IncludeHandler* includeHandler = context->preprocessor->includeHandler;
if (!includeHandler)
@@ -1644,7 +1655,7 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context)
// Push the new file onto our stack of input streams
// TODO(tfoley): check if we have made our include stack too deep
- auto sourceManager = context->preprocessor->getCompileRequest()->getSourceManager();
+ auto sourceManager = context->preprocessor->getSourceManager();
// See if this an already loaded source file
SourceFile* sourceFile = sourceManager->findSourceFileRecursively(filePathInfo.uniqueIdentity);
@@ -1652,7 +1663,7 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context)
if (!sourceFile)
{
ComPtr<ISlangBlob> foundSourceBlob;
- if (SLANG_FAILED(includeHandler->readFile(filePathInfo.foundPath, foundSourceBlob.writeRef())))
+ if (SLANG_FAILED(readFile(context, filePathInfo.foundPath, foundSourceBlob.writeRef())))
{
GetSink(context)->diagnose(pathToken.loc, Diagnostics::includeFailed, path);
return;
@@ -1843,7 +1854,7 @@ static void HandleLineDirective(PreprocessorDirectiveContext* context)
return;
}
- auto sourceManager = context->preprocessor->translationUnit->compileRequest->getSourceManager();
+ auto sourceManager = context->preprocessor->getSourceManager();
String file;
if (PeekTokenType(context) == TokenType::EndOfDirective)
@@ -1891,7 +1902,7 @@ SLANG_PRAGMA_DIRECTIVE_CALLBACK(handlePragmaOnceDirective)
// We are using the 'uniqueIdentity' as determined by the ISlangFileSystemEx interface to determine file identities.
auto directiveLoc = GetDirectiveLoc(context);
- auto issuedFromPathInfo = context->preprocessor->translationUnit->compileRequest->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual);
+ auto issuedFromPathInfo = context->preprocessor->getSourceManager()->getPathInfo(directiveLoc, SourceLocType::Actual);
// Must have uniqueIdentity for a #pragma once to work
if (!issuedFromPathInfo.hasUniqueIdentity())
@@ -1962,82 +1973,6 @@ static void HandlePragmaDirective(PreprocessorDirectiveContext* context)
(subDirective->callback)(context, subDirectiveToken);
}
-// Handle a `#version` directive
-static void handleGLSLVersionDirective(PreprocessorDirectiveContext* context)
-{
- Token versionNumberToken;
- if(!ExpectRaw(
- context,
- TokenType::IntegerLiteral,
- Diagnostics::expectedTokenInPreprocessorDirective,
- &versionNumberToken))
- {
- return;
- }
-
- Token glslProfileToken;
- if(PeekTokenType(context) == TokenType::Identifier)
- {
- glslProfileToken = AdvanceToken(context);
- }
-
- // Need to construct a representation taht we can hook into our compilation result
-
- auto modifier = new GLSLVersionDirective();
- modifier->versionNumberToken = versionNumberToken;
- modifier->glslProfileToken = glslProfileToken;
-
- // Attach the modifier to the program we are parsing!
-
- addModifier(
- context->preprocessor->getSyntax(),
- modifier);
-}
-
-// Handle a `#extension` directive, e.g.,
-//
-// #extension some_extension_name : enable
-//
-static void handleGLSLExtensionDirective(PreprocessorDirectiveContext* context)
-{
- Token extensionNameToken;
- if(!ExpectRaw(
- context,
- TokenType::Identifier,
- Diagnostics::expectedTokenInPreprocessorDirective,
- &extensionNameToken))
- {
- return;
- }
-
- if( !ExpectRaw(context, TokenType::Colon, Diagnostics::expectedTokenInPreprocessorDirective) )
- {
- return;
- }
-
- Token dispositionToken;
- if(!ExpectRaw(
- context,
- TokenType::Identifier,
- Diagnostics::expectedTokenInPreprocessorDirective,
- &dispositionToken))
- {
- return;
- }
-
- // Need to construct a representation taht we can hook into our compilation result
-
- auto modifier = new GLSLExtensionDirective();
- modifier->extensionNameToken = extensionNameToken;
- modifier->dispositionToken = dispositionToken;
-
- // Attach the modifier to the program we are parsing!
-
- addModifier(
- context->preprocessor->getSyntax(),
- modifier);
-}
-
// Handle an invalid directive
static void HandleInvalidDirective(PreprocessorDirectiveContext* context)
{
@@ -2092,11 +2027,6 @@ static const PreprocessorDirective kDirectives[] =
{ "line", &HandleLineDirective, 0 },
{ "pragma", &HandlePragmaDirective, 0 },
- // TODO(tfoley): These are specific to GLSL, and probably
- // shouldn't be enabled for HLSL or Slang
- { "version", &handleGLSLVersionDirective, 0 },
- { "extension", &handleGLSLExtensionDirective, 0 },
-
{ nullptr, nullptr, 0 },
};
@@ -2270,7 +2200,7 @@ static void DefineMacro(
PreprocessorMacro* macro = CreateMacro(preprocessor);
- auto sourceManager = preprocessor->translationUnit->compileRequest->getSourceManager();
+ auto sourceManager = preprocessor->getSourceManager();
SourceFile* keyFile = sourceManager->createSourceFileWithString(pathInfo, key);
SourceFile* valueFile = sourceManager->createSourceFileWithString(pathInfo, value);
@@ -2280,10 +2210,10 @@ static void DefineMacro(
// Use existing `Lexer` to generate a token stream.
Lexer lexer;
- lexer.initialize(valueView, GetSink(preprocessor), getNamePool(preprocessor), sourceManager->getMemoryArena());
+ lexer.initialize(valueView, GetSink(preprocessor), preprocessor->getNamePool(), sourceManager->getMemoryArena());
macro->tokens = lexer.lexAllTokens();
- Name* keyName = preprocessor->translationUnit->compileRequest->getNamePool()->getName(key);
+ Name* keyName = preprocessor->getNamePool()->getName(key);
macro->nameAndLoc.name = keyName;
macro->nameAndLoc.loc = keyView->getRange().begin;
@@ -2321,11 +2251,13 @@ TokenList preprocessSource(
DiagnosticSink* sink,
IncludeHandler* includeHandler,
Dictionary<String, String> defines,
- TranslationUnitRequest* translationUnit)
+ Linkage* linkage,
+ Module* parentModule)
{
Preprocessor preprocessor;
InitializePreprocessor(&preprocessor, sink);
- preprocessor.translationUnit = translationUnit;
+ preprocessor.linkage = linkage;
+ preprocessor.parentModule = parentModule;
preprocessor.includeHandler = includeHandler;
for (auto p : defines)
@@ -2333,7 +2265,7 @@ TokenList preprocessSource(
DefineMacro(&preprocessor, p.Key, p.Value);
}
- SourceManager* sourceManager = translationUnit->compileRequest->getSourceManager();
+ SourceManager* sourceManager = linkage->getSourceManager();
SourceView* sourceView = sourceManager->createSourceView(file, nullptr);
diff --git a/source/slang/preprocessor.h b/source/slang/preprocessor.h
index 4d02cb50b..6e8ac1c69 100644
--- a/source/slang/preprocessor.h
+++ b/source/slang/preprocessor.h
@@ -8,8 +8,9 @@
namespace Slang {
class DiagnosticSink;
+class Linkage;
+class Module;
class ModuleDecl;
-class TranslationUnitRequest;
// Callback interface for the preprocessor to use when looking
// for files in `#include` directives.
@@ -20,9 +21,6 @@ struct IncludeHandler
const String& pathIncludedFrom,
PathInfo& pathInfoOut) = 0;
- virtual SlangResult readFile(const String& path,
- ISlangBlob** blobOut) = 0;
-
virtual String simplifyPath(const String& path) = 0;
};
@@ -32,7 +30,8 @@ TokenList preprocessSource(
DiagnosticSink* sink,
IncludeHandler* includeHandler,
Dictionary<String, String> defines,
- TranslationUnitRequest* translationUnit);
+ Linkage* linkage,
+ Module* parentModule);
} // namespace Slang
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp
index 2b0be98c9..b40900faf 100644
--- a/source/slang/reflection.cpp
+++ b/source/slang/reflection.cpp
@@ -585,10 +585,15 @@ SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType)
SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * reflection, char const * name)
{
- auto context = convert(reflection);
- auto compileRequest = context->targetRequest->compileRequest;
+ auto programLayout = convert(reflection);
+ auto program = programLayout->getProgram();
+
+ // TODO: We should extend this API to support getting error messages
+ // when type lookup fails.
+ //
+ Slang::DiagnosticSink sink;
- RefPtr<Type> result = compileRequest->getTypeFromString(name);
+ RefPtr<Type> result = program->getTypeFromString(name, &sink);
return (SlangReflectionType*)result.Ptr();
}
@@ -599,12 +604,13 @@ SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout(
{
auto context = convert(reflection);
auto type = convert(inType);
- auto layoutContext = getInitialLayoutContextForTarget(context->targetRequest);
+ auto targetReq = context->getTargetReq();
+ auto layoutContext = getInitialLayoutContextForTarget(targetReq, context);
RefPtr<TypeLayout> result;
- if (context->targetRequest->typeLayouts.TryGetValue(type, result))
+ if (targetReq->getTypeLayouts().TryGetValue(type, result))
return (SlangReflectionTypeLayout*)result.Ptr();
result = CreateTypeLayout(layoutContext, type);
- context->targetRequest->typeLayouts[type] = result;
+ targetReq->getTypeLayouts()[type] = result;
return (SlangReflectionTypeLayout*)result.Ptr();
}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 16dfe8618..7a5b58d07 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -60,6 +60,8 @@ Session::Session()
// Make sure our source manager is initialized
builtinSourceManager.initialize(nullptr, nullptr);
+ m_builtinLinkage = new Linkage(this);
+
// Initialize representations of some very basic types:
initializeTypes();
@@ -90,11 +92,12 @@ Session::Session()
struct IncludeHandlerImpl : IncludeHandler
{
- CompileRequest* request;
+ Linkage* linkage;
+ SearchDirectoryList* searchDirectories;
ISlangFileSystemExt* _getFileSystemExt()
{
- return request->fileSystemExt;
+ return linkage->getFileSystemExt();
}
SlangResult _findFile(SlangPathType fromPathType, const String& fromPath, const String& path, PathInfo& pathInfoOut)
@@ -153,18 +156,22 @@ struct IncludeHandlerImpl : IncludeHandler
}
// Search all the searchDirectories
- for (auto & dir : request->searchDirectories)
+ for(auto sd = searchDirectories; sd; sd = sd->parent)
{
- SlangResult res = _findFile(SLANG_PATH_TYPE_DIRECTORY, dir.path, pathToInclude, pathInfoOut);
- if (SLANG_SUCCEEDED(res) || res != SLANG_E_NOT_FOUND)
+ for(auto& dir : sd->searchDirectories)
{
- return res;
+ SlangResult res = _findFile(SLANG_PATH_TYPE_DIRECTORY, dir.path, pathToInclude, pathInfoOut);
+ if (SLANG_SUCCEEDED(res) || res != SLANG_E_NOT_FOUND)
+ {
+ return res;
+ }
}
}
return SLANG_E_NOT_FOUND;
}
+#if 0
virtual SlangResult readFile(const String& path,
ISlangBlob** blobOut) override
{
@@ -175,6 +182,7 @@ struct IncludeHandlerImpl : IncludeHandler
return SLANG_OK;
}
+#endif
virtual String simplifyPath(const String& path) override
{
@@ -192,9 +200,9 @@ struct IncludeHandlerImpl : IncludeHandler
//
-Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target)
+Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target)
{
- auto entryPointProfile = entryPoint->profile;
+ auto entryPointProfile = entryPoint->getProfile();
auto targetProfile = target->targetProfile;
// Depending on the target *format* we might have to restrict the
@@ -310,20 +318,13 @@ Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target
//
-CompileRequest::CompileRequest(Session* session)
- : mSession(session)
+Linkage::Linkage(Session* session)
+ : m_session(session)
+ , m_sourceManager(&m_defaultSourceManager)
{
getNamePool()->setRootNamePool(session->getRootNamePool());
- setSourceManager(&sourceManagerStorage);
-
- sourceManager->initialize(session->getBuiltinSourceManager(), nullptr);
-
- // Set all the default writers
- for (int i = 0; i < int(WriterChannel::CountOf); ++i)
- {
- setWriter(WriterChannel(i), nullptr);
- }
+ m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr);
setFileSystem(nullptr);
}
@@ -379,10 +380,61 @@ ComPtr<ISlangBlob> createRawBlob(void const* inData, size_t size)
}
//
+// TargetRequest
+//
+
+Session* TargetRequest::getSession()
+{
+ return linkage->getSession();
+}
MatrixLayoutMode TargetRequest::getDefaultMatrixLayoutMode()
{
- return compileRequest->getDefaultMatrixLayoutMode();
+ return linkage->getDefaultMatrixLayoutMode();
+}
+
+//
+// TranslationUnitRequest
+//
+
+TranslationUnitRequest::TranslationUnitRequest(
+ FrontEndCompileRequest* compileRequest)
+ : compileRequest(compileRequest)
+{
+ module = new Module(compileRequest->getLinkage());
+}
+
+
+Session* TranslationUnitRequest::getSession()
+{
+ return compileRequest->getSession();
+}
+
+NamePool* TranslationUnitRequest::getNamePool()
+{
+ return compileRequest->getNamePool();
+}
+
+SourceManager* TranslationUnitRequest::getSourceManager()
+{
+ return compileRequest->getSourceManager();
+}
+
+void TranslationUnitRequest::addSourceFile(SourceFile* sourceFile)
+{
+ m_sourceFiles.Add(sourceFile);
+
+ // We want to record that the compiled module has a dependency
+ // on the path of the source file, but we also need to account
+ // for cases where the user added a source string/blob without
+ // an associated path (so that the API passes along an empty
+ // string).
+ //
+ auto path = sourceFile->getPathInfo().foundPath;
+ if(path.Length())
+ {
+ getModule()->addFilePathDependency(path);
+ }
}
@@ -407,7 +459,7 @@ static ISlangWriter* _getDefaultWriter(WriterChannel chan)
}
}
-void CompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer)
+void EndToEndCompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer)
{
// If the user passed in null, we will use the default writer on that channel
m_writers[int(chan)] = writer ? writer : _getDefaultWriter(chan);
@@ -415,20 +467,20 @@ void CompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer)
// For diagnostic output, if the user passes in nullptr, we set on mSink.writer as that enables buffering on DiagnosticSink
if (chan == WriterChannel::Diagnostic)
{
- mSink.writer = writer;
+ m_sink.writer = writer;
}
}
-SlangResult CompileRequest::loadFile(String const& path, ISlangBlob** outBlob)
+SlangResult Linkage::loadFile(String const& path, ISlangBlob** outBlob)
{
return fileSystemExt->loadFile(path.Buffer(), outBlob);
}
-RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope)
+RefPtr<Expr> Linkage::parseTypeString(String typeStr, RefPtr<Scope> scope)
{
// Create a SourceManager on the stack, so any allocations for 'SourceFile'/'SourceView' etc will be cleaned up
SourceManager localSourceManager;
- localSourceManager.initialize(sourceManager, nullptr);
+ localSourceManager.initialize(getSourceManager(), nullptr);
Slang::SourceFile* srcFile = localSourceManager.createSourceFileWithString(PathInfo::makeTypeParse(), typeStr);
@@ -440,20 +492,20 @@ RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translatio
// Use RAII - to make sure everything is reset even if an exception is thrown.
struct ScopeReplaceSourceManager
{
- ScopeReplaceSourceManager(CompileRequest* request, SourceManager* replaceManager):
- m_request(request),
- m_originalSourceManager(request->getSourceManager())
+ ScopeReplaceSourceManager(Linkage* linkage, SourceManager* replaceManager):
+ m_linkage(linkage),
+ m_originalSourceManager(linkage->getSourceManager())
{
- request->setSourceManager(replaceManager);
+ linkage->setSourceManager(replaceManager);
}
~ScopeReplaceSourceManager()
{
- m_request->setSourceManager(m_originalSourceManager);
+ m_linkage->setSourceManager(m_originalSourceManager);
}
private:
- CompileRequest* m_request;
+ Linkage* m_linkage;
SourceManager* m_originalSourceManager;
};
@@ -465,87 +517,131 @@ RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translatio
&sink,
nullptr,
Dictionary<String,String>(),
- translationUnit);
+ this,
+ nullptr);
- return parseTypeFromSourceFile(translationUnit, tokens, &sink, scope);
+ return parseTypeFromSourceFile(
+ getSession(),
+ tokens, &sink, scope, getNamePool(), SourceLanguage::Slang);
}
-RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp);
-Type* CompileRequest::getTypeFromString(String typeStr)
+RefPtr<Type> checkProperType(
+ Linkage* linkage,
+ TypeExp typeExp,
+ DiagnosticSink* sink);
+
+Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink)
{
+ // If we've looked up this type name before,
+ // then we can re-use it.
+ //
RefPtr<Type> type;
- if (types.TryGetValue(typeStr, type))
+ if(m_types.TryGetValue(typeStr, type))
return type;
- auto translationUnit = translationUnits.First();
+
+ // Otherwise, we need to start looking in
+ // the modules that were directly or
+ // indirectly referenced.
+ //
+ // TODO: This `scopesToTry` idiom appears
+ // all over the code, and isn't really
+ // how we should be handling this kind of
+ // lookup at all.
+ //
List<RefPtr<Scope>> scopesToTry;
- for (auto tu : translationUnits)
- scopesToTry.Add(tu->SyntaxNode->scope);
- for (auto & module : loadedModulesList)
- scopesToTry.Add(module->moduleDecl->scope);
- // parse type name
- for (auto & s : scopesToTry)
- {
- RefPtr<Expr> typeExpr = parseTypeString(translationUnit,
+ for(auto module : getModuleDependencies())
+ scopesToTry.Add(module->getModuleDecl()->scope);
+
+ auto linkage = getLinkage();
+ for(auto& s : scopesToTry)
+ {
+ RefPtr<Expr> typeExpr = linkage->parseTypeString(
typeStr, s);
- type = checkProperType(translationUnit, TypeExp(typeExpr));
- if (type)
+ type = checkProperType(linkage, TypeExp(typeExpr), sink);
+ if(type)
break;
}
- if (type)
+ if( type )
{
- types[typeStr] = type;
+ m_types[typeStr] = type;
}
- return type.Ptr();
+ return type;
}
-void CompileRequest::parseTranslationUnit(
+CompileRequestBase::CompileRequestBase(
+ Linkage* linkage,
+ DiagnosticSink* sink)
+ : m_linkage(linkage)
+ , m_sink(sink)
+{}
+
+
+FrontEndCompileRequest::FrontEndCompileRequest(
+ Linkage* linkage,
+ DiagnosticSink* sink)
+ : CompileRequestBase(linkage, sink)
+{
+}
+
+void FrontEndCompileRequest::parseTranslationUnit(
TranslationUnitRequest* translationUnit)
{
IncludeHandlerImpl includeHandler;
- includeHandler.request = this;
+ includeHandler.linkage = getLinkage();
+ includeHandler.searchDirectories = &searchDirectories;
RefPtr<Scope> languageScope;
switch (translationUnit->sourceLanguage)
{
case SourceLanguage::HLSL:
- languageScope = mSession->hlslLanguageScope;
+ languageScope = getSession()->hlslLanguageScope;
break;
case SourceLanguage::Slang:
default:
- languageScope = mSession->slangLanguageScope;
+ languageScope = getSession()->slangLanguageScope;
break;
}
Dictionary<String, String> combinedPreprocessorDefinitions;
+ for(auto& def : getLinkage()->preprocessorDefinitions)
+ combinedPreprocessorDefinitions.Add(def.Key, def.Value);
for(auto& def : preprocessorDefinitions)
combinedPreprocessorDefinitions.Add(def.Key, def.Value);
for(auto& def : translationUnit->preprocessorDefinitions)
combinedPreprocessorDefinitions.Add(def.Key, def.Value);
+ auto module = translationUnit->getModule();
RefPtr<ModuleDecl> translationUnitSyntax = new ModuleDecl();
- translationUnit->SyntaxNode = translationUnitSyntax;
+ translationUnitSyntax->nameAndLoc.name = translationUnit->moduleName;
+ translationUnitSyntax->module = module;
+ module->setModuleDecl(translationUnitSyntax);
- for (auto sourceFile : translationUnit->sourceFiles)
+ for (auto sourceFile : translationUnit->getSourceFiles())
{
auto tokens = preprocessSource(
sourceFile,
- &mSink,
+ getSink(),
&includeHandler,
combinedPreprocessorDefinitions,
- translationUnit);
+ getLinkage(),
+ module);
parseSourceFile(
translationUnit,
tokens,
- &mSink,
+ getSink(),
languageScope);
}
}
-void validateEntryPoints(CompileRequest*);
+RefPtr<Program> createUnspecializedProgram(
+ FrontEndCompileRequest* compileRequest);
-void CompileRequest::checkAllTranslationUnits()
+RefPtr<Program> createSpecializedProgram(
+ EndToEndCompileRequest* endToEndReq);
+
+void FrontEndCompileRequest::checkAllTranslationUnits()
{
// Iterate over all translation units and
// apply the semantic checking logic.
@@ -553,12 +649,9 @@ void CompileRequest::checkAllTranslationUnits()
{
checkTranslationUnit(translationUnit.Ptr());
}
-
- // Next, do follow-up validation on any entry points.
- validateEntryPoints(this);
}
-void CompileRequest::generateIR()
+void FrontEndCompileRequest::generateIR()
{
// Our task in this function is to generate IR code
// for all of the declarations in the translation
@@ -581,9 +674,9 @@ void CompileRequest::generateIR()
if (verifyDebugSerialization)
{
// Verify debug information
- if (SLANG_FAILED(IRSerialUtil::verifySerialize(irModule, mSession, sourceManager, IRSerialBinary::CompressionType::None, IRSerialWriter::OptionFlag::DebugInfo)))
+ if (SLANG_FAILED(IRSerialUtil::verifySerialize(irModule, getSession(), getSourceManager(), IRSerialBinary::CompressionType::None, IRSerialWriter::OptionFlag::DebugInfo)))
{
- mSink.diagnose(irModule->moduleInst->sourceLoc, Diagnostics::serialDebugVerificationFailed);
+ getSink()->diagnose(irModule->moduleInst->sourceLoc, Diagnostics::serialDebugVerificationFailed);
}
}
@@ -593,7 +686,7 @@ void CompileRequest::generateIR()
{
// Write IR out to serialData - copying over SourceLoc information directly
IRSerialWriter writer;
- writer.write(irModule, sourceManager, IRSerialWriter::OptionFlag::RawSourceLocation, &serialData);
+ writer.write(irModule, getSourceManager(), IRSerialWriter::OptionFlag::RawSourceLocation, &serialData);
// Destroy irModule such that memory can be used for newly constructed read irReadModule
irModule = nullptr;
@@ -602,7 +695,7 @@ void CompileRequest::generateIR()
{
// Read IR back from serialData
IRSerialReader reader;
- reader.read(serialData, mSession, nullptr, irReadModule);
+ reader.read(serialData, getSession(), nullptr, irReadModule);
}
// Set irModule to the read module
@@ -610,12 +703,12 @@ void CompileRequest::generateIR()
}
// Set the module on the translation unit
- translationUnit->irModule = irModule;
+ translationUnit->getModule()->setIRModule(irModule);
}
}
// Try to infer a single common source language for a request
-static SourceLanguage inferSourceLanguage(CompileRequest* request)
+static SourceLanguage inferSourceLanguage(FrontEndCompileRequest* request)
{
SourceLanguage language = SourceLanguage::Unknown;
for (auto& translationUnit : request->translationUnits)
@@ -639,29 +732,115 @@ static SourceLanguage inferSourceLanguage(CompileRequest* request)
return language;
}
-SlangResult CompileRequest::executeActionsInner()
+SlangResult FrontEndCompileRequest::executeActionsInner()
{
- // Do some cleanup on settings specified by user.
- // In particular, we want to propagate flags from the overall request down to
- // each translation unit.
+ // We currently allow GlSL files on the command line so that we can
+ // drive our "pass-through" mode, but we really want to issue an error
+ // message if the user is seriously asking us to compile them.
for (auto& translationUnit : translationUnits)
{
- translationUnit->compileFlags |= compileFlags;
+ switch(translationUnit->sourceLanguage)
+ {
+ default:
+ break;
+
+ case SourceLanguage::GLSL:
+ getSink()->diagnose(SourceLoc(), Diagnostics::glslIsNotSupported);
+ return SLANG_FAIL;
+ }
+ }
+
+
+ // Parse everything from the input files requested
+ for (auto& translationUnit : translationUnits)
+ {
+ parseTranslationUnit(translationUnit.Ptr());
}
+ if (getSink()->GetErrorCount() != 0)
+ return SLANG_FAIL;
+
+ // Perform semantic checking on the whole collection
+ checkAllTranslationUnits();
+ if (getSink()->GetErrorCount() != 0)
+ return SLANG_FAIL;
+
+ // Look up all the entry points that are expected,
+ // and use them to populate the `program` member.
+ //
+ m_program = createUnspecializedProgram(this);
+ if (getSink()->GetErrorCount() != 0)
+ return SLANG_FAIL;
+
+ if ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0)
+ {
+ // Generate initial IR for all the translation
+ // units, if we are in a mode where IR is called for.
+ generateIR();
+ }
+
+ if (getSink()->GetErrorCount() != 0)
+ return SLANG_FAIL;
+
+ // Do parameter binding generation, for each compilation target.
+ //
+ for(auto targetReq : getLinkage()->targets)
+ {
+ auto targetProgram = m_program->getTargetProgram(targetReq);
+ targetProgram->getOrCreateLayout(getSink());
+ }
+ if (getSink()->GetErrorCount() != 0)
+ return SLANG_FAIL;
+
+ return SLANG_OK;
+}
+
+BackEndCompileRequest::BackEndCompileRequest(
+ Linkage* linkage,
+ DiagnosticSink* sink,
+ Program* program)
+ : CompileRequestBase(linkage, sink)
+ , m_program(program)
+{}
+
+EndToEndCompileRequest::EndToEndCompileRequest(
+ Session* session)
+ : m_session(session)
+{
+ m_linkage = new Linkage(session);
+
+ m_sink.sourceManager = m_linkage->getSourceManager();
+
+ // Set all the default writers
+ for (int i = 0; i < int(WriterChannel::CountOf); ++i)
+ {
+ setWriter(WriterChannel(i), nullptr);
+ }
+
+ m_frontEndReq = new FrontEndCompileRequest(getLinkage(), getSink());
+
+ m_backEndReq = new BackEndCompileRequest(getLinkage(), getSink());
+}
+
+SlangResult EndToEndCompileRequest::executeActionsInner()
+{
// If no code-generation target was specified, then try to infer one from the source language,
// just to make sure we can do something reasonable when invoked from the command line.
- if (targets.Count() == 0)
+ //
+ // TODO: This logic should be moved into `options.cpp` or somewhere else
+ // specific to the command-line tool.
+ //
+ if (getLinkage()->targets.Count() == 0)
{
- auto language = inferSourceLanguage(this);
+ auto language = inferSourceLanguage(getFrontEndReq());
switch (language)
{
case SourceLanguage::HLSL:
- addTarget(CodeGenTarget::DXBytecode);
+ getLinkage()->addTarget(CodeGenTarget::DXBytecode);
break;
case SourceLanguage::GLSL:
- addTarget(CodeGenTarget::SPIRV);
+ getLinkage()->addTarget(CodeGenTarget::SPIRV);
break;
default:
@@ -672,105 +851,117 @@ SlangResult CompileRequest::executeActionsInner()
// We only do parsing and semantic checking if we *aren't* doing
// a pass-through compilation.
//
- // Note that we *do* perform output generation as normal in pass-through mode.
if (passThrough == PassThroughMode::None)
{
- // We currently allow GlSL files on the command line so that we can
- // drive our "pass-through" mode, but we really want to issue an error
- // message if the user is seriously asking us to compile them.
- for (auto& translationUnit : translationUnits)
- {
- switch(translationUnit->sourceLanguage)
- {
- default:
- break;
-
- case SourceLanguage::GLSL:
- mSink.diagnose(SourceLoc(), Diagnostics::glslIsNotSupported);
- return SLANG_FAIL;
- }
- }
+ SLANG_RETURN_ON_FAIL(getFrontEndReq()->executeActionsInner());
+ }
+ // If command line specifies to skip codegen, we exit here.
+ // Note: this is a debugging option.
+ //
+ if (shouldSkipCodegen ||
+ ((getFrontEndReq()->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) != 0))
+ {
+ // We will use the program (and matching layout information)
+ // that was computed in the front-end for all subsequent
+ // reflection queries, etc.
+ //
+ m_specializedProgram = getUnspecializedProgram();
- // Parse everything from the input files requested
- for (auto& translationUnit : translationUnits)
- {
- parseTranslationUnit(translationUnit.Ptr());
- }
- if (mSink.GetErrorCount() != 0)
- return SLANG_FAIL;
+ return SLANG_OK;
+ }
- // Perform semantic checking on the whole collection
- checkAllTranslationUnits();
- if (mSink.GetErrorCount() != 0)
+ // If codegen is enabled, we need to move along to
+ // apply any generic specialization that the user asked for.
+ //
+ if (passThrough == PassThroughMode::None)
+ {
+ m_specializedProgram = createSpecializedProgram(this);
+ if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
- if ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0)
+ // For each code generation target, we will generate specialized
+ // parameter binding information (taking global generic
+ // arguments into account at this time).
+ //
+ for (auto targetReq : getLinkage()->targets)
{
- // Generate initial IR for all the translation
- // units, if we are in a mode where IR is called for.
- generateIR();
+ auto targetProgram = m_specializedProgram->getTargetProgram(targetReq);
+ targetProgram->getOrCreateLayout(getSink());
}
-
- if (mSink.GetErrorCount() != 0)
+ if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
-
- // For each code generation target generate
- // parameter binding information.
- // This step is done globally, because all translation
- // units and entry points need to agree on where
- // parameters are allocated.
- for (auto targetReq : targets)
+ }
+ else
+ {
+ // We need to create dummy `EntryPoint` objects
+ // to make sure that the logic in `generateOutput`
+ // sees something worth processing.
+ //
+ auto specializedProgram = new Program(getLinkage());
+ m_specializedProgram = specializedProgram;
+ for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs())
{
- generateParameterBindings(targetReq);
- if (mSink.GetErrorCount() != 0)
- return SLANG_FAIL;
+ RefPtr<EntryPoint> entryPoint = EntryPoint::createDummyForPassThrough(
+ entryPointReq->getName(),
+ entryPointReq->getProfile());
+
+ specializedProgram->addEntryPoint(entryPoint);
}
}
- // If command line specifies to skip codegen, we exit here.
- // Note: this is a debugging option.
- if (shouldSkipCodegen ||
- ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) != 0))
- return SLANG_OK;
-
// Generate output code, in whatever format was requested
+ getBackEndReq()->setProgram(getSpecializedProgram());
generateOutput(this);
- if (mSink.GetErrorCount() != 0)
+ if (getSink()->GetErrorCount() != 0)
return SLANG_FAIL;
return SLANG_OK;
}
// Act as expected of the API-based compiler
-SlangResult CompileRequest::executeActions()
+SlangResult EndToEndCompileRequest::executeActions()
{
SlangResult res = executeActionsInner();
- mDiagnosticOutput = mSink.outputBuffer.ProduceString();
+ mDiagnosticOutput = getSink()->outputBuffer.ProduceString();
return res;
}
-int CompileRequest::addTranslationUnit(SourceLanguage language, String const&)
+int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* moduleName)
{
UInt result = translationUnits.Count();
- RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest();
+ RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(this);
translationUnit->compileRequest = this;
translationUnit->sourceLanguage = SourceLanguage(language);
+ translationUnit->moduleName = moduleName;
+
translationUnits.Add(translationUnit);
return (int) result;
}
-void CompileRequest::addTranslationUnitSourceFile(
+int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language)
+{
+ // We want to ensure that symbols defined in different translation
+ // units get unique mangled names, so that we can, e.g., tell apart
+ // a `main()` function in `vertex.slang` and a `main()` in `fragment.slang`,
+ // even when they are being compiled together.
+ //
+ String generatedName = "tu";
+ generatedName.append(translationUnits.Count());
+ return addTranslationUnit(language, getNamePool()->getName(generatedName));
+}
+
+void FrontEndCompileRequest::addTranslationUnitSourceFile(
int translationUnitIndex,
SourceFile* sourceFile)
{
- translationUnits[translationUnitIndex]->sourceFiles.Add(sourceFile);
+ translationUnits[translationUnitIndex]->addSourceFile(sourceFile);
}
-void CompileRequest::addTranslationUnitSourceBlob(
+void FrontEndCompileRequest::addTranslationUnitSourceBlob(
int translationUnitIndex,
String const& path,
ISlangBlob* sourceBlob)
@@ -781,7 +972,7 @@ void CompileRequest::addTranslationUnitSourceBlob(
addTranslationUnitSourceFile(translationUnitIndex, sourceFile);
}
-void CompileRequest::addTranslationUnitSourceString(
+void FrontEndCompileRequest::addTranslationUnitSourceString(
int translationUnitIndex,
String const& path,
String const& source)
@@ -792,7 +983,7 @@ void CompileRequest::addTranslationUnitSourceString(
addTranslationUnitSourceFile(translationUnitIndex, sourceFile);
}
-void CompileRequest::addTranslationUnitSourceFile(
+void FrontEndCompileRequest::addTranslationUnitSourceFile(
int translationUnitIndex,
String const& path)
{
@@ -809,7 +1000,7 @@ void CompileRequest::addTranslationUnitSourceFile(
if(SLANG_FAILED(result))
{
// Emit a diagnostic!
- mSink.diagnose(
+ getSink()->diagnose(
SourceLoc(),
Diagnostics::cannotOpenFile,
path);
@@ -820,36 +1011,51 @@ void CompileRequest::addTranslationUnitSourceFile(
translationUnitIndex,
path,
sourceBlob);
+}
+
+int FrontEndCompileRequest::addEntryPoint(
+ int translationUnitIndex,
+ String const& name,
+ Profile entryPointProfile)
+{
+ auto translationUnitReq = translationUnits[translationUnitIndex];
+
+ UInt result = m_entryPointReqs.Count();
+
+ RefPtr<FrontEndEntryPointRequest> entryPointReq = new FrontEndEntryPointRequest(
+ this,
+ translationUnitIndex,
+ getNamePool()->getName(name),
+ entryPointProfile);
+
+ m_entryPointReqs.Add(entryPointReq);
+// translationUnitReq->entryPoints.Add(entryPointReq);
- mDependencyFilePaths.Add(path);
+ return int(result);
}
-int CompileRequest::addEntryPoint(
+int EndToEndCompileRequest::addEntryPoint(
int translationUnitIndex,
String const& name,
Profile entryPointProfile,
List<String> const & genericTypeNames)
{
- RefPtr<EntryPointRequest> entryPoint = new EntryPointRequest();
- entryPoint->compileRequest = this;
- entryPoint->name = getNamePool()->getName(name);
- entryPoint->profile = entryPointProfile;
- entryPoint->translationUnitIndex = translationUnitIndex;
+ getFrontEndReq()->addEntryPoint(translationUnitIndex, name, entryPointProfile);
+
+ EntryPointInfo entryPointInfo;
for (auto typeName : genericTypeNames)
- entryPoint->genericArgStrings.Add(typeName);
- auto translationUnit = translationUnits[translationUnitIndex].Ptr();
- translationUnit->entryPoints.Add(entryPoint);
+ entryPointInfo.genericArgStrings.Add(typeName);
UInt result = entryPoints.Count();
- entryPoints.Add(entryPoint);
+ entryPoints.Add(_Move(entryPointInfo));
return (int) result;
}
-UInt CompileRequest::addTarget(
+UInt Linkage::addTarget(
CodeGenTarget target)
{
RefPtr<TargetRequest> targetReq = new TargetRequest();
- targetReq->compileRequest = this;
+ targetReq->linkage = this;
targetReq->target = target;
UInt result = targets.Count();
@@ -857,15 +1063,16 @@ UInt CompileRequest::addTarget(
return (int) result;
}
-void CompileRequest::loadParsedModule(
- RefPtr<TranslationUnitRequest> const& translationUnit,
- Name* name,
- const PathInfo& pathInfo)
+void Linkage::loadParsedModule(
+ RefPtr<TranslationUnitRequest> translationUnit,
+ Name* name,
+ const PathInfo& pathInfo)
{
// Note: we add the loaded module to our name->module listing
// before doing semantic checking, so that if it tries to
// recursively `import` itself, we can detect it.
- RefPtr<LoadedModule> loadedModule = new LoadedModule();
+ //
+ RefPtr<Module> loadedModule = translationUnit->getModule();
// Get a path
String mostUniqueIdentity = pathInfo.getMostUniqueIdentity();
@@ -874,12 +1081,11 @@ void CompileRequest::loadParsedModule(
mapPathToLoadedModule.Add(mostUniqueIdentity, loadedModule);
mapNameToLoadedModules.Add(name, loadedModule);
- int errorCountBefore = mSink.GetErrorCount();
- checkTranslationUnit(translationUnit.Ptr());
- int errorCountAfter = mSink.GetErrorCount();
+ auto sink = translationUnit->compileRequest->getSink();
- RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode;
- loadedModule->moduleDecl = moduleDecl;
+ int errorCountBefore = sink->GetErrorCount();
+ checkTranslationUnit(translationUnit.Ptr());
+ int errorCountAfter = sink->GetErrorCount();
if (errorCountAfter != errorCountBefore)
{
@@ -890,39 +1096,56 @@ void CompileRequest::loadParsedModule(
// If we didn't run into any errors, then try to generate
// IR code for the imported module.
SLANG_ASSERT(errorCountAfter == 0);
- loadedModule->irModule = generateIRForTranslationUnit(translationUnit);
+ loadedModule->setIRModule(generateIRForTranslationUnit(translationUnit));
}
loadedModulesList.Add(loadedModule);
}
-RefPtr<ModuleDecl> CompileRequest::loadModule(
+Module* Linkage::loadModule(String const& name)
+{
+ // TODO: We either need to have a diagnostics sink
+ // get passed into this operation, or associate
+ // one with the linkage.
+ //
+ DiagnosticSink* sink = nullptr;
+ return findOrImportModule(
+ getNamePool()->getName(name),
+ SourceLoc(),
+ sink);
+}
+
+
+RefPtr<Module> Linkage::loadModule(
Name* name,
const PathInfo& filePathInfo,
ISlangBlob* sourceBlob,
- SourceLoc const& srcLoc)
+ SourceLoc const& srcLoc,
+ DiagnosticSink* sink)
{
- RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest();
- translationUnit->compileRequest = this;
+ RefPtr<FrontEndCompileRequest> frontEndReq = new FrontEndCompileRequest(this, sink);
- // We don't want to use the same options that the user specified
- // for loading modules on-demand. In particular, we always want
- // semantic checking to be enabled.
- //
- // TODO: decide which options, if any, should be inherited.
- translationUnit->compileFlags = 0;
+ RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(frontEndReq);
+ translationUnit->compileRequest = frontEndReq;
+ translationUnit->moduleName = name;
+
+ auto module = translationUnit->getModule();
+
+ ModuleBeingImportedRAII moduleBeingImported(
+ this,
+ module);
// Create with the 'friendly' name
SourceFile* sourceFile = getSourceManager()->createSourceFileWithBlob(filePathInfo, sourceBlob);
- translationUnit->sourceFiles.Add(sourceFile);
+ translationUnit->addSourceFile(sourceFile);
- int errorCountBefore = mSink.GetErrorCount();
- parseTranslationUnit(translationUnit.Ptr());
- int errorCountAfter = mSink.GetErrorCount();
+ int errorCountBefore = sink->GetErrorCount();
+ frontEndReq->parseTranslationUnit(translationUnit);
+ int errorCountAfter = sink->GetErrorCount();
if( errorCountAfter != errorCountBefore )
{
- mSink.diagnose(srcLoc, Diagnostics::errorInImportedModule);
+ sink->diagnose(srcLoc, Diagnostics::errorInImportedModule);
}
if (errorCountAfter)
{
@@ -935,38 +1158,57 @@ RefPtr<ModuleDecl> CompileRequest::loadModule(
name,
filePathInfo);
- errorCountAfter = mSink.GetErrorCount();
+ errorCountAfter = sink->GetErrorCount();
if (errorCountAfter != errorCountBefore)
{
- mSink.diagnose(srcLoc, Diagnostics::errorInImportedModule);
+ sink->diagnose(srcLoc, Diagnostics::errorInImportedModule);
// Something went wrong during the parsing, so we should bail out.
return nullptr;
}
- return translationUnit->SyntaxNode;
+ return module;
+}
+
+bool Linkage::isBeingImported(Module* module)
+{
+ for(auto ii = m_modulesBeingImported; ii; ii = ii->next)
+ {
+ if(module == ii->module)
+ return true;
+ }
+ return false;
}
-RefPtr<ModuleDecl> CompileRequest::findOrImportModule(
+RefPtr<Module> Linkage::findOrImportModule(
Name* name,
- SourceLoc const& loc)
+ SourceLoc const& loc,
+ DiagnosticSink* sink)
{
// Have we already loaded a module matching this name?
- // If so, return it.
+ //
RefPtr<LoadedModule> loadedModule;
if (mapNameToLoadedModules.TryGetValue(name, loadedModule))
{
+ // If the map shows a null module having been loaded,
+ // then that means there was a prior load attempt,
+ // but it failed, so we won't bother trying again.
+ //
if (!loadedModule)
return nullptr;
- if (!loadedModule->moduleDecl)
+ // If state shows us that the module is already being
+ // imported deeper on the call stack, then we've
+ // hit a recursive case, and that is an error.
+ //
+ if(isBeingImported(loadedModule))
{
// We seem to be in the middle of loading this module
- mSink.diagnose(loc, Diagnostics::recursiveModuleImport, name);
+ sink->diagnose(loc, Diagnostics::recursiveModuleImport, name);
return nullptr;
}
- return loadedModule->moduleDecl;
+ return loadedModule;
}
// Derive a file name for the module, by taking the given
@@ -991,7 +1233,8 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule(
// using our ordinary include-handling logic.
IncludeHandlerImpl includeHandler;
- includeHandler.request = this;
+ includeHandler.linkage = this;
+ includeHandler.searchDirectories = &searchDirectories;
// Get the original path info
PathInfo pathIncludedFromInfo = getSourceManager()->getPathInfo(loc, SourceLocType::Actual);
@@ -1000,20 +1243,20 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule(
// We have to load via the found path - as that is how file was originally loaded
if (SLANG_FAILED(includeHandler.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo)))
{
- this->mSink.diagnose(loc, Diagnostics::cannotFindFile, fileName);
+ sink->diagnose(loc, Diagnostics::cannotFindFile, fileName);
mapNameToLoadedModules[name] = nullptr;
return nullptr;
}
// Maybe this was loaded previously at a different relative name?
if (mapPathToLoadedModule.TryGetValue(filePathInfo.getMostUniqueIdentity(), loadedModule))
- return loadedModule->moduleDecl;
+ return loadedModule;
// Try to load it
ComPtr<ISlangBlob> fileContents;
- if (SLANG_FAILED(includeHandler.readFile(filePathInfo.foundPath, fileContents.writeRef())))
+ if(SLANG_FAILED(getFileSystemExt()->loadFile(filePathInfo.foundPath.Buffer(), fileContents.writeRef())))
{
- this->mSink.diagnose(loc, Diagnostics::cannotOpenFile, fileName);
+ sink->diagnose(loc, Diagnostics::cannotOpenFile, fileName);
mapNameToLoadedModules[name] = nullptr;
return nullptr;
}
@@ -1024,26 +1267,159 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule(
name,
filePathInfo,
fileContents,
- loc);
+ loc,
+ sink);
}
-Decl * CompileRequest::lookupGlobalDecl(Name * name)
+//
+// ModuleDependencyList
+//
+
+void ModuleDependencyList::addDependency(Module* module)
{
- Decl* resultDecl = nullptr;
- for (auto module : loadedModulesList)
+ // If we depend on a module, then we depend on everything it depends on.
+ //
+ // Note: We are processing these sub-depenencies before adding the
+ // `module` itself, so that in the common case a module will always
+ // appear *after* everything it depends on.
+ //
+ // However, this rule is being violated in the compiler right now because
+ // the modules for hte top-level translation units of a compile request
+ // will be added to the list first (using `addLeafDependency`) to
+ // maintain compatibility with old behavior. This may be fixed later.
+ //
+ for(auto subDependency : module->getModuleDependencyList())
{
- if (module->moduleDecl->memberDictionary.TryGetValue(name, resultDecl))
- break;
+ _addDependency(subDependency);
+ }
+ _addDependency(module);
+}
+
+void ModuleDependencyList::addLeafDependency(Module* module)
+{
+ _addDependency(module);
+}
+
+void ModuleDependencyList::_addDependency(Module* module)
+{
+ if(m_moduleSet.Contains(module))
+ return;
+
+ m_moduleList.Add(module);
+ m_moduleSet.Add(module);
+}
+
+//
+// FilePathDependencyList
+//
+
+void FilePathDependencyList::addDependency(String const& path)
+{
+ if(m_filePathSet.Contains(path))
+ return;
+
+ m_filePathList.Add(path);
+ m_filePathSet.Add(path);
+}
+
+void FilePathDependencyList::addDependency(Module* module)
+{
+ for(auto& path : module->getFilePathDependencyList())
+ {
+ addDependency(path);
}
- for (auto transUnit : translationUnits)
+}
+
+
+
+//
+// Module
+//
+
+Module::Module(Linkage* linkage)
+ : m_linkage(linkage)
+{}
+
+
+void Module::addModuleDependency(Module* module)
+{
+ m_moduleDependencyList.addDependency(module);
+ m_filePathDependencyList.addDependency(module);
+}
+
+void Module::addFilePathDependency(String const& path)
+{
+ m_filePathDependencyList.addDependency(path);
+}
+
+// Program
+
+Program::Program(Linkage* linkage)
+ : m_linkage(linkage)
+{}
+
+void Program::addReferencedModule(Module* module)
+{
+ m_moduleDependencyList.addDependency(module);
+ m_filePathDependencyList.addDependency(module);
+}
+
+void Program::addReferencedLeafModule(Module* module)
+{
+ m_moduleDependencyList.addLeafDependency(module);
+ m_filePathDependencyList.addDependency(module);
+}
+
+void Program::addEntryPoint(EntryPoint* entryPoint)
+{
+ m_entryPoints.Add(entryPoint);
+
+ for(auto module : entryPoint->getModuleDependencies())
{
- if (transUnit->SyntaxNode->memberDictionary.TryGetValue(name, resultDecl))
- break;
+ addReferencedModule(module);
}
- return resultDecl;
}
-void CompileRequest::noteInternalErrorLoc(SourceLoc const& loc)
+RefPtr<IRModule> Program::getOrCreateIRModule(DiagnosticSink* sink)
+{
+ if(!m_irModule)
+ {
+ m_irModule = generateIRForProgram(
+ m_linkage->getSession(),
+ this,
+ sink);
+ }
+ return m_irModule;
+}
+
+
+TargetProgram* Program::getTargetProgram(TargetRequest* target)
+{
+ RefPtr<TargetProgram> targetProgram;
+ if(!m_targetPrograms.TryGetValue(target, targetProgram))
+ {
+ targetProgram = new TargetProgram(this, target);
+ m_targetPrograms[target] = targetProgram;
+ }
+ return targetProgram;
+}
+
+//
+// TargetProgram
+//
+
+TargetProgram::TargetProgram(
+ Program* program,
+ TargetRequest* targetReq)
+ : m_program(program)
+ , m_targetReq(targetReq)
+{
+ m_entryPointResults.SetSize(program->getEntryPoints().Count());
+}
+
+//
+
+void DiagnosticSink::noteInternalErrorLoc(SourceLoc const& loc)
{
// Don't consider invalid source locations.
if(!loc.isValid())
@@ -1054,14 +1430,19 @@ void CompileRequest::noteInternalErrorLoc(SourceLoc const& loc)
// code might have confused the compiler.
if(internalErrorLocsNoted == 0)
{
- mSink.diagnose(loc, Diagnostics::noteLocationOfInternalError);
+ diagnose(loc, Diagnostics::noteLocationOfInternalError);
}
internalErrorLocsNoted++;
}
+Session* CompileRequestBase::getSession()
+{
+ return getLinkage()->getSession();
+}
+
static const Slang::Guid IID_ISlangFileSystemExt = SLANG_UUID_ISlangFileSystemExt;
-void CompileRequest::setFileSystem(ISlangFileSystem* inFileSystem)
+void Linkage::setFileSystem(ISlangFileSystem* inFileSystem)
{
// Set the fileSystem
fileSystem = inFileSystem;
@@ -1085,15 +1466,16 @@ void CompileRequest::setFileSystem(ISlangFileSystem* inFileSystem)
}
// Set the file system used on the source manager
- sourceManager->setFileSystemExt(fileSystemExt);
+ getSourceManager()->setFileSystemExt(fileSystemExt);
}
-RefPtr<ModuleDecl> findOrImportModule(
- CompileRequest* request,
+RefPtr<Module> findOrImportModule(
+ Linkage* linkage,
Name* name,
- SourceLoc const& loc)
+ SourceLoc const& loc,
+ DiagnosticSink* sink)
{
- return request->findOrImportModule(name, loc);
+ return linkage->findOrImportModule(name, loc, sink);
}
void Session::addBuiltinSource(
@@ -1101,30 +1483,34 @@ void Session::addBuiltinSource(
String const& path,
String const& source)
{
- RefPtr<CompileRequest> compileRequest = new CompileRequest(this);
- compileRequest->setSourceManager(getBuiltinSourceManager());
+ DiagnosticSink sink;
+ RefPtr<FrontEndCompileRequest> compileRequest = new FrontEndCompileRequest(
+ m_builtinLinkage,
+ &sink);
- auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, path);
+ Name* moduleName = getNamePool()->getName(path);
+ auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, moduleName);
compileRequest->addTranslationUnitSourceString(
translationUnitIndex,
path,
source);
- SlangResult res = compileRequest->executeActions();
+ SlangResult res = compileRequest->executeActionsInner();
if (SLANG_FAILED(res))
{
- fprintf(stderr, "%s", compileRequest->mDiagnosticOutput.Buffer());
+ char const* diagnostics = sink.outputBuffer.Buffer();
+ fprintf(stderr, "%s", diagnostics);
#ifdef _WIN32
- OutputDebugStringA(compileRequest->mDiagnosticOutput.Buffer());
+ OutputDebugStringA(diagnostics);
#endif
SLANG_UNEXPECTED("error in Slang standard library");
}
// Extract the AST for the code we just parsed
- auto syntax = compileRequest->translationUnits[translationUnitIndex]->SyntaxNode;
+ auto syntax = compileRequest->translationUnits[translationUnitIndex]->getModuleDecl();
// HACK(tfoley): mark all declarations in the "stdlib" so
// that we can detect them later (e.g., so we don't emit them)
@@ -1176,19 +1562,37 @@ Session::~Session()
// implementation of C interface
-#define SESSION(x) reinterpret_cast<Slang::Session *>(x)
-#define REQ(x) reinterpret_cast<Slang::CompileRequest*>(x)
+static SlangSession* convert(Slang::Session* session)
+{ return reinterpret_cast<SlangSession*>(session); }
+
+static Slang::Session* convert(SlangSession* session)
+{ return reinterpret_cast<Slang::Session*>(session); }
+
+static SlangCompileRequest* convert(Slang::EndToEndCompileRequest* request)
+{ return reinterpret_cast<SlangCompileRequest*>(request); }
+
+static Slang::EndToEndCompileRequest* convert(SlangCompileRequest* request)
+{ return reinterpret_cast<Slang::EndToEndCompileRequest*>(request); }
+
+static SlangLinkage* convert(Slang::Linkage* linkage)
+{ return reinterpret_cast<SlangLinkage*>(linkage); }
+
+static Slang::Linkage* convert(SlangLinkage* linkage)
+{ return reinterpret_cast<Slang::Linkage*>(linkage); }
+
+static SlangModule* convert(Slang::Module* module)
+{ return reinterpret_cast<SlangModule*>(module); }
SLANG_API SlangSession* spCreateSession(const char*)
{
- return reinterpret_cast<SlangSession *>(new Slang::Session());
+ return convert(new Slang::Session());
}
SLANG_API void spDestroySession(
SlangSession* session)
{
if(!session) return;
- delete SESSION(session);
+ delete convert(session);
}
SLANG_API void spAddBuiltins(
@@ -1196,7 +1600,7 @@ SLANG_API void spAddBuiltins(
char const* sourcePath,
char const* sourceString)
{
- auto s = SESSION(session);
+ auto s = convert(session);
s->addBuiltinSource(
// TODO(tfoley): Add ability to directly new builtins to the approriate scope
@@ -1210,7 +1614,7 @@ SLANG_API void spSessionSetSharedLibraryLoader(
SlangSession* session,
ISlangSharedLibraryLoader* loader)
{
- auto s = SESSION(session);
+ auto s = convert(session);
if (!loader)
{
@@ -1237,7 +1641,7 @@ SLANG_API void spSessionSetSharedLibraryLoader(
SLANG_API ISlangSharedLibraryLoader* spSessionGetSharedLibraryLoader(
SlangSession* session)
{
- auto s = SESSION(session);
+ auto s = convert(session);
return (s->sharedLibraryLoader == Slang::DefaultSharedLibraryLoader::getSingleton()) ? nullptr : s->sharedLibraryLoader.get();
}
@@ -1245,7 +1649,7 @@ SLANG_API SlangResult spSessionCheckCompileTargetSupport(
SlangSession* session,
SlangCompileTarget target)
{
- auto s = SESSION(session);
+ auto s = convert(session);
return Slang::checkCompileTargetSupport(s, Slang::CodeGenTarget(target));
}
@@ -1253,16 +1657,45 @@ SLANG_API SlangResult spSessionCheckPassThroughSupport(
SlangSession* session,
SlangPassThrough passThrough)
{
- auto s = SESSION(session);
+ auto s = convert(session);
return Slang::checkExternalCompilerSupport(s, Slang::PassThroughMode(passThrough));
}
+
+SLANG_API SlangLinkage* spCreateLinkage(
+ SlangSession* session)
+{
+ auto s = convert(session);
+ auto linkage = new Slang::Linkage(s);
+ return convert(linkage);
+}
+
+SLANG_API void spDestroyLinkage(
+ SlangLinkage* linkage)
+{
+ if(!linkage) return;
+ auto lnk = convert(linkage);
+ delete lnk;
+}
+
+SLANG_API SlangModule* spLoadModule(
+ SlangLinkage* linkage,
+ char const* moduleName)
+{
+ if(!linkage) return nullptr;
+ auto lnk = convert(linkage);
+
+ auto mod = lnk->loadModule(moduleName);
+ return convert(mod);
+}
+
+
SLANG_API SlangCompileRequest* spCreateCompileRequest(
SlangSession* session)
{
- auto s = SESSION(session);
- auto req = new Slang::CompileRequest(s);
- return reinterpret_cast<SlangCompileRequest*>(req);
+ auto s = convert(session);
+ auto req = new Slang::EndToEndCompileRequest(s);
+ return convert(req);
}
/*!
@@ -1272,7 +1705,7 @@ SLANG_API void spDestroyCompileRequest(
SlangCompileRequest* request)
{
if(!request) return;
- auto req = REQ(request);
+ auto req = convert(request);
delete req;
}
@@ -1281,21 +1714,21 @@ SLANG_API void spSetFileSystem(
ISlangFileSystem* fileSystem)
{
if(!request) return;
- REQ(request)->setFileSystem(fileSystem);
+ convert(request)->getLinkage()->setFileSystem(fileSystem);
}
SLANG_API void spSetCompileFlags(
SlangCompileRequest* request,
SlangCompileFlags flags)
{
- REQ(request)->compileFlags = flags;
+ convert(request)->getFrontEndReq()->compileFlags = flags;
}
SLANG_API void spSetDumpIntermediates(
SlangCompileRequest* request,
int enable)
{
- REQ(request)->shouldDumpIntermediates = enable != 0;
+ convert(request)->getBackEndReq()->shouldDumpIntermediates = enable != 0;
}
SLANG_API void spSetLineDirectiveMode(
@@ -1304,13 +1737,13 @@ SLANG_API void spSetLineDirectiveMode(
{
// TODO: validation
- REQ(request)->lineDirectiveMode = Slang::LineDirectiveMode(mode);
+ convert(request)->getBackEndReq()->lineDirectiveMode = Slang::LineDirectiveMode(mode);
}
SLANG_API void spSetCommandLineCompilerMode(
SlangCompileRequest* request)
{
- REQ(request)->isCommandLineCompile = true;
+ convert(request)->isCommandLineCompile = true;
}
@@ -1318,17 +1751,19 @@ SLANG_API void spSetCodeGenTarget(
SlangCompileRequest* request,
SlangCompileTarget target)
{
- auto req = REQ(request);
- req->targets.Clear();
- req->addTarget(Slang::CodeGenTarget(target));
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ linkage->targets.Clear();
+ linkage->addTarget(Slang::CodeGenTarget(target));
}
SLANG_API int spAddCodeGenTarget(
SlangCompileRequest* request,
SlangCompileTarget target)
{
- auto req = REQ(request);
- return (int) req->addTarget(Slang::CodeGenTarget(target));
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ return (int) linkage->addTarget(Slang::CodeGenTarget(target));
}
SLANG_API void spSetTargetProfile(
@@ -1336,8 +1771,9 @@ SLANG_API void spSetTargetProfile(
int targetIndex,
SlangProfileID profile)
{
- auto req = REQ(request);
- req->targets[targetIndex]->targetProfile = Slang::Profile(profile);
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ linkage->targets[targetIndex]->targetProfile = Slang::Profile(profile);
}
SLANG_API void spSetTargetFlags(
@@ -1345,8 +1781,9 @@ SLANG_API void spSetTargetFlags(
int targetIndex,
SlangTargetFlags flags)
{
- auto req = REQ(request);
- req->targets[targetIndex]->targetFlags = flags;
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ linkage->targets[targetIndex]->targetFlags = flags;
}
SLANG_API void spSetTargetFloatingPointMode(
@@ -1354,16 +1791,18 @@ SLANG_API void spSetTargetFloatingPointMode(
int targetIndex,
SlangFloatingPointMode mode)
{
- auto req = REQ(request);
- req->targets[targetIndex]->floatingPointMode = Slang::FloatingPointMode(mode);
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ linkage->targets[targetIndex]->floatingPointMode = Slang::FloatingPointMode(mode);
}
SLANG_API void spSetMatrixLayoutMode(
SlangCompileRequest* request,
SlangMatrixLayoutMode mode)
{
- auto req = REQ(request);
- req->defaultMatrixLayoutMode = Slang::MatrixLayoutMode(mode);
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ linkage->defaultMatrixLayoutMode = Slang::MatrixLayoutMode(mode);
}
SLANG_API void spSetTargetMatrixLayoutMode(
@@ -1380,7 +1819,7 @@ SLANG_API void spSetOutputContainerFormat(
SlangCompileRequest* request,
SlangContainerFormat format)
{
- auto req = REQ(request);
+ auto req = convert(request);
req->containerFormat = Slang::ContainerFormat(format);
}
@@ -1389,7 +1828,7 @@ SLANG_API void spSetPassThrough(
SlangCompileRequest* request,
SlangPassThrough passThrough)
{
- REQ(request)->passThrough = Slang::PassThroughMode(passThrough);
+ convert(request)->passThrough = Slang::PassThroughMode(passThrough);
}
SLANG_API void spSetDiagnosticCallback(
@@ -1400,7 +1839,7 @@ SLANG_API void spSetDiagnosticCallback(
using namespace Slang;
if(!request) return;
- auto req = REQ(request);
+ auto req = convert(request);
ComPtr<ISlangWriter> writer(new CallbackWriter(callback, userData, WriterFlag::IsConsole));
req->setWriter(WriterChannel::Diagnostic, writer);
@@ -1412,7 +1851,7 @@ SLANG_API void spSetWriter(
ISlangWriter* writer)
{
if (!request) return;
- auto req = REQ(request);
+ auto req = convert(request);
req->setWriter(Slang::WriterChannel(chan), writer);
}
@@ -1422,15 +1861,17 @@ SLANG_API ISlangWriter* spGetWriter(
SlangWriterChannel chan)
{
if (!request) return nullptr;
- auto req = REQ(request);
+ auto req = convert(request);
return req->getWriter(Slang::WriterChannel(chan));
}
SLANG_API void spAddSearchPath(
- SlangCompileRequest* request,
- const char* path)
+ SlangCompileRequest* request,
+ const char* path)
{
- REQ(request)->searchDirectories.Add(Slang::SearchDirectory(path));
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ linkage->searchDirectories.searchDirectories.Add(Slang::SearchDirectory(path));
}
SLANG_API void spAddPreprocessorDefine(
@@ -1438,25 +1879,27 @@ SLANG_API void spAddPreprocessorDefine(
const char* key,
const char* value)
{
- REQ(request)->preprocessorDefinitions[key] = value;
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ linkage->preprocessorDefinitions[key] = value;
}
SLANG_API char const* spGetDiagnosticOutput(
SlangCompileRequest* request)
{
if(!request) return 0;
- auto req = REQ(request);
+ auto req = convert(request);
return req->mDiagnosticOutput.begin();
}
SLANG_API SlangResult spGetDiagnosticOutputBlob(
- SlangCompileRequest* request,
- ISlangBlob** outBlob)
+ SlangCompileRequest* request,
+ ISlangBlob** outBlob)
{
if(!request) return SLANG_ERROR_INVALID_PARAMETER;
if(!outBlob) return SLANG_ERROR_INVALID_PARAMETER;
- auto req = REQ(request);
+ auto req = convert(request);
if(!req->diagnosticOutputBlob)
{
@@ -1475,11 +1918,13 @@ SLANG_API int spAddTranslationUnit(
SlangSourceLanguage language,
char const* name)
{
- auto req = REQ(request);
+ SLANG_UNUSED(name);
+
+ auto req = convert(request);
+ auto frontEndReq = req->getFrontEndReq();
- return req->addTranslationUnit(
- Slang::SourceLanguage(language),
- name ? name : "");
+ return frontEndReq->addTranslationUnit(
+ Slang::SourceLanguage(language));
}
SLANG_API void spTranslationUnit_addPreprocessorDefine(
@@ -1488,10 +1933,10 @@ SLANG_API void spTranslationUnit_addPreprocessorDefine(
const char* key,
const char* value)
{
- auto req = REQ(request);
-
- req->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value;
+ auto req = convert(request);
+ auto frontEndReq = req->getFrontEndReq();
+ frontEndReq->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value;
}
SLANG_API void spAddTranslationUnitSourceFile(
@@ -1500,12 +1945,13 @@ SLANG_API void spAddTranslationUnitSourceFile(
char const* path)
{
if(!request) return;
- auto req = REQ(request);
+ auto req = convert(request);
+ auto frontEndReq = req->getFrontEndReq();
if(!path) return;
if(translationUnitIndex < 0) return;
- if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return;
+ if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return;
- req->addTranslationUnitSourceFile(
+ frontEndReq->addTranslationUnitSourceFile(
translationUnitIndex,
path);
}
@@ -1533,14 +1979,15 @@ SLANG_API void spAddTranslationUnitSourceStringSpan(
char const* sourceEnd)
{
if(!request) return;
- auto req = REQ(request);
+ auto req = convert(request);
+ auto frontEndReq = req->getFrontEndReq();
if(!sourceBegin) return;
if(translationUnitIndex < 0) return;
- if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return;
+ if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return;
if(!path) path = "";
- req->addTranslationUnitSourceString(
+ frontEndReq->addTranslationUnitSourceString(
translationUnitIndex,
path,
Slang::UnownedStringSlice(sourceBegin, sourceEnd));
@@ -1553,14 +2000,15 @@ SLANG_API void spAddTranslationUnitSourceBlob(
ISlangBlob* sourceBlob)
{
if(!request) return;
- auto req = REQ(request);
+ auto req = convert(request);
+ auto frontEndReq = req->getFrontEndReq();
if(!sourceBlob) return;
if(translationUnitIndex < 0) return;
- if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return;
+ if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return;
if(!path) path = "";
- req->addTranslationUnitSourceBlob(
+ frontEndReq->addTranslationUnitSourceBlob(
translationUnitIndex,
path,
sourceBlob);
@@ -1584,17 +2032,13 @@ SLANG_API int spAddEntryPoint(
char const* name,
SlangStage stage)
{
- if(!request) return -1;
- auto req = REQ(request);
- if(!name) return -1;
- if(translationUnitIndex < 0) return -1;
- if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1;
-
- return req->addEntryPoint(
+ return spAddEntryPointEx(
+ request,
translationUnitIndex,
name,
- Slang::Profile(Slang::Stage(stage)),
- Slang::List<Slang::String>());
+ stage,
+ 0,
+ nullptr);
}
SLANG_API int spAddEntryPointEx(
@@ -1606,10 +2050,11 @@ SLANG_API int spAddEntryPointEx(
char const ** genericParamTypeNames)
{
if (!request) return -1;
- auto req = REQ(request);
+ auto req = convert(request);
+ auto frontEndReq = req->getFrontEndReq();
if (!name) return -1;
if (translationUnitIndex < 0) return -1;
- if (Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1;
+ if (Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return -1;
Slang::List<Slang::String> typeNames;
for (int i = 0; i < genericParamTypeNameCount; i++)
typeNames.Add(genericParamTypeNames[i]);
@@ -1620,12 +2065,28 @@ SLANG_API int spAddEntryPointEx(
typeNames);
}
+SLANG_API SlangResult spSetGlobalGenericArgs(
+ SlangCompileRequest* request,
+ int genericArgCount,
+ char const** genericArgs)
+{
+ if (!request) return SLANG_FAIL;
+ auto req = convert(request);
+
+ auto& genericArgStrings = req->globalGenericArgStrings;
+ genericArgStrings.Clear();
+ for (int i = 0; i < genericArgCount; i++)
+ genericArgStrings.Add(genericArgs[i]);
+
+ return SLANG_OK;
+}
+
// Compile in a context that already has its translation units specified
SLANG_API SlangResult spCompile(
SlangCompileRequest* request)
{
- auto req = REQ(request);
+ auto req = convert(request);
#if !defined(SLANG_DEBUG_INTERNAL_ERROR)
// By default we'd like to catch as many internal errors as possible,
@@ -1654,7 +2115,7 @@ SLANG_API SlangResult spCompile(
// We will print out information on the exception to help out the user
// in either filing a bug, or locating what in their code created
// a problem.
- req->mSink.diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAbortedDueToException, typeid(e).name(), e.Message);
+ req->getSink()->diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAbortedDueToException, typeid(e).name(), e.Message);
}
catch (...)
{
@@ -1662,9 +2123,9 @@ SLANG_API SlangResult spCompile(
// `Exception`, so something really fishy is going on. We want to
// let the user know that we messed up, so they know to blame Slang
// and not some other component in their system.
- req->mSink.diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAborted);
+ req->getSink()->diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAborted);
}
- req->mDiagnosticOutput = req->mSink.outputBuffer.ProduceString();
+ req->mDiagnosticOutput = req->getSink()->outputBuffer.ProduceString();
return res;
#else
// When debugging, we probably don't want to filter out any errors, since
@@ -1680,8 +2141,10 @@ spGetDependencyFileCount(
SlangCompileRequest* request)
{
if(!request) return 0;
- auto req = REQ(request);
- return (int) req->mDependencyFilePaths.Count();
+ auto req = convert(request);
+ auto frontEndReq = req->getFrontEndReq();
+ auto program = frontEndReq->getProgram();
+ return (int) program->getFilePathDependencies().Count();
}
/** Get the path to a file this compilation dependend on.
@@ -1692,16 +2155,19 @@ spGetDependencyFilePath(
int index)
{
if(!request) return 0;
- auto req = REQ(request);
- return req->mDependencyFilePaths[index].begin();
+ auto req = convert(request);
+ auto frontEndReq = req->getFrontEndReq();
+ auto program = frontEndReq->getProgram();
+ return program->getFilePathDependencies()[index].begin();
}
SLANG_API int
spGetTranslationUnitCount(
SlangCompileRequest* request)
{
- auto req = REQ(request);
- return (int) req->translationUnits.Count();
+ auto req = convert(request);
+ auto frontEndReq = req->getFrontEndReq();
+ return (int) frontEndReq->translationUnits.Count();
}
// Get the output code associated with a specific translation unit
@@ -1718,15 +2184,26 @@ SLANG_API void const* spGetEntryPointCode(
int entryPointIndex,
size_t* outSize)
{
- auto req = REQ(request);
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ auto program = req->getSpecializedProgram();
// TODO: We should really accept a target index in this API
- auto targetCount = req->targets.Count();
- if (targetCount == 0)
+ Slang::UInt targetIndex = 0;
+ auto targetCount = linkage->targets.Count();
+ if (targetIndex >= targetCount)
return nullptr;
- auto targetReq = req->targets[0];
+ auto targetReq = linkage->targets[targetIndex];
- Slang::CompileResult& result = targetReq->entryPointResults[entryPointIndex];
+
+ if(entryPointIndex < 0) return nullptr;
+ if(Slang::UInt(entryPointIndex) >= req->entryPoints.Count()) return nullptr;
+ auto entryPoint = program->getEntryPoint(entryPointIndex);
+
+ auto targetProgram = program->getTargetProgram(targetReq);
+ if(!targetProgram)
+ return nullptr;
+ Slang::CompileResult& result = targetProgram->getExistingEntryPointResult(entryPointIndex);
void const* data = nullptr;
size_t size = 0;
@@ -1761,21 +2238,29 @@ SLANG_API SlangResult spGetEntryPointCodeBlob(
if(!request) return SLANG_ERROR_INVALID_PARAMETER;
if(!outBlob) return SLANG_ERROR_INVALID_PARAMETER;
- auto req = REQ(request);
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ auto program = req->getSpecializedProgram();
- int targetCount = (int) req->targets.Count();
+ int targetCount = (int) linkage->targets.Count();
if((targetIndex < 0) || (targetIndex >= targetCount))
{
return SLANG_ERROR_INVALID_PARAMETER;
}
- auto targetReq = req->targets[targetIndex];
+ auto targetReq = linkage->targets[targetIndex];
int entryPointCount = (int) req->entryPoints.Count();
if((entryPointIndex < 0) || (entryPointIndex >= entryPointCount))
{
return SLANG_ERROR_INVALID_PARAMETER;
}
- Slang::CompileResult& result = targetReq->entryPointResults[entryPointIndex];
+ auto entryPointReq = program->getEntryPoint(entryPointIndex);
+
+
+ auto targetProgram = program->getTargetProgram(targetReq);
+ if(!targetProgram)
+ return SLANG_FAIL;
+ Slang::CompileResult& result = targetProgram->getExistingEntryPointResult(entryPointIndex);
auto blob = result.getBlob();
*outBlob = blob.detach();
@@ -1793,13 +2278,9 @@ SLANG_API void const* spGetCompileRequestCode(
SlangCompileRequest* request,
size_t* outSize)
{
- auto req = REQ(request);
-
- void const* data = req->generatedBytecode.Buffer();
- size_t size = req->generatedBytecode.Count();
-
- if(outSize) *outSize = size;
- return data;
+ SLANG_UNUSED(request);
+ SLANG_UNUSED(outSize);
+ return nullptr;
}
// Reflection API
@@ -1808,7 +2289,9 @@ SLANG_API SlangReflection* spGetReflection(
SlangCompileRequest* request)
{
if( !request ) return 0;
- auto req = REQ(request);
+ auto req = convert(request);
+ auto linkage = req->getLinkage();
+ auto program = req->getSpecializedProgram();
// Note(tfoley): The API signature doesn't let the client
// specify which target they want to access reflection
@@ -1818,12 +2301,16 @@ SLANG_API SlangReflection* spGetReflection(
// so that we can do this better, and make it clear that
// `spGetReflection()` is shorthand for `targetIndex == 0`.
//
- auto targetCount = req->targets.Count();
- if (targetCount == 0)
- return 0;
- auto targetReq = req->targets[0];
+ Slang::UInt targetIndex = 0;
+ auto targetCount = linkage->targets.Count();
+ if (targetIndex >= targetCount)
+ return nullptr;
+
+ auto targetReq = linkage->targets[targetIndex];
+ auto targetProgram = program->getTargetProgram(targetReq);
+ auto programLayout = targetProgram->getExistingLayout();
- return (SlangReflection*) targetReq->layout.Ptr();
+ return (SlangReflection*) programLayout;
}
// ... rest of reflection API implementation is in `Reflection.cpp`
diff --git a/source/slang/syntax-visitors.h b/source/slang/syntax-visitors.h
index 9644deae1..3fca323e8 100644
--- a/source/slang/syntax-visitors.h
+++ b/source/slang/syntax-visitors.h
@@ -6,8 +6,10 @@
namespace Slang
{
- class CompileRequest;
- class EntryPointRequest;
+ class DiagnosticSink;
+ class EntryPoint;
+ class Linkage;
+ class Module;
class ShaderCompiler;
class ShaderLinkInfo;
class ShaderSymbol;
@@ -24,10 +26,11 @@ namespace Slang
// Needed by import declaration checking.
//
// TODO: need a better location to declare this.
- RefPtr<ModuleDecl> findOrImportModule(
- CompileRequest* request,
+ RefPtr<Module> findOrImportModule(
+ Linkage* linkage,
Name* name,
- SourceLoc const& loc);
+ SourceLoc const& loc,
+ DiagnosticSink* sink);
}
#endif \ No newline at end of file
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index 709206278..b1b9f6d80 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -2713,4 +2713,15 @@ RefPtr<Val> TaggedUnionSubtypeWitness::SubstituteImpl(SubstitutionSet subst, int
return substWitness;
}
+Module* getModule(Decl* decl)
+{
+ for( auto dd = decl; dd; dd = dd->ParentDecl )
+ {
+ if(auto moduleDecl = as<ModuleDecl>(dd))
+ return moduleDecl->module;
+ }
+ return nullptr;
+}
+
+
} // namespace Slang
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index 6a404214e..5198a44b2 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -12,6 +12,7 @@
namespace Slang
{
+ class Module;
class Name;
class Session;
class Substitutions;
@@ -1360,6 +1361,10 @@ namespace Slang
Function = 4,
All = 7
};
+
+ /// Get the module that a declaration is associated with, if any.
+ Module* getModule(Decl* decl);
+
} // namespace Slang
#endif
diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp
index 0bc676cf9..95a92ee2c 100644
--- a/source/slang/type-layout.cpp
+++ b/source/slang/type-layout.cpp
@@ -802,7 +802,7 @@ LayoutRulesImpl* GetLayoutRulesImpl(LayoutRule rule)
LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targetReq)
{
- switch (targetReq->target)
+ switch (targetReq->getTarget())
{
case CodeGenTarget::HLSL:
case CodeGenTarget::DXBytecode:
@@ -821,12 +821,13 @@ LayoutRulesFamilyImpl* getDefaultLayoutRulesFamilyForTarget(TargetRequest* targe
}
}
-TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq)
+TypeLayoutContext getInitialLayoutContextForTarget(TargetRequest* targetReq, ProgramLayout* programLayout)
{
LayoutRulesFamilyImpl* rulesFamily = getDefaultLayoutRulesFamilyForTarget(targetReq);
TypeLayoutContext context;
context.targetReq = targetReq;
+ context.programLayout = programLayout;
context.rules = nullptr;
context.matrixLayoutMode = targetReq->getDefaultMatrixLayoutMode();
@@ -962,7 +963,7 @@ static bool isOpenGLTarget(TargetRequest*)
bool isD3DTarget(TargetRequest* targetReq)
{
- switch( targetReq->target )
+ switch( targetReq->getTarget() )
{
case CodeGenTarget::HLSL:
case CodeGenTarget::DXBytecode:
@@ -978,7 +979,7 @@ bool isD3DTarget(TargetRequest* targetReq)
bool isKhronosTarget(TargetRequest* targetReq)
{
- switch( targetReq->target )
+ switch( targetReq->getTarget() )
{
default:
return false;
@@ -1008,7 +1009,7 @@ static bool isSM5OrEarlier(TargetRequest* targetReq)
if(!isD3DTarget(targetReq))
return false;
- auto profile = targetReq->targetProfile;
+ auto profile = targetReq->getTargetProfile();
if(profile.getFamily() == ProfileFamily::DX)
{
@@ -1024,7 +1025,7 @@ static bool isSM5_1OrLater(TargetRequest* targetReq)
if(!isD3DTarget(targetReq))
return false;
- auto profile = targetReq->targetProfile;
+ auto profile = targetReq->getTargetProfile();
if(profile.getFamily() == ProfileFamily::DX)
{
@@ -2102,7 +2103,7 @@ SimpleLayoutInfo GetLayoutImpl(
//
// The `maybeAdjustLayoutForArrayElementType` computes an "adjusted"
// type layout for the element type which takes the array stride into
- // acount. If it returns the same type layout that was passed in,
+ // account. If it returns the same type layout that was passed in,
// then that means no adjustement took place.
//
// The `additionalSpacesNeededForAdjustedElementType` variable counts
@@ -2327,13 +2328,35 @@ SimpleLayoutInfo GetLayoutImpl(
// we should have already populated ProgramLayout::genericEntryPointParams list at this point,
// so we can find the index of this generic param decl in the list
genParamTypeLayout->type = type;
- genParamTypeLayout->paramIndex = findGenericParam(context.targetReq->layout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl());
+ genParamTypeLayout->paramIndex = findGenericParam(context.programLayout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl());
genParamTypeLayout->rules = rules;
genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count += 1;
*outTypeLayout = genParamTypeLayout;
}
return info;
}
+ else if( auto simpleGenericParam = declRef.as<GenericTypeParamDecl>() )
+ {
+ // A bare generic type parameter can come up during layout
+ // of a generic entry point (or an entry point nested in
+ // a generic type). For now we will just pretend like
+ // the fields of generic parameter type take no space,
+ // since there is no reasonable way to account for them
+ // in the resulting layout.
+ //
+ // TODO: It might be better to completely ignore generic
+ // entry points during initial layout, but doing so would
+ // mean that users couldn't get layout information on
+ // any parameters, even those that don't depend on
+ // generics.
+ //
+ SimpleLayoutInfo info;
+ return GetSimpleLayoutImpl(
+ info,
+ type,
+ rules,
+ outTypeLayout);
+ }
}
else if (auto errorType = as<ErrorType>(type))
{
diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h
index e20db7f56..1d939b18f 100644
--- a/source/slang/type-layout.h
+++ b/source/slang/type-layout.h
@@ -649,6 +649,14 @@ public:
RefPtr<VarLayout> globalScopeLayout;
*/
+ /// The target and program for which layout was computed
+ TargetProgram* targetProgram;
+
+ TargetProgram* getTargetProgram() { return targetProgram; }
+ TargetRequest* getTargetReq() { return targetProgram->getTargetReq(); }
+ Program* getProgram() { return targetProgram->getProgram(); }
+
+
// We catalog the requested entry points here,
// and any entry-point-specific parameter data
// will (eventually) belong there...
@@ -656,8 +664,6 @@ public:
List<RefPtr<GenericParamLayout>> globalGenericParams;
Dictionary<String, GenericParamLayout*> globalGenericParamsMap;
-
- TargetRequest* targetRequest = nullptr;
};
StructTypeLayout* getGlobalStructLayout(
@@ -804,6 +810,8 @@ struct LayoutRulesFamilyImpl
virtual LayoutRulesImpl* getShaderRecordConstantBufferRules() = 0;
};
+typedef List<RefPtr<GenericParamLayout>> GenericParamLayouts;
+
struct TypeLayoutContext
{
// The layout rules to use (e.g., we compute
@@ -812,7 +820,12 @@ struct TypeLayoutContext
LayoutRulesImpl* rules;
// The target request that is triggering layout
- TargetRequest* targetReq;
+ TargetRequest* targetReq;
+
+ // A parent program layout that will establish the ordering
+ // of all global generic type parameters.
+ //
+ ProgramLayout* programLayout;
// Whether to lay out matrices column-major
// or row-major.
@@ -840,8 +853,13 @@ struct TypeLayoutContext
// Get an appropriate set of layout rules (packaged up
// as a `TypeLayoutContext`) to perform type layout
// for the given target.
+//
+// The provided `programLayout` is used to establish
+// the ordering of all global generic type paramters.
+//
TypeLayoutContext getInitialLayoutContextForTarget(
- TargetRequest* targetReq);
+ TargetRequest* targetReq,
+ ProgramLayout* programLayout);
// Get the "simple" layout for a type according to a given set of layout
// rules. Note that a "simple" layout can only consume one `LayoutResourceKind`,