summaryrefslogtreecommitdiff
path: root/source/core/slang-nvrtc-compiler.cpp
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2021-01-26 12:15:08 -0500
committerGitHub <noreply@github.com>2021-01-26 09:15:08 -0800
commit798d7731eca286df456bc2ec56c0695ba006b472 (patch)
tree37ced2db457a08aa8cfc81f19f18daf9ca26d3f2 /source/core/slang-nvrtc-compiler.cpp
parent00fad59d49d31538270b811903aeb449c97ca152 (diff)
Improved NVRTC location finding (#1674)
* #include an absolute path didn't work - because paths were taken to always be relative. * WIP more sophisticated mechanism to find NVRTC. * Improve nvrtc searching to include PATH. * Make getting an extension able to differentiate between no extension, and just a . * Add comment. * Add support for searching instance path. * Small improvements around scope and finding NVRTC. * Improve documentation around NVRTC loading.
Diffstat (limited to 'source/core/slang-nvrtc-compiler.cpp')
-rw-r--r--source/core/slang-nvrtc-compiler.cpp315
1 files changed, 262 insertions, 53 deletions
diff --git a/source/core/slang-nvrtc-compiler.cpp b/source/core/slang-nvrtc-compiler.cpp
index 1bdb4dfa7..78afc618f 100644
--- a/source/core/slang-nvrtc-compiler.cpp
+++ b/source/core/slang-nvrtc-compiler.cpp
@@ -7,11 +7,13 @@
#include "../core/slang-blob.h"
#include "slang-string-util.h"
+#include "slang-string-slice-pool.h"
#include "slang-io.h"
#include "slang-shared-library.h"
#include "slang-semantic-version.h"
+
namespace nvrtc
{
@@ -484,73 +486,280 @@ SlangResult NVRTCDownstreamCompiler::compile(const CompileOptions& options, RefP
return SLANG_OK;
}
+/* An implementation of Path::Visitor that can be used for finding NVRTC shared library installations. */
+struct NVRTCPathVisitor : Path::Visitor
+{
+ struct Candidate
+ {
+ typedef Candidate ThisType;
+
+ bool operator==(const ThisType& rhs) const { return path == rhs.path && version == rhs.version; }
+ bool operator!=(const ThisType& rhs) const { return !(*this == rhs); }
+
+ static Candidate make(const String& path, const SemanticVersion& version)
+ {
+ Candidate can;
+ can.version = version;
+ can.path = path;
+ return can;
+ }
+ String path;
+ SemanticVersion version;
+ };
+
+ Index findVersion(const SemanticVersion& version) const
+ {
+ const Index count = m_candidates.getCount();
+ for (Index i = 0; i < count; ++i)
+ {
+ if (m_candidates[i].version == version)
+ {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ static bool _orderCandiate(const Candidate& a, const Candidate& b) { return a.version < b.version; }
+ void sortCandidates() { m_candidates.sort(_orderCandiate); }
+
+ void accept(Path::Type type, const UnownedStringSlice& filename) SLANG_OVERRIDE
+ {
+ // Lets make sure it start's with nvrtc64, but not worry about case
+ if (type == Path::Type::File)
+ {
+ // If there is a defined extension, make sure it has it
+ if (m_postfix.getLength() && filename.getLength() >= m_postfix.getLength())
+ {
+ // We test without case - really for windows
+ UnownedStringSlice filenamePostfix = filename.tail(filename.getLength() - m_postfix.getLength());
+ if (!filenamePostfix.caseInsensitiveEquals(m_postfix.getUnownedSlice()))
+ {
+ return;
+ }
+ }
+
+ if (filename.getLength() >= m_prefix.getLength() &&
+ filename.subString(0, m_prefix.getLength()).caseInsensitiveEquals(m_prefix.getUnownedSlice()))
+ {
+ // Versions are typically (on windows) of the form
+ // nvrtc64_110_2.dll
+ // 11 - Major
+ // 0 Minor
+ // 2 Patch
+ Index endIndex = filename.indexOf('.');
+ endIndex = (endIndex < 0) ? filename.getLength() : endIndex;
+
+ UnownedStringSlice versionSlice = UnownedStringSlice(filename.begin() + m_prefix.getLength(), filename.begin() + endIndex);
+
+ Int patch = 0;
+ UnownedStringSlice majorMinorSlice;
+ {
+ List<UnownedStringSlice> slices;
+ StringUtil::split(versionSlice, '_', slices);
+ if (slices.getCount() >= 2)
+ {
+ // We don't bother checking for error here, if it's not parsable, it will be 0
+ StringUtil::parseInt(slices[1], patch);
+ }
+ majorMinorSlice = slices[0];
+ }
+
+ if (majorMinorSlice.getLength() < 2)
+ {
+ // Must be a major and minor
+ return;
+ }
+
+ UnownedStringSlice majorSlice = majorMinorSlice.head(majorMinorSlice.getLength() - 1);
+ UnownedStringSlice minorSlice = majorMinorSlice.subString(majorMinorSlice.getLength() - 1, 1);
+
+ Int major;
+ Int minor;
+
+ if (SLANG_FAILED(StringUtil::parseInt(majorSlice, major)) ||
+ SLANG_FAILED(StringUtil::parseInt(minorSlice, minor)))
+ {
+ return;
+ }
+
+ const SemanticVersion version = SemanticVersion(int(major), int(minor), int(patch));
+
+ // We may want to add multiple versions, if they are in different locations - as there may be multiple entries
+ // in the PATH, and only one works. We'll only know which works by loading
+#if 0
+ // We already found this version, so let's not add it again
+ if (findVersion(version) >= 0)
+ {
+ return;
+ }
+#endif
+
+ // Strip to make a shared library name
+ UnownedStringSlice sharedLibraryName = filename.tail(m_prefix.getLength() - m_sharedLibraryStem.getLength());
+ sharedLibraryName = filename.head(filename.getLength() - m_postfix.getLength());
+
+ auto candidate = Candidate::make(Path::combine(m_basePath, sharedLibraryName), version);
+
+ // If we already have this candidate, then skip
+ if (m_candidates.indexOf(candidate) >= 0)
+ {
+ return;
+ }
+
+ // Add to the list of candidates
+ m_candidates.add(candidate);
+ }
+ }
+ }
+
+ SlangResult findInDirectory(const String& path)
+ {
+ m_basePath = path;
+ return Path::find(path, nullptr, this);
+ }
+
+ bool hasCandidates() const { return m_candidates.getCount() > 0; }
+
+ NVRTCPathVisitor(const UnownedStringSlice& sharedLibraryStem):
+ m_sharedLibraryStem(sharedLibraryStem)
+ {
+ // Work out the prefix and postfix of the shader
+ StringBuilder buf;
+ SharedLibrary::appendPlatformFileName(sharedLibraryStem, buf);
+ const Index index = buf.indexOf(sharedLibraryStem);
+ SLANG_ASSERT(index >= 0);
+
+ m_prefix = buf.getUnownedSlice().head(index + sharedLibraryStem.getLength());
+ m_postfix = buf.getUnownedSlice().tail(index + sharedLibraryStem.getLength());
+ }
+
+ String m_prefix;
+ String m_postfix;
+ String m_basePath;
+ String m_sharedLibraryStem;
+
+ List<Candidate> m_candidates;
+};
+
+static SlangResult _findAndLoadNVRTC(ISlangSharedLibraryLoader* loader, ComPtr<ISlangSharedLibrary>& outLibrary)
+{
+#if SLANG_WINDOWS_FAMILY
+ // We only need to search 64 bit versions on windows
+ NVRTCPathVisitor visitor(UnownedStringSlice::fromLiteral("nvrtc64_"));
+
+ // First try the instance path (if supported on platform)
+ {
+ StringBuilder instancePath;
+ if (SLANG_SUCCEEDED(PlatformUtil::getInstancePath(instancePath)))
+ {
+ visitor.findInDirectory(instancePath);
+ }
+ }
+
+ // If we don't have a candidate try CUDA_PATH
+ if (!visitor.hasCandidates())
+ {
+ StringBuilder buf;
+ if (!SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("CUDA_PATH"), buf)))
+ {
+ // Look for candidates in the directory
+ visitor.findInDirectory(Path::combine(buf, "bin"));
+ }
+ }
+
+ // If we haven't we go searching through PATH
+ if (!visitor.hasCandidates())
+ {
+ List<UnownedStringSlice> splitPath;
+
+ StringBuilder buf;
+ if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("PATH"), buf)))
+ {
+ // Split so we get individual paths
+ List<UnownedStringSlice> paths;
+ StringUtil::split(buf.getUnownedSlice(), ';', paths);
+
+ // We use a pool to make sure we only check each path once
+ StringSlicePool pool(StringSlicePool::Style::Empty);
+
+ // We are going to search the paths in order
+ for (const auto path : paths)
+ {
+ // PATH can have the same path multiple times. If we have already searched this path, we don't need to again
+ if (!pool.has(path))
+ {
+ pool.add(path);
+
+ Path::split(path, splitPath);
+
+ // We could search every path, but here we restrict to paths that look like CUDA installations.
+ // It's a path that contains a CUDA directory and has bin
+ if (splitPath.indexOf("CUDA") >= 0 && splitPath[splitPath.getCount() - 1].caseInsensitiveEquals(UnownedStringSlice::fromLiteral("bin")))
+ {
+ // Okay lets search it
+ visitor.findInDirectory(path);
+ }
+ }
+ }
+ }
+ }
+
+ // Put into version order with oldest first.
+ visitor.sortCandidates();
+
+ // We want to start with the newest version...
+ for (Index i = visitor.m_candidates.getCount() - 1; i >= 0; --i)
+ {
+ const auto& candidate = visitor.m_candidates[i];
+ if (SLANG_SUCCEEDED(loader->loadSharedLibrary(candidate.path.getBuffer(), outLibrary.writeRef())))
+ {
+ return SLANG_OK;
+ }
+ }
+#endif
+ // This is an official-ish list of versions is here:
+ // https://developer.nvidia.com/cuda-toolkit-archive
+
+ // Filenames for NVRTC
+ // https://docs.nvidia.com/cuda/nvrtc/index.html
+ //
+ // From this it appears on platforms other than windows the SharedLibrary name
+ // should be nvrtc which is already tried, so we can give up now.
+ return SLANG_E_NOT_FOUND;
+}
+
/* static */SlangResult NVRTCDownstreamCompilerUtil::locateCompilers(const String& path, ISlangSharedLibraryLoader* loader, DownstreamCompilerSet* set)
{
ComPtr<ISlangSharedLibrary> library;
+ // If the user supplies a path to their preferred version of NVRTC,
+ // we just use this.
if (path.getLength() != 0)
{
SLANG_RETURN_ON_FAIL(loader->loadSharedLibrary(path.getBuffer(), library.writeRef()));
}
else
{
- // If the user doesn't supply a path to their preferred version of NVRTC,
- // we will search for a suitable library version, proceeding from more
- // recent versions to less recent ones.
- //
- // TODO: The list here was cobbled together from what NRTC releases I
- // could easily identify. It would be good to ver this against some
- // kind of official list.
+ // As a catch-all for non-Windows platforms, we search for
+ // a library simply named `nvrtc` (well, `libnvrtc`) which
+ // is expected to match whatever the user has installed.
//
- // It would probably be good to support 32- and 64-bit here, and also
- // to deal with any variation in the shared library name across platforms
- //
- static const char* kNVRTCLibraryNames[]
- {
- // As a catch-all for non-Windows platforms, we search for
- // a library simply named `nvrtc` (well, `libnvrtc`) which
- // is expected to match whatever the user has installed.
- //
-
- // A list of versions is here
- // https://developer.nvidia.com/cuda-toolkit-archive
-
- "nvrtc",
-
- "nvrtc64_112_0",
-
- "nvrtc64_111_1",
- "nvrtc64_111_0",
-
- "nvrtc64_110_0",
- "nvrtc64_102_0",
- "nvrtc64_101_0",
- "nvrtc64_100_0",
- "nvrtc64_92",
- "nvrtc64_91",
- "nvrtc64_90",
- "nvrtc64_80",
- "nvrtc64_75",
- };
-
- SlangResult result = SLANG_FAIL;
- for( auto libraryName : kNVRTCLibraryNames )
+ // On Windows an installation could place the version of nvrtc it uses in the same directory
+ // as the slang binary, such that it's loaded.
+ // Using this name also allows a ISlangSharedLibraryLoader to easily identify what is required
+ // and perhaps load a specific version
+ if (SLANG_FAILED(loader->loadSharedLibrary("nvrtc", library.writeRef())))
{
- // If we succeed at loading one of the library versions
- // from our list, we will not continue to search; this
- // approach assumes that the `kNVRTCLibraryNames` array
- // has been sorted so that earlier entries are preferable.
- //
- result = loader->loadSharedLibrary(libraryName, library.writeRef());
- if(!SLANG_FAILED(result))
- break;
+ // Try something more sophisticated to locate NVRTC
+ SLANG_RETURN_ON_FAIL(_findAndLoadNVRTC(loader, library));
}
+ }
- // If we tried to load all of the candidate versions and none
- // was successful, then we report back a failure.
- //
- if(SLANG_FAILED(result))
- return result;
+ SLANG_ASSERT(library);
+ if (!library)
+ {
+ return SLANG_FAIL;
}
RefPtr<NVRTCDownstreamCompiler> compiler(new NVRTCDownstreamCompiler);