From c6f6ce12ec522b193b42bcd12d3a2540c7a6ff92 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 28 Jul 2021 12:24:12 -0700 Subject: Experimental DXR1.0 support in gfx. (#1915) * Experimental DXR1.0 support in gfx. - Add `dispatchRays` command. - Add `createRayTracingPipelineState` method to construct a D3D ray tracing state object from a linked slang program and user specified shader table. Limitations/simplifications: no local root signature support, shader table entries contains only shader identifiers and is specified at pipeline creation time, owned by the pipeline state object. * Root object binding for raytracing pipelines. * `maybeSpecializePipeline` implementation for raytracing pipelines. * Add ray-tracing-pipeline example. * Fixes. * Update README.md * Update comments on the lifespan of specialized pipelines Co-authored-by: Yong He Co-authored-by: jsmall-nvidia --- .../ray-tracing-pipeline.vcxproj | 193 ++++++ .../ray-tracing-pipeline.vcxproj.filters | 18 + examples/ray-tracing-pipeline/README.md | 9 + examples/ray-tracing-pipeline/main.cpp | 665 +++++++++++++++++++++ examples/ray-tracing-pipeline/shaders.slang | 108 ++++ premake5.lua | 1 + slang-gfx.h | 23 +- slang.sln | 11 + tools/gfx/d3d12/render-d3d12.cpp | 381 +++++++++++- tools/gfx/debug-layer.cpp | 20 + tools/gfx/debug-layer.h | 7 + tools/gfx/renderer-shared.cpp | 8 + tools/gfx/renderer-shared.h | 18 +- tools/gfx/vulkan/render-vk.cpp | 27 +- 14 files changed, 1457 insertions(+), 32 deletions(-) create mode 100644 build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj create mode 100644 build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj.filters create mode 100644 examples/ray-tracing-pipeline/README.md create mode 100644 examples/ray-tracing-pipeline/main.cpp create mode 100644 examples/ray-tracing-pipeline/shaders.slang diff --git a/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj b/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj new file mode 100644 index 000000000..b439eeb84 --- /dev/null +++ b/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj @@ -0,0 +1,193 @@ + + + + + Debug + Win32 + + + Debug + x64 + + + Release + Win32 + + + Release + x64 + + + + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC} + true + Win32Proj + ray-tracing-pipeline + + + + Application + true + Unicode + v142 + + + Application + true + Unicode + v142 + + + Application + false + Unicode + v142 + + + Application + false + Unicode + v142 + + + + + + + + + + + + + + + + + + + true + ..\..\..\bin\windows-x86\debug\ + ..\..\..\intermediate\windows-x86\debug\ray-tracing-pipeline\ + ray-tracing-pipeline + .exe + + + true + ..\..\..\bin\windows-x64\debug\ + ..\..\..\intermediate\windows-x64\debug\ray-tracing-pipeline\ + ray-tracing-pipeline + .exe + + + false + ..\..\..\bin\windows-x86\release\ + ..\..\..\intermediate\windows-x86\release\ray-tracing-pipeline\ + ray-tracing-pipeline + .exe + + + false + ..\..\..\bin\windows-x64\release\ + ..\..\..\intermediate\windows-x64\release\ray-tracing-pipeline\ + ray-tracing-pipeline + .exe + + + + NotUsing + Level3 + _DEBUG;%(PreprocessorDefinitions) + ..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories) + EditAndContinue + Disabled + MultiThreadedDebug + + + Windows + true + + + + + NotUsing + Level3 + _DEBUG;%(PreprocessorDefinitions) + ..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories) + EditAndContinue + Disabled + MultiThreadedDebug + + + Windows + true + + + + + NotUsing + Level3 + NDEBUG;%(PreprocessorDefinitions) + ..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories) + Full + true + true + false + true + MultiThreaded + + + Windows + true + true + + + + + NotUsing + Level3 + NDEBUG;%(PreprocessorDefinitions) + ..\..\..;..\..\..\tools;%(AdditionalIncludeDirectories) + Full + true + true + false + true + MultiThreaded + + + Windows + true + true + + + + + + + + + + + {37BED5B5-23FA-D81F-8C0C-F1167867813A} + + + {DB00DA62-0533-4AFD-B59F-A67D5B3A0808} + + + {222F7498-B40C-4F3F-A704-DDEB91A4484A} + + + {F5ADB74E-02A7-44FB-AA3B-FC02F8AC7A4B} + + + {3565FE5E-4FA3-11EB-AE93-0242AC130002} + + + {F9BE7957-8399-899E-0C49-E714FDDD4B65} + + + + + + \ No newline at end of file diff --git a/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj.filters b/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj.filters new file mode 100644 index 000000000..650faecbb --- /dev/null +++ b/build/visual-studio/ray-tracing-pipeline/ray-tracing-pipeline.vcxproj.filters @@ -0,0 +1,18 @@ + + + + + {E9C7FDCE-D52A-8D73-7EB0-C5296AF258F6} + + + + + Source Files + + + + + Source Files + + + \ No newline at end of file diff --git a/examples/ray-tracing-pipeline/README.md b/examples/ray-tracing-pipeline/README.md new file mode 100644 index 000000000..48cec4c18 --- /dev/null +++ b/examples/ray-tracing-pipeline/README.md @@ -0,0 +1,9 @@ +Slang "Ray Tracing Pipeline" Example +====================================== + +The goal of this example is to demonstrate how to write shaders for ray-tracing pipelines in Slang. + +The `shaders.slang` file contains a set of ray-tracing shader entry-points that traces primary rays from camera and shade intersections with basic lighting + ray-traced shadows. The file also defines a vertex and a fragment shader entry point for displaying the ray-traced image produced by the compute shader. + +The `main.cpp` file contains the C++ application code, showing how to use the Slang API to load and compile the shader code, and how to use a graphics API abstraction layer implemented in `tools/gfx` to set-up and use ray-tracing pipelines (DXR 1.0 equivalent API). +Note that this abstraction layer is *not* required in order to work with Slang, and it is just there to help us write example and test applications more conveniently. diff --git a/examples/ray-tracing-pipeline/main.cpp b/examples/ray-tracing-pipeline/main.cpp new file mode 100644 index 000000000..3c83447b4 --- /dev/null +++ b/examples/ray-tracing-pipeline/main.cpp @@ -0,0 +1,665 @@ +// main.cpp + +// This file implements an example of hardware ray-tracing using +// Slang shaders and the `gfx` graphics API. + +#include +#include "slang-gfx.h" +#include "gfx-util/shader-cursor.h" +#include "tools/platform/window.h" +#include "tools/platform/vector-math.h" +#include "slang-com-ptr.h" +#include "source/core/slang-basic.h" +#include "examples/example-base/example-base.h" + +using namespace gfx; +using namespace Slang; + +struct Uniforms +{ + float screenWidth, screenHeight; + float focalLength = 24.0f, frameHeight = 24.0f; + float cameraDir[4]; + float cameraUp[4]; + float cameraRight[4]; + float cameraPosition[4]; + float lightDir[4]; +}; + +struct Vertex +{ + float position[3]; +}; + +// Define geometry data for our test scene. +// The scene contains a floor plane, and a cube placed on top of it at the center. +static const int kVertexCount = 24; +static const Vertex kVertexData[kVertexCount] = +{ + // Floor plane + {{-100.0f, 0, 100.0f}}, + {{100.0f, 0, 100.0f}}, + {{100.0f, 0, -100.0f}}, + {{-100.0f, 0, -100.0f}}, + // Cube face (+y). + {{-1.0f, 2.0, 1.0f}}, + {{1.0f, 2.0, 1.0f}}, + {{1.0f, 2.0, -1.0f}}, + {{-1.0f, 2.0, -1.0f}}, + // Cube face (+z). + {{-1.0f, 0.0, 1.0f}}, + {{1.0f, 0.0, 1.0f}}, + {{1.0f, 2.0, 1.0f}}, + {{-1.0f, 2.0, 1.0f}}, + // Cube face (-z). + {{-1.0f, 0.0, -1.0f}}, + {{-1.0f, 2.0, -1.0f}}, + {{1.0f, 2.0, -1.0f}}, + {{1.0f, 0.0, -1.0f}}, + // Cube face (-x). + {{-1.0f, 0.0, -1.0f}}, + {{-1.0f, 0.0, 1.0f}}, + {{-1.0f, 2.0, 1.0f}}, + {{-1.0f, 2.0, -1.0f}}, + // Cube face (+x). + {{1.0f, 2.0, -1.0f}}, + {{1.0f, 2.0, 1.0f}}, + {{1.0f, 0.0, 1.0f}}, + {{1.0f, 0.0, -1.0f}}, +}; +static const int kIndexCount = 36; +static const int kIndexData[kIndexCount] = +{ + 0, 1, 2, 0, 2, 3, + 4, 5, 6, 4, 6, 7, + 8, 9, 10, 8, 10, 11, + 12, 13, 14, 12, 14, 15, + 16, 17, 18, 16, 18, 19, + 20, 21, 22, 20, 22, 23 +}; + +struct Primitive +{ + float data[4]; + float color[4]; +}; +static const int kPrimitiveCount = 12; +static const Primitive kPrimitiveData[kPrimitiveCount] = +{ + {{0.0f, 1.0f, 0.0f, 0.0f}, {0.75f, 0.8f, 0.85f, 1.0f}}, + {{0.0f, 1.0f, 0.0f, 0.0f}, {0.75f, 0.8f, 0.85f, 1.0f}}, + {{0.0f, 1.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, + {{0.0f, 1.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, + {{0.0f, 0.0f, 1.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, + {{0.0f, 0.0f, 1.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, + {{0.0f, 0.0f, -1.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, + {{0.0f, 0.0f, -1.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, + {{-1.0f, 0.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, + {{-1.0f, 0.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, + {{1.0f, 0.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, + {{1.0f, 0.0f, 0.0f, 0.0f}, {0.95f, 0.85f, 0.05f, 1.0f}}, +}; + + +// We need to use a rasterization pipeline to copy the ray-traced image +// to the swapchain. To do so we need to render a full-screen triangle. +// We will define a small helper type that defines the data for such a triangle. +// +struct FullScreenTriangle +{ + struct Vertex + { + float position[2]; + }; + + enum + { + kVertexCount = 3 + }; + + static const Vertex kVertices[kVertexCount]; +}; +const FullScreenTriangle::Vertex FullScreenTriangle::kVertices[FullScreenTriangle::kVertexCount] = { + {{-1, -1}}, + {{-1, 3}}, + {{3, -1}}, +}; + +// The example application will be implemented as a `struct`, so that +// we can scope the resources it allocates without using global variables. +// +struct RayTracing : public WindowedAppBase +{ + + +Uniforms gUniforms = {}; + + +// Many Slang API functions return detailed diagnostic information +// (error messages, warnings, etc.) as a "blob" of data, or return +// a null blob pointer instead if there were no issues. +// +// For convenience, we define a subroutine that will dump the information +// in a diagnostic blob if one is produced, and skip it otherwise. +// +void diagnoseIfNeeded(slang::IBlob* diagnosticsBlob) +{ + if( diagnosticsBlob != nullptr ) + { + printf("%s", (const char*) diagnosticsBlob->getBufferPointer()); +#ifdef _WIN32 + _Win32OutputDebugString((const char*)diagnosticsBlob->getBufferPointer()); +#endif + } +} + +// Load and compile shader code from souce. +gfx::Result loadShaderProgram( + gfx::IDevice* device, + gfx::PipelineType pipelineType, + gfx::IShaderProgram** outProgram) +{ + ComPtr slangSession; + slangSession = device->getSlangSession(); + + ComPtr diagnosticsBlob; + slang::IModule* module = slangSession->loadModule("shaders", diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + if(!module) + return SLANG_FAIL; + + Slang::List componentTypes; + componentTypes.add(module); + if (pipelineType == PipelineType::RayTracing) + { + ComPtr entryPoint; + SLANG_RETURN_ON_FAIL(module->findEntryPointByName("rayGenShader", entryPoint.writeRef())); + componentTypes.add(entryPoint); + SLANG_RETURN_ON_FAIL(module->findEntryPointByName("missShader", entryPoint.writeRef())); + componentTypes.add(entryPoint); + SLANG_RETURN_ON_FAIL( + module->findEntryPointByName("closestHitShader", entryPoint.writeRef())); + componentTypes.add(entryPoint); + SLANG_RETURN_ON_FAIL( + module->findEntryPointByName("shadowRayHitShader", entryPoint.writeRef())); + componentTypes.add(entryPoint); + } + else + { + ComPtr entryPoint; + SLANG_RETURN_ON_FAIL(module->findEntryPointByName("vertexMain", entryPoint.writeRef())); + componentTypes.add(entryPoint); + SLANG_RETURN_ON_FAIL(module->findEntryPointByName("fragmentMain", entryPoint.writeRef())); + componentTypes.add(entryPoint); + } + + ComPtr linkedProgram; + SlangResult result = slangSession->createCompositeComponentType( + componentTypes.getBuffer(), + componentTypes.getCount(), + linkedProgram.writeRef(), + diagnosticsBlob.writeRef()); + diagnoseIfNeeded(diagnosticsBlob); + SLANG_RETURN_ON_FAIL(result); + + gfx::IShaderProgram::Desc programDesc = {}; + programDesc.pipelineType = pipelineType; + programDesc.slangProgram = linkedProgram; + SLANG_RETURN_ON_FAIL(device->createProgram(programDesc, outProgram)); + + return SLANG_OK; +} + +ComPtr gPresentPipelineState; +ComPtr gRenderPipelineState; +ComPtr gFullScreenVertexBuffer; +ComPtr gVertexBuffer; +ComPtr gIndexBuffer; +ComPtr gPrimitiveBuffer; +ComPtr gTransformBuffer; +ComPtr gPrimitiveBufferSRV; +ComPtr gInstanceBuffer; +ComPtr gBLASBuffer; +ComPtr gBLAS; +ComPtr gTLASBuffer; +ComPtr gTLAS; +ComPtr gResultTexture; +ComPtr gResultTextureUAV; + +uint64_t lastTime = 0; + +// glm::vec3 lightDir = normalize(glm::vec3(10, 10, 10)); +// glm::vec3 lightColor = glm::vec3(1, 1, 1); + +glm::vec3 cameraPosition = glm::vec3(-2.53f, 2.72f, 4.3f); +float cameraOrientationAngles[2] = {-0.475f, -0.35f}; // Spherical angles (theta, phi). + +float translationScale = 0.5f; +float rotationScale = 0.01f; + +// In order to control camera movement, we will +// use good old WASD +bool wPressed = false; +bool aPressed = false; +bool sPressed = false; +bool dPressed = false; + +bool isMouseDown = false; +float lastMouseX = 0.0f; +float lastMouseY = 0.0f; + +void setKeyState(platform::KeyCode key, bool state) +{ + switch (key) + { + default: + break; + case platform::KeyCode::W: + wPressed = state; + break; + case platform::KeyCode::A: + aPressed = state; + break; + case platform::KeyCode::S: + sPressed = state; + break; + case platform::KeyCode::D: + dPressed = state; + break; + } +} +void onKeyDown(platform::KeyEventArgs args) { setKeyState(args.key, true); } +void onKeyUp(platform::KeyEventArgs args) { setKeyState(args.key, false); } + +void onMouseDown(platform::MouseEventArgs args) +{ + isMouseDown = true; + lastMouseX = (float)args.x; + lastMouseY = (float)args.y; +} + +void onMouseMove(platform::MouseEventArgs args) +{ + if (isMouseDown) + { + float deltaX = args.x - lastMouseX; + float deltaY = args.y - lastMouseY; + + cameraOrientationAngles[0] += -deltaX * rotationScale; + cameraOrientationAngles[1] += -deltaY * rotationScale; + lastMouseX = (float)args.x; + lastMouseY = (float)args.y; + } +} +void onMouseUp(platform::MouseEventArgs args) { isMouseDown = false; } + +Slang::Result initialize() +{ + initializeBase("Ray Tracing Pipeline", 1024, 768); + gWindow->events.mouseMove = [this](const platform::MouseEventArgs& e) { onMouseMove(e); }; + gWindow->events.mouseUp = [this](const platform::MouseEventArgs& e) { onMouseUp(e); }; + gWindow->events.mouseDown = [this](const platform::MouseEventArgs& e) { onMouseDown(e); }; + gWindow->events.keyDown = [this](const platform::KeyEventArgs& e) { onKeyDown(e); }; + gWindow->events.keyUp = [this](const platform::KeyEventArgs& e) { onKeyUp(e); }; + + IBufferResource::Desc vertexBufferDesc; + vertexBufferDesc.type = IResource::Type::Buffer; + vertexBufferDesc.sizeInBytes = kVertexCount * sizeof(Vertex); + vertexBufferDesc.defaultState = ResourceState::ShaderResource; + gVertexBuffer = gDevice->createBufferResource(vertexBufferDesc, &kVertexData[0]); + if(!gVertexBuffer) return SLANG_FAIL; + + IBufferResource::Desc indexBufferDesc; + indexBufferDesc.type = IResource::Type::Buffer; + indexBufferDesc.sizeInBytes = kIndexCount * sizeof(int32_t); + indexBufferDesc.defaultState = ResourceState::ShaderResource; + gIndexBuffer = gDevice->createBufferResource(indexBufferDesc, &kIndexData[0]); + if (!gIndexBuffer) + return SLANG_FAIL; + + IBufferResource::Desc primitiveBufferDesc; + primitiveBufferDesc.type = IResource::Type::Buffer; + primitiveBufferDesc.sizeInBytes = kPrimitiveCount * sizeof(Primitive); + primitiveBufferDesc.defaultState = ResourceState::ShaderResource; + gPrimitiveBuffer = gDevice->createBufferResource(primitiveBufferDesc, &kPrimitiveData[0]); + if (!gPrimitiveBuffer) + return SLANG_FAIL; + + IResourceView::Desc primitiveSRVDesc = {}; + primitiveSRVDesc.format = Format::Unknown; + primitiveSRVDesc.type = IResourceView::Type::ShaderResource; + gPrimitiveBufferSRV = gDevice->createBufferView(gPrimitiveBuffer, primitiveSRVDesc); + + IBufferResource::Desc transformBufferDesc; + transformBufferDesc.type = IResource::Type::Buffer; + transformBufferDesc.sizeInBytes = sizeof(float) * 12; + transformBufferDesc.defaultState = ResourceState::ShaderResource; + float transformData[12] = { + 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f}; + gTransformBuffer = gDevice->createBufferResource(transformBufferDesc, &transformData); + if (!gTransformBuffer) + return SLANG_FAIL; + // Build bottom level acceleration structure. + { + IAccelerationStructure::BuildInputs accelerationStructureBuildInputs; + IAccelerationStructure::PrebuildInfo accelerationStructurePrebuildInfo; + accelerationStructureBuildInputs.descCount = 1; + accelerationStructureBuildInputs.kind = IAccelerationStructure::Kind::BottomLevel; + accelerationStructureBuildInputs.flags = + IAccelerationStructure::BuildFlags::AllowCompaction; + IAccelerationStructure::GeometryDesc geomDesc; + geomDesc.flags = IAccelerationStructure::GeometryFlags::Opaque; + geomDesc.type = IAccelerationStructure::GeometryType::Triangles; + geomDesc.content.triangles.indexCount = kIndexCount; + geomDesc.content.triangles.indexData = gIndexBuffer->getDeviceAddress(); + geomDesc.content.triangles.indexFormat = Format::R_UInt32; + geomDesc.content.triangles.vertexCount = kVertexCount; + geomDesc.content.triangles.vertexData = gVertexBuffer->getDeviceAddress(); + geomDesc.content.triangles.vertexFormat = Format::RGB_Float32; + geomDesc.content.triangles.vertexStride = sizeof(Vertex); + geomDesc.content.triangles.transform3x4 = gTransformBuffer->getDeviceAddress(); + accelerationStructureBuildInputs.geometryDescs = &geomDesc; + + // Query buffer size for acceleration structure build. + SLANG_RETURN_ON_FAIL(gDevice->getAccelerationStructurePrebuildInfo( + accelerationStructureBuildInputs, &accelerationStructurePrebuildInfo)); + // Allocate buffers for acceleration structure. + IBufferResource::Desc asDraftBufferDesc; + asDraftBufferDesc.type = IResource::Type::Buffer; + asDraftBufferDesc.defaultState = ResourceState::AccelerationStructure; + asDraftBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.resultDataMaxSize; + ComPtr draftBuffer = gDevice->createBufferResource(asDraftBufferDesc); + IBufferResource::Desc scratchBufferDesc; + scratchBufferDesc.type = IResource::Type::Buffer; + scratchBufferDesc.defaultState = ResourceState::UnorderedAccess; + scratchBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.scratchDataSize; + ComPtr scratchBuffer = gDevice->createBufferResource(scratchBufferDesc); + + // Build acceleration structure. + ComPtr compactedSizeQuery; + IQueryPool::Desc queryPoolDesc; + queryPoolDesc.count = 1; + queryPoolDesc.type = QueryType::AccelerationStructureCompactedSize; + SLANG_RETURN_ON_FAIL( + gDevice->createQueryPool(queryPoolDesc, compactedSizeQuery.writeRef())); + + ComPtr draftAS; + IAccelerationStructure::CreateDesc draftCreateDesc; + draftCreateDesc.buffer = draftBuffer; + draftCreateDesc.kind = IAccelerationStructure::Kind::BottomLevel; + draftCreateDesc.offset = 0; + draftCreateDesc.size = accelerationStructurePrebuildInfo.resultDataMaxSize; + SLANG_RETURN_ON_FAIL( + gDevice->createAccelerationStructure(draftCreateDesc, draftAS.writeRef())); + + auto commandBuffer = gTransientHeaps[0]->createCommandBuffer(); + auto encoder = commandBuffer->encodeRayTracingCommands(); + IAccelerationStructure::BuildDesc buildDesc = {}; + buildDesc.dest = draftAS; + buildDesc.inputs = accelerationStructureBuildInputs; + buildDesc.scratchData = scratchBuffer->getDeviceAddress(); + AccelerationStructureQueryDesc compactedSizeQueryDesc = {}; + compactedSizeQueryDesc.queryPool = compactedSizeQuery; + compactedSizeQueryDesc.queryType = QueryType::AccelerationStructureCompactedSize; + encoder->buildAccelerationStructure(buildDesc, 1, &compactedSizeQueryDesc); + encoder->endEncoding(); + commandBuffer->close(); + gQueue->executeCommandBuffer(commandBuffer); + gQueue->wait(); + + uint64_t compactedSize = 0; + compactedSizeQuery->getResult(0, 1, &compactedSize); + IBufferResource::Desc asBufferDesc; + asBufferDesc.type = IResource::Type::Buffer; + asBufferDesc.defaultState = ResourceState::AccelerationStructure; + asBufferDesc.sizeInBytes = compactedSize; + gBLASBuffer = gDevice->createBufferResource(asBufferDesc); + IAccelerationStructure::CreateDesc createDesc; + createDesc.buffer = gBLASBuffer; + createDesc.kind = IAccelerationStructure::Kind::BottomLevel; + createDesc.offset = 0; + createDesc.size = compactedSize; + gDevice->createAccelerationStructure(createDesc, gBLAS.writeRef()); + + commandBuffer = gTransientHeaps[0]->createCommandBuffer(); + encoder = commandBuffer->encodeRayTracingCommands(); + encoder->copyAccelerationStructure(gBLAS, draftAS, AccelerationStructureCopyMode::Compact); + encoder->endEncoding(); + commandBuffer->close(); + gQueue->executeCommandBuffer(commandBuffer); + gQueue->wait(); + } + + // Build top level acceleration structure. + { + List instanceDescs; + instanceDescs.setCount(1); + instanceDescs[0].accelerationStructure = gBLAS->getDeviceAddress(); + instanceDescs[0].flags = + IAccelerationStructure::GeometryInstanceFlags::TriangleFacingCullDisable; + instanceDescs[0].instanceContributionToHitGroupIndex = 0; + instanceDescs[0].instanceID = 0; + instanceDescs[0].instanceMask = 0xFF; + float transformMatrix[] = {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f}; + memcpy(&instanceDescs[0].transform[0][0], transformMatrix, sizeof(float) * 12); + + IBufferResource::Desc instanceBufferDesc; + instanceBufferDesc.type = IResource::Type::Buffer; + instanceBufferDesc.sizeInBytes = + instanceDescs.getCount() * sizeof(IAccelerationStructure::InstanceDesc); + instanceBufferDesc.defaultState = ResourceState::ShaderResource; + gInstanceBuffer = gDevice->createBufferResource(instanceBufferDesc, instanceDescs.getBuffer()); + if (!gInstanceBuffer) + return SLANG_FAIL; + + IAccelerationStructure::BuildInputs accelerationStructureBuildInputs = {}; + IAccelerationStructure::PrebuildInfo accelerationStructurePrebuildInfo = {}; + accelerationStructureBuildInputs.descCount = 1; + accelerationStructureBuildInputs.kind = IAccelerationStructure::Kind::TopLevel; + accelerationStructureBuildInputs.instanceDescs = gInstanceBuffer->getDeviceAddress(); + + // Query buffer size for acceleration structure build. + SLANG_RETURN_ON_FAIL(gDevice->getAccelerationStructurePrebuildInfo( + accelerationStructureBuildInputs, &accelerationStructurePrebuildInfo)); + + IBufferResource::Desc asBufferDesc; + asBufferDesc.type = IResource::Type::Buffer; + asBufferDesc.defaultState = ResourceState::AccelerationStructure; + asBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.resultDataMaxSize; + gTLASBuffer = gDevice->createBufferResource(asBufferDesc); + + IBufferResource::Desc scratchBufferDesc; + scratchBufferDesc.type = IResource::Type::Buffer; + scratchBufferDesc.defaultState = ResourceState::UnorderedAccess; + scratchBufferDesc.sizeInBytes = accelerationStructurePrebuildInfo.scratchDataSize; + ComPtr scratchBuffer = gDevice->createBufferResource(scratchBufferDesc); + + IAccelerationStructure::CreateDesc createDesc; + createDesc.buffer = gTLASBuffer; + createDesc.kind = IAccelerationStructure::Kind::TopLevel; + createDesc.offset = 0; + createDesc.size = accelerationStructurePrebuildInfo.resultDataMaxSize; + SLANG_RETURN_ON_FAIL(gDevice->createAccelerationStructure(createDesc, gTLAS.writeRef())); + + auto commandBuffer = gTransientHeaps[0]->createCommandBuffer(); + auto encoder = commandBuffer->encodeRayTracingCommands(); + IAccelerationStructure::BuildDesc buildDesc = {}; + buildDesc.dest = gTLAS; + buildDesc.inputs = accelerationStructureBuildInputs; + buildDesc.scratchData = scratchBuffer->getDeviceAddress(); + encoder->buildAccelerationStructure(buildDesc, 0, nullptr); + encoder->endEncoding(); + commandBuffer->close(); + gQueue->executeCommandBuffer(commandBuffer); + gQueue->wait(); + } + + IBufferResource::Desc fullScreenVertexBufferDesc; + fullScreenVertexBufferDesc.type = IResource::Type::Buffer; + fullScreenVertexBufferDesc.sizeInBytes = + FullScreenTriangle::kVertexCount * sizeof(FullScreenTriangle::Vertex); + fullScreenVertexBufferDesc.defaultState = ResourceState::VertexBuffer; + gFullScreenVertexBuffer = gDevice->createBufferResource( + fullScreenVertexBufferDesc, &FullScreenTriangle::kVertices[0]); + if (!gFullScreenVertexBuffer) + return SLANG_FAIL; + + InputElementDesc inputElements[] = { + {"POSITION", 0, Format::RG_Float32, offsetof(FullScreenTriangle::Vertex, position)}, + }; + auto inputLayout = gDevice->createInputLayout(&inputElements[0], SLANG_COUNT_OF(inputElements)); + if (!inputLayout) + return SLANG_FAIL; + + ComPtr shaderProgram; + SLANG_RETURN_ON_FAIL(loadShaderProgram(gDevice, PipelineType::Graphics, shaderProgram.writeRef())); + GraphicsPipelineStateDesc desc; + desc.inputLayout = inputLayout; + desc.program = shaderProgram; + desc.framebufferLayout = gFramebufferLayout; + gPresentPipelineState = gDevice->createGraphicsPipelineState(desc); + if (!gPresentPipelineState) + return SLANG_FAIL; + + ComPtr rayTracingProgram; + SLANG_RETURN_ON_FAIL( + loadShaderProgram(gDevice, PipelineType::RayTracing, rayTracingProgram.writeRef())); + RayTracingPipelineStateDesc rtpDesc = {}; + rtpDesc.program = rayTracingProgram; + rtpDesc.hitGroupCount = 2; + HitGroupDesc hitGroups[2]; + hitGroups[0].closestHitEntryPoint = "closestHitShader"; + hitGroups[1].closestHitEntryPoint = "shadowRayHitShader"; + rtpDesc.hitGroups = hitGroups; + rtpDesc.maxRayPayloadSize = 64; + rtpDesc.maxRecursion = 2; + rtpDesc.shaderTableHitGroupCount = 2; + int32_t shaderTable[] = {0, 1}; + rtpDesc.shaderTableHitGroupIndices = shaderTable; + SLANG_RETURN_ON_FAIL( + gDevice->createRayTracingPipelineState(rtpDesc, gRenderPipelineState.writeRef())); + if (!gRenderPipelineState) + return SLANG_FAIL; + + createResultTexture(); + return SLANG_OK; +} + +void createResultTexture() +{ + ITextureResource::Desc resultTextureDesc = {}; + resultTextureDesc.type = IResource::Type::Texture2D; + resultTextureDesc.numMipLevels = 1; + resultTextureDesc.size.width = windowWidth; + resultTextureDesc.size.height = windowHeight; + resultTextureDesc.size.depth = 1; + resultTextureDesc.defaultState = ResourceState::UnorderedAccess; + resultTextureDesc.format = Format::RGBA_Float16; + gResultTexture = gDevice->createTextureResource(resultTextureDesc); + IResourceView::Desc resultUAVDesc = {}; + resultUAVDesc.format = resultTextureDesc.format; + resultUAVDesc.type = IResourceView::Type::UnorderedAccess; + gResultTextureUAV = gDevice->createTextureView(gResultTexture, resultUAVDesc); +} + +virtual void windowSizeChanged() override +{ + WindowedAppBase::windowSizeChanged(); + createResultTexture(); +} + +glm::vec3 getVectorFromSphericalAngles(float theta, float phi) +{ + auto sinTheta = sin(theta); + auto cosTheta = cos(theta); + auto sinPhi = sin(phi); + auto cosPhi = cos(phi); + return glm::vec3(-sinTheta * cosPhi, sinPhi, -cosTheta * cosPhi); +} +void updateUniforms() +{ + gUniforms.screenWidth = (float)windowWidth; + gUniforms.screenHeight = (float)windowHeight; + if (!lastTime) + lastTime = getCurrentTime(); + uint64_t currentTime = getCurrentTime(); + float deltaTime = float(double(currentTime - lastTime) / double(getTimerFrequency())); + lastTime = currentTime; + + auto camDir = + getVectorFromSphericalAngles(cameraOrientationAngles[0], cameraOrientationAngles[1]); + auto camUp = getVectorFromSphericalAngles( + cameraOrientationAngles[0], cameraOrientationAngles[1] + glm::pi() * 0.5f); + auto camRight = glm::cross(camDir, camUp); + + glm::vec3 movement = glm::vec3(0); + if (wPressed) + movement += camDir; + if (sPressed) + movement -= camDir; + if (aPressed) + movement -= camRight; + if (dPressed) + movement += camRight; + + cameraPosition += deltaTime * translationScale * movement; + + memcpy(gUniforms.cameraDir, &camDir, sizeof(float) * 3); + memcpy(gUniforms.cameraUp, &camUp, sizeof(float) * 3); + memcpy(gUniforms.cameraRight, &camRight, sizeof(float) * 3); + memcpy(gUniforms.cameraPosition, &cameraPosition, sizeof(float) * 3); + auto lightDir = glm::normalize(glm::vec3(1.0f, 3.0f, 2.0f)); + memcpy(gUniforms.lightDir, &lightDir, sizeof(float) * 3); +} + +virtual void renderFrame(int frameBufferIndex) override +{ + updateUniforms(); + { + ComPtr renderCommandBuffer = + gTransientHeaps[frameBufferIndex]->createCommandBuffer(); + auto renderEncoder = renderCommandBuffer->encodeRayTracingCommands(); + IShaderObject* rootObject = nullptr; + renderEncoder->bindPipeline(gRenderPipelineState, &rootObject); + auto cursor = ShaderCursor(rootObject); + cursor["resultTexture"].setResource(gResultTextureUAV); + cursor["uniforms"].setData(&gUniforms, sizeof(Uniforms)); + cursor["sceneBVH"].setResource(gTLAS); + cursor["primitiveBuffer"].setResource(gPrimitiveBufferSRV); + renderEncoder->dispatchRays(nullptr, windowWidth, windowHeight, 1); + renderEncoder->endEncoding(); + renderCommandBuffer->close(); + gQueue->executeCommandBuffer(renderCommandBuffer); + } + + { + ComPtr presentCommandBuffer = + gTransientHeaps[frameBufferIndex]->createCommandBuffer(); + auto presentEncoder = presentCommandBuffer->encodeRenderCommands( + gRenderPass, gFramebuffers[frameBufferIndex]); + gfx::Viewport viewport = {}; + viewport.maxZ = 1.0f; + viewport.extentX = (float)windowWidth; + viewport.extentY = (float)windowHeight; + presentEncoder->setViewportAndScissor(viewport); + auto rootObject = presentEncoder->bindPipeline(gPresentPipelineState); + auto cursor = ShaderCursor(rootObject->getEntryPoint(1)); + cursor["t"].setResource(gResultTextureUAV); + presentEncoder->setVertexBuffer( + 0, gFullScreenVertexBuffer, sizeof(FullScreenTriangle::Vertex)); + presentEncoder->setPrimitiveTopology(PrimitiveTopology::TriangleList); + presentEncoder->draw(3); + presentEncoder->endEncoding(); + presentCommandBuffer->close(); + gQueue->executeCommandBuffer(presentCommandBuffer); + } + // With that, we are done drawing for one frame, and ready for the next. + // + gSwapchain->present(); +} + +}; + +// This macro instantiates an appropriate main function to +// run the application defined above. +PLATFORM_UI_MAIN(innerMain) diff --git a/examples/ray-tracing-pipeline/shaders.slang b/examples/ray-tracing-pipeline/shaders.slang new file mode 100644 index 000000000..77193f08e --- /dev/null +++ b/examples/ray-tracing-pipeline/shaders.slang @@ -0,0 +1,108 @@ +// shaders.slang + +struct Uniforms +{ + float screenWidth, screenHeight; + float focalLength, frameHeight; + float4 cameraDir; + float4 cameraUp; + float4 cameraRight; + float4 cameraPosition; + float4 lightDir; +}; + +struct Primitive +{ + float4 data0; + float4 color; + float3 getNormal() { return data0.xyz; } + float3 getColor() { return color.xyz; } +}; + +struct RayPayload +{ + float4 color; +}; + +uniform RWTexture2D resultTexture; +uniform RaytracingAccelerationStructure sceneBVH; +uniform StructuredBuffer primitiveBuffer; +uniform Uniforms uniforms; + +[shader("raygeneration")] +void rayGenShader() +{ + uint2 threadIdx = DispatchRaysIndex().xy; + if (threadIdx.x >= (int)uniforms.screenWidth) return; + if (threadIdx.y >= (int)uniforms.screenHeight) return; + + float frameWidth = uniforms.screenWidth / uniforms.screenHeight * uniforms.frameHeight; + float imageY = (threadIdx.y / uniforms.screenHeight - 0.5f) * uniforms.frameHeight; + float imageX = (threadIdx.x / uniforms.screenWidth - 0.5f) * frameWidth; + float imageZ = uniforms.focalLength; + float3 rayDir = normalize(uniforms.cameraDir.xyz*imageZ - uniforms.cameraUp.xyz * imageY + uniforms.cameraRight.xyz * imageX); + + // Trace the ray. + RayDesc ray; + ray.Origin = uniforms.cameraPosition.xyz; + ray.Direction = rayDir; + ray.TMin = 0.001; + ray.TMax = 10000.0; + RayPayload payload = { float4(0, 0, 0, 0) }; + TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 0, 0, 0, ray, payload); + + resultTexture[threadIdx.xy] = payload.color; +} + +[shader("miss")] +void missShader(inout RayPayload payload) +{ + payload.color = float4(0, 0, 0, 1); +} + +[shader("closesthit")] +void closestHitShader(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr) +{ + float3 hitLocation = WorldRayOrigin() + WorldRayDirection() * RayTCurrent(); + float3 shadowRayDir = uniforms.lightDir.xyz; + + RayDesc ray; + ray.Origin = hitLocation; + ray.Direction = shadowRayDir; + ray.TMin = 0.001; + ray.TMax = 10000.0; + RayPayload shadowPayload = { float4(0, 0, 0, 0) }; + TraceRay(sceneBVH, RAY_FLAG_ACCEPT_FIRST_HIT_AND_END_SEARCH, ~0, 1, 0, 0, ray, shadowPayload); + float shadow = 1.0 - shadowPayload.color.x; + + let primitiveIndex = PrimitiveIndex(); + float3 normal = primitiveBuffer[primitiveIndex].getNormal(); + float3 color = primitiveBuffer[primitiveIndex].getColor(); + float ndotl = max(0.0, shadow * dot(normal, uniforms.lightDir.xyz)); + float intensity = ndotl * 0.7 + 0.3; + payload.color = float4(color * intensity, 1.0f); +} + +[shader("closesthit")] +void shadowRayHitShader(inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr) +{ + payload.color = float4(1.0, 1.0, 1.0, 1.0); +} + +/// Vertex and fragment shader for displaying the final image. + +[shader("vertex")] +float4 vertexMain(float2 position : POSITION) + : SV_Position +{ + return float4(position, 0.5, 1.0); +} + +[shader("fragment")] +float4 fragmentMain( + float4 sv_position : SV_Position, + uniform RWTexture2D t) + : SV_Target +{ + return t.Load(sv_position.xy); +} diff --git a/premake5.lua b/premake5.lua index 2ed40ba16..71af44d4a 100644 --- a/premake5.lua +++ b/premake5.lua @@ -653,6 +653,7 @@ example "hello-world" example "triangle" example "ray-tracing" +example "ray-tracing-pipeline" example "gpu-printing" kind "ConsoleApp" diff --git a/slang-gfx.h b/slang-gfx.h index f76788b03..f4a70d25c 100644 --- a/slang-gfx.h +++ b/slang-gfx.h @@ -919,10 +919,20 @@ struct RayTracingPipelineFlags }; }; +struct HitGroupDesc +{ + const char* closestHitEntryPoint = nullptr; + const char* anyHitEntryPoint = nullptr; + const char* intersectionEntryPoint = nullptr; +}; + struct RayTracingPipelineStateDesc { IShaderProgram* program = nullptr; - + int32_t hitGroupCount; + const HitGroupDesc* hitGroups; + int32_t shaderTableHitGroupCount; + int32_t* shaderTableHitGroupIndices; int maxRecursion; int maxRayPayloadSize; RayTracingPipelineFlags::Enum flags; @@ -1191,6 +1201,17 @@ public: IAccelerationStructure* const* structures, AccessFlag::Enum sourceAccess, AccessFlag::Enum destAccess) = 0; + + virtual SLANG_NO_THROW void SLANG_MCALL + bindPipeline(IPipelineState* state, IShaderObject** outRootObject) = 0; + /// Issues a dispatch command to start ray tracing workload with a ray tracing pipeline. + /// `rayGenShaderName` specifies the name of the ray generation shader to launch. Pass nullptr for + /// the first ray generation shader defined in `raytracingPipeline`. + virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) = 0; }; #define SLANG_UUID_IRayTracingCommandEncoder \ { \ diff --git a/slang.sln b/slang.sln index 1551b7110..c0e80b982 100644 --- a/slang.sln +++ b/slang.sln @@ -35,6 +35,8 @@ Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "model-viewer", "build\visua EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ray-tracing", "build\visual-studio\ray-tracing\ray-tracing.vcxproj", "{71AC0F50-5DFD-FA91-8661-E95372118EFB}" EndProject +Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "ray-tracing-pipeline", "build\visual-studio\ray-tracing-pipeline\ray-tracing-pipeline.vcxproj", "{17BA8E32-034E-84DA-6C12-DE8E58C5BECC}" +EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "shader-object", "build\visual-studio\shader-object\shader-object.vcxproj", "{25512BFB-1138-EDF2-BA88-5310A64E6659}" EndProject Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "shader-toy", "build\visual-studio\shader-toy\shader-toy.vcxproj", "{0FC5DE93-FBEA-A8FA-E430-2EC6D0F5CDC6}" @@ -195,6 +197,14 @@ Global {71AC0F50-5DFD-FA91-8661-E95372118EFB}.Release|Win32.Build.0 = Release|Win32 {71AC0F50-5DFD-FA91-8661-E95372118EFB}.Release|x64.ActiveCfg = Release|x64 {71AC0F50-5DFD-FA91-8661-E95372118EFB}.Release|x64.Build.0 = Release|x64 + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Debug|Win32.ActiveCfg = Debug|Win32 + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Debug|Win32.Build.0 = Debug|Win32 + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Debug|x64.ActiveCfg = Debug|x64 + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Debug|x64.Build.0 = Debug|x64 + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Release|Win32.ActiveCfg = Release|Win32 + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Release|Win32.Build.0 = Release|Win32 + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Release|x64.ActiveCfg = Release|x64 + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC}.Release|x64.Build.0 = Release|x64 {25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|Win32.ActiveCfg = Debug|Win32 {25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|Win32.Build.0 = Debug|Win32 {25512BFB-1138-EDF2-BA88-5310A64E6659}.Debug|x64.ActiveCfg = Debug|x64 @@ -293,6 +303,7 @@ Global {010BE414-ED5B-CF56-16C0-BD18027062C0} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} {2F8724C6-1BC3-2730-84D5-3F277030D04A} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} {71AC0F50-5DFD-FA91-8661-E95372118EFB} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} + {17BA8E32-034E-84DA-6C12-DE8E58C5BECC} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} {25512BFB-1138-EDF2-BA88-5310A64E6659} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} {0FC5DE93-FBEA-A8FA-E430-2EC6D0F5CDC6} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} {3BB99068-27C9-3C39-9082-A1577CB12BD2} = {EB5FC2C6-D72D-B6CC-C0C1-26F3AC2E9231} diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index 24a1fd93e..7e436d28d 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -10,6 +10,7 @@ #include "../d3d/d3d-swapchain.h" #include "core/slang-blob.h" #include "core/slang-basic.h" +#include "core/slang-chunked-list.h" // In order to use the Slang API, we need to include its header @@ -123,8 +124,6 @@ public: const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override; virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState( const ComputePipelineStateDesc& desc, IPipelineState** outState) override; - virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState( - const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override; virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool( const IQueryPool::Desc& desc, IQueryPool** outState) override; @@ -156,6 +155,8 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL createAccelerationStructure( const IAccelerationStructure::CreateDesc& desc, IAccelerationStructure** outView) override; + virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState( + const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override; #endif public: @@ -193,6 +194,7 @@ public: virtual void setRootDescriptorTable(int index, D3D12_GPU_DESCRIPTOR_HANDLE BaseDescriptor) = 0; virtual void setRootSignature(ID3D12RootSignature* rootSignature) = 0; virtual void setRootConstants(Index rootParamIndex, Index dstOffsetIn32BitValues, Index countOf32BitValues, void const* srcData) = 0; + virtual void setPipelineState(PipelineStateBase* pipelineState) = 0; }; class BufferResourceImpl: public gfx::BufferResource @@ -340,6 +342,31 @@ public: } }; +#if SLANG_GFX_HAS_DXR_SUPPORT + class RayTracingPipelineStateImpl : public PipelineStateBase + { + public: + ComPtr m_stateObject; + D3D12_DISPATCH_RAYS_DESC m_dispatchDesc = {}; + Dictionary m_mapRayGenShaderNameToShaderTableIndex; + // Shader Tables for each ray-tracing stage stored in GPU memory. + RefPtr m_rayGenShaderTable; + RefPtr m_hitgroupShaderTable; + RefPtr m_missShaderTable; + void init(const RayTracingPipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::RayTracing; + pipelineDesc.rayTracing = inDesc; + initializeBase(pipelineDesc); + } + Result createShaderTables( + D3D12Device* device, + slang::IComponentType* slangProgram, + const RayTracingPipelineStateDesc& desc); + }; +#endif + class QueryPoolImpl : public IQueryPool, public ComObject { public: @@ -461,6 +488,11 @@ public: { m_commandList->SetGraphicsRoot32BitConstants(UINT(rootParamIndex), UINT(countOf32BitValues), srcData, UINT(dstOffsetIn32BitValues)); } + virtual void setPipelineState(PipelineStateBase* pipeline) override + { + auto pipelineImpl = static_cast(pipeline); + m_commandList->SetPipelineState(pipelineImpl->m_pipelineState.get()); + } GraphicsSubmitter(ID3D12GraphicsCommandList* commandList): m_commandList(commandList) @@ -492,7 +524,11 @@ public: { m_commandList->SetComputeRoot32BitConstants(UINT(rootParamIndex), UINT(countOf32BitValues), srcData, UINT(dstOffsetIn32BitValues)); } - + virtual void setPipelineState(PipelineStateBase* pipeline) override + { + auto pipelineImpl = static_cast(pipeline); + m_commandList->SetPipelineState(pipelineImpl->m_pipelineState.get()); + } ComputeSubmitter(ID3D12GraphicsCommandList* commandList) : m_commandList(commandList) { @@ -568,6 +604,7 @@ public: { uint64_t waitValue; HANDLE fenceEvent; + ID3D12Fence* fence = nullptr; }; ShortList m_waitInfos; @@ -585,7 +622,7 @@ public: m_waitInfos[i].fenceEvent = CreateEventEx( nullptr, false, - CREATE_EVENT_INITIAL_SET | CREATE_EVENT_MANUAL_RESET, + 0, EVENT_ALL_ACCESS); } return m_waitInfos[queueIndex]; @@ -666,7 +703,7 @@ public: ID3D12GraphicsCommandList* m_d3dCmdList; ID3D12GraphicsCommandList* m_preCmdList = nullptr; - RefPtr m_currentPipeline; + RefPtr m_currentPipeline; static int getBindPointIndex(PipelineType type) { @@ -690,13 +727,14 @@ public: m_d3dCmdList = m_commandBuffer->m_cmdList; m_renderer = commandBuffer->m_renderer; m_transientHeap = commandBuffer->m_transientHeap; + m_device = commandBuffer->m_renderer->m_device; } void endEncodingImpl() { m_isOpen = false; } Result bindPipelineImpl(IPipelineState* pipelineState, IShaderObject** outRootObject) { - m_currentPipeline = static_cast(pipelineState); + m_currentPipeline = static_cast(pipelineState); auto rootObject = &m_commandBuffer->m_rootShaderObject; SLANG_RETURN_ON_FAIL(rootObject->reset( m_renderer, @@ -707,7 +745,11 @@ public: return SLANG_OK; } - Result _bindRenderState(Submitter* submitter); + /// Specializes the pipeline according to current root-object argument values, + /// applys the root object bindings and binds the pipeline state. + /// The newly specialized pipeline is held alive by the pipeline cache so users of + /// `newPipeline` do not need to maintain its lifespan. + Result _bindRenderState(Submitter* submitter, RefPtr& newPipeline); }; struct DescriptorTable @@ -2956,7 +2998,6 @@ public: { PipelineCommandEncoder::init(cmdBuffer); m_preCmdList = nullptr; - m_device = renderer->m_device; m_renderPass = renderPass; m_framebuffer = framebuffer; m_transientHeap = transientHeap; @@ -3174,7 +3215,8 @@ public: // Submit - setting for graphics { GraphicsSubmitter submitter(m_d3dCmdList); - if(SLANG_FAILED(_bindRenderState(&submitter))) + RefPtr newPipeline; + if(SLANG_FAILED(_bindRenderState(&submitter, newPipeline))) { assert(!"Failed to bind render state"); } @@ -3314,7 +3356,6 @@ public: { PipelineCommandEncoder::init(cmdBuffer); m_preCmdList = nullptr; - m_device = renderer->m_device; m_transientHeap = transientHeap; m_currentPipeline = nullptr; } @@ -3330,7 +3371,8 @@ public: // Submit binding for compute { ComputeSubmitter submitter(m_d3dCmdList); - if(SLANG_FAILED(_bindRenderState(&submitter))) + RefPtr newPipeline; + if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline))) { assert(!"Failed to bind render state"); } @@ -3402,12 +3444,15 @@ public: } #if SLANG_GFX_HAS_DXR_SUPPORT - class RayTracingCommandEncoderImpl : public IRayTracingCommandEncoder + class RayTracingCommandEncoderImpl + : public IRayTracingCommandEncoder + , public PipelineCommandEncoder { public: CommandBufferImpl* m_commandBuffer; void init(D3D12Device* renderer, CommandBufferImpl* commandBuffer) { + PipelineCommandEncoder::init(commandBuffer); m_commandBuffer = commandBuffer; } virtual SLANG_NO_THROW void SLANG_MCALL buildAccelerationStructure( @@ -3434,6 +3479,13 @@ public: IAccelerationStructure* const* structures, AccessFlag::Enum sourceAccess, AccessFlag::Enum destAccess) override; + virtual SLANG_NO_THROW void SLANG_MCALL + bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override; + virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) override; virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() {} virtual SLANG_NO_THROW void SLANG_MCALL writeTimestamp(IQueryPool* pool, SlangInt index) override @@ -3533,8 +3585,7 @@ public: auto transientHeap = cmdImpl->m_transientHeap; auto& waitInfo = transientHeap->getQueueWaitInfo(m_queueIndex); waitInfo.waitValue = m_fenceValue; - ResetEvent(waitInfo.fenceEvent); - m_fence->SetEventOnCompletion(m_fenceValue, waitInfo.fenceEvent); + waitInfo.fence = m_fence; } m_d3dQueue->Signal(m_fence, m_fenceValue); ResetEvent(globalWaitHandle); @@ -3722,8 +3773,13 @@ SLANG_NO_THROW Result SLANG_MCALL D3D12Device::TransientResourceHeapImpl::synchr Array waitHandles; for (auto& waitInfo : m_waitInfos) { - if (waitInfo.waitValue != 0) + if (waitInfo.waitValue == 0) + continue; + if (waitInfo.fence) + { + waitInfo.fence->SetEventOnCompletion(waitInfo.waitValue, waitInfo.fenceEvent); waitHandles.add(waitInfo.fenceEvent); + } } WaitForMultipleObjects((DWORD)waitHandles.getCount(), waitHandles.getBuffer(), TRUE, INFINITE); m_viewHeap.deallocateAll(); @@ -3763,16 +3819,15 @@ Result D3D12Device::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffe return SLANG_OK; } -Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitter) +Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitter, RefPtr& newPipeline) { - RefPtr newPipeline; RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootShaderObject; m_renderer->maybeSpecializePipeline(m_currentPipeline, rootObjectImpl, newPipeline); - PipelineStateImpl* newPipelineImpl = static_cast(newPipeline.Ptr()); + PipelineStateBase* newPipelineImpl = static_cast(newPipeline.Ptr()); auto commandList = m_d3dCmdList; auto pipelineTypeIndex = (int)newPipelineImpl->desc.type; auto programImpl = static_cast(newPipelineImpl->m_program.Ptr()); - commandList->SetPipelineState(newPipelineImpl->m_pipelineState); + submitter->setPipelineState(newPipelineImpl); submitter->setRootSignature(programImpl->m_rootObjectLayout->m_rootSignature); RefPtr specializedRootLayout; SLANG_RETURN_ON_FAIL(rootObjectImpl->getSpecializedLayout(specializedRootLayout.writeRef())); @@ -5469,11 +5524,6 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i return SLANG_OK; } -Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState) -{ - return SLANG_E_NOT_AVAILABLE; -} - Result D3D12Device::QueryPoolImpl::init(const IQueryPool::Desc& desc, D3D12Device* device) { // Translate query type. @@ -5801,7 +5851,290 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::memoryBarrier m_commandBuffer->m_cmdList4->ResourceBarrier((UINT)count, barriers.getArrayView().getBuffer()); } +void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::bindPipeline( + IPipelineState* state, IShaderObject** outRootObject) +{ + bindPipelineImpl(state, outRootObject); +} + +void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) +{ + RefPtr newPipeline; + PipelineStateBase* pipeline = m_currentPipeline.Ptr(); + { + struct RayTracingSubmitter : public ComputeSubmitter + { + ID3D12GraphicsCommandList4* m_cmdList4; + RayTracingSubmitter(ID3D12GraphicsCommandList4* cmdList4) + : ComputeSubmitter(cmdList4), m_cmdList4(cmdList4) + { + } + virtual void setPipelineState(PipelineStateBase* pipeline) override + { + auto pipelineImpl = static_cast(pipeline); + m_cmdList4->SetPipelineState1(pipelineImpl->m_stateObject.get()); + } + }; + RayTracingSubmitter submitter(m_commandBuffer->m_cmdList4); + if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline))) + { + assert(!"Failed to bind render state"); + } + if (newPipeline) + pipeline = newPipeline.Ptr(); + } + auto pipelineImpl = static_cast(pipeline); + auto dispatchDesc = pipelineImpl->m_dispatchDesc; + int32_t rayGenShaderOffset = 0; + if (rayGenShaderName) + { + rayGenShaderOffset = + pipelineImpl->m_mapRayGenShaderNameToShaderTableIndex[rayGenShaderName].GetValue() * + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + } + dispatchDesc.RayGenerationShaderRecord.StartAddress += rayGenShaderOffset; + dispatchDesc.Width = (UINT)width; + dispatchDesc.Height = (UINT)height; + dispatchDesc.Depth = (UINT)depth; + m_commandBuffer->m_cmdList4->DispatchRays(&dispatchDesc); +} + +Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState) +{ + if (!m_device5) + { + return SLANG_E_NOT_AVAILABLE; + } + + RefPtr pipelineStateImpl = new RayTracingPipelineStateImpl(); + pipelineStateImpl->init(inDesc); + + auto program = static_cast(inDesc.program); + auto slangProgram = program->slangProgram; + auto programLayout = slangProgram->getLayout(); + + if (!program->m_rootObjectLayout->m_rootSignature) + { + returnComPtr(outState, pipelineStateImpl); + return SLANG_OK; + } + List subObjects; + ChunkedList dxilLibraries; + ChunkedList hitGroups; + ChunkedList> codeBlobs; + ComPtr diagnostics; + ChunkedList stringPool; + int32_t rayGenIndex = 0; + for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) + { + ComPtr codeBlob; + auto compileResult = + slangProgram->getEntryPointCode(i, 0, codeBlob.writeRef(), diagnostics.writeRef()); + if (diagnostics.get()) + { + getDebugCallback()->handleMessage( + compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error, + DebugMessageSource::Slang, + (char*)diagnostics->getBufferPointer()); + } + SLANG_RETURN_ON_FAIL(compileResult); + codeBlobs.add(codeBlob); + D3D12_DXIL_LIBRARY_DESC library = {}; + library.DXILLibrary.BytecodeLength = codeBlob->getBufferSize();; + library.DXILLibrary.pShaderBytecode = codeBlob->getBufferPointer(); + + D3D12_STATE_SUBOBJECT dxilSubObject = {}; + dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY; + dxilSubObject.pDesc = dxilLibraries.add(library); + subObjects.add(dxilSubObject); + + auto entryPointLayout = programLayout->getEntryPointByIndex(i); + switch (entryPointLayout->getStage()) + { + case SLANG_STAGE_RAY_GENERATION: + pipelineStateImpl + ->m_mapRayGenShaderNameToShaderTableIndex[entryPointLayout->getName()] = + rayGenIndex; + rayGenIndex++; + break; + default: + break; + } + } + auto getWStr = [&](const char* name) + { + String str = String(name); + auto wstr = str.toWString(); + return stringPool.add(wstr)->begin(); + }; + for (int i = 0; i < inDesc.hitGroupCount; i++) + { + auto hitGroup = inDesc.hitGroups[i]; + D3D12_HIT_GROUP_DESC hitGroupDesc = {}; + hitGroupDesc.Type = hitGroup.intersectionEntryPoint == nullptr + ? D3D12_HIT_GROUP_TYPE_TRIANGLES + : D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE; + + if (hitGroup.anyHitEntryPoint) + { + hitGroupDesc.AnyHitShaderImport = getWStr(hitGroup.anyHitEntryPoint); + } + if (hitGroup.closestHitEntryPoint) + { + hitGroupDesc.ClosestHitShaderImport = getWStr(hitGroup.closestHitEntryPoint); + } + if (hitGroup.intersectionEntryPoint) + { + hitGroupDesc.IntersectionShaderImport = getWStr(hitGroup.intersectionEntryPoint); + } + StringBuilder hitGroupName; + hitGroupName << "hitgroup_" << i; + hitGroupDesc.HitGroupExport = getWStr(hitGroupName.ToString().getBuffer()); + + D3D12_STATE_SUBOBJECT hitGroupSubObject = {}; + hitGroupSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP; + hitGroupSubObject.pDesc = hitGroups.add(hitGroupDesc); + subObjects.add(hitGroupSubObject); + } + + D3D12_RAYTRACING_SHADER_CONFIG shaderConfig = {}; + // According to DXR spec, fixed function triangle intersections must use float2 as ray attributes + // that defines the barycentric coordinates at intersection. + shaderConfig.MaxAttributeSizeInBytes = sizeof(float) * 2; + shaderConfig.MaxPayloadSizeInBytes = inDesc.maxRayPayloadSize; + D3D12_STATE_SUBOBJECT shaderConfigSubObject = {}; + shaderConfigSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG; + shaderConfigSubObject.pDesc = &shaderConfig; + subObjects.add(shaderConfigSubObject); + + D3D12_GLOBAL_ROOT_SIGNATURE globalSignatureDesc = {}; + globalSignatureDesc.pGlobalRootSignature = program->m_rootObjectLayout->m_rootSignature.get(); + D3D12_STATE_SUBOBJECT globalSignatureSubobject = {}; + globalSignatureSubobject.Type = D3D12_STATE_SUBOBJECT_TYPE_GLOBAL_ROOT_SIGNATURE; + globalSignatureSubobject.pDesc = &globalSignatureDesc; + subObjects.add(globalSignatureSubobject); + + D3D12_RAYTRACING_PIPELINE_CONFIG pipelineConfig = {}; + pipelineConfig.MaxTraceRecursionDepth = inDesc.maxRecursion; + D3D12_STATE_SUBOBJECT pipelineConfigSubobject = {}; + pipelineConfigSubobject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_PIPELINE_CONFIG; + pipelineConfigSubobject.pDesc = &pipelineConfig; + subObjects.add(pipelineConfigSubobject); + + D3D12_STATE_OBJECT_DESC rtpsoDesc = {}; + rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE; + rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount(); + rtpsoDesc.pSubobjects = subObjects.getBuffer(); + SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef()))); + + SLANG_RETURN_ON_FAIL(pipelineStateImpl->createShaderTables(this, slangProgram, inDesc)); + + returnComPtr(outState, pipelineStateImpl); + return SLANG_OK; +} + +Result D3D12Device::RayTracingPipelineStateImpl::createShaderTables( + D3D12Device* device, + slang::IComponentType* slangProgram, + const RayTracingPipelineStateDesc& desc) +{ + ComPtr stateObjectProperties; + m_stateObject->QueryInterface(stateObjectProperties.writeRef()); + auto programLayout = slangProgram->getLayout(); + struct ShaderIdentifier { uint32_t data[D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES / sizeof(uint32_t)]; }; + List rayGenIdentifiers, missIdentifiers, hitgroupIdentifiers; + for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) + { + auto entryPointLayout = programLayout->getEntryPointByIndex(i); + ShaderIdentifier identifier; + switch (entryPointLayout->getStage()) + { + case SLANG_STAGE_RAY_GENERATION: + memcpy( + &identifier, + stateObjectProperties->GetShaderIdentifier( + String(entryPointLayout->getName()).toWString().begin()), + sizeof(ShaderIdentifier)); + rayGenIdentifiers.add(identifier); + break; + case SLANG_STAGE_MISS: + memcpy( + &identifier, + stateObjectProperties->GetShaderIdentifier( + String(entryPointLayout->getName()).toWString().begin()), + sizeof(ShaderIdentifier)); + missIdentifiers.add(identifier); + break; + default: + break; + } + } + for (int i = 0; i < desc.shaderTableHitGroupCount; i++) + { + StringBuilder hitgroupName; + hitgroupName << "hitgroup_" << desc.shaderTableHitGroupIndices[i]; + ShaderIdentifier hitgroupIdentifier; + memcpy( + &hitgroupIdentifier, + stateObjectProperties->GetShaderIdentifier(hitgroupName.toWString().begin()), + sizeof(ShaderIdentifier)); + hitgroupIdentifiers.add(hitgroupIdentifier); + } + + auto createShaderTableResource = [&](ArrayView content, + RefPtr& outResource) -> Result + { + IBufferResource::Desc bufferDesc = {}; + bufferDesc.type = IResource::Type::Buffer; + bufferDesc.defaultState = ResourceState::ShaderResource; + bufferDesc.allowedStates = ResourceStateSet( + ResourceState::CopySource, + ResourceState::UnorderedAccess, + ResourceState::ShaderResource); + bufferDesc.elementSize = 0; + bufferDesc.sizeInBytes = content.getCount() * sizeof(ShaderIdentifier); + bufferDesc.format = Format::Unknown; + ComPtr shaderTableResource; + SLANG_RETURN_ON_FAIL(device->createBufferResource( + bufferDesc, content.getBuffer(), shaderTableResource.writeRef())); + outResource = static_cast(shaderTableResource.get()); + return SLANG_OK; + }; + + if (desc.shaderTableHitGroupCount) + { + SLANG_RETURN_ON_FAIL( + createShaderTableResource(hitgroupIdentifiers.getArrayView(), m_hitgroupShaderTable)); + m_dispatchDesc.HitGroupTable.SizeInBytes = + (uint64_t)(sizeof(ShaderIdentifier)) * desc.shaderTableHitGroupCount; + m_dispatchDesc.HitGroupTable.StrideInBytes = sizeof(ShaderIdentifier); + m_dispatchDesc.HitGroupTable.StartAddress = m_hitgroupShaderTable->getDeviceAddress(); + } + if (rayGenIdentifiers.getCount()) + { + SLANG_RETURN_ON_FAIL( + createShaderTableResource(rayGenIdentifiers.getArrayView(), m_rayGenShaderTable)); + m_dispatchDesc.RayGenerationShaderRecord.SizeInBytes = sizeof(ShaderIdentifier); + m_dispatchDesc.RayGenerationShaderRecord.StartAddress = m_rayGenShaderTable->getDeviceAddress(); + } + if (missIdentifiers.getCount()) + { + SLANG_RETURN_ON_FAIL( + createShaderTableResource(missIdentifiers.getArrayView(), m_missShaderTable)); + m_dispatchDesc.MissShaderTable.SizeInBytes = + (uint64_t)(sizeof(ShaderIdentifier)) * missIdentifiers.getCount(); + m_dispatchDesc.MissShaderTable.StrideInBytes = sizeof(ShaderIdentifier); + m_dispatchDesc.MissShaderTable.StartAddress = m_missShaderTable->getDeviceAddress(); + } + return SLANG_OK; +} + #endif // SLANG_GFX_HAS_DXR_SUPPORT + Result D3D12Device::ShaderObjectImpl::setResource(ShaderOffset const& offset, IResourceView* resourceView) { if (offset.bindingRangeIndex < 0) diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp index 067581559..50cacc6c2 100644 --- a/tools/gfx/debug-layer.cpp +++ b/tools/gfx/debug-layer.cpp @@ -705,6 +705,7 @@ DebugCommandBuffer::DebugCommandBuffer() m_renderCommandEncoder.commandBuffer = this; m_computeCommandEncoder.commandBuffer = this; m_resourceCommandEncoder.commandBuffer = this; + m_rayTracingCommandEncoder.commandBuffer = this; } void DebugCommandBuffer::encodeRenderCommands( @@ -1084,6 +1085,25 @@ void DebugRayTracingCommandEncoder::memoryBarrier( baseObject->memoryBarrier(count, innerAS.getBuffer(), sourceAccess, destAccess); } +void DebugRayTracingCommandEncoder::bindPipeline( + IPipelineState* state, IShaderObject** outRootObject) +{ + SLANG_GFX_API_FUNC; + auto innerPipeline = getInnerObj(state); + baseObject->bindPipeline(innerPipeline, commandBuffer->rootObject.baseObject.writeRef()); + *outRootObject = &commandBuffer->rootObject; +} + +void DebugRayTracingCommandEncoder::dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) +{ + SLANG_GFX_API_FUNC; + baseObject->dispatchRays(rayGenShaderName, width, height, depth); +} + const ICommandQueue::Desc& DebugCommandQueue::getDesc() { SLANG_GFX_API_FUNC; diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h index 7433db966..c7de48149 100644 --- a/tools/gfx/debug-layer.h +++ b/tools/gfx/debug-layer.h @@ -351,6 +351,13 @@ public: IAccelerationStructure* const* structures, AccessFlag::Enum sourceAccess, AccessFlag::Enum destAccess) override; + virtual SLANG_NO_THROW void SLANG_MCALL + bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override; + virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) override; public: DebugCommandBuffer* commandBuffer; diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 2eb19b6e9..bb80c4f53 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -605,6 +605,14 @@ Result RendererBase::maybeSpecializePipeline( pipelineDesc, specializedPipelineComPtr.writeRef())); break; } + case PipelineType::RayTracing: + { + auto pipelineDesc = currentPipeline->desc.rayTracing; + pipelineDesc.program = specializedProgram; + SLANG_RETURN_ON_FAIL(createRayTracingPipelineState( + pipelineDesc, specializedPipelineComPtr.writeRef())); + break; + } default: break; } diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 1f0a3eaab..31a7566a2 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -766,7 +766,7 @@ public: auto bindingRangeIndex = offset.bindingRangeIndex; auto bindingRange = layout->getBindingRange(bindingRangeIndex); - auto objectIndex = bindingRange.subObjectIndex + offset.bindingArrayIndex; + Slang::Index objectIndex = bindingRange.subObjectIndex + offset.bindingArrayIndex; if (objectIndex >= m_userProvidedSpecializationArgs.getCount()) m_userProvidedSpecializationArgs.setCount(objectIndex + 1); if (!m_userProvidedSpecializationArgs[objectIndex]) @@ -816,7 +816,7 @@ public: subObjectIndexInRange++) { ExtendedShaderObjectTypeList typeArgs; - auto objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange; + Slang::Index objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange; auto subObject = m_objects[objectIndex]; if (!subObject) @@ -932,9 +932,19 @@ public: PipelineType type; GraphicsPipelineStateDesc graphics; ComputePipelineStateDesc compute; + RayTracingPipelineStateDesc rayTracing; ShaderProgramBase* getProgram() { - return static_cast(type == PipelineType::Compute ? compute.program : graphics.program); + switch (type) + { + case PipelineType::Compute: + return static_cast(compute.program); + case PipelineType::Graphics: + return static_cast(graphics.program); + case PipelineType::RayTracing: + return static_cast(rayTracing.program); + } + return nullptr; } } desc; @@ -1105,6 +1115,8 @@ public: public: ExtendedShaderObjectTypeList specializationArgs; // Given current pipeline and root shader object binding, generate and bind a specialized pipeline if necessary. + // The newly specialized pipeline is held alive by the pipeline cache so users of `outNewPipeline` do not + // need to maintain its lifespan. Result maybeSpecializePipeline( PipelineStateBase* currentPipeline, ShaderObjectBase* rootObject, diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index bc0271aa6..592cbaac1 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -1266,7 +1266,7 @@ public: vkPushConstantRange.size = ordinaryDataSize; vkPushConstantRange.stageFlags = VK_SHADER_STAGE_ALL; // TODO: be more clever - while(m_ownPushConstantRanges.getCount() <= pushConstantRangeIndex) + while((uint32_t)m_ownPushConstantRanges.getCount() <= pushConstantRangeIndex) { VkPushConstantRange emptyRange = { 0 }; m_ownPushConstantRanges.add(emptyRange); @@ -2995,7 +2995,7 @@ public: case slang::BindingType::ConstantBuffer: { BindingOffset objOffset = rangeOffset; - for (uint32_t i = 0; i < count; ++i) + for (Index i = 0; i < count; ++i) { // Binding a constant buffer sub-object is simple enough: // we just call `bindAsConstantBuffer` on it to bind @@ -3016,7 +3016,7 @@ public: case slang::BindingType::ParameterBlock: { BindingOffset objOffset = rangeOffset; - for (uint32_t i = 0; i < count; ++i) + for (Index i = 0; i < count; ++i) { // The case for `ParameterBlock` is not that different // from `ConstantBuffer`, except that we call `bindAsParameterBlock` @@ -3047,7 +3047,7 @@ public: // SimpleBindingOffset objOffset = rangeOffset.pending; SimpleBindingOffset objStride = rangeStride.pending; - for (uint32_t i = 0; i < count; ++i) + for (Index i = 0; i < count; ++i) { // An existential-type sub-object is always bound just as a value, // which handles its nested bindings and descriptor sets, but @@ -4258,6 +4258,25 @@ public: _memoryBarrier(count, structures, srcAccess, destAccess); } + virtual SLANG_NO_THROW void SLANG_MCALL + bindPipeline(IPipelineState* pipeline, IShaderObject** outRootObject) override + { + SLANG_UNUSED(pipeline); + SLANG_UNUSED(outRootObject); + } + + virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) override + { + SLANG_UNUSED(rayGenShaderName); + SLANG_UNUSED(width); + SLANG_UNUSED(height); + SLANG_UNUSED(depth); + } + virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override { } -- cgit v1.2.3