summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/ast-legalize.cpp6
-rw-r--r--source/slang/check.cpp3
-rw-r--r--source/slang/compiler.h24
-rw-r--r--source/slang/ir.cpp58
-rw-r--r--source/slang/lower-to-ir.cpp98
-rw-r--r--source/slang/parameter-binding.cpp4
-rw-r--r--source/slang/slang.cpp60
7 files changed, 198 insertions, 55 deletions
diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp
index 5b6c5ae38..8e4da2717 100644
--- a/source/slang/ast-legalize.cpp
+++ b/source/slang/ast-legalize.cpp
@@ -3540,9 +3540,9 @@ struct LoweringVisitor
return translationUnit->sourceLanguage;
}
- for (auto loadedModuleDecl : shared->compileRequest->loadedModulesList)
+ for (auto loadedModule : shared->compileRequest->loadedModulesList)
{
- if (moduleDecl == loadedModuleDecl)
+ if (moduleDecl == loadedModule->moduleDecl)
return SourceLanguage::Slang;
}
@@ -4696,7 +4696,7 @@ LoweredEntryPoint lowerEntryPoint(
for (auto rr : entryPoint->compileRequest->loadedModulesList)
{
sharedContext.loweredDecls.Add(
- rr,
+ rr->moduleDecl,
LoweredDecl(loweredProgram));
}
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 4b8f4f4c1..c1893423a 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -6648,8 +6648,9 @@ namespace Slang
globalGenericParams.Add(p);
}
// add imported modules
- for (auto moduleDecl : entryPoint->compileRequest->loadedModulesList)
+ for (auto loadedModule : entryPoint->compileRequest->loadedModulesList)
{
+ auto moduleDecl = loadedModule->moduleDecl;
auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>();
for (auto p : globalGenParams)
globalGenericParams.Add(p);
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index a48ad2287..0e85a1088 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -211,6 +211,19 @@ namespace Slang
String path;
};
+ // Represents a module that has been loaded through the front-end
+ // (up through IR generation).
+ //
+ class LoadedModule : public RefObject
+ {
+ public:
+ // The AST for the module
+ RefPtr<ModuleDecl> moduleDecl;
+
+ // The IR for the module
+ IRModule* irModule = nullptr;
+ };
+
class Session;
class CompileRequest : public RefObject
@@ -285,13 +298,13 @@ namespace Slang
// Modules that have been dynamically loaded via `import`
//
// This is a list of unique modules loaded, in the order they were encountered.
- List<RefPtr<ModuleDecl> > loadedModulesList;
+ List<RefPtr<LoadedModule> > loadedModulesList;
// Map from the path of a module file to its definition
- Dictionary<String, RefPtr<ModuleDecl>> mapPathToLoadedModule;
+ Dictionary<String, RefPtr<LoadedModule>> mapPathToLoadedModule;
// Map from the logical name of a module to its definition
- Dictionary<Name*, RefPtr<ModuleDecl>> mapNameToLoadedModules;
+ Dictionary<Name*, RefPtr<LoadedModule>> mapNameToLoadedModules;
CompileRequest(Session* session);
@@ -344,6 +357,11 @@ namespace Slang
String const& path,
TokenList const& tokens);
+ void loadParsedModule(
+ RefPtr<TranslationUnitRequest> const& translationUnit,
+ Name* name,
+ String const& path);
+
RefPtr<ModuleDecl> findOrImportModule(
Name* name,
SourceLoc const& loc);
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index eb78d144e..ae7b71172 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -3613,6 +3613,23 @@ namespace Slang
return result;
}
+ bool isDefinition(
+ IRGlobalValue* val)
+ {
+ switch (val->op)
+ {
+ case kIROp_witness_table:
+ return ((IRWitnessTable*)val)->entries.first != nullptr;
+
+ case kIROp_global_var:
+ case kIROp_Func:
+ return ((IRGlobalValueWithCode*)val)->firstBlock != nullptr;
+
+ default:
+ return false;
+ }
+ }
+
// Is `newVal` marked as being a better match for our
// chosen code-generation target?
//
@@ -3657,7 +3674,17 @@ namespace Slang
auto newLevel = getTargetSpecialiationLevel(newVal, targetName);
auto oldLevel = getTargetSpecialiationLevel(oldVal, targetName);
- return UInt(newLevel) > UInt(oldLevel);
+ if(newLevel != oldLevel)
+ return UInt(newLevel) > UInt(oldLevel);
+
+ // All other factors being equal, a definition is
+ // better than a declaration.
+ auto newIsDef = isDefinition(newVal);
+ auto oldIsDef = isDefinition(oldVal);
+ if (newIsDef != oldIsDef)
+ return newIsDef;
+
+ return false;
}
IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc)
@@ -3743,6 +3770,19 @@ namespace Slang
}
}
+ void insertGlobalValueSymbols(
+ IRSharedSpecContext* sharedContext,
+ IRModule* originalModule)
+ {
+ if (!originalModule)
+ return;
+
+ for (auto gv = originalModule->firstGlobalValue; gv; gv = gv->nextGlobalValue)
+ {
+ insertGlobalValueSymbol(sharedContext, gv);
+ }
+ }
+
void initializeSharedSpecContext(
IRSharedSpecContext* sharedContext,
Session* session,
@@ -3766,13 +3806,10 @@ namespace Slang
sharedContext->module = module;
sharedContext->originalModule = originalModule;
- // First, we will populate a map with all of the IR values
+ // We will populate a map with all of the IR values
// that use the same mangled name, to make lookup easier
// in other steps.
- for (auto gv = originalModule->firstGlobalValue; gv; gv = gv->nextGlobalValue)
- {
- insertGlobalValueSymbol(sharedContext, gv);
- }
+ insertGlobalValueSymbols(sharedContext, originalModule);
}
// implementation provided in parameter-binding.cpp
@@ -3867,6 +3904,15 @@ namespace Slang
nullptr,
originalIRModule);
+ // We also need to attach the IR definitions for symbols from
+ // any loaded modules:
+ for (auto loadedModule : compileRequest->loadedModulesList)
+ {
+ insertGlobalValueSymbols(&sharedContextStorage, loadedModule->irModule);
+ }
+ // any loaded modules
+
+
IRSpecContext contextStorage;
IRSpecContext* context = &contextStorage;
context->shared = &sharedContextStorage;
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index c01059c61..5ec175668 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -269,6 +269,7 @@ LoweredValInfo LoweredValInfo::swizzledLValue(
struct SharedIRGenContext
{
CompileRequest* compileRequest;
+ ModuleDecl* mainModuleDecl;
Dictionary<Decl*, LoweredValInfo> declValues;
@@ -2534,6 +2535,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
SLANG_UNIMPLEMENTED_X("decl catch-all");
}
+ LoweredValInfo visitEmptyDecl(EmptyDecl* /*decl*/)
+ {
+ return LoweredValInfo();
+ }
+
LoweredValInfo visitTypeDefDecl(TypeDefDecl * decl)
{
return LoweredValInfo::simple(context->irBuilder->getTypeVal(decl->type.type));
@@ -2691,7 +2697,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
context->shared->declValues[
DeclRef<VarDeclBase>(decl, nullptr)] = globalVal;
- if( auto initExpr = decl->initExpr )
+ if (isImportedDecl(decl))
+ {
+ // Always emit imported declarations as declarations,
+ // and not definitions.
+ }
+ else if( auto initExpr = decl->initExpr )
{
IRBuilder subBuilderStorage = *getBuilder();
IRBuilder* subBuilder = &subBuilderStorage;
@@ -3131,6 +3142,49 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
irFunc->mangledName = mangledName;
}
+ ModuleDecl* findModuleDecl(Decl* decl)
+ {
+ for (auto dd = decl; dd; dd = dd->ParentDecl)
+ {
+ if (auto moduleDecl = dynamic_cast<ModuleDecl*>(dd))
+ return moduleDecl;
+ }
+ return nullptr;
+ }
+
+ bool isFromStdLib(Decl* decl)
+ {
+ for (auto dd = decl; dd; dd = dd->ParentDecl)
+ {
+ if (dd->HasModifier<FromStdLibModifier>())
+ return true;
+ }
+ return false;
+ }
+
+ bool isImportedDecl(Decl* decl)
+ {
+ ModuleDecl* moduleDecl = findModuleDecl(decl);
+ if (!moduleDecl)
+ return false;
+
+ // HACK: don't treat standard library code as
+ // being imported for right now, just because
+ // we don't load its IR in the same way as
+ // for other imports.
+ //
+ // TODO: Fix this the right way, by having standard
+ // library declarations have IR modules that we link
+ // in via the normal means.
+ if (isFromStdLib(decl))
+ return false;
+
+ if (moduleDecl != this->context->shared->mainModuleDecl)
+ return true;
+
+ return false;
+ }
+
LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl)
{
// Collect the parameter lists we will use for our new function.
@@ -3248,18 +3302,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
irResultType);
irFunc->type = irFuncType;
- if (!decl->Body)
+ if (isImportedDecl(decl))
+ {
+ // Always emit imported declarations as declarations,
+ // and not definitions.
+ }
+ else if (!decl->Body)
{
// This is a function declaration without a body.
// In Slang we currently try not to support forward declarations
- // (although we might have to give in eventually), so the
- // only case where this arises is for a function that
- // needs to be imported from another module.
-
- // TODO: we may need to attach something to the declaration,
- // so that later passes don't get confused by it not having
- // a body.
-
+ // (although we might have to give in eventually), so
+ // this case should really only occur for builtin declarations.
}
else
{
@@ -3594,6 +3647,10 @@ LoweredValInfo maybeEmitSpecializeInst(IRGenContext* context,
if (!hasGenericSubstitutions(declRef.substitutions))
return loweredDecl;
+ // There's no reason to specialize something that maps to a NULL pointer.
+ if (loweredDecl.flavor == LoweredValInfo::Flavor::None)
+ return loweredDecl;
+
auto val = getSimpleVal(context, loweredDecl);
// We have the "raw" substitutions from the AST, but we may
@@ -3635,7 +3692,7 @@ static void lowerEntryPointToIR(
// we need to lower all global type arguments as well
for (auto arg : entryPointRequest->genericParameterTypes)
lowerType(context, arg);
- auto loweredEntryPointFunc = lowerDecl(context, entryPointFuncDecl);
+ auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl);
}
#if 0
@@ -3682,12 +3739,18 @@ IRModule* lowerEntryPointToIR(
IRModule* generateIRForTranslationUnit(
TranslationUnitRequest* translationUnit)
{
+ // If the user did not opt into IR usage, then don't compile IR
+ // for the translation unit.
+ if (!(translationUnit->compileFlags & SLANG_COMPILE_FLAG_USE_IR))
+ return nullptr;
+
auto compileRequest = translationUnit->compileRequest;
SharedIRGenContext sharedContextStorage;
SharedIRGenContext* sharedContext = &sharedContextStorage;
sharedContext->compileRequest = compileRequest;
+ sharedContext->mainModuleDecl = translationUnit->SyntaxNode;
IRGenContext contextStorage;
IRGenContext* context = &contextStorage;
@@ -3710,10 +3773,23 @@ IRModule* generateIRForTranslationUnit(
// We need to emit IR for all public/exported symbols
// in the translation unit.
+ //
+ // For now, we will assume that *all* global-scope declarations
+ // represent public/exported symbols.
+
+ // First, ensure that all entry points have been emitted,
+ // in case they require special handling.
for (auto entryPoint : translationUnit->entryPoints)
{
lowerEntryPointToIR(context, entryPoint);
}
+ //
+ // Next, ensure that all other global declarations have
+ // been emitted.
+ for (auto decl : translationUnit->SyntaxNode->Members)
+ {
+ ensureDecl(context, decl);
+ }
// If we are being sked to dump IR during compilation,
// then we can dump the initial IR for the module here.
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index 0daa2abc7..b97c174fc 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -1648,9 +1648,9 @@ static void collectParameters(
}
// Now collect parameters from loaded modules
- for (auto& module : request->loadedModulesList)
+ for (auto& loadedModule : request->loadedModulesList)
{
- collectModuleParameters(context, module.Ptr());
+ collectModuleParameters(context, loadedModule->moduleDecl.Ptr());
}
}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 6a103fc2d..204313e84 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -430,6 +430,23 @@ UInt CompileRequest::addTarget(
return (int) result;
}
+void CompileRequest::loadParsedModule(
+ RefPtr<TranslationUnitRequest> const& translationUnit,
+ Name* name,
+ String const& path)
+{
+ checkTranslationUnit(translationUnit.Ptr());
+
+ RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode;
+
+ RefPtr<LoadedModule> loadedModule = new LoadedModule();
+ loadedModule->moduleDecl = moduleDecl;
+ loadedModule->irModule = generateIRForTranslationUnit(translationUnit);
+
+ mapPathToLoadedModule.Add(path, loadedModule);
+ mapNameToLoadedModules.Add(name, loadedModule);
+ loadedModulesList.Add(loadedModule);
+}
RefPtr<ModuleDecl> CompileRequest::loadModule(
Name* name,
@@ -454,20 +471,12 @@ RefPtr<ModuleDecl> CompileRequest::loadModule(
// TODO: handle errors
- checkTranslationUnit(translationUnit.Ptr());
-
- // Skip code generation
-
- //
-
- RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode;
-
- mapPathToLoadedModule.Add(path, moduleDecl);
- mapNameToLoadedModules.Add(name, moduleDecl);
- loadedModulesList.Add(moduleDecl);
-
- return moduleDecl;
+ loadParsedModule(
+ translationUnit,
+ name,
+ path);
+ return translationUnit->SyntaxNode;
}
void CompileRequest::handlePoundImport(
@@ -491,14 +500,6 @@ void CompileRequest::handlePoundImport(
// TODO: handle errors
- checkTranslationUnit(translationUnit.Ptr());
-
- // Skip code generation
-
- //
-
- RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode;
-
// TODO: It is a bit broken here that we use the module path,
// as the "name" when registering things, but this saves
// us the trouble of trying to special-case things when
@@ -508,10 +509,11 @@ void CompileRequest::handlePoundImport(
// running the name->path logic in reverse (e.g., replacing
// `-` with `_` and `/` with `.`).
Name* name = getNamePool()->getName(path);
- mapNameToLoadedModules.Add(name, moduleDecl);
- mapPathToLoadedModule.Add(path, moduleDecl);
- loadedModulesList.Add(moduleDecl);
+ loadParsedModule(
+ translationUnit,
+ name,
+ path);
}
RefPtr<ModuleDecl> CompileRequest::findOrImportModule(
@@ -520,9 +522,9 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule(
{
// Have we already loaded a module matching this name?
// If so, return it.
- RefPtr<ModuleDecl> moduleDecl;
- if (mapNameToLoadedModules.TryGetValue(name, moduleDecl))
- return moduleDecl;
+ RefPtr<LoadedModule> loadedModule;
+ if (mapNameToLoadedModules.TryGetValue(name, loadedModule))
+ return loadedModule->moduleDecl;
// Derive a file name for the module, by taking the given
// identifier, replacing all occurences of `_` with `-`,
@@ -572,8 +574,8 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule(
}
// Maybe this was loaded previously via `#import`
- if (mapPathToLoadedModule.TryGetValue(foundPath, moduleDecl))
- return moduleDecl;
+ if (mapPathToLoadedModule.TryGetValue(foundPath, loadedModule))
+ return loadedModule->moduleDecl;
// We've found a file that we can load for the given module, so