summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/compiler-core/slang-downstream-compiler.h13
-rw-r--r--source/compiler-core/slang-nvrtc-compiler.cpp637
-rw-r--r--source/core/slang-shared-library.cpp16
-rw-r--r--source/core/slang-shared-library.h8
-rw-r--r--source/slang/hlsl.meta.slang46
-rwxr-xr-xsource/slang/slang-compiler.cpp5
-rw-r--r--source/slang/slang-emit-cuda.cpp24
-rw-r--r--source/slang/slang-emit-cuda.h14
8 files changed, 517 insertions, 246 deletions
diff --git a/source/compiler-core/slang-downstream-compiler.h b/source/compiler-core/slang-downstream-compiler.h
index 35acc820e..def13993e 100644
--- a/source/compiler-core/slang-downstream-compiler.h
+++ b/source/compiler-core/slang-downstream-compiler.h
@@ -140,9 +140,9 @@ public:
{
enum Enum : SourceLanguageFlags
{
- Unknown = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_UNKNOWN,
- Slang = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_SLANG,
- HLSL = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_HLSL,
+ Unknown = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_UNKNOWN,
+ Slang = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_SLANG,
+ HLSL = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_HLSL,
GLSL = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_GLSL,
C = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_C,
CPP = SourceLanguageFlags(1) << SLANG_SOURCE_LANGUAGE_CPP,
@@ -247,9 +247,10 @@ public:
{
enum Enum : Flags
{
- EnableExceptionHandling = 0x01,
- Verbose = 0x02,
- EnableSecurityChecks = 0x04,
+ EnableExceptionHandling = 0x01, ///< Enables exception handling support (say as optionally supported by C++)
+ Verbose = 0x02, ///< Give more verbose diagnostics
+ EnableSecurityChecks = 0x04, ///< Enable runtime security checks (such as for buffer overruns) - enabling typically decreases performance
+ EnableFloat16 = 0x08, ///< If set compiles with support for float16/half
};
};
diff --git a/source/compiler-core/slang-nvrtc-compiler.cpp b/source/compiler-core/slang-nvrtc-compiler.cpp
index b37743a32..7a8d5bcfd 100644
--- a/source/compiler-core/slang-nvrtc-compiler.cpp
+++ b/source/compiler-core/slang-nvrtc-compiler.cpp
@@ -13,12 +13,11 @@
#include "../core/slang-shared-library.h"
#include "../core/slang-semantic-version.h"
+#include "../core/slang-shared-library.h"
namespace nvrtc
{
-
-
typedef enum {
NVRTC_SUCCESS = 0,
NVRTC_ERROR_OUT_OF_MEMORY = 1,
@@ -124,12 +123,24 @@ protected:
nvrtcProgram m_program;
};
+ SlangResult _findIncludePath(String& outIncludePath);
+
+ SlangResult _getIncludePath(String& outIncludePath);
+
+ SlangResult _maybeAddHalfSupport(const CompileOptions& options, CommandLine& ioCmdLine);
#define SLANG_NVTRC_MEMBER_FUNCS(ret, name, params) \
ret (*m_##name) params;
SLANG_NVRTC_FUNCS(SLANG_NVTRC_MEMBER_FUNCS);
+ // Holds list of paths passed in where cuda_fp16.h is found. Does *NOT* include cuda_fp16.h.
+ List<String> m_cudaFp16FoundPaths;
+
+ bool m_includeSearched = false;
+ // Holds location of where include (for cuda_fp16.h) is found.
+ String m_includePath;
+
ComPtr<ISlangSharedLibrary> m_sharedLibrary;
};
@@ -232,6 +243,400 @@ static SlangResult _parseNVRTCLine(const UnownedStringSlice& line, DownstreamDia
return SLANG_E_NOT_FOUND;
}
+/* An implementation of Path::Visitor that can be used for finding NVRTC shared library installations. */
+struct NVRTCPathVisitor : Path::Visitor
+{
+ struct Candidate
+ {
+ typedef Candidate ThisType;
+
+ bool operator==(const ThisType& rhs) const { return path == rhs.path && version == rhs.version; }
+ bool operator!=(const ThisType& rhs) const { return !(*this == rhs); }
+
+ static Candidate make(const String& path, const SemanticVersion& version)
+ {
+ Candidate can;
+ can.version = version;
+ can.path = path;
+ return can;
+ }
+ String path;
+ SemanticVersion version;
+ };
+
+ Index findVersion(const SemanticVersion& version) const
+ {
+ const Index count = m_candidates.getCount();
+ for (Index i = 0; i < count; ++i)
+ {
+ if (m_candidates[i].version == version)
+ {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ static bool _orderCandiate(const Candidate& a, const Candidate& b) { return a.version < b.version; }
+ void sortCandidates() { m_candidates.sort(_orderCandiate); }
+
+
+#if SLANG_WINDOWS_FAMILY
+ SlangResult getVersion(const UnownedStringSlice& filename, SemanticVersion& outVersion)
+ {
+ // Versions on windows of the form
+ // nvrtc64_110_2.dll
+ // 11 - Major
+ // 0 Minor
+ // 2 Patch
+ Index endIndex = filename.indexOf('.');
+ endIndex = (endIndex < 0) ? filename.getLength() : endIndex;
+
+ // If we have a version slice, split it
+ UnownedStringSlice versionSlice = UnownedStringSlice(filename.begin() + m_prefix.getLength(), filename.begin() + endIndex);
+
+ if (versionSlice.getLength() <= 0)
+ {
+ return SLANG_E_NOT_FOUND;
+ }
+ Int patch = 0;
+ UnownedStringSlice majorMinorSlice;
+ {
+ List<UnownedStringSlice> slices;
+ StringUtil::split(versionSlice, '_', slices);
+ if (slices.getCount() >= 2)
+ {
+ // We don't bother checking for error here, if it's not parsable, it will be 0
+ StringUtil::parseInt(slices[1], patch);
+ }
+ majorMinorSlice = slices[0];
+ }
+
+ if (majorMinorSlice.getLength() < 2)
+ {
+ // Must be a major and minor
+ return SLANG_FAIL;
+ }
+
+ UnownedStringSlice majorSlice = majorMinorSlice.head(majorMinorSlice.getLength() - 1);
+ UnownedStringSlice minorSlice = majorMinorSlice.subString(majorMinorSlice.getLength() - 1, 1);
+
+ Int major;
+ Int minor;
+
+ SLANG_RETURN_ON_FAIL(StringUtil::parseInt(majorSlice, major));
+ SLANG_RETURN_ON_FAIL(StringUtil::parseInt(minorSlice, minor));
+
+ outVersion = SemanticVersion(int(major), int(minor), int(patch));
+ return SLANG_OK;
+ }
+#else
+ // How the path is constructed depends on platform
+ // https://docs.nvidia.com/cuda/nvrtc/index.html
+ // TODO(JS): Handle version number depending on the platform - it's different for Windows/OSX/Linux
+ SlangResult getVersion(const UnownedStringSlice& filename, SemanticVersion& outVersion)
+ {
+ SLANG_UNUSED(filename);
+ SLANG_UNUSED(outVersion);
+ return SLANG_E_NOT_IMPLEMENTED;
+ }
+
+#endif
+
+ void accept(Path::Type type, const UnownedStringSlice& filename) SLANG_OVERRIDE
+ {
+ // Lets make sure it start's with nvrtc, but not worry about case
+ if (type == Path::Type::File)
+ {
+ // If there is a defined extension, make sure it has it
+ if (m_postfix.getLength() && filename.getLength() >= m_postfix.getLength())
+ {
+ // We test without case - really for windows
+ UnownedStringSlice filenamePostfix = filename.tail(filename.getLength() - m_postfix.getLength());
+ if (!filenamePostfix.caseInsensitiveEquals(m_postfix.getUnownedSlice()))
+ {
+ return;
+ }
+ }
+
+
+ if (filename.getLength() >= m_prefix.getLength() &&
+ filename.subString(0, m_prefix.getLength()).caseInsensitiveEquals(m_prefix.getUnownedSlice()))
+ {
+ SemanticVersion version;
+ // If it produces an error, just use 0.0.0
+ if (SLANG_FAILED(getVersion(filename, version)))
+ {
+ version = SemanticVersion();
+ }
+
+ // We may want to add multiple versions, if they are in different locations - as there may be multiple entries
+ // in the PATH, and only one works. We'll only know which works by loading
+
+#if 0
+ // We already found this version, so let's not add it again
+ if (findVersion(version) >= 0)
+ {
+ return;
+ }
+#endif
+
+ // Strip to make a shared library name
+ UnownedStringSlice sharedLibraryName = filename.tail(m_prefix.getLength() - m_sharedLibraryStem.getLength());
+ sharedLibraryName = filename.head(filename.getLength() - m_postfix.getLength());
+
+ auto candidate = Candidate::make(Path::combine(m_basePath, sharedLibraryName), version);
+
+ // If we already have this candidate, then skip
+ if (m_candidates.indexOf(candidate) >= 0)
+ {
+ return;
+ }
+
+ // Add to the list of candidates
+ m_candidates.add(candidate);
+ }
+ }
+ }
+
+ SlangResult findInDirectory(const String& path)
+ {
+ m_basePath = path;
+ return Path::find(path, nullptr, this);
+ }
+
+ bool hasCandidates() const { return m_candidates.getCount() > 0; }
+
+ NVRTCPathVisitor(const UnownedStringSlice& sharedLibraryStem) :
+ m_sharedLibraryStem(sharedLibraryStem)
+ {
+ // Work out the prefix and postfix of the shader
+ StringBuilder buf;
+ SharedLibrary::appendPlatformFileName(sharedLibraryStem, buf);
+ const Index index = buf.indexOf(sharedLibraryStem);
+ SLANG_ASSERT(index >= 0);
+
+ m_prefix = buf.getUnownedSlice().head(index + sharedLibraryStem.getLength());
+ m_postfix = buf.getUnownedSlice().tail(index + sharedLibraryStem.getLength());
+ }
+
+ String m_prefix;
+ String m_postfix;
+ String m_basePath;
+ String m_sharedLibraryStem;
+
+ List<Candidate> m_candidates;
+};
+
+template <typename T>
+SLANG_FORCE_INLINE static void _unusedFunction(const T& func)
+{
+ SLANG_UNUSED(func);
+}
+
+#define SLANG_UNUSED_FUNCTION(x) _unusedFunction(x)
+
+static UnownedStringSlice _getNVRTCBaseName()
+{
+#if SLANG_WINDOWS_FAMILY && SLANG_PTR_IS_64
+ return UnownedStringSlice::fromLiteral("nvrtc64_");
+#else
+ return UnownedStringSlice::fromLiteral("nvrtc");
+#endif
+}
+
+// Candidates are in m_candidates list. Will be ordered from the oldest to newest (in version number)
+static SlangResult _findNVRTC(NVRTCPathVisitor& visitor)
+{
+ // First try the instance path (if supported on platform)
+ {
+ StringBuilder instancePath;
+ if (SLANG_SUCCEEDED(PlatformUtil::getInstancePath(instancePath)))
+ {
+ visitor.findInDirectory(instancePath);
+ }
+ }
+
+ // If we don't have a candidate try CUDA_PATH
+ if (!visitor.hasCandidates())
+ {
+ StringBuilder buf;
+ if (!SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("CUDA_PATH"), buf)))
+ {
+ // Look for candidates in the directory
+ visitor.findInDirectory(Path::combine(buf, "bin"));
+ }
+ }
+
+ // If we haven't we go searching through PATH
+ if (!visitor.hasCandidates())
+ {
+ List<UnownedStringSlice> splitPath;
+
+ StringBuilder buf;
+ if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("PATH"), buf)))
+ {
+ // Split so we get individual paths
+ List<UnownedStringSlice> paths;
+ StringUtil::split(buf.getUnownedSlice(), ';', paths);
+
+ // We use a pool to make sure we only check each path once
+ StringSlicePool pool(StringSlicePool::Style::Empty);
+
+ // We are going to search the paths in order
+ for (const auto path : paths)
+ {
+ // PATH can have the same path multiple times. If we have already searched this path, we don't need to again
+ if (!pool.has(path))
+ {
+ pool.add(path);
+
+ Path::split(path, splitPath);
+
+ // We could search every path, but here we restrict to paths that look like CUDA installations.
+ // It's a path that contains a CUDA directory and has bin
+ if (splitPath.indexOf("CUDA") >= 0 && splitPath[splitPath.getCount() - 1].caseInsensitiveEquals(UnownedStringSlice::fromLiteral("bin")))
+ {
+ // Okay lets search it
+ visitor.findInDirectory(path);
+ }
+ }
+ }
+ }
+ }
+
+ // Put into version order with oldest first.
+ visitor.sortCandidates();
+
+ return SLANG_OK;
+}
+
+static const UnownedStringSlice g_fp16HeaderName = UnownedStringSlice::fromLiteral("cuda_fp16.h");
+
+SlangResult NVRTCDownstreamCompiler::_getIncludePath(String& outPath)
+{
+ if (!m_includeSearched)
+ {
+ m_includeSearched = true;
+
+ SLANG_ASSERT(m_includePath.getLength() == 0);
+
+ _findIncludePath(m_includePath);
+ }
+
+ outPath = m_includePath;
+ return m_includePath.getLength() ? SLANG_OK : SLANG_E_NOT_FOUND;
+}
+
+SlangResult _findFileInIncludePath(const String& path, const UnownedStringSlice& filename, String& outPath)
+{
+ if (File::exists(Path::combine(path, filename)))
+ {
+ outPath = path;
+ return SLANG_OK;
+ }
+
+ {
+ String includePath = Path::combine(path, "include");
+ if (File::exists(Path::combine(includePath, filename)))
+ {
+ outPath = includePath;
+ return SLANG_OK;
+ }
+ }
+
+ {
+ String cudaIncludePath = Path::combine(path, "CUDA/include");
+ if (File::exists(Path::combine(cudaIncludePath, filename)))
+ {
+ outPath = cudaIncludePath;
+ return SLANG_OK;
+ }
+ }
+
+ return SLANG_E_NOT_FOUND;
+}
+
+SlangResult NVRTCDownstreamCompiler::_findIncludePath(String& outPath)
+{
+ outPath = String();
+
+ // Try looking up from a symbol. This will work as long as the nvrtc is loaded somehow from a dll/sharedlibrary
+ // And the header is included from there
+ {
+ String libPath = SharedLibraryUtils::getSharedLibraryFileName((void*)m_nvrtcCreateProgram);
+ if (libPath.getLength())
+ {
+ const String parentPath = Path::getParentDirectory(libPath);
+ if (SLANG_SUCCEEDED(_findFileInIncludePath(parentPath, g_fp16HeaderName, outPath)))
+ {
+ return SLANG_OK;
+ }
+ }
+ }
+
+ // Try CUDA_PATH environment variable
+ {
+ StringBuilder buf;
+ if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("CUDA_PATH"), buf)))
+ {
+ String includePath = Path::combine(buf, "include");
+
+ if (File::exists(Path::combine(includePath, g_fp16HeaderName)))
+ {
+ outPath = includePath;
+ return SLANG_OK;
+ }
+ }
+ }
+
+ return SLANG_E_NOT_FOUND;
+}
+
+SlangResult NVRTCDownstreamCompiler::_maybeAddHalfSupport(const CompileOptions& options, CommandLine& ioCmdLine)
+{
+ if ((options.flags & CompileOptions::Flag::EnableFloat16) == 0)
+ {
+ return SLANG_OK;
+ }
+
+ // First check if we know if one of the include paths contains cuda_fp16.h
+ for (const auto& includePath : options.includePaths)
+ {
+ if (m_cudaFp16FoundPaths.indexOf(includePath) >= 0)
+ {
+ // Okay we have an include path that we know works.
+ // Just need to enable HALF in prelude
+ ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_HALF");
+ return SLANG_OK;
+ }
+ }
+
+ // Let's see if one of the paths finds cuda_fp16.h
+ for (const auto& includePath : options.includePaths)
+ {
+ const String checkPath = Path::combine(includePath, g_fp16HeaderName);
+ if (File::exists(checkPath))
+ {
+ m_cudaFp16FoundPaths.add(includePath);
+ // Just need to enable HALF in prelude
+ ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_HALF");
+ return SLANG_OK;
+ }
+ }
+
+ String includePath;
+ SLANG_RETURN_ON_FAIL(_getIncludePath(includePath));
+
+ // Add the found include path
+ ioCmdLine.addArg("-I");
+ ioCmdLine.addArg(includePath);
+
+ ioCmdLine.addArg("-DSLANG_CUDA_ENABLE_HALF");
+
+ return SLANG_OK;
+}
+
SlangResult NVRTCDownstreamCompiler::compile(const CompileOptions& options, RefPtr<DownstreamCompileResult>& outResult)
{
// This compiler doesn't read files, they should be read externally and stored in sourceContents/sourceContentsPath
@@ -302,6 +707,8 @@ SlangResult NVRTCDownstreamCompiler::compile(const CompileOptions& options, RefP
cmdLine.addArg(include);
}
+ SLANG_RETURN_ON_FAIL(_maybeAddHalfSupport(options, cmdLine));
+
// Neither of these options are strictly required, for general use of nvrtc,
// but are enabled to make use withing Slang work more smoothly
{
@@ -391,7 +798,7 @@ SlangResult NVRTCDownstreamCompiler::compile(const CompileOptions& options, RefP
// We will define a dummy `stddef.h` that includes the bare minimum
// lines required to get the OptiX headers to compile without complaint.
//
- // TODO: Confirm that the `LP64` definition herei s actually needed.
+ // TODO: Confirm that the `LP64` definition here is actually needed.
//
headerIncludeNames.add("stddef.h");
headers.add("#pragma once\n" "#define LP64\n");
@@ -488,227 +895,15 @@ SlangResult NVRTCDownstreamCompiler::compile(const CompileOptions& options, RefP
return SLANG_OK;
}
-/* An implementation of Path::Visitor that can be used for finding NVRTC shared library installations. */
-struct NVRTCPathVisitor : Path::Visitor
-{
- struct Candidate
- {
- typedef Candidate ThisType;
-
- bool operator==(const ThisType& rhs) const { return path == rhs.path && version == rhs.version; }
- bool operator!=(const ThisType& rhs) const { return !(*this == rhs); }
-
- static Candidate make(const String& path, const SemanticVersion& version)
- {
- Candidate can;
- can.version = version;
- can.path = path;
- return can;
- }
- String path;
- SemanticVersion version;
- };
-
- Index findVersion(const SemanticVersion& version) const
- {
- const Index count = m_candidates.getCount();
- for (Index i = 0; i < count; ++i)
- {
- if (m_candidates[i].version == version)
- {
- return i;
- }
- }
- return -1;
- }
-
- static bool _orderCandiate(const Candidate& a, const Candidate& b) { return a.version < b.version; }
- void sortCandidates() { m_candidates.sort(_orderCandiate); }
-
- void accept(Path::Type type, const UnownedStringSlice& filename) SLANG_OVERRIDE
- {
- // Lets make sure it start's with nvrtc64, but not worry about case
- if (type == Path::Type::File)
- {
- // If there is a defined extension, make sure it has it
- if (m_postfix.getLength() && filename.getLength() >= m_postfix.getLength())
- {
- // We test without case - really for windows
- UnownedStringSlice filenamePostfix = filename.tail(filename.getLength() - m_postfix.getLength());
- if (!filenamePostfix.caseInsensitiveEquals(m_postfix.getUnownedSlice()))
- {
- return;
- }
- }
-
- if (filename.getLength() >= m_prefix.getLength() &&
- filename.subString(0, m_prefix.getLength()).caseInsensitiveEquals(m_prefix.getUnownedSlice()))
- {
- // Versions are typically (on windows) of the form
- // nvrtc64_110_2.dll
- // 11 - Major
- // 0 Minor
- // 2 Patch
- Index endIndex = filename.indexOf('.');
- endIndex = (endIndex < 0) ? filename.getLength() : endIndex;
-
- UnownedStringSlice versionSlice = UnownedStringSlice(filename.begin() + m_prefix.getLength(), filename.begin() + endIndex);
-
- Int patch = 0;
- UnownedStringSlice majorMinorSlice;
- {
- List<UnownedStringSlice> slices;
- StringUtil::split(versionSlice, '_', slices);
- if (slices.getCount() >= 2)
- {
- // We don't bother checking for error here, if it's not parsable, it will be 0
- StringUtil::parseInt(slices[1], patch);
- }
- majorMinorSlice = slices[0];
- }
-
- if (majorMinorSlice.getLength() < 2)
- {
- // Must be a major and minor
- return;
- }
-
- UnownedStringSlice majorSlice = majorMinorSlice.head(majorMinorSlice.getLength() - 1);
- UnownedStringSlice minorSlice = majorMinorSlice.subString(majorMinorSlice.getLength() - 1, 1);
-
- Int major;
- Int minor;
-
- if (SLANG_FAILED(StringUtil::parseInt(majorSlice, major)) ||
- SLANG_FAILED(StringUtil::parseInt(minorSlice, minor)))
- {
- return;
- }
-
- const SemanticVersion version = SemanticVersion(int(major), int(minor), int(patch));
-
- // We may want to add multiple versions, if they are in different locations - as there may be multiple entries
- // in the PATH, and only one works. We'll only know which works by loading
-#if 0
- // We already found this version, so let's not add it again
- if (findVersion(version) >= 0)
- {
- return;
- }
-#endif
-
- // Strip to make a shared library name
- UnownedStringSlice sharedLibraryName = filename.tail(m_prefix.getLength() - m_sharedLibraryStem.getLength());
- sharedLibraryName = filename.head(filename.getLength() - m_postfix.getLength());
-
- auto candidate = Candidate::make(Path::combine(m_basePath, sharedLibraryName), version);
-
- // If we already have this candidate, then skip
- if (m_candidates.indexOf(candidate) >= 0)
- {
- return;
- }
-
- // Add to the list of candidates
- m_candidates.add(candidate);
- }
- }
- }
-
- SlangResult findInDirectory(const String& path)
- {
- m_basePath = path;
- return Path::find(path, nullptr, this);
- }
-
- bool hasCandidates() const { return m_candidates.getCount() > 0; }
-
- NVRTCPathVisitor(const UnownedStringSlice& sharedLibraryStem):
- m_sharedLibraryStem(sharedLibraryStem)
- {
- // Work out the prefix and postfix of the shader
- StringBuilder buf;
- SharedLibrary::appendPlatformFileName(sharedLibraryStem, buf);
- const Index index = buf.indexOf(sharedLibraryStem);
- SLANG_ASSERT(index >= 0);
- m_prefix = buf.getUnownedSlice().head(index + sharedLibraryStem.getLength());
- m_postfix = buf.getUnownedSlice().tail(index + sharedLibraryStem.getLength());
- }
-
- String m_prefix;
- String m_postfix;
- String m_basePath;
- String m_sharedLibraryStem;
-
- List<Candidate> m_candidates;
-};
static SlangResult _findAndLoadNVRTC(ISlangSharedLibraryLoader* loader, ComPtr<ISlangSharedLibrary>& outLibrary)
{
#if SLANG_WINDOWS_FAMILY && SLANG_PTR_IS_64
- // We only need to search 64 bit versions on windows
- NVRTCPathVisitor visitor(UnownedStringSlice::fromLiteral("nvrtc64_"));
-
- // First try the instance path (if supported on platform)
- {
- StringBuilder instancePath;
- if (SLANG_SUCCEEDED(PlatformUtil::getInstancePath(instancePath)))
- {
- visitor.findInDirectory(instancePath);
- }
- }
-
- // If we don't have a candidate try CUDA_PATH
- if (!visitor.hasCandidates())
- {
- StringBuilder buf;
- if (!SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("CUDA_PATH"), buf)))
- {
- // Look for candidates in the directory
- visitor.findInDirectory(Path::combine(buf, "bin"));
- }
- }
-
- // If we haven't we go searching through PATH
- if (!visitor.hasCandidates())
- {
- List<UnownedStringSlice> splitPath;
-
- StringBuilder buf;
- if (SLANG_SUCCEEDED(PlatformUtil::getEnvironmentVariable(UnownedStringSlice::fromLiteral("PATH"), buf)))
- {
- // Split so we get individual paths
- List<UnownedStringSlice> paths;
- StringUtil::split(buf.getUnownedSlice(), ';', paths);
-
- // We use a pool to make sure we only check each path once
- StringSlicePool pool(StringSlicePool::Style::Empty);
-
- // We are going to search the paths in order
- for (const auto path : paths)
- {
- // PATH can have the same path multiple times. If we have already searched this path, we don't need to again
- if (!pool.has(path))
- {
- pool.add(path);
-
- Path::split(path, splitPath);
-
- // We could search every path, but here we restrict to paths that look like CUDA installations.
- // It's a path that contains a CUDA directory and has bin
- if (splitPath.indexOf("CUDA") >= 0 && splitPath[splitPath.getCount() - 1].caseInsensitiveEquals(UnownedStringSlice::fromLiteral("bin")))
- {
- // Okay lets search it
- visitor.findInDirectory(path);
- }
- }
- }
- }
- }
- // Put into version order with oldest first.
- visitor.sortCandidates();
+ // We only need to search 64 bit versions on windows
+ NVRTCPathVisitor visitor(_getNVRTCBaseName());
+ SLANG_RETURN_ON_FAIL(_findNVRTC(visitor));
// We want to start with the newest version...
for (Index i = visitor.m_candidates.getCount() - 1; i >= 0; --i)
@@ -719,9 +914,13 @@ static SlangResult _findAndLoadNVRTC(ISlangSharedLibraryLoader* loader, ComPtr<I
return SLANG_OK;
}
}
+
#else
SLANG_UNUSED(loader);
SLANG_UNUSED(outLibrary);
+
+ SLANG_UNUSED_FUNCTION(_getNVRTCBaseName);
+ SLANG_UNUSED_FUNCTION(_findNVRTC);
#endif
// This is an official-ish list of versions is here:
@@ -735,10 +934,8 @@ static SlangResult _findAndLoadNVRTC(ISlangSharedLibraryLoader* loader, ComPtr<I
return SLANG_E_NOT_FOUND;
}
-
/* static */SlangResult NVRTCDownstreamCompilerUtil::locateCompilers(const String& path, ISlangSharedLibraryLoader* loader, DownstreamCompilerSet* set)
{
-
ComPtr<ISlangSharedLibrary> library;
// If the user supplies a path to their preferred version of NVRTC,
diff --git a/source/core/slang-shared-library.cpp b/source/core/slang-shared-library.cpp
index 17b881540..844bea8d9 100644
--- a/source/core/slang-shared-library.cpp
+++ b/source/core/slang-shared-library.cpp
@@ -76,9 +76,21 @@ TemporarySharedLibrary::~TemporarySharedLibrary()
/* !!!!!!!!!!!!!!!!!!!!!!!!!! DefaultSharedLibrary !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!*/
-ISlangUnknown* DefaultSharedLibrary::getInterface(const Guid& guid)
+SLANG_NO_THROW SlangResult SLANG_MCALL DefaultSharedLibrary::queryInterface(SlangUUID const& uuid, void** outObject)
{
- return (guid == ISlangUnknown::getTypeGuid() || guid == ISlangSharedLibrary::getTypeGuid()) ? static_cast<ISlangSharedLibrary*>(this) : nullptr;
+ if (uuid == DefaultSharedLibrary::getTypeGuid())
+ {
+ *outObject = this;
+ return SLANG_OK;
+ }
+
+ if (uuid == ISlangUnknown::getTypeGuid() || uuid == ISlangSharedLibrary::getTypeGuid())
+ {
+ addReference();
+ *outObject = static_cast<ISlangSharedLibrary*>(this);
+ return SLANG_OK;
+ }
+ return SLANG_E_NO_INTERFACE;
}
DefaultSharedLibrary::~DefaultSharedLibrary()
diff --git a/source/core/slang-shared-library.h b/source/core/slang-shared-library.h
index c33f1f41b..452379d68 100644
--- a/source/core/slang-shared-library.h
+++ b/source/core/slang-shared-library.h
@@ -48,9 +48,13 @@ private:
class DefaultSharedLibrary : public ISlangSharedLibrary, public RefObject
{
public:
- // ISlangUnknown
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_CLASS_GUID(0xe7f2597b, 0xf803, 0x4b6e, { 0xaf, 0x8b, 0xcb, 0xe3, 0xa2, 0x21, 0xfd, 0x5a })
+ // ISlangUnknown
+ virtual SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE;
+ SLANG_REF_OBJECT_IUNKNOWN_ADD_REF
+ SLANG_REF_OBJECT_IUNKNOWN_RELEASE
+
// ISlangSharedLibrary
virtual SLANG_NO_THROW void* SLANG_MCALL findSymbolAddressByName(char const* name) SLANG_OVERRIDE;
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 761016866..754b3ac63 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -1066,6 +1066,7 @@ matrix<uint,N,M> asuint(matrix<uint,N,M> x)
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "uint16_t(packHalf2x16(vec2($0, 0.0)))")
+__target_intrinsic(cuda, "__half_as_ushort")
uint16_t asuint16(float16_t value);
vector<uint16_t,N> asuint16<let N : int>(vector<float16_t,N> value)
@@ -1078,6 +1079,7 @@ matrix<uint16_t,R,C> asuint16<let R : int, let C : int>(matrix<float16_t,R,C> va
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "float16_t(unpackHalf2x16($0).x)")
+__target_intrinsic(cuda, "__ushort_as_half")
float16_t asfloat16(uint16_t value);
vector<float16_t,N> asfloat16<let N : int>(vector<uint16_t,N> value)
@@ -1088,11 +1090,16 @@ matrix<float16_t,R,C> asfloat16<let R : int, let C : int>(matrix<uint16_t,R,C> v
// Float<->signed cases:
-__target_intrinsic(hlsl) [__unsafeForceInlineEarly] int16_t asint16(float16_t value) { return asuint16(value); }
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "__half_as_short")
+[__unsafeForceInlineEarly] int16_t asint16(float16_t value) { return asuint16(value); }
__target_intrinsic(hlsl) [__unsafeForceInlineEarly] vector<int16_t,N> asint16<let N : int>(vector<float16_t,N> value) { return asuint16(value); }
__target_intrinsic(hlsl) [__unsafeForceInlineEarly] matrix<int16_t,R,C> asint16<let R : int, let C : int>(matrix<float16_t,R,C> value) { return asuint16(value); }
-__target_intrinsic(hlsl) [__unsafeForceInlineEarly] float16_t asfloat16(int16_t value) { return asfloat16(asuint16(value)); }
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "__short_as_half")
+[__unsafeForceInlineEarly] float16_t asfloat16(int16_t value) { return asfloat16(asuint16(value)); }
+
__target_intrinsic(hlsl) [__unsafeForceInlineEarly] vector<float16_t,N> asfloat16<let N : int>(vector<int16_t,N> value) { return asfloat16(asuint16(value)); }
__target_intrinsic(hlsl) [__unsafeForceInlineEarly] matrix<float16_t,R,C> asfloat16<let R : int, let C : int>(matrix<int16_t,R,C> value) { return asfloat16(asuint16(value)); }
@@ -1593,6 +1600,8 @@ vector<float, N> f16tof32(vector<uint, N> value)
VECTOR_MAP_UNARY(float, N, f16tof32, value);
}
+
+
// Convert to 16-bit float stored in low bits of integer
__target_intrinsic(glsl, "packHalf2x16(vec2($0,0.0))")
__glsl_version(420)
@@ -1606,6 +1615,39 @@ vector<uint, N> f32tof16(vector<float, N> value)
VECTOR_MAP_UNARY(uint, N, f32tof16, value);
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+// The following is Slang specific and NOT part of standard HLSL
+// It's not clear what happens with float16 time in HLSL -> can the float16 coerce to uint for example? If so that would
+// give the wrong result
+
+__target_intrinsic(glsl, "unpackHalf2x16($0).x")
+__target_intrinsic(cuda, "__half2float")
+__glsl_version(420)
+float f16tof32(float16_t value);
+
+__generic<let N : int>
+__target_intrinsic(hlsl)
+__target_intrinsic(cuda, "__half2float")
+vector<float, N> f16tof32(vector<float16_t, N> value)
+{
+ VECTOR_MAP_UNARY(float, N, f16tof32, value);
+}
+
+// Convert to float16_t
+__target_intrinsic(glsl, "packHalf2x16(vec2($0,0.0))")
+__glsl_version(420)
+__target_intrinsic(cuda, "__float2half")
+float16_t f32tof16_(float value);
+
+__generic<let N : int>
+__target_intrinsic(cuda, "__float2half")
+vector<float16_t, N> f32tof16_(vector<float, N> value)
+{
+ VECTOR_MAP_UNARY(uint, N, f32tof16, value);
+}
+
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
// Flip surface normal to face forward, if needed
__generic<T : __BuiltinFloatingPointType, let N : int>
__target_intrinsic(hlsl)
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 19a5fddf8..1d416634a 100755
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -1421,6 +1421,11 @@ SlangResult dissassembleDXILUsingDXC(
options.requiredCapabilityVersions.add(version);
}
+
+ if (cudaTracker->isBaseTypeRequired(BaseType::Half))
+ {
+ options.flags |= CompileOptions::Flag::EnableFloat16;
+ }
}
options.sourceContents = source.source;
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index 2f5a9917d..a259ea933 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -27,7 +27,7 @@ static bool _isSingleNameBasicType(IROp op)
}
}
-/* static */ UnownedStringSlice CUDASourceEmitter::getBuiltinTypeName(IROp op)
+UnownedStringSlice CUDASourceEmitter::getBuiltinTypeName(IROp op)
{
switch (op)
{
@@ -44,8 +44,11 @@ static bool _isSingleNameBasicType(IROp op)
case kIROp_UIntType: return UnownedStringSlice("uint");
case kIROp_UInt64Type: return UnownedStringSlice("ulonglong");
- // Not clear just yet how we should handle half... we want all processing as float probly, but when reading/writing to memory converting
- case kIROp_HalfType: return UnownedStringSlice("half");
+ case kIROp_HalfType:
+ {
+ m_extensionTracker->requireBaseType(BaseType::Half);
+ return UnownedStringSlice("__half");
+ }
case kIROp_FloatType: return UnownedStringSlice("float");
case kIROp_DoubleType: return UnownedStringSlice("double");
@@ -54,7 +57,7 @@ static bool _isSingleNameBasicType(IROp op)
}
-/* static */ UnownedStringSlice CUDASourceEmitter::getVectorPrefix(IROp op)
+UnownedStringSlice CUDASourceEmitter::getVectorPrefix(IROp op)
{
switch (op)
{
@@ -70,8 +73,11 @@ static bool _isSingleNameBasicType(IROp op)
case kIROp_UIntType: return UnownedStringSlice("uint");
case kIROp_UInt64Type: return UnownedStringSlice("ulonglong");
- // Not clear just yet how we should handle half... we want all processing as float probly, but when reading/writing to memory converting
- case kIROp_HalfType: return UnownedStringSlice("half");
+ case kIROp_HalfType:
+ {
+ m_extensionTracker->requireBaseType(BaseType::Half);
+ return UnownedStringSlice("__half");
+ }
case kIROp_FloatType: return UnownedStringSlice("float");
case kIROp_DoubleType: return UnownedStringSlice("double");
@@ -160,12 +166,6 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target,
switch (type->getOp())
{
- case kIROp_HalfType:
- {
- // Special case half
- out << getBuiltinTypeName(kIROp_FloatType);
- return SLANG_OK;
- }
case kIROp_VectorType:
{
auto vecType = static_cast<IRVectorType*>(type);
diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h
index fefa40a11..a5d227c6b 100644
--- a/source/slang/slang-emit-cuda.h
+++ b/source/slang/slang-emit-cuda.h
@@ -11,7 +11,17 @@ class CUDAExtensionTracker : public RefObject
{
public:
+ typedef uint32_t BaseTypeFlags;
+
SemanticVersion m_smVersion;
+
+ void requireBaseType(BaseType baseType) { m_baseTypeFlags |= _getFlag(baseType); }
+ bool isBaseTypeRequired(BaseType baseType) { return (m_baseTypeFlags & _getFlag(baseType)) != 0; }
+
+protected:
+ static BaseTypeFlags _getFlag(BaseType baseType) { return BaseTypeFlags(1) << int(baseType); }
+
+ BaseTypeFlags m_baseTypeFlags = 0;
};
class CUDASourceEmitter : public CPPSourceEmitter
@@ -30,8 +40,8 @@ public:
};
};
- static UnownedStringSlice getBuiltinTypeName(IROp op);
- static UnownedStringSlice getVectorPrefix(IROp op);
+ UnownedStringSlice getBuiltinTypeName(IROp op);
+ UnownedStringSlice getVectorPrefix(IROp op);
virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; }
virtual void emitTempModifiers(IRInst* temp) SLANG_OVERRIDE;