From c2dc1a86ed2f5e160749fe9f99b70db6c3e4d7a6 Mon Sep 17 00:00:00 2001 From: skallweitNV <64953474+skallweitNV@users.noreply.github.com> Date: Mon, 12 Dec 2022 19:25:48 +0100 Subject: Refactor shader cache (#2558) * Fix a bug in Path::find * Fix code formatting * Fix LockFile and add LockFileGuard * Add PersistentCache and unit test * Replace file path dependency list with source file dependency list * Add note on ordering in Module/FileDependencyList * Remove old shader cache code * Refactor shader cache implementation * Temporarily skip unit tests reading/writing files * Fix warning * Reenable lock file test * Rename shader cache tests and disable crashing test * Testing * Stop using Path::getCanonical * Fix persistent cache lock and test * Fix threading issues * Move adding file dependency hashes to getEntryPointHash() * Fix handling of #include files * Allow specifying additional search paths for gfx testing device * Work on shader cache tests * Update project files * Revive shader cache graphics tests * Split graphics pipeline test * Fix compilation --- build/visual-studio/core/core.vcxproj | 2 + build/visual-studio/core/core.vcxproj.filters | 6 + .../gfx-unit-test-tool/gfx-unit-test-tool.vcxproj | 7 +- .../gfx-unit-test-tool.vcxproj.filters | 13 +- build/visual-studio/gfx/gfx.vcxproj | 2 - build/visual-studio/gfx/gfx.vcxproj.filters | 6 - build/visual-studio/slang-rt/slang-rt.vcxproj | 2 + .../slang-rt/slang-rt.vcxproj.filters | 6 + .../slang-unit-test-tool.vcxproj | 1 + .../slang-unit-test-tool.vcxproj.filters | 3 + slang-gfx.h | 52 +- slang.h | 36 +- source/core/slang-crypto.cpp | 14 +- source/core/slang-io.cpp | 2 +- source/core/slang-io.h | 27 +- source/core/slang-persistent-cache.cpp | 289 ++++ source/core/slang-persistent-cache.h | 91 ++ source/slang/slang-compiler.cpp | 29 +- source/slang/slang-compiler.h | 158 +-- source/slang/slang-preprocessor.cpp | 23 +- source/slang/slang-preprocessor.h | 2 +- source/slang/slang.cpp | 229 +--- tools/gfx-unit-test/gfx-test-util.cpp | 12 +- tools/gfx-unit-test/gfx-test-util.h | 4 +- .../multiple-entry-point-shader-cache-shader.slang | 28 - .../shader-cache-graphics-fragment.slang | 24 + .../shader-cache-graphics-vertex.slang | 36 + .../shader-cache-multiple-entry-points.slang | 31 + .../shader-cache-specialization.slang | 68 + tools/gfx-unit-test/shader-cache-tests.cpp | 1449 ++++++++------------ tools/gfx-unit-test/split-graphics-fragment.slang | 24 - tools/gfx-unit-test/split-graphics-vertex.slang | 36 - tools/gfx/gfx.slang | 24 +- tools/gfx/persistent-shader-cache.cpp | 316 ----- tools/gfx/persistent-shader-cache.h | 99 -- tools/gfx/renderer-shared.cpp | 94 +- tools/gfx/renderer-shared.h | 27 +- tools/slang-unit-test/unit-test-lock-file.cpp | 4 +- .../slang-unit-test/unit-test-persistent-cache.cpp | 629 +++++++++ 39 files changed, 2048 insertions(+), 1857 deletions(-) create mode 100644 source/core/slang-persistent-cache.cpp create mode 100644 source/core/slang-persistent-cache.h delete mode 100644 tools/gfx-unit-test/multiple-entry-point-shader-cache-shader.slang create mode 100644 tools/gfx-unit-test/shader-cache-graphics-fragment.slang create mode 100644 tools/gfx-unit-test/shader-cache-graphics-vertex.slang create mode 100644 tools/gfx-unit-test/shader-cache-multiple-entry-points.slang create mode 100644 tools/gfx-unit-test/shader-cache-specialization.slang delete mode 100644 tools/gfx-unit-test/split-graphics-fragment.slang delete mode 100644 tools/gfx-unit-test/split-graphics-vertex.slang delete mode 100644 tools/gfx/persistent-shader-cache.cpp delete mode 100644 tools/gfx/persistent-shader-cache.h create mode 100644 tools/slang-unit-test/unit-test-persistent-cache.cpp diff --git a/build/visual-studio/core/core.vcxproj b/build/visual-studio/core/core.vcxproj index e171a7d68..ecd6b54fe 100644 --- a/build/visual-studio/core/core.vcxproj +++ b/build/visual-studio/core/core.vcxproj @@ -297,6 +297,7 @@ + @@ -353,6 +354,7 @@ + diff --git a/build/visual-studio/core/core.vcxproj.filters b/build/visual-studio/core/core.vcxproj.filters index 06d151349..144c5259f 100644 --- a/build/visual-studio/core/core.vcxproj.filters +++ b/build/visual-studio/core/core.vcxproj.filters @@ -123,6 +123,9 @@ Header Files + + Header Files + Header Files @@ -287,6 +290,9 @@ Source Files + + Source Files + Source Files diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj index b0879ccfe..869b0fd27 100644 --- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj +++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj @@ -317,16 +317,17 @@ - + + - - + + diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters index 00fd1de16..58045ccee 100644 --- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters +++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters @@ -121,9 +121,6 @@ Source Files - - Source Files - Source Files @@ -142,13 +139,19 @@ Source Files + + Source Files + + + Source Files + Source Files - + Source Files - + Source Files diff --git a/build/visual-studio/gfx/gfx.vcxproj b/build/visual-studio/gfx/gfx.vcxproj index 8f200be8e..476f6a808 100644 --- a/build/visual-studio/gfx/gfx.vcxproj +++ b/build/visual-studio/gfx/gfx.vcxproj @@ -419,7 +419,6 @@ IF EXIST "$(SolutionDir)tools\gfx\slang.slang"\ (xcopy /Q /E /Y /I "$(SolutionDi - @@ -530,7 +529,6 @@ IF EXIST "$(SolutionDir)tools\gfx\slang.slang"\ (xcopy /Q /E /Y /I "$(SolutionDi - diff --git a/build/visual-studio/gfx/gfx.vcxproj.filters b/build/visual-studio/gfx/gfx.vcxproj.filters index c708450d5..1a7ed3d03 100644 --- a/build/visual-studio/gfx/gfx.vcxproj.filters +++ b/build/visual-studio/gfx/gfx.vcxproj.filters @@ -306,9 +306,6 @@ Header Files - - Header Files - Header Files @@ -635,9 +632,6 @@ Source Files - - Source Files - Source Files diff --git a/build/visual-studio/slang-rt/slang-rt.vcxproj b/build/visual-studio/slang-rt/slang-rt.vcxproj index 03d3852fd..92191667c 100644 --- a/build/visual-studio/slang-rt/slang-rt.vcxproj +++ b/build/visual-studio/slang-rt/slang-rt.vcxproj @@ -309,6 +309,7 @@ + @@ -366,6 +367,7 @@ + diff --git a/build/visual-studio/slang-rt/slang-rt.vcxproj.filters b/build/visual-studio/slang-rt/slang-rt.vcxproj.filters index df9f8c2ca..b99c077ae 100644 --- a/build/visual-studio/slang-rt/slang-rt.vcxproj.filters +++ b/build/visual-studio/slang-rt/slang-rt.vcxproj.filters @@ -123,6 +123,9 @@ Header Files + + Header Files + Header Files @@ -290,6 +293,9 @@ Source Files + + Source Files + Source Files diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj index deab210ee..5eec9ec82 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj @@ -296,6 +296,7 @@ + diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters index a33dc44cc..c350b6f24 100644 --- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters +++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters @@ -62,6 +62,9 @@ Source Files + + Source Files + Source Files diff --git a/slang-gfx.h b/slang-gfx.h index 590d7d5c7..07fd0d58a 100644 --- a/slang-gfx.h +++ b/slang-gfx.h @@ -2052,25 +2052,6 @@ public: 0xbe91ba6c, 0x784, 0x4308, { 0xa1, 0x0, 0x19, 0xc3, 0x66, 0x83, 0x44, 0xb2 } \ } -// These are exclusively used to track hit/miss counts for shader cache entries. Entry hit and -// miss counts specifically indicate if the file containing relevant shader code was found in -// the cache, while the general hit and miss counts indicate whether the file was both found and -// up-to-date. -class IShaderCacheStatistics : public ISlangUnknown -{ -public: - virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheMissCount() = 0; - virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheHitCount() = 0; - virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheEntryDirtyCount() = 0; - - virtual SLANG_NO_THROW Result SLANG_MCALL resetCacheStatistics() = 0; -}; - -#define SLANG_UUID_IShaderCacheStatistics \ - { \ - 0x8eccc8ec, 0x5c04, 0x4a51, { 0x99, 0x75, 0x13, 0xf8, 0xfe, 0xa1, 0x59, 0xf3 } \ - } - struct AdapterLUID { uint8_t luid[16]; @@ -2225,15 +2206,10 @@ public: struct ShaderCacheDesc { - // The filename for the file the cache's state should be saved to or loaded from. - const char* cacheFilename = "cache.txt"; - // The root directory for the shader cache. + // The root directory for the shader cache. If not set, shader cache is disabled. const char* shaderCachePath = nullptr; - // The file system for loading cached shader kernels. The layer does not maintain a strong reference to the object, - // instead the user is responsible for holding the object alive during the lifetime of an `IDevice`. - ISlangFileSystem* shaderCacheFileSystem = nullptr; // The maximum number of entries stored in the cache. By default, there is no limit. - GfxCount entryCountLimit = 0; + GfxCount maxEntryCount = 0; }; struct InteropHandles @@ -2597,6 +2573,30 @@ public: 0x715bdf26, 0x5135, 0x11eb, { 0xAE, 0x93, 0x02, 0x42, 0xAC, 0x13, 0x00, 0x02 } \ } +struct ShaderCacheStats +{ + GfxCount hitCount; + GfxCount missCount; + GfxCount entryCount; +}; + +// These are exclusively used to track hit/miss counts for shader cache entries. Entry hit and +// miss counts specifically indicate if the file containing relevant shader code was found in +// the cache, while the general hit and miss counts indicate whether the file was both found and +// up-to-date. +class IShaderCache : public ISlangUnknown +{ +public: + virtual SLANG_NO_THROW Result SLANG_MCALL clearShaderCache() = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL getShaderCacheStats(ShaderCacheStats* outStats) = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL resetShaderCacheStats() = 0; +}; + +#define SLANG_UUID_IShaderCache \ + { \ + 0x8eccc8ec, 0x5c04, 0x4a51, { 0x99, 0x75, 0x13, 0xf8, 0xfe, 0xa1, 0x59, 0xf3 } \ + } + class IPipelineCreationAPIDispatcher : public ISlangUnknown { public: diff --git a/slang.h b/slang.h index 784a3e763..4aa30b049 100644 --- a/slang.h +++ b/slang.h @@ -4425,34 +4425,16 @@ namespace slang IBlob** outCode, IBlob** outDiagnostics = nullptr) = 0; - /** Compute the hash code of all dependencies for this component type. This generally means file path - dependencies but can also include the component's name or sub-components. The dependency-based - hash effectively represents all the files that may be included/imported by a component type along with - any non-code-specific that helps define a component. This can be useful to simply check for a component - type without needing to inspect the code. For example, a shader cache might key its entries using the - dependency-based hash in order to determine at a glance if a particular shader is present, with no - regard for the shader's contents. - - This function should only have a meaningful implementation in ComponentType. All other types derived from - ComponentType that also inherit from IComponentType should do nothing. - */ - virtual SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash( - SlangInt entryPointIndex, - SlangInt targetIndex, - IBlob** outHash) = 0; - - /** Compute the hash code of this component type's contents as indicated by the file dependencies. - This hash is ideal when we need to confirm whether shader code changes have occurred. For example, - a shader cache needs to be able to check when a cache entry contains out-of-date code, which can be - easily detected by comparing the contents-based hashes since they will directly reflect any change - to the shader's code. - - This function should only have a meaningful implementation in ComponentType. All other types derived - from ComponentType that also inherit from IComponentType should do nothing. However, the only component - type that should ever be hashing its contents is Module as it represents all the code in a given - translation unit. + /** Compute a hash for the entry point at `entryPointIndex` for the chosen `targetIndex`. + + This computes a hash based on all the dependencies for this component type as well as the + target settings affecting the compiler backend. The computed hash is used as a key for caching + the output of the compiler backend to implement shader caching. */ - virtual SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(IBlob** outHash) = 0; + virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( + SlangInt entryPointIndex, + SlangInt targetIndex, + IBlob** outHash) = 0; /** Specialize the component by binding its specialization parameters to concrete arguments. diff --git a/source/core/slang-crypto.cpp b/source/core/slang-crypto.cpp index ece7b01e9..dfe246c7c 100644 --- a/source/core/slang-crypto.cpp +++ b/source/core/slang-crypto.cpp @@ -82,15 +82,19 @@ void MD5::update(const void* data, SlangInt size) saved_lo = m_lo; if ((m_lo = (saved_lo + size) & 0x1fffffff) < saved_lo) + { m_hi++; + } m_hi += (uint32_t)size >> 29; used = saved_lo & 0x3f; - if (used) { + if (used) + { available = 64 - used; - if (size < available) { + if (size < available) + { ::memcpy(&m_buffer[used], data, size); return; } @@ -101,7 +105,8 @@ void MD5::update(const void* data, SlangInt size) processBlock(m_buffer, 64); } - if (size >= 64) { + if (size >= 64) + { data = processBlock(data, size & ~(SlangInt)0x3f); size &= 0x3f; } @@ -119,7 +124,8 @@ MD5::Digest MD5::finalize() available = 64 - used; - if (available < 8) { + if (available < 8) + { ::memset(&m_buffer[used], 0, available); processBlock(m_buffer, 64); used = 0; diff --git a/source/core/slang-io.cpp b/source/core/slang-io.cpp index d8ef48ee3..f2be44f4f 100644 --- a/source/core/slang-io.cpp +++ b/source/core/slang-io.cpp @@ -682,7 +682,7 @@ namespace Slang WIN32_FIND_DATAW fileData; HANDLE findHandle = FindFirstFileW(searchPath.toWString(), &fileData); - if (!findHandle) + if (findHandle == INVALID_HANDLE_VALUE) { return SLANG_E_NOT_FOUND; } diff --git a/source/core/slang-io.h b/source/core/slang-io.h index fc5cbfa9d..e766359ff 100644 --- a/source/core/slang-io.h +++ b/source/core/slang-io.h @@ -266,9 +266,9 @@ namespace Slang private: LockFile(const LockFile&) = delete; - LockFile(LockFile&) = delete; + LockFile(LockFile&&) = delete; LockFile& operator=(const LockFile&) = delete; - LockFile& operator=(const LockFile&&) = delete; + LockFile& operator=(LockFile&&) = delete; #if SLANG_WINDOWS_FAMILY void* m_fileHandle; @@ -278,6 +278,29 @@ namespace Slang bool m_isOpen; }; + class LockFileGuard + { + public: + LockFileGuard(LockFile& lockFile, LockFile::LockType lockType = LockFile::LockType::Exclusive) + : m_lockFile(lockFile) + { + m_lockFile.lock(lockType); + } + + ~LockFileGuard() + { + m_lockFile.unlock(); + } + + private: + LockFileGuard(const LockFileGuard&) = delete; + LockFileGuard(LockFileGuard&&) = delete; + LockFileGuard& operator=(const LockFileGuard&) = delete; + LockFileGuard& operator=(LockFileGuard&&) = delete; + + LockFile& m_lockFile; + }; + } #endif diff --git a/source/core/slang-persistent-cache.cpp b/source/core/slang-persistent-cache.cpp new file mode 100644 index 000000000..2b4113e16 --- /dev/null +++ b/source/core/slang-persistent-cache.cpp @@ -0,0 +1,289 @@ +#include "slang-persistent-cache.h" + +#include "../core/slang-io.h" +#include "../core/slang-stream.h" +#include "../core/slang-string-util.h" +#include "../core/slang-blob.h" + +namespace Slang +{ + +PersistentCache::PersistentCache(const Desc& desc) +{ + m_cacheDirectory = Path::simplify(desc.directory); + Path::createDirectory(m_cacheDirectory); + + m_lockFileName = Path::simplify(m_cacheDirectory + "/lock"); + m_indexFileName = Path::simplify(m_cacheDirectory + "/index"); + + m_lockFile.open(m_lockFileName); + + m_maxEntryCount = desc.maxEntryCount; + + resetStats(); + + initialize(); +} + +PersistentCache::~PersistentCache() +{ +} + +SlangResult PersistentCache::clear() +{ + if (!m_lockFile.isOpen()) + { + return SLANG_E_CANNOT_OPEN; + } + + // Acquire the exclusive lock. + std::lock_guard mutexLock(m_mutex); + LockFileGuard fileLock(m_lockFile); + + struct Visitor : Path::Visitor + { + const String& directory; + const String& lockFileName; + + Visitor(const String& directory, const String& lockFileName) + : directory(directory) + , lockFileName(lockFileName) + {} + + void accept(Path::Type type, const UnownedStringSlice& fileName) SLANG_OVERRIDE + { + String fullPath = Path::simplify(directory + "/" + fileName);; + if (type == Path::Type::File && lockFileName != fullPath) + { + Path::remove(fullPath); + } + } + }; + + Visitor visitor(m_cacheDirectory, m_lockFileName); + Path::find(m_cacheDirectory, nullptr, &visitor); + + m_stats.entryCount = 0; + + return SLANG_OK; +} + +void PersistentCache::resetStats() +{ + m_stats.entryCount = 0; + m_stats.hitCount = 0; + m_stats.missCount = 0; +} + +SlangResult PersistentCache::readEntry(const Key& key, ISlangBlob** outData) +{ + // Be pessimistic and assume we have a cache miss. + ++m_stats.missCount; + + if (!m_lockFile.isOpen()) + { + return SLANG_E_CANNOT_OPEN; + } + + // Acquire the exclusive lock. + std::lock_guard mutexLock(m_mutex); + LockFileGuard fileLock(m_lockFile); + + // Return if index does not exist. + if (!File::exists(m_indexFileName)) + { + return SLANG_E_NOT_FOUND; + } + + // Read the cache index. + CacheIndex cacheIndex; + SLANG_RETURN_ON_FAIL(readIndex(m_indexFileName, cacheIndex)); + + // Increase the age of all entries in the cache. + for (auto& entry : cacheIndex) + { + ++entry.age; + } + + // Find the entry. + Index entryIndex = cacheIndex.findFirstIndex([&key] (const CacheEntry& entry) { return entry.key == key; }); + if (entryIndex == -1) + { + return SLANG_E_NOT_FOUND; + } + + // Read the entry. + String entryFileName = getEntryFileName(key); + ScopedAllocation data; + SlangResult result = File::readAllBytes(entryFileName, data); + if (result == SLANG_OK) + { + --m_stats.missCount; + ++m_stats.hitCount; + cacheIndex[entryIndex].age = 0; + auto blob = RawBlob::moveCreate(data); + *outData = blob.detach(); + } + else + { + cacheIndex.removeAt(entryIndex); + } + + // Write the cache index. + SLANG_RETURN_ON_FAIL(writeIndex(m_indexFileName, cacheIndex)); + m_stats.entryCount = (Count)cacheIndex.getCount(); + + return result; +} + +SlangResult PersistentCache::writeEntry(const Key& key, ISlangBlob* data) +{ + SLANG_ASSERT(data); + + if (!m_lockFile.isOpen()) + { + return SLANG_E_CANNOT_OPEN; + } + + // Acquire the exclusive lock. + std::lock_guard mutexLock(m_mutex); + LockFileGuard fileLock(m_lockFile); + + // Read the cache index. + // We ignore any errors when reading the index and just write a new one. + CacheIndex cacheIndex; + readIndex(m_indexFileName, cacheIndex); + + // Increase the age of all entries in the cache and get the index of + // the oldest entry. + Index oldestEntryIndex = -1; + uint32_t oldestEntryAge = 0; + for (Index entryIndex = 0; entryIndex < cacheIndex.getCount(); ++entryIndex) + { + auto& entry = cacheIndex[entryIndex]; + ++entry.age; + if (entry.age > oldestEntryAge) + { + oldestEntryIndex = entryIndex; + oldestEntryAge = entry.age; + } + } + + // Write the cache entry. + String entryFileName = getEntryFileName(key); + SLANG_RETURN_ON_FAIL(File::writeAllBytes(entryFileName, data->getBufferPointer(), data->getBufferSize())); + + // Update the index. + if (m_maxEntryCount > 0 && cacheIndex.getCount() >= m_maxEntryCount) + { + // Replace oldest entry. + SLANG_ASSERT(oldestEntryIndex >= 0); + File::remove(getEntryFileName(cacheIndex[oldestEntryIndex].key)); + cacheIndex[oldestEntryIndex] = CacheEntry{ key, 0 }; + } + else + { + // Add new entry. + cacheIndex.add(CacheEntry{ key, 0 }); + } + + // Write the cache index. + SlangResult result = writeIndex(m_indexFileName, cacheIndex); + if (result == SLANG_OK) + { + m_stats.entryCount = (Count)cacheIndex.getCount(); + } + else + { + // If writing the index failed, remove the entry file to avoid growing the cache. + Path::remove(entryFileName); + } + + return result; +} + +SlangResult PersistentCache::initialize() +{ + if (!m_lockFile.isOpen()) + { + return SLANG_E_CANNOT_OPEN; + } + + // Acquire the exclusive lock. + std::lock_guard mutexLock(m_mutex); + LockFileGuard fileLock(m_lockFile); + + CacheIndex cacheIndex; + if (SLANG_SUCCEEDED(readIndex(m_indexFileName, cacheIndex))) + { + m_stats.entryCount = (Count)cacheIndex.getCount(); + } + + return SLANG_OK; +} + +String PersistentCache::getEntryFileName(const Key& key) +{ + StringBuilder str; + str << m_cacheDirectory << "/" << key.toString(); + return str; +} + +struct CacheIndexHeader +{ + char magic[4]; + uint32_t version; + uint32_t count; + uint32_t reserved; +}; + +static const char* kMagic = "SLS$"; +static const uint32_t kVersion = 1; + +SlangResult PersistentCache::readIndex(const String& fileName, CacheIndex& outIndex) +{ + FileStream fs; + SLANG_RETURN_ON_FAIL(fs.init(fileName, FileMode::Open)); + + // Get file size. + SLANG_RETURN_ON_FAIL(fs.seek(SeekOrigin::End, 0)); + size_t fileSize = (size_t)fs.getPosition(); + SLANG_RETURN_ON_FAIL(fs.seek(SeekOrigin::Start, 0)); + + CacheIndexHeader header; + SLANG_RETURN_ON_FAIL(fs.readExactly(&header, sizeof(header))); + if (::memcmp(header.magic, kMagic, 4) != 0 || header.version != kVersion) + { + return SLANG_E_INTERNAL_FAIL; + } + + // Return if payload does not have the right size. + if (header.count * sizeof(CacheEntry) != fileSize - sizeof(header)) + { + return SLANG_E_INTERNAL_FAIL; + } + + outIndex.setCount(header.count); + SLANG_RETURN_ON_FAIL(fs.readExactly(outIndex.getBuffer(), header.count * sizeof(CacheEntry))); + + return SLANG_OK; +} + +SlangResult PersistentCache::writeIndex(const String& fileName, const CacheIndex& index) +{ + FileStream fs; + SLANG_RETURN_ON_FAIL(fs.init(fileName, FileMode::Create)); + + CacheIndexHeader header; + ::memcpy(header.magic, kMagic, 4); + header.version = kVersion; + header.count = (uint32_t)index.getCount(); + header.reserved = 0; + SLANG_RETURN_ON_FAIL(fs.write(&header, sizeof(header))); + + SLANG_RETURN_ON_FAIL(fs.write(index.getBuffer(), index.getCount() * sizeof(CacheEntry))); + + return SLANG_OK; +} + +} diff --git a/source/core/slang-persistent-cache.h b/source/core/slang-persistent-cache.h new file mode 100644 index 000000000..db5ef0a2e --- /dev/null +++ b/source/core/slang-persistent-cache.h @@ -0,0 +1,91 @@ +#pragma once +#include "../../slang.h" +#include "../core/slang-crypto.h" +#include "../core/slang-io.h" +#include "../core/slang-string.h" + +#include + +namespace Slang +{ + +/// Implements a simple persistent cache on the filesystem for storing key/value pairs. +/// Keys are SHA1 hashes and values are arbitrary blobs of data. +/// The cache is save for concurrent access from multiple threads/processes by using +/// a lock file within the cache directory. Furthermore, the cache implements a LRU +/// eviction policy. +class PersistentCache : public RefObject +{ +public: + struct Desc + { + // The root directory for the cache. + const char* directory = nullptr; + // The maximum number of entries stored in the cache. By default, there is no limit. + Count maxEntryCount = 0; + }; + + struct Stats + { + // Number of cache hits since last resetting the stats. + Count hitCount; + // Number of cache misses since last resetting the stats. + Count missCount; + // Current number of entries in the cache. + Count entryCount; + }; + + using Key = SHA1::Digest; + + PersistentCache(const Desc& desc); + ~PersistentCache(); + + /// Clear the contents of the cache by removing the cache index and all entry files. + SlangResult clear(); + + const Stats& getStats() const { return m_stats; } + void resetStats(); + + /// Read an entry from the cache. + /// Returns SLANG_OK if successful, SLANG_E_NOT_FOUND if the entry is not in the cache. + SlangResult readEntry(const Key& key, ISlangBlob** outData); + + /// Write an entry to the cache. + /// Returns SLANG_OK if successful. + SlangResult writeEntry(const Key& key, ISlangBlob* data); + +private: + struct CacheEntry + { + Key key; + uint32_t age; + }; + + using CacheIndex = List; + + SlangResult initialize(); + + String getEntryFileName(const Key& key); + + SlangResult readIndex(const String& fileName, CacheIndex& outIndex); + SlangResult writeIndex(const String& fileName, const CacheIndex& index); + + String m_cacheDirectory; + String m_lockFileName; + String m_indexFileName; + + // For exclusive locking we need both a mutex (acquired first) + // followed by a a file lock. The mutex is needed because on Linux + // the file lock is only locking between processes, not threads. + std::mutex m_mutex; + Slang::LockFile m_lockFile; + + Count m_maxEntryCount; + + Stats m_stats; + + // Used for unit tests. + friend struct PersistentCacheTest; +}; + +} diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index ce0ae4085..f2ebefb9d 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -223,14 +223,9 @@ namespace Slang visitor->visitEntryPoint(this, as(specializationInfo)); } - void EntryPoint::updateDependencyBasedHash( - DigestBuilder& builder, - SlangInt entryPointIndex) + void EntryPoint::buildHash(DigestBuilder& builder) { - // CompositeComponentType will have already hashed the relevant entry point's name - // and file path dependencies, so we immediately return. SLANG_UNUSED(builder); - SLANG_UNUSED(entryPointIndex); } List const& EntryPoint::getModuleDependencies() @@ -242,12 +237,12 @@ namespace Slang return empty; } - List const& EntryPoint::getFilePathDependencies() + List const& EntryPoint::getFileDependencies() { if(auto module = getModule()) - return getModule()->getFilePathDependencies(); + return getModule()->getFileDependencies(); - static List empty; + static List empty; return empty; } @@ -269,8 +264,8 @@ namespace Slang if (auto declaredWitness = as(witness)) { auto declModule = getModule(declaredWitness->declRef.getDecl()); - m_moduleDependency.addDependency(declModule); - m_pathDependency.addDependency(declModule); + m_moduleDependencyList.addDependency(declModule); + m_fileDependencyList.addDependency(declModule); if (m_requirementSet.Add(declModule)) { m_requirements.add(declModule); @@ -301,12 +296,8 @@ namespace Slang return Super::getInterface(guid); } - void TypeConformance::updateDependencyBasedHash( - DigestBuilder& builder, - SlangInt entryPointIndex) + void TypeConformance::buildHash(DigestBuilder& builder) { - SLANG_UNUSED(entryPointIndex); - //TODO: Implement some kind of hashInto for Val then replace this auto subtypeWitness = m_subtypeWitness->toString(); @@ -316,12 +307,12 @@ namespace Slang List const& TypeConformance::getModuleDependencies() { - return m_moduleDependency.getModuleList(); + return m_moduleDependencyList.getModuleList(); } - List const& TypeConformance::getFilePathDependencies() + List const& TypeConformance::getFileDependencies() { - return m_pathDependency.getFilePathList(); + return m_fileDependencyList.getFileList(); } Index TypeConformance::getRequirementCount() { return m_requirements.getCount(); } diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 692fefbe2..f377d4882 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -197,6 +197,7 @@ namespace Slang }; /// Tracks an ordered list of modules that something depends on. + /// TODO: Shader caching currently relies on this being in well defined order. struct ModuleDependencyList { public: @@ -216,15 +217,16 @@ namespace Slang HashSet m_moduleSet; }; - /// Tracks an unordered list of filesystem paths that something depends on - struct FilePathDependencyList + /// Tracks an unordered list of source files that something depends on + /// TODO: Shader caching currently relies on this being in well defined order. + struct FileDependencyList { public: - /// Get the list of paths that are depended on. - List const& getFilePathList() { return m_filePathList; } + /// Get the list of files that are depended on. + List const& getFileList() { return m_fileList; } - /// Add a path to the list, if it is not already present - void addDependency(String const& path); + /// Add a file to the list, if it is not already present + void addDependency(SourceFile* sourceFile); /// Add all of the paths that `module` depends on to the list void addDependency(Module* module); @@ -236,11 +238,11 @@ namespace Slang // multiple times from `getFilePathList`, but because // order isn't important, we could potentially do better // in terms of memory (at some cost in performance) by - // just sorting the `m_filePathList` every once in + // just sorting the `m_fileList` every once in // a while and then deduplicating. - List m_filePathList; - HashSet m_filePathSet; + List m_fileList; + HashSet m_fileSet; }; class EntryPoint; @@ -292,13 +294,12 @@ namespace Slang slang::IBlob** outDiagnostics) SLANG_OVERRIDE; /// ComponentType is the only class inheriting from IComponentType that provides a - /// meaningful implementation for these two functions. All others should forward these - /// and implement updateDependencyBasedHash and updateASTBasedHash instead. - SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash( + /// meaningful implementation for this function. All others should forward these and + /// implement `buildHash`. + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( SlangInt entryPointIndex, SlangInt targetIndex, slang::IBlob** outHash) SLANG_OVERRIDE; - SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(slang::IBlob** outHash) SLANG_OVERRIDE; /// Get the linkage (aka "session" in the public API) for this component type. Linkage* getLinkage() { return m_linkage; } @@ -309,14 +310,7 @@ namespace Slang TargetProgram* getTargetProgram(TargetRequest* target); /// Update the hash builder with the dependencies for this component type. - virtual void updateDependencyBasedHash( - DigestBuilder& hashBuilder, - SlangInt entryPointIndex) = 0; - - /// Update the hash builder with the source contents for this component type. - /// Module should be the only derived ComponentType class which has a meaningful - /// implementation; all others should do nothing. - virtual void updateContentsBasedHash(DigestBuilder& hashBuilder) = 0; + virtual void buildHash(DigestBuilder& builder) = 0; /// Get the number of entry points linked into this component type. virtual Index getEntryPointCount() = 0; @@ -371,9 +365,9 @@ namespace Slang /// virtual List const& getModuleDependencies() = 0; - /// Get the full list of filesystem paths this component type depends on. + /// Get the full list of source files this component type depends on. /// - virtual List const& getFilePathDependencies() = 0; + virtual List const& getFileDependencies() = 0; /// Callback for use with `enumerateIRModules` typedef void (*EnumerateIRModulesCallback)(IRModule* irModule, void* userData); @@ -515,11 +509,7 @@ namespace Slang Linkage* linkage, List> const& childComponents); - virtual void updateDependencyBasedHash( - DigestBuilder& hashBuilder, - SlangInt entryPointIndex) override; - - virtual void updateContentsBasedHash(DigestBuilder& hashBuilder) override; + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; List> const& getChildComponents() { return m_childComponents; }; Index getChildComponentCount() { return m_childComponents.getCount(); } @@ -540,7 +530,7 @@ namespace Slang RefPtr getRequirement(Index index) SLANG_OVERRIDE; List const& getModuleDependencies() SLANG_OVERRIDE; - List const& getFilePathDependencies() SLANG_OVERRIDE; + List const& getFileDependencies() SLANG_OVERRIDE; class CompositeSpecializationInfo : public SpecializationInfo { @@ -584,7 +574,7 @@ namespace Slang List m_requirements; ModuleDependencyList m_moduleDependencyList; - FilePathDependencyList m_filePathDependencyList; + FileDependencyList m_fileDependencyList; }; /// A component type created by specializing another component type. @@ -597,14 +587,7 @@ namespace Slang List const& specializationArgs, DiagnosticSink* sink); - virtual void updateDependencyBasedHash( - DigestBuilder& hashBuilder, - SlangInt entryPointIndex) override; - - virtual void updateContentsBasedHash(DigestBuilder& hashBuilder) override - { - SLANG_UNUSED(hashBuilder); - } + virtual void buildHash(DigestBuilder& builer) SLANG_OVERRIDE; /// Get the base (unspecialized) component type that is being specialized. RefPtr getBaseComponentType() { return m_base; } @@ -638,7 +621,7 @@ namespace Slang RefPtr getRequirement(Index index) SLANG_OVERRIDE; List const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; } - List const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencies; } + List const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencies; } /// Get a list of tagged-union types referenced by the specialization parameters. List const& getTaggedUnionTypes() { return m_taggedUnionTypes; } @@ -673,7 +656,7 @@ namespace Slang List m_taggedUnionTypes; List m_moduleDependencies; - List m_filePathDependencies; + List m_fileDependencies; List> m_requirements; }; @@ -748,9 +731,9 @@ namespace Slang { return m_base->getModuleDependencies(); } - List const& getFilePathDependencies() SLANG_OVERRIDE + List const& getFileDependencies() SLANG_OVERRIDE { - return m_base->getFilePathDependencies(); + return m_base->getFileDependencies(); } SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE @@ -790,14 +773,7 @@ namespace Slang void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; - virtual void updateDependencyBasedHash( - DigestBuilder& hashBuilder, - SlangInt entryPointIndex) override; - - virtual void updateContentsBasedHash(DigestBuilder& hashBuilder) override - { - SLANG_UNUSED(hashBuilder); - } + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; private: RefPtr m_base; @@ -891,27 +867,15 @@ namespace Slang return Super::getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); } - SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash( + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( SlangInt entryPointIndex, SlangInt targetIndex, slang::IBlob** outHash) SLANG_OVERRIDE { - return Super::computeDependencyBasedHash(entryPointIndex, targetIndex, outHash); - } - - SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(slang::IBlob** outHash) SLANG_OVERRIDE - { - return Super::computeContentsBasedHash(outHash); + return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); } - virtual void updateDependencyBasedHash( - DigestBuilder& hashBuilder, - SlangInt entryPointIndex) override; - - virtual void updateContentsBasedHash(DigestBuilder& hashBuilder) override - { - SLANG_UNUSED(hashBuilder); - } + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; /// Create an entry point that refers to the given function. static RefPtr create( @@ -948,7 +912,7 @@ namespace Slang /// but may also include modules that are required by its generic type arguments. /// List const& getModuleDependencies() SLANG_OVERRIDE; // { return getModule()->getModuleDependencies(); } - List const& getFilePathDependencies() SLANG_OVERRIDE; // { return getModule()->getFilePathDependencies(); } + List const& getFileDependencies() SLANG_OVERRIDE; // { return getModule()->getFileDependencies(); } /// Create a dummy `EntryPoint` that is only usable for pass-through compilation. static RefPtr createDummyForPassThrough( @@ -1118,30 +1082,18 @@ namespace Slang entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics); } - SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash( + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( SlangInt entryPointIndex, SlangInt targetIndex, slang::IBlob** outHash) SLANG_OVERRIDE { - return Super::computeDependencyBasedHash(entryPointIndex, targetIndex, outHash); - } - - SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(slang::IBlob** outHash) SLANG_OVERRIDE - { - return Super::computeContentsBasedHash(outHash); + return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); } - virtual void updateDependencyBasedHash( - DigestBuilder& hashBuilder, - SlangInt entryPointIndex) override; - - virtual void updateContentsBasedHash(DigestBuilder& hashBuilder) override - { - SLANG_UNUSED(hashBuilder); - } + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; List const& getModuleDependencies() SLANG_OVERRIDE; - List const& getFilePathDependencies() SLANG_OVERRIDE; + List const& getFileDependencies() SLANG_OVERRIDE; SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; } @@ -1182,8 +1134,8 @@ namespace Slang DiagnosticSink* sink) SLANG_OVERRIDE; private: SubtypeWitness* m_subtypeWitness; - ModuleDependencyList m_moduleDependency; - FilePathDependencyList m_pathDependency; + ModuleDependencyList m_moduleDependencyList; + FileDependencyList m_fileDependencyList; List> m_requirements; HashSet m_requirementSet; RefPtr m_irModule; @@ -1314,24 +1266,15 @@ namespace Slang // - SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash( + SLANG_NO_THROW void SLANG_MCALL getEntryPointHash( SlangInt entryPointIndex, SlangInt targetIndex, slang::IBlob** outHash) SLANG_OVERRIDE { - return Super::computeDependencyBasedHash(entryPointIndex, targetIndex, outHash); - } - - SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(slang::IBlob** outHash) SLANG_OVERRIDE - { - return Super::computeContentsBasedHash(outHash); + return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash); } - virtual void updateDependencyBasedHash( - DigestBuilder& hashBuilder, - SlangInt entryPointIndex) override; - - virtual void updateContentsBasedHash(DigestBuilder& hashBuilder) override; + virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; /// Create a module (initially empty). Module(Linkage* linkage, ASTBuilder* astBuilder = nullptr); @@ -1345,14 +1288,14 @@ namespace Slang /// Get the list of other modules this module depends on List const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); } - /// Get the list of filesystem paths this module depends on - List const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); } + /// Get the list of files this module depends on + List const& getFileDependencyList() { return m_fileDependencyList.getFileList(); } /// Register a module that this module depends on void addModuleDependency(Module* module); - /// Register a filesystem path that this module depends on - void addFilePathDependency(String const& path); + /// Register a source file that this module depends on + void addFileDependency(SourceFile* sourceFile); /// Set the AST for this module. /// @@ -1381,7 +1324,7 @@ namespace Slang RefPtr getRequirement(Index index) SLANG_OVERRIDE; List const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencyList.getModuleList(); } - List const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencyList.getFilePathList(); } + List const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencyList.getFileList(); } /// Given a mangled name finds the exported NodeBase associated with this module. /// If not found returns nullptr. @@ -1443,8 +1386,8 @@ namespace Slang // List of modules this module depends on ModuleDependencyList m_moduleDependencyList; - // List of filesystem paths this module depends on - FilePathDependencyList m_filePathDependencyList; + // List of source files this module depends on + FileDependencyList m_fileDependencyList; // Entry points that were defined in thsi module // @@ -1474,9 +1417,6 @@ namespace Slang // and m_mangledExportSymbols holds the NodeBase* values for each index. StringSlicePool m_mangledExportPool; List m_mangledExportSymbols; - - MD5::Digest lastModifiedDigest; - MD5::Digest contentsDigest; }; typedef Module LoadedModule; @@ -1768,12 +1708,10 @@ namespace Slang SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest( SlangCompileRequest** outCompileRequest) override; - // Updates the supplied has builder with linkage-related information, which includes preprocessor + // Updates the supplied builder with linkage-related information, which includes preprocessor // defines, the compiler version, and other compiler options. This is then merged with the hash // produced for the program to produce a key that can be used with the shader cache. - void updateDependencyBasedHash( - DigestBuilder& builder, - SlangInt targetIndex); + void buildHash(DigestBuilder& builder, SlangInt targetIndex); void addTarget( slang::TargetDesc const& desc); diff --git a/source/slang/slang-preprocessor.cpp b/source/slang/slang-preprocessor.cpp index fca8f5029..341e75ea4 100644 --- a/source/slang/slang-preprocessor.cpp +++ b/source/slang/slang-preprocessor.cpp @@ -32,9 +32,9 @@ void PreprocessorHandler::handleEndOfTranslationUnit(Preprocessor* preprocessor) SLANG_UNUSED(preprocessor); } -void PreprocessorHandler::handleFileDependency(String const& path) +void PreprocessorHandler::handleFileDependency(SourceFile* sourceFile) { - SLANG_UNUSED(path); + SLANG_UNUSED(sourceFile); } // In order to simplify the naming scheme, we will nest the implementaiton of the @@ -2966,15 +2966,6 @@ static SlangResult readFile( auto fileSystemExt = context->m_preprocessor->fileSystem; SLANG_RETURN_ON_FAIL(fileSystemExt->loadFile(path.getBuffer(), outBlob)); - // If we are running the preprocessor as part of compiling a - // specific module, then we must keep track of the file we've - // read as yet another file that the module will depend on. - // - if( auto handler = context->m_preprocessor->handler ) - { - handler->handleFileDependency(path); - } - return SLANG_OK; } @@ -3056,10 +3047,18 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context) } sourceFile = sourceManager->createSourceFileWithBlob(filePathInfo, foundSourceBlob); - sourceManager->addSourceFile(filePathInfo.uniqueIdentity, sourceFile); } + // If we are running the preprocessor as part of compiling a + // specific module, then we must keep track of the file we've + // read as yet another file that the module will depend on. + // + if (auto handler = context->m_preprocessor->handler) + { + handler->handleFileDependency(sourceFile); + } + // This is a new parse (even if it's a pre-existing source file), so create a new SourceView SourceView* sourceView = sourceManager->createSourceView(sourceFile, &filePathInfo, directiveLoc); diff --git a/source/slang/slang-preprocessor.h b/source/slang/slang-preprocessor.h index 4d7721d31..c37fd1607 100644 --- a/source/slang/slang-preprocessor.h +++ b/source/slang/slang-preprocessor.h @@ -28,7 +28,7 @@ using preprocessor::Preprocessor; struct PreprocessorHandler { virtual void handleEndOfTranslationUnit(Preprocessor* preprocessor); - virtual void handleFileDependency(String const& path); + virtual void handleFileDependency(SourceFile* sourceFile); }; /// Description of a preprocessor options/dependencies diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 3b2175fad..4f057bb21 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1320,9 +1320,7 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompileRequest( return SLANG_OK; } -void Linkage::updateDependencyBasedHash( - DigestBuilder& builder, - SlangInt targetIndex) +void Linkage::buildHash(DigestBuilder& builder, SlangInt targetIndex) { // Add the Slang compiler version to the hash auto version = String(getBuildTagString()); @@ -1346,29 +1344,39 @@ void Linkage::updateDependencyBasedHash( // Add the target specified by targetIndex auto targetReq = targets[targetIndex]; builder.append(targetReq->getTarget()); + builder.append(targetReq->getTargetProfile().raw); builder.append(targetReq->getTargetFlags()); builder.append(targetReq->getFloatingPointMode()); builder.append(targetReq->getLineDirectiveMode()); - builder.append(targetReq->shouldDumpIntermediates()); builder.append(targetReq->getForceGLSLScalarBufferLayout()); + builder.append(targetReq->getDefaultMatrixLayoutMode()); + builder.append(targetReq->shouldDumpIntermediates()); builder.append(targetReq->shouldTrackLiveness()); - auto targetProfile = targetReq->getTargetProfile(); - builder.append(targetProfile.getStage()); - builder.append(targetProfile.getVersion()); - builder.append(targetProfile.getFamily()); - - auto targetProfileName = String(targetProfile.getName()); - builder.append(targetProfileName); - auto cookedCapabilities = targetReq->getTargetCaps().getExpandedAtoms(); for (auto& capability : cookedCapabilities) { builder.append(capability); } + const PassThroughMode passThroughMode = getDownstreamCompilerRequiredForTarget(targetReq->getTarget()); + const SourceLanguage sourceLanguage = getDefaultSourceLanguageForDownstreamCompiler(passThroughMode); + + // Add prelude for the given downstream compiler. + ComPtr prelude; + getGlobalSession()->getLanguagePrelude((SlangSourceLanguage)sourceLanguage, prelude.writeRef()); + if (prelude) + { + builder.append(prelude); + } + + // TODO: Downstream compilers (specifically dxc) can currently #include additional dependencies. + // This is currently the case for NVAPI headers included in the prelude. + // These dependencies are currently not picked up by the shader cache which is a significant issue. + // This can only be fixed by running the preprocessor in the slang compiler so dxc (or any other + // downstream compiler for that matter) isn't resolving any includes implicitly. + // Add the downstream compiler version (if it exists) to the hash - auto passThroughMode = getDownstreamCompilerRequiredForTarget(targetReq->getTarget()); auto downstreamCompiler = getSessionImpl()->getOrLoadDownstreamCompiler(passThroughMode, nullptr); if (downstreamCompiler) { @@ -1674,25 +1682,7 @@ void TranslationUnitRequest::_addSourceFile(SourceFile* sourceFile) { m_sourceFiles.add(sourceFile); - // We want to record that the compiled module has a dependency - // on the path of the source file, but we also need to account - // for cases where the user added a source string/blob without - // an associated path and/or wasn't from a file. - - auto pathInfo = sourceFile->getPathInfo(); - if (pathInfo.hasFoundPath()) - { - getModule()->addFilePathDependency(pathInfo.foundPath); - } - else - { - // No path exists for this source, so we generate a new string to use as a - // fake path in the list of file path dependencies. This is needed to account - // for non-file-based dependencies later when shader files are being hashed for - // the shader cache. - auto sourceHash = MD5::compute(sourceFile->getContent().begin(), sourceFile->getContent().getLength()); - getModule()->addFilePathDependency(sourceHash.toString()); - } + getModule()->addFileDependency(sourceFile); } List const& TranslationUnitRequest::getSourceFiles() @@ -1896,9 +1886,9 @@ protected: // by applications to decide when they need to "hot reload" // their shader code. // - void handleFileDependency(String const& path) SLANG_OVERRIDE + void handleFileDependency(SourceFile* sourceFile) SLANG_OVERRIDE { - m_module->addFilePathDependency(path); + m_module->addFileDependency(sourceFile); } // The second task that this handler deals with is detecting @@ -3200,23 +3190,23 @@ void ModuleDependencyList::_addDependency(Module* module) } // -// FilePathDependencyList +// FileDependencyList // -void FilePathDependencyList::addDependency(String const& path) +void FileDependencyList::addDependency(SourceFile* sourceFile) { - if(m_filePathSet.Contains(path)) + if(m_fileSet.Contains(sourceFile)) return; - m_filePathList.add(path); - m_filePathSet.Add(path); + m_fileList.add(sourceFile); + m_fileSet.Add(sourceFile); } -void FilePathDependencyList::addDependency(Module* module) +void FileDependencyList::addDependency(Module* module) { - for(auto& path : module->getFilePathDependencyList()) + for(SourceFile* sourceFile : module->getFileDependencyList()) { - addDependency(path); + addDependency(sourceFile); } } @@ -3247,75 +3237,20 @@ ISlangUnknown* Module::getInterface(const Guid& guid) return Super::getInterface(guid); } -void Module::updateDependencyBasedHash( - DigestBuilder& builder, - SlangInt entryPointIndex) +void Module::buildHash(DigestBuilder& builder) { - // CompositeComponentType will have already hashed this Module's file - // dependencies. SLANG_UNUSED(builder); - SLANG_UNUSED(entryPointIndex); -} - -void Module::updateContentsBasedHash(DigestBuilder& builder) -{ - auto filePathDependencies = getFilePathDependencies(); - - DigestBuilder lastModifiedBuilder; - auto statFailed = false; - for (auto file : filePathDependencies) - { - struct stat fileStatus; - auto res = stat(file.getBuffer(), &fileStatus); - if (res != 0) - { - statFailed = true; - break; - } - lastModifiedBuilder.append(fileStatus.st_mtime); - } - - MD5::Digest temp = lastModifiedBuilder.finalize(); - if (statFailed || temp != lastModifiedDigest) - { - // Either a stat() call failed, or changes were made to at least one of the file dependencies, - // so we will need to re-generate the contents digest and save the new digest. - DigestBuilder contentsBuilder; - for (auto file : filePathDependencies) - { - List fileContents; - if (SLANG_FAILED(File::readAllBytes(file, fileContents))) - { - // Failure to read the file means this is a digest for the contents of a source - // file which does not live on disk. - contentsBuilder.append(file); - } - else - { - contentsBuilder.append(fileContents); - } - } - contentsDigest = contentsBuilder.finalize(); - if (!statFailed) - { - // If no stat() calls failed, then we have a valid last modified digest and should - // update the one we have saved. - lastModifiedDigest = temp; - } - } - - builder.append(contentsDigest); } void Module::addModuleDependency(Module* module) { m_moduleDependencyList.addDependency(module); - m_filePathDependencyList.addDependency(module); + m_fileDependencyList.addDependency(module); } -void Module::addFilePathDependency(String const& path) +void Module::addFileDependency(SourceFile* sourceFile) { - m_filePathDependencyList.addDependency(path); + m_fileDependencyList.addDependency(sourceFile); } void Module::setModuleDecl(ModuleDecl* moduleDecl) @@ -3505,12 +3440,12 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode( return artifact->loadBlob(ArtifactKeep::Yes, outCode); } -SLANG_NO_THROW void SLANG_MCALL ComponentType::computeDependencyBasedHash( +SLANG_NO_THROW void SLANG_MCALL ComponentType::getEntryPointHash( SlangInt entryPointIndex, SlangInt targetIndex, slang::IBlob** outHash) { - DigestBuilder builder; + DigestBuilder builder; // A note on enums that may be hashed in as part of the following two function calls: // @@ -3518,19 +3453,19 @@ SLANG_NO_THROW void SLANG_MCALL ComponentType::computeDependencyBasedHash( // the compiler, part of hashing the linkage is hashing in the compiler version. // Consequently, any encoding differences as a result of different compiler versions // will already be reflected in the resulting hash. - getLinkage()->updateDependencyBasedHash(builder, targetIndex); - updateDependencyBasedHash(builder, entryPointIndex); + getLinkage()->buildHash(builder, targetIndex); - // Add file path dependencies to the hash - all child components - // will have file path dependencies that are a subset of this list. - auto fileDeps = getFilePathDependencies(); - for (auto& file : fileDeps) + // Enumerate all file dependencies and add them to the hash. + for (SourceFile* sourceFile : getFileDependencies()) { - builder.append(file); + // TODO: We want to lazily evaluate & cache the source file digest + SHA1::Digest digest = SHA1::compute(sourceFile->getContent().begin(), sourceFile->getContent().getLength()); + builder.append(digest); } - // Add the name and name override for the specified entry point - // to the hash. + buildHash(builder); + + // Add the name and name override for the specified entry point to the hash. auto entryPointName = getEntryPoint(entryPointIndex)->getName()->text; builder.append(entryPointName); auto entryPointMangledName = getEntryPointMangledName(entryPointIndex); @@ -3542,14 +3477,6 @@ SLANG_NO_THROW void SLANG_MCALL ComponentType::computeDependencyBasedHash( *outHash = hash.detach(); } -SLANG_NO_THROW void SLANG_MCALL ComponentType::computeContentsBasedHash(slang::IBlob** outHash) -{ - DigestBuilder builder; - updateContentsBasedHash(builder); - auto hash = builder.finalize().toBlob(); - *outHash = hash.detach(); -} - SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointHostCallable( int entryPointIndex, int targetIndex, @@ -3864,9 +3791,9 @@ CompositeComponentType::CompositeComponentType( { m_moduleDependencyList.addDependency(module); } - for(auto filePath : child->getFilePathDependencies()) + for(auto sourceFile : child->getFileDependencies()) { - m_filePathDependencyList.addDependency(filePath); + m_fileDependencyList.addDependency(sourceFile); } auto childRequirementCount = child->getRequirementCount(); @@ -3882,25 +3809,13 @@ CompositeComponentType::CompositeComponentType( } } -void CompositeComponentType::updateDependencyBasedHash( - DigestBuilder& builder, - SlangInt entryPointIndex) -{ - auto componentCount = getChildComponentCount(); - - for (Index i = 0; i < componentCount; ++i) - { - getChildComponent(i)->updateDependencyBasedHash(builder, entryPointIndex); - } -} - -void CompositeComponentType::updateContentsBasedHash(DigestBuilder& builder) +void CompositeComponentType::buildHash(DigestBuilder& builder) { auto componentCount = getChildComponentCount(); for (Index i = 0; i < componentCount; ++i) { - getChildComponent(i)->updateContentsBasedHash(builder); + getChildComponent(i)->buildHash(builder); } } @@ -3959,9 +3874,9 @@ List const& CompositeComponentType::getModuleDependencies() return m_moduleDependencyList.getModuleList(); } -List const& CompositeComponentType::getFilePathDependencies() +List const& CompositeComponentType::getFileDependencies() { - return m_filePathDependencyList.getFilePathList(); + return m_fileDependencyList.getFileList(); } void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) @@ -4226,7 +4141,7 @@ SpecializedComponentType::SpecializedComponentType( // The starting point for our lists comes from the base component type. // m_moduleDependencies = base->getModuleDependencies(); - m_filePathDependencies = base->getFilePathDependencies(); + m_fileDependencies = base->getFileDependencies(); Index baseRequirementCount = base->getRequirementCount(); for( Index r = 0; r < baseRequirementCount; r++ ) @@ -4238,11 +4153,11 @@ SpecializedComponentType::SpecializedComponentType( // dependencies and requirements based on the modules that // were collected when looking at the specialization arguments. - // We want to avoid adding the same file path dependency more than once. + // We want to avoid adding the same file dependency more than once. // - HashSet filePathDependencySet; - for(auto path : m_filePathDependencies) - filePathDependencySet.Add(path); + HashSet fileDependencySet; + for(SourceFile* sourceFile : m_fileDependencies) + fileDependencySet.Add(sourceFile); for(auto module : moduleCollector.m_modulesList) { @@ -4259,7 +4174,7 @@ SpecializedComponentType::SpecializedComponentType( m_requirements.add(module); // The speciialized component type will also have a dependency - // on all the file paths that any of the modules involved in + // on all the files that any of the modules involved in // it depend on (including those that are required but not // yet linked in). // @@ -4268,12 +4183,12 @@ SpecializedComponentType::SpecializedComponentType( // source files, so we want to include anything that could // affect the validity of generated code. // - for(auto path : module->getFilePathDependencies()) + for(SourceFile* sourceFile : module->getFileDependencies()) { - if(filePathDependencySet.Contains(path)) + if(fileDependencySet.Contains(sourceFile)) continue; - filePathDependencySet.Add(path); - m_filePathDependencies.add(path); + fileDependencySet.Add(sourceFile); + m_fileDependencies.add(sourceFile); } // Finalyl we also add the module for the specialization arguments @@ -4383,9 +4298,7 @@ SpecializedComponentType::SpecializedComponentType( collector.visitSpecialized(this); } -void SpecializedComponentType::updateDependencyBasedHash( - DigestBuilder& builder, - SlangInt entryPointIndex) +void SpecializedComponentType::buildHash(DigestBuilder& builder) { auto specializationArgCount = getSpecializationArgCount(); for (Index i = 0; i < specializationArgCount; ++i) @@ -4395,7 +4308,7 @@ void SpecializedComponentType::updateDependencyBasedHash( builder.append(argString); } - getBaseComponentType()->updateDependencyBasedHash(builder, entryPointIndex); + getBaseComponentType()->buildHash(builder); } void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) @@ -4442,13 +4355,8 @@ void RenamedEntryPointComponentType::acceptVisitor( this, as(specializationInfo)); } -void RenamedEntryPointComponentType::updateDependencyBasedHash( - DigestBuilder& builder, - SlangInt entryPointIndex) +void RenamedEntryPointComponentType::buildHash(DigestBuilder& builder) { - // CompositeComponentType will have already hashed the name override and file - // dependencies for this entry point. - SLANG_UNUSED(entryPointIndex); SLANG_UNUSED(builder); } @@ -5144,14 +5052,15 @@ int EndToEndCompileRequest::getDependencyFileCount() { auto frontEndReq = getFrontEndReq(); auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); - return (int)program->getFilePathDependencies().getCount(); + return (int)program->getFileDependencies().getCount(); } char const* EndToEndCompileRequest::getDependencyFilePath(int index) { auto frontEndReq = getFrontEndReq(); auto program = frontEndReq->getGlobalAndEntryPointsComponentType(); - return program->getFilePathDependencies()[index].begin(); + SourceFile* sourceFile = program->getFileDependencies()[index]; + return sourceFile->getPathInfo().hasFileFoundPath() ? sourceFile->getPathInfo().foundPath.getBuffer() : "unknown"; } int EndToEndCompileRequest::getTranslationUnitCount() diff --git a/tools/gfx-unit-test/gfx-test-util.cpp b/tools/gfx-unit-test/gfx-test-util.cpp index 116b4222a..5cbb30e71 100644 --- a/tools/gfx-unit-test/gfx-test-util.cpp +++ b/tools/gfx-unit-test/gfx-test-util.cpp @@ -194,6 +194,7 @@ namespace gfx_test Slang::ComPtr createTestingDevice( UnitTestContext* context, Slang::RenderApiFlag::Enum api, + Slang::List additionalSearchPaths, gfx::IDevice::ShaderCacheDesc shaderCache) { Slang::ComPtr device; @@ -222,10 +223,13 @@ namespace gfx_test SLANG_IGNORE_TEST } deviceDesc.slang.slangGlobalSession = context->slangGlobalSession; - const char* searchPaths[] = { "", "../../tools/gfx-unit-test", "tools/gfx-unit-test" }; - deviceDesc.slang.searchPathCount = (SlangInt)SLANG_COUNT_OF(searchPaths); - deviceDesc.slang.searchPaths = searchPaths; - + Slang::List searchPaths; + searchPaths.add(""); + searchPaths.add("../../tools/gfx-unit-test"); + searchPaths.add("tools/gfx-unit-test"); + searchPaths.addRange(additionalSearchPaths); + deviceDesc.slang.searchPaths = searchPaths.getBuffer(); + deviceDesc.slang.searchPathCount = (gfx::GfxCount)searchPaths.getCount(); deviceDesc.shaderCache = shaderCache; gfx::D3D12DeviceExtendedDesc extDesc = {}; diff --git a/tools/gfx-unit-test/gfx-test-util.h b/tools/gfx-unit-test/gfx-test-util.h index d11d5623c..f829d6d12 100644 --- a/tools/gfx-unit-test/gfx-test-util.h +++ b/tools/gfx-unit-test/gfx-test-util.h @@ -77,6 +77,7 @@ namespace gfx_test Slang::ComPtr createTestingDevice( UnitTestContext* context, Slang::RenderApiFlag::Enum api, + Slang::List additionalSearchPaths = {}, gfx::IDevice::ShaderCacheDesc shaderCache = {}); void initializeRenderDoc(); @@ -88,13 +89,14 @@ namespace gfx_test const ImplFunc& f, UnitTestContext* context, Slang::RenderApiFlag::Enum api, + Slang::List searchPaths = {}, gfx::IDevice::ShaderCacheDesc shaderCache = {}) { if ((api & context->enabledApis) == 0) { SLANG_IGNORE_TEST } - auto device = createTestingDevice(context, api, shaderCache); + auto device = createTestingDevice(context, api, searchPaths, shaderCache); if (!device) { SLANG_IGNORE_TEST diff --git a/tools/gfx-unit-test/multiple-entry-point-shader-cache-shader.slang b/tools/gfx-unit-test/multiple-entry-point-shader-cache-shader.slang deleted file mode 100644 index 9287b62ea..000000000 --- a/tools/gfx-unit-test/multiple-entry-point-shader-cache-shader.slang +++ /dev/null @@ -1,28 +0,0 @@ -uniform RWStructuredBuffer buffer; - -[shader("compute")] -[numthreads(4, 1, 1)] -void computeA( -uint3 sv_dispatchThreadID : SV_DispatchThreadID) -{ - var input = buffer[sv_dispatchThreadID.x]; - buffer[sv_dispatchThreadID.x] = input + 1.0f; -} - -[shader("compute")] -[numthreads(4, 1, 1)] -void computeB( -uint3 sv_dispatchThreadID : SV_DispatchThreadID) -{ - var input = buffer[sv_dispatchThreadID.x]; - buffer[sv_dispatchThreadID.x] = input + 2.0f; -} - -[shader("compute")] -[numthreads(4, 1, 1)] -void computeC( -uint3 sv_dispatchThreadID : SV_DispatchThreadID) -{ - var input = buffer[sv_dispatchThreadID.x]; - buffer[sv_dispatchThreadID.x] = input + 3.0f; -} diff --git a/tools/gfx-unit-test/shader-cache-graphics-fragment.slang b/tools/gfx-unit-test/shader-cache-graphics-fragment.slang new file mode 100644 index 000000000..392aa15ba --- /dev/null +++ b/tools/gfx-unit-test/shader-cache-graphics-fragment.slang @@ -0,0 +1,24 @@ +// shader-cache-graphics-fragment.slang + +// Output of the vertex shader, and input to the fragment shader. +struct CoarseVertex +{ + float3 color; +}; + +// Output of the fragment shader +struct Fragment +{ + float4 color; +}; + +// Fragment Shader + +[shader("fragment")] +float4 main( + CoarseVertex coarseVertex : CoarseVertex) : SV_Target +{ + float3 color = coarseVertex.color; + + return float4(color, 1.0); +} diff --git a/tools/gfx-unit-test/shader-cache-graphics-vertex.slang b/tools/gfx-unit-test/shader-cache-graphics-vertex.slang new file mode 100644 index 000000000..a86f8bcf1 --- /dev/null +++ b/tools/gfx-unit-test/shader-cache-graphics-vertex.slang @@ -0,0 +1,36 @@ +// shader-cache-graphics-vertex.slang + +// Per-vertex attributes to be assembled from bound vertex buffers. +struct AssembledVertex +{ + float3 position : POSITION; +}; + +// Output of the vertex shader, and input to the fragment shader. +struct CoarseVertex +{ + float3 color; +}; + +// Vertex Shader + +struct VertexStageOutput +{ + CoarseVertex coarseVertex : CoarseVertex; + float4 sv_position : SV_Position; +}; + +[shader("vertex")] +VertexStageOutput main( + AssembledVertex assembledVertex) +{ + VertexStageOutput output; + + float3 position = assembledVertex.position; + float3 color = float3(1.0, 0.0, 0.0); + + output.coarseVertex.color = color; + output.sv_position = float4(position, 1.0); + + return output; +} diff --git a/tools/gfx-unit-test/shader-cache-multiple-entry-points.slang b/tools/gfx-unit-test/shader-cache-multiple-entry-points.slang new file mode 100644 index 000000000..a0015b83c --- /dev/null +++ b/tools/gfx-unit-test/shader-cache-multiple-entry-points.slang @@ -0,0 +1,31 @@ +// shader-cache-multiple-entry-points.slang + +[shader("compute")] +[numthreads(4, 1, 1)] +void computeA( + uint3 sv_dispatchThreadID: SV_DispatchThreadID, + uniform RWStructuredBuffer buffer) +{ + var input = buffer[sv_dispatchThreadID.x]; + buffer[sv_dispatchThreadID.x] = input + 1.0f; +} + +[shader("compute")] +[numthreads(4, 1, 1)] +void computeB( + uint3 sv_dispatchThreadID: SV_DispatchThreadID, + uniform RWStructuredBuffer buffer) +{ + var input = buffer[sv_dispatchThreadID.x]; + buffer[sv_dispatchThreadID.x] = input + 2.0f; +} + +[shader("compute")] +[numthreads(4, 1, 1)] +void computeC( + uint3 sv_dispatchThreadID: SV_DispatchThreadID, + uniform RWStructuredBuffer buffer) +{ + var input = buffer[sv_dispatchThreadID.x]; + buffer[sv_dispatchThreadID.x] = input + 3.0f; +} diff --git a/tools/gfx-unit-test/shader-cache-specialization.slang b/tools/gfx-unit-test/shader-cache-specialization.slang new file mode 100644 index 000000000..63994aee8 --- /dev/null +++ b/tools/gfx-unit-test/shader-cache-specialization.slang @@ -0,0 +1,68 @@ +// shader-cache-specialization.slang + +// This is a copy of `shader-object.slang` in `shader-object` example +// for use by compute-smoke gfx unit test. + +// This file implements a simple compute shader that transforms +// input floating point numbers stored in a `RWStructuredBuffer`. +// Specifically, for each number x from input buffer, compute +// f(x) and store the result back in the same buffer. + +// The compute shader supports multiple transformation functions, +// such add(x, c) which returns x+c, or mul(x, c) which returns x*c. +// This functions are implemented as types that conforms to the +// `ITransformer` interface. + +// The main entry point function takes a parameter of `ITransformer` +// type, and applies the transformation to numbers in the input +// buffer. By defining the shader parameter using interfaces, +// we enable the flexiblity to generate either specialized compute +// kernels that performs specific transformation or a general +// kernel that can perform any transformations encoded by the +// parameter at run-time, without changing any shader code or +// host-application logic for setting and preparing shader parameters. + +// Defines the transformer interface, which implements a single +// `transform` operation. +interface ITransformer +{ + float transform(float x); +} + +// Represents a transform function f(x) = x + c. +struct AddTransformer : ITransformer +{ + float c; + float transform(float x) { return x + c + 10.0f; } +}; + +// Represents a transform function f(x) = x * c. +struct MulTransformer : ITransformer +{ + float c; + float transform(float x) { return x * c; } +}; + +// Represents a composite function f(x) = f0(f1(x)); +struct CompositeTransformer : ITransformer +{ + ITransformer func0; + ITransformer func1; + float transform(float x) + { + return func0.transform(func1.transform(x)); + } +}; + +// Main entry-point. Applies the transformation encoded by `transformer` +// to all elements in `buffer`. +[shader("compute")] +[numthreads(4,1,1)] +void computeMain( + uint3 sv_dispatchThreadID : SV_DispatchThreadID, + uniform RWStructuredBuffer buffer, + uniform ITransformer transformer) +{ + var input = buffer[sv_dispatchThreadID.x]; + buffer[sv_dispatchThreadID.x] = transformer.transform(input); +} diff --git a/tools/gfx-unit-test/shader-cache-tests.cpp b/tools/gfx-unit-test/shader-cache-tests.cpp index 486b59cda..4cccc726f 100644 --- a/tools/gfx-unit-test/shader-cache-tests.cpp +++ b/tools/gfx-unit-test/shader-cache-tests.cpp @@ -5,8 +5,7 @@ #include "tools/gfx-util/shader-cursor.h" #include "source/core/slang-basic.h" #include "source/core/slang-string-util.h" - -#include "source/core/slang-memory-file-system.h" +#include "source/core/slang-io.h" #include "source/core/slang-file-system.h" #include "gfx-test-texture-util.h" @@ -17,69 +16,136 @@ using namespace Slang; namespace gfx_test { - struct BaseShaderCacheTest + struct ShaderCacheTest { UnitTestContext* context; - RenderApiFlag::Enum api; + Slang::RenderApiFlag::Enum api; + + String testDirectory; + String cacheDirectory; + + ComPtr diskFileSystem; + + IDevice::ShaderCacheDesc shaderCacheDesc = {}; ComPtr device; - ComPtr shaderCacheStats; + ComPtr shaderCache; ComPtr pipelineState; ComPtr bufferView; - IDevice::ShaderCacheDesc shaderCache = {}; - - // Two file systems in order to get around problems posed by the testing framework. - // - // - diskFileSystem - Used to save any files that must exist on disk for subsequent - // save/load function calls (most prominently loadComputeProgram()) to pick up. - // This is also used to test the file stream implementation for the cache. - // - memoryFileSystem - Used to test the fallback path for the cache in the case physical - // file paths cannot be obtained, which prevents usage of file streams. - ComPtr diskFileSystem; - ComPtr memoryFileSystem; - - // Simple compute shaders we can pipe to our individual shader files for cache testing - String contentsA = String( + String computeShaderA = String( R"( - uniform RWStructuredBuffer buffer; - [shader("compute")] [numthreads(4, 1, 1)] - void computeMain( - uint3 sv_dispatchThreadID : SV_DispatchThreadID) + void main( + uint3 sv_dispatchThreadID : SV_DispatchThreadID, + uniform RWStructuredBuffer buffer) { var input = buffer[sv_dispatchThreadID.x]; buffer[sv_dispatchThreadID.x] = input + 1.0f; - })"); + } + )"); - String contentsB = String( + String computeShaderB = String( R"( - uniform RWStructuredBuffer buffer; - [shader("compute")] [numthreads(4, 1, 1)] - void computeMain( - uint3 sv_dispatchThreadID : SV_DispatchThreadID) + void main( + uint3 sv_dispatchThreadID : SV_DispatchThreadID, + uniform RWStructuredBuffer buffer) { var input = buffer[sv_dispatchThreadID.x]; buffer[sv_dispatchThreadID.x] = input + 2.0f; - })"); + } + )"); - String contentsC = String( + String computeShaderC = String( R"( - uniform RWStructuredBuffer buffer; - [shader("compute")] [numthreads(4, 1, 1)] - void computeMain( - uint3 sv_dispatchThreadID : SV_DispatchThreadID) + void main( + uint3 sv_dispatchThreadID : SV_DispatchThreadID, + uniform RWStructuredBuffer buffer) { var input = buffer[sv_dispatchThreadID.x]; buffer[sv_dispatchThreadID.x] = input + 3.0f; - })"); + } + )"); + + + void removeDirectory(const String& directory) + { + auto osFileSystem = OSFileSystem::getMutableSingleton(); + + struct Context + { + ISlangMutableFileSystem *fileSystem; + const String& directory; + } context { osFileSystem, directory }; + + osFileSystem->enumeratePathContents( + directory.getBuffer(), + [](SlangPathType pathType, const char* fileName, void* userData) + { + struct Context* context = static_cast(userData); + if (pathType == SlangPathType::SLANG_PATH_TYPE_FILE) + { + String path = Path::simplify(context->directory + "/" + fileName); + context->fileSystem->remove(path.getBuffer()); + } + }, + &context); + + osFileSystem->remove(directory.getBuffer()); + } + + void writeShader(const String& source, const String& fileName) + { + diskFileSystem->saveFile(fileName.getBuffer(), source.getBuffer(), source.getLength()); + } + + void init(UnitTestContext* context, Slang::RenderApiFlag::Enum api) + { + this->context = context; + this->api = api; + + testDirectory = Path::simplify(Path::getParentDirectory(Path::getExecutablePath()) + "/shader-cache-test"); + cacheDirectory = Path::simplify(testDirectory + "/cache"); + + // Cleanup if there are stale files from a previously aborted test. + removeDirectory(cacheDirectory); + removeDirectory(testDirectory); + + Path::createDirectory(testDirectory); + diskFileSystem = new RelativeFileSystem(OSFileSystem::getMutableSingleton(), testDirectory); + shaderCacheDesc.shaderCachePath = cacheDirectory.getBuffer(); + } + + void cleanup() + { + removeDirectory(cacheDirectory); + removeDirectory(testDirectory); + } + + template + void runStep(Func func) + { + List additionalSearchPaths; + additionalSearchPaths.add(testDirectory.getBuffer()); + + runTestImpl( + [this, func] (IDevice* device, UnitTestContext* ctx) + { + this->device = device; + device->queryInterface(SLANG_UUID_IShaderCache, (void**)this->shaderCache.writeRef()); + func(); + this->device = nullptr; + this->shaderCache = nullptr; + }, + context, api, additionalSearchPaths, shaderCacheDesc); + } - void createRequiredResources() + void createComputeResources() { const int numberCount = 4; float initialData[] = { 0.0f, 1.0f, 2.0f, 3.0f }; @@ -108,57 +174,25 @@ namespace gfx_test device->createBufferView(numbersBuffer, nullptr, viewDesc, bufferView.writeRef())); } - void freeOldResources() + void freeComputeResources() { bufferView = nullptr; pipelineState = nullptr; - device = nullptr; - shaderCacheStats = nullptr; } - // TODO: This should be removed at some point. Currently exists as a workaround for module loading - // seemingly not accounting for updated shader code under the same module name with the same entry point. - void generateNewDevice() + void createComputePipeline(const char* moduleName, const char* entryPointName) { - freeOldResources(); - device = createTestingDevice(context, api, shaderCache); - } - - void init(ComPtr device, UnitTestContext* context) - { - this->device = device; - this->context = context; - switch (device->getDeviceInfo().deviceType) - { - case DeviceType::DirectX11: - api = RenderApiFlag::D3D11; - break; - case DeviceType::DirectX12: - api = RenderApiFlag::D3D12; - break; - case DeviceType::Vulkan: - api = RenderApiFlag::Vulkan; - break; - case DeviceType::CPU: - api = RenderApiFlag::CPU; - break; - case DeviceType::CUDA: - api = RenderApiFlag::CUDA; - break; - case DeviceType::OpenGl: - api = RenderApiFlag::OpenGl; - break; - default: - SLANG_IGNORE_TEST - } + ComPtr shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, moduleName, entryPointName, slangReflection)); - memoryFileSystem = new MemoryFileSystem(); - diskFileSystem = OSFileSystem::getMutableSingleton(); - diskFileSystem->createDirectory("tools/gfx-unit-test/shader-cache-test"); - diskFileSystem = new RelativeFileSystem(diskFileSystem, "tools/gfx-unit-test/shader-cache-test"); + ComputePipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + GFX_CHECK_CALL_ABORT( + device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); } - void submitGPUWork() + void dispatchComputePipeline() { ComPtr transientHeap; ITransientResourceHeap::Desc transientHeapDesc = {}; @@ -174,437 +208,284 @@ namespace gfx_test auto rootObject = encoder->bindPipeline(pipelineState); - ShaderCursor rootCursor(rootObject); + ShaderCursor entryPointCursor(rootObject->getEntryPoint(0)); + entryPointCursor.getPath("buffer").setResource(bufferView); + + // ShaderCursor rootCursor(rootObject); // Bind buffer view to the entry point. - rootCursor.getPath("buffer").setResource(bufferView); + // rootCursor.getPath("buffer").setResource(bufferView); - encoder->dispatchCompute(1, 1, 1); + encoder->dispatchCompute(4, 1, 1); encoder->endEncoding(); commandBuffer->close(); queue->executeCommandBuffer(commandBuffer); queue->waitOnHost(); - } + } - void cleanUpFiles() + void runComputePipeline(const char* moduleName, const char* entryPointName) { - freeOldResources(); - - List filePaths; - diskFileSystem->enumeratePathContents( - ".", - [](SlangPathType pathType, const char* name, void* userData) - { - if (pathType == SlangPathType::SLANG_PATH_TYPE_FILE) - { - List& out = *(List*)userData; - out.add(String(name)); - } - }, - &filePaths); + createComputeResources(); + createComputePipeline(moduleName, entryPointName); + dispatchComputePipeline(); + freeComputeResources(); + } - for (auto file : filePaths) - { - diskFileSystem->remove(file.getBuffer()); - } - // Get a mutable singleton so we can delete the folder. - auto fileSystem = OSFileSystem::getMutableSingleton(); - fileSystem->remove("tools/gfx-unit-test/shader-cache-test"); + ShaderCacheStats getStats() + { + SLANG_ASSERT(shaderCache); + ShaderCacheStats stats; + shaderCache->getShaderCacheStats(&stats); + return stats; } - void run() + void run(UnitTestContext* context, Slang::RenderApiFlag::Enum api) { - shaderCache.shaderCacheFileSystem = diskFileSystem; - runTests(); - shaderCache.shaderCacheFileSystem = memoryFileSystem; + init(context, api); runTests(); - - cleanUpFiles(); + cleanup(); } virtual void runTests() = 0; }; - // Due to needing a workaround to prevent loading old, outdated modules, we need to - // recreate the device between each segment of the test for all tests. However, we need to maintain the - // same cache filesystem for the same duration, so the device is immediately recreated - // to ensure we can pass the filesystem all the way through. - // - // General TODO: Remove the repeated generateNewDevice() and createRequiredResources() calls once - // a solution exists that allows source code changes under the same module name to be picked - // up on load. - - // One shader file on disk, all modifications are done to the same file - struct SingleEntryShaderCache : BaseShaderCacheTest + // Basic shader cache test using 3 different shader files stored on disk. + struct ShaderCacheTestBasic : ShaderCacheTest { - void generateNewPipelineState(Slang::String shaderContents) + void runTests() { - diskFileSystem->saveFile("test-tmp-single-entry.slang", shaderContents.getBuffer(), shaderContents.getLength()); + // Write shader source files. + writeShader(computeShaderA, "shader-cache-tmp-a.slang"); + writeShader(computeShaderB, "shader-cache-tmp-b.slang"); + writeShader(computeShaderC, "shader-cache-tmp-c.slang"); - ComPtr shaderProgram; - slang::ProgramLayout* slangReflection; - GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "shader-cache-test/test-tmp-single-entry", "computeMain", slangReflection)); + // Cache is cold and we expect 3 misses. + runStep( + [this]() + { + runComputePipeline("shader-cache-tmp-a", "main"); + runComputePipeline("shader-cache-tmp-b", "main"); + runComputePipeline("shader-cache-tmp-c", "main"); - ComputePipelineStateDesc pipelineDesc = {}; - pipelineDesc.program = shaderProgram.get(); - GFX_CHECK_CALL_ABORT( - device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); - } + SLANG_CHECK(getStats().missCount == 3); + SLANG_CHECK(getStats().hitCount == 0); + SLANG_CHECK(getStats().entryCount == 3); + } + ); - void runTests() - { - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(contentsA); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(contentsA); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(contentsC); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1); - } - }; + // Cache is hot and we expect 3 hits. + runStep( + [this]() + { + runComputePipeline("shader-cache-tmp-a", "main"); + runComputePipeline("shader-cache-tmp-b", "main"); + runComputePipeline("shader-cache-tmp-c", "main"); - // Several shader files on disk, modifications may be done to any file - struct MultipleEntryShaderCache : BaseShaderCacheTest - { - void modifyShaderA(String shaderContents) - { - diskFileSystem->saveFile("test-tmp-multi-entry-A.slang", shaderContents.getBuffer(), shaderContents.getLength()); - } + SLANG_CHECK(getStats().missCount == 0); + SLANG_CHECK(getStats().hitCount == 3); + SLANG_CHECK(getStats().entryCount == 3); + } + ); - void modifyShaderB(String shaderContents) - { - diskFileSystem->saveFile("test-tmp-multi-entry-B.slang", shaderContents.getBuffer(), shaderContents.getLength()); - } + // Write shader source files, all rotated by one. + writeShader(computeShaderA, "shader-cache-tmp-b.slang"); + writeShader(computeShaderB, "shader-cache-tmp-c.slang"); + writeShader(computeShaderC, "shader-cache-tmp-a.slang"); - void modifyShaderC(String shaderContents) - { - diskFileSystem->saveFile("test-tmp-multi-entry-C.slang", shaderContents.getBuffer(), shaderContents.getLength()); - } + // Cache is cold again and we expect 3 misses. + runStep( + [this]() + { + runComputePipeline("shader-cache-tmp-a", "main"); + runComputePipeline("shader-cache-tmp-b", "main"); + runComputePipeline("shader-cache-tmp-c", "main"); - void generateNewPipelineState(GfxIndex shaderIndex) - { - ComPtr shaderProgram; - slang::ProgramLayout* slangReflection; - const char* shaderFilename; - switch (shaderIndex) - { - case 0: - shaderFilename = "shader-cache-test/test-tmp-multi-entry-A"; - break; - case 1: - shaderFilename = "shader-cache-test/test-tmp-multi-entry-B"; - break; - case 2: - shaderFilename = "shader-cache-test/test-tmp-multi-entry-C"; - break; - default: - // Should never reach this point since we wrote the test - SLANG_IGNORE_TEST; - } - GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, shaderFilename, "computeMain", slangReflection)); - - ComputePipelineStateDesc pipelineDesc = {}; - pipelineDesc.program = shaderProgram.get(); - GFX_CHECK_CALL_ABORT( - device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); - } + SLANG_CHECK(getStats().missCount == 3); + SLANG_CHECK(getStats().hitCount == 0); + SLANG_CHECK(getStats().entryCount == 6); + } + ); - void checkAllCacheEntries() - { - generateNewPipelineState(0); - submitGPUWork(); - generateNewPipelineState(1); - submitGPUWork(); - generateNewPipelineState(2); - submitGPUWork(); - } + // Cache is hot again and we expect 3 hits. + runStep( + [this]() + { + runComputePipeline("shader-cache-tmp-a", "main"); + runComputePipeline("shader-cache-tmp-b", "main"); + runComputePipeline("shader-cache-tmp-c", "main"); - void runTests() - { - generateNewDevice(); - createRequiredResources(); - modifyShaderA(contentsA); - modifyShaderB(contentsB); - modifyShaderC(contentsC); - checkAllCacheEntries(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 3); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - checkAllCacheEntries(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 3); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - modifyShaderA(contentsB); - checkAllCacheEntries(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 2); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1); - - generateNewDevice(); - createRequiredResources(); - modifyShaderA(contentsC); - modifyShaderB(contentsA); - modifyShaderC(contentsB); - checkAllCacheEntries(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 3); + SLANG_CHECK(getStats().missCount == 0); + SLANG_CHECK(getStats().hitCount == 3); + SLANG_CHECK(getStats().entryCount == 6); + } + ); } }; - // One shader file on disk containing several entry points, no modifications are made to the file - struct MultipleEntryPointShader : BaseShaderCacheTest + // Test one shader file on disk with multiple entry points. + struct ShaderCacheTestEntryPoint : ShaderCacheTest { - void generateNewPipelineState(GfxIndex shaderIndex) + void runTests() { - ComPtr shaderProgram; - slang::ProgramLayout* slangReflection; - const char* entryPointName; - switch (shaderIndex) - { - case 0: - entryPointName = "computeA"; - break; - case 1: - entryPointName = "computeB"; - break; - case 2: - entryPointName = "computeC"; - break; - default: - // Should never reach this point since we wrote the test - SLANG_IGNORE_TEST; - } - GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "multiple-entry-point-shader-cache-shader", entryPointName, slangReflection)); + // Cache is cold and we expect 3 misses, one for each entry point. + runStep( + [this]() + { + runComputePipeline("shader-cache-multiple-entry-points", "computeA"); + runComputePipeline("shader-cache-multiple-entry-points", "computeB"); + runComputePipeline("shader-cache-multiple-entry-points", "computeC"); - ComputePipelineStateDesc pipelineDesc = {}; - pipelineDesc.program = shaderProgram.get(); - GFX_CHECK_CALL_ABORT( - device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); - } + SLANG_CHECK(getStats().missCount == 3); + SLANG_CHECK(getStats().hitCount == 0); + SLANG_CHECK(getStats().entryCount == 3); + } + ); - void runTests() - { - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(0); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(1); - submitGPUWork(); - generateNewPipelineState(0); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(2); - submitGPUWork(); - generateNewPipelineState(1); - submitGPUWork(); - generateNewPipelineState(0); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 2); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); + // Cache is hot and we expect 3 hits. + runStep( + [this]() + { + runComputePipeline("shader-cache-multiple-entry-points", "computeA"); + runComputePipeline("shader-cache-multiple-entry-points", "computeB"); + runComputePipeline("shader-cache-multiple-entry-points", "computeC"); + + SLANG_CHECK(getStats().missCount == 0); + SLANG_CHECK(getStats().hitCount == 3); + SLANG_CHECK(getStats().entryCount == 3); + } + ); } }; - // One shader file contains an import/include, direct code modifications are made to the imported file - // This test specifically checks four cases: - // 1. import w/o changes in the imported file - // 2. import w/ changes in the imported file - // 3. #include w/o changes in the included file (the included file is the same as the imported file in the prior step) - // 4. #include w/ changes in the included file - struct ShaderFileImportsShaderCache : BaseShaderCacheTest + // Test cache invalidation due to an import/include file being changed on disk. + struct ShaderCacheTestImportInclude : ShaderCacheTest { String importedContentsA = String( R"( - struct TestFunction + void processElement(RWStructuredBuffer buffer, uint index) { - void simpleElementAdd(RWStructuredBuffer buffer, uint index) - { - var input = buffer[index]; - buffer[index] = input + 1.0f; - } - };)"); + var input = buffer[index]; + buffer[index] = input + 1.0f; + } + )"); String importedContentsB = String( R"( - struct TestFunction + void processElement(RWStructuredBuffer buffer, uint index) { - void simpleElementAdd(RWStructuredBuffer buffer, uint index) - { - var input = buffer[index]; - buffer[index] = input + 2.0f; - } - };)"); + var input = buffer[index]; + buffer[index] = input + 2.0f; + } + )"); String importFile = String( R"( - import test_tmp_imported; - - uniform RWStructuredBuffer buffer; + import shader_cache_tmp_imported; [shader("compute")] [numthreads(4, 1, 1)] - void computeMain( - uint3 sv_dispatchThreadID : SV_DispatchThreadID) + void main( + uint3 sv_dispatchThreadID : SV_DispatchThreadID, + uniform RWStructuredBuffer buffer) { - TestFunction test; - for (uint i = 0; i < 4; ++i) - { - test.simpleElementAdd(buffer, i); - } - })"); + processElement(buffer, sv_dispatchThreadID.x); + } + )"); String includeFile = String( R"( - #include "test-tmp-imported.slang" + #include "shader-cache-tmp-imported.slang" - uniform RWStructuredBuffer buffer; - [shader("compute")] [numthreads(4, 1, 1)] - void computeMain( - uint3 sv_dispatchThreadID : SV_DispatchThreadID) + void main( + uint3 sv_dispatchThreadID : SV_DispatchThreadID, + uniform RWStructuredBuffer buffer) { - TestFunction test; - for (uint i = 0; i < 4; ++i) - { - test.simpleElementAdd(buffer, i); - } + processElement(buffer, sv_dispatchThreadID.x); })"); - void initializeFiles() + void runTests() { - diskFileSystem->saveFile("test-tmp-imported.slang", importedContentsA.getBuffer(), importedContentsA.getLength()); - diskFileSystem->saveFile("test-tmp-importing.slang", importFile.getBuffer(), importFile.getLength()); - } + // Write shader source files. + writeShader(importedContentsA, "shader-cache-tmp-imported.slang"); + writeShader(importFile, "shader-cache-tmp-import.slang"); + writeShader(includeFile, "shader-cache-tmp-include.slang"); - void modifyImportedFile(String importedContents) - { - diskFileSystem->saveFile("test-tmp-imported.slang", importedContents.getBuffer(), importedContents.getLength()); - } + // Cache is cold and we expect 2 misses. + runStep( + [this]() + { + runComputePipeline("shader-cache-tmp-import", "main"); + runComputePipeline("shader-cache-tmp-include", "main"); - void changeImportToInclude() - { - diskFileSystem->saveFile("test-tmp-importing.slang", includeFile.getBuffer(), includeFile.getLength()); - } + SLANG_CHECK(getStats().missCount == 2); + SLANG_CHECK(getStats().hitCount == 0); + SLANG_CHECK(getStats().entryCount == 2); + } + ); - void generateNewPipelineState() - { - ComPtr shaderProgram; - slang::ProgramLayout* slangReflection; - GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "shader-cache-test/test-tmp-importing", "computeMain", slangReflection)); + // Cache is hot and we expect 2 hits. + runStep( + [this]() + { + runComputePipeline("shader-cache-tmp-import", "main"); + runComputePipeline("shader-cache-tmp-include", "main"); - ComputePipelineStateDesc pipelineDesc = {}; - pipelineDesc.program = shaderProgram.get(); - GFX_CHECK_CALL_ABORT( - device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); - } + SLANG_CHECK(getStats().missCount == 0); + SLANG_CHECK(getStats().hitCount == 2); + SLANG_CHECK(getStats().entryCount == 2); + } + ); - void runTests() - { - generateNewDevice(); - createRequiredResources(); - initializeFiles(); - generateNewPipelineState(); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - modifyImportedFile(importedContentsB); - generateNewPipelineState(); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1); - - generateNewDevice(); - createRequiredResources(); - changeImportToInclude(); - generateNewPipelineState(); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1); - - generateNewDevice(); - createRequiredResources(); - modifyImportedFile(importedContentsA); - generateNewPipelineState(); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1); + // Change content of imported/included shader file. + writeShader(importedContentsB, "shader-cache-tmp-imported.slang"); + + // Cache is cold and we expect 2 misses. + runStep( + [this]() + { + runComputePipeline("shader-cache-tmp-import", "main"); + runComputePipeline("shader-cache-tmp-include", "main"); + + SLANG_CHECK(getStats().missCount == 2); + SLANG_CHECK(getStats().hitCount == 0); + SLANG_CHECK(getStats().entryCount == 4); + } + ); + + // Cache is hot and we expect 2 hits. + runStep( + [this]() + { + runComputePipeline("shader-cache-tmp-import", "main"); + runComputePipeline("shader-cache-tmp-include", "main"); + + SLANG_CHECK(getStats().missCount == 0); + SLANG_CHECK(getStats().hitCount == 2); + SLANG_CHECK(getStats().entryCount == 4); + } + ); } }; // One shader featuring multiple kinds of shader objects that can be bound. - struct SpecializationArgsEntries : BaseShaderCacheTest + struct ShaderCacheTestSpecialization : ShaderCacheTest { slang::ProgramLayout* slangReflection; + void createComputePipeline() + { + ComPtr shaderProgram; + + GFX_CHECK_CALL_ABORT( + loadComputeProgram(device, shaderProgram, "shader-cache-specialization", "computeMain", slangReflection)); + + ComputePipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + GFX_CHECK_CALL_ABORT( + device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); + } + void createAddTransformer(IShaderObject** transformer) { slang::TypeReflection* addTransformerType = @@ -627,7 +508,7 @@ namespace gfx_test ShaderCursor(*transformer).getPath("c").setData(&c, sizeof(float)); } - void submitGPUWork(GfxIndex transformerType) + void dispatchComputePipeline(const char* transformerTypeName) { Slang::ComPtr transientHeap; ITransientResourceHeap::Desc transientHeapDesc = {}; @@ -643,23 +524,16 @@ namespace gfx_test auto rootObject = encoder->bindPipeline(pipelineState); - ComPtr transformer; - switch (transformerType) - { - case 0: - createAddTransformer(transformer.writeRef()); - break; - case 1: - createMulTransformer(transformer.writeRef()); - break; - default: - /* Should not get here */ - SLANG_IGNORE_TEST; - } + Slang::ComPtr transformer; + slang::TypeReflection* transformerType = slangReflection->findTypeByName(transformerTypeName); + GFX_CHECK_CALL_ABORT(device->createShaderObject( + transformerType, ShaderObjectContainerType::None, transformer.writeRef())); + + float c = 1.0f; + ShaderCursor(transformer).getPath("c").setData(&c, sizeof(float)); ShaderCursor entryPointCursor(rootObject->getEntryPoint(0)); entryPointCursor.getPath("buffer").setResource(bufferView); - entryPointCursor.getPath("transformer").setObject(transformer); encoder->dispatchCompute(1, 1, 1); @@ -669,78 +543,78 @@ namespace gfx_test queue->waitOnHost(); } - void generateNewPipelineState() + void runComputePipeline(const char* transformerTypeName) { - ComPtr shaderProgram; - - GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "compute-smoke", "computeMain", slangReflection)); - - ComputePipelineStateDesc pipelineDesc = {}; - pipelineDesc.program = shaderProgram.get(); - GFX_CHECK_CALL_ABORT( - device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); - } + createComputeResources(); + createComputePipeline(); + dispatchComputePipeline(transformerTypeName); + freeComputeResources(); + } void runTests() { - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(); - submitGPUWork(0); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(); - submitGPUWork(1); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - } - }; + // Cache is cold and we expect 2 misses. + runStep( + [this]() + { + runComputePipeline("AddTransformer"); + runComputePipeline("MulTransformer"); - // Same gist as the multiple entry point compute shader but with a graphics - // shader file containing a vertex and fragment shader - struct Vertex - { - float position[3]; - }; + SLANG_CHECK(getStats().missCount == 2); + SLANG_CHECK(getStats().hitCount == 0); + SLANG_CHECK(getStats().entryCount == 2); + } + ); - static const int kVertexCount = 3; - static const Vertex kVertexData[kVertexCount] = - { - { 0, 0, 0.5 }, - { 1, 0, 0.5 }, - { 0, 1, 0.5 }, + // Cache is hot and we expect 2 hits. + runStep( + [this]() + { + runComputePipeline("AddTransformer"); + runComputePipeline("MulTransformer"); + + SLANG_CHECK(getStats().missCount == 0); + SLANG_CHECK(getStats().hitCount == 2); + SLANG_CHECK(getStats().entryCount == 2); + } + ); + } }; - struct GraphicsShaderCache : BaseShaderCacheTest + // Same gist as the multiple entry point compute shader but with a graphics + // shader file containing a vertex and fragment shader. + struct ShaderCacheTestGraphics : ShaderCacheTest { - const int kWidth = 256; - const int kHeight = 256; - const Format format = Format::R32G32B32A32_FLOAT; + struct Vertex + { + float position[3]; + }; - ComPtr shaderProgram; - ComPtr renderPass; - ComPtr framebuffer; + static const int kWidth = 256; + static const int kHeight = 256; + static const Format format = Format::R32G32B32A32_FLOAT; ComPtr vertexBuffer; ComPtr colorBuffer; + ComPtr inputLayout; + ComPtr framebufferLayout; + ComPtr renderPass; + ComPtr framebuffer; ComPtr createVertexBuffer(IDevice* device) { + const Vertex vertices[] = { + { 0, 0, 0.5 }, + { 1, 0, 0.5 }, + { 0, 1, 0.5 }, + }; + IBufferResource::Desc vertexBufferDesc; vertexBufferDesc.type = IResource::Type::Buffer; - vertexBufferDesc.sizeInBytes = kVertexCount * sizeof(Vertex); + vertexBufferDesc.sizeInBytes = sizeof(vertices); vertexBufferDesc.defaultState = ResourceState::VertexBuffer; vertexBufferDesc.allowedStates = ResourceState::VertexBuffer; - ComPtr vertexBuffer = device->createBufferResource(vertexBufferDesc, &kVertexData[0]); + ComPtr vertexBuffer = device->createBufferResource(vertexBufferDesc, vertices); SLANG_CHECK_ABORT(vertexBuffer != nullptr); return vertexBuffer; } @@ -761,13 +635,7 @@ namespace gfx_test return colorBuffer; } - void createShaderProgram() - { - slang::ProgramLayout* slangReflection; - GFX_CHECK_CALL_ABORT(loadGraphicsProgram(device, shaderProgram, "shader-cache-graphics", "vertexMain", "fragmentMain", slangReflection)); - } - - void createRequiredResources() + void createGraphicsResources() { VertexStreamDesc vertexStreams[] = { { sizeof(Vertex), InputSlotClass::PerVertex, 0 }, @@ -782,7 +650,7 @@ namespace gfx_test inputLayoutDesc.inputElements = inputElements; inputLayoutDesc.vertexStreamCount = SLANG_COUNT_OF(vertexStreams); inputLayoutDesc.vertexStreams = vertexStreams; - auto inputLayout = device->createInputLayout(inputLayoutDesc); + inputLayout = device->createInputLayout(inputLayoutDesc); SLANG_CHECK_ABORT(inputLayout != nullptr); vertexBuffer = createVertexBuffer(device); @@ -795,18 +663,9 @@ namespace gfx_test IFramebufferLayout::Desc framebufferLayoutDesc; framebufferLayoutDesc.renderTargetCount = 1; framebufferLayoutDesc.renderTargets = &targetLayout; - ComPtr framebufferLayout = device->createFramebufferLayout(framebufferLayoutDesc); + framebufferLayout = device->createFramebufferLayout(framebufferLayoutDesc); SLANG_CHECK_ABORT(framebufferLayout != nullptr); - GraphicsPipelineStateDesc pipelineDesc = {}; - pipelineDesc.program = shaderProgram.get(); - pipelineDesc.inputLayout = inputLayout; - pipelineDesc.framebufferLayout = framebufferLayout; - pipelineDesc.depthStencil.depthTestEnable = false; - pipelineDesc.depthStencil.depthWriteEnable = false; - GFX_CHECK_CALL_ABORT( - device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef())); - IRenderPassLayout::Desc renderPassDesc = {}; renderPassDesc.framebufferLayout = framebufferLayout; renderPassDesc.renderTargetCount = 1; @@ -833,7 +692,35 @@ namespace gfx_test GFX_CHECK_CALL_ABORT(device->createFramebuffer(framebufferDesc, framebuffer.writeRef())); } - void submitGPUWork() + void freeGraphicsResources() + { + inputLayout = nullptr; + framebufferLayout = nullptr; + renderPass = nullptr; + framebuffer = nullptr; + vertexBuffer = nullptr; + colorBuffer = nullptr; + pipelineState = nullptr; + } + + void createGraphicsPipeline() + { + ComPtr shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT( + loadGraphicsProgram(device, shaderProgram, "shader-cache-graphics", "vertexMain", "fragmentMain", slangReflection)); + + GraphicsPipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + pipelineDesc.inputLayout = inputLayout; + pipelineDesc.framebufferLayout = framebufferLayout; + pipelineDesc.depthStencil.depthTestEnable = false; + pipelineDesc.depthStencil.depthWriteEnable = false; + GFX_CHECK_CALL_ABORT( + device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef())); + } + + void dispatchGraphicsPipeline() { ComPtr transientHeap; ITransientResourceHeap::Desc transientHeapDesc = {}; @@ -857,28 +744,50 @@ namespace gfx_test encoder->setVertexBuffer(0, vertexBuffer); encoder->setPrimitiveTopology(PrimitiveTopology::TriangleList); - encoder->draw(kVertexCount); + encoder->draw(3); encoder->endEncoding(); commandBuffer->close(); queue->executeCommandBuffer(commandBuffer); queue->waitOnHost(); } + void runGraphicsPipeline() + { + createGraphicsResources(); + createGraphicsPipeline(); + dispatchGraphicsPipeline(); + freeGraphicsResources(); + } + void runTests() { - generateNewDevice(); - createShaderProgram(); - createRequiredResources(); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 2); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); + // Cache is cold and we expect 2 misses (2 entry points). + runStep( + [this]() + { + runGraphicsPipeline(); + + SLANG_CHECK(getStats().missCount == 2); + SLANG_CHECK(getStats().hitCount == 0); + SLANG_CHECK(getStats().entryCount == 2); + } + ); + + // Cache is hot and we expect 2 hits. + runStep( + [this]() + { + runGraphicsPipeline(); + + SLANG_CHECK(getStats().missCount == 0); + SLANG_CHECK(getStats().hitCount == 2); + SLANG_CHECK(getStats().entryCount == 2); + } + ); } }; - // Same as GraphicsShaderCache, but instead of having a singular file containing both a vertex and fragment shader, we + // Same as ShaderCacheTestGraphics, but instead of having a singular file containing both a vertex and fragment shader, we // now have two separate shader files, one containing the vertex shader and the other the fragment with the same // names, with the expectation that we should record cache misses for both fetches. // @@ -890,54 +799,38 @@ namespace gfx_test // // We do not actively test geometry shaders here, but it is simply an extension of this test and should be expected // to behave similarly. - struct SplitGraphicsShader : GraphicsShaderCache + struct ShaderCacheTestGraphicsSplit : ShaderCacheTestGraphics { - void createShaderProgram() - { - slang::ProgramLayout* slangReflection; - const char* moduleNames[] = { "split-graphics-vertex", "split-graphics-fragment" }; - GFX_CHECK_CALL_ABORT(loadSplitGraphicsProgram(device, shaderProgram, moduleNames, "main", "main", slangReflection)); - } - - Result loadSplitGraphicsProgram( - IDevice* device, - ComPtr& outShaderProgram, - const char** shaderModuleNames, - const char* vertexEntryPointName, - const char* fragmentEntryPointName, - slang::ProgramLayout*& slangReflection) + void createGraphicsPipeline() { ComPtr slangSession; - SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef())); + GFX_CHECK_CALL_ABORT(device->getSlangSession(slangSession.writeRef())); - ComPtr diagnosticsBlob; - slang::IModule* vertexModule = slangSession->loadModule(shaderModuleNames[0], diagnosticsBlob.writeRef()); - if (!vertexModule) - return SLANG_FAIL; - slang::IModule* fragmentModule = slangSession->loadModule(shaderModuleNames[1], diagnosticsBlob.writeRef()); - if (!fragmentModule) - return SLANG_FAIL; + slang::IModule* vertexModule = slangSession->loadModule("shader-cache-graphics-vertex"); + SLANG_CHECK_ABORT(vertexModule); + slang::IModule* fragmentModule = slangSession->loadModule("shader-cache-graphics-fragment"); + SLANG_CHECK_ABORT(fragmentModule); ComPtr vertexEntryPoint; - SLANG_RETURN_ON_FAIL( - vertexModule->findEntryPointByName(vertexEntryPointName, vertexEntryPoint.writeRef())); + GFX_CHECK_CALL_ABORT( + vertexModule->findEntryPointByName("main", vertexEntryPoint.writeRef())); ComPtr fragmentEntryPoint; - SLANG_RETURN_ON_FAIL( - fragmentModule->findEntryPointByName(fragmentEntryPointName, fragmentEntryPoint.writeRef())); + GFX_CHECK_CALL_ABORT( + fragmentModule->findEntryPointByName("main", fragmentEntryPoint.writeRef())); Slang::List componentTypes; componentTypes.add(vertexModule); componentTypes.add(fragmentModule); Slang::ComPtr composedProgram; - SlangResult result = slangSession->createCompositeComponentType( - componentTypes.getBuffer(), - componentTypes.getCount(), - composedProgram.writeRef(), - diagnosticsBlob.writeRef()); - SLANG_RETURN_ON_FAIL(result); - slangReflection = composedProgram->getLayout(); + GFX_CHECK_CALL_ABORT( + slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef())); + + slang::ProgramLayout* slangReflection = composedProgram->getLayout(); Slang::List entryPoints; entryPoints.add(vertexEntryPoint); @@ -949,263 +842,60 @@ namespace gfx_test programDesc.entryPointCount = 2; programDesc.slangEntryPoints = entryPoints.getBuffer(); - auto shaderProgram = device->createProgram(programDesc); - - outShaderProgram = shaderProgram; - return SLANG_OK; - } - - void runTests() - { - generateNewDevice(); - createShaderProgram(); - createRequiredResources(); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 2); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - } - }; - - // Same as MultipleEntryShaderCache, but we now set the maximum entry count limit, so the cache - // should remove entries as needed when it reaches capacity. - // - // This test does not modify shaders as other tests already test this, instead focusing on checking - // that entries are correctly removed as cache limits are reached and that entries are always in - // the right order. - // - // As opening multiple streams to the same file is dependent on the OS, this test is run on the - // in-memory file system. Cache eviction policy with an on-disk file system will need to be inspected - // manually. - struct CacheWithMaxEntryLimit : MultipleEntryShaderCache - { - List test0Lines; // C -> B -> A - List test1Lines; // C -> B - List test2Lines; // A -> B - List test3Lines; // A -> C - List test4Lines; // C -> B -> A - List entryKeys; // C, B, A - - void getCacheFile(List& lines) - { - ComPtr contentsBlob; - memoryFileSystem->loadFile(shaderCache.cacheFilename, contentsBlob.writeRef()); - List temp; - StringUtil::calcLines(UnownedStringSlice((char*)contentsBlob->getBufferPointer()), temp); - for (auto line : temp) - { - if (line.trim().getLength() != 0) - lines.add(line); - } - } - - // Check the correctness of the cache's entries by comparing the order of entries in the - // current state of the cache with what we expect. - void checkCacheFiles() - { - // Check that shader A appears where we expect it to. - SLANG_CHECK(test2Lines[0] == test3Lines[0]); - SLANG_CHECK(test2Lines[0] == test4Lines[2]); - - // Check that shader B appears where we expect it to. - SLANG_CHECK(test1Lines[1] == test2Lines[1]); - SLANG_CHECK(test1Lines[1] == test4Lines[1]); - - // Check that shader C appears where we expect it to. - SLANG_CHECK(test1Lines[0] == test3Lines[1]); - SLANG_CHECK(test1Lines[0] == test4Lines[0]); - } - - // Cache limit 3, three unique shaders - void runTest0() - { - shaderCache.entryCountLimit = 3; - generateNewDevice(); - createRequiredResources(); - modifyShaderA(contentsA); - modifyShaderB(contentsB); - modifyShaderC(contentsC); - generateNewPipelineState(0); - submitGPUWork(); - generateNewPipelineState(1); - submitGPUWork(); - generateNewPipelineState(2); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 3); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - // This needs to be called in order to force the cache file to be updated, otherwise we will - // be unable to perform the necessary checks. - freeOldResources(); - - getCacheFile(test0Lines); - SLANG_CHECK(test0Lines.getCount() == 3); - - // This segment also doubles as the point where we fetch the keys for all three shaders - // to use in later checks. - for (auto line : test0Lines) - { - List digests; - StringUtil::split(line.getUnownedSlice(), ' ', digests); - if (digests.getCount() != 2) - continue; - entryKeys.add(digests[0]); - } - - ComPtr unused; - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef()))); - } - - // Cache limit 2, access shaders A then B then C - void runTest1() - { - shaderCache.entryCountLimit = 2; - generateNewDevice(); - createRequiredResources(); - modifyShaderA(contentsA); - modifyShaderB(contentsB); - modifyShaderC(contentsC); - generateNewPipelineState(0); - submitGPUWork(); - generateNewPipelineState(1); - submitGPUWork(); - generateNewPipelineState(2); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 3); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - freeOldResources(); - - getCacheFile(test1Lines); - SLANG_CHECK(test1Lines.getCount() == 2); - - ComPtr unused; - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_FAILED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef()))); - } - - // Cache limit 2, access shaders B and then A - void runTest2() - { - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(1); - submitGPUWork(); - generateNewPipelineState(0); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - freeOldResources(); - - getCacheFile(test2Lines); - SLANG_CHECK(test2Lines.getCount() == 2); - - ComPtr unused; - SLANG_CHECK(SLANG_FAILED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef()))); - } + ComPtr shaderProgram = device->createProgram(programDesc); - // Cache limit 2, access shaders C and then A - void runTest3() - { - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(2); - submitGPUWork(); - generateNewPipelineState(0); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - freeOldResources(); - - getCacheFile(test3Lines); - SLANG_CHECK(test3Lines.getCount() == 2); - - ComPtr unused; - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_FAILED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef()))); + GraphicsPipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + pipelineDesc.inputLayout = inputLayout; + pipelineDesc.framebufferLayout = framebufferLayout; + pipelineDesc.depthStencil.depthTestEnable = false; + pipelineDesc.depthStencil.depthWriteEnable = false; + GFX_CHECK_CALL_ABORT( + device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef())); } - // Cache limit 3, access shaders A then B then C - void runTest4() + void runGraphicsPipeline() { - shaderCache.entryCountLimit = 3; - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(0); - submitGPUWork(); - generateNewPipelineState(1); - submitGPUWork(); - generateNewPipelineState(2); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 2); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - freeOldResources(); - - getCacheFile(test4Lines); - SLANG_CHECK(test4Lines.getCount() == 3); - - ComPtr unused; - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef()))); - SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef()))); - } + createGraphicsResources(); + createGraphicsPipeline(); + dispatchGraphicsPipeline(); + freeGraphicsResources(); + } void runTests() { - runTest0(); - runTest1(); - runTest2(); - runTest3(); - runTest4(); + // Cache is cold and we expect 2 misses (2 entry points). + runStep( + [this]() + { + runGraphicsPipeline(); - checkCacheFiles(); - } + SLANG_CHECK(getStats().missCount == 2); + SLANG_CHECK(getStats().hitCount == 0); + SLANG_CHECK(getStats().entryCount == 2); + } + ); - void run() - { - shaderCache.shaderCacheFileSystem = memoryFileSystem; - runTests(); + // Cache is hot and we expect 2 hits. + runStep( + [this]() + { + runGraphicsPipeline(); - cleanUpFiles(); - } + SLANG_CHECK(getStats().missCount == 0); + SLANG_CHECK(getStats().hitCount == 2); + SLANG_CHECK(getStats().entryCount == 2); + } + ); } }; - // This test is specifically for source files which live entirely in memory. The key difference between - // these and physical source files is such files have their contents hash added to the file dependencies - // list instead of a file path, meaning any given specific set of shader contents will be treated as a - // wholly unique module. - struct NonPhysicalFileDependencyEntry : BaseShaderCacheTest + // Test caching of shaders that are compiled from source strings instead of files. + struct ShaderCacheTestSourceString : ShaderCacheTest { - void generateNewPipelineState(Slang::String shaderContents) + void createComputePipeline(Slang::String shaderSource) { ComPtr shaderProgram; - GFX_CHECK_CALL_ABORT(loadComputeProgramFromSource(device, shaderProgram, shaderContents)); + GFX_CHECK_CALL_ABORT(loadComputeProgramFromSource(device, shaderProgram, shaderSource)); ComputePipelineStateDesc pipelineDesc = {}; pipelineDesc.program = shaderProgram.get(); @@ -1213,135 +903,120 @@ namespace gfx_test device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); } - void runTests() + void runComputePipeline(Slang::String shaderSource) { - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(contentsA); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(contentsA); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); - - generateNewDevice(); - createRequiredResources(); - generateNewPipelineState(contentsC); - submitGPUWork(); - - device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef()); - SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1); - SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0); - SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0); + createComputeResources(); + createComputePipeline(shaderSource); + dispatchComputePipeline(); + freeComputeResources(); } - }; - template - void shaderCacheTestImpl(ComPtr device, UnitTestContext* context) - { - T test; - test.init(device, context); - test.run(); - } + void runTests() + { + // Cache is cold and we expect 3 misses. + runStep( + [this]() + { + runComputePipeline(computeShaderA); + runComputePipeline(computeShaderB); + runComputePipeline(computeShaderC); - SLANG_UNIT_TEST(singleEntryShaderCacheD3D12) - { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); - } + SLANG_CHECK(getStats().missCount == 3); + SLANG_CHECK(getStats().hitCount == 0); + SLANG_CHECK(getStats().entryCount == 3); + } + ); - SLANG_UNIT_TEST(singleEntryShaderCacheVulkan) - { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); - } + // Cache is hot and we expect 3 hits. + runStep( + [this]() + { + runComputePipeline(computeShaderA); + runComputePipeline(computeShaderB); + runComputePipeline(computeShaderC); - SLANG_UNIT_TEST(multipleEntryShaderCacheD3D12) - { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); - } + SLANG_CHECK(getStats().missCount == 0); + SLANG_CHECK(getStats().hitCount == 3); + SLANG_CHECK(getStats().entryCount == 3); + } + ); + } + }; - SLANG_UNIT_TEST(multipleEntryShaderCacheVulkan) + template + void runTest(UnitTestContext* context, Slang::RenderApiFlag::Enum api) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + T test; + test.run(context, api); } - SLANG_UNIT_TEST(multipleEntryPointShaderCacheD3D12) + SLANG_UNIT_TEST(shaderCacheBasicD3D12) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + runTest(unitTestContext, Slang::RenderApiFlag::D3D12); } - SLANG_UNIT_TEST(multipleEntryPointShaderCacheVulkan) + SLANG_UNIT_TEST(shaderCacheBasicVulkan) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + runTest(unitTestContext, Slang::RenderApiFlag::Vulkan); } - SLANG_UNIT_TEST(shaderFileImportsShaderCacheD3D12) + SLANG_UNIT_TEST(shaderCacheEntryPointD3D12) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + runTest(unitTestContext, Slang::RenderApiFlag::D3D12); } - SLANG_UNIT_TEST(shaderFileImportsShaderCacheVulkan) + SLANG_UNIT_TEST(shaderCacheEntryPointVulkan) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + runTest(unitTestContext, Slang::RenderApiFlag::Vulkan); } - SLANG_UNIT_TEST(specializationArgsShaderCacheD3D12) + SLANG_UNIT_TEST(shaderCacheImportIncludeD3D12) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + runTest(unitTestContext, Slang::RenderApiFlag::D3D12); } - SLANG_UNIT_TEST(specializationArgsShaderCacheVulkan) + SLANG_UNIT_TEST(shaderCacheImportIncludeVulkan) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + runTest(unitTestContext, Slang::RenderApiFlag::Vulkan); } - SLANG_UNIT_TEST(cacheEvictionPolicyD3D12) + SLANG_UNIT_TEST(shaderCacheSpecializationD3D12) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + runTest(unitTestContext, Slang::RenderApiFlag::D3D12); } - SLANG_UNIT_TEST(cacheEvictionPolicyVulkan) + SLANG_UNIT_TEST(shaderCacheSpecializationVulkan) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + runTest(unitTestContext, Slang::RenderApiFlag::Vulkan); } - SLANG_UNIT_TEST(graphicsShaderCacheD3D12) + SLANG_UNIT_TEST(shaderCacheGraphicsD3D12) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + runTest(unitTestContext, Slang::RenderApiFlag::D3D12); } - SLANG_UNIT_TEST(graphicsShaderCacheVulkan) + SLANG_UNIT_TEST(shaderCacheGraphicsVulkan) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + runTest(unitTestContext, Slang::RenderApiFlag::Vulkan); } - SLANG_UNIT_TEST(splitGraphicsShaderCacheD3D12) + SLANG_UNIT_TEST(shaderCacheGraphicsSplitD3D12) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + runTest(unitTestContext, Slang::RenderApiFlag::D3D12); } - SLANG_UNIT_TEST(splitGraphicsShaderCacheVulkan) + SLANG_UNIT_TEST(shaderCacheGraphicsSplitVulkan) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + runTest(unitTestContext, Slang::RenderApiFlag::Vulkan); } - SLANG_UNIT_TEST(nonPhysicalFileDependenciesCacheEntryD3D12) + SLANG_UNIT_TEST(shaderCacheSourceStringD3D12) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + runTest(unitTestContext, Slang::RenderApiFlag::D3D12); } - SLANG_UNIT_TEST(nonPhysicalFileDependenciesCacheEntryVulkan) + SLANG_UNIT_TEST(shaderCacheSourceStringVulkan) { - runTestImpl(shaderCacheTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + runTest(unitTestContext, Slang::RenderApiFlag::Vulkan); } } diff --git a/tools/gfx-unit-test/split-graphics-fragment.slang b/tools/gfx-unit-test/split-graphics-fragment.slang deleted file mode 100644 index db515a957..000000000 --- a/tools/gfx-unit-test/split-graphics-fragment.slang +++ /dev/null @@ -1,24 +0,0 @@ -// split-graphics-fragment.slang - -// Output of the vertex shader, and input to the fragment shader. -struct CoarseVertex -{ - float3 color; -}; - -// Output of the fragment shader -struct Fragment -{ - float4 color; -}; - -// Fragment Shader - -[shader("fragment")] -float4 main( - CoarseVertex coarseVertex : CoarseVertex) : SV_Target -{ - float3 color = coarseVertex.color; - - return float4(color, 1.0); -} diff --git a/tools/gfx-unit-test/split-graphics-vertex.slang b/tools/gfx-unit-test/split-graphics-vertex.slang deleted file mode 100644 index 615686a90..000000000 --- a/tools/gfx-unit-test/split-graphics-vertex.slang +++ /dev/null @@ -1,36 +0,0 @@ -// split-graphics-vertex.slang - -// Per-vertex attributes to be assembled from bound vertex buffers. -struct AssembledVertex -{ - float3 position : POSITION; -}; - -// Output of the vertex shader, and input to the fragment shader. -struct CoarseVertex -{ - float3 color; -}; - -// Vertex Shader - -struct VertexStageOutput -{ - CoarseVertex coarseVertex : CoarseVertex; - float4 sv_position : SV_Position; -}; - -[shader("vertex")] -VertexStageOutput main( - AssembledVertex assembledVertex) -{ - VertexStageOutput output; - - float3 position = assembledVertex.position; - float3 color = float3(1.0, 0.0, 0.0); - - output.coarseVertex.color = color; - output.sv_position = float4(position, 1.0); - - return output; -} diff --git a/tools/gfx/gfx.slang b/tools/gfx/gfx.slang index b4bd76470..3d75e3b40 100644 --- a/tools/gfx/gfx.slang +++ b/tools/gfx/gfx.slang @@ -1712,15 +1712,10 @@ struct SlangDesc struct ShaderCacheDesc { - // The filename for the file the cache's state should be saved to or loaded from. - NativeString cacheFilename = "cache.txt"; - // The root directory for the shader cache. + // The root directory for the shader cache. If not set, shader cache is disabled. NativeString shaderCachePath; - // The file system for loading cached shader kernels. The layer does not maintain a strong reference to the object, - // instead the user is responsible for holding the object alive during the lifetime of an `IDevice`. - void* shaderCacheFileSystem = nullptr; // The maximum number of entries stored in the cache. - GfxCount entryCountLimit = 0; + GfxCount maxEntryCount = 0; }; struct DeviceInteropHandles @@ -1934,6 +1929,21 @@ interface IDevice Result getTextureRowAlignment(out Size outAlignment); }; +struct ShaderCacheStats +{ + GfxCount hitCount; + GfxCount missCount; + GfxCount entryCount; +}; + +[COM("715bdf26-5135-11eb-AE93-02-42-AC-13-00-02")] +interface IShaderCache +{ + Result clearShaderCache(); + Result getShaderCacheStats(out ShaderCacheStats outStats); + Result resetShaderCacheStats(); +}; + #define SLANG_GFX_IMPORT [DllImport("gfx")] /// Checks if format is compressed SLANG_GFX_IMPORT bool gfxIsCompressedFormat(Format format); diff --git a/tools/gfx/persistent-shader-cache.cpp b/tools/gfx/persistent-shader-cache.cpp deleted file mode 100644 index 7dc64632b..000000000 --- a/tools/gfx/persistent-shader-cache.cpp +++ /dev/null @@ -1,316 +0,0 @@ -// slang-shader-cache-index.cpp -#include "persistent-shader-cache.h" - -#include "../../source/core/slang-io.h" -#include "../../source/core/slang-string-util.h" -#include "../../source/core/slang-file-system.h" - -#include "../../source/core/slang-char-util.h" - -#include - -namespace gfx -{ - -using namespace std::chrono; - -PersistentShaderCache::PersistentShaderCache(const IDevice::ShaderCacheDesc& inDesc) -{ - desc = inDesc; - - // If a path is provided, we will want our underlying file system to be initialized using that path. - if (desc.shaderCachePath) - { - if (!desc.shaderCacheFileSystem) - { - // Only a path was provided, so we get a mutable file system - // using OSFileSystem::getMutableSingleton. - desc.shaderCacheFileSystem = OSFileSystem::getMutableSingleton(); - } - desc.shaderCacheFileSystem = new RelativeFileSystem(desc.shaderCacheFileSystem, desc.shaderCachePath); - } - - // If our shader cache has an underlying file system, check if it's mutable. If so, store a pointer - // to the mutable version for operations which require writing to disk. - if (desc.shaderCacheFileSystem) - { - desc.shaderCacheFileSystem->queryInterface(ISlangMutableFileSystem::getTypeGuid(), (void**)mutableShaderCacheFileSystem.writeRef()); - } - - loadCacheFromFile(); -} - -PersistentShaderCache::~PersistentShaderCache() -{ - if (isMemoryFileSystem) - { - saveCacheToMemory(); - } -} - -// Load a previous cache index saved to disk. If not found, create a new cache index -// and save it to disk as filename. -void PersistentShaderCache::loadCacheFromFile() -{ - // We will need to combine the filename with the cache path in order to have the correct - // file path for initializing the stream. This needs to be done separately because there - // is no guarantee that the underlying file system is mutable. - String filePath; - if (mutableShaderCacheFileSystem) - { - ComPtr fullPath; - if (SLANG_FAILED(mutableShaderCacheFileSystem->getPath(PathKind::OperatingSystem, desc.cacheFilename, fullPath.writeRef()))) - { - // If we fail to obtain a physical file path, then this must be a MemoryFileSystem. In this case, file streams - // will not work as they require the file to be on disk, so we will rely on a fall back implementation. - isMemoryFileSystem = true; - loadCacheFromMemory(); - return; - } - filePath = String((char*)fullPath->getBufferPointer()); - } - else - { - filePath = Path::combine(String(desc.shaderCachePath), String(desc.cacheFilename)); - } - - if (SLANG_FAILED(indexStream.init(filePath, FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite))) - { - // If we failed to open a stream to the file, then the file does not yet exist on disk. - // We will create the index file if our underlying file system is mutable. - if (mutableShaderCacheFileSystem) - { - indexStream.init(filePath, FileMode::Create, FileAccess::ReadWrite, FileShare::ReadWrite); - } - return; - } - else - { - const auto start = indexStream.getPosition(); - indexStream.seek(SeekOrigin::End, 0); - const auto end = indexStream.getPosition(); - indexStream.seek(SeekOrigin::Start, 0); - const Index numEntries = (Index)(end - start) / sizeof(ShaderCacheEntry); - - if (desc.entryCountLimit > 0 && numEntries > desc.entryCountLimit) - { - // If the size limit for the current cache is smaller than the cache that produced the file we're trying to - // load, re-create the entire file. - // - // FileStream does not currently have any methods for truncating an existing file, so in this case, our cache - // index would no longer accurately reflect the state of our cache due to the extra now-garbage lines present. - // While this has no impact on cache operation, it could be problematic for debugging purposes, etc. - indexStream.close(); - indexStream.init(filePath, FileMode::Create, FileAccess::ReadWrite, FileShare::ReadWrite); - return; - } - else - { - // The cache index is not guaranteed to be ordered by most recent access, so we need a temporary list to store - // all the entries in order to sort them before filling in our linked list. - List tempEntries; - tempEntries.setCount(numEntries); - size_t bytesRead; - indexStream.read(tempEntries.getBuffer(), sizeof(ShaderCacheEntry) * numEntries, bytesRead); - - // We will need to sort tempEntries by last accessed time before we can add entries to our linked list. - tempEntries.quickSort(tempEntries.getBuffer(), 0, tempEntries.getCount() - 1, [](ShaderCacheEntry a, ShaderCacheEntry b) { return a.lastAccessedTime > b.lastAccessedTime; }); - for (auto& entry : tempEntries) - { - // If we reach this point, then the current cache is at least the same size in entries as the cache - // that produced the index we're reading in, so we don't need to check if we're exceeding capacity. - auto entryIndexNode = orderedEntries.AddLast(entries.getCount()); - entries.add(entry); - keyToEntry.Add(entry.dependencyBasedDigest, entryIndexNode); - } - } - } -} - -ShaderCacheEntry* PersistentShaderCache::findEntry(const DigestType& key, ISlangBlob** outCompiledCode) -{ - LinkedNode* entryIndexNode; - if (!keyToEntry.TryGetValue(key, entryIndexNode)) - { - // The key was not found in the cache, so we return nullptr. - *outCompiledCode = nullptr; - return nullptr; - } - - // If the key is found, load the stored contents from disk. We then move the corresponding - // entry to the front of the linked list and update the cache file on disk - desc.shaderCacheFileSystem->loadFile(key.toString().getBuffer(), outCompiledCode); - auto index = entryIndexNode->Value; - entries[index].lastAccessedTime = (double)high_resolution_clock::now().time_since_epoch().count(); - if (orderedEntries.FirstNode() != entryIndexNode) - { - orderedEntries.RemoveFromList(entryIndexNode); - orderedEntries.AddFirst(entryIndexNode); - if (mutableShaderCacheFileSystem && !isMemoryFileSystem) - { - auto offset = index * sizeof(ShaderCacheEntry); - indexStream.seek(SeekOrigin::Start, offset + 2 * sizeof(DigestType)); - indexStream.write(&entries[index].lastAccessedTime, sizeof(double)); - indexStream.flush(); - } - } - return &entries[index]; -} - -void PersistentShaderCache::addEntry(const DigestType& dependencyDigest, const DigestType& contentsDigest, ISlangBlob* compiledCode) -{ - if (!mutableShaderCacheFileSystem) - { - // Should not save new entries if the underlying file system isn't mutable. - return; - } - - // Check that we do not exceed the cache's size limit by adding another entry. If so, - // remove the least recently used entry first. - // - // In theory, the cache could be more than just one entry over the entry count limit. - // However, this is impossible in practice because we fully re-create the entry list - // and cache index file if the size of the current cache is smaller than the cache - // that generated the index file we loaded. In any case, the initial number of entries - // in the cache will always be fewer than the size limit and this check will be hit - // on the first entry added that exceeds the cache's size. - Index index = entries.getCount(); - if (desc.entryCountLimit > 0 && orderedEntries.Count() >= desc.entryCountLimit) - { - index = deleteLRUEntry(); - } - - auto lastAccessedTime = (double)high_resolution_clock::now().time_since_epoch().count(); - - ShaderCacheEntry entry = { dependencyDigest, contentsDigest, lastAccessedTime }; - auto entryNode = orderedEntries.AddFirst(index); - if (index == entries.getCount()) - { - // No entries were removed, so we can tack this entry on at the end. - entries.add(entry); - } - else - { - // An entry was deleted, so we overwrite that slot with the new entry. - entries[index] = entry; - } - keyToEntry.Add(dependencyDigest, entryNode); - - mutableShaderCacheFileSystem->saveFileBlob(dependencyDigest.toString().getBuffer(), compiledCode); - - if (!isMemoryFileSystem) - { - indexStream.seek(SeekOrigin::End, 0); - indexStream.write(&entry, sizeof(ShaderCacheEntry)); - indexStream.flush(); - } -} - -void PersistentShaderCache::updateEntry( - const DigestType& dependencyDigest, - const DigestType& contentsDigest, - ISlangBlob* updatedCode) -{ - if (!mutableShaderCacheFileSystem) - { - // Updating entries requires saving to disk in order to overwrite the old shader file - // on disk, so we return if the underlying file system isn't mutable. - return; - } - - // Unlike in addEntry(), we only update the contents digest here because the last accessed time will have already - // been updated while finding the entry. - auto entryIndexNode = *keyToEntry.TryGetValue(dependencyDigest); - auto index = entryIndexNode->Value; - entries[index].contentsBasedDigest = contentsDigest; - mutableShaderCacheFileSystem->saveFileBlob(dependencyDigest.toString().getBuffer(), updatedCode); - - if (!isMemoryFileSystem) - { - auto offset = index * sizeof(ShaderCacheEntry); - indexStream.seek(SeekOrigin::Start, offset + sizeof(DigestType)); - indexStream.write(&contentsDigest, sizeof(DigestType)); - indexStream.flush(); - } -} - -Index PersistentShaderCache::deleteLRUEntry() -{ - if (!mutableShaderCacheFileSystem) - { - // This is here as a safety precaution but should never be hit as - // addEntry() and its memory-based equivalent are the only functions - // that should call this. - return -1; - } - - auto lruEntry = orderedEntries.LastNode(); - auto index = lruEntry->Value; - auto shaderKey = entries[index].dependencyBasedDigest; - - keyToEntry.Remove(shaderKey); - mutableShaderCacheFileSystem->remove(shaderKey.toString().getBuffer()); - - orderedEntries.Delete(lruEntry); - return index; -} - -// An in-memory file system cannot utilize file streaming to update the index file in place. -// Consequently, the cache index file is updated once on exit and is guaranteed to maintain the -// correct order of entries from most to least recently used. However, any kind of interruption -// in program execution that results in the cache destructor not being called will result in an -// inaccurate cache index. -// -// These currently assume that the underlying file system must be a MemoryFileSystem as this is the -// only in-memory file system that currently exists in Slang, which is guaranteed to be mutable. -// Mutability checks will need to be added if this changes in the future. -void PersistentShaderCache::loadCacheFromMemory() -{ - ComPtr indexBlob; - if (SLANG_FAILED(mutableShaderCacheFileSystem->loadFile(desc.cacheFilename, indexBlob.writeRef()))) - { - mutableShaderCacheFileSystem->saveFile(desc.cacheFilename, nullptr, 0); - return; - } - - auto indexString = UnownedStringSlice((char*)indexBlob->getBufferPointer()); - - List lines; - StringUtil::calcLines(indexString, lines); - for (auto line : lines) - { - List entryFields; - StringUtil::split(line, ' ', entryFields); - if (entryFields.getCount() != 2) - continue; - - ShaderCacheEntry entry; - entry.dependencyBasedDigest = DigestType(entryFields[0]); - entry.contentsBasedDigest = DigestType(entryFields[1]); - entry.lastAccessedTime = 0; - - auto entryNode = orderedEntries.AddLast(entries.getCount()); - entries.add(entry); - keyToEntry.Add(entry.dependencyBasedDigest, entryNode); - - if (desc.entryCountLimit > 0 && orderedEntries.Count() == desc.entryCountLimit) - break; - } -} - -void PersistentShaderCache::saveCacheToMemory() -{ - StringBuilder indexSb; - for (auto& entryIndex : orderedEntries) - { - auto entry = entries[entryIndex]; - indexSb << entry.dependencyBasedDigest.toString(); - indexSb << " "; - indexSb << entry.contentsBasedDigest.toString(); - indexSb << "\n"; - } - - mutableShaderCacheFileSystem->saveFile(desc.cacheFilename, indexSb.getBuffer(), indexSb.getLength()); -} - -} diff --git a/tools/gfx/persistent-shader-cache.h b/tools/gfx/persistent-shader-cache.h deleted file mode 100644 index 530d50a58..000000000 --- a/tools/gfx/persistent-shader-cache.h +++ /dev/null @@ -1,99 +0,0 @@ -// slang-shader-cache-index.h -#pragma once -#include "../../slang.h" -#include "../../slang-gfx.h" -#include "../../slang-com-ptr.h" - -#include "../../source/core/slang-string.h" -#include "../../source/core/slang-dictionary.h" -#include "../../source/core/slang-linked-list.h" -#include "../../source/core/slang-stream.h" -#include "../../source/core/slang-crypto.h" - -namespace gfx -{ - -using namespace Slang; - -using DigestType = MD5::Digest; - -struct ShaderCacheEntry -{ - DigestType dependencyBasedDigest; - DigestType contentsBasedDigest; - double lastAccessedTime; - - bool operator==(const ShaderCacheEntry& rhs) - { - return dependencyBasedDigest == rhs.dependencyBasedDigest - && contentsBasedDigest == rhs.contentsBasedDigest - && lastAccessedTime == rhs.lastAccessedTime; - } - - uint32_t getHashCode() - { - return dependencyBasedDigest.getHashCode(); - } -}; - -class PersistentShaderCache : public RefObject -{ -public: - PersistentShaderCache(const IDevice::ShaderCacheDesc& inDesc); - ~PersistentShaderCache(); - - // Fetch the cache entry corresponding to the provided key. If found, move the entry to - // the front of entries and return the entry and the corresponding compiled code in - // outCompiledCode. Else, return nullptr. - ShaderCacheEntry* findEntry(const DigestType& key, ISlangBlob** outCompiledCode); - - // Add an entry to the cache with the provided key and contents hashes. If - // adding an entry causes the cache to exceed size limitations, this will also - // delete the least recently used entry. - void addEntry(const DigestType& dependencyDigest, const DigestType& contentsDigest, ISlangBlob* compiledCode); - - // Update the contents hash for the specified entry in the cache and update the - // corresponding file on disk. - void updateEntry(const DigestType& dependencyDigest, const DigestType& contentsDigest, ISlangBlob* updatedCode); - -private: - // Load a previous cache index saved to disk. If not found, create a new cache index - // and save it to disk as filename. - void loadCacheFromFile(); - - // Delete the last entry (the least recently used) from entries, remove its key/value pair - // from keyToEntry, and remove the corresponding file on disk. Returns the index in 'entries' - // of the removed entry so addEntry() can overwrite the corresponding entry in 'entries' - // with the new entry. This should only be called by addEntry() when the cache reaches maximum capacity. - Index deleteLRUEntry(); - - // Without access to a physical file path, in-memory file systems cannot leverage file streams and - // need to fall back on a different implementation for loading and saving the cache to memory. - void loadCacheFromMemory(); - void saveCacheToMemory(); - - // The shader cache's description. - IDevice::ShaderCacheDesc desc; - - // The underlying file system used for the shader cache. - ComPtr mutableShaderCacheFileSystem = nullptr; - bool isMemoryFileSystem = false; - - // A file stream to the index file opened during cache load. This will only - // exist for a cache that exists on-disk. - FileStream indexStream; - - // Dictionary mapping each shader's key to its corresponding node (entry) in the - // linked list 'orderedEntries'. - Dictionary*> keyToEntry; - - // Linked list containing the corresponding indices in 'entries' for entries in the - // shader cache ordered from most to least recently used. - LinkedList orderedEntries; - - // List of entries in the shader cache. This list is not guaranteed to be in order of recency - // as the main and fall back implementations handle outputting to the file differently. - List entries; -}; - -} diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 3397b325e..4a8fd04b6 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -27,7 +27,7 @@ const Slang::Guid GfxGUID::IID_IResource = SLANG_UUID_IResource; const Slang::Guid GfxGUID::IID_IBufferResource = SLANG_UUID_IBufferResource; const Slang::Guid GfxGUID::IID_ITextureResource = SLANG_UUID_ITextureResource; const Slang::Guid GfxGUID::IID_IDevice = SLANG_UUID_IDevice; -const Slang::Guid GfxGUID::IID_IShaderCacheStatistics = SLANG_UUID_IShaderCacheStatistics; +const Slang::Guid GfxGUID::IID_IShaderCache = SLANG_UUID_IShaderCache; const Slang::Guid GfxGUID::IID_IShaderObject = SLANG_UUID_IShaderObject; const Slang::Guid GfxGUID::IID_IRenderPassLayout = SLANG_UUID_IRenderPassLayout; @@ -343,48 +343,19 @@ Result RendererBase::getEntryPointCodeFromShaderCache( return program->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); } - // Produce a string which we can use to query the shader cache by combining two separate hashes which - // together comprise all the compilation arguments for this program. - ComPtr session; - getSlangSession(session.writeRef()); - - ComPtr shaderKeyBlob; - program->computeDependencyBasedHash(entryPointIndex, targetIndex, shaderKeyBlob.writeRef()); - DigestType shaderKey(shaderKeyBlob); - - // Produce a hash using the AST for this program - This is needed to check whether a cache entry is effectively dirty, - // or to save along with the compiled code into an entry so the entry can be checked if fetched later on. - ComPtr contentsHashBlob; - program->computeContentsBasedHash(contentsHashBlob.writeRef()); - DigestType contentsHash(contentsHashBlob); + // Hash all relevant state for generating the entry point shader code to use as a key + // for the shader cache. + ComPtr hashBlob; + program->getEntryPointHash(entryPointIndex, targetIndex, hashBlob.writeRef()); + PersistentCache::Key cacheKey(hashBlob); + // Query the shader cache. ComPtr codeBlob; - - // Query the shader cache index for an entry with shaderKey as its key. - auto entry = persistentShaderCache->findEntry(shaderKey, codeBlob.writeRef()); - if (entry && contentsHash == entry->contentsBasedDigest) + if (persistentShaderCache->readEntry(cacheKey, codeBlob.writeRef()) != SLANG_OK) { - // We found the entry in the cache, and the entry's contents are up-to-date. Nothing else needs to be done. - shaderCacheHitCount++; - } - else - { - // There are two possibilities: the entry does not exist in the cache, or the entry's contents are out-of-date. - // Both will require calling getEntryPointCode() in order to fetch the correct compiled code, so we'll do that now. + // No cached entry found. Generate the code and add it to the cache. SLANG_RETURN_ON_FAIL(program->getEntryPointCode(entryPointIndex, targetIndex, codeBlob.writeRef(), outDiagnostics)); - - // If the entry was not found in the cache, let's add it. Otherwise, the entry's contents were out-of-date, so let's - // update the entry with the updated contents. - if (!entry) - { - persistentShaderCache->addEntry(shaderKey, contentsHash, codeBlob); - shaderCacheMissCount++; - } - else - { - persistentShaderCache->updateEntry(shaderKey, contentsHash, codeBlob); - shaderCacheEntryDirtyCount++; - } + persistentShaderCache->writeEntry(cacheKey, codeBlob); } *outCode = codeBlob.detach(); @@ -393,9 +364,10 @@ Result RendererBase::getEntryPointCodeFromShaderCache( SlangResult RendererBase::queryInterface(SlangUUID const& uuid, void** outObject) { - if (uuid == GfxGUID::IID_IShaderCacheStatistics) + // Only return the shader cache interface if it is enabled. + if (uuid == GfxGUID::IID_IShaderCache && persistentShaderCache) { - *outObject = static_cast(this); + *outObject = static_cast(this); addRef(); return SLANG_OK; } @@ -413,12 +385,13 @@ IDevice* gfx::RendererBase::getInterface(const Guid& guid) SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc) { - auto cacheDesc = desc.shaderCache; - // We only want to initialize the shader cache if either a shader cache path or file system - // was provided. - if (cacheDesc.shaderCachePath || cacheDesc.shaderCacheFileSystem) + // We only want to initialize the shader cache if a shader cache path was provided. + if (desc.shaderCache.shaderCachePath) { - persistentShaderCache = new PersistentShaderCache(desc.shaderCache); + PersistentCache::Desc cacheDesc; + cacheDesc.directory = desc.shaderCache.shaderCachePath; + cacheDesc.maxEntryCount = desc.shaderCache.maxEntryCount; + persistentShaderCache = new PersistentCache(cacheDesc); } if (desc.apiCommandDispatcher) @@ -751,26 +724,31 @@ Result RendererBase::getShaderObjectLayout( return SLANG_OK; } -GfxCount RendererBase::getCacheMissCount() +Result RendererBase::clearShaderCache() { - return shaderCacheMissCount; + SLANG_ASSERT(persistentShaderCache); + return persistentShaderCache->clear(); } -GfxCount RendererBase::getCacheHitCount() +Result RendererBase::getShaderCacheStats(ShaderCacheStats* outStats) { - return shaderCacheHitCount; -} + SLANG_ASSERT(persistentShaderCache); + if (!outStats) + { + return SLANG_E_INVALID_ARG; + } -GfxCount RendererBase::getCacheEntryDirtyCount() -{ - return shaderCacheEntryDirtyCount; + const auto& stats = persistentShaderCache->getStats(); + outStats->entryCount = (GfxCount)stats.entryCount; + outStats->hitCount = (GfxCount)stats.hitCount; + outStats->missCount = (GfxCount)stats.missCount; + return SLANG_OK; } -Result RendererBase::resetCacheStatistics() +Result RendererBase::resetShaderCacheStats() { - shaderCacheMissCount = 0; - shaderCacheHitCount = 0; - shaderCacheEntryDirtyCount = 0; + SLANG_ASSERT(persistentShaderCache); + persistentShaderCache->resetStats(); return SLANG_OK; } diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 01111e292..c7137f0fa 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -4,8 +4,7 @@ #include "slang-context.h" #include "core/slang-basic.h" #include "core/slang-com-object.h" - -#include "persistent-shader-cache.h" +#include "core/slang-persistent-cache.h" #include "resource-desc-utils.h" @@ -28,7 +27,7 @@ struct GfxGUID static const Slang::Guid IID_ITextureResource; static const Slang::Guid IID_IInputLayout; static const Slang::Guid IID_IDevice; - static const Slang::Guid IID_IShaderCacheStatistics; + static const Slang::Guid IID_IShaderCache; static const Slang::Guid IID_IShaderObjectLayout; static const Slang::Guid IID_IShaderObject; static const Slang::Guid IID_IRenderPassLayout; @@ -1214,7 +1213,7 @@ public: // Renderer implementation shared by all platforms. // Responsible for shader compilation, specialization and caching. -class RendererBase : public IDevice, public IShaderCacheStatistics, public Slang::ComObject +class RendererBase : public IDevice, public IShaderCache, public Slang::ComObject { friend class ShaderObjectBase; public: @@ -1354,27 +1353,21 @@ public: ShaderObjectLayoutBase* layout, IShaderObject** outObject) = 0; + public: + // IShaderCache interface + virtual SLANG_NO_THROW Result SLANG_MCALL clearShaderCache() SLANG_OVERRIDE; + virtual SLANG_NO_THROW Result SLANG_MCALL getShaderCacheStats(ShaderCacheStats* outStats) SLANG_OVERRIDE; + virtual SLANG_NO_THROW Result SLANG_MCALL resetShaderCacheStats() SLANG_OVERRIDE; + protected: virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc); protected: Slang::List m_features; - -public: - virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheMissCount() override; - virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheHitCount() override; - virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheEntryDirtyCount() override; - virtual SLANG_NO_THROW Result SLANG_MCALL resetCacheStatistics() override; - -protected: - GfxCount shaderCacheMissCount = 0; - GfxCount shaderCacheHitCount = 0; - GfxCount shaderCacheEntryDirtyCount = 0; - public: SlangContext slangContext; ShaderCache shaderCache; - RefPtr persistentShaderCache = nullptr; + Slang::RefPtr persistentShaderCache; Slang::Dictionary> m_shaderObjectLayoutCache; Slang::ComPtr m_pipelineCreationAPIDispatcher; diff --git a/tools/slang-unit-test/unit-test-lock-file.cpp b/tools/slang-unit-test/unit-test-lock-file.cpp index c5709242d..33e787a1d 100644 --- a/tools/slang-unit-test/unit-test-lock-file.cpp +++ b/tools/slang-unit-test/unit-test-lock-file.cpp @@ -12,13 +12,13 @@ using namespace Slang; SLANG_UNIT_TEST(lockFile) { - static const String fileName = "test_lock_file"; + static String fileName = Path::simplify(Path::getParentDirectory(Path::getExecutablePath()) + "/test_lock_file"); // Open/close lock file. { LockFile file; SLANG_CHECK(file.isOpen() == false); - SLANG_CHECK(file.open(fileName) == SLANG_OK); + SLANG_CHECK_ABORT(file.open(fileName) == SLANG_OK); SLANG_CHECK(file.isOpen() == true); SLANG_CHECK(File::exists(fileName) == true); file.close(); diff --git a/tools/slang-unit-test/unit-test-persistent-cache.cpp b/tools/slang-unit-test/unit-test-persistent-cache.cpp new file mode 100644 index 000000000..55c358d77 --- /dev/null +++ b/tools/slang-unit-test/unit-test-persistent-cache.cpp @@ -0,0 +1,629 @@ +// unit-test-persistent-cache.cpp +#include "tools/unit-test/slang-unit-test.h" + +#include "../../source/core/slang-persistent-cache.h" +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-file-system.h" +#include "../../source/core/slang-random-generator.h" + +#include +#include +#include +#include +#include +#include + +using namespace Slang; + +static DefaultRandomGenerator rng(0xdeadbeef); + +inline ComPtr createRandomBlob(size_t size) +{ + ScopedAllocation alloc; + alloc.allocate(size); + rng.nextData(alloc.getData(), size); + return RawBlob::moveCreate(alloc); +} + +inline bool isBlobEqual(ISlangBlob* a, ISlangBlob* b) +{ + return + a->getBufferSize() == b->getBufferSize() && + ::memcmp(a->getBufferPointer(), b->getBufferPointer(), a->getBufferSize()) == 0; +} + +class Barrier +{ +public: + Barrier(size_t threadCount, std::function completionFunc = nullptr) + : m_threadCount(threadCount) + , m_waitCount(threadCount) + , m_completionFunc(completionFunc) + {} + + Barrier(const Barrier& barrier) = delete; + Barrier& operator=(const Barrier& barrier) = delete; + + void wait() + { + std::unique_lock lock(m_mutex); + + auto generation = m_generation; + + if (--m_waitCount == 0) + { + if (m_completionFunc) m_completionFunc(); + ++m_generation; + m_waitCount = m_threadCount; + m_condition.notify_all(); + } + else + { + m_condition.wait(lock, [this, generation] () { return generation != m_generation; }); + } + } + +private: + size_t m_threadCount; + size_t m_waitCount; + size_t m_generation = 0; + std::function m_completionFunc; + std::mutex m_mutex; + std::condition_variable m_condition; +}; + +namespace Slang +{ + +/// Helper class for performing tests on the persistent cache. +/// This class is a friend class of PersistentCache and can access its internals. +struct PersistentCacheTest +{ + ISlangMutableFileSystem* osFileSystem; + String cacheDirectory; + RefPtr cache; + + PersistentCacheTest(Count maxEntryCount = 0) + { + osFileSystem = OSFileSystem::getMutableSingleton(); + cacheDirectory = Path::simplify(Path::getParentDirectory(Path::getExecutablePath()) + "/persistent-cache-test"); + + removeCacheFiles(); + + PersistentCache::Desc desc; + desc.directory = cacheDirectory.getBuffer(); + desc.maxEntryCount = maxEntryCount; + cache = new PersistentCache(desc); + } + + virtual ~PersistentCacheTest() + { + cache = nullptr; + + removeCacheFiles(); + } + + void removeCacheFiles() + { + // Remove all files the cache created. + osFileSystem->enumeratePathContents( + cacheDirectory.getBuffer(), + [](SlangPathType pathType, const char* fileName, void* userData) + { + PersistentCacheTest* self = static_cast(userData); + String path = self->cacheDirectory + "/" + fileName; + self->osFileSystem->remove(path.getBuffer()); + }, + this); + + // Also remove the cache directory. + osFileSystem->remove(cacheDirectory.getBuffer()); + } + + // Entry (key, data) for testing. + struct Entry + { + PersistentCache::Key key; + ComPtr data; + }; + + // Helper to write an entry to the cache. + void writeEntry(const Entry& entry) + { + SLANG_CHECK(cache->writeEntry(entry.key, entry.data) == SLANG_OK); + } + + // Helper to read an entry from the cache and discard the data. + // Returns true if the entry was found, false otherwise. + bool readEntry(const Entry& entry) + { + ComPtr data; + SlangResult result = cache->readEntry(entry.key, data.writeRef()); + SLANG_CHECK(result == SLANG_OK || result == SLANG_E_NOT_FOUND); + if (result == SLANG_OK) + { + SLANG_CHECK(isBlobEqual(data, entry.data)); + } + if (result == SLANG_E_NOT_FOUND) + { + SLANG_CHECK(data == nullptr); + } + return result == SLANG_OK; + } + + // Get the absolute filename for a cache entry file. + String getEntryFileName(const Entry& entry) + { + return cache->getEntryFileName(entry.key); + } + + // Get the absolute filename of the cache index file. + String getIndexFilename() + { + return cache->m_indexFileName; + } +}; + +} // namespace Slang + +// Performs basic tests on the cache. +// - write/read entries +// - check for correct cache stats +// - clearing the cache +// - resetting stats +struct BasicTest : public PersistentCacheTest +{ + BasicTest() : PersistentCacheTest() {} + + void run() + { + // Check that cache is empty. + SLANG_CHECK(cache->getStats().entryCount == 0); + SLANG_CHECK(cache->getStats().hitCount == 0); + SLANG_CHECK(cache->getStats().missCount == 0); + + // Setup a list of entries to store in the cache. + List entries; + for (size_t i = 0; i < 10; ++i) + { + auto data = createRandomBlob(i * 1024); + auto key = SHA1::compute(data->getBufferPointer(), data->getBufferSize()); + entries.add(Entry{ key, data }); + } + + for (size_t i = 0; i < 10; ++i) + { + const auto& entry = entries[i]; + ComPtr data; + + // Try to read an entry. Check that its not found and counts as a miss. + SLANG_CHECK(cache->readEntry(entry.key, data.writeRef()) == SLANG_E_NOT_FOUND); + SLANG_CHECK(cache->getStats().missCount == i + 1); + + // Write the entry. Check that it gets added. + SLANG_CHECK(cache->writeEntry(entry.key, entry.data) == SLANG_OK); + SLANG_CHECK(cache->getStats().entryCount == i + 1); + } + + SLANG_CHECK(cache->getStats().entryCount == 10); + SLANG_CHECK(cache->getStats().hitCount == 0); + SLANG_CHECK(cache->getStats().missCount == 10); + + for (size_t i = 0; i < 10; ++i) + { + const auto& entry = entries[i]; + ComPtr data; + + // Read entries. Check that these are cache hits and return the correct data. + SLANG_CHECK(cache->readEntry(entry.key, data.writeRef()) == SLANG_OK); + SLANG_CHECK(cache->getStats().hitCount == i + 1); + SLANG_CHECK(isBlobEqual(data, entry.data)); + } + + SLANG_CHECK(cache->getStats().entryCount == 10); + SLANG_CHECK(cache->getStats().hitCount == 10); + SLANG_CHECK(cache->getStats().missCount == 10); + + // Clear the cache. Check that entry count is reset. + SLANG_CHECK(cache->clear() == SLANG_OK); + SLANG_CHECK(cache->getStats().entryCount == 0); + SLANG_CHECK(cache->getStats().hitCount == 10); + SLANG_CHECK(cache->getStats().missCount == 10); + + // Reset stats. + cache->resetStats(); + SLANG_CHECK(cache->getStats().entryCount == 0); + SLANG_CHECK(cache->getStats().hitCount == 0); + SLANG_CHECK(cache->getStats().missCount == 0); + + // Check that cache is empty. + for (size_t i = 0; i < 10; ++i) + { + const auto& entry = entries[i]; + ComPtr data; + SLANG_CHECK(cache->readEntry(entry.key, data.writeRef()) == SLANG_E_NOT_FOUND); + } + SLANG_CHECK(cache->getStats().missCount == 10); + } +}; + +// Tests the least-recently-used cache eviction policy. +struct EvictionTest : public PersistentCacheTest +{ + EvictionTest() : PersistentCacheTest(3) {} + + void run() + { + // Setup a list of entries to store in the cache. + List entries; + for (size_t i = 0; i < 10; ++i) + { + auto data = createRandomBlob(4096); + auto key = SHA1::compute(data->getBufferPointer(), data->getBufferSize()); + entries.add(Entry{ key, data }); + } + + writeEntry(entries[0]); + writeEntry(entries[1]); + writeEntry(entries[2]); + + SLANG_CHECK(readEntry(entries[0]) == true); + SLANG_CHECK(readEntry(entries[1]) == true); + SLANG_CHECK(readEntry(entries[2]) == true); + + // Evict LRU entry 0. + writeEntry(entries[3]); + SLANG_CHECK(readEntry(entries[0]) == false); + SLANG_CHECK(readEntry(entries[1]) == true); + SLANG_CHECK(readEntry(entries[2]) == true); + SLANG_CHECK(readEntry(entries[3]) == true); + + // Evict LRU entry 1. + writeEntry(entries[4]); + SLANG_CHECK(readEntry(entries[1]) == false); + SLANG_CHECK(readEntry(entries[2]) == true); + SLANG_CHECK(readEntry(entries[3]) == true); + SLANG_CHECK(readEntry(entries[4]) == true); + + // Evict LRU entry 2. + writeEntry(entries[5]); + SLANG_CHECK(readEntry(entries[2]) == false); + SLANG_CHECK(readEntry(entries[3]) == true); + SLANG_CHECK(readEntry(entries[4]) == true); + SLANG_CHECK(readEntry(entries[5]) == true); + + // Evict LRU entry 4. + SLANG_CHECK(readEntry(entries[3]) == true); + writeEntry(entries[6]); + SLANG_CHECK(readEntry(entries[3]) == true); + SLANG_CHECK(readEntry(entries[4]) == false); + SLANG_CHECK(readEntry(entries[5]) == true); + SLANG_CHECK(readEntry(entries[6]) == true); + } +}; + + +// Tests the cache to be robust against various corruptions. +// These can happen if the cache files are manipulated externally. +// The cache might also be corrupted if the application is terminated while writing. +struct CorruptionTest : public PersistentCacheTest +{ + List entries; + + template + void testIndexCorruption(Func func, SlangResult expectedReadResult) + { + writeEntry(entries[0]); + SLANG_CHECK(readEntry(entries[0]) == true); + func(); + // We expect a SLANG_E_NOT_FOUND because the cache has an empty index now. + ComPtr data; + SLANG_CHECK(cache->readEntry(entries[0].key, data.writeRef()) == expectedReadResult); + + writeEntry(entries[0]); + SLANG_CHECK(readEntry(entries[0]) == true); + func(); + writeEntry(entries[0]); + SLANG_CHECK(readEntry(entries[0]) == true); + } + + void run() + { + // Setup a list of entries to store in the cache. + for (size_t i = 0; i < 10; ++i) + { + auto data = createRandomBlob(4096); + auto key = SHA1::compute(data->getBufferPointer(), data->getBufferSize()); + entries.add(Entry{ key, data }); + } + + // Test behavior when a cached entry file is removed externally before reading. + writeEntry(entries[0]); + SLANG_CHECK(readEntry(entries[0]) == true); + osFileSystem->remove(getEntryFileName(entries[0]).getBuffer()); + ComPtr data; + // First time we read the entry, we expect a SLANG_E_CANNOT_OPEN because the file is gone. + SLANG_CHECK(cache->readEntry(entries[0].key, data.writeRef()) == SLANG_E_CANNOT_OPEN); + // The next time we read the entry, we expect a SLANG_E_NOT_FOUND because the entry has + // been removed from the cache index. + SLANG_CHECK(cache->readEntry(entries[0].key, data.writeRef()) == SLANG_E_NOT_FOUND); + + // Test behavior when a cached entry file is removed externally before writing. + writeEntry(entries[0]); + SLANG_CHECK(readEntry(entries[0]) == true); + osFileSystem->remove(getEntryFileName(entries[0]).getBuffer()); + writeEntry(entries[0]); + SLANG_CHECK(readEntry(entries[0]) == true); + + // Test behavior when the index file is removed before reading. + writeEntry(entries[0]); + SLANG_CHECK(readEntry(entries[0]) == true); + osFileSystem->remove(getIndexFilename().getBuffer()); + // We expect a SLANG_E_NOT_FOUND because the cache has an empty index now. + SLANG_CHECK(cache->readEntry(entries[0].key, data.writeRef()) == SLANG_E_NOT_FOUND); + + // Test behavior when the index file is removed before writing. + writeEntry(entries[0]); + SLANG_CHECK(readEntry(entries[0]) == true); + osFileSystem->remove(getIndexFilename().getBuffer()); + writeEntry(entries[1]); + SLANG_CHECK(readEntry(entries[1]) == true); + + // Test different corruptions of the index file. + testIndexCorruption( + [this]() + { + osFileSystem->remove(getIndexFilename().getBuffer()); + }, + SLANG_E_NOT_FOUND); + + testIndexCorruption( + [this]() + { + FileStream fs; + fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite); + fs.write("x", 1); + }, + SLANG_E_INTERNAL_FAIL); + + testIndexCorruption( + [this]() + { + FileStream fs; + fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite); + fs.seek(SeekOrigin::Start, 4); + uint32_t version = 0xffffffff; + fs.write(&version, sizeof(version)); + }, + SLANG_E_INTERNAL_FAIL); + + testIndexCorruption( + [this]() + { + FileStream fs; + fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite); + fs.seek(SeekOrigin::Start, 8); + uint32_t count = 0x7fffffff; + fs.write(&count, sizeof(count)); + }, + SLANG_E_INTERNAL_FAIL); + + testIndexCorruption( + [this]() + { + FileStream fs; + fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite); + fs.seek(SeekOrigin::Start, 8); + uint32_t count = 0; + fs.write(&count, sizeof(count)); + }, + SLANG_E_INTERNAL_FAIL); + + testIndexCorruption( + [this]() + { + FileStream fs; + fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite); + fs.seek(SeekOrigin::End, 0); + fs.write("x", 1); + }, + SLANG_E_INTERNAL_FAIL); + } +}; + +struct MultiThreadingTest : public PersistentCacheTest +{ + void run() + { + } +}; + + +#undef ENABLE_LOGGING +#undef ENABLE_WRITE_TEST + +#ifdef ENABLE_LOGGING +#define LOG(fmt, ...) printf(fmt, ##__VA_ARGS__); fflush(stdout); +#else +#define LOG(fmt, ...) +#endif + +// Stress testing. +// This test spawns a number of threads to do concurrent access to the cache. +// For now this is fairly simple: +// - spawn a number of threads +// - write random entries to the cache concurrenctly (slightly oversubscribe) +// - synchronize +// - read entries from the cache concurretly (test that we get the expected number of hits/misses) +// - synchronize +// - repeat for a number of iterations +struct StressTest : public PersistentCacheTest +{ + // Number of entries to write/read per iteration. + static const uint32_t kEntryCount = 100; + // Number of entries the cache is short for storing one iteration. + static const uint32_t kEntryShortageCount = 10; + // Number of parallel threads to write/read. + static const uint32_t kThreadCount = 4; + // Number of entries to write/read per thread per iteration. + static const uint32_t kBatchCount = kEntryCount / kThreadCount; + // Total number of iterations. + static const uint32_t kIterationCount = 4; + + static_assert(kEntryCount % kThreadCount == 0, "kEntryCount must be divisible by kThreadCount"); + + List entries; + + std::atomic iteration{0}; + std::atomic entriesWritten{0}; + std::atomic bytesWritten{0}; + std::atomic entriesRead{0}; + std::atomic bytesRead{0}; + std::atomic readSuccess{0}; + std::thread threads[kThreadCount]; + + Barrier *read_barrier; + Barrier *write_barrier; + + std::mutex mutex; + std::condition_variable conditionVariable; + uint32_t generation{0}; + + StressTest() : PersistentCacheTest(kEntryCount - kEntryShortageCount) {} + + void run() + { + // Setup a list of entries to store in the cache. + for (size_t i = 0; i < kEntryCount * 2; ++i) + { + size_t size = rng.nextInt32InRange(256, 64 * 1024); + auto data = createRandomBlob(size); + auto key = SHA1::compute(data->getBufferPointer(), data->getBufferSize()); + entries.add(Entry{ key, data }); + } + + auto startTime = std::chrono::high_resolution_clock::now(); + + Barrier read_barrier_( + kThreadCount, + []() + { + LOG("Read synchronized\n"); + }); + Barrier write_barrier_( + kThreadCount, + [this](){ + LOG("Write synchronized\n"); +#ifndef ENABLE_WRITE_TEST + SLANG_CHECK(readSuccess == kEntryCount - kEntryShortageCount); + readSuccess.store(0); +#endif + iteration += 1; + }); + + read_barrier = &read_barrier_; + write_barrier = &write_barrier_; + + for (uint32_t threadIndex = 0; threadIndex < kThreadCount; ++threadIndex) + { + threads[threadIndex] = std::thread( + [](StressTest* self, uint32_t threadIndex) + { + LOG("Thread %u: starting\n", threadIndex); + + while (true) + { + // Write to cache. + size_t startIndex = (self->iteration * kEntryCount + (threadIndex * kBatchCount)) % (kEntryCount * 2); + for (size_t i = 0; i < kBatchCount; ++i) + { + const Entry& entry = self->entries[startIndex + i]; +#ifdef ENABLE_WRITE_TEST + self->osFileSystem->saveFileBlob(self->getEntryFileName(entry).getBuffer(), entry.data); +#else + self->writeEntry(entry); +#endif + self->entriesWritten.fetch_add(1); + self->bytesWritten.fetch_add((uint32_t)entry.data->getBufferSize()); + } + + LOG("Thread %u: ended writing (iteration=%u)\n", threadIndex, self->iteration.load()); + + // Synchronize. + self->read_barrier->wait(); + + // Read from cache. + for (size_t i = 0; i < kBatchCount; ++i) + { + const Entry& entry = self->entries[startIndex + i]; +#ifndef ENABLE_WRITE_TEST + if (self->readEntry(entry)) + { + self->readSuccess.fetch_add(1); + self->bytesRead.fetch_add((uint32_t)entry.data->getBufferSize()); + } +#endif + self->entriesRead.fetch_add(1); + } + + LOG("Thread %u: ended reading (iteration=%u)\n", threadIndex, self->iteration.load()); + + // Synchronize. + self->write_barrier->wait(); + + // Terminate. + if (self->iteration >= kIterationCount) + { + LOG("Thread %u: terminates\n", threadIndex); + return; + } + } + }, + this, threadIndex); + } + + for (auto& thread : threads) + { + thread.join(); + } + + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = endTime - startTime; + auto seconds = std::chrono::duration_cast(duration).count() / 1000.0; + + LOG("Total time: %.3fs\n", seconds); + LOG("Total bytes written: %d\n", bytesWritten.load()); + LOG("Write througput: %.3fMB/s\n", (bytesWritten.load() / (1024.0 * 1024.0)) / seconds); + LOG("Total bytes read: %d\n", bytesRead.load()); + } +}; + +SLANG_UNIT_TEST(persistentCacheBasic) +{ + BasicTest test; + test.run(); +} + +SLANG_UNIT_TEST(persistentCacheEviction) +{ + EvictionTest test; + test.run(); +} + +SLANG_UNIT_TEST(persistentCacheCorruption) +{ + CorruptionTest test; + test.run(); +} + +SLANG_UNIT_TEST(persistentCacheMultiThreading) +{ + MultiThreadingTest test; + test.run(); +} + +SLANG_UNIT_TEST(persistentCacheStress) +{ + StressTest test; + test.run(); +} -- cgit v1.2.3