From 47b43f8b15ef35c520b9b287fd17ff25e36bfe95 Mon Sep 17 00:00:00 2001 From: Dietrich Geisler Date: Mon, 29 Jun 2020 17:42:12 -0400 Subject: Backend for Multiple Entry Points (#1411) * Backend for Multiple Entry Points Introduces the basic backend on the compiler for zero or more entry points. Entry points have been extended to lists for several functions, with loopFunctions have been extended to take in entry points and indices as appropriate, to allow for multiple entry points once the frontend is expanded. Several functions are currently being assumed to have a single entry point for simplicity and provide a work in progress commit. * Progress on debugging fixes * Tests passing * Refactored emitEntryPoints * Updated lists to be by constant reference * Fixes to formatting * Refactoring updates for the compiler * Fix for compilation errors * Reformatting * More reformatting * Moved struct around to help with compilation Co-authored-by: Tim Foley --- source/slang/slang-compiler.cpp | 560 +++++++++++++++++++++++++++------------- 1 file changed, 384 insertions(+), 176 deletions(-) (limited to 'source/slang/slang-compiler.cpp') diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index fe0f7d69c..ca7e5fb83 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -508,29 +508,42 @@ namespace Slang SLANG_OK; } - // - - /// If there is a pass-through compile going on, find the translation unit for the given entry point. - TranslationUnitRequest* findPassThroughTranslationUnit( - EndToEndCompileRequest* endToEndReq, - Int entryPointIndex) - { - // If there isn't an end-to-end compile going on, + bool isPassThroughEnabled( + EndToEndCompileRequest* endToEndReq) + { // If there isn't an end-to-end compile going on, // there can be no pass-through. // - if(!endToEndReq) return nullptr; + if (!endToEndReq) return false; // And if pass-through isn't set, we don't need // access to the translation unit. // - if(endToEndReq->passThrough == PassThroughMode::None) return nullptr; - + if(endToEndReq->passThrough == PassThroughMode::None) return false; + return true; + } + /// If there is a pass-through compile going on, find the translation unit for the given entry point. + /// Assumes isPassThroughEnabled has already been called + TranslationUnitRequest* getPassThroughTranslationUnit( + EndToEndCompileRequest* endToEndReq, + Int entryPointIndex) + { + SLANG_ASSERT(endToEndReq); + SLANG_ASSERT(endToEndReq->passThrough != PassThroughMode::None); auto frontEndReq = endToEndReq->getFrontEndReq(); auto entryPointReq = frontEndReq->getEntryPointReq(entryPointIndex); auto translationUnit = entryPointReq->getTranslationUnit(); return translationUnit; } + TranslationUnitRequest* findPassThroughTranslationUnit( + EndToEndCompileRequest* endToEndReq, + Int entryPointIndex) + { + if (isPassThroughEnabled(endToEndReq)) + return getPassThroughTranslationUnit(endToEndReq, entryPointIndex); + return nullptr; + } + static void _appendEscapedPath(const UnownedStringSlice& path, StringBuilder& outBuilder) { for (auto c : path) @@ -552,67 +565,86 @@ namespace Slang outCodeBuilder << fileContent << "\n"; } - SlangResult emitEntryPointSource( + SlangResult emitEntryPointsSource( BackEndCompileRequest* compileRequest, - Int entryPointIndex, + List entryPointIndices, TargetRequest* targetReq, CodeGenTarget target, EndToEndCompileRequest* endToEndReq, - SourceResult& outSource) + SourceResult& outSource) { outSource.reset(); - if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) + if(isPassThroughEnabled(endToEndReq)) { - // Generate a string that includes the content of - // the source file(s), along with a line directive - // to ensure that we get reasonable messages - // from the downstream compiler when in pass-through - // mode. - - StringBuilder codeBuilder; - if (target == CodeGenTarget::GLSL) + for (auto entryPointIndex = entryPointIndices.begin(); entryPointIndex != entryPointIndices.end(); entryPointIndex++) { - // Special case GLSL - int translationUnitCounter = 0; - for (auto sourceFile : translationUnit->getSourceFiles()) + auto translationUnit = getPassThroughTranslationUnit(endToEndReq, *entryPointIndex); + SLANG_ASSERT(translationUnit); + // Generate a string that includes the content of + // the source file(s), along with a line directive + // to ensure that we get reasonable messages + // from the downstream compiler when in pass-through + // mode. + + StringBuilder codeBuilder; + if (target == CodeGenTarget::GLSL) { - int translationUnitIndex = translationUnitCounter++; - - // We want to output `#line` directives, but we need - // to skip this for the first file, since otherwise - // some GLSL implementations will get tripped up by - // not having the `#version` directive be the first - // thing in the file. - if (translationUnitIndex != 0) + // Special case GLSL + int translationUnitCounter = 0; + for (auto sourceFile : translationUnit->getSourceFiles()) { - codeBuilder << "#line 1 " << translationUnitIndex << "\n"; + int translationUnitIndex = translationUnitCounter++; + + // We want to output `#line` directives, but we need + // to skip this for the first file, since otherwise + // some GLSL implementations will get tripped up by + // not having the `#version` directive be the first + // thing in the file. + if (translationUnitIndex != 0) + { + codeBuilder << "#line 1 " << translationUnitIndex << "\n"; + } + codeBuilder << sourceFile->getContent() << "\n"; } - codeBuilder << sourceFile->getContent() << "\n"; } - } - else - { - for(auto sourceFile : translationUnit->getSourceFiles()) + else { - _appendCodeWithPath(sourceFile->getPathInfo().foundPath.getUnownedSlice(), sourceFile->getContent(), codeBuilder); + for (auto sourceFile : translationUnit->getSourceFiles()) + { + _appendCodeWithPath(sourceFile->getPathInfo().foundPath.getUnownedSlice(), sourceFile->getContent(), codeBuilder); + } } - } - outSource.source = codeBuilder.ProduceString(); + outSource.source = codeBuilder.ProduceString(); + } return SLANG_OK; } else { return emitEntryPointSourceFromIR( compileRequest, - entryPointIndex, + entryPointIndices, target, targetReq, outSource); } } + SlangResult emitEntryPointSource( + BackEndCompileRequest* compileRequest, + Int entryPointIndex, + TargetRequest* targetReq, + CodeGenTarget target, + EndToEndCompileRequest* endToEndReq, + SourceResult& outSource) + { + List entryPointIndices; + entryPointIndices.add(entryPointIndex); + return emitEntryPointsSource(compileRequest, entryPointIndices, targetReq, + target, endToEndReq, outSource); + } + String GetHLSLProfileName(Profile profile) { switch( profile.getFamily() ) @@ -745,13 +777,17 @@ namespace Slang } } - String calcSourcePathForEntryPoint( + String calcSourcePathForEntryPoints( EndToEndCompileRequest* endToEndReq, - UInt entryPointIndex) + List entryPointIndices) { + String failureMode = "slang-generated"; + if (entryPointIndices.getCount() != 1) + return failureMode; + auto entryPointIndex = entryPointIndices[0]; auto translationUnitRequest = findPassThroughTranslationUnit(endToEndReq, entryPointIndex); - if(!translationUnitRequest) - return "slang-generated"; + if (!translationUnitRequest) + return failureMode; const auto& sourceFiles = translationUnitRequest->getSourceFiles(); @@ -776,6 +812,49 @@ namespace Slang } } + String calcSourcePathForEntryPoint( + EndToEndCompileRequest* endToEndReq, + Int entryPointIndex) + { + List entryPointIndices; + entryPointIndices.add(entryPointIndex); + return calcSourcePathForEntryPoints(endToEndReq, entryPointIndices); + } + + struct EntryPointAndIndex { + EntryPoint* entryPoint; + Int index; + EntryPointAndIndex(); + EntryPointAndIndex(EntryPoint* entryPoint, Int index); + }; + + EntryPointAndIndex::EntryPointAndIndex() + { + entryPoint = NULL; + index = -1; + } + + EntryPointAndIndex::EntryPointAndIndex(EntryPoint* ep, Int i) + { + entryPoint = ep; + index = i; + } + + // Helper function for recovering the entry point code indices from a list of entry points + List getEntryPointIndices(List const& entryPoints) { + List result; + for (auto entryPoint = entryPoints.begin(); entryPoint != entryPoints.end(); entryPoint++) { + result.add(entryPoint->index); + } + return result; + } + + // Helper function for cases where we can assume a single entry point + EntryPointAndIndex assertSingleEntryPoint(List const& entryPoints) { + SLANG_ASSERT(entryPoints.getCount() == 1); + return *entryPoints.begin(); + } + #if SLANG_ENABLE_DXBC_SUPPORT static UnownedStringSlice _getSlice(ID3DBlob* blob) @@ -879,8 +958,7 @@ namespace Slang SlangResult emitDXBytecodeForEntryPoint( BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, + EntryPointAndIndex entryPoint, TargetRequest* targetReq, EndToEndCompileRequest* endToEndReq, List& byteCodeOut) @@ -897,12 +975,12 @@ namespace Slang } SourceResult source; - SLANG_RETURN_ON_FAIL(emitEntryPointSource(compileRequest, entryPointIndex, targetReq, CodeGenTarget::HLSL, endToEndReq, source)); + SLANG_RETURN_ON_FAIL(emitEntryPointSource(compileRequest, entryPoint.index, targetReq, CodeGenTarget::HLSL, endToEndReq, source)); const auto& hlslCode = source.source; maybeDumpIntermediate(compileRequest, hlslCode.getBuffer(), CodeGenTarget::HLSL); - auto profile = getEffectiveProfile(entryPoint, targetReq); + auto profile = getEffectiveProfile(entryPoint.entryPoint, targetReq); auto linkage = compileRequest->getLinkage(); @@ -922,7 +1000,7 @@ namespace Slang FxcIncludeHandler fxcIncludeHandlerStorage; FxcIncludeHandler* fxcIncludeHandler = nullptr; - if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) + if(auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPoint.index)) { for( auto& define : translationUnit->compileRequest->preprocessorDefinitions ) { @@ -990,13 +1068,13 @@ namespace Slang { case DebugInfoLevel::None: break; - + default: flags |= D3DCOMPILE_DEBUG; break; } - const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); + const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPoint.index); ComPtr codeBlob; ComPtr diagnosticsBlob; @@ -1006,7 +1084,7 @@ namespace Slang sourcePath.getBuffer(), dxMacros, fxcIncludeHandler, - getText(entryPoint->getName()).begin(), + getText(entryPoint.entryPoint->getName()).begin(), GetHLSLProfileName(profile).getBuffer(), flags, 0, // unused: effect flags @@ -1066,8 +1144,7 @@ namespace Slang SlangResult emitDXBytecodeAssemblyForEntryPoint( BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, + EntryPointAndIndex entryPoint, TargetRequest* targetReq, EndToEndCompileRequest* endToEndReq, String& assemOut) @@ -1077,7 +1154,6 @@ namespace Slang SLANG_RETURN_ON_FAIL(emitDXBytecodeForEntryPoint( compileRequest, entryPoint, - entryPointIndex, targetReq, endToEndReq, dxbc)); @@ -1198,9 +1274,9 @@ SlangResult dissassembleDXILUsingDXC( return SLANG_OK; } - SlangResult emitWithDownstreamForEntryPoint( + SlangResult emitWithDownstreamForEntryPoints( BackEndCompileRequest* slangRequest, - Int entryPointIndex, + List entryPointIndices, TargetRequest* targetReq, EndToEndCompileRequest* endToEndReq, RefPtr& outResult) @@ -1211,7 +1287,7 @@ SlangResult dissassembleDXILUsingDXC( auto session = slangRequest->getSession(); - const String originalSourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); + const String originalSourcePath = calcSourcePathForEntryPoints(endToEndReq, entryPointIndices); CodeGenTarget sourceTarget = CodeGenTarget::None; SourceLanguage sourceLanguage = SourceLanguage::Unknown; @@ -1270,8 +1346,11 @@ SlangResult dissassembleDXILUsingDXC( /* This is more convoluted than the other scenarios, because when we invoke C/C++ compiler we would ideally like to use the original file. We want to do this because we want includes relative to the source file to work, and for that to work most easily we want to use the original file, if there is one */ - if (auto translationUnit = findPassThroughTranslationUnit(endToEndReq, entryPointIndex)) + // Note also that we require there to be only one entry point to use a translation unit + // TODO(DG): Review this assertion later + if (isPassThroughEnabled(endToEndReq) && entryPointIndices.getCount() == 1) { + auto translationUnit = getPassThroughTranslationUnit(endToEndReq, entryPointIndices[0]); // If it's pass through we accumulate the preprocessor definitions. for (auto& define : translationUnit->compileRequest->preprocessorDefinitions) { @@ -1333,7 +1412,7 @@ SlangResult dissassembleDXILUsingDXC( else { SourceResult source; - SLANG_RETURN_ON_FAIL(emitEntryPointSource(slangRequest, entryPointIndex, targetReq, sourceTarget, endToEndReq, source)); + SLANG_RETURN_ON_FAIL(emitEntryPointsSource(slangRequest, entryPointIndices, targetReq, sourceTarget, endToEndReq, source)); options.sourceContents = source.source; } @@ -1341,7 +1420,7 @@ SlangResult dissassembleDXILUsingDXC( else { SourceResult source; - SLANG_RETURN_ON_FAIL(emitEntryPointSource(slangRequest, entryPointIndex, targetReq, sourceTarget, endToEndReq, source)); + SLANG_RETURN_ON_FAIL(emitEntryPointsSource(slangRequest, entryPointIndices, targetReq, sourceTarget, endToEndReq, source)); // Look for the version if (auto cudaTracker = as(source.extensionTracker)) @@ -1557,25 +1636,24 @@ SlangResult dissassembleDXILUsingDXC( return SLANG_OK; } - SlangResult emitSPIRVForEntryPointDirectly( + SlangResult emitSPIRVForEntryPointsDirectly( BackEndCompileRequest* compileRequest, - Int entryPointIndex, + List entryPointIndices, TargetRequest* targetReq, List& spirvOut); - SlangResult emitSPIRVForEntryPointViaGLSL( - BackEndCompileRequest* slangRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq, - List& spirvOut) + SlangResult emitSPIRVForEntryPointsViaGLSL( + BackEndCompileRequest* slangRequest, + List const& entryPoints, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List& spirvOut) { spirvOut.clear(); SourceResult source; - SLANG_RETURN_ON_FAIL(emitEntryPointSource(slangRequest, entryPointIndex, targetReq, CodeGenTarget::GLSL, endToEndReq, source)); + SLANG_RETURN_ON_FAIL(emitEntryPointsSource(slangRequest, getEntryPointIndices(entryPoints), targetReq, CodeGenTarget::GLSL, endToEndReq, source)); const auto& rawGLSL = source.source; @@ -1586,7 +1664,9 @@ SlangResult dissassembleDXILUsingDXC( ((List*)userData)->addRange((uint8_t*)data, size); }; - const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPointIndex); + SLANG_ASSERT(entryPoints.getCount() == 1); + auto entryPoint = entryPoints[0]; + const String sourcePath = calcSourcePathForEntryPoint(endToEndReq, entryPoint.index); glslang_CompileRequest_1_1 request; memset(&request, 0, sizeof(request)); @@ -1594,10 +1674,10 @@ SlangResult dissassembleDXILUsingDXC( request.action = GLSLANG_ACTION_COMPILE_GLSL_TO_SPIRV; request.sourcePath = sourcePath.getBuffer(); - request.slangStage = (SlangStage)entryPoint->getStage(); + request.slangStage = (SlangStage)entryPoint.entryPoint->getStage(); - request.inputBegin = rawGLSL.begin(); - request.inputEnd = rawGLSL.end(); + request.inputBegin = rawGLSL.begin(); + request.inputEnd = rawGLSL.end(); if (GLSLExtensionTracker* tracker = as(source.extensionTracker.Ptr())) { @@ -1616,47 +1696,43 @@ SlangResult dissassembleDXILUsingDXC( return SLANG_OK; } - SlangResult emitSPIRVForEntryPoint( - BackEndCompileRequest* slangRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq, - List& spirvOut) + SlangResult emitSPIRVForEntryPoints( + BackEndCompileRequest* slangRequest, + List const& entryPoints, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + List& spirvOut) { if( slangRequest->shouldEmitSPIRVDirectly ) { - return emitSPIRVForEntryPointDirectly( + return emitSPIRVForEntryPointsDirectly( slangRequest, - entryPointIndex, + getEntryPointIndices(entryPoints), targetReq, spirvOut); } else { - return emitSPIRVForEntryPointViaGLSL( + return emitSPIRVForEntryPointsViaGLSL( slangRequest, - entryPoint, - entryPointIndex, + entryPoints, targetReq, endToEndReq, spirvOut); } } - SlangResult emitSPIRVAssemblyForEntryPoint( - BackEndCompileRequest* slangRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq, - String& assemblyOut) + SlangResult emitSPIRVAssemblyForEntryPoints( + BackEndCompileRequest* slangRequest, + List const& entryPoints, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq, + String& assemblyOut) { List spirv; - SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoint( + SLANG_RETURN_ON_FAIL(emitSPIRVForEntryPoints( slangRequest, - entryPoint, - entryPointIndex, + entryPoints, targetReq, endToEndReq, spirv)); @@ -1668,13 +1744,12 @@ SlangResult dissassembleDXILUsingDXC( } #endif - // Do emit logic for a single entry point - CompileResult emitEntryPoint( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - Int entryPointIndex, - TargetRequest* targetReq, - EndToEndCompileRequest* endToEndReq) + // Do emit logic for a zero or more entry points + CompileResult emitEntryPoints( + BackEndCompileRequest* compileRequest, + List const& entryPoints, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) { CompileResult result; @@ -1689,9 +1764,9 @@ SlangResult dissassembleDXILUsingDXC( { RefPtr downstreamResult; - if (SLANG_SUCCEEDED(emitWithDownstreamForEntryPoint( + if (SLANG_SUCCEEDED(emitWithDownstreamForEntryPoints( compileRequest, - entryPointIndex, + getEntryPointIndices(entryPoints), targetReq, endToEndReq, downstreamResult))) @@ -1709,7 +1784,8 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::CSource: { SourceResult source; - if (SLANG_FAILED(emitEntryPointSource(compileRequest, entryPointIndex, targetReq, target, endToEndReq, source))) + if (SLANG_FAILED(emitEntryPointsSource(compileRequest, getEntryPointIndices(entryPoints), + targetReq, target, endToEndReq, source))) { return result; } @@ -1723,11 +1799,11 @@ SlangResult dissassembleDXILUsingDXC( #if SLANG_ENABLE_DXBC_SUPPORT case CodeGenTarget::DXBytecode: { + // Assert only one entry point case -- move out of this function List code; if (SLANG_SUCCEEDED(emitDXBytecodeForEntryPoint( compileRequest, - entryPoint, - entryPointIndex, + assertSingleEntryPoint(entryPoints), targetReq, endToEndReq, code))) @@ -1741,11 +1817,11 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXBytecodeAssembly: { + // Assert only one entry point case String code; if (SLANG_SUCCEEDED(emitDXBytecodeAssemblyForEntryPoint( compileRequest, - entryPoint, - entryPointIndex, + assertSingleEntryPoint(entryPoints), targetReq, endToEndReq, code))) @@ -1761,10 +1837,11 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXIL: { List code; + EntryPointAndIndex entryPoint = assertSingleEntryPoint(entryPoints); if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( compileRequest, - entryPoint, - entryPointIndex, + entryPoint.entryPoint, + entryPoint.index, targetReq, endToEndReq, code))) @@ -1778,19 +1855,20 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::DXILAssembly: { List code; + EntryPointAndIndex entryPoint = assertSingleEntryPoint(entryPoints); if (SLANG_SUCCEEDED(emitDXILForEntryPointUsingDXC( compileRequest, - entryPoint, - entryPointIndex, + entryPoint.entryPoint, + entryPoint.index, targetReq, endToEndReq, code))) { - String assembly; + String assembly; dissassembleDXILUsingDXC( compileRequest, code.getBuffer(), - code.getCount(), + code.getCount(), assembly); maybeDumpIntermediate(compileRequest, assembly.getBuffer(), target); @@ -1803,10 +1881,9 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::SPIRV: { List code; - if (SLANG_SUCCEEDED(emitSPIRVForEntryPoint( + if (SLANG_SUCCEEDED(emitSPIRVForEntryPoints( compileRequest, - entryPoint, - entryPointIndex, + entryPoints, targetReq, endToEndReq, code))) @@ -1820,10 +1897,9 @@ SlangResult dissassembleDXILUsingDXC( case CodeGenTarget::SPIRVAssembly: { String code; - if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoint( + if (SLANG_SUCCEEDED(emitSPIRVAssemblyForEntryPoints( compileRequest, - entryPoint, - entryPointIndex, + entryPoints, targetReq, endToEndReq, code))) @@ -1838,7 +1914,7 @@ SlangResult dissassembleDXILUsingDXC( // The user requested no output break; - // Note(tfoley): We currently hit this case when compiling the stdlib + // Note(tfoley): We currently hit this case when compiling the stdlib case CodeGenTarget::Unknown: break; @@ -1850,6 +1926,18 @@ SlangResult dissassembleDXILUsingDXC( return result; } + // Do emit logic for a single entry point + CompileResult emitEntryPoint( + BackEndCompileRequest* compileRequest, + EntryPointAndIndex entryPoint, + TargetRequest* targetReq, + EndToEndCompileRequest* endToEndReq) + { + List entryPoints; + entryPoints.add(entryPoint); + return emitEntryPoints(compileRequest, entryPoints, targetReq, endToEndReq); + } + enum class OutputFileKind { Text, @@ -1913,14 +2001,11 @@ SlangResult dissassembleDXILUsingDXC( fclose(file); } - static void writeEntryPointResultToFile( - BackEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - String const& outputPath, - CompileResult const& result) + static void writeCompileResultToFile( + BackEndCompileRequest* compileRequest, + String const& outputPath, + CompileResult const& result) { - SLANG_UNUSED(entryPoint); - switch (result.format) { case ResultFormat::Text: @@ -1953,6 +2038,16 @@ SlangResult dissassembleDXILUsingDXC( } + static void writeEntryPointResultToFile( + BackEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + String const& outputPath, + CompileResult const& result) + { + SLANG_UNUSED(entryPoint); + writeCompileResultToFile(compileRequest, outputPath, result); + } + static void writeOutputToConsole( ISlangWriter* writer, String const& text) @@ -1960,14 +2055,11 @@ SlangResult dissassembleDXILUsingDXC( writer->write(text.getBuffer(), text.getLength()); } - static void writeEntryPointResultToStandardOutput( - EndToEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - TargetRequest* targetReq, - CompileResult const& result) + static void writeCompileRequestToStandardOutput( + EndToEndCompileRequest* compileRequest, + TargetRequest* targetReq, + CompileResult const& result) { - SLANG_UNUSED(entryPoint); - ISlangWriter* writer = compileRequest->getWriter(WriterChannel::StdOutput); auto backEndReq = compileRequest->getBackEndReq(); @@ -2007,7 +2099,7 @@ SlangResult dissassembleDXILUsingDXC( #if SLANG_ENABLE_DXIL_SUPPORT case CodeGenTarget::DXIL: { - String assembly; + String assembly; dissassembleDXILUsingDXC(backEndReq, blobData, blobSize, assembly); writeOutputToConsole(writer, assembly); } @@ -2035,22 +2127,22 @@ SlangResult dissassembleDXILUsingDXC( default: SLANG_UNEXPECTED("unhandled output format"); return; - } - } - else - { - // Redirecting stdout to a file, so do the usual thing - writer->setMode(SLANG_WRITER_MODE_BINARY); - - writeOutputFile( - backEndReq, - writer, - "stdout", - blobData, - blobSize); + } } + else + { + // Redirecting stdout to a file, so do the usual thing + writer->setMode(SLANG_WRITER_MODE_BINARY); + + writeOutputFile( + backEndReq, + writer, + "stdout", + blobData, + blobSize); } - break; + } + break; default: SLANG_UNEXPECTED("unhandled output format"); @@ -2059,17 +2151,59 @@ SlangResult dissassembleDXILUsingDXC( } + static void writeEntryPointResultToStandardOutput( + EndToEndCompileRequest* compileRequest, + EntryPoint* entryPoint, + TargetRequest* targetReq, + CompileResult const& result) + { + SLANG_UNUSED(entryPoint); + writeCompileRequestToStandardOutput(compileRequest, targetReq, result); + } + + static void writeWholeProgramResult( + EndToEndCompileRequest* compileRequest, + TargetRequest* targetReq) + { + auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); + auto targetProgram = program->getTargetProgram(targetReq); + auto backEndReq = compileRequest->getBackEndReq(); + + auto& result = targetProgram->getExistingWholeProgramResult(); + + // Skip the case with no output + if (result.format == ResultFormat::None) + return; + + // It is possible that we are dynamically discovering entry + // points (using `[shader(...)]` attributes), so that there + // might be entry points added to the program that did not + // get paths specified via command-line options. + // + RefPtr targetInfo; + if (compileRequest->targetInfos.TryGetValue(targetReq, targetInfo)) + { + String outputPath = targetInfo->wholeTargetOutputPath; + if (outputPath != "") + { + writeCompileResultToFile(backEndReq, outputPath, result); + return; + } + } + + writeCompileRequestToStandardOutput(compileRequest, targetReq, result); + } + static void writeEntryPointResult( EndToEndCompileRequest* compileRequest, - EntryPoint* entryPoint, - TargetRequest* targetReq, - Int entryPointIndex) + EntryPointAndIndex entryPoint, + TargetRequest* targetReq) { auto program = compileRequest->getSpecializedGlobalAndEntryPointsComponentType(); auto targetProgram = program->getTargetProgram(targetReq); auto backEndReq = compileRequest->getBackEndReq(); - auto& result = targetProgram->getExistingEntryPointResult(entryPointIndex); + auto& result = targetProgram->getExistingEntryPointResult(entryPoint.index); // Skip the case with no output if (result.format == ResultFormat::None) @@ -2084,14 +2218,46 @@ SlangResult dissassembleDXILUsingDXC( if(compileRequest->targetInfos.TryGetValue(targetReq, targetInfo)) { String outputPath; - if(targetInfo->entryPointOutputPaths.TryGetValue(entryPointIndex, outputPath)) + if(targetInfo->entryPointOutputPaths.TryGetValue(entryPoint.index, outputPath)) { - writeEntryPointResultToFile(backEndReq, entryPoint, outputPath, result); + writeEntryPointResultToFile(backEndReq, entryPoint.entryPoint, outputPath, result); return; } } - writeEntryPointResultToStandardOutput(compileRequest, entryPoint, targetReq, result); + writeEntryPointResultToStandardOutput(compileRequest, entryPoint.entryPoint, targetReq, result); + } + + CompileResult& TargetProgram::_createWholeProgramResult( + List entryPointIndices, + BackEndCompileRequest* backEndRequest, + EndToEndCompileRequest* endToEndRequest) + { + List entryPoints; + for (auto entryPointIndex = entryPointIndices.begin(); entryPointIndex != entryPointIndices.end(); entryPointIndex++) { + if (*entryPointIndex >= m_entryPointResults.getCount()) + m_entryPointResults.setCount(*entryPointIndex + 1); + + // It is possible that entry points goot added to the `Program` + // *after* we created this `TargetProgram`, so there might be + // a request for an entry point that we didn't allocate space for. + // + // TODO: Change the construction logic so that a `Program` is + // constructed all at once rather than incrementally, to avoid + // this problem. + // + auto entryPoint = m_program->getEntryPoint(*entryPointIndex); + entryPoints.add(EntryPointAndIndex(entryPoint, *entryPointIndex)); + } + auto& result = m_wholeProgramResult; + result = emitEntryPoints( + backEndRequest, + entryPoints, + m_targetReq, + endToEndRequest); + + return result; + } CompileResult& TargetProgram::_createEntryPointResult( @@ -2115,8 +2281,7 @@ SlangResult dissassembleDXILUsingDXC( auto& result = m_entryPointResults[entryPointIndex]; result = emitEntryPoint( backEndRequest, - entryPoint, - entryPointIndex, + EntryPointAndIndex(entryPoint, entryPointIndex), m_targetReq, endToEndRequest); @@ -2124,6 +2289,34 @@ SlangResult dissassembleDXILUsingDXC( } + CompileResult& TargetProgram::getOrCreateWholeProgramResult( + List entryPointIndices, + DiagnosticSink* sink) + { + auto& result = m_wholeProgramResult; + if (result.format != ResultFormat::None) + return result; + + // If we haven't yet computed a layout for this target + // program, we need to make sure that is done before + // code generation. + // + if (!getOrCreateIRModuleForLayout(sink)) + { + return result; + } + + RefPtr backEndRequest = new BackEndCompileRequest( + m_program->getLinkage(), + sink, + m_program); + + return _createWholeProgramResult( + entryPointIndices, + backEndRequest, + nullptr); + } + CompileResult& TargetProgram::getOrCreateEntryPointResult( Int entryPointIndex, DiagnosticSink* sink) @@ -2166,13 +2359,23 @@ SlangResult dissassembleDXILUsingDXC( // Generate target code any entry points that // have been requested for compilation. auto entryPointCount = program->getEntryPointCount(); - for(Index ii = 0; ii < entryPointCount; ++ii) + if (targetReq->isWholeProgramRequest) { - targetProgram->_createEntryPointResult( - ii, + targetProgram->_createWholeProgramResult( + List(), compileReq, endToEndReq); } + else + { + for (Index ii = 0; ii < entryPointCount; ++ii) + { + targetProgram->_createEntryPointResult( + ii, + compileReq, + endToEndReq); + } + } } @@ -2377,13 +2580,19 @@ SlangResult dissassembleDXILUsingDXC( for (auto targetReq : linkage->targets) { Index entryPointCount = program->getEntryPointCount(); - for (Index ee = 0; ee < entryPointCount; ++ee) - { - writeEntryPointResult( + if (targetReq->isWholeProgramRequest) { + writeWholeProgramResult( compileRequest, - program->getEntryPoint(ee), - targetReq, - ee); + targetReq); + } + else { + for (Index ee = 0; ee < entryPointCount; ++ee) + { + writeEntryPointResult( + compileRequest, + EntryPointAndIndex(program->getEntryPoint(ee), ee), + targetReq); + } } } @@ -2563,5 +2772,4 @@ SlangResult dissassembleDXILUsingDXC( maybeDumpIntermediate(compileRequest, text, strlen(text), target); } - } -- cgit v1.2.3