diff options
Diffstat (limited to 'source/core')
| -rw-r--r-- | source/core/slang-io.cpp | 19 | ||||
| -rw-r--r-- | source/core/slang-io.h | 11 | ||||
| -rw-r--r-- | source/core/slang-nvrtc-compiler.cpp | 315 | ||||
| -rw-r--r-- | source/core/slang-platform.cpp | 22 | ||||
| -rw-r--r-- | source/core/slang-platform.h | 5 | ||||
| -rw-r--r-- | source/core/slang-string.cpp | 62 | ||||
| -rw-r--r-- | source/core/slang-string.h | 56 |
7 files changed, 393 insertions, 97 deletions
diff --git a/source/core/slang-io.cpp b/source/core/slang-io.cpp index d4f603109..0755a6a2f 100644 --- a/source/core/slang-io.cpp +++ b/source/core/slang-io.cpp @@ -175,9 +175,9 @@ namespace Slang return sb.ProduceString(); } - /* static */ Index Path::findLastSeparatorIndex(String const& path) + /* static */ Index Path::findLastSeparatorIndex(UnownedStringSlice const& path) { - const char* chars = path.getBuffer(); + const char* chars = path.begin(); for (Index i = path.getLength() - 1; i >= 0; --i) { const char c = chars[i]; @@ -189,7 +189,7 @@ namespace Slang return -1; } - /* static */Index Path::findExtIndex(String const& path) + /* static */Index Path::findExtIndex(UnownedStringSlice const& path) { const Index sepIndex = findLastSeparatorIndex(path); @@ -238,13 +238,22 @@ namespace Slang return path; } - String Path::getPathExt(const String& path) + UnownedStringSlice Path::getPathExt(const UnownedStringSlice& path) { const Index dotPos = findExtIndex(path); if (dotPos >= 0) + { return path.subString(dotPos + 1, path.getLength() - dotPos - 1); + } else - return ""; + { + // Note that the caller can identify if path has no extension or just a . + // as if it's a dot a zero length slice is returned in path + // If it's not then a default slice is returned (which doesn't point into path). + // + // Granted this is a little obscure and perhaps should be improved. + return UnownedStringSlice(); + } } String Path::getParentDirectory(const String& path) diff --git a/source/core/slang-io.h b/source/core/slang-io.h index ac3156b8a..a077d416b 100644 --- a/source/core/slang-io.h +++ b/source/core/slang-io.h @@ -68,14 +68,19 @@ namespace Slang static SlangResult find(const String& directoryPath, const char* pattern, Visitor* visitor); /// Returns -1 if no separator is found - static Index findLastSeparatorIndex(String const& path); + static Index findLastSeparatorIndex(String const& path) { return findLastSeparatorIndex(path.getUnownedSlice()); } + static Index findLastSeparatorIndex(UnownedStringSlice const& path); /// Finds the index of the last dot in a path, else returns -1 - static Index findExtIndex(String const& path); + static Index findExtIndex(String const& path) { return findExtIndex(path.getUnownedSlice()); } + static Index findExtIndex(UnownedStringSlice const& path); static String replaceExt(const String& path, const char* newExt); static String getFileName(const String& path); static String getPathWithoutExt(const String& path); - static String getPathExt(const String& path); + + static String getPathExt(const String& path) { return getPathExt(path.getUnownedSlice()); } + static UnownedStringSlice getPathExt(const UnownedStringSlice& path); + static String getParentDirectory(const String& path); static String getFileNameWithoutExt(const String& path); diff --git a/source/core/slang-nvrtc-compiler.cpp b/source/core/slang-nvrtc-compiler.cpp index 1bdb4dfa7..78afc618f 100644 --- a/source/core/slang-nvrtc-compiler.cpp +++ b/source/core/slang-nvrtc-compiler.cpp @@ -7,11 +7,13 @@ #include "../core/slang-blob.h" #include "slang-string-util.h" +#include "slang-string-slice-pool.h" #include "slang-io.h" #include "slang-shared-library.h" #include "slang-semantic-version.h" + namespace nvrtc { @@ -484,73 +486,280 @@ SlangResult NVRTCDownstreamCompiler::compile(const CompileOptions& options, RefP return SLANG_OK; } +/* An implementation of Path::Visitor that can be used for finding NVRTC shared library installations. */ +struct NVRTCPathVisitor : Path::Visitor +{ + struct Candidate + { + typedef Candidate ThisType; + + bool operator==(const ThisType& rhs) const { return path == rhs.path && version == rhs.version; } + bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + + static Candidate make(const String& path, const SemanticVersion& version) + { + Candidate can; + can.version = version; + can.path = path; + return can; + } + String path; + SemanticVersion version; + }; + + Index findVersion(const SemanticVersion& version) const + { + const Index count = m_candidates.getCount(); + for (Index i = 0; i < count; ++i) + { + if (m_candidates[i].version == version) + { + return i; + } + } + return -1; + } + + static bool _orderCandiate(const Candidate& a, const Candidate& b) { return a.version < b.version; } + void sortCandidates() { m_candidates.sort(_orderCandiate); } + + void accept(Path::Type type, const UnownedStringSlice& filename) SLANG_OVERRIDE + { + // Lets make sure it start's with nvrtc64, but not worry about case + if (type == Path::Type::File) + { + // If there is a defined extension, make sure it has it + if (m_postfix.getLength() && filename.getLength() >= m_postfix.getLength()) + { + // We test without case - really for windows + UnownedStringSlice filenamePostfix = filename.tail(filename.getLength() - m_postfix.getLength()); + if (!filenamePostfix.caseInsensitiveEquals(m_postfix.getUnownedSlice())) + { + return; + } + } + + if (filename.getLength() >= m_prefix.getLength() && + filename.subString(0, m_prefix.getLength()).caseInsensitiveEquals(m_prefix.getUnownedSlice())) + { + // Versions are typically (on windows) of the form + // nvrtc64_110_2.dll + // 11 - Major + // 0 Minor + // 2 Patch + Index endIndex = filename.indexOf('.'); + endIndex = (endIndex < 0) ? filename.getLength() : endIndex; + + UnownedStringSlice versionSlice = UnownedStringSlice(filename.begin() + m_prefix.getLength(), filename.begin() + endIndex); + + Int patch = 0; + UnownedStringSlice majorMinorSlice; + { + List<UnownedStringSlice> slices; + StringUtil::split(versionSlice, '_', slices); + if (slices.getCount() >= 2) + { + // We don't bother checking for error here, if it's not parsable, it will be 0 + StringUtil::parseInt(slices[1], patch); + } + majorMinorSlice = slices[0]; + } + + if (majorMinorSlice.getLength() < 2) + { + // Must be a major and minor + return; + } + + UnownedStringSlice majorSlice = majorMinorSlice.head(majorMinorSlice.getLength() - 1); + UnownedStringSlice minorSlice = majorMinorSlice.subString(majorMinorSlice.getLength() - 1, 1); + + Int major; + Int minor; + + if (SLANG_FAILED(StringUtil::parseInt(majorSlice, major)) || + SLANG_FAILED(StringUtil::parseInt(minorSlice, minor))) + { + return; + } + + const SemanticVersion version = SemanticVersion(int(major), int(minor), int(patch)); + + // We may want to add multiple versions, if they are in different locations - as there may be multiple entries + // in the PATH, and only one works. We'll only know which works by loading +#if 0 + // We already found this version, so let's not add it again + if (findVersion(version) >= 0) + { + return; + } +#endif + + // Strip to make a shared library name + UnownedStringSlice sharedLibraryName = filename.tail(m_prefix.getLength() - m_sharedLibraryStem.getLength()); + sharedLibraryName = filename.head(filename.getLength() - m_postfix.getLength()); + + auto candidate = Candidate::make(Path::combine(m_basePath, sharedLibraryName), version); + + // If we already have this candidate, then skip + if (m_candidates.indexOf(candidate) >= 0) + { + return; + } + + // Add to the list of candidates + m_candidates.add(candidate); + } + } + } + + SlangResult findInDirectory(const String& path) + { + m_basePath = path; + return Path::find(path, nullptr, this); + } + + bool hasCandidates() const { return m_candidates.getCount() > 0; } + + NVRTCPathVisitor(const UnownedStringSlice& sharedLibraryStem): + m_sharedLibraryStem(sharedLibraryStem) + { + // Work out the prefix and postfix of the shader + StringBuilder buf; + SharedLibrary::appendPlatformFileName(sharedLibraryStem, buf); + const Index index = buf.indexOf(sharedLibraryStem); + SLANG_ASSERT(index >= 0); + + m_prefix = buf.getUnownedSlice().head(index + sharedLibraryStem.getLength()); + m_postfix = buf.getUnownedSlice().tail(index + sharedLibraryStem.getLength()); + } + + String m_prefix; + String m_postfix; + String m_basePath; + String m_sharedLibraryStem; + + List<Candidate> m_candidates; +}; + +static SlangResult _findAndLoadNVRTC(ISlangSharedLibraryLoader* loader, ComPtr<ISlangSharedLibrary>& outLibrary) +{ +#if SLANG_WINDOWS_FAMILY + // We only need to search 64 bit versions on windows + NVRTCPathVisitor visitor(UnownedStringSlice::fromLiteral("nvrtc64_")); + + // First try the instance path (if supported on platform) + { + StringBuilder instancePath; + if (SLANG_SUCCEEDED(PlatformUtil::getInstancePath(instancePath))) + { + visitor.findInDirectory(instancePath); + } + } + + // If we don't have a candidate try CUDA_PATH + if (!visitor.hasCandidates()) + { + StringBuilder buf; + if (!SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("CUDA_PATH"), buf))) + { + // Look for candidates in the directory + visitor.findInDirectory(Path::combine(buf, "bin")); + } + } + + // If we haven't we go searching through PATH + if (!visitor.hasCandidates()) + { + List<UnownedStringSlice> splitPath; + + StringBuilder buf; + if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("PATH"), buf))) + { + // Split so we get individual paths + List<UnownedStringSlice> paths; + StringUtil::split(buf.getUnownedSlice(), ';', paths); + + // We use a pool to make sure we only check each path once + StringSlicePool pool(StringSlicePool::Style::Empty); + + // We are going to search the paths in order + for (const auto path : paths) + { + // PATH can have the same path multiple times. If we have already searched this path, we don't need to again + if (!pool.has(path)) + { + pool.add(path); + + Path::split(path, splitPath); + + // We could search every path, but here we restrict to paths that look like CUDA installations. + // It's a path that contains a CUDA directory and has bin + if (splitPath.indexOf("CUDA") >= 0 && splitPath[splitPath.getCount() - 1].caseInsensitiveEquals(UnownedStringSlice::fromLiteral("bin"))) + { + // Okay lets search it + visitor.findInDirectory(path); + } + } + } + } + } + + // Put into version order with oldest first. + visitor.sortCandidates(); + + // We want to start with the newest version... + for (Index i = visitor.m_candidates.getCount() - 1; i >= 0; --i) + { + const auto& candidate = visitor.m_candidates[i]; + if (SLANG_SUCCEEDED(loader->loadSharedLibrary(candidate.path.getBuffer(), outLibrary.writeRef()))) + { + return SLANG_OK; + } + } +#endif + // This is an official-ish list of versions is here: + // https://developer.nvidia.com/cuda-toolkit-archive + + // Filenames for NVRTC + // https://docs.nvidia.com/cuda/nvrtc/index.html + // + // From this it appears on platforms other than windows the SharedLibrary name + // should be nvrtc which is already tried, so we can give up now. + return SLANG_E_NOT_FOUND; +} + /* static */SlangResult NVRTCDownstreamCompilerUtil::locateCompilers(const String& path, ISlangSharedLibraryLoader* loader, DownstreamCompilerSet* set) { ComPtr<ISlangSharedLibrary> library; + // If the user supplies a path to their preferred version of NVRTC, + // we just use this. if (path.getLength() != 0) { SLANG_RETURN_ON_FAIL(loader->loadSharedLibrary(path.getBuffer(), library.writeRef())); } else { - // If the user doesn't supply a path to their preferred version of NVRTC, - // we will search for a suitable library version, proceeding from more - // recent versions to less recent ones. - // - // TODO: The list here was cobbled together from what NRTC releases I - // could easily identify. It would be good to ver this against some - // kind of official list. + // As a catch-all for non-Windows platforms, we search for + // a library simply named `nvrtc` (well, `libnvrtc`) which + // is expected to match whatever the user has installed. // - // It would probably be good to support 32- and 64-bit here, and also - // to deal with any variation in the shared library name across platforms - // - static const char* kNVRTCLibraryNames[] - { - // As a catch-all for non-Windows platforms, we search for - // a library simply named `nvrtc` (well, `libnvrtc`) which - // is expected to match whatever the user has installed. - // - - // A list of versions is here - // https://developer.nvidia.com/cuda-toolkit-archive - - "nvrtc", - - "nvrtc64_112_0", - - "nvrtc64_111_1", - "nvrtc64_111_0", - - "nvrtc64_110_0", - "nvrtc64_102_0", - "nvrtc64_101_0", - "nvrtc64_100_0", - "nvrtc64_92", - "nvrtc64_91", - "nvrtc64_90", - "nvrtc64_80", - "nvrtc64_75", - }; - - SlangResult result = SLANG_FAIL; - for( auto libraryName : kNVRTCLibraryNames ) + // On Windows an installation could place the version of nvrtc it uses in the same directory + // as the slang binary, such that it's loaded. + // Using this name also allows a ISlangSharedLibraryLoader to easily identify what is required + // and perhaps load a specific version + if (SLANG_FAILED(loader->loadSharedLibrary("nvrtc", library.writeRef()))) { - // If we succeed at loading one of the library versions - // from our list, we will not continue to search; this - // approach assumes that the `kNVRTCLibraryNames` array - // has been sorted so that earlier entries are preferable. - // - result = loader->loadSharedLibrary(libraryName, library.writeRef()); - if(!SLANG_FAILED(result)) - break; + // Try something more sophisticated to locate NVRTC + SLANG_RETURN_ON_FAIL(_findAndLoadNVRTC(loader, library)); } + } - // If we tried to load all of the candidate versions and none - // was successful, then we report back a failure. - // - if(SLANG_FAILED(result)) - return result; + SLANG_ASSERT(library); + if (!library) + { + return SLANG_FAIL; } RefPtr<NVRTCDownstreamCompiler> compiler(new NVRTCDownstreamCompiler); diff --git a/source/core/slang-platform.cpp b/source/core/slang-platform.cpp index 10b9b576b..20dfed5a8 100644 --- a/source/core/slang-platform.cpp +++ b/source/core/slang-platform.cpp @@ -66,6 +66,19 @@ SLANG_COMPILE_TIME_ASSERT(E_NOTIMPL == SLANG_E_NOT_IMPLEMENTED); SLANG_COMPILE_TIME_ASSERT(E_INVALIDARG == SLANG_E_INVALID_ARG); SLANG_COMPILE_TIME_ASSERT(E_OUTOFMEMORY == SLANG_E_OUT_OF_MEMORY); +/* static */SlangResult PlatformUtil::getInstancePath(StringBuilder& out) +{ + wchar_t path[_MAX_PATH]; + ::GetModuleFileName(::GetModuleHandle(NULL), path, SLANG_COUNT_OF(path)); + String pathString = String::fromWString(path); + + // We don't want the instance name, just the path to it + out.Clear(); + out.append(Path::getParentDirectory(pathString)); + + return out.getLength() > 0 ? SLANG_OK : SLANG_FAIL; +} + /* static */SlangResult PlatformUtil::appendResult(SlangResult res, StringBuilder& builderOut) { if (SLANG_FAILED(res) && res != SLANG_FAIL) @@ -142,6 +155,13 @@ SLANG_COMPILE_TIME_ASSERT(E_OUTOFMEMORY == SLANG_E_OUT_OF_MEMORY); #else // _WIN32 +/* static */SlangResult PlatformUtil::getInstancePath(StringBuilder& out) +{ + // On non Windows it's typically hard to get the instance path, so we'll say not implemented. + // The meaning is also somewhat more ambiguous - is it the exe or the shared library path? + return SLANG_E_NOT_IMPLEMENTED; +} + /* static */SlangResult PlatformUtil::appendResult(SlangResult res, StringBuilder& builderOut) { return SLANG_E_NOT_IMPLEMENTED; @@ -260,4 +280,6 @@ static const PlatformFlags s_familyFlags[int(PlatformFamily::CountOf)] = return s_familyFlags[int(family)]; } + + } diff --git a/source/core/slang-platform.h b/source/core/slang-platform.h index 804e2b773..ff7f1ccd9 100644 --- a/source/core/slang-platform.h +++ b/source/core/slang-platform.h @@ -106,6 +106,7 @@ namespace Slang static String calcPlatformPath(const UnownedStringSlice& path); static void calcPlatformPath(const UnownedStringSlice& path, StringBuilder& outBuilder); + private: /// Not constructible! SharedLibrary(); @@ -131,6 +132,10 @@ namespace Slang /// Given an environment name returns the set system variable. /// Will return SLANG_E_NOT_FOUND if the variable is not set static SlangResult getEnvironmentVariable(const UnownedStringSlice& name, StringBuilder& out); + + /// Get the path to this instance (the path to the dll/executable/shared library the call is in) + /// NOTE! This is not supported on all platforms, and will return SLANG_E_NOT_IMPLEMENTED in that scenario + static SlangResult getInstancePath(StringBuilder& out); }; #ifndef _MSC_VER diff --git a/source/core/slang-string.cpp b/source/core/slang-string.cpp index 975c83315..6d06eb1c6 100644 --- a/source/core/slang-string.cpp +++ b/source/core/slang-string.cpp @@ -492,4 +492,66 @@ namespace Slang return -1; } + UnownedStringSlice UnownedStringSlice::subString(Index idx, Index len) const + { + const Index totalLen = getLength(); + SLANG_ASSERT(idx >= 0 && len >= 0 && idx <= totalLen); + + // If too large, we truncate + len = (idx + len > totalLen) ? (totalLen - idx) : len; + + // Return the substring + return UnownedStringSlice(m_begin + idx, m_begin + idx + len); + } + + bool UnownedStringSlice::operator==(ThisType const& other) const + { + // Note that memcmp is undefined when passed in null ptrs, so if we want to handle + // we need to cover that case. + // Can only be nullptr if size is 0. + auto thisSize = getLength(); + auto otherSize = other.getLength(); + + if (thisSize != otherSize) + { + return false; + } + + const char*const thisChars = begin(); + const char*const otherChars = other.begin(); + if (thisChars == otherChars || thisSize == 0) + { + return true; + } + SLANG_ASSERT(thisChars && otherChars); + return memcmp(thisChars, otherChars, thisSize) == 0; + } + + bool UnownedStringSlice::caseInsensitiveEquals(const ThisType& rhs) const + { + const auto length = getLength(); + if (length != rhs.getLength()) + { + return false; + } + + const char* a = m_begin; + const char* b = rhs.m_begin; + + // Assuming this is a faster test + if (memcmp(a, b, length) != 0) + { + // They aren't identical so compare character by character + for (Index i = 0; i < length; ++i) + { + if (CharUtil::toLower(a[i]) != CharUtil::toLower(b[i])) + { + return false; + } + } + } + + return true; + } + } diff --git a/source/core/slang-string.h b/source/core/slang-string.h index cec4b0a09..e57718d40 100644 --- a/source/core/slang-string.h +++ b/source/core/slang-string.h @@ -65,6 +65,8 @@ namespace Slang struct UnownedStringSlice { public: + typedef UnownedStringSlice ThisType; + UnownedStringSlice() : m_begin(nullptr) , m_end(nullptr) @@ -109,6 +111,20 @@ namespace Slang /// Find first index of slice. If not found returns -1 Index indexOf(const UnownedStringSlice& slice) const; + /// Returns a substring. idx is the start index, and len + /// is the amount of characters. + /// The returned length might be truncated, if len extends beyond slice. + UnownedStringSlice subString(Index idx, Index len) const; + + /// Return a head of the slice - everything up to the index + SLANG_FORCE_INLINE UnownedStringSlice head(Index idx) const { SLANG_ASSERT(idx >= 0 && idx <= getLength()); return UnownedStringSlice(m_begin, idx); } + /// Return a tail of the slice - everything from the index to the end of the slice + SLANG_FORCE_INLINE UnownedStringSlice tail(Index idx) const { SLANG_ASSERT(idx >= 0 && idx <= getLength()); return UnownedStringSlice(m_begin + idx, m_end); } + + /// True if rhs and this are equal without having to take into account case + /// Note 'case' here is *not* locale specific - it is only A-Z and a-z + bool caseInsensitiveEquals(const ThisType& rhs) const; + Index lastIndexOf(char c) const { const Index size = Index(m_end - m_begin); @@ -128,43 +144,11 @@ namespace Slang return m_begin[i]; } - bool operator==(UnownedStringSlice const& other) const - { - // Note that memcmp is undefined when passed in null ptrs, so if we want to handle - // we need to cover that case. - // Can only be nullptr if size is 0. - auto thisSize = getLength(); - auto otherSize = other.getLength(); + bool operator==(ThisType const& other) const; + bool operator!=(UnownedStringSlice const& other) const { return !(*this == other); } - if (thisSize != otherSize) - { - return false; - } - - const char*const thisChars = begin(); - const char*const otherChars = other.begin(); - if (thisChars == otherChars || thisSize == 0) - { - return true; - } - SLANG_ASSERT(thisChars && otherChars); - return memcmp(thisChars, otherChars, thisSize) == 0; - } - - bool operator==(char const* str) const - { - return (*this) == UnownedStringSlice(str, str + ::strlen(str)); - } - - bool operator!=(UnownedStringSlice const& other) const - { - return !(*this == other); - } - - bool operator!=(char const* str) const - { - return (*this) != UnownedStringSlice(str, str + ::strlen(str)); - } + bool operator==(char const* str) const { return (*this) == UnownedStringSlice(str); } + bool operator!=(char const* str) const { return !(*this == str); } /// True if contents is a single char of c SLANG_FORCE_INLINE bool isChar(char c) const { return getLength() == 1 && m_begin[0] == c; } |
