From 3192f34f57abd3245995342a0a5971ebbbbd945c Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 15 Apr 2024 23:28:28 -0700 Subject: [GFX] Fix d3d12 buffer view creation logic for StructuredBuffers. (#3954) --- .../gfx-unit-test-tool/gfx-unit-test-tool.vcxproj | 2 + .../gfx-unit-test-tool.vcxproj.filters | 6 + slang.h | 3 +- source/core/slang-array.h | 4 +- source/slang/slang-check-conformance.cpp | 10 - source/slang/slang-check-decl.cpp | 10 + source/slang/slang-check-impl.h | 3 +- source/slang/slang-parameter-binding.cpp | 36 ++-- source/slang/slang-reflection-api.cpp | 32 +++- source/slang/slang-type-layout.cpp | 13 +- source/slang/slang-type-layout.h | 8 +- tools/gfx-unit-test/uint16-buffer.slang | 14 ++ tools/gfx-unit-test/uint16-structured-buffer.cpp | 96 ++++++++++ tools/gfx/d3d12/d3d12-device.cpp | 130 +------------ tools/gfx/d3d12/d3d12-resource-views.cpp | 203 +++++++++++++++++++++ tools/gfx/d3d12/d3d12-resource-views.h | 23 +++ tools/gfx/d3d12/d3d12-shader-object-layout.cpp | 16 ++ tools/gfx/d3d12/d3d12-shader-object-layout.h | 3 + tools/gfx/d3d12/d3d12-shader-object.cpp | 15 +- 19 files changed, 464 insertions(+), 163 deletions(-) create mode 100644 tools/gfx-unit-test/uint16-buffer.slang create mode 100644 tools/gfx-unit-test/uint16-structured-buffer.cpp diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj index 955c702d9..8d204f9b3 100644 --- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj +++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj @@ -323,6 +323,7 @@ + @@ -350,6 +351,7 @@ + diff --git a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters index 6d48a99ae..26e7a0a58 100644 --- a/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters +++ b/build/visual-studio/gfx-unit-test-tool/gfx-unit-test-tool.vcxproj.filters @@ -125,6 +125,9 @@ Source Files + + Source Files + Source Files @@ -202,5 +205,8 @@ Source Files + + Source Files + \ No newline at end of file diff --git a/slang.h b/slang.h index 4ed37d88c..77e9d3bd9 100644 --- a/slang.h +++ b/slang.h @@ -2390,6 +2390,7 @@ extern "C" SLANG_API size_t spReflectionTypeLayout_GetStride(SlangReflectionTypeLayout* type, SlangParameterCategory category); SLANG_API int32_t spReflectionTypeLayout_getAlignment(SlangReflectionTypeLayout* type, SlangParameterCategory category); + SLANG_API uint32_t spReflectionTypeLayout_GetFieldCount(SlangReflectionTypeLayout* type); SLANG_API SlangReflectionVariableLayout* spReflectionTypeLayout_GetFieldByIndex(SlangReflectionTypeLayout* type, unsigned index); SLANG_API SlangInt spReflectionTypeLayout_findFieldIndexByName(SlangReflectionTypeLayout* typeLayout, const char* nameBegin, const char* nameEnd); @@ -2884,7 +2885,7 @@ namespace slang unsigned int getFieldCount() { - return getType()->getFieldCount(); + return spReflectionTypeLayout_GetFieldCount((SlangReflectionTypeLayout*)this); } VariableLayoutReflection* getFieldByIndex(unsigned int index) diff --git a/source/core/slang-array.h b/source/core/slang-array.h index f1ed4fe0d..d24ff0970 100644 --- a/source/core/slang-array.h +++ b/source/core/slang-array.h @@ -159,14 +159,14 @@ namespace Slang void insertArray(Array& arr, const T& val, TArgs... args) { arr.add(val); - insertArray(arr, args...); + insertArray(arr, args...); } template auto makeArray(TArgs ...args) -> Array::Type, sizeof...(args)> { Array::Type, Index(sizeof...(args))> rs; - insertArray(rs, args...); + insertArray::Type>(rs, args...); return rs; } diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index ff4a40031..a0e51b180 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -227,16 +227,6 @@ namespace Slang return nullptr; } - bool SemanticsVisitor::isInterfaceType(Type* type) - { - if (auto declRefType = as(type)) - { - if (auto interfaceDeclRef = declRefType->getDeclRef().as()) - return true; - } - return false; - } - bool SemanticsVisitor::isValidGenericConstraintType(Type* type) { if (auto andType = as(type)) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5190f8c0e..4f8dd3dc5 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1329,6 +1329,16 @@ namespace Slang return arrayType->isUnsized(); } + bool isInterfaceType(Type* type) + { + if (auto declRefType = as(type)) + { + if (auto interfaceDeclRef = declRefType->getDeclRef().as()) + return true; + } + return false; + } + EnumDecl* isEnumType(Type* type) { if (auto declRefType = as(type)) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index e6e980fe8..ff7ce6978 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2040,7 +2040,6 @@ namespace Slang SubtypeWitness* checkAndConstructSubtypeWitness(Type* subType, Type* superType); - bool isInterfaceType(Type* type); bool isValidGenericConstraintType(Type* type); bool isTypeDifferentiable(Type* type); @@ -2763,6 +2762,8 @@ namespace Slang bool isUnsizedArrayType(Type* type); + bool isInterfaceType(Type* type); + EnumDecl* isEnumType(Type* type); DeclVisibility getDeclVisibility(Decl* decl); diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 6725380f2..516b9eb44 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -935,13 +935,13 @@ static void addExplicitParameterBinding( if (overlappedVarLayout) { //legal if atomicUint - if(parameterInfo->varLayout->varDecl.getDecl()->getType()->astNodeType == ASTNodeType::GLSLAtomicUintType - && overlappedVarLayout->varDecl.getDecl()->getType()->astNodeType == ASTNodeType::GLSLAtomicUintType) + if(parameterInfo->varLayout->getVariable()->getType()->astNodeType == ASTNodeType::GLSLAtomicUintType + && overlappedVarLayout->getVariable()->getType()->astNodeType == ASTNodeType::GLSLAtomicUintType) { return; } - auto paramA = parameterInfo->varLayout->varDecl.getDecl(); - auto paramB = overlappedVarLayout->varDecl.getDecl(); + auto paramA = parameterInfo->varLayout->getVariable(); + auto paramB = overlappedVarLayout->getVariable(); auto& diagnosticInfo = Diagnostics::parameterBindingsOverlap; @@ -1024,7 +1024,8 @@ static void addExplicitParameterBindings_HLSL( // TODO: warning here! } - addExplicitParameterBinding(context, parameterInfo, varDecl.getDecl(), semanticInfo, count); + if (auto varDeclBase = varDecl.as()) + addExplicitParameterBinding(context, parameterInfo, varDeclBase.getDecl(), semanticInfo, count); } } @@ -1133,10 +1134,12 @@ static void addExplicitParameterBindings_GLSL( auto count = resInfo->count; semanticInfo.kind = kind; - addExplicitParameterBinding(context, parameterInfo, varDecl.getDecl(), semanticInfo, count); - if (foundSubpass) - addExplicitParameterBinding(context, parameterInfo, varDecl.getDecl(), subpassSemanticInfo, count); - + if (auto varDeclBase = varDecl.as()) + { + addExplicitParameterBinding(context, parameterInfo, varDeclBase.getDecl(), semanticInfo, count); + if (foundSubpass) + addExplicitParameterBinding(context, parameterInfo, varDeclBase.getDecl(), subpassSemanticInfo, count); + } return; } @@ -1147,7 +1150,7 @@ static void addExplicitParameterBindings_GLSL( // If we have the options, but cannot infer bindings, we don't need to go further if (hlslToVulkanLayoutOptions == nullptr || !hlslToVulkanLayoutOptions->canInferBindings()) { - _maybeDiagnoseMissingVulkanLayoutModifier(context, varDecl); + _maybeDiagnoseMissingVulkanLayoutModifier(context, varDecl.as()); return; } @@ -1169,7 +1172,7 @@ static void addExplicitParameterBindings_GLSL( // We can't infer TextureSampler from HLSL (it's not an HLSL concept) // So use default layout - auto varType = varDecl.getDecl()->getType(); + auto varType = getType(context->getASTBuilder(), varDecl.as()); if (auto textureType = as(varType)) { if (textureType->isCombined()) @@ -1187,7 +1190,7 @@ static void addExplicitParameterBindings_GLSL( // If inference is not enabled for this kind, we can issue a warning if (!hlslToVulkanLayoutOptions->canInfer(vulkanKind, hlslInfo.space)) { - _maybeDiagnoseMissingVulkanLayoutModifier(context, varDecl); + _maybeDiagnoseMissingVulkanLayoutModifier(context, varDecl.as()); return; } @@ -1201,7 +1204,7 @@ static void addExplicitParameterBindings_GLSL( const LayoutSize count = resInfo->count; - addExplicitParameterBinding(context, parameterInfo, varDecl.getDecl(), semanticInfo, count); + addExplicitParameterBinding(context, parameterInfo, as(varDecl.getDecl()), semanticInfo, count); } // Given a single parameter, collect whatever information we have on @@ -2408,7 +2411,10 @@ struct ScopeLayoutBuilder { auto rules = m_layoutContext.rules; m_pendingDataTypeLayoutBuilder.beginLayoutIfNeeded(nullptr, rules); - auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(varLayout->varDecl, fieldPendingDataTypeLayout); + auto varDeclBase = varLayout->varDecl.as(); + if (!varDeclBase) + return; + auto fieldPendingDataVarLayout = m_pendingDataTypeLayoutBuilder.addField(varDeclBase, fieldPendingDataTypeLayout); m_structLayout->pendingDataTypeLayout = m_pendingDataTypeLayoutBuilder.getTypeLayout(); @@ -4011,7 +4017,7 @@ RefPtr generateParameterBindings( if( varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform) ) { needDefaultConstantBuffer = true; - diagnoseGlobalUniform(&sharedContext, varLayout->varDecl.getDecl()); + diagnoseGlobalUniform(&sharedContext, as(varLayout->varDecl.getDecl())); } } } diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 90103d8d9..ab73ce7f4 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -55,12 +55,12 @@ static inline SpecializationParamLayout* convert(SlangReflectionTypeParameter * return (SpecializationParamLayout*) typeParam; } -static inline VarDeclBase* convert(SlangReflectionVariable* var) +static inline Decl* convert(SlangReflectionVariable* var) { - return (VarDeclBase*) var; + return (Decl*) var; } -static inline SlangReflectionVariable* convert(VarDeclBase* var) +static inline SlangReflectionVariable* convert(Decl* var) { return (SlangReflectionVariable*) var; } @@ -932,7 +932,7 @@ SLANG_API SlangInt spReflectionTypeLayout_findFieldIndexByName(SlangReflectionTy for(Index f = 0; f < fieldCount; ++f) { auto field = structTypeLayout->fields[f]; - if(getReflectionName(field->varDecl.getDecl())->text.getUnownedSlice() == name) + if(getReflectionName(field->getVariable())->text.getUnownedSlice() == name) return f; } } @@ -1059,6 +1059,18 @@ SLANG_API SlangParameterCategory spReflectionTypeLayout_GetParameterCategory(Sla return getParameterCategory(typeLayout); } +SLANG_API uint32_t spReflectionTypeLayout_GetFieldCount(SlangReflectionTypeLayout* inTypeLayout) +{ + auto typeLayout = convert(inTypeLayout); + if (!typeLayout) return 0; + + if (auto structTypeLayout = as(typeLayout)) + { + return (uint32_t)structTypeLayout->fields.getCount(); + } + return 0; +} + SLANG_API unsigned spReflectionTypeLayout_GetCategoryCount(SlangReflectionTypeLayout* inTypeLayout) { auto typeLayout = convert(inTypeLayout); @@ -2508,6 +2520,9 @@ SLANG_API SlangInt spReflectionTypeLayout_getSubObjectRangeDescriptorRangeSpaceO SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* inVar) { auto var = convert(inVar); + if (as(var)) + return "$base"; + if(!var) return nullptr; // If the variable is one that has an "external" name that is supposed @@ -2521,14 +2536,19 @@ SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* inVa SLANG_API SlangReflectionType* spReflectionVariable_GetType(SlangReflectionVariable* inVar) { auto var = convert(inVar); + + if (auto inheritanceDecl = as(var)) + return convert(inheritanceDecl->base.type); + if(!var) return nullptr; - return convert(var->getType()); + return convert(as(var)->getType()); } SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflectionVariable* inVar, SlangModifierID modifierID) { auto var = convert(inVar); + if(!var) return nullptr; Modifier* modifier = nullptr; @@ -2571,7 +2591,7 @@ SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangR auto varLayout = convert(inVarLayout); if(!varLayout) return nullptr; - return convert(varLayout->varDecl.getDecl()); + return (SlangReflectionVariable*)(varLayout->varDecl.getDecl()); } SLANG_API SlangReflectionTypeLayout* spReflectionVariableLayout_GetTypeLayout(SlangReflectionVariableLayout* inVarLayout) diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index e791efadb..30c41d0fe 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -3,6 +3,7 @@ #include "slang-syntax.h" #include "slang-ir-insts.h" +#include "slang-check-impl.h" #include "../compiler-core/slang-artifact-desc-util.h" @@ -3331,7 +3332,7 @@ void StructTypeLayoutBuilder::beginLayoutIfNeeded( } RefPtr StructTypeLayoutBuilder::addField( - DeclRef field, + DeclRef field, TypeLayoutResult fieldResult) { SLANG_ASSERT(m_typeLayout); @@ -4082,6 +4083,16 @@ static TypeLayoutResult _createTypeLayout( _addLayout(context, type, typeLayout); + // Add all base fields first. + for (auto inheritanceDeclRef : getMembersOfType(context.astBuilder, structDeclRef)) + { + auto baseType = getSup(context.astBuilder, inheritanceDeclRef); + if (isInterfaceType(baseType)) + continue; + auto baseTypeLayout = _createTypeLayout(context, baseType); + typeLayoutBuilder.addField(inheritanceDeclRef, baseTypeLayout); + } + // First, add all fields with explicit offsets. for (auto field : getFields(context.astBuilder, structDeclRef, MemberFilterStyle::Instance)) { diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index eb0884287..a44efb085 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -514,8 +514,8 @@ class VarLayout : public Layout { public: // The variable we are laying out - DeclRef varDecl; - VarDeclBase* getVariable() { return varDecl.getDecl(); } + DeclRef varDecl; + VarDeclBase* getVariable() { return varDecl.as().getDecl(); } Name* getName() { return getVariable()->getName(); } @@ -721,7 +721,7 @@ public: // TODO: This should map from a declaration to the *index* // in the array above, rather than to the actual pointer, // so that we - Dictionary> mapVarToLayout; + Dictionary> mapVarToLayout; }; class GenericParamTypeLayout : public TypeLayout @@ -1207,7 +1207,7 @@ public: /// One of the `beginLayout*()` functions must have been called previously. /// RefPtr addField( - DeclRef field, + DeclRef field, TypeLayoutResult fieldResult); RefPtr addExplicitUniformField( diff --git a/tools/gfx-unit-test/uint16-buffer.slang b/tools/gfx-unit-test/uint16-buffer.slang new file mode 100644 index 000000000..bdc09c7d8 --- /dev/null +++ b/tools/gfx-unit-test/uint16-buffer.slang @@ -0,0 +1,14 @@ +// uint16-buffer.slang - Simple shader that takes a buffer of uint16 type and increments all elements by 1. + +// This is to verify that GFX can correct set correct buffer strides for structured buffer bindings. + +uniform RWStructuredBuffer buffer; + +[shader("compute")] +[numthreads(4,1,1)] +void computeMain( + uint3 sv_dispatchThreadID : SV_DispatchThreadID) +{ + var input = buffer[sv_dispatchThreadID.x]; + buffer[sv_dispatchThreadID.x] = input + 1; +} diff --git a/tools/gfx-unit-test/uint16-structured-buffer.cpp b/tools/gfx-unit-test/uint16-structured-buffer.cpp new file mode 100644 index 000000000..8f2f2cb97 --- /dev/null +++ b/tools/gfx-unit-test/uint16-structured-buffer.cpp @@ -0,0 +1,96 @@ +#include "tools/unit-test/slang-unit-test.h" + +#include "slang-gfx.h" +#include "gfx-test-util.h" +#include "tools/gfx-util/shader-cursor.h" +#include "source/core/slang-basic.h" + +using namespace gfx; + +namespace gfx_test +{ + void uint16BufferTestImpl(IDevice* device, UnitTestContext* context) + { + Slang::ComPtr transientHeap; + ITransientResourceHeap::Desc transientHeapDesc = {}; + transientHeapDesc.constantBufferSize = 4096; + GFX_CHECK_CALL_ABORT( + device->createTransientResourceHeap(transientHeapDesc, transientHeap.writeRef())); + + ComPtr shaderProgram; + slang::ProgramLayout* slangReflection; + GFX_CHECK_CALL_ABORT(loadComputeProgram(device, shaderProgram, "uint16-buffer", "computeMain", slangReflection)); + + ComputePipelineStateDesc pipelineDesc = {}; + pipelineDesc.program = shaderProgram.get(); + ComPtr pipelineState; + GFX_CHECK_CALL_ABORT( + device->createComputePipelineState(pipelineDesc, pipelineState.writeRef())); + + const int numberCount = 4; + uint16_t initialData[] = { 0, 1, 2, 3 }; + IBufferResource::Desc bufferDesc = {}; + bufferDesc.sizeInBytes = numberCount * sizeof(uint16_t); + bufferDesc.format = gfx::Format::Unknown; + // Note: we don't specify any element size here, and gfx should be able to derive the + // correct element size from the reflection infomation. + bufferDesc.elementSize = 0; + bufferDesc.allowedStates = ResourceStateSet( + ResourceState::ShaderResource, + ResourceState::UnorderedAccess, + ResourceState::CopyDestination, + ResourceState::CopySource); + bufferDesc.defaultState = ResourceState::UnorderedAccess; + bufferDesc.memoryType = MemoryType::DeviceLocal; + + ComPtr numbersBuffer; + GFX_CHECK_CALL_ABORT(device->createBufferResource( + bufferDesc, + (void*)initialData, + numbersBuffer.writeRef())); + + ComPtr bufferView; + IResourceView::Desc viewDesc = {}; + viewDesc.type = IResourceView::Type::UnorderedAccess; + viewDesc.format = Format::Unknown; + GFX_CHECK_CALL_ABORT( + device->createBufferView(numbersBuffer, nullptr, viewDesc, bufferView.writeRef())); + + // We have done all the set up work, now it is time to start recording a command buffer for + // GPU execution. + { + ICommandQueue::Desc queueDesc = { ICommandQueue::QueueType::Graphics }; + auto queue = device->createCommandQueue(queueDesc); + + auto commandBuffer = transientHeap->createCommandBuffer(); + auto encoder = commandBuffer->encodeComputeCommands(); + + auto rootObject = encoder->bindPipeline(pipelineState); + + // Bind buffer view to the entry point. + ShaderCursor(rootObject).getPath("buffer").setResource(bufferView); + + encoder->dispatchCompute(1, 1, 1); + encoder->endEncoding(); + commandBuffer->close(); + queue->executeCommandBuffer(commandBuffer); + queue->waitOnHost(); + } + + compareComputeResult( + device, + numbersBuffer, + Slang::makeArray(1, 2, 3, 4)); + } + + SLANG_UNIT_TEST(uint16BufferTestD3D12) + { + runTestImpl(uint16BufferTestImpl, unitTestContext, Slang::RenderApiFlag::D3D12); + } + + SLANG_UNIT_TEST(uint16BufferTestVulkan) + { + runTestImpl(uint16BufferTestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan); + } + +} diff --git a/tools/gfx/d3d12/d3d12-device.cpp b/tools/gfx/d3d12/d3d12-device.cpp index 33375497b..03f9997f4 100644 --- a/tools/gfx/d3d12/d3d12-device.cpp +++ b/tools/gfx/d3d12/d3d12-device.cpp @@ -1674,126 +1674,16 @@ Result DeviceImpl::createBufferView( viewImpl->m_counterResource = counterResourceImpl; viewImpl->m_desc = desc; - switch (desc.type) - { - default: - return SLANG_FAIL; - - case IResourceView::Type::UnorderedAccess: - { - D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc = {}; - uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER; - uavDesc.Format = D3DUtil::getMapFormat(desc.format); - uavDesc.Buffer.FirstElement = desc.bufferRange.firstElement; - uint64_t viewSize = 0; - if (desc.bufferElementSize) - { - uavDesc.Buffer.StructureByteStride = (UINT)desc.bufferElementSize; - uavDesc.Buffer.NumElements = - desc.bufferRange.elementCount == 0 - ? UINT(resourceDesc.sizeInBytes / desc.bufferElementSize) - : (UINT)desc.bufferRange.elementCount; - viewSize = (uint64_t)desc.bufferElementSize * uavDesc.Buffer.NumElements; - } - else if (desc.format == Format::Unknown) - { - uavDesc.Format = DXGI_FORMAT_R32_TYPELESS; - uavDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 - ? UINT(resourceDesc.sizeInBytes / 4) - : UINT(desc.bufferRange.elementCount / 4); - uavDesc.Buffer.Flags |= D3D12_BUFFER_UAV_FLAG_RAW; - viewSize = 4ull * uavDesc.Buffer.NumElements; - } - else - { - FormatInfo sizeInfo; - gfxGetFormatInfo(desc.format, &sizeInfo); - assert(sizeInfo.pixelsPerBlock == 1); - uavDesc.Buffer.NumElements = - desc.bufferRange.elementCount == 0 - ? UINT(resourceDesc.sizeInBytes / sizeInfo.blockSizeInBytes) - : (UINT)desc.bufferRange.elementCount; - viewSize = (uint64_t)uavDesc.Buffer.NumElements * sizeInfo.blockSizeInBytes; - } - - if (viewSize >= (1ull << 32) - 8) - { - // D3D12 does not support view descriptors that has size near 4GB. - // We will not create actual SRV/UAVs for such large buffers. - // However, a buffer this large can still be bound as root parameter. - // So instead of failing, we quietly ignore descriptor creation. - viewImpl->m_descriptor.cpuHandle.ptr = 0; - } - else - { - SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor)); - viewImpl->m_allocator = m_cpuViewHeap; - m_device->CreateUnorderedAccessView( - resourceImpl->m_resource, - counterResourceImpl ? counterResourceImpl->m_resource.getResource() : nullptr, - &uavDesc, - viewImpl->m_descriptor.cpuHandle); - } - } - break; - - case IResourceView::Type::ShaderResource: - { - D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc = {}; - srvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; - srvDesc.Format = D3DUtil::getMapFormat(desc.format); - srvDesc.Buffer.StructureByteStride = 0; - srvDesc.Buffer.FirstElement = desc.bufferRange.firstElement; - srvDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING; - uint64_t viewSize = 0; - if (desc.bufferElementSize) - { - srvDesc.Buffer.StructureByteStride = (UINT)desc.bufferElementSize; - srvDesc.Buffer.NumElements = - desc.bufferRange.elementCount == 0 - ? UINT(resourceDesc.sizeInBytes / desc.bufferElementSize) - : (UINT)desc.bufferRange.elementCount; - viewSize = (uint64_t)desc.bufferElementSize * srvDesc.Buffer.NumElements; - } - else if (desc.format == Format::Unknown) - { - srvDesc.Format = DXGI_FORMAT_R32_TYPELESS; - srvDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 - ? UINT(resourceDesc.sizeInBytes / 4) - : UINT(desc.bufferRange.elementCount / 4); - srvDesc.Buffer.Flags |= D3D12_BUFFER_SRV_FLAG_RAW; - viewSize = 4ull * srvDesc.Buffer.NumElements; - } - else - { - FormatInfo sizeInfo; - gfxGetFormatInfo(desc.format, &sizeInfo); - assert(sizeInfo.pixelsPerBlock == 1); - srvDesc.Buffer.NumElements = - desc.bufferRange.elementCount == 0 - ? UINT(resourceDesc.sizeInBytes / sizeInfo.blockSizeInBytes) - : (UINT)desc.bufferRange.elementCount; - viewSize = (uint64_t)srvDesc.Buffer.NumElements * sizeInfo.blockSizeInBytes; - } - if (viewSize >= (1ull << 32) - 8) - { - // D3D12 does not support view descriptors that has size near 4GB. - // We will not create actual SRV/UAVs for such large buffers. - // However, a buffer this large can still be bound as root parameter. - // So instead of failing, we quietly ignore descriptor creation. - viewImpl->m_descriptor.cpuHandle.ptr = 0; - } - else - { - SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor)); - viewImpl->m_allocator = m_cpuViewHeap; - m_device->CreateShaderResourceView( - resourceImpl->m_resource, &srvDesc, viewImpl->m_descriptor.cpuHandle); - } - } - break; - } - + SLANG_RETURN_ON_FAIL(createD3D12BufferDescriptor( + resourceImpl, + counterResourceImpl, + desc, + this, + m_cpuViewHeap.get(), + &viewImpl->m_descriptor)); + if (viewImpl->m_descriptor.cpuHandle.ptr != 0) + viewImpl->m_allocator = m_cpuViewHeap.get(); + returnComPtr(outView, viewImpl); return SLANG_OK; } diff --git a/tools/gfx/d3d12/d3d12-resource-views.cpp b/tools/gfx/d3d12/d3d12-resource-views.cpp index 5a044dd3b..b0b441f87 100644 --- a/tools/gfx/d3d12/d3d12-resource-views.cpp +++ b/tools/gfx/d3d12/d3d12-resource-views.cpp @@ -1,5 +1,6 @@ // d3d12-resource-views.cpp #include "d3d12-resource-views.h" +#include "d3d12-device.h" namespace gfx { @@ -12,6 +13,208 @@ ResourceViewInternalImpl::~ResourceViewInternalImpl() { if (m_descriptor.cpuHandle.ptr) m_allocator->free(m_descriptor); + for (auto desc : m_mapBufferStrideToDescriptor) + { + m_allocator->free(desc.second); + } +} + +SlangResult createD3D12BufferDescriptor( + BufferResourceImpl* buffer, + BufferResourceImpl* counterBuffer, + IResourceView::Desc const& desc, + DeviceImpl* device, + D3D12GeneralExpandingDescriptorHeap* descriptorHeap, + D3D12Descriptor* outDescriptor) +{ + + auto resourceImpl = (BufferResourceImpl*)buffer; + auto resourceDesc = *resourceImpl->getDesc(); + const auto counterResourceImpl = static_cast(counterBuffer); + + switch (desc.type) + { + default: + return SLANG_FAIL; + + case IResourceView::Type::UnorderedAccess: + { + D3D12_UNORDERED_ACCESS_VIEW_DESC uavDesc = {}; + uavDesc.ViewDimension = D3D12_UAV_DIMENSION_BUFFER; + uavDesc.Format = D3DUtil::getMapFormat(desc.format); + uavDesc.Buffer.FirstElement = desc.bufferRange.firstElement; + uint64_t viewSize = 0; + if (desc.bufferElementSize) + { + uavDesc.Buffer.StructureByteStride = (UINT)desc.bufferElementSize; + uavDesc.Buffer.NumElements = + desc.bufferRange.elementCount == 0 + ? UINT(resourceDesc.sizeInBytes / desc.bufferElementSize) + : (UINT)desc.bufferRange.elementCount; + viewSize = (uint64_t)desc.bufferElementSize * uavDesc.Buffer.NumElements; + } + else if (desc.format == Format::Unknown) + { + uavDesc.Format = DXGI_FORMAT_R32_TYPELESS; + uavDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 + ? UINT(resourceDesc.sizeInBytes / 4) + : UINT(desc.bufferRange.elementCount / 4); + uavDesc.Buffer.Flags |= D3D12_BUFFER_UAV_FLAG_RAW; + viewSize = 4ull * uavDesc.Buffer.NumElements; + } + else + { + FormatInfo sizeInfo; + gfxGetFormatInfo(desc.format, &sizeInfo); + assert(sizeInfo.pixelsPerBlock == 1); + uavDesc.Buffer.NumElements = + desc.bufferRange.elementCount == 0 + ? UINT(resourceDesc.sizeInBytes / sizeInfo.blockSizeInBytes) + : (UINT)desc.bufferRange.elementCount; + viewSize = (uint64_t)uavDesc.Buffer.NumElements * sizeInfo.blockSizeInBytes; + } + + if (viewSize >= (1ull << 32) - 8) + { + // D3D12 does not support view descriptors that has size near 4GB. + // We will not create actual SRV/UAVs for such large buffers. + // However, a buffer this large can still be bound as root parameter. + // So instead of failing, we quietly ignore descriptor creation. + outDescriptor->cpuHandle.ptr = 0; + } + else + { + SLANG_RETURN_ON_FAIL(descriptorHeap->allocate(outDescriptor)); + device->m_device->CreateUnorderedAccessView( + resourceImpl->m_resource, + counterResourceImpl ? counterResourceImpl->m_resource.getResource() : nullptr, + &uavDesc, + outDescriptor->cpuHandle); + } + } + break; + + case IResourceView::Type::ShaderResource: + { + D3D12_SHADER_RESOURCE_VIEW_DESC srvDesc = {}; + srvDesc.ViewDimension = D3D12_SRV_DIMENSION_BUFFER; + srvDesc.Format = D3DUtil::getMapFormat(desc.format); + srvDesc.Buffer.StructureByteStride = 0; + srvDesc.Buffer.FirstElement = desc.bufferRange.firstElement; + srvDesc.Shader4ComponentMapping = D3D12_DEFAULT_SHADER_4_COMPONENT_MAPPING; + uint64_t viewSize = 0; + if (desc.bufferElementSize) + { + srvDesc.Buffer.StructureByteStride = (UINT)desc.bufferElementSize; + srvDesc.Buffer.NumElements = + desc.bufferRange.elementCount == 0 + ? UINT(resourceDesc.sizeInBytes / desc.bufferElementSize) + : (UINT)desc.bufferRange.elementCount; + viewSize = (uint64_t)desc.bufferElementSize * srvDesc.Buffer.NumElements; + } + else if (desc.format == Format::Unknown) + { + srvDesc.Format = DXGI_FORMAT_R32_TYPELESS; + srvDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 + ? UINT(resourceDesc.sizeInBytes / 4) + : UINT(desc.bufferRange.elementCount / 4); + srvDesc.Buffer.Flags |= D3D12_BUFFER_SRV_FLAG_RAW; + viewSize = 4ull * srvDesc.Buffer.NumElements; + } + else + { + FormatInfo sizeInfo; + gfxGetFormatInfo(desc.format, &sizeInfo); + assert(sizeInfo.pixelsPerBlock == 1); + srvDesc.Buffer.NumElements = + desc.bufferRange.elementCount == 0 + ? UINT(resourceDesc.sizeInBytes / sizeInfo.blockSizeInBytes) + : (UINT)desc.bufferRange.elementCount; + viewSize = (uint64_t)srvDesc.Buffer.NumElements * sizeInfo.blockSizeInBytes; + } + if (viewSize >= (1ull << 32) - 8) + { + // D3D12 does not support view descriptors that has size near 4GB. + // We will not create actual SRV/UAVs for such large buffers. + // However, a buffer this large can still be bound as root parameter. + // So instead of failing, we quietly ignore descriptor creation. + outDescriptor->cpuHandle.ptr = 0; + } + else + { + SLANG_RETURN_ON_FAIL(descriptorHeap->allocate(outDescriptor)); + device->m_device->CreateShaderResourceView( + resourceImpl->m_resource, &srvDesc, outDescriptor->cpuHandle); + } + } + break; + } + return SLANG_OK; +} + +SlangResult ResourceViewInternalImpl::getBufferDescriptorForBinding( + DeviceImpl* device, + ResourceViewImpl* view, + uint32_t bufferStride, + D3D12Descriptor& outDescriptor) +{ + // If stride is 0, just use the default descriptor. + if (bufferStride == 0) + { + outDescriptor = m_descriptor; + return SLANG_OK; + } + + // Otherwise, look for an existing descriptor from the cache if it exists. + if (auto descriptor = m_mapBufferStrideToDescriptor.tryGetValue(bufferStride)) + { + outDescriptor = *descriptor; + return SLANG_OK; + } + + // We need to create and cache a d3d12 descriptor for the resource view that encodes + // the given buffer stride. + auto bufferResImpl = static_cast(view->m_resource.get()); + auto desc = view->m_desc; + uint64_t bufferSize = 0; + if (desc.bufferElementSize == 0) + { + // If buffer element size is 0, we assume the buffer range from original desc is in bytes. + bufferSize = desc.bufferRange.elementCount; + if (bufferSize == 0) + { + bufferSize = bufferResImpl->getDesc()->sizeInBytes - desc.bufferRange.firstElement; + } + desc.bufferElementSize = bufferStride; + desc.bufferRange.firstElement /= bufferStride; + desc.bufferRange.elementCount = bufferSize / bufferStride; + } + else + { + // If buffer element size is not 0, we assume the buffer range from original desc is in elements + // of original stride. + if (desc.bufferRange.elementCount == 0) + { + bufferSize = bufferResImpl->getDesc()->sizeInBytes - desc.bufferRange.firstElement * desc.bufferElementSize; + } + else + { + bufferSize = desc.bufferRange.elementCount * desc.bufferElementSize; + } + desc.bufferElementSize = bufferStride; + desc.bufferRange.firstElement = desc.bufferRange.firstElement * desc.bufferElementSize / bufferStride; + desc.bufferRange.elementCount = bufferSize / bufferStride; + } + SLANG_RETURN_ON_FAIL(createD3D12BufferDescriptor( + bufferResImpl, + static_cast(view->m_counterResource.get()), + desc, + device, + m_allocator, + &outDescriptor)); + m_mapBufferStrideToDescriptor[bufferStride] = outDescriptor; + + return SLANG_OK; } Result ResourceViewImpl::getNativeHandle(InteropHandle* outHandle) diff --git a/tools/gfx/d3d12/d3d12-resource-views.h b/tools/gfx/d3d12/d3d12-resource-views.h index fd3f44116..f8c6654f2 100644 --- a/tools/gfx/d3d12/d3d12-resource-views.h +++ b/tools/gfx/d3d12/d3d12-resource-views.h @@ -12,14 +12,37 @@ namespace d3d12 using namespace Slang; +class ResourceViewImpl; + class ResourceViewInternalImpl { public: + // The default descriptor for the view. D3D12Descriptor m_descriptor; + + // StructuredBuffer descriptors for different strides. + Dictionary m_mapBufferStrideToDescriptor; + RefPtr m_allocator; + ~ResourceViewInternalImpl(); + + // Get a d3d12 descriptor from the buffer view with the given buffer element stride. + SlangResult getBufferDescriptorForBinding( + DeviceImpl* device, + ResourceViewImpl* view, + uint32_t bufferStride, + D3D12Descriptor& outDescriptor); }; +SlangResult createD3D12BufferDescriptor( + BufferResourceImpl* buffer, + BufferResourceImpl* counterBuffer, + IResourceView::Desc const& desc, + DeviceImpl* device, + D3D12GeneralExpandingDescriptorHeap* descriptorHeap, + D3D12Descriptor* outDescriptor); + class ResourceViewImpl : public ResourceViewBase , public ResourceViewInternalImpl diff --git a/tools/gfx/d3d12/d3d12-shader-object-layout.cpp b/tools/gfx/d3d12/d3d12-shader-object-layout.cpp index 0c0095621..ef59c3672 100644 --- a/tools/gfx/d3d12/d3d12-shader-object-layout.cpp +++ b/tools/gfx/d3d12/d3d12-shader-object-layout.cpp @@ -116,6 +116,7 @@ Result ShaderObjectLayoutImpl::Builder::setElementTypeLayout( uint32_t count = (uint32_t)typeLayout->getBindingRangeBindingCount(r); slang::TypeLayoutReflection* slangLeafTypeLayout = typeLayout->getBindingRangeLeafTypeLayout(r); + BindingRangeInfo bindingRangeInfo = {}; bindingRangeInfo.bindingType = slangBindingType; bindingRangeInfo.resourceShape = slangLeafTypeLayout->getResourceShape(); @@ -126,6 +127,21 @@ Result ShaderObjectLayoutImpl::Builder::setElementTypeLayout( typeLayout, r); bindingRangeInfo.isSpecializable = typeLayout->isBindingRangeSpecializable(r); + switch (slangBindingType) + { + case slang::BindingType::RawBuffer: + case slang::BindingType::TypedBuffer: + case slang::BindingType::MutableRawBuffer: + case slang::BindingType::MutableTypedBuffer: + { + auto bufferElementType = slangLeafTypeLayout->getElementTypeLayout(); + if (bufferElementType) + { + bindingRangeInfo.bufferElementStride = (uint32_t)bufferElementType->getStride(); + } + } + break; + } if (bindingRangeInfo.isRootParameter) { RootParameterInfo rootInfo = {}; diff --git a/tools/gfx/d3d12/d3d12-shader-object-layout.h b/tools/gfx/d3d12/d3d12-shader-object-layout.h index 4b46df0b0..b8b3082f6 100644 --- a/tools/gfx/d3d12/d3d12-shader-object-layout.h +++ b/tools/gfx/d3d12/d3d12-shader-object-layout.h @@ -66,6 +66,9 @@ public: /// as a sub-object. uint32_t subObjectIndex; + /// The stride of a structured buffer. + uint32_t bufferElementStride; + bool isRootParameter; /// Is this binding range represent a specialization point, such as an existential value, or a `ParameterBlock`. diff --git a/tools/gfx/d3d12/d3d12-shader-object.cpp b/tools/gfx/d3d12/d3d12-shader-object.cpp index 1b7b51106..3d38df5ab 100644 --- a/tools/gfx/d3d12/d3d12-shader-object.cpp +++ b/tools/gfx/d3d12/d3d12-shader-object.cpp @@ -939,6 +939,8 @@ Result ShaderObjectImpl::setResource(ShaderOffset const& offset, IResourceView* } ResourceViewInternalImpl* internalResourceView = nullptr; + auto resourceViewImpl = static_cast(resourceView); + switch (resourceView->getViewDesc()->type) { #if SLANG_GFX_HAS_DXR_SUPPORT @@ -953,7 +955,6 @@ Result ShaderObjectImpl::setResource(ShaderOffset const& offset, IResourceView* #endif default: { - auto resourceViewImpl = static_cast(resourceView); // Hold a reference to the resource to prevent its destruction. const auto resourceOffset = bindingRange.baseIndex + offset.bindingArrayIndex; m_boundResources[resourceOffset] = resourceViewImpl->m_resource; @@ -964,13 +965,21 @@ Result ShaderObjectImpl::setResource(ShaderOffset const& offset, IResourceView* } auto descriptorSlotIndex = bindingRange.baseIndex + (int32_t)offset.bindingArrayIndex; - if (internalResourceView->m_descriptor.cpuHandle.ptr) + D3D12Descriptor srcDescriptor = {}; + + SLANG_RETURN_ON_FAIL(internalResourceView->getBufferDescriptorForBinding( + static_cast(m_device.get()), + resourceViewImpl, + bindingRange.bufferElementStride, + srcDescriptor)); + + if (srcDescriptor.cpuHandle.ptr) { d3dDevice->CopyDescriptorsSimple( 1, m_descriptorSet.resourceTable.getCpuHandle( bindingRange.baseIndex + (int32_t)offset.bindingArrayIndex), - internalResourceView->m_descriptor.cpuHandle, + srcDescriptor.cpuHandle, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV); } else -- cgit v1.2.3