summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slang-gfx.h2
-rw-r--r--tools/gfx-unit-test/root-mutable-shader-object.cpp120
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp29
-rw-r--r--tools/gfx/debug-layer.cpp4
-rw-r--r--tools/gfx/debug-layer.h3
-rw-r--r--tools/gfx/mutable-shader-object.h163
-rw-r--r--tools/gfx/renderer-shared.cpp28
-rw-r--r--tools/gfx/renderer-shared.h11
-rw-r--r--tools/gfx/vulkan/render-vk.cpp31
9 files changed, 376 insertions, 15 deletions
diff --git a/slang-gfx.h b/slang-gfx.h
index 971c2fffc..98e8fe9c5 100644
--- a/slang-gfx.h
+++ b/slang-gfx.h
@@ -1018,7 +1018,7 @@ public:
IShaderObject** outObject) = 0;
/// Copies contents from another shader object to this object.
- virtual SLANG_NO_THROW Result SLANG_MCALL copyFrom(IShaderObject* other) = 0;
+ virtual SLANG_NO_THROW Result SLANG_MCALL copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap) = 0;
virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() = 0;
diff --git a/tools/gfx-unit-test/root-mutable-shader-object.cpp b/tools/gfx-unit-test/root-mutable-shader-object.cpp
new file mode 100644
index 000000000..8cd2abbd6
--- /dev/null
+++ b/tools/gfx-unit-test/root-mutable-shader-object.cpp
@@ -0,0 +1,120 @@
+#include "tools/unit-test/slang-unit-test.h"
+
+#include "slang-gfx.h"
+#include "gfx-test-util.h"
+#include "tools/gfx-util/shader-cursor.h"
+#include "source/core/slang-basic.h"
+
+using namespace gfx;
+
+namespace gfx_test
+{
+ void mutableRootShaderObjectTestImpl(IDevice* device, UnitTestContext* context)
+ {
+ Slang::ComPtr<ITransientResourceHeap> transientHeap;
+ ITransientResourceHeap::Desc transientHeapDesc = {};
+ transientHeapDesc.constantBufferSize = 4096;
+ GFX_CHECK_CALL_ABORT(
+ device->createTransientResourceHeap(transientHeapDesc, transientHeap.writeRef()));
+
+ ComPtr<IShaderProgram> shaderProgram;
+ slang::ProgramLayout* slangReflection;
+ GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "mutable-shader-object", "computeMain", slangReflection));
+
+ ComputePipelineStateDesc pipelineDesc = {};
+ pipelineDesc.program = shaderProgram.get();
+ ComPtr<gfx::IPipelineState> pipelineState;
+ GFX_CHECK_CALL_ABORT(
+ device->createComputePipelineState(pipelineDesc, pipelineState.writeRef()));
+
+ float initialData[] = { 0.0f, 1.0f, 2.0f, 3.0f };
+ const int numberCount = SLANG_COUNT_OF(initialData);
+ IBufferResource::Desc bufferDesc = {};
+ bufferDesc.sizeInBytes = sizeof(initialData);
+ bufferDesc.format = gfx::Format::Unknown;
+ bufferDesc.elementSize = sizeof(float);
+ bufferDesc.allowedStates = ResourceStateSet(
+ ResourceState::ShaderResource,
+ ResourceState::UnorderedAccess,
+ ResourceState::CopyDestination,
+ ResourceState::CopySource);
+ bufferDesc.defaultState = ResourceState::UnorderedAccess;
+ bufferDesc.cpuAccessFlags = AccessFlag::Write | AccessFlag::Read;
+
+ ComPtr<IBufferResource> numbersBuffer;
+ GFX_CHECK_CALL_ABORT(device->createBufferResource(
+ bufferDesc,
+ (void*)initialData,
+ numbersBuffer.writeRef()));
+
+ ComPtr<IResourceView> bufferView;
+ IResourceView::Desc viewDesc = {};
+ viewDesc.type = IResourceView::Type::UnorderedAccess;
+ viewDesc.format = Format::Unknown;
+ GFX_CHECK_CALL_ABORT(device->createBufferView(numbersBuffer, viewDesc, bufferView.writeRef()));
+
+ ComPtr<IShaderObject> rootObject;
+ device->createMutableRootShaderObject(shaderProgram, rootObject.writeRef());
+ auto entryPointCursor = ShaderCursor(rootObject->getEntryPoint(0));
+ entryPointCursor.getPath("buffer").setResource(bufferView);
+
+ slang::TypeReflection* addTransformerType =
+ slangReflection->findTypeByName("AddTransformer");
+ ComPtr<IShaderObject> transformer;
+ GFX_CHECK_CALL_ABORT(device->createMutableShaderObject(
+ addTransformerType, ShaderObjectContainerType::None, transformer.writeRef()));
+ entryPointCursor.getPath("transformer").setObject(transformer);
+
+ // Set the `c` field of the `AddTransformer`.
+ float c = 1.0f;
+ ShaderCursor(transformer).getPath("c").setData(&c, sizeof(float));
+
+ {
+ ICommandQueue::Desc queueDesc = { ICommandQueue::QueueType::Graphics };
+ auto queue = device->createCommandQueue(queueDesc);
+
+ auto commandBuffer = transientHeap->createCommandBuffer();
+ {
+ auto encoder = commandBuffer->encodeComputeCommands();
+ auto root = encoder->bindPipeline(pipelineState);
+ root->copyFrom(rootObject, transientHeap);
+ encoder->dispatchCompute(1, 1, 1);
+ encoder->endEncoding();
+ }
+
+ auto barrierEncoder = commandBuffer->encodeResourceCommands();
+ barrierEncoder->bufferBarrier(1, numbersBuffer.readRef(), ResourceState::UnorderedAccess, ResourceState::UnorderedAccess);
+ barrierEncoder->endEncoding();
+
+ // Mutate `transformer` object and run again.
+ c = 2.0f;
+ ShaderCursor(transformer).getPath("c").setData(&c, sizeof(float));
+ {
+ auto encoder = commandBuffer->encodeComputeCommands();
+ auto root = encoder->bindPipeline(pipelineState);
+ root->copyFrom(rootObject, transientHeap);
+ encoder->dispatchCompute(1, 1, 1);
+ encoder->endEncoding();
+ }
+
+ commandBuffer->close();
+ queue->executeCommandBuffer(commandBuffer);
+ queue->wait();
+ }
+
+ compareComputeResult(
+ device,
+ numbersBuffer,
+ Slang::makeArray<float>(3.0f, 4.0f, 5.0f, 6.0f));
+ }
+
+ SLANG_UNIT_TEST(mutableRootShaderObjectD3D12)
+ {
+ runTestImpl(mutableRootShaderObjectTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
+ }
+
+ SLANG_UNIT_TEST(mutableRootShaderObjectVulkan)
+ {
+ runTestImpl(mutableRootShaderObjectTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);
+ }
+}
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 4d5845b85..5a0478f68 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -131,6 +131,8 @@ public:
override;
virtual Result createMutableShaderObject(
ShaderObjectLayoutBase* layout, IShaderObject** outObject) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override;
@@ -2819,6 +2821,21 @@ public:
return SLANG_OK;
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap) override
+ {
+ SLANG_RETURN_ON_FAIL(Super::copyFrom(object, transientHeap));
+ if (auto srcObj = dynamic_cast<MutableRootShaderObject*>(object))
+ {
+ for (Index i = 0; i < srcObj->m_entryPoints.getCount(); i++)
+ {
+ m_entryPoints[i]->copyFrom(srcObj->m_entryPoints[i], transientHeap);
+ }
+ return SLANG_OK;
+ }
+ return SLANG_FAIL;
+ }
+
public:
Result bindAsRoot(
BindingContext* context,
@@ -4935,8 +4952,8 @@ Result D3D12Device::getTextureAllocationInfo(
D3D12_RESOURCE_DESC resourceDesc = {};
setupResourceDesc(resourceDesc, srcDesc);
auto allocInfo = m_device->GetResourceAllocationInfo(0xFF, 1, &resourceDesc);
- *outSize = allocInfo.SizeInBytes;
- *outAlignment = allocInfo.Alignment;
+ *outSize = (size_t)allocInfo.SizeInBytes;
+ *outAlignment = (size_t)allocInfo.Alignment;
return SLANG_OK;
}
@@ -5843,6 +5860,14 @@ Result D3D12Device::createMutableShaderObject(
return SLANG_OK;
}
+Result D3D12Device::createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject)
+{
+ RefPtr<MutableRootShaderObject> result =
+ new MutableRootShaderObject(this, static_cast<ShaderProgramBase*>(program));
+ returnComPtr(outObject, result);
+ return SLANG_OK;
+}
+
Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState)
{
GraphicsPipelineStateDesc desc = inDesc;
diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp
index 6266a6175..831566c7c 100644
--- a/tools/gfx/debug-layer.cpp
+++ b/tools/gfx/debug-layer.cpp
@@ -1686,10 +1686,10 @@ Result DebugShaderObject::getCurrentVersion(
return SLANG_OK;
}
-Result DebugShaderObject::copyFrom(IShaderObject* other)
+Result DebugShaderObject::copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap)
{
SLANG_GFX_API_FUNC;
- return baseObject->copyFrom(getInnerObj(other));
+ return baseObject->copyFrom(getInnerObj(other), getInnerObj(transientHeap));
}
const void* DebugShaderObject::getRawData()
diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h
index 28ad8c4b5..f61917835 100644
--- a/tools/gfx/debug-layer.h
+++ b/tools/gfx/debug-layer.h
@@ -280,7 +280,8 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL getCurrentVersion(
ITransientResourceHeap* transientHeap, IShaderObject** outObject) override;
- virtual SLANG_NO_THROW Result SLANG_MCALL copyFrom(IShaderObject* other) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap) override;
virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() override;
virtual SLANG_NO_THROW size_t SLANG_MCALL getSize() override;
virtual SLANG_NO_THROW Result SLANG_MCALL
diff --git a/tools/gfx/mutable-shader-object.h b/tools/gfx/mutable-shader-object.h
index 3c551023a..182334351 100644
--- a/tools/gfx/mutable-shader-object.h
+++ b/tools/gfx/mutable-shader-object.h
@@ -220,4 +220,167 @@ namespace gfx
return m_shaderObjectVersions.getLastAllocation().object;
}
};
+
+ // A proxy shader object to hold mutable shader parameters for global scope and entry-points.
+ class MutableRootShaderObject : public ShaderObjectBase
+ {
+ public:
+ Slang::List<uint8_t> m_data;
+ Slang::OrderedDictionary<ShaderOffset, Slang::RefPtr<ResourceViewBase>> m_resources;
+ Slang::OrderedDictionary<ShaderOffset, Slang::RefPtr<SamplerStateBase>> m_samplers;
+ Slang::OrderedDictionary<ShaderOffset, Slang::RefPtr<ShaderObjectBase>> m_objects;
+ Slang::OrderedDictionary<ShaderOffset, Slang::List<slang::SpecializationArg>> m_specializationArgs;
+ Slang::List<Slang::RefPtr<MutableRootShaderObject>> m_entryPoints;
+ Slang::RefPtr<BufferResource> m_constantBufferOverride;
+ slang::TypeLayoutReflection* m_elementTypeLayout;
+
+ MutableRootShaderObject(RendererBase* device, slang::TypeLayoutReflection* entryPointLayout)
+ {
+ this->m_device = device;
+ m_elementTypeLayout = entryPointLayout;
+ m_data.setCount(entryPointLayout->getSize());
+ memset(m_data.begin(), 0, m_data.getCount());
+ }
+
+ MutableRootShaderObject(RendererBase* device, Slang::RefPtr<ShaderProgramBase> program)
+ {
+ this->m_device = device;
+ auto programLayout = program->slangProgram->getLayout();
+ SlangInt entryPointCount = programLayout->getEntryPointCount();
+ for (SlangInt e = 0; e < entryPointCount; ++e)
+ {
+ auto slangEntryPoint = programLayout->getEntryPointByIndex(e);
+ Slang::RefPtr<MutableRootShaderObject> entryPointObject =
+ new MutableRootShaderObject(device, slangEntryPoint->getTypeLayout()->getElementTypeLayout());
+
+ m_entryPoints.add(entryPointObject);
+ }
+ m_data.setCount(programLayout->getGlobalParamsTypeLayout()->getSize());
+ memset(m_data.begin(), 0, m_data.getCount());
+ m_elementTypeLayout = programLayout->getGlobalParamsTypeLayout();
+ }
+
+
+ virtual SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL
+ getElementTypeLayout() override
+ {
+ return m_elementTypeLayout;
+ }
+
+ virtual SLANG_NO_THROW ShaderObjectContainerType SLANG_MCALL getContainerType() override
+ {
+ return ShaderObjectContainerType::None;
+ }
+
+ virtual SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() override
+ {
+ return (UInt)m_entryPoints.getCount();
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ getEntryPoint(UInt index, IShaderObject** entryPoint) override
+ {
+ returnComPtr(entryPoint, m_entryPoints[index]);
+ return SLANG_OK;
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ setData(ShaderOffset const& offset, void const* data, size_t size) override
+ {
+ auto newSize = Slang::Index(size + offset.uniformOffset);
+ if (newSize > m_data.getCount())
+ m_data.setCount((Slang::Index)newSize);
+ memcpy(m_data.begin() + offset.uniformOffset, data, size);
+ return SLANG_OK;
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ getObject(ShaderOffset const& offset, IShaderObject** object) override
+ {
+ *object = nullptr;
+
+ Slang::RefPtr<ShaderObjectBase> subObject;
+ if (m_objects.TryGetValue(offset, subObject))
+ {
+ returnComPtr(object, subObject);
+ }
+ return SLANG_OK;
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ setObject(ShaderOffset const& offset, IShaderObject* object) override
+ {
+ m_objects[offset] = static_cast<ShaderObjectBase*>(object);
+ return SLANG_OK;
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ setResource(ShaderOffset const& offset, IResourceView* resourceView) override
+ {
+ m_resources[offset] = static_cast<ResourceViewBase*>(resourceView);
+ return SLANG_OK;
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ setSampler(ShaderOffset const& offset, ISamplerState* sampler) override
+ {
+ m_samplers[offset] = static_cast<SamplerStateBase*>(sampler);
+ return SLANG_OK;
+ }
+ virtual SLANG_NO_THROW Result SLANG_MCALL setCombinedTextureSampler(
+ ShaderOffset const& offset, IResourceView* textureView, ISamplerState* sampler) override
+ {
+ m_resources[offset] = static_cast<ResourceViewBase*>(textureView);
+ m_samplers[offset] = static_cast<SamplerStateBase*>(sampler);
+ return SLANG_OK;
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL setSpecializationArgs(
+ ShaderOffset const& offset,
+ const slang::SpecializationArg* args,
+ uint32_t count) override
+ {
+ Slang::List<slang::SpecializationArg> specArgs;
+ specArgs.addRange(args, count);
+ m_specializationArgs[offset] = specArgs;
+ return SLANG_OK;
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL getCurrentVersion(
+ ITransientResourceHeap* transientHeap, IShaderObject** outObject) override
+ {
+ return SLANG_FAIL;
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap) override
+ {
+ auto otherObject = static_cast<MutableRootShaderObject*>(other);
+ *this = *otherObject;
+ return SLANG_OK;
+ }
+
+ virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() override
+ {
+ return m_data.begin();
+ }
+
+ virtual SLANG_NO_THROW size_t SLANG_MCALL getSize() override
+ {
+ return (size_t)m_data.getCount();
+ }
+
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ setConstantBufferOverride(IBufferResource* constantBuffer) override
+ {
+ m_constantBufferOverride = static_cast<BufferResource*>(constantBuffer);
+ return SLANG_OK;
+ }
+
+ virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override
+ {
+ SLANG_UNUSED(args);
+ return SLANG_OK;
+ }
+ };
+
}
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index 5b2dc949b..c212f21eb 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -799,4 +799,32 @@ IDebugCallback* _getNullDebugCallback()
return &result;
}
+Result ShaderObjectBase::copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap)
+{
+ if (auto srcObj = dynamic_cast<MutableRootShaderObject*>(object))
+ {
+ setData(gfx::ShaderOffset(), srcObj->m_data.begin(), (size_t)srcObj->m_data.getCount());
+ for (auto& kv : srcObj->m_objects)
+ {
+ ComPtr<IShaderObject> subObject;
+ SLANG_RETURN_ON_FAIL(kv.Value->getCurrentVersion(transientHeap, subObject.writeRef()));
+ setObject(kv.Key, subObject);
+ }
+ for (auto& kv : srcObj->m_resources)
+ {
+ setResource(kv.Key, kv.Value.Ptr());
+ }
+ for (auto& kv : srcObj->m_samplers)
+ {
+ setSampler(kv.Key, kv.Value.Ptr());
+ }
+ for (auto& kv : srcObj->m_specializationArgs)
+ {
+ setSpecializationArgs(kv.Key, kv.Value.begin(), (uint32_t)kv.Value.getCount());
+ }
+ return SLANG_OK;
+ }
+ return SLANG_FAIL;
+}
+
} // namespace gfx
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index abbeb285f..1d5d1860e 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -533,15 +533,12 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL getCurrentVersion(
ITransientResourceHeap* transientHeap, IShaderObject** outObject) override
{
- SLANG_UNUSED(outObject);
- return SLANG_E_NOT_AVAILABLE;
+ returnComPtr(outObject, this);
+ return SLANG_OK;
}
- virtual SLANG_NO_THROW Result SLANG_MCALL copyFrom(IShaderObject* other) override
- {
- SLANG_UNUSED(other);
- return SLANG_E_NOT_AVAILABLE;
- }
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap) override;
virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() override
{
diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp
index 5c3e75110..1fb06b8e7 100644
--- a/tools/gfx/vulkan/render-vk.cpp
+++ b/tools/gfx/vulkan/render-vk.cpp
@@ -106,6 +106,9 @@ public:
override;
virtual Result createMutableShaderObject(ShaderObjectLayoutBase* layout, IShaderObject** outObject)
override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ createMutableRootShaderObject(
+ IShaderProgram* program, IShaderObject** outObject) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override;
@@ -3449,6 +3452,21 @@ public:
return SLANG_OK;
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap) override
+ {
+ SLANG_RETURN_ON_FAIL(Super::copyFrom(object, transientHeap));
+ if (auto srcObj = dynamic_cast<MutableRootShaderObject*>(object))
+ {
+ for (Index i = 0; i < srcObj->m_entryPoints.getCount(); i++)
+ {
+ m_entryPoints[i]->copyFrom(srcObj->m_entryPoints[i], transientHeap);
+ }
+ return SLANG_OK;
+ }
+ return SLANG_FAIL;
+ }
+
/// Bind this object as a root shader object
Result bindAsRoot(
PipelineCommandEncoder* encoder,
@@ -6744,8 +6762,8 @@ Result VKDevice::getTextureAllocationInfo(
VkMemoryRequirements memRequirements;
m_api.vkGetImageMemoryRequirements(m_device, image, &memRequirements);
- *outSize = memRequirements.size;
- *outAlignment = memRequirements.alignment;
+ *outSize = (size_t)memRequirements.size;
+ *outAlignment = (size_t)memRequirements.alignment;
m_api.vkDestroyImage(m_device, image, nullptr);
return SLANG_OK;
@@ -7555,6 +7573,15 @@ Result VKDevice::createMutableShaderObject(
return SLANG_OK;
}
+Result VKDevice::createMutableRootShaderObject(
+ IShaderProgram* program, IShaderObject** outObject)
+{
+ RefPtr<MutableRootShaderObject> result =
+ new MutableRootShaderObject(this, static_cast<ShaderProgramBase*>(program));
+ returnComPtr(outObject, result);
+ return SLANG_OK;
+}
+
Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState)
{
GraphicsPipelineStateDesc desc = inDesc;