diff options
Diffstat (limited to 'source/slang/slang-compiler.cpp')
| -rw-r--r-- | source/slang/slang-compiler.cpp | 641 |
1 files changed, 286 insertions, 355 deletions
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 14912d719..1b20a869d 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -20,7 +20,6 @@ #include "slang-parser.h" #include "slang-preprocessor.h" #include "slang-type-layout.h" -#include "slang-emit.h" #include "slang-glsl-extension-tracker.h" #include "slang-emit-cuda.h" @@ -587,18 +586,29 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) return PassThroughMode::None; } - bool isPassThroughEnabled( - EndToEndCompileRequest* endToEndReq) + EndToEndCompileRequest* CodeGenContext::isPassThroughEnabled() { + auto endToEndReq = isEndToEndCompile(); + // If there isn't an end-to-end compile going on, // there can be no pass-through. // - if (!endToEndReq) return false; + if (!endToEndReq) + return nullptr; - // And if pass-through isn't set, we don't need - // access to the translation unit. - return endToEndReq->m_passThrough != PassThroughMode::None; + // And if pass-through isn't set on that end-to-end compile, + // then we clearly areb't doing a pass-through compile. + // + if(endToEndReq->m_passThrough == PassThroughMode::None) + return nullptr; + + // If we have confirmed that pass-through compilation is going on, + // we return the end-to-end request, because it has all the + // relevant state that we need to implement pass-through mode. + // + return endToEndReq; } + /// If there is a pass-through compile going on, find the translation unit for the given entry point. /// Assumes isPassThroughEnabled has already been called TranslationUnitRequest* getPassThroughTranslationUnit( @@ -613,11 +623,10 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) return translationUnit; } - TranslationUnitRequest* findPassThroughTranslationUnit( - EndToEndCompileRequest* endToEndReq, + TranslationUnitRequest* CodeGenContext::findPassThroughTranslationUnit( Int entryPointIndex) { - if (isPassThroughEnabled(endToEndReq)) + if (auto endToEndReq = isPassThroughEnabled()) return getPassThroughTranslationUnit(endToEndReq, entryPointIndex); return nullptr; } @@ -652,20 +661,14 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } } - SlangResult emitEntryPointsSource( - BackEndCompileRequest* compileRequest, - const List<Int>& entryPointIndices, - TargetRequest* targetReq, - CodeGenTarget target, - EndToEndCompileRequest* endToEndReq, - ExtensionTracker* extensionTracker, + SlangResult CodeGenContext::emitEntryPointsSource( String& outSource) { outSource = String(); - if(isPassThroughEnabled(endToEndReq)) + if(auto endToEndReq = isPassThroughEnabled()) { - for (auto entryPointIndex : entryPointIndices) + for (auto entryPointIndex : getEntryPointIndices()) { auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndex); SLANG_ASSERT(translationUnit); @@ -676,7 +679,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) // mode. StringBuilder codeBuilder; - if (target == CodeGenTarget::GLSL) + if (getTargetFormat() == CodeGenTarget::GLSL) { // Special case GLSL int translationUnitCounter = 0; @@ -711,30 +714,10 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) else { return emitEntryPointsSourceFromIR( - compileRequest, - entryPointIndices, - target, - targetReq, - extensionTracker, outSource); } } - SlangResult emitEntryPointSource( - BackEndCompileRequest* compileRequest, - Int entryPointIndex, - TargetRequest* targetReq, - CodeGenTarget target, - EndToEndCompileRequest* endToEndReq, - ExtensionTracker* extensionTracker, - String& outSource) - { - List<Int> entryPointIndices; - entryPointIndices.add(entryPointIndex); - return emitEntryPointsSource(compileRequest, entryPointIndices, targetReq, - target, endToEndReq, extensionTracker, outSource); - } - String GetHLSLProfileName(Profile profile) { switch( profile.getFamily() ) @@ -868,21 +851,19 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } } - String calcSourcePathForEntryPoints( - EndToEndCompileRequest* endToEndReq, - const List<Int>& entryPointIndices) + String CodeGenContext::calcSourcePathForEntryPoints() { String failureMode = "slang-generated"; - if (entryPointIndices.getCount() != 1) + if (getEntryPointCount() != 1) return failureMode; - auto entryPointIndex = entryPointIndices[0]; - auto translationUnitRequest = findPassThroughTranslationUnit(endToEndReq, entryPointIndex); + auto entryPointIndex = getSingleEntryPointIndex(); + auto translationUnitRequest = findPassThroughTranslationUnit(entryPointIndex); if (!translationUnitRequest) return failureMode; const auto& sourceFiles = translationUnitRequest->getSourceFiles(); - auto sink = endToEndReq->getSink(); + auto sink = getSink(); const Index numSourceFiles = sourceFiles.getCount(); @@ -903,15 +884,6 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } } - String calcSourcePathForEntryPoint( - EndToEndCompileRequest* endToEndReq, - Int entryPointIndex) - { - List<Int> entryPointIndices; - entryPointIndices.add(entryPointIndex); - return calcSourcePathForEntryPoints(endToEndReq, entryPointIndices); - } - // Helper function for cases where we can assume a single entry point Int assertSingleEntryPoint(List<Int> const& entryPointIndices) { SLANG_ASSERT(entryPointIndices.getCount() == 1); @@ -1012,31 +984,29 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } } - SlangResult emitWithDownstreamForEntryPoints( - ComponentType* program, - BackEndCompileRequest* slangRequest, - const List<Int>& entryPointIndices, - TargetRequest* targetReq, - CodeGenTarget target, - EndToEndCompileRequest* endToEndReq, + SlangResult CodeGenContext::emitWithDownstreamForEntryPoints( RefPtr<DownstreamCompileResult>& outResult) { outResult.setNull(); - auto sink = slangRequest->getSink(); - - auto session = slangRequest->getSession(); - + auto sink = getSink(); + auto session = getSession(); CodeGenTarget sourceTarget = CodeGenTarget::None; SourceLanguage sourceLanguage = SourceLanguage::Unknown; + auto target = getTargetFormat(); RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target); - PassThroughMode compilerType = endToEndReq ? endToEndReq->m_passThrough : PassThroughMode::None; + PassThroughMode compilerType; - // If we are not in pass through, lookup the default compiler for the emitted source type - if (compilerType == PassThroughMode::None) + if (auto endToEndReq = isPassThroughEnabled()) + { + compilerType = endToEndReq->m_passThrough; + } + else { + // If we are not in pass through, lookup the default compiler for the emitted source type + // Get the default source codegen type for a given target sourceTarget = _getDefaultSourceForTarget(target); compilerType = (PassThroughMode)session->getDownstreamCompilerForTransition((SlangCompileTarget)sourceTarget, (SlangCompileTarget)target); @@ -1070,7 +1040,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) // Set compiler specific args { - auto linkage = targetReq->getLinkage(); + auto linkage = getLinkage(); auto name = TypeTextUtil::getPassThroughName((SlangPassThrough)compilerType); const Index nameIndex = linkage->m_downstreamArgs.findName(name); @@ -1087,17 +1057,15 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) /* This is more convoluted than the other scenarios, because when we invoke C/C++ compiler we would ideally like to use the original file. We want to do this because we want includes relative to the source file to work, and for that to work most easily we want to use the original file, if there is one */ - if (isPassThroughEnabled(endToEndReq)) + if (auto endToEndReq = isPassThroughEnabled()) { // If we are pass through, we may need to set extension tracker state. if (GLSLExtensionTracker* glslTracker = as<GLSLExtensionTracker>(extensionTracker)) { - trackGLSLTargetCaps(glslTracker, targetReq->getTargetCaps()); + trackGLSLTargetCaps(glslTracker, getTargetCaps()); } - // TODO(DG): Review this assertion later - SLANG_ASSERT(entryPointIndices.getCount() == 1); - auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndices[0]); + auto translationUnit = getPassThroughTranslationUnit(endToEndReq, getSingleEntryPointIndex()); // We are just passing thru, so it's whatever it originally was sourceLanguage = translationUnit->sourceLanguage; @@ -1118,7 +1086,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) preprocessorDefinitions.Add(define.Key, define.Value); } { - auto linkage = targetReq->getLinkage(); + auto linkage = getLinkage(); for (auto& define : linkage->preprocessorDefinitions) { preprocessorDefinitions.Add(define.Key, define.Value); @@ -1134,7 +1102,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) On invoking DXC for example include paths do not appear to be set at all (even with pass-through). */ - auto linkage = targetReq->getLinkage(); + auto linkage = getLinkage(); // Add all the search paths @@ -1156,8 +1124,10 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) { // If it's not file based we can set an appropriate path name, and it doesn't matter if it doesn't // exist on the file system - options.sourceContentsPath = calcSourcePathForEntryPoints(endToEndReq, entryPointIndices); - SLANG_RETURN_ON_FAIL(emitEntryPointsSource(slangRequest, entryPointIndices, targetReq, sourceTarget, endToEndReq, extensionTracker, options.sourceContents)); + options.sourceContentsPath = calcSourcePathForEntryPoints(); + + CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); + SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(options.sourceContents)); } else { @@ -1173,8 +1143,9 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } else { - SLANG_RETURN_ON_FAIL(emitEntryPointsSource(slangRequest, entryPointIndices, targetReq, sourceTarget, endToEndReq, extensionTracker, options.sourceContents)); - maybeDumpIntermediate(slangRequest, options.sourceContents.getBuffer(), sourceTarget); + CodeGenContext sourceCodeGenContext(this, sourceTarget, extensionTracker); + SLANG_RETURN_ON_FAIL(sourceCodeGenContext.emitEntryPointsSource(options.sourceContents)); + sourceCodeGenContext.maybeDumpIntermediate(options.sourceContents.getBuffer()); sourceLanguage = (SourceLanguage)TypeConvertUtil::getSourceLanguageFromTarget((SlangCompileTarget)sourceTarget); } @@ -1213,8 +1184,8 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } // Set the file sytem and source manager, as *may* be used by downstream compiler - options.fileSystemExt = slangRequest->getFileSystemExt(); - options.sourceManager = slangRequest->getSourceManager(); + options.fileSystemExt = getFileSystemExt(); + options.sourceManager = getSourceManager(); // Set the source type options.sourceLanguage = SlangSourceLanguage(sourceLanguage); @@ -1226,23 +1197,16 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) compilerType == PassThroughMode::Dxc || compilerType == PassThroughMode::Glslang) { - if (entryPointIndices.getCount() != 1) - { - // We only support a single entry point on this target - SLANG_ASSERT(!"Can only compile with a single entry point on this target"); - return SLANG_FAIL; - } + auto entryPointIndex = getSingleEntryPointIndex(); - const Index entryPointIndex = entryPointIndices[0]; - - auto entryPoint = program->getEntryPoint(entryPointIndex); - auto profile = getEffectiveProfile(entryPoint, targetReq); + auto entryPoint = getEntryPoint(entryPointIndex); + auto profile = getEffectiveProfile(entryPoint, getTargetReq()); options.stage = SlangStage(profile.getStage()); // Set the entry point name options.entryPointName = getText(entryPoint->getName()); - auto entryPointNameOverride = program->getEntryPointNameOverride(entryPointIndex); + auto entryPointNameOverride = getProgram()->getEntryPointNameOverride(entryPointIndex); if (entryPointNameOverride.getLength() != 0) { options.entryPointName = entryPointNameOverride; @@ -1272,7 +1236,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } // Set the matrix layout - options.matrixLayout = targetReq->getDefaultMatrixLayoutMode(); + options.matrixLayout = getTargetReq()->getDefaultMatrixLayoutMode(); } else if (compilerType == PassThroughMode::Fxc) { @@ -1287,7 +1251,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) target = CodeGenTarget::ShaderSharedLibrary; } - if (!isPassThroughEnabled(endToEndReq)) + if (!isPassThroughEnabled()) { if (_isCPUHostTarget(target)) { @@ -1301,7 +1265,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) // Need to configure for the compilation { - auto linkage = targetReq->getLinkage(); + auto linkage = getLinkage(); switch (linkage->optimizationLevel) { @@ -1322,7 +1286,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) default: SLANG_ASSERT(!"Unhandled debug level"); break; } - switch( targetReq->getFloatingPointMode() ) + switch( getTargetReq()->getFloatingPointMode()) { case FloatingPointMode::Default: options.floatingPointMode = DownstreamCompiler::FloatingPointMode::Default; break; case FloatingPointMode::Precise: options.floatingPointMode = DownstreamCompiler::FloatingPointMode::Precise; break; @@ -1346,10 +1310,10 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) // because we always perform code generation on a single // entry point at a time. // - Index entryPointCount = slangRequest->getProgram()->getEntryPointCount(); + Index entryPointCount = getEntryPointCount(); for(Index ee = 0; ee < entryPointCount; ++ee) { - auto stage = slangRequest->getProgram()->getEntryPoint(ee)->getStage(); + auto stage = getEntryPoint(ee)->getStage(); switch(stage) { default: @@ -1401,7 +1365,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) SLANG_RETURN_ON_FAIL(compiler->compile(options, downstreamCompileResult)); auto downstreamElapsedTime = (std::chrono::high_resolution_clock::now() - downstreamStartTime).count() * 0.000000001; - slangRequest->getSession()->addDownstreamCompileTime(downstreamElapsedTime); + getSession()->addDownstreamCompileTime(downstreamElapsedTime); const auto& diagnostics = downstreamCompileResult->getDiagnostics(); @@ -1461,15 +1425,14 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) return SLANG_OK; } - SlangResult dissassembleWithDownstream( - BackEndCompileRequest* slangRequest, - CodeGenTarget target, + SlangResult CodeGenContext::dissassembleWithDownstream( const void* data, size_t dataSizeInBytes, ISlangBlob** outBlob) { - auto session = slangRequest->getSession(); - auto sink = slangRequest->getSink(); + auto session = getSession(); + auto sink = getSink(); + auto target = getTargetFormat(); // Get the downstream compiler that can be used for this target @@ -1495,23 +1458,19 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) return SLANG_OK; } - SlangResult dissassembleWithDownstream( - BackEndCompileRequest* slangRequest, - CodeGenTarget target, + SlangResult CodeGenContext::dissassembleWithDownstream( DownstreamCompileResult* downstreamResult, ISlangBlob** outBlob) { ComPtr<ISlangBlob> codeBlob; SLANG_RETURN_ON_FAIL(downstreamResult->getBinary(codeBlob)); - return dissassembleWithDownstream(slangRequest, target, codeBlob->getBufferPointer(), codeBlob->getBufferSize(), outBlob); + return dissassembleWithDownstream(codeBlob->getBufferPointer(), codeBlob->getBufferSize(), outBlob); } SlangResult emitSPIRVForEntryPointsDirectly( - BackEndCompileRequest* compileRequest, - const List<Int>& entryPointIndices, - TargetRequest* targetReq, - List<uint8_t>& spirvOut); + CodeGenContext* codeGenContext, + List<uint8_t>& spirvOut); static CodeGenTarget _getIntermediateTarget(CodeGenTarget target) { @@ -1525,68 +1484,50 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } /// Function to simplify the logic around emitting, and dissassembling - static SlangResult _emitEntryPoints( - ComponentType* program, - BackEndCompileRequest* compileRequest, - const List<Int>& entryPointIndices, - TargetRequest* targetReq, - CodeGenTarget target, - EndToEndCompileRequest* endToEndReq, + SlangResult CodeGenContext::_emitEntryPoints( RefPtr<DownstreamCompileResult>& outDownstreamResult) { + auto target = getTargetFormat(); switch (target) { case CodeGenTarget::SPIRVAssembly: case CodeGenTarget::DXBytecodeAssembly: case CodeGenTarget::DXILAssembly: { - RefPtr<DownstreamCompileResult> code; - - // Compile the intermediate target + // First compile to an intermediate target for the corresponding binary format. const CodeGenTarget intermediateTarget = _getIntermediateTarget(target); - SLANG_RETURN_ON_FAIL(_emitEntryPoints(program, compileRequest, entryPointIndices, targetReq, intermediateTarget, endToEndReq, code)); + CodeGenContext intermediateContext(this, intermediateTarget); - maybeDumpIntermediate(compileRequest, code, intermediateTarget); + RefPtr<DownstreamCompileResult> code; + SLANG_RETURN_ON_FAIL(intermediateContext._emitEntryPoints(code)); + intermediateContext.maybeDumpIntermediate(code); + // Then disassemble the intermediate binary result to get the desired output // Output the disassembly ComPtr<ISlangBlob> disassemblyBlob; - SLANG_RETURN_ON_FAIL(dissassembleWithDownstream(compileRequest, intermediateTarget, code, disassemblyBlob.writeRef())); + SLANG_RETURN_ON_FAIL(intermediateContext.dissassembleWithDownstream(code, disassemblyBlob.writeRef())); outDownstreamResult = new BlobDownstreamCompileResult(DownstreamDiagnostics(), disassemblyBlob); return SLANG_OK; } case CodeGenTarget::SPIRV: + if (getTargetReq()->shouldEmitSPIRVDirectly()) + { + List<uint8_t> spirv; + SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPointsDirectly(this, spirv)); + auto spirvBlob = ListBlob::moveCreate(spirv); + outDownstreamResult = new BlobDownstreamCompileResult(DownstreamDiagnostics(), spirvBlob); + return SLANG_OK; + } + /* fall through to: */ case CodeGenTarget::DXIL: case CodeGenTarget::DXBytecode: case CodeGenTarget::PTX: case CodeGenTarget::ShaderHostCallable: case CodeGenTarget::ShaderSharedLibrary: case CodeGenTarget::HostExecutable: - { - RefPtr<DownstreamCompileResult> downstreamResult; - - if (target == CodeGenTarget::SPIRV && targetReq->shouldEmitSPIRVDirectly()) - { - List<uint8_t> spirv; - SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPointsDirectly(compileRequest, entryPointIndices, targetReq, spirv)); - auto spirvBlob = ListBlob::moveCreate(spirv); - downstreamResult = new BlobDownstreamCompileResult(DownstreamDiagnostics(), spirvBlob); - } - else - { - SLANG_RETURN_ON_FAIL(emitWithDownstreamForEntryPoints( - program, - compileRequest, - entryPointIndices, - targetReq, - target, - endToEndReq, - downstreamResult)); - } - - outDownstreamResult = downstreamResult; + SLANG_RETURN_ON_FAIL(emitWithDownstreamForEntryPoints(outDownstreamResult)); return SLANG_OK; - } default: break; } @@ -1595,16 +1536,11 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } // Do emit logic for a zero or more entry points - CompileResult emitEntryPoints( - ComponentType* program, - BackEndCompileRequest* compileRequest, - const List<Int>& entryPointIndices, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq) + CompileResult CodeGenContext::emitEntryPoints() { CompileResult result; - auto target = targetReq->getTarget(); + auto target = getTargetFormat(); switch (target) { @@ -1621,15 +1557,9 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) { RefPtr<DownstreamCompileResult> downstreamResult; - if (SLANG_SUCCEEDED(_emitEntryPoints(program, - compileRequest, - entryPointIndices, - targetReq, - target, - endToEndReq, - downstreamResult))) + if (SLANG_SUCCEEDED(_emitEntryPoints(downstreamResult))) { - maybeDumpIntermediate(compileRequest, downstreamResult, target); + maybeDumpIntermediate(downstreamResult); result = CompileResult(downstreamResult); } } @@ -1643,19 +1573,15 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) { RefPtr<ExtensionTracker> extensionTracker = _newExtensionTracker(target); + CodeGenContext subContext(this, target, extensionTracker); + String code; - if (SLANG_FAILED(emitEntryPointsSource(compileRequest, - entryPointIndices, - targetReq, - target, - endToEndReq, - extensionTracker, - code))) + if (SLANG_FAILED(subContext.emitEntryPointsSource(code))) { return result; } - maybeDumpIntermediate(compileRequest, code.getBuffer(), target); + subContext.maybeDumpIntermediate(code.getBuffer()); result = CompileResult(code); } break; @@ -1676,19 +1602,6 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) return result; } - // Do emit logic for a single entry point - CompileResult emitEntryPoint( - ComponentType* program, - BackEndCompileRequest* compileRequest, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq) - { - List<Int> entryPointIndices; - entryPointIndices.add(entryPointIndex); - return emitEntryPoints(program, compileRequest, entryPointIndices, targetReq, endToEndReq); - } - enum class OutputFileKind { Text, @@ -1696,7 +1609,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) }; static void writeOutputFile( - BackEndCompileRequest* compileRequest, + CodeGenContext* context, FILE* file, String const& path, void const* data, @@ -1705,7 +1618,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) size_t count = fwrite(data, size, 1, file); if (count != 1) { - compileRequest->getSink()->diagnose( + context->getSink()->diagnose( SourceLoc(), Diagnostics::cannotWriteOutputFile, path); @@ -1713,7 +1626,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } static void writeOutputFile( - BackEndCompileRequest* compileRequest, + CodeGenContext* context, ISlangWriter* writer, String const& path, void const* data, @@ -1722,7 +1635,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) if (SLANG_FAILED(writer->write((const char*)data, size))) { - compileRequest->getSink()->diagnose( + context->getSink()->diagnose( SourceLoc(), Diagnostics::cannotWriteOutputFile, path); @@ -1730,7 +1643,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } static void writeOutputFile( - BackEndCompileRequest* compileRequest, + CodeGenContext* context, String const& path, void const* data, size_t size, @@ -1741,19 +1654,19 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) kind == OutputFileKind::Binary ? "wb" : "w"); if (!file) { - compileRequest->getSink()->diagnose( + context->getSink()->diagnose( SourceLoc(), Diagnostics::cannotWriteOutputFile, path); return; } - writeOutputFile(compileRequest, file, path, data, size); + writeOutputFile(context, file, path, data, size); fclose(file); } static void writeCompileResultToFile( - BackEndCompileRequest* compileRequest, + CodeGenContext* context, String const& outputPath, CompileResult const& result) { @@ -1762,7 +1675,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) case ResultFormat::Text: { auto text = result.outputString; - writeOutputFile(compileRequest, + writeOutputFile(context, outputPath, text.begin(), text.end() - text.begin(), @@ -1778,7 +1691,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) SLANG_UNEXPECTED("No blob to emit"); return; } - writeOutputFile(compileRequest, + writeOutputFile(context, outputPath, blob->getBufferPointer(), blob->getBufferSize(), @@ -1793,16 +1706,6 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } - static void writeEntryPointResultToFile( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - String const& outputPath, - CompileResult const& result) - { - SLANG_UNUSED(entryPoint); - writeCompileResultToFile(compileRequest, outputPath, result); - } - static void writeOutputToConsole( ISlangWriter* writer, String const& text) @@ -1810,13 +1713,13 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) writer->write(text.getBuffer(), text.getLength()); } - static void writeCompileRequestToStandardOutput( - EndToEndCompileRequest* compileRequest, - TargetRequest* targetReq, - CompileResult const& result) + static void writeCompileResultToStandardOutput( + CodeGenContext* codeGenContext, + EndToEndCompileRequest* endToEndReq, + CompileResult const& result) { - ISlangWriter* writer = compileRequest->getWriter(WriterChannel::StdOutput); - auto backEndReq = compileRequest->getBackEndReq(); + auto targetReq = codeGenContext->getTargetReq(); + ISlangWriter* writer = endToEndReq->getWriter(WriterChannel::StdOutput); switch (result.format) { @@ -1863,7 +1766,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) { ComPtr<ISlangBlob> disassemblyBlob; - if (SLANG_SUCCEEDED(dissassembleWithDownstream(backEndReq, targetReq->getTarget(), blobData, blobSize, disassemblyBlob.writeRef()))) + if (SLANG_SUCCEEDED(codeGenContext->dissassembleWithDownstream(blobData, blobSize, disassemblyBlob.writeRef()))) { const UnownedStringSlice disassembly = StringUtil::getSlice(disassemblyBlob); writeOutputToConsole(writer, disassembly); @@ -1891,7 +1794,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) writer->setMode(SLANG_WRITER_MODE_BINARY); writeOutputFile( - backEndReq, + codeGenContext, writer, "stdout", blobData, @@ -1907,23 +1810,11 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } - static void writeEntryPointResultToStandardOutput( - EndToEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - TargetRequest* targetReq, - CompileResult const& result) + void EndToEndCompileRequest::writeWholeProgramResult( + TargetRequest* targetReq) { - SLANG_UNUSED(entryPoint); - writeCompileRequestToStandardOutput(compileRequest, targetReq, result); - } - - static void writeWholeProgramResult( - EndToEndCompileRequest* compileRequest, - TargetRequest* targetReq) - { - auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); + auto program = getSpecializedGlobalAndEntryPointsComponentType(); auto targetProgram = program->getTargetProgram(targetReq); - auto backEndReq = compileRequest->getBackEndReq(); auto& result = targetProgram->getExistingWholeProgramResult(); @@ -1931,34 +1822,37 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) if (result.format == ResultFormat::None) return; + CodeGenContext::EntryPointIndices entryPointIndices; + for (Index i = 0; i < program->getEntryPointCount(); ++i) + entryPointIndices.add(i); + CodeGenContext::Shared sharedCodeGenContext(targetProgram, entryPointIndices, getSink(), this); + CodeGenContext codeGenContext(&sharedCodeGenContext); + // It is possible that we are dynamically discovering entry // points (using `[shader(...)]` attributes), so that there // might be entry points added to the program that did not // get paths specified via command-line options. // RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; - if (compileRequest->m_targetInfos.TryGetValue(targetReq, targetInfo)) + if (m_targetInfos.TryGetValue(targetReq, targetInfo)) { String outputPath = targetInfo->wholeTargetOutputPath; if (outputPath != "") { - writeCompileResultToFile(backEndReq, outputPath, result); + writeCompileResultToFile(&codeGenContext, outputPath, result); return; } } - writeCompileRequestToStandardOutput(compileRequest, targetReq, result); + writeCompileResultToStandardOutput(&codeGenContext, this, result); } - static void writeEntryPointResult( - ComponentType* currentProgram, - EndToEndCompileRequest* compileRequest, - Int entryPointIndex, - TargetRequest* targetReq) + void EndToEndCompileRequest::writeEntryPointResult( + TargetRequest* targetReq, + Int entryPointIndex) { - auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); + auto program = getSpecializedGlobalAndEntryPointsComponentType(); auto targetProgram = program->getTargetProgram(targetReq); - auto backEndReq = compileRequest->getBackEndReq(); auto& result = targetProgram->getExistingEntryPointResult(entryPointIndex); @@ -1966,29 +1860,35 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) if (result.format == ResultFormat::None) return; + CodeGenContext::EntryPointIndices entryPointIndices; + entryPointIndices.add(entryPointIndex); + + CodeGenContext::Shared sharedCodeGenContext(targetProgram, entryPointIndices, getSink(), this); + CodeGenContext codeGenContext(&sharedCodeGenContext); + // It is possible that we are dynamically discovering entry // points (using `[shader(...)]` attributes), so that there // might be entry points added to the program that did not // get paths specified via command-line options. // RefPtr<EndToEndCompileRequest::TargetInfo> targetInfo; - auto entryPoint = currentProgram->getEntryPoint(entryPointIndex); - if(compileRequest->m_targetInfos.TryGetValue(targetReq, targetInfo)) + auto entryPoint = program->getEntryPoint(entryPointIndex); + if(m_targetInfos.TryGetValue(targetReq, targetInfo)) { String outputPath; if(targetInfo->entryPointOutputPaths.TryGetValue(entryPointIndex, outputPath)) { - writeEntryPointResultToFile(backEndReq, entryPoint, outputPath, result); + writeCompileResultToFile(&codeGenContext, outputPath, result); return; } } - writeEntryPointResultToStandardOutput(compileRequest, entryPoint, targetReq, result); + writeCompileResultToStandardOutput(&codeGenContext, this, result); } CompileResult& TargetProgram::_createWholeProgramResult( - BackEndCompileRequest* backEndRequest, - EndToEndCompileRequest* endToEndRequest) + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq) { // We want to call `emitEntryPoints` function to generate code that contains // all the entrypoints defined in `m_program`. @@ -2001,20 +1901,19 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) entryPointIndices[i] = i; auto& result = m_wholeProgramResult; - result = emitEntryPoints( - m_program, - backEndRequest, - entryPointIndices, - m_targetReq, - endToEndRequest); + + CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); + CodeGenContext codeGenContext(&sharedCodeGenContext); + + result = codeGenContext.emitEntryPoints(); return result; } CompileResult& TargetProgram::_createEntryPointResult( Int entryPointIndex, - BackEndCompileRequest* backEndRequest, - EndToEndCompileRequest* endToEndRequest) + DiagnosticSink* sink, + EndToEndCompileRequest* endToEndReq) { // It is possible that entry points goot added to the `Program` // *after* we created this `TargetProgram`, so there might be @@ -2028,15 +1927,15 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) m_entryPointResults.setCount(entryPointIndex+1); auto& result = m_entryPointResults[entryPointIndex]; - result = emitEntryPoint( - m_program, - backEndRequest, - entryPointIndex, - m_targetReq, - endToEndRequest); - return result; + CodeGenContext::EntryPointIndices entryPointIndices; + entryPointIndices.add(entryPointIndex); + CodeGenContext::Shared sharedCodeGenContext(this, entryPointIndices, sink, endToEndReq); + CodeGenContext codeGenContext(&sharedCodeGenContext); + result = codeGenContext.emitEntryPoints(); + + return result; } CompileResult& TargetProgram::getOrCreateWholeProgramResult( @@ -2055,17 +1954,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) return result; } - RefPtr<BackEndCompileRequest> backEndRequest = new BackEndCompileRequest( - m_program->getLinkage(), - sink, - m_program); - - backEndRequest->shouldDumpIR = - (m_targetReq->getTargetFlags() & SLANG_TARGET_FLAG_DUMP_IR) != 0; - - return _createWholeProgramResult( - backEndRequest, - nullptr); + return _createWholeProgramResult(sink); } CompileResult& TargetProgram::getOrCreateEntryPointResult( @@ -2088,36 +1977,23 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) return result; } - RefPtr<BackEndCompileRequest> backEndRequest = new BackEndCompileRequest( - m_program->getLinkage(), - sink, - m_program); - backEndRequest->shouldDumpIR = - (m_targetReq->getTargetFlags() & SLANG_TARGET_FLAG_DUMP_IR) != 0; - backEndRequest->shouldDumpIntermediates = m_targetReq->shouldDumpIntermediates(); - return _createEntryPointResult( entryPointIndex, - backEndRequest, - nullptr); + sink); } - void generateOutputForTarget( - BackEndCompileRequest* compileReq, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq) + void EndToEndCompileRequest::generateOutput( + TargetProgram* targetProgram) { - auto program = compileReq->getProgram(); - auto targetProgram = program->getTargetProgram(targetReq); + auto program = targetProgram->getProgram(); + auto targetReq = targetProgram->getTargetReq(); // Generate target code any entry points that // have been requested for compilation. auto entryPointCount = program->getEntryPointCount(); if (targetReq->isWholeProgramRequest()) { - targetProgram->_createWholeProgramResult( - compileReq, - endToEndReq); + targetProgram->_createWholeProgramResult(getSink(), this); } else { @@ -2125,8 +2001,8 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) { targetProgram->_createEntryPointResult( ii, - compileReq, - endToEndReq); + getSink(), + this); } } } @@ -2308,20 +2184,18 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) } - static void _generateOutput( - BackEndCompileRequest* compileRequest, - EndToEndCompileRequest* endToEndReq) + void EndToEndCompileRequest::generateOutput( + ComponentType* program) { // When dynamic dispatch is disabled, the program must // be fully specialized by now. So we check if we still // have unspecialized generic/existential parameters, // and report them as an error. // - auto program = compileRequest->getProgram(); auto specializationParamCount = program->getSpecializationParamCount(); - if (compileRequest->disableDynamicDispatch && specializationParamCount != 0) + if (disableDynamicDispatch && specializationParamCount != 0) { - auto sink = compileRequest->getSink(); + auto sink = getSink(); for( Index ii = 0; ii < specializationParamCount; ++ii ) { @@ -2347,56 +2221,48 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) // Go through the code-generation targets that the user // has specified, and generate code for each of them. // - auto linkage = compileRequest->getLinkage(); + auto linkage = getLinkage(); for (auto targetReq : linkage->targets) { - generateOutputForTarget(compileRequest, targetReq, endToEndReq); + auto targetProgram = program->getTargetProgram(targetReq); + generateOutput(targetProgram); } } - void generateOutput( - BackEndCompileRequest* compileRequest) + void EndToEndCompileRequest::generateOutput() { - _generateOutput(compileRequest, nullptr); - } - - void generateOutput( - EndToEndCompileRequest* compileRequest) - { - _generateOutput(compileRequest->getBackEndReq(), compileRequest); + generateOutput(getSpecializedGlobalAndEntryPointsComponentType()); // If we are in command-line mode, we might be expected to actually // write output to one or more files here. - if (compileRequest->m_isCommandLineCompile) + if (m_isCommandLineCompile) { - auto linkage = compileRequest->getLinkage(); - auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); + auto linkage = getLinkage(); + auto program = getSpecializedGlobalAndEntryPointsComponentType(); for (auto targetReq : linkage->targets) { - Index entryPointCount = program->getEntryPointCount(); - if (targetReq->isWholeProgramRequest()) { + if (targetReq->isWholeProgramRequest()) + { writeWholeProgramResult( - compileRequest, targetReq); } else { + Index entryPointCount = program->getEntryPointCount(); for (Index ee = 0; ee < entryPointCount; ++ee) { writeEntryPointResult( - program, - compileRequest, - ee, - targetReq); + targetReq, + ee); } } } - compileRequest->maybeCreateContainer(); - compileRequest->maybeWriteContainer(compileRequest->m_containerOutputPath); + maybeCreateContainer(); + maybeWriteContainer(m_containerOutputPath); - _writeDependencyFile(compileRequest); + _writeDependencyFile(this); } } @@ -2404,8 +2270,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) // - void dumpIntermediate( - BackEndCompileRequest* request, + void CodeGenContext::dumpIntermediate( void const* data, size_t size, char const* ext, @@ -2427,7 +2292,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) #endif String path; - path.append(request->m_dumpIntermediatePrefix); + path.append(getIntermediateDumpPrefix()); path.append(id); path.append(ext); @@ -2438,36 +2303,32 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) fclose(file); } - void dumpIntermediateText( - BackEndCompileRequest* compileRequest, + void CodeGenContext::dumpIntermediateText( void const* data, size_t size, char const* ext) { - dumpIntermediate(compileRequest, data, size, ext, false); + dumpIntermediate(data, size, ext, false); } - void dumpIntermediateBinary( - BackEndCompileRequest* compileRequest, + void CodeGenContext::dumpIntermediateBinary( void const* data, size_t size, char const* ext) { - dumpIntermediate(compileRequest, data, size, ext, true); + dumpIntermediate(data, size, ext, true); } - void maybeDumpIntermediate( - BackEndCompileRequest* compileRequest, - DownstreamCompileResult* compileResult, - CodeGenTarget target) + void CodeGenContext::maybeDumpIntermediate( + DownstreamCompileResult* compileResult) { - if (!compileRequest->shouldDumpIntermediates) + if (!shouldDumpIntermediates()) return; ComPtr<ISlangBlob> blob; if (SLANG_SUCCEEDED(compileResult->getBinary(blob))) { - maybeDumpIntermediate(compileRequest, blob->getBufferPointer(), blob->getBufferSize(), target); + maybeDumpIntermediate(blob->getBufferPointer(), blob->getBufferSize()); } } @@ -2497,15 +2358,14 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) return nullptr; } - void maybeDumpIntermediate( - BackEndCompileRequest* compileRequest, + void CodeGenContext::maybeDumpIntermediate( void const* data, - size_t size, - CodeGenTarget target) + size_t size) { - if (!compileRequest->shouldDumpIntermediates) + if (!shouldDumpIntermediates()) return; + auto target = getTargetFormat(); switch (target) { case CodeGenTarget::CPPSource: @@ -2518,7 +2378,7 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) case CodeGenTarget::GLSL: case CodeGenTarget::HLSL: { - dumpIntermediateText(compileRequest, data, size, _getTargetExtension(target)); + dumpIntermediateText(data, size, _getTargetExtension(target)); break; } @@ -2537,13 +2397,13 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) const char* ext = _getTargetExtension(target); SLANG_ASSERT(ext); - dumpIntermediateBinary(compileRequest, data, size, ext); + dumpIntermediateBinary(data, size, ext); ComPtr<ISlangBlob> disassemblyBlob; - if (SLANG_SUCCEEDED(dissassembleWithDownstream(compileRequest, target, data, size, disassemblyBlob.writeRef()))) + if (SLANG_SUCCEEDED(dissassembleWithDownstream(data, size, disassemblyBlob.writeRef()))) { StringBuilder buf; buf << ext << ".asm"; - dumpIntermediateText(compileRequest, disassemblyBlob->getBufferPointer(), disassemblyBlob->getBufferSize(), buf.getBuffer()); + dumpIntermediateText(disassemblyBlob->getBufferPointer(), disassemblyBlob->getBufferSize(), buf.getBuffer()); } break; } @@ -2552,21 +2412,92 @@ void printDiagnosticArg(StringBuilder& sb, CodeGenTarget val) case CodeGenTarget::ShaderSharedLibrary: case CodeGenTarget::HostExecutable: { - dumpIntermediateBinary(compileRequest, data, size, _getTargetExtension(target)); + dumpIntermediateBinary(data, size, _getTargetExtension(target)); break; } default: break; } } - void maybeDumpIntermediate( - BackEndCompileRequest* compileRequest, - char const* text, - CodeGenTarget target) + void CodeGenContext::maybeDumpIntermediate( + char const* text) { - if (!compileRequest->shouldDumpIntermediates) + if (!shouldDumpIntermediates()) return; - maybeDumpIntermediate(compileRequest, text, strlen(text), target); + maybeDumpIntermediate(text, strlen(text)); + } + + IRDumpOptions CodeGenContext::getIRDumpOptions() + { + if (auto endToEndReq = isEndToEndCompile()) + { + return endToEndReq->getFrontEndReq()->m_irDumpOptions; + } + return IRDumpOptions(); + } + + bool CodeGenContext::shouldValidateIR() + { + if (auto endToEndReq = isEndToEndCompile()) + { + if (endToEndReq->getFrontEndReq()->shouldValidateIR) + return true; + } + + return false; + } + + bool CodeGenContext::shouldDumpIR() + { + if (getTargetReq()->getTargetFlags() & SLANG_TARGET_FLAG_DUMP_IR) + return true; + + if (auto endToEndReq = isEndToEndCompile()) + { + if (endToEndReq->getFrontEndReq()->shouldDumpIR) + return true; + } + + return false; + } + + bool CodeGenContext::shouldDumpIntermediates() + { + if (getTargetReq()->shouldDumpIntermediates()) + return true; + if (auto endToEndReq = isEndToEndCompile()) + { + if (endToEndReq->shouldDumpIntermediates) + return true; + } + return false; + } + + String CodeGenContext::getIntermediateDumpPrefix() + { + if (auto endToEndReq = isEndToEndCompile()) + { + return endToEndReq->m_dumpIntermediatePrefix; + } + return String(); + } + + bool CodeGenContext::getUseUnknownImageFormatAsDefault() + { + if (auto endToEndReq = isEndToEndCompile()) + { + return endToEndReq->useUnknownImageFormatAsDefault; + } + return false; + } + + bool CodeGenContext::isSpecializationDisabled() + { + if (auto endToEndReq = isEndToEndCompile()) + { + return endToEndReq->disableSpecialization; + } + return false; } } |
