From f0cd62b37c5dfbbdb3fb205f1be2b8beba0dfed4 Mon Sep 17 00:00:00 2001 From: lucy96chen <47800040+lucy96chen@users.noreply.github.com> Date: Wed, 12 Oct 2022 09:55:09 -0700 Subject: Shader caching (#2432) * Changed all getEntryPointCode calls to use RendererBase::getEntryPointCodeFromShaderCache * Hashing hooked up, tests pass but need to add more to fully test functionality * checkpoint * Checkpoint: File system creation seems functional, saving is broken * checkpoint: Fixed filename generation from MD5 hash, shader blob might be going missing ahead of pipeline state creation * Fixed a lot of bugs related to hash code generation, shader cache is likely working but needs further testing * Added workaround for module loading by re-creating the test device, shader cache test functional * Vulkan shader caching bug fixed, checkpoint commit before more refinement * pre-ToT merge checkpoint * checkpoint commit, improving cache keys * Significantly expanded items included in the dependency hash for Module; Added dependency hash functions to SpecializedComponentType and RenamedEntryPointComponentType * Temporarily disable shader cache test * Mid cleanup changes, solution successfully builds * Added several helper update functions to slang-md5 to help simplify usage; Added a function under ISession to compute a hash for all linkage-related items; Function renames and cleaned up some comments * Ran premake.bat; Renamed getASTBasedHashCode to computeASTBasedHash * Added slang unit tests for Checksum and MD5; Extended gfx shader cache test to test with multiple shader files and one shader file with multiple entry points * Solution builds and shader cache tests pass, but at least a couple other tests now failing * ran premake.bat * More cleanup changes * Added shaderCachePath field to IDevice desc in gfx.slang, gfx-smoke.slang should be functional * ran premake * cleanup changes; Adding test printf to getEntryPointCodeFromShaderCache to see if output can be seen in CI * Removed debugging printfs; Added handling for getEntryPointCode() failing * Cleanup changes; Jonathan's fixes to SerialWriter to zero initialize otherwise uninitialized memory; Change to SwizzleExpr creation to zero initialize elementCount * Changed enable_if_t to enable_if * Fixed enable_if * Added test for import vs include and changes to included and imported files; Fixed build errors in CUDA; Renamed shader cache statistics fields * cleanup changes * Readd removed file * Restructured computeDependencyBasedHash calls, added computeDependencyBasedHashImpl to all classes dervied from ComponentType * Applied same restructuring to the AST hash functions * Cleanup changes; Moved HashBuilder out to slang-digest.h and added some helper functions to streamline the process of adding items to a hash * Cleanup; Fixed incorrect expected results for shader import and include test --- tools/gfx/cuda/cuda-device.cpp | 2 +- tools/gfx/d3d11/d3d11-device.cpp | 2 +- tools/gfx/d3d12/d3d12-pipeline-state.cpp | 4 +- tools/gfx/gfx.slang | 2 + tools/gfx/open-gl/render-gl.cpp | 2 +- tools/gfx/renderer-shared.cpp | 176 ++++++++++++++++++++++++++++++- tools/gfx/renderer-shared.h | 31 +++++- tools/gfx/vulkan/vk-pipeline-state.cpp | 6 +- 8 files changed, 212 insertions(+), 13 deletions(-) (limited to 'tools/gfx') diff --git a/tools/gfx/cuda/cuda-device.cpp b/tools/gfx/cuda/cuda-device.cpp index be5dbbc96..32454109b 100644 --- a/tools/gfx/cuda/cuda-device.cpp +++ b/tools/gfx/cuda/cuda-device.cpp @@ -913,7 +913,7 @@ SLANG_NO_THROW Result SLANG_MCALL DeviceImpl::createProgram( ComPtr kernelCode; ComPtr diagnostics; - auto compileResult = desc.slangGlobalScope->getEntryPointCode( + auto compileResult = getEntryPointCodeFromShaderCache(desc.slangGlobalScope, (SlangInt)0, 0, kernelCode.writeRef(), diagnostics.writeRef()); if (diagnostics) { diff --git a/tools/gfx/d3d11/d3d11-device.cpp b/tools/gfx/d3d11/d3d11-device.cpp index aa665ebd4..969eb7d1b 100644 --- a/tools/gfx/d3d11/d3d11-device.cpp +++ b/tools/gfx/d3d11/d3d11-device.cpp @@ -1239,7 +1239,7 @@ Result DeviceImpl::createProgram( ComPtr kernelCode; ComPtr diagnostics; - auto compileResult = slangGlobalScope->getEntryPointCode( + auto compileResult = getEntryPointCodeFromShaderCache(slangGlobalScope, (SlangInt)i, 0, kernelCode.writeRef(), diagnostics.writeRef()); if (diagnostics) diff --git a/tools/gfx/d3d12/d3d12-pipeline-state.cpp b/tools/gfx/d3d12/d3d12-pipeline-state.cpp index ec073bf44..adfdcd518 100644 --- a/tools/gfx/d3d12/d3d12-pipeline-state.cpp +++ b/tools/gfx/d3d12/d3d12-pipeline-state.cpp @@ -50,7 +50,7 @@ Result PipelineStateImpl::ensureAPIPipelineStateCreated() auto programImpl = static_cast(m_program.Ptr()); if (programImpl->m_shaders.getCount() == 0) { - SLANG_RETURN_ON_FAIL(programImpl->compileShaders()); + SLANG_RETURN_ON_FAIL(programImpl->compileShaders(m_device)); } if (desc.type == PipelineType::Graphics) { @@ -356,7 +356,7 @@ Result RayTracingPipelineStateImpl::ensureAPIPipelineStateCreated() SlangInt entryPointIndex) { ComPtr codeBlob; - auto compileResult = component->getEntryPointCode( + auto compileResult = m_device->getEntryPointCodeFromShaderCache(component, entryPointIndex, 0, codeBlob.writeRef(), diagnostics.writeRef()); if (diagnostics.get()) { diff --git a/tools/gfx/gfx.slang b/tools/gfx/gfx.slang index 17a38e28d..c1296b472 100644 --- a/tools/gfx/gfx.slang +++ b/tools/gfx/gfx.slang @@ -1733,6 +1733,8 @@ struct DeviceDesc void *apiCommandDispatcher = nullptr; // The slot (typically UAV) used to identify NVAPI intrinsics. If >=0 NVAPI is required. GfxIndex nvapiExtnSlot = -1; + // The root directory for the shader cache. + NativeString shaderCachePath; // The file system for loading cached shader kernels. The layer does not maintain a strong reference to the object, // instead the user is responsible for holding the object alive during the lifetime of an `IDevice`. void *shaderCacheFileSystem = nullptr; diff --git a/tools/gfx/open-gl/render-gl.cpp b/tools/gfx/open-gl/render-gl.cpp index 7daf577ef..fdec875f0 100644 --- a/tools/gfx/open-gl/render-gl.cpp +++ b/tools/gfx/open-gl/render-gl.cpp @@ -2791,7 +2791,7 @@ Result GLDevice::createProgram( { ComPtr kernelCode; ComPtr diagnostics; - auto compileResult = desc.slangGlobalScope->getEntryPointCode( + auto compileResult = getEntryPointCodeFromShaderCache(desc.slangGlobalScope, i, 0, kernelCode.writeRef(), diagnostics.writeRef()); if (diagnostics) { diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 43629a0f2..8ed776776 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -3,6 +3,11 @@ #include "core/slang-io.h" #include "core/slang-token-reader.h" +#include "../../source/core/slang-file-system.h" + +#include "../../slang.h" +#include "../../source/slang/slang-hash-utils.h" + using namespace Slang; namespace gfx @@ -23,6 +28,7 @@ const Slang::Guid GfxGUID::IID_IResource = SLANG_UUID_IResource; const Slang::Guid GfxGUID::IID_IBufferResource = SLANG_UUID_IBufferResource; const Slang::Guid GfxGUID::IID_ITextureResource = SLANG_UUID_ITextureResource; const Slang::Guid GfxGUID::IID_IDevice = SLANG_UUID_IDevice; +const Slang::Guid GfxGUID::IID_IShaderCacheStatistics = SLANG_UUID_IShaderCacheStatistics; const Slang::Guid GfxGUID::IID_IShaderObject = SLANG_UUID_IShaderObject; const Slang::Guid GfxGUID::IID_IRenderPassLayout = SLANG_UUID_IRenderPassLayout; @@ -325,6 +331,129 @@ void PipelineStateBase::initializeBase(const PipelineStateDesc& inDesc) } } +void updateCacheEntry(ISlangMutableFileSystem* fileSystem, slang::IBlob* compiledCode, String shaderFilename, slang::Digest ASTHash) +{ + auto hashSize = sizeof(slang::Digest); + + auto bufferSize = hashSize + compiledCode->getBufferSize(); + List contents; + contents.setCount(bufferSize); + uint8_t* buffer = contents.begin(); + memcpy(buffer, &ASTHash, hashSize); + memcpy(buffer + hashSize, (void*)compiledCode->getBufferPointer(), compiledCode->getBufferSize()); + fileSystem->saveFile(shaderFilename.getBuffer(), buffer, bufferSize); +} + +Result RendererBase::getEntryPointCodeFromShaderCache( + slang::IComponentType* program, + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics) +{ + // TODO: Need a way in filesystem to query both file size and file creation time, if cache size exceeds + // specified maximum size (in bytes or files) then delete oldest files - cache eviction policy + + // Immediately call getEntryPointCode if no shader cache was provided on initialization + if (!shaderCacheFileSystem) + { + return program->getEntryPointCode(entryPointIndex, targetIndex, outCode, outDiagnostics); + } + + // Produce a string which we can use to query the shader cache by combining two separate hashes which + // together comprise all the compilation arguments for this program. + ComPtr session; + getSlangSession(session.writeRef()); + + slang::Digest shaderKeyHash; + program->computeDependencyBasedHash(entryPointIndex, targetIndex, &shaderKeyHash); + + StringBuilder shaderKey = hashToString(shaderKeyHash); + + // Produce a hash using the AST for this program - This is needed to check whether a cache entry is effectively dirty, + // or to save along with the compiled code into an entry so the entry can be checked if fetched later on. + slang::Digest ASTHash; + program->computeASTBasedHash(&ASTHash); + + ComPtr codeBlob; + + // Query shaderCacheFileSystem for an entry whose key matches shaderFilename + // - If we find it, then copy the file contents into memory and return in outCode + auto result = shaderCacheFileSystem->loadFile(shaderKey.getBuffer(), codeBlob.writeRef()); + + if (SLANG_FAILED(result)) + { + // If we didn't find it, call program->getEntryPointCode() to get and return the code. We also + // make sure to save a new entry in the shader cache. + if (SLANG_SUCCEEDED(program->getEntryPointCode(entryPointIndex, targetIndex, codeBlob.writeRef(), outDiagnostics))) + { + if (mutableShaderCacheFileSystem) + { + updateCacheEntry(mutableShaderCacheFileSystem, codeBlob, shaderKey, ASTHash); + } + + shaderCacheMissCount++; + } + else + { + // If getEntryPointCode() failed to fetch the code, we return SLANG_FAIL along with the diagnostics output + // in outDiagnostics. + return SLANG_FAIL; + } + } + else + { + // If the entry exists, we need to check that the entry isn't effectively dirty. Since we stored + // the AST hash with the compiled code, we can determine this by comparing the stored hash with the + // AST hash generated earlier. + auto entryContents = codeBlob->getBufferPointer(); + auto hashSize = sizeof(slang::Digest); + if (memcmp(ASTHash.values, entryContents, hashSize) != 0) + { + // The AST hash stored in the entry does not match the AST hash generated earlier, indicating + // that the shader code has changed and the entry needs to be updated. + if (SLANG_SUCCEEDED(program->getEntryPointCode(entryPointIndex, targetIndex, codeBlob.writeRef(), outDiagnostics))) + { + if (mutableShaderCacheFileSystem) + { + updateCacheEntry(mutableShaderCacheFileSystem, codeBlob, shaderKey, ASTHash); + } + + shaderCacheEntryDirtyCount++; + } + else + { + // If getEntryPointCode() failed to fetch the code, we return SLANG_FAIL along with the diagnostics output + // in outDiagnostics. + return SLANG_FAIL; + } + } + else + { + auto compiledCode = RawBlob::create((uint8_t*)codeBlob->getBufferPointer() + hashSize, codeBlob->getBufferSize() - hashSize); + codeBlob = compiledCode; + + shaderCacheHitCount++; + } + } + + *outCode = codeBlob.detach(); + return SLANG_OK; +} + +SlangResult RendererBase::queryInterface(SlangUUID const& uuid, void** outObject) +{ + if (uuid == GfxGUID::IID_IShaderCacheStatistics) + { + *outObject = static_cast(this); + addRef(); + return SLANG_OK; + } + + *outObject = getInterface(uuid); + return SLANG_OK; +} + IDevice* gfx::RendererBase::getInterface(const Guid& guid) { return (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IDevice) @@ -334,6 +463,28 @@ IDevice* gfx::RendererBase::getInterface(const Guid& guid) SLANG_NO_THROW Result SLANG_MCALL RendererBase::initialize(const Desc& desc) { + // If a shader cache file system was provided, use the provided system. + if (desc.shaderCacheFileSystem) + { + shaderCacheFileSystem = desc.shaderCacheFileSystem; + } + if (desc.shaderCachePath) + { + // Only a path was provided, create a RelativeFileSystem using the path + if (!shaderCacheFileSystem) + { + shaderCacheFileSystem = OSFileSystem::getMutableSingleton(); + } + shaderCacheFileSystem = new RelativeFileSystem(shaderCacheFileSystem, desc.shaderCachePath); + } + + // If we initialized a file system for the shader cache, check if it's mutable. If so, store a pointer + // to the mutable version in order to save new entries later. + if (shaderCacheFileSystem) + { + shaderCacheFileSystem->queryInterface(ISlangMutableFileSystem::getTypeGuid(), (void**)mutableShaderCacheFileSystem.writeRef()); + } + if (desc.apiCommandDispatcher) { desc.apiCommandDispatcher->queryInterface( @@ -664,7 +815,28 @@ Result RendererBase::getShaderObjectLayout( return SLANG_OK; } +GfxCount RendererBase::getCacheMissCount() +{ + return shaderCacheMissCount; +} +GfxCount RendererBase::getCacheHitCount() +{ + return shaderCacheHitCount; +} + +GfxCount RendererBase::getCacheEntryDirtyCount() +{ + return shaderCacheEntryDirtyCount; +} + +Result RendererBase::resetCacheStatistics() +{ + shaderCacheMissCount = 0; + shaderCacheHitCount = 0; + shaderCacheEntryDirtyCount = 0; + return SLANG_OK; +} ShaderComponentID ShaderCache::getComponentId(slang::TypeReflection* type) { @@ -908,7 +1080,7 @@ void ShaderProgramBase::init(const IShaderProgram::Desc& inDesc) } } -Result ShaderProgramBase::compileShaders() +Result ShaderProgramBase::compileShaders(RendererBase* device) { // For a fully specialized program, read and store its kernel code in `shaderProgram`. auto compileShader = [&](slang::EntryPointReflection* entryPointInfo, @@ -918,7 +1090,7 @@ Result ShaderProgramBase::compileShaders() auto stage = entryPointInfo->getStage(); ComPtr kernelCode; ComPtr diagnostics; - auto compileResult = entryPointComponent->getEntryPointCode( + auto compileResult = device->getEntryPointCodeFromShaderCache(entryPointComponent, entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef()); if (diagnostics) { diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index d0e4b52fb..a753fe017 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -26,6 +26,7 @@ struct GfxGUID static const Slang::Guid IID_ITextureResource; static const Slang::Guid IID_IInputLayout; static const Slang::Guid IID_IDevice; + static const Slang::Guid IID_IShaderCacheStatistics; static const Slang::Guid IID_IShaderObjectLayout; static const Slang::Guid IID_IShaderObject; static const Slang::Guid IID_IRenderPassLayout; @@ -857,7 +858,7 @@ public: return false; } - Slang::Result compileShaders(); + Slang::Result compileShaders(RendererBase* device); virtual Slang::Result createShaderModule( slang::EntryPointReflection* entryPointInfo, Slang::ComPtr kernelCode); @@ -1211,11 +1212,12 @@ public: // Renderer implementation shared by all platforms. // Responsible for shader compilation, specialization and caching. -class RendererBase : public IDevice, public Slang::ComObject +class RendererBase : public IDevice, public IShaderCacheStatistics, public Slang::ComObject { friend class ShaderObjectBase; public: - SLANG_COM_OBJECT_IUNKNOWN_ALL + SLANG_COM_OBJECT_IUNKNOWN_ADD_REF + SLANG_COM_OBJECT_IUNKNOWN_RELEASE virtual SLANG_NO_THROW Result SLANG_MCALL getNativeDeviceHandles(InteropHandles* outHandles) SLANG_OVERRIDE; virtual SLANG_NO_THROW Result SLANG_MCALL getFeatures( @@ -1224,6 +1226,8 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL getFormatSupportedResourceStates(Format format, ResourceStateSet* outStates) override; virtual SLANG_NO_THROW Result SLANG_MCALL getSlangSession(slang::ISession** outSlangSession) SLANG_OVERRIDE; + virtual SLANG_NO_THROW SlangResult SLANG_MCALL + queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE; IDevice* getInterface(const Slang::Guid& guid); virtual SLANG_NO_THROW Result SLANG_MCALL createTextureFromNativeHandle( @@ -1309,6 +1313,13 @@ public: // Provides a default implementation that returns SLANG_E_NOT_AVAILABLE. virtual SLANG_NO_THROW Result SLANG_MCALL getTextureRowAlignment(size_t* outAlignment) override; + Result getEntryPointCodeFromShaderCache( + slang::IComponentType* program, + SlangInt entryPointIndex, + SlangInt targetIndex, + slang::IBlob** outCode, + slang::IBlob** outDiagnostics = nullptr); + Result getShaderObjectLayout( slang::TypeReflection* type, ShaderObjectContainerType container, @@ -1346,10 +1357,24 @@ protected: protected: Slang::List m_features; +public: + virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheMissCount() override; + virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheHitCount() override; + virtual SLANG_NO_THROW GfxCount SLANG_MCALL getCacheEntryDirtyCount() override; + virtual SLANG_NO_THROW Result SLANG_MCALL resetCacheStatistics() override; + +protected: + GfxCount shaderCacheMissCount = 0; + GfxCount shaderCacheHitCount = 0; + GfxCount shaderCacheEntryDirtyCount = 0; + public: SlangContext slangContext; ShaderCache shaderCache; + ISlangFileSystem* shaderCacheFileSystem = nullptr; + ComPtr mutableShaderCacheFileSystem = nullptr; + Slang::Dictionary> m_shaderObjectLayoutCache; Slang::ComPtr m_pipelineCreationAPIDispatcher; }; diff --git a/tools/gfx/vulkan/vk-pipeline-state.cpp b/tools/gfx/vulkan/vk-pipeline-state.cpp index 710cbdaef..06bd13197 100644 --- a/tools/gfx/vulkan/vk-pipeline-state.cpp +++ b/tools/gfx/vulkan/vk-pipeline-state.cpp @@ -251,7 +251,7 @@ Result PipelineStateImpl::createVKGraphicsPipelineState() auto programImpl = static_cast(m_program.Ptr()); if (programImpl->m_stageCreateInfos.getCount() == 0) { - SLANG_RETURN_ON_FAIL(programImpl->compileShaders()); + SLANG_RETURN_ON_FAIL(programImpl->compileShaders(m_device)); } pipelineInfo.sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO; @@ -281,7 +281,7 @@ Result PipelineStateImpl::createVKComputePipelineState() auto programImpl = static_cast(m_program.Ptr()); if (programImpl->m_stageCreateInfos.getCount() == 0) { - SLANG_RETURN_ON_FAIL(programImpl->compileShaders()); + SLANG_RETURN_ON_FAIL(programImpl->compileShaders(m_device)); } VkPipelineCache pipelineCache = VK_NULL_HANDLE; @@ -340,7 +340,7 @@ Result RayTracingPipelineStateImpl::createVKRayTracingPipelineState() auto programImpl = static_cast(m_program.Ptr()); if (programImpl->m_stageCreateInfos.getCount() == 0) { - SLANG_RETURN_ON_FAIL(programImpl->compileShaders()); + SLANG_RETURN_ON_FAIL(programImpl->compileShaders(m_device)); } VkRayTracingPipelineCreateInfoKHR raytracingPipelineInfo = { -- cgit v1.2.3