diff options
Diffstat (limited to 'source/slang/slang.cpp')
| -rw-r--r-- | source/slang/slang.cpp | 1127 |
1 files changed, 807 insertions, 320 deletions
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 16dfe8618..7a5b58d07 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -60,6 +60,8 @@ Session::Session() // Make sure our source manager is initialized builtinSourceManager.initialize(nullptr, nullptr); + m_builtinLinkage = new Linkage(this); + // Initialize representations of some very basic types: initializeTypes(); @@ -90,11 +92,12 @@ Session::Session() struct IncludeHandlerImpl : IncludeHandler { - CompileRequest* request; + Linkage* linkage; + SearchDirectoryList* searchDirectories; ISlangFileSystemExt* _getFileSystemExt() { - return request->fileSystemExt; + return linkage->getFileSystemExt(); } SlangResult _findFile(SlangPathType fromPathType, const String& fromPath, const String& path, PathInfo& pathInfoOut) @@ -153,18 +156,22 @@ struct IncludeHandlerImpl : IncludeHandler } // Search all the searchDirectories - for (auto & dir : request->searchDirectories) + for(auto sd = searchDirectories; sd; sd = sd->parent) { - SlangResult res = _findFile(SLANG_PATH_TYPE_DIRECTORY, dir.path, pathToInclude, pathInfoOut); - if (SLANG_SUCCEEDED(res) || res != SLANG_E_NOT_FOUND) + for(auto& dir : sd->searchDirectories) { - return res; + SlangResult res = _findFile(SLANG_PATH_TYPE_DIRECTORY, dir.path, pathToInclude, pathInfoOut); + if (SLANG_SUCCEEDED(res) || res != SLANG_E_NOT_FOUND) + { + return res; + } } } return SLANG_E_NOT_FOUND; } +#if 0 virtual SlangResult readFile(const String& path, ISlangBlob** blobOut) override { @@ -175,6 +182,7 @@ struct IncludeHandlerImpl : IncludeHandler return SLANG_OK; } +#endif virtual String simplifyPath(const String& path) override { @@ -192,9 +200,9 @@ struct IncludeHandlerImpl : IncludeHandler // -Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target) +Profile getEffectiveProfile(EntryPoint* entryPoint, TargetRequest* target) { - auto entryPointProfile = entryPoint->profile; + auto entryPointProfile = entryPoint->getProfile(); auto targetProfile = target->targetProfile; // Depending on the target *format* we might have to restrict the @@ -310,20 +318,13 @@ Profile getEffectiveProfile(EntryPointRequest* entryPoint, TargetRequest* target // -CompileRequest::CompileRequest(Session* session) - : mSession(session) +Linkage::Linkage(Session* session) + : m_session(session) + , m_sourceManager(&m_defaultSourceManager) { getNamePool()->setRootNamePool(session->getRootNamePool()); - setSourceManager(&sourceManagerStorage); - - sourceManager->initialize(session->getBuiltinSourceManager(), nullptr); - - // Set all the default writers - for (int i = 0; i < int(WriterChannel::CountOf); ++i) - { - setWriter(WriterChannel(i), nullptr); - } + m_defaultSourceManager.initialize(session->getBuiltinSourceManager(), nullptr); setFileSystem(nullptr); } @@ -379,10 +380,61 @@ ComPtr<ISlangBlob> createRawBlob(void const* inData, size_t size) } // +// TargetRequest +// + +Session* TargetRequest::getSession() +{ + return linkage->getSession(); +} MatrixLayoutMode TargetRequest::getDefaultMatrixLayoutMode() { - return compileRequest->getDefaultMatrixLayoutMode(); + return linkage->getDefaultMatrixLayoutMode(); +} + +// +// TranslationUnitRequest +// + +TranslationUnitRequest::TranslationUnitRequest( + FrontEndCompileRequest* compileRequest) + : compileRequest(compileRequest) +{ + module = new Module(compileRequest->getLinkage()); +} + + +Session* TranslationUnitRequest::getSession() +{ + return compileRequest->getSession(); +} + +NamePool* TranslationUnitRequest::getNamePool() +{ + return compileRequest->getNamePool(); +} + +SourceManager* TranslationUnitRequest::getSourceManager() +{ + return compileRequest->getSourceManager(); +} + +void TranslationUnitRequest::addSourceFile(SourceFile* sourceFile) +{ + m_sourceFiles.Add(sourceFile); + + // We want to record that the compiled module has a dependency + // on the path of the source file, but we also need to account + // for cases where the user added a source string/blob without + // an associated path (so that the API passes along an empty + // string). + // + auto path = sourceFile->getPathInfo().foundPath; + if(path.Length()) + { + getModule()->addFilePathDependency(path); + } } @@ -407,7 +459,7 @@ static ISlangWriter* _getDefaultWriter(WriterChannel chan) } } -void CompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) +void EndToEndCompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) { // If the user passed in null, we will use the default writer on that channel m_writers[int(chan)] = writer ? writer : _getDefaultWriter(chan); @@ -415,20 +467,20 @@ void CompileRequest::setWriter(WriterChannel chan, ISlangWriter* writer) // For diagnostic output, if the user passes in nullptr, we set on mSink.writer as that enables buffering on DiagnosticSink if (chan == WriterChannel::Diagnostic) { - mSink.writer = writer; + m_sink.writer = writer; } } -SlangResult CompileRequest::loadFile(String const& path, ISlangBlob** outBlob) +SlangResult Linkage::loadFile(String const& path, ISlangBlob** outBlob) { return fileSystemExt->loadFile(path.Buffer(), outBlob); } -RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope) +RefPtr<Expr> Linkage::parseTypeString(String typeStr, RefPtr<Scope> scope) { // Create a SourceManager on the stack, so any allocations for 'SourceFile'/'SourceView' etc will be cleaned up SourceManager localSourceManager; - localSourceManager.initialize(sourceManager, nullptr); + localSourceManager.initialize(getSourceManager(), nullptr); Slang::SourceFile* srcFile = localSourceManager.createSourceFileWithString(PathInfo::makeTypeParse(), typeStr); @@ -440,20 +492,20 @@ RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translatio // Use RAII - to make sure everything is reset even if an exception is thrown. struct ScopeReplaceSourceManager { - ScopeReplaceSourceManager(CompileRequest* request, SourceManager* replaceManager): - m_request(request), - m_originalSourceManager(request->getSourceManager()) + ScopeReplaceSourceManager(Linkage* linkage, SourceManager* replaceManager): + m_linkage(linkage), + m_originalSourceManager(linkage->getSourceManager()) { - request->setSourceManager(replaceManager); + linkage->setSourceManager(replaceManager); } ~ScopeReplaceSourceManager() { - m_request->setSourceManager(m_originalSourceManager); + m_linkage->setSourceManager(m_originalSourceManager); } private: - CompileRequest* m_request; + Linkage* m_linkage; SourceManager* m_originalSourceManager; }; @@ -465,87 +517,131 @@ RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translatio &sink, nullptr, Dictionary<String,String>(), - translationUnit); + this, + nullptr); - return parseTypeFromSourceFile(translationUnit, tokens, &sink, scope); + return parseTypeFromSourceFile( + getSession(), + tokens, &sink, scope, getNamePool(), SourceLanguage::Slang); } -RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp); -Type* CompileRequest::getTypeFromString(String typeStr) +RefPtr<Type> checkProperType( + Linkage* linkage, + TypeExp typeExp, + DiagnosticSink* sink); + +Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink) { + // If we've looked up this type name before, + // then we can re-use it. + // RefPtr<Type> type; - if (types.TryGetValue(typeStr, type)) + if(m_types.TryGetValue(typeStr, type)) return type; - auto translationUnit = translationUnits.First(); + + // Otherwise, we need to start looking in + // the modules that were directly or + // indirectly referenced. + // + // TODO: This `scopesToTry` idiom appears + // all over the code, and isn't really + // how we should be handling this kind of + // lookup at all. + // List<RefPtr<Scope>> scopesToTry; - for (auto tu : translationUnits) - scopesToTry.Add(tu->SyntaxNode->scope); - for (auto & module : loadedModulesList) - scopesToTry.Add(module->moduleDecl->scope); - // parse type name - for (auto & s : scopesToTry) - { - RefPtr<Expr> typeExpr = parseTypeString(translationUnit, + for(auto module : getModuleDependencies()) + scopesToTry.Add(module->getModuleDecl()->scope); + + auto linkage = getLinkage(); + for(auto& s : scopesToTry) + { + RefPtr<Expr> typeExpr = linkage->parseTypeString( typeStr, s); - type = checkProperType(translationUnit, TypeExp(typeExpr)); - if (type) + type = checkProperType(linkage, TypeExp(typeExpr), sink); + if(type) break; } - if (type) + if( type ) { - types[typeStr] = type; + m_types[typeStr] = type; } - return type.Ptr(); + return type; } -void CompileRequest::parseTranslationUnit( +CompileRequestBase::CompileRequestBase( + Linkage* linkage, + DiagnosticSink* sink) + : m_linkage(linkage) + , m_sink(sink) +{} + + +FrontEndCompileRequest::FrontEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink) + : CompileRequestBase(linkage, sink) +{ +} + +void FrontEndCompileRequest::parseTranslationUnit( TranslationUnitRequest* translationUnit) { IncludeHandlerImpl includeHandler; - includeHandler.request = this; + includeHandler.linkage = getLinkage(); + includeHandler.searchDirectories = &searchDirectories; RefPtr<Scope> languageScope; switch (translationUnit->sourceLanguage) { case SourceLanguage::HLSL: - languageScope = mSession->hlslLanguageScope; + languageScope = getSession()->hlslLanguageScope; break; case SourceLanguage::Slang: default: - languageScope = mSession->slangLanguageScope; + languageScope = getSession()->slangLanguageScope; break; } Dictionary<String, String> combinedPreprocessorDefinitions; + for(auto& def : getLinkage()->preprocessorDefinitions) + combinedPreprocessorDefinitions.Add(def.Key, def.Value); for(auto& def : preprocessorDefinitions) combinedPreprocessorDefinitions.Add(def.Key, def.Value); for(auto& def : translationUnit->preprocessorDefinitions) combinedPreprocessorDefinitions.Add(def.Key, def.Value); + auto module = translationUnit->getModule(); RefPtr<ModuleDecl> translationUnitSyntax = new ModuleDecl(); - translationUnit->SyntaxNode = translationUnitSyntax; + translationUnitSyntax->nameAndLoc.name = translationUnit->moduleName; + translationUnitSyntax->module = module; + module->setModuleDecl(translationUnitSyntax); - for (auto sourceFile : translationUnit->sourceFiles) + for (auto sourceFile : translationUnit->getSourceFiles()) { auto tokens = preprocessSource( sourceFile, - &mSink, + getSink(), &includeHandler, combinedPreprocessorDefinitions, - translationUnit); + getLinkage(), + module); parseSourceFile( translationUnit, tokens, - &mSink, + getSink(), languageScope); } } -void validateEntryPoints(CompileRequest*); +RefPtr<Program> createUnspecializedProgram( + FrontEndCompileRequest* compileRequest); -void CompileRequest::checkAllTranslationUnits() +RefPtr<Program> createSpecializedProgram( + EndToEndCompileRequest* endToEndReq); + +void FrontEndCompileRequest::checkAllTranslationUnits() { // Iterate over all translation units and // apply the semantic checking logic. @@ -553,12 +649,9 @@ void CompileRequest::checkAllTranslationUnits() { checkTranslationUnit(translationUnit.Ptr()); } - - // Next, do follow-up validation on any entry points. - validateEntryPoints(this); } -void CompileRequest::generateIR() +void FrontEndCompileRequest::generateIR() { // Our task in this function is to generate IR code // for all of the declarations in the translation @@ -581,9 +674,9 @@ void CompileRequest::generateIR() if (verifyDebugSerialization) { // Verify debug information - if (SLANG_FAILED(IRSerialUtil::verifySerialize(irModule, mSession, sourceManager, IRSerialBinary::CompressionType::None, IRSerialWriter::OptionFlag::DebugInfo))) + if (SLANG_FAILED(IRSerialUtil::verifySerialize(irModule, getSession(), getSourceManager(), IRSerialBinary::CompressionType::None, IRSerialWriter::OptionFlag::DebugInfo))) { - mSink.diagnose(irModule->moduleInst->sourceLoc, Diagnostics::serialDebugVerificationFailed); + getSink()->diagnose(irModule->moduleInst->sourceLoc, Diagnostics::serialDebugVerificationFailed); } } @@ -593,7 +686,7 @@ void CompileRequest::generateIR() { // Write IR out to serialData - copying over SourceLoc information directly IRSerialWriter writer; - writer.write(irModule, sourceManager, IRSerialWriter::OptionFlag::RawSourceLocation, &serialData); + writer.write(irModule, getSourceManager(), IRSerialWriter::OptionFlag::RawSourceLocation, &serialData); // Destroy irModule such that memory can be used for newly constructed read irReadModule irModule = nullptr; @@ -602,7 +695,7 @@ void CompileRequest::generateIR() { // Read IR back from serialData IRSerialReader reader; - reader.read(serialData, mSession, nullptr, irReadModule); + reader.read(serialData, getSession(), nullptr, irReadModule); } // Set irModule to the read module @@ -610,12 +703,12 @@ void CompileRequest::generateIR() } // Set the module on the translation unit - translationUnit->irModule = irModule; + translationUnit->getModule()->setIRModule(irModule); } } // Try to infer a single common source language for a request -static SourceLanguage inferSourceLanguage(CompileRequest* request) +static SourceLanguage inferSourceLanguage(FrontEndCompileRequest* request) { SourceLanguage language = SourceLanguage::Unknown; for (auto& translationUnit : request->translationUnits) @@ -639,29 +732,115 @@ static SourceLanguage inferSourceLanguage(CompileRequest* request) return language; } -SlangResult CompileRequest::executeActionsInner() +SlangResult FrontEndCompileRequest::executeActionsInner() { - // Do some cleanup on settings specified by user. - // In particular, we want to propagate flags from the overall request down to - // each translation unit. + // We currently allow GlSL files on the command line so that we can + // drive our "pass-through" mode, but we really want to issue an error + // message if the user is seriously asking us to compile them. for (auto& translationUnit : translationUnits) { - translationUnit->compileFlags |= compileFlags; + switch(translationUnit->sourceLanguage) + { + default: + break; + + case SourceLanguage::GLSL: + getSink()->diagnose(SourceLoc(), Diagnostics::glslIsNotSupported); + return SLANG_FAIL; + } + } + + + // Parse everything from the input files requested + for (auto& translationUnit : translationUnits) + { + parseTranslationUnit(translationUnit.Ptr()); } + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + // Perform semantic checking on the whole collection + checkAllTranslationUnits(); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + // Look up all the entry points that are expected, + // and use them to populate the `program` member. + // + m_program = createUnspecializedProgram(this); + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + if ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) + { + // Generate initial IR for all the translation + // units, if we are in a mode where IR is called for. + generateIR(); + } + + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + // Do parameter binding generation, for each compilation target. + // + for(auto targetReq : getLinkage()->targets) + { + auto targetProgram = m_program->getTargetProgram(targetReq); + targetProgram->getOrCreateLayout(getSink()); + } + if (getSink()->GetErrorCount() != 0) + return SLANG_FAIL; + + return SLANG_OK; +} + +BackEndCompileRequest::BackEndCompileRequest( + Linkage* linkage, + DiagnosticSink* sink, + Program* program) + : CompileRequestBase(linkage, sink) + , m_program(program) +{} + +EndToEndCompileRequest::EndToEndCompileRequest( + Session* session) + : m_session(session) +{ + m_linkage = new Linkage(session); + + m_sink.sourceManager = m_linkage->getSourceManager(); + + // Set all the default writers + for (int i = 0; i < int(WriterChannel::CountOf); ++i) + { + setWriter(WriterChannel(i), nullptr); + } + + m_frontEndReq = new FrontEndCompileRequest(getLinkage(), getSink()); + + m_backEndReq = new BackEndCompileRequest(getLinkage(), getSink()); +} + +SlangResult EndToEndCompileRequest::executeActionsInner() +{ // If no code-generation target was specified, then try to infer one from the source language, // just to make sure we can do something reasonable when invoked from the command line. - if (targets.Count() == 0) + // + // TODO: This logic should be moved into `options.cpp` or somewhere else + // specific to the command-line tool. + // + if (getLinkage()->targets.Count() == 0) { - auto language = inferSourceLanguage(this); + auto language = inferSourceLanguage(getFrontEndReq()); switch (language) { case SourceLanguage::HLSL: - addTarget(CodeGenTarget::DXBytecode); + getLinkage()->addTarget(CodeGenTarget::DXBytecode); break; case SourceLanguage::GLSL: - addTarget(CodeGenTarget::SPIRV); + getLinkage()->addTarget(CodeGenTarget::SPIRV); break; default: @@ -672,105 +851,117 @@ SlangResult CompileRequest::executeActionsInner() // We only do parsing and semantic checking if we *aren't* doing // a pass-through compilation. // - // Note that we *do* perform output generation as normal in pass-through mode. if (passThrough == PassThroughMode::None) { - // We currently allow GlSL files on the command line so that we can - // drive our "pass-through" mode, but we really want to issue an error - // message if the user is seriously asking us to compile them. - for (auto& translationUnit : translationUnits) - { - switch(translationUnit->sourceLanguage) - { - default: - break; - - case SourceLanguage::GLSL: - mSink.diagnose(SourceLoc(), Diagnostics::glslIsNotSupported); - return SLANG_FAIL; - } - } + SLANG_RETURN_ON_FAIL(getFrontEndReq()->executeActionsInner()); + } + // If command line specifies to skip codegen, we exit here. + // Note: this is a debugging option. + // + if (shouldSkipCodegen || + ((getFrontEndReq()->compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) != 0)) + { + // We will use the program (and matching layout information) + // that was computed in the front-end for all subsequent + // reflection queries, etc. + // + m_specializedProgram = getUnspecializedProgram(); - // Parse everything from the input files requested - for (auto& translationUnit : translationUnits) - { - parseTranslationUnit(translationUnit.Ptr()); - } - if (mSink.GetErrorCount() != 0) - return SLANG_FAIL; + return SLANG_OK; + } - // Perform semantic checking on the whole collection - checkAllTranslationUnits(); - if (mSink.GetErrorCount() != 0) + // If codegen is enabled, we need to move along to + // apply any generic specialization that the user asked for. + // + if (passThrough == PassThroughMode::None) + { + m_specializedProgram = createSpecializedProgram(this); + if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; - if ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) == 0) + // For each code generation target, we will generate specialized + // parameter binding information (taking global generic + // arguments into account at this time). + // + for (auto targetReq : getLinkage()->targets) { - // Generate initial IR for all the translation - // units, if we are in a mode where IR is called for. - generateIR(); + auto targetProgram = m_specializedProgram->getTargetProgram(targetReq); + targetProgram->getOrCreateLayout(getSink()); } - - if (mSink.GetErrorCount() != 0) + if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; - - // For each code generation target generate - // parameter binding information. - // This step is done globally, because all translation - // units and entry points need to agree on where - // parameters are allocated. - for (auto targetReq : targets) + } + else + { + // We need to create dummy `EntryPoint` objects + // to make sure that the logic in `generateOutput` + // sees something worth processing. + // + auto specializedProgram = new Program(getLinkage()); + m_specializedProgram = specializedProgram; + for(auto entryPointReq : getFrontEndReq()->getEntryPointReqs()) { - generateParameterBindings(targetReq); - if (mSink.GetErrorCount() != 0) - return SLANG_FAIL; + RefPtr<EntryPoint> entryPoint = EntryPoint::createDummyForPassThrough( + entryPointReq->getName(), + entryPointReq->getProfile()); + + specializedProgram->addEntryPoint(entryPoint); } } - // If command line specifies to skip codegen, we exit here. - // Note: this is a debugging option. - if (shouldSkipCodegen || - ((compileFlags & SLANG_COMPILE_FLAG_NO_CODEGEN) != 0)) - return SLANG_OK; - // Generate output code, in whatever format was requested + getBackEndReq()->setProgram(getSpecializedProgram()); generateOutput(this); - if (mSink.GetErrorCount() != 0) + if (getSink()->GetErrorCount() != 0) return SLANG_FAIL; return SLANG_OK; } // Act as expected of the API-based compiler -SlangResult CompileRequest::executeActions() +SlangResult EndToEndCompileRequest::executeActions() { SlangResult res = executeActionsInner(); - mDiagnosticOutput = mSink.outputBuffer.ProduceString(); + mDiagnosticOutput = getSink()->outputBuffer.ProduceString(); return res; } -int CompileRequest::addTranslationUnit(SourceLanguage language, String const&) +int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language, Name* moduleName) { UInt result = translationUnits.Count(); - RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(); + RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(this); translationUnit->compileRequest = this; translationUnit->sourceLanguage = SourceLanguage(language); + translationUnit->moduleName = moduleName; + translationUnits.Add(translationUnit); return (int) result; } -void CompileRequest::addTranslationUnitSourceFile( +int FrontEndCompileRequest::addTranslationUnit(SourceLanguage language) +{ + // We want to ensure that symbols defined in different translation + // units get unique mangled names, so that we can, e.g., tell apart + // a `main()` function in `vertex.slang` and a `main()` in `fragment.slang`, + // even when they are being compiled together. + // + String generatedName = "tu"; + generatedName.append(translationUnits.Count()); + return addTranslationUnit(language, getNamePool()->getName(generatedName)); +} + +void FrontEndCompileRequest::addTranslationUnitSourceFile( int translationUnitIndex, SourceFile* sourceFile) { - translationUnits[translationUnitIndex]->sourceFiles.Add(sourceFile); + translationUnits[translationUnitIndex]->addSourceFile(sourceFile); } -void CompileRequest::addTranslationUnitSourceBlob( +void FrontEndCompileRequest::addTranslationUnitSourceBlob( int translationUnitIndex, String const& path, ISlangBlob* sourceBlob) @@ -781,7 +972,7 @@ void CompileRequest::addTranslationUnitSourceBlob( addTranslationUnitSourceFile(translationUnitIndex, sourceFile); } -void CompileRequest::addTranslationUnitSourceString( +void FrontEndCompileRequest::addTranslationUnitSourceString( int translationUnitIndex, String const& path, String const& source) @@ -792,7 +983,7 @@ void CompileRequest::addTranslationUnitSourceString( addTranslationUnitSourceFile(translationUnitIndex, sourceFile); } -void CompileRequest::addTranslationUnitSourceFile( +void FrontEndCompileRequest::addTranslationUnitSourceFile( int translationUnitIndex, String const& path) { @@ -809,7 +1000,7 @@ void CompileRequest::addTranslationUnitSourceFile( if(SLANG_FAILED(result)) { // Emit a diagnostic! - mSink.diagnose( + getSink()->diagnose( SourceLoc(), Diagnostics::cannotOpenFile, path); @@ -820,36 +1011,51 @@ void CompileRequest::addTranslationUnitSourceFile( translationUnitIndex, path, sourceBlob); +} + +int FrontEndCompileRequest::addEntryPoint( + int translationUnitIndex, + String const& name, + Profile entryPointProfile) +{ + auto translationUnitReq = translationUnits[translationUnitIndex]; + + UInt result = m_entryPointReqs.Count(); + + RefPtr<FrontEndEntryPointRequest> entryPointReq = new FrontEndEntryPointRequest( + this, + translationUnitIndex, + getNamePool()->getName(name), + entryPointProfile); + + m_entryPointReqs.Add(entryPointReq); +// translationUnitReq->entryPoints.Add(entryPointReq); - mDependencyFilePaths.Add(path); + return int(result); } -int CompileRequest::addEntryPoint( +int EndToEndCompileRequest::addEntryPoint( int translationUnitIndex, String const& name, Profile entryPointProfile, List<String> const & genericTypeNames) { - RefPtr<EntryPointRequest> entryPoint = new EntryPointRequest(); - entryPoint->compileRequest = this; - entryPoint->name = getNamePool()->getName(name); - entryPoint->profile = entryPointProfile; - entryPoint->translationUnitIndex = translationUnitIndex; + getFrontEndReq()->addEntryPoint(translationUnitIndex, name, entryPointProfile); + + EntryPointInfo entryPointInfo; for (auto typeName : genericTypeNames) - entryPoint->genericArgStrings.Add(typeName); - auto translationUnit = translationUnits[translationUnitIndex].Ptr(); - translationUnit->entryPoints.Add(entryPoint); + entryPointInfo.genericArgStrings.Add(typeName); UInt result = entryPoints.Count(); - entryPoints.Add(entryPoint); + entryPoints.Add(_Move(entryPointInfo)); return (int) result; } -UInt CompileRequest::addTarget( +UInt Linkage::addTarget( CodeGenTarget target) { RefPtr<TargetRequest> targetReq = new TargetRequest(); - targetReq->compileRequest = this; + targetReq->linkage = this; targetReq->target = target; UInt result = targets.Count(); @@ -857,15 +1063,16 @@ UInt CompileRequest::addTarget( return (int) result; } -void CompileRequest::loadParsedModule( - RefPtr<TranslationUnitRequest> const& translationUnit, - Name* name, - const PathInfo& pathInfo) +void Linkage::loadParsedModule( + RefPtr<TranslationUnitRequest> translationUnit, + Name* name, + const PathInfo& pathInfo) { // Note: we add the loaded module to our name->module listing // before doing semantic checking, so that if it tries to // recursively `import` itself, we can detect it. - RefPtr<LoadedModule> loadedModule = new LoadedModule(); + // + RefPtr<Module> loadedModule = translationUnit->getModule(); // Get a path String mostUniqueIdentity = pathInfo.getMostUniqueIdentity(); @@ -874,12 +1081,11 @@ void CompileRequest::loadParsedModule( mapPathToLoadedModule.Add(mostUniqueIdentity, loadedModule); mapNameToLoadedModules.Add(name, loadedModule); - int errorCountBefore = mSink.GetErrorCount(); - checkTranslationUnit(translationUnit.Ptr()); - int errorCountAfter = mSink.GetErrorCount(); + auto sink = translationUnit->compileRequest->getSink(); - RefPtr<ModuleDecl> moduleDecl = translationUnit->SyntaxNode; - loadedModule->moduleDecl = moduleDecl; + int errorCountBefore = sink->GetErrorCount(); + checkTranslationUnit(translationUnit.Ptr()); + int errorCountAfter = sink->GetErrorCount(); if (errorCountAfter != errorCountBefore) { @@ -890,39 +1096,56 @@ void CompileRequest::loadParsedModule( // If we didn't run into any errors, then try to generate // IR code for the imported module. SLANG_ASSERT(errorCountAfter == 0); - loadedModule->irModule = generateIRForTranslationUnit(translationUnit); + loadedModule->setIRModule(generateIRForTranslationUnit(translationUnit)); } loadedModulesList.Add(loadedModule); } -RefPtr<ModuleDecl> CompileRequest::loadModule( +Module* Linkage::loadModule(String const& name) +{ + // TODO: We either need to have a diagnostics sink + // get passed into this operation, or associate + // one with the linkage. + // + DiagnosticSink* sink = nullptr; + return findOrImportModule( + getNamePool()->getName(name), + SourceLoc(), + sink); +} + + +RefPtr<Module> Linkage::loadModule( Name* name, const PathInfo& filePathInfo, ISlangBlob* sourceBlob, - SourceLoc const& srcLoc) + SourceLoc const& srcLoc, + DiagnosticSink* sink) { - RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(); - translationUnit->compileRequest = this; + RefPtr<FrontEndCompileRequest> frontEndReq = new FrontEndCompileRequest(this, sink); - // We don't want to use the same options that the user specified - // for loading modules on-demand. In particular, we always want - // semantic checking to be enabled. - // - // TODO: decide which options, if any, should be inherited. - translationUnit->compileFlags = 0; + RefPtr<TranslationUnitRequest> translationUnit = new TranslationUnitRequest(frontEndReq); + translationUnit->compileRequest = frontEndReq; + translationUnit->moduleName = name; + + auto module = translationUnit->getModule(); + + ModuleBeingImportedRAII moduleBeingImported( + this, + module); // Create with the 'friendly' name SourceFile* sourceFile = getSourceManager()->createSourceFileWithBlob(filePathInfo, sourceBlob); - translationUnit->sourceFiles.Add(sourceFile); + translationUnit->addSourceFile(sourceFile); - int errorCountBefore = mSink.GetErrorCount(); - parseTranslationUnit(translationUnit.Ptr()); - int errorCountAfter = mSink.GetErrorCount(); + int errorCountBefore = sink->GetErrorCount(); + frontEndReq->parseTranslationUnit(translationUnit); + int errorCountAfter = sink->GetErrorCount(); if( errorCountAfter != errorCountBefore ) { - mSink.diagnose(srcLoc, Diagnostics::errorInImportedModule); + sink->diagnose(srcLoc, Diagnostics::errorInImportedModule); } if (errorCountAfter) { @@ -935,38 +1158,57 @@ RefPtr<ModuleDecl> CompileRequest::loadModule( name, filePathInfo); - errorCountAfter = mSink.GetErrorCount(); + errorCountAfter = sink->GetErrorCount(); if (errorCountAfter != errorCountBefore) { - mSink.diagnose(srcLoc, Diagnostics::errorInImportedModule); + sink->diagnose(srcLoc, Diagnostics::errorInImportedModule); // Something went wrong during the parsing, so we should bail out. return nullptr; } - return translationUnit->SyntaxNode; + return module; +} + +bool Linkage::isBeingImported(Module* module) +{ + for(auto ii = m_modulesBeingImported; ii; ii = ii->next) + { + if(module == ii->module) + return true; + } + return false; } -RefPtr<ModuleDecl> CompileRequest::findOrImportModule( +RefPtr<Module> Linkage::findOrImportModule( Name* name, - SourceLoc const& loc) + SourceLoc const& loc, + DiagnosticSink* sink) { // Have we already loaded a module matching this name? - // If so, return it. + // RefPtr<LoadedModule> loadedModule; if (mapNameToLoadedModules.TryGetValue(name, loadedModule)) { + // If the map shows a null module having been loaded, + // then that means there was a prior load attempt, + // but it failed, so we won't bother trying again. + // if (!loadedModule) return nullptr; - if (!loadedModule->moduleDecl) + // If state shows us that the module is already being + // imported deeper on the call stack, then we've + // hit a recursive case, and that is an error. + // + if(isBeingImported(loadedModule)) { // We seem to be in the middle of loading this module - mSink.diagnose(loc, Diagnostics::recursiveModuleImport, name); + sink->diagnose(loc, Diagnostics::recursiveModuleImport, name); return nullptr; } - return loadedModule->moduleDecl; + return loadedModule; } // Derive a file name for the module, by taking the given @@ -991,7 +1233,8 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( // using our ordinary include-handling logic. IncludeHandlerImpl includeHandler; - includeHandler.request = this; + includeHandler.linkage = this; + includeHandler.searchDirectories = &searchDirectories; // Get the original path info PathInfo pathIncludedFromInfo = getSourceManager()->getPathInfo(loc, SourceLocType::Actual); @@ -1000,20 +1243,20 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( // We have to load via the found path - as that is how file was originally loaded if (SLANG_FAILED(includeHandler.findFile(fileName, pathIncludedFromInfo.foundPath, filePathInfo))) { - this->mSink.diagnose(loc, Diagnostics::cannotFindFile, fileName); + sink->diagnose(loc, Diagnostics::cannotFindFile, fileName); mapNameToLoadedModules[name] = nullptr; return nullptr; } // Maybe this was loaded previously at a different relative name? if (mapPathToLoadedModule.TryGetValue(filePathInfo.getMostUniqueIdentity(), loadedModule)) - return loadedModule->moduleDecl; + return loadedModule; // Try to load it ComPtr<ISlangBlob> fileContents; - if (SLANG_FAILED(includeHandler.readFile(filePathInfo.foundPath, fileContents.writeRef()))) + if(SLANG_FAILED(getFileSystemExt()->loadFile(filePathInfo.foundPath.Buffer(), fileContents.writeRef()))) { - this->mSink.diagnose(loc, Diagnostics::cannotOpenFile, fileName); + sink->diagnose(loc, Diagnostics::cannotOpenFile, fileName); mapNameToLoadedModules[name] = nullptr; return nullptr; } @@ -1024,26 +1267,159 @@ RefPtr<ModuleDecl> CompileRequest::findOrImportModule( name, filePathInfo, fileContents, - loc); + loc, + sink); } -Decl * CompileRequest::lookupGlobalDecl(Name * name) +// +// ModuleDependencyList +// + +void ModuleDependencyList::addDependency(Module* module) { - Decl* resultDecl = nullptr; - for (auto module : loadedModulesList) + // If we depend on a module, then we depend on everything it depends on. + // + // Note: We are processing these sub-depenencies before adding the + // `module` itself, so that in the common case a module will always + // appear *after* everything it depends on. + // + // However, this rule is being violated in the compiler right now because + // the modules for hte top-level translation units of a compile request + // will be added to the list first (using `addLeafDependency`) to + // maintain compatibility with old behavior. This may be fixed later. + // + for(auto subDependency : module->getModuleDependencyList()) { - if (module->moduleDecl->memberDictionary.TryGetValue(name, resultDecl)) - break; + _addDependency(subDependency); + } + _addDependency(module); +} + +void ModuleDependencyList::addLeafDependency(Module* module) +{ + _addDependency(module); +} + +void ModuleDependencyList::_addDependency(Module* module) +{ + if(m_moduleSet.Contains(module)) + return; + + m_moduleList.Add(module); + m_moduleSet.Add(module); +} + +// +// FilePathDependencyList +// + +void FilePathDependencyList::addDependency(String const& path) +{ + if(m_filePathSet.Contains(path)) + return; + + m_filePathList.Add(path); + m_filePathSet.Add(path); +} + +void FilePathDependencyList::addDependency(Module* module) +{ + for(auto& path : module->getFilePathDependencyList()) + { + addDependency(path); } - for (auto transUnit : translationUnits) +} + + + +// +// Module +// + +Module::Module(Linkage* linkage) + : m_linkage(linkage) +{} + + +void Module::addModuleDependency(Module* module) +{ + m_moduleDependencyList.addDependency(module); + m_filePathDependencyList.addDependency(module); +} + +void Module::addFilePathDependency(String const& path) +{ + m_filePathDependencyList.addDependency(path); +} + +// Program + +Program::Program(Linkage* linkage) + : m_linkage(linkage) +{} + +void Program::addReferencedModule(Module* module) +{ + m_moduleDependencyList.addDependency(module); + m_filePathDependencyList.addDependency(module); +} + +void Program::addReferencedLeafModule(Module* module) +{ + m_moduleDependencyList.addLeafDependency(module); + m_filePathDependencyList.addDependency(module); +} + +void Program::addEntryPoint(EntryPoint* entryPoint) +{ + m_entryPoints.Add(entryPoint); + + for(auto module : entryPoint->getModuleDependencies()) { - if (transUnit->SyntaxNode->memberDictionary.TryGetValue(name, resultDecl)) - break; + addReferencedModule(module); } - return resultDecl; } -void CompileRequest::noteInternalErrorLoc(SourceLoc const& loc) +RefPtr<IRModule> Program::getOrCreateIRModule(DiagnosticSink* sink) +{ + if(!m_irModule) + { + m_irModule = generateIRForProgram( + m_linkage->getSession(), + this, + sink); + } + return m_irModule; +} + + +TargetProgram* Program::getTargetProgram(TargetRequest* target) +{ + RefPtr<TargetProgram> targetProgram; + if(!m_targetPrograms.TryGetValue(target, targetProgram)) + { + targetProgram = new TargetProgram(this, target); + m_targetPrograms[target] = targetProgram; + } + return targetProgram; +} + +// +// TargetProgram +// + +TargetProgram::TargetProgram( + Program* program, + TargetRequest* targetReq) + : m_program(program) + , m_targetReq(targetReq) +{ + m_entryPointResults.SetSize(program->getEntryPoints().Count()); +} + +// + +void DiagnosticSink::noteInternalErrorLoc(SourceLoc const& loc) { // Don't consider invalid source locations. if(!loc.isValid()) @@ -1054,14 +1430,19 @@ void CompileRequest::noteInternalErrorLoc(SourceLoc const& loc) // code might have confused the compiler. if(internalErrorLocsNoted == 0) { - mSink.diagnose(loc, Diagnostics::noteLocationOfInternalError); + diagnose(loc, Diagnostics::noteLocationOfInternalError); } internalErrorLocsNoted++; } +Session* CompileRequestBase::getSession() +{ + return getLinkage()->getSession(); +} + static const Slang::Guid IID_ISlangFileSystemExt = SLANG_UUID_ISlangFileSystemExt; -void CompileRequest::setFileSystem(ISlangFileSystem* inFileSystem) +void Linkage::setFileSystem(ISlangFileSystem* inFileSystem) { // Set the fileSystem fileSystem = inFileSystem; @@ -1085,15 +1466,16 @@ void CompileRequest::setFileSystem(ISlangFileSystem* inFileSystem) } // Set the file system used on the source manager - sourceManager->setFileSystemExt(fileSystemExt); + getSourceManager()->setFileSystemExt(fileSystemExt); } -RefPtr<ModuleDecl> findOrImportModule( - CompileRequest* request, +RefPtr<Module> findOrImportModule( + Linkage* linkage, Name* name, - SourceLoc const& loc) + SourceLoc const& loc, + DiagnosticSink* sink) { - return request->findOrImportModule(name, loc); + return linkage->findOrImportModule(name, loc, sink); } void Session::addBuiltinSource( @@ -1101,30 +1483,34 @@ void Session::addBuiltinSource( String const& path, String const& source) { - RefPtr<CompileRequest> compileRequest = new CompileRequest(this); - compileRequest->setSourceManager(getBuiltinSourceManager()); + DiagnosticSink sink; + RefPtr<FrontEndCompileRequest> compileRequest = new FrontEndCompileRequest( + m_builtinLinkage, + &sink); - auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, path); + Name* moduleName = getNamePool()->getName(path); + auto translationUnitIndex = compileRequest->addTranslationUnit(SourceLanguage::Slang, moduleName); compileRequest->addTranslationUnitSourceString( translationUnitIndex, path, source); - SlangResult res = compileRequest->executeActions(); + SlangResult res = compileRequest->executeActionsInner(); if (SLANG_FAILED(res)) { - fprintf(stderr, "%s", compileRequest->mDiagnosticOutput.Buffer()); + char const* diagnostics = sink.outputBuffer.Buffer(); + fprintf(stderr, "%s", diagnostics); #ifdef _WIN32 - OutputDebugStringA(compileRequest->mDiagnosticOutput.Buffer()); + OutputDebugStringA(diagnostics); #endif SLANG_UNEXPECTED("error in Slang standard library"); } // Extract the AST for the code we just parsed - auto syntax = compileRequest->translationUnits[translationUnitIndex]->SyntaxNode; + auto syntax = compileRequest->translationUnits[translationUnitIndex]->getModuleDecl(); // HACK(tfoley): mark all declarations in the "stdlib" so // that we can detect them later (e.g., so we don't emit them) @@ -1176,19 +1562,37 @@ Session::~Session() // implementation of C interface -#define SESSION(x) reinterpret_cast<Slang::Session *>(x) -#define REQ(x) reinterpret_cast<Slang::CompileRequest*>(x) +static SlangSession* convert(Slang::Session* session) +{ return reinterpret_cast<SlangSession*>(session); } + +static Slang::Session* convert(SlangSession* session) +{ return reinterpret_cast<Slang::Session*>(session); } + +static SlangCompileRequest* convert(Slang::EndToEndCompileRequest* request) +{ return reinterpret_cast<SlangCompileRequest*>(request); } + +static Slang::EndToEndCompileRequest* convert(SlangCompileRequest* request) +{ return reinterpret_cast<Slang::EndToEndCompileRequest*>(request); } + +static SlangLinkage* convert(Slang::Linkage* linkage) +{ return reinterpret_cast<SlangLinkage*>(linkage); } + +static Slang::Linkage* convert(SlangLinkage* linkage) +{ return reinterpret_cast<Slang::Linkage*>(linkage); } + +static SlangModule* convert(Slang::Module* module) +{ return reinterpret_cast<SlangModule*>(module); } SLANG_API SlangSession* spCreateSession(const char*) { - return reinterpret_cast<SlangSession *>(new Slang::Session()); + return convert(new Slang::Session()); } SLANG_API void spDestroySession( SlangSession* session) { if(!session) return; - delete SESSION(session); + delete convert(session); } SLANG_API void spAddBuiltins( @@ -1196,7 +1600,7 @@ SLANG_API void spAddBuiltins( char const* sourcePath, char const* sourceString) { - auto s = SESSION(session); + auto s = convert(session); s->addBuiltinSource( // TODO(tfoley): Add ability to directly new builtins to the approriate scope @@ -1210,7 +1614,7 @@ SLANG_API void spSessionSetSharedLibraryLoader( SlangSession* session, ISlangSharedLibraryLoader* loader) { - auto s = SESSION(session); + auto s = convert(session); if (!loader) { @@ -1237,7 +1641,7 @@ SLANG_API void spSessionSetSharedLibraryLoader( SLANG_API ISlangSharedLibraryLoader* spSessionGetSharedLibraryLoader( SlangSession* session) { - auto s = SESSION(session); + auto s = convert(session); return (s->sharedLibraryLoader == Slang::DefaultSharedLibraryLoader::getSingleton()) ? nullptr : s->sharedLibraryLoader.get(); } @@ -1245,7 +1649,7 @@ SLANG_API SlangResult spSessionCheckCompileTargetSupport( SlangSession* session, SlangCompileTarget target) { - auto s = SESSION(session); + auto s = convert(session); return Slang::checkCompileTargetSupport(s, Slang::CodeGenTarget(target)); } @@ -1253,16 +1657,45 @@ SLANG_API SlangResult spSessionCheckPassThroughSupport( SlangSession* session, SlangPassThrough passThrough) { - auto s = SESSION(session); + auto s = convert(session); return Slang::checkExternalCompilerSupport(s, Slang::PassThroughMode(passThrough)); } + +SLANG_API SlangLinkage* spCreateLinkage( + SlangSession* session) +{ + auto s = convert(session); + auto linkage = new Slang::Linkage(s); + return convert(linkage); +} + +SLANG_API void spDestroyLinkage( + SlangLinkage* linkage) +{ + if(!linkage) return; + auto lnk = convert(linkage); + delete lnk; +} + +SLANG_API SlangModule* spLoadModule( + SlangLinkage* linkage, + char const* moduleName) +{ + if(!linkage) return nullptr; + auto lnk = convert(linkage); + + auto mod = lnk->loadModule(moduleName); + return convert(mod); +} + + SLANG_API SlangCompileRequest* spCreateCompileRequest( SlangSession* session) { - auto s = SESSION(session); - auto req = new Slang::CompileRequest(s); - return reinterpret_cast<SlangCompileRequest*>(req); + auto s = convert(session); + auto req = new Slang::EndToEndCompileRequest(s); + return convert(req); } /*! @@ -1272,7 +1705,7 @@ SLANG_API void spDestroyCompileRequest( SlangCompileRequest* request) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); delete req; } @@ -1281,21 +1714,21 @@ SLANG_API void spSetFileSystem( ISlangFileSystem* fileSystem) { if(!request) return; - REQ(request)->setFileSystem(fileSystem); + convert(request)->getLinkage()->setFileSystem(fileSystem); } SLANG_API void spSetCompileFlags( SlangCompileRequest* request, SlangCompileFlags flags) { - REQ(request)->compileFlags = flags; + convert(request)->getFrontEndReq()->compileFlags = flags; } SLANG_API void spSetDumpIntermediates( SlangCompileRequest* request, int enable) { - REQ(request)->shouldDumpIntermediates = enable != 0; + convert(request)->getBackEndReq()->shouldDumpIntermediates = enable != 0; } SLANG_API void spSetLineDirectiveMode( @@ -1304,13 +1737,13 @@ SLANG_API void spSetLineDirectiveMode( { // TODO: validation - REQ(request)->lineDirectiveMode = Slang::LineDirectiveMode(mode); + convert(request)->getBackEndReq()->lineDirectiveMode = Slang::LineDirectiveMode(mode); } SLANG_API void spSetCommandLineCompilerMode( SlangCompileRequest* request) { - REQ(request)->isCommandLineCompile = true; + convert(request)->isCommandLineCompile = true; } @@ -1318,17 +1751,19 @@ SLANG_API void spSetCodeGenTarget( SlangCompileRequest* request, SlangCompileTarget target) { - auto req = REQ(request); - req->targets.Clear(); - req->addTarget(Slang::CodeGenTarget(target)); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets.Clear(); + linkage->addTarget(Slang::CodeGenTarget(target)); } SLANG_API int spAddCodeGenTarget( SlangCompileRequest* request, SlangCompileTarget target) { - auto req = REQ(request); - return (int) req->addTarget(Slang::CodeGenTarget(target)); + auto req = convert(request); + auto linkage = req->getLinkage(); + return (int) linkage->addTarget(Slang::CodeGenTarget(target)); } SLANG_API void spSetTargetProfile( @@ -1336,8 +1771,9 @@ SLANG_API void spSetTargetProfile( int targetIndex, SlangProfileID profile) { - auto req = REQ(request); - req->targets[targetIndex]->targetProfile = Slang::Profile(profile); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets[targetIndex]->targetProfile = Slang::Profile(profile); } SLANG_API void spSetTargetFlags( @@ -1345,8 +1781,9 @@ SLANG_API void spSetTargetFlags( int targetIndex, SlangTargetFlags flags) { - auto req = REQ(request); - req->targets[targetIndex]->targetFlags = flags; + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets[targetIndex]->targetFlags = flags; } SLANG_API void spSetTargetFloatingPointMode( @@ -1354,16 +1791,18 @@ SLANG_API void spSetTargetFloatingPointMode( int targetIndex, SlangFloatingPointMode mode) { - auto req = REQ(request); - req->targets[targetIndex]->floatingPointMode = Slang::FloatingPointMode(mode); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->targets[targetIndex]->floatingPointMode = Slang::FloatingPointMode(mode); } SLANG_API void spSetMatrixLayoutMode( SlangCompileRequest* request, SlangMatrixLayoutMode mode) { - auto req = REQ(request); - req->defaultMatrixLayoutMode = Slang::MatrixLayoutMode(mode); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->defaultMatrixLayoutMode = Slang::MatrixLayoutMode(mode); } SLANG_API void spSetTargetMatrixLayoutMode( @@ -1380,7 +1819,7 @@ SLANG_API void spSetOutputContainerFormat( SlangCompileRequest* request, SlangContainerFormat format) { - auto req = REQ(request); + auto req = convert(request); req->containerFormat = Slang::ContainerFormat(format); } @@ -1389,7 +1828,7 @@ SLANG_API void spSetPassThrough( SlangCompileRequest* request, SlangPassThrough passThrough) { - REQ(request)->passThrough = Slang::PassThroughMode(passThrough); + convert(request)->passThrough = Slang::PassThroughMode(passThrough); } SLANG_API void spSetDiagnosticCallback( @@ -1400,7 +1839,7 @@ SLANG_API void spSetDiagnosticCallback( using namespace Slang; if(!request) return; - auto req = REQ(request); + auto req = convert(request); ComPtr<ISlangWriter> writer(new CallbackWriter(callback, userData, WriterFlag::IsConsole)); req->setWriter(WriterChannel::Diagnostic, writer); @@ -1412,7 +1851,7 @@ SLANG_API void spSetWriter( ISlangWriter* writer) { if (!request) return; - auto req = REQ(request); + auto req = convert(request); req->setWriter(Slang::WriterChannel(chan), writer); } @@ -1422,15 +1861,17 @@ SLANG_API ISlangWriter* spGetWriter( SlangWriterChannel chan) { if (!request) return nullptr; - auto req = REQ(request); + auto req = convert(request); return req->getWriter(Slang::WriterChannel(chan)); } SLANG_API void spAddSearchPath( - SlangCompileRequest* request, - const char* path) + SlangCompileRequest* request, + const char* path) { - REQ(request)->searchDirectories.Add(Slang::SearchDirectory(path)); + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->searchDirectories.searchDirectories.Add(Slang::SearchDirectory(path)); } SLANG_API void spAddPreprocessorDefine( @@ -1438,25 +1879,27 @@ SLANG_API void spAddPreprocessorDefine( const char* key, const char* value) { - REQ(request)->preprocessorDefinitions[key] = value; + auto req = convert(request); + auto linkage = req->getLinkage(); + linkage->preprocessorDefinitions[key] = value; } SLANG_API char const* spGetDiagnosticOutput( SlangCompileRequest* request) { if(!request) return 0; - auto req = REQ(request); + auto req = convert(request); return req->mDiagnosticOutput.begin(); } SLANG_API SlangResult spGetDiagnosticOutputBlob( - SlangCompileRequest* request, - ISlangBlob** outBlob) + SlangCompileRequest* request, + ISlangBlob** outBlob) { if(!request) return SLANG_ERROR_INVALID_PARAMETER; if(!outBlob) return SLANG_ERROR_INVALID_PARAMETER; - auto req = REQ(request); + auto req = convert(request); if(!req->diagnosticOutputBlob) { @@ -1475,11 +1918,13 @@ SLANG_API int spAddTranslationUnit( SlangSourceLanguage language, char const* name) { - auto req = REQ(request); + SLANG_UNUSED(name); + + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); - return req->addTranslationUnit( - Slang::SourceLanguage(language), - name ? name : ""); + return frontEndReq->addTranslationUnit( + Slang::SourceLanguage(language)); } SLANG_API void spTranslationUnit_addPreprocessorDefine( @@ -1488,10 +1933,10 @@ SLANG_API void spTranslationUnit_addPreprocessorDefine( const char* key, const char* value) { - auto req = REQ(request); - - req->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value; + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + frontEndReq->translationUnits[translationUnitIndex]->preprocessorDefinitions[key] = value; } SLANG_API void spAddTranslationUnitSourceFile( @@ -1500,12 +1945,13 @@ SLANG_API void spAddTranslationUnitSourceFile( char const* path) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if(!path) return; if(translationUnitIndex < 0) return; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return; + if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return; - req->addTranslationUnitSourceFile( + frontEndReq->addTranslationUnitSourceFile( translationUnitIndex, path); } @@ -1533,14 +1979,15 @@ SLANG_API void spAddTranslationUnitSourceStringSpan( char const* sourceEnd) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if(!sourceBegin) return; if(translationUnitIndex < 0) return; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return; + if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return; if(!path) path = ""; - req->addTranslationUnitSourceString( + frontEndReq->addTranslationUnitSourceString( translationUnitIndex, path, Slang::UnownedStringSlice(sourceBegin, sourceEnd)); @@ -1553,14 +2000,15 @@ SLANG_API void spAddTranslationUnitSourceBlob( ISlangBlob* sourceBlob) { if(!request) return; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if(!sourceBlob) return; if(translationUnitIndex < 0) return; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return; + if(Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return; if(!path) path = ""; - req->addTranslationUnitSourceBlob( + frontEndReq->addTranslationUnitSourceBlob( translationUnitIndex, path, sourceBlob); @@ -1584,17 +2032,13 @@ SLANG_API int spAddEntryPoint( char const* name, SlangStage stage) { - if(!request) return -1; - auto req = REQ(request); - if(!name) return -1; - if(translationUnitIndex < 0) return -1; - if(Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1; - - return req->addEntryPoint( + return spAddEntryPointEx( + request, translationUnitIndex, name, - Slang::Profile(Slang::Stage(stage)), - Slang::List<Slang::String>()); + stage, + 0, + nullptr); } SLANG_API int spAddEntryPointEx( @@ -1606,10 +2050,11 @@ SLANG_API int spAddEntryPointEx( char const ** genericParamTypeNames) { if (!request) return -1; - auto req = REQ(request); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); if (!name) return -1; if (translationUnitIndex < 0) return -1; - if (Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1; + if (Slang::UInt(translationUnitIndex) >= frontEndReq->translationUnits.Count()) return -1; Slang::List<Slang::String> typeNames; for (int i = 0; i < genericParamTypeNameCount; i++) typeNames.Add(genericParamTypeNames[i]); @@ -1620,12 +2065,28 @@ SLANG_API int spAddEntryPointEx( typeNames); } +SLANG_API SlangResult spSetGlobalGenericArgs( + SlangCompileRequest* request, + int genericArgCount, + char const** genericArgs) +{ + if (!request) return SLANG_FAIL; + auto req = convert(request); + + auto& genericArgStrings = req->globalGenericArgStrings; + genericArgStrings.Clear(); + for (int i = 0; i < genericArgCount; i++) + genericArgStrings.Add(genericArgs[i]); + + return SLANG_OK; +} + // Compile in a context that already has its translation units specified SLANG_API SlangResult spCompile( SlangCompileRequest* request) { - auto req = REQ(request); + auto req = convert(request); #if !defined(SLANG_DEBUG_INTERNAL_ERROR) // By default we'd like to catch as many internal errors as possible, @@ -1654,7 +2115,7 @@ SLANG_API SlangResult spCompile( // We will print out information on the exception to help out the user // in either filing a bug, or locating what in their code created // a problem. - req->mSink.diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAbortedDueToException, typeid(e).name(), e.Message); + req->getSink()->diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAbortedDueToException, typeid(e).name(), e.Message); } catch (...) { @@ -1662,9 +2123,9 @@ SLANG_API SlangResult spCompile( // `Exception`, so something really fishy is going on. We want to // let the user know that we messed up, so they know to blame Slang // and not some other component in their system. - req->mSink.diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAborted); + req->getSink()->diagnose(Slang::SourceLoc(), Slang::Diagnostics::compilationAborted); } - req->mDiagnosticOutput = req->mSink.outputBuffer.ProduceString(); + req->mDiagnosticOutput = req->getSink()->outputBuffer.ProduceString(); return res; #else // When debugging, we probably don't want to filter out any errors, since @@ -1680,8 +2141,10 @@ spGetDependencyFileCount( SlangCompileRequest* request) { if(!request) return 0; - auto req = REQ(request); - return (int) req->mDependencyFilePaths.Count(); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + auto program = frontEndReq->getProgram(); + return (int) program->getFilePathDependencies().Count(); } /** Get the path to a file this compilation dependend on. @@ -1692,16 +2155,19 @@ spGetDependencyFilePath( int index) { if(!request) return 0; - auto req = REQ(request); - return req->mDependencyFilePaths[index].begin(); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + auto program = frontEndReq->getProgram(); + return program->getFilePathDependencies()[index].begin(); } SLANG_API int spGetTranslationUnitCount( SlangCompileRequest* request) { - auto req = REQ(request); - return (int) req->translationUnits.Count(); + auto req = convert(request); + auto frontEndReq = req->getFrontEndReq(); + return (int) frontEndReq->translationUnits.Count(); } // Get the output code associated with a specific translation unit @@ -1718,15 +2184,26 @@ SLANG_API void const* spGetEntryPointCode( int entryPointIndex, size_t* outSize) { - auto req = REQ(request); + auto req = convert(request); + auto linkage = req->getLinkage(); + auto program = req->getSpecializedProgram(); // TODO: We should really accept a target index in this API - auto targetCount = req->targets.Count(); - if (targetCount == 0) + Slang::UInt targetIndex = 0; + auto targetCount = linkage->targets.Count(); + if (targetIndex >= targetCount) return nullptr; - auto targetReq = req->targets[0]; + auto targetReq = linkage->targets[targetIndex]; - Slang::CompileResult& result = targetReq->entryPointResults[entryPointIndex]; + + if(entryPointIndex < 0) return nullptr; + if(Slang::UInt(entryPointIndex) >= req->entryPoints.Count()) return nullptr; + auto entryPoint = program->getEntryPoint(entryPointIndex); + + auto targetProgram = program->getTargetProgram(targetReq); + if(!targetProgram) + return nullptr; + Slang::CompileResult& result = targetProgram->getExistingEntryPointResult(entryPointIndex); void const* data = nullptr; size_t size = 0; @@ -1761,21 +2238,29 @@ SLANG_API SlangResult spGetEntryPointCodeBlob( if(!request) return SLANG_ERROR_INVALID_PARAMETER; if(!outBlob) return SLANG_ERROR_INVALID_PARAMETER; - auto req = REQ(request); + auto req = convert(request); + auto linkage = req->getLinkage(); + auto program = req->getSpecializedProgram(); - int targetCount = (int) req->targets.Count(); + int targetCount = (int) linkage->targets.Count(); if((targetIndex < 0) || (targetIndex >= targetCount)) { return SLANG_ERROR_INVALID_PARAMETER; } - auto targetReq = req->targets[targetIndex]; + auto targetReq = linkage->targets[targetIndex]; int entryPointCount = (int) req->entryPoints.Count(); if((entryPointIndex < 0) || (entryPointIndex >= entryPointCount)) { return SLANG_ERROR_INVALID_PARAMETER; } - Slang::CompileResult& result = targetReq->entryPointResults[entryPointIndex]; + auto entryPointReq = program->getEntryPoint(entryPointIndex); + + + auto targetProgram = program->getTargetProgram(targetReq); + if(!targetProgram) + return SLANG_FAIL; + Slang::CompileResult& result = targetProgram->getExistingEntryPointResult(entryPointIndex); auto blob = result.getBlob(); *outBlob = blob.detach(); @@ -1793,13 +2278,9 @@ SLANG_API void const* spGetCompileRequestCode( SlangCompileRequest* request, size_t* outSize) { - auto req = REQ(request); - - void const* data = req->generatedBytecode.Buffer(); - size_t size = req->generatedBytecode.Count(); - - if(outSize) *outSize = size; - return data; + SLANG_UNUSED(request); + SLANG_UNUSED(outSize); + return nullptr; } // Reflection API @@ -1808,7 +2289,9 @@ SLANG_API SlangReflection* spGetReflection( SlangCompileRequest* request) { if( !request ) return 0; - auto req = REQ(request); + auto req = convert(request); + auto linkage = req->getLinkage(); + auto program = req->getSpecializedProgram(); // Note(tfoley): The API signature doesn't let the client // specify which target they want to access reflection @@ -1818,12 +2301,16 @@ SLANG_API SlangReflection* spGetReflection( // so that we can do this better, and make it clear that // `spGetReflection()` is shorthand for `targetIndex == 0`. // - auto targetCount = req->targets.Count(); - if (targetCount == 0) - return 0; - auto targetReq = req->targets[0]; + Slang::UInt targetIndex = 0; + auto targetCount = linkage->targets.Count(); + if (targetIndex >= targetCount) + return nullptr; + + auto targetReq = linkage->targets[targetIndex]; + auto targetProgram = program->getTargetProgram(targetReq); + auto programLayout = targetProgram->getExistingLayout(); - return (SlangReflection*) targetReq->layout.Ptr(); + return (SlangReflection*) programLayout; } // ... rest of reflection API implementation is in `Reflection.cpp` |
