diff options
| author | Yong He <yonghe@outlook.com> | 2022-01-21 10:17:39 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-01-21 10:17:39 -0800 |
| commit | f85bc7ae98486b37518958e659f659f1ff9b125c (patch) | |
| tree | b18f40a62ac51ee77bdf651a6d9a26f277019ab4 /tools/gfx/d3d12/render-d3d12.cpp | |
| parent | 11d248293f1b56a790faadead1e3d94de81f29a2 (diff) | |
GFX: seperated ShaderTable. (#2090)
Diffstat (limited to 'tools/gfx/d3d12/render-d3d12.cpp')
| -rw-r--r-- | tools/gfx/d3d12/render-d3d12.cpp | 272 |
1 files changed, 134 insertions, 138 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index 634fca03e..e0c324d34 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -137,6 +137,8 @@ public: createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override; virtual SLANG_NO_THROW Result SLANG_MCALL + createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) override; + virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnostics) override; virtual SLANG_NO_THROW Result SLANG_MCALL createGraphicsPipelineState( const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override; @@ -466,12 +468,6 @@ public: { public: ComPtr<ID3D12StateObject> m_stateObject; - D3D12_DISPATCH_RAYS_DESC m_dispatchDesc = {}; - Dictionary<String, int32_t> m_mapRayGenShaderNameToShaderTableIndex; - // Shader Tables for each ray-tracing stage stored in GPU memory. - RefPtr<BufferResourceImpl> m_rayGenShaderTable; - RefPtr<BufferResourceImpl> m_hitgroupShaderTable; - RefPtr<BufferResourceImpl> m_missShaderTable; void init(const RayTracingPipelineStateDesc& inDesc) { PipelineStateDesc pipelineDesc; @@ -479,10 +475,6 @@ public: pipelineDesc.rayTracing = inDesc; initializeBase(pipelineDesc); } - Result createShaderTables( - D3D12Device* device, - slang::IComponentType* slangProgram, - const RayTracingPipelineStateDesc& desc); }; #endif @@ -882,7 +874,7 @@ public: IBufferResource* uploadResource; if (buffer->getDesc()->memoryType != MemoryType::Upload) { - transientHeap->allocateStagingBuffer(size, uploadResource, ResourceState::CopySource); + transientHeap->allocateStagingBuffer(size, uploadResource, ResourceState::General); } D3D12Resource& uploadResourceRef = @@ -3253,6 +3245,95 @@ public: D3D12DescriptorHeap m_cpuSamplerHeap; }; + class ShaderTableImpl : public ShaderTableBase + { + public: + uint32_t m_rayGenTableOffset; + uint32_t m_missTableOffset; + uint32_t m_hitGroupTableOffset; + + D3D12Device* m_device; + + virtual RefPtr<BufferResource> createDeviceBuffer( + PipelineStateBase* pipeline, + TransientResourceHeapBase* transientHeap, + IResourceCommandEncoder* encoder) override + { + uint32_t raygenTableSize = m_rayGenShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + uint32_t missTableSize = m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + uint32_t hitgroupTableSize = m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + m_rayGenTableOffset = 0; + m_missTableOffset = + (uint32_t)D3DUtil::calcAligned(raygenTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT); + m_hitGroupTableOffset = (uint32_t)D3DUtil::calcAligned( + m_missTableOffset + missTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT); + uint32_t tableSize = m_hitGroupTableOffset + hitgroupTableSize; + + auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline); + ComPtr<IBufferResource> bufferResource; + IBufferResource::Desc bufferDesc = {}; + bufferDesc.memoryType = gfx::MemoryType::DeviceLocal; + bufferDesc.defaultState = ResourceState::General; + bufferDesc.type = IResource::Type::Buffer; + bufferDesc.sizeInBytes = tableSize; + m_device->createBufferResource(bufferDesc, nullptr, bufferResource.writeRef()); + + ComPtr<ID3D12StateObjectProperties> stateObjectProperties; + pipelineImpl->m_stateObject->QueryInterface(stateObjectProperties.writeRef()); + + TransientResourceHeapImpl* transientHeapImpl = + static_cast<TransientResourceHeapImpl*>(transientHeap); + + IBufferResource* stagingBuffer = nullptr; + transientHeapImpl->allocateStagingBuffer( + tableSize, stagingBuffer, ResourceState::General); + + assert(stagingBuffer); + void* stagingPtr = nullptr; + stagingBuffer->map(nullptr, &stagingPtr); + + auto copyShaderIdInto = [&](void* dest, String& name) + { + void* shaderId = stateObjectProperties->GetShaderIdentifier(name.toWString().begin()); + memcpy(dest, shaderId, D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES); + }; + + uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr; + for (uint32_t i = 0; i < m_rayGenShaderCount; i++) + { + copyShaderIdInto( + stagingBufferPtr + m_rayGenTableOffset + + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, + m_entryPointNames[i]); + } + for (uint32_t i = 0; i < m_missShaderCount; i++) + { + copyShaderIdInto( + stagingBufferPtr + m_missTableOffset + + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, + m_entryPointNames[m_rayGenShaderCount + i]); + } + for (uint32_t i = 0; i < m_hitGroupCount; i++) + { + copyShaderIdInto( + stagingBufferPtr + m_hitGroupTableOffset + + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, + m_entryPointNames[m_rayGenShaderCount + m_missShaderCount + i]); + } + + stagingBuffer->unmap(nullptr); + encoder->copyBuffer(bufferResource, 0, stagingBuffer, 0, tableSize); + encoder->bufferBarrier( + 1, + bufferResource.readRef(), + gfx::ResourceState::CopyDestination, + gfx::ResourceState::ShaderResource); + RefPtr<BufferResource> resultPtr = static_cast<BufferResource*>(bufferResource.get()); + return _Move(resultPtr); + } + + }; + class CommandBufferImpl : public ICommandBuffer , public ComObject @@ -4103,7 +4184,7 @@ public: IBufferResource* stagingBuffer; m_commandBuffer->m_transientHeap->allocateStagingBuffer( - bufferSize, stagingBuffer, ResourceState::CopySource); + bufferSize, stagingBuffer, ResourceState::General); BufferResourceImpl* bufferImpl = static_cast<BufferResourceImpl*>(stagingBuffer); uint8_t* bufferData = nullptr; @@ -4447,7 +4528,8 @@ public: virtual SLANG_NO_THROW void SLANG_MCALL bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override; virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( - const char* rayGenShaderName, + uint32_t rayGenShaderIndex, + IShaderTable* shaderTable, int32_t width, int32_t height, int32_t depth) override; @@ -6945,6 +7027,15 @@ Result D3D12Device::createMutableRootShaderObject(IShaderProgram* program, IShad return SLANG_OK; } +Result D3D12Device::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) +{ + RefPtr<ShaderTableImpl> result = new ShaderTableImpl(); + result->m_device = this; + result->init(desc); + returnComPtr(outShaderTable, result); + return SLANG_OK; +} + Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState) { GraphicsPipelineStateDesc desc = inDesc; @@ -7546,7 +7637,8 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::bindPipeline( } void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays( - const char* rayGenShaderName, + uint32_t rayGenShaderIndex, + IShaderTable* shaderTable, int32_t width, int32_t height, int32_t depth) @@ -7576,15 +7668,33 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays( pipeline = newPipeline.Ptr(); } auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline); - auto dispatchDesc = pipelineImpl->m_dispatchDesc; - int32_t rayGenShaderOffset = 0; - if (rayGenShaderName) - { - rayGenShaderOffset = - pipelineImpl->m_mapRayGenShaderNameToShaderTableIndex[rayGenShaderName].GetValue() * - D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; - } - dispatchDesc.RayGenerationShaderRecord.StartAddress += rayGenShaderOffset; + + auto shaderTableImpl = static_cast<ShaderTableImpl*>(shaderTable); + + ResourceCommandEncoderImpl resourceCopyEncoder; + resourceCopyEncoder.init(m_renderer, m_commandBuffer); + auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer(pipelineImpl, m_transientHeap, &resourceCopyEncoder); + auto shaderTableAddr = shaderTableBuffer->getDeviceAddress(); + + D3D12_DISPATCH_RAYS_DESC dispatchDesc = {}; + + dispatchDesc.RayGenerationShaderRecord.StartAddress = + shaderTableAddr + shaderTableImpl->m_rayGenTableOffset + + rayGenShaderIndex * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + dispatchDesc.RayGenerationShaderRecord.SizeInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + + dispatchDesc.MissShaderTable.StartAddress = + shaderTableAddr + shaderTableImpl->m_missTableOffset; + dispatchDesc.MissShaderTable.SizeInBytes = + shaderTableImpl->m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + dispatchDesc.MissShaderTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + + dispatchDesc.HitGroupTable.StartAddress = + shaderTableAddr + shaderTableImpl->m_hitGroupTableOffset; + dispatchDesc.HitGroupTable.SizeInBytes = + shaderTableImpl->m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + dispatchDesc.HitGroupTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + dispatchDesc.Width = (UINT)width; dispatchDesc.Height = (UINT)height; dispatchDesc.Depth = (UINT)depth; @@ -7616,7 +7726,6 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD ChunkedList<ComPtr<ISlangBlob>> codeBlobs; ComPtr<ISlangBlob> diagnostics; ChunkedList<OSString> stringPool; - int32_t rayGenIndex = 0; for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) { ComPtr<ISlangBlob> codeBlob; @@ -7639,19 +7748,6 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY; dxilSubObject.pDesc = dxilLibraries.add(library); subObjects.add(dxilSubObject); - - auto entryPointLayout = programLayout->getEntryPointByIndex(i); - switch (entryPointLayout->getStage()) - { - case SLANG_STAGE_RAY_GENERATION: - pipelineStateImpl - ->m_mapRayGenShaderNameToShaderTableIndex[entryPointLayout->getName()] = - rayGenIndex; - rayGenIndex++; - break; - default: - break; - } } auto getWStr = [&](const char* name) { @@ -7679,9 +7775,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD { hitGroupDesc.IntersectionShaderImport = getWStr(hitGroup.intersectionEntryPoint); } - StringBuilder hitGroupName; - hitGroupName << "hitgroup_" << i; - hitGroupDesc.HitGroupExport = getWStr(hitGroupName.ToString().getBuffer()); + hitGroupDesc.HitGroupExport = getWStr(hitGroup.hitGroupName); D3D12_STATE_SUBOBJECT hitGroupSubObject = {}; hitGroupSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP; @@ -7719,108 +7813,10 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD rtpsoDesc.pSubobjects = subObjects.getBuffer(); SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef()))); - SLANG_RETURN_ON_FAIL(pipelineStateImpl->createShaderTables(this, slangProgram, inDesc)); - returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } -Result D3D12Device::RayTracingPipelineStateImpl::createShaderTables( - D3D12Device* device, - slang::IComponentType* slangProgram, - const RayTracingPipelineStateDesc& desc) -{ - ComPtr<ID3D12StateObjectProperties> stateObjectProperties; - m_stateObject->QueryInterface(stateObjectProperties.writeRef()); - auto programLayout = slangProgram->getLayout(); - struct ShaderIdentifier { uint32_t data[D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES / sizeof(uint32_t)]; }; - List<ShaderIdentifier> rayGenIdentifiers, missIdentifiers, hitgroupIdentifiers; - for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) - { - auto entryPointLayout = programLayout->getEntryPointByIndex(i); - ShaderIdentifier identifier; - switch (entryPointLayout->getStage()) - { - case SLANG_STAGE_RAY_GENERATION: - memcpy( - &identifier, - stateObjectProperties->GetShaderIdentifier( - String(entryPointLayout->getName()).toWString().begin()), - sizeof(ShaderIdentifier)); - rayGenIdentifiers.add(identifier); - break; - case SLANG_STAGE_MISS: - memcpy( - &identifier, - stateObjectProperties->GetShaderIdentifier( - String(entryPointLayout->getName()).toWString().begin()), - sizeof(ShaderIdentifier)); - missIdentifiers.add(identifier); - break; - default: - break; - } - } - for (int i = 0; i < desc.shaderTableHitGroupCount; i++) - { - StringBuilder hitgroupName; - hitgroupName << "hitgroup_" << desc.shaderTableHitGroupIndices[i]; - ShaderIdentifier hitgroupIdentifier; - memcpy( - &hitgroupIdentifier, - stateObjectProperties->GetShaderIdentifier(hitgroupName.toWString().begin()), - sizeof(ShaderIdentifier)); - hitgroupIdentifiers.add(hitgroupIdentifier); - } - - auto createShaderTableResource = [&](ArrayView<ShaderIdentifier> content, - RefPtr<BufferResourceImpl>& outResource) -> Result - { - IBufferResource::Desc bufferDesc = {}; - bufferDesc.type = IResource::Type::Buffer; - bufferDesc.defaultState = ResourceState::ShaderResource; - bufferDesc.allowedStates = ResourceStateSet( - ResourceState::CopySource, - ResourceState::UnorderedAccess, - ResourceState::ShaderResource); - bufferDesc.elementSize = 0; - bufferDesc.sizeInBytes = content.getCount() * sizeof(ShaderIdentifier); - bufferDesc.format = Format::Unknown; - ComPtr<IBufferResource> shaderTableResource; - SLANG_RETURN_ON_FAIL(device->createBufferResource( - bufferDesc, content.getBuffer(), shaderTableResource.writeRef())); - outResource = static_cast<BufferResourceImpl*>(shaderTableResource.get()); - return SLANG_OK; - }; - - if (desc.shaderTableHitGroupCount) - { - SLANG_RETURN_ON_FAIL( - createShaderTableResource(hitgroupIdentifiers.getArrayView(), m_hitgroupShaderTable)); - m_dispatchDesc.HitGroupTable.SizeInBytes = - (uint64_t)(sizeof(ShaderIdentifier)) * desc.shaderTableHitGroupCount; - m_dispatchDesc.HitGroupTable.StrideInBytes = sizeof(ShaderIdentifier); - m_dispatchDesc.HitGroupTable.StartAddress = m_hitgroupShaderTable->getDeviceAddress(); - } - if (rayGenIdentifiers.getCount()) - { - SLANG_RETURN_ON_FAIL( - createShaderTableResource(rayGenIdentifiers.getArrayView(), m_rayGenShaderTable)); - m_dispatchDesc.RayGenerationShaderRecord.SizeInBytes = sizeof(ShaderIdentifier); - m_dispatchDesc.RayGenerationShaderRecord.StartAddress = m_rayGenShaderTable->getDeviceAddress(); - } - if (missIdentifiers.getCount()) - { - SLANG_RETURN_ON_FAIL( - createShaderTableResource(missIdentifiers.getArrayView(), m_missShaderTable)); - m_dispatchDesc.MissShaderTable.SizeInBytes = - (uint64_t)(sizeof(ShaderIdentifier)) * missIdentifiers.getCount(); - m_dispatchDesc.MissShaderTable.StrideInBytes = sizeof(ShaderIdentifier); - m_dispatchDesc.MissShaderTable.StartAddress = m_missShaderTable->getDeviceAddress(); - } - return SLANG_OK; -} - #endif // SLANG_GFX_HAS_DXR_SUPPORT Result createNullDescriptor( |
