diff options
| -rw-r--r-- | examples/ray-tracing-pipeline/main.cpp | 22 | ||||
| -rw-r--r-- | slang-gfx.h | 34 | ||||
| -rw-r--r-- | tools/gfx/d3d12/render-d3d12.cpp | 272 | ||||
| -rw-r--r-- | tools/gfx/debug-layer.cpp | 17 | ||||
| -rw-r--r-- | tools/gfx/debug-layer.h | 12 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.cpp | 30 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.h | 50 | ||||
| -rw-r--r-- | tools/gfx/transient-resource-heap-base.h | 2 | ||||
| -rw-r--r-- | tools/gfx/vulkan/render-vk.cpp | 6 |
9 files changed, 290 insertions, 155 deletions
diff --git a/examples/ray-tracing-pipeline/main.cpp b/examples/ray-tracing-pipeline/main.cpp index a4e6333e9..9b01d251c 100644 --- a/examples/ray-tracing-pipeline/main.cpp +++ b/examples/ray-tracing-pipeline/main.cpp @@ -222,6 +222,7 @@ ComPtr<gfx::IBufferResource> gTLASBuffer; ComPtr<gfx::IAccelerationStructure> gTLAS; ComPtr<gfx::ITextureResource> gResultTexture; ComPtr<gfx::IResourceView> gResultTextureUAV; +ComPtr<gfx::IShaderTable> gShaderTable; uint64_t lastTime = 0; @@ -521,6 +522,8 @@ Slang::Result initialize() if (!gPresentPipelineState) return SLANG_FAIL; + const char* hitgroupNames[] = {"hitgroup0", "hitgroup1"}; + ComPtr<IShaderProgram> rayTracingProgram; SLANG_RETURN_ON_FAIL( loadShaderProgram(gDevice, true, rayTracingProgram.writeRef())); @@ -529,18 +532,29 @@ Slang::Result initialize() rtpDesc.hitGroupCount = 2; HitGroupDesc hitGroups[2]; hitGroups[0].closestHitEntryPoint = "closestHitShader"; + hitGroups[0].hitGroupName = hitgroupNames[0]; hitGroups[1].closestHitEntryPoint = "shadowRayHitShader"; + hitGroups[1].hitGroupName = hitgroupNames[1]; rtpDesc.hitGroups = hitGroups; rtpDesc.maxRayPayloadSize = 64; rtpDesc.maxRecursion = 2; - rtpDesc.shaderTableHitGroupCount = 2; - int32_t shaderTable[] = {0, 1}; - rtpDesc.shaderTableHitGroupIndices = shaderTable; SLANG_RETURN_ON_FAIL( gDevice->createRayTracingPipelineState(rtpDesc, gRenderPipelineState.writeRef())); if (!gRenderPipelineState) return SLANG_FAIL; + IShaderTable::Desc shaderTableDesc = {}; + const char* raygenName = "rayGenShader"; + const char* missName = "missShader"; + shaderTableDesc.program = rayTracingProgram; + shaderTableDesc.hitGroupCount = 2; + shaderTableDesc.hitGroupNames = hitgroupNames; + shaderTableDesc.rayGenShaderCount = 1; + shaderTableDesc.rayGenShaderEntryPointNames = &raygenName; + shaderTableDesc.missShaderCount = 1; + shaderTableDesc.missShaderEntryPointNames = &missName; + SLANG_RETURN_ON_FAIL(gDevice->createShaderTable(shaderTableDesc, gShaderTable.writeRef())); + createResultTexture(); return SLANG_OK; } @@ -626,7 +640,7 @@ virtual void renderFrame(int frameBufferIndex) override cursor["uniforms"].setData(&gUniforms, sizeof(Uniforms)); cursor["sceneBVH"].setResource(gTLAS); cursor["primitiveBuffer"].setResource(gPrimitiveBufferSRV); - renderEncoder->dispatchRays(nullptr, windowWidth, windowHeight, 1); + renderEncoder->dispatchRays(0, gShaderTable, windowWidth, windowHeight, 1); renderEncoder->endEncoding(); renderCommandBuffer->close(); gQueue->executeCommandBuffer(renderCommandBuffer); diff --git a/slang-gfx.h b/slang-gfx.h index c35892eb4..5949e5463 100644 --- a/slang-gfx.h +++ b/slang-gfx.h @@ -1247,6 +1247,7 @@ struct RayTracingPipelineFlags struct HitGroupDesc { + const char* hitGroupName = nullptr; const char* closestHitEntryPoint = nullptr; const char* anyHitEntryPoint = nullptr; const char* intersectionEntryPoint = nullptr; @@ -1257,13 +1258,33 @@ struct RayTracingPipelineStateDesc IShaderProgram* program = nullptr; int32_t hitGroupCount; const HitGroupDesc* hitGroups; - int32_t shaderTableHitGroupCount; - int32_t* shaderTableHitGroupIndices; int maxRecursion; int maxRayPayloadSize; RayTracingPipelineFlags::Enum flags; }; +class IShaderTable : public ISlangUnknown +{ +public: + struct Desc + { + uint32_t rayGenShaderCount; + const char** rayGenShaderEntryPointNames; + + uint32_t missShaderCount; + const char** missShaderEntryPointNames; + + uint32_t hitGroupCount; + const char** hitGroupNames; + + IShaderProgram* program; + }; +}; +#define SLANG_UUID_IShaderTable \ + { \ + 0xa721522c, 0xdf31, 0x4c2f, { 0xa5, 0xe7, 0x3b, 0xe0, 0x12, 0x4b, 0x31, 0x78 } \ + } + class IPipelineState : public ISlangUnknown { }; @@ -1657,10 +1678,10 @@ public: bindPipeline(IPipelineState* state, IShaderObject** outRootObject) = 0; /// Issues a dispatch command to start ray tracing workload with a ray tracing pipeline. - /// `rayGenShaderName` specifies the name of the ray generation shader to launch. Pass nullptr for - /// the first ray generation shader defined in `raytracingPipeline`. + /// `rayGenShaderIndex` specifies the index into the shader table that identifies the ray generation shader. virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( - const char* rayGenShaderName, + uint32_t rayGenShaderIndex, + IShaderTable* shaderTable, int32_t width, int32_t height, int32_t depth) = 0; @@ -2153,6 +2174,9 @@ public: IShaderProgram* program, IShaderObject** outObject) = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL + createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable) = 0; + virtual SLANG_NO_THROW Result SLANG_MCALL createProgram( const IShaderProgram::Desc& desc, IShaderProgram** outProgram, diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index 634fca03e..e0c324d34 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -137,6 +137,8 @@ public: createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override; virtual SLANG_NO_THROW Result SLANG_MCALL + createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) override; + virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnostics) override; virtual SLANG_NO_THROW Result SLANG_MCALL createGraphicsPipelineState( const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override; @@ -466,12 +468,6 @@ public: { public: ComPtr<ID3D12StateObject> m_stateObject; - D3D12_DISPATCH_RAYS_DESC m_dispatchDesc = {}; - Dictionary<String, int32_t> m_mapRayGenShaderNameToShaderTableIndex; - // Shader Tables for each ray-tracing stage stored in GPU memory. - RefPtr<BufferResourceImpl> m_rayGenShaderTable; - RefPtr<BufferResourceImpl> m_hitgroupShaderTable; - RefPtr<BufferResourceImpl> m_missShaderTable; void init(const RayTracingPipelineStateDesc& inDesc) { PipelineStateDesc pipelineDesc; @@ -479,10 +475,6 @@ public: pipelineDesc.rayTracing = inDesc; initializeBase(pipelineDesc); } - Result createShaderTables( - D3D12Device* device, - slang::IComponentType* slangProgram, - const RayTracingPipelineStateDesc& desc); }; #endif @@ -882,7 +874,7 @@ public: IBufferResource* uploadResource; if (buffer->getDesc()->memoryType != MemoryType::Upload) { - transientHeap->allocateStagingBuffer(size, uploadResource, ResourceState::CopySource); + transientHeap->allocateStagingBuffer(size, uploadResource, ResourceState::General); } D3D12Resource& uploadResourceRef = @@ -3253,6 +3245,95 @@ public: D3D12DescriptorHeap m_cpuSamplerHeap; }; + class ShaderTableImpl : public ShaderTableBase + { + public: + uint32_t m_rayGenTableOffset; + uint32_t m_missTableOffset; + uint32_t m_hitGroupTableOffset; + + D3D12Device* m_device; + + virtual RefPtr<BufferResource> createDeviceBuffer( + PipelineStateBase* pipeline, + TransientResourceHeapBase* transientHeap, + IResourceCommandEncoder* encoder) override + { + uint32_t raygenTableSize = m_rayGenShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + uint32_t missTableSize = m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + uint32_t hitgroupTableSize = m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + m_rayGenTableOffset = 0; + m_missTableOffset = + (uint32_t)D3DUtil::calcAligned(raygenTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT); + m_hitGroupTableOffset = (uint32_t)D3DUtil::calcAligned( + m_missTableOffset + missTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT); + uint32_t tableSize = m_hitGroupTableOffset + hitgroupTableSize; + + auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline); + ComPtr<IBufferResource> bufferResource; + IBufferResource::Desc bufferDesc = {}; + bufferDesc.memoryType = gfx::MemoryType::DeviceLocal; + bufferDesc.defaultState = ResourceState::General; + bufferDesc.type = IResource::Type::Buffer; + bufferDesc.sizeInBytes = tableSize; + m_device->createBufferResource(bufferDesc, nullptr, bufferResource.writeRef()); + + ComPtr<ID3D12StateObjectProperties> stateObjectProperties; + pipelineImpl->m_stateObject->QueryInterface(stateObjectProperties.writeRef()); + + TransientResourceHeapImpl* transientHeapImpl = + static_cast<TransientResourceHeapImpl*>(transientHeap); + + IBufferResource* stagingBuffer = nullptr; + transientHeapImpl->allocateStagingBuffer( + tableSize, stagingBuffer, ResourceState::General); + + assert(stagingBuffer); + void* stagingPtr = nullptr; + stagingBuffer->map(nullptr, &stagingPtr); + + auto copyShaderIdInto = [&](void* dest, String& name) + { + void* shaderId = stateObjectProperties->GetShaderIdentifier(name.toWString().begin()); + memcpy(dest, shaderId, D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES); + }; + + uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr; + for (uint32_t i = 0; i < m_rayGenShaderCount; i++) + { + copyShaderIdInto( + stagingBufferPtr + m_rayGenTableOffset + + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, + m_entryPointNames[i]); + } + for (uint32_t i = 0; i < m_missShaderCount; i++) + { + copyShaderIdInto( + stagingBufferPtr + m_missTableOffset + + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, + m_entryPointNames[m_rayGenShaderCount + i]); + } + for (uint32_t i = 0; i < m_hitGroupCount; i++) + { + copyShaderIdInto( + stagingBufferPtr + m_hitGroupTableOffset + + D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i, + m_entryPointNames[m_rayGenShaderCount + m_missShaderCount + i]); + } + + stagingBuffer->unmap(nullptr); + encoder->copyBuffer(bufferResource, 0, stagingBuffer, 0, tableSize); + encoder->bufferBarrier( + 1, + bufferResource.readRef(), + gfx::ResourceState::CopyDestination, + gfx::ResourceState::ShaderResource); + RefPtr<BufferResource> resultPtr = static_cast<BufferResource*>(bufferResource.get()); + return _Move(resultPtr); + } + + }; + class CommandBufferImpl : public ICommandBuffer , public ComObject @@ -4103,7 +4184,7 @@ public: IBufferResource* stagingBuffer; m_commandBuffer->m_transientHeap->allocateStagingBuffer( - bufferSize, stagingBuffer, ResourceState::CopySource); + bufferSize, stagingBuffer, ResourceState::General); BufferResourceImpl* bufferImpl = static_cast<BufferResourceImpl*>(stagingBuffer); uint8_t* bufferData = nullptr; @@ -4447,7 +4528,8 @@ public: virtual SLANG_NO_THROW void SLANG_MCALL bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override; virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( - const char* rayGenShaderName, + uint32_t rayGenShaderIndex, + IShaderTable* shaderTable, int32_t width, int32_t height, int32_t depth) override; @@ -6945,6 +7027,15 @@ Result D3D12Device::createMutableRootShaderObject(IShaderProgram* program, IShad return SLANG_OK; } +Result D3D12Device::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) +{ + RefPtr<ShaderTableImpl> result = new ShaderTableImpl(); + result->m_device = this; + result->init(desc); + returnComPtr(outShaderTable, result); + return SLANG_OK; +} + Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState) { GraphicsPipelineStateDesc desc = inDesc; @@ -7546,7 +7637,8 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::bindPipeline( } void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays( - const char* rayGenShaderName, + uint32_t rayGenShaderIndex, + IShaderTable* shaderTable, int32_t width, int32_t height, int32_t depth) @@ -7576,15 +7668,33 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays( pipeline = newPipeline.Ptr(); } auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline); - auto dispatchDesc = pipelineImpl->m_dispatchDesc; - int32_t rayGenShaderOffset = 0; - if (rayGenShaderName) - { - rayGenShaderOffset = - pipelineImpl->m_mapRayGenShaderNameToShaderTableIndex[rayGenShaderName].GetValue() * - D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; - } - dispatchDesc.RayGenerationShaderRecord.StartAddress += rayGenShaderOffset; + + auto shaderTableImpl = static_cast<ShaderTableImpl*>(shaderTable); + + ResourceCommandEncoderImpl resourceCopyEncoder; + resourceCopyEncoder.init(m_renderer, m_commandBuffer); + auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer(pipelineImpl, m_transientHeap, &resourceCopyEncoder); + auto shaderTableAddr = shaderTableBuffer->getDeviceAddress(); + + D3D12_DISPATCH_RAYS_DESC dispatchDesc = {}; + + dispatchDesc.RayGenerationShaderRecord.StartAddress = + shaderTableAddr + shaderTableImpl->m_rayGenTableOffset + + rayGenShaderIndex * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + dispatchDesc.RayGenerationShaderRecord.SizeInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + + dispatchDesc.MissShaderTable.StartAddress = + shaderTableAddr + shaderTableImpl->m_missTableOffset; + dispatchDesc.MissShaderTable.SizeInBytes = + shaderTableImpl->m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + dispatchDesc.MissShaderTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + + dispatchDesc.HitGroupTable.StartAddress = + shaderTableAddr + shaderTableImpl->m_hitGroupTableOffset; + dispatchDesc.HitGroupTable.SizeInBytes = + shaderTableImpl->m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + dispatchDesc.HitGroupTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES; + dispatchDesc.Width = (UINT)width; dispatchDesc.Height = (UINT)height; dispatchDesc.Depth = (UINT)depth; @@ -7616,7 +7726,6 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD ChunkedList<ComPtr<ISlangBlob>> codeBlobs; ComPtr<ISlangBlob> diagnostics; ChunkedList<OSString> stringPool; - int32_t rayGenIndex = 0; for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) { ComPtr<ISlangBlob> codeBlob; @@ -7639,19 +7748,6 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY; dxilSubObject.pDesc = dxilLibraries.add(library); subObjects.add(dxilSubObject); - - auto entryPointLayout = programLayout->getEntryPointByIndex(i); - switch (entryPointLayout->getStage()) - { - case SLANG_STAGE_RAY_GENERATION: - pipelineStateImpl - ->m_mapRayGenShaderNameToShaderTableIndex[entryPointLayout->getName()] = - rayGenIndex; - rayGenIndex++; - break; - default: - break; - } } auto getWStr = [&](const char* name) { @@ -7679,9 +7775,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD { hitGroupDesc.IntersectionShaderImport = getWStr(hitGroup.intersectionEntryPoint); } - StringBuilder hitGroupName; - hitGroupName << "hitgroup_" << i; - hitGroupDesc.HitGroupExport = getWStr(hitGroupName.ToString().getBuffer()); + hitGroupDesc.HitGroupExport = getWStr(hitGroup.hitGroupName); D3D12_STATE_SUBOBJECT hitGroupSubObject = {}; hitGroupSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP; @@ -7719,108 +7813,10 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD rtpsoDesc.pSubobjects = subObjects.getBuffer(); SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef()))); - SLANG_RETURN_ON_FAIL(pipelineStateImpl->createShaderTables(this, slangProgram, inDesc)); - returnComPtr(outState, pipelineStateImpl); return SLANG_OK; } -Result D3D12Device::RayTracingPipelineStateImpl::createShaderTables( - D3D12Device* device, - slang::IComponentType* slangProgram, - const RayTracingPipelineStateDesc& desc) -{ - ComPtr<ID3D12StateObjectProperties> stateObjectProperties; - m_stateObject->QueryInterface(stateObjectProperties.writeRef()); - auto programLayout = slangProgram->getLayout(); - struct ShaderIdentifier { uint32_t data[D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES / sizeof(uint32_t)]; }; - List<ShaderIdentifier> rayGenIdentifiers, missIdentifiers, hitgroupIdentifiers; - for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) - { - auto entryPointLayout = programLayout->getEntryPointByIndex(i); - ShaderIdentifier identifier; - switch (entryPointLayout->getStage()) - { - case SLANG_STAGE_RAY_GENERATION: - memcpy( - &identifier, - stateObjectProperties->GetShaderIdentifier( - String(entryPointLayout->getName()).toWString().begin()), - sizeof(ShaderIdentifier)); - rayGenIdentifiers.add(identifier); - break; - case SLANG_STAGE_MISS: - memcpy( - &identifier, - stateObjectProperties->GetShaderIdentifier( - String(entryPointLayout->getName()).toWString().begin()), - sizeof(ShaderIdentifier)); - missIdentifiers.add(identifier); - break; - default: - break; - } - } - for (int i = 0; i < desc.shaderTableHitGroupCount; i++) - { - StringBuilder hitgroupName; - hitgroupName << "hitgroup_" << desc.shaderTableHitGroupIndices[i]; - ShaderIdentifier hitgroupIdentifier; - memcpy( - &hitgroupIdentifier, - stateObjectProperties->GetShaderIdentifier(hitgroupName.toWString().begin()), - sizeof(ShaderIdentifier)); - hitgroupIdentifiers.add(hitgroupIdentifier); - } - - auto createShaderTableResource = [&](ArrayView<ShaderIdentifier> content, - RefPtr<BufferResourceImpl>& outResource) -> Result - { - IBufferResource::Desc bufferDesc = {}; - bufferDesc.type = IResource::Type::Buffer; - bufferDesc.defaultState = ResourceState::ShaderResource; - bufferDesc.allowedStates = ResourceStateSet( - ResourceState::CopySource, - ResourceState::UnorderedAccess, - ResourceState::ShaderResource); - bufferDesc.elementSize = 0; - bufferDesc.sizeInBytes = content.getCount() * sizeof(ShaderIdentifier); - bufferDesc.format = Format::Unknown; - ComPtr<IBufferResource> shaderTableResource; - SLANG_RETURN_ON_FAIL(device->createBufferResource( - bufferDesc, content.getBuffer(), shaderTableResource.writeRef())); - outResource = static_cast<BufferResourceImpl*>(shaderTableResource.get()); - return SLANG_OK; - }; - - if (desc.shaderTableHitGroupCount) - { - SLANG_RETURN_ON_FAIL( - createShaderTableResource(hitgroupIdentifiers.getArrayView(), m_hitgroupShaderTable)); - m_dispatchDesc.HitGroupTable.SizeInBytes = - (uint64_t)(sizeof(ShaderIdentifier)) * desc.shaderTableHitGroupCount; - m_dispatchDesc.HitGroupTable.StrideInBytes = sizeof(ShaderIdentifier); - m_dispatchDesc.HitGroupTable.StartAddress = m_hitgroupShaderTable->getDeviceAddress(); - } - if (rayGenIdentifiers.getCount()) - { - SLANG_RETURN_ON_FAIL( - createShaderTableResource(rayGenIdentifiers.getArrayView(), m_rayGenShaderTable)); - m_dispatchDesc.RayGenerationShaderRecord.SizeInBytes = sizeof(ShaderIdentifier); - m_dispatchDesc.RayGenerationShaderRecord.StartAddress = m_rayGenShaderTable->getDeviceAddress(); - } - if (missIdentifiers.getCount()) - { - SLANG_RETURN_ON_FAIL( - createShaderTableResource(missIdentifiers.getArrayView(), m_missShaderTable)); - m_dispatchDesc.MissShaderTable.SizeInBytes = - (uint64_t)(sizeof(ShaderIdentifier)) * missIdentifiers.getCount(); - m_dispatchDesc.MissShaderTable.StrideInBytes = sizeof(ShaderIdentifier); - m_dispatchDesc.MissShaderTable.StartAddress = m_missShaderTable->getDeviceAddress(); - } - return SLANG_OK; -} - #endif // SLANG_GFX_HAS_DXR_SUPPORT Result createNullDescriptor( diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp index fa8fc6fda..53deec385 100644 --- a/tools/gfx/debug-layer.cpp +++ b/tools/gfx/debug-layer.cpp @@ -140,7 +140,7 @@ SLANG_GFX_DEBUG_GET_INTERFACE_IMPL(TransientResourceHeap) SLANG_GFX_DEBUG_GET_INTERFACE_IMPL(QueryPool) SLANG_GFX_DEBUG_GET_INTERFACE_IMPL_PARENT(AccelerationStructure, ResourceView) SLANG_GFX_DEBUG_GET_INTERFACE_IMPL(Fence) - +SLANG_GFX_DEBUG_GET_INTERFACE_IMPL(ShaderTable) #undef SLANG_GFX_DEBUG_GET_INTERFACE_IMPL #undef SLANG_GFX_DEBUG_GET_INTERFACE_IMPL_PARENT @@ -179,6 +179,7 @@ SLANG_GFX_DEBUG_GET_OBJ_IMPL(TransientResourceHeap) SLANG_GFX_DEBUG_GET_OBJ_IMPL(QueryPool) SLANG_GFX_DEBUG_GET_OBJ_IMPL(AccelerationStructure) SLANG_GFX_DEBUG_GET_OBJ_IMPL(Fence) +SLANG_GFX_DEBUG_GET_OBJ_IMPL(ShaderTable) #undef SLANG_GFX_DEBUG_GET_OBJ_IMPL @@ -802,6 +803,15 @@ Result DebugDevice::getTextureAllocationInfo( return baseObject->getTextureAllocationInfo(desc, outSize, outAlignment); } +Result DebugDevice::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable) +{ + SLANG_GFX_API_FUNC; + RefPtr<DebugShaderTable> result = new DebugShaderTable(); + SLANG_RETURN_ON_FAIL(baseObject->createShaderTable(desc, result->baseObject.writeRef())); + returnComPtr(outTable, result); + return SLANG_OK; +} + IResource::Type DebugBufferResource::getType() { SLANG_GFX_API_FUNC; @@ -1477,13 +1487,14 @@ void DebugRayTracingCommandEncoder::bindPipeline( } void DebugRayTracingCommandEncoder::dispatchRays( - const char* rayGenShaderName, + uint32_t rayGenShaderIndex, + IShaderTable* shaderTable, int32_t width, int32_t height, int32_t depth) { SLANG_GFX_API_FUNC; - baseObject->dispatchRays(rayGenShaderName, width, height, depth); + baseObject->dispatchRays(rayGenShaderIndex, getInnerObj(shaderTable), width, height, depth); } const ICommandQueue::Desc& DebugCommandQueue::getDesc() diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h index 04dafadf2..a2de7eb31 100644 --- a/tools/gfx/debug-layer.h +++ b/tools/gfx/debug-layer.h @@ -156,6 +156,15 @@ public: uint64_t timeout) override; virtual SLANG_NO_THROW Result SLANG_MCALL getTextureAllocationInfo( const ITextureResource::Desc& desc, size_t* outSize, size_t* outAlignment) override; + virtual SLANG_NO_THROW Result SLANG_MCALL + createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable) override; +}; + +class DebugShaderTable : public DebugObject<IShaderTable> +{ +public: + SLANG_COM_OBJECT_IUNKNOWN_ALL; + IShaderTable* getInterface(const Slang::Guid& guid); }; class DebugQueryPool : public DebugObject<IQueryPool> @@ -511,7 +520,8 @@ public: virtual SLANG_NO_THROW void SLANG_MCALL bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override; virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( - const char* rayGenShaderName, + uint32_t rayGenShaderIndex, + IShaderTable* shaderTable, int32_t width, int32_t height, int32_t depth) override; diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index fb7baf85d..52cf7ffac 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -32,6 +32,7 @@ const Slang::Guid GfxGUID::IID_ICommandQueue = SLANG_UUID_ICommandQueue; const Slang::Guid GfxGUID::IID_IQueryPool = SLANG_UUID_IQueryPool; const Slang::Guid GfxGUID::IID_IAccelerationStructure = SLANG_UUID_IAccelerationStructure; const Slang::Guid GfxGUID::IID_IFence = SLANG_UUID_IFence; +const Slang::Guid GfxGUID::IID_IShaderTable = SLANG_UUID_IShaderTable; StageType translateStage(SlangStage slangStage) @@ -438,6 +439,13 @@ Result RendererBase::createAccelerationStructure( return SLANG_E_NOT_AVAILABLE; } +Result RendererBase::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable) +{ + SLANG_UNUSED(desc); + SLANG_UNUSED(outTable); + return SLANG_E_NOT_AVAILABLE; +} + Result RendererBase::createRayTracingPipelineState(const RayTracingPipelineStateDesc& desc, IPipelineState** outState) { SLANG_UNUSED(desc); @@ -847,4 +855,26 @@ Result ShaderObjectBase::copyFrom(IShaderObject* object, ITransientResourceHeap* return SLANG_FAIL; } +Result ShaderTableBase::init(const IShaderTable::Desc& desc) +{ + m_rayGenShaderCount = desc.rayGenShaderCount; + m_missShaderCount = desc.missShaderCount; + m_hitGroupCount = desc.hitGroupCount; + m_entryPointNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount); + for (uint32_t i = 0; i < desc.rayGenShaderCount; i++) + { + m_entryPointNames.add(desc.rayGenShaderEntryPointNames[i]); + } + for (uint32_t i = 0; i < desc.missShaderCount; i++) + { + m_entryPointNames.add(desc.missShaderEntryPointNames[i]); + } + for (uint32_t i = 0; i < desc.hitGroupCount; i++) + { + m_entryPointNames.add(desc.hitGroupNames[i]); + } + return SLANG_OK; +} + } // namespace gfx + diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index 33d51e9ad..6baade085 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -39,6 +39,7 @@ struct GfxGUID static const Slang::Guid IID_IQueryPool; static const Slang::Guid IID_IAccelerationStructure; static const Slang::Guid IID_IFence; + static const Slang::Guid IID_IShaderTable; }; // We use a `BreakableReference` to avoid the cyclic reference situation in gfx implementation. @@ -1186,7 +1187,7 @@ public: } public: SLANG_COM_OBJECT_IUNKNOWN_ALL - ITransientResourceHeap* getInterface(const Slang::Guid& guid) + ITransientResourceHeap* getInterface(const Slang::Guid& guid) { if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ITransientResourceHeap) return static_cast<ITransientResourceHeap*>(this); @@ -1194,6 +1195,48 @@ public: } }; +class ShaderTableBase + : public IShaderTable + , public Slang::ComObject +{ +public: + Slang::List<Slang::String> m_entryPointNames; + uint32_t m_rayGenShaderCount; + uint32_t m_missShaderCount; + uint32_t m_hitGroupCount; + + Slang::Dictionary<PipelineStateBase*, Slang::RefPtr<BufferResource>> m_deviceBuffers; + + SLANG_COM_OBJECT_IUNKNOWN_ALL + IShaderTable* getInterface(const Slang::Guid& guid) + { + if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderTable) + return static_cast<IShaderTable*>(this); + return nullptr; + } + + virtual Slang::RefPtr<BufferResource> createDeviceBuffer( + PipelineStateBase* pipeline, + TransientResourceHeapBase* transientHeap, + IResourceCommandEncoder* encoder) = 0; + + BufferResource* getOrCreateBuffer( + PipelineStateBase* pipeline, + TransientResourceHeapBase* transientHeap, + IResourceCommandEncoder* encoder) + { + if (auto ptr = m_deviceBuffers.TryGetValue(pipeline)) + { + return ptr->Ptr(); + } + auto result = createDeviceBuffer(pipeline, transientHeap, encoder); + m_deviceBuffers[pipeline] = result; + return result; + } + + Result init(const IShaderTable::Desc& desc); +}; + // Renderer implementation shared by all platforms. // Responsible for shader compilation, specialization and caching. class RendererBase : public IDevice, public Slang::ComObject @@ -1256,6 +1299,11 @@ public: // Provides a default implementation that returns SLANG_E_NOT_AVAILABLE for platforms // without ray tracing support. + virtual SLANG_NO_THROW Result SLANG_MCALL + createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable) override; + + // Provides a default implementation that returns SLANG_E_NOT_AVAILABLE for platforms + // without ray tracing support. virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState( const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override; diff --git a/tools/gfx/transient-resource-heap-base.h b/tools/gfx/transient-resource-heap-base.h index f3df1c139..2dc16dcd4 100644 --- a/tools/gfx/transient-resource-heap-base.h +++ b/tools/gfx/transient-resource-heap-base.h @@ -55,7 +55,7 @@ public: bufferDesc.defaultState = state; bufferDesc.allowedStates = ResourceStateSet(ResourceState::CopyDestination, ResourceState::CopySource); - if (state == ResourceState::CopySource) + if (state == ResourceState::General) bufferDesc.memoryType = MemoryType::Upload; else bufferDesc.memoryType = MemoryType::ReadBack; diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 9f5af2a1e..f31195cc8 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -5117,12 +5117,14 @@ public: } virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays( - const char* rayGenShaderName, + uint32_t raygenShaderIndex, + IShaderTable* shaderTable, int32_t width, int32_t height, int32_t depth) override { - SLANG_UNUSED(rayGenShaderName); + SLANG_UNUSED(raygenShaderIndex); + SLANG_UNUSED(shaderTable); SLANG_UNUSED(width); SLANG_UNUSED(height); SLANG_UNUSED(depth); |
