diff options
| -rw-r--r-- | build/visual-studio/shader-object/shader-object.vcxproj | 190 | ||||
| -rw-r--r-- | build/visual-studio/shader-object/shader-object.vcxproj.filters | 18 | ||||
| -rw-r--r-- | examples/shader-object/README.md | 5 | ||||
| -rw-r--r-- | examples/shader-object/main.cpp | 230 | ||||
| -rw-r--r-- | examples/shader-object/shader-object.slang | 65 | ||||
| -rw-r--r-- | premake5.lua | 3 | ||||
| -rw-r--r-- | slang.sln | 11 | ||||
| -rw-r--r-- | tools/gfx/cuda/render-cuda.cpp | 69 | ||||
| -rw-r--r-- | tools/gfx/d3d11/render-d3d11.cpp | 14 | ||||
| -rw-r--r-- | tools/gfx/d3d12/render-d3d12.cpp | 82 | ||||
| -rw-r--r-- | tools/gfx/open-gl/render-gl.cpp | 9 | ||||
| -rw-r--r-- | tools/gfx/render-graphics-common.cpp | 15 | ||||
| -rw-r--r-- | tools/gfx/render.h | 4 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.cpp | 21 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.h | 14 | ||||
| -rw-r--r-- | tools/gfx/vulkan/render-vk.cpp | 9 |
16 files changed, 660 insertions, 99 deletions
diff --git a/build/visual-studio/shader-object/shader-object.vcxproj b/build/visual-studio/shader-object/shader-object.vcxproj new file mode 100644 index 000000000..1af67af5b --- /dev/null +++ b/build/visual-studio/shader-object/shader-object.vcxproj @@ -0,0 +1,190 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" ToolsVersion="14.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|Win32"> + <Configuration>Debug</Configuration> + <Platform>Win32</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|Win32"> + <Configuration>Release</Configuration> + <Platform>Win32</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <ProjectGuid>{25512BFB-1138-EDF2-BA88-5310A64E6659}</ProjectGuid> + <IgnoreWarnCompileDuplicatedFilename>true</IgnoreWarnCompileDuplicatedFilename> + <Keyword>Win32Proj</Keyword> + <RootNamespace>shader-object</RootNamespace> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <CharacterSet>Unicode</CharacterSet> + <PlatformToolset>v140</PlatformToolset> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <CharacterSet>Unicode</CharacterSet> + <PlatformToolset>v140</PlatformToolset> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <CharacterSet>Unicode</CharacterSet> + <PlatformToolset>v140</PlatformToolset> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <CharacterSet>Unicode</CharacterSet> + <PlatformToolset>v140</PlatformToolset> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> + <LinkIncremental>true</LinkIncremental> + <OutDir>..\..\..\bin\windows-x86\debug\</OutDir> + <IntDir>..\..\..\intermediate\windows-x86\debug\shader-object\</IntDir> + <TargetName>shader-object</TargetName> + <TargetExt>.exe</TargetExt> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <LinkIncremental>true</LinkIncremental> + <OutDir>..\..\..\bin\windows-x64\debug\</OutDir> + <IntDir>..\..\..\intermediate\windows-x64\debug\shader-object\</IntDir> + <TargetName>shader-object</TargetName> + <TargetExt>.exe</TargetExt> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> + <LinkIncremental>false</LinkIncremental> + <OutDir>..\..\..\bin\windows-x86\release\</OutDir> + <IntDir>..\..\..\intermediate\windows-x86\release\shader-object\</IntDir> + <TargetName>shader-object</TargetName> + <TargetExt>.exe</TargetExt> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <LinkIncremental>false</LinkIncremental> + <OutDir>..\..\..\bin\windows-x64\release\</OutDir> + <IntDir>..\..\..\intermediate\windows-x64\release\shader-object\</IntDir> + <TargetName>shader-object</TargetName> + <TargetExt>.exe</TargetExt> + </PropertyGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> + <ClCompile> + <PrecompiledHeader>NotUsing</PrecompiledHeader> + <WarningLevel>Level3</WarningLevel> + <PreprocessorDefinitions>_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <DebugInformationFormat>EditAndContinue</DebugInformationFormat> + <Optimization>Disabled</Optimization> + <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <PrecompiledHeader>NotUsing</PrecompiledHeader> + <WarningLevel>Level3</WarningLevel> + <PreprocessorDefinitions>_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <DebugInformationFormat>EditAndContinue</DebugInformationFormat> + <Optimization>Disabled</Optimization> + <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> + <ClCompile> + <PrecompiledHeader>NotUsing</PrecompiledHeader> + <WarningLevel>Level3</WarningLevel> + <PreprocessorDefinitions>NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <Optimization>Full</Optimization> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <MinimalRebuild>false</MinimalRebuild> + <StringPooling>true</StringPooling> + <RuntimeLibrary>MultiThreaded</RuntimeLibrary> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <PrecompiledHeader>NotUsing</PrecompiledHeader> + <WarningLevel>Level3</WarningLevel> + <PreprocessorDefinitions>NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <AdditionalIncludeDirectories>..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories> + <Optimization>Full</Optimization> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <MinimalRebuild>false</MinimalRebuild> + <StringPooling>true</StringPooling> + <RuntimeLibrary>MultiThreaded</RuntimeLibrary> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <ClCompile Include="..\..\..\examples\shader-object\main.cpp" /> + </ItemGroup> + <ItemGroup> + <None Include="..\..\..\examples\shader-object\shader-object.slang" /> + </ItemGroup> + <ItemGroup> + <ProjectReference Include="..\slang\slang.vcxproj"> + <Project>{DB00DA62-0533-4AFD-B59F-A67D5B3A0808}</Project> + </ProjectReference> + <ProjectReference Include="..\core\core.vcxproj"> + <Project>{F9BE7957-8399-899E-0C49-E714FDDD4B65}</Project> + </ProjectReference> + <ProjectReference Include="..\gfx\gfx.vcxproj"> + <Project>{222F7498-B40C-4F3F-A704-DDEB91A4484A}</Project> + </ProjectReference> + <ProjectReference Include="..\gfx-util\gfx-util.vcxproj"> + <Project>{F5ADB74E-02A7-44FB-AA3B-FC02F8AC7A4B}</Project> + </ProjectReference> + <ProjectReference Include="..\graphics-app-framework\graphics-app-framework.vcxproj"> + <Project>{3565FE5E-4FA3-11EB-AE93-0242AC130002}</Project> + </ProjectReference> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/build/visual-studio/shader-object/shader-object.vcxproj.filters b/build/visual-studio/shader-object/shader-object.vcxproj.filters new file mode 100644 index 000000000..da1b2e417 --- /dev/null +++ b/build/visual-studio/shader-object/shader-object.vcxproj.filters @@ -0,0 +1,18 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <Filter Include="Source Files"> + <UniqueIdentifier>{E9C7FDCE-D52A-8D73-7EB0-C5296AF258F6}</UniqueIdentifier> + </Filter> + </ItemGroup> + <ItemGroup> + <ClCompile Include="..\..\..\examples\shader-object\main.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + </ItemGroup> + <ItemGroup> + <None Include="..\..\..\examples\shader-object\shader-object.slang"> + <Filter>Source Files</Filter> + </None> + </ItemGroup> +</Project>
\ No newline at end of file diff --git a/examples/shader-object/README.md b/examples/shader-object/README.md new file mode 100644 index 000000000..cd72a4f61 --- /dev/null +++ b/examples/shader-object/README.md @@ -0,0 +1,5 @@ +Slang "Shader Object" Example +========================== + +This example shows how to use the Shader Object model implemented in the `slang-gfx` layer to +manage shader parameter binding and shader specialization. diff --git a/examples/shader-object/main.cpp b/examples/shader-object/main.cpp new file mode 100644 index 000000000..8a5bc8373 --- /dev/null +++ b/examples/shader-object/main.cpp @@ -0,0 +1,230 @@ +// main.cpp + +// This file provides the application code for the `shader-object` example. +// + +// This example uses the Slang gfx layer to target different APIs. +// The goal is to demonstrate how the Shader Object model implemented in `gfx` layer +// simplifies shader specialization and parameter binding when using `interface` typed +// shader parameters. +// +#include <slang.h> +#include <slang-com-ptr.h> +using Slang::ComPtr; + +#include "gfx/render.h" +#include "gfx-util/shader-cursor.h" +#include "source/core/slang-basic.h" + +using namespace gfx; + +// Helper function for print out diagnostic messages output by Slang compiler. +void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob) +{ + if (diagnosticsBlob != nullptr) + { + printf("%s", (const char*)diagnosticsBlob->getBufferPointer()); + } +} + +// Loads the shader code defined in `shader-object.slang` for use by the `gfx` layer. +// +Result loadShaderProgram( + gfx::IRenderer* renderer, + ComPtr<gfx::IShaderProgram>& outShaderProgram, + slang::ProgramLayout*& slangReflection) +{ + // We need to obatin a compilation session (`slang::ISession`) that will provide + // a scope to all the compilation and loading of code we do. + // + // Our example application uses the `gfx` graphics API abstraction layer, which already + // creates a Slang compilation session for us, so we just grab and use it here. + ComPtr<slang::ISession> slangSession; + SLANG_RETURN_ON_FAIL(renderer->getSlangSession(slangSession.writeRef())); + + // Once the session has been obtained, we can start loading code into it. + // + // The simplest way to load code is by calling `loadModule` with the name of a Slang + // module. A call to `loadModule("MyStuff")` will behave more or less as if you + // wrote: + // + // import MyStuff; + // + // In a Slang shader file. The compiler will use its search paths to try to locate + // `MyModule.slang`, then compile and load that file. If a matching module had + // already been loaded previously, that would be used directly. + // + // Note: The only interesting wrinkle here is that our file is named `shader-object` with + // a hyphen in it, so the name is not directly usable as an identifier in Slang code. + // Instead, when trying to import this module in the context of Slang code, a user + // needs to replace the hyphens with underscores: + // + // import shader_object; + // + ComPtr<slang::IBlob> diagnosticsBlob; + slang::IModule* module = slangSession->loadModule("shader-object", diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + if(!module) + return SLANG_FAIL; + + // Loading the `shader-object` module will compile and check all the shader code in it, + // including the shader entry points we want to use. Now that the module is loaded + // we can look up those entry points by name. + // + // Note: If you are using this `loadModule` approach to load your shader code it is + // important to tag your entry point functions with the `[shader("...")]` attribute + // (e.g., `[shader("vertex")] void vertexMain(...)`). Without that information there + // is no umambiguous way for the compiler to know which functions represent entry + // points when it parses your code via `loadModule()`. + // + char const* computeEntryPointName = "computeMain"; + ComPtr<slang::IEntryPoint> computeEntryPoint; + SLANG_RETURN_ON_FAIL( + module->findEntryPointByName(computeEntryPointName, computeEntryPoint.writeRef())); + + // At this point we have a few different Slang API objects that represent + // pieces of our code: `module`, `vertexEntryPoint`, and `fragmentEntryPoint`. + // + // A single Slang module could contain many different entry points (e.g., + // four vertex entry points, three fragment entry points, and two compute + // shaders), and before we try to generate output code for our target API + // we need to identify which entry points we plan to use together. + // + // Modules and entry points are both examples of *component types* in the + // Slang API. The API also provides a way to build a *composite* out of + // other pieces, and that is what we are going to do with our module + // and entry points. + // + Slang::List<slang::IComponentType*> componentTypes; + componentTypes.add(module); + componentTypes.add(computeEntryPoint); + + // Actually creating the composite component type is a single operation + // on the Slang session, but the operation could potentially fail if + // something about the composite was invalid (e.g., you are trying to + // combine multiple copies of the same module), so we need to deal + // with the possibility of diagnostic output. + // + ComPtr<slang::IComponentType> composedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + composedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + slangReflection = composedProgram->getLayout(); + + // At this point, `composedProgram` represents the shader program + // we want to run, and the compute shader there have been checked. + // We can create a `gfx::IShaderProgram` object from `composedProgram` + // so it may be used by the graphics layer. + gfx::IShaderProgram::Desc programDesc = {}; + programDesc.pipelineType = gfx::PipelineType::Compute; + programDesc.slangProgram = composedProgram.get(); + + auto shaderProgram = renderer->createProgram(programDesc); + + outShaderProgram = shaderProgram; + return SLANG_OK; +} + +// Main body of the example. +int main() +{ + // Creates a `gfx` renderer, which provides the main interface for + // interacting with the graphics API. + Slang::ComPtr<gfx::IRenderer> renderer; + IRenderer::Desc rendererDesc = {}; + rendererDesc.rendererType = RendererType::CUDA; + SLANG_RETURN_ON_FAIL(gfxCreateRenderer(&rendererDesc, nullptr, renderer.writeRef())); + + // Now we can load the shader code. + // A `gfx::IShaderProgram` object for use in the `gfx` layer. + ComPtr<gfx::IShaderProgram> shaderProgram; + // A composed `IComponentType` that gives us reflection info on the shader code. + slang::ProgramLayout* slangReflection; + SLANG_RETURN_ON_FAIL(loadShaderProgram(renderer, shaderProgram, slangReflection)); + + // Create a pipelien state with the loaded shader. + gfx::ComputePipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr<gfx::IPipelineState> pipelineState; + SLANG_RETURN_ON_FAIL( + renderer->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); + + // Create and initiate our input/output buffer. + const int numberCount = 4; + float initialData[] = {0.0f, 1.0f, 2.0f, 3.0f}; + IBufferResource::Desc bufferDesc = {}; + bufferDesc.sizeInBytes = numberCount * sizeof(float); + bufferDesc.format = gfx::Format::Unknown; + bufferDesc.elementSize = sizeof(float); + bufferDesc.bindFlags = gfx::IResource::BindFlag::NonPixelShaderResource | + gfx::IResource::BindFlag::UnorderedAccess; + bufferDesc.cpuAccessFlags = IResource::AccessFlag::Write | IResource::AccessFlag::Read; + + ComPtr<gfx::IBufferResource> numbersBuffer; + SLANG_RETURN_ON_FAIL(renderer->createBufferResource( + gfx::IResource::Usage::UnorderedAccess, + bufferDesc, + (void*)initialData, + numbersBuffer.writeRef())); + + // Create a resource view for the buffer. + ComPtr<gfx::IResourceView> bufferView; + gfx::IResourceView::Desc viewDesc = {}; + viewDesc.type = gfx::IResourceView::Type::UnorderedAccess; + viewDesc.format = gfx::Format::Unknown; + SLANG_RETURN_ON_FAIL(renderer->createBufferView(numbersBuffer, viewDesc, bufferView.writeRef())); + + // Now comes the interesting part: binding the shader parameter for the + // compute kernel that we about to launch. We would like to construct + // a shader object that represents a `f(x)=x+1` transformation and apply + // it to the numbers in `numbersBuffer`. + // To start, we create a root shader object that represents the root level + // scope of the shader parameters. + ComPtr<gfx::IShaderObject> rootObject; + SLANG_RETURN_ON_FAIL(renderer->createRootShaderObject(shaderProgram, rootObject.writeRef())); + // We can set parameters directly with `rootObject`, but that requires us to use + // the Slang reflection API to obtain the proper offsets into the root object for each parameter. + // We implemented these logic in the `ShaderCursor` helper class, which simplifies the user + // code to find shader parameters. Here we demonstrate how to set parameters with `ShaderCursor`. + gfx::ShaderCursor entryPointCursor(rootObject->getEntryPoint(0)); // get a cursor the the first entry-point. + // Bind buffer view to the entry point. + entryPointCursor.getPath("buffer").setResource(bufferView); + + // Next, we create a shader object that represents the transformer we want to use. + // To do so, we first need to lookup for the `AddTransformer` type defined in the shader code. + ComPtr<gfx::IShaderObject> transformer; + ComPtr<gfx::IShaderObjectLayout> transformerObjectLayout; + slang::TypeLayoutReflection* addTransformerTypeLayout = slangReflection->getTypeLayout( + slangReflection->findTypeByName("AddTransformer")); + + // Now we can use this type to create a shader object that can be bound to the root object. + SLANG_RETURN_ON_FAIL(renderer->createShaderObjectLayout( + addTransformerTypeLayout, transformerObjectLayout.writeRef())); + SLANG_RETURN_ON_FAIL( + renderer->createShaderObject(transformerObjectLayout, transformer.writeRef())); + // Set the `c` field of the `AddTransformer`. + float c = 1.0f; + gfx::ShaderCursor(transformer).getPath("c").setData(&c, sizeof(float)); + + // Now the transformer object is ready, we can bind it to root object. + entryPointCursor.getPath("transformer").setObject(transformer); + + // We have set up all required parameters in entry-point object, now it is time + // to bind the pipeline and root object and launch the kernel. + renderer->setPipelineState(pipelineState); + SLANG_RETURN_ON_FAIL(renderer->bindRootShaderObject(gfx::PipelineType::Compute, rootObject)); + renderer->dispatchCompute(1, 1, 1); + + // Read back the results. + renderer->waitForGpu(); + float* result = (float*)renderer->map(numbersBuffer, gfx::MapFlavor::HostRead); + for (int i = 0; i < numberCount; i++) + printf("%f\n", result[i]); + renderer->unmap(numbersBuffer); + + return SLANG_OK; +} diff --git a/examples/shader-object/shader-object.slang b/examples/shader-object/shader-object.slang new file mode 100644 index 000000000..7cb295f53 --- /dev/null +++ b/examples/shader-object/shader-object.slang @@ -0,0 +1,65 @@ +// shader-object.slang + +// This file implements a simple compute shader that transforms +// input floating point numbers stored in a `RWStructuredBuffer`. +// Specifically, for each number x from input buffer, compute +// f(x) and store the result back in the same buffer. + +// The compute shader supports multiple transformation functions, +// such add(x, c) which returns x+c, or mul(x, c) which returns x*c. +// This functions are implemented as types that conforms to the +// `ITransformer` interface. + +// The main entry point function takes a parameter of `ITransformer` +// type, and applies the transformation to numbers in the input +// buffer. By defining the shader parameter using interfaces, +// we enable the flexiblity to generate either specialized compute +// kernels that performs specific transformation or a general +// kernel that can perform any transformations encoded by the +// parameter at run-time, without changing any shader code or +// host-application logic for setting and preparing shader parameters. + +// Defines the transformer interface, which implements a single +// `transform` operation. +interface ITransformer +{ + float transform(float x); +} + +// Represents a transform function f(x) = x + c. +struct AddTransformer : ITransformer +{ + float c; + float transform(float x) { return x + c + 10.0f; } +}; + +// Represents a transform function f(x) = x * c. +struct MulTransformer : ITransformer +{ + float c; + float transform(float x) { return x * c; } +}; + +// Represents a composite function f(x) = f0(f1(x)); +struct CompositeTransformer : ITransformer +{ + ITransformer func0; + ITransformer func1; + float transform(float x) + { + return func0.transform(func1.transform(x)); + } +}; + +// Main entry-point. Applies the transformation encoded by `transformer` +// to all elements in `buffer`. +[shader("compute")] +[numthreads(4,1,1)] +void computeMain( + uint3 sv_dispatchThreadID : SV_DispatchThreadID, + uniform RWStructuredBuffer<float> buffer, + uniform ITransformer transformer) +{ + var input = buffer[sv_dispatchThreadID.x]; + buffer[sv_dispatchThreadID.x] = transformer.transform(input); +} diff --git a/premake5.lua b/premake5.lua index 4068219fd..e3ae76fc2 100644 --- a/premake5.lua +++ b/premake5.lua @@ -578,6 +578,9 @@ if isTargetWindows then example "shader-toy" end +example "shader-object" + kind "ConsoleApp" + example "cpu-hello-world" kind "ConsoleApp" @@ -15,6 +15,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "heterogeneous-hello-world", EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "model-viewer", "build\visual-studio\model-viewer\model-viewer.vcxproj", "{2F8724C6-1BC3-2730-84D5-3F277030D04A}" EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "shader-object", "build\visual-studio\shader-object\shader-object.vcxproj", "{25512BFB-1138-EDF2-BA88-5310A64E6659}" +EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "shader-toy", "build\visual-studio\shader-toy\shader-toy.vcxproj", "{0FC5DE93-FBEA-A8FA-E430-2EC6D0F5CDC6}" EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "generator", "generator", "{F3AB4ED5-5F37-BC99-6848-3F8ED452189A}" @@ -115,6 +117,14 @@ Global {2F8724C6-1BC3-2730-84D5-3F277030D04A}.Release|Win32.Build.0 = Release|Win32 {2F8724C6-1BC3-2730-84D5-3F277030D04A}.Release|x64.ActiveCfg = Release|x64 {2F8724C6-1BC3-2730-84D5-3F277030D04A}.Release|x64.Build.0 = Release|x64 + {25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|Win32.ActiveCfg = Debug|Win32 + {25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|Win32.Build.0 = Debug|Win32 + {25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|x64.ActiveCfg = Debug|x64 + {25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|x64.Build.0 = Debug|x64 + {25512BFB-1138-EDF2-BA88-5310A64E6659}.Release|Win32.ActiveCfg = Release|Win32 + {25512BFB-1138-EDF2-BA88-5310A64E6659}.Release|Win32.Build.0 = Release|Win32 + {25512BFB-1138-EDF2-BA88-5310A64E6659}.Release|x64.ActiveCfg = Release|x64 + {25512BFB-1138-EDF2-BA88-5310A64E6659}.Release|x64.Build.0 = Release|x64 {0FC5DE93-FBEA-A8FA-E430-2EC6D0F5CDC6}.Debug|Win32.ActiveCfg = Debug|Win32 {0FC5DE93-FBEA-A8FA-E430-2EC6D0F5CDC6}.Debug|Win32.Build.0 = Debug|Win32 {0FC5DE93-FBEA-A8FA-E430-2EC6D0F5CDC6}.Debug|x64.ActiveCfg = Debug|x64 @@ -245,6 +255,7 @@ Global {010BE414-ED5B-CF56-16C0-BD18027062C0} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} {150CAA5A-0177-6A66-AA92-CFCB96DC2D49} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} {2F8724C6-1BC3-2730-84D5-3F277030D04A} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} + {25512BFB-1138-EDF2-BA88-5310A64E6659} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} {0FC5DE93-FBEA-A8FA-E430-2EC6D0F5CDC6} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} {E145B2B8-CD13-A6BE-B6A7-16E5A2148223} = {F3AB4ED5-5F37-BC99-6848-3F8ED452189A} {61F7EB00-7281-4BF3-9470-7C2EA92620C3} = {57B5AA5E-C340-1823-CC51-9B17385C7423} diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp index 7d7ee8eb9..4f87bdfc9 100644 --- a/tools/gfx/cuda/render-cuda.cpp +++ b/tools/gfx/cuda/render-cuda.cpp @@ -242,36 +242,6 @@ public: RefPtr<TextureCUDAResource> textureResource = nullptr; }; -class CUDAProgramLayout; - -class CUDAShaderProgram : public ShaderProgramBase -{ -public: - CUmodule cudaModule = nullptr; - CUfunction cudaKernel; - String kernelName; - RefPtr<CUDAProgramLayout> layout; - - ~CUDAShaderProgram() - { - if (cudaModule) - cuModuleUnload(cudaModule); - } -}; - -class CUDAPipelineState : public PipelineStateBase -{ -public: - RefPtr<CUDAShaderProgram> shaderProgram; - void init(const ComputePipelineStateDesc& inDesc) - { - PipelineStateDesc pipelineDesc; - pipelineDesc.type = PipelineType::Compute; - pipelineDesc.compute = inDesc; - initializeBase(pipelineDesc); - } -}; - class CUDAShaderObjectLayout : public ShaderObjectLayoutBase { public: @@ -578,8 +548,6 @@ public: // TODO: the logic here is a copy-paste of `GraphicsCommonShaderObject::collectSpecializationArgs`, // consider moving the implementation to `ShaderObjectBase` and share the logic among different implementations. - if (!m_bindingFinalized) - return SLANG_FAIL; auto& subObjectRanges = getLayout()->subObjectRanges; // The following logic is built on the assumption that all fields that involve existential types (and // therefore require specialization) will results in a sub-object range in the type layout. @@ -677,7 +645,42 @@ public: entryPointObjects[index]->addRef(); return SLANG_OK; } + virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override + { + SLANG_RETURN_ON_FAIL(CUDAShaderObject::collectSpecializationArgs(args)); + for (auto& entryPoint : entryPointObjects) + { + SLANG_RETURN_ON_FAIL(entryPoint->collectSpecializationArgs(args)); + } + return SLANG_OK; + } +}; +class CUDAShaderProgram : public ShaderProgramBase +{ +public: + CUmodule cudaModule = nullptr; + CUfunction cudaKernel; + String kernelName; + RefPtr<CUDAProgramLayout> layout; + ~CUDAShaderProgram() + { + if (cudaModule) + cuModuleUnload(cudaModule); + } +}; + +class CUDAPipelineState : public PipelineStateBase +{ +public: + RefPtr<CUDAShaderProgram> shaderProgram; + void init(const ComputePipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::Compute; + pipelineDesc.compute = inDesc; + initializeBase(pipelineDesc); + } }; class CUDARenderer : public RendererBase @@ -802,7 +805,6 @@ private: CUcontext m_context = nullptr; RefPtr<CUDAPipelineState> currentPipeline = nullptr; RefPtr<CUDARootShaderObject> currentRootObject = nullptr; - SlangContext slangContext; public: ~CUDARenderer() { @@ -1332,6 +1334,7 @@ private: { RefPtr<CUDAShaderProgram> cudaProgram = new CUDAShaderProgram(); cudaProgram->slangProgram = desc.slangProgram; + cudaProgram->layout = new CUDAProgramLayout(this, desc.slangProgram->getLayout()); *outProgram = cudaProgram.detach(); return SLANG_OK; } diff --git a/tools/gfx/d3d11/render-d3d11.cpp b/tools/gfx/d3d11/render-d3d11.cpp index 079d89a59..49fe101fb 100644 --- a/tools/gfx/d3d11/render-d3d11.cpp +++ b/tools/gfx/d3d11/render-d3d11.cpp @@ -1173,7 +1173,8 @@ Result D3D11Renderer::createBufferResource(IResource::Usage initialUsage, const bufferDesc.Usage = D3D11_USAGE_DEFAULT; // If written by CPU, make it dynamic - if (descIn.cpuAccessFlags & IResource::AccessFlag::Write) + if ((descIn.cpuAccessFlags & IResource::AccessFlag::Write) && + ((descIn.bindFlags & IResource::BindFlag::UnorderedAccess) == 0)) { bufferDesc.Usage = D3D11_USAGE_DYNAMIC; } @@ -1203,7 +1204,7 @@ Result D3D11Renderer::createBufferResource(IResource::Usage initialUsage, const } } - if( bufferDesc.Usage == D3D11_USAGE_DYNAMIC ) + if (srcDesc.cpuAccessFlags & IResource::AccessFlag::Write) { bufferDesc.CPUAccessFlags |= D3D11_CPU_ACCESS_WRITE; } @@ -1791,6 +1792,15 @@ void D3D11Renderer::drawIndexed(UInt indexCount, UInt startIndex, UInt baseVerte Result D3D11Renderer::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) { + if (desc.slangProgram && desc.slangProgram->getSpecializationParamCount() != 0) + { + // For a specializable program, we don't invoke any actual slang compilation yet. + RefPtr<ShaderProgramImpl> shaderProgram = new ShaderProgramImpl(); + initProgramCommon(shaderProgram, desc); + *outProgram = shaderProgram.detach(); + return SLANG_OK; + } + if( desc.kernelCount == 0 ) { return createProgramFromSlang(this, desc, outProgram); diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index de7cbd2e2..0ab07c262 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -2960,6 +2960,15 @@ void D3D12Renderer::setDescriptorSet(PipelineType pipelineType, IPipelineLayout* Result D3D12Renderer::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) { + if (desc.slangProgram && desc.slangProgram->getSpecializationParamCount() != 0) + { + // For a specializable program, we don't invoke any actual slang compilation yet. + RefPtr<ShaderProgramImpl> shaderProgram = new ShaderProgramImpl(); + initProgramCommon(shaderProgram, desc); + *outProgram = shaderProgram.detach(); + return SLANG_OK; + } + if( desc.kernelCount == 0 ) { return createProgramFromSlang(this, desc, outProgram); @@ -3740,43 +3749,54 @@ Result D3D12Renderer::createComputePipelineState(const ComputePipelineStateDesc& auto pipelineLayoutImpl = (PipelineLayoutImpl*) desc.pipelineLayout; auto programImpl = (ShaderProgramImpl*) desc.program; - // Describe and create the compute pipeline state object - D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {}; - computeDesc.pRootSignature = pipelineLayoutImpl->m_rootSignature; - computeDesc.CS = { programImpl->m_computeShader.getBuffer(), SIZE_T(programImpl->m_computeShader.getCount()) }; - + // Only actually create a D3D12 pipeline state if the pipeline is fully specialized. ComPtr<ID3D12PipelineState> pipelineState; - -#ifdef GFX_NVAPI - if (m_nvapi) + if (!programImpl->slangProgram || programImpl->slangProgram->getSpecializationParamCount() == 0) { - // Also fill the extension structure. - // Use the same UAV slot index and register space that are declared in the shader. - - // For simplicities sake we just use u0 - NVAPI_D3D12_PSO_SET_SHADER_EXTENSION_SLOT_DESC extensionDesc; - extensionDesc.baseVersion = NV_PSO_EXTENSION_DESC_VER; - extensionDesc.version = NV_SET_SHADER_EXTENSION_SLOT_DESC_VER; - extensionDesc.uavSlot = 0; - extensionDesc.registerSpace = 0; - - // Put the pointer to the extension into an array - there can be multiple extensions enabled at once. - const NVAPI_D3D12_PSO_EXTENSION_DESC* extensions[] = { &extensionDesc }; + // Describe and create the compute pipeline state object + D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {}; + computeDesc.pRootSignature = pipelineLayoutImpl->m_rootSignature; + computeDesc.CS = { + programImpl->m_computeShader.getBuffer(), + SIZE_T(programImpl->m_computeShader.getCount())}; - // Now create the PSO. - const NvAPI_Status nvapiStatus = NvAPI_D3D12_CreateComputePipelineState(m_device, &computeDesc, SLANG_COUNT_OF(extensions), extensions, pipelineState.writeRef()); - - if (nvapiStatus != NVAPI_OK) - { - return SLANG_FAIL; +#ifdef GFX_NVAPI + if (m_nvapi) + { + // Also fill the extension structure. + // Use the same UAV slot index and register space that are declared in the shader. + + // For simplicities sake we just use u0 + NVAPI_D3D12_PSO_SET_SHADER_EXTENSION_SLOT_DESC extensionDesc; + extensionDesc.baseVersion = NV_PSO_EXTENSION_DESC_VER; + extensionDesc.version = NV_SET_SHADER_EXTENSION_SLOT_DESC_VER; + extensionDesc.uavSlot = 0; + extensionDesc.registerSpace = 0; + + // Put the pointer to the extension into an array - there can be multiple extensions + // enabled at once. + const NVAPI_D3D12_PSO_EXTENSION_DESC* extensions[] = {&extensionDesc}; + + // Now create the PSO. + const NvAPI_Status nvapiStatus = NvAPI_D3D12_CreateComputePipelineState( + m_device, + &computeDesc, + SLANG_COUNT_OF(extensions), + extensions, + pipelineState.writeRef()); + + if (nvapiStatus != NVAPI_OK) + { + return SLANG_FAIL; + } } - } - else + else #endif - { - SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState(&computeDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + { + SLANG_RETURN_ON_FAIL(m_device->CreateComputePipelineState( + &computeDesc, IID_PPV_ARGS(pipelineState.writeRef()))); + } } - RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(); pipelineStateImpl->m_pipelineLayout = pipelineLayoutImpl; pipelineStateImpl->m_pipelineState = pipelineState; diff --git a/tools/gfx/open-gl/render-gl.cpp b/tools/gfx/open-gl/render-gl.cpp index b84db44b6..03736cfa4 100644 --- a/tools/gfx/open-gl/render-gl.cpp +++ b/tools/gfx/open-gl/render-gl.cpp @@ -1489,6 +1489,15 @@ SLANG_NO_THROW Result SLANG_MCALL Result GLRenderer::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) { + if (desc.slangProgram && desc.slangProgram->getSpecializationParamCount() != 0) + { + // For a specializable program, we don't invoke any actual slang compilation yet. + RefPtr<ShaderProgramImpl> shaderProgram = new ShaderProgramImpl(m_weakRenderer, 0); + initProgramCommon(shaderProgram, desc); + *outProgram = shaderProgram.detach(); + return SLANG_OK; + } + if( desc.kernelCount == 0 ) { return createProgramFromSlang(this, desc, outProgram); diff --git a/tools/gfx/render-graphics-common.cpp b/tools/gfx/render-graphics-common.cpp index fb01867d8..5f083538d 100644 --- a/tools/gfx/render-graphics-common.cpp +++ b/tools/gfx/render-graphics-common.cpp @@ -724,8 +724,6 @@ public: return SLANG_E_INVALID_ARG; auto subObject = static_cast<GraphicsCommonShaderObject*>(object); - if (!subObject->m_bindingFinalized) - return SLANG_E_INVALID_ARG; auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex); @@ -815,8 +813,6 @@ public: // Appends all types that are used to specialize the element type of this shader object in `args` list. virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override { - if (!m_bindingFinalized) - return SLANG_FAIL; auto& subObjectRanges = getLayout()->getSubObjectRanges(); // The following logic is built on the assumption that all fields that involve existential types (and // therefore require specialization) will results in a sub-object range in the type layout. @@ -914,7 +910,7 @@ protected: // In the case where the sub-object range represents an // existential-type leaf field (e.g., an `IBar`), we - // cannot pre-allocate the objet(s) to go into that + // cannot pre-allocate the object(s) to go into that // range, since we can't possibly know what to allocate // at this point. // @@ -1218,6 +1214,15 @@ public: return SLANG_OK; } + virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override + { + SLANG_RETURN_ON_FAIL(GraphicsCommonShaderObject::collectSpecializationArgs(args)); + for (auto& entryPoint : m_entryPoints) + { + SLANG_RETURN_ON_FAIL(entryPoint->collectSpecializationArgs(args)); + } + return SLANG_OK; + } protected: virtual Result _bindIntoDescriptorSets(ComPtr<IDescriptorSet>* descriptorSets) override diff --git a/tools/gfx/render.h b/tools/gfx/render.h index 13af56550..e783273ee 100644 --- a/tools/gfx/render.h +++ b/tools/gfx/render.h @@ -282,7 +282,8 @@ public: case Usage::DepthWrite: return BindFlag::DepthStencil; case Usage::UnorderedAccess: - return BindFlag::UnorderedAccess; + return BindFlag::Enum(BindFlag::UnorderedAccess | BindFlag::PixelShaderResource | + BindFlag::NonPixelShaderResource); case Usage::PixelShaderResource: return BindFlag::PixelShaderResource; case Usage::NonPixelShaderResource: @@ -880,7 +881,6 @@ public: setSampler(ShaderOffset const& offset, ISamplerState* sampler) = 0; virtual SLANG_NO_THROW Result SLANG_MCALL setCombinedTextureSampler( ShaderOffset const& offset, IResourceView* textureView, ISamplerState* sampler) = 0; - virtual SLANG_NO_THROW Result SLANG_MCALL finalizeBindings() = 0; }; #define SLANG_UUID_IShaderObject \ { \ diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index a2e1751fd..741b1eba5 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -319,7 +319,7 @@ void ShaderCache::addShaderBinary(ShaderComponentID componentId, ShaderBinary* b shaderBinaries[componentId] = binary; } -void ShaderCache::addSpecializedPipeline(PipelineKey key, Slang::RefPtr<PipelineStateBase> specializedPipeline) +void ShaderCache::addSpecializedPipeline(PipelineKey key, Slang::ComPtr<IPipelineState> specializedPipeline) { specializedPipelines[key] = specializedPipeline; } @@ -377,21 +377,8 @@ void ShaderObjectLayoutBase::initBase(RendererBase* renderer, slang::TypeLayoutR m_componentID = m_renderer->shaderCache.getComponentId(m_elementTypeLayout->getType()); } -SLANG_NO_THROW Result SLANG_MCALL ShaderObjectBase::finalizeBindings() -{ - m_bindingFinalized = true; - - // With all binding fixed, the shader object's type can be determined by specializing the - // shader object's type with the type of bound sub objects. - // Now obtain a componentID for the specialized shader object type from the shader cache. - SLANG_RETURN_ON_FAIL(getSpecializedShaderObjectType(&shaderObjectType)); - return SLANG_OK; -} - - // Get the final type this shader object represents. If the shader object's type has existential fields, // this function will return a specialized type using the bound sub-objects' type as specialization argument. - Result ShaderObjectBase::getSpecializedShaderObjectType(ExtendedShaderObjectType* outType) { if (shaderObjectType.slangType) @@ -432,7 +419,7 @@ Result RendererBase::maybeSpecializePipeline(ShaderObjectBase* rootObject) pipelineKey.specializationArgs.addRange(specializationArgs.componentIDs); pipelineKey.updateHash(); - auto specializedPipelineState = shaderCache.getSpecializedPipelineState(pipelineKey); + ComPtr<gfx::IPipelineState> specializedPipelineState = shaderCache.getSpecializedPipelineState(pipelineKey); // Try to find specialized pipeline from shader cache. if (!specializedPipelineState) { @@ -493,6 +480,7 @@ Result RendererBase::maybeSpecializePipeline(ShaderObjectBase* rootObject) IShaderProgram::Desc specializedProgramDesc = {}; specializedProgramDesc.kernelCount = unspecializedProgramLayout->getEntryPointCount(); ShortList<IShaderProgram::KernelDesc> kernelDescs; + kernelDescs.setCount(entryPointBinaries.getCount()); for (Slang::Index i = 0; i < entryPointBinaries.getCount(); i++) { auto entryPoint = unspecializedProgramLayout->getEntryPointByIndex(i);; @@ -507,7 +495,6 @@ Result RendererBase::maybeSpecializePipeline(ShaderObjectBase* rootObject) SLANG_RETURN_ON_FAIL(createProgram(specializedProgramDesc, specializedProgram.writeRef())); // Create specialized pipeline state. - ComPtr<IPipelineState> specializedPipelineState; switch (pipelineType) { case PipelineType::Compute: @@ -529,7 +516,7 @@ Result RendererBase::maybeSpecializePipeline(ShaderObjectBase* rootObject) } auto specializedPipelineStateBase = static_cast<PipelineStateBase*>(specializedPipelineState.get()); specializedPipelineStateBase->unspecializedPipelineState = currentPipeline; - shaderCache.addSpecializedPipeline(pipelineKey, specializedPipelineStateBase); + shaderCache.addSpecializedPipeline(pipelineKey, specializedPipelineState); } setPipelineState(specializedPipelineState); } diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index c4a000e87..26ed5b040 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -161,9 +161,6 @@ protected: // The shader object layout used to create this shader object. Slang::RefPtr<ShaderObjectLayoutBase> m_layout = nullptr; - // Indicates whether all bindings have been finalized. - bool m_bindingFinalized = false; - // The specialized shader object type. ExtendedShaderObjectType shaderObjectType = { nullptr, kInvalidComponentID }; public: @@ -201,8 +198,6 @@ public: return m_layout->getElementTypeLayout(); } - SLANG_NO_THROW Result SLANG_MCALL finalizeBindings() SLANG_OVERRIDE; - virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) = 0; }; @@ -338,20 +333,21 @@ public: void init(ISlangFileSystem* cacheFileSystem); void writeToFileSystem(ISlangMutableFileSystem* outputFileSystem); - Slang::RefPtr<PipelineStateBase> getSpecializedPipelineState(PipelineKey programKey) + Slang::ComPtr<IPipelineState> getSpecializedPipelineState(PipelineKey programKey) { - Slang::RefPtr<PipelineStateBase> result; + Slang::ComPtr<IPipelineState> result; if (specializedPipelines.TryGetValue(programKey, result)) return result; return nullptr; } Slang::RefPtr<ShaderBinary> tryLoadShaderBinary(ShaderComponentID componentId); void addShaderBinary(ShaderComponentID componentId, ShaderBinary* binary); - void addSpecializedPipeline(PipelineKey key, Slang::RefPtr<PipelineStateBase> specializedPipeline); + void addSpecializedPipeline(PipelineKey key, Slang::ComPtr<IPipelineState> specializedPipeline); + protected: Slang::ComPtr<ISlangFileSystem> fileSystem; Slang::OrderedDictionary<OwningComponentKey, ShaderComponentID> componentIds; - Slang::OrderedDictionary<PipelineKey, Slang::RefPtr<PipelineStateBase>> specializedPipelines; + Slang::OrderedDictionary<PipelineKey, Slang::ComPtr<IPipelineState>> specializedPipelines; Slang::OrderedDictionary<ShaderComponentID, Slang::RefPtr<ShaderBinary>> shaderBinaries; }; diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 565a8e96f..f58a81b8d 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -2841,6 +2841,15 @@ void VKRenderer::setDescriptorSet(PipelineType pipelineType, IPipelineLayout* la Result VKRenderer::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) { + if (desc.slangProgram && desc.slangProgram->getSpecializationParamCount() != 0) + { + // For a specializable program, we don't invoke any actual slang compilation yet. + RefPtr<ShaderProgramImpl> shaderProgram = new ShaderProgramImpl(m_api, desc.pipelineType); + initProgramCommon(shaderProgram, desc); + *outProgram = shaderProgram.detach(); + return SLANG_OK; + } + if( desc.kernelCount == 0 ) { return createProgramFromSlang(this, desc, outProgram); |
