summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-07-28 12:24:12 -0700
committerGitHub <noreply@github.com>2021-07-28 12:24:12 -0700
commitc6f6ce12ec522b193b42bcd12d3a2540c7a6ff92 (patch)
treed5f77aa02df88c71ef4f898db40434bf4c1f3010
parent23d406f8a3b325f91fecd9ad52bd510ded5f49a7 (diff)
Experimental DXR1.0 support in gfx. (#1915)
* Experimental DXR1.0 support in gfx. - Add `dispatchRays` command. - Add `createRayTracingPipelineState` method to construct a D3D ray tracing state object from a linked slang program and user specified shader table. Limitations/simplifications: no local root signature support, shader table entries contains only shader identifiers and is specified at pipeline creation time, owned by the pipeline state object. * Root object binding for raytracing pipelines. * `maybeSpecializePipeline` implementation for raytracing pipelines. * Add ray-tracing-pipeline example. * Fixes. * Update README.md * Update comments on the lifespan of specialized pipelines Co-authored-by: Yong He <yhe@nvidia.com> Co-authored-by: jsmall-nvidia <jsmall@nvidia.com>
-rw-r--r--build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj193
-rw-r--r--build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj.filters18
-rw-r--r--examples/ray-tracing-pipeline/README.md9
-rw-r--r--examples/ray-tracing-pipeline/main.cpp665
-rw-r--r--examples/ray-tracing-pipeline/shaders.slang108
-rw-r--r--premake5.lua1
-rw-r--r--slang-gfx.h23
-rw-r--r--slang.sln11
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp381
-rw-r--r--tools/gfx/debug-layer.cpp20
-rw-r--r--tools/gfx/debug-layer.h7
-rw-r--r--tools/gfx/renderer-shared.cpp8
-rw-r--r--tools/gfx/renderer-shared.h18
-rw-r--r--tools/gfx/vulkan/render-vk.cpp27
14 files changed, 1457 insertions, 32 deletions
diff --git a/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj b/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj
new file mode 100644
index 000000000..b439eeb84
--- /dev/null
+++ b/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj
@@ -0,0 +1,193 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|Win32">
+ <Configuration>Debug</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|Win32">
+ <Configuration>Release</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <ProjectGuid>{17BA8E32-034E-84DA-6C12-DE8E58C5BECC}</ProjectGuid>
+ <IgnoreWarnCompileDuplicatedFilename>true</IgnoreWarnCompileDuplicatedFilename>
+ <Keyword>Win32Proj</Keyword>
+ <RootNamespace>ray-tracing-pipeline</RootNamespace>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <CharacterSet>Unicode</CharacterSet>
+ <PlatformToolset>v142</PlatformToolset>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <CharacterSet>Unicode</CharacterSet>
+ <PlatformToolset>v142</PlatformToolset>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <CharacterSet>Unicode</CharacterSet>
+ <PlatformToolset>v142</PlatformToolset>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <CharacterSet>Unicode</CharacterSet>
+ <PlatformToolset>v142</PlatformToolset>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
+ <LinkIncremental>true</LinkIncremental>
+ <OutDir>..\..\..\bin\windows-x86\debug\</OutDir>
+ <IntDir>..\..\..\intermediate\windows-x86\debug\ray-tracing-pipeline\</IntDir>
+ <TargetName>ray-tracing-pipeline</TargetName>
+ <TargetExt>.exe</TargetExt>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <LinkIncremental>true</LinkIncremental>
+ <OutDir>..\..\..\bin\windows-x64\debug\</OutDir>
+ <IntDir>..\..\..\intermediate\windows-x64\debug\ray-tracing-pipeline\</IntDir>
+ <TargetName>ray-tracing-pipeline</TargetName>
+ <TargetExt>.exe</TargetExt>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
+ <LinkIncremental>false</LinkIncremental>
+ <OutDir>..\..\..\bin\windows-x86\release\</OutDir>
+ <IntDir>..\..\..\intermediate\windows-x86\release\ray-tracing-pipeline\</IntDir>
+ <TargetName>ray-tracing-pipeline</TargetName>
+ <TargetExt>.exe</TargetExt>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <LinkIncremental>false</LinkIncremental>
+ <OutDir>..\..\..\bin\windows-x64\release\</OutDir>
+ <IntDir>..\..\..\intermediate\windows-x64\release\ray-tracing-pipeline\</IntDir>
+ <TargetName>ray-tracing-pipeline</TargetName>
+ <TargetExt>.exe</TargetExt>
+ </PropertyGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
+ <ClCompile>
+ <PrecompiledHeader>NotUsing</PrecompiledHeader>
+ <WarningLevel>Level3</WarningLevel>
+ <PreprocessorDefinitions>_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
+ <DebugInformationFormat>EditAndContinue</DebugInformationFormat>
+ <Optimization>Disabled</Optimization>
+ <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <PrecompiledHeader>NotUsing</PrecompiledHeader>
+ <WarningLevel>Level3</WarningLevel>
+ <PreprocessorDefinitions>_DEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
+ <DebugInformationFormat>EditAndContinue</DebugInformationFormat>
+ <Optimization>Disabled</Optimization>
+ <RuntimeLibrary>MultiThreadedDebug</RuntimeLibrary>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
+ <ClCompile>
+ <PrecompiledHeader>NotUsing</PrecompiledHeader>
+ <WarningLevel>Level3</WarningLevel>
+ <PreprocessorDefinitions>NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
+ <Optimization>Full</Optimization>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <MinimalRebuild>false</MinimalRebuild>
+ <StringPooling>true</StringPooling>
+ <RuntimeLibrary>MultiThreaded</RuntimeLibrary>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <PrecompiledHeader>NotUsing</PrecompiledHeader>
+ <WarningLevel>Level3</WarningLevel>
+ <PreprocessorDefinitions>NDEBUG;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <AdditionalIncludeDirectories>..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
+ <Optimization>Full</Optimization>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <MinimalRebuild>false</MinimalRebuild>
+ <StringPooling>true</StringPooling>
+ <RuntimeLibrary>MultiThreaded</RuntimeLibrary>
+ </ClCompile>
+ <Link>
+ <SubSystem>Windows</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <ClCompile Include="..\..\..\examples\ray-tracing-pipeline\main.cpp" />
+ </ItemGroup>
+ <ItemGroup>
+ <None Include="..\..\..\examples\ray-tracing-pipeline\shaders.slang" />
+ </ItemGroup>
+ <ItemGroup>
+ <ProjectReference Include="..\example-base\example-base.vcxproj">
+ <Project>{37BED5B5-23FA-D81F-8C0C-F1167867813A}</Project>
+ </ProjectReference>
+ <ProjectReference Include="..\slang\slang.vcxproj">
+ <Project>{DB00DA62-0533-4AFD-B59F-A67D5B3A0808}</Project>
+ </ProjectReference>
+ <ProjectReference Include="..\gfx\gfx.vcxproj">
+ <Project>{222F7498-B40C-4F3F-A704-DDEB91A4484A}</Project>
+ </ProjectReference>
+ <ProjectReference Include="..\gfx-util\gfx-util.vcxproj">
+ <Project>{F5ADB74E-02A7-44FB-AA3B-FC02F8AC7A4B}</Project>
+ </ProjectReference>
+ <ProjectReference Include="..\platform\platform.vcxproj">
+ <Project>{3565FE5E-4FA3-11EB-AE93-0242AC130002}</Project>
+ </ProjectReference>
+ <ProjectReference Include="..\core\core.vcxproj">
+ <Project>{F9BE7957-8399-899E-0C49-E714FDDD4B65}</Project>
+ </ProjectReference>
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj.filters b/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj.filters
new file mode 100644
index 000000000..650faecbb
--- /dev/null
+++ b/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj.filters
@@ -0,0 +1,18 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <Filter Include="Source Files">
+ <UniqueIdentifier>{E9C7FDCE-D52A-8D73-7EB0-C5296AF258F6}</UniqueIdentifier>
+ </Filter>
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="..\..\..\examples\ray-tracing-pipeline\main.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ </ItemGroup>
+ <ItemGroup>
+ <None Include="..\..\..\examples\ray-tracing-pipeline\shaders.slang">
+ <Filter>Source Files</Filter>
+ </None>
+ </ItemGroup>
+</Project> \ No newline at end of file
diff --git a/examples/ray-tracing-pipeline/README.md b/examples/ray-tracing-pipeline/README.md
new file mode 100644
index 000000000..48cec4c18
--- /dev/null
+++ b/examples/ray-tracing-pipeline/README.md
@@ -0,0 +1,9 @@
+Slang "Ray Tracing Pipeline" Example
+======================================
+
+The goal of this example is to demonstrate how to write shaders for ray-tracing pipelines in Slang.
+
+The `shaders.slang` file contains a set of ray-tracing shader entry-points that traces primary rays from camera and shade intersections with basic lighting + ray-traced shadows. The file also defines a vertex and a fragment shader entry point for displaying the ray-traced image produced by the compute shader.
+
+The `main.cpp` file contains the C++ application code, showing how to use the Slang API to load and compile the shader code, and how to use a graphics API abstraction layer implemented in `tools/gfx` to set-up and use ray-tracing pipelines (DXR 1.0 equivalent API).
+Note that this abstraction layer is *not* required in order to work with Slang, and it is just there to help us write example and test applications more conveniently.
diff --git a/examples/ray-tracing-pipeline/main.cpp b/examples/ray-tracing-pipeline/main.cpp
new file mode 100644
index 000000000..3c83447b4
--- /dev/null
+++ b/examples/ray-tracing-pipeline/main.cpp
@@ -0,0 +1,665 @@
+// main.cpp
+
+// This file implements an example of hardware ray-tracing using
+// Slang shaders and the `gfx` graphics API.
+
+#include <slang.h>
+#include "slang-gfx.h"
+#include "gfx-util/shader-cursor.h"
+#include "tools/platform/window.h"
+#include "tools/platform/vector-math.h"
+#include "slang-com-ptr.h"
+#include "source/core/slang-basic.h"
+#include "examples/example-base/example-base.h"
+
+using namespace gfx;
+using namespace Slang;
+
+struct Uniforms
+{
+ float screenWidth, screenHeight;
+ float focalLength = 24.0f, frameHeight = 24.0f;
+ float cameraDir[4];
+ float cameraUp[4];
+ float cameraRight[4];
+ float cameraPosition[4];
+ float lightDir[4];
+};
+
+struct Vertex
+{
+ float position[3];
+};
+
+// Define geometry data for our test scene.
+// The scene contains a floor plane, and a cube placed on top of it at the center.
+static const int kVertexCount = 24;
+static const Vertex kVertexData[kVertexCount] =
+{
+ // Floor plane
+ {{-100.0f, 0, 100.0f}},
+ {{100.0f, 0, 100.0f}},
+ {{100.0f, 0, -100.0f}},
+ {{-100.0f, 0, -100.0f}},
+ // Cube face (+y).
+ {{-1.0f, 2.0, 1.0f}},
+ {{1.0f, 2.0, 1.0f}},
+ {{1.0f, 2.0, -1.0f}},
+ {{-1.0f, 2.0, -1.0f}},
+ // Cube face (+z).
+ {{-1.0f, 0.0, 1.0f}},
+ {{1.0f, 0.0, 1.0f}},
+ {{1.0f, 2.0, 1.0f}},
+ {{-1.0f, 2.0, 1.0f}},
+ // Cube face (-z).
+ {{-1.0f, 0.0, -1.0f}},
+ {{-1.0f, 2.0, -1.0f}},
+ {{1.0f, 2.0, -1.0f}},
+ {{1.0f, 0.0, -1.0f}},
+ // Cube face (-x).
+ {{-1.0f, 0.0, -1.0f}},
+ {{-1.0f, 0.0, 1.0f}},
+ {{-1.0f, 2.0, 1.0f}},
+ {{-1.0f, 2.0, -1.0f}},
+ // Cube face (+x).
+ {{1.0f, 2.0, -1.0f}},
+ {{1.0f, 2.0, 1.0f}},
+ {{1.0f, 0.0, 1.0f}},
+ {{1.0f, 0.0, -1.0f}},
+};
+static const int kIndexCount = 36;
+static const int kIndexData[kIndexCount] =
+{
+ 0, 1, 2, 0, 2, 3,
+ 4, 5, 6, 4, 6, 7,
+ 8, 9, 10, 8, 10, 11,
+ 12, 13, 14, 12, 14, 15,
+ 16, 17, 18, 16, 18, 19,
+ 20, 21, 22, 20, 22, 23
+};
+
+struct Primitive
+{
+ float data[4];
+ float color[4];
+};
+static const int kPrimitiveCount = 12;
+static const Primitive kPrimitiveData[kPrimitiveCount] =
+{
+ {{0.0f, 1.0f, 0.0f, 0.0f}, {0.75f, 0.8f, 0.85f, 1.0f}},
+ {{0.0f, 1.0f, 0.0f, 0.0f}, {0.75f, 0.8f, 0.85f, 1.0f}},
+ {{0.0f, 1.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+ {{0.0f, 1.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+ {{0.0f, 0.0f, 1.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+ {{0.0f, 0.0f, 1.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+ {{0.0f, 0.0f, -1.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+ {{0.0f, 0.0f, -1.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+ {{-1.0f, 0.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+ {{-1.0f, 0.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+ {{1.0f, 0.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+ {{1.0f, 0.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}},
+};
+
+
+// We need to use a rasterization pipeline to copy the ray-traced image
+// to the swapchain. To do so we need to render a full-screen triangle.
+// We will define a small helper type that defines the data for such a triangle.
+//
+struct FullScreenTriangle
+{
+ struct Vertex
+ {
+ float position[2];
+ };
+
+ enum
+ {
+ kVertexCount = 3
+ };
+
+ static const Vertex kVertices[kVertexCount];
+};
+const FullScreenTriangle::Vertex FullScreenTriangle::kVertices[FullScreenTriangle::kVertexCount] = {
+ {{-1, -1}},
+ {{-1, 3}},
+ {{3, -1}},
+};
+
+// The example application will be implemented as a `struct`, so that
+// we can scope the resources it allocates without using global variables.
+//
+struct RayTracing : public WindowedAppBase
+{
+
+
+Uniforms gUniforms = {};
+
+
+// Many Slang API functions return detailed diagnostic information
+// (error messages, warnings, etc.) as a "blob" of data, or return
+// a null blob pointer instead if there were no issues.
+//
+// For convenience, we define a subroutine that will dump the information
+// in a diagnostic blob if one is produced, and skip it otherwise.
+//
+void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob)
+{
+ if( diagnosticsBlob != nullptr )
+ {
+ printf("%s", (const char*) diagnosticsBlob->getBufferPointer());
+#ifdef _WIN32
+ _Win32OutputDebugString((const char*)diagnosticsBlob->getBufferPointer());
+#endif
+ }
+}
+
+// Load and compile shader code from souce.
+gfx::Result loadShaderProgram(
+ gfx::IDevice* device,
+ gfx::PipelineType pipelineType,
+ gfx::IShaderProgram** outProgram)
+{
+ ComPtr<slang::ISession> slangSession;
+ slangSession = device->getSlangSession();
+
+ ComPtr<slang::IBlob> diagnosticsBlob;
+ slang::IModule* module = slangSession->loadModule("shaders", diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ if(!module)
+ return SLANG_FAIL;
+
+ Slang::List<slang::IComponentType*> componentTypes;
+ componentTypes.add(module);
+ if (pipelineType == PipelineType::RayTracing)
+ {
+ ComPtr<slang::IEntryPoint> entryPoint;
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShader", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShader", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(
+ module->findEntryPointByName("closestHitShader", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(
+ module->findEntryPointByName("shadowRayHitShader", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ }
+ else
+ {
+ ComPtr<slang::IEntryPoint> entryPoint;
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("vertexMain", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ SLANG_RETURN_ON_FAIL(module->findEntryPointByName("fragmentMain", entryPoint.writeRef()));
+ componentTypes.add(entryPoint);
+ }
+
+ ComPtr<slang::IComponentType> linkedProgram;
+ SlangResult result = slangSession->createCompositeComponentType(
+ componentTypes.getBuffer(),
+ componentTypes.getCount(),
+ linkedProgram.writeRef(),
+ diagnosticsBlob.writeRef());
+ diagnoseIfNeeded(diagnosticsBlob);
+ SLANG_RETURN_ON_FAIL(result);
+
+ gfx::IShaderProgram::Desc programDesc = {};
+ programDesc.pipelineType = pipelineType;
+ programDesc.slangProgram = linkedProgram;
+ SLANG_RETURN_ON_FAIL(device->createProgram(programDesc, outProgram));
+
+ return SLANG_OK;
+}
+
+ComPtr<gfx::IPipelineState> gPresentPipelineState;
+ComPtr<gfx::IPipelineState> gRenderPipelineState;
+ComPtr<gfx::IBufferResource> gFullScreenVertexBuffer;
+ComPtr<gfx::IBufferResource> gVertexBuffer;
+ComPtr<gfx::IBufferResource> gIndexBuffer;
+ComPtr<gfx::IBufferResource> gPrimitiveBuffer;
+ComPtr<gfx::IBufferResource> gTransformBuffer;
+ComPtr<gfx::IResourceView> gPrimitiveBufferSRV;
+ComPtr<gfx::IBufferResource> gInstanceBuffer;
+ComPtr<gfx::IBufferResource> gBLASBuffer;
+ComPtr<gfx::IAccelerationStructure> gBLAS;
+ComPtr<gfx::IBufferResource> gTLASBuffer;
+ComPtr<gfx::IAccelerationStructure> gTLAS;
+ComPtr<gfx::ITextureResource> gResultTexture;
+ComPtr<gfx::IResourceView> gResultTextureUAV;
+
+uint64_t lastTime = 0;
+
+// glm::vec3 lightDir = normalize(glm::vec3(10, 10, 10));
+// glm::vec3 lightColor = glm::vec3(1, 1, 1);
+
+glm::vec3 cameraPosition = glm::vec3(-2.53f, 2.72f, 4.3f);
+float cameraOrientationAngles[2] = {-0.475f, -0.35f}; // Spherical angles (theta, phi).
+
+float translationScale = 0.5f;
+float rotationScale = 0.01f;
+
+// In order to control camera movement, we will
+// use good old WASD
+bool wPressed = false;
+bool aPressed = false;
+bool sPressed = false;
+bool dPressed = false;
+
+bool isMouseDown = false;
+float lastMouseX = 0.0f;
+float lastMouseY = 0.0f;
+
+void setKeyState(platform::KeyCode key, bool state)
+{
+ switch (key)
+ {
+ default:
+ break;
+ case platform::KeyCode::W:
+ wPressed = state;
+ break;
+ case platform::KeyCode::A:
+ aPressed = state;
+ break;
+ case platform::KeyCode::S:
+ sPressed = state;
+ break;
+ case platform::KeyCode::D:
+ dPressed = state;
+ break;
+ }
+}
+void onKeyDown(platform::KeyEventArgs args) { setKeyState(args.key, true); }
+void onKeyUp(platform::KeyEventArgs args) { setKeyState(args.key, false); }
+
+void onMouseDown(platform::MouseEventArgs args)
+{
+ isMouseDown = true;
+ lastMouseX = (float)args.x;
+ lastMouseY = (float)args.y;
+}
+
+void onMouseMove(platform::MouseEventArgs args)
+{
+ if (isMouseDown)
+ {
+ float deltaX = args.x - lastMouseX;
+ float deltaY = args.y - lastMouseY;
+
+ cameraOrientationAngles[0] += -deltaX * rotationScale;
+ cameraOrientationAngles[1] += -deltaY * rotationScale;
+ lastMouseX = (float)args.x;
+ lastMouseY = (float)args.y;
+ }
+}
+void onMouseUp(platform::MouseEventArgs args) { isMouseDown = false; }
+
+Slang::Result initialize()
+{
+ initializeBase("Ray Tracing Pipeline", 1024, 768);
+ gWindow->events.mouseMove = [this](const platform::MouseEventArgs& e) { onMouseMove(e); };
+ gWindow->events.mouseUp = [this](const platform::MouseEventArgs& e) { onMouseUp(e); };
+ gWindow->events.mouseDown = [this](const platform::MouseEventArgs& e) { onMouseDown(e); };
+ gWindow->events.keyDown = [this](const platform::KeyEventArgs& e) { onKeyDown(e); };
+ gWindow->events.keyUp = [this](const platform::KeyEventArgs& e) { onKeyUp(e); };
+
+ IBufferResource::Desc vertexBufferDesc;
+ vertexBufferDesc.type = IResource::Type::Buffer;
+ vertexBufferDesc.sizeInBytes = kVertexCount * sizeof(Vertex);
+ vertexBufferDesc.defaultState = ResourceState::ShaderResource;
+ gVertexBuffer = gDevice->createBufferResource(vertexBufferDesc, &kVertexData[0]);
+ if(!gVertexBuffer) return SLANG_FAIL;
+
+ IBufferResource::Desc indexBufferDesc;
+ indexBufferDesc.type = IResource::Type::Buffer;
+ indexBufferDesc.sizeInBytes = kIndexCount * sizeof(int32_t);
+ indexBufferDesc.defaultState = ResourceState::ShaderResource;
+ gIndexBuffer = gDevice->createBufferResource(indexBufferDesc, &kIndexData[0]);
+ if (!gIndexBuffer)
+ return SLANG_FAIL;
+
+ IBufferResource::Desc primitiveBufferDesc;
+ primitiveBufferDesc.type = IResource::Type::Buffer;
+ primitiveBufferDesc.sizeInBytes = kPrimitiveCount * sizeof(Primitive);
+ primitiveBufferDesc.defaultState = ResourceState::ShaderResource;
+ gPrimitiveBuffer = gDevice->createBufferResource(primitiveBufferDesc, &kPrimitiveData[0]);
+ if (!gPrimitiveBuffer)
+ return SLANG_FAIL;
+
+ IResourceView::Desc primitiveSRVDesc = {};
+ primitiveSRVDesc.format = Format::Unknown;
+ primitiveSRVDesc.type = IResourceView::Type::ShaderResource;
+ gPrimitiveBufferSRV = gDevice->createBufferView(gPrimitiveBuffer, primitiveSRVDesc);
+
+ IBufferResource::Desc transformBufferDesc;
+ transformBufferDesc.type = IResource::Type::Buffer;
+ transformBufferDesc.sizeInBytes = sizeof(float) * 12;
+ transformBufferDesc.defaultState = ResourceState::ShaderResource;
+ float transformData[12] = {
+ 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f};
+ gTransformBuffer = gDevice->createBufferResource(transformBufferDesc, &transformData);
+ if (!gTransformBuffer)
+ return SLANG_FAIL;
+ // Build bottom level acceleration structure.
+ {
+ IAccelerationStructure::BuildInputs accelerationStructureBuildInputs;
+ IAccelerationStructure::PrebuildInfo accelerationStructurePrebuildInfo;
+ accelerationStructureBuildInputs.descCount = 1;
+ accelerationStructureBuildInputs.kind = IAccelerationStructure::Kind::BottomLevel;
+ accelerationStructureBuildInputs.flags =
+ IAccelerationStructure::BuildFlags::AllowCompaction;
+ IAccelerationStructure::GeometryDesc geomDesc;
+ geomDesc.flags = IAccelerationStructure::GeometryFlags::Opaque;
+ geomDesc.type = IAccelerationStructure::GeometryType::Triangles;
+ geomDesc.content.triangles.indexCount = kIndexCount;
+ geomDesc.content.triangles.indexData = gIndexBuffer->getDeviceAddress();
+ geomDesc.content.triangles.indexFormat = Format::R_UInt32;
+ geomDesc.content.triangles.vertexCount = kVertexCount;
+ geomDesc.content.triangles.vertexData = gVertexBuffer->getDeviceAddress();
+ geomDesc.content.triangles.vertexFormat = Format::RGB_Float32;
+ geomDesc.content.triangles.vertexStride = sizeof(Vertex);
+ geomDesc.content.triangles.transform3x4 = gTransformBuffer->getDeviceAddress();
+ accelerationStructureBuildInputs.geometryDescs = &geomDesc;
+
+ // Query buffer size for acceleration structure build.
+ SLANG_RETURN_ON_FAIL(gDevice->getAccelerationStructurePrebuildInfo(
+ accelerationStructureBuildInputs, &accelerationStructurePrebuildInfo));
+ // Allocate buffers for acceleration structure.
+ IBufferResource::Desc asDraftBufferDesc;
+ asDraftBufferDesc.type = IResource::Type::Buffer;
+ asDraftBufferDesc.defaultState = ResourceState::AccelerationStructure;
+ asDraftBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.resultDataMaxSize;
+ ComPtr<IBufferResource> draftBuffer = gDevice->createBufferResource(asDraftBufferDesc);
+ IBufferResource::Desc scratchBufferDesc;
+ scratchBufferDesc.type = IResource::Type::Buffer;
+ scratchBufferDesc.defaultState = ResourceState::UnorderedAccess;
+ scratchBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.scratchDataSize;
+ ComPtr<IBufferResource> scratchBuffer = gDevice->createBufferResource(scratchBufferDesc);
+
+ // Build acceleration structure.
+ ComPtr<IQueryPool> compactedSizeQuery;
+ IQueryPool::Desc queryPoolDesc;
+ queryPoolDesc.count = 1;
+ queryPoolDesc.type = QueryType::AccelerationStructureCompactedSize;
+ SLANG_RETURN_ON_FAIL(
+ gDevice->createQueryPool(queryPoolDesc, compactedSizeQuery.writeRef()));
+
+ ComPtr<IAccelerationStructure> draftAS;
+ IAccelerationStructure::CreateDesc draftCreateDesc;
+ draftCreateDesc.buffer = draftBuffer;
+ draftCreateDesc.kind = IAccelerationStructure::Kind::BottomLevel;
+ draftCreateDesc.offset = 0;
+ draftCreateDesc.size = accelerationStructurePrebuildInfo.resultDataMaxSize;
+ SLANG_RETURN_ON_FAIL(
+ gDevice->createAccelerationStructure(draftCreateDesc, draftAS.writeRef()));
+
+ auto commandBuffer = gTransientHeaps[0]->createCommandBuffer();
+ auto encoder = commandBuffer->encodeRayTracingCommands();
+ IAccelerationStructure::BuildDesc buildDesc = {};
+ buildDesc.dest = draftAS;
+ buildDesc.inputs = accelerationStructureBuildInputs;
+ buildDesc.scratchData = scratchBuffer->getDeviceAddress();
+ AccelerationStructureQueryDesc compactedSizeQueryDesc = {};
+ compactedSizeQueryDesc.queryPool = compactedSizeQuery;
+ compactedSizeQueryDesc.queryType = QueryType::AccelerationStructureCompactedSize;
+ encoder->buildAccelerationStructure(buildDesc, 1, &compactedSizeQueryDesc);
+ encoder->endEncoding();
+ commandBuffer->close();
+ gQueue->executeCommandBuffer(commandBuffer);
+ gQueue->wait();
+
+ uint64_t compactedSize = 0;
+ compactedSizeQuery->getResult(0, 1, &compactedSize);
+ IBufferResource::Desc asBufferDesc;
+ asBufferDesc.type = IResource::Type::Buffer;
+ asBufferDesc.defaultState = ResourceState::AccelerationStructure;
+ asBufferDesc.sizeInBytes = compactedSize;
+ gBLASBuffer = gDevice->createBufferResource(asBufferDesc);
+ IAccelerationStructure::CreateDesc createDesc;
+ createDesc.buffer = gBLASBuffer;
+ createDesc.kind = IAccelerationStructure::Kind::BottomLevel;
+ createDesc.offset = 0;
+ createDesc.size = compactedSize;
+ gDevice->createAccelerationStructure(createDesc, gBLAS.writeRef());
+
+ commandBuffer = gTransientHeaps[0]->createCommandBuffer();
+ encoder = commandBuffer->encodeRayTracingCommands();
+ encoder->copyAccelerationStructure(gBLAS, draftAS, AccelerationStructureCopyMode::Compact);
+ encoder->endEncoding();
+ commandBuffer->close();
+ gQueue->executeCommandBuffer(commandBuffer);
+ gQueue->wait();
+ }
+
+ // Build top level acceleration structure.
+ {
+ List<IAccelerationStructure::InstanceDesc> instanceDescs;
+ instanceDescs.setCount(1);
+ instanceDescs[0].accelerationStructure = gBLAS->getDeviceAddress();
+ instanceDescs[0].flags =
+ IAccelerationStructure::GeometryInstanceFlags::TriangleFacingCullDisable;
+ instanceDescs[0].instanceContributionToHitGroupIndex = 0;
+ instanceDescs[0].instanceID = 0;
+ instanceDescs[0].instanceMask = 0xFF;
+ float transformMatrix[] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f};
+ memcpy(&instanceDescs[0].transform[0][0], transformMatrix, sizeof(float) * 12);
+
+ IBufferResource::Desc instanceBufferDesc;
+ instanceBufferDesc.type = IResource::Type::Buffer;
+ instanceBufferDesc.sizeInBytes =
+ instanceDescs.getCount() * sizeof(IAccelerationStructure::InstanceDesc);
+ instanceBufferDesc.defaultState = ResourceState::ShaderResource;
+ gInstanceBuffer = gDevice->createBufferResource(instanceBufferDesc, instanceDescs.getBuffer());
+ if (!gInstanceBuffer)
+ return SLANG_FAIL;
+
+ IAccelerationStructure::BuildInputs accelerationStructureBuildInputs = {};
+ IAccelerationStructure::PrebuildInfo accelerationStructurePrebuildInfo = {};
+ accelerationStructureBuildInputs.descCount = 1;
+ accelerationStructureBuildInputs.kind = IAccelerationStructure::Kind::TopLevel;
+ accelerationStructureBuildInputs.instanceDescs = gInstanceBuffer->getDeviceAddress();
+
+ // Query buffer size for acceleration structure build.
+ SLANG_RETURN_ON_FAIL(gDevice->getAccelerationStructurePrebuildInfo(
+ accelerationStructureBuildInputs, &accelerationStructurePrebuildInfo));
+
+ IBufferResource::Desc asBufferDesc;
+ asBufferDesc.type = IResource::Type::Buffer;
+ asBufferDesc.defaultState = ResourceState::AccelerationStructure;
+ asBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.resultDataMaxSize;
+ gTLASBuffer = gDevice->createBufferResource(asBufferDesc);
+
+ IBufferResource::Desc scratchBufferDesc;
+ scratchBufferDesc.type = IResource::Type::Buffer;
+ scratchBufferDesc.defaultState = ResourceState::UnorderedAccess;
+ scratchBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.scratchDataSize;
+ ComPtr<IBufferResource> scratchBuffer = gDevice->createBufferResource(scratchBufferDesc);
+
+ IAccelerationStructure::CreateDesc createDesc;
+ createDesc.buffer = gTLASBuffer;
+ createDesc.kind = IAccelerationStructure::Kind::TopLevel;
+ createDesc.offset = 0;
+ createDesc.size = accelerationStructurePrebuildInfo.resultDataMaxSize;
+ SLANG_RETURN_ON_FAIL(gDevice->createAccelerationStructure(createDesc, gTLAS.writeRef()));
+
+ auto commandBuffer = gTransientHeaps[0]->createCommandBuffer();
+ auto encoder = commandBuffer->encodeRayTracingCommands();
+ IAccelerationStructure::BuildDesc buildDesc = {};
+ buildDesc.dest = gTLAS;
+ buildDesc.inputs = accelerationStructureBuildInputs;
+ buildDesc.scratchData = scratchBuffer->getDeviceAddress();
+ encoder->buildAccelerationStructure(buildDesc, 0, nullptr);
+ encoder->endEncoding();
+ commandBuffer->close();
+ gQueue->executeCommandBuffer(commandBuffer);
+ gQueue->wait();
+ }
+
+ IBufferResource::Desc fullScreenVertexBufferDesc;
+ fullScreenVertexBufferDesc.type = IResource::Type::Buffer;
+ fullScreenVertexBufferDesc.sizeInBytes =
+ FullScreenTriangle::kVertexCount * sizeof(FullScreenTriangle::Vertex);
+ fullScreenVertexBufferDesc.defaultState = ResourceState::VertexBuffer;
+ gFullScreenVertexBuffer = gDevice->createBufferResource(
+ fullScreenVertexBufferDesc, &FullScreenTriangle::kVertices[0]);
+ if (!gFullScreenVertexBuffer)
+ return SLANG_FAIL;
+
+ InputElementDesc inputElements[] = {
+ {"POSITION", 0, Format::RG_Float32, offsetof(FullScreenTriangle::Vertex, position)},
+ };
+ auto inputLayout = gDevice->createInputLayout(&inputElements[0], SLANG_COUNT_OF(inputElements));
+ if (!inputLayout)
+ return SLANG_FAIL;
+
+ ComPtr<IShaderProgram> shaderProgram;
+ SLANG_RETURN_ON_FAIL(loadShaderProgram(gDevice, PipelineType::Graphics, shaderProgram.writeRef()));
+ GraphicsPipelineStateDesc desc;
+ desc.inputLayout = inputLayout;
+ desc.program = shaderProgram;
+ desc.framebufferLayout = gFramebufferLayout;
+ gPresentPipelineState = gDevice->createGraphicsPipelineState(desc);
+ if (!gPresentPipelineState)
+ return SLANG_FAIL;
+
+ ComPtr<IShaderProgram> rayTracingProgram;
+ SLANG_RETURN_ON_FAIL(
+ loadShaderProgram(gDevice, PipelineType::RayTracing, rayTracingProgram.writeRef()));
+ RayTracingPipelineStateDesc rtpDesc = {};
+ rtpDesc.program = rayTracingProgram;
+ rtpDesc.hitGroupCount = 2;
+ HitGroupDesc hitGroups[2];
+ hitGroups[0].closestHitEntryPoint = "closestHitShader";
+ hitGroups[1].closestHitEntryPoint = "shadowRayHitShader";
+ rtpDesc.hitGroups = hitGroups;
+ rtpDesc.maxRayPayloadSize = 64;
+ rtpDesc.maxRecursion = 2;
+ rtpDesc.shaderTableHitGroupCount = 2;
+ int32_t shaderTable[] = {0, 1};
+ rtpDesc.shaderTableHitGroupIndices = shaderTable;
+ SLANG_RETURN_ON_FAIL(
+ gDevice->createRayTracingPipelineState(rtpDesc, gRenderPipelineState.writeRef()));
+ if (!gRenderPipelineState)
+ return SLANG_FAIL;
+
+ createResultTexture();
+ return SLANG_OK;
+}
+
+void createResultTexture()
+{
+ ITextureResource::Desc resultTextureDesc = {};
+ resultTextureDesc.type = IResource::Type::Texture2D;
+ resultTextureDesc.numMipLevels = 1;
+ resultTextureDesc.size.width = windowWidth;
+ resultTextureDesc.size.height = windowHeight;
+ resultTextureDesc.size.depth = 1;
+ resultTextureDesc.defaultState = ResourceState::UnorderedAccess;
+ resultTextureDesc.format = Format::RGBA_Float16;
+ gResultTexture = gDevice->createTextureResource(resultTextureDesc);
+ IResourceView::Desc resultUAVDesc = {};
+ resultUAVDesc.format = resultTextureDesc.format;
+ resultUAVDesc.type = IResourceView::Type::UnorderedAccess;
+ gResultTextureUAV = gDevice->createTextureView(gResultTexture, resultUAVDesc);
+}
+
+virtual void windowSizeChanged() override
+{
+ WindowedAppBase::windowSizeChanged();
+ createResultTexture();
+}
+
+glm::vec3 getVectorFromSphericalAngles(float theta, float phi)
+{
+ auto sinTheta = sin(theta);
+ auto cosTheta = cos(theta);
+ auto sinPhi = sin(phi);
+ auto cosPhi = cos(phi);
+ return glm::vec3(-sinTheta * cosPhi, sinPhi, -cosTheta * cosPhi);
+}
+void updateUniforms()
+{
+ gUniforms.screenWidth = (float)windowWidth;
+ gUniforms.screenHeight = (float)windowHeight;
+ if (!lastTime)
+ lastTime = getCurrentTime();
+ uint64_t currentTime = getCurrentTime();
+ float deltaTime = float(double(currentTime - lastTime) / double(getTimerFrequency()));
+ lastTime = currentTime;
+
+ auto camDir =
+ getVectorFromSphericalAngles(cameraOrientationAngles[0], cameraOrientationAngles[1]);
+ auto camUp = getVectorFromSphericalAngles(
+ cameraOrientationAngles[0], cameraOrientationAngles[1] + glm::pi<float>() * 0.5f);
+ auto camRight = glm::cross(camDir, camUp);
+
+ glm::vec3 movement = glm::vec3(0);
+ if (wPressed)
+ movement += camDir;
+ if (sPressed)
+ movement -= camDir;
+ if (aPressed)
+ movement -= camRight;
+ if (dPressed)
+ movement += camRight;
+
+ cameraPosition += deltaTime * translationScale * movement;
+
+ memcpy(gUniforms.cameraDir, &camDir, sizeof(float) * 3);
+ memcpy(gUniforms.cameraUp, &camUp, sizeof(float) * 3);
+ memcpy(gUniforms.cameraRight, &camRight, sizeof(float) * 3);
+ memcpy(gUniforms.cameraPosition, &cameraPosition, sizeof(float) * 3);
+ auto lightDir = glm::normalize(glm::vec3(1.0f, 3.0f, 2.0f));
+ memcpy(gUniforms.lightDir, &lightDir, sizeof(float) * 3);
+}
+
+virtual void renderFrame(int frameBufferIndex) override
+{
+ updateUniforms();
+ {
+ ComPtr<ICommandBuffer> renderCommandBuffer =
+ gTransientHeaps[frameBufferIndex]->createCommandBuffer();
+ auto renderEncoder = renderCommandBuffer->encodeRayTracingCommands();
+ IShaderObject* rootObject = nullptr;
+ renderEncoder->bindPipeline(gRenderPipelineState, &rootObject);
+ auto cursor = ShaderCursor(rootObject);
+ cursor["resultTexture"].setResource(gResultTextureUAV);
+ cursor["uniforms"].setData(&gUniforms, sizeof(Uniforms));
+ cursor["sceneBVH"].setResource(gTLAS);
+ cursor["primitiveBuffer"].setResource(gPrimitiveBufferSRV);
+ renderEncoder->dispatchRays(nullptr, windowWidth, windowHeight, 1);
+ renderEncoder->endEncoding();
+ renderCommandBuffer->close();
+ gQueue->executeCommandBuffer(renderCommandBuffer);
+ }
+
+ {
+ ComPtr<ICommandBuffer> presentCommandBuffer =
+ gTransientHeaps[frameBufferIndex]->createCommandBuffer();
+ auto presentEncoder = presentCommandBuffer->encodeRenderCommands(
+ gRenderPass, gFramebuffers[frameBufferIndex]);
+ gfx::Viewport viewport = {};
+ viewport.maxZ = 1.0f;
+ viewport.extentX = (float)windowWidth;
+ viewport.extentY = (float)windowHeight;
+ presentEncoder->setViewportAndScissor(viewport);
+ auto rootObject = presentEncoder->bindPipeline(gPresentPipelineState);
+ auto cursor = ShaderCursor(rootObject->getEntryPoint(1));
+ cursor["t"].setResource(gResultTextureUAV);
+ presentEncoder->setVertexBuffer(
+ 0, gFullScreenVertexBuffer, sizeof(FullScreenTriangle::Vertex));
+ presentEncoder->setPrimitiveTopology(PrimitiveTopology::TriangleList);
+ presentEncoder->draw(3);
+ presentEncoder->endEncoding();
+ presentCommandBuffer->close();
+ gQueue->executeCommandBuffer(presentCommandBuffer);
+ }
+ // With that, we are done drawing for one frame, and ready for the next.
+ //
+ gSwapchain->present();
+}
+
+};
+
+// This macro instantiates an appropriate main function to
+// run the application defined above.
+PLATFORM_UI_MAIN(innerMain<RayTracing>)
diff --git a/examples/ray-tracing-pipeline/shaders.slang b/examples/ray-tracing-pipeline/shaders.slang
new file mode 100644
index 000000000..77193f08e
--- /dev/null
+++ b/examples/ray-tracing-pipeline/shaders.slang
@@ -0,0 +1,108 @@
+// shaders.slang
+
+struct Uniforms
+{
+ float screenWidth, screenHeight;
+ float focalLength, frameHeight;
+ float4 cameraDir;
+ float4 cameraUp;
+ float4 cameraRight;
+ float4 cameraPosition;
+ float4 lightDir;
+};
+
+struct Primitive
+{
+ float4 data0;
+ float4 color;
+ float3 getNormal() { return data0.xyz; }
+ float3 getColor() { return color.xyz; }
+};
+
+struct RayPayload
+{
+ float4 color;
+};
+
+uniform RWTexture2D resultTexture;
+uniform RaytracingAccelerationStructure sceneBVH;
+uniform StructuredBuffer<Primitive> primitiveBuffer;
+uniform Uniforms uniforms;
+
+[shader("raygeneration")]
+void rayGenShader()
+{
+ uint2 threadIdx = DispatchRaysIndex().xy;
+ if (threadIdx.x >= (int)uniforms.screenWidth) return;
+ if (threadIdx.y >= (int)uniforms.screenHeight) return;
+
+ float frameWidth = uniforms.screenWidth / uniforms.screenHeight * uniforms.frameHeight;
+ float imageY = (threadIdx.y / uniforms.screenHeight - 0.5f) * uniforms.frameHeight;
+ float imageX = (threadIdx.x / uniforms.screenWidth - 0.5f) * frameWidth;
+ float imageZ = uniforms.focalLength;
+ float3 rayDir = normalize(uniforms.cameraDir.xyz*imageZ - uniforms.cameraUp.xyz * imageY + uniforms.cameraRight.xyz * imageX);
+
+ // Trace the ray.
+ RayDesc ray;
+ ray.Origin = uniforms.cameraPosition.xyz;
+ ray.Direction = rayDir;
+ ray.TMin = 0.001;
+ ray.TMax = 10000.0;
+ RayPayload payload = { float4(0, 0, 0, 0) };
+ TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 0, 0, 0, ray, payload);
+
+ resultTexture[threadIdx.xy] = payload.color;
+}
+
+[shader("miss")]
+void missShader(inout RayPayload payload)
+{
+ payload.color = float4(0, 0, 0, 1);
+}
+
+[shader("closesthit")]
+void closestHitShader(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr)
+{
+ float3 hitLocation = WorldRayOrigin() + WorldRayDirection() * RayTCurrent();
+ float3 shadowRayDir = uniforms.lightDir.xyz;
+
+ RayDesc ray;
+ ray.Origin = hitLocation;
+ ray.Direction = shadowRayDir;
+ ray.TMin = 0.001;
+ ray.TMax = 10000.0;
+ RayPayload shadowPayload = { float4(0, 0, 0, 0) };
+ TraceRay(sceneBVH, RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH, ~0, 1, 0, 0, ray, shadowPayload);
+ float shadow = 1.0 - shadowPayload.color.x;
+
+ let primitiveIndex = PrimitiveIndex();
+ float3 normal = primitiveBuffer[primitiveIndex].getNormal();
+ float3 color = primitiveBuffer[primitiveIndex].getColor();
+ float ndotl = max(0.0, shadow * dot(normal, uniforms.lightDir.xyz));
+ float intensity = ndotl * 0.7 + 0.3;
+ payload.color = float4(color * intensity, 1.0f);
+}
+
+[shader("closesthit")]
+void shadowRayHitShader(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr)
+{
+ payload.color = float4(1.0, 1.0, 1.0, 1.0);
+}
+
+/// Vertex and fragment shader for displaying the final image.
+
+[shader("vertex")]
+float4 vertexMain(float2 position : POSITION)
+ : SV_Position
+{
+ return float4(position, 0.5, 1.0);
+}
+
+[shader("fragment")]
+float4 fragmentMain(
+ float4 sv_position : SV_Position,
+ uniform RWTexture2D t)
+ : SV_Target
+{
+ return t.Load(sv_position.xy);
+}
diff --git a/premake5.lua b/premake5.lua
index 2ed40ba16..71af44d4a 100644
--- a/premake5.lua
+++ b/premake5.lua
@@ -653,6 +653,7 @@ example "hello-world"
example "triangle"
example "ray-tracing"
+example "ray-tracing-pipeline"
example "gpu-printing"
kind "ConsoleApp"
diff --git a/slang-gfx.h b/slang-gfx.h
index f76788b03..f4a70d25c 100644
--- a/slang-gfx.h
+++ b/slang-gfx.h
@@ -919,10 +919,20 @@ struct RayTracingPipelineFlags
};
};
+struct HitGroupDesc
+{
+ const char* closestHitEntryPoint = nullptr;
+ const char* anyHitEntryPoint = nullptr;
+ const char* intersectionEntryPoint = nullptr;
+};
+
struct RayTracingPipelineStateDesc
{
IShaderProgram* program = nullptr;
-
+ int32_t hitGroupCount;
+ const HitGroupDesc* hitGroups;
+ int32_t shaderTableHitGroupCount;
+ int32_t* shaderTableHitGroupIndices;
int maxRecursion;
int maxRayPayloadSize;
RayTracingPipelineFlags::Enum flags;
@@ -1191,6 +1201,17 @@ public:
IAccelerationStructure* const* structures,
AccessFlag::Enum sourceAccess,
AccessFlag::Enum destAccess) = 0;
+
+ virtual SLANG_NO_THROW void SLANG_MCALL
+ bindPipeline(IPipelineState* state, IShaderObject** outRootObject) = 0;
+ /// Issues a dispatch command to start ray tracing workload with a ray tracing pipeline.
+ /// `rayGenShaderName` specifies the name of the ray generation shader to launch. Pass nullptr for
+ /// the first ray generation shader defined in `raytracingPipeline`.
+ virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth) = 0;
};
#define SLANG_UUID_IRayTracingCommandEncoder \
{ \
diff --git a/slang.sln b/slang.sln
index 1551b7110..c0e80b982 100644
--- a/slang.sln
+++ b/slang.sln
@@ -35,6 +35,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "model-viewer", "build\visua
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ray-tracing", "build\visual-studio\ray-tracing\ray-tracing.vcxproj", "{71AC0F50-5DFD-FA91-8661-E95372118EFB}"
EndProject
+Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ray-tracing-pipeline", "build\visual-studio\ray-tracing-pipeline\ray-tracing-pipeline.vcxproj", "{17BA8E32-034E-84DA-6C12-DE8E58C5BECC}"
+EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "shader-object", "build\visual-studio\shader-object\shader-object.vcxproj", "{25512BFB-1138-EDF2-BA88-5310A64E6659}"
EndProject
Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "shader-toy", "build\visual-studio\shader-toy\shader-toy.vcxproj", "{0FC5DE93-FBEA-A8FA-E430-2EC6D0F5CDC6}"
@@ -195,6 +197,14 @@ Global
{71AC0F50-5DFD-FA91-8661-E95372118EFB}.Release|Win32.Build.0 = Release|Win32
{71AC0F50-5DFD-FA91-8661-E95372118EFB}.Release|x64.ActiveCfg = Release|x64
{71AC0F50-5DFD-FA91-8661-E95372118EFB}.Release|x64.Build.0 = Release|x64
+ {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Debug|Win32.ActiveCfg = Debug|Win32
+ {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Debug|Win32.Build.0 = Debug|Win32
+ {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Debug|x64.ActiveCfg = Debug|x64
+ {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Debug|x64.Build.0 = Debug|x64
+ {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Release|Win32.ActiveCfg = Release|Win32
+ {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Release|Win32.Build.0 = Release|Win32
+ {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Release|x64.ActiveCfg = Release|x64
+ {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Release|x64.Build.0 = Release|x64
{25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|Win32.ActiveCfg = Debug|Win32
{25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|Win32.Build.0 = Debug|Win32
{25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|x64.ActiveCfg = Debug|x64
@@ -293,6 +303,7 @@ Global
{010BE414-ED5B-CF56-16C0-BD18027062C0} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231}
{2F8724C6-1BC3-2730-84D5-3F277030D04A} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231}
{71AC0F50-5DFD-FA91-8661-E95372118EFB} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231}
+ {17BA8E32-034E-84DA-6C12-DE8E58C5BECC} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231}
{25512BFB-1138-EDF2-BA88-5310A64E6659} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231}
{0FC5DE93-FBEA-A8FA-E430-2EC6D0F5CDC6} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231}
{3BB99068-27C9-3C39-9082-A1577CB12BD2} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231}
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 24a1fd93e..7e436d28d 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -10,6 +10,7 @@
#include "../d3d/d3d-swapchain.h"
#include "core/slang-blob.h"
#include "core/slang-basic.h"
+#include "core/slang-chunked-list.h"
// In order to use the Slang API, we need to include its header
@@ -123,8 +124,6 @@ public:
const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState(
const ComputePipelineStateDesc& desc, IPipelineState** outState) override;
- virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState(
- const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool(
const IQueryPool::Desc& desc, IQueryPool** outState) override;
@@ -156,6 +155,8 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL createAccelerationStructure(
const IAccelerationStructure::CreateDesc& desc,
IAccelerationStructure** outView) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState(
+ const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override;
#endif
public:
@@ -193,6 +194,7 @@ public:
virtual void setRootDescriptorTable(int index, D3D12_GPU_DESCRIPTOR_HANDLE BaseDescriptor) = 0;
virtual void setRootSignature(ID3D12RootSignature* rootSignature) = 0;
virtual void setRootConstants(Index rootParamIndex, Index dstOffsetIn32BitValues, Index countOf32BitValues, void const* srcData) = 0;
+ virtual void setPipelineState(PipelineStateBase* pipelineState) = 0;
};
class BufferResourceImpl: public gfx::BufferResource
@@ -340,6 +342,31 @@ public:
}
};
+#if SLANG_GFX_HAS_DXR_SUPPORT
+ class RayTracingPipelineStateImpl : public PipelineStateBase
+ {
+ public:
+ ComPtr<ID3D12StateObject> m_stateObject;
+ D3D12_DISPATCH_RAYS_DESC m_dispatchDesc = {};
+ Dictionary<String, int32_t> m_mapRayGenShaderNameToShaderTableIndex;
+ // Shader Tables for each ray-tracing stage stored in GPU memory.
+ RefPtr<BufferResourceImpl> m_rayGenShaderTable;
+ RefPtr<BufferResourceImpl> m_hitgroupShaderTable;
+ RefPtr<BufferResourceImpl> m_missShaderTable;
+ void init(const RayTracingPipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::RayTracing;
+ pipelineDesc.rayTracing = inDesc;
+ initializeBase(pipelineDesc);
+ }
+ Result createShaderTables(
+ D3D12Device* device,
+ slang::IComponentType* slangProgram,
+ const RayTracingPipelineStateDesc& desc);
+ };
+#endif
+
class QueryPoolImpl : public IQueryPool, public ComObject
{
public:
@@ -461,6 +488,11 @@ public:
{
m_commandList->SetGraphicsRoot32BitConstants(UINT(rootParamIndex), UINT(countOf32BitValues), srcData, UINT(dstOffsetIn32BitValues));
}
+ virtual void setPipelineState(PipelineStateBase* pipeline) override
+ {
+ auto pipelineImpl = static_cast<PipelineStateImpl*>(pipeline);
+ m_commandList->SetPipelineState(pipelineImpl->m_pipelineState.get());
+ }
GraphicsSubmitter(ID3D12GraphicsCommandList* commandList):
m_commandList(commandList)
@@ -492,7 +524,11 @@ public:
{
m_commandList->SetComputeRoot32BitConstants(UINT(rootParamIndex), UINT(countOf32BitValues), srcData, UINT(dstOffsetIn32BitValues));
}
-
+ virtual void setPipelineState(PipelineStateBase* pipeline) override
+ {
+ auto pipelineImpl = static_cast<PipelineStateImpl*>(pipeline);
+ m_commandList->SetPipelineState(pipelineImpl->m_pipelineState.get());
+ }
ComputeSubmitter(ID3D12GraphicsCommandList* commandList) :
m_commandList(commandList)
{
@@ -568,6 +604,7 @@ public:
{
uint64_t waitValue;
HANDLE fenceEvent;
+ ID3D12Fence* fence = nullptr;
};
ShortList<QueueWaitInfo, 4> m_waitInfos;
@@ -585,7 +622,7 @@ public:
m_waitInfos[i].fenceEvent = CreateEventEx(
nullptr,
false,
- CREATE_EVENT_INITIAL_SET | CREATE_EVENT_MANUAL_RESET,
+ 0,
EVENT_ALL_ACCESS);
}
return m_waitInfos[queueIndex];
@@ -666,7 +703,7 @@ public:
ID3D12GraphicsCommandList* m_d3dCmdList;
ID3D12GraphicsCommandList* m_preCmdList = nullptr;
- RefPtr<PipelineStateImpl> m_currentPipeline;
+ RefPtr<PipelineStateBase> m_currentPipeline;
static int getBindPointIndex(PipelineType type)
{
@@ -690,13 +727,14 @@ public:
m_d3dCmdList = m_commandBuffer->m_cmdList;
m_renderer = commandBuffer->m_renderer;
m_transientHeap = commandBuffer->m_transientHeap;
+ m_device = commandBuffer->m_renderer->m_device;
}
void endEncodingImpl() { m_isOpen = false; }
Result bindPipelineImpl(IPipelineState* pipelineState, IShaderObject** outRootObject)
{
- m_currentPipeline = static_cast<PipelineStateImpl*>(pipelineState);
+ m_currentPipeline = static_cast<PipelineStateBase*>(pipelineState);
auto rootObject = &m_commandBuffer->m_rootShaderObject;
SLANG_RETURN_ON_FAIL(rootObject->reset(
m_renderer,
@@ -707,7 +745,11 @@ public:
return SLANG_OK;
}
- Result _bindRenderState(Submitter* submitter);
+ /// Specializes the pipeline according to current root-object argument values,
+ /// applys the root object bindings and binds the pipeline state.
+ /// The newly specialized pipeline is held alive by the pipeline cache so users of
+ /// `newPipeline` do not need to maintain its lifespan.
+ Result _bindRenderState(Submitter* submitter, RefPtr<PipelineStateBase>& newPipeline);
};
struct DescriptorTable
@@ -2956,7 +2998,6 @@ public:
{
PipelineCommandEncoder::init(cmdBuffer);
m_preCmdList = nullptr;
- m_device = renderer->m_device;
m_renderPass = renderPass;
m_framebuffer = framebuffer;
m_transientHeap = transientHeap;
@@ -3174,7 +3215,8 @@ public:
// Submit - setting for graphics
{
GraphicsSubmitter submitter(m_d3dCmdList);
- if(SLANG_FAILED(_bindRenderState(&submitter)))
+ RefPtr<PipelineStateBase> newPipeline;
+ if(SLANG_FAILED(_bindRenderState(&submitter, newPipeline)))
{
assert(!"Failed to bind render state");
}
@@ -3314,7 +3356,6 @@ public:
{
PipelineCommandEncoder::init(cmdBuffer);
m_preCmdList = nullptr;
- m_device = renderer->m_device;
m_transientHeap = transientHeap;
m_currentPipeline = nullptr;
}
@@ -3330,7 +3371,8 @@ public:
// Submit binding for compute
{
ComputeSubmitter submitter(m_d3dCmdList);
- if(SLANG_FAILED(_bindRenderState(&submitter)))
+ RefPtr<PipelineStateBase> newPipeline;
+ if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline)))
{
assert(!"Failed to bind render state");
}
@@ -3402,12 +3444,15 @@ public:
}
#if SLANG_GFX_HAS_DXR_SUPPORT
- class RayTracingCommandEncoderImpl : public IRayTracingCommandEncoder
+ class RayTracingCommandEncoderImpl
+ : public IRayTracingCommandEncoder
+ , public PipelineCommandEncoder
{
public:
CommandBufferImpl* m_commandBuffer;
void init(D3D12Device* renderer, CommandBufferImpl* commandBuffer)
{
+ PipelineCommandEncoder::init(commandBuffer);
m_commandBuffer = commandBuffer;
}
virtual SLANG_NO_THROW void SLANG_MCALL buildAccelerationStructure(
@@ -3434,6 +3479,13 @@ public:
IAccelerationStructure* const* structures,
AccessFlag::Enum sourceAccess,
AccessFlag::Enum destAccess) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL
+ bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth) override;
virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() {}
virtual SLANG_NO_THROW void SLANG_MCALL
writeTimestamp(IQueryPool* pool, SlangInt index) override
@@ -3533,8 +3585,7 @@ public:
auto transientHeap = cmdImpl->m_transientHeap;
auto& waitInfo = transientHeap->getQueueWaitInfo(m_queueIndex);
waitInfo.waitValue = m_fenceValue;
- ResetEvent(waitInfo.fenceEvent);
- m_fence->SetEventOnCompletion(m_fenceValue, waitInfo.fenceEvent);
+ waitInfo.fence = m_fence;
}
m_d3dQueue->Signal(m_fence, m_fenceValue);
ResetEvent(globalWaitHandle);
@@ -3722,8 +3773,13 @@ SLANG_NO_THROW Result SLANG_MCALL D3D12Device::TransientResourceHeapImpl::synchr
Array<HANDLE, 16> waitHandles;
for (auto& waitInfo : m_waitInfos)
{
- if (waitInfo.waitValue != 0)
+ if (waitInfo.waitValue == 0)
+ continue;
+ if (waitInfo.fence)
+ {
+ waitInfo.fence->SetEventOnCompletion(waitInfo.waitValue, waitInfo.fenceEvent);
waitHandles.add(waitInfo.fenceEvent);
+ }
}
WaitForMultipleObjects((DWORD)waitHandles.getCount(), waitHandles.getBuffer(), TRUE, INFINITE);
m_viewHeap.deallocateAll();
@@ -3763,16 +3819,15 @@ Result D3D12Device::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffe
return SLANG_OK;
}
-Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitter)
+Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitter, RefPtr<PipelineStateBase>& newPipeline)
{
- RefPtr<PipelineStateBase> newPipeline;
RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootShaderObject;
m_renderer->maybeSpecializePipeline(m_currentPipeline, rootObjectImpl, newPipeline);
- PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr());
+ PipelineStateBase* newPipelineImpl = static_cast<PipelineStateBase*>(newPipeline.Ptr());
auto commandList = m_d3dCmdList;
auto pipelineTypeIndex = (int)newPipelineImpl->desc.type;
auto programImpl = static_cast<ShaderProgramImpl*>(newPipelineImpl->m_program.Ptr());
- commandList->SetPipelineState(newPipelineImpl->m_pipelineState);
+ submitter->setPipelineState(newPipelineImpl);
submitter->setRootSignature(programImpl->m_rootObjectLayout->m_rootSignature);
RefPtr<ShaderObjectLayoutImpl> specializedRootLayout;
SLANG_RETURN_ON_FAIL(rootObjectImpl->getSpecializedLayout(specializedRootLayout.writeRef()));
@@ -5469,11 +5524,6 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
return SLANG_OK;
}
-Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState)
-{
- return SLANG_E_NOT_AVAILABLE;
-}
-
Result D3D12Device::QueryPoolImpl::init(const IQueryPool::Desc& desc, D3D12Device* device)
{
// Translate query type.
@@ -5801,7 +5851,290 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::memoryBarrier
m_commandBuffer->m_cmdList4->ResourceBarrier((UINT)count, barriers.getArrayView().getBuffer());
}
+void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::bindPipeline(
+ IPipelineState* state, IShaderObject** outRootObject)
+{
+ bindPipelineImpl(state, outRootObject);
+}
+
+void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth)
+{
+ RefPtr<PipelineStateBase> newPipeline;
+ PipelineStateBase* pipeline = m_currentPipeline.Ptr();
+ {
+ struct RayTracingSubmitter : public ComputeSubmitter
+ {
+ ID3D12GraphicsCommandList4* m_cmdList4;
+ RayTracingSubmitter(ID3D12GraphicsCommandList4* cmdList4)
+ : ComputeSubmitter(cmdList4), m_cmdList4(cmdList4)
+ {
+ }
+ virtual void setPipelineState(PipelineStateBase* pipeline) override
+ {
+ auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
+ m_cmdList4->SetPipelineState1(pipelineImpl->m_stateObject.get());
+ }
+ };
+ RayTracingSubmitter submitter(m_commandBuffer->m_cmdList4);
+ if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline)))
+ {
+ assert(!"Failed to bind render state");
+ }
+ if (newPipeline)
+ pipeline = newPipeline.Ptr();
+ }
+ auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
+ auto dispatchDesc = pipelineImpl->m_dispatchDesc;
+ int32_t rayGenShaderOffset = 0;
+ if (rayGenShaderName)
+ {
+ rayGenShaderOffset =
+ pipelineImpl->m_mapRayGenShaderNameToShaderTableIndex[rayGenShaderName].GetValue() *
+ D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ }
+ dispatchDesc.RayGenerationShaderRecord.StartAddress += rayGenShaderOffset;
+ dispatchDesc.Width = (UINT)width;
+ dispatchDesc.Height = (UINT)height;
+ dispatchDesc.Depth = (UINT)depth;
+ m_commandBuffer->m_cmdList4->DispatchRays(&dispatchDesc);
+}
+
+Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState)
+{
+ if (!m_device5)
+ {
+ return SLANG_E_NOT_AVAILABLE;
+ }
+
+ RefPtr<RayTracingPipelineStateImpl> pipelineStateImpl = new RayTracingPipelineStateImpl();
+ pipelineStateImpl->init(inDesc);
+
+ auto program = static_cast<ShaderProgramImpl*>(inDesc.program);
+ auto slangProgram = program->slangProgram;
+ auto programLayout = slangProgram->getLayout();
+
+ if (!program->m_rootObjectLayout->m_rootSignature)
+ {
+ returnComPtr(outState, pipelineStateImpl);
+ return SLANG_OK;
+ }
+ List<D3D12_STATE_SUBOBJECT> subObjects;
+ ChunkedList<D3D12_DXIL_LIBRARY_DESC> dxilLibraries;
+ ChunkedList<D3D12_HIT_GROUP_DESC> hitGroups;
+ ChunkedList<ComPtr<ISlangBlob>> codeBlobs;
+ ComPtr<ISlangBlob> diagnostics;
+ ChunkedList<OSString> stringPool;
+ int32_t rayGenIndex = 0;
+ for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
+ {
+ ComPtr<ISlangBlob> codeBlob;
+ auto compileResult =
+ slangProgram->getEntryPointCode(i, 0, codeBlob.writeRef(), diagnostics.writeRef());
+ if (diagnostics.get())
+ {
+ getDebugCallback()->handleMessage(
+ compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error,
+ DebugMessageSource::Slang,
+ (char*)diagnostics->getBufferPointer());
+ }
+ SLANG_RETURN_ON_FAIL(compileResult);
+ codeBlobs.add(codeBlob);
+ D3D12_DXIL_LIBRARY_DESC library = {};
+ library.DXILLibrary.BytecodeLength = codeBlob->getBufferSize();;
+ library.DXILLibrary.pShaderBytecode = codeBlob->getBufferPointer();
+
+ D3D12_STATE_SUBOBJECT dxilSubObject = {};
+ dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY;
+ dxilSubObject.pDesc = dxilLibraries.add(library);
+ subObjects.add(dxilSubObject);
+
+ auto entryPointLayout = programLayout->getEntryPointByIndex(i);
+ switch (entryPointLayout->getStage())
+ {
+ case SLANG_STAGE_RAY_GENERATION:
+ pipelineStateImpl
+ ->m_mapRayGenShaderNameToShaderTableIndex[entryPointLayout->getName()] =
+ rayGenIndex;
+ rayGenIndex++;
+ break;
+ default:
+ break;
+ }
+ }
+ auto getWStr = [&](const char* name)
+ {
+ String str = String(name);
+ auto wstr = str.toWString();
+ return stringPool.add(wstr)->begin();
+ };
+ for (int i = 0; i < inDesc.hitGroupCount; i++)
+ {
+ auto hitGroup = inDesc.hitGroups[i];
+ D3D12_HIT_GROUP_DESC hitGroupDesc = {};
+ hitGroupDesc.Type = hitGroup.intersectionEntryPoint == nullptr
+ ? D3D12_HIT_GROUP_TYPE_TRIANGLES
+ : D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE;
+
+ if (hitGroup.anyHitEntryPoint)
+ {
+ hitGroupDesc.AnyHitShaderImport = getWStr(hitGroup.anyHitEntryPoint);
+ }
+ if (hitGroup.closestHitEntryPoint)
+ {
+ hitGroupDesc.ClosestHitShaderImport = getWStr(hitGroup.closestHitEntryPoint);
+ }
+ if (hitGroup.intersectionEntryPoint)
+ {
+ hitGroupDesc.IntersectionShaderImport = getWStr(hitGroup.intersectionEntryPoint);
+ }
+ StringBuilder hitGroupName;
+ hitGroupName << "hitgroup_" << i;
+ hitGroupDesc.HitGroupExport = getWStr(hitGroupName.ToString().getBuffer());
+
+ D3D12_STATE_SUBOBJECT hitGroupSubObject = {};
+ hitGroupSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP;
+ hitGroupSubObject.pDesc = hitGroups.add(hitGroupDesc);
+ subObjects.add(hitGroupSubObject);
+ }
+
+ D3D12_RAYTRACING_SHADER_CONFIG shaderConfig = {};
+ // According to DXR spec, fixed function triangle intersections must use float2 as ray attributes
+ // that defines the barycentric coordinates at intersection.
+ shaderConfig.MaxAttributeSizeInBytes = sizeof(float) * 2;
+ shaderConfig.MaxPayloadSizeInBytes = inDesc.maxRayPayloadSize;
+ D3D12_STATE_SUBOBJECT shaderConfigSubObject = {};
+ shaderConfigSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG;
+ shaderConfigSubObject.pDesc = &shaderConfig;
+ subObjects.add(shaderConfigSubObject);
+
+ D3D12_GLOBAL_ROOT_SIGNATURE globalSignatureDesc = {};
+ globalSignatureDesc.pGlobalRootSignature = program->m_rootObjectLayout->m_rootSignature.get();
+ D3D12_STATE_SUBOBJECT globalSignatureSubobject = {};
+ globalSignatureSubobject.Type = D3D12_STATE_SUBOBJECT_TYPE_GLOBAL_ROOT_SIGNATURE;
+ globalSignatureSubobject.pDesc = &globalSignatureDesc;
+ subObjects.add(globalSignatureSubobject);
+
+ D3D12_RAYTRACING_PIPELINE_CONFIG pipelineConfig = {};
+ pipelineConfig.MaxTraceRecursionDepth = inDesc.maxRecursion;
+ D3D12_STATE_SUBOBJECT pipelineConfigSubobject = {};
+ pipelineConfigSubobject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_PIPELINE_CONFIG;
+ pipelineConfigSubobject.pDesc = &pipelineConfig;
+ subObjects.add(pipelineConfigSubobject);
+
+ D3D12_STATE_OBJECT_DESC rtpsoDesc = {};
+ rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE;
+ rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount();
+ rtpsoDesc.pSubobjects = subObjects.getBuffer();
+ SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef())));
+
+ SLANG_RETURN_ON_FAIL(pipelineStateImpl->createShaderTables(this, slangProgram, inDesc));
+
+ returnComPtr(outState, pipelineStateImpl);
+ return SLANG_OK;
+}
+
+Result D3D12Device::RayTracingPipelineStateImpl::createShaderTables(
+ D3D12Device* device,
+ slang::IComponentType* slangProgram,
+ const RayTracingPipelineStateDesc& desc)
+{
+ ComPtr<ID3D12StateObjectProperties> stateObjectProperties;
+ m_stateObject->QueryInterface(stateObjectProperties.writeRef());
+ auto programLayout = slangProgram->getLayout();
+ struct ShaderIdentifier { uint32_t data[D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES / sizeof(uint32_t)]; };
+ List<ShaderIdentifier> rayGenIdentifiers, missIdentifiers, hitgroupIdentifiers;
+ for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
+ {
+ auto entryPointLayout = programLayout->getEntryPointByIndex(i);
+ ShaderIdentifier identifier;
+ switch (entryPointLayout->getStage())
+ {
+ case SLANG_STAGE_RAY_GENERATION:
+ memcpy(
+ &identifier,
+ stateObjectProperties->GetShaderIdentifier(
+ String(entryPointLayout->getName()).toWString().begin()),
+ sizeof(ShaderIdentifier));
+ rayGenIdentifiers.add(identifier);
+ break;
+ case SLANG_STAGE_MISS:
+ memcpy(
+ &identifier,
+ stateObjectProperties->GetShaderIdentifier(
+ String(entryPointLayout->getName()).toWString().begin()),
+ sizeof(ShaderIdentifier));
+ missIdentifiers.add(identifier);
+ break;
+ default:
+ break;
+ }
+ }
+ for (int i = 0; i < desc.shaderTableHitGroupCount; i++)
+ {
+ StringBuilder hitgroupName;
+ hitgroupName << "hitgroup_" << desc.shaderTableHitGroupIndices[i];
+ ShaderIdentifier hitgroupIdentifier;
+ memcpy(
+ &hitgroupIdentifier,
+ stateObjectProperties->GetShaderIdentifier(hitgroupName.toWString().begin()),
+ sizeof(ShaderIdentifier));
+ hitgroupIdentifiers.add(hitgroupIdentifier);
+ }
+
+ auto createShaderTableResource = [&](ArrayView<ShaderIdentifier> content,
+ RefPtr<BufferResourceImpl>& outResource) -> Result
+ {
+ IBufferResource::Desc bufferDesc = {};
+ bufferDesc.type = IResource::Type::Buffer;
+ bufferDesc.defaultState = ResourceState::ShaderResource;
+ bufferDesc.allowedStates = ResourceStateSet(
+ ResourceState::CopySource,
+ ResourceState::UnorderedAccess,
+ ResourceState::ShaderResource);
+ bufferDesc.elementSize = 0;
+ bufferDesc.sizeInBytes = content.getCount() * sizeof(ShaderIdentifier);
+ bufferDesc.format = Format::Unknown;
+ ComPtr<IBufferResource> shaderTableResource;
+ SLANG_RETURN_ON_FAIL(device->createBufferResource(
+ bufferDesc, content.getBuffer(), shaderTableResource.writeRef()));
+ outResource = static_cast<BufferResourceImpl*>(shaderTableResource.get());
+ return SLANG_OK;
+ };
+
+ if (desc.shaderTableHitGroupCount)
+ {
+ SLANG_RETURN_ON_FAIL(
+ createShaderTableResource(hitgroupIdentifiers.getArrayView(), m_hitgroupShaderTable));
+ m_dispatchDesc.HitGroupTable.SizeInBytes =
+ (uint64_t)(sizeof(ShaderIdentifier)) * desc.shaderTableHitGroupCount;
+ m_dispatchDesc.HitGroupTable.StrideInBytes = sizeof(ShaderIdentifier);
+ m_dispatchDesc.HitGroupTable.StartAddress = m_hitgroupShaderTable->getDeviceAddress();
+ }
+ if (rayGenIdentifiers.getCount())
+ {
+ SLANG_RETURN_ON_FAIL(
+ createShaderTableResource(rayGenIdentifiers.getArrayView(), m_rayGenShaderTable));
+ m_dispatchDesc.RayGenerationShaderRecord.SizeInBytes = sizeof(ShaderIdentifier);
+ m_dispatchDesc.RayGenerationShaderRecord.StartAddress = m_rayGenShaderTable->getDeviceAddress();
+ }
+ if (missIdentifiers.getCount())
+ {
+ SLANG_RETURN_ON_FAIL(
+ createShaderTableResource(missIdentifiers.getArrayView(), m_missShaderTable));
+ m_dispatchDesc.MissShaderTable.SizeInBytes =
+ (uint64_t)(sizeof(ShaderIdentifier)) * missIdentifiers.getCount();
+ m_dispatchDesc.MissShaderTable.StrideInBytes = sizeof(ShaderIdentifier);
+ m_dispatchDesc.MissShaderTable.StartAddress = m_missShaderTable->getDeviceAddress();
+ }
+ return SLANG_OK;
+}
+
#endif // SLANG_GFX_HAS_DXR_SUPPORT
+
Result D3D12Device::ShaderObjectImpl::setResource(ShaderOffset const& offset, IResourceView* resourceView)
{
if (offset.bindingRangeIndex < 0)
diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp
index 067581559..50cacc6c2 100644
--- a/tools/gfx/debug-layer.cpp
+++ b/tools/gfx/debug-layer.cpp
@@ -705,6 +705,7 @@ DebugCommandBuffer::DebugCommandBuffer()
m_renderCommandEncoder.commandBuffer = this;
m_computeCommandEncoder.commandBuffer = this;
m_resourceCommandEncoder.commandBuffer = this;
+ m_rayTracingCommandEncoder.commandBuffer = this;
}
void DebugCommandBuffer::encodeRenderCommands(
@@ -1084,6 +1085,25 @@ void DebugRayTracingCommandEncoder::memoryBarrier(
baseObject->memoryBarrier(count, innerAS.getBuffer(), sourceAccess, destAccess);
}
+void DebugRayTracingCommandEncoder::bindPipeline(
+ IPipelineState* state, IShaderObject** outRootObject)
+{
+ SLANG_GFX_API_FUNC;
+ auto innerPipeline = getInnerObj(state);
+ baseObject->bindPipeline(innerPipeline, commandBuffer->rootObject.baseObject.writeRef());
+ *outRootObject = &commandBuffer->rootObject;
+}
+
+void DebugRayTracingCommandEncoder::dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth)
+{
+ SLANG_GFX_API_FUNC;
+ baseObject->dispatchRays(rayGenShaderName, width, height, depth);
+}
+
const ICommandQueue::Desc& DebugCommandQueue::getDesc()
{
SLANG_GFX_API_FUNC;
diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h
index 7433db966..c7de48149 100644
--- a/tools/gfx/debug-layer.h
+++ b/tools/gfx/debug-layer.h
@@ -351,6 +351,13 @@ public:
IAccelerationStructure* const* structures,
AccessFlag::Enum sourceAccess,
AccessFlag::Enum destAccess) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL
+ bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth) override;
public:
DebugCommandBuffer* commandBuffer;
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 2eb19b6e9..bb80c4f53 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -605,6 +605,14 @@ Result RendererBase::maybeSpecializePipeline(
pipelineDesc, specializedPipelineComPtr.writeRef()));
break;
}
+ case PipelineType::RayTracing:
+ {
+ auto pipelineDesc = currentPipeline->desc.rayTracing;
+ pipelineDesc.program = specializedProgram;
+ SLANG_RETURN_ON_FAIL(createRayTracingPipelineState(
+ pipelineDesc, specializedPipelineComPtr.writeRef()));
+ break;
+ }
default:
break;
}
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 1f0a3eaab..31a7566a2 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -766,7 +766,7 @@ public:
auto bindingRangeIndex = offset.bindingRangeIndex;
auto bindingRange = layout->getBindingRange(bindingRangeIndex);
- auto objectIndex = bindingRange.subObjectIndex + offset.bindingArrayIndex;
+ Slang::Index objectIndex = bindingRange.subObjectIndex + offset.bindingArrayIndex;
if (objectIndex >= m_userProvidedSpecializationArgs.getCount())
m_userProvidedSpecializationArgs.setCount(objectIndex + 1);
if (!m_userProvidedSpecializationArgs[objectIndex])
@@ -816,7 +816,7 @@ public:
subObjectIndexInRange++)
{
ExtendedShaderObjectTypeList typeArgs;
- auto objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange;
+ Slang::Index objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange;
auto subObject = m_objects[objectIndex];
if (!subObject)
@@ -932,9 +932,19 @@ public:
PipelineType type;
GraphicsPipelineStateDesc graphics;
ComputePipelineStateDesc compute;
+ RayTracingPipelineStateDesc rayTracing;
ShaderProgramBase* getProgram()
{
- return static_cast<ShaderProgramBase*>(type == PipelineType::Compute ? compute.program : graphics.program);
+ switch (type)
+ {
+ case PipelineType::Compute:
+ return static_cast<ShaderProgramBase*>(compute.program);
+ case PipelineType::Graphics:
+ return static_cast<ShaderProgramBase*>(graphics.program);
+ case PipelineType::RayTracing:
+ return static_cast<ShaderProgramBase*>(rayTracing.program);
+ }
+ return nullptr;
}
} desc;
@@ -1105,6 +1115,8 @@ public:
public:
ExtendedShaderObjectTypeList specializationArgs;
// Given current pipeline and root shader object binding, generate and bind a specialized pipeline if necessary.
+ // The newly specialized pipeline is held alive by the pipeline cache so users of `outNewPipeline` do not
+ // need to maintain its lifespan.
Result maybeSpecializePipeline(
PipelineStateBase* currentPipeline,
ShaderObjectBase* rootObject,
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index bc0271aa6..592cbaac1 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -1266,7 +1266,7 @@ public:
vkPushConstantRange.size = ordinaryDataSize;
vkPushConstantRange.stageFlags = VK_SHADER_STAGE_ALL; // TODO: be more clever
- while(m_ownPushConstantRanges.getCount() <= pushConstantRangeIndex)
+ while((uint32_t)m_ownPushConstantRanges.getCount() <= pushConstantRangeIndex)
{
VkPushConstantRange emptyRange = { 0 };
m_ownPushConstantRanges.add(emptyRange);
@@ -2995,7 +2995,7 @@ public:
case slang::BindingType::ConstantBuffer:
{
BindingOffset objOffset = rangeOffset;
- for (uint32_t i = 0; i < count; ++i)
+ for (Index i = 0; i < count; ++i)
{
// Binding a constant buffer sub-object is simple enough:
// we just call `bindAsConstantBuffer` on it to bind
@@ -3016,7 +3016,7 @@ public:
case slang::BindingType::ParameterBlock:
{
BindingOffset objOffset = rangeOffset;
- for (uint32_t i = 0; i < count; ++i)
+ for (Index i = 0; i < count; ++i)
{
// The case for `ParameterBlock<X>` is not that different
// from `ConstantBuffer<X>`, except that we call `bindAsParameterBlock`
@@ -3047,7 +3047,7 @@ public:
//
SimpleBindingOffset objOffset = rangeOffset.pending;
SimpleBindingOffset objStride = rangeStride.pending;
- for (uint32_t i = 0; i < count; ++i)
+ for (Index i = 0; i < count; ++i)
{
// An existential-type sub-object is always bound just as a value,
// which handles its nested bindings and descriptor sets, but
@@ -4258,6 +4258,25 @@ public:
_memoryBarrier(count, structures, srcAccess, destAccess);
}
+ virtual SLANG_NO_THROW void SLANG_MCALL
+ bindPipeline(IPipelineState* pipeline, IShaderObject** outRootObject) override
+ {
+ SLANG_UNUSED(pipeline);
+ SLANG_UNUSED(outRootObject);
+ }
+
+ virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth) override
+ {
+ SLANG_UNUSED(rayGenShaderName);
+ SLANG_UNUSED(width);
+ SLANG_UNUSED(height);
+ SLANG_UNUSED(depth);
+ }
+
virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override
{
}