summaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-07-28 12:24:12 -0700
committerGitHub <noreply@github.com>2021-07-28 12:24:12 -0700
commitc6f6ce12ec522b193b42bcd12d3a2540c7a6ff92 (patch)
treed5f77aa02df88c71ef4f898db40434bf4c1f3010 /tools
parent23d406f8a3b325f91fecd9ad52bd510ded5f49a7 (diff)
Experimental DXR1.0 support in gfx. (#1915)
* Experimental DXR1.0 support in gfx. - Add `dispatchRays` command. - Add `createRayTracingPipelineState` method to construct a D3D ray tracing state object from a linked slang program and user specified shader table. Limitations/simplifications: no local root signature support, shader table entries contains only shader identifiers and is specified at pipeline creation time, owned by the pipeline state object. * Root object binding for raytracing pipelines. * `maybeSpecializePipeline` implementation for raytracing pipelines. * Add ray-tracing-pipeline example. * Fixes. * Update README.md * Update comments on the lifespan of specialized pipelines Co-authored-by: Yong He <yhe@nvidia.com> Co-authored-by: jsmall-nvidia <jsmall@nvidia.com>
Diffstat (limited to 'tools')
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp381
-rw-r--r--tools/gfx/debug-layer.cpp20
-rw-r--r--tools/gfx/debug-layer.h7
-rw-r--r--tools/gfx/renderer-shared.cpp8
-rw-r--r--tools/gfx/renderer-shared.h18
-rw-r--r--tools/gfx/vulkan/render-vk.cpp27
6 files changed, 430 insertions, 31 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 24a1fd93e..7e436d28d 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -10,6 +10,7 @@
#include "../d3d/d3d-swapchain.h"
#include "core/slang-blob.h"
#include "core/slang-basic.h"
+#include "core/slang-chunked-list.h"
// In order to use the Slang API, we need to include its header
@@ -123,8 +124,6 @@ public:
const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState(
const ComputePipelineStateDesc& desc, IPipelineState** outState) override;
- virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState(
- const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool(
const IQueryPool::Desc& desc, IQueryPool** outState) override;
@@ -156,6 +155,8 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL createAccelerationStructure(
const IAccelerationStructure::CreateDesc& desc,
IAccelerationStructure** outView) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState(
+ const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override;
#endif
public:
@@ -193,6 +194,7 @@ public:
virtual void setRootDescriptorTable(int index, D3D12_GPU_DESCRIPTOR_HANDLE BaseDescriptor) = 0;
virtual void setRootSignature(ID3D12RootSignature* rootSignature) = 0;
virtual void setRootConstants(Index rootParamIndex, Index dstOffsetIn32BitValues, Index countOf32BitValues, void const* srcData) = 0;
+ virtual void setPipelineState(PipelineStateBase* pipelineState) = 0;
};
class BufferResourceImpl: public gfx::BufferResource
@@ -340,6 +342,31 @@ public:
}
};
+#if SLANG_GFX_HAS_DXR_SUPPORT
+ class RayTracingPipelineStateImpl : public PipelineStateBase
+ {
+ public:
+ ComPtr<ID3D12StateObject> m_stateObject;
+ D3D12_DISPATCH_RAYS_DESC m_dispatchDesc = {};
+ Dictionary<String, int32_t> m_mapRayGenShaderNameToShaderTableIndex;
+ // Shader Tables for each ray-tracing stage stored in GPU memory.
+ RefPtr<BufferResourceImpl> m_rayGenShaderTable;
+ RefPtr<BufferResourceImpl> m_hitgroupShaderTable;
+ RefPtr<BufferResourceImpl> m_missShaderTable;
+ void init(const RayTracingPipelineStateDesc& inDesc)
+ {
+ PipelineStateDesc pipelineDesc;
+ pipelineDesc.type = PipelineType::RayTracing;
+ pipelineDesc.rayTracing = inDesc;
+ initializeBase(pipelineDesc);
+ }
+ Result createShaderTables(
+ D3D12Device* device,
+ slang::IComponentType* slangProgram,
+ const RayTracingPipelineStateDesc& desc);
+ };
+#endif
+
class QueryPoolImpl : public IQueryPool, public ComObject
{
public:
@@ -461,6 +488,11 @@ public:
{
m_commandList->SetGraphicsRoot32BitConstants(UINT(rootParamIndex), UINT(countOf32BitValues), srcData, UINT(dstOffsetIn32BitValues));
}
+ virtual void setPipelineState(PipelineStateBase* pipeline) override
+ {
+ auto pipelineImpl = static_cast<PipelineStateImpl*>(pipeline);
+ m_commandList->SetPipelineState(pipelineImpl->m_pipelineState.get());
+ }
GraphicsSubmitter(ID3D12GraphicsCommandList* commandList):
m_commandList(commandList)
@@ -492,7 +524,11 @@ public:
{
m_commandList->SetComputeRoot32BitConstants(UINT(rootParamIndex), UINT(countOf32BitValues), srcData, UINT(dstOffsetIn32BitValues));
}
-
+ virtual void setPipelineState(PipelineStateBase* pipeline) override
+ {
+ auto pipelineImpl = static_cast<PipelineStateImpl*>(pipeline);
+ m_commandList->SetPipelineState(pipelineImpl->m_pipelineState.get());
+ }
ComputeSubmitter(ID3D12GraphicsCommandList* commandList) :
m_commandList(commandList)
{
@@ -568,6 +604,7 @@ public:
{
uint64_t waitValue;
HANDLE fenceEvent;
+ ID3D12Fence* fence = nullptr;
};
ShortList<QueueWaitInfo, 4> m_waitInfos;
@@ -585,7 +622,7 @@ public:
m_waitInfos[i].fenceEvent = CreateEventEx(
nullptr,
false,
- CREATE_EVENT_INITIAL_SET | CREATE_EVENT_MANUAL_RESET,
+ 0,
EVENT_ALL_ACCESS);
}
return m_waitInfos[queueIndex];
@@ -666,7 +703,7 @@ public:
ID3D12GraphicsCommandList* m_d3dCmdList;
ID3D12GraphicsCommandList* m_preCmdList = nullptr;
- RefPtr<PipelineStateImpl> m_currentPipeline;
+ RefPtr<PipelineStateBase> m_currentPipeline;
static int getBindPointIndex(PipelineType type)
{
@@ -690,13 +727,14 @@ public:
m_d3dCmdList = m_commandBuffer->m_cmdList;
m_renderer = commandBuffer->m_renderer;
m_transientHeap = commandBuffer->m_transientHeap;
+ m_device = commandBuffer->m_renderer->m_device;
}
void endEncodingImpl() { m_isOpen = false; }
Result bindPipelineImpl(IPipelineState* pipelineState, IShaderObject** outRootObject)
{
- m_currentPipeline = static_cast<PipelineStateImpl*>(pipelineState);
+ m_currentPipeline = static_cast<PipelineStateBase*>(pipelineState);
auto rootObject = &m_commandBuffer->m_rootShaderObject;
SLANG_RETURN_ON_FAIL(rootObject->reset(
m_renderer,
@@ -707,7 +745,11 @@ public:
return SLANG_OK;
}
- Result _bindRenderState(Submitter* submitter);
+ /// Specializes the pipeline according to current root-object argument values,
+ /// applys the root object bindings and binds the pipeline state.
+ /// The newly specialized pipeline is held alive by the pipeline cache so users of
+ /// `newPipeline` do not need to maintain its lifespan.
+ Result _bindRenderState(Submitter* submitter, RefPtr<PipelineStateBase>& newPipeline);
};
struct DescriptorTable
@@ -2956,7 +2998,6 @@ public:
{
PipelineCommandEncoder::init(cmdBuffer);
m_preCmdList = nullptr;
- m_device = renderer->m_device;
m_renderPass = renderPass;
m_framebuffer = framebuffer;
m_transientHeap = transientHeap;
@@ -3174,7 +3215,8 @@ public:
// Submit - setting for graphics
{
GraphicsSubmitter submitter(m_d3dCmdList);
- if(SLANG_FAILED(_bindRenderState(&submitter)))
+ RefPtr<PipelineStateBase> newPipeline;
+ if(SLANG_FAILED(_bindRenderState(&submitter, newPipeline)))
{
assert(!"Failed to bind render state");
}
@@ -3314,7 +3356,6 @@ public:
{
PipelineCommandEncoder::init(cmdBuffer);
m_preCmdList = nullptr;
- m_device = renderer->m_device;
m_transientHeap = transientHeap;
m_currentPipeline = nullptr;
}
@@ -3330,7 +3371,8 @@ public:
// Submit binding for compute
{
ComputeSubmitter submitter(m_d3dCmdList);
- if(SLANG_FAILED(_bindRenderState(&submitter)))
+ RefPtr<PipelineStateBase> newPipeline;
+ if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline)))
{
assert(!"Failed to bind render state");
}
@@ -3402,12 +3444,15 @@ public:
}
#if SLANG_GFX_HAS_DXR_SUPPORT
- class RayTracingCommandEncoderImpl : public IRayTracingCommandEncoder
+ class RayTracingCommandEncoderImpl
+ : public IRayTracingCommandEncoder
+ , public PipelineCommandEncoder
{
public:
CommandBufferImpl* m_commandBuffer;
void init(D3D12Device* renderer, CommandBufferImpl* commandBuffer)
{
+ PipelineCommandEncoder::init(commandBuffer);
m_commandBuffer = commandBuffer;
}
virtual SLANG_NO_THROW void SLANG_MCALL buildAccelerationStructure(
@@ -3434,6 +3479,13 @@ public:
IAccelerationStructure* const* structures,
AccessFlag::Enum sourceAccess,
AccessFlag::Enum destAccess) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL
+ bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth) override;
virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() {}
virtual SLANG_NO_THROW void SLANG_MCALL
writeTimestamp(IQueryPool* pool, SlangInt index) override
@@ -3533,8 +3585,7 @@ public:
auto transientHeap = cmdImpl->m_transientHeap;
auto& waitInfo = transientHeap->getQueueWaitInfo(m_queueIndex);
waitInfo.waitValue = m_fenceValue;
- ResetEvent(waitInfo.fenceEvent);
- m_fence->SetEventOnCompletion(m_fenceValue, waitInfo.fenceEvent);
+ waitInfo.fence = m_fence;
}
m_d3dQueue->Signal(m_fence, m_fenceValue);
ResetEvent(globalWaitHandle);
@@ -3722,8 +3773,13 @@ SLANG_NO_THROW Result SLANG_MCALL D3D12Device::TransientResourceHeapImpl::synchr
Array<HANDLE, 16> waitHandles;
for (auto& waitInfo : m_waitInfos)
{
- if (waitInfo.waitValue != 0)
+ if (waitInfo.waitValue == 0)
+ continue;
+ if (waitInfo.fence)
+ {
+ waitInfo.fence->SetEventOnCompletion(waitInfo.waitValue, waitInfo.fenceEvent);
waitHandles.add(waitInfo.fenceEvent);
+ }
}
WaitForMultipleObjects((DWORD)waitHandles.getCount(), waitHandles.getBuffer(), TRUE, INFINITE);
m_viewHeap.deallocateAll();
@@ -3763,16 +3819,15 @@ Result D3D12Device::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffe
return SLANG_OK;
}
-Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitter)
+Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitter, RefPtr<PipelineStateBase>& newPipeline)
{
- RefPtr<PipelineStateBase> newPipeline;
RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootShaderObject;
m_renderer->maybeSpecializePipeline(m_currentPipeline, rootObjectImpl, newPipeline);
- PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr());
+ PipelineStateBase* newPipelineImpl = static_cast<PipelineStateBase*>(newPipeline.Ptr());
auto commandList = m_d3dCmdList;
auto pipelineTypeIndex = (int)newPipelineImpl->desc.type;
auto programImpl = static_cast<ShaderProgramImpl*>(newPipelineImpl->m_program.Ptr());
- commandList->SetPipelineState(newPipelineImpl->m_pipelineState);
+ submitter->setPipelineState(newPipelineImpl);
submitter->setRootSignature(programImpl->m_rootObjectLayout->m_rootSignature);
RefPtr<ShaderObjectLayoutImpl> specializedRootLayout;
SLANG_RETURN_ON_FAIL(rootObjectImpl->getSpecializedLayout(specializedRootLayout.writeRef()));
@@ -5469,11 +5524,6 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
return SLANG_OK;
}
-Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState)
-{
- return SLANG_E_NOT_AVAILABLE;
-}
-
Result D3D12Device::QueryPoolImpl::init(const IQueryPool::Desc& desc, D3D12Device* device)
{
// Translate query type.
@@ -5801,7 +5851,290 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::memoryBarrier
m_commandBuffer->m_cmdList4->ResourceBarrier((UINT)count, barriers.getArrayView().getBuffer());
}
+void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::bindPipeline(
+ IPipelineState* state, IShaderObject** outRootObject)
+{
+ bindPipelineImpl(state, outRootObject);
+}
+
+void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth)
+{
+ RefPtr<PipelineStateBase> newPipeline;
+ PipelineStateBase* pipeline = m_currentPipeline.Ptr();
+ {
+ struct RayTracingSubmitter : public ComputeSubmitter
+ {
+ ID3D12GraphicsCommandList4* m_cmdList4;
+ RayTracingSubmitter(ID3D12GraphicsCommandList4* cmdList4)
+ : ComputeSubmitter(cmdList4), m_cmdList4(cmdList4)
+ {
+ }
+ virtual void setPipelineState(PipelineStateBase* pipeline) override
+ {
+ auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
+ m_cmdList4->SetPipelineState1(pipelineImpl->m_stateObject.get());
+ }
+ };
+ RayTracingSubmitter submitter(m_commandBuffer->m_cmdList4);
+ if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline)))
+ {
+ assert(!"Failed to bind render state");
+ }
+ if (newPipeline)
+ pipeline = newPipeline.Ptr();
+ }
+ auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
+ auto dispatchDesc = pipelineImpl->m_dispatchDesc;
+ int32_t rayGenShaderOffset = 0;
+ if (rayGenShaderName)
+ {
+ rayGenShaderOffset =
+ pipelineImpl->m_mapRayGenShaderNameToShaderTableIndex[rayGenShaderName].GetValue() *
+ D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ }
+ dispatchDesc.RayGenerationShaderRecord.StartAddress += rayGenShaderOffset;
+ dispatchDesc.Width = (UINT)width;
+ dispatchDesc.Height = (UINT)height;
+ dispatchDesc.Depth = (UINT)depth;
+ m_commandBuffer->m_cmdList4->DispatchRays(&dispatchDesc);
+}
+
+Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState)
+{
+ if (!m_device5)
+ {
+ return SLANG_E_NOT_AVAILABLE;
+ }
+
+ RefPtr<RayTracingPipelineStateImpl> pipelineStateImpl = new RayTracingPipelineStateImpl();
+ pipelineStateImpl->init(inDesc);
+
+ auto program = static_cast<ShaderProgramImpl*>(inDesc.program);
+ auto slangProgram = program->slangProgram;
+ auto programLayout = slangProgram->getLayout();
+
+ if (!program->m_rootObjectLayout->m_rootSignature)
+ {
+ returnComPtr(outState, pipelineStateImpl);
+ return SLANG_OK;
+ }
+ List<D3D12_STATE_SUBOBJECT> subObjects;
+ ChunkedList<D3D12_DXIL_LIBRARY_DESC> dxilLibraries;
+ ChunkedList<D3D12_HIT_GROUP_DESC> hitGroups;
+ ChunkedList<ComPtr<ISlangBlob>> codeBlobs;
+ ComPtr<ISlangBlob> diagnostics;
+ ChunkedList<OSString> stringPool;
+ int32_t rayGenIndex = 0;
+ for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
+ {
+ ComPtr<ISlangBlob> codeBlob;
+ auto compileResult =
+ slangProgram->getEntryPointCode(i, 0, codeBlob.writeRef(), diagnostics.writeRef());
+ if (diagnostics.get())
+ {
+ getDebugCallback()->handleMessage(
+ compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error,
+ DebugMessageSource::Slang,
+ (char*)diagnostics->getBufferPointer());
+ }
+ SLANG_RETURN_ON_FAIL(compileResult);
+ codeBlobs.add(codeBlob);
+ D3D12_DXIL_LIBRARY_DESC library = {};
+ library.DXILLibrary.BytecodeLength = codeBlob->getBufferSize();;
+ library.DXILLibrary.pShaderBytecode = codeBlob->getBufferPointer();
+
+ D3D12_STATE_SUBOBJECT dxilSubObject = {};
+ dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY;
+ dxilSubObject.pDesc = dxilLibraries.add(library);
+ subObjects.add(dxilSubObject);
+
+ auto entryPointLayout = programLayout->getEntryPointByIndex(i);
+ switch (entryPointLayout->getStage())
+ {
+ case SLANG_STAGE_RAY_GENERATION:
+ pipelineStateImpl
+ ->m_mapRayGenShaderNameToShaderTableIndex[entryPointLayout->getName()] =
+ rayGenIndex;
+ rayGenIndex++;
+ break;
+ default:
+ break;
+ }
+ }
+ auto getWStr = [&](const char* name)
+ {
+ String str = String(name);
+ auto wstr = str.toWString();
+ return stringPool.add(wstr)->begin();
+ };
+ for (int i = 0; i < inDesc.hitGroupCount; i++)
+ {
+ auto hitGroup = inDesc.hitGroups[i];
+ D3D12_HIT_GROUP_DESC hitGroupDesc = {};
+ hitGroupDesc.Type = hitGroup.intersectionEntryPoint == nullptr
+ ? D3D12_HIT_GROUP_TYPE_TRIANGLES
+ : D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE;
+
+ if (hitGroup.anyHitEntryPoint)
+ {
+ hitGroupDesc.AnyHitShaderImport = getWStr(hitGroup.anyHitEntryPoint);
+ }
+ if (hitGroup.closestHitEntryPoint)
+ {
+ hitGroupDesc.ClosestHitShaderImport = getWStr(hitGroup.closestHitEntryPoint);
+ }
+ if (hitGroup.intersectionEntryPoint)
+ {
+ hitGroupDesc.IntersectionShaderImport = getWStr(hitGroup.intersectionEntryPoint);
+ }
+ StringBuilder hitGroupName;
+ hitGroupName << "hitgroup_" << i;
+ hitGroupDesc.HitGroupExport = getWStr(hitGroupName.ToString().getBuffer());
+
+ D3D12_STATE_SUBOBJECT hitGroupSubObject = {};
+ hitGroupSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP;
+ hitGroupSubObject.pDesc = hitGroups.add(hitGroupDesc);
+ subObjects.add(hitGroupSubObject);
+ }
+
+ D3D12_RAYTRACING_SHADER_CONFIG shaderConfig = {};
+ // According to DXR spec, fixed function triangle intersections must use float2 as ray attributes
+ // that defines the barycentric coordinates at intersection.
+ shaderConfig.MaxAttributeSizeInBytes = sizeof(float) * 2;
+ shaderConfig.MaxPayloadSizeInBytes = inDesc.maxRayPayloadSize;
+ D3D12_STATE_SUBOBJECT shaderConfigSubObject = {};
+ shaderConfigSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG;
+ shaderConfigSubObject.pDesc = &shaderConfig;
+ subObjects.add(shaderConfigSubObject);
+
+ D3D12_GLOBAL_ROOT_SIGNATURE globalSignatureDesc = {};
+ globalSignatureDesc.pGlobalRootSignature = program->m_rootObjectLayout->m_rootSignature.get();
+ D3D12_STATE_SUBOBJECT globalSignatureSubobject = {};
+ globalSignatureSubobject.Type = D3D12_STATE_SUBOBJECT_TYPE_GLOBAL_ROOT_SIGNATURE;
+ globalSignatureSubobject.pDesc = &globalSignatureDesc;
+ subObjects.add(globalSignatureSubobject);
+
+ D3D12_RAYTRACING_PIPELINE_CONFIG pipelineConfig = {};
+ pipelineConfig.MaxTraceRecursionDepth = inDesc.maxRecursion;
+ D3D12_STATE_SUBOBJECT pipelineConfigSubobject = {};
+ pipelineConfigSubobject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_PIPELINE_CONFIG;
+ pipelineConfigSubobject.pDesc = &pipelineConfig;
+ subObjects.add(pipelineConfigSubobject);
+
+ D3D12_STATE_OBJECT_DESC rtpsoDesc = {};
+ rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE;
+ rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount();
+ rtpsoDesc.pSubobjects = subObjects.getBuffer();
+ SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef())));
+
+ SLANG_RETURN_ON_FAIL(pipelineStateImpl->createShaderTables(this, slangProgram, inDesc));
+
+ returnComPtr(outState, pipelineStateImpl);
+ return SLANG_OK;
+}
+
+Result D3D12Device::RayTracingPipelineStateImpl::createShaderTables(
+ D3D12Device* device,
+ slang::IComponentType* slangProgram,
+ const RayTracingPipelineStateDesc& desc)
+{
+ ComPtr<ID3D12StateObjectProperties> stateObjectProperties;
+ m_stateObject->QueryInterface(stateObjectProperties.writeRef());
+ auto programLayout = slangProgram->getLayout();
+ struct ShaderIdentifier { uint32_t data[D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES / sizeof(uint32_t)]; };
+ List<ShaderIdentifier> rayGenIdentifiers, missIdentifiers, hitgroupIdentifiers;
+ for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
+ {
+ auto entryPointLayout = programLayout->getEntryPointByIndex(i);
+ ShaderIdentifier identifier;
+ switch (entryPointLayout->getStage())
+ {
+ case SLANG_STAGE_RAY_GENERATION:
+ memcpy(
+ &identifier,
+ stateObjectProperties->GetShaderIdentifier(
+ String(entryPointLayout->getName()).toWString().begin()),
+ sizeof(ShaderIdentifier));
+ rayGenIdentifiers.add(identifier);
+ break;
+ case SLANG_STAGE_MISS:
+ memcpy(
+ &identifier,
+ stateObjectProperties->GetShaderIdentifier(
+ String(entryPointLayout->getName()).toWString().begin()),
+ sizeof(ShaderIdentifier));
+ missIdentifiers.add(identifier);
+ break;
+ default:
+ break;
+ }
+ }
+ for (int i = 0; i < desc.shaderTableHitGroupCount; i++)
+ {
+ StringBuilder hitgroupName;
+ hitgroupName << "hitgroup_" << desc.shaderTableHitGroupIndices[i];
+ ShaderIdentifier hitgroupIdentifier;
+ memcpy(
+ &hitgroupIdentifier,
+ stateObjectProperties->GetShaderIdentifier(hitgroupName.toWString().begin()),
+ sizeof(ShaderIdentifier));
+ hitgroupIdentifiers.add(hitgroupIdentifier);
+ }
+
+ auto createShaderTableResource = [&](ArrayView<ShaderIdentifier> content,
+ RefPtr<BufferResourceImpl>& outResource) -> Result
+ {
+ IBufferResource::Desc bufferDesc = {};
+ bufferDesc.type = IResource::Type::Buffer;
+ bufferDesc.defaultState = ResourceState::ShaderResource;
+ bufferDesc.allowedStates = ResourceStateSet(
+ ResourceState::CopySource,
+ ResourceState::UnorderedAccess,
+ ResourceState::ShaderResource);
+ bufferDesc.elementSize = 0;
+ bufferDesc.sizeInBytes = content.getCount() * sizeof(ShaderIdentifier);
+ bufferDesc.format = Format::Unknown;
+ ComPtr<IBufferResource> shaderTableResource;
+ SLANG_RETURN_ON_FAIL(device->createBufferResource(
+ bufferDesc, content.getBuffer(), shaderTableResource.writeRef()));
+ outResource = static_cast<BufferResourceImpl*>(shaderTableResource.get());
+ return SLANG_OK;
+ };
+
+ if (desc.shaderTableHitGroupCount)
+ {
+ SLANG_RETURN_ON_FAIL(
+ createShaderTableResource(hitgroupIdentifiers.getArrayView(), m_hitgroupShaderTable));
+ m_dispatchDesc.HitGroupTable.SizeInBytes =
+ (uint64_t)(sizeof(ShaderIdentifier)) * desc.shaderTableHitGroupCount;
+ m_dispatchDesc.HitGroupTable.StrideInBytes = sizeof(ShaderIdentifier);
+ m_dispatchDesc.HitGroupTable.StartAddress = m_hitgroupShaderTable->getDeviceAddress();
+ }
+ if (rayGenIdentifiers.getCount())
+ {
+ SLANG_RETURN_ON_FAIL(
+ createShaderTableResource(rayGenIdentifiers.getArrayView(), m_rayGenShaderTable));
+ m_dispatchDesc.RayGenerationShaderRecord.SizeInBytes = sizeof(ShaderIdentifier);
+ m_dispatchDesc.RayGenerationShaderRecord.StartAddress = m_rayGenShaderTable->getDeviceAddress();
+ }
+ if (missIdentifiers.getCount())
+ {
+ SLANG_RETURN_ON_FAIL(
+ createShaderTableResource(missIdentifiers.getArrayView(), m_missShaderTable));
+ m_dispatchDesc.MissShaderTable.SizeInBytes =
+ (uint64_t)(sizeof(ShaderIdentifier)) * missIdentifiers.getCount();
+ m_dispatchDesc.MissShaderTable.StrideInBytes = sizeof(ShaderIdentifier);
+ m_dispatchDesc.MissShaderTable.StartAddress = m_missShaderTable->getDeviceAddress();
+ }
+ return SLANG_OK;
+}
+
#endif // SLANG_GFX_HAS_DXR_SUPPORT
+
Result D3D12Device::ShaderObjectImpl::setResource(ShaderOffset const& offset, IResourceView* resourceView)
{
if (offset.bindingRangeIndex < 0)
diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp
index 067581559..50cacc6c2 100644
--- a/tools/gfx/debug-layer.cpp
+++ b/tools/gfx/debug-layer.cpp
@@ -705,6 +705,7 @@ DebugCommandBuffer::DebugCommandBuffer()
m_renderCommandEncoder.commandBuffer = this;
m_computeCommandEncoder.commandBuffer = this;
m_resourceCommandEncoder.commandBuffer = this;
+ m_rayTracingCommandEncoder.commandBuffer = this;
}
void DebugCommandBuffer::encodeRenderCommands(
@@ -1084,6 +1085,25 @@ void DebugRayTracingCommandEncoder::memoryBarrier(
baseObject->memoryBarrier(count, innerAS.getBuffer(), sourceAccess, destAccess);
}
+void DebugRayTracingCommandEncoder::bindPipeline(
+ IPipelineState* state, IShaderObject** outRootObject)
+{
+ SLANG_GFX_API_FUNC;
+ auto innerPipeline = getInnerObj(state);
+ baseObject->bindPipeline(innerPipeline, commandBuffer->rootObject.baseObject.writeRef());
+ *outRootObject = &commandBuffer->rootObject;
+}
+
+void DebugRayTracingCommandEncoder::dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth)
+{
+ SLANG_GFX_API_FUNC;
+ baseObject->dispatchRays(rayGenShaderName, width, height, depth);
+}
+
const ICommandQueue::Desc& DebugCommandQueue::getDesc()
{
SLANG_GFX_API_FUNC;
diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h
index 7433db966..c7de48149 100644
--- a/tools/gfx/debug-layer.h
+++ b/tools/gfx/debug-layer.h
@@ -351,6 +351,13 @@ public:
IAccelerationStructure* const* structures,
AccessFlag::Enum sourceAccess,
AccessFlag::Enum destAccess) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL
+ bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override;
+ virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth) override;
public:
DebugCommandBuffer* commandBuffer;
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 2eb19b6e9..bb80c4f53 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -605,6 +605,14 @@ Result RendererBase::maybeSpecializePipeline(
pipelineDesc, specializedPipelineComPtr.writeRef()));
break;
}
+ case PipelineType::RayTracing:
+ {
+ auto pipelineDesc = currentPipeline->desc.rayTracing;
+ pipelineDesc.program = specializedProgram;
+ SLANG_RETURN_ON_FAIL(createRayTracingPipelineState(
+ pipelineDesc, specializedPipelineComPtr.writeRef()));
+ break;
+ }
default:
break;
}
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 1f0a3eaab..31a7566a2 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -766,7 +766,7 @@ public:
auto bindingRangeIndex = offset.bindingRangeIndex;
auto bindingRange = layout->getBindingRange(bindingRangeIndex);
- auto objectIndex = bindingRange.subObjectIndex + offset.bindingArrayIndex;
+ Slang::Index objectIndex = bindingRange.subObjectIndex + offset.bindingArrayIndex;
if (objectIndex >= m_userProvidedSpecializationArgs.getCount())
m_userProvidedSpecializationArgs.setCount(objectIndex + 1);
if (!m_userProvidedSpecializationArgs[objectIndex])
@@ -816,7 +816,7 @@ public:
subObjectIndexInRange++)
{
ExtendedShaderObjectTypeList typeArgs;
- auto objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange;
+ Slang::Index objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange;
auto subObject = m_objects[objectIndex];
if (!subObject)
@@ -932,9 +932,19 @@ public:
PipelineType type;
GraphicsPipelineStateDesc graphics;
ComputePipelineStateDesc compute;
+ RayTracingPipelineStateDesc rayTracing;
ShaderProgramBase* getProgram()
{
- return static_cast<ShaderProgramBase*>(type == PipelineType::Compute ? compute.program : graphics.program);
+ switch (type)
+ {
+ case PipelineType::Compute:
+ return static_cast<ShaderProgramBase*>(compute.program);
+ case PipelineType::Graphics:
+ return static_cast<ShaderProgramBase*>(graphics.program);
+ case PipelineType::RayTracing:
+ return static_cast<ShaderProgramBase*>(rayTracing.program);
+ }
+ return nullptr;
}
} desc;
@@ -1105,6 +1115,8 @@ public:
public:
ExtendedShaderObjectTypeList specializationArgs;
// Given current pipeline and root shader object binding, generate and bind a specialized pipeline if necessary.
+ // The newly specialized pipeline is held alive by the pipeline cache so users of `outNewPipeline` do not
+ // need to maintain its lifespan.
Result maybeSpecializePipeline(
PipelineStateBase* currentPipeline,
ShaderObjectBase* rootObject,
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index bc0271aa6..592cbaac1 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -1266,7 +1266,7 @@ public:
vkPushConstantRange.size = ordinaryDataSize;
vkPushConstantRange.stageFlags = VK_SHADER_STAGE_ALL; // TODO: be more clever
- while(m_ownPushConstantRanges.getCount() <= pushConstantRangeIndex)
+ while((uint32_t)m_ownPushConstantRanges.getCount() <= pushConstantRangeIndex)
{
VkPushConstantRange emptyRange = { 0 };
m_ownPushConstantRanges.add(emptyRange);
@@ -2995,7 +2995,7 @@ public:
case slang::BindingType::ConstantBuffer:
{
BindingOffset objOffset = rangeOffset;
- for (uint32_t i = 0; i < count; ++i)
+ for (Index i = 0; i < count; ++i)
{
// Binding a constant buffer sub-object is simple enough:
// we just call `bindAsConstantBuffer` on it to bind
@@ -3016,7 +3016,7 @@ public:
case slang::BindingType::ParameterBlock:
{
BindingOffset objOffset = rangeOffset;
- for (uint32_t i = 0; i < count; ++i)
+ for (Index i = 0; i < count; ++i)
{
// The case for `ParameterBlock<X>` is not that different
// from `ConstantBuffer<X>`, except that we call `bindAsParameterBlock`
@@ -3047,7 +3047,7 @@ public:
//
SimpleBindingOffset objOffset = rangeOffset.pending;
SimpleBindingOffset objStride = rangeStride.pending;
- for (uint32_t i = 0; i < count; ++i)
+ for (Index i = 0; i < count; ++i)
{
// An existential-type sub-object is always bound just as a value,
// which handles its nested bindings and descriptor sets, but
@@ -4258,6 +4258,25 @@ public:
_memoryBarrier(count, structures, srcAccess, destAccess);
}
+ virtual SLANG_NO_THROW void SLANG_MCALL
+ bindPipeline(IPipelineState* pipeline, IShaderObject** outRootObject) override
+ {
+ SLANG_UNUSED(pipeline);
+ SLANG_UNUSED(outRootObject);
+ }
+
+ virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
+ const char* rayGenShaderName,
+ int32_t width,
+ int32_t height,
+ int32_t depth) override
+ {
+ SLANG_UNUSED(rayGenShaderName);
+ SLANG_UNUSED(width);
+ SLANG_UNUSED(height);
+ SLANG_UNUSED(depth);
+ }
+
virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override
{
}