summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorskallweitNV <64953474+skallweitNV@users.noreply.github.com>2022-12-12 19:25:48 +0100
committerGitHub <noreply@github.com>2022-12-12 10:25:48 -0800
commitc2dc1a86ed2f5e160749fe9f99b70db6c3e4d7a6 (patch)
treeea65b9635d892917a2420688a27c38537c4758be
parent8d359fc6133fa49d2d3b7f8bb4b37916e719c344 (diff)
Refactor shader cache (#2558)
* Fix a bug in Path::find * Fix code formatting * Fix LockFile and add LockFileGuard * Add PersistentCache and unit test * Replace file path dependency list with source file dependency list * Add note on ordering in Module/FileDependencyList * Remove old shader cache code * Refactor shader cache implementation * Temporarily skip unit tests reading/writing files * Fix warning * Reenable lock file test * Rename shader cache tests and disable crashing test * Testing * Stop using Path::getCanonical * Fix persistent cache lock and test * Fix threading issues * Move adding file dependency hashes to getEntryPointHash() * Fix handling of #include files * Allow specifying additional search paths for gfx testing device * Work on shader cache tests * Update project files * Revive shader cache graphics tests * Split graphics pipeline test * Fix compilation
-rw-r--r--build/visual-studio/core/core.vcxproj2
-rw-r--r--build/visual-studio/core/core.vcxproj.filters6
-rw-r--r--build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj7
-rw-r--r--build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters13
-rw-r--r--build/visual-studio/gfx/gfx.vcxproj2
-rw-r--r--build/visual-studio/gfx/gfx.vcxproj.filters6
-rw-r--r--build/visual-studio/slang-rt/slang-rt.vcxproj2
-rw-r--r--build/visual-studio/slang-rt/slang-rt.vcxproj.filters6
-rw-r--r--build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj1
-rw-r--r--build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters3
-rw-r--r--slang-gfx.h52
-rw-r--r--slang.h36
-rw-r--r--source/core/slang-crypto.cpp14
-rw-r--r--source/core/slang-io.cpp2
-rw-r--r--source/core/slang-io.h27
-rw-r--r--source/core/slang-persistent-cache.cpp289
-rw-r--r--source/core/slang-persistent-cache.h91
-rw-r--r--source/slang/slang-compiler.cpp29
-rwxr-xr-xsource/slang/slang-compiler.h158
-rw-r--r--source/slang/slang-preprocessor.cpp23
-rw-r--r--source/slang/slang-preprocessor.h2
-rw-r--r--source/slang/slang.cpp229
-rw-r--r--tools/gfx-unit-test/gfx-test-util.cpp12
-rw-r--r--tools/gfx-unit-test/gfx-test-util.h4
-rw-r--r--tools/gfx-unit-test/shader-cache-graphics-fragment.slang (renamed from tools/gfx-unit-test/split-graphics-fragment.slang)2
-rw-r--r--tools/gfx-unit-test/shader-cache-graphics-vertex.slang (renamed from tools/gfx-unit-test/split-graphics-vertex.slang)2
-rw-r--r--tools/gfx-unit-test/shader-cache-multiple-entry-points.slang (renamed from tools/gfx-unit-test/multiple-entry-point-shader-cache-shader.slang)13
-rw-r--r--tools/gfx-unit-test/shader-cache-specialization.slang68
-rw-r--r--tools/gfx-unit-test/shader-cache-tests.cpp1449
-rw-r--r--tools/gfx/gfx.slang24
-rw-r--r--tools/gfx/persistent-shader-cache.cpp316
-rw-r--r--tools/gfx/persistent-shader-cache.h99
-rw-r--r--tools/gfx/renderer-shared.cpp94
-rw-r--r--tools/gfx/renderer-shared.h27
-rw-r--r--tools/slang-unit-test/unit-test-lock-file.cpp4
-rw-r--r--tools/slang-unit-test/unit-test-persistent-cache.cpp629
36 files changed, 1967 insertions, 1776 deletions
diff --git a/build/visual-studio/core/core.vcxproj b/build/visual-studio/core/core.vcxproj
index e171a7d68..ecd6b54fe 100644
--- a/build/visual-studio/core/core.vcxproj
+++ b/build/visual-studio/core/core.vcxproj
@@ -297,6 +297,7 @@
<ClInclude Include="..\..\..\source\core\slang-memory-arena.h" />
<ClInclude Include="..\..\..\source\core\slang-memory-file-system.h" />
<ClInclude Include="..\..\..\source\core\slang-offset-container.h" />
+ <ClInclude Include="..\..\..\source\core\slang-persistent-cache.h" />
<ClInclude Include="..\..\..\source\core\slang-platform.h" />
<ClInclude Include="..\..\..\source\core\slang-process-util.h" />
<ClInclude Include="..\..\..\source\core\slang-process.h" />
@@ -353,6 +354,7 @@
<ClCompile Include="..\..\..\source\core\slang-memory-arena.cpp" />
<ClCompile Include="..\..\..\source\core\slang-memory-file-system.cpp" />
<ClCompile Include="..\..\..\source\core\slang-offset-container.cpp" />
+ <ClCompile Include="..\..\..\source\core\slang-persistent-cache.cpp" />
<ClCompile Include="..\..\..\source\core\slang-platform.cpp" />
<ClCompile Include="..\..\..\source\core\slang-process-util.cpp" />
<ClCompile Include="..\..\..\source\core\slang-random-generator.cpp" />
diff --git a/build/visual-studio/core/core.vcxproj.filters b/build/visual-studio/core/core.vcxproj.filters
index 06d151349..144c5259f 100644
--- a/build/visual-studio/core/core.vcxproj.filters
+++ b/build/visual-studio/core/core.vcxproj.filters
@@ -123,6 +123,9 @@
<ClInclude Include="..\..\..\source\core\slang-offset-container.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\core\slang-persistent-cache.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\core\slang-platform.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -287,6 +290,9 @@
<ClCompile Include="..\..\..\source\core\slang-offset-container.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\core\slang-persistent-cache.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\core\slang-platform.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj
index b0879ccfe..869b0fd27 100644
--- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj
+++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj
@@ -317,16 +317,17 @@
<None Include="..\..\..\tools\gfx-unit-test\compute-trivial.slang" />
<None Include="..\..\..\tools\gfx-unit-test\format-test-shaders.slang" />
<None Include="..\..\..\tools\gfx-unit-test\graphics-smoke.slang" />
- <None Include="..\..\..\tools\gfx-unit-test\multiple-entry-point-shader-cache-shader.slang" />
<None Include="..\..\..\tools\gfx-unit-test\mutable-shader-object.slang" />
<None Include="..\..\..\tools\gfx-unit-test\nested-parameter-block.slang" />
<None Include="..\..\..\tools\gfx-unit-test\ray-tracing-test-shaders.slang" />
<None Include="..\..\..\tools\gfx-unit-test\resolve-resource-shader.slang" />
<None Include="..\..\..\tools\gfx-unit-test\root-shader-parameter.slang" />
<None Include="..\..\..\tools\gfx-unit-test\sampler-array.slang" />
+ <None Include="..\..\..\tools\gfx-unit-test\shader-cache-graphics-fragment.slang" />
+ <None Include="..\..\..\tools\gfx-unit-test\shader-cache-graphics-vertex.slang" />
<None Include="..\..\..\tools\gfx-unit-test\shader-cache-graphics.slang" />
- <None Include="..\..\..\tools\gfx-unit-test\split-graphics-fragment.slang" />
- <None Include="..\..\..\tools\gfx-unit-test\split-graphics-vertex.slang" />
+ <None Include="..\..\..\tools\gfx-unit-test\shader-cache-multiple-entry-points.slang" />
+ <None Include="..\..\..\tools\gfx-unit-test\shader-cache-specialization.slang" />
<None Include="..\..\..\tools\gfx-unit-test\swapchain-shader.slang" />
<None Include="..\..\..\tools\gfx-unit-test\trivial-copy-textures.slang" />
<None Include="..\..\..\tools\gfx-unit-test\trivial-copy.slang" />
diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters
index 00fd1de16..58045ccee 100644
--- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters
+++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters
@@ -121,9 +121,6 @@
<None Include="..\..\..\tools\gfx-unit-test\graphics-smoke.slang">
<Filter>Source Files</Filter>
</None>
- <None Include="..\..\..\tools\gfx-unit-test\multiple-entry-point-shader-cache-shader.slang">
- <Filter>Source Files</Filter>
- </None>
<None Include="..\..\..\tools\gfx-unit-test\mutable-shader-object.slang">
<Filter>Source Files</Filter>
</None>
@@ -142,13 +139,19 @@
<None Include="..\..\..\tools\gfx-unit-test\sampler-array.slang">
<Filter>Source Files</Filter>
</None>
+ <None Include="..\..\..\tools\gfx-unit-test\shader-cache-graphics-fragment.slang">
+ <Filter>Source Files</Filter>
+ </None>
+ <None Include="..\..\..\tools\gfx-unit-test\shader-cache-graphics-vertex.slang">
+ <Filter>Source Files</Filter>
+ </None>
<None Include="..\..\..\tools\gfx-unit-test\shader-cache-graphics.slang">
<Filter>Source Files</Filter>
</None>
- <None Include="..\..\..\tools\gfx-unit-test\split-graphics-fragment.slang">
+ <None Include="..\..\..\tools\gfx-unit-test\shader-cache-multiple-entry-points.slang">
<Filter>Source Files</Filter>
</None>
- <None Include="..\..\..\tools\gfx-unit-test\split-graphics-vertex.slang">
+ <None Include="..\..\..\tools\gfx-unit-test\shader-cache-specialization.slang">
<Filter>Source Files</Filter>
</None>
<None Include="..\..\..\tools\gfx-unit-test\swapchain-shader.slang">
diff --git a/build/visual-studio/gfx/gfx.vcxproj b/build/visual-studio/gfx/gfx.vcxproj
index 8f200be8e..476f6a808 100644
--- a/build/visual-studio/gfx/gfx.vcxproj
+++ b/build/visual-studio/gfx/gfx.vcxproj
@@ -419,7 +419,6 @@ IF EXIST "$(SolutionDir)tools\gfx\slang.slang"\ (xcopy /Q /E /Y /I "$(SolutionDi
<ClInclude Include="..\..\..\tools\gfx\nvapi\nvapi-include.h" />
<ClInclude Include="..\..\..\tools\gfx\nvapi\nvapi-util.h" />
<ClInclude Include="..\..\..\tools\gfx\open-gl\render-gl.h" />
- <ClInclude Include="..\..\..\tools\gfx\persistent-shader-cache.h" />
<ClInclude Include="..\..\..\tools\gfx\renderer-shared.h" />
<ClInclude Include="..\..\..\tools\gfx\resource-desc-utils.h" />
<ClInclude Include="..\..\..\tools\gfx\simple-render-pass-layout.h" />
@@ -530,7 +529,6 @@ IF EXIST "$(SolutionDir)tools\gfx\slang.slang"\ (xcopy /Q /E /Y /I "$(SolutionDi
<ClCompile Include="..\..\..\tools\gfx\immediate-renderer-base.cpp" />
<ClCompile Include="..\..\..\tools\gfx\nvapi\nvapi-util.cpp" />
<ClCompile Include="..\..\..\tools\gfx\open-gl\render-gl.cpp" />
- <ClCompile Include="..\..\..\tools\gfx\persistent-shader-cache.cpp" />
<ClCompile Include="..\..\..\tools\gfx\render.cpp" />
<ClCompile Include="..\..\..\tools\gfx\renderer-shared.cpp" />
<ClCompile Include="..\..\..\tools\gfx\resource-desc-utils.cpp" />
diff --git a/build/visual-studio/gfx/gfx.vcxproj.filters b/build/visual-studio/gfx/gfx.vcxproj.filters
index c708450d5..1a7ed3d03 100644
--- a/build/visual-studio/gfx/gfx.vcxproj.filters
+++ b/build/visual-studio/gfx/gfx.vcxproj.filters
@@ -306,9 +306,6 @@
<ClInclude Include="..\..\..\tools\gfx\open-gl\render-gl.h">
<Filter>Header Files</Filter>
</ClInclude>
- <ClInclude Include="..\..\..\tools\gfx\persistent-shader-cache.h">
- <Filter>Header Files</Filter>
- </ClInclude>
<ClInclude Include="..\..\..\tools\gfx\renderer-shared.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -635,9 +632,6 @@
<ClCompile Include="..\..\..\tools\gfx\open-gl\render-gl.cpp">
<Filter>Source Files</Filter>
</ClCompile>
- <ClCompile Include="..\..\..\tools\gfx\persistent-shader-cache.cpp">
- <Filter>Source Files</Filter>
- </ClCompile>
<ClCompile Include="..\..\..\tools\gfx\render.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/build/visual-studio/slang-rt/slang-rt.vcxproj b/build/visual-studio/slang-rt/slang-rt.vcxproj
index 03d3852fd..92191667c 100644
--- a/build/visual-studio/slang-rt/slang-rt.vcxproj
+++ b/build/visual-studio/slang-rt/slang-rt.vcxproj
@@ -309,6 +309,7 @@
<ClInclude Include="..\..\..\source\core\slang-memory-arena.h" />
<ClInclude Include="..\..\..\source\core\slang-memory-file-system.h" />
<ClInclude Include="..\..\..\source\core\slang-offset-container.h" />
+ <ClInclude Include="..\..\..\source\core\slang-persistent-cache.h" />
<ClInclude Include="..\..\..\source\core\slang-platform.h" />
<ClInclude Include="..\..\..\source\core\slang-process-util.h" />
<ClInclude Include="..\..\..\source\core\slang-process.h" />
@@ -366,6 +367,7 @@
<ClCompile Include="..\..\..\source\core\slang-memory-arena.cpp" />
<ClCompile Include="..\..\..\source\core\slang-memory-file-system.cpp" />
<ClCompile Include="..\..\..\source\core\slang-offset-container.cpp" />
+ <ClCompile Include="..\..\..\source\core\slang-persistent-cache.cpp" />
<ClCompile Include="..\..\..\source\core\slang-platform.cpp" />
<ClCompile Include="..\..\..\source\core\slang-process-util.cpp" />
<ClCompile Include="..\..\..\source\core\slang-random-generator.cpp" />
diff --git a/build/visual-studio/slang-rt/slang-rt.vcxproj.filters b/build/visual-studio/slang-rt/slang-rt.vcxproj.filters
index df9f8c2ca..b99c077ae 100644
--- a/build/visual-studio/slang-rt/slang-rt.vcxproj.filters
+++ b/build/visual-studio/slang-rt/slang-rt.vcxproj.filters
@@ -123,6 +123,9 @@
<ClInclude Include="..\..\..\source\core\slang-offset-container.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\core\slang-persistent-cache.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\core\slang-platform.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -290,6 +293,9 @@
<ClCompile Include="..\..\..\source\core\slang-offset-container.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\core\slang-persistent-cache.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\core\slang-platform.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj
index deab210ee..5eec9ec82 100644
--- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj
+++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj
@@ -296,6 +296,7 @@
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-memory-arena.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-offset-container.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-path.cpp" />
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-persistent-cache.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-process.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-riff.cpp" />
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-rtti.cpp" />
diff --git a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters
index a33dc44cc..c350b6f24 100644
--- a/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters
+++ b/build/visual-studio/slang-unit-test-tool/slang-unit-test-tool.vcxproj.filters
@@ -62,6 +62,9 @@
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-path.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-persistent-cache.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\tools\slang-unit-test\unit-test-process.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/slang-gfx.h b/slang-gfx.h
index 590d7d5c7..07fd0d58a 100644
--- a/slang-gfx.h
+++ b/slang-gfx.h
@@ -2052,25 +2052,6 @@ public:
0xbe91ba6c, 0x784, 0x4308, { 0xa1, 0x0, 0x19, 0xc3, 0x66, 0x83, 0x44, 0xb2 } \
}
-// These are exclusively used to track hit/miss counts for shader cache entries. Entry hit and
-// miss counts specifically indicate if the file containing relevant shader code was found in
-// the cache, while the general hit and miss counts indicate whether the file was both found and
-// up-to-date.
-class IShaderCacheStatistics : public ISlangUnknown
-{
-public:
- virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheMissCount() = 0;
- virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheHitCount() = 0;
- virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheEntryDirtyCount() = 0;
-
- virtual SLANG_NO_THROW Result SLANG_MCALL resetCacheStatistics() = 0;
-};
-
-#define SLANG_UUID_IShaderCacheStatistics \
- { \
- 0x8eccc8ec, 0x5c04, 0x4a51, { 0x99, 0x75, 0x13, 0xf8, 0xfe, 0xa1, 0x59, 0xf3 } \
- }
-
struct AdapterLUID
{
uint8_t luid[16];
@@ -2225,15 +2206,10 @@ public:
struct ShaderCacheDesc
{
- // The filename for the file the cache's state should be saved to or loaded from.
- const char* cacheFilename = "cache.txt";
- // The root directory for the shader cache.
+ // The root directory for the shader cache. If not set, shader cache is disabled.
const char* shaderCachePath = nullptr;
- // The file system for loading cached shader kernels. The layer does not maintain a strong reference to the object,
- // instead the user is responsible for holding the object alive during the lifetime of an `IDevice`.
- ISlangFileSystem* shaderCacheFileSystem = nullptr;
// The maximum number of entries stored in the cache. By default, there is no limit.
- GfxCount entryCountLimit = 0;
+ GfxCount maxEntryCount = 0;
};
struct InteropHandles
@@ -2597,6 +2573,30 @@ public:
0x715bdf26, 0x5135, 0x11eb, { 0xAE, 0x93, 0x02, 0x42, 0xAC, 0x13, 0x00, 0x02 } \
}
+struct ShaderCacheStats
+{
+ GfxCount hitCount;
+ GfxCount missCount;
+ GfxCount entryCount;
+};
+
+// These are exclusively used to track hit/miss counts for shader cache entries. Entry hit and
+// miss counts specifically indicate if the file containing relevant shader code was found in
+// the cache, while the general hit and miss counts indicate whether the file was both found and
+// up-to-date.
+class IShaderCache : public ISlangUnknown
+{
+public:
+ virtual SLANG_NO_THROW Result SLANG_MCALL clearShaderCache() = 0;
+ virtual SLANG_NO_THROW Result SLANG_MCALL getShaderCacheStats(ShaderCacheStats* outStats) = 0;
+ virtual SLANG_NO_THROW Result SLANG_MCALL resetShaderCacheStats() = 0;
+};
+
+#define SLANG_UUID_IShaderCache \
+ { \
+ 0x8eccc8ec, 0x5c04, 0x4a51, { 0x99, 0x75, 0x13, 0xf8, 0xfe, 0xa1, 0x59, 0xf3 } \
+ }
+
class IPipelineCreationAPIDispatcher : public ISlangUnknown
{
public:
diff --git a/slang.h b/slang.h
index 784a3e763..4aa30b049 100644
--- a/slang.h
+++ b/slang.h
@@ -4425,34 +4425,16 @@ namespace slang
IBlob** outCode,
IBlob** outDiagnostics = nullptr) = 0;
- /** Compute the hash code of all dependencies for this component type. This generally means file path
- dependencies but can also include the component's name or sub-components. The dependency-based
- hash effectively represents all the files that may be included/imported by a component type along with
- any non-code-specific that helps define a component. This can be useful to simply check for a component
- type without needing to inspect the code. For example, a shader cache might key its entries using the
- dependency-based hash in order to determine at a glance if a particular shader is present, with no
- regard for the shader's contents.
-
- This function should only have a meaningful implementation in ComponentType. All other types derived from
- ComponentType that also inherit from IComponentType should do nothing.
- */
- virtual SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash(
- SlangInt entryPointIndex,
- SlangInt targetIndex,
- IBlob** outHash) = 0;
-
- /** Compute the hash code of this component type's contents as indicated by the file dependencies.
- This hash is ideal when we need to confirm whether shader code changes have occurred. For example,
- a shader cache needs to be able to check when a cache entry contains out-of-date code, which can be
- easily detected by comparing the contents-based hashes since they will directly reflect any change
- to the shader's code.
-
- This function should only have a meaningful implementation in ComponentType. All other types derived
- from ComponentType that also inherit from IComponentType should do nothing. However, the only component
- type that should ever be hashing its contents is Module as it represents all the code in a given
- translation unit.
+ /** Compute a hash for the entry point at `entryPointIndex` for the chosen `targetIndex`.
+
+ This computes a hash based on all the dependencies for this component type as well as the
+ target settings affecting the compiler backend. The computed hash is used as a key for caching
+ the output of the compiler backend to implement shader caching.
*/
- virtual SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(IBlob** outHash) = 0;
+ virtual SLANG_NO_THROW void SLANG_MCALL getEntryPointHash(
+ SlangInt entryPointIndex,
+ SlangInt targetIndex,
+ IBlob** outHash) = 0;
/** Specialize the component by binding its specialization parameters to concrete arguments.
diff --git a/source/core/slang-crypto.cpp b/source/core/slang-crypto.cpp
index ece7b01e9..dfe246c7c 100644
--- a/source/core/slang-crypto.cpp
+++ b/source/core/slang-crypto.cpp
@@ -82,15 +82,19 @@ void MD5::update(const void* data, SlangInt size)
saved_lo = m_lo;
if ((m_lo = (saved_lo + size) & 0x1fffffff) < saved_lo)
+ {
m_hi++;
+ }
m_hi += (uint32_t)size >> 29;
used = saved_lo & 0x3f;
- if (used) {
+ if (used)
+ {
available = 64 - used;
- if (size < available) {
+ if (size < available)
+ {
::memcpy(&m_buffer[used], data, size);
return;
}
@@ -101,7 +105,8 @@ void MD5::update(const void* data, SlangInt size)
processBlock(m_buffer, 64);
}
- if (size >= 64) {
+ if (size >= 64)
+ {
data = processBlock(data, size & ~(SlangInt)0x3f);
size &= 0x3f;
}
@@ -119,7 +124,8 @@ MD5::Digest MD5::finalize()
available = 64 - used;
- if (available < 8) {
+ if (available < 8)
+ {
::memset(&m_buffer[used], 0, available);
processBlock(m_buffer, 64);
used = 0;
diff --git a/source/core/slang-io.cpp b/source/core/slang-io.cpp
index d8ef48ee3..f2be44f4f 100644
--- a/source/core/slang-io.cpp
+++ b/source/core/slang-io.cpp
@@ -682,7 +682,7 @@ namespace Slang
WIN32_FIND_DATAW fileData;
HANDLE findHandle = FindFirstFileW(searchPath.toWString(), &fileData);
- if (!findHandle)
+ if (findHandle == INVALID_HANDLE_VALUE)
{
return SLANG_E_NOT_FOUND;
}
diff --git a/source/core/slang-io.h b/source/core/slang-io.h
index fc5cbfa9d..e766359ff 100644
--- a/source/core/slang-io.h
+++ b/source/core/slang-io.h
@@ -266,9 +266,9 @@ namespace Slang
private:
LockFile(const LockFile&) = delete;
- LockFile(LockFile&) = delete;
+ LockFile(LockFile&&) = delete;
LockFile& operator=(const LockFile&) = delete;
- LockFile& operator=(const LockFile&&) = delete;
+ LockFile& operator=(LockFile&&) = delete;
#if SLANG_WINDOWS_FAMILY
void* m_fileHandle;
@@ -278,6 +278,29 @@ namespace Slang
bool m_isOpen;
};
+ class LockFileGuard
+ {
+ public:
+ LockFileGuard(LockFile& lockFile, LockFile::LockType lockType = LockFile::LockType::Exclusive)
+ : m_lockFile(lockFile)
+ {
+ m_lockFile.lock(lockType);
+ }
+
+ ~LockFileGuard()
+ {
+ m_lockFile.unlock();
+ }
+
+ private:
+ LockFileGuard(const LockFileGuard&) = delete;
+ LockFileGuard(LockFileGuard&&) = delete;
+ LockFileGuard& operator=(const LockFileGuard&) = delete;
+ LockFileGuard& operator=(LockFileGuard&&) = delete;
+
+ LockFile& m_lockFile;
+ };
+
}
#endif
diff --git a/source/core/slang-persistent-cache.cpp b/source/core/slang-persistent-cache.cpp
new file mode 100644
index 000000000..2b4113e16
--- /dev/null
+++ b/source/core/slang-persistent-cache.cpp
@@ -0,0 +1,289 @@
+#include "slang-persistent-cache.h"
+
+#include "../core/slang-io.h"
+#include "../core/slang-stream.h"
+#include "../core/slang-string-util.h"
+#include "../core/slang-blob.h"
+
+namespace Slang
+{
+
+PersistentCache::PersistentCache(const Desc& desc)
+{
+ m_cacheDirectory = Path::simplify(desc.directory);
+ Path::createDirectory(m_cacheDirectory);
+
+ m_lockFileName = Path::simplify(m_cacheDirectory + "/lock");
+ m_indexFileName = Path::simplify(m_cacheDirectory + "/index");
+
+ m_lockFile.open(m_lockFileName);
+
+ m_maxEntryCount = desc.maxEntryCount;
+
+ resetStats();
+
+ initialize();
+}
+
+PersistentCache::~PersistentCache()
+{
+}
+
+SlangResult PersistentCache::clear()
+{
+ if (!m_lockFile.isOpen())
+ {
+ return SLANG_E_CANNOT_OPEN;
+ }
+
+ // Acquire the exclusive lock.
+ std::lock_guard<std::mutex> mutexLock(m_mutex);
+ LockFileGuard fileLock(m_lockFile);
+
+ struct Visitor : Path::Visitor
+ {
+ const String& directory;
+ const String& lockFileName;
+
+ Visitor(const String& directory, const String& lockFileName)
+ : directory(directory)
+ , lockFileName(lockFileName)
+ {}
+
+ void accept(Path::Type type, const UnownedStringSlice& fileName) SLANG_OVERRIDE
+ {
+ String fullPath = Path::simplify(directory + "/" + fileName);;
+ if (type == Path::Type::File && lockFileName != fullPath)
+ {
+ Path::remove(fullPath);
+ }
+ }
+ };
+
+ Visitor visitor(m_cacheDirectory, m_lockFileName);
+ Path::find(m_cacheDirectory, nullptr, &visitor);
+
+ m_stats.entryCount = 0;
+
+ return SLANG_OK;
+}
+
+void PersistentCache::resetStats()
+{
+ m_stats.entryCount = 0;
+ m_stats.hitCount = 0;
+ m_stats.missCount = 0;
+}
+
+SlangResult PersistentCache::readEntry(const Key& key, ISlangBlob** outData)
+{
+ // Be pessimistic and assume we have a cache miss.
+ ++m_stats.missCount;
+
+ if (!m_lockFile.isOpen())
+ {
+ return SLANG_E_CANNOT_OPEN;
+ }
+
+ // Acquire the exclusive lock.
+ std::lock_guard<std::mutex> mutexLock(m_mutex);
+ LockFileGuard fileLock(m_lockFile);
+
+ // Return if index does not exist.
+ if (!File::exists(m_indexFileName))
+ {
+ return SLANG_E_NOT_FOUND;
+ }
+
+ // Read the cache index.
+ CacheIndex cacheIndex;
+ SLANG_RETURN_ON_FAIL(readIndex(m_indexFileName, cacheIndex));
+
+ // Increase the age of all entries in the cache.
+ for (auto& entry : cacheIndex)
+ {
+ ++entry.age;
+ }
+
+ // Find the entry.
+ Index entryIndex = cacheIndex.findFirstIndex([&key] (const CacheEntry& entry) { return entry.key == key; });
+ if (entryIndex == -1)
+ {
+ return SLANG_E_NOT_FOUND;
+ }
+
+ // Read the entry.
+ String entryFileName = getEntryFileName(key);
+ ScopedAllocation data;
+ SlangResult result = File::readAllBytes(entryFileName, data);
+ if (result == SLANG_OK)
+ {
+ --m_stats.missCount;
+ ++m_stats.hitCount;
+ cacheIndex[entryIndex].age = 0;
+ auto blob = RawBlob::moveCreate(data);
+ *outData = blob.detach();
+ }
+ else
+ {
+ cacheIndex.removeAt(entryIndex);
+ }
+
+ // Write the cache index.
+ SLANG_RETURN_ON_FAIL(writeIndex(m_indexFileName, cacheIndex));
+ m_stats.entryCount = (Count)cacheIndex.getCount();
+
+ return result;
+}
+
+SlangResult PersistentCache::writeEntry(const Key& key, ISlangBlob* data)
+{
+ SLANG_ASSERT(data);
+
+ if (!m_lockFile.isOpen())
+ {
+ return SLANG_E_CANNOT_OPEN;
+ }
+
+ // Acquire the exclusive lock.
+ std::lock_guard<std::mutex> mutexLock(m_mutex);
+ LockFileGuard fileLock(m_lockFile);
+
+ // Read the cache index.
+ // We ignore any errors when reading the index and just write a new one.
+ CacheIndex cacheIndex;
+ readIndex(m_indexFileName, cacheIndex);
+
+ // Increase the age of all entries in the cache and get the index of
+ // the oldest entry.
+ Index oldestEntryIndex = -1;
+ uint32_t oldestEntryAge = 0;
+ for (Index entryIndex = 0; entryIndex < cacheIndex.getCount(); ++entryIndex)
+ {
+ auto& entry = cacheIndex[entryIndex];
+ ++entry.age;
+ if (entry.age > oldestEntryAge)
+ {
+ oldestEntryIndex = entryIndex;
+ oldestEntryAge = entry.age;
+ }
+ }
+
+ // Write the cache entry.
+ String entryFileName = getEntryFileName(key);
+ SLANG_RETURN_ON_FAIL(File::writeAllBytes(entryFileName, data->getBufferPointer(), data->getBufferSize()));
+
+ // Update the index.
+ if (m_maxEntryCount > 0 && cacheIndex.getCount() >= m_maxEntryCount)
+ {
+ // Replace oldest entry.
+ SLANG_ASSERT(oldestEntryIndex >= 0);
+ File::remove(getEntryFileName(cacheIndex[oldestEntryIndex].key));
+ cacheIndex[oldestEntryIndex] = CacheEntry{ key, 0 };
+ }
+ else
+ {
+ // Add new entry.
+ cacheIndex.add(CacheEntry{ key, 0 });
+ }
+
+ // Write the cache index.
+ SlangResult result = writeIndex(m_indexFileName, cacheIndex);
+ if (result == SLANG_OK)
+ {
+ m_stats.entryCount = (Count)cacheIndex.getCount();
+ }
+ else
+ {
+ // If writing the index failed, remove the entry file to avoid growing the cache.
+ Path::remove(entryFileName);
+ }
+
+ return result;
+}
+
+SlangResult PersistentCache::initialize()
+{
+ if (!m_lockFile.isOpen())
+ {
+ return SLANG_E_CANNOT_OPEN;
+ }
+
+ // Acquire the exclusive lock.
+ std::lock_guard<std::mutex> mutexLock(m_mutex);
+ LockFileGuard fileLock(m_lockFile);
+
+ CacheIndex cacheIndex;
+ if (SLANG_SUCCEEDED(readIndex(m_indexFileName, cacheIndex)))
+ {
+ m_stats.entryCount = (Count)cacheIndex.getCount();
+ }
+
+ return SLANG_OK;
+}
+
+String PersistentCache::getEntryFileName(const Key& key)
+{
+ StringBuilder str;
+ str << m_cacheDirectory << "/" << key.toString();
+ return str;
+}
+
+struct CacheIndexHeader
+{
+ char magic[4];
+ uint32_t version;
+ uint32_t count;
+ uint32_t reserved;
+};
+
+static const char* kMagic = "SLS$";
+static const uint32_t kVersion = 1;
+
+SlangResult PersistentCache::readIndex(const String& fileName, CacheIndex& outIndex)
+{
+ FileStream fs;
+ SLANG_RETURN_ON_FAIL(fs.init(fileName, FileMode::Open));
+
+ // Get file size.
+ SLANG_RETURN_ON_FAIL(fs.seek(SeekOrigin::End, 0));
+ size_t fileSize = (size_t)fs.getPosition();
+ SLANG_RETURN_ON_FAIL(fs.seek(SeekOrigin::Start, 0));
+
+ CacheIndexHeader header;
+ SLANG_RETURN_ON_FAIL(fs.readExactly(&header, sizeof(header)));
+ if (::memcmp(header.magic, kMagic, 4) != 0 || header.version != kVersion)
+ {
+ return SLANG_E_INTERNAL_FAIL;
+ }
+
+ // Return if payload does not have the right size.
+ if (header.count * sizeof(CacheEntry) != fileSize - sizeof(header))
+ {
+ return SLANG_E_INTERNAL_FAIL;
+ }
+
+ outIndex.setCount(header.count);
+ SLANG_RETURN_ON_FAIL(fs.readExactly(outIndex.getBuffer(), header.count * sizeof(CacheEntry)));
+
+ return SLANG_OK;
+}
+
+SlangResult PersistentCache::writeIndex(const String& fileName, const CacheIndex& index)
+{
+ FileStream fs;
+ SLANG_RETURN_ON_FAIL(fs.init(fileName, FileMode::Create));
+
+ CacheIndexHeader header;
+ ::memcpy(header.magic, kMagic, 4);
+ header.version = kVersion;
+ header.count = (uint32_t)index.getCount();
+ header.reserved = 0;
+ SLANG_RETURN_ON_FAIL(fs.write(&header, sizeof(header)));
+
+ SLANG_RETURN_ON_FAIL(fs.write(index.getBuffer(), index.getCount() * sizeof(CacheEntry)));
+
+ return SLANG_OK;
+}
+
+}
diff --git a/source/core/slang-persistent-cache.h b/source/core/slang-persistent-cache.h
new file mode 100644
index 000000000..db5ef0a2e
--- /dev/null
+++ b/source/core/slang-persistent-cache.h
@@ -0,0 +1,91 @@
+#pragma once
+#include "../../slang.h"
+#include "../core/slang-crypto.h"
+#include "../core/slang-io.h"
+#include "../core/slang-string.h"
+
+#include <mutex>
+
+namespace Slang
+{
+
+/// Implements a simple persistent cache on the filesystem for storing key/value pairs.
+/// Keys are SHA1 hashes and values are arbitrary blobs of data.
+/// The cache is save for concurrent access from multiple threads/processes by using
+/// a lock file within the cache directory. Furthermore, the cache implements a LRU
+/// eviction policy.
+class PersistentCache : public RefObject
+{
+public:
+ struct Desc
+ {
+ // The root directory for the cache.
+ const char* directory = nullptr;
+ // The maximum number of entries stored in the cache. By default, there is no limit.
+ Count maxEntryCount = 0;
+ };
+
+ struct Stats
+ {
+ // Number of cache hits since last resetting the stats.
+ Count hitCount;
+ // Number of cache misses since last resetting the stats.
+ Count missCount;
+ // Current number of entries in the cache.
+ Count entryCount;
+ };
+
+ using Key = SHA1::Digest;
+
+ PersistentCache(const Desc& desc);
+ ~PersistentCache();
+
+ /// Clear the contents of the cache by removing the cache index and all entry files.
+ SlangResult clear();
+
+ const Stats& getStats() const { return m_stats; }
+ void resetStats();
+
+ /// Read an entry from the cache.
+ /// Returns SLANG_OK if successful, SLANG_E_NOT_FOUND if the entry is not in the cache.
+ SlangResult readEntry(const Key& key, ISlangBlob** outData);
+
+ /// Write an entry to the cache.
+ /// Returns SLANG_OK if successful.
+ SlangResult writeEntry(const Key& key, ISlangBlob* data);
+
+private:
+ struct CacheEntry
+ {
+ Key key;
+ uint32_t age;
+ };
+
+ using CacheIndex = List<CacheEntry>;
+
+ SlangResult initialize();
+
+ String getEntryFileName(const Key& key);
+
+ SlangResult readIndex(const String& fileName, CacheIndex& outIndex);
+ SlangResult writeIndex(const String& fileName, const CacheIndex& index);
+
+ String m_cacheDirectory;
+ String m_lockFileName;
+ String m_indexFileName;
+
+ // For exclusive locking we need both a mutex (acquired first)
+ // followed by a a file lock. The mutex is needed because on Linux
+ // the file lock is only locking between processes, not threads.
+ std::mutex m_mutex;
+ Slang::LockFile m_lockFile;
+
+ Count m_maxEntryCount;
+
+ Stats m_stats;
+
+ // Used for unit tests.
+ friend struct PersistentCacheTest;
+};
+
+}
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index ce0ae4085..f2ebefb9d 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -223,14 +223,9 @@ namespace Slang
visitor->visitEntryPoint(this, as<EntryPointSpecializationInfo>(specializationInfo));
}
- void EntryPoint::updateDependencyBasedHash(
- DigestBuilder<MD5>& builder,
- SlangInt entryPointIndex)
+ void EntryPoint::buildHash(DigestBuilder<SHA1>& builder)
{
- // CompositeComponentType will have already hashed the relevant entry point's name
- // and file path dependencies, so we immediately return.
SLANG_UNUSED(builder);
- SLANG_UNUSED(entryPointIndex);
}
List<Module*> const& EntryPoint::getModuleDependencies()
@@ -242,12 +237,12 @@ namespace Slang
return empty;
}
- List<String> const& EntryPoint::getFilePathDependencies()
+ List<SourceFile*> const& EntryPoint::getFileDependencies()
{
if(auto module = getModule())
- return getModule()->getFilePathDependencies();
+ return getModule()->getFileDependencies();
- static List<String> empty;
+ static List<SourceFile*> empty;
return empty;
}
@@ -269,8 +264,8 @@ namespace Slang
if (auto declaredWitness = as<DeclaredSubtypeWitness>(witness))
{
auto declModule = getModule(declaredWitness->declRef.getDecl());
- m_moduleDependency.addDependency(declModule);
- m_pathDependency.addDependency(declModule);
+ m_moduleDependencyList.addDependency(declModule);
+ m_fileDependencyList.addDependency(declModule);
if (m_requirementSet.Add(declModule))
{
m_requirements.add(declModule);
@@ -301,12 +296,8 @@ namespace Slang
return Super::getInterface(guid);
}
- void TypeConformance::updateDependencyBasedHash(
- DigestBuilder<MD5>& builder,
- SlangInt entryPointIndex)
+ void TypeConformance::buildHash(DigestBuilder<SHA1>& builder)
{
- SLANG_UNUSED(entryPointIndex);
-
//TODO: Implement some kind of hashInto for Val then replace this
auto subtypeWitness = m_subtypeWitness->toString();
@@ -316,12 +307,12 @@ namespace Slang
List<Module*> const& TypeConformance::getModuleDependencies()
{
- return m_moduleDependency.getModuleList();
+ return m_moduleDependencyList.getModuleList();
}
- List<String> const& TypeConformance::getFilePathDependencies()
+ List<SourceFile*> const& TypeConformance::getFileDependencies()
{
- return m_pathDependency.getFilePathList();
+ return m_fileDependencyList.getFileList();
}
Index TypeConformance::getRequirementCount() { return m_requirements.getCount(); }
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 692fefbe2..f377d4882 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -197,6 +197,7 @@ namespace Slang
};
/// Tracks an ordered list of modules that something depends on.
+ /// TODO: Shader caching currently relies on this being in well defined order.
struct ModuleDependencyList
{
public:
@@ -216,15 +217,16 @@ namespace Slang
HashSet<Module*> m_moduleSet;
};
- /// Tracks an unordered list of filesystem paths that something depends on
- struct FilePathDependencyList
+ /// Tracks an unordered list of source files that something depends on
+ /// TODO: Shader caching currently relies on this being in well defined order.
+ struct FileDependencyList
{
public:
- /// Get the list of paths that are depended on.
- List<String> const& getFilePathList() { return m_filePathList; }
+ /// Get the list of files that are depended on.
+ List<SourceFile*> const& getFileList() { return m_fileList; }
- /// Add a path to the list, if it is not already present
- void addDependency(String const& path);
+ /// Add a file to the list, if it is not already present
+ void addDependency(SourceFile* sourceFile);
/// Add all of the paths that `module` depends on to the list
void addDependency(Module* module);
@@ -236,11 +238,11 @@ namespace Slang
// multiple times from `getFilePathList`, but because
// order isn't important, we could potentially do better
// in terms of memory (at some cost in performance) by
- // just sorting the `m_filePathList` every once in
+ // just sorting the `m_fileList` every once in
// a while and then deduplicating.
- List<String> m_filePathList;
- HashSet<String> m_filePathSet;
+ List<SourceFile*> m_fileList;
+ HashSet<SourceFile*> m_fileSet;
};
class EntryPoint;
@@ -292,13 +294,12 @@ namespace Slang
slang::IBlob** outDiagnostics) SLANG_OVERRIDE;
/// ComponentType is the only class inheriting from IComponentType that provides a
- /// meaningful implementation for these two functions. All others should forward these
- /// and implement updateDependencyBasedHash and updateASTBasedHash instead.
- SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash(
+ /// meaningful implementation for this function. All others should forward these and
+ /// implement `buildHash`.
+ SLANG_NO_THROW void SLANG_MCALL getEntryPointHash(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IBlob** outHash) SLANG_OVERRIDE;
- SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(slang::IBlob** outHash) SLANG_OVERRIDE;
/// Get the linkage (aka "session" in the public API) for this component type.
Linkage* getLinkage() { return m_linkage; }
@@ -309,14 +310,7 @@ namespace Slang
TargetProgram* getTargetProgram(TargetRequest* target);
/// Update the hash builder with the dependencies for this component type.
- virtual void updateDependencyBasedHash(
- DigestBuilder<MD5>& hashBuilder,
- SlangInt entryPointIndex) = 0;
-
- /// Update the hash builder with the source contents for this component type.
- /// Module should be the only derived ComponentType class which has a meaningful
- /// implementation; all others should do nothing.
- virtual void updateContentsBasedHash(DigestBuilder<MD5>& hashBuilder) = 0;
+ virtual void buildHash(DigestBuilder<SHA1>& builder) = 0;
/// Get the number of entry points linked into this component type.
virtual Index getEntryPointCount() = 0;
@@ -371,9 +365,9 @@ namespace Slang
///
virtual List<Module*> const& getModuleDependencies() = 0;
- /// Get the full list of filesystem paths this component type depends on.
+ /// Get the full list of source files this component type depends on.
///
- virtual List<String> const& getFilePathDependencies() = 0;
+ virtual List<SourceFile*> const& getFileDependencies() = 0;
/// Callback for use with `enumerateIRModules`
typedef void (*EnumerateIRModulesCallback)(IRModule* irModule, void* userData);
@@ -515,11 +509,7 @@ namespace Slang
Linkage* linkage,
List<RefPtr<ComponentType>> const& childComponents);
- virtual void updateDependencyBasedHash(
- DigestBuilder<MD5>& hashBuilder,
- SlangInt entryPointIndex) override;
-
- virtual void updateContentsBasedHash(DigestBuilder<MD5>& hashBuilder) override;
+ virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE;
List<RefPtr<ComponentType>> const& getChildComponents() { return m_childComponents; };
Index getChildComponentCount() { return m_childComponents.getCount(); }
@@ -540,7 +530,7 @@ namespace Slang
RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
List<Module*> const& getModuleDependencies() SLANG_OVERRIDE;
- List<String> const& getFilePathDependencies() SLANG_OVERRIDE;
+ List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE;
class CompositeSpecializationInfo : public SpecializationInfo
{
@@ -584,7 +574,7 @@ namespace Slang
List<ComponentType*> m_requirements;
ModuleDependencyList m_moduleDependencyList;
- FilePathDependencyList m_filePathDependencyList;
+ FileDependencyList m_fileDependencyList;
};
/// A component type created by specializing another component type.
@@ -597,14 +587,7 @@ namespace Slang
List<SpecializationArg> const& specializationArgs,
DiagnosticSink* sink);
- virtual void updateDependencyBasedHash(
- DigestBuilder<MD5>& hashBuilder,
- SlangInt entryPointIndex) override;
-
- virtual void updateContentsBasedHash(DigestBuilder<MD5>& hashBuilder) override
- {
- SLANG_UNUSED(hashBuilder);
- }
+ virtual void buildHash(DigestBuilder<SHA1>& builer) SLANG_OVERRIDE;
/// Get the base (unspecialized) component type that is being specialized.
RefPtr<ComponentType> getBaseComponentType() { return m_base; }
@@ -638,7 +621,7 @@ namespace Slang
RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencies; }
- List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencies; }
+ List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencies; }
/// Get a list of tagged-union types referenced by the specialization parameters.
List<TaggedUnionType*> const& getTaggedUnionTypes() { return m_taggedUnionTypes; }
@@ -673,7 +656,7 @@ namespace Slang
List<TaggedUnionType*> m_taggedUnionTypes;
List<Module*> m_moduleDependencies;
- List<String> m_filePathDependencies;
+ List<SourceFile*> m_fileDependencies;
List<RefPtr<ComponentType>> m_requirements;
};
@@ -748,9 +731,9 @@ namespace Slang
{
return m_base->getModuleDependencies();
}
- List<String> const& getFilePathDependencies() SLANG_OVERRIDE
+ List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE
{
- return m_base->getFilePathDependencies();
+ return m_base->getFileDependencies();
}
SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE
@@ -790,14 +773,7 @@ namespace Slang
void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
SLANG_OVERRIDE;
- virtual void updateDependencyBasedHash(
- DigestBuilder<MD5>& hashBuilder,
- SlangInt entryPointIndex) override;
-
- virtual void updateContentsBasedHash(DigestBuilder<MD5>& hashBuilder) override
- {
- SLANG_UNUSED(hashBuilder);
- }
+ virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE;
private:
RefPtr<ComponentType> m_base;
@@ -891,27 +867,15 @@ namespace Slang
return Super::getEntryPointHostCallable(entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
}
- SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash(
+ SLANG_NO_THROW void SLANG_MCALL getEntryPointHash(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IBlob** outHash) SLANG_OVERRIDE
{
- return Super::computeDependencyBasedHash(entryPointIndex, targetIndex, outHash);
- }
-
- SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(slang::IBlob** outHash) SLANG_OVERRIDE
- {
- return Super::computeContentsBasedHash(outHash);
+ return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash);
}
- virtual void updateDependencyBasedHash(
- DigestBuilder<MD5>& hashBuilder,
- SlangInt entryPointIndex) override;
-
- virtual void updateContentsBasedHash(DigestBuilder<MD5>& hashBuilder) override
- {
- SLANG_UNUSED(hashBuilder);
- }
+ virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE;
/// Create an entry point that refers to the given function.
static RefPtr<EntryPoint> create(
@@ -948,7 +912,7 @@ namespace Slang
/// but may also include modules that are required by its generic type arguments.
///
List<Module*> const& getModuleDependencies() SLANG_OVERRIDE; // { return getModule()->getModuleDependencies(); }
- List<String> const& getFilePathDependencies() SLANG_OVERRIDE; // { return getModule()->getFilePathDependencies(); }
+ List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE; // { return getModule()->getFileDependencies(); }
/// Create a dummy `EntryPoint` that is only usable for pass-through compilation.
static RefPtr<EntryPoint> createDummyForPassThrough(
@@ -1118,30 +1082,18 @@ namespace Slang
entryPointIndex, targetIndex, outSharedLibrary, outDiagnostics);
}
- SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash(
+ SLANG_NO_THROW void SLANG_MCALL getEntryPointHash(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IBlob** outHash) SLANG_OVERRIDE
{
- return Super::computeDependencyBasedHash(entryPointIndex, targetIndex, outHash);
- }
-
- SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(slang::IBlob** outHash) SLANG_OVERRIDE
- {
- return Super::computeContentsBasedHash(outHash);
+ return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash);
}
- virtual void updateDependencyBasedHash(
- DigestBuilder<MD5>& hashBuilder,
- SlangInt entryPointIndex) override;
-
- virtual void updateContentsBasedHash(DigestBuilder<MD5>& hashBuilder) override
- {
- SLANG_UNUSED(hashBuilder);
- }
+ virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE;
List<Module*> const& getModuleDependencies() SLANG_OVERRIDE;
- List<String> const& getFilePathDependencies() SLANG_OVERRIDE;
+ List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE;
SLANG_NO_THROW Index SLANG_MCALL getSpecializationParamCount() SLANG_OVERRIDE { return 0; }
@@ -1182,8 +1134,8 @@ namespace Slang
DiagnosticSink* sink) SLANG_OVERRIDE;
private:
SubtypeWitness* m_subtypeWitness;
- ModuleDependencyList m_moduleDependency;
- FilePathDependencyList m_pathDependency;
+ ModuleDependencyList m_moduleDependencyList;
+ FileDependencyList m_fileDependencyList;
List<RefPtr<Module>> m_requirements;
HashSet<Module*> m_requirementSet;
RefPtr<IRModule> m_irModule;
@@ -1314,24 +1266,15 @@ namespace Slang
//
- SLANG_NO_THROW void SLANG_MCALL computeDependencyBasedHash(
+ SLANG_NO_THROW void SLANG_MCALL getEntryPointHash(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IBlob** outHash) SLANG_OVERRIDE
{
- return Super::computeDependencyBasedHash(entryPointIndex, targetIndex, outHash);
- }
-
- SLANG_NO_THROW void SLANG_MCALL computeContentsBasedHash(slang::IBlob** outHash) SLANG_OVERRIDE
- {
- return Super::computeContentsBasedHash(outHash);
+ return Super::getEntryPointHash(entryPointIndex, targetIndex, outHash);
}
- virtual void updateDependencyBasedHash(
- DigestBuilder<MD5>& hashBuilder,
- SlangInt entryPointIndex) override;
-
- virtual void updateContentsBasedHash(DigestBuilder<MD5>& hashBuilder) override;
+ virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE;
/// Create a module (initially empty).
Module(Linkage* linkage, ASTBuilder* astBuilder = nullptr);
@@ -1345,14 +1288,14 @@ namespace Slang
/// Get the list of other modules this module depends on
List<Module*> const& getModuleDependencyList() { return m_moduleDependencyList.getModuleList(); }
- /// Get the list of filesystem paths this module depends on
- List<String> const& getFilePathDependencyList() { return m_filePathDependencyList.getFilePathList(); }
+ /// Get the list of files this module depends on
+ List<SourceFile*> const& getFileDependencyList() { return m_fileDependencyList.getFileList(); }
/// Register a module that this module depends on
void addModuleDependency(Module* module);
- /// Register a filesystem path that this module depends on
- void addFilePathDependency(String const& path);
+ /// Register a source file that this module depends on
+ void addFileDependency(SourceFile* sourceFile);
/// Set the AST for this module.
///
@@ -1381,7 +1324,7 @@ namespace Slang
RefPtr<ComponentType> getRequirement(Index index) SLANG_OVERRIDE;
List<Module*> const& getModuleDependencies() SLANG_OVERRIDE { return m_moduleDependencyList.getModuleList(); }
- List<String> const& getFilePathDependencies() SLANG_OVERRIDE { return m_filePathDependencyList.getFilePathList(); }
+ List<SourceFile*> const& getFileDependencies() SLANG_OVERRIDE { return m_fileDependencyList.getFileList(); }
/// Given a mangled name finds the exported NodeBase associated with this module.
/// If not found returns nullptr.
@@ -1443,8 +1386,8 @@ namespace Slang
// List of modules this module depends on
ModuleDependencyList m_moduleDependencyList;
- // List of filesystem paths this module depends on
- FilePathDependencyList m_filePathDependencyList;
+ // List of source files this module depends on
+ FileDependencyList m_fileDependencyList;
// Entry points that were defined in thsi module
//
@@ -1474,9 +1417,6 @@ namespace Slang
// and m_mangledExportSymbols holds the NodeBase* values for each index.
StringSlicePool m_mangledExportPool;
List<NodeBase*> m_mangledExportSymbols;
-
- MD5::Digest lastModifiedDigest;
- MD5::Digest contentsDigest;
};
typedef Module LoadedModule;
@@ -1768,12 +1708,10 @@ namespace Slang
SLANG_NO_THROW SlangResult SLANG_MCALL createCompileRequest(
SlangCompileRequest** outCompileRequest) override;
- // Updates the supplied has builder with linkage-related information, which includes preprocessor
+ // Updates the supplied builder with linkage-related information, which includes preprocessor
// defines, the compiler version, and other compiler options. This is then merged with the hash
// produced for the program to produce a key that can be used with the shader cache.
- void updateDependencyBasedHash(
- DigestBuilder<MD5>& builder,
- SlangInt targetIndex);
+ void buildHash(DigestBuilder<SHA1>& builder, SlangInt targetIndex);
void addTarget(
slang::TargetDesc const& desc);
diff --git a/source/slang/slang-preprocessor.cpp b/source/slang/slang-preprocessor.cpp
index fca8f5029..341e75ea4 100644
--- a/source/slang/slang-preprocessor.cpp
+++ b/source/slang/slang-preprocessor.cpp
@@ -32,9 +32,9 @@ void PreprocessorHandler::handleEndOfTranslationUnit(Preprocessor* preprocessor)
SLANG_UNUSED(preprocessor);
}
-void PreprocessorHandler::handleFileDependency(String const& path)
+void PreprocessorHandler::handleFileDependency(SourceFile* sourceFile)
{
- SLANG_UNUSED(path);
+ SLANG_UNUSED(sourceFile);
}
// In order to simplify the naming scheme, we will nest the implementaiton of the
@@ -2966,15 +2966,6 @@ static SlangResult readFile(
auto fileSystemExt = context->m_preprocessor->fileSystem;
SLANG_RETURN_ON_FAIL(fileSystemExt->loadFile(path.getBuffer(), outBlob));
- // If we are running the preprocessor as part of compiling a
- // specific module, then we must keep track of the file we've
- // read as yet another file that the module will depend on.
- //
- if( auto handler = context->m_preprocessor->handler )
- {
- handler->handleFileDependency(path);
- }
-
return SLANG_OK;
}
@@ -3056,10 +3047,18 @@ static void HandleIncludeDirective(PreprocessorDirectiveContext* context)
}
sourceFile = sourceManager->createSourceFileWithBlob(filePathInfo, foundSourceBlob);
-
sourceManager->addSourceFile(filePathInfo.uniqueIdentity, sourceFile);
}
+ // If we are running the preprocessor as part of compiling a
+ // specific module, then we must keep track of the file we've
+ // read as yet another file that the module will depend on.
+ //
+ if (auto handler = context->m_preprocessor->handler)
+ {
+ handler->handleFileDependency(sourceFile);
+ }
+
// This is a new parse (even if it's a pre-existing source file), so create a new SourceView
SourceView* sourceView = sourceManager->createSourceView(sourceFile, &filePathInfo, directiveLoc);
diff --git a/source/slang/slang-preprocessor.h b/source/slang/slang-preprocessor.h
index 4d7721d31..c37fd1607 100644
--- a/source/slang/slang-preprocessor.h
+++ b/source/slang/slang-preprocessor.h
@@ -28,7 +28,7 @@ using preprocessor::Preprocessor;
struct PreprocessorHandler
{
virtual void handleEndOfTranslationUnit(Preprocessor* preprocessor);
- virtual void handleFileDependency(String const& path);
+ virtual void handleFileDependency(SourceFile* sourceFile);
};
/// Description of a preprocessor options/dependencies
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 3b2175fad..4f057bb21 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1320,9 +1320,7 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createCompileRequest(
return SLANG_OK;
}
-void Linkage::updateDependencyBasedHash(
- DigestBuilder<MD5>& builder,
- SlangInt targetIndex)
+void Linkage::buildHash(DigestBuilder<SHA1>& builder, SlangInt targetIndex)
{
// Add the Slang compiler version to the hash
auto version = String(getBuildTagString());
@@ -1346,29 +1344,39 @@ void Linkage::updateDependencyBasedHash(
// Add the target specified by targetIndex
auto targetReq = targets[targetIndex];
builder.append(targetReq->getTarget());
+ builder.append(targetReq->getTargetProfile().raw);
builder.append(targetReq->getTargetFlags());
builder.append(targetReq->getFloatingPointMode());
builder.append(targetReq->getLineDirectiveMode());
- builder.append(targetReq->shouldDumpIntermediates());
builder.append(targetReq->getForceGLSLScalarBufferLayout());
+ builder.append(targetReq->getDefaultMatrixLayoutMode());
+ builder.append(targetReq->shouldDumpIntermediates());
builder.append(targetReq->shouldTrackLiveness());
- auto targetProfile = targetReq->getTargetProfile();
- builder.append(targetProfile.getStage());
- builder.append(targetProfile.getVersion());
- builder.append(targetProfile.getFamily());
-
- auto targetProfileName = String(targetProfile.getName());
- builder.append(targetProfileName);
-
auto cookedCapabilities = targetReq->getTargetCaps().getExpandedAtoms();
for (auto& capability : cookedCapabilities)
{
builder.append(capability);
}
+ const PassThroughMode passThroughMode = getDownstreamCompilerRequiredForTarget(targetReq->getTarget());
+ const SourceLanguage sourceLanguage = getDefaultSourceLanguageForDownstreamCompiler(passThroughMode);
+
+ // Add prelude for the given downstream compiler.
+ ComPtr<ISlangBlob> prelude;
+ getGlobalSession()->getLanguagePrelude((SlangSourceLanguage)sourceLanguage, prelude.writeRef());
+ if (prelude)
+ {
+ builder.append(prelude);
+ }
+
+ // TODO: Downstream compilers (specifically dxc) can currently #include additional dependencies.
+ // This is currently the case for NVAPI headers included in the prelude.
+ // These dependencies are currently not picked up by the shader cache which is a significant issue.
+ // This can only be fixed by running the preprocessor in the slang compiler so dxc (or any other
+ // downstream compiler for that matter) isn't resolving any includes implicitly.
+
// Add the downstream compiler version (if it exists) to the hash
- auto passThroughMode = getDownstreamCompilerRequiredForTarget(targetReq->getTarget());
auto downstreamCompiler = getSessionImpl()->getOrLoadDownstreamCompiler(passThroughMode, nullptr);
if (downstreamCompiler)
{
@@ -1674,25 +1682,7 @@ void TranslationUnitRequest::_addSourceFile(SourceFile* sourceFile)
{
m_sourceFiles.add(sourceFile);
- // We want to record that the compiled module has a dependency
- // on the path of the source file, but we also need to account
- // for cases where the user added a source string/blob without
- // an associated path and/or wasn't from a file.
-
- auto pathInfo = sourceFile->getPathInfo();
- if (pathInfo.hasFoundPath())
- {
- getModule()->addFilePathDependency(pathInfo.foundPath);
- }
- else
- {
- // No path exists for this source, so we generate a new string to use as a
- // fake path in the list of file path dependencies. This is needed to account
- // for non-file-based dependencies later when shader files are being hashed for
- // the shader cache.
- auto sourceHash = MD5::compute(sourceFile->getContent().begin(), sourceFile->getContent().getLength());
- getModule()->addFilePathDependency(sourceHash.toString());
- }
+ getModule()->addFileDependency(sourceFile);
}
List<SourceFile*> const& TranslationUnitRequest::getSourceFiles()
@@ -1896,9 +1886,9 @@ protected:
// by applications to decide when they need to "hot reload"
// their shader code.
//
- void handleFileDependency(String const& path) SLANG_OVERRIDE
+ void handleFileDependency(SourceFile* sourceFile) SLANG_OVERRIDE
{
- m_module->addFilePathDependency(path);
+ m_module->addFileDependency(sourceFile);
}
// The second task that this handler deals with is detecting
@@ -3200,23 +3190,23 @@ void ModuleDependencyList::_addDependency(Module* module)
}
//
-// FilePathDependencyList
+// FileDependencyList
//
-void FilePathDependencyList::addDependency(String const& path)
+void FileDependencyList::addDependency(SourceFile* sourceFile)
{
- if(m_filePathSet.Contains(path))
+ if(m_fileSet.Contains(sourceFile))
return;
- m_filePathList.add(path);
- m_filePathSet.Add(path);
+ m_fileList.add(sourceFile);
+ m_fileSet.Add(sourceFile);
}
-void FilePathDependencyList::addDependency(Module* module)
+void FileDependencyList::addDependency(Module* module)
{
- for(auto& path : module->getFilePathDependencyList())
+ for(SourceFile* sourceFile : module->getFileDependencyList())
{
- addDependency(path);
+ addDependency(sourceFile);
}
}
@@ -3247,75 +3237,20 @@ ISlangUnknown* Module::getInterface(const Guid& guid)
return Super::getInterface(guid);
}
-void Module::updateDependencyBasedHash(
- DigestBuilder<MD5>& builder,
- SlangInt entryPointIndex)
+void Module::buildHash(DigestBuilder<SHA1>& builder)
{
- // CompositeComponentType will have already hashed this Module's file
- // dependencies.
SLANG_UNUSED(builder);
- SLANG_UNUSED(entryPointIndex);
-}
-
-void Module::updateContentsBasedHash(DigestBuilder<MD5>& builder)
-{
- auto filePathDependencies = getFilePathDependencies();
-
- DigestBuilder<MD5> lastModifiedBuilder;
- auto statFailed = false;
- for (auto file : filePathDependencies)
- {
- struct stat fileStatus;
- auto res = stat(file.getBuffer(), &fileStatus);
- if (res != 0)
- {
- statFailed = true;
- break;
- }
- lastModifiedBuilder.append(fileStatus.st_mtime);
- }
-
- MD5::Digest temp = lastModifiedBuilder.finalize();
- if (statFailed || temp != lastModifiedDigest)
- {
- // Either a stat() call failed, or changes were made to at least one of the file dependencies,
- // so we will need to re-generate the contents digest and save the new digest.
- DigestBuilder<MD5> contentsBuilder;
- for (auto file : filePathDependencies)
- {
- List<uint8_t> fileContents;
- if (SLANG_FAILED(File::readAllBytes(file, fileContents)))
- {
- // Failure to read the file means this is a digest for the contents of a source
- // file which does not live on disk.
- contentsBuilder.append(file);
- }
- else
- {
- contentsBuilder.append(fileContents);
- }
- }
- contentsDigest = contentsBuilder.finalize();
- if (!statFailed)
- {
- // If no stat() calls failed, then we have a valid last modified digest and should
- // update the one we have saved.
- lastModifiedDigest = temp;
- }
- }
-
- builder.append(contentsDigest);
}
void Module::addModuleDependency(Module* module)
{
m_moduleDependencyList.addDependency(module);
- m_filePathDependencyList.addDependency(module);
+ m_fileDependencyList.addDependency(module);
}
-void Module::addFilePathDependency(String const& path)
+void Module::addFileDependency(SourceFile* sourceFile)
{
- m_filePathDependencyList.addDependency(path);
+ m_fileDependencyList.addDependency(sourceFile);
}
void Module::setModuleDecl(ModuleDecl* moduleDecl)
@@ -3505,12 +3440,12 @@ SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointCode(
return artifact->loadBlob(ArtifactKeep::Yes, outCode);
}
-SLANG_NO_THROW void SLANG_MCALL ComponentType::computeDependencyBasedHash(
+SLANG_NO_THROW void SLANG_MCALL ComponentType::getEntryPointHash(
SlangInt entryPointIndex,
SlangInt targetIndex,
slang::IBlob** outHash)
{
- DigestBuilder<MD5> builder;
+ DigestBuilder<SHA1> builder;
// A note on enums that may be hashed in as part of the following two function calls:
//
@@ -3518,19 +3453,19 @@ SLANG_NO_THROW void SLANG_MCALL ComponentType::computeDependencyBasedHash(
// the compiler, part of hashing the linkage is hashing in the compiler version.
// Consequently, any encoding differences as a result of different compiler versions
// will already be reflected in the resulting hash.
- getLinkage()->updateDependencyBasedHash(builder, targetIndex);
- updateDependencyBasedHash(builder, entryPointIndex);
+ getLinkage()->buildHash(builder, targetIndex);
- // Add file path dependencies to the hash - all child components
- // will have file path dependencies that are a subset of this list.
- auto fileDeps = getFilePathDependencies();
- for (auto& file : fileDeps)
+ // Enumerate all file dependencies and add them to the hash.
+ for (SourceFile* sourceFile : getFileDependencies())
{
- builder.append(file);
+ // TODO: We want to lazily evaluate & cache the source file digest
+ SHA1::Digest digest = SHA1::compute(sourceFile->getContent().begin(), sourceFile->getContent().getLength());
+ builder.append(digest);
}
- // Add the name and name override for the specified entry point
- // to the hash.
+ buildHash(builder);
+
+ // Add the name and name override for the specified entry point to the hash.
auto entryPointName = getEntryPoint(entryPointIndex)->getName()->text;
builder.append(entryPointName);
auto entryPointMangledName = getEntryPointMangledName(entryPointIndex);
@@ -3542,14 +3477,6 @@ SLANG_NO_THROW void SLANG_MCALL ComponentType::computeDependencyBasedHash(
*outHash = hash.detach();
}
-SLANG_NO_THROW void SLANG_MCALL ComponentType::computeContentsBasedHash(slang::IBlob** outHash)
-{
- DigestBuilder<MD5> builder;
- updateContentsBasedHash(builder);
- auto hash = builder.finalize().toBlob();
- *outHash = hash.detach();
-}
-
SLANG_NO_THROW SlangResult SLANG_MCALL ComponentType::getEntryPointHostCallable(
int entryPointIndex,
int targetIndex,
@@ -3864,9 +3791,9 @@ CompositeComponentType::CompositeComponentType(
{
m_moduleDependencyList.addDependency(module);
}
- for(auto filePath : child->getFilePathDependencies())
+ for(auto sourceFile : child->getFileDependencies())
{
- m_filePathDependencyList.addDependency(filePath);
+ m_fileDependencyList.addDependency(sourceFile);
}
auto childRequirementCount = child->getRequirementCount();
@@ -3882,25 +3809,13 @@ CompositeComponentType::CompositeComponentType(
}
}
-void CompositeComponentType::updateDependencyBasedHash(
- DigestBuilder<MD5>& builder,
- SlangInt entryPointIndex)
-{
- auto componentCount = getChildComponentCount();
-
- for (Index i = 0; i < componentCount; ++i)
- {
- getChildComponent(i)->updateDependencyBasedHash(builder, entryPointIndex);
- }
-}
-
-void CompositeComponentType::updateContentsBasedHash(DigestBuilder<MD5>& builder)
+void CompositeComponentType::buildHash(DigestBuilder<SHA1>& builder)
{
auto componentCount = getChildComponentCount();
for (Index i = 0; i < componentCount; ++i)
{
- getChildComponent(i)->updateContentsBasedHash(builder);
+ getChildComponent(i)->buildHash(builder);
}
}
@@ -3959,9 +3874,9 @@ List<Module*> const& CompositeComponentType::getModuleDependencies()
return m_moduleDependencyList.getModuleList();
}
-List<String> const& CompositeComponentType::getFilePathDependencies()
+List<SourceFile*> const& CompositeComponentType::getFileDependencies()
{
- return m_filePathDependencyList.getFilePathList();
+ return m_fileDependencyList.getFileList();
}
void CompositeComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
@@ -4226,7 +4141,7 @@ SpecializedComponentType::SpecializedComponentType(
// The starting point for our lists comes from the base component type.
//
m_moduleDependencies = base->getModuleDependencies();
- m_filePathDependencies = base->getFilePathDependencies();
+ m_fileDependencies = base->getFileDependencies();
Index baseRequirementCount = base->getRequirementCount();
for( Index r = 0; r < baseRequirementCount; r++ )
@@ -4238,11 +4153,11 @@ SpecializedComponentType::SpecializedComponentType(
// dependencies and requirements based on the modules that
// were collected when looking at the specialization arguments.
- // We want to avoid adding the same file path dependency more than once.
+ // We want to avoid adding the same file dependency more than once.
//
- HashSet<String> filePathDependencySet;
- for(auto path : m_filePathDependencies)
- filePathDependencySet.Add(path);
+ HashSet<SourceFile*> fileDependencySet;
+ for(SourceFile* sourceFile : m_fileDependencies)
+ fileDependencySet.Add(sourceFile);
for(auto module : moduleCollector.m_modulesList)
{
@@ -4259,7 +4174,7 @@ SpecializedComponentType::SpecializedComponentType(
m_requirements.add(module);
// The speciialized component type will also have a dependency
- // on all the file paths that any of the modules involved in
+ // on all the files that any of the modules involved in
// it depend on (including those that are required but not
// yet linked in).
//
@@ -4268,12 +4183,12 @@ SpecializedComponentType::SpecializedComponentType(
// source files, so we want to include anything that could
// affect the validity of generated code.
//
- for(auto path : module->getFilePathDependencies())
+ for(SourceFile* sourceFile : module->getFileDependencies())
{
- if(filePathDependencySet.Contains(path))
+ if(fileDependencySet.Contains(sourceFile))
continue;
- filePathDependencySet.Add(path);
- m_filePathDependencies.add(path);
+ fileDependencySet.Add(sourceFile);
+ m_fileDependencies.add(sourceFile);
}
// Finalyl we also add the module for the specialization arguments
@@ -4383,9 +4298,7 @@ SpecializedComponentType::SpecializedComponentType(
collector.visitSpecialized(this);
}
-void SpecializedComponentType::updateDependencyBasedHash(
- DigestBuilder<MD5>& builder,
- SlangInt entryPointIndex)
+void SpecializedComponentType::buildHash(DigestBuilder<SHA1>& builder)
{
auto specializationArgCount = getSpecializationArgCount();
for (Index i = 0; i < specializationArgCount; ++i)
@@ -4395,7 +4308,7 @@ void SpecializedComponentType::updateDependencyBasedHash(
builder.append(argString);
}
- getBaseComponentType()->updateDependencyBasedHash(builder, entryPointIndex);
+ getBaseComponentType()->buildHash(builder);
}
void SpecializedComponentType::acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo)
@@ -4442,13 +4355,8 @@ void RenamedEntryPointComponentType::acceptVisitor(
this, as<EntryPoint::EntryPointSpecializationInfo>(specializationInfo));
}
-void RenamedEntryPointComponentType::updateDependencyBasedHash(
- DigestBuilder<MD5>& builder,
- SlangInt entryPointIndex)
+void RenamedEntryPointComponentType::buildHash(DigestBuilder<SHA1>& builder)
{
- // CompositeComponentType will have already hashed the name override and file
- // dependencies for this entry point.
- SLANG_UNUSED(entryPointIndex);
SLANG_UNUSED(builder);
}
@@ -5144,14 +5052,15 @@ int EndToEndCompileRequest::getDependencyFileCount()
{
auto frontEndReq = getFrontEndReq();
auto program = frontEndReq->getGlobalAndEntryPointsComponentType();
- return (int)program->getFilePathDependencies().getCount();
+ return (int)program->getFileDependencies().getCount();
}
char const* EndToEndCompileRequest::getDependencyFilePath(int index)
{
auto frontEndReq = getFrontEndReq();
auto program = frontEndReq->getGlobalAndEntryPointsComponentType();
- return program->getFilePathDependencies()[index].begin();
+ SourceFile* sourceFile = program->getFileDependencies()[index];
+ return sourceFile->getPathInfo().hasFileFoundPath() ? sourceFile->getPathInfo().foundPath.getBuffer() : "unknown";
}
int EndToEndCompileRequest::getTranslationUnitCount()
diff --git a/tools/gfx-unit-test/gfx-test-util.cpp b/tools/gfx-unit-test/gfx-test-util.cpp
index 116b4222a..5cbb30e71 100644
--- a/tools/gfx-unit-test/gfx-test-util.cpp
+++ b/tools/gfx-unit-test/gfx-test-util.cpp
@@ -194,6 +194,7 @@ namespace gfx_test
Slang::ComPtr<gfx::IDevice> createTestingDevice(
UnitTestContext* context,
Slang::RenderApiFlag::Enum api,
+ Slang::List<const char*> additionalSearchPaths,
gfx::IDevice::ShaderCacheDesc shaderCache)
{
Slang::ComPtr<gfx::IDevice> device;
@@ -222,10 +223,13 @@ namespace gfx_test
SLANG_IGNORE_TEST
}
deviceDesc.slang.slangGlobalSession = context->slangGlobalSession;
- const char* searchPaths[] = { "", "../../tools/gfx-unit-test", "tools/gfx-unit-test" };
- deviceDesc.slang.searchPathCount = (SlangInt)SLANG_COUNT_OF(searchPaths);
- deviceDesc.slang.searchPaths = searchPaths;
-
+ Slang::List<const char*> searchPaths;
+ searchPaths.add("");
+ searchPaths.add("../../tools/gfx-unit-test");
+ searchPaths.add("tools/gfx-unit-test");
+ searchPaths.addRange(additionalSearchPaths);
+ deviceDesc.slang.searchPaths = searchPaths.getBuffer();
+ deviceDesc.slang.searchPathCount = (gfx::GfxCount)searchPaths.getCount();
deviceDesc.shaderCache = shaderCache;
gfx::D3D12DeviceExtendedDesc extDesc = {};
diff --git a/tools/gfx-unit-test/gfx-test-util.h b/tools/gfx-unit-test/gfx-test-util.h
index d11d5623c..f829d6d12 100644
--- a/tools/gfx-unit-test/gfx-test-util.h
+++ b/tools/gfx-unit-test/gfx-test-util.h
@@ -77,6 +77,7 @@ namespace gfx_test
Slang::ComPtr<gfx::IDevice> createTestingDevice(
UnitTestContext* context,
Slang::RenderApiFlag::Enum api,
+ Slang::List<const char*> additionalSearchPaths = {},
gfx::IDevice::ShaderCacheDesc shaderCache = {});
void initializeRenderDoc();
@@ -88,13 +89,14 @@ namespace gfx_test
const ImplFunc& f,
UnitTestContext* context,
Slang::RenderApiFlag::Enum api,
+ Slang::List<const char*> searchPaths = {},
gfx::IDevice::ShaderCacheDesc shaderCache = {})
{
if ((api & context->enabledApis) == 0)
{
SLANG_IGNORE_TEST
}
- auto device = createTestingDevice(context, api, shaderCache);
+ auto device = createTestingDevice(context, api, searchPaths, shaderCache);
if (!device)
{
SLANG_IGNORE_TEST
diff --git a/tools/gfx-unit-test/split-graphics-fragment.slang b/tools/gfx-unit-test/shader-cache-graphics-fragment.slang
index db515a957..392aa15ba 100644
--- a/tools/gfx-unit-test/split-graphics-fragment.slang
+++ b/tools/gfx-unit-test/shader-cache-graphics-fragment.slang
@@ -1,4 +1,4 @@
-// split-graphics-fragment.slang
+// shader-cache-graphics-fragment.slang
// Output of the vertex shader, and input to the fragment shader.
struct CoarseVertex
diff --git a/tools/gfx-unit-test/split-graphics-vertex.slang b/tools/gfx-unit-test/shader-cache-graphics-vertex.slang
index 615686a90..a86f8bcf1 100644
--- a/tools/gfx-unit-test/split-graphics-vertex.slang
+++ b/tools/gfx-unit-test/shader-cache-graphics-vertex.slang
@@ -1,4 +1,4 @@
-// split-graphics-vertex.slang
+// shader-cache-graphics-vertex.slang
// Per-vertex attributes to be assembled from bound vertex buffers.
struct AssembledVertex
diff --git a/tools/gfx-unit-test/multiple-entry-point-shader-cache-shader.slang b/tools/gfx-unit-test/shader-cache-multiple-entry-points.slang
index 9287b62ea..a0015b83c 100644
--- a/tools/gfx-unit-test/multiple-entry-point-shader-cache-shader.slang
+++ b/tools/gfx-unit-test/shader-cache-multiple-entry-points.slang
@@ -1,9 +1,10 @@
-uniform RWStructuredBuffer<float> buffer;
-
+// shader-cache-multiple-entry-points.slang
+
[shader("compute")]
[numthreads(4, 1, 1)]
void computeA(
-uint3 sv_dispatchThreadID : SV_DispatchThreadID)
+ uint3 sv_dispatchThreadID: SV_DispatchThreadID,
+ uniform RWStructuredBuffer<float> buffer)
{
var input = buffer[sv_dispatchThreadID.x];
buffer[sv_dispatchThreadID.x] = input + 1.0f;
@@ -12,7 +13,8 @@ uint3 sv_dispatchThreadID : SV_DispatchThreadID)
[shader("compute")]
[numthreads(4, 1, 1)]
void computeB(
-uint3 sv_dispatchThreadID : SV_DispatchThreadID)
+ uint3 sv_dispatchThreadID: SV_DispatchThreadID,
+ uniform RWStructuredBuffer<float> buffer)
{
var input = buffer[sv_dispatchThreadID.x];
buffer[sv_dispatchThreadID.x] = input + 2.0f;
@@ -21,7 +23,8 @@ uint3 sv_dispatchThreadID : SV_DispatchThreadID)
[shader("compute")]
[numthreads(4, 1, 1)]
void computeC(
-uint3 sv_dispatchThreadID : SV_DispatchThreadID)
+ uint3 sv_dispatchThreadID: SV_DispatchThreadID,
+ uniform RWStructuredBuffer<float> buffer)
{
var input = buffer[sv_dispatchThreadID.x];
buffer[sv_dispatchThreadID.x] = input + 3.0f;
diff --git a/tools/gfx-unit-test/shader-cache-specialization.slang b/tools/gfx-unit-test/shader-cache-specialization.slang
new file mode 100644
index 000000000..63994aee8
--- /dev/null
+++ b/tools/gfx-unit-test/shader-cache-specialization.slang
@@ -0,0 +1,68 @@
+// shader-cache-specialization.slang
+
+// This is a copy of `shader-object.slang` in `shader-object` example
+// for use by compute-smoke gfx unit test.
+
+// This file implements a simple compute shader that transforms
+// input floating point numbers stored in a `RWStructuredBuffer`.
+// Specifically, for each number x from input buffer, compute
+// f(x) and store the result back in the same buffer.
+
+// The compute shader supports multiple transformation functions,
+// such add(x, c) which returns x+c, or mul(x, c) which returns x*c.
+// This functions are implemented as types that conforms to the
+// `ITransformer` interface.
+
+// The main entry point function takes a parameter of `ITransformer`
+// type, and applies the transformation to numbers in the input
+// buffer. By defining the shader parameter using interfaces,
+// we enable the flexiblity to generate either specialized compute
+// kernels that performs specific transformation or a general
+// kernel that can perform any transformations encoded by the
+// parameter at run-time, without changing any shader code or
+// host-application logic for setting and preparing shader parameters.
+
+// Defines the transformer interface, which implements a single
+// `transform` operation.
+interface ITransformer
+{
+ float transform(float x);
+}
+
+// Represents a transform function f(x) = x + c.
+struct AddTransformer : ITransformer
+{
+ float c;
+ float transform(float x) { return x + c + 10.0f; }
+};
+
+// Represents a transform function f(x) = x * c.
+struct MulTransformer : ITransformer
+{
+ float c;
+ float transform(float x) { return x * c; }
+};
+
+// Represents a composite function f(x) = f0(f1(x));
+struct CompositeTransformer : ITransformer
+{
+ ITransformer func0;
+ ITransformer func1;
+ float transform(float x)
+ {
+ return func0.transform(func1.transform(x));
+ }
+};
+
+// Main entry-point. Applies the transformation encoded by `transformer`
+// to all elements in `buffer`.
+[shader("compute")]
+[numthreads(4,1,1)]
+void computeMain(
+ uint3 sv_dispatchThreadID : SV_DispatchThreadID,
+ uniform RWStructuredBuffer<float> buffer,
+ uniform ITransformer transformer)
+{
+ var input = buffer[sv_dispatchThreadID.x];
+ buffer[sv_dispatchThreadID.x] = transformer.transform(input);
+}
diff --git a/tools/gfx-unit-test/shader-cache-tests.cpp b/tools/gfx-unit-test/shader-cache-tests.cpp
index 486b59cda..4cccc726f 100644
--- a/tools/gfx-unit-test/shader-cache-tests.cpp
+++ b/tools/gfx-unit-test/shader-cache-tests.cpp
@@ -5,8 +5,7 @@
#include "tools/gfx-util/shader-cursor.h"
#include "source/core/slang-basic.h"
#include "source/core/slang-string-util.h"
-
-#include "source/core/slang-memory-file-system.h"
+#include "source/core/slang-io.h"
#include "source/core/slang-file-system.h"
#include "gfx-test-texture-util.h"
@@ -17,69 +16,136 @@ using namespace Slang;
namespace gfx_test
{
- struct BaseShaderCacheTest
+ struct ShaderCacheTest
{
UnitTestContext* context;
- RenderApiFlag::Enum api;
+ Slang::RenderApiFlag::Enum api;
+
+ String testDirectory;
+ String cacheDirectory;
+
+ ComPtr<ISlangMutableFileSystem> diskFileSystem;
+
+ IDevice::ShaderCacheDesc shaderCacheDesc = {};
ComPtr<IDevice> device;
- ComPtr<IShaderCacheStatistics> shaderCacheStats;
+ ComPtr<IShaderCache> shaderCache;
ComPtr<IPipelineState> pipelineState;
ComPtr<IResourceView> bufferView;
- IDevice::ShaderCacheDesc shaderCache = {};
-
- // Two file systems in order to get around problems posed by the testing framework.
- //
- // - diskFileSystem - Used to save any files that must exist on disk for subsequent
- // save/load function calls (most prominently loadComputeProgram()) to pick up.
- // This is also used to test the file stream implementation for the cache.
- // - memoryFileSystem - Used to test the fallback path for the cache in the case physical
- // file paths cannot be obtained, which prevents usage of file streams.
- ComPtr<ISlangMutableFileSystem> diskFileSystem;
- ComPtr<ISlangMutableFileSystem> memoryFileSystem;
-
- // Simple compute shaders we can pipe to our individual shader files for cache testing
- String contentsA = String(
+ String computeShaderA = String(
R"(
- uniform RWStructuredBuffer<float> buffer;
-
[shader("compute")]
[numthreads(4, 1, 1)]
- void computeMain(
- uint3 sv_dispatchThreadID : SV_DispatchThreadID)
+ void main(
+ uint3 sv_dispatchThreadID : SV_DispatchThreadID,
+ uniform RWStructuredBuffer<float> buffer)
{
var input = buffer[sv_dispatchThreadID.x];
buffer[sv_dispatchThreadID.x] = input + 1.0f;
- })");
+ }
+ )");
- String contentsB = String(
+ String computeShaderB = String(
R"(
- uniform RWStructuredBuffer<float> buffer;
-
[shader("compute")]
[numthreads(4, 1, 1)]
- void computeMain(
- uint3 sv_dispatchThreadID : SV_DispatchThreadID)
+ void main(
+ uint3 sv_dispatchThreadID : SV_DispatchThreadID,
+ uniform RWStructuredBuffer<float> buffer)
{
var input = buffer[sv_dispatchThreadID.x];
buffer[sv_dispatchThreadID.x] = input + 2.0f;
- })");
+ }
+ )");
- String contentsC = String(
+ String computeShaderC = String(
R"(
- uniform RWStructuredBuffer<float> buffer;
-
[shader("compute")]
[numthreads(4, 1, 1)]
- void computeMain(
- uint3 sv_dispatchThreadID : SV_DispatchThreadID)
+ void main(
+ uint3 sv_dispatchThreadID : SV_DispatchThreadID,
+ uniform RWStructuredBuffer<float> buffer)
{
var input = buffer[sv_dispatchThreadID.x];
buffer[sv_dispatchThreadID.x] = input + 3.0f;
- })");
+ }
+ )");
+
+
+ void removeDirectory(const String& directory)
+ {
+ auto osFileSystem = OSFileSystem::getMutableSingleton();
+
+ struct Context
+ {
+ ISlangMutableFileSystem *fileSystem;
+ const String& directory;
+ } context { osFileSystem, directory };
+
+ osFileSystem->enumeratePathContents(
+ directory.getBuffer(),
+ [](SlangPathType pathType, const char* fileName, void* userData)
+ {
+ struct Context* context = static_cast<Context *>(userData);
+ if (pathType == SlangPathType::SLANG_PATH_TYPE_FILE)
+ {
+ String path = Path::simplify(context->directory + "/" + fileName);
+ context->fileSystem->remove(path.getBuffer());
+ }
+ },
+ &context);
+
+ osFileSystem->remove(directory.getBuffer());
+ }
+
+ void writeShader(const String& source, const String& fileName)
+ {
+ diskFileSystem->saveFile(fileName.getBuffer(), source.getBuffer(), source.getLength());
+ }
+
+ void init(UnitTestContext* context, Slang::RenderApiFlag::Enum api)
+ {
+ this->context = context;
+ this->api = api;
+
+ testDirectory = Path::simplify(Path::getParentDirectory(Path::getExecutablePath()) + "/shader-cache-test");
+ cacheDirectory = Path::simplify(testDirectory + "/cache");
+
+ // Cleanup if there are stale files from a previously aborted test.
+ removeDirectory(cacheDirectory);
+ removeDirectory(testDirectory);
+
+ Path::createDirectory(testDirectory);
+ diskFileSystem = new RelativeFileSystem(OSFileSystem::getMutableSingleton(), testDirectory);
+ shaderCacheDesc.shaderCachePath = cacheDirectory.getBuffer();
+ }
+
+ void cleanup()
+ {
+ removeDirectory(cacheDirectory);
+ removeDirectory(testDirectory);
+ }
+
+ template<typename Func>
+ void runStep(Func func)
+ {
+ List<const char*> additionalSearchPaths;
+ additionalSearchPaths.add(testDirectory.getBuffer());
+
+ runTestImpl(
+ [this, func] (IDevice* device, UnitTestContext* ctx)
+ {
+ this->device = device;
+ device->queryInterface(SLANG_UUID_IShaderCache, (void**)this->shaderCache.writeRef());
+ func();
+ this->device = nullptr;
+ this->shaderCache = nullptr;
+ },
+ context, api, additionalSearchPaths, shaderCacheDesc);
+ }
- void createRequiredResources()
+ void createComputeResources()
{
const int numberCount = 4;
float initialData[] = { 0.0f, 1.0f, 2.0f, 3.0f };
@@ -108,57 +174,25 @@ namespace gfx_test
device->createBufferView(numbersBuffer, nullptr, viewDesc, bufferView.writeRef()));
}
- void freeOldResources()
+ void freeComputeResources()
{
bufferView = nullptr;
pipelineState = nullptr;
- device = nullptr;
- shaderCacheStats = nullptr;
}
- // TODO: This should be removed at some point. Currently exists as a workaround for module loading
- // seemingly not accounting for updated shader code under the same module name with the same entry point.
- void generateNewDevice()
+ void createComputePipeline(const char* moduleName, const char* entryPointName)
{
- freeOldResources();
- device = createTestingDevice(context, api, shaderCache);
- }
-
- void init(ComPtr<IDevice> device, UnitTestContext* context)
- {
- this->device = device;
- this->context = context;
- switch (device->getDeviceInfo().deviceType)
- {
- case DeviceType::DirectX11:
- api = RenderApiFlag::D3D11;
- break;
- case DeviceType::DirectX12:
- api = RenderApiFlag::D3D12;
- break;
- case DeviceType::Vulkan:
- api = RenderApiFlag::Vulkan;
- break;
- case DeviceType::CPU:
- api = RenderApiFlag::CPU;
- break;
- case DeviceType::CUDA:
- api = RenderApiFlag::CUDA;
- break;
- case DeviceType::OpenGl:
- api = RenderApiFlag::OpenGl;
- break;
- default:
- SLANG_IGNORE_TEST
- }
+ ComPtr<IShaderProgram> shaderProgram;
+ slang::ProgramLayout* slangReflection;
+ GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, moduleName, entryPointName, slangReflection));
- memoryFileSystem = new MemoryFileSystem();
- diskFileSystem = OSFileSystem::getMutableSingleton();
- diskFileSystem->createDirectory("tools/gfx-unit-test/shader-cache-test");
- diskFileSystem = new RelativeFileSystem(diskFileSystem, "tools/gfx-unit-test/shader-cache-test");
+ ComputePipelineStateDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ GFX_CHECK_CALL_ABORT(
+ device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
}
- void submitGPUWork()
+ void dispatchComputePipeline()
{
ComPtr<ITransientResourceHeap> transientHeap;
ITransientResourceHeap::Desc transientHeapDesc = {};
@@ -174,437 +208,284 @@ namespace gfx_test
auto rootObject = encoder->bindPipeline(pipelineState);
- ShaderCursor rootCursor(rootObject);
+ ShaderCursor entryPointCursor(rootObject->getEntryPoint(0));
+ entryPointCursor.getPath("buffer").setResource(bufferView);
+
+ // ShaderCursor rootCursor(rootObject);
// Bind buffer view to the entry point.
- rootCursor.getPath("buffer").setResource(bufferView);
+ // rootCursor.getPath("buffer").setResource(bufferView);
- encoder->dispatchCompute(1, 1, 1);
+ encoder->dispatchCompute(4, 1, 1);
encoder->endEncoding();
commandBuffer->close();
queue->executeCommandBuffer(commandBuffer);
queue->waitOnHost();
- }
+ }
- void cleanUpFiles()
+ void runComputePipeline(const char* moduleName, const char* entryPointName)
{
- freeOldResources();
-
- List<String> filePaths;
- diskFileSystem->enumeratePathContents(
- ".",
- [](SlangPathType pathType, const char* name, void* userData)
- {
- if (pathType == SlangPathType::SLANG_PATH_TYPE_FILE)
- {
- List<String>& out = *(List<String>*)userData;
- out.add(String(name));
- }
- },
- &filePaths);
+ createComputeResources();
+ createComputePipeline(moduleName, entryPointName);
+ dispatchComputePipeline();
+ freeComputeResources();
+ }
- for (auto file : filePaths)
- {
- diskFileSystem->remove(file.getBuffer());
- }
- // Get a mutable singleton so we can delete the folder.
- auto fileSystem = OSFileSystem::getMutableSingleton();
- fileSystem->remove("tools/gfx-unit-test/shader-cache-test");
+ ShaderCacheStats getStats()
+ {
+ SLANG_ASSERT(shaderCache);
+ ShaderCacheStats stats;
+ shaderCache->getShaderCacheStats(&stats);
+ return stats;
}
- void run()
+ void run(UnitTestContext* context, Slang::RenderApiFlag::Enum api)
{
- shaderCache.shaderCacheFileSystem = diskFileSystem;
- runTests();
- shaderCache.shaderCacheFileSystem = memoryFileSystem;
+ init(context, api);
runTests();
-
- cleanUpFiles();
+ cleanup();
}
virtual void runTests() = 0;
};
- // Due to needing a workaround to prevent loading old, outdated modules, we need to
- // recreate the device between each segment of the test for all tests. However, we need to maintain the
- // same cache filesystem for the same duration, so the device is immediately recreated
- // to ensure we can pass the filesystem all the way through.
- //
- // General TODO: Remove the repeated generateNewDevice() and createRequiredResources() calls once
- // a solution exists that allows source code changes under the same module name to be picked
- // up on load.
-
- // One shader file on disk, all modifications are done to the same file
- struct SingleEntryShaderCache : BaseShaderCacheTest
+ // Basic shader cache test using 3 different shader files stored on disk.
+ struct ShaderCacheTestBasic : ShaderCacheTest
{
- void generateNewPipelineState(Slang::String shaderContents)
+ void runTests()
{
- diskFileSystem->saveFile("test-tmp-single-entry.slang", shaderContents.getBuffer(), shaderContents.getLength());
+ // Write shader source files.
+ writeShader(computeShaderA, "shader-cache-tmp-a.slang");
+ writeShader(computeShaderB, "shader-cache-tmp-b.slang");
+ writeShader(computeShaderC, "shader-cache-tmp-c.slang");
- ComPtr<IShaderProgram> shaderProgram;
- slang::ProgramLayout* slangReflection;
- GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "shader-cache-test/test-tmp-single-entry", "computeMain", slangReflection));
+ // Cache is cold and we expect 3 misses.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-tmp-a", "main");
+ runComputePipeline("shader-cache-tmp-b", "main");
+ runComputePipeline("shader-cache-tmp-c", "main");
- ComputePipelineStateDesc pipelineDesc = {};
- pipelineDesc.program = shaderProgram.get();
- GFX_CHECK_CALL_ABORT(
- device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
- }
+ SLANG_CHECK(getStats().missCount == 3);
+ SLANG_CHECK(getStats().hitCount == 0);
+ SLANG_CHECK(getStats().entryCount == 3);
+ }
+ );
- void runTests()
- {
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(contentsA);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(contentsA);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(contentsC);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1);
- }
- };
+ // Cache is hot and we expect 3 hits.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-tmp-a", "main");
+ runComputePipeline("shader-cache-tmp-b", "main");
+ runComputePipeline("shader-cache-tmp-c", "main");
- // Several shader files on disk, modifications may be done to any file
- struct MultipleEntryShaderCache : BaseShaderCacheTest
- {
- void modifyShaderA(String shaderContents)
- {
- diskFileSystem->saveFile("test-tmp-multi-entry-A.slang", shaderContents.getBuffer(), shaderContents.getLength());
- }
+ SLANG_CHECK(getStats().missCount == 0);
+ SLANG_CHECK(getStats().hitCount == 3);
+ SLANG_CHECK(getStats().entryCount == 3);
+ }
+ );
- void modifyShaderB(String shaderContents)
- {
- diskFileSystem->saveFile("test-tmp-multi-entry-B.slang", shaderContents.getBuffer(), shaderContents.getLength());
- }
+ // Write shader source files, all rotated by one.
+ writeShader(computeShaderA, "shader-cache-tmp-b.slang");
+ writeShader(computeShaderB, "shader-cache-tmp-c.slang");
+ writeShader(computeShaderC, "shader-cache-tmp-a.slang");
- void modifyShaderC(String shaderContents)
- {
- diskFileSystem->saveFile("test-tmp-multi-entry-C.slang", shaderContents.getBuffer(), shaderContents.getLength());
- }
+ // Cache is cold again and we expect 3 misses.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-tmp-a", "main");
+ runComputePipeline("shader-cache-tmp-b", "main");
+ runComputePipeline("shader-cache-tmp-c", "main");
- void generateNewPipelineState(GfxIndex shaderIndex)
- {
- ComPtr<IShaderProgram> shaderProgram;
- slang::ProgramLayout* slangReflection;
- const char* shaderFilename;
- switch (shaderIndex)
- {
- case 0:
- shaderFilename = "shader-cache-test/test-tmp-multi-entry-A";
- break;
- case 1:
- shaderFilename = "shader-cache-test/test-tmp-multi-entry-B";
- break;
- case 2:
- shaderFilename = "shader-cache-test/test-tmp-multi-entry-C";
- break;
- default:
- // Should never reach this point since we wrote the test
- SLANG_IGNORE_TEST;
- }
- GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, shaderFilename, "computeMain", slangReflection));
-
- ComputePipelineStateDesc pipelineDesc = {};
- pipelineDesc.program = shaderProgram.get();
- GFX_CHECK_CALL_ABORT(
- device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
- }
+ SLANG_CHECK(getStats().missCount == 3);
+ SLANG_CHECK(getStats().hitCount == 0);
+ SLANG_CHECK(getStats().entryCount == 6);
+ }
+ );
- void checkAllCacheEntries()
- {
- generateNewPipelineState(0);
- submitGPUWork();
- generateNewPipelineState(1);
- submitGPUWork();
- generateNewPipelineState(2);
- submitGPUWork();
- }
+ // Cache is hot again and we expect 3 hits.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-tmp-a", "main");
+ runComputePipeline("shader-cache-tmp-b", "main");
+ runComputePipeline("shader-cache-tmp-c", "main");
- void runTests()
- {
- generateNewDevice();
- createRequiredResources();
- modifyShaderA(contentsA);
- modifyShaderB(contentsB);
- modifyShaderC(contentsC);
- checkAllCacheEntries();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 3);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- checkAllCacheEntries();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 3);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- modifyShaderA(contentsB);
- checkAllCacheEntries();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 2);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1);
-
- generateNewDevice();
- createRequiredResources();
- modifyShaderA(contentsC);
- modifyShaderB(contentsA);
- modifyShaderC(contentsB);
- checkAllCacheEntries();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 3);
+ SLANG_CHECK(getStats().missCount == 0);
+ SLANG_CHECK(getStats().hitCount == 3);
+ SLANG_CHECK(getStats().entryCount == 6);
+ }
+ );
}
};
- // One shader file on disk containing several entry points, no modifications are made to the file
- struct MultipleEntryPointShader : BaseShaderCacheTest
+ // Test one shader file on disk with multiple entry points.
+ struct ShaderCacheTestEntryPoint : ShaderCacheTest
{
- void generateNewPipelineState(GfxIndex shaderIndex)
+ void runTests()
{
- ComPtr<IShaderProgram> shaderProgram;
- slang::ProgramLayout* slangReflection;
- const char* entryPointName;
- switch (shaderIndex)
- {
- case 0:
- entryPointName = "computeA";
- break;
- case 1:
- entryPointName = "computeB";
- break;
- case 2:
- entryPointName = "computeC";
- break;
- default:
- // Should never reach this point since we wrote the test
- SLANG_IGNORE_TEST;
- }
- GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "multiple-entry-point-shader-cache-shader", entryPointName, slangReflection));
+ // Cache is cold and we expect 3 misses, one for each entry point.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-multiple-entry-points", "computeA");
+ runComputePipeline("shader-cache-multiple-entry-points", "computeB");
+ runComputePipeline("shader-cache-multiple-entry-points", "computeC");
- ComputePipelineStateDesc pipelineDesc = {};
- pipelineDesc.program = shaderProgram.get();
- GFX_CHECK_CALL_ABORT(
- device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
- }
+ SLANG_CHECK(getStats().missCount == 3);
+ SLANG_CHECK(getStats().hitCount == 0);
+ SLANG_CHECK(getStats().entryCount == 3);
+ }
+ );
- void runTests()
- {
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(0);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(1);
- submitGPUWork();
- generateNewPipelineState(0);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(2);
- submitGPUWork();
- generateNewPipelineState(1);
- submitGPUWork();
- generateNewPipelineState(0);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 2);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
+ // Cache is hot and we expect 3 hits.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-multiple-entry-points", "computeA");
+ runComputePipeline("shader-cache-multiple-entry-points", "computeB");
+ runComputePipeline("shader-cache-multiple-entry-points", "computeC");
+
+ SLANG_CHECK(getStats().missCount == 0);
+ SLANG_CHECK(getStats().hitCount == 3);
+ SLANG_CHECK(getStats().entryCount == 3);
+ }
+ );
}
};
- // One shader file contains an import/include, direct code modifications are made to the imported file
- // This test specifically checks four cases:
- // 1. import w/o changes in the imported file
- // 2. import w/ changes in the imported file
- // 3. #include w/o changes in the included file (the included file is the same as the imported file in the prior step)
- // 4. #include w/ changes in the included file
- struct ShaderFileImportsShaderCache : BaseShaderCacheTest
+ // Test cache invalidation due to an import/include file being changed on disk.
+ struct ShaderCacheTestImportInclude : ShaderCacheTest
{
String importedContentsA = String(
R"(
- struct TestFunction
+ void processElement(RWStructuredBuffer<float> buffer, uint index)
{
- void simpleElementAdd(RWStructuredBuffer<float> buffer, uint index)
- {
- var input = buffer[index];
- buffer[index] = input + 1.0f;
- }
- };)");
+ var input = buffer[index];
+ buffer[index] = input + 1.0f;
+ }
+ )");
String importedContentsB = String(
R"(
- struct TestFunction
+ void processElement(RWStructuredBuffer<float> buffer, uint index)
{
- void simpleElementAdd(RWStructuredBuffer<float> buffer, uint index)
- {
- var input = buffer[index];
- buffer[index] = input + 2.0f;
- }
- };)");
+ var input = buffer[index];
+ buffer[index] = input + 2.0f;
+ }
+ )");
String importFile = String(
R"(
- import test_tmp_imported;
-
- uniform RWStructuredBuffer<float> buffer;
+ import shader_cache_tmp_imported;
[shader("compute")]
[numthreads(4, 1, 1)]
- void computeMain(
- uint3 sv_dispatchThreadID : SV_DispatchThreadID)
+ void main(
+ uint3 sv_dispatchThreadID : SV_DispatchThreadID,
+ uniform RWStructuredBuffer<float> buffer)
{
- TestFunction test;
- for (uint i = 0; i < 4; ++i)
- {
- test.simpleElementAdd(buffer, i);
- }
- })");
+ processElement(buffer, sv_dispatchThreadID.x);
+ }
+ )");
String includeFile = String(
R"(
- #include "test-tmp-imported.slang"
+ #include "shader-cache-tmp-imported.slang"
- uniform RWStructuredBuffer<float> buffer;
-
[shader("compute")]
[numthreads(4, 1, 1)]
- void computeMain(
- uint3 sv_dispatchThreadID : SV_DispatchThreadID)
+ void main(
+ uint3 sv_dispatchThreadID : SV_DispatchThreadID,
+ uniform RWStructuredBuffer<float> buffer)
{
- TestFunction test;
- for (uint i = 0; i < 4; ++i)
- {
- test.simpleElementAdd(buffer, i);
- }
+ processElement(buffer, sv_dispatchThreadID.x);
})");
- void initializeFiles()
+ void runTests()
{
- diskFileSystem->saveFile("test-tmp-imported.slang", importedContentsA.getBuffer(), importedContentsA.getLength());
- diskFileSystem->saveFile("test-tmp-importing.slang", importFile.getBuffer(), importFile.getLength());
- }
+ // Write shader source files.
+ writeShader(importedContentsA, "shader-cache-tmp-imported.slang");
+ writeShader(importFile, "shader-cache-tmp-import.slang");
+ writeShader(includeFile, "shader-cache-tmp-include.slang");
- void modifyImportedFile(String importedContents)
- {
- diskFileSystem->saveFile("test-tmp-imported.slang", importedContents.getBuffer(), importedContents.getLength());
- }
+ // Cache is cold and we expect 2 misses.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-tmp-import", "main");
+ runComputePipeline("shader-cache-tmp-include", "main");
- void changeImportToInclude()
- {
- diskFileSystem->saveFile("test-tmp-importing.slang", includeFile.getBuffer(), includeFile.getLength());
- }
+ SLANG_CHECK(getStats().missCount == 2);
+ SLANG_CHECK(getStats().hitCount == 0);
+ SLANG_CHECK(getStats().entryCount == 2);
+ }
+ );
- void generateNewPipelineState()
- {
- ComPtr<IShaderProgram> shaderProgram;
- slang::ProgramLayout* slangReflection;
- GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "shader-cache-test/test-tmp-importing", "computeMain", slangReflection));
+ // Cache is hot and we expect 2 hits.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-tmp-import", "main");
+ runComputePipeline("shader-cache-tmp-include", "main");
- ComputePipelineStateDesc pipelineDesc = {};
- pipelineDesc.program = shaderProgram.get();
- GFX_CHECK_CALL_ABORT(
- device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
- }
+ SLANG_CHECK(getStats().missCount == 0);
+ SLANG_CHECK(getStats().hitCount == 2);
+ SLANG_CHECK(getStats().entryCount == 2);
+ }
+ );
- void runTests()
- {
- generateNewDevice();
- createRequiredResources();
- initializeFiles();
- generateNewPipelineState();
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- modifyImportedFile(importedContentsB);
- generateNewPipelineState();
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1);
-
- generateNewDevice();
- createRequiredResources();
- changeImportToInclude();
- generateNewPipelineState();
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1);
-
- generateNewDevice();
- createRequiredResources();
- modifyImportedFile(importedContentsA);
- generateNewPipelineState();
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 1);
+ // Change content of imported/included shader file.
+ writeShader(importedContentsB, "shader-cache-tmp-imported.slang");
+
+ // Cache is cold and we expect 2 misses.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-tmp-import", "main");
+ runComputePipeline("shader-cache-tmp-include", "main");
+
+ SLANG_CHECK(getStats().missCount == 2);
+ SLANG_CHECK(getStats().hitCount == 0);
+ SLANG_CHECK(getStats().entryCount == 4);
+ }
+ );
+
+ // Cache is hot and we expect 2 hits.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("shader-cache-tmp-import", "main");
+ runComputePipeline("shader-cache-tmp-include", "main");
+
+ SLANG_CHECK(getStats().missCount == 0);
+ SLANG_CHECK(getStats().hitCount == 2);
+ SLANG_CHECK(getStats().entryCount == 4);
+ }
+ );
}
};
// One shader featuring multiple kinds of shader objects that can be bound.
- struct SpecializationArgsEntries : BaseShaderCacheTest
+ struct ShaderCacheTestSpecialization : ShaderCacheTest
{
slang::ProgramLayout* slangReflection;
+ void createComputePipeline()
+ {
+ ComPtr<IShaderProgram> shaderProgram;
+
+ GFX_CHECK_CALL_ABORT(
+ loadComputeProgram(device, shaderProgram, "shader-cache-specialization", "computeMain", slangReflection));
+
+ ComputePipelineStateDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ GFX_CHECK_CALL_ABORT(
+ device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
+ }
+
void createAddTransformer(IShaderObject** transformer)
{
slang::TypeReflection* addTransformerType =
@@ -627,7 +508,7 @@ namespace gfx_test
ShaderCursor(*transformer).getPath("c").setData(&c, sizeof(float));
}
- void submitGPUWork(GfxIndex transformerType)
+ void dispatchComputePipeline(const char* transformerTypeName)
{
Slang::ComPtr<ITransientResourceHeap> transientHeap;
ITransientResourceHeap::Desc transientHeapDesc = {};
@@ -643,23 +524,16 @@ namespace gfx_test
auto rootObject = encoder->bindPipeline(pipelineState);
- ComPtr<IShaderObject> transformer;
- switch (transformerType)
- {
- case 0:
- createAddTransformer(transformer.writeRef());
- break;
- case 1:
- createMulTransformer(transformer.writeRef());
- break;
- default:
- /* Should not get here */
- SLANG_IGNORE_TEST;
- }
+ Slang::ComPtr<IShaderObject> transformer;
+ slang::TypeReflection* transformerType = slangReflection->findTypeByName(transformerTypeName);
+ GFX_CHECK_CALL_ABORT(device->createShaderObject(
+ transformerType, ShaderObjectContainerType::None, transformer.writeRef()));
+
+ float c = 1.0f;
+ ShaderCursor(transformer).getPath("c").setData(&c, sizeof(float));
ShaderCursor entryPointCursor(rootObject->getEntryPoint(0));
entryPointCursor.getPath("buffer").setResource(bufferView);
-
entryPointCursor.getPath("transformer").setObject(transformer);
encoder->dispatchCompute(1, 1, 1);
@@ -669,78 +543,78 @@ namespace gfx_test
queue->waitOnHost();
}
- void generateNewPipelineState()
+ void runComputePipeline(const char* transformerTypeName)
{
- ComPtr<IShaderProgram> shaderProgram;
-
- GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "compute-smoke", "computeMain", slangReflection));
-
- ComputePipelineStateDesc pipelineDesc = {};
- pipelineDesc.program = shaderProgram.get();
- GFX_CHECK_CALL_ABORT(
- device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
- }
+ createComputeResources();
+ createComputePipeline();
+ dispatchComputePipeline(transformerTypeName);
+ freeComputeResources();
+ }
void runTests()
{
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState();
- submitGPUWork(0);
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState();
- submitGPUWork(1);
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
- }
- };
+ // Cache is cold and we expect 2 misses.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("AddTransformer");
+ runComputePipeline("MulTransformer");
- // Same gist as the multiple entry point compute shader but with a graphics
- // shader file containing a vertex and fragment shader
- struct Vertex
- {
- float position[3];
- };
+ SLANG_CHECK(getStats().missCount == 2);
+ SLANG_CHECK(getStats().hitCount == 0);
+ SLANG_CHECK(getStats().entryCount == 2);
+ }
+ );
- static const int kVertexCount = 3;
- static const Vertex kVertexData[kVertexCount] =
- {
- { 0, 0, 0.5 },
- { 1, 0, 0.5 },
- { 0, 1, 0.5 },
+ // Cache is hot and we expect 2 hits.
+ runStep(
+ [this]()
+ {
+ runComputePipeline("AddTransformer");
+ runComputePipeline("MulTransformer");
+
+ SLANG_CHECK(getStats().missCount == 0);
+ SLANG_CHECK(getStats().hitCount == 2);
+ SLANG_CHECK(getStats().entryCount == 2);
+ }
+ );
+ }
};
- struct GraphicsShaderCache : BaseShaderCacheTest
+ // Same gist as the multiple entry point compute shader but with a graphics
+ // shader file containing a vertex and fragment shader.
+ struct ShaderCacheTestGraphics : ShaderCacheTest
{
- const int kWidth = 256;
- const int kHeight = 256;
- const Format format = Format::R32G32B32A32_FLOAT;
+ struct Vertex
+ {
+ float position[3];
+ };
- ComPtr<IShaderProgram> shaderProgram;
- ComPtr<IRenderPassLayout> renderPass;
- ComPtr<IFramebuffer> framebuffer;
+ static const int kWidth = 256;
+ static const int kHeight = 256;
+ static const Format format = Format::R32G32B32A32_FLOAT;
ComPtr<IBufferResource> vertexBuffer;
ComPtr<ITextureResource> colorBuffer;
+ ComPtr<IInputLayout> inputLayout;
+ ComPtr<IFramebufferLayout> framebufferLayout;
+ ComPtr<IRenderPassLayout> renderPass;
+ ComPtr<IFramebuffer> framebuffer;
ComPtr<IBufferResource> createVertexBuffer(IDevice* device)
{
+ const Vertex vertices[] = {
+ { 0, 0, 0.5 },
+ { 1, 0, 0.5 },
+ { 0, 1, 0.5 },
+ };
+
IBufferResource::Desc vertexBufferDesc;
vertexBufferDesc.type = IResource::Type::Buffer;
- vertexBufferDesc.sizeInBytes = kVertexCount * sizeof(Vertex);
+ vertexBufferDesc.sizeInBytes = sizeof(vertices);
vertexBufferDesc.defaultState = ResourceState::VertexBuffer;
vertexBufferDesc.allowedStates = ResourceState::VertexBuffer;
- ComPtr<IBufferResource> vertexBuffer = device->createBufferResource(vertexBufferDesc, &kVertexData[0]);
+ ComPtr<IBufferResource> vertexBuffer = device->createBufferResource(vertexBufferDesc, vertices);
SLANG_CHECK_ABORT(vertexBuffer != nullptr);
return vertexBuffer;
}
@@ -761,13 +635,7 @@ namespace gfx_test
return colorBuffer;
}
- void createShaderProgram()
- {
- slang::ProgramLayout* slangReflection;
- GFX_CHECK_CALL_ABORT(loadGraphicsProgram(device, shaderProgram, "shader-cache-graphics", "vertexMain", "fragmentMain", slangReflection));
- }
-
- void createRequiredResources()
+ void createGraphicsResources()
{
VertexStreamDesc vertexStreams[] = {
{ sizeof(Vertex), InputSlotClass::PerVertex, 0 },
@@ -782,7 +650,7 @@ namespace gfx_test
inputLayoutDesc.inputElements = inputElements;
inputLayoutDesc.vertexStreamCount = SLANG_COUNT_OF(vertexStreams);
inputLayoutDesc.vertexStreams = vertexStreams;
- auto inputLayout = device->createInputLayout(inputLayoutDesc);
+ inputLayout = device->createInputLayout(inputLayoutDesc);
SLANG_CHECK_ABORT(inputLayout != nullptr);
vertexBuffer = createVertexBuffer(device);
@@ -795,18 +663,9 @@ namespace gfx_test
IFramebufferLayout::Desc framebufferLayoutDesc;
framebufferLayoutDesc.renderTargetCount = 1;
framebufferLayoutDesc.renderTargets = &targetLayout;
- ComPtr<gfx::IFramebufferLayout> framebufferLayout = device->createFramebufferLayout(framebufferLayoutDesc);
+ framebufferLayout = device->createFramebufferLayout(framebufferLayoutDesc);
SLANG_CHECK_ABORT(framebufferLayout != nullptr);
- GraphicsPipelineStateDesc pipelineDesc = {};
- pipelineDesc.program = shaderProgram.get();
- pipelineDesc.inputLayout = inputLayout;
- pipelineDesc.framebufferLayout = framebufferLayout;
- pipelineDesc.depthStencil.depthTestEnable = false;
- pipelineDesc.depthStencil.depthWriteEnable = false;
- GFX_CHECK_CALL_ABORT(
- device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef()));
-
IRenderPassLayout::Desc renderPassDesc = {};
renderPassDesc.framebufferLayout = framebufferLayout;
renderPassDesc.renderTargetCount = 1;
@@ -833,7 +692,35 @@ namespace gfx_test
GFX_CHECK_CALL_ABORT(device->createFramebuffer(framebufferDesc, framebuffer.writeRef()));
}
- void submitGPUWork()
+ void freeGraphicsResources()
+ {
+ inputLayout = nullptr;
+ framebufferLayout = nullptr;
+ renderPass = nullptr;
+ framebuffer = nullptr;
+ vertexBuffer = nullptr;
+ colorBuffer = nullptr;
+ pipelineState = nullptr;
+ }
+
+ void createGraphicsPipeline()
+ {
+ ComPtr<IShaderProgram> shaderProgram;
+ slang::ProgramLayout* slangReflection;
+ GFX_CHECK_CALL_ABORT(
+ loadGraphicsProgram(device, shaderProgram, "shader-cache-graphics", "vertexMain", "fragmentMain", slangReflection));
+
+ GraphicsPipelineStateDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ pipelineDesc.inputLayout = inputLayout;
+ pipelineDesc.framebufferLayout = framebufferLayout;
+ pipelineDesc.depthStencil.depthTestEnable = false;
+ pipelineDesc.depthStencil.depthWriteEnable = false;
+ GFX_CHECK_CALL_ABORT(
+ device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef()));
+ }
+
+ void dispatchGraphicsPipeline()
{
ComPtr<ITransientResourceHeap> transientHeap;
ITransientResourceHeap::Desc transientHeapDesc = {};
@@ -857,28 +744,50 @@ namespace gfx_test
encoder->setVertexBuffer(0, vertexBuffer);
encoder->setPrimitiveTopology(PrimitiveTopology::TriangleList);
- encoder->draw(kVertexCount);
+ encoder->draw(3);
encoder->endEncoding();
commandBuffer->close();
queue->executeCommandBuffer(commandBuffer);
queue->waitOnHost();
}
+ void runGraphicsPipeline()
+ {
+ createGraphicsResources();
+ createGraphicsPipeline();
+ dispatchGraphicsPipeline();
+ freeGraphicsResources();
+ }
+
void runTests()
{
- generateNewDevice();
- createShaderProgram();
- createRequiredResources();
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 2);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
+ // Cache is cold and we expect 2 misses (2 entry points).
+ runStep(
+ [this]()
+ {
+ runGraphicsPipeline();
+
+ SLANG_CHECK(getStats().missCount == 2);
+ SLANG_CHECK(getStats().hitCount == 0);
+ SLANG_CHECK(getStats().entryCount == 2);
+ }
+ );
+
+ // Cache is hot and we expect 2 hits.
+ runStep(
+ [this]()
+ {
+ runGraphicsPipeline();
+
+ SLANG_CHECK(getStats().missCount == 0);
+ SLANG_CHECK(getStats().hitCount == 2);
+ SLANG_CHECK(getStats().entryCount == 2);
+ }
+ );
}
};
- // Same as GraphicsShaderCache, but instead of having a singular file containing both a vertex and fragment shader, we
+ // Same as ShaderCacheTestGraphics, but instead of having a singular file containing both a vertex and fragment shader, we
// now have two separate shader files, one containing the vertex shader and the other the fragment with the same
// names, with the expectation that we should record cache misses for both fetches.
//
@@ -890,54 +799,38 @@ namespace gfx_test
//
// We do not actively test geometry shaders here, but it is simply an extension of this test and should be expected
// to behave similarly.
- struct SplitGraphicsShader : GraphicsShaderCache
+ struct ShaderCacheTestGraphicsSplit : ShaderCacheTestGraphics
{
- void createShaderProgram()
- {
- slang::ProgramLayout* slangReflection;
- const char* moduleNames[] = { "split-graphics-vertex", "split-graphics-fragment" };
- GFX_CHECK_CALL_ABORT(loadSplitGraphicsProgram(device, shaderProgram, moduleNames, "main", "main", slangReflection));
- }
-
- Result loadSplitGraphicsProgram(
- IDevice* device,
- ComPtr<IShaderProgram>& outShaderProgram,
- const char** shaderModuleNames,
- const char* vertexEntryPointName,
- const char* fragmentEntryPointName,
- slang::ProgramLayout*& slangReflection)
+ void createGraphicsPipeline()
{
ComPtr<slang::ISession> slangSession;
- SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
+ GFX_CHECK_CALL_ABORT(device->getSlangSession(slangSession.writeRef()));
- ComPtr<slang::IBlob> diagnosticsBlob;
- slang::IModule* vertexModule = slangSession->loadModule(shaderModuleNames[0], diagnosticsBlob.writeRef());
- if (!vertexModule)
- return SLANG_FAIL;
- slang::IModule* fragmentModule = slangSession->loadModule(shaderModuleNames[1], diagnosticsBlob.writeRef());
- if (!fragmentModule)
- return SLANG_FAIL;
+ slang::IModule* vertexModule = slangSession->loadModule("shader-cache-graphics-vertex");
+ SLANG_CHECK_ABORT(vertexModule);
+ slang::IModule* fragmentModule = slangSession->loadModule("shader-cache-graphics-fragment");
+ SLANG_CHECK_ABORT(fragmentModule);
ComPtr<slang::IEntryPoint> vertexEntryPoint;
- SLANG_RETURN_ON_FAIL(
- vertexModule->findEntryPointByName(vertexEntryPointName, vertexEntryPoint.writeRef()));
+ GFX_CHECK_CALL_ABORT(
+ vertexModule->findEntryPointByName("main", vertexEntryPoint.writeRef()));
ComPtr<slang::IEntryPoint> fragmentEntryPoint;
- SLANG_RETURN_ON_FAIL(
- fragmentModule->findEntryPointByName(fragmentEntryPointName, fragmentEntryPoint.writeRef()));
+ GFX_CHECK_CALL_ABORT(
+ fragmentModule->findEntryPointByName("main", fragmentEntryPoint.writeRef()));
Slang::List<slang::IComponentType*> componentTypes;
componentTypes.add(vertexModule);
componentTypes.add(fragmentModule);
Slang::ComPtr<slang::IComponentType> composedProgram;
- SlangResult result = slangSession->createCompositeComponentType(
- componentTypes.getBuffer(),
- componentTypes.getCount(),
- composedProgram.writeRef(),
- diagnosticsBlob.writeRef());
- SLANG_RETURN_ON_FAIL(result);
- slangReflection = composedProgram->getLayout();
+ GFX_CHECK_CALL_ABORT(
+ slangSession->createCompositeComponentType(
+ componentTypes.getBuffer(),
+ componentTypes.getCount(),
+ composedProgram.writeRef()));
+
+ slang::ProgramLayout* slangReflection = composedProgram->getLayout();
Slang::List<slang::IComponentType*> entryPoints;
entryPoints.add(vertexEntryPoint);
@@ -949,263 +842,60 @@ namespace gfx_test
programDesc.entryPointCount = 2;
programDesc.slangEntryPoints = entryPoints.getBuffer();
- auto shaderProgram = device->createProgram(programDesc);
-
- outShaderProgram = shaderProgram;
- return SLANG_OK;
- }
-
- void runTests()
- {
- generateNewDevice();
- createShaderProgram();
- createRequiredResources();
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 2);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
- }
- };
-
- // Same as MultipleEntryShaderCache, but we now set the maximum entry count limit, so the cache
- // should remove entries as needed when it reaches capacity.
- //
- // This test does not modify shaders as other tests already test this, instead focusing on checking
- // that entries are correctly removed as cache limits are reached and that entries are always in
- // the right order.
- //
- // As opening multiple streams to the same file is dependent on the OS, this test is run on the
- // in-memory file system. Cache eviction policy with an on-disk file system will need to be inspected
- // manually.
- struct CacheWithMaxEntryLimit : MultipleEntryShaderCache
- {
- List<String> test0Lines; // C -> B -> A
- List<String> test1Lines; // C -> B
- List<String> test2Lines; // A -> B
- List<String> test3Lines; // A -> C
- List<String> test4Lines; // C -> B -> A
- List<String> entryKeys; // C, B, A
-
- void getCacheFile(List<String>& lines)
- {
- ComPtr<ISlangBlob> contentsBlob;
- memoryFileSystem->loadFile(shaderCache.cacheFilename, contentsBlob.writeRef());
- List<UnownedStringSlice> temp;
- StringUtil::calcLines(UnownedStringSlice((char*)contentsBlob->getBufferPointer()), temp);
- for (auto line : temp)
- {
- if (line.trim().getLength() != 0)
- lines.add(line);
- }
- }
-
- // Check the correctness of the cache's entries by comparing the order of entries in the
- // current state of the cache with what we expect.
- void checkCacheFiles()
- {
- // Check that shader A appears where we expect it to.
- SLANG_CHECK(test2Lines[0] == test3Lines[0]);
- SLANG_CHECK(test2Lines[0] == test4Lines[2]);
-
- // Check that shader B appears where we expect it to.
- SLANG_CHECK(test1Lines[1] == test2Lines[1]);
- SLANG_CHECK(test1Lines[1] == test4Lines[1]);
-
- // Check that shader C appears where we expect it to.
- SLANG_CHECK(test1Lines[0] == test3Lines[1]);
- SLANG_CHECK(test1Lines[0] == test4Lines[0]);
- }
-
- // Cache limit 3, three unique shaders
- void runTest0()
- {
- shaderCache.entryCountLimit = 3;
- generateNewDevice();
- createRequiredResources();
- modifyShaderA(contentsA);
- modifyShaderB(contentsB);
- modifyShaderC(contentsC);
- generateNewPipelineState(0);
- submitGPUWork();
- generateNewPipelineState(1);
- submitGPUWork();
- generateNewPipelineState(2);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 3);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- // This needs to be called in order to force the cache file to be updated, otherwise we will
- // be unable to perform the necessary checks.
- freeOldResources();
-
- getCacheFile(test0Lines);
- SLANG_CHECK(test0Lines.getCount() == 3);
-
- // This segment also doubles as the point where we fetch the keys for all three shaders
- // to use in later checks.
- for (auto line : test0Lines)
- {
- List<UnownedStringSlice> digests;
- StringUtil::split(line.getUnownedSlice(), ' ', digests);
- if (digests.getCount() != 2)
- continue;
- entryKeys.add(digests[0]);
- }
-
- ComPtr<ISlangBlob> unused;
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef())));
- }
-
- // Cache limit 2, access shaders A then B then C
- void runTest1()
- {
- shaderCache.entryCountLimit = 2;
- generateNewDevice();
- createRequiredResources();
- modifyShaderA(contentsA);
- modifyShaderB(contentsB);
- modifyShaderC(contentsC);
- generateNewPipelineState(0);
- submitGPUWork();
- generateNewPipelineState(1);
- submitGPUWork();
- generateNewPipelineState(2);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 3);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- freeOldResources();
-
- getCacheFile(test1Lines);
- SLANG_CHECK(test1Lines.getCount() == 2);
-
- ComPtr<ISlangBlob> unused;
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_FAILED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef())));
- }
-
- // Cache limit 2, access shaders B and then A
- void runTest2()
- {
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(1);
- submitGPUWork();
- generateNewPipelineState(0);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- freeOldResources();
-
- getCacheFile(test2Lines);
- SLANG_CHECK(test2Lines.getCount() == 2);
-
- ComPtr<ISlangBlob> unused;
- SLANG_CHECK(SLANG_FAILED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef())));
- }
+ ComPtr<IShaderProgram> shaderProgram = device->createProgram(programDesc);
- // Cache limit 2, access shaders C and then A
- void runTest3()
- {
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(2);
- submitGPUWork();
- generateNewPipelineState(0);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- freeOldResources();
-
- getCacheFile(test3Lines);
- SLANG_CHECK(test3Lines.getCount() == 2);
-
- ComPtr<ISlangBlob> unused;
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_FAILED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef())));
+ GraphicsPipelineStateDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ pipelineDesc.inputLayout = inputLayout;
+ pipelineDesc.framebufferLayout = framebufferLayout;
+ pipelineDesc.depthStencil.depthTestEnable = false;
+ pipelineDesc.depthStencil.depthWriteEnable = false;
+ GFX_CHECK_CALL_ABORT(
+ device->createGraphicsPipelineState(pipelineDesc, pipelineState.writeRef()));
}
- // Cache limit 3, access shaders A then B then C
- void runTest4()
+ void runGraphicsPipeline()
{
- shaderCache.entryCountLimit = 3;
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(0);
- submitGPUWork();
- generateNewPipelineState(1);
- submitGPUWork();
- generateNewPipelineState(2);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 2);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- freeOldResources();
-
- getCacheFile(test4Lines);
- SLANG_CHECK(test4Lines.getCount() == 3);
-
- ComPtr<ISlangBlob> unused;
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[0].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[1].getBuffer(), unused.writeRef())));
- SLANG_CHECK(SLANG_SUCCEEDED(memoryFileSystem->loadFile(entryKeys[2].getBuffer(), unused.writeRef())));
- }
+ createGraphicsResources();
+ createGraphicsPipeline();
+ dispatchGraphicsPipeline();
+ freeGraphicsResources();
+ }
void runTests()
{
- runTest0();
- runTest1();
- runTest2();
- runTest3();
- runTest4();
+ // Cache is cold and we expect 2 misses (2 entry points).
+ runStep(
+ [this]()
+ {
+ runGraphicsPipeline();
- checkCacheFiles();
- }
+ SLANG_CHECK(getStats().missCount == 2);
+ SLANG_CHECK(getStats().hitCount == 0);
+ SLANG_CHECK(getStats().entryCount == 2);
+ }
+ );
- void run()
- {
- shaderCache.shaderCacheFileSystem = memoryFileSystem;
- runTests();
+ // Cache is hot and we expect 2 hits.
+ runStep(
+ [this]()
+ {
+ runGraphicsPipeline();
- cleanUpFiles();
- }
+ SLANG_CHECK(getStats().missCount == 0);
+ SLANG_CHECK(getStats().hitCount == 2);
+ SLANG_CHECK(getStats().entryCount == 2);
+ }
+ ); }
};
- // This test is specifically for source files which live entirely in memory. The key difference between
- // these and physical source files is such files have their contents hash added to the file dependencies
- // list instead of a file path, meaning any given specific set of shader contents will be treated as a
- // wholly unique module.
- struct NonPhysicalFileDependencyEntry : BaseShaderCacheTest
+ // Test caching of shaders that are compiled from source strings instead of files.
+ struct ShaderCacheTestSourceString : ShaderCacheTest
{
- void generateNewPipelineState(Slang::String shaderContents)
+ void createComputePipeline(Slang::String shaderSource)
{
ComPtr<IShaderProgram> shaderProgram;
- GFX_CHECK_CALL_ABORT(loadComputeProgramFromSource(device, shaderProgram, shaderContents));
+ GFX_CHECK_CALL_ABORT(loadComputeProgramFromSource(device, shaderProgram, shaderSource));
ComputePipelineStateDesc pipelineDesc = {};
pipelineDesc.program = shaderProgram.get();
@@ -1213,135 +903,120 @@ namespace gfx_test
device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
}
- void runTests()
+ void runComputePipeline(Slang::String shaderSource)
{
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(contentsA);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(contentsA);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
-
- generateNewDevice();
- createRequiredResources();
- generateNewPipelineState(contentsC);
- submitGPUWork();
-
- device->queryInterface(SLANG_UUID_IShaderCacheStatistics, (void**)shaderCacheStats.writeRef());
- SLANG_CHECK(shaderCacheStats->getCacheMissCount() == 1);
- SLANG_CHECK(shaderCacheStats->getCacheHitCount() == 0);
- SLANG_CHECK(shaderCacheStats->getCacheEntryDirtyCount() == 0);
+ createComputeResources();
+ createComputePipeline(shaderSource);
+ dispatchComputePipeline();
+ freeComputeResources();
}
- };
- template <typename T>
- void shaderCacheTestImpl(ComPtr<IDevice> device, UnitTestContext* context)
- {
- T test;
- test.init(device, context);
- test.run();
- }
+ void runTests()
+ {
+ // Cache is cold and we expect 3 misses.
+ runStep(
+ [this]()
+ {
+ runComputePipeline(computeShaderA);
+ runComputePipeline(computeShaderB);
+ runComputePipeline(computeShaderC);
- SLANG_UNIT_TEST(singleEntryShaderCacheD3D12)
- {
- runTestImpl(shaderCacheTestImpl<SingleEntryShaderCache>, unitTestContext, Slang::RenderApiFlag::D3D12);
- }
+ SLANG_CHECK(getStats().missCount == 3);
+ SLANG_CHECK(getStats().hitCount == 0);
+ SLANG_CHECK(getStats().entryCount == 3);
+ }
+ );
- SLANG_UNIT_TEST(singleEntryShaderCacheVulkan)
- {
- runTestImpl(shaderCacheTestImpl<SingleEntryShaderCache>, unitTestContext, Slang::RenderApiFlag::Vulkan);
- }
+ // Cache is hot and we expect 3 hits.
+ runStep(
+ [this]()
+ {
+ runComputePipeline(computeShaderA);
+ runComputePipeline(computeShaderB);
+ runComputePipeline(computeShaderC);
- SLANG_UNIT_TEST(multipleEntryShaderCacheD3D12)
- {
- runTestImpl(shaderCacheTestImpl<MultipleEntryShaderCache>, unitTestContext, Slang::RenderApiFlag::D3D12);
- }
+ SLANG_CHECK(getStats().missCount == 0);
+ SLANG_CHECK(getStats().hitCount == 3);
+ SLANG_CHECK(getStats().entryCount == 3);
+ }
+ );
+ }
+ };
- SLANG_UNIT_TEST(multipleEntryShaderCacheVulkan)
+ template<typename T>
+ void runTest(UnitTestContext* context, Slang::RenderApiFlag::Enum api)
{
- runTestImpl(shaderCacheTestImpl<MultipleEntryShaderCache>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ T test;
+ test.run(context, api);
}
- SLANG_UNIT_TEST(multipleEntryPointShaderCacheD3D12)
+ SLANG_UNIT_TEST(shaderCacheBasicD3D12)
{
- runTestImpl(shaderCacheTestImpl<MultipleEntryPointShader>, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTest<ShaderCacheTestBasic>(unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(multipleEntryPointShaderCacheVulkan)
+ SLANG_UNIT_TEST(shaderCacheBasicVulkan)
{
- runTestImpl(shaderCacheTestImpl<MultipleEntryPointShader>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTest<ShaderCacheTestBasic>(unitTestContext, Slang::RenderApiFlag::Vulkan);
}
- SLANG_UNIT_TEST(shaderFileImportsShaderCacheD3D12)
+ SLANG_UNIT_TEST(shaderCacheEntryPointD3D12)
{
- runTestImpl(shaderCacheTestImpl<ShaderFileImportsShaderCache>, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTest<ShaderCacheTestEntryPoint>(unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(shaderFileImportsShaderCacheVulkan)
+ SLANG_UNIT_TEST(shaderCacheEntryPointVulkan)
{
- runTestImpl(shaderCacheTestImpl<ShaderFileImportsShaderCache>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTest<ShaderCacheTestEntryPoint>(unitTestContext, Slang::RenderApiFlag::Vulkan);
}
- SLANG_UNIT_TEST(specializationArgsShaderCacheD3D12)
+ SLANG_UNIT_TEST(shaderCacheImportIncludeD3D12)
{
- runTestImpl(shaderCacheTestImpl<SpecializationArgsEntries>, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTest<ShaderCacheTestImportInclude>(unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(specializationArgsShaderCacheVulkan)
+ SLANG_UNIT_TEST(shaderCacheImportIncludeVulkan)
{
- runTestImpl(shaderCacheTestImpl<SpecializationArgsEntries>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTest<ShaderCacheTestImportInclude>(unitTestContext, Slang::RenderApiFlag::Vulkan);
}
- SLANG_UNIT_TEST(cacheEvictionPolicyD3D12)
+ SLANG_UNIT_TEST(shaderCacheSpecializationD3D12)
{
- runTestImpl(shaderCacheTestImpl<CacheWithMaxEntryLimit>, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTest<ShaderCacheTestSpecialization>(unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(cacheEvictionPolicyVulkan)
+ SLANG_UNIT_TEST(shaderCacheSpecializationVulkan)
{
- runTestImpl(shaderCacheTestImpl<CacheWithMaxEntryLimit>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTest<ShaderCacheTestSpecialization>(unitTestContext, Slang::RenderApiFlag::Vulkan);
}
- SLANG_UNIT_TEST(graphicsShaderCacheD3D12)
+ SLANG_UNIT_TEST(shaderCacheGraphicsD3D12)
{
- runTestImpl(shaderCacheTestImpl<GraphicsShaderCache>, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTest<ShaderCacheTestGraphics>(unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(graphicsShaderCacheVulkan)
+ SLANG_UNIT_TEST(shaderCacheGraphicsVulkan)
{
- runTestImpl(shaderCacheTestImpl<GraphicsShaderCache>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTest<ShaderCacheTestGraphics>(unitTestContext, Slang::RenderApiFlag::Vulkan);
}
- SLANG_UNIT_TEST(splitGraphicsShaderCacheD3D12)
+ SLANG_UNIT_TEST(shaderCacheGraphicsSplitD3D12)
{
- runTestImpl(shaderCacheTestImpl<SplitGraphicsShader>, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTest<ShaderCacheTestGraphicsSplit>(unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(splitGraphicsShaderCacheVulkan)
+ SLANG_UNIT_TEST(shaderCacheGraphicsSplitVulkan)
{
- runTestImpl(shaderCacheTestImpl<SplitGraphicsShader>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTest<ShaderCacheTestGraphicsSplit>(unitTestContext, Slang::RenderApiFlag::Vulkan);
}
- SLANG_UNIT_TEST(nonPhysicalFileDependenciesCacheEntryD3D12)
+ SLANG_UNIT_TEST(shaderCacheSourceStringD3D12)
{
- runTestImpl(shaderCacheTestImpl<NonPhysicalFileDependencyEntry>, unitTestContext, Slang::RenderApiFlag::D3D12);
+ runTest<ShaderCacheTestSourceString>(unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(nonPhysicalFileDependenciesCacheEntryVulkan)
+ SLANG_UNIT_TEST(shaderCacheSourceStringVulkan)
{
- runTestImpl(shaderCacheTestImpl<NonPhysicalFileDependencyEntry>, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ runTest<ShaderCacheTestSourceString>(unitTestContext, Slang::RenderApiFlag::Vulkan);
}
}
diff --git a/tools/gfx/gfx.slang b/tools/gfx/gfx.slang
index b4bd76470..3d75e3b40 100644
--- a/tools/gfx/gfx.slang
+++ b/tools/gfx/gfx.slang
@@ -1712,15 +1712,10 @@ struct SlangDesc
struct ShaderCacheDesc
{
- // The filename for the file the cache's state should be saved to or loaded from.
- NativeString cacheFilename = "cache.txt";
- // The root directory for the shader cache.
+ // The root directory for the shader cache. If not set, shader cache is disabled.
NativeString shaderCachePath;
- // The file system for loading cached shader kernels. The layer does not maintain a strong reference to the object,
- // instead the user is responsible for holding the object alive during the lifetime of an `IDevice`.
- void* shaderCacheFileSystem = nullptr;
// The maximum number of entries stored in the cache.
- GfxCount entryCountLimit = 0;
+ GfxCount maxEntryCount = 0;
};
struct DeviceInteropHandles
@@ -1934,6 +1929,21 @@ interface IDevice
Result getTextureRowAlignment(out Size outAlignment);
};
+struct ShaderCacheStats
+{
+ GfxCount hitCount;
+ GfxCount missCount;
+ GfxCount entryCount;
+};
+
+[COM("715bdf26-5135-11eb-AE93-02-42-AC-13-00-02")]
+interface IShaderCache
+{
+ Result clearShaderCache();
+ Result getShaderCacheStats(out ShaderCacheStats outStats);
+ Result resetShaderCacheStats();
+};
+
#define SLANG_GFX_IMPORT [DllImport("gfx")]
/// Checks if format is compressed
SLANG_GFX_IMPORT bool gfxIsCompressedFormat(Format format);
diff --git a/tools/gfx/persistent-shader-cache.cpp b/tools/gfx/persistent-shader-cache.cpp
deleted file mode 100644
index 7dc64632b..000000000
--- a/tools/gfx/persistent-shader-cache.cpp
+++ /dev/null
@@ -1,316 +0,0 @@
-// slang-shader-cache-index.cpp
-#include "persistent-shader-cache.h"
-
-#include "../../source/core/slang-io.h"
-#include "../../source/core/slang-string-util.h"
-#include "../../source/core/slang-file-system.h"
-
-#include "../../source/core/slang-char-util.h"
-
-#include <chrono>
-
-namespace gfx
-{
-
-using namespace std::chrono;
-
-PersistentShaderCache::PersistentShaderCache(const IDevice::ShaderCacheDesc& inDesc)
-{
- desc = inDesc;
-
- // If a path is provided, we will want our underlying file system to be initialized using that path.
- if (desc.shaderCachePath)
- {
- if (!desc.shaderCacheFileSystem)
- {
- // Only a path was provided, so we get a mutable file system
- // using OSFileSystem::getMutableSingleton.
- desc.shaderCacheFileSystem = OSFileSystem::getMutableSingleton();
- }
- desc.shaderCacheFileSystem = new RelativeFileSystem(desc.shaderCacheFileSystem, desc.shaderCachePath);
- }
-
- // If our shader cache has an underlying file system, check if it's mutable. If so, store a pointer
- // to the mutable version for operations which require writing to disk.
- if (desc.shaderCacheFileSystem)
- {
- desc.shaderCacheFileSystem->queryInterface(ISlangMutableFileSystem::getTypeGuid(), (void**)mutableShaderCacheFileSystem.writeRef());
- }
-
- loadCacheFromFile();
-}
-
-PersistentShaderCache::~PersistentShaderCache()
-{
- if (isMemoryFileSystem)
- {
- saveCacheToMemory();
- }
-}
-
-// Load a previous cache index saved to disk. If not found, create a new cache index
-// and save it to disk as filename.
-void PersistentShaderCache::loadCacheFromFile()
-{
- // We will need to combine the filename with the cache path in order to have the correct
- // file path for initializing the stream. This needs to be done separately because there
- // is no guarantee that the underlying file system is mutable.
- String filePath;
- if (mutableShaderCacheFileSystem)
- {
- ComPtr<ISlangBlob> fullPath;
- if (SLANG_FAILED(mutableShaderCacheFileSystem->getPath(PathKind::OperatingSystem, desc.cacheFilename, fullPath.writeRef())))
- {
- // If we fail to obtain a physical file path, then this must be a MemoryFileSystem. In this case, file streams
- // will not work as they require the file to be on disk, so we will rely on a fall back implementation.
- isMemoryFileSystem = true;
- loadCacheFromMemory();
- return;
- }
- filePath = String((char*)fullPath->getBufferPointer());
- }
- else
- {
- filePath = Path::combine(String(desc.shaderCachePath), String(desc.cacheFilename));
- }
-
- if (SLANG_FAILED(indexStream.init(filePath, FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite)))
- {
- // If we failed to open a stream to the file, then the file does not yet exist on disk.
- // We will create the index file if our underlying file system is mutable.
- if (mutableShaderCacheFileSystem)
- {
- indexStream.init(filePath, FileMode::Create, FileAccess::ReadWrite, FileShare::ReadWrite);
- }
- return;
- }
- else
- {
- const auto start = indexStream.getPosition();
- indexStream.seek(SeekOrigin::End, 0);
- const auto end = indexStream.getPosition();
- indexStream.seek(SeekOrigin::Start, 0);
- const Index numEntries = (Index)(end - start) / sizeof(ShaderCacheEntry);
-
- if (desc.entryCountLimit > 0 && numEntries > desc.entryCountLimit)
- {
- // If the size limit for the current cache is smaller than the cache that produced the file we're trying to
- // load, re-create the entire file.
- //
- // FileStream does not currently have any methods for truncating an existing file, so in this case, our cache
- // index would no longer accurately reflect the state of our cache due to the extra now-garbage lines present.
- // While this has no impact on cache operation, it could be problematic for debugging purposes, etc.
- indexStream.close();
- indexStream.init(filePath, FileMode::Create, FileAccess::ReadWrite, FileShare::ReadWrite);
- return;
- }
- else
- {
- // The cache index is not guaranteed to be ordered by most recent access, so we need a temporary list to store
- // all the entries in order to sort them before filling in our linked list.
- List<ShaderCacheEntry> tempEntries;
- tempEntries.setCount(numEntries);
- size_t bytesRead;
- indexStream.read(tempEntries.getBuffer(), sizeof(ShaderCacheEntry) * numEntries, bytesRead);
-
- // We will need to sort tempEntries by last accessed time before we can add entries to our linked list.
- tempEntries.quickSort(tempEntries.getBuffer(), 0, tempEntries.getCount() - 1, [](ShaderCacheEntry a, ShaderCacheEntry b) { return a.lastAccessedTime > b.lastAccessedTime; });
- for (auto& entry : tempEntries)
- {
- // If we reach this point, then the current cache is at least the same size in entries as the cache
- // that produced the index we're reading in, so we don't need to check if we're exceeding capacity.
- auto entryIndexNode = orderedEntries.AddLast(entries.getCount());
- entries.add(entry);
- keyToEntry.Add(entry.dependencyBasedDigest, entryIndexNode);
- }
- }
- }
-}
-
-ShaderCacheEntry* PersistentShaderCache::findEntry(const DigestType& key, ISlangBlob** outCompiledCode)
-{
- LinkedNode<Index>* entryIndexNode;
- if (!keyToEntry.TryGetValue(key, entryIndexNode))
- {
- // The key was not found in the cache, so we return nullptr.
- *outCompiledCode = nullptr;
- return nullptr;
- }
-
- // If the key is found, load the stored contents from disk. We then move the corresponding
- // entry to the front of the linked list and update the cache file on disk
- desc.shaderCacheFileSystem->loadFile(key.toString().getBuffer(), outCompiledCode);
- auto index = entryIndexNode->Value;
- entries[index].lastAccessedTime = (double)high_resolution_clock::now().time_since_epoch().count();
- if (orderedEntries.FirstNode() != entryIndexNode)
- {
- orderedEntries.RemoveFromList(entryIndexNode);
- orderedEntries.AddFirst(entryIndexNode);
- if (mutableShaderCacheFileSystem && !isMemoryFileSystem)
- {
- auto offset = index * sizeof(ShaderCacheEntry);
- indexStream.seek(SeekOrigin::Start, offset + 2 * sizeof(DigestType));
- indexStream.write(&entries[index].lastAccessedTime, sizeof(double));
- indexStream.flush();
- }
- }
- return &entries[index];
-}
-
-void PersistentShaderCache::addEntry(const DigestType& dependencyDigest, const DigestType& contentsDigest, ISlangBlob* compiledCode)
-{
- if (!mutableShaderCacheFileSystem)
- {
- // Should not save new entries if the underlying file system isn't mutable.
- return;
- }
-
- // Check that we do not exceed the cache's size limit by adding another entry. If so,
- // remove the least recently used entry first.
- //
- // In theory, the cache could be more than just one entry over the entry count limit.
- // However, this is impossible in practice because we fully re-create the entry list
- // and cache index file if the size of the current cache is smaller than the cache
- // that generated the index file we loaded. In any case, the initial number of entries
- // in the cache will always be fewer than the size limit and this check will be hit
- // on the first entry added that exceeds the cache's size.
- Index index = entries.getCount();
- if (desc.entryCountLimit > 0 && orderedEntries.Count() >= desc.entryCountLimit)
- {
- index = deleteLRUEntry();
- }
-
- auto lastAccessedTime = (double)high_resolution_clock::now().time_since_epoch().count();
-
- ShaderCacheEntry entry = { dependencyDigest, contentsDigest, lastAccessedTime };
- auto entryNode = orderedEntries.AddFirst(index);
- if (index == entries.getCount())
- {
- // No entries were removed, so we can tack this entry on at the end.
- entries.add(entry);
- }
- else
- {
- // An entry was deleted, so we overwrite that slot with the new entry.
- entries[index] = entry;
- }
- keyToEntry.Add(dependencyDigest, entryNode);
-
- mutableShaderCacheFileSystem->saveFileBlob(dependencyDigest.toString().getBuffer(), compiledCode);
-
- if (!isMemoryFileSystem)
- {
- indexStream.seek(SeekOrigin::End, 0);
- indexStream.write(&entry, sizeof(ShaderCacheEntry));
- indexStream.flush();
- }
-}
-
-void PersistentShaderCache::updateEntry(
- const DigestType& dependencyDigest,
- const DigestType& contentsDigest,
- ISlangBlob* updatedCode)
-{
- if (!mutableShaderCacheFileSystem)
- {
- // Updating entries requires saving to disk in order to overwrite the old shader file
- // on disk, so we return if the underlying file system isn't mutable.
- return;
- }
-
- // Unlike in addEntry(), we only update the contents digest here because the last accessed time will have already
- // been updated while finding the entry.
- auto entryIndexNode = *keyToEntry.TryGetValue(dependencyDigest);
- auto index = entryIndexNode->Value;
- entries[index].contentsBasedDigest = contentsDigest;
- mutableShaderCacheFileSystem->saveFileBlob(dependencyDigest.toString().getBuffer(), updatedCode);
-
- if (!isMemoryFileSystem)
- {
- auto offset = index * sizeof(ShaderCacheEntry);
- indexStream.seek(SeekOrigin::Start, offset + sizeof(DigestType));
- indexStream.write(&contentsDigest, sizeof(DigestType));
- indexStream.flush();
- }
-}
-
-Index PersistentShaderCache::deleteLRUEntry()
-{
- if (!mutableShaderCacheFileSystem)
- {
- // This is here as a safety precaution but should never be hit as
- // addEntry() and its memory-based equivalent are the only functions
- // that should call this.
- return -1;
- }
-
- auto lruEntry = orderedEntries.LastNode();
- auto index = lruEntry->Value;
- auto shaderKey = entries[index].dependencyBasedDigest;
-
- keyToEntry.Remove(shaderKey);
- mutableShaderCacheFileSystem->remove(shaderKey.toString().getBuffer());
-
- orderedEntries.Delete(lruEntry);
- return index;
-}
-
-// An in-memory file system cannot utilize file streaming to update the index file in place.
-// Consequently, the cache index file is updated once on exit and is guaranteed to maintain the
-// correct order of entries from most to least recently used. However, any kind of interruption
-// in program execution that results in the cache destructor not being called will result in an
-// inaccurate cache index.
-//
-// These currently assume that the underlying file system must be a MemoryFileSystem as this is the
-// only in-memory file system that currently exists in Slang, which is guaranteed to be mutable.
-// Mutability checks will need to be added if this changes in the future.
-void PersistentShaderCache::loadCacheFromMemory()
-{
- ComPtr<ISlangBlob> indexBlob;
- if (SLANG_FAILED(mutableShaderCacheFileSystem->loadFile(desc.cacheFilename, indexBlob.writeRef())))
- {
- mutableShaderCacheFileSystem->saveFile(desc.cacheFilename, nullptr, 0);
- return;
- }
-
- auto indexString = UnownedStringSlice((char*)indexBlob->getBufferPointer());
-
- List<UnownedStringSlice> lines;
- StringUtil::calcLines(indexString, lines);
- for (auto line : lines)
- {
- List<UnownedStringSlice> entryFields;
- StringUtil::split(line, ' ', entryFields);
- if (entryFields.getCount() != 2)
- continue;
-
- ShaderCacheEntry entry;
- entry.dependencyBasedDigest = DigestType(entryFields[0]);
- entry.contentsBasedDigest = DigestType(entryFields[1]);
- entry.lastAccessedTime = 0;
-
- auto entryNode = orderedEntries.AddLast(entries.getCount());
- entries.add(entry);
- keyToEntry.Add(entry.dependencyBasedDigest, entryNode);
-
- if (desc.entryCountLimit > 0 && orderedEntries.Count() == desc.entryCountLimit)
- break;
- }
-}
-
-void PersistentShaderCache::saveCacheToMemory()
-{
- StringBuilder indexSb;
- for (auto& entryIndex : orderedEntries)
- {
- auto entry = entries[entryIndex];
- indexSb << entry.dependencyBasedDigest.toString();
- indexSb << " ";
- indexSb << entry.contentsBasedDigest.toString();
- indexSb << "\n";
- }
-
- mutableShaderCacheFileSystem->saveFile(desc.cacheFilename, indexSb.getBuffer(), indexSb.getLength());
-}
-
-}
diff --git a/tools/gfx/persistent-shader-cache.h b/tools/gfx/persistent-shader-cache.h
deleted file mode 100644
index 530d50a58..000000000
--- a/tools/gfx/persistent-shader-cache.h
+++ /dev/null
@@ -1,99 +0,0 @@
-// slang-shader-cache-index.h
-#pragma once
-#include "../../slang.h"
-#include "../../slang-gfx.h"
-#include "../../slang-com-ptr.h"
-
-#include "../../source/core/slang-string.h"
-#include "../../source/core/slang-dictionary.h"
-#include "../../source/core/slang-linked-list.h"
-#include "../../source/core/slang-stream.h"
-#include "../../source/core/slang-crypto.h"
-
-namespace gfx
-{
-
-using namespace Slang;
-
-using DigestType = MD5::Digest;
-
-struct ShaderCacheEntry
-{
- DigestType dependencyBasedDigest;
- DigestType contentsBasedDigest;
- double lastAccessedTime;
-
- bool operator==(const ShaderCacheEntry& rhs)
- {
- return dependencyBasedDigest == rhs.dependencyBasedDigest
- && contentsBasedDigest == rhs.contentsBasedDigest
- && lastAccessedTime == rhs.lastAccessedTime;
- }
-
- uint32_t getHashCode()
- {
- return dependencyBasedDigest.getHashCode();
- }
-};
-
-class PersistentShaderCache : public RefObject
-{
-public:
- PersistentShaderCache(const IDevice::ShaderCacheDesc& inDesc);
- ~PersistentShaderCache();
-
- // Fetch the cache entry corresponding to the provided key. If found, move the entry to
- // the front of entries and return the entry and the corresponding compiled code in
- // outCompiledCode. Else, return nullptr.
- ShaderCacheEntry* findEntry(const DigestType& key, ISlangBlob** outCompiledCode);
-
- // Add an entry to the cache with the provided key and contents hashes. If
- // adding an entry causes the cache to exceed size limitations, this will also
- // delete the least recently used entry.
- void addEntry(const DigestType& dependencyDigest, const DigestType& contentsDigest, ISlangBlob* compiledCode);
-
- // Update the contents hash for the specified entry in the cache and update the
- // corresponding file on disk.
- void updateEntry(const DigestType& dependencyDigest, const DigestType& contentsDigest, ISlangBlob* updatedCode);
-
-private:
- // Load a previous cache index saved to disk. If not found, create a new cache index
- // and save it to disk as filename.
- void loadCacheFromFile();
-
- // Delete the last entry (the least recently used) from entries, remove its key/value pair
- // from keyToEntry, and remove the corresponding file on disk. Returns the index in 'entries'
- // of the removed entry so addEntry() can overwrite the corresponding entry in 'entries'
- // with the new entry. This should only be called by addEntry() when the cache reaches maximum capacity.
- Index deleteLRUEntry();
-
- // Without access to a physical file path, in-memory file systems cannot leverage file streams and
- // need to fall back on a different implementation for loading and saving the cache to memory.
- void loadCacheFromMemory();
- void saveCacheToMemory();
-
- // The shader cache's description.
- IDevice::ShaderCacheDesc desc;
-
- // The underlying file system used for the shader cache.
- ComPtr<ISlangMutableFileSystem> mutableShaderCacheFileSystem = nullptr;
- bool isMemoryFileSystem = false;
-
- // A file stream to the index file opened during cache load. This will only
- // exist for a cache that exists on-disk.
- FileStream indexStream;
-
- // Dictionary mapping each shader's key to its corresponding node (entry) in the
- // linked list 'orderedEntries'.
- Dictionary<DigestType, LinkedNode<Index>*> keyToEntry;
-
- // Linked list containing the corresponding indices in 'entries' for entries in the
- // shader cache ordered from most to least recently used.
- LinkedList<Index> orderedEntries;
-
- // List of entries in the shader cache. This list is not guaranteed to be in order of recency
- // as the main and fall back implementations handle outputting to the file differently.
- List<ShaderCacheEntry> entries;
-};
-
-}
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 3397b325e..4a8fd04b6 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -27,7 +27,7 @@ const Slang::Guid GfxGUID::IID_IResource = SLANG_UUID_IResource;
const Slang::Guid GfxGUID::IID_IBufferResource = SLANG_UUID_IBufferResource;
const Slang::Guid GfxGUID::IID_ITextureResource = SLANG_UUID_ITextureResource;
const Slang::Guid GfxGUID::IID_IDevice = SLANG_UUID_IDevice;
-const Slang::Guid GfxGUID::IID_IShaderCacheStatistics = SLANG_UUID_IShaderCacheStatistics;
+const Slang::Guid GfxGUID::IID_IShaderCache = SLANG_UUID_IShaderCache;
const Slang::Guid GfxGUID::IID_IShaderObject = SLANG_UUID_IShaderObject;
const Slang::Guid GfxGUID::IID_IRenderPassLayout = SLANG_UUID_IRenderPassLayout;
@@ -343,48 +343,19 @@ Result RendererBase::getEntryPointCodeFromShaderCache(
return program->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics);
}
- // Produce a string which we can use to query the shader cache by combining two separate hashes which
- // together comprise all the compilation arguments for this program.
- ComPtr<slang::ISession> session;
- getSlangSession(session.writeRef());
-
- ComPtr<ISlangBlob> shaderKeyBlob;
- program->computeDependencyBasedHash(entryPointIndex, targetIndex, shaderKeyBlob.writeRef());
- DigestType shaderKey(shaderKeyBlob);
-
- // Produce a hash using the AST for this program - This is needed to check whether a cache entry is effectively dirty,
- // or to save along with the compiled code into an entry so the entry can be checked if fetched later on.
- ComPtr<ISlangBlob> contentsHashBlob;
- program->computeContentsBasedHash(contentsHashBlob.writeRef());
- DigestType contentsHash(contentsHashBlob);
+ // Hash all relevant state for generating the entry point shader code to use as a key
+ // for the shader cache.
+ ComPtr<ISlangBlob> hashBlob;
+ program->getEntryPointHash(entryPointIndex, targetIndex, hashBlob.writeRef());
+ PersistentCache::Key cacheKey(hashBlob);
+ // Query the shader cache.
ComPtr<ISlangBlob> codeBlob;
-
- // Query the shader cache index for an entry with shaderKey as its key.
- auto entry = persistentShaderCache->findEntry(shaderKey, codeBlob.writeRef());
- if (entry && contentsHash == entry->contentsBasedDigest)
+ if (persistentShaderCache->readEntry(cacheKey, codeBlob.writeRef()) != SLANG_OK)
{
- // We found the entry in the cache, and the entry's contents are up-to-date. Nothing else needs to be done.
- shaderCacheHitCount++;
- }
- else
- {
- // There are two possibilities: the entry does not exist in the cache, or the entry's contents are out-of-date.
- // Both will require calling getEntryPointCode() in order to fetch the correct compiled code, so we'll do that now.
+ // No cached entry found. Generate the code and add it to the cache.
SLANG_RETURN_ON_FAIL(program->getEntryPointCode(entryPointIndex, targetIndex, codeBlob.writeRef(), outDiagnostics));
-
- // If the entry was not found in the cache, let's add it. Otherwise, the entry's contents were out-of-date, so let's
- // update the entry with the updated contents.
- if (!entry)
- {
- persistentShaderCache->addEntry(shaderKey, contentsHash, codeBlob);
- shaderCacheMissCount++;
- }
- else
- {
- persistentShaderCache->updateEntry(shaderKey, contentsHash, codeBlob);
- shaderCacheEntryDirtyCount++;
- }
+ persistentShaderCache->writeEntry(cacheKey, codeBlob);
}
*outCode = codeBlob.detach();
@@ -393,9 +364,10 @@ Result RendererBase::getEntryPointCodeFromShaderCache(
SlangResult RendererBase::queryInterface(SlangUUID const& uuid, void** outObject)
{
- if (uuid == GfxGUID::IID_IShaderCacheStatistics)
+ // Only return the shader cache interface if it is enabled.
+ if (uuid == GfxGUID::IID_IShaderCache && persistentShaderCache)
{
- *outObject = static_cast<IShaderCacheStatistics*>(this);
+ *outObject = static_cast<IShaderCache*>(this);
addRef();
return SLANG_OK;
}
@@ -413,12 +385,13 @@ IDevice* gfx::RendererBase::getInterface(const Guid& guid)
SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc)
{
- auto cacheDesc = desc.shaderCache;
- // We only want to initialize the shader cache if either a shader cache path or file system
- // was provided.
- if (cacheDesc.shaderCachePath || cacheDesc.shaderCacheFileSystem)
+ // We only want to initialize the shader cache if a shader cache path was provided.
+ if (desc.shaderCache.shaderCachePath)
{
- persistentShaderCache = new PersistentShaderCache(desc.shaderCache);
+ PersistentCache::Desc cacheDesc;
+ cacheDesc.directory = desc.shaderCache.shaderCachePath;
+ cacheDesc.maxEntryCount = desc.shaderCache.maxEntryCount;
+ persistentShaderCache = new PersistentCache(cacheDesc);
}
if (desc.apiCommandDispatcher)
@@ -751,26 +724,31 @@ Result RendererBase::getShaderObjectLayout(
return SLANG_OK;
}
-GfxCount RendererBase::getCacheMissCount()
+Result RendererBase::clearShaderCache()
{
- return shaderCacheMissCount;
+ SLANG_ASSERT(persistentShaderCache);
+ return persistentShaderCache->clear();
}
-GfxCount RendererBase::getCacheHitCount()
+Result RendererBase::getShaderCacheStats(ShaderCacheStats* outStats)
{
- return shaderCacheHitCount;
-}
+ SLANG_ASSERT(persistentShaderCache);
+ if (!outStats)
+ {
+ return SLANG_E_INVALID_ARG;
+ }
-GfxCount RendererBase::getCacheEntryDirtyCount()
-{
- return shaderCacheEntryDirtyCount;
+ const auto& stats = persistentShaderCache->getStats();
+ outStats->entryCount = (GfxCount)stats.entryCount;
+ outStats->hitCount = (GfxCount)stats.hitCount;
+ outStats->missCount = (GfxCount)stats.missCount;
+ return SLANG_OK;
}
-Result RendererBase::resetCacheStatistics()
+Result RendererBase::resetShaderCacheStats()
{
- shaderCacheMissCount = 0;
- shaderCacheHitCount = 0;
- shaderCacheEntryDirtyCount = 0;
+ SLANG_ASSERT(persistentShaderCache);
+ persistentShaderCache->resetStats();
return SLANG_OK;
}
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 01111e292..c7137f0fa 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -4,8 +4,7 @@
#include "slang-context.h"
#include "core/slang-basic.h"
#include "core/slang-com-object.h"
-
-#include "persistent-shader-cache.h"
+#include "core/slang-persistent-cache.h"
#include "resource-desc-utils.h"
@@ -28,7 +27,7 @@ struct GfxGUID
static const Slang::Guid IID_ITextureResource;
static const Slang::Guid IID_IInputLayout;
static const Slang::Guid IID_IDevice;
- static const Slang::Guid IID_IShaderCacheStatistics;
+ static const Slang::Guid IID_IShaderCache;
static const Slang::Guid IID_IShaderObjectLayout;
static const Slang::Guid IID_IShaderObject;
static const Slang::Guid IID_IRenderPassLayout;
@@ -1214,7 +1213,7 @@ public:
// Renderer implementation shared by all platforms.
// Responsible for shader compilation, specialization and caching.
-class RendererBase : public IDevice, public IShaderCacheStatistics, public Slang::ComObject
+class RendererBase : public IDevice, public IShaderCache, public Slang::ComObject
{
friend class ShaderObjectBase;
public:
@@ -1354,27 +1353,21 @@ public:
ShaderObjectLayoutBase* layout,
IShaderObject** outObject) = 0;
+ public:
+ // IShaderCache interface
+ virtual SLANG_NO_THROW Result SLANG_MCALL clearShaderCache() SLANG_OVERRIDE;
+ virtual SLANG_NO_THROW Result SLANG_MCALL getShaderCacheStats(ShaderCacheStats* outStats) SLANG_OVERRIDE;
+ virtual SLANG_NO_THROW Result SLANG_MCALL resetShaderCacheStats() SLANG_OVERRIDE;
+
protected:
virtual SLANG_NO_THROW SlangResult SLANG_MCALL initialize(const Desc& desc);
protected:
Slang::List<Slang::String> m_features;
-
-public:
- virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheMissCount() override;
- virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheHitCount() override;
- virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheEntryDirtyCount() override;
- virtual SLANG_NO_THROW Result SLANG_MCALL resetCacheStatistics() override;
-
-protected:
- GfxCount shaderCacheMissCount = 0;
- GfxCount shaderCacheHitCount = 0;
- GfxCount shaderCacheEntryDirtyCount = 0;
-
public:
SlangContext slangContext;
ShaderCache shaderCache;
- RefPtr<PersistentShaderCache> persistentShaderCache = nullptr;
+ Slang::RefPtr<Slang::PersistentCache> persistentShaderCache;
Slang::Dictionary<slang::TypeLayoutReflection*, Slang::RefPtr<ShaderObjectLayoutBase>> m_shaderObjectLayoutCache;
Slang::ComPtr<IPipelineCreationAPIDispatcher> m_pipelineCreationAPIDispatcher;
diff --git a/tools/slang-unit-test/unit-test-lock-file.cpp b/tools/slang-unit-test/unit-test-lock-file.cpp
index c5709242d..33e787a1d 100644
--- a/tools/slang-unit-test/unit-test-lock-file.cpp
+++ b/tools/slang-unit-test/unit-test-lock-file.cpp
@@ -12,13 +12,13 @@ using namespace Slang;
SLANG_UNIT_TEST(lockFile)
{
- static const String fileName = "test_lock_file";
+ static String fileName = Path::simplify(Path::getParentDirectory(Path::getExecutablePath()) + "/test_lock_file");
// Open/close lock file.
{
LockFile file;
SLANG_CHECK(file.isOpen() == false);
- SLANG_CHECK(file.open(fileName) == SLANG_OK);
+ SLANG_CHECK_ABORT(file.open(fileName) == SLANG_OK);
SLANG_CHECK(file.isOpen() == true);
SLANG_CHECK(File::exists(fileName) == true);
file.close();
diff --git a/tools/slang-unit-test/unit-test-persistent-cache.cpp b/tools/slang-unit-test/unit-test-persistent-cache.cpp
new file mode 100644
index 000000000..55c358d77
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-persistent-cache.cpp
@@ -0,0 +1,629 @@
+// unit-test-persistent-cache.cpp
+#include "tools/unit-test/slang-unit-test.h"
+
+#include "../../source/core/slang-persistent-cache.h"
+#include "../../source/core/slang-io.h"
+#include "../../source/core/slang-file-system.h"
+#include "../../source/core/slang-random-generator.h"
+
+#include <chrono>
+#include <thread>
+#include <atomic>
+#include <mutex>
+#include <condition_variable>
+#include <functional>
+
+using namespace Slang;
+
+static DefaultRandomGenerator rng(0xdeadbeef);
+
+inline ComPtr<ISlangBlob> createRandomBlob(size_t size)
+{
+ ScopedAllocation alloc;
+ alloc.allocate(size);
+ rng.nextData(alloc.getData(), size);
+ return RawBlob::moveCreate(alloc);
+}
+
+inline bool isBlobEqual(ISlangBlob* a, ISlangBlob* b)
+{
+ return
+ a->getBufferSize() == b->getBufferSize() &&
+ ::memcmp(a->getBufferPointer(), b->getBufferPointer(), a->getBufferSize()) == 0;
+}
+
+class Barrier
+{
+public:
+ Barrier(size_t threadCount, std::function<void()> completionFunc = nullptr)
+ : m_threadCount(threadCount)
+ , m_waitCount(threadCount)
+ , m_completionFunc(completionFunc)
+ {}
+
+ Barrier(const Barrier& barrier) = delete;
+ Barrier& operator=(const Barrier& barrier) = delete;
+
+ void wait()
+ {
+ std::unique_lock<std::mutex> lock(m_mutex);
+
+ auto generation = m_generation;
+
+ if (--m_waitCount == 0)
+ {
+ if (m_completionFunc) m_completionFunc();
+ ++m_generation;
+ m_waitCount = m_threadCount;
+ m_condition.notify_all();
+ }
+ else
+ {
+ m_condition.wait(lock, [this, generation] () { return generation != m_generation; });
+ }
+ }
+
+private:
+ size_t m_threadCount;
+ size_t m_waitCount;
+ size_t m_generation = 0;
+ std::function<void()> m_completionFunc;
+ std::mutex m_mutex;
+ std::condition_variable m_condition;
+};
+
+namespace Slang
+{
+
+/// Helper class for performing tests on the persistent cache.
+/// This class is a friend class of PersistentCache and can access its internals.
+struct PersistentCacheTest
+{
+ ISlangMutableFileSystem* osFileSystem;
+ String cacheDirectory;
+ RefPtr<PersistentCache> cache;
+
+ PersistentCacheTest(Count maxEntryCount = 0)
+ {
+ osFileSystem = OSFileSystem::getMutableSingleton();
+ cacheDirectory = Path::simplify(Path::getParentDirectory(Path::getExecutablePath()) + "/persistent-cache-test");
+
+ removeCacheFiles();
+
+ PersistentCache::Desc desc;
+ desc.directory = cacheDirectory.getBuffer();
+ desc.maxEntryCount = maxEntryCount;
+ cache = new PersistentCache(desc);
+ }
+
+ virtual ~PersistentCacheTest()
+ {
+ cache = nullptr;
+
+ removeCacheFiles();
+ }
+
+ void removeCacheFiles()
+ {
+ // Remove all files the cache created.
+ osFileSystem->enumeratePathContents(
+ cacheDirectory.getBuffer(),
+ [](SlangPathType pathType, const char* fileName, void* userData)
+ {
+ PersistentCacheTest* self = static_cast<PersistentCacheTest*>(userData);
+ String path = self->cacheDirectory + "/" + fileName;
+ self->osFileSystem->remove(path.getBuffer());
+ },
+ this);
+
+ // Also remove the cache directory.
+ osFileSystem->remove(cacheDirectory.getBuffer());
+ }
+
+ // Entry (key, data) for testing.
+ struct Entry
+ {
+ PersistentCache::Key key;
+ ComPtr<ISlangBlob> data;
+ };
+
+ // Helper to write an entry to the cache.
+ void writeEntry(const Entry& entry)
+ {
+ SLANG_CHECK(cache->writeEntry(entry.key, entry.data) == SLANG_OK);
+ }
+
+ // Helper to read an entry from the cache and discard the data.
+ // Returns true if the entry was found, false otherwise.
+ bool readEntry(const Entry& entry)
+ {
+ ComPtr<ISlangBlob> data;
+ SlangResult result = cache->readEntry(entry.key, data.writeRef());
+ SLANG_CHECK(result == SLANG_OK || result == SLANG_E_NOT_FOUND);
+ if (result == SLANG_OK)
+ {
+ SLANG_CHECK(isBlobEqual(data, entry.data));
+ }
+ if (result == SLANG_E_NOT_FOUND)
+ {
+ SLANG_CHECK(data == nullptr);
+ }
+ return result == SLANG_OK;
+ }
+
+ // Get the absolute filename for a cache entry file.
+ String getEntryFileName(const Entry& entry)
+ {
+ return cache->getEntryFileName(entry.key);
+ }
+
+ // Get the absolute filename of the cache index file.
+ String getIndexFilename()
+ {
+ return cache->m_indexFileName;
+ }
+};
+
+} // namespace Slang
+
+// Performs basic tests on the cache.
+// - write/read entries
+// - check for correct cache stats
+// - clearing the cache
+// - resetting stats
+struct BasicTest : public PersistentCacheTest
+{
+ BasicTest() : PersistentCacheTest() {}
+
+ void run()
+ {
+ // Check that cache is empty.
+ SLANG_CHECK(cache->getStats().entryCount == 0);
+ SLANG_CHECK(cache->getStats().hitCount == 0);
+ SLANG_CHECK(cache->getStats().missCount == 0);
+
+ // Setup a list of entries to store in the cache.
+ List<Entry> entries;
+ for (size_t i = 0; i < 10; ++i)
+ {
+ auto data = createRandomBlob(i * 1024);
+ auto key = SHA1::compute(data->getBufferPointer(), data->getBufferSize());
+ entries.add(Entry{ key, data });
+ }
+
+ for (size_t i = 0; i < 10; ++i)
+ {
+ const auto& entry = entries[i];
+ ComPtr<ISlangBlob> data;
+
+ // Try to read an entry. Check that its not found and counts as a miss.
+ SLANG_CHECK(cache->readEntry(entry.key, data.writeRef()) == SLANG_E_NOT_FOUND);
+ SLANG_CHECK(cache->getStats().missCount == i + 1);
+
+ // Write the entry. Check that it gets added.
+ SLANG_CHECK(cache->writeEntry(entry.key, entry.data) == SLANG_OK);
+ SLANG_CHECK(cache->getStats().entryCount == i + 1);
+ }
+
+ SLANG_CHECK(cache->getStats().entryCount == 10);
+ SLANG_CHECK(cache->getStats().hitCount == 0);
+ SLANG_CHECK(cache->getStats().missCount == 10);
+
+ for (size_t i = 0; i < 10; ++i)
+ {
+ const auto& entry = entries[i];
+ ComPtr<ISlangBlob> data;
+
+ // Read entries. Check that these are cache hits and return the correct data.
+ SLANG_CHECK(cache->readEntry(entry.key, data.writeRef()) == SLANG_OK);
+ SLANG_CHECK(cache->getStats().hitCount == i + 1);
+ SLANG_CHECK(isBlobEqual(data, entry.data));
+ }
+
+ SLANG_CHECK(cache->getStats().entryCount == 10);
+ SLANG_CHECK(cache->getStats().hitCount == 10);
+ SLANG_CHECK(cache->getStats().missCount == 10);
+
+ // Clear the cache. Check that entry count is reset.
+ SLANG_CHECK(cache->clear() == SLANG_OK);
+ SLANG_CHECK(cache->getStats().entryCount == 0);
+ SLANG_CHECK(cache->getStats().hitCount == 10);
+ SLANG_CHECK(cache->getStats().missCount == 10);
+
+ // Reset stats.
+ cache->resetStats();
+ SLANG_CHECK(cache->getStats().entryCount == 0);
+ SLANG_CHECK(cache->getStats().hitCount == 0);
+ SLANG_CHECK(cache->getStats().missCount == 0);
+
+ // Check that cache is empty.
+ for (size_t i = 0; i < 10; ++i)
+ {
+ const auto& entry = entries[i];
+ ComPtr<ISlangBlob> data;
+ SLANG_CHECK(cache->readEntry(entry.key, data.writeRef()) == SLANG_E_NOT_FOUND);
+ }
+ SLANG_CHECK(cache->getStats().missCount == 10);
+ }
+};
+
+// Tests the least-recently-used cache eviction policy.
+struct EvictionTest : public PersistentCacheTest
+{
+ EvictionTest() : PersistentCacheTest(3) {}
+
+ void run()
+ {
+ // Setup a list of entries to store in the cache.
+ List<Entry> entries;
+ for (size_t i = 0; i < 10; ++i)
+ {
+ auto data = createRandomBlob(4096);
+ auto key = SHA1::compute(data->getBufferPointer(), data->getBufferSize());
+ entries.add(Entry{ key, data });
+ }
+
+ writeEntry(entries[0]);
+ writeEntry(entries[1]);
+ writeEntry(entries[2]);
+
+ SLANG_CHECK(readEntry(entries[0]) == true);
+ SLANG_CHECK(readEntry(entries[1]) == true);
+ SLANG_CHECK(readEntry(entries[2]) == true);
+
+ // Evict LRU entry 0.
+ writeEntry(entries[3]);
+ SLANG_CHECK(readEntry(entries[0]) == false);
+ SLANG_CHECK(readEntry(entries[1]) == true);
+ SLANG_CHECK(readEntry(entries[2]) == true);
+ SLANG_CHECK(readEntry(entries[3]) == true);
+
+ // Evict LRU entry 1.
+ writeEntry(entries[4]);
+ SLANG_CHECK(readEntry(entries[1]) == false);
+ SLANG_CHECK(readEntry(entries[2]) == true);
+ SLANG_CHECK(readEntry(entries[3]) == true);
+ SLANG_CHECK(readEntry(entries[4]) == true);
+
+ // Evict LRU entry 2.
+ writeEntry(entries[5]);
+ SLANG_CHECK(readEntry(entries[2]) == false);
+ SLANG_CHECK(readEntry(entries[3]) == true);
+ SLANG_CHECK(readEntry(entries[4]) == true);
+ SLANG_CHECK(readEntry(entries[5]) == true);
+
+ // Evict LRU entry 4.
+ SLANG_CHECK(readEntry(entries[3]) == true);
+ writeEntry(entries[6]);
+ SLANG_CHECK(readEntry(entries[3]) == true);
+ SLANG_CHECK(readEntry(entries[4]) == false);
+ SLANG_CHECK(readEntry(entries[5]) == true);
+ SLANG_CHECK(readEntry(entries[6]) == true);
+ }
+};
+
+
+// Tests the cache to be robust against various corruptions.
+// These can happen if the cache files are manipulated externally.
+// The cache might also be corrupted if the application is terminated while writing.
+struct CorruptionTest : public PersistentCacheTest
+{
+ List<Entry> entries;
+
+ template<typename Func>
+ void testIndexCorruption(Func func, SlangResult expectedReadResult)
+ {
+ writeEntry(entries[0]);
+ SLANG_CHECK(readEntry(entries[0]) == true);
+ func();
+ // We expect a SLANG_E_NOT_FOUND because the cache has an empty index now.
+ ComPtr<ISlangBlob> data;
+ SLANG_CHECK(cache->readEntry(entries[0].key, data.writeRef()) == expectedReadResult);
+
+ writeEntry(entries[0]);
+ SLANG_CHECK(readEntry(entries[0]) == true);
+ func();
+ writeEntry(entries[0]);
+ SLANG_CHECK(readEntry(entries[0]) == true);
+ }
+
+ void run()
+ {
+ // Setup a list of entries to store in the cache.
+ for (size_t i = 0; i < 10; ++i)
+ {
+ auto data = createRandomBlob(4096);
+ auto key = SHA1::compute(data->getBufferPointer(), data->getBufferSize());
+ entries.add(Entry{ key, data });
+ }
+
+ // Test behavior when a cached entry file is removed externally before reading.
+ writeEntry(entries[0]);
+ SLANG_CHECK(readEntry(entries[0]) == true);
+ osFileSystem->remove(getEntryFileName(entries[0]).getBuffer());
+ ComPtr<ISlangBlob> data;
+ // First time we read the entry, we expect a SLANG_E_CANNOT_OPEN because the file is gone.
+ SLANG_CHECK(cache->readEntry(entries[0].key, data.writeRef()) == SLANG_E_CANNOT_OPEN);
+ // The next time we read the entry, we expect a SLANG_E_NOT_FOUND because the entry has
+ // been removed from the cache index.
+ SLANG_CHECK(cache->readEntry(entries[0].key, data.writeRef()) == SLANG_E_NOT_FOUND);
+
+ // Test behavior when a cached entry file is removed externally before writing.
+ writeEntry(entries[0]);
+ SLANG_CHECK(readEntry(entries[0]) == true);
+ osFileSystem->remove(getEntryFileName(entries[0]).getBuffer());
+ writeEntry(entries[0]);
+ SLANG_CHECK(readEntry(entries[0]) == true);
+
+ // Test behavior when the index file is removed before reading.
+ writeEntry(entries[0]);
+ SLANG_CHECK(readEntry(entries[0]) == true);
+ osFileSystem->remove(getIndexFilename().getBuffer());
+ // We expect a SLANG_E_NOT_FOUND because the cache has an empty index now.
+ SLANG_CHECK(cache->readEntry(entries[0].key, data.writeRef()) == SLANG_E_NOT_FOUND);
+
+ // Test behavior when the index file is removed before writing.
+ writeEntry(entries[0]);
+ SLANG_CHECK(readEntry(entries[0]) == true);
+ osFileSystem->remove(getIndexFilename().getBuffer());
+ writeEntry(entries[1]);
+ SLANG_CHECK(readEntry(entries[1]) == true);
+
+ // Test different corruptions of the index file.
+ testIndexCorruption(
+ [this]()
+ {
+ osFileSystem->remove(getIndexFilename().getBuffer());
+ },
+ SLANG_E_NOT_FOUND);
+
+ testIndexCorruption(
+ [this]()
+ {
+ FileStream fs;
+ fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite);
+ fs.write("x", 1);
+ },
+ SLANG_E_INTERNAL_FAIL);
+
+ testIndexCorruption(
+ [this]()
+ {
+ FileStream fs;
+ fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite);
+ fs.seek(SeekOrigin::Start, 4);
+ uint32_t version = 0xffffffff;
+ fs.write(&version, sizeof(version));
+ },
+ SLANG_E_INTERNAL_FAIL);
+
+ testIndexCorruption(
+ [this]()
+ {
+ FileStream fs;
+ fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite);
+ fs.seek(SeekOrigin::Start, 8);
+ uint32_t count = 0x7fffffff;
+ fs.write(&count, sizeof(count));
+ },
+ SLANG_E_INTERNAL_FAIL);
+
+ testIndexCorruption(
+ [this]()
+ {
+ FileStream fs;
+ fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite);
+ fs.seek(SeekOrigin::Start, 8);
+ uint32_t count = 0;
+ fs.write(&count, sizeof(count));
+ },
+ SLANG_E_INTERNAL_FAIL);
+
+ testIndexCorruption(
+ [this]()
+ {
+ FileStream fs;
+ fs.init(getIndexFilename(), FileMode::Open, FileAccess::ReadWrite, FileShare::ReadWrite);
+ fs.seek(SeekOrigin::End, 0);
+ fs.write("x", 1);
+ },
+ SLANG_E_INTERNAL_FAIL);
+ }
+};
+
+struct MultiThreadingTest : public PersistentCacheTest
+{
+ void run()
+ {
+ }
+};
+
+
+#undef ENABLE_LOGGING
+#undef ENABLE_WRITE_TEST
+
+#ifdef ENABLE_LOGGING
+#define LOG(fmt, ...) printf(fmt, ##__VA_ARGS__); fflush(stdout);
+#else
+#define LOG(fmt, ...)
+#endif
+
+// Stress testing.
+// This test spawns a number of threads to do concurrent access to the cache.
+// For now this is fairly simple:
+// - spawn a number of threads
+// - write random entries to the cache concurrenctly (slightly oversubscribe)
+// - synchronize
+// - read entries from the cache concurretly (test that we get the expected number of hits/misses)
+// - synchronize
+// - repeat for a number of iterations
+struct StressTest : public PersistentCacheTest
+{
+ // Number of entries to write/read per iteration.
+ static const uint32_t kEntryCount = 100;
+ // Number of entries the cache is short for storing one iteration.
+ static const uint32_t kEntryShortageCount = 10;
+ // Number of parallel threads to write/read.
+ static const uint32_t kThreadCount = 4;
+ // Number of entries to write/read per thread per iteration.
+ static const uint32_t kBatchCount = kEntryCount / kThreadCount;
+ // Total number of iterations.
+ static const uint32_t kIterationCount = 4;
+
+ static_assert(kEntryCount % kThreadCount == 0, "kEntryCount must be divisible by kThreadCount");
+
+ List<Entry> entries;
+
+ std::atomic<uint32_t> iteration{0};
+ std::atomic<uint32_t> entriesWritten{0};
+ std::atomic<uint32_t> bytesWritten{0};
+ std::atomic<uint32_t> entriesRead{0};
+ std::atomic<uint32_t> bytesRead{0};
+ std::atomic<uint32_t> readSuccess{0};
+ std::thread threads[kThreadCount];
+
+ Barrier *read_barrier;
+ Barrier *write_barrier;
+
+ std::mutex mutex;
+ std::condition_variable conditionVariable;
+ uint32_t generation{0};
+
+ StressTest() : PersistentCacheTest(kEntryCount - kEntryShortageCount) {}
+
+ void run()
+ {
+ // Setup a list of entries to store in the cache.
+ for (size_t i = 0; i < kEntryCount * 2; ++i)
+ {
+ size_t size = rng.nextInt32InRange(256, 64 * 1024);
+ auto data = createRandomBlob(size);
+ auto key = SHA1::compute(data->getBufferPointer(), data->getBufferSize());
+ entries.add(Entry{ key, data });
+ }
+
+ auto startTime = std::chrono::high_resolution_clock::now();
+
+ Barrier read_barrier_(
+ kThreadCount,
+ []()
+ {
+ LOG("Read synchronized\n");
+ });
+ Barrier write_barrier_(
+ kThreadCount,
+ [this](){
+ LOG("Write synchronized\n");
+#ifndef ENABLE_WRITE_TEST
+ SLANG_CHECK(readSuccess == kEntryCount - kEntryShortageCount);
+ readSuccess.store(0);
+#endif
+ iteration += 1;
+ });
+
+ read_barrier = &read_barrier_;
+ write_barrier = &write_barrier_;
+
+ for (uint32_t threadIndex = 0; threadIndex < kThreadCount; ++threadIndex)
+ {
+ threads[threadIndex] = std::thread(
+ [](StressTest* self, uint32_t threadIndex)
+ {
+ LOG("Thread %u: starting\n", threadIndex);
+
+ while (true)
+ {
+ // Write to cache.
+ size_t startIndex = (self->iteration * kEntryCount + (threadIndex * kBatchCount)) % (kEntryCount * 2);
+ for (size_t i = 0; i < kBatchCount; ++i)
+ {
+ const Entry& entry = self->entries[startIndex + i];
+#ifdef ENABLE_WRITE_TEST
+ self->osFileSystem->saveFileBlob(self->getEntryFileName(entry).getBuffer(), entry.data);
+#else
+ self->writeEntry(entry);
+#endif
+ self->entriesWritten.fetch_add(1);
+ self->bytesWritten.fetch_add((uint32_t)entry.data->getBufferSize());
+ }
+
+ LOG("Thread %u: ended writing (iteration=%u)\n", threadIndex, self->iteration.load());
+
+ // Synchronize.
+ self->read_barrier->wait();
+
+ // Read from cache.
+ for (size_t i = 0; i < kBatchCount; ++i)
+ {
+ const Entry& entry = self->entries[startIndex + i];
+#ifndef ENABLE_WRITE_TEST
+ if (self->readEntry(entry))
+ {
+ self->readSuccess.fetch_add(1);
+ self->bytesRead.fetch_add((uint32_t)entry.data->getBufferSize());
+ }
+#endif
+ self->entriesRead.fetch_add(1);
+ }
+
+ LOG("Thread %u: ended reading (iteration=%u)\n", threadIndex, self->iteration.load());
+
+ // Synchronize.
+ self->write_barrier->wait();
+
+ // Terminate.
+ if (self->iteration >= kIterationCount)
+ {
+ LOG("Thread %u: terminates\n", threadIndex);
+ return;
+ }
+ }
+ },
+ this, threadIndex);
+ }
+
+ for (auto& thread : threads)
+ {
+ thread.join();
+ }
+
+ auto endTime = std::chrono::high_resolution_clock::now();
+ auto duration = endTime - startTime;
+ auto seconds = std::chrono::duration_cast<std::chrono::milliseconds>(duration).count() / 1000.0;
+
+ LOG("Total time: %.3fs\n", seconds);
+ LOG("Total bytes written: %d\n", bytesWritten.load());
+ LOG("Write througput: %.3fMB/s\n", (bytesWritten.load() / (1024.0 * 1024.0)) / seconds);
+ LOG("Total bytes read: %d\n", bytesRead.load());
+ }
+};
+
+SLANG_UNIT_TEST(persistentCacheBasic)
+{
+ BasicTest test;
+ test.run();
+}
+
+SLANG_UNIT_TEST(persistentCacheEviction)
+{
+ EvictionTest test;
+ test.run();
+}
+
+SLANG_UNIT_TEST(persistentCacheCorruption)
+{
+ CorruptionTest test;
+ test.run();
+}
+
+SLANG_UNIT_TEST(persistentCacheMultiThreading)
+{
+ MultiThreadingTest test;
+ test.run();
+}
+
+SLANG_UNIT_TEST(persistentCacheStress)
+{
+ StressTest test;
+ test.run();
+}