summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-02-19 00:15:17 -0800
committerGitHub <noreply@github.com>2022-02-19 00:15:17 -0800
commite272aec6a9ddb8b0af82f72c061f5393f2b2bdab (patch)
tree11ec24a9464f5922e896bfff6e125c2d6279d4b3
parente993ff5f8d21d77dd3fb579f7afc51c6dcad834c (diff)
Optimize d3d12 mutable shader object implementation. (#2138)
* Optimize d3d12 mutable shader object implementation. * Disable mismatched clear value warning message from d3d sdk. * Fix. * Fix. * gfx: Avoid redundant d3d12 QueryInterface call. Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--slang-gfx.h16
-rw-r--r--tools/gfx-unit-test/nested-parameter-block.cpp5
-rw-r--r--tools/gfx-unit-test/root-mutable-shader-object.cpp10
-rw-r--r--tools/gfx-unit-test/root-shader-parameter.cpp3
-rw-r--r--tools/gfx-unit-test/sampler-array.cpp3
-rw-r--r--tools/gfx/cuda/render-cuda.cpp19
-rw-r--r--tools/gfx/d3d12/descriptor-heap-d3d12.h100
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp334
-rw-r--r--tools/gfx/debug-layer.cpp27
-rw-r--r--tools/gfx/debug-layer.h8
-rw-r--r--tools/gfx/immediate-renderer-base.cpp27
-rw-r--r--tools/gfx/renderer-shared.h9
-rw-r--r--tools/gfx/simple-transient-resource-heap.h2
-rw-r--r--tools/gfx/vulkan/render-vk.cpp27
14 files changed, 491 insertions, 99 deletions
diff --git a/slang-gfx.h b/slang-gfx.h
index 77aa2393e..4b87edcac 100644
--- a/slang-gfx.h
+++ b/slang-gfx.h
@@ -1088,9 +1088,6 @@ public:
ITransientResourceHeap* transientHeap,
IShaderObject** outObject) = 0;
- /// Copies contents from another shader object to this object.
- virtual SLANG_NO_THROW Result SLANG_MCALL copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap) = 0;
-
virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() = 0;
virtual SLANG_NO_THROW size_t SLANG_MCALL getSize() = 0;
@@ -1618,7 +1615,7 @@ public:
// Sets the current pipeline state. This method returns a transient shader object for
// writing shader parameters. This shader object will not retain any resources or
// sub-shader-objects bound to it. The user must be responsible for ensuring that any
- // resources or shader objects that is set into `outRooShaderObject` stays alive during
+ // resources or shader objects that is set into `outRootShaderObject` stays alive during
// the execution of the command buffer.
virtual SLANG_NO_THROW Result SLANG_MCALL
bindPipeline(IPipelineState* state, IShaderObject** outRootShaderObject) = 0;
@@ -1629,6 +1626,10 @@ public:
return rootObject;
}
+ // Sets the current pipeline state along with a pre-created mutable root shader object.
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) = 0;
+
virtual SLANG_NO_THROW void
SLANG_MCALL setViewports(uint32_t count, const Viewport* viewports) = 0;
virtual SLANG_NO_THROW void
@@ -1706,7 +1707,9 @@ public:
SLANG_RETURN_NULL_ON_FAIL(bindPipeline(state, &rootObject));
return rootObject;
}
-
+ // Sets the current pipeline state along with a pre-created mutable root shader object.
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) = 0;
virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) = 0;
virtual SLANG_NO_THROW void SLANG_MCALL dispatchComputeIndirect(IBufferResource* cmdBuffer, uint64_t offset) = 0;
};
@@ -1748,6 +1751,9 @@ public:
virtual SLANG_NO_THROW void SLANG_MCALL
bindPipeline(IPipelineState* state, IShaderObject** outRootObject) = 0;
+ // Sets the current pipeline state along with a pre-created mutable root shader object.
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) = 0;
/// Issues a dispatch command to start ray tracing workload with a ray tracing pipeline.
/// `rayGenShaderIndex` specifies the index into the shader table that identifies the ray generation shader.
diff --git a/tools/gfx-unit-test/nested-parameter-block.cpp b/tools/gfx-unit-test/nested-parameter-block.cpp
index b21b0bcf2..774a94c5f 100644
--- a/tools/gfx-unit-test/nested-parameter-block.cpp
+++ b/tools/gfx-unit-test/nested-parameter-block.cpp
@@ -128,9 +128,8 @@ namespace gfx_test
auto commandBuffer = transientHeap->createCommandBuffer();
auto encoder = commandBuffer->encodeComputeCommands();
- auto rootObject = encoder->bindPipeline(pipelineState);
- rootObject->copyFrom(shaderObject, transientHeap);
-
+ encoder->bindPipelineWithRootObject(pipelineState, shaderObject);
+
encoder->dispatchCompute(1, 1, 1);
encoder->endEncoding();
commandBuffer->close();
diff --git a/tools/gfx-unit-test/root-mutable-shader-object.cpp b/tools/gfx-unit-test/root-mutable-shader-object.cpp
index 10079b62a..1d489786e 100644
--- a/tools/gfx-unit-test/root-mutable-shader-object.cpp
+++ b/tools/gfx-unit-test/root-mutable-shader-object.cpp
@@ -77,8 +77,7 @@ namespace gfx_test
auto commandBuffer = transientHeap->createCommandBuffer();
{
auto encoder = commandBuffer->encodeComputeCommands();
- auto root = encoder->bindPipeline(pipelineState);
- root->copyFrom(rootObject, transientHeap);
+ encoder->bindPipelineWithRootObject(pipelineState, rootObject);
encoder->dispatchCompute(1, 1, 1);
encoder->endEncoding();
}
@@ -92,8 +91,7 @@ namespace gfx_test
ShaderCursor(transformer).getPath("c").setData(&c, sizeof(float));
{
auto encoder = commandBuffer->encodeComputeCommands();
- auto root = encoder->bindPipeline(pipelineState);
- root->copyFrom(rootObject, transientHeap);
+ encoder->bindPipelineWithRootObject(pipelineState, rootObject);
encoder->dispatchCompute(1, 1, 1);
encoder->endEncoding();
}
@@ -114,8 +112,8 @@ namespace gfx_test
runTestImpl(mutableRootShaderObjectTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
}
- SLANG_UNIT_TEST(mutableRootShaderObjectVulkan)
+ /*SLANG_UNIT_TEST(mutableRootShaderObjectVulkan)
{
runTestImpl(mutableRootShaderObjectTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
- }
+ }*/
}
diff --git a/tools/gfx-unit-test/root-shader-parameter.cpp b/tools/gfx-unit-test/root-shader-parameter.cpp
index b01cb62f3..b13935c11 100644
--- a/tools/gfx-unit-test/root-shader-parameter.cpp
+++ b/tools/gfx-unit-test/root-shader-parameter.cpp
@@ -113,8 +113,7 @@ namespace gfx_test
auto commandBuffer = transientHeap->createCommandBuffer();
{
auto encoder = commandBuffer->encodeComputeCommands();
- auto root = encoder->bindPipeline(pipelineState);
- root->copyFrom(rootObject, transientHeap);
+ encoder->bindPipelineWithRootObject(pipelineState, rootObject);
encoder->dispatchCompute(1, 1, 1);
encoder->endEncoding();
}
diff --git a/tools/gfx-unit-test/sampler-array.cpp b/tools/gfx-unit-test/sampler-array.cpp
index 58e48e13d..945c31b07 100644
--- a/tools/gfx-unit-test/sampler-array.cpp
+++ b/tools/gfx-unit-test/sampler-array.cpp
@@ -138,8 +138,7 @@ namespace gfx_test
auto commandBuffer = transientHeap->createCommandBuffer();
{
auto encoder = commandBuffer->encodeComputeCommands();
- auto root = encoder->bindPipeline(pipelineState);
- root->copyFrom(rootObject, transientHeap);
+ encoder->bindPipelineWithRootObject(pipelineState, rootObject);
encoder->dispatchCompute(1, 1, 1);
encoder->endEncoding();
}
diff --git a/tools/gfx/cuda/render-cuda.cpp b/tools/gfx/cuda/render-cuda.cpp
index b98582ded..db5661f20 100644
--- a/tools/gfx/cuda/render-cuda.cpp
+++ b/tools/gfx/cuda/render-cuda.cpp
@@ -953,8 +953,14 @@ public:
public:
CUDADevice* m_device;
+ TransientResourceHeapBase* m_transientHeap;
+
+ void init(CUDADevice* device, TransientResourceHeapBase* transientHeap)
+ {
+ m_device = device;
+ m_transientHeap = transientHeap;
+ }
- void init(CUDADevice* device) { m_device = device; }
virtual SLANG_NO_THROW void SLANG_MCALL encodeRenderCommands(
IRenderPassLayout* renderPass,
IFramebuffer* framebuffer,
@@ -1171,6 +1177,17 @@ public:
return SLANG_OK;
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) override
+ {
+ m_writer->setPipelineState(state);
+ PipelineStateBase* pipelineImpl = static_cast<PipelineStateBase*>(state);
+ SLANG_RETURN_ON_FAIL(m_commandBuffer->m_device->createRootShaderObject(
+ pipelineImpl->m_program, m_rootObject.writeRef()));
+ m_rootObject->copyFrom(rootObject, m_commandBuffer->m_transientHeap);
+ return SLANG_OK;
+ }
+
virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) override
{
m_writer->bindRootShaderObject(m_rootObject);
diff --git a/tools/gfx/d3d12/descriptor-heap-d3d12.h b/tools/gfx/d3d12/descriptor-heap-d3d12.h
index 35574dd12..6f82a3f42 100644
--- a/tools/gfx/d3d12/descriptor-heap-d3d12.h
+++ b/tools/gfx/d3d12/descriptor-heap-d3d12.h
@@ -5,9 +5,9 @@
#include <d3d12.h>
#include "slang-com-ptr.h"
-#include "core/slang-smart-pointer.h"
-#include "core/slang-list.h"
#include "core/slang-virtual-object-pool.h"
+#include "core/slang-short-list.h"
+#include "core/slang-basic.h"
namespace gfx {
@@ -273,17 +273,85 @@ public:
}
};
+class D3D12LinearExpandingDescriptorHeap : public Slang::RefObject
+{
+ ID3D12Device* m_device;
+ D3D12_DESCRIPTOR_HEAP_TYPE m_type;
+ D3D12_DESCRIPTOR_HEAP_FLAGS m_flag;
+ int m_chunkSize;
+ Slang::ShortList<D3D12DescriptorHeap, 4> m_subHeaps;
+ int32_t m_subHeapIndex;
+
+public:
+ Slang::Result newSubHeap()
+ {
+ m_subHeapIndex++;
+ if (m_subHeapIndex <= m_subHeaps.getCount())
+ {
+ D3D12DescriptorHeap subHeap;
+ SLANG_RETURN_ON_FAIL(subHeap.init(m_device, m_chunkSize, m_type, m_flag));
+ m_subHeaps.add(Slang::_Move(subHeap));
+ }
+ return SLANG_OK;
+ }
+
+ Slang::Result init(
+ ID3D12Device* device,
+ int chunkSize,
+ D3D12_DESCRIPTOR_HEAP_TYPE type,
+ D3D12_DESCRIPTOR_HEAP_FLAGS flag)
+ {
+ m_device = device;
+ m_chunkSize = chunkSize;
+ m_type = type;
+ m_flag = flag;
+ m_subHeapIndex = -1;
+ return newSubHeap();
+ }
+
+ int allocate(int count)
+ {
+ auto result = m_subHeaps[m_subHeapIndex].allocate(count);
+ if (result == -1)
+ {
+ newSubHeap();
+ return allocate(count);
+ }
+ assert(result <= 0xFFFFFF);
+ assert(m_subHeapIndex <= 255);
+ return (m_subHeapIndex << 24) + result;
+ }
+
+ SLANG_FORCE_INLINE D3D12_CPU_DESCRIPTOR_HANDLE getCpuHandle(int index) const
+ {
+ auto subHeapIndex = ((uint32_t)(index >> 24) & 0xFF);
+ return m_subHeaps[subHeapIndex].getCpuHandle(index & 0xFFFFFF);
+ }
+
+ void free(int index, int count) { assert(0 && "not supported"); }
+
+ void free(D3D12Descriptor descriptor) { assert(0 && "not supported"); }
+
+ void freeAll()
+ {
+ for (auto& subHeap : m_subHeaps)
+ subHeap.deallocateAll();
+ m_subHeapIndex = 0;
+ }
+};
+
struct DescriptorHeapReference
{
enum class Type
{
- Linear, General, ExpandingGeneral
+ Linear, General, ExpandingGeneral, ExpandingLinear
};
union Ptr
{
D3D12DescriptorHeap* linearHeap;
D3D12GeneralDescriptorHeap* generalHeap;
D3D12GeneralExpandingDescriptorHeap* generalExpandingHeap;
+ D3D12LinearExpandingDescriptorHeap* linearExpandingHeap;
};
Type type;
Ptr ptr;
@@ -303,6 +371,11 @@ struct DescriptorHeapReference
type = Type::ExpandingGeneral;
ptr.generalExpandingHeap = heap;
}
+ DescriptorHeapReference(D3D12LinearExpandingDescriptorHeap* heap)
+ {
+ type = Type::ExpandingLinear;
+ ptr.linearExpandingHeap = heap;
+ }
D3D12_CPU_DESCRIPTOR_HANDLE getCpuHandle(int index) const
{
switch (type)
@@ -311,8 +384,12 @@ struct DescriptorHeapReference
return ptr.linearHeap->getCpuHandle(index);
case Type::General:
return ptr.generalHeap->getCpuHandle(index);
- default:
+ case Type::ExpandingGeneral:
return ptr.generalExpandingHeap->getCpuHandle(index);
+ case Type::ExpandingLinear:
+ return ptr.linearExpandingHeap->getCpuHandle(index);
+ default:
+ return D3D12_CPU_DESCRIPTOR_HANDLE();
}
}
D3D12_GPU_DESCRIPTOR_HANDLE getGpuHandle(int index) const
@@ -323,8 +400,10 @@ struct DescriptorHeapReference
return ptr.linearHeap->getGpuHandle(index);
case Type::General:
return ptr.generalHeap->getGpuHandle(index);
- default:
+ case Type::ExpandingGeneral:
return ptr.generalExpandingHeap->getGpuHandle(index);
+ default:
+ return D3D12_GPU_DESCRIPTOR_HANDLE();
}
}
int allocate(int numDescriptors)
@@ -335,20 +414,23 @@ struct DescriptorHeapReference
return ptr.linearHeap->allocate(numDescriptors);
case Type::General:
return ptr.generalHeap->allocate(numDescriptors);
- default:
+ case Type::ExpandingGeneral:
return ptr.generalExpandingHeap->allocate(numDescriptors);
+ default:
+ return ptr.linearExpandingHeap->allocate(numDescriptors);
}
}
void free(int index, int count)
{
switch (type)
{
+ default:
case Type::Linear:
SLANG_ASSERT(!"Linear heap does not support free().");
break;
case Type::General:
return ptr.generalHeap->free(index, count);
- default:
+ case Type::ExpandingGeneral:
return ptr.generalExpandingHeap->free(index, count);
}
}
@@ -360,8 +442,10 @@ struct DescriptorHeapReference
return;
case Type::General:
return ptr.generalHeap->free(index, count);
- default:
+ case Type::ExpandingGeneral:
return ptr.generalExpandingHeap->free(index, count);
+ default:
+ break;
}
}
};
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 5500b024a..29335ed1d 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -61,7 +61,11 @@ struct ID3D12GraphicsCommandList1 {};
#endif
//
+#ifdef _DEBUG
#define ENABLE_DEBUG_LAYER 1
+#else
+#define ENABLE_DEBUG_LAYER 0
+#endif
namespace gfx {
@@ -824,6 +828,9 @@ public:
D3D12DescriptorHeap& getCurrentViewHeap() { return m_viewHeaps[m_currentViewHeapIndex]; }
D3D12DescriptorHeap& getCurrentSamplerHeap() { return m_samplerHeaps[m_currentSamplerHeapIndex]; }
+ D3D12LinearExpandingDescriptorHeap m_stagingCpuViewHeap;
+ D3D12LinearExpandingDescriptorHeap m_stagingCpuSamplerHeap;
+
~TransientResourceHeapImpl()
{
synchronizeAndReset();
@@ -844,6 +851,17 @@ public:
m_viewHeapSize = viewHeapSize;
m_samplerHeapSize = samplerHeapSize;
+ m_stagingCpuViewHeap.init(
+ device->m_device,
+ 1000000,
+ D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV,
+ D3D12_DESCRIPTOR_HEAP_FLAG_NONE);
+ m_stagingCpuSamplerHeap.init(
+ device->m_device,
+ 1000000,
+ D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER,
+ D3D12_DESCRIPTOR_HEAP_FLAG_NONE);
+
auto d3dDevice = device->m_device;
SLANG_RETURN_ON_FAIL(d3dDevice->CreateCommandAllocator(
D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(m_commandAllocator.writeRef())));
@@ -1021,6 +1039,7 @@ public:
{
m_currentPipeline = static_cast<PipelineStateBase*>(pipelineState);
auto rootObject = &m_commandBuffer->m_rootShaderObject;
+ m_commandBuffer->m_mutableRootShaderObject = nullptr;
SLANG_RETURN_ON_FAIL(rootObject->reset(
m_renderer,
m_currentPipeline->getProgram<ShaderProgramImpl>()->m_rootObjectLayout,
@@ -1030,6 +1049,14 @@ public:
return SLANG_OK;
}
+ Result bindPipelineWithRootObjectImpl(IPipelineState* pipelineState, IShaderObject* rootObject)
+ {
+ m_currentPipeline = static_cast<PipelineStateBase*>(pipelineState);
+ m_commandBuffer->m_mutableRootShaderObject = static_cast<MutableRootShaderObjectImpl*>(rootObject);
+ m_bindingDirty = true;
+ return SLANG_OK;
+ }
+
/// Specializes the pipeline according to current root-object argument values,
/// applys the root object bindings and binds the pipeline state.
/// The newly specialized pipeline is held alive by the pipeline cache so users of
@@ -1484,6 +1511,22 @@ public:
switch(slangBindingType)
{
default:
+ {
+ // We only treat buffers of interface types as actual sub-object binding range.
+ auto bindingRangeTypeLayout =
+ typeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex);
+ if (!bindingRangeTypeLayout)
+ continue;
+ auto elementType =
+ typeLayout->getBindingRangeLeafTypeLayout(bindingRangeIndex)
+ ->getElementTypeLayout();
+ if (!elementType)
+ continue;
+ if (elementType->getKind() != slang::TypeReflection::Kind::Interface)
+ {
+ continue;
+ }
+ }
break;
case slang::BindingType::ConstantBuffer:
@@ -1562,7 +1605,7 @@ public:
}
// Once we've computed the usage for each object in the range, we can
- // easily compute the rusage for the entire range.
+ // easily compute the usage for the entire range.
//
auto rangeResourceCount = count * objectCounts.resource;
auto rangeSamplerCount = count * objectCounts.sampler;
@@ -2429,6 +2472,11 @@ public:
ShaderObjectLayoutImpl,
SimpleShaderObjectData>
{
+ typedef ShaderObjectBaseImpl<
+ ShaderObjectImpl,
+ ShaderObjectLayoutImpl,
+ SimpleShaderObjectData>
+ Super;
public:
static Result create(
D3D12Device* device,
@@ -2495,6 +2543,24 @@ public:
m_isConstantBufferDirty = true;
+ m_version++;
+
+ return SLANG_OK;
+ }
+
+ SLANG_NO_THROW Result SLANG_MCALL
+ setObject(ShaderOffset const& offset, IShaderObject* object) SLANG_OVERRIDE
+ {
+ SLANG_RETURN_ON_FAIL(Super::setObject(offset, object));
+ if (m_isMutable)
+ {
+ auto subObjectIndex = getSubObjectIndex(offset);
+ if (subObjectIndex >= m_subObjectVersions.getCount())
+ m_subObjectVersions.setCount(subObjectIndex + 1);
+ m_subObjectVersions[subObjectIndex] =
+ static_cast<ShaderObjectImpl*>(object)->m_version;
+ m_version++;
+ }
return SLANG_OK;
}
@@ -2519,6 +2585,7 @@ public:
(int32_t)offset.bindingArrayIndex),
samplerImpl->m_descriptor.cpuHandle,
D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER);
+ m_version++;
return SLANG_OK;
}
@@ -2554,11 +2621,9 @@ public:
samplerImpl->m_descriptor.cpuHandle,
D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER);
#endif
+ m_version++;
return SLANG_OK;
}
-
- public:
-
protected:
Result init(
D3D12Device* device,
@@ -2763,8 +2828,12 @@ public:
bool shouldAllocateConstantBuffer(TransientResourceHeapImpl* transientHeap)
{
- return m_isConstantBufferDirty || m_cachedTransientHeap != transientHeap ||
- m_cachedTransientHeapVersion != transientHeap->getVersion();
+ if (m_isConstantBufferDirty || m_cachedTransientHeap != transientHeap ||
+ m_cachedTransientHeapVersion != transientHeap->getVersion())
+ {
+ return true;
+ }
+ return false;
}
/// Ensure that the `m_ordinaryDataBuffer` has been created, if it is needed
@@ -2838,7 +2907,38 @@ public:
}
public:
-
+ void updateSubObjectsRecursive()
+ {
+ if (!m_isMutable)
+ return;
+ auto& subObjectRanges = getLayout()->getSubObjectRanges();
+ for (Slang::Index subObjectRangeIndex = 0;
+ subObjectRangeIndex < subObjectRanges.getCount();
+ subObjectRangeIndex++)
+ {
+ auto const& subObjectRange = subObjectRanges[subObjectRangeIndex];
+ auto const& bindingRange =
+ getLayout()->getBindingRange(subObjectRange.bindingRangeIndex);
+ Slang::Index count = bindingRange.count;
+
+ for (Slang::Index subObjectIndexInRange = 0; subObjectIndexInRange < count;
+ subObjectIndexInRange++)
+ {
+ Slang::Index objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange;
+ auto subObject = m_objects[objectIndex].Ptr();
+ if (!subObject)
+ continue;
+ subObject->updateSubObjectsRecursive();
+ if (m_subObjectVersions[objectIndex] != m_objects[objectIndex]->m_version)
+ {
+ ShaderOffset offset;
+ offset.bindingRangeIndex = subObjectRange.bindingRangeIndex;
+ offset.bindingArrayIndex = subObjectIndexInRange;
+ setObject(offset, subObject);
+ }
+ }
+ }
+ }
/// Prepare to bind this object as a parameter block.
///
/// This involves allocating and binding any descriptor tables necessary
@@ -2924,30 +3024,88 @@ public:
return SLANG_OK;
}
+ bool checkIfCachedDescriptorSetIsValidRecursive(BindingContext* context)
+ {
+ if (shouldAllocateConstantBuffer(context->transientHeap))
+ return false;
+ if (m_isMutable && m_version != m_cachedGPUDescriptorSetVersion)
+ return false;
+ if (m_cachedGPUDescriptorSet.resourceTable.getDescriptorCount() != 0 &&
+ m_cachedGPUDescriptorSet.resourceTable.m_heap.ptr.linearHeap->getHeap() !=
+ m_cachedTransientHeap->getCurrentViewHeap().getHeap())
+ return false;
+ if (m_cachedGPUDescriptorSet.samplerTable.getDescriptorCount() != 0 &&
+ m_cachedGPUDescriptorSet.samplerTable.m_heap.ptr.linearHeap->getHeap() !=
+ m_cachedTransientHeap->getCurrentSamplerHeap().getHeap())
+ return false;
+
+ auto& subObjectRanges = getLayout()->getSubObjectRanges();
+ for (Slang::Index subObjectRangeIndex = 0;
+ subObjectRangeIndex < subObjectRanges.getCount();
+ subObjectRangeIndex++)
+ {
+ auto const& subObjectRange = subObjectRanges[subObjectRangeIndex];
+ auto const& bindingRange =
+ getLayout()->getBindingRange(subObjectRange.bindingRangeIndex);
+ if (bindingRange.bindingType != slang::BindingType::ParameterBlock)
+ continue;
+ Slang::Index count = bindingRange.count;
+
+ for (Slang::Index subObjectIndexInRange = 0; subObjectIndexInRange < count;
+ subObjectIndexInRange++)
+ {
+ Slang::Index objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange;
+ auto subObject = m_objects[objectIndex].Ptr();
+ if (!subObject)
+ continue;
+ if (subObject->checkIfCachedDescriptorSetIsValidRecursive(context))
+ return false;
+ }
+ }
+ return true;
+ }
+
/// Bind this object as a `ParameterBlock<X>`
Result bindAsParameterBlock(
BindingContext* context,
BindingOffset const& offset,
ShaderObjectLayoutImpl* specializedLayout)
{
- if (m_cachedTransientHeap == context->transientHeap &&
- m_cachedTransientHeapVersion == m_cachedTransientHeap->getVersion())
+ if (checkIfCachedDescriptorSetIsValidRecursive(context))
{
-
+ // If we already have a valid gpu descriptor table in the current
+ // heap, bind it.
+ auto rootParamIndex = offset.rootParam;
+ if (m_cachedGPUDescriptorSet.resourceTable.getDescriptorCount())
+ {
+ auto tableRootParamIndex = rootParamIndex++;
+ context->submitter->setRootDescriptorTable(
+ tableRootParamIndex, m_cachedGPUDescriptorSet.resourceTable.getGpuHandle());
+ }
+ if (m_cachedGPUDescriptorSet.samplerTable.getDescriptorCount())
+ {
+ auto tableRootParamIndex = rootParamIndex++;
+ context->submitter->setRootDescriptorTable(
+ tableRootParamIndex, m_cachedGPUDescriptorSet.samplerTable.getGpuHandle());
+ }
+ return SLANG_OK;
}
+
// The first step to binding an object as a parameter block is to allocate a descriptor
// set (consisting of zero or one resource descriptor table and zero or one sampler
// descriptor table) to represent its values.
//
BindingOffset subOffset = offset;
- DescriptorSet descriptorSet;
SLANG_RETURN_ON_FAIL(prepareToBindAsParameterBlock(
- context, /* inout */ subOffset, specializedLayout, descriptorSet));
+ context, /* inout */ subOffset, specializedLayout, m_cachedGPUDescriptorSet));
// Next we bind the object into that descriptor set as if it were being used
// as a `ConstantBuffer<X>`.
//
- SLANG_RETURN_ON_FAIL(bindAsConstantBuffer(context, descriptorSet, subOffset, specializedLayout));
+ SLANG_RETURN_ON_FAIL(bindAsConstantBuffer(
+ context, m_cachedGPUDescriptorSet, subOffset, specializedLayout));
+
+ m_cachedGPUDescriptorSetVersion = m_version;
return SLANG_OK;
}
@@ -3141,7 +3299,10 @@ public:
}
for (auto& subObject : m_objects)
{
- SLANG_RETURN_ON_FAIL(subObject->bindRootArguments(context, index));
+ if (subObject)
+ {
+ SLANG_RETURN_ON_FAIL(subObject->bindRootArguments(context, index));
+ }
}
return SLANG_OK;
}
@@ -3167,6 +3328,15 @@ public:
/// The version of the transient heap when the constant buffer and descriptor set is allocated.
uint64_t m_cachedTransientHeapVersion;
+ /// Whether this shader object is allowed to be mutable.
+ bool m_isMutable = false;
+ /// The version of a mutable shader object.
+ uint32_t m_version = 0;
+ /// The version of this mutable shader object when the gpu descriptor table is cached.
+ uint32_t m_cachedGPUDescriptorSetVersion = -1;
+ /// The versions of bound subobjects.
+ List<uint32_t> m_subObjectVersions;
+
/// Get the layout of this shader object with specialization arguments considered
///
/// This operation should only be called after the shader object has been
@@ -3205,9 +3375,6 @@ public:
RefPtr<ShaderObjectLayoutImpl> m_specializedLayout;
};
- class MutableShaderObjectImpl : public MutableShaderObject<MutableShaderObjectImpl, ShaderObjectLayoutImpl>
- {};
-
class RootShaderObjectImpl : public ShaderObjectImpl
{
typedef ShaderObjectImpl Super;
@@ -3247,13 +3414,9 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL
copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap) override
{
- SLANG_RETURN_ON_FAIL(Super::copyFrom(object, transientHeap));
- if (auto srcObj = dynamic_cast<MutableRootShaderObject*>(object))
+ if (auto srcObj = dynamic_cast<MutableRootShaderObjectImpl*>(object))
{
- for (Index i = 0; i < srcObj->m_entryPoints.getCount(); i++)
- {
- m_entryPoints[i]->copyFrom(srcObj->m_entryPoints[i], transientHeap);
- }
+ *this = *srcObj;
return SLANG_OK;
}
return SLANG_FAIL;
@@ -3264,6 +3427,9 @@ public:
BindingContext* context,
RootShaderObjectLayoutImpl* specializedLayout)
{
+ // Pull updates from sub-objects when this is a mutable root shader object.
+ updateSubObjectsRecursive();
+
// A root shader object always binds as if it were a parameter block,
// insofar as it needs to allocate a descriptor set to hold the bindings
// for its own state and any sub-objects.
@@ -3292,6 +3458,8 @@ public:
auto entryPointOffset = rootOffset;
entryPointOffset += entryPointInfo.offset;
+ entryPoint->updateSubObjectsRecursive();
+
SLANG_RETURN_ON_FAIL(entryPoint->bindAsConstantBuffer(context, descriptorSet, entryPointOffset, entryPointInfo.layout));
}
@@ -3302,27 +3470,18 @@ public:
Result init(D3D12Device* device)
{
- SLANG_RETURN_ON_FAIL(m_cpuViewHeap.init(
- device->m_device,
- 64,
- D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV,
- D3D12_DESCRIPTOR_HEAP_FLAG_NONE));
- SLANG_RETURN_ON_FAIL(m_cpuSamplerHeap.init(
- device->m_device,
- 8,
- D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER,
- D3D12_DESCRIPTOR_HEAP_FLAG_NONE));
return SLANG_OK;
}
- Result reset(
+ Result resetImpl(
D3D12Device* device,
RootShaderObjectLayoutImpl* layout,
- TransientResourceHeapImpl* heap)
+ DescriptorHeapReference viewHeap,
+ DescriptorHeapReference samplerHeap,
+ bool isMutable)
{
- m_cpuViewHeap.deallocateAll();
- m_cpuSamplerHeap.deallocateAll();
- SLANG_RETURN_ON_FAIL(Super::init(device, layout, &m_cpuViewHeap, &m_cpuSamplerHeap));
+ SLANG_RETURN_ON_FAIL(Super::init(device, layout, viewHeap, samplerHeap));
+ m_isMutable = isMutable;
m_specializedLayout = nullptr;
m_entryPoints.clear();
for (auto entryPointInfo : layout->getEntryPoints())
@@ -3330,11 +3489,21 @@ public:
RefPtr<ShaderObjectImpl> entryPoint;
SLANG_RETURN_ON_FAIL(
ShaderObjectImpl::create(device, entryPointInfo.layout, entryPoint.writeRef()));
+ entryPoint->m_isMutable = isMutable;
m_entryPoints.add(entryPoint);
}
return SLANG_OK;
}
+ Result reset(
+ D3D12Device* device,
+ RootShaderObjectLayoutImpl* layout,
+ TransientResourceHeapImpl* heap)
+ {
+ return resetImpl(
+ device, layout, &heap->m_stagingCpuViewHeap, &heap->m_stagingCpuSamplerHeap, false);
+ }
+
protected:
Result _createSpecializedLayout(ShaderObjectLayoutImpl** outLayout) SLANG_OVERRIDE
{
@@ -3434,12 +3603,18 @@ public:
}
List<RefPtr<ShaderObjectImpl>> m_entryPoints;
+ };
+ class MutableRootShaderObjectImpl : public RootShaderObjectImpl
+ {
public:
- // Descriptor heaps for the root object. Resets with the life cycle of each root shader
- // object use.
- D3D12DescriptorHeap m_cpuViewHeap;
- D3D12DescriptorHeap m_cpuSamplerHeap;
+ // Override default reference counting behavior to disable lifetime management via ComPtr.
+ // Root objects are managed by command buffer and does not need to be freed by the user.
+ SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return ShaderObjectBase::addRef(); }
+ SLANG_NO_THROW uint32_t SLANG_MCALL release() override
+ {
+ return ShaderObjectBase::release();
+ }
};
class ShaderTableImpl : public ShaderTableBase
@@ -3586,6 +3761,7 @@ public:
// device.
D3D12Device* m_renderer;
RootShaderObjectImpl m_rootShaderObject;
+ RefPtr<MutableRootShaderObjectImpl> m_mutableRootShaderObject;
void bindDescriptorHeaps()
{
@@ -3596,6 +3772,12 @@ public:
m_cmdList->SetDescriptorHeaps(SLANG_COUNT_OF(heaps), heaps);
}
+ void reinit()
+ {
+ bindDescriptorHeaps();
+ m_rootShaderObject.init(m_renderer);
+ }
+
void init(
D3D12Device* renderer,
ID3D12GraphicsCommandList* d3dCommandList,
@@ -3605,11 +3787,15 @@ public:
m_renderer = renderer;
m_cmdList = d3dCommandList;
- bindDescriptorHeaps();
- m_rootShaderObject.init(renderer);
+ reinit();
#if SLANG_GFX_HAS_DXR_SUPPORT
m_cmdList->QueryInterface<ID3D12GraphicsCommandList4>(m_cmdList4.writeRef());
+ if (m_cmdList4)
+ {
+ m_cmdList1 = m_cmdList4;
+ return;
+ }
#endif
m_cmdList->QueryInterface<ID3D12GraphicsCommandList1>(m_cmdList1.writeRef());
}
@@ -4473,6 +4659,12 @@ public:
return bindPipelineImpl(state, outRootObject);
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) override
+ {
+ return bindPipelineWithRootObjectImpl(state, rootObject);
+ }
+
virtual SLANG_NO_THROW void SLANG_MCALL
setViewports(uint32_t count, const Viewport* viewports) override
{
@@ -4799,6 +4991,12 @@ public:
return bindPipelineImpl(state, outRootObject);
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL bindPipelineWithRootObject(
+ IPipelineState* state, IShaderObject* rootObject) override
+ {
+ return bindPipelineWithRootObjectImpl(state, rootObject);
+ }
+
virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) override
{
// Submit binding for compute
@@ -4874,6 +5072,11 @@ public:
DeviceAddress source) override;
virtual SLANG_NO_THROW void SLANG_MCALL
bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL bindPipelineWithRootObject(
+ IPipelineState* state, IShaderObject* rootObject) override
+ {
+ return bindPipelineWithRootObjectImpl(state, rootObject);
+ }
virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
uint32_t rayGenShaderIndex,
IShaderTable* shaderTable,
@@ -5053,7 +5256,6 @@ public:
waitInfo.fence = m_fence;
}
m_d3dQueue->Signal(m_fence, m_fenceValue);
- ResetEvent(globalWaitHandle);
if (fence)
{
@@ -5290,6 +5492,8 @@ SLANG_NO_THROW Result SLANG_MCALL D3D12Device::TransientResourceHeapImpl::synchr
m_currentSamplerHeapIndex = -1;
allocateNewViewDescriptorHeap(m_device);
allocateNewSamplerDescriptorHeap(m_device);
+ m_stagingCpuSamplerHeap.freeAll();
+ m_stagingCpuViewHeap.freeAll();
m_commandListAllocId = 0;
SLANG_RETURN_ON_FAIL(m_commandAllocator->Reset());
Super::reset();
@@ -5303,7 +5507,7 @@ Result D3D12Device::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffe
auto result = static_cast<D3D12Device::CommandBufferImpl*>(
m_commandBufferPool[m_commandListAllocId].Ptr());
m_d3dCommandListPool[m_commandListAllocId]->Reset(m_commandAllocator, nullptr);
- result->init(m_device, m_d3dCommandListPool[m_commandListAllocId], this);
+ result->reinit();
++m_commandListAllocId;
returnComPtr(outCmdBuffer, result);
return SLANG_OK;
@@ -5327,7 +5531,9 @@ Result D3D12Device::TransientResourceHeapImpl::createCommandBuffer(ICommandBuffe
Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitter, RefPtr<PipelineStateBase>& newPipeline)
{
- RootShaderObjectImpl* rootObjectImpl = &m_commandBuffer->m_rootShaderObject;
+ RootShaderObjectImpl* rootObjectImpl = m_commandBuffer->m_mutableRootShaderObject
+ ? m_commandBuffer->m_mutableRootShaderObject.Ptr()
+ : &m_commandBuffer->m_rootShaderObject;
SLANG_RETURN_ON_FAIL(m_renderer->maybeSpecializePipeline(m_currentPipeline, rootObjectImpl, newPipeline));
PipelineStateBase* newPipelineImpl = static_cast<PipelineStateBase*>(newPipeline.Ptr());
auto commandList = m_d3dCmdList;
@@ -5335,10 +5541,7 @@ Result D3D12Device::PipelineCommandEncoder::_bindRenderState(Submitter* submitte
auto programImpl = static_cast<ShaderProgramImpl*>(newPipelineImpl->m_program.Ptr());
submitter->setRootSignature(programImpl->m_rootObjectLayout->m_rootSignature);
submitter->setPipelineState(newPipelineImpl);
- RefPtr<ShaderObjectLayoutImpl> specializedRootLayout;
- SLANG_RETURN_ON_FAIL(rootObjectImpl->getSpecializedLayout(specializedRootLayout.writeRef()));
- RootShaderObjectLayoutImpl* rootLayoutImpl =
- static_cast<RootShaderObjectLayoutImpl*>(specializedRootLayout.Ptr());
+ RootShaderObjectLayoutImpl* rootLayoutImpl = programImpl->m_rootObjectLayout;
// We need to set up a context for binding shader objects to the pipeline state.
// This type mostly exists to bundle together a bunch of parameters that would
@@ -5803,7 +6006,15 @@ Result D3D12Device::_createDevice(DeviceCheckFlags deviceCheckFlags, const Unown
{
infoQueue->SetBreakOnSeverity(D3D12_MESSAGE_SEVERITY_ERROR, true);
}
-
+ D3D12_MESSAGE_ID hideMessages[] = {
+ D3D12_MESSAGE_ID_CLEARRENDERTARGETVIEW_MISMATCHINGCLEARVALUE,
+ D3D12_MESSAGE_ID_CLEARDEPTHSTENCILVIEW_MISMATCHINGCLEARVALUE,
+ };
+ D3D12_INFO_QUEUE_FILTER f = {};
+ f.DenyList.NumIDs = (UINT)SLANG_COUNT_OF(hideMessages);
+ f.DenyList.pIDList = hideMessages;
+ infoQueue->AddStorageFilterEntries(&f);
+
// Apparently there is a problem with sm 6.3 with spurious errors, with debug layer enabled
D3D12_FEATURE_DATA_SHADER_MODEL featureShaderModel;
featureShaderModel.HighestShaderModel = D3D_SHADER_MODEL(0x63);
@@ -7549,19 +7760,18 @@ Result D3D12Device::createMutableShaderObject(
ShaderObjectLayoutBase* layout,
IShaderObject** outObject)
{
- auto layoutImpl = static_cast<ShaderObjectLayoutImpl*>(layout);
-
- RefPtr<MutableShaderObjectImpl> result = new MutableShaderObjectImpl();
- SLANG_RETURN_ON_FAIL(result->init(this, layoutImpl));
- returnComPtr(outObject, result);
-
- return SLANG_OK;
+ auto result = createShaderObject(layout, outObject);
+ SLANG_RETURN_ON_FAIL(result);
+ static_cast<ShaderObjectImpl*>(*outObject)->m_isMutable = true;
+ return result;
}
Result D3D12Device::createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject)
{
- RefPtr<MutableRootShaderObject> result =
- new MutableRootShaderObject(this, static_cast<ShaderProgramBase*>(program));
+ RefPtr<MutableRootShaderObjectImpl> result = new MutableRootShaderObjectImpl();
+ result->init(this);
+ auto programImpl = static_cast<ShaderProgramImpl*>(program);
+ result->resetImpl(this, programImpl->m_rootObjectLayout, m_cpuViewHeap.Ptr(), m_cpuSamplerHeap.Ptr(), true);
returnComPtr(outObject, result);
return SLANG_OK;
}
@@ -8518,6 +8728,8 @@ Result D3D12Device::ShaderObjectImpl::setResource(ShaderOffset const& offset, IR
if (offset.bindingRangeIndex >= layout->getBindingRangeCount())
return SLANG_E_INVALID_ARG;
+ m_version++;
+
ID3D12Device* d3dDevice = static_cast<D3D12Device*>(getDevice())->m_device;
auto& bindingRange = layout->getBindingRange(offset.bindingRangeIndex);
diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp
index 6d735fb8d..8ecdf405a 100644
--- a/tools/gfx/debug-layer.cpp
+++ b/tools/gfx/debug-layer.cpp
@@ -1063,6 +1063,13 @@ Result DebugComputeCommandEncoder::bindPipeline(
return result;
}
+Result DebugComputeCommandEncoder::bindPipelineWithRootObject(
+ IPipelineState* state, IShaderObject* rootObject)
+{
+ SLANG_GFX_API_FUNC;
+ return baseObject->bindPipelineWithRootObject(getInnerObj(state), getInnerObj(rootObject));
+}
+
void DebugComputeCommandEncoder::dispatchCompute(int x, int y, int z)
{
SLANG_GFX_API_FUNC;
@@ -1098,6 +1105,13 @@ Result DebugRenderCommandEncoder::bindPipeline(
return result;
}
+Result DebugRenderCommandEncoder::bindPipelineWithRootObject(
+ IPipelineState* state, IShaderObject* rootObject)
+{
+ SLANG_GFX_API_FUNC;
+ return baseObject->bindPipelineWithRootObject(getInnerObj(state), getInnerObj(rootObject));
+}
+
void DebugRenderCommandEncoder::setViewports(uint32_t count, const Viewport* viewports)
{
SLANG_GFX_API_FUNC;
@@ -1485,6 +1499,13 @@ void DebugRayTracingCommandEncoder::bindPipeline(
*outRootObject = &commandBuffer->rootObject;
}
+Result DebugRayTracingCommandEncoder::bindPipelineWithRootObject(
+ IPipelineState* state, IShaderObject* rootObject)
+{
+ SLANG_GFX_API_FUNC;
+ return baseObject->bindPipelineWithRootObject(getInnerObj(state), getInnerObj(rootObject));
+}
+
void DebugRayTracingCommandEncoder::dispatchRays(
uint32_t rayGenShaderIndex,
IShaderTable* shaderTable,
@@ -1767,12 +1788,6 @@ Result DebugShaderObject::getCurrentVersion(
return SLANG_OK;
}
-Result DebugShaderObject::copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap)
-{
- SLANG_GFX_API_FUNC;
- return baseObject->copyFrom(getInnerObj(other), getInnerObj(transientHeap));
-}
-
const void* DebugShaderObject::getRawData()
{
SLANG_GFX_API_FUNC;
diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h
index e0f994a36..a8cdc1b4f 100644
--- a/tools/gfx/debug-layer.h
+++ b/tools/gfx/debug-layer.h
@@ -305,8 +305,6 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL getCurrentVersion(
ITransientResourceHeap* transientHeap, IShaderObject** outObject) override;
- virtual SLANG_NO_THROW Result SLANG_MCALL
- copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap) override;
virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() override;
virtual SLANG_NO_THROW size_t SLANG_MCALL getSize() override;
virtual SLANG_NO_THROW Result SLANG_MCALL
@@ -437,6 +435,8 @@ public:
virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override;
virtual SLANG_NO_THROW Result SLANG_MCALL
bindPipeline(IPipelineState* state, IShaderObject** outRootShaderObject) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) override;
virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) override;
virtual SLANG_NO_THROW void SLANG_MCALL
dispatchComputeIndirect(IBufferResource* cmdBuffer, uint64_t offset) override;
@@ -482,6 +482,8 @@ public:
virtual SLANG_NO_THROW void SLANG_MCALL endEncoding() override;
virtual SLANG_NO_THROW Result SLANG_MCALL
bindPipeline(IPipelineState* state, IShaderObject** outRootShaderObject) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) override;
virtual SLANG_NO_THROW void SLANG_MCALL
setViewports(uint32_t count, const Viewport* viewports) override;
virtual SLANG_NO_THROW void SLANG_MCALL
@@ -566,6 +568,8 @@ public:
DeviceAddress source) override;
virtual SLANG_NO_THROW void SLANG_MCALL
bindPipeline(IPipelineState* state, IShaderObject** outRootObject) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) override;
virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
uint32_t rayGenShaderIndex,
IShaderTable* shaderTable,
diff --git a/tools/gfx/immediate-renderer-base.cpp b/tools/gfx/immediate-renderer-base.cpp
index c5e5032e5..e18727bdf 100644
--- a/tools/gfx/immediate-renderer-base.cpp
+++ b/tools/gfx/immediate-renderer-base.cpp
@@ -38,10 +38,12 @@ public:
bool m_hasWriteTimestamps = false;
RefPtr<ImmediateRendererBase> m_renderer;
RefPtr<ShaderObjectBase> m_rootShaderObject;
+ TransientResourceHeapBase* m_transientHeap;
- void init(ImmediateRendererBase* renderer)
+ void init(ImmediateRendererBase* renderer, TransientResourceHeapBase* transientHeap)
{
m_renderer = renderer;
+ m_transientHeap = transientHeap;
}
void reset()
@@ -289,6 +291,17 @@ public:
return SLANG_OK;
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) override
+ {
+ m_writer->setPipelineState(state);
+ auto stateImpl = static_cast<PipelineStateBase*>(state);
+ SLANG_RETURN_ON_FAIL(m_commandBuffer->m_renderer->createRootShaderObject(
+ stateImpl->m_program, m_commandBuffer->m_rootShaderObject.writeRef()));
+ m_commandBuffer->m_rootShaderObject->copyFrom(rootObject, m_commandBuffer->m_transientHeap);
+ return SLANG_OK;
+ }
+
virtual SLANG_NO_THROW void SLANG_MCALL
setViewports(uint32_t count, const Viewport* viewports) override
{
@@ -435,6 +448,18 @@ public:
return SLANG_OK;
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* state, IShaderObject* rootObject) override
+ {
+ m_writer->setPipelineState(state);
+ auto stateImpl = static_cast<PipelineStateBase*>(state);
+ SLANG_RETURN_ON_FAIL(m_commandBuffer->m_renderer->createRootShaderObject(
+ stateImpl->m_program, m_commandBuffer->m_rootShaderObject.writeRef()));
+ m_commandBuffer->m_rootShaderObject->copyFrom(
+ rootObject, m_commandBuffer->m_transientHeap);
+ return SLANG_OK;
+ }
+
virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) override
{
m_writer->bindRootShaderObject(m_commandBuffer->m_rootShaderObject);
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 8fc74158c..8136d6735 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -546,7 +546,7 @@ public:
}
virtual SLANG_NO_THROW Result SLANG_MCALL
- copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap) override;
+ copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap);
virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() override
{
@@ -596,6 +596,13 @@ public:
void setSpecializationArgsForContainerElement(ExtendedShaderObjectTypeList& specializationArgs);
+ Slang::Index getSubObjectIndex(ShaderOffset offset)
+ {
+ auto layout = getLayout();
+ auto bindingRange = layout->getBindingRange(offset.bindingRangeIndex);
+ return bindingRange.subObjectIndex + offset.bindingArrayIndex;
+ }
+
virtual SLANG_NO_THROW Result SLANG_MCALL
setObject(ShaderOffset const& offset, IShaderObject* object) SLANG_OVERRIDE
{
diff --git a/tools/gfx/simple-transient-resource-heap.h b/tools/gfx/simple-transient-resource-heap.h
index 5706c6b1d..c94fe2a7b 100644
--- a/tools/gfx/simple-transient-resource-heap.h
+++ b/tools/gfx/simple-transient-resource-heap.h
@@ -33,7 +33,7 @@ public:
createCommandBuffer(ICommandBuffer** outCommandBuffer) override
{
Slang::RefPtr<TCommandBuffer> newCmdBuffer = new TCommandBuffer();
- newCmdBuffer->init(m_device);
+ newCmdBuffer->init(m_device, this);
returnComPtr(outCommandBuffer, newCmdBuffer);
return SLANG_OK;
}
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index 0a66c653b..ebf25513b 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -2530,6 +2530,14 @@ public:
return SLANG_OK;
}
+ Result setPipelineStateWithRootObjectImpl(IPipelineState* state, IShaderObject* inObject)
+ {
+ IShaderObject* rootObject = nullptr;
+ SLANG_RETURN_ON_FAIL(setPipelineStateImpl(state, &rootObject));
+ static_cast<ShaderObjectBase*>(rootObject)->copyFrom(inObject, m_commandBuffer->m_transientHeap);
+ return SLANG_OK;
+ }
+
void flushBindingState(VkPipelineBindPoint pipelineBindPoint)
{
auto& api = *m_api;
@@ -5098,6 +5106,12 @@ public:
return setPipelineStateImpl(pipelineState, outRootObject);
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ bindPipelineWithRootObject(IPipelineState* pipelineState, IShaderObject* rootObject) override
+ {
+ return setPipelineStateWithRootObjectImpl(pipelineState, rootObject);
+ }
+
virtual SLANG_NO_THROW void SLANG_MCALL
setViewports(uint32_t count, const Viewport* viewports) override
{
@@ -5380,6 +5394,12 @@ public:
return setPipelineStateImpl(pipelineState, outRootObject);
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL bindPipelineWithRootObject(
+ IPipelineState* pipelineState, IShaderObject* rootObject) override
+ {
+ return setPipelineStateWithRootObjectImpl(pipelineState, rootObject);
+ }
+
virtual SLANG_NO_THROW void SLANG_MCALL dispatchCompute(int x, int y, int z) override
{
auto pipeline = static_cast<PipelineStateImpl*>(m_currentPipeline.Ptr());
@@ -5653,6 +5673,13 @@ public:
setPipelineStateImpl(pipeline, outRootObject);
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL bindPipelineWithRootObject(
+ IPipelineState* pipelineState, IShaderObject* rootObject) override
+ {
+ return setPipelineStateWithRootObjectImpl(pipelineState, rootObject);
+ }
+
+ // TODO: Implement after implementing createRayTracingPipelineState
virtual SLANG_NO_THROW void SLANG_MCALL dispatchRays(
uint32_t raygenShaderIndex,
IShaderTable* shaderTable,