summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp6
-rw-r--r--tools/gfx/renderer-shared.cpp8
-rw-r--r--tools/gfx/renderer-shared.h2
-rw-r--r--tools/gfx/vulkan/render-vk.cpp271
-rw-r--r--tools/gfx/vulkan/vk-api.h15
-rw-r--r--tools/gfx/vulkan/vk-util.cpp2
-rw-r--r--tools/gfx/vulkan/vk-util.h3
7 files changed, 273 insertions, 34 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 12f8bafa6..9be157605 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -3512,7 +3512,7 @@ public:
copyShaderIdInto(
stagingBufferPtr + m_rayGenTableOffset +
D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
- m_entryPointNames[i],
+ m_shaderGroupNames[i],
m_recordOverwrites[i]);
}
for (uint32_t i = 0; i < m_missShaderCount; i++)
@@ -3520,7 +3520,7 @@ public:
copyShaderIdInto(
stagingBufferPtr + m_missTableOffset +
D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
- m_entryPointNames[m_rayGenShaderCount + i],
+ m_shaderGroupNames[m_rayGenShaderCount + i],
m_recordOverwrites[m_rayGenShaderCount + i]);
}
for (uint32_t i = 0; i < m_hitGroupCount; i++)
@@ -3528,7 +3528,7 @@ public:
copyShaderIdInto(
stagingBufferPtr + m_hitGroupTableOffset +
D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
- m_entryPointNames[m_rayGenShaderCount + m_missShaderCount + i],
+ m_shaderGroupNames[m_rayGenShaderCount + m_missShaderCount + i],
m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + i]);
}
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index a97462c73..244146ca3 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -956,11 +956,11 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc)
m_rayGenShaderCount = desc.rayGenShaderCount;
m_missShaderCount = desc.missShaderCount;
m_hitGroupCount = desc.hitGroupCount;
- m_entryPointNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
+ m_shaderGroupNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
m_recordOverwrites.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
for (uint32_t i = 0; i < desc.rayGenShaderCount; i++)
{
- m_entryPointNames.add(desc.rayGenShaderEntryPointNames[i]);
+ m_shaderGroupNames.add(desc.rayGenShaderEntryPointNames[i]);
if (desc.rayGenShaderRecordOverwrites)
{
m_recordOverwrites.add(desc.rayGenShaderRecordOverwrites[i]);
@@ -972,7 +972,7 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc)
}
for (uint32_t i = 0; i < desc.missShaderCount; i++)
{
- m_entryPointNames.add(desc.missShaderEntryPointNames[i]);
+ m_shaderGroupNames.add(desc.missShaderEntryPointNames[i]);
if (desc.missShaderRecordOverwrites)
{
m_recordOverwrites.add(desc.missShaderRecordOverwrites[i]);
@@ -984,7 +984,7 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc)
}
for (uint32_t i = 0; i < desc.hitGroupCount; i++)
{
- m_entryPointNames.add(desc.hitGroupNames[i]);
+ m_shaderGroupNames.add(desc.hitGroupNames[i]);
if (desc.hitGroupRecordOverwrites)
{
m_recordOverwrites.add(desc.hitGroupRecordOverwrites[i]);
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 4ebd28ed2..5afc77027 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -1076,7 +1076,7 @@ class ShaderTableBase
, public Slang::ComObject
{
public:
- Slang::List<Slang::String> m_entryPointNames;
+ Slang::List<Slang::String> m_shaderGroupNames;
Slang::List<ShaderRecordOverwrite> m_recordOverwrites;
uint32_t m_rayGenShaderCount;
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index f11ca59aa..1a42f3261 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -114,6 +114,8 @@ public:
createMutableRootShaderObject(
IShaderProgram* program, IShaderObject** outObject) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(
const IShaderProgram::Desc& desc,
IShaderProgram** outProgram,
@@ -1013,6 +1015,17 @@ public:
VkPipeline m_pipeline = VK_NULL_HANDLE;
};
+ class RayTracingPipelineStateImpl : public PipelineStateImpl
+ {
+ public:
+ Dictionary<String, Index> shaderGroupNameToIndex;
+ Int shaderGroupCount;
+
+ RayTracingPipelineStateImpl(VKDevice* device)
+ : PipelineStateImpl(device)
+ {};
+ };
+
// In order to bind shader parameters to the correct locations, we need to
// be able to describe those locations. Most shader parameters in Vulkan
// simply consume a single `binding`, but we also need to deal with
@@ -2399,6 +2412,7 @@ public:
BreakableReference<VKDevice> m_device;
Array<VkPipelineShaderStageCreateInfo, 8> m_stageCreateInfos;
+ Array<String, 8> m_entryPointNames;
Array<ComPtr<ISlangBlob>, 8> m_codeBlobs; //< To keep storage of code in scope
Array<VkShaderModule, 8> m_modules;
RefPtr<RootShaderObjectLayout> m_rootObjectLayout;
@@ -3885,6 +3899,119 @@ public:
List<RefPtr<EntryPointShaderObject>> m_entryPoints;
};
+ class ShaderTableImpl : public ShaderTableBase
+ {
+ public:
+ uint32_t m_raygenTableSize;
+ uint32_t m_missTableSize;
+ uint32_t m_hitTableSize;
+ uint32_t m_callableTableSize;
+
+ VKDevice* m_device;
+
+ virtual RefPtr<BufferResource> createDeviceBuffer(
+ PipelineStateBase* pipeline,
+ TransientResourceHeapBase* transientHeap,
+ IResourceCommandEncoder* encoder) override
+ {
+ auto vkApi = m_device->m_api;
+ auto rtProps = vkApi.m_rtProperties;
+ uint32_t handleSize = rtProps.shaderGroupHandleSize;
+ m_raygenTableSize = (uint32_t)VulkanUtil::calcAligned(m_rayGenShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
+ m_missTableSize = (uint32_t)VulkanUtil::calcAligned(m_missShaderCount * handleSize, rtProps.shaderGroupBaseAlignment);
+ m_hitTableSize = (uint32_t)VulkanUtil::calcAligned(m_hitGroupCount * handleSize, rtProps.shaderGroupBaseAlignment);
+ m_callableTableSize = 0; // TODO: Are callable shaders needed?
+ uint32_t tableSize = m_raygenTableSize + m_missTableSize + m_hitTableSize + m_callableTableSize;
+
+ auto pipelineImpl = static_cast<RayTracingPipelineStateImpl*>(pipeline);
+ ComPtr<IBufferResource> bufferResource;
+ IBufferResource::Desc bufferDesc = {};
+ bufferDesc.memoryType = MemoryType::DeviceLocal;
+ bufferDesc.defaultState = ResourceState::General;
+ bufferDesc.allowedStates = ResourceStateSet(ResourceState::General, ResourceState::CopyDestination);
+ bufferDesc.type = IResource::Type::Buffer;
+ bufferDesc.sizeInBytes = tableSize;
+ m_device->createBufferResource(bufferDesc, nullptr, bufferResource.writeRef());
+
+ TransientResourceHeapImpl* transientHeapImpl =
+ static_cast<TransientResourceHeapImpl*>(transientHeap);
+
+ IBufferResource* stagingBuffer = nullptr;
+ transientHeapImpl->allocateStagingBuffer(
+ tableSize, stagingBuffer, ResourceState::General);
+
+ assert(stagingBuffer);
+ void* stagingPtr = nullptr;
+ stagingBuffer->map(nullptr, &stagingPtr);
+
+ List<uint8_t> handles;
+ auto handleCount = pipelineImpl->shaderGroupCount;
+ auto totalHandleSize = handleSize * handleCount;
+ handles.setCount(totalHandleSize);
+ auto result = vkApi.vkGetRayTracingShaderGroupHandlesKHR(m_device->m_device, pipelineImpl->m_pipeline, 0, (uint32_t)handleCount, totalHandleSize, handles.getBuffer());
+
+ uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr;
+ auto subTablePtr = stagingBufferPtr;
+ Int shaderTableEntryCounter = 0;
+
+ // Each loop calculates the copy source and destination locations by fetching the name of the shader
+ // group from the list of shader group names and getting its corresponding index in the buffer of handles.
+ for (uint32_t i = 0; i < m_rayGenShaderCount; i++)
+ {
+ auto dstHandlePtr = subTablePtr + i * handleSize;
+ auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
+ auto shaderGroupIndexPtr = pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName);
+ if (!shaderGroupIndexPtr)
+ continue;
+
+ auto shaderGroupIndex = *shaderGroupIndexPtr;
+ auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
+ memcpy(dstHandlePtr, srcHandlePtr, handleSize);
+ }
+ subTablePtr += m_raygenTableSize;
+
+ for (uint32_t i = 0; i < m_missShaderCount; i++)
+ {
+ auto dstHandlePtr = subTablePtr + i * handleSize;
+ auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
+ auto shaderGroupIndexPtr = pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName);
+ if (!shaderGroupIndexPtr)
+ continue;
+
+ auto shaderGroupIndex = *shaderGroupIndexPtr;
+ auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
+ memcpy(dstHandlePtr, srcHandlePtr, handleSize);
+ }
+ subTablePtr += m_missTableSize;
+
+ for (uint32_t i = 0; i < m_hitGroupCount; i++)
+ {
+ auto dstHandlePtr = subTablePtr + i * handleSize;
+ auto shaderGroupName = m_shaderGroupNames[shaderTableEntryCounter++];
+ auto shaderGroupIndexPtr = pipelineImpl->shaderGroupNameToIndex.TryGetValue(shaderGroupName);
+ if (!shaderGroupIndexPtr)
+ continue;
+
+ auto shaderGroupIndex = *shaderGroupIndexPtr;
+ auto srcHandlePtr = handles.getBuffer() + shaderGroupIndex * handleSize;
+ memcpy(dstHandlePtr, srcHandlePtr, handleSize);
+ }
+ subTablePtr += m_hitTableSize;
+
+ // TODO: Callable shaders?
+
+ stagingBuffer->unmap(nullptr);
+ encoder->copyBuffer(bufferResource, 0, stagingBuffer, 0, tableSize);
+ encoder->bufferBarrier(
+ 1,
+ bufferResource.readRef(),
+ gfx::ResourceState::CopyDestination,
+ gfx::ResourceState::ShaderResource);
+ RefPtr<BufferResource> resultPtr = static_cast<BufferResource*>(bufferResource.get());
+ return _Move(resultPtr);
+ }
+ };
+
class CommandBufferImpl
: public ICommandBuffer
, public ComObject
@@ -5506,11 +5633,9 @@ public:
virtual SLANG_NO_THROW void SLANG_MCALL
bindPipeline(IPipelineState* pipeline, IShaderObject** outRootObject) override
{
- SLANG_UNUSED(pipeline);
- SLANG_UNUSED(outRootObject);
+ setPipelineStateImpl(pipeline, outRootObject);
}
- // TODO: Implement after implementing createRayTracingPipelineState
virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
uint32_t raygenShaderIndex,
IShaderTable* shaderTable,
@@ -5518,15 +5643,46 @@ public:
int32_t height,
int32_t depth) override
{
- SLANG_UNUSED(raygenShaderIndex);
- SLANG_UNUSED(shaderTable);
- SLANG_UNUSED(width);
- SLANG_UNUSED(height);
- SLANG_UNUSED(depth);
+ auto vkApi = m_commandBuffer->m_renderer->m_api;
+ auto vkCommandBuffer = m_commandBuffer->m_commandBuffer;
+
+ flushBindingState(VK_PIPELINE_BIND_POINT_RAY_TRACING_KHR);
+
+ auto rtProps = vkApi.m_rtProperties;
+ auto shaderTableImpl = (ShaderTableImpl*)shaderTable;
+ auto alignedHandleSize = VulkanUtil::calcAligned(rtProps.shaderGroupHandleSize, rtProps.shaderGroupHandleAlignment);
+
+ ResourceCommandEncoder resourceCopyEncoder;
+ resourceCopyEncoder.init(m_commandBuffer);
+ auto shaderTableBuffer = shaderTableImpl->getOrCreateBuffer(m_currentPipeline, m_commandBuffer->m_transientHeap, &resourceCopyEncoder);
+
+ VkStridedDeviceAddressRegionKHR raygenSBT;
+ raygenSBT.deviceAddress = shaderTableBuffer->getDeviceAddress();
+ raygenSBT.stride = VulkanUtil::calcAligned(alignedHandleSize, rtProps.shaderGroupBaseAlignment);
+ raygenSBT.size = raygenSBT.stride;
+
+ VkStridedDeviceAddressRegionKHR missSBT;
+ missSBT.deviceAddress = raygenSBT.deviceAddress + raygenSBT.size;
+ missSBT.stride = alignedHandleSize;
+ missSBT.size = shaderTableImpl->m_missTableSize;
+
+ VkStridedDeviceAddressRegionKHR hitSBT;
+ hitSBT.deviceAddress = missSBT.deviceAddress + missSBT.size;
+ hitSBT.stride = alignedHandleSize;
+ hitSBT.size = shaderTableImpl->m_hitTableSize;
+
+ // TODO: Are callable shaders needed?
+ VkStridedDeviceAddressRegionKHR callableSBT;
+ callableSBT.deviceAddress = 0;
+ callableSBT.stride = 0;
+ callableSBT.size = 0;
+
+ vkApi.vkCmdTraceRaysKHR(vkCommandBuffer, &raygenSBT, &missSBT, &hitSBT, &callableSBT, (uint32_t)width, (uint32_t)height, (uint32_t)depth);
}
virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override
{
+ endEncodingImpl();
}
};
@@ -6811,6 +6967,10 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool
extendedFeatures.rayQueryFeatures.pNext = deviceFeatures2.pNext;
deviceFeatures2.pNext = &extendedFeatures.rayQueryFeatures;
+ // Ray tracing pipeline features
+ extendedFeatures.rayTracingPipelineFeatures.pNext = deviceFeatures2.pNext;
+ deviceFeatures2.pNext = &extendedFeatures.rayTracingPipelineFeatures;
+
// Acceleration structure features
extendedFeatures.accelerationStructureFeatures.pNext = deviceFeatures2.pNext;
deviceFeatures2.pNext = &extendedFeatures.accelerationStructureFeatures;
@@ -6941,6 +7101,7 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool
deviceExtensions.add(VK_KHR_SHADER_SUBGROUP_EXTENDED_TYPES_EXTENSION_NAME);
m_features.add("shader-subgroup-extended-types");
}
+
if (extendedFeatures.accelerationStructureFeatures.accelerationStructure)
{
extendedFeatures.accelerationStructureFeatures.pNext = (void*)deviceCreateInfo.pNext;
@@ -6949,6 +7110,15 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool
deviceExtensions.add(VK_KHR_DEFERRED_HOST_OPERATIONS_EXTENSION_NAME);
m_features.add("acceleration-structure");
}
+
+ if (extendedFeatures.rayTracingPipelineFeatures.rayTracingPipeline)
+ {
+ extendedFeatures.rayTracingPipelineFeatures.pNext = (void*)deviceCreateInfo.pNext;
+ deviceCreateInfo.pNext = &extendedFeatures.rayTracingPipelineFeatures;
+ deviceExtensions.add(VK_KHR_RAY_TRACING_PIPELINE_EXTENSION_NAME);
+ m_features.add("ray-tracing-pipeline");
+ }
+
if (extendedFeatures.rayQueryFeatures.rayQuery)
{
extendedFeatures.rayQueryFeatures.pNext = (void*)deviceCreateInfo.pNext;
@@ -6957,6 +7127,7 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool
m_features.add("ray-query");
m_features.add("ray-tracing");
}
+
if (extendedFeatures.bufferDeviceAddressFeatures.bufferDeviceAddress)
{
extendedFeatures.bufferDeviceAddressFeatures.pNext = (void*)deviceCreateInfo.pNext;
@@ -6982,6 +7153,12 @@ Result VKDevice::initVulkanInstanceAndDevice(const InteropHandle* handles, bool
m_features.add("robustness2");
}
+ VkPhysicalDeviceProperties2 extendedProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2 };
+ VkPhysicalDeviceRayTracingPipelinePropertiesKHR rtProps = { VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_TRACING_PIPELINE_PROPERTIES_KHR };
+ extendedProps.pNext = &rtProps;
+ m_api.vkGetPhysicalDeviceProperties2(m_api.m_physicalDevice, &extendedProps);
+ m_api.m_rtProperties = rtProps;
+
uint32_t extensionCount = 0;
m_api.vkEnumerateDeviceExtensionProperties(m_api.m_physicalDevice, NULL, &extensionCount, NULL);
Slang::List<VkExtensionProperties> extensions;
@@ -8672,14 +8849,16 @@ Result VKDevice::createProgram(
// uses "main" as the name. We should introduce a compiler parameter
// to control the entry point naming behavior in SPIRV-direct path
// so we can remove the ad-hoc logic here.
- const char* entryPointName = "main";
+ auto realEntryPointName = entryPointInfo->getNameOverride();
+ const char* spirvBinaryEntryPointName = "main";
if (m_desc.slang.targetFlags & SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY)
- entryPointName = entryPointInfo->getName();
+ spirvBinaryEntryPointName = realEntryPointName;
shaderProgram->m_stageCreateInfos.add(compileEntryPoint(
- entryPointName,
+ spirvBinaryEntryPointName,
kernelCode,
(VkShaderStageFlagBits)VulkanUtil::getShaderStage(stage),
shaderModule));
+ shaderProgram->m_entryPointNames.add(realEntryPointName);
shaderProgram->m_modules.add(shaderModule);
return SLANG_OK;
};
@@ -8749,6 +8928,15 @@ Result VKDevice::createMutableRootShaderObject(
return SLANG_OK;
}
+Result VKDevice::createShaderTable(const IShaderTable::Desc& desc, IShaderTable** outShaderTable)
+{
+ RefPtr<ShaderTableImpl> result = new ShaderTableImpl();
+ result->m_device = this;
+ result->init(desc);
+ returnComPtr(outShaderTable, result);
+ return SLANG_OK;
+}
+
VkSampleCountFlagBits translateSampleCount(uint32_t sampleCount)
{
switch (sampleCount)
@@ -8966,6 +9154,12 @@ Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& in
auto& vkBlendDesc = colorBlendAttachments[0];
memset(&vkBlendDesc, 0, sizeof(vkBlendDesc));
vkBlendDesc.blendEnable = VK_FALSE;
+ vkBlendDesc.srcColorBlendFactor = VK_BLEND_FACTOR_ONE;
+ vkBlendDesc.dstColorBlendFactor = VK_BLEND_FACTOR_ONE;
+ vkBlendDesc.colorBlendOp = VK_BLEND_OP_ADD;
+ vkBlendDesc.srcAlphaBlendFactor = VK_BLEND_FACTOR_ONE;
+ vkBlendDesc.dstAlphaBlendFactor = VK_BLEND_FACTOR_ONE;
+ vkBlendDesc.alphaBlendOp = VK_BLEND_OP_ADD;
vkBlendDesc.colorWriteMask = (VkColorComponentFlags)RenderTargetWriteMask::EnableAll;
}
else
@@ -9101,13 +9295,14 @@ VkPipelineCreateFlags translateFlags(RayTracingPipelineFlags::Enum flags)
return vkFlags;
}
-uint32_t findShaderIndexByName(const VkPipelineShaderStageCreateInfo* stageCreateInfos, size_t stageCount, const char* name)
+uint32_t findEntryPointIndexByName(const Dictionary<String, Index>& entryPointNameToIndex, const char* name)
{
- // TODO: Linear search is inefficient, use a Dictionary?
- for (size_t i = 0; i < stageCount; ++i)
- {
- if (strcmp(stageCreateInfos[i].pName, name)) return (uint32_t)i;
- }
+ if (!name) return VK_SHADER_UNUSED_KHR;
+
+ auto indexPtr = entryPointNameToIndex.TryGetValue(String(name));
+ if (indexPtr)
+ return (uint32_t)*indexPtr;
+ // TODO: Error reporting?
return VK_SHADER_UNUSED_KHR;
}
@@ -9128,11 +9323,39 @@ Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc
raytracingPipelineInfo.pNext = nullptr;
raytracingPipelineInfo.flags = translateFlags(desc.flags);
- VkPipelineShaderStageCreateInfo shaderStageInfo = { VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO };
raytracingPipelineInfo.stageCount = (uint32_t)programImpl->m_stageCreateInfos.getCount();
raytracingPipelineInfo.pStages = programImpl->m_stageCreateInfos.getBuffer();
+ // Build Dictionary from group name to group index
+ Dictionary<String, Index> shaderGroupNameToIndex;
+ // Build Dictionary from entry point name to entry point index (stageCreateInfos index) for all hit shaders - findShaderIndexByName
+ Dictionary<String, Index> entryPointNameToIndex;
+
List<VkRayTracingShaderGroupCreateInfoKHR> shaderGroupInfos;
+ for (uint32_t i = 0; i < raytracingPipelineInfo.stageCount; ++i)
+ {
+ auto stageCreateInfo = programImpl->m_stageCreateInfos[i];
+ auto entryPointName = programImpl->m_entryPointNames[i];
+ entryPointNameToIndex.Add(entryPointName, i);
+ if (stageCreateInfo.stage & (VK_SHADER_STAGE_ANY_HIT_BIT_KHR | VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR | VK_SHADER_STAGE_INTERSECTION_BIT_KHR))
+ continue;
+
+ VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR };
+ shaderGroupInfo.pNext = nullptr;
+ shaderGroupInfo.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
+ shaderGroupInfo.generalShader = i;
+ shaderGroupInfo.closestHitShader = VK_SHADER_UNUSED_KHR;
+ shaderGroupInfo.anyHitShader = VK_SHADER_UNUSED_KHR;
+ shaderGroupInfo.intersectionShader = VK_SHADER_UNUSED_KHR;
+ shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr;
+
+ // For groups with a single entry point, the group name is the entry point name.
+ auto shaderGroupName = entryPointName;
+ auto shaderGroupIndex = shaderGroupInfos.getCount();
+ shaderGroupInfos.add(shaderGroupInfo);
+ shaderGroupNameToIndex.Add(shaderGroupName, shaderGroupIndex);
+ }
+
for (int32_t i = 0; i < desc.hitGroupCount; ++i)
{
VkRayTracingShaderGroupCreateInfoKHR shaderGroupInfo = { VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR };
@@ -9142,12 +9365,14 @@ Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc
shaderGroupInfo.type = (groupDesc.intersectionEntryPoint)
? VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR : VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
shaderGroupInfo.generalShader = VK_SHADER_UNUSED_KHR;
- shaderGroupInfo.closestHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.closestHitEntryPoint);
- shaderGroupInfo.anyHitShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.anyHitEntryPoint);
- shaderGroupInfo.intersectionShader = findShaderIndexByName(raytracingPipelineInfo.pStages, raytracingPipelineInfo.stageCount, groupDesc.intersectionEntryPoint);
+ shaderGroupInfo.closestHitShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.closestHitEntryPoint);
+ shaderGroupInfo.anyHitShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.anyHitEntryPoint);
+ shaderGroupInfo.intersectionShader = findEntryPointIndexByName(entryPointNameToIndex, groupDesc.intersectionEntryPoint);
shaderGroupInfo.pShaderGroupCaptureReplayHandle = nullptr;
+ auto shaderGroupIndex = shaderGroupInfos.getCount();
shaderGroupInfos.add(shaderGroupInfo);
+ shaderGroupNameToIndex.Add(String(groupDesc.hitGroupName), shaderGroupIndex);
}
raytracingPipelineInfo.groupCount = (uint32_t)shaderGroupInfos.getCount();
@@ -9168,10 +9393,12 @@ Result VKDevice::createRayTracingPipelineState(const RayTracingPipelineStateDesc
VkPipeline pipeline = VK_NULL_HANDLE;
SLANG_VK_CHECK(m_api.vkCreateRayTracingPipelinesKHR(m_device, VK_NULL_HANDLE, pipelineCache, 1, &raytracingPipelineInfo, nullptr, &pipeline));
- RefPtr<PipelineStateImpl> pipelineStateImpl = new PipelineStateImpl(this);
+ RefPtr<RayTracingPipelineStateImpl> pipelineStateImpl = new RayTracingPipelineStateImpl(this);
pipelineStateImpl->m_pipeline = pipeline;
pipelineStateImpl->init(desc);
pipelineStateImpl->establishStrongDeviceReference();
+ pipelineStateImpl->shaderGroupNameToIndex = shaderGroupNameToIndex;
+ pipelineStateImpl->shaderGroupCount = shaderGroupInfos.getCount();
m_deviceObjectsWithPotentialBackReferences.add(pipelineStateImpl);
returnComPtr(outState, pipelineStateImpl);
return SLANG_OK;
diff --git a/tools/gfx/vulkan/vk-api.h b/tools/gfx/vulkan/vk-api.h
index 8c87e6254..fa601ef35 100644
--- a/tools/gfx/vulkan/vk-api.h
+++ b/tools/gfx/vulkan/vk-api.h
@@ -164,6 +164,8 @@ namespace gfx {
x(vkDestroySwapchainKHR) \
x(vkAcquireNextImageKHR) \
x(vkCreateRayTracingPipelinesKHR) \
+ x(vkCmdTraceRaysKHR) \
+ x(vkGetRayTracingShaderGroupHandlesKHR) \
/* */
#if SLANG_WINDOWS_FAMILY
@@ -251,6 +253,10 @@ struct VulkanExtendedFeatureProperties
// Acceleration structure features
VkPhysicalDeviceAccelerationStructureFeaturesKHR accelerationStructureFeatures = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ACCELERATION_STRUCTURE_FEATURES_KHR};
+ // Ray tracing pipeline features
+ VkPhysicalDeviceRayTracingPipelineFeaturesKHR rayTracingPipelineFeatures = {
+ VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_TRACING_PIPELINE_FEATURES_KHR
+ };
// Ray query (inline ray-tracing) features
VkPhysicalDeviceRayQueryFeaturesKHR rayQueryFeatures = {
VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_RAY_QUERY_FEATURES_KHR};
@@ -303,10 +309,11 @@ struct VulkanApi
VkDevice m_device = VK_NULL_HANDLE;
VkPhysicalDevice m_physicalDevice = VK_NULL_HANDLE;
- VkPhysicalDeviceProperties m_deviceProperties;
- VkPhysicalDeviceFeatures m_deviceFeatures;
- VkPhysicalDeviceMemoryProperties m_deviceMemoryProperties;
- VulkanExtendedFeatureProperties m_extendedFeatures;
+ VkPhysicalDeviceProperties m_deviceProperties;
+ VkPhysicalDeviceRayTracingPipelinePropertiesKHR m_rtProperties;
+ VkPhysicalDeviceFeatures m_deviceFeatures;
+ VkPhysicalDeviceMemoryProperties m_deviceMemoryProperties;
+ VulkanExtendedFeatureProperties m_extendedFeatures;
};
} // renderer_test
diff --git a/tools/gfx/vulkan/vk-util.cpp b/tools/gfx/vulkan/vk-util.cpp
index 633239f24..974a61422 100644
--- a/tools/gfx/vulkan/vk-util.cpp
+++ b/tools/gfx/vulkan/vk-util.cpp
@@ -140,6 +140,8 @@ VkShaderStageFlags VulkanUtil::getShaderStage(SlangStage stage)
return VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT;
case SLANG_STAGE_INTERSECTION:
return VK_SHADER_STAGE_INTERSECTION_BIT_KHR;
+ case SLANG_STAGE_MISS:
+ return VK_SHADER_STAGE_MISS_BIT_KHR;
case SLANG_STAGE_RAY_GENERATION:
return VK_SHADER_STAGE_RAYGEN_BIT_KHR;
case SLANG_STAGE_VERTEX:
diff --git a/tools/gfx/vulkan/vk-util.h b/tools/gfx/vulkan/vk-util.h
index cdf7bcc79..4c32b1615 100644
--- a/tools/gfx/vulkan/vk-util.h
+++ b/tools/gfx/vulkan/vk-util.h
@@ -44,6 +44,9 @@ struct VulkanUtil
static VkImageLayout getImageLayoutFromState(ResourceState state);
+ /// Calculate size taking into account alignment. Alignment must be a power of 2
+ static UInt calcAligned(UInt size, UInt alignment) { return (size + alignment - 1) & ~(alignment - 1); }
+
static inline bool isDepthFormat(VkFormat format)
{
switch (format)