diff options
| -rw-r--r-- | source/slang/ast-legalize.cpp | 6 | ||||
| -rw-r--r-- | source/slang/check.cpp | 3 | ||||
| -rw-r--r-- | source/slang/compiler.h | 24 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 58 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 98 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 60 |
7 files changed, 198 insertions, 55 deletions
diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp index 5b6c5ae38..8e4da2717 100644 --- a/source/slang/ast-legalize.cpp +++ b/source/slang/ast-legalize.cpp @@ -3540,9 +3540,9 @@ struct LoweringVisitor return translationUnit->sourceLanguage; } - for (auto loadedModuleDecl : shared->compileRequest->loadedModulesList) + for (auto loadedModule : shared->compileRequest->loadedModulesList) { - if (moduleDecl == loadedModuleDecl) + if (moduleDecl == loadedModule->moduleDecl) return SourceLanguage::Slang; } @@ -4696,7 +4696,7 @@ LoweredEntryPoint lowerEntryPoint( for (auto rr : entryPoint->compileRequest->loadedModulesList) { sharedContext.loweredDecls.Add( - rr, + rr->moduleDecl, LoweredDecl(loweredProgram)); } diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 4b8f4f4c1..c1893423a 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -6648,8 +6648,9 @@ namespace Slang globalGenericParams.Add(p); } // add imported modules - for (auto moduleDecl : entryPoint->compileRequest->loadedModulesList) + for (auto loadedModule : entryPoint->compileRequest->loadedModulesList) { + auto moduleDecl = loadedModule->moduleDecl; auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); for (auto p : globalGenParams) globalGenericParams.Add(p); diff --git a/source/slang/compiler.h b/source/slang/compiler.h index a48ad2287..0e85a1088 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -211,6 +211,19 @@ namespace Slang String path; }; + // Represents a module that has been loaded through the front-end + // (up through IR generation). + // + class LoadedModule : public RefObject + { + public: + // The AST for the module + RefPtr<ModuleDecl> moduleDecl; + + // The IR for the module + IRModule* irModule = nullptr; + }; + class Session; class CompileRequest : public RefObject @@ -285,13 +298,13 @@ namespace Slang // Modules that have been dynamically loaded via `import` // // This is a list of unique modules loaded, in the order they were encountered. - List<RefPtr<ModuleDecl> > loadedModulesList; + List<RefPtr<LoadedModule> > loadedModulesList; // Map from the path of a module file to its definition - Dictionary<String, RefPtr<ModuleDecl>> mapPathToLoadedModule; + Dictionary<String, RefPtr<LoadedModule>> mapPathToLoadedModule; // Map from the logical name of a module to its definition - Dictionary<Name*, RefPtr<ModuleDecl>> mapNameToLoadedModules; + Dictionary<Name*, RefPtr<LoadedModule>> mapNameToLoadedModules; CompileRequest(Session* session); @@ -344,6 +357,11 @@ namespace Slang String const& path, TokenList const& tokens); + void loadParsedModule( + RefPtr<TranslationUnitRequest> const& translationUnit, + Name* name, + String const& path); + RefPtr<ModuleDecl> findOrImportModule( Name* name, SourceLoc const& loc); diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index eb78d144e..ae7b71172 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -3613,6 +3613,23 @@ namespace Slang return result; } + bool isDefinition( + IRGlobalValue* val) + { + switch (val->op) + { + case kIROp_witness_table: + return ((IRWitnessTable*)val)->entries.first != nullptr; + + case kIROp_global_var: + case kIROp_Func: + return ((IRGlobalValueWithCode*)val)->firstBlock != nullptr; + + default: + return false; + } + } + // Is `newVal` marked as being a better match for our // chosen code-generation target? // @@ -3657,7 +3674,17 @@ namespace Slang auto newLevel = getTargetSpecialiationLevel(newVal, targetName); auto oldLevel = getTargetSpecialiationLevel(oldVal, targetName); - return UInt(newLevel) > UInt(oldLevel); + if(newLevel != oldLevel) + return UInt(newLevel) > UInt(oldLevel); + + // All other factors being equal, a definition is + // better than a declaration. + auto newIsDef = isDefinition(newVal); + auto oldIsDef = isDefinition(oldVal); + if (newIsDef != oldIsDef) + return newIsDef; + + return false; } IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc) @@ -3743,6 +3770,19 @@ namespace Slang } } + void insertGlobalValueSymbols( + IRSharedSpecContext* sharedContext, + IRModule* originalModule) + { + if (!originalModule) + return; + + for (auto gv = originalModule->firstGlobalValue; gv; gv = gv->nextGlobalValue) + { + insertGlobalValueSymbol(sharedContext, gv); + } + } + void initializeSharedSpecContext( IRSharedSpecContext* sharedContext, Session* session, @@ -3766,13 +3806,10 @@ namespace Slang sharedContext->module = module; sharedContext->originalModule = originalModule; - // First, we will populate a map with all of the IR values + // We will populate a map with all of the IR values // that use the same mangled name, to make lookup easier // in other steps. - for (auto gv = originalModule->firstGlobalValue; gv; gv = gv->nextGlobalValue) - { - insertGlobalValueSymbol(sharedContext, gv); - } + insertGlobalValueSymbols(sharedContext, originalModule); } // implementation provided in parameter-binding.cpp @@ -3867,6 +3904,15 @@ namespace Slang nullptr, originalIRModule); + // We also need to attach the IR definitions for symbols from + // any loaded modules: + for (auto loadedModule : compileRequest->loadedModulesList) + { + insertGlobalValueSymbols(&sharedContextStorage, loadedModule->irModule); + } + // any loaded modules + + IRSpecContext contextStorage; IRSpecContext* context = &contextStorage; context->shared = &sharedContextStorage; diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index c01059c61..5ec175668 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -269,6 +269,7 @@ LoweredValInfo LoweredValInfo::swizzledLValue( struct SharedIRGenContext { CompileRequest* compileRequest; + ModuleDecl* mainModuleDecl; Dictionary<Decl*, LoweredValInfo> declValues; @@ -2534,6 +2535,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> SLANG_UNIMPLEMENTED_X("decl catch-all"); } + LoweredValInfo visitEmptyDecl(EmptyDecl* /*decl*/) + { + return LoweredValInfo(); + } + LoweredValInfo visitTypeDefDecl(TypeDefDecl * decl) { return LoweredValInfo::simple(context->irBuilder->getTypeVal(decl->type.type)); @@ -2691,7 +2697,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> context->shared->declValues[ DeclRef<VarDeclBase>(decl, nullptr)] = globalVal; - if( auto initExpr = decl->initExpr ) + if (isImportedDecl(decl)) + { + // Always emit imported declarations as declarations, + // and not definitions. + } + else if( auto initExpr = decl->initExpr ) { IRBuilder subBuilderStorage = *getBuilder(); IRBuilder* subBuilder = &subBuilderStorage; @@ -3131,6 +3142,49 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> irFunc->mangledName = mangledName; } + ModuleDecl* findModuleDecl(Decl* decl) + { + for (auto dd = decl; dd; dd = dd->ParentDecl) + { + if (auto moduleDecl = dynamic_cast<ModuleDecl*>(dd)) + return moduleDecl; + } + return nullptr; + } + + bool isFromStdLib(Decl* decl) + { + for (auto dd = decl; dd; dd = dd->ParentDecl) + { + if (dd->HasModifier<FromStdLibModifier>()) + return true; + } + return false; + } + + bool isImportedDecl(Decl* decl) + { + ModuleDecl* moduleDecl = findModuleDecl(decl); + if (!moduleDecl) + return false; + + // HACK: don't treat standard library code as + // being imported for right now, just because + // we don't load its IR in the same way as + // for other imports. + // + // TODO: Fix this the right way, by having standard + // library declarations have IR modules that we link + // in via the normal means. + if (isFromStdLib(decl)) + return false; + + if (moduleDecl != this->context->shared->mainModuleDecl) + return true; + + return false; + } + LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) { // Collect the parameter lists we will use for our new function. @@ -3248,18 +3302,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> irResultType); irFunc->type = irFuncType; - if (!decl->Body) + if (isImportedDecl(decl)) + { + // Always emit imported declarations as declarations, + // and not definitions. + } + else if (!decl->Body) { // This is a function declaration without a body. // In Slang we currently try not to support forward declarations - // (although we might have to give in eventually), so the - // only case where this arises is for a function that - // needs to be imported from another module. - - // TODO: we may need to attach something to the declaration, - // so that later passes don't get confused by it not having - // a body. - + // (although we might have to give in eventually), so + // this case should really only occur for builtin declarations. } else { @@ -3594,6 +3647,10 @@ LoweredValInfo maybeEmitSpecializeInst(IRGenContext* context, if (!hasGenericSubstitutions(declRef.substitutions)) return loweredDecl; + // There's no reason to specialize something that maps to a NULL pointer. + if (loweredDecl.flavor == LoweredValInfo::Flavor::None) + return loweredDecl; + auto val = getSimpleVal(context, loweredDecl); // We have the "raw" substitutions from the AST, but we may @@ -3635,7 +3692,7 @@ static void lowerEntryPointToIR( // we need to lower all global type arguments as well for (auto arg : entryPointRequest->genericParameterTypes) lowerType(context, arg); - auto loweredEntryPointFunc = lowerDecl(context, entryPointFuncDecl); + auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl); } #if 0 @@ -3682,12 +3739,18 @@ IRModule* lowerEntryPointToIR( IRModule* generateIRForTranslationUnit( TranslationUnitRequest* translationUnit) { + // If the user did not opt into IR usage, then don't compile IR + // for the translation unit. + if (!(translationUnit->compileFlags & SLANG_COMPILE_FLAG_USE_IR)) + return nullptr; + auto compileRequest = translationUnit->compileRequest; SharedIRGenContext sharedContextStorage; SharedIRGenContext* sharedContext = &sharedContextStorage; sharedContext->compileRequest = compileRequest; + sharedContext->mainModuleDecl = translationUnit->SyntaxNode; IRGenContext contextStorage; IRGenContext* context = &contextStorage; @@ -3710,10 +3773,23 @@ IRModule* generateIRForTranslationUnit( // We need to emit IR for all public/exported symbols // in the translation unit. + // + // For now, we will assume that *all* global-scope declarations + // represent public/exported symbols. + + // First, ensure that all entry points have been emitted, + // in case they require special handling. for (auto entryPoint : translationUnit->entryPoints) { lowerEntryPointToIR(context, entryPoint); } + // + // Next, ensure that all other global declarations have + // been emitted. + for (auto decl : translationUnit->SyntaxNode->Members) + { + ensureDecl(context, decl); + } // If we are being sked to dump IR during compilation, // then we can dump the initial IR for the module here. diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 0daa2abc7..b97c174fc 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -1648,9 +1648,9 @@ static void collectParameters( } // Now collect parameters from loaded modules - for (auto& module : request->loadedModulesList) + for (auto& loadedModule : request->loadedModulesList) { - collectModuleParameters(context, module.Ptr()); + collectModuleParameters(context, loadedModule->moduleDecl.Ptr()); } } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 6a103fc2d..204313e84 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -430,6 +430,23 @@ UInt CompileRequest::addTarget( return (int) result; } +void CompileRequest::loadParsedModule( + RefPtr<TranslationUnitRequest> const& translationUnit, + Name* name, + String const& path) +{ + checkTranslationUnit(translationUnit.Ptr()); + + RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode; + + RefPtr<LoadedModule> loadedModule = new LoadedModule(); + loadedModule->moduleDecl = moduleDecl; + loadedModule->irModule = generateIRForTranslationUnit(translationUnit); + + mapPathToLoadedModule.Add(path, loadedModule); + mapNameToLoadedModules.Add(name, loadedModule); + loadedModulesList.Add(loadedModule); +} RefPtr<ModuleDecl> CompileRequest::loadModule( Name* name, @@ -454,20 +471,12 @@ RefPtr<ModuleDecl> CompileRequest::loadModule( // TODO: handle errors - checkTranslationUnit(translationUnit.Ptr()); - - // Skip code generation - - // - - RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode; - - mapPathToLoadedModule.Add(path, moduleDecl); - mapNameToLoadedModules.Add(name, moduleDecl); - loadedModulesList.Add(moduleDecl); - - return moduleDecl; + loadParsedModule( + translationUnit, + name, + path); + return translationUnit->SyntaxNode; } void CompileRequest::handlePoundImport( @@ -491,14 +500,6 @@ void CompileRequest::handlePoundImport( // TODO: handle errors - checkTranslationUnit(translationUnit.Ptr()); - - // Skip code generation - - // - - RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode; - // TODO: It is a bit broken here that we use the module path, // as the "name" when registering things, but this saves // us the trouble of trying to special-case things when @@ -508,10 +509,11 @@ void CompileRequest::handlePoundImport( // running the name->path logic in reverse (e.g., replacing // `-` with `_` and `/` with `.`). Name* name = getNamePool()->getName(path); - mapNameToLoadedModules.Add(name, moduleDecl); - mapPathToLoadedModule.Add(path, moduleDecl); - loadedModulesList.Add(moduleDecl); + loadParsedModule( + translationUnit, + name, + path); } RefPtr<ModuleDecl> CompileRequest::findOrImportModule( @@ -520,9 +522,9 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( { // Have we already loaded a module matching this name? // If so, return it. - RefPtr<ModuleDecl> moduleDecl; - if (mapNameToLoadedModules.TryGetValue(name, moduleDecl)) - return moduleDecl; + RefPtr<LoadedModule> loadedModule; + if (mapNameToLoadedModules.TryGetValue(name, loadedModule)) + return loadedModule->moduleDecl; // Derive a file name for the module, by taking the given // identifier, replacing all occurences of `_` with `-`, @@ -572,8 +574,8 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( } // Maybe this was loaded previously via `#import` - if (mapPathToLoadedModule.TryGetValue(foundPath, moduleDecl)) - return moduleDecl; + if (mapPathToLoadedModule.TryGetValue(foundPath, loadedModule)) + return loadedModule->moduleDecl; // We've found a file that we can load for the given module, so |
