diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2021-01-26 12:15:08 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-01-26 09:15:08 -0800 |
| commit | 798d7731eca286df456bc2ec56c0695ba006b472 (patch) | |
| tree | 37ced2db457a08aa8cfc81f19f18daf9ca26d3f2 /source/core/slang-nvrtc-compiler.cpp | |
| parent | 00fad59d49d31538270b811903aeb449c97ca152 (diff) | |
Improved NVRTC location finding (#1674)
* #include an absolute path didn't work - because paths were taken to always be relative.
* WIP more sophisticated mechanism to find NVRTC.
* Improve nvrtc searching to include PATH.
* Make getting an extension able to differentiate between no extension, and just a .
* Add comment.
* Add support for searching instance path.
* Small improvements around scope and finding NVRTC.
* Improve documentation around NVRTC loading.
Diffstat (limited to 'source/core/slang-nvrtc-compiler.cpp')
| -rw-r--r-- | source/core/slang-nvrtc-compiler.cpp | 315 |
1 files changed, 262 insertions, 53 deletions
diff --git a/source/core/slang-nvrtc-compiler.cpp b/source/core/slang-nvrtc-compiler.cpp index 1bdb4dfa7..78afc618f 100644 --- a/source/core/slang-nvrtc-compiler.cpp +++ b/source/core/slang-nvrtc-compiler.cpp @@ -7,11 +7,13 @@ #include "../core/slang-blob.h" #include "slang-string-util.h" +#include "slang-string-slice-pool.h" #include "slang-io.h" #include "slang-shared-library.h" #include "slang-semantic-version.h" + namespace nvrtc { @@ -484,73 +486,280 @@ SlangResult NVRTCDownstreamCompiler::compile(const CompileOptions& options, RefP return SLANG_OK; } +/* An implementation of Path::Visitor that can be used for finding NVRTC shared library installations. */ +struct NVRTCPathVisitor : Path::Visitor +{ + struct Candidate + { + typedef Candidate ThisType; + + bool operator==(const ThisType& rhs) const { return path == rhs.path && version == rhs.version; } + bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + static Candidate make(const String& path, const SemanticVersion& version) + { + Candidate can; + can.version = version; + can.path = path; + return can; + } + String path; + SemanticVersion version; + }; + + Index findVersion(const SemanticVersion& version) const + { + const Index count = m_candidates.getCount(); + for (Index i = 0; i < count; ++i) + { + if (m_candidates[i].version == version) + { + return i; + } + } + return -1; + } + + static bool _orderCandiate(const Candidate& a, const Candidate& b) { return a.version < b.version; } + void sortCandidates() { m_candidates.sort(_orderCandiate); } + + void accept(Path::Type type, const UnownedStringSlice& filename) SLANG_OVERRIDE + { + // Lets make sure it start's with nvrtc64, but not worry about case + if (type == Path::Type::File) + { + // If there is a defined extension, make sure it has it + if (m_postfix.getLength() && filename.getLength() >= m_postfix.getLength()) + { + // We test without case - really for windows + UnownedStringSlice filenamePostfix = filename.tail(filename.getLength() - m_postfix.getLength()); + if (!filenamePostfix.caseInsensitiveEquals(m_postfix.getUnownedSlice())) + { + return; + } + } + + if (filename.getLength() >= m_prefix.getLength() && + filename.subString(0, m_prefix.getLength()).caseInsensitiveEquals(m_prefix.getUnownedSlice())) + { + // Versions are typically (on windows) of the form + // nvrtc64_110_2.dll + // 11 - Major + // 0 Minor + // 2 Patch + Index endIndex = filename.indexOf('.'); + endIndex = (endIndex < 0) ? filename.getLength() : endIndex; + + UnownedStringSlice versionSlice = UnownedStringSlice(filename.begin() + m_prefix.getLength(), filename.begin() + endIndex); + + Int patch = 0; + UnownedStringSlice majorMinorSlice; + { + List<UnownedStringSlice> slices; + StringUtil::split(versionSlice, '_', slices); + if (slices.getCount() >= 2) + { + // We don't bother checking for error here, if it's not parsable, it will be 0 + StringUtil::parseInt(slices[1], patch); + } + majorMinorSlice = slices[0]; + } + + if (majorMinorSlice.getLength() < 2) + { + // Must be a major and minor + return; + } + + UnownedStringSlice majorSlice = majorMinorSlice.head(majorMinorSlice.getLength() - 1); + UnownedStringSlice minorSlice = majorMinorSlice.subString(majorMinorSlice.getLength() - 1, 1); + + Int major; + Int minor; + + if (SLANG_FAILED(StringUtil::parseInt(majorSlice, major)) || + SLANG_FAILED(StringUtil::parseInt(minorSlice, minor))) + { + return; + } + + const SemanticVersion version = SemanticVersion(int(major), int(minor), int(patch)); + + // We may want to add multiple versions, if they are in different locations - as there may be multiple entries + // in the PATH, and only one works. We'll only know which works by loading +#if 0 + // We already found this version, so let's not add it again + if (findVersion(version) >= 0) + { + return; + } +#endif + + // Strip to make a shared library name + UnownedStringSlice sharedLibraryName = filename.tail(m_prefix.getLength() - m_sharedLibraryStem.getLength()); + sharedLibraryName = filename.head(filename.getLength() - m_postfix.getLength()); + + auto candidate = Candidate::make(Path::combine(m_basePath, sharedLibraryName), version); + + // If we already have this candidate, then skip + if (m_candidates.indexOf(candidate) >= 0) + { + return; + } + + // Add to the list of candidates + m_candidates.add(candidate); + } + } + } + + SlangResult findInDirectory(const String& path) + { + m_basePath = path; + return Path::find(path, nullptr, this); + } + + bool hasCandidates() const { return m_candidates.getCount() > 0; } + + NVRTCPathVisitor(const UnownedStringSlice& sharedLibraryStem): + m_sharedLibraryStem(sharedLibraryStem) + { + // Work out the prefix and postfix of the shader + StringBuilder buf; + SharedLibrary::appendPlatformFileName(sharedLibraryStem, buf); + const Index index = buf.indexOf(sharedLibraryStem); + SLANG_ASSERT(index >= 0); + + m_prefix = buf.getUnownedSlice().head(index + sharedLibraryStem.getLength()); + m_postfix = buf.getUnownedSlice().tail(index + sharedLibraryStem.getLength()); + } + + String m_prefix; + String m_postfix; + String m_basePath; + String m_sharedLibraryStem; + + List<Candidate> m_candidates; +}; + +static SlangResult _findAndLoadNVRTC(ISlangSharedLibraryLoader* loader, ComPtr<ISlangSharedLibrary>& outLibrary) +{ +#if SLANG_WINDOWS_FAMILY + // We only need to search 64 bit versions on windows + NVRTCPathVisitor visitor(UnownedStringSlice::fromLiteral("nvrtc64_")); + + // First try the instance path (if supported on platform) + { + StringBuilder instancePath; + if (SLANG_SUCCEEDED(PlatformUtil::getInstancePath(instancePath))) + { + visitor.findInDirectory(instancePath); + } + } + + // If we don't have a candidate try CUDA_PATH + if (!visitor.hasCandidates()) + { + StringBuilder buf; + if (!SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("CUDA_PATH"), buf))) + { + // Look for candidates in the directory + visitor.findInDirectory(Path::combine(buf, "bin")); + } + } + + // If we haven't we go searching through PATH + if (!visitor.hasCandidates()) + { + List<UnownedStringSlice> splitPath; + + StringBuilder buf; + if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("PATH"), buf))) + { + // Split so we get individual paths + List<UnownedStringSlice> paths; + StringUtil::split(buf.getUnownedSlice(), ';', paths); + + // We use a pool to make sure we only check each path once + StringSlicePool pool(StringSlicePool::Style::Empty); + + // We are going to search the paths in order + for (const auto path : paths) + { + // PATH can have the same path multiple times. If we have already searched this path, we don't need to again + if (!pool.has(path)) + { + pool.add(path); + + Path::split(path, splitPath); + + // We could search every path, but here we restrict to paths that look like CUDA installations. + // It's a path that contains a CUDA directory and has bin + if (splitPath.indexOf("CUDA") >= 0 && splitPath[splitPath.getCount() - 1].caseInsensitiveEquals(UnownedStringSlice::fromLiteral("bin"))) + { + // Okay lets search it + visitor.findInDirectory(path); + } + } + } + } + } + + // Put into version order with oldest first. + visitor.sortCandidates(); + + // We want to start with the newest version... + for (Index i = visitor.m_candidates.getCount() - 1; i >= 0; --i) + { + const auto& candidate = visitor.m_candidates[i]; + if (SLANG_SUCCEEDED(loader->loadSharedLibrary(candidate.path.getBuffer(), outLibrary.writeRef()))) + { + return SLANG_OK; + } + } +#endif + // This is an official-ish list of versions is here: + // https://developer.nvidia.com/cuda-toolkit-archive + + // Filenames for NVRTC + // https://docs.nvidia.com/cuda/nvrtc/index.html + // + // From this it appears on platforms other than windows the SharedLibrary name + // should be nvrtc which is already tried, so we can give up now. + return SLANG_E_NOT_FOUND; +} + /* static */SlangResult NVRTCDownstreamCompilerUtil::locateCompilers(const String& path, ISlangSharedLibraryLoader* loader, DownstreamCompilerSet* set) { ComPtr<ISlangSharedLibrary> library; + // If the user supplies a path to their preferred version of NVRTC, + // we just use this. if (path.getLength() != 0) { SLANG_RETURN_ON_FAIL(loader->loadSharedLibrary(path.getBuffer(), library.writeRef())); } else { - // If the user doesn't supply a path to their preferred version of NVRTC, - // we will search for a suitable library version, proceeding from more - // recent versions to less recent ones. - // - // TODO: The list here was cobbled together from what NRTC releases I - // could easily identify. It would be good to ver this against some - // kind of official list. + // As a catch-all for non-Windows platforms, we search for + // a library simply named `nvrtc` (well, `libnvrtc`) which + // is expected to match whatever the user has installed. // - // It would probably be good to support 32- and 64-bit here, and also - // to deal with any variation in the shared library name across platforms - // - static const char* kNVRTCLibraryNames[] - { - // As a catch-all for non-Windows platforms, we search for - // a library simply named `nvrtc` (well, `libnvrtc`) which - // is expected to match whatever the user has installed. - // - - // A list of versions is here - // https://developer.nvidia.com/cuda-toolkit-archive - - "nvrtc", - - "nvrtc64_112_0", - - "nvrtc64_111_1", - "nvrtc64_111_0", - - "nvrtc64_110_0", - "nvrtc64_102_0", - "nvrtc64_101_0", - "nvrtc64_100_0", - "nvrtc64_92", - "nvrtc64_91", - "nvrtc64_90", - "nvrtc64_80", - "nvrtc64_75", - }; - - SlangResult result = SLANG_FAIL; - for( auto libraryName : kNVRTCLibraryNames ) + // On Windows an installation could place the version of nvrtc it uses in the same directory + // as the slang binary, such that it's loaded. + // Using this name also allows a ISlangSharedLibraryLoader to easily identify what is required + // and perhaps load a specific version + if (SLANG_FAILED(loader->loadSharedLibrary("nvrtc", library.writeRef()))) { - // If we succeed at loading one of the library versions - // from our list, we will not continue to search; this - // approach assumes that the `kNVRTCLibraryNames` array - // has been sorted so that earlier entries are preferable. - // - result = loader->loadSharedLibrary(libraryName, library.writeRef()); - if(!SLANG_FAILED(result)) - break; + // Try something more sophisticated to locate NVRTC + SLANG_RETURN_ON_FAIL(_findAndLoadNVRTC(loader, library)); } + } - // If we tried to load all of the candidate versions and none - // was successful, then we report back a failure. - // - if(SLANG_FAILED(result)) - return result; + SLANG_ASSERT(library); + if (!library) + { + return SLANG_FAIL; } RefPtr<NVRTCDownstreamCompiler> compiler(new NVRTCDownstreamCompiler); |
