diff options
| author | Yong He <yonghe@outlook.com> | 2021-07-28 12:24:12 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-07-28 12:24:12 -0700 |
| commit | c6f6ce12ec522b193b42bcd12d3a2540c7a6ff92 (patch) | |
| tree | d5f77aa02df88c71ef4f898db40434bf4c1f3010 /tools | |
| parent | 23d406f8a3b325f91fecd9ad52bd510ded5f49a7 (diff) | |
Experimental DXR1.0 support in gfx. (#1915)
* Experimental DXR1.0 support in gfx.
- Add `dispatchRays` command.
- Add `createRayTracingPipelineState` method to construct a D3D ray tracing state object from a linked slang program and user specified shader table.
Limitations/simplifications: no local root signature support, shader table entries contains only shader identifiers and is specified at pipeline creation time, owned by the pipeline state object.
* Root object binding for raytracing pipelines.
* `maybeSpecializePipeline` implementation for raytracing pipelines.
* Add ray-tracing-pipeline example.
* Fixes.
* Update README.md
* Update comments on the lifespan of specialized pipelines
Co-authored-by: Yong He <yhe@nvidia.com>
Co-authored-by: jsmall-nvidia <jsmall@nvidia.com>
Diffstat (limited to 'tools')
| -rw-r--r-- | tools/gfx/d3d12/render-d3d12.cpp | 381 | ||||
| -rw-r--r-- | tools/gfx/debug-layer.cpp | 20 | ||||
| -rw-r--r-- | tools/gfx/debug-layer.h | 7 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.cpp | 8 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.h | 18 | ||||
| -rw-r--r-- | tools/gfx/vulkan/render-vk.cpp | 27 |
6 files changed, 430 insertions, 31 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index 24a1fd93e..7e436d28d 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -10,6 +10,7 @@ #include "../d3d/d3d-swapchain.h" #include "core/slang-blob.h" #include "core/slang-basic.h" +#include "core/slang-chunked-list.h" // In order to use the Slang API, we need to include its header @@ -123,8 +124,6 @@ public: const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override; virtual SLANG_NO_THROW Result SLANG_MCALL createComputePipelineState( const ComputePipelineStateDesc& desc, IPipelineState** outState) override; - virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState( - const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override; virtual SLANG_NO_THROW Result SLANG_MCALL createQueryPool( const IQueryPool::Desc& desc, IQueryPool** outState) override; @@ -156,6 +155,8 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL createAccelerationStructure( const IAccelerationStructure::CreateDesc& desc, IAccelerationStructure** outView) override; + virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState( + const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override; #endif public: @@ -193,6 +194,7 @@ public: virtual void setRootDescriptorTable(int index, D3D12_GPU_DESCRIPTOR_HANDLE BaseDescriptor) = 0; virtual void setRootSignature(ID3D12RootSignature* rootSignature) = 0; virtual void setRootConstants(Index rootParamIndex, Index dstOffsetIn32BitValues, Index countOf32BitValues, void const* srcData) = 0; + virtual void setPipelineState(PipelineStateBase* pipelineState) = 0; }; class BufferResourceImpl: public gfx::BufferResource @@ -340,6 +342,31 @@ public: } }; +#if SLANG_GFX_HAS_DXR_SUPPORT + class RayTracingPipelineStateImpl : public PipelineStateBase + { + public: + ComPtr<ID3D12StateObject> m_stateObject; + D3D12_DISPATCH_RAYS_DESC m_dispatchDesc = {}; + Dictionary<String, int32_t> m_mapRayGenShaderNameToShaderTableIndex; + // Shader Tables for each ray-tracing stage stored in GPU memory. + RefPtr<BufferResourceImpl> m_rayGenShaderTable; + RefPtr<BufferResourceImpl> m_hitgroupShaderTable; + RefPtr<BufferResourceImpl> m_missShaderTable; + void init(const RayTracingPipelineStateDesc& inDesc) + { + PipelineStateDesc pipelineDesc; + pipelineDesc.type = PipelineType::RayTracing; + pipelineDesc.rayTracing = inDesc; + initializeBase(pipelineDesc); + } + Result createShaderTables( + D3D12Device* device, + slang::IComponentType* slangProgram, + const RayTracingPipelineStateDesc& desc); + }; +#endif + class QueryPoolImpl : public IQueryPool, public ComObject { public: @@ -461,6 +488,11 @@ public: { m_commandList->SetGraphicsRoot32BitConstants(UINT(rootParamIndex), UINT(countOf32BitValues), srcData, UINT(dstOffsetIn32BitValues)); } + virtual void setPipelineState(PipelineStateBase* pipeline) override + { + auto pipelineImpl = static_cast<PipelineStateImpl*>(pipeline); + m_commandList->SetPipelineState(pipelineImpl->m_pipelineState.get()); + } GraphicsSubmitter(ID3D12GraphicsCommandList* commandList): m_commandList(commandList) @@ -492,7 +524,11 @@ public: { m_commandList->SetComputeRoot32BitConstants(UINT(rootParamIndex), UINT(countOf32BitValues), srcData, UINT(dstOffsetIn32BitValues)); } - + virtual void setPipelineState(PipelineStateBase* pipeline) override + { + auto pipelineImpl = static_cast<PipelineStateImpl*>(pipeline); + m_commandList->SetPipelineState(pipelineImpl->m_pipelineState.get()); + } ComputeSubmitter(ID3D12GraphicsCommandList* commandList) : m_commandList(commandList) { @@ -568,6 +604,7 @@ public: { uint64_t waitValue; HANDLE fenceEvent; + ID3D12Fence* fence = nullptr; }; ShortList<QueueWaitInfo, 4> m_waitInfos; @@ -585,7 +622,7 @@ public: m_waitInfos[i].fenceEvent = CreateEventEx( nullptr, false, - CREATE_EVENT_INITIAL_SET | CREATE_EVENT_MANUAL_RESET, + 0, EVENT_ALL_ACCESS); } return m_waitInfos[queueIndex]; @@ -666,7 +703,7 @@ public: ID3D12GraphicsCommandList* m_d3dCmdList; ID3D12GraphicsCommandList* m_preCmdList = nullptr; - RefPtr<PipelineStateImpl> m_currentPipeline; + RefPtr<PipelineStateBase> m_currentPipeline; static int getBindPointIndex(PipelineType type) { @@ -690,13 +727,14 @@ public: m_d3dCmdList = m_commandBuffer->m_cmdList; m_renderer = commandBuffer->m_renderer; m_transientHeap = commandBuffer->m_transientHeap; + m_device = commandBuffer->m_renderer->m_device; } void endEncodingImpl() { m_isOpen = false; } Result bindPipelineImpl(IPipelineState* pipelineState, IShaderObject** outRootObject) { - m_currentPipeline = static_cast<PipelineStateImpl*>(pipelineState); + m_currentPipeline = static_cast<PipelineStateBase*>(pipelineState); auto rootObject = &m_commandBuffer->m_rootShaderObject; SLANG_RETURN_ON_FAIL(rootObject->reset( m_renderer, @@ -707,7 +745,11 @@ public: return SLANG_OK; } - Result _bindRenderState(Submitter* submitter); + /// Specializes the pipeline according to current root-object argument values, + /// applys the root object bindings and binds the pipeline state. + /// The newly specialized pipeline is held alive by the pipeline cache so users of + /// `newPipeline` do not need to maintain its lifespan. + Result _bindRenderState(Submitter* submitter, RefPtr<PipelineStateBase>& newPipeline); }; struct DescriptorTable @@ -2956,7 +2998,6 @@ public: { PipelineCommandEncoder::init(cmdBuffer); m_preCmdList = nullptr; - m_device = renderer->m_device; m_renderPass = renderPass; m_framebuffer = framebuffer; m_transientHeap = transientHeap; @@ -3174,7 +3215,8 @@ public: // Submit - setting for graphics { GraphicsSubmitter submitter(m_d3dCmdList); - if(SLANG_FAILED(_bindRenderState(&submitter))) + RefPtr<PipelineStateBase> newPipeline; + if(SLANG_FAILED(_bindRenderState(&submitter, newPipeline))) { assert(!"Failed to bind render state"); } @@ -3314,7 +3356,6 @@ public: { PipelineCommandEncoder::init(cmdBuffer); m_preCmdList = nullptr; - m_device = renderer->m_device; m_transientHeap = transientHeap; m_currentPipeline = nullptr; } @@ -3330,7 +3371,8 @@ public: // Submit binding for compute { ComputeSubmitter submitter(m_d3dCmdList); - if(SLANG_FAILED(_bindRenderState(&submitter))) + RefPtr<PipelineStateBase> newPipeline; + if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline))) { assert(!"Failed to bind render state"); } @@ -3402,12 +3444,15 @@ public: } #if SLANG_GFX_HAS_DXR_SUPPORT - class RayTracingCommandEncoderImpl : public IRayTracingCommandEncoder + class RayTracingCommandEncoderImpl + : public IRayTracingCommandEncoder + , public PipelineCommandEncoder { public: CommandBufferImpl* m_commandBuffer; void init(D3D12Device* renderer, CommandBufferImpl* commandBuffer) { + PipelineCommandEncoder::init(commandBuffer); m_commandBuffer = commandBuffer; } virtual SLANG_NO_THROW void SLANG_MCALL buildAccelerationStructure( @@ -3434,6 +3479,13 @@ public: IAccelerationStructure* const* structures, AccessFlag::Enum sourceAccess, AccessFlag::Enum destAccess) override; + virtual SLANG_NO_THROW void SLANG_MCALL + bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override; + virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) override; virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() {} virtual SLANG_NO_THROW void SLANG_MCALL writeTimestamp(IQueryPool* pool, SlangInt index) override @@ -3533,8 +3585,7 @@ public: auto transientHeap = cmdImpl->m_transientHeap; auto& waitInfo = transientHeap->getQueueWaitInfo(m_queueIndex); waitInfo.waitValue = m_fenceValue; - ResetEvent(waitInfo.fenceEvent); - m_fence->SetEventOnCompletion(m_fenceValue, waitInfo.fenceEvent); + waitInfo.fence = m_fence; } m_d3dQueue->Signal(m_fence, m_fenceValue); ResetEvent(globalWaitHandle); @@ -3722,8 +3773,13 @@ SLANG_NO_THROW Result SLANG_MCALL D3D12Device::TransientResourceHeapImpl::synchr Array<HANDLE, 16> waitHandles; for (auto& waitInfo : m_waitInfos) { - if (waitInfo.waitValue != 0) + if (waitInfo.waitValue == 0) + continue; + if (waitInfo.fence) + { + waitInfo.fence->SetEventOnCompletion(waitInfo.waitValue, waitInfo.fenceEvent); waitHandles.add(waitInfo.fenceEvent); + } } WaitForMultipleObjects((DWORD)waitHandles.getCount(), waitHandles.getBuffer(), TRUE, INFINITE); m_viewHeap.deallocateAll(); @@ -3763,16 +3819,15 @@ Result D3D12Device::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffe return SLANG_OK; } -Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitter) +Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitter, RefPtr<PipelineStateBase>& newPipeline) { - RefPtr<PipelineStateBase> newPipeline; RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootShaderObject; m_renderer->maybeSpecializePipeline(m_currentPipeline, rootObjectImpl, newPipeline); - PipelineStateImpl* newPipelineImpl = static_cast<PipelineStateImpl*>(newPipeline.Ptr()); + PipelineStateBase* newPipelineImpl = static_cast<PipelineStateBase*>(newPipeline.Ptr()); auto commandList = m_d3dCmdList; auto pipelineTypeIndex = (int)newPipelineImpl->desc.type; auto programImpl = static_cast<ShaderProgramImpl*>(newPipelineImpl->m_program.Ptr()); - commandList->SetPipelineState(newPipelineImpl->m_pipelineState); + submitter->setPipelineState(newPipelineImpl); submitter->setRootSignature(programImpl->m_rootObjectLayout->m_rootSignature); RefPtr<ShaderObjectLayoutImpl> specializedRootLayout; SLANG_RETURN_ON_FAIL(rootObjectImpl->getSpecializedLayout(specializedRootLayout.writeRef())); @@ -5469,11 +5524,6 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i return SLANG_OK; } -Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState) -{ - return SLANG_E_NOT_AVAILABLE; -} - Result D3D12Device::QueryPoolImpl::init(const IQueryPool::Desc& desc, D3D12Device* device) { // Translate query type. @@ -5801,7 +5851,290 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::memoryBarrier m_commandBuffer->m_cmdList4->ResourceBarrier((UINT)count, barriers.getArrayView().getBuffer()); } +void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::bindPipeline( + IPipelineState* state, IShaderObject** outRootObject) +{ + bindPipelineImpl(state, outRootObject); +} + +void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) +{ + RefPtr<PipelineStateBase> newPipeline; + PipelineStateBase* pipeline = m_currentPipeline.Ptr(); + { + struct RayTracingSubmitter : public ComputeSubmitter + { + ID3D12GraphicsCommandList4* m_cmdList4; + RayTracingSubmitter(ID3D12GraphicsCommandList4* cmdList4) + : ComputeSubmitter(cmdList4), m_cmdList4(cmdList4) + { + } + virtual void setPipelineState(PipelineStateBase* pipeline) override + { + auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline); + m_cmdList4->SetPipelineState1(pipelineImpl->m_stateObject.get()); + } + }; + RayTracingSubmitter submitter(m_commandBuffer->m_cmdList4); + if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline))) + { + assert(!"Failed to bind render state"); + } + if (newPipeline) + pipeline = newPipeline.Ptr(); + } + auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline); + auto dispatchDesc = pipelineImpl->m_dispatchDesc; + int32_t rayGenShaderOffset = 0; + if (rayGenShaderName) + { + rayGenShaderOffset = + pipelineImpl->m_mapRayGenShaderNameToShaderTableIndex[rayGenShaderName].GetValue() * + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + } + dispatchDesc.RayGenerationShaderRecord.StartAddress += rayGenShaderOffset; + dispatchDesc.Width = (UINT)width; + dispatchDesc.Height = (UINT)height; + dispatchDesc.Depth = (UINT)depth; + m_commandBuffer->m_cmdList4->DispatchRays(&dispatchDesc); +} + +Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateDesc& inDesc, IPipelineState** outState) +{ + if (!m_device5) + { + return SLANG_E_NOT_AVAILABLE; + } + + RefPtr<RayTracingPipelineStateImpl> pipelineStateImpl = new RayTracingPipelineStateImpl(); + pipelineStateImpl->init(inDesc); + + auto program = static_cast<ShaderProgramImpl*>(inDesc.program); + auto slangProgram = program->slangProgram; + auto programLayout = slangProgram->getLayout(); + + if (!program->m_rootObjectLayout->m_rootSignature) + { + returnComPtr(outState, pipelineStateImpl); + return SLANG_OK; + } + List<D3D12_STATE_SUBOBJECT> subObjects; + ChunkedList<D3D12_DXIL_LIBRARY_DESC> dxilLibraries; + ChunkedList<D3D12_HIT_GROUP_DESC> hitGroups; + ChunkedList<ComPtr<ISlangBlob>> codeBlobs; + ComPtr<ISlangBlob> diagnostics; + ChunkedList<OSString> stringPool; + int32_t rayGenIndex = 0; + for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) + { + ComPtr<ISlangBlob> codeBlob; + auto compileResult = + slangProgram->getEntryPointCode(i, 0, codeBlob.writeRef(), diagnostics.writeRef()); + if (diagnostics.get()) + { + getDebugCallback()->handleMessage( + compileResult == SLANG_OK ? DebugMessageType::Warning : DebugMessageType::Error, + DebugMessageSource::Slang, + (char*)diagnostics->getBufferPointer()); + } + SLANG_RETURN_ON_FAIL(compileResult); + codeBlobs.add(codeBlob); + D3D12_DXIL_LIBRARY_DESC library = {}; + library.DXILLibrary.BytecodeLength = codeBlob->getBufferSize();; + library.DXILLibrary.pShaderBytecode = codeBlob->getBufferPointer(); + + D3D12_STATE_SUBOBJECT dxilSubObject = {}; + dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY; + dxilSubObject.pDesc = dxilLibraries.add(library); + subObjects.add(dxilSubObject); + + auto entryPointLayout = programLayout->getEntryPointByIndex(i); + switch (entryPointLayout->getStage()) + { + case SLANG_STAGE_RAY_GENERATION: + pipelineStateImpl + ->m_mapRayGenShaderNameToShaderTableIndex[entryPointLayout->getName()] = + rayGenIndex; + rayGenIndex++; + break; + default: + break; + } + } + auto getWStr = [&](const char* name) + { + String str = String(name); + auto wstr = str.toWString(); + return stringPool.add(wstr)->begin(); + }; + for (int i = 0; i < inDesc.hitGroupCount; i++) + { + auto hitGroup = inDesc.hitGroups[i]; + D3D12_HIT_GROUP_DESC hitGroupDesc = {}; + hitGroupDesc.Type = hitGroup.intersectionEntryPoint == nullptr + ? D3D12_HIT_GROUP_TYPE_TRIANGLES + : D3D12_HIT_GROUP_TYPE_PROCEDURAL_PRIMITIVE; + + if (hitGroup.anyHitEntryPoint) + { + hitGroupDesc.AnyHitShaderImport = getWStr(hitGroup.anyHitEntryPoint); + } + if (hitGroup.closestHitEntryPoint) + { + hitGroupDesc.ClosestHitShaderImport = getWStr(hitGroup.closestHitEntryPoint); + } + if (hitGroup.intersectionEntryPoint) + { + hitGroupDesc.IntersectionShaderImport = getWStr(hitGroup.intersectionEntryPoint); + } + StringBuilder hitGroupName; + hitGroupName << "hitgroup_" << i; + hitGroupDesc.HitGroupExport = getWStr(hitGroupName.ToString().getBuffer()); + + D3D12_STATE_SUBOBJECT hitGroupSubObject = {}; + hitGroupSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP; + hitGroupSubObject.pDesc = hitGroups.add(hitGroupDesc); + subObjects.add(hitGroupSubObject); + } + + D3D12_RAYTRACING_SHADER_CONFIG shaderConfig = {}; + // According to DXR spec, fixed function triangle intersections must use float2 as ray attributes + // that defines the barycentric coordinates at intersection. + shaderConfig.MaxAttributeSizeInBytes = sizeof(float) * 2; + shaderConfig.MaxPayloadSizeInBytes = inDesc.maxRayPayloadSize; + D3D12_STATE_SUBOBJECT shaderConfigSubObject = {}; + shaderConfigSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG; + shaderConfigSubObject.pDesc = &shaderConfig; + subObjects.add(shaderConfigSubObject); + + D3D12_GLOBAL_ROOT_SIGNATURE globalSignatureDesc = {}; + globalSignatureDesc.pGlobalRootSignature = program->m_rootObjectLayout->m_rootSignature.get(); + D3D12_STATE_SUBOBJECT globalSignatureSubobject = {}; + globalSignatureSubobject.Type = D3D12_STATE_SUBOBJECT_TYPE_GLOBAL_ROOT_SIGNATURE; + globalSignatureSubobject.pDesc = &globalSignatureDesc; + subObjects.add(globalSignatureSubobject); + + D3D12_RAYTRACING_PIPELINE_CONFIG pipelineConfig = {}; + pipelineConfig.MaxTraceRecursionDepth = inDesc.maxRecursion; + D3D12_STATE_SUBOBJECT pipelineConfigSubobject = {}; + pipelineConfigSubobject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_PIPELINE_CONFIG; + pipelineConfigSubobject.pDesc = &pipelineConfig; + subObjects.add(pipelineConfigSubobject); + + D3D12_STATE_OBJECT_DESC rtpsoDesc = {}; + rtpsoDesc.Type = D3D12_STATE_OBJECT_TYPE_RAYTRACING_PIPELINE; + rtpsoDesc.NumSubobjects = (UINT)subObjects.getCount(); + rtpsoDesc.pSubobjects = subObjects.getBuffer(); + SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef()))); + + SLANG_RETURN_ON_FAIL(pipelineStateImpl->createShaderTables(this, slangProgram, inDesc)); + + returnComPtr(outState, pipelineStateImpl); + return SLANG_OK; +} + +Result D3D12Device::RayTracingPipelineStateImpl::createShaderTables( + D3D12Device* device, + slang::IComponentType* slangProgram, + const RayTracingPipelineStateDesc& desc) +{ + ComPtr<ID3D12StateObjectProperties> stateObjectProperties; + m_stateObject->QueryInterface(stateObjectProperties.writeRef()); + auto programLayout = slangProgram->getLayout(); + struct ShaderIdentifier { uint32_t data[D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES / sizeof(uint32_t)]; }; + List<ShaderIdentifier> rayGenIdentifiers, missIdentifiers, hitgroupIdentifiers; + for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) + { + auto entryPointLayout = programLayout->getEntryPointByIndex(i); + ShaderIdentifier identifier; + switch (entryPointLayout->getStage()) + { + case SLANG_STAGE_RAY_GENERATION: + memcpy( + &identifier, + stateObjectProperties->GetShaderIdentifier( + String(entryPointLayout->getName()).toWString().begin()), + sizeof(ShaderIdentifier)); + rayGenIdentifiers.add(identifier); + break; + case SLANG_STAGE_MISS: + memcpy( + &identifier, + stateObjectProperties->GetShaderIdentifier( + String(entryPointLayout->getName()).toWString().begin()), + sizeof(ShaderIdentifier)); + missIdentifiers.add(identifier); + break; + default: + break; + } + } + for (int i = 0; i < desc.shaderTableHitGroupCount; i++) + { + StringBuilder hitgroupName; + hitgroupName << "hitgroup_" << desc.shaderTableHitGroupIndices[i]; + ShaderIdentifier hitgroupIdentifier; + memcpy( + &hitgroupIdentifier, + stateObjectProperties->GetShaderIdentifier(hitgroupName.toWString().begin()), + sizeof(ShaderIdentifier)); + hitgroupIdentifiers.add(hitgroupIdentifier); + } + + auto createShaderTableResource = [&](ArrayView<ShaderIdentifier> content, + RefPtr<BufferResourceImpl>& outResource) -> Result + { + IBufferResource::Desc bufferDesc = {}; + bufferDesc.type = IResource::Type::Buffer; + bufferDesc.defaultState = ResourceState::ShaderResource; + bufferDesc.allowedStates = ResourceStateSet( + ResourceState::CopySource, + ResourceState::UnorderedAccess, + ResourceState::ShaderResource); + bufferDesc.elementSize = 0; + bufferDesc.sizeInBytes = content.getCount() * sizeof(ShaderIdentifier); + bufferDesc.format = Format::Unknown; + ComPtr<IBufferResource> shaderTableResource; + SLANG_RETURN_ON_FAIL(device->createBufferResource( + bufferDesc, content.getBuffer(), shaderTableResource.writeRef())); + outResource = static_cast<BufferResourceImpl*>(shaderTableResource.get()); + return SLANG_OK; + }; + + if (desc.shaderTableHitGroupCount) + { + SLANG_RETURN_ON_FAIL( + createShaderTableResource(hitgroupIdentifiers.getArrayView(), m_hitgroupShaderTable)); + m_dispatchDesc.HitGroupTable.SizeInBytes = + (uint64_t)(sizeof(ShaderIdentifier)) * desc.shaderTableHitGroupCount; + m_dispatchDesc.HitGroupTable.StrideInBytes = sizeof(ShaderIdentifier); + m_dispatchDesc.HitGroupTable.StartAddress = m_hitgroupShaderTable->getDeviceAddress(); + } + if (rayGenIdentifiers.getCount()) + { + SLANG_RETURN_ON_FAIL( + createShaderTableResource(rayGenIdentifiers.getArrayView(), m_rayGenShaderTable)); + m_dispatchDesc.RayGenerationShaderRecord.SizeInBytes = sizeof(ShaderIdentifier); + m_dispatchDesc.RayGenerationShaderRecord.StartAddress = m_rayGenShaderTable->getDeviceAddress(); + } + if (missIdentifiers.getCount()) + { + SLANG_RETURN_ON_FAIL( + createShaderTableResource(missIdentifiers.getArrayView(), m_missShaderTable)); + m_dispatchDesc.MissShaderTable.SizeInBytes = + (uint64_t)(sizeof(ShaderIdentifier)) * missIdentifiers.getCount(); + m_dispatchDesc.MissShaderTable.StrideInBytes = sizeof(ShaderIdentifier); + m_dispatchDesc.MissShaderTable.StartAddress = m_missShaderTable->getDeviceAddress(); + } + return SLANG_OK; +} + #endif // SLANG_GFX_HAS_DXR_SUPPORT + Result D3D12Device::ShaderObjectImpl::setResource(ShaderOffset const& offset, IResourceView* resourceView) { if (offset.bindingRangeIndex < 0) diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp index 067581559..50cacc6c2 100644 --- a/tools/gfx/debug-layer.cpp +++ b/tools/gfx/debug-layer.cpp @@ -705,6 +705,7 @@ DebugCommandBuffer::DebugCommandBuffer() m_renderCommandEncoder.commandBuffer = this; m_computeCommandEncoder.commandBuffer = this; m_resourceCommandEncoder.commandBuffer = this; + m_rayTracingCommandEncoder.commandBuffer = this; } void DebugCommandBuffer::encodeRenderCommands( @@ -1084,6 +1085,25 @@ void DebugRayTracingCommandEncoder::memoryBarrier( baseObject->memoryBarrier(count, innerAS.getBuffer(), sourceAccess, destAccess); } +void DebugRayTracingCommandEncoder::bindPipeline( + IPipelineState* state, IShaderObject** outRootObject) +{ + SLANG_GFX_API_FUNC; + auto innerPipeline = getInnerObj(state); + baseObject->bindPipeline(innerPipeline, commandBuffer->rootObject.baseObject.writeRef()); + *outRootObject = &commandBuffer->rootObject; +} + +void DebugRayTracingCommandEncoder::dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) +{ + SLANG_GFX_API_FUNC; + baseObject->dispatchRays(rayGenShaderName, width, height, depth); +} + const ICommandQueue::Desc& DebugCommandQueue::getDesc() { SLANG_GFX_API_FUNC; diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h index 7433db966..c7de48149 100644 --- a/tools/gfx/debug-layer.h +++ b/tools/gfx/debug-layer.h @@ -351,6 +351,13 @@ public: IAccelerationStructure* const* structures, AccessFlag::Enum sourceAccess, AccessFlag::Enum destAccess) override; + virtual SLANG_NO_THROW void SLANG_MCALL + bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override; + virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) override; public: DebugCommandBuffer* commandBuffer; diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 2eb19b6e9..bb80c4f53 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -605,6 +605,14 @@ Result RendererBase::maybeSpecializePipeline( pipelineDesc, specializedPipelineComPtr.writeRef())); break; } + case PipelineType::RayTracing: + { + auto pipelineDesc = currentPipeline->desc.rayTracing; + pipelineDesc.program = specializedProgram; + SLANG_RETURN_ON_FAIL(createRayTracingPipelineState( + pipelineDesc, specializedPipelineComPtr.writeRef())); + break; + } default: break; } diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 1f0a3eaab..31a7566a2 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -766,7 +766,7 @@ public: auto bindingRangeIndex = offset.bindingRangeIndex; auto bindingRange = layout->getBindingRange(bindingRangeIndex); - auto objectIndex = bindingRange.subObjectIndex + offset.bindingArrayIndex; + Slang::Index objectIndex = bindingRange.subObjectIndex + offset.bindingArrayIndex; if (objectIndex >= m_userProvidedSpecializationArgs.getCount()) m_userProvidedSpecializationArgs.setCount(objectIndex + 1); if (!m_userProvidedSpecializationArgs[objectIndex]) @@ -816,7 +816,7 @@ public: subObjectIndexInRange++) { ExtendedShaderObjectTypeList typeArgs; - auto objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange; + Slang::Index objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange; auto subObject = m_objects[objectIndex]; if (!subObject) @@ -932,9 +932,19 @@ public: PipelineType type; GraphicsPipelineStateDesc graphics; ComputePipelineStateDesc compute; + RayTracingPipelineStateDesc rayTracing; ShaderProgramBase* getProgram() { - return static_cast<ShaderProgramBase*>(type == PipelineType::Compute ? compute.program : graphics.program); + switch (type) + { + case PipelineType::Compute: + return static_cast<ShaderProgramBase*>(compute.program); + case PipelineType::Graphics: + return static_cast<ShaderProgramBase*>(graphics.program); + case PipelineType::RayTracing: + return static_cast<ShaderProgramBase*>(rayTracing.program); + } + return nullptr; } } desc; @@ -1105,6 +1115,8 @@ public: public: ExtendedShaderObjectTypeList specializationArgs; // Given current pipeline and root shader object binding, generate and bind a specialized pipeline if necessary. + // The newly specialized pipeline is held alive by the pipeline cache so users of `outNewPipeline` do not + // need to maintain its lifespan. Result maybeSpecializePipeline( PipelineStateBase* currentPipeline, ShaderObjectBase* rootObject, diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index bc0271aa6..592cbaac1 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -1266,7 +1266,7 @@ public: vkPushConstantRange.size = ordinaryDataSize; vkPushConstantRange.stageFlags = VK_SHADER_STAGE_ALL; // TODO: be more clever - while(m_ownPushConstantRanges.getCount() <= pushConstantRangeIndex) + while((uint32_t)m_ownPushConstantRanges.getCount() <= pushConstantRangeIndex) { VkPushConstantRange emptyRange = { 0 }; m_ownPushConstantRanges.add(emptyRange); @@ -2995,7 +2995,7 @@ public: case slang::BindingType::ConstantBuffer: { BindingOffset objOffset = rangeOffset; - for (uint32_t i = 0; i < count; ++i) + for (Index i = 0; i < count; ++i) { // Binding a constant buffer sub-object is simple enough: // we just call `bindAsConstantBuffer` on it to bind @@ -3016,7 +3016,7 @@ public: case slang::BindingType::ParameterBlock: { BindingOffset objOffset = rangeOffset; - for (uint32_t i = 0; i < count; ++i) + for (Index i = 0; i < count; ++i) { // The case for `ParameterBlock<X>` is not that different // from `ConstantBuffer<X>`, except that we call `bindAsParameterBlock` @@ -3047,7 +3047,7 @@ public: // SimpleBindingOffset objOffset = rangeOffset.pending; SimpleBindingOffset objStride = rangeStride.pending; - for (uint32_t i = 0; i < count; ++i) + for (Index i = 0; i < count; ++i) { // An existential-type sub-object is always bound just as a value, // which handles its nested bindings and descriptor sets, but @@ -4258,6 +4258,25 @@ public: _memoryBarrier(count, structures, srcAccess, destAccess); } + virtual SLANG_NO_THROW void SLANG_MCALL + bindPipeline(IPipelineState* pipeline, IShaderObject** outRootObject) override + { + SLANG_UNUSED(pipeline); + SLANG_UNUSED(outRootObject); + } + + virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( + const char* rayGenShaderName, + int32_t width, + int32_t height, + int32_t depth) override + { + SLANG_UNUSED(rayGenShaderName); + SLANG_UNUSED(width); + SLANG_UNUSED(height); + SLANG_UNUSED(depth); + } + virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override { } |
