summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--examples/ray-tracing-pipeline/main.cpp22
-rw-r--r--slang-gfx.h34
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp272
-rw-r--r--tools/gfx/debug-layer.cpp17
-rw-r--r--tools/gfx/debug-layer.h12
-rw-r--r--tools/gfx/renderer-shared.cpp30
-rw-r--r--tools/gfx/renderer-shared.h50
-rw-r--r--tools/gfx/transient-resource-heap-base.h2
-rw-r--r--tools/gfx/vulkan/render-vk.cpp6
9 files changed, 290 insertions, 155 deletions
diff --git a/examples/ray-tracing-pipeline/main.cpp b/examples/ray-tracing-pipeline/main.cpp
index a4e6333e9..9b01d251c 100644
--- a/examples/ray-tracing-pipeline/main.cpp
+++ b/examples/ray-tracing-pipeline/main.cpp
@@ -222,6 +222,7 @@ ComPtr<gfx::IBufferResource> gTLASBuffer;
ComPtr<gfx::IAccelerationStructure> gTLAS;
ComPtr<gfx::ITextureResource> gResultTexture;
ComPtr<gfx::IResourceView> gResultTextureUAV;
+ComPtr<gfx::IShaderTable> gShaderTable;
uint64_t lastTime = 0;
@@ -521,6 +522,8 @@ Slang::Result initialize()
if (!gPresentPipelineState)
return SLANG_FAIL;
+ const char* hitgroupNames[] = {"hitgroup0", "hitgroup1"};
+
ComPtr<IShaderProgram> rayTracingProgram;
SLANG_RETURN_ON_FAIL(
loadShaderProgram(gDevice, true, rayTracingProgram.writeRef()));
@@ -529,18 +532,29 @@ Slang::Result initialize()
rtpDesc.hitGroupCount = 2;
HitGroupDesc hitGroups[2];
hitGroups[0].closestHitEntryPoint = "closestHitShader";
+ hitGroups[0].hitGroupName = hitgroupNames[0];
hitGroups[1].closestHitEntryPoint = "shadowRayHitShader";
+ hitGroups[1].hitGroupName = hitgroupNames[1];
rtpDesc.hitGroups = hitGroups;
rtpDesc.maxRayPayloadSize = 64;
rtpDesc.maxRecursion = 2;
- rtpDesc.shaderTableHitGroupCount = 2;
- int32_t shaderTable[] = {0, 1};
- rtpDesc.shaderTableHitGroupIndices = shaderTable;
SLANG_RETURN_ON_FAIL(
gDevice->createRayTracingPipelineState(rtpDesc, gRenderPipelineState.writeRef()));
if (!gRenderPipelineState)
return SLANG_FAIL;
+ IShaderTable::Desc shaderTableDesc = {};
+ const char* raygenName = "rayGenShader";
+ const char* missName = "missShader";
+ shaderTableDesc.program = rayTracingProgram;
+ shaderTableDesc.hitGroupCount = 2;
+ shaderTableDesc.hitGroupNames = hitgroupNames;
+ shaderTableDesc.rayGenShaderCount = 1;
+ shaderTableDesc.rayGenShaderEntryPointNames = &raygenName;
+ shaderTableDesc.missShaderCount = 1;
+ shaderTableDesc.missShaderEntryPointNames = &missName;
+ SLANG_RETURN_ON_FAIL(gDevice->createShaderTable(shaderTableDesc, gShaderTable.writeRef()));
+
createResultTexture();
return SLANG_OK;
}
@@ -626,7 +640,7 @@ virtual void renderFrame(int frameBufferIndex) override
cursor["uniforms"].setData(&gUniforms, sizeof(Uniforms));
cursor["sceneBVH"].setResource(gTLAS);
cursor["primitiveBuffer"].setResource(gPrimitiveBufferSRV);
- renderEncoder->dispatchRays(nullptr, windowWidth, windowHeight, 1);
+ renderEncoder->dispatchRays(0, gShaderTable, windowWidth, windowHeight, 1);
renderEncoder->endEncoding();
renderCommandBuffer->close();
gQueue->executeCommandBuffer(renderCommandBuffer);
diff --git a/slang-gfx.h b/slang-gfx.h
index c35892eb4..5949e5463 100644
--- a/slang-gfx.h
+++ b/slang-gfx.h
@@ -1247,6 +1247,7 @@ struct RayTracingPipelineFlags
struct HitGroupDesc
{
+ const char* hitGroupName = nullptr;
const char* closestHitEntryPoint = nullptr;
const char* anyHitEntryPoint = nullptr;
const char* intersectionEntryPoint = nullptr;
@@ -1257,13 +1258,33 @@ struct RayTracingPipelineStateDesc
IShaderProgram* program = nullptr;
int32_t hitGroupCount;
const HitGroupDesc* hitGroups;
- int32_t shaderTableHitGroupCount;
- int32_t* shaderTableHitGroupIndices;
int maxRecursion;
int maxRayPayloadSize;
RayTracingPipelineFlags::Enum flags;
};
+class IShaderTable : public ISlangUnknown
+{
+public:
+ struct Desc
+ {
+ uint32_t rayGenShaderCount;
+ const char** rayGenShaderEntryPointNames;
+
+ uint32_t missShaderCount;
+ const char** missShaderEntryPointNames;
+
+ uint32_t hitGroupCount;
+ const char** hitGroupNames;
+
+ IShaderProgram* program;
+ };
+};
+#define SLANG_UUID_IShaderTable \
+ { \
+ 0xa721522c, 0xdf31, 0x4c2f, { 0xa5, 0xe7, 0x3b, 0xe0, 0x12, 0x4b, 0x31, 0x78 } \
+ }
+
class IPipelineState : public ISlangUnknown
{
};
@@ -1657,10 +1678,10 @@ public:
bindPipeline(IPipelineState* state, IShaderObject** outRootObject) = 0;
/// Issues a dispatch command to start ray tracing workload with a ray tracing pipeline.
- /// `rayGenShaderName` specifies the name of the ray generation shader to launch. Pass nullptr for
- /// the first ray generation shader defined in `raytracingPipeline`.
+ /// `rayGenShaderIndex` specifies the index into the shader table that identifies the ray generation shader.
virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
- const char* rayGenShaderName,
+ uint32_t rayGenShaderIndex,
+ IShaderTable* shaderTable,
int32_t width,
int32_t height,
int32_t depth) = 0;
@@ -2153,6 +2174,9 @@ public:
IShaderProgram* program,
IShaderObject** outObject) = 0;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable) = 0;
+
virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(
const IShaderProgram::Desc& desc,
IShaderProgram** outProgram,
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 634fca03e..e0c324d34 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -137,6 +137,8 @@ public:
createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
+ createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnostics) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createGraphicsPipelineState(
const GraphicsPipelineStateDesc& desc, IPipelineState** outState) override;
@@ -466,12 +468,6 @@ public:
{
public:
ComPtr<ID3D12StateObject> m_stateObject;
- D3D12_DISPATCH_RAYS_DESC m_dispatchDesc = {};
- Dictionary<String, int32_t> m_mapRayGenShaderNameToShaderTableIndex;
- // Shader Tables for each ray-tracing stage stored in GPU memory.
- RefPtr<BufferResourceImpl> m_rayGenShaderTable;
- RefPtr<BufferResourceImpl> m_hitgroupShaderTable;
- RefPtr<BufferResourceImpl> m_missShaderTable;
void init(const RayTracingPipelineStateDesc& inDesc)
{
PipelineStateDesc pipelineDesc;
@@ -479,10 +475,6 @@ public:
pipelineDesc.rayTracing = inDesc;
initializeBase(pipelineDesc);
}
- Result createShaderTables(
- D3D12Device* device,
- slang::IComponentType* slangProgram,
- const RayTracingPipelineStateDesc& desc);
};
#endif
@@ -882,7 +874,7 @@ public:
IBufferResource* uploadResource;
if (buffer->getDesc()->memoryType != MemoryType::Upload)
{
- transientHeap->allocateStagingBuffer(size, uploadResource, ResourceState::CopySource);
+ transientHeap->allocateStagingBuffer(size, uploadResource, ResourceState::General);
}
D3D12Resource& uploadResourceRef =
@@ -3253,6 +3245,95 @@ public:
D3D12DescriptorHeap m_cpuSamplerHeap;
};
+ class ShaderTableImpl : public ShaderTableBase
+ {
+ public:
+ uint32_t m_rayGenTableOffset;
+ uint32_t m_missTableOffset;
+ uint32_t m_hitGroupTableOffset;
+
+ D3D12Device* m_device;
+
+ virtual RefPtr<BufferResource> createDeviceBuffer(
+ PipelineStateBase* pipeline,
+ TransientResourceHeapBase* transientHeap,
+ IResourceCommandEncoder* encoder) override
+ {
+ uint32_t raygenTableSize = m_rayGenShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ uint32_t missTableSize = m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ uint32_t hitgroupTableSize = m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ m_rayGenTableOffset = 0;
+ m_missTableOffset =
+ (uint32_t)D3DUtil::calcAligned(raygenTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
+ m_hitGroupTableOffset = (uint32_t)D3DUtil::calcAligned(
+ m_missTableOffset + missTableSize, D3D12_RAYTRACING_SHADER_TABLE_BYTE_ALIGNMENT);
+ uint32_t tableSize = m_hitGroupTableOffset + hitgroupTableSize;
+
+ auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
+ ComPtr<IBufferResource> bufferResource;
+ IBufferResource::Desc bufferDesc = {};
+ bufferDesc.memoryType = gfx::MemoryType::DeviceLocal;
+ bufferDesc.defaultState = ResourceState::General;
+ bufferDesc.type = IResource::Type::Buffer;
+ bufferDesc.sizeInBytes = tableSize;
+ m_device->createBufferResource(bufferDesc, nullptr, bufferResource.writeRef());
+
+ ComPtr<ID3D12StateObjectProperties> stateObjectProperties;
+ pipelineImpl->m_stateObject->QueryInterface(stateObjectProperties.writeRef());
+
+ TransientResourceHeapImpl* transientHeapImpl =
+ static_cast<TransientResourceHeapImpl*>(transientHeap);
+
+ IBufferResource* stagingBuffer = nullptr;
+ transientHeapImpl->allocateStagingBuffer(
+ tableSize, stagingBuffer, ResourceState::General);
+
+ assert(stagingBuffer);
+ void* stagingPtr = nullptr;
+ stagingBuffer->map(nullptr, &stagingPtr);
+
+ auto copyShaderIdInto = [&](void* dest, String& name)
+ {
+ void* shaderId = stateObjectProperties->GetShaderIdentifier(name.toWString().begin());
+ memcpy(dest, shaderId, D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES);
+ };
+
+ uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr;
+ for (uint32_t i = 0; i < m_rayGenShaderCount; i++)
+ {
+ copyShaderIdInto(
+ stagingBufferPtr + m_rayGenTableOffset +
+ D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
+ m_entryPointNames[i]);
+ }
+ for (uint32_t i = 0; i < m_missShaderCount; i++)
+ {
+ copyShaderIdInto(
+ stagingBufferPtr + m_missTableOffset +
+ D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
+ m_entryPointNames[m_rayGenShaderCount + i]);
+ }
+ for (uint32_t i = 0; i < m_hitGroupCount; i++)
+ {
+ copyShaderIdInto(
+ stagingBufferPtr + m_hitGroupTableOffset +
+ D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
+ m_entryPointNames[m_rayGenShaderCount + m_missShaderCount + i]);
+ }
+
+ stagingBuffer->unmap(nullptr);
+ encoder->copyBuffer(bufferResource, 0, stagingBuffer, 0, tableSize);
+ encoder->bufferBarrier(
+ 1,
+ bufferResource.readRef(),
+ gfx::ResourceState::CopyDestination,
+ gfx::ResourceState::ShaderResource);
+ RefPtr<BufferResource> resultPtr = static_cast<BufferResource*>(bufferResource.get());
+ return _Move(resultPtr);
+ }
+
+ };
+
class CommandBufferImpl
: public ICommandBuffer
, public ComObject
@@ -4103,7 +4184,7 @@ public:
IBufferResource* stagingBuffer;
m_commandBuffer->m_transientHeap->allocateStagingBuffer(
- bufferSize, stagingBuffer, ResourceState::CopySource);
+ bufferSize, stagingBuffer, ResourceState::General);
BufferResourceImpl* bufferImpl = static_cast<BufferResourceImpl*>(stagingBuffer);
uint8_t* bufferData = nullptr;
@@ -4447,7 +4528,8 @@ public:
virtual SLANG_NO_THROW void SLANG_MCALL
bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override;
virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
- const char* rayGenShaderName,
+ uint32_t rayGenShaderIndex,
+ IShaderTable* shaderTable,
int32_t width,
int32_t height,
int32_t depth) override;
@@ -6945,6 +7027,15 @@ Result D3D12Device::createMutableRootShaderObject(IShaderProgram* program, IShad
return SLANG_OK;
}
+Result D3D12Device::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable)
+{
+ RefPtr<ShaderTableImpl> result = new ShaderTableImpl();
+ result->m_device = this;
+ result->init(desc);
+ returnComPtr(outShaderTable, result);
+ return SLANG_OK;
+}
+
Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState)
{
GraphicsPipelineStateDesc desc = inDesc;
@@ -7546,7 +7637,8 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::bindPipeline(
}
void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays(
- const char* rayGenShaderName,
+ uint32_t rayGenShaderIndex,
+ IShaderTable* shaderTable,
int32_t width,
int32_t height,
int32_t depth)
@@ -7576,15 +7668,33 @@ void D3D12Device::CommandBufferImpl::RayTracingCommandEncoderImpl::dispatchRays(
pipeline = newPipeline.Ptr();
}
auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
- auto dispatchDesc = pipelineImpl->m_dispatchDesc;
- int32_t rayGenShaderOffset = 0;
- if (rayGenShaderName)
- {
- rayGenShaderOffset =
- pipelineImpl->m_mapRayGenShaderNameToShaderTableIndex[rayGenShaderName].GetValue() *
- D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
- }
- dispatchDesc.RayGenerationShaderRecord.StartAddress += rayGenShaderOffset;
+
+ auto shaderTableImpl = static_cast<ShaderTableImpl*>(shaderTable);
+
+ ResourceCommandEncoderImpl resourceCopyEncoder;
+ resourceCopyEncoder.init(m_renderer, m_commandBuffer);
+ auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer(pipelineImpl, m_transientHeap, &resourceCopyEncoder);
+ auto shaderTableAddr = shaderTableBuffer->getDeviceAddress();
+
+ D3D12_DISPATCH_RAYS_DESC dispatchDesc = {};
+
+ dispatchDesc.RayGenerationShaderRecord.StartAddress =
+ shaderTableAddr + shaderTableImpl->m_rayGenTableOffset +
+ rayGenShaderIndex * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ dispatchDesc.RayGenerationShaderRecord.SizeInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+
+ dispatchDesc.MissShaderTable.StartAddress =
+ shaderTableAddr + shaderTableImpl->m_missTableOffset;
+ dispatchDesc.MissShaderTable.SizeInBytes =
+ shaderTableImpl->m_missShaderCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ dispatchDesc.MissShaderTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+
+ dispatchDesc.HitGroupTable.StartAddress =
+ shaderTableAddr + shaderTableImpl->m_hitGroupTableOffset;
+ dispatchDesc.HitGroupTable.SizeInBytes =
+ shaderTableImpl->m_hitGroupCount * D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+ dispatchDesc.HitGroupTable.StrideInBytes = D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES;
+
dispatchDesc.Width = (UINT)width;
dispatchDesc.Height = (UINT)height;
dispatchDesc.Depth = (UINT)depth;
@@ -7616,7 +7726,6 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
ChunkedList<ComPtr<ISlangBlob>> codeBlobs;
ComPtr<ISlangBlob> diagnostics;
ChunkedList<OSString> stringPool;
- int32_t rayGenIndex = 0;
for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
{
ComPtr<ISlangBlob> codeBlob;
@@ -7639,19 +7748,6 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY;
dxilSubObject.pDesc = dxilLibraries.add(library);
subObjects.add(dxilSubObject);
-
- auto entryPointLayout = programLayout->getEntryPointByIndex(i);
- switch (entryPointLayout->getStage())
- {
- case SLANG_STAGE_RAY_GENERATION:
- pipelineStateImpl
- ->m_mapRayGenShaderNameToShaderTableIndex[entryPointLayout->getName()] =
- rayGenIndex;
- rayGenIndex++;
- break;
- default:
- break;
- }
}
auto getWStr = [&](const char* name)
{
@@ -7679,9 +7775,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
{
hitGroupDesc.IntersectionShaderImport = getWStr(hitGroup.intersectionEntryPoint);
}
- StringBuilder hitGroupName;
- hitGroupName << "hitgroup_" << i;
- hitGroupDesc.HitGroupExport = getWStr(hitGroupName.ToString().getBuffer());
+ hitGroupDesc.HitGroupExport = getWStr(hitGroup.hitGroupName);
D3D12_STATE_SUBOBJECT hitGroupSubObject = {};
hitGroupSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_HIT_GROUP;
@@ -7719,108 +7813,10 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
rtpsoDesc.pSubobjects = subObjects.getBuffer();
SLANG_RETURN_ON_FAIL(m_device5->CreateStateObject(&rtpsoDesc, IID_PPV_ARGS(pipelineStateImpl->m_stateObject.writeRef())));
- SLANG_RETURN_ON_FAIL(pipelineStateImpl->createShaderTables(this, slangProgram, inDesc));
-
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
}
-Result D3D12Device::RayTracingPipelineStateImpl::createShaderTables(
- D3D12Device* device,
- slang::IComponentType* slangProgram,
- const RayTracingPipelineStateDesc& desc)
-{
- ComPtr<ID3D12StateObjectProperties> stateObjectProperties;
- m_stateObject->QueryInterface(stateObjectProperties.writeRef());
- auto programLayout = slangProgram->getLayout();
- struct ShaderIdentifier { uint32_t data[D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES / sizeof(uint32_t)]; };
- List<ShaderIdentifier> rayGenIdentifiers, missIdentifiers, hitgroupIdentifiers;
- for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
- {
- auto entryPointLayout = programLayout->getEntryPointByIndex(i);
- ShaderIdentifier identifier;
- switch (entryPointLayout->getStage())
- {
- case SLANG_STAGE_RAY_GENERATION:
- memcpy(
- &identifier,
- stateObjectProperties->GetShaderIdentifier(
- String(entryPointLayout->getName()).toWString().begin()),
- sizeof(ShaderIdentifier));
- rayGenIdentifiers.add(identifier);
- break;
- case SLANG_STAGE_MISS:
- memcpy(
- &identifier,
- stateObjectProperties->GetShaderIdentifier(
- String(entryPointLayout->getName()).toWString().begin()),
- sizeof(ShaderIdentifier));
- missIdentifiers.add(identifier);
- break;
- default:
- break;
- }
- }
- for (int i = 0; i < desc.shaderTableHitGroupCount; i++)
- {
- StringBuilder hitgroupName;
- hitgroupName << "hitgroup_" << desc.shaderTableHitGroupIndices[i];
- ShaderIdentifier hitgroupIdentifier;
- memcpy(
- &hitgroupIdentifier,
- stateObjectProperties->GetShaderIdentifier(hitgroupName.toWString().begin()),
- sizeof(ShaderIdentifier));
- hitgroupIdentifiers.add(hitgroupIdentifier);
- }
-
- auto createShaderTableResource = [&](ArrayView<ShaderIdentifier> content,
- RefPtr<BufferResourceImpl>& outResource) -> Result
- {
- IBufferResource::Desc bufferDesc = {};
- bufferDesc.type = IResource::Type::Buffer;
- bufferDesc.defaultState = ResourceState::ShaderResource;
- bufferDesc.allowedStates = ResourceStateSet(
- ResourceState::CopySource,
- ResourceState::UnorderedAccess,
- ResourceState::ShaderResource);
- bufferDesc.elementSize = 0;
- bufferDesc.sizeInBytes = content.getCount() * sizeof(ShaderIdentifier);
- bufferDesc.format = Format::Unknown;
- ComPtr<IBufferResource> shaderTableResource;
- SLANG_RETURN_ON_FAIL(device->createBufferResource(
- bufferDesc, content.getBuffer(), shaderTableResource.writeRef()));
- outResource = static_cast<BufferResourceImpl*>(shaderTableResource.get());
- return SLANG_OK;
- };
-
- if (desc.shaderTableHitGroupCount)
- {
- SLANG_RETURN_ON_FAIL(
- createShaderTableResource(hitgroupIdentifiers.getArrayView(), m_hitgroupShaderTable));
- m_dispatchDesc.HitGroupTable.SizeInBytes =
- (uint64_t)(sizeof(ShaderIdentifier)) * desc.shaderTableHitGroupCount;
- m_dispatchDesc.HitGroupTable.StrideInBytes = sizeof(ShaderIdentifier);
- m_dispatchDesc.HitGroupTable.StartAddress = m_hitgroupShaderTable->getDeviceAddress();
- }
- if (rayGenIdentifiers.getCount())
- {
- SLANG_RETURN_ON_FAIL(
- createShaderTableResource(rayGenIdentifiers.getArrayView(), m_rayGenShaderTable));
- m_dispatchDesc.RayGenerationShaderRecord.SizeInBytes = sizeof(ShaderIdentifier);
- m_dispatchDesc.RayGenerationShaderRecord.StartAddress = m_rayGenShaderTable->getDeviceAddress();
- }
- if (missIdentifiers.getCount())
- {
- SLANG_RETURN_ON_FAIL(
- createShaderTableResource(missIdentifiers.getArrayView(), m_missShaderTable));
- m_dispatchDesc.MissShaderTable.SizeInBytes =
- (uint64_t)(sizeof(ShaderIdentifier)) * missIdentifiers.getCount();
- m_dispatchDesc.MissShaderTable.StrideInBytes = sizeof(ShaderIdentifier);
- m_dispatchDesc.MissShaderTable.StartAddress = m_missShaderTable->getDeviceAddress();
- }
- return SLANG_OK;
-}
-
#endif // SLANG_GFX_HAS_DXR_SUPPORT
Result createNullDescriptor(
diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp
index fa8fc6fda..53deec385 100644
--- a/tools/gfx/debug-layer.cpp
+++ b/tools/gfx/debug-layer.cpp
@@ -140,7 +140,7 @@ SLANG_GFX_DEBUG_GET_INTERFACE_IMPL(TransientResourceHeap)
SLANG_GFX_DEBUG_GET_INTERFACE_IMPL(QueryPool)
SLANG_GFX_DEBUG_GET_INTERFACE_IMPL_PARENT(AccelerationStructure, ResourceView)
SLANG_GFX_DEBUG_GET_INTERFACE_IMPL(Fence)
-
+SLANG_GFX_DEBUG_GET_INTERFACE_IMPL(ShaderTable)
#undef SLANG_GFX_DEBUG_GET_INTERFACE_IMPL
#undef SLANG_GFX_DEBUG_GET_INTERFACE_IMPL_PARENT
@@ -179,6 +179,7 @@ SLANG_GFX_DEBUG_GET_OBJ_IMPL(TransientResourceHeap)
SLANG_GFX_DEBUG_GET_OBJ_IMPL(QueryPool)
SLANG_GFX_DEBUG_GET_OBJ_IMPL(AccelerationStructure)
SLANG_GFX_DEBUG_GET_OBJ_IMPL(Fence)
+SLANG_GFX_DEBUG_GET_OBJ_IMPL(ShaderTable)
#undef SLANG_GFX_DEBUG_GET_OBJ_IMPL
@@ -802,6 +803,15 @@ Result DebugDevice::getTextureAllocationInfo(
return baseObject->getTextureAllocationInfo(desc, outSize, outAlignment);
}
+Result DebugDevice::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable)
+{
+ SLANG_GFX_API_FUNC;
+ RefPtr<DebugShaderTable> result = new DebugShaderTable();
+ SLANG_RETURN_ON_FAIL(baseObject->createShaderTable(desc, result->baseObject.writeRef()));
+ returnComPtr(outTable, result);
+ return SLANG_OK;
+}
+
IResource::Type DebugBufferResource::getType()
{
SLANG_GFX_API_FUNC;
@@ -1477,13 +1487,14 @@ void DebugRayTracingCommandEncoder::bindPipeline(
}
void DebugRayTracingCommandEncoder::dispatchRays(
- const char* rayGenShaderName,
+ uint32_t rayGenShaderIndex,
+ IShaderTable* shaderTable,
int32_t width,
int32_t height,
int32_t depth)
{
SLANG_GFX_API_FUNC;
- baseObject->dispatchRays(rayGenShaderName, width, height, depth);
+ baseObject->dispatchRays(rayGenShaderIndex, getInnerObj(shaderTable), width, height, depth);
}
const ICommandQueue::Desc& DebugCommandQueue::getDesc()
diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h
index 04dafadf2..a2de7eb31 100644
--- a/tools/gfx/debug-layer.h
+++ b/tools/gfx/debug-layer.h
@@ -156,6 +156,15 @@ public:
uint64_t timeout) override;
virtual SLANG_NO_THROW Result SLANG_MCALL getTextureAllocationInfo(
const ITextureResource::Desc& desc, size_t* outSize, size_t* outAlignment) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable) override;
+};
+
+class DebugShaderTable : public DebugObject<IShaderTable>
+{
+public:
+ SLANG_COM_OBJECT_IUNKNOWN_ALL;
+ IShaderTable* getInterface(const Slang::Guid& guid);
};
class DebugQueryPool : public DebugObject<IQueryPool>
@@ -511,7 +520,8 @@ public:
virtual SLANG_NO_THROW void SLANG_MCALL
bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override;
virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
- const char* rayGenShaderName,
+ uint32_t rayGenShaderIndex,
+ IShaderTable* shaderTable,
int32_t width,
int32_t height,
int32_t depth) override;
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index fb7baf85d..52cf7ffac 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -32,6 +32,7 @@ const Slang::Guid GfxGUID::IID_ICommandQueue = SLANG_UUID_ICommandQueue;
const Slang::Guid GfxGUID::IID_IQueryPool = SLANG_UUID_IQueryPool;
const Slang::Guid GfxGUID::IID_IAccelerationStructure = SLANG_UUID_IAccelerationStructure;
const Slang::Guid GfxGUID::IID_IFence = SLANG_UUID_IFence;
+const Slang::Guid GfxGUID::IID_IShaderTable = SLANG_UUID_IShaderTable;
StageType translateStage(SlangStage slangStage)
@@ -438,6 +439,13 @@ Result RendererBase::createAccelerationStructure(
return SLANG_E_NOT_AVAILABLE;
}
+Result RendererBase::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable)
+{
+ SLANG_UNUSED(desc);
+ SLANG_UNUSED(outTable);
+ return SLANG_E_NOT_AVAILABLE;
+}
+
Result RendererBase::createRayTracingPipelineState(const RayTracingPipelineStateDesc& desc, IPipelineState** outState)
{
SLANG_UNUSED(desc);
@@ -847,4 +855,26 @@ Result ShaderObjectBase::copyFrom(IShaderObject* object, ITransientResourceHeap*
return SLANG_FAIL;
}
+Result ShaderTableBase::init(const IShaderTable::Desc& desc)
+{
+ m_rayGenShaderCount = desc.rayGenShaderCount;
+ m_missShaderCount = desc.missShaderCount;
+ m_hitGroupCount = desc.hitGroupCount;
+ m_entryPointNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
+ for (uint32_t i = 0; i < desc.rayGenShaderCount; i++)
+ {
+ m_entryPointNames.add(desc.rayGenShaderEntryPointNames[i]);
+ }
+ for (uint32_t i = 0; i < desc.missShaderCount; i++)
+ {
+ m_entryPointNames.add(desc.missShaderEntryPointNames[i]);
+ }
+ for (uint32_t i = 0; i < desc.hitGroupCount; i++)
+ {
+ m_entryPointNames.add(desc.hitGroupNames[i]);
+ }
+ return SLANG_OK;
+}
+
} // namespace gfx
+
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 33d51e9ad..6baade085 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -39,6 +39,7 @@ struct GfxGUID
static const Slang::Guid IID_IQueryPool;
static const Slang::Guid IID_IAccelerationStructure;
static const Slang::Guid IID_IFence;
+ static const Slang::Guid IID_IShaderTable;
};
// We use a `BreakableReference` to avoid the cyclic reference situation in gfx implementation.
@@ -1186,7 +1187,7 @@ public:
}
public:
SLANG_COM_OBJECT_IUNKNOWN_ALL
- ITransientResourceHeap* getInterface(const Slang::Guid& guid)
+ ITransientResourceHeap* getInterface(const Slang::Guid& guid)
{
if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_ITransientResourceHeap)
return static_cast<ITransientResourceHeap*>(this);
@@ -1194,6 +1195,48 @@ public:
}
};
+class ShaderTableBase
+ : public IShaderTable
+ , public Slang::ComObject
+{
+public:
+ Slang::List<Slang::String> m_entryPointNames;
+ uint32_t m_rayGenShaderCount;
+ uint32_t m_missShaderCount;
+ uint32_t m_hitGroupCount;
+
+ Slang::Dictionary<PipelineStateBase*, Slang::RefPtr<BufferResource>> m_deviceBuffers;
+
+ SLANG_COM_OBJECT_IUNKNOWN_ALL
+ IShaderTable* getInterface(const Slang::Guid& guid)
+ {
+ if (guid == GfxGUID::IID_ISlangUnknown || guid == GfxGUID::IID_IShaderTable)
+ return static_cast<IShaderTable*>(this);
+ return nullptr;
+ }
+
+ virtual Slang::RefPtr<BufferResource> createDeviceBuffer(
+ PipelineStateBase* pipeline,
+ TransientResourceHeapBase* transientHeap,
+ IResourceCommandEncoder* encoder) = 0;
+
+ BufferResource* getOrCreateBuffer(
+ PipelineStateBase* pipeline,
+ TransientResourceHeapBase* transientHeap,
+ IResourceCommandEncoder* encoder)
+ {
+ if (auto ptr = m_deviceBuffers.TryGetValue(pipeline))
+ {
+ return ptr->Ptr();
+ }
+ auto result = createDeviceBuffer(pipeline, transientHeap, encoder);
+ m_deviceBuffers[pipeline] = result;
+ return result;
+ }
+
+ Result init(const IShaderTable::Desc& desc);
+};
+
// Renderer implementation shared by all platforms.
// Responsible for shader compilation, specialization and caching.
class RendererBase : public IDevice, public Slang::ComObject
@@ -1256,6 +1299,11 @@ public:
// Provides a default implementation that returns SLANG_E_NOT_AVAILABLE for platforms
// without ray tracing support.
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outTable) override;
+
+ // Provides a default implementation that returns SLANG_E_NOT_AVAILABLE for platforms
+ // without ray tracing support.
virtual SLANG_NO_THROW Result SLANG_MCALL createRayTracingPipelineState(
const RayTracingPipelineStateDesc& desc, IPipelineState** outState) override;
diff --git a/tools/gfx/transient-resource-heap-base.h b/tools/gfx/transient-resource-heap-base.h
index f3df1c139..2dc16dcd4 100644
--- a/tools/gfx/transient-resource-heap-base.h
+++ b/tools/gfx/transient-resource-heap-base.h
@@ -55,7 +55,7 @@ public:
bufferDesc.defaultState = state;
bufferDesc.allowedStates =
ResourceStateSet(ResourceState::CopyDestination, ResourceState::CopySource);
- if (state == ResourceState::CopySource)
+ if (state == ResourceState::General)
bufferDesc.memoryType = MemoryType::Upload;
else
bufferDesc.memoryType = MemoryType::ReadBack;
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index 9f5af2a1e..f31195cc8 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -5117,12 +5117,14 @@ public:
}
virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
- const char* rayGenShaderName,
+ uint32_t raygenShaderIndex,
+ IShaderTable* shaderTable,
int32_t width,
int32_t height,
int32_t depth) override
{
- SLANG_UNUSED(rayGenShaderName);
+ SLANG_UNUSED(raygenShaderIndex);
+ SLANG_UNUSED(shaderTable);
SLANG_UNUSED(width);
SLANG_UNUSED(height);
SLANG_UNUSED(depth);