summaryrefslogtreecommitdiffstats
path: root/tools/gfx/d3d12/render-d3d12.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tools/gfx/d3d12/render-d3d12.cpp')
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp272
1 files changed, 134 insertions, 138 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 634fca03e..e0c324d34 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -137,6 +137,8 @@ public:
createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
+ createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnostics) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createGraphicsPipelineState(
const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override;
@@ -466,12 +468,6 @@ public:
{
public:
ComPtr<ID3D12StateObject> m_stateObject;
- D3D12_DISPATCH_RAYS_DESC m_dispatchDesc = {};
- Dictionary<String, int32_t> m_mapRayGenShaderNameToShaderTableIndex;
- // Shader Tables for each ray-tracing stage stored in GPU memory.
- RefPtr<BufferResourceImpl> m_rayGenShaderTable;
- RefPtr<BufferResourceImpl> m_hitgroupShaderTable;
- RefPtr<BufferResourceImpl> m_missShaderTable;
void init(const RayTracingPipelineStateDesc& inDesc)
{
PipelineStateDesc pipelineDesc;
@@ -479,10 +475,6 @@ public:
pipelineDesc.rayTracing = inDesc;
initializeBase(pipelineDesc);
}
- Result createShaderTables(
- D3D12Device* device,
- slang::IComponentType* slangProgram,
- const RayTracingPipelineStateDesc& desc);
};
#endif
@@ -882,7 +874,7 @@ public:
IBufferResource* uploadResource;
if (buffer->getDesc()->memoryType != MemoryType::Upload)
{
- transientHeap->allocateStagingBuffer(size, uploadResource, ResourceState::CopySource);
+ transientHeap->allocateStagingBuffer(size, uploadResource, ResourceState::General);
}
D3D12Resource& uploadResourceRef =
@@ -3253,6 +3245,95 @@ public:
D3D12DescriptorHeap m_cpuSamplerHeap;
};
+ class ShaderTableImpl : public ShaderTableBase
+ {
+ public:
+ uint32_t m_rayGenTableOffset;
+ uint32_t m_missTableOffset;
+ uint32_t m_hitGroupTableOffset;
+
+ D3D12Device* m_device;
+
+ virtual RefPtr<BufferResource> createDeviceBuffer(
+ PipelineStateBase* pipeline,
+ TransientResourceHeapBase* transientHeap,
+ IResourceCommandEncoder* encoder) override
+ {
+ uint32_t raygenTableSize = m_rayGenShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ uint32_t missTableSize = m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ uint32_t hitgroupTableSize = m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ m_rayGenTableOffset = 0;
+ m_missTableOffset =
+ (uint32_t)D3DUtil::calcAligned(raygenTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
+ m_hitGroupTableOffset = (uint32_t)D3DUtil::calcAligned(
+ m_missTableOffset + missTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
+ uint32_t tableSize = m_hitGroupTableOffset + hitgroupTableSize;
+
+ auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
+ ComPtr<IBufferResource> bufferResource;
+ IBufferResource::Desc bufferDesc = {};
+ bufferDesc.memoryType = gfx::MemoryType::DeviceLocal;
+ bufferDesc.defaultState = ResourceState::General;
+ bufferDesc.type = IResource::Type::Buffer;
+ bufferDesc.sizeInBytes = tableSize;
+ m_device->createBufferResource(bufferDesc, nullptr, bufferResource.writeRef());
+
+ ComPtr<ID3D12StateObjectProperties> stateObjectProperties;
+ pipelineImpl->m_stateObject->QueryInterface(stateObjectProperties.writeRef());
+
+ TransientResourceHeapImpl* transientHeapImpl =
+ static_cast<TransientResourceHeapImpl*>(transientHeap);
+
+ IBufferResource* stagingBuffer = nullptr;
+ transientHeapImpl->allocateStagingBuffer(
+ tableSize, stagingBuffer, ResourceState::General);
+
+ assert(stagingBuffer);
+ void* stagingPtr = nullptr;
+ stagingBuffer->map(nullptr, &stagingPtr);
+
+ auto copyShaderIdInto = [&](void* dest, String& name)
+ {
+ void* shaderId = stateObjectProperties->GetShaderIdentifier(name.toWString().begin());
+ memcpy(dest, shaderId, D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES);
+ };
+
+ uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr;
+ for (uint32_t i = 0; i < m_rayGenShaderCount; i++)
+ {
+ copyShaderIdInto(
+ stagingBufferPtr + m_rayGenTableOffset +
+ D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
+ m_entryPointNames[i]);
+ }
+ for (uint32_t i = 0; i < m_missShaderCount; i++)
+ {
+ copyShaderIdInto(
+ stagingBufferPtr + m_missTableOffset +
+ D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
+ m_entryPointNames[m_rayGenShaderCount + i]);
+ }
+ for (uint32_t i = 0; i < m_hitGroupCount; i++)
+ {
+ copyShaderIdInto(
+ stagingBufferPtr + m_hitGroupTableOffset +
+ D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
+ m_entryPointNames[m_rayGenShaderCount + m_missShaderCount + i]);
+ }
+
+ stagingBuffer->unmap(nullptr);
+ encoder->copyBuffer(bufferResource, 0, stagingBuffer, 0, tableSize);
+ encoder->bufferBarrier(
+ 1,
+ bufferResource.readRef(),
+ gfx::ResourceState::CopyDestination,
+ gfx::ResourceState::ShaderResource);
+ RefPtr<BufferResource> resultPtr = static_cast<BufferResource*>(bufferResource.get());
+ return _Move(resultPtr);
+ }
+
+ };
+
class CommandBufferImpl
: public ICommandBuffer
, public ComObject
@@ -4103,7 +4184,7 @@ public:
IBufferResource* stagingBuffer;
m_commandBuffer->m_transientHeap->allocateStagingBuffer(
- bufferSize, stagingBuffer, ResourceState::CopySource);
+ bufferSize, stagingBuffer, ResourceState::General);
BufferResourceImpl* bufferImpl = static_cast<BufferResourceImpl*>(stagingBuffer);
uint8_t* bufferData = nullptr;
@@ -4447,7 +4528,8 @@ public:
virtual SLANG_NO_THROW void SLANG_MCALL
bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override;
virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
- const char* rayGenShaderName,
+ uint32_t rayGenShaderIndex,
+ IShaderTable* shaderTable,
int32_t width,
int32_t height,
int32_t depth) override;
@@ -6945,6 +7027,15 @@ Result D3D12Device::createMutableRootShaderObject(IShaderProgram* program, IShad
return SLANG_OK;
}
+Result D3D12Device::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable)
+{
+ RefPtr<ShaderTableImpl> result = new ShaderTableImpl();
+ result->m_device = this;
+ result->init(desc);
+ returnComPtr(outShaderTable, result);
+ return SLANG_OK;
+}
+
Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState)
{
GraphicsPipelineStateDesc desc = inDesc;
@@ -7546,7 +7637,8 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::bindPipeline(
}
void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays(
- const char* rayGenShaderName,
+ uint32_t rayGenShaderIndex,
+ IShaderTable* shaderTable,
int32_t width,
int32_t height,
int32_t depth)
@@ -7576,15 +7668,33 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays(
pipeline = newPipeline.Ptr();
}
auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
- auto dispatchDesc = pipelineImpl->m_dispatchDesc;
- int32_t rayGenShaderOffset = 0;
- if (rayGenShaderName)
- {
- rayGenShaderOffset =
- pipelineImpl->m_mapRayGenShaderNameToShaderTableIndex[rayGenShaderName].GetValue() *
- D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
- }
- dispatchDesc.RayGenerationShaderRecord.StartAddress += rayGenShaderOffset;
+
+ auto shaderTableImpl = static_cast<ShaderTableImpl*>(shaderTable);
+
+ ResourceCommandEncoderImpl resourceCopyEncoder;
+ resourceCopyEncoder.init(m_renderer, m_commandBuffer);
+ auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer(pipelineImpl, m_transientHeap, &resourceCopyEncoder);
+ auto shaderTableAddr = shaderTableBuffer->getDeviceAddress();
+
+ D3D12_DISPATCH_RAYS_DESC dispatchDesc = {};
+
+ dispatchDesc.RayGenerationShaderRecord.StartAddress =
+ shaderTableAddr + shaderTableImpl->m_rayGenTableOffset +
+ rayGenShaderIndex * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ dispatchDesc.RayGenerationShaderRecord.SizeInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+
+ dispatchDesc.MissShaderTable.StartAddress =
+ shaderTableAddr + shaderTableImpl->m_missTableOffset;
+ dispatchDesc.MissShaderTable.SizeInBytes =
+ shaderTableImpl->m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ dispatchDesc.MissShaderTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+
+ dispatchDesc.HitGroupTable.StartAddress =
+ shaderTableAddr + shaderTableImpl->m_hitGroupTableOffset;
+ dispatchDesc.HitGroupTable.SizeInBytes =
+ shaderTableImpl->m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ dispatchDesc.HitGroupTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+
dispatchDesc.Width = (UINT)width;
dispatchDesc.Height = (UINT)height;
dispatchDesc.Depth = (UINT)depth;
@@ -7616,7 +7726,6 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
ChunkedList<ComPtr<ISlangBlob>> codeBlobs;
ComPtr<ISlangBlob> diagnostics;
ChunkedList<OSString> stringPool;
- int32_t rayGenIndex = 0;
for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
{
ComPtr<ISlangBlob> codeBlob;
@@ -7639,19 +7748,6 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY;
dxilSubObject.pDesc = dxilLibraries.add(library);
subObjects.add(dxilSubObject);
-
- auto entryPointLayout = programLayout->getEntryPointByIndex(i);
- switch (entryPointLayout->getStage())
- {
- case SLANG_STAGE_RAY_GENERATION:
- pipelineStateImpl
- ->m_mapRayGenShaderNameToShaderTableIndex[entryPointLayout->getName()] =
- rayGenIndex;
- rayGenIndex++;
- break;
- default:
- break;
- }
}
auto getWStr = [&](const char* name)
{
@@ -7679,9 +7775,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
{
hitGroupDesc.IntersectionShaderImport = getWStr(hitGroup.intersectionEntryPoint);
}
- StringBuilder hitGroupName;
- hitGroupName << "hitgroup_" << i;
- hitGroupDesc.HitGroupExport = getWStr(hitGroupName.ToString().getBuffer());
+ hitGroupDesc.HitGroupExport = getWStr(hitGroup.hitGroupName);
D3D12_STATE_SUBOBJECT hitGroupSubObject = {};
hitGroupSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP;
@@ -7719,108 +7813,10 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
rtpsoDesc.pSubobjects = subObjects.getBuffer();
SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef())));
- SLANG_RETURN_ON_FAIL(pipelineStateImpl->createShaderTables(this, slangProgram, inDesc));
-
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
}
-Result D3D12Device::RayTracingPipelineStateImpl::createShaderTables(
- D3D12Device* device,
- slang::IComponentType* slangProgram,
- const RayTracingPipelineStateDesc& desc)
-{
- ComPtr<ID3D12StateObjectProperties> stateObjectProperties;
- m_stateObject->QueryInterface(stateObjectProperties.writeRef());
- auto programLayout = slangProgram->getLayout();
- struct ShaderIdentifier { uint32_t data[D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES / sizeof(uint32_t)]; };
- List<ShaderIdentifier> rayGenIdentifiers, missIdentifiers, hitgroupIdentifiers;
- for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
- {
- auto entryPointLayout = programLayout->getEntryPointByIndex(i);
- ShaderIdentifier identifier;
- switch (entryPointLayout->getStage())
- {
- case SLANG_STAGE_RAY_GENERATION:
- memcpy(
- &identifier,
- stateObjectProperties->GetShaderIdentifier(
- String(entryPointLayout->getName()).toWString().begin()),
- sizeof(ShaderIdentifier));
- rayGenIdentifiers.add(identifier);
- break;
- case SLANG_STAGE_MISS:
- memcpy(
- &identifier,
- stateObjectProperties->GetShaderIdentifier(
- String(entryPointLayout->getName()).toWString().begin()),
- sizeof(ShaderIdentifier));
- missIdentifiers.add(identifier);
- break;
- default:
- break;
- }
- }
- for (int i = 0; i < desc.shaderTableHitGroupCount; i++)
- {
- StringBuilder hitgroupName;
- hitgroupName << "hitgroup_" << desc.shaderTableHitGroupIndices[i];
- ShaderIdentifier hitgroupIdentifier;
- memcpy(
- &hitgroupIdentifier,
- stateObjectProperties->GetShaderIdentifier(hitgroupName.toWString().begin()),
- sizeof(ShaderIdentifier));
- hitgroupIdentifiers.add(hitgroupIdentifier);
- }
-
- auto createShaderTableResource = [&](ArrayView<ShaderIdentifier> content,
- RefPtr<BufferResourceImpl>& outResource) -> Result
- {
- IBufferResource::Desc bufferDesc = {};
- bufferDesc.type = IResource::Type::Buffer;
- bufferDesc.defaultState = ResourceState::ShaderResource;
- bufferDesc.allowedStates = ResourceStateSet(
- ResourceState::CopySource,
- ResourceState::UnorderedAccess,
- ResourceState::ShaderResource);
- bufferDesc.elementSize = 0;
- bufferDesc.sizeInBytes = content.getCount() * sizeof(ShaderIdentifier);
- bufferDesc.format = Format::Unknown;
- ComPtr<IBufferResource> shaderTableResource;
- SLANG_RETURN_ON_FAIL(device->createBufferResource(
- bufferDesc, content.getBuffer(), shaderTableResource.writeRef()));
- outResource = static_cast<BufferResourceImpl*>(shaderTableResource.get());
- return SLANG_OK;
- };
-
- if (desc.shaderTableHitGroupCount)
- {
- SLANG_RETURN_ON_FAIL(
- createShaderTableResource(hitgroupIdentifiers.getArrayView(), m_hitgroupShaderTable));
- m_dispatchDesc.HitGroupTable.SizeInBytes =
- (uint64_t)(sizeof(ShaderIdentifier)) * desc.shaderTableHitGroupCount;
- m_dispatchDesc.HitGroupTable.StrideInBytes = sizeof(ShaderIdentifier);
- m_dispatchDesc.HitGroupTable.StartAddress = m_hitgroupShaderTable->getDeviceAddress();
- }
- if (rayGenIdentifiers.getCount())
- {
- SLANG_RETURN_ON_FAIL(
- createShaderTableResource(rayGenIdentifiers.getArrayView(), m_rayGenShaderTable));
- m_dispatchDesc.RayGenerationShaderRecord.SizeInBytes = sizeof(ShaderIdentifier);
- m_dispatchDesc.RayGenerationShaderRecord.StartAddress = m_rayGenShaderTable->getDeviceAddress();
- }
- if (missIdentifiers.getCount())
- {
- SLANG_RETURN_ON_FAIL(
- createShaderTableResource(missIdentifiers.getArrayView(), m_missShaderTable));
- m_dispatchDesc.MissShaderTable.SizeInBytes =
- (uint64_t)(sizeof(ShaderIdentifier)) * missIdentifiers.getCount();
- m_dispatchDesc.MissShaderTable.StrideInBytes = sizeof(ShaderIdentifier);
- m_dispatchDesc.MissShaderTable.StartAddress = m_missShaderTable->getDeviceAddress();
- }
- return SLANG_OK;
-}
-
#endif // SLANG_GFX_HAS_DXR_SUPPORT
Result createNullDescriptor(