From 5cbd61774c6ef2209fa0afc79b1dbbb68514346b Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 6 Dec 2021 09:06:16 -0800 Subject: gfx Mutable Root shader object implementation. (#2042) * gfx Mutable Root shader object implementation. * Fix x86 build. Co-authored-by: Yong He --- tools/gfx-unit-test/root-mutable-shader-object.cpp | 120 +++++++++++++++ tools/gfx/d3d12/render-d3d12.cpp | 29 +++- tools/gfx/debug-layer.cpp | 4 +- tools/gfx/debug-layer.h | 3 +- tools/gfx/mutable-shader-object.h | 163 +++++++++++++++++++++ tools/gfx/renderer-shared.cpp | 28 ++++ tools/gfx/renderer-shared.h | 11 +- tools/gfx/vulkan/render-vk.cpp | 31 +++- 8 files changed, 375 insertions(+), 14 deletions(-) create mode 100644 tools/gfx-unit-test/root-mutable-shader-object.cpp (limited to 'tools') diff --git a/tools/gfx-unit-test/root-mutable-shader-object.cpp b/tools/gfx-unit-test/root-mutable-shader-object.cpp new file mode 100644 index 000000000..8cd2abbd6 --- /dev/null +++ b/tools/gfx-unit-test/root-mutable-shader-object.cpp @@ -0,0 +1,120 @@ +#include "tools/unit-test/slang-unit-test.h" + +#include "slang-gfx.h" +#include "gfx-test-util.h" +#include "tools/gfx-util/shader-cursor.h" +#include "source/core/slang-basic.h" + +using namespace gfx; + +namespace gfx_test +{ + void mutableRootShaderObjectTestImpl(IDevice* device, UnitTestContext* context) + { + Slang::ComPtr transientHeap; + ITransientResourceHeap::Desc transientHeapDesc = {}; + transientHeapDesc.constantBufferSize = 4096; + GFX_CHECK_CALL_ABORT( + device->createTransientResourceHeap(transientHeapDesc, transientHeap.writeRef())); + + ComPtr shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "mutable-shader-object", "computeMain", slangReflection)); + + ComputePipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr pipelineState; + GFX_CHECK_CALL_ABORT( + device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); + + float initialData[] = { 0.0f, 1.0f, 2.0f, 3.0f }; + const int numberCount = SLANG_COUNT_OF(initialData); + IBufferResource::Desc bufferDesc = {}; + bufferDesc.sizeInBytes = sizeof(initialData); + bufferDesc.format = gfx::Format::Unknown; + bufferDesc.elementSize = sizeof(float); + bufferDesc.allowedStates = ResourceStateSet( + ResourceState::ShaderResource, + ResourceState::UnorderedAccess, + ResourceState::CopyDestination, + ResourceState::CopySource); + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.cpuAccessFlags = AccessFlag::Write | AccessFlag::Read; + + ComPtr numbersBuffer; + GFX_CHECK_CALL_ABORT(device->createBufferResource( + bufferDesc, + (void*)initialData, + numbersBuffer.writeRef())); + + ComPtr bufferView; + IResourceView::Desc viewDesc = {}; + viewDesc.type = IResourceView::Type::UnorderedAccess; + viewDesc.format = Format::Unknown; + GFX_CHECK_CALL_ABORT(device->createBufferView(numbersBuffer, viewDesc, bufferView.writeRef())); + + ComPtr rootObject; + device->createMutableRootShaderObject(shaderProgram, rootObject.writeRef()); + auto entryPointCursor = ShaderCursor(rootObject->getEntryPoint(0)); + entryPointCursor.getPath("buffer").setResource(bufferView); + + slang::TypeReflection* addTransformerType = + slangReflection->findTypeByName("AddTransformer"); + ComPtr transformer; + GFX_CHECK_CALL_ABORT(device->createMutableShaderObject( + addTransformerType, ShaderObjectContainerType::None, transformer.writeRef())); + entryPointCursor.getPath("transformer").setObject(transformer); + + // Set the `c` field of the `AddTransformer`. + float c = 1.0f; + ShaderCursor(transformer).getPath("c").setData(&c, sizeof(float)); + + { + ICommandQueue::Desc queueDesc = { ICommandQueue::QueueType::Graphics }; + auto queue = device->createCommandQueue(queueDesc); + + auto commandBuffer = transientHeap->createCommandBuffer(); + { + auto encoder = commandBuffer->encodeComputeCommands(); + auto root = encoder->bindPipeline(pipelineState); + root->copyFrom(rootObject, transientHeap); + encoder->dispatchCompute(1, 1, 1); + encoder->endEncoding(); + } + + auto barrierEncoder = commandBuffer->encodeResourceCommands(); + barrierEncoder->bufferBarrier(1, numbersBuffer.readRef(), ResourceState::UnorderedAccess, ResourceState::UnorderedAccess); + barrierEncoder->endEncoding(); + + // Mutate `transformer` object and run again. + c = 2.0f; + ShaderCursor(transformer).getPath("c").setData(&c, sizeof(float)); + { + auto encoder = commandBuffer->encodeComputeCommands(); + auto root = encoder->bindPipeline(pipelineState); + root->copyFrom(rootObject, transientHeap); + encoder->dispatchCompute(1, 1, 1); + encoder->endEncoding(); + } + + commandBuffer->close(); + queue->executeCommandBuffer(commandBuffer); + queue->wait(); + } + + compareComputeResult( + device, + numbersBuffer, + Slang::makeArray(3.0f, 4.0f, 5.0f, 6.0f)); + } + + SLANG_UNIT_TEST(mutableRootShaderObjectD3D12) + { + runTestImpl(mutableRootShaderObjectTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + } + + SLANG_UNIT_TEST(mutableRootShaderObjectVulkan) + { + runTestImpl(mutableRootShaderObjectTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + } +} diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index 4d5845b85..5a0478f68 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -131,6 +131,8 @@ public: override; virtual Result createMutableShaderObject( ShaderObjectLayoutBase* layout, IShaderObject** outObject) override; + virtual SLANG_NO_THROW Result SLANG_MCALL + createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject) override; virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override; @@ -2819,6 +2821,21 @@ public: return SLANG_OK; } + virtual SLANG_NO_THROW Result SLANG_MCALL + copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap) override + { + SLANG_RETURN_ON_FAIL(Super::copyFrom(object, transientHeap)); + if (auto srcObj = dynamic_cast(object)) + { + for (Index i = 0; i < srcObj->m_entryPoints.getCount(); i++) + { + m_entryPoints[i]->copyFrom(srcObj->m_entryPoints[i], transientHeap); + } + return SLANG_OK; + } + return SLANG_FAIL; + } + public: Result bindAsRoot( BindingContext* context, @@ -4935,8 +4952,8 @@ Result D3D12Device::getTextureAllocationInfo( D3D12_RESOURCE_DESC resourceDesc = {}; setupResourceDesc(resourceDesc, srcDesc); auto allocInfo = m_device->GetResourceAllocationInfo(0xFF, 1, &resourceDesc); - *outSize = allocInfo.SizeInBytes; - *outAlignment = allocInfo.Alignment; + *outSize = (size_t)allocInfo.SizeInBytes; + *outAlignment = (size_t)allocInfo.Alignment; return SLANG_OK; } @@ -5843,6 +5860,14 @@ Result D3D12Device::createMutableShaderObject( return SLANG_OK; } +Result D3D12Device::createMutableRootShaderObject(IShaderProgram* program, IShaderObject** outObject) +{ + RefPtr result = + new MutableRootShaderObject(this, static_cast(program)); + returnComPtr(outObject, result); + return SLANG_OK; +} + Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState) { GraphicsPipelineStateDesc desc = inDesc; diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp index 6266a6175..831566c7c 100644 --- a/tools/gfx/debug-layer.cpp +++ b/tools/gfx/debug-layer.cpp @@ -1686,10 +1686,10 @@ Result DebugShaderObject::getCurrentVersion( return SLANG_OK; } -Result DebugShaderObject::copyFrom(IShaderObject* other) +Result DebugShaderObject::copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap) { SLANG_GFX_API_FUNC; - return baseObject->copyFrom(getInnerObj(other)); + return baseObject->copyFrom(getInnerObj(other), getInnerObj(transientHeap)); } const void* DebugShaderObject::getRawData() diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h index 28ad8c4b5..f61917835 100644 --- a/tools/gfx/debug-layer.h +++ b/tools/gfx/debug-layer.h @@ -280,7 +280,8 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL getCurrentVersion( ITransientResourceHeap* transientHeap, IShaderObject** outObject) override; - virtual SLANG_NO_THROW Result SLANG_MCALL copyFrom(IShaderObject* other) override; + virtual SLANG_NO_THROW Result SLANG_MCALL + copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap) override; virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() override; virtual SLANG_NO_THROW size_t SLANG_MCALL getSize() override; virtual SLANG_NO_THROW Result SLANG_MCALL diff --git a/tools/gfx/mutable-shader-object.h b/tools/gfx/mutable-shader-object.h index 3c551023a..182334351 100644 --- a/tools/gfx/mutable-shader-object.h +++ b/tools/gfx/mutable-shader-object.h @@ -220,4 +220,167 @@ namespace gfx return m_shaderObjectVersions.getLastAllocation().object; } }; + + // A proxy shader object to hold mutable shader parameters for global scope and entry-points. + class MutableRootShaderObject : public ShaderObjectBase + { + public: + Slang::List m_data; + Slang::OrderedDictionary> m_resources; + Slang::OrderedDictionary> m_samplers; + Slang::OrderedDictionary> m_objects; + Slang::OrderedDictionary> m_specializationArgs; + Slang::List> m_entryPoints; + Slang::RefPtr m_constantBufferOverride; + slang::TypeLayoutReflection* m_elementTypeLayout; + + MutableRootShaderObject(RendererBase* device, slang::TypeLayoutReflection* entryPointLayout) + { + this->m_device = device; + m_elementTypeLayout = entryPointLayout; + m_data.setCount(entryPointLayout->getSize()); + memset(m_data.begin(), 0, m_data.getCount()); + } + + MutableRootShaderObject(RendererBase* device, Slang::RefPtr program) + { + this->m_device = device; + auto programLayout = program->slangProgram->getLayout(); + SlangInt entryPointCount = programLayout->getEntryPointCount(); + for (SlangInt e = 0; e < entryPointCount; ++e) + { + auto slangEntryPoint = programLayout->getEntryPointByIndex(e); + Slang::RefPtr entryPointObject = + new MutableRootShaderObject(device, slangEntryPoint->getTypeLayout()->getElementTypeLayout()); + + m_entryPoints.add(entryPointObject); + } + m_data.setCount(programLayout->getGlobalParamsTypeLayout()->getSize()); + memset(m_data.begin(), 0, m_data.getCount()); + m_elementTypeLayout = programLayout->getGlobalParamsTypeLayout(); + } + + + virtual SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL + getElementTypeLayout() override + { + return m_elementTypeLayout; + } + + virtual SLANG_NO_THROW ShaderObjectContainerType SLANG_MCALL getContainerType() override + { + return ShaderObjectContainerType::None; + } + + virtual SLANG_NO_THROW UInt SLANG_MCALL getEntryPointCount() override + { + return (UInt)m_entryPoints.getCount(); + } + + virtual SLANG_NO_THROW Result SLANG_MCALL + getEntryPoint(UInt index, IShaderObject** entryPoint) override + { + returnComPtr(entryPoint, m_entryPoints[index]); + return SLANG_OK; + } + + virtual SLANG_NO_THROW Result SLANG_MCALL + setData(ShaderOffset const& offset, void const* data, size_t size) override + { + auto newSize = Slang::Index(size + offset.uniformOffset); + if (newSize > m_data.getCount()) + m_data.setCount((Slang::Index)newSize); + memcpy(m_data.begin() + offset.uniformOffset, data, size); + return SLANG_OK; + } + + virtual SLANG_NO_THROW Result SLANG_MCALL + getObject(ShaderOffset const& offset, IShaderObject** object) override + { + *object = nullptr; + + Slang::RefPtr subObject; + if (m_objects.TryGetValue(offset, subObject)) + { + returnComPtr(object, subObject); + } + return SLANG_OK; + } + + virtual SLANG_NO_THROW Result SLANG_MCALL + setObject(ShaderOffset const& offset, IShaderObject* object) override + { + m_objects[offset] = static_cast(object); + return SLANG_OK; + } + + virtual SLANG_NO_THROW Result SLANG_MCALL + setResource(ShaderOffset const& offset, IResourceView* resourceView) override + { + m_resources[offset] = static_cast(resourceView); + return SLANG_OK; + } + + virtual SLANG_NO_THROW Result SLANG_MCALL + setSampler(ShaderOffset const& offset, ISamplerState* sampler) override + { + m_samplers[offset] = static_cast(sampler); + return SLANG_OK; + } + virtual SLANG_NO_THROW Result SLANG_MCALL setCombinedTextureSampler( + ShaderOffset const& offset, IResourceView* textureView, ISamplerState* sampler) override + { + m_resources[offset] = static_cast(textureView); + m_samplers[offset] = static_cast(sampler); + return SLANG_OK; + } + + virtual SLANG_NO_THROW Result SLANG_MCALL setSpecializationArgs( + ShaderOffset const& offset, + const slang::SpecializationArg* args, + uint32_t count) override + { + Slang::List specArgs; + specArgs.addRange(args, count); + m_specializationArgs[offset] = specArgs; + return SLANG_OK; + } + + virtual SLANG_NO_THROW Result SLANG_MCALL getCurrentVersion( + ITransientResourceHeap* transientHeap, IShaderObject** outObject) override + { + return SLANG_FAIL; + } + + virtual SLANG_NO_THROW Result SLANG_MCALL copyFrom(IShaderObject* other, ITransientResourceHeap* transientHeap) override + { + auto otherObject = static_cast(other); + *this = *otherObject; + return SLANG_OK; + } + + virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() override + { + return m_data.begin(); + } + + virtual SLANG_NO_THROW size_t SLANG_MCALL getSize() override + { + return (size_t)m_data.getCount(); + } + + virtual SLANG_NO_THROW Result SLANG_MCALL + setConstantBufferOverride(IBufferResource* constantBuffer) override + { + m_constantBufferOverride = static_cast(constantBuffer); + return SLANG_OK; + } + + virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override + { + SLANG_UNUSED(args); + return SLANG_OK; + } + }; + } diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp index 5b2dc949b..c212f21eb 100644 --- a/tools/gfx/renderer-shared.cpp +++ b/tools/gfx/renderer-shared.cpp @@ -799,4 +799,32 @@ IDebugCallback* _getNullDebugCallback() return &result; } +Result ShaderObjectBase::copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap) +{ + if (auto srcObj = dynamic_cast(object)) + { + setData(gfx::ShaderOffset(), srcObj->m_data.begin(), (size_t)srcObj->m_data.getCount()); + for (auto& kv : srcObj->m_objects) + { + ComPtr subObject; + SLANG_RETURN_ON_FAIL(kv.Value->getCurrentVersion(transientHeap, subObject.writeRef())); + setObject(kv.Key, subObject); + } + for (auto& kv : srcObj->m_resources) + { + setResource(kv.Key, kv.Value.Ptr()); + } + for (auto& kv : srcObj->m_samplers) + { + setSampler(kv.Key, kv.Value.Ptr()); + } + for (auto& kv : srcObj->m_specializationArgs) + { + setSpecializationArgs(kv.Key, kv.Value.begin(), (uint32_t)kv.Value.getCount()); + } + return SLANG_OK; + } + return SLANG_FAIL; +} + } // namespace gfx diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index abbeb285f..1d5d1860e 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -533,15 +533,12 @@ public: virtual SLANG_NO_THROW Result SLANG_MCALL getCurrentVersion( ITransientResourceHeap* transientHeap, IShaderObject** outObject) override { - SLANG_UNUSED(outObject); - return SLANG_E_NOT_AVAILABLE; + returnComPtr(outObject, this); + return SLANG_OK; } - virtual SLANG_NO_THROW Result SLANG_MCALL copyFrom(IShaderObject* other) override - { - SLANG_UNUSED(other); - return SLANG_E_NOT_AVAILABLE; - } + virtual SLANG_NO_THROW Result SLANG_MCALL + copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap) override; virtual SLANG_NO_THROW const void* SLANG_MCALL getRawData() override { diff --git a/tools/gfx/vulkan/render-vk.cpp b/tools/gfx/vulkan/render-vk.cpp index 5c3e75110..1fb06b8e7 100644 --- a/tools/gfx/vulkan/render-vk.cpp +++ b/tools/gfx/vulkan/render-vk.cpp @@ -106,6 +106,9 @@ public: override; virtual Result createMutableShaderObject(ShaderObjectLayoutBase* layout, IShaderObject** outObject) override; + virtual SLANG_NO_THROW Result SLANG_MCALL + createMutableRootShaderObject( + IShaderProgram* program, IShaderObject** outObject) override; virtual SLANG_NO_THROW Result SLANG_MCALL createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram) override; @@ -3449,6 +3452,21 @@ public: return SLANG_OK; } + virtual SLANG_NO_THROW Result SLANG_MCALL + copyFrom(IShaderObject* object, ITransientResourceHeap* transientHeap) override + { + SLANG_RETURN_ON_FAIL(Super::copyFrom(object, transientHeap)); + if (auto srcObj = dynamic_cast(object)) + { + for (Index i = 0; i < srcObj->m_entryPoints.getCount(); i++) + { + m_entryPoints[i]->copyFrom(srcObj->m_entryPoints[i], transientHeap); + } + return SLANG_OK; + } + return SLANG_FAIL; + } + /// Bind this object as a root shader object Result bindAsRoot( PipelineCommandEncoder* encoder, @@ -6744,8 +6762,8 @@ Result VKDevice::getTextureAllocationInfo( VkMemoryRequirements memRequirements; m_api.vkGetImageMemoryRequirements(m_device, image, &memRequirements); - *outSize = memRequirements.size; - *outAlignment = memRequirements.alignment; + *outSize = (size_t)memRequirements.size; + *outAlignment = (size_t)memRequirements.alignment; m_api.vkDestroyImage(m_device, image, nullptr); return SLANG_OK; @@ -7555,6 +7573,15 @@ Result VKDevice::createMutableShaderObject( return SLANG_OK; } +Result VKDevice::createMutableRootShaderObject( + IShaderProgram* program, IShaderObject** outObject) +{ + RefPtr result = + new MutableRootShaderObject(this, static_cast(program)); + returnComPtr(outObject, result); + return SLANG_OK; +} + Result VKDevice::createGraphicsPipelineState(const GraphicsPipelineStateDesc& inDesc, IPipelineState** outState) { GraphicsPipelineStateDesc desc = inDesc; -- cgit v1.2.3