diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2019-12-12 11:39:19 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2019-12-12 11:39:19 -0500 |
| commit | 6e6a876a6b5ad3d2ef402757d2e20641f5a2b49b (patch) | |
| tree | 29eab69c4982537376d7eaf422d9090c849e95f7 | |
| parent | 79ec0cfdb5f3461c763e0bf712cf42eb87fccb90 (diff) | |
Slang compiles CUDA source via NVRTC (#1151)
* CPPCompiler -> DownstreamCompiler
* Added DownstreamCompileResult to start abstraction such that we don't need files.
* * Split out slang-blob.cpp
* Made CompileResult hold a DownstreamCompileResult - for access to binary or ISlangSharedLibrary
* Keep temporary files in scope.
* Add a hash to the hex dump stream.
* Move all file tracking into DownstreamCompiler.
* WIP support for nvrtc.
* WIP: Adding support for nvrtc compiler.
Adding enum types, wiring up the nvrtc into slang.
* Fix remaining CPPCompiler references.
* Fix order issue on target string matching.
* Use ISlangSharedLibrary for nvrtc.
* Use DownstreamCompiler for nvrtc.
* WIP first pass at compilation win nvrtc.
* Added testing if file is on file system into CommandLineDownstreamCompiler.
Added sourceContentsPath.
* Make test cuda-compile.cu work by just compiling not comparing output.
* Fix warning on clang.
24 files changed, 698 insertions, 111 deletions
@@ -517,6 +517,8 @@ extern "C" SLANG_EXECUTABLE, ///< Executable (for hosting CPU/OS) SLANG_SHARED_LIBRARY, ///< A shared library/Dll (for hosting CPU/OS) SLANG_HOST_CALLABLE, ///< A CPU target that makes the compiled code available to be run immediately + SLANG_CUDA_SOURCE, ///< Cuda source + SLANG_PTX, ///< PTX SLANG_TARGET_COUNT_OF, }; @@ -545,6 +547,7 @@ extern "C" SLANG_PASS_THROUGH_VISUAL_STUDIO, ///< Visual studio C/C++ compiler SLANG_PASS_THROUGH_GCC, ///< GCC C/C++ compiler SLANG_PASS_THROUGH_GENERIC_C_CPP, ///< Generic C or C++ compiler, which is decided by the source type + SLANG_PASS_THROUGH_NVRTC, ///< NVRTC Cuda compiler SLANG_PASS_THROUGH_COUNT_OF, }; @@ -615,6 +618,7 @@ extern "C" SLANG_SOURCE_LANGUAGE_GLSL, SLANG_SOURCE_LANGUAGE_C, SLANG_SOURCE_LANGUAGE_CPP, + SLANG_SOURCE_LANGUAGE_CUDA, SLANG_SOURCE_LANGUAGE_COUNT_OF, }; diff --git a/source/core/core.vcxproj b/source/core/core.vcxproj index b0f33f2a2..b10bcc683 100644 --- a/source/core/core.vcxproj +++ b/source/core/core.vcxproj @@ -188,6 +188,7 @@ <ClInclude Include="slang-list.h" /> <ClInclude Include="slang-math.h" /> <ClInclude Include="slang-memory-arena.h" /> + <ClInclude Include="slang-nvrtc-compiler.h" /> <ClInclude Include="slang-object-scope-manager.h" /> <ClInclude Include="slang-offset-container.h" /> <ClInclude Include="slang-platform.h" /> @@ -221,6 +222,7 @@ <ClCompile Include="slang-hex-dump-util.cpp" /> <ClCompile Include="slang-io.cpp" /> <ClCompile Include="slang-memory-arena.cpp" /> + <ClCompile Include="slang-nvrtc-compiler.cpp" /> <ClCompile Include="slang-object-scope-manager.cpp" /> <ClCompile Include="slang-offset-container.cpp" /> <ClCompile Include="slang-platform.cpp" /> diff --git a/source/core/core.vcxproj.filters b/source/core/core.vcxproj.filters index 44d199771..b296b23af 100644 --- a/source/core/core.vcxproj.filters +++ b/source/core/core.vcxproj.filters @@ -63,6 +63,9 @@ <ClInclude Include="slang-memory-arena.h"> <Filter>Header Files</Filter> </ClInclude> + <ClInclude Include="slang-nvrtc-compiler.h"> + <Filter>Header Files</Filter> + </ClInclude> <ClInclude Include="slang-object-scope-manager.h"> <Filter>Header Files</Filter> </ClInclude> @@ -158,6 +161,9 @@ <ClCompile Include="slang-memory-arena.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="slang-nvrtc-compiler.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="slang-object-scope-manager.cpp"> <Filter>Source Files</Filter> </ClCompile> diff --git a/source/core/slang-downstream-compiler.cpp b/source/core/slang-downstream-compiler.cpp index 52ec2fcd7..9532cf17f 100644 --- a/source/core/slang-downstream-compiler.cpp +++ b/source/core/slang-downstream-compiler.cpp @@ -16,6 +16,7 @@ #include "slang-visual-studio-compiler-util.h" #include "slang-gcc-compiler-util.h" +#include "slang-nvrtc-compiler.h" namespace Slang { @@ -57,6 +58,7 @@ void DownstreamCompiler::Desc::appendAsText(StringBuilder& out) const case CompilerType::Clang: return UnownedStringSlice::fromLiteral("Clang"); case CompilerType::SNC: return UnownedStringSlice::fromLiteral("SNC"); case CompilerType::GHS: return UnownedStringSlice::fromLiteral("GHS"); + case CompilerType::NVRTC: return UnownedStringSlice::fromLiteral("NVRTC"); } } @@ -212,6 +214,38 @@ SlangResult CommandLineDownstreamCompileResult::getBinary(ComPtr<ISlangBlob>& ou /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! CommandLineDownstreamCompiler !!!!!!!!!!!!!!!!!!!!!!*/ +static bool _isContentsInFile(const DownstreamCompiler::CompileOptions& options) +{ + if (options.sourceContentsPath.getLength() <= 0) + { + return false; + } + + // We can see if we can load it + if (File::exists(options.sourceContentsPath)) + { + // Here we look for the file on the regular file system (as opposed to using the + // ISlangFileSystem. This is unfortunate but necessary - because when we call out + // to the compiler all it is able to (currently) see are files on the file system. + // + // Note that it could be coincidence that the filesystem has a file that's identical in + // contents/name. That being the case though, any includes wouldn't work for a generated + // file either from some specialized ISlangFileSystem, so this is probably as good as it gets + // until we can integrate directly to a C/C++ compiler through say a shared library where we can control + // file system access. + try + { + String readContents = File::readAllText(options.sourceContentsPath); + // We should see if they are the same + return options.sourceContents == readContents.getUnownedSlice(); + } + catch (const Slang::IOException&) + { + } + } + return false; +} + SlangResult CommandLineDownstreamCompiler::compile(const CompileOptions& inOptions, RefPtr<DownstreamCompileResult>& out) { // Copy the command line options @@ -232,8 +266,12 @@ SlangResult CommandLineDownstreamCompiler::compile(const CompileOptions& inOptio SLANG_RETURN_ON_FAIL(File::generateTemporary(UnownedStringSlice::fromLiteral("slang-generated"), modulePath)); options.modulePath = modulePath; } - - if (options.sourceContents.getLength() != 0) + + if (_isContentsInFile(options)) + { + options.sourceFiles.add(options.sourceContentsPath); + } + else { String compileSourcePath = modulePath; @@ -264,10 +302,11 @@ SlangResult CommandLineDownstreamCompiler::compile(const CompileOptions& inOptio // Add it as a source file options.sourceFiles.add(compileSourcePath); - - // There is no source contents - options.sourceContents = String(); } + + // There is no source contents + options.sourceContents = String(); + options.sourceContentsPath = String(); } // Append command line args to the end of cmdLine using the target specific function for the specified options @@ -488,8 +527,29 @@ static void _addGCCFamilyCompiler(const String& path, const String& inExeName, D _addGCCFamilyCompiler(desc.getPath(CompilerType::Clang), "clang", set); _addGCCFamilyCompiler(desc.getPath(CompilerType::GCC), "g++", set); - // Set the default to the compiler closest to how this source was compiled - set->setDefaultCompiler(findClosestCompiler(set, getCompiledWithDesc())); + { + DownstreamCompiler* cppCompiler = findClosestCompiler(set, getCompiledWithDesc()); + + // Set the default to the compiler closest to how this source was compiled + set->setDefaultCompiler(DownstreamCompiler::SourceType::CPP, cppCompiler); + set->setDefaultCompiler(DownstreamCompiler::SourceType::C, cppCompiler); + } + + // Lets see if we have NVRTC. + { + ISlangSharedLibrary* sharedLibrary = desc.sharedLibraries[int(CompilerType::NVRTC)]; + if (sharedLibrary) + { + RefPtr<DownstreamCompiler> compiler; + if (SLANG_SUCCEEDED(NVRTCDownstreamCompilerUtil::createCompiler(sharedLibrary, compiler))) + { + set->addCompiler(compiler); + + set->setDefaultCompiler(DownstreamCompiler::SourceType::CUDA, compiler); + } + } + } + return SLANG_OK; } diff --git a/source/core/slang-downstream-compiler.h b/source/core/slang-downstream-compiler.h index 12cf54a91..99bcfd29f 100644 --- a/source/core/slang-downstream-compiler.h +++ b/source/core/slang-downstream-compiler.h @@ -94,6 +94,25 @@ protected: DownstreamDiagnostics m_diagnostics; }; + +class BlobDownstreamCompileResult : public DownstreamCompileResult +{ +public: + typedef DownstreamCompileResult Super; + + virtual SlangResult getHostCallableSharedLibrary(ComPtr<ISlangSharedLibrary>& outLibrary) SLANG_OVERRIDE { SLANG_UNUSED(outLibrary); return SLANG_FAIL; } + virtual SlangResult getBinary(ComPtr<ISlangBlob>& outBlob) SLANG_OVERRIDE { outBlob = m_blob; return m_blob ? SLANG_OK : SLANG_FAIL; } + + BlobDownstreamCompileResult(const DownstreamDiagnostics& diags, ISlangBlob* blob): + Super(diags), + m_blob(blob) + { + + } +protected: + ComPtr<ISlangBlob> m_blob; +}; + class DownstreamCompiler: public RefObject { public: @@ -109,12 +128,15 @@ public: Clang, SNC, GHS, + NVRTC, CountOf, }; enum class SourceType { C, ///< C source CPP, ///< C++ source + CUDA, ///< The CUDA language + CountOf, }; struct Desc @@ -205,6 +227,9 @@ public: /// The contents of the source to compile. This can be empty is sourceFiles is set. /// If the compiler is a commandLine file this source will be written to a temporary file. String sourceContents; + /// 'Path' that the contents originated from. NOTE! This is for reporting only and doesn't have to exist on file system + String sourceContentsPath; + /// The names/paths of source to compile. This can be empty if sourceContents is set. List<String> sourceFiles; @@ -249,6 +274,7 @@ protected: DownstreamCompiler(const Desc& desc) : m_desc(desc) {} + DownstreamCompiler() {} Desc m_desc; }; @@ -332,9 +358,9 @@ public: void addCompiler(DownstreamCompiler* compiler); /// Get a default compiler - DownstreamCompiler* getDefaultCompiler() const { return m_defaultCompiler; } + DownstreamCompiler* getDefaultCompiler(DownstreamCompiler::SourceType sourceType) const { return m_defaultCompilers[int(sourceType)]; } /// Set the default compiler - void setDefaultCompiler(DownstreamCompiler* compiler) { m_defaultCompiler = compiler; } + void setDefaultCompiler(DownstreamCompiler::SourceType sourceType, DownstreamCompiler* compiler) { m_defaultCompilers[int(sourceType)] = compiler; } /// True if has a compiler of the specified type bool hasCompiler(DownstreamCompiler::CompilerType compilerType) const; @@ -343,7 +369,7 @@ protected: Index _findIndex(const DownstreamCompiler::Desc& desc) const; - RefPtr<DownstreamCompiler> m_defaultCompiler; + RefPtr<DownstreamCompiler> m_defaultCompilers[int(DownstreamCompiler::SourceType::CountOf)]; // This could be a dictionary/map - but doing a linear search is going to be fine and it makes // somethings easier. List<RefPtr<DownstreamCompiler>> m_compilers; @@ -380,7 +406,10 @@ struct DownstreamCompilerUtil: public DownstreamCompilerBaseUtil const String& getPath(CompilerType type) const { return paths[int(type)]; } void setPath(CompilerType type, const String& path) { paths[int(type)] = path; } + InitializeSetDesc() { memset(sharedLibraries, 0, sizeof(sharedLibraries)); } + String paths[int(DownstreamCompiler::CompilerType::CountOf)]; + ISlangSharedLibrary* sharedLibraries[int(DownstreamCompiler::CompilerType::CountOf)]; }; /// Find a compiler diff --git a/source/core/slang-hex-dump-util.cpp b/source/core/slang-hex-dump-util.cpp index 1583d8461..b0bd6f923 100644 --- a/source/core/slang-hex-dump-util.cpp +++ b/source/core/slang-hex-dump-util.cpp @@ -75,6 +75,13 @@ static const char s_hex[] = "0123456789abcdef"; *dst++ = s_hex[byte & 0xf]; } + // If not a complete line write spaces + for (size_t i = count; i < size_t(maxBytesPerLine); ++i) + { + *dst++ = ' '; + *dst++ = ' '; + } + *dst++ = ' '; for (size_t i = 0; i < count; ++i) diff --git a/source/core/slang-nvrtc-compiler.cpp b/source/core/slang-nvrtc-compiler.cpp new file mode 100644 index 000000000..e812a2ab9 --- /dev/null +++ b/source/core/slang-nvrtc-compiler.cpp @@ -0,0 +1,369 @@ +// slang-nvrtc-compiler.cpp +#include "slang-nvrtc-compiler.h" + +#include "slang-common.h" +#include "../../slang-com-helper.h" + +#include "../core/slang-blob.h" + +#include "slang-string-util.h" + +#include "slang-io.h" +#include "slang-shared-library.h" + +namespace nvrtc +{ + +typedef enum { + NVRTC_SUCCESS = 0, + NVRTC_ERROR_OUT_OF_MEMORY = 1, + NVRTC_ERROR_PROGRAM_CREATION_FAILURE = 2, + NVRTC_ERROR_INVALID_INPUT = 3, + NVRTC_ERROR_INVALID_PROGRAM = 4, + NVRTC_ERROR_INVALID_OPTION = 5, + NVRTC_ERROR_COMPILATION = 6, + NVRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7, + NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8, + NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9, + NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10, + NVRTC_ERROR_INTERNAL_ERROR = 11 +} nvrtcResult; + +typedef struct _nvrtcProgram *nvrtcProgram; + +#define SLANG_NVRTC_FUNCS(x) \ + x(const char*, nvrtcGetErrorString, (nvrtcResult result)) \ + x(nvrtcResult, nvrtcVersion, (int *major, int *minor)) \ + x(nvrtcResult, nvrtcCreateProgram, (nvrtcProgram *prog, const char *src, const char *name, int numHeaders, const char * const *headers, const char * const *includeNames)) \ + x(nvrtcResult, nvrtcDestroyProgram, (nvrtcProgram *prog)) \ + x(nvrtcResult, nvrtcCompileProgram, (nvrtcProgram prog, int numOptions, const char * const *options)) \ + x(nvrtcResult, nvrtcGetPTXSize, (nvrtcProgram prog, size_t *ptxSizeRet)) \ + x(nvrtcResult, nvrtcGetPTX, (nvrtcProgram prog, char *ptx)) \ + x(nvrtcResult, nvrtcGetProgramLogSize, (nvrtcProgram prog, size_t *logSizeRet)) \ + x(nvrtcResult, nvrtcGetProgramLog, (nvrtcProgram prog, char *log))\ + x(nvrtcResult, nvrtcAddNameExpression, (nvrtcProgram prog, const char * const name_expression)) \ + x(nvrtcResult, nvrtcGetLoweredName, (nvrtcProgram prog, const char *const name_expression, const char** lowered_name)) + +} // namespace nvrtc + +namespace Slang +{ +using namespace nvrtc; + +static SlangResult _asResult(nvrtcResult res) +{ + switch (res) + { + case NVRTC_SUCCESS: + { + return SLANG_OK; + } + case NVRTC_ERROR_OUT_OF_MEMORY: + { + return SLANG_E_OUT_OF_MEMORY; + } + case NVRTC_ERROR_PROGRAM_CREATION_FAILURE: + case NVRTC_ERROR_INVALID_INPUT: + case NVRTC_ERROR_INVALID_PROGRAM: + { + return SLANG_FAIL; + } + case NVRTC_ERROR_INVALID_OPTION: + { + return SLANG_E_INVALID_ARG; + } + case NVRTC_ERROR_COMPILATION: + case NVRTC_ERROR_BUILTIN_OPERATION_FAILURE: + case NVRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION: + case NVRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION: + case NVRTC_ERROR_NAME_EXPRESSION_NOT_VALID: + { + return SLANG_FAIL; + } + case NVRTC_ERROR_INTERNAL_ERROR: + { + return SLANG_E_INTERNAL_FAIL; + } + default: return SLANG_FAIL; + } +} + +class NVRTCDownstreamCompiler : public DownstreamCompiler +{ +public: + typedef DownstreamCompiler Super; + + // DownstreamCompiler + virtual SlangResult compile(const CompileOptions& options, RefPtr<DownstreamCompileResult>& outResult) SLANG_OVERRIDE; + + /// Must be called before use + SlangResult init(ISlangSharedLibrary* library); + + NVRTCDownstreamCompiler() {} + +protected: + + struct ScopeProgram + { + ScopeProgram(NVRTCDownstreamCompiler* compiler, nvrtcProgram program): + m_compiler(compiler), + m_program(program) + { + } + ~ScopeProgram() + { + m_compiler->m_nvrtcDestroyProgram(&m_program); + } + NVRTCDownstreamCompiler* m_compiler; + nvrtcProgram m_program; + }; + + +#define SLANG_NVTRC_MEMBER_FUNCS(ret, name, params) \ + ret (*m_##name) params; + + SLANG_NVRTC_FUNCS(SLANG_NVTRC_MEMBER_FUNCS); + + ComPtr<ISlangSharedLibrary> m_sharedLibrary; +}; + +#define SLANG_NVRTC_RETURN_ON_FAIL(x) { nvrtcResult _res = x; if (_res != NVRTC_SUCCESS) return _asResult(_res); } + +SlangResult NVRTCDownstreamCompiler::init(ISlangSharedLibrary* library) +{ +#define SLANG_NVTRC_GET_FUNC(ret, name, params) \ + m_##name = (ret (*) params)library->findFuncByName(#name); \ + if (m_##name == nullptr) return SLANG_FAIL; + + SLANG_NVRTC_FUNCS(SLANG_NVTRC_GET_FUNC) + + m_sharedLibrary = library; + + m_desc.type = CompilerType::NVRTC; + + int major, minor; + m_nvrtcVersion(&major, &minor); + m_desc.majorVersion = major; + m_desc.minorVersion = minor; + + return SLANG_OK; +} + +static SlangResult _parseLocation(const UnownedStringSlice& in, DownstreamDiagnostic& outDiagnostic) +{ + const Index startIndex = in.indexOf('('); + + if (startIndex >= 0) + { + outDiagnostic.filePath = UnownedStringSlice(in.begin(), in.begin() + startIndex); + UnownedStringSlice remaining(in.begin() + startIndex + 1, in.end()); + const Int endIndex = remaining.indexOf(')'); + + UnownedStringSlice lineText = UnownedStringSlice(remaining.begin(), remaining.begin() + endIndex); + + Int line; + SLANG_RETURN_ON_FAIL(StringUtil::parseInt(lineText, line)); + outDiagnostic.fileLine = line; + } + else + { + outDiagnostic.fileLine = 0; + outDiagnostic.filePath = in; + } + return SLANG_OK; +} + +static SlangResult _parseNVRTCLine(const UnownedStringSlice& line, DownstreamDiagnostic& outDiagnostic) +{ + typedef DownstreamDiagnostic Diagnostic; + typedef Diagnostic::Type Type; + + outDiagnostic.stage = Diagnostic::Stage::Compile; + + List<UnownedStringSlice> split; + StringUtil::split(line, ':', split); + + if (split.getCount() == 3) + { + // tests/cuda/cuda-compile.cu(7): warning: variable "c" is used before its value is set + + const auto split1 = split[1].trim(); + + if (split1 == "error") + { + outDiagnostic.type = Type::Error; + } + else if (split1 == "warning") + { + outDiagnostic.type = Type::Warning; + } + outDiagnostic.text = split[2].trim(); + + SLANG_RETURN_ON_FAIL(_parseLocation(split[0], outDiagnostic)); + return SLANG_OK; + } + + return SLANG_E_NOT_FOUND; +} + +SlangResult NVRTCDownstreamCompiler::compile(const CompileOptions& options, RefPtr<DownstreamCompileResult>& outResult) +{ + // This compiler doesn't read files, they should be read externally and stored in sourceContents/sourceContentsPath + if (options.sourceFiles.getCount() > 0) + { + return SLANG_FAIL; + } + + CommandLine cmdLine; + + switch (options.debugInfoType) + { + case DebugInfoType::None: + { + break; + } + default: + { + cmdLine.addArg("--device-debug"); + break; + } + case DebugInfoType::Maximal: + { + cmdLine.addArg("--device-debug"); + cmdLine.addArg("--generate-line-info"); + break; + } + } + + // Don't seem to have such a control, so ignore for now + //switch (options.optimizationLevel) + //{ + // default: break; + //} + + switch (options.floatingPointMode) + { + case FloatingPointMode::Default: break; + case FloatingPointMode::Precise: + { + break; + } + case FloatingPointMode::Fast: + { + cmdLine.addArg("--use_fast_math"); + break; + } + } + + // Add defines + for (const auto& define : options.defines) + { + StringBuilder builder; + builder << "-D"; + builder << define.nameWithSig; + if (define.value.getLength()) + { + builder << "=" << define.value; + } + + cmdLine.addArg(builder); + } + + // Add includes + for (const auto& include : options.includePaths) + { + cmdLine.addArg("-I"); + cmdLine.addArg(include); + } + + + nvrtcProgram program = nullptr; + nvrtcResult res = m_nvrtcCreateProgram(&program, options.sourceContents.getBuffer(), options.sourceContentsPath.getBuffer(), 0, nullptr, nullptr); + if (res != NVRTC_SUCCESS) + { + return _asResult(res); + } + ScopeProgram scope(this, program); + + List<const char*> dstOptions; + dstOptions.setCount(cmdLine.m_args.getCount()); + for (Index i = 0; i < cmdLine.m_args.getCount(); ++i) + { + dstOptions[i] = cmdLine.m_args[i].value.getBuffer(); + } + + res = m_nvrtcCompileProgram(program, int(dstOptions.getCount()), dstOptions.getBuffer()); + + RefPtr<ListBlob> blob; + DownstreamDiagnostics diagnostics; + + diagnostics.result = _asResult(res); + + { + String rawDiagnostics; + + size_t logSize = 0; + SLANG_NVRTC_RETURN_ON_FAIL(m_nvrtcGetProgramLogSize(program, &logSize)); + + if (logSize) + { + char* dst = rawDiagnostics.prepareForAppend(Index(logSize)); + SLANG_NVRTC_RETURN_ON_FAIL(m_nvrtcGetProgramLog(program, dst)); + rawDiagnostics.appendInPlace(dst, Index(logSize)); + + diagnostics.rawDiagnostics = rawDiagnostics; + } + + // Parse the diagnostics here + for (auto line : LineParser(diagnostics.rawDiagnostics.getUnownedSlice())) + { + DownstreamDiagnostic diagnostic; + SlangResult lineRes = _parseNVRTCLine(line, diagnostic); + + if (SLANG_SUCCEEDED(lineRes)) + { + diagnostics.diagnostics.add(diagnostic); + } + else if (lineRes != SLANG_E_NOT_FOUND) + { + return lineRes; + } + } + + // if it has a compilation error.. set on output + if (diagnostics.has(DownstreamDiagnostic::Type::Error)) + { + diagnostics.result = SLANG_FAIL; + } + } + + if (res == nvrtc::NVRTC_SUCCESS) + { + // We should parse the log to set up the diagnostics + size_t ptxSize; + SLANG_NVRTC_RETURN_ON_FAIL(m_nvrtcGetPTXSize(program, &ptxSize)); + + List<uint8_t> ptx; + ptx.setCount(Index(ptxSize)); + + SLANG_NVRTC_RETURN_ON_FAIL(m_nvrtcGetPTX(program, (char*)ptx.getBuffer())); + + blob = ListBlob::moveCreate(ptx); + } + + outResult = new BlobDownstreamCompileResult(diagnostics, blob); + + return SLANG_OK; +} + +/* static */SlangResult NVRTCDownstreamCompilerUtil::createCompiler(ISlangSharedLibrary* library, RefPtr<DownstreamCompiler>& outCompiler) +{ + RefPtr<NVRTCDownstreamCompiler> compiler(new NVRTCDownstreamCompiler); + + SLANG_RETURN_ON_FAIL(compiler->init(library)); + + outCompiler = compiler; + return SLANG_OK; +} + +} diff --git a/source/core/slang-nvrtc-compiler.h b/source/core/slang-nvrtc-compiler.h new file mode 100644 index 000000000..91cd92b8c --- /dev/null +++ b/source/core/slang-nvrtc-compiler.h @@ -0,0 +1,20 @@ +#ifndef SLANG_NVRTC_COMPILER_UTIL_H +#define SLANG_NVRTC_COMPILER_UTIL_H + +#include "slang-downstream-compiler.h" + +#include "../core/slang-platform.h" + +namespace Slang +{ + + +struct NVRTCDownstreamCompilerUtil +{ + /// Create a NVRTC downstream compiler. Note on success the created compiler will own the shared library handle. + static SlangResult createCompiler(ISlangSharedLibrary* library, RefPtr<DownstreamCompiler>& outCompiler); +}; + +} + +#endif diff --git a/source/core/slang-shared-library.cpp b/source/core/slang-shared-library.cpp index fd61ba0a0..2b18ad6aa 100644 --- a/source/core/slang-shared-library.cpp +++ b/source/core/slang-shared-library.cpp @@ -22,6 +22,7 @@ static const Guid IID_ISlangSharedLibraryLoader = SLANG_UUID_ISlangSharedLibrary "d3dcompiler_47", // SharedLibraryType::Fxc "slang-glslang", // SharedLibraryType::Glslang "dxil", // SharedLibraryType::Dxil + "nvrtc64_102_0", // SharedLibraryType::NVRTC }; /* static */DefaultSharedLibraryLoader DefaultSharedLibraryLoader::s_singleton; @@ -44,13 +45,23 @@ ISlangUnknown* DefaultSharedLibraryLoader::getInterface(const Guid& guid) return (guid == IID_ISlangUnknown || guid == IID_ISlangSharedLibraryLoader) ? static_cast<ISlangSharedLibraryLoader*>(this) : nullptr; } -SlangResult DefaultSharedLibraryLoader::loadSharedLibrary(const char* path, ISlangSharedLibrary** sharedLibraryOut) +SlangResult DefaultSharedLibraryLoader::loadSharedLibrary(const char* path, ISlangSharedLibrary** outSharedLibrary) { - *sharedLibraryOut = nullptr; + *outSharedLibrary = nullptr; // Try loading SharedLibrary::Handle handle; SLANG_RETURN_ON_FAIL(SharedLibrary::load(path, handle)); - *sharedLibraryOut = ComPtr<ISlangSharedLibrary>(new DefaultSharedLibrary(handle)).detach(); + *outSharedLibrary = ComPtr<ISlangSharedLibrary>(new DefaultSharedLibrary(handle)).detach(); + return SLANG_OK; +} + +SlangResult DefaultSharedLibraryLoader::loadPlatformSharedLibrary(const char* path, ISlangSharedLibrary** outSharedLibrary) +{ + *outSharedLibrary = nullptr; + // Try loading + SharedLibrary::Handle handle; + SLANG_RETURN_ON_FAIL(SharedLibrary::loadWithPlatformPath(path, handle)); + *outSharedLibrary = ComPtr<ISlangSharedLibrary>(new DefaultSharedLibrary(handle)).detach(); return SLANG_OK; } diff --git a/source/core/slang-shared-library.h b/source/core/slang-shared-library.h index 517292908..9c58d4e10 100644 --- a/source/core/slang-shared-library.h +++ b/source/core/slang-shared-library.h @@ -21,6 +21,7 @@ enum class SharedLibraryType Fxc, ///< Fxc compiler Glslang, ///< Slang specific glslang compiler Dxil, ///< Dxil is used with dxc + NVRTC, ///< Nvrtc compiler CountOf, }; @@ -36,7 +37,9 @@ public: // ISlangSharedLibraryLoader virtual SLANG_NO_THROW SlangResult SLANG_MCALL loadSharedLibrary(const char* path, - ISlangSharedLibrary** sharedLibraryOut) SLANG_OVERRIDE; + ISlangSharedLibrary** outSharedLibrary) SLANG_OVERRIDE; + + SlangResult loadPlatformSharedLibrary(const char* path, ISlangSharedLibrary** outSharedLibrary); /// Get the singleton static DefaultSharedLibraryLoader* getSingleton() { return &s_singleton; } diff --git a/source/core/slang-visual-studio-compiler-util.cpp b/source/core/slang-visual-studio-compiler-util.cpp index 8797c3384..356f21a25 100644 --- a/source/core/slang-visual-studio-compiler-util.cpp +++ b/source/core/slang-visual-studio-compiler-util.cpp @@ -119,6 +119,10 @@ namespace Slang cmdLine.addArg("/MD"); break; } + case DebugInfoType::None: + { + break; + } case DebugInfoType::Maximal: { // Multithreaded statically linked *debug* runtime library diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index 3586fcf25..db69a3155 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -121,11 +121,11 @@ namespace Slang return func; } - DownstreamCompilerSet* Session::requireCPPCompilerSet() + DownstreamCompilerSet* Session::requireDownstreamCompilerSet() { - if (cppCompilerSet == nullptr) + if (downstreamCompilerSet == nullptr) { - cppCompilerSet = new DownstreamCompilerSet; + downstreamCompilerSet = new DownstreamCompilerSet; typedef DownstreamCompiler::CompilerType CompilerType; DownstreamCompilerUtil::InitializeSetDesc desc; @@ -134,10 +134,12 @@ namespace Slang desc.paths[int(CompilerType::Clang)] = m_downstreamCompilerPaths[int(PassThroughMode::Clang)]; desc.paths[int(CompilerType::VisualStudio)] = m_downstreamCompilerPaths[int(PassThroughMode::VisualStudio)]; - DownstreamCompilerUtil::initializeSet(desc, cppCompilerSet); + desc.sharedLibraries[int(CompilerType::NVRTC)] = getOrLoadSharedLibrary(SharedLibraryType::NVRTC, nullptr); + + DownstreamCompilerUtil::initializeSet(desc, downstreamCompilerSet); } - SLANG_ASSERT(cppCompilerSet); - return cppCompilerSet; + SLANG_ASSERT(downstreamCompilerSet); + return downstreamCompilerSet; } TypeCheckingCache* Session::getTypeCheckingCache() diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 9ac49f60c..cba25d7df 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -76,6 +76,7 @@ namespace Slang { +// NOTE! These must be in the same order as the SlangCompileTarget enum #define SLANG_CODE_GEN_TARGETS(x) \ x("unknown", Unknown) \ x("none", None) \ @@ -93,7 +94,9 @@ namespace Slang x("cpp", CPPSource) \ x("exe,executable", Executable) \ x("sharedlib,sharedlibrary,dll", SharedLibrary) \ - x("callable,host-callable", HostCallable) + x("callable,host-callable", HostCallable) \ + x("cu,cuda", CUDASource) \ + x("ptx", PTX) #define SLANG_CODE_GEN_INFO(names, e) \ { CodeGenTarget::e, UnownedStringSlice::fromLiteral(names) }, @@ -115,6 +118,7 @@ namespace Slang { const auto& info = s_codeGenTargetInfos[i]; + // If this assert fails, then the SLANG_CODE_GEN_TARGETS macro has the wrong order SLANG_ASSERT(i == int(info.target)); if (StringUtil::indexOfInSplit(info.names, ',', name) >= 0) @@ -472,23 +476,27 @@ namespace Slang } case PassThroughMode::Clang: { - return session->requireCPPCompilerSet()->hasCompiler(DownstreamCompiler::CompilerType::Clang) ? SLANG_OK: SLANG_E_NOT_FOUND; + return session->requireDownstreamCompilerSet()->hasCompiler(DownstreamCompiler::CompilerType::Clang) ? SLANG_OK: SLANG_E_NOT_FOUND; } case PassThroughMode::VisualStudio: { - return session->requireCPPCompilerSet()->hasCompiler(DownstreamCompiler::CompilerType::VisualStudio) ? SLANG_OK: SLANG_E_NOT_FOUND; + return session->requireDownstreamCompilerSet()->hasCompiler(DownstreamCompiler::CompilerType::VisualStudio) ? SLANG_OK: SLANG_E_NOT_FOUND; } case PassThroughMode::Gcc: { - return session->requireCPPCompilerSet()->hasCompiler(DownstreamCompiler::CompilerType::GCC) ? SLANG_OK: SLANG_E_NOT_FOUND; + return session->requireDownstreamCompilerSet()->hasCompiler(DownstreamCompiler::CompilerType::GCC) ? SLANG_OK: SLANG_E_NOT_FOUND; } case PassThroughMode::GenericCCpp: { List<DownstreamCompiler::Desc> descs; - session->requireCPPCompilerSet()->getCompilerDescs(descs); + session->requireDownstreamCompilerSet()->getCompilerDescs(descs); return descs.getCount() ? SLANG_OK: SLANG_E_NOT_FOUND; } + case PassThroughMode::NVRTC: + { + return session->requireDownstreamCompilerSet()->hasCompiler(DownstreamCompiler::CompilerType::NVRTC) ? SLANG_OK: SLANG_E_NOT_FOUND; + } } return SLANG_E_NOT_IMPLEMENTED; } @@ -541,6 +549,10 @@ namespace Slang // We need some C/C++ compiler return PassThroughMode::GenericCCpp; } + case CodeGenTarget::PTX: + { + return PassThroughMode::NVRTC; + } default: break; } @@ -549,7 +561,7 @@ namespace Slang return PassThroughMode::None; } - PassThroughMode getPassThroughModeForCPPCompiler(DownstreamCompiler::CompilerType type) + PassThroughMode getPassThroughModeForDownstreamCompiler(DownstreamCompiler::CompilerType type) { typedef DownstreamCompiler::CompilerType CompilerType; @@ -558,6 +570,7 @@ namespace Slang case CompilerType::VisualStudio: return PassThroughMode::VisualStudio; case CompilerType::GCC: return PassThroughMode::Gcc; case CompilerType::Clang: return PassThroughMode::Clang; + case CompilerType::NVRTC: return PassThroughMode::NVRTC; default: return PassThroughMode::None; } } @@ -1277,7 +1290,7 @@ SlangResult dissassembleDXILUsingDXC( return SLANG_OK; } - SlangResult emitCPUBinaryForEntryPoint( + SlangResult emitDownstreamForEntryPoint( BackEndCompileRequest* slangRequest, Int entryPointIndex, TargetRequest* targetReq, @@ -1297,7 +1310,24 @@ SlangResult dissassembleDXILUsingDXC( // If we are not in pass through, lookup the default compiler for the emitted source type if (downstreamCompiler == PassThroughMode::None) { - downstreamCompiler = PassThroughMode(session->getDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP)); + auto target = targetReq->target; + + switch (target) + { + case CodeGenTarget::PTX: + { + downstreamCompiler = PassThroughMode(session->getDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CUDA)); + break; + } + case CodeGenTarget::HostCallable: + case CodeGenTarget::SharedLibrary: + case CodeGenTarget::Executable: + { + downstreamCompiler = PassThroughMode(session->getDefaultDownstreamCompiler(SLANG_SOURCE_LANGUAGE_CPP)); + break; + } + default: break; + } } // Get the required downstream CPP compiler @@ -1382,39 +1412,13 @@ SlangResult dissassembleDXILUsingDXC( const PathInfo& pathInfo = sourceFile->getPathInfo(); if (pathInfo.type == PathInfo::Type::FoundPath || pathInfo.type == PathInfo::Type::Normal) { - String compileSourcePath = pathInfo.foundPath; - // We can see if we can load it - if (File::exists(compileSourcePath)) - { - // Here we look for the file on the regular file system (as opposed to using the - // ISlangFileSystem. This is unfortunate but necessary - because when we call out - // to the CPP compiler all it is able to (currently) see are files on the file system. - // - // Note that it could be coincidence that the filesystem has a file that's identical in - // contents/name. That being the case though, any includes wouldn't work for a generated - // file either from some specialized ISlangFileSystem, so this is probably as good as it gets - // until we can integrate directly to a C/C++ compiler through say a shared library where we can control - // file system access. - try - { - String readContents = File::readAllText(compileSourcePath); - // We should see if they are the same - if ((sourceFile->getContent() == readContents.getUnownedSlice())) - { - // We just say use this file - options.sourceFiles.add(compileSourcePath); - } - } - catch (const Slang::IOException&) - { - } - } + options.sourceContentsPath = pathInfo.foundPath; } + options.sourceContents = sourceFile->getContent(); } - - // If can't just use file, concat together and make - if (options.sourceFiles.getCount() == 0) + else { + // If can't just use file, concat together and make StringBuilder codeBuilder; for (auto sourceFile : translationUnit->getSourceFiles()) { @@ -1437,8 +1441,18 @@ SlangResult dissassembleDXILUsingDXC( } // Set the source type - options.sourceType = (rawSourceLanguage == SourceLanguage::C) ? DownstreamCompiler::SourceType::C : DownstreamCompiler::SourceType::CPP; - + switch (rawSourceLanguage) + { + case SourceLanguage::C: options.sourceType = DownstreamCompiler::SourceType::C; break; + case SourceLanguage::CPP: options.sourceType = DownstreamCompiler::SourceType::CPP; break; + case SourceLanguage::CUDA: options.sourceType = DownstreamCompiler::SourceType::CUDA; break; + default: + { + SLANG_ASSERT(!"Unhandled source language"); + return SLANG_FAIL; + } + } + // Disable exceptions and security checks options.flags &= ~(CompileOptions::Flag::EnableExceptionHandling | CompileOptions::Flag::EnableSecurityChecks); @@ -1456,6 +1470,13 @@ SlangResult dissassembleDXILUsingDXC( options.targetType = DownstreamCompiler::TargetType::Executable; break; } + case CodeGenTarget::PTX: + { + // TODO(JS): Not clear what to do here. + // For example should 'Kernel' be distinct from 'Executable'. For now just use executable. + options.targetType = DownstreamCompiler::TargetType::Executable; + break; + } default: break; } @@ -1488,7 +1509,7 @@ SlangResult dissassembleDXILUsingDXC( case FloatingPointMode::Default: options.floatingPointMode = DownstreamCompiler::FloatingPointMode::Default; break; case FloatingPointMode::Precise: options.floatingPointMode = DownstreamCompiler::FloatingPointMode::Precise; break; case FloatingPointMode::Fast: options.floatingPointMode = DownstreamCompiler::FloatingPointMode::Fast; break; - default: SLANG_ASSERT(!"Unhanlde floating point mode"); + default: SLANG_ASSERT(!"Unhandled floating point mode"); } // Add all the search paths (as calculated earlier - they will only be set if this is a pass through else will be empty) @@ -1686,13 +1707,14 @@ SlangResult dissassembleDXILUsingDXC( switch (target) { + case CodeGenTarget::PTX: case CodeGenTarget::HostCallable: case CodeGenTarget::SharedLibrary: case CodeGenTarget::Executable: { RefPtr<DownstreamCompileResult> downstreamResult; - if (SLANG_SUCCEEDED(emitCPUBinaryForEntryPoint( + if (SLANG_SUCCEEDED(emitDownstreamForEntryPoint( compileRequest, entryPointIndex, targetReq, @@ -2012,7 +2034,6 @@ SlangResult dissassembleDXILUsingDXC( const void* blobData = blob->getBufferPointer(); size_t blobSize = blob->getBufferSize(); - if (writer->isConsole()) { // Writing to console, so we need to generate text output. @@ -2047,12 +2068,16 @@ SlangResult dissassembleDXILUsingDXC( } break; + case CodeGenTarget::PTX: + // For now we just dump PTX out as hex + case CodeGenTarget::HostCallable: case CodeGenTarget::SharedLibrary: case CodeGenTarget::Executable: HexDumpUtil::dumpWithMarkers((const uint8_t*)blobData, blobSize, 24, writer); break; + default: SLANG_UNEXPECTED("unhandled output format"); return; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index c33b2cf28..be5251d14 100644 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -67,6 +67,8 @@ namespace Slang Executable = SLANG_EXECUTABLE, SharedLibrary = SLANG_SHARED_LIBRARY, HostCallable = SLANG_HOST_CALLABLE, + CUDASource = SLANG_CUDA_SOURCE, + PTX = SLANG_PTX, CountOf = SLANG_TARGET_COUNT_OF, }; @@ -777,6 +779,7 @@ namespace Slang VisualStudio = SLANG_PASS_THROUGH_VISUAL_STUDIO, ///< Visual studio compiler Gcc = SLANG_PASS_THROUGH_GCC, ///< Gcc compiler GenericCCpp = SLANG_PASS_THROUGH_GENERIC_C_CPP, ///< Generic C/C++ compiler + NVRTC = SLANG_PASS_THROUGH_NVRTC, CountOf = SLANG_PASS_THROUGH_COUNT_OF, }; @@ -1074,7 +1077,7 @@ namespace Slang /// Given a target returns the required downstream compiler PassThroughMode getDownstreamCompilerRequiredForTarget(CodeGenTarget target); - PassThroughMode getPassThroughModeForCPPCompiler(DownstreamCompiler::CompilerType type); + PassThroughMode getPassThroughModeForDownstreamCompiler(DownstreamCompiler::CompilerType type); /// A context for loading and re-using code modules. @@ -1836,7 +1839,7 @@ namespace Slang /// Get the specified compiler DownstreamCompiler* getDownstreamCompiler(PassThroughMode downstreamCompiler); /// Get the default cpp compiler for a language - DownstreamCompiler* getDefaultCPPCompiler(SourceLanguage sourceLanguage); + DownstreamCompiler* getDefaultDownstreamCompiler(SourceLanguage sourceLanguage); enum class SharedLibraryFuncType { @@ -1892,7 +1895,7 @@ namespace Slang RefPtr<Type> stringType; RefPtr<Type> enumTypeType; - RefPtr<DownstreamCompilerSet> cppCompilerSet; ///< Information about available C/C++ compilers. null unless information is requested (because slow) + RefPtr<DownstreamCompilerSet> downstreamCompilerSet; ///< Information about available C/C++ compilers. null unless information is requested (because slow) ComPtr<ISlangSharedLibraryLoader> sharedLibraryLoader; ///< The shared library loader (never null) ComPtr<ISlangSharedLibrary> sharedLibraries[int(SharedLibraryType::CountOf)]; ///< The loaded shared libraries @@ -1974,7 +1977,7 @@ namespace Slang const String& getDownstreamCompilerPrelude(PassThroughMode mode) { return m_downstreamCompilerPreludes[int(mode)]; } /// Finds out what compilers are present and caches the result - DownstreamCompilerSet* requireCPPCompilerSet(); + DownstreamCompilerSet* requireDownstreamCompilerSet(); Session(); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 755a005d1..45eef5148 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -583,10 +583,10 @@ String emitEntryPointSource( { const SourceLanguage sourceLanguage = (sourceStyle == SourceStyle::C) ? SourceLanguage::C : SourceLanguage::CPP; // Get the compiler used for the language - DownstreamCompiler* compiler = session->getDefaultCPPCompiler(sourceLanguage); + DownstreamCompiler* compiler = session->getDefaultDownstreamCompiler(sourceLanguage); if (compiler) { - passThru = getPassThroughModeForCPPCompiler(compiler->getDesc().type); + passThru = getPassThroughModeForDownstreamCompiler(compiler->getDesc().type); } } diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index c7ede8d93..310042b8b 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -55,7 +55,8 @@ SlangResult tryReadCommandLineArgument(DiagnosticSink* sink, char const* option, x(clang, CLANG) \ x(gcc, GCC) \ x(c, GENERIC_C_CPP) \ - x(cpp, GENERIC_C_CPP) + x(cpp, GENERIC_C_CPP) \ + x(nvrtc, NVRTC) static SlangResult _parsePassThrough(const UnownedStringSlice& name, SlangPassThrough& outPassThrough) { @@ -87,6 +88,10 @@ static SlangSourceLanguage _findSourceLanguage(const UnownedStringSlice& text) { return SLANG_SOURCE_LANGUAGE_HLSL; } + else if (text == "cu" || text == "cuda") + { + return SLANG_SOURCE_LANGUAGE_CUDA; + } return SLANG_SOURCE_LANGUAGE_UNKNOWN; } @@ -334,6 +339,8 @@ struct OptionsParser { ".c", SLANG_SOURCE_LANGUAGE_C, SLANG_STAGE_NONE }, { ".cpp", SLANG_SOURCE_LANGUAGE_CPP, SLANG_STAGE_NONE }, + { ".cu", SLANG_SOURCE_LANGUAGE_CUDA, SLANG_STAGE_NONE } + }; for (int i = 0; i < SLANG_COUNT_OF(entries); ++i) diff --git a/source/slang/slang-profile.h b/source/slang/slang-profile.h index b174245ba..2996c7040 100644 --- a/source/slang/slang-profile.h +++ b/source/slang/slang-profile.h @@ -15,6 +15,7 @@ namespace Slang GLSL = SLANG_SOURCE_LANGUAGE_GLSL, C = SLANG_SOURCE_LANGUAGE_C, CPP = SLANG_SOURCE_LANGUAGE_CPP, + CUDA = SLANG_SOURCE_LANGUAGE_CUDA, CountOf = SLANG_SOURCE_LANGUAGE_COUNT_OF, }; diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index faf748098..0319d1f7d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -107,6 +107,7 @@ Session::Session() } m_defaultDownstreamCompilers[Index(SourceLanguage::C)] = PassThroughMode::GenericCCpp; m_defaultDownstreamCompilers[Index(SourceLanguage::CPP)] = PassThroughMode::GenericCCpp; + m_defaultDownstreamCompilers[Index(SourceLanguage::CUDA)] = PassThroughMode::NVRTC; } } @@ -193,7 +194,14 @@ SLANG_NO_THROW void SLANG_MCALL Session::setDownstreamCompilerPath( case PassThroughMode::GenericCCpp: { // If any compiler path set changed, require all to be refreshed - cppCompilerSet.setNull(); + downstreamCompilerSet.setNull(); + break; + } + case PassThroughMode::NVRTC: + { + // TODO(JS): We need a way to set the NVRTC path. + // We want to unload... and try again... + downstreamCompilerSet.setNull(); break; } default: break; @@ -249,6 +257,10 @@ static bool _canCompile(PassThroughMode compiler, SourceLanguage sourceLanguage) { return sourceLanguage == SourceLanguage::C || sourceLanguage == SourceLanguage::CPP; } + case PassThroughMode::NVRTC: + { + return sourceLanguage == SourceLanguage::CUDA; + } default: break; } return false; @@ -259,13 +271,10 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Session::setDefaultDownstreamCompiler(Sla auto sourceLanguage = SourceLanguage(inSourceLanguage); auto compiler = PassThroughMode(defaultCompiler); - if (sourceLanguage == SourceLanguage::C || sourceLanguage == SourceLanguage::CPP) + if (_canCompile(compiler, sourceLanguage)) { - if (_canCompile(compiler, sourceLanguage)) - { - m_defaultDownstreamCompilers[int(sourceLanguage)] = compiler; - return SLANG_OK; - } + m_defaultDownstreamCompilers[int(sourceLanguage)] = compiler; + return SLANG_OK; } return SLANG_FAIL; @@ -280,19 +289,20 @@ SlangPassThrough SLANG_MCALL Session::getDefaultDownstreamCompiler(SlangSourceLa DownstreamCompiler* Session::getDownstreamCompiler(PassThroughMode compiler) { - DownstreamCompilerSet* compilerSet = requireCPPCompilerSet(); + DownstreamCompilerSet* compilerSet = requireDownstreamCompilerSet(); switch (compiler) { - case PassThroughMode::GenericCCpp: return compilerSet->getDefaultCompiler(); + case PassThroughMode::GenericCCpp: return compilerSet->getDefaultCompiler(DownstreamCompiler::SourceType::CPP); case PassThroughMode::Clang: return DownstreamCompilerUtil::findCompiler(compilerSet, DownstreamCompilerUtil::MatchType::Newest, DownstreamCompiler::Desc(DownstreamCompiler::CompilerType::Clang)); case PassThroughMode::VisualStudio: return DownstreamCompilerUtil::findCompiler(compilerSet, DownstreamCompilerUtil::MatchType::Newest, DownstreamCompiler::Desc(DownstreamCompiler::CompilerType::VisualStudio)); case PassThroughMode::Gcc: return DownstreamCompilerUtil::findCompiler(compilerSet, DownstreamCompilerUtil::MatchType::Newest, DownstreamCompiler::Desc(DownstreamCompiler::CompilerType::GCC)); + case PassThroughMode::NVRTC: return compilerSet->getDefaultCompiler(DownstreamCompiler::SourceType::CUDA); default: break; } return nullptr; } -DownstreamCompiler* Session::getDefaultCPPCompiler(SourceLanguage sourceLanguage) +DownstreamCompiler* Session::getDefaultDownstreamCompiler(SourceLanguage sourceLanguage) { return getDownstreamCompiler(m_defaultDownstreamCompilers[int(sourceLanguage)]); } diff --git a/tests/cuda/cuda-compile.cu b/tests/cuda/cuda-compile.cu new file mode 100644 index 000000000..35387cd20 --- /dev/null +++ b/tests/cuda/cuda-compile.cu @@ -0,0 +1,7 @@ +//TEST(smoke):COMPILE: -pass-through nvrtc -target ptx -entry hello tests/cuda/cuda-compile.cu + +__global__ +void hello(char *a, int *b) +{ + a[threadIdx.x] += b[threadIdx.x]; +} diff --git a/tools/render-test/options.cpp b/tools/render-test/options.cpp index fec934afc..7d3fd27b5 100644 --- a/tools/render-test/options.cpp +++ b/tools/render-test/options.cpp @@ -66,6 +66,10 @@ static SlangSourceLanguage _findSourceLanguage(const UnownedStringSlice& text) { return SLANG_SOURCE_LANGUAGE_HLSL; } + else if (text == "cu" || text == "cuda") + { + return SLANG_SOURCE_LANGUAGE_CUDA; + } return SLANG_SOURCE_LANGUAGE_UNKNOWN; } diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index 37c9b1036..5afcc6d24 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -62,6 +62,9 @@ static const char computeEntryPointName[] = "computeMain"; case SLANG_SOURCE_LANGUAGE_CPP: spAddPreprocessorDefine(slangRequest, "__CPP__", "1"); break; + case SLANG_SOURCE_LANGUAGE_CUDA: + spAddPreprocessorDefine(slangRequest, "__CUDA__", "1"); + break; default: assert(!"unexpected"); diff --git a/tools/slang-test/slang-test-main.cpp b/tools/slang-test/slang-test-main.cpp index 51c3ecea9..44cc73b4e 100644 --- a/tools/slang-test/slang-test-main.cpp +++ b/tools/slang-test/slang-test-main.cpp @@ -21,6 +21,8 @@ using namespace Slang; #include "../../source/core/slang-downstream-compiler.h" +#include "../../source/core/slang-nvrtc-compiler.h" + #include "../../source/core/slang-process-util.h" #define STB_IMAGE_IMPLEMENTATION @@ -523,6 +525,10 @@ static SlangPassThrough _toPassThroughType(const UnownedStringSlice& slice) { return SLANG_PASS_THROUGH_VISUAL_STUDIO; } + else if (slice == "nvrtc") + { + return SLANG_PASS_THROUGH_NVRTC; + } return SLANG_PASS_THROUGH_NONE; } @@ -562,6 +568,10 @@ static PassThroughFlags _getPassThroughFlagsForTarget(SlangCompileTarget target) { return PassThroughFlag::Generic_C_CPP; } + case SLANG_PTX: + { + return PassThroughFlag::NVRTC; + } default: { @@ -593,6 +603,8 @@ static SlangCompileTarget _getCompileTarget(const UnownedStringSlice& name) CASE("dll", SHARED_LIBRARY) CASE("callable", HOST_CALLABLE) CASE("host-callable", HOST_CALLABLE) + CASE("ptx", PTX) + CASE("cuda", CUDA_SOURCE) #undef CASE return SLANG_TARGET_UNKNOWN; @@ -1088,7 +1100,6 @@ TestResult runSimpleTest(TestContext* context, TestInput& input) TestResult runCompile(TestContext* context, TestInput& input) { - // need to execute the stand-alone Slang compiler on the file, and compare its output to what we expect auto outputStem = input.outputStem; CommandLine cmdLine; @@ -1280,9 +1291,9 @@ static String _calcModulePath(const TestInput& input) return Path::combine(directory, moduleName); } -static TestResult runCompilerCompile(TestContext* context, TestInput& input) +static TestResult runCPPCompilerCompile(TestContext* context, TestInput& input) { - DownstreamCompiler* compiler = context->getDefaultCompiler(); + DownstreamCompiler* compiler = context->getDefaultCompiler(DownstreamCompiler::SourceType::CPP); if (!compiler) { return TestResult::Ignored; @@ -1322,9 +1333,9 @@ static TestResult runCompilerCompile(TestContext* context, TestInput& input) return TestResult::Pass; } -static TestResult runCompilerSharedLibrary(TestContext* context, TestInput& input) +static TestResult runCPPCompilerSharedLibrary(TestContext* context, TestInput& input) { - DownstreamCompiler* compiler = context->getDefaultCompiler(); + DownstreamCompiler* compiler = context->getDefaultCompiler(DownstreamCompiler::SourceType::CPP); if (!compiler) { return TestResult::Ignored; @@ -1440,9 +1451,9 @@ static TestResult runCompilerSharedLibrary(TestContext* context, TestInput& inpu return TestResult::Pass; } -static TestResult runCompilerExecute(TestContext* context, TestInput& input) +static TestResult runCPPCompilerExecute(TestContext* context, TestInput& input) { - DownstreamCompiler* compiler = context->getDefaultCompiler(); + DownstreamCompiler* compiler = context->getDefaultCompiler(DownstreamCompiler::SourceType::CPP); if (!compiler) { return TestResult::Ignored; @@ -2442,9 +2453,9 @@ static const TestCommandInfo s_testCommandInfos[] = { "COMPARE_RENDER_COMPUTE", &runSlangRenderComputeComparisonTest}, { "COMPARE_GLSL", &runGLSLComparisonTest}, { "CROSS_COMPILE", &runCrossCompilerTest}, - { "CPP_COMPILER_EXECUTE", &runCompilerExecute}, - { "CPP_COMPILER_SHARED_LIBRARY", &runCompilerSharedLibrary}, - { "CPP_COMPILER_COMPILE", &runCompilerCompile}, + { "CPP_COMPILER_EXECUTE", &runCPPCompilerExecute}, + { "CPP_COMPILER_SHARED_LIBRARY", &runCPPCompilerSharedLibrary}, + { "CPP_COMPILER_COMPILE", &runCPPCompilerCompile}, { "PERFORMANCE_PROFILE", &runPerformanceProfile}, { "COMPILE", &runCompile}, }; @@ -2777,6 +2788,7 @@ static bool endsWithAllowedExtension( ".rgen", ".c", ".cpp", + ".cu", }; for( auto allowedExtension : allowedExtensions) @@ -2871,6 +2883,7 @@ SlangResult innerMain(int argc, char** argv) auto unixCatagory = categorySet.add("unix", fullTestCategory); #endif + // An un-categorized test will always belong to the `full` category categorySet.defaultCategory = fullTestCategory; diff --git a/tools/slang-test/test-context.cpp b/tools/slang-test/test-context.cpp index c37261b61..d333dbcaa 100644 --- a/tools/slang-test/test-context.cpp +++ b/tools/slang-test/test-context.cpp @@ -3,6 +3,7 @@ #include "../../source/core/slang-io.h" #include "../../source/core/slang-string-util.h" +#include "../../source/core/slang-shared-library.h" #include <stdio.h> #include <stdlib.h> @@ -27,15 +28,6 @@ Result TestContext::init() TestContext::~TestContext() { - for (auto& pair : m_sharedLibTools) - { - const auto& tool = pair.Value; - if (tool.m_sharedLibrary) - { - SharedLibrary::unload(tool.m_sharedLibrary); - } - } - if (m_session) { spDestroySession(m_session); @@ -60,11 +52,13 @@ TestContext::InnerMainFunc TestContext::getInnerMainFunc(const String& dirPath, SharedLibrary::appendPlatformFileName(sharedLibToolBuilder.getUnownedSlice(), builder); String path = Path::combine(dirPath, builder); + DefaultSharedLibraryLoader* loader = DefaultSharedLibraryLoader::getSingleton(); + SharedLibraryTool tool = {}; - if (SLANG_SUCCEEDED(SharedLibrary::loadWithPlatformPath(path.begin(), tool.m_sharedLibrary))) + if (SLANG_SUCCEEDED(loader->loadPlatformSharedLibrary(path.begin(), tool.m_sharedLibrary.writeRef()))) { - tool.m_func = (InnerMainFunc)SharedLibrary::findFuncByName(tool.m_sharedLibrary, "innerMain"); + tool.m_func = (InnerMainFunc)tool.m_sharedLibrary->findFuncByName("innerMain"); } m_sharedLibTools.Add(name, tool); @@ -76,12 +70,7 @@ void TestContext::setInnerMainFunc(const String& name, InnerMainFunc func) SharedLibraryTool* tool = m_sharedLibTools.TryGetValue(name); if (tool) { - if (tool->m_sharedLibrary) - { - SharedLibrary::unload(tool->m_sharedLibrary); - tool->m_sharedLibrary = nullptr; - } - + tool->m_sharedLibrary.setNull(); tool->m_func = func; } else @@ -99,14 +88,19 @@ DownstreamCompilerSet* TestContext::getCompilerSet() compilerSet = new DownstreamCompilerSet; DownstreamCompilerUtil::InitializeSetDesc desc; + + ComPtr<ISlangSharedLibrary> nvrtcSharedLibrary; + DefaultSharedLibraryLoader::getSingleton()->loadSharedLibrary(DefaultSharedLibraryLoader::getSharedLibraryNameFromType(SharedLibraryType::NVRTC), nvrtcSharedLibrary.writeRef()); + desc.sharedLibraries[int(DownstreamCompiler::CompilerType::NVRTC)] = nvrtcSharedLibrary; + DownstreamCompilerUtil::initializeSet(desc, compilerSet); } return compilerSet; } -Slang::DownstreamCompiler* TestContext::getDefaultCompiler() +Slang::DownstreamCompiler* TestContext::getDefaultCompiler(DownstreamCompiler::SourceType sourceType) { DownstreamCompilerSet* set = getCompilerSet(); - return set ? set->getDefaultCompiler() : nullptr; + return set ? set->getDefaultCompiler(sourceType) : nullptr; } diff --git a/tools/slang-test/test-context.h b/tools/slang-test/test-context.h index aa81fc72a..520e0bf1d 100644 --- a/tools/slang-test/test-context.h +++ b/tools/slang-test/test-context.h @@ -11,6 +11,8 @@ #include "../../source/core/slang-render-api-util.h" #include "../../source/core/slang-downstream-compiler.h" +#include "../../slang-com-ptr.h" + #include "options.h" typedef uint32_t PassThroughFlags; @@ -25,6 +27,7 @@ struct PassThroughFlag GCC = 1 << int(SLANG_PASS_THROUGH_GCC), Clang = 1 << int(SLANG_PASS_THROUGH_CLANG), Generic_C_CPP = 1 << int(SLANG_PASS_THROUGH_GENERIC_C_CPP), + NVRTC = 1 << int(SLANG_PASS_THROUGH_NVRTC) }; }; @@ -94,7 +97,7 @@ class TestContext /// Get compiler set Slang::DownstreamCompilerSet* getCompilerSet(); - Slang::DownstreamCompiler* getDefaultCompiler(); + Slang::DownstreamCompiler* getDefaultCompiler(Slang::DownstreamCompiler::SourceType sourceType); /// Ctor TestContext(); @@ -117,7 +120,7 @@ class TestContext protected: struct SharedLibraryTool { - Slang::SharedLibrary::Handle m_sharedLibrary; + Slang::ComPtr<ISlangSharedLibrary> m_sharedLibrary; InnerMainFunc m_func; }; |
