From b8982fcf43b86c1e39dcc3dd19bff2821633eda6 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 9 Feb 2022 15:30:38 -0800 Subject: Various fixes to gfx. (#2120) * Various fixes to gfx. * Fix. * Fixes. * Fix. Co-authored-by: Yong He --- tools/gfx/d3d12/render-d3d12.cpp | 341 +++++++++++++++++++++++++++++---------- 1 file changed, 256 insertions(+), 85 deletions(-) (limited to 'tools/gfx/d3d12/render-d3d12.cpp') diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp index a6d02cdc3..e5a1d1876 100644 --- a/tools/gfx/d3d12/render-d3d12.cpp +++ b/tools/gfx/d3d12/render-d3d12.cpp @@ -110,7 +110,10 @@ public: IResourceView::Desc const& desc, IResourceView** outView) override; virtual SLANG_NO_THROW Result SLANG_MCALL createBufferView( - IBufferResource* buffer, IResourceView::Desc const& desc, IResourceView** outView) override; + IBufferResource* buffer, + IBufferResource* counterBuffer, + IResourceView::Desc const& desc, + IResourceView** outView) override; virtual SLANG_NO_THROW Result SLANG_MCALL createFramebuffer(IFramebuffer::Desc const& desc, IFramebuffer** outFrameBuffer) override; @@ -377,7 +380,7 @@ public: { public: D3D12Descriptor m_descriptor; - Slang::RefPtr m_allocator; + Slang::RefPtr m_allocator; ~SamplerStateImpl() { m_allocator->free(m_descriptor); @@ -394,7 +397,7 @@ public: { public: D3D12Descriptor m_descriptor; - RefPtr m_allocator; + RefPtr m_allocator; ~ResourceViewInternalImpl() { m_allocator->free(m_descriptor); } }; @@ -2279,6 +2282,7 @@ public: struct ShaderBinary { SlangStage stage; + slang::EntryPointReflection* entryPointInfo; List code; }; @@ -3934,7 +3938,24 @@ public: virtual SLANG_NO_THROW void SLANG_MCALL dispatchComputeIndirect(IBufferResource* argBuffer, uint64_t offset) override { - SLANG_UNIMPLEMENTED_X("dispatchComputeIndirect"); + // Submit binding for compute + { + ComputeSubmitter submitter(m_d3dCmdList); + RefPtr newPipeline; + if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline))) + { + assert(!"Failed to bind render state"); + } + } + auto argBufferImpl = static_cast(argBuffer); + + m_d3dCmdList->ExecuteIndirect( + m_renderer->dispatchIndirectCmdSignature, + 1, + argBufferImpl->m_resource, + offset, + nullptr, + 0); } }; @@ -4233,11 +4254,13 @@ public: } auto rowSize = (footprint.Footprint.Width + formatInfo.blockWidth - 1) / formatInfo.blockWidth * formatInfo.blockSizeInBytes; + auto rowCount = (footprint.Footprint.Height + formatInfo.blockHeight - 1) / + formatInfo.blockHeight; footprint.Footprint.RowPitch = (UINT)D3DUtil::calcAligned( rowSize, (uint32_t)D3D12_TEXTURE_DATA_PITCH_ALIGNMENT); - auto bufferSize = footprint.Footprint.RowPitch * footprint.Footprint.Height * - footprint.Footprint.Depth; + auto bufferSize = + footprint.Footprint.RowPitch * rowCount * footprint.Footprint.Depth; IBufferResource* stagingBuffer; m_commandBuffer->m_transientHeap->allocateStagingBuffer( @@ -4249,11 +4272,10 @@ public: bufferImpl->m_resource.getResource()->Map(0, &mapRange, (void**)&bufferData); for (uint32_t z = 0; z < footprint.Footprint.Depth; z++) { - auto imageStart = bufferData + footprint.Footprint.RowPitch * - footprint.Footprint.Height * (size_t)z; + auto imageStart = bufferData + footprint.Footprint.RowPitch * rowCount * (size_t)z; auto srcData = (uint8_t*)subResourceData->data + subResourceData->strideZ * z; - for (uint32_t row = 0; row < footprint.Footprint.Height; row++) + for (uint32_t row = 0; row < rowCount; row++) { memcpy( imageStart + row * (size_t)footprint.Footprint.RowPitch, @@ -4979,14 +5001,14 @@ public: RefPtr m_resourceCommandQueue; RefPtr m_resourceCommandTransientHeap; - RefPtr m_rtvAllocator; - RefPtr m_dsvAllocator; + RefPtr m_rtvAllocator; + RefPtr m_dsvAllocator; // Space in the GPU-visible heaps is precious, so we will also keep // around CPU-visible heaps for storing descriptors in a format // that is ready for copying into the GPU-visible heaps as needed. // - RefPtr m_cpuViewHeap; ///< Cbv, Srv, Uav - RefPtr m_cpuSamplerHeap; ///< Heap for samplers + RefPtr m_cpuViewHeap; ///< Cbv, Srv, Uav + RefPtr m_cpuSamplerHeap; ///< Heap for samplers // Dll entry points PFN_D3D12_GET_DEBUG_INTERFACE m_D3D12GetDebugInterface = nullptr; @@ -4999,6 +5021,7 @@ public: // as well as the command type to be used (DrawInstanced and DrawIndexedInstanced, in this case). ComPtr drawIndirectCmdSignature; ComPtr drawIndexedIndirectCmdSignature; + ComPtr dispatchIndirectCmdSignature; }; SLANG_NO_THROW Result SLANG_MCALL D3D12Device::TransientResourceHeapImpl::synchronizeAndReset() @@ -5189,18 +5212,34 @@ static void _initSrvDesc( { switch (desc.Dimension) { - case D3D12_RESOURCE_DIMENSION_TEXTURE1D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE1D; break; - case D3D12_RESOURCE_DIMENSION_TEXTURE2D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2D; break; - case D3D12_RESOURCE_DIMENSION_TEXTURE3D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE3D; break; - default: assert(!"Unknown dimension"); + case D3D12_RESOURCE_DIMENSION_TEXTURE1D: + descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE1D; + descOut.Texture1D.MipLevels = subresourceRange.mipLevelCount == 0 + ? desc.MipLevels - subresourceRange.mipLevel + : subresourceRange.mipLevelCount; + descOut.Texture1D.MostDetailedMip = subresourceRange.mipLevel; + break; + case D3D12_RESOURCE_DIMENSION_TEXTURE2D: + descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2D; + descOut.Texture2D.PlaneSlice = + D3DUtil::getPlaneSlice(descOut.Format, subresourceRange.aspectMask); + descOut.Texture2D.ResourceMinLODClamp = 0.0f; + descOut.Texture2D.MipLevels = subresourceRange.mipLevelCount == 0 + ? desc.MipLevels - subresourceRange.mipLevel + : subresourceRange.mipLevelCount; + descOut.Texture2D.MostDetailedMip = subresourceRange.mipLevel; + break; + case D3D12_RESOURCE_DIMENSION_TEXTURE3D: + descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE3D; + descOut.Texture3D.MipLevels = subresourceRange.mipLevelCount == 0 + ? desc.MipLevels - subresourceRange.mipLevel + : subresourceRange.mipLevelCount; + descOut.Texture3D.MostDetailedMip = subresourceRange.mipLevel; + break; + default: + assert(!"Unknown dimension"); } - descOut.Texture2D.MipLevels = - subresourceRange.mipLevelCount == 0 ? desc.MipLevels : subresourceRange.mipLevelCount; - descOut.Texture2D.MostDetailedMip = subresourceRange.mipLevel; - descOut.Texture2D.PlaneSlice = - D3DUtil::getPlaneSlice(descOut.Format, subresourceRange.aspectMask); - descOut.Texture2D.ResourceMinLODClamp = 0.0f; } else if (resourceType == IResource::Type::TextureCube) { @@ -5213,7 +5252,7 @@ static void _initSrvDesc( : subresourceRange.layerCount / 6; descOut.TextureCubeArray.First2DArrayFace = subresourceRange.baseArrayLayer; descOut.TextureCubeArray.MipLevels = subresourceRange.mipLevelCount == 0 - ? desc.MipLevels + ? desc.MipLevels - subresourceRange.mipLevel : subresourceRange.mipLevelCount; descOut.TextureCubeArray.MostDetailedMip = subresourceRange.mipLevel; descOut.TextureCubeArray.ResourceMinLODClamp = 0; @@ -5223,7 +5262,7 @@ static void _initSrvDesc( descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURECUBE; descOut.TextureCube.MipLevels = subresourceRange.mipLevelCount == 0 - ? desc.MipLevels + ? desc.MipLevels - subresourceRange.mipLevel : subresourceRange.mipLevelCount; descOut.TextureCube.MostDetailedMip = subresourceRange.mipLevel; descOut.TextureCube.ResourceMinLODClamp = 0; @@ -5235,22 +5274,46 @@ static void _initSrvDesc( switch (desc.Dimension) { - case D3D12_RESOURCE_DIMENSION_TEXTURE1D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE1DARRAY; break; - case D3D12_RESOURCE_DIMENSION_TEXTURE2D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2DARRAY; break; - case D3D12_RESOURCE_DIMENSION_TEXTURE3D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE3D; break; + case D3D12_RESOURCE_DIMENSION_TEXTURE1D: + descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE1DARRAY; + descOut.Texture1D.MostDetailedMip = subresourceRange.mipLevel; + descOut.Texture1D.MipLevels = subresourceRange.mipLevelCount == 0 + ? desc.MipLevels + : subresourceRange.mipLevelCount; + descOut.Texture1DArray.ArraySize = subresourceRange.layerCount == 0 + ? desc.DepthOrArraySize + : subresourceRange.layerCount; + descOut.Texture1DArray.FirstArraySlice = subresourceRange.baseArrayLayer; + descOut.Texture1DArray.ResourceMinLODClamp = 0; + descOut.Texture1DArray.MostDetailedMip = subresourceRange.mipLevel; + descOut.Texture1DArray.MipLevels = subresourceRange.mipLevelCount == 0 + ? desc.MipLevels - subresourceRange.mipLevel + : subresourceRange.mipLevelCount; + break; + case D3D12_RESOURCE_DIMENSION_TEXTURE2D: + descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2DARRAY; + descOut.Texture2DArray.ArraySize = + subresourceRange.layerCount == 0 ? desc.DepthOrArraySize : subresourceRange.layerCount; + descOut.Texture2DArray.FirstArraySlice = subresourceRange.baseArrayLayer; + descOut.Texture2DArray.PlaneSlice = + D3DUtil::getPlaneSlice(descOut.Format, subresourceRange.aspectMask); + descOut.Texture2DArray.ResourceMinLODClamp = 0; + descOut.Texture2DArray.MostDetailedMip = subresourceRange.mipLevel; + descOut.Texture2DArray.MipLevels = subresourceRange.mipLevelCount == 0 + ? desc.MipLevels - subresourceRange.mipLevel + : subresourceRange.mipLevelCount; + break; + case D3D12_RESOURCE_DIMENSION_TEXTURE3D: + descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE3D; + descOut.Texture3D.MostDetailedMip = subresourceRange.mipLevel; + descOut.Texture3D.MipLevels = subresourceRange.mipLevelCount == 0 + ? desc.MipLevels + : subresourceRange.mipLevelCount; + break; - default: assert(!"Unknown dimension"); + default: + assert(!"Unknown dimension"); } - - descOut.Texture2DArray.ArraySize = - subresourceRange.layerCount == 0 ? desc.DepthOrArraySize : subresourceRange.layerCount; - descOut.Texture2DArray.MostDetailedMip = subresourceRange.mipLevel; - descOut.Texture2DArray.MipLevels = - subresourceRange.mipLevelCount == 0 ? desc.MipLevels : subresourceRange.mipLevelCount; - descOut.Texture2DArray.FirstArraySlice = subresourceRange.baseArrayLayer; - descOut.Texture2DArray.PlaneSlice = - D3DUtil::getPlaneSlice(descOut.Format, subresourceRange.aspectMask); - descOut.Texture2DArray.ResourceMinLODClamp = 0; } } @@ -5822,20 +5885,20 @@ Result D3D12Device::initialize(const Desc& desc) // since this object is already owned by `D3D12Device`. m_resourceCommandTransientHeap->breakStrongReferenceToDevice(); - m_cpuViewHeap = new D3D12GeneralDescriptorHeap(); + m_cpuViewHeap = new D3D12GeneralExpandingDescriptorHeap(); SLANG_RETURN_ON_FAIL(m_cpuViewHeap->init( m_device, 1024 * 1024, D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE)); - m_cpuSamplerHeap = new D3D12GeneralDescriptorHeap(); + m_cpuSamplerHeap = new D3D12GeneralExpandingDescriptorHeap(); SLANG_RETURN_ON_FAIL(m_cpuSamplerHeap->init( m_device, 2048, D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER, D3D12_DESCRIPTOR_HEAP_FLAG_NONE)); - m_rtvAllocator = new D3D12GeneralDescriptorHeap(); + m_rtvAllocator = new D3D12GeneralExpandingDescriptorHeap(); SLANG_RETURN_ON_FAIL(m_rtvAllocator->init( m_device, 16 * 1024, D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE)); - m_dsvAllocator = new D3D12GeneralDescriptorHeap(); + m_dsvAllocator = new D3D12GeneralExpandingDescriptorHeap(); SLANG_RETURN_ON_FAIL(m_dsvAllocator->init( m_device, 1024, D3D12_DESCRIPTOR_HEAP_TYPE_DSV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE)); @@ -5936,6 +5999,21 @@ Result D3D12Device::initialize(const Desc& desc) SLANG_RETURN_ON_FAIL(m_device->CreateCommandSignature(&desc, nullptr, IID_PPV_ARGS(drawIndexedIndirectCmdSignature.writeRef()))); } + // Allocate a D3D12 "command signature" object that matches the behavior + // of a D3D11-style `Dispatch` operation. + { + D3D12_INDIRECT_ARGUMENT_DESC args; + args.Type = D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH; + + D3D12_COMMAND_SIGNATURE_DESC desc; + desc.ByteStride = sizeof(D3D12_DISPATCH_ARGUMENTS); + desc.NumArgumentDescs = 1; + desc.pArgumentDescs = &args; + desc.NodeMask = 0; + + SLANG_RETURN_ON_FAIL(m_device->CreateCommandSignature( + &desc, nullptr, IID_PPV_ARGS(dispatchIndirectCmdSignature.writeRef()))); + } m_isInitialized = true; return SLANG_OK; } @@ -6046,6 +6124,31 @@ static D3D12_RESOURCE_DIMENSION _calcResourceDimension(IResource::Type type) } } +DXGI_FORMAT getTypelessFormatFromDepthFormat(Format format) +{ + switch (format) + { + case Format::D16_UNORM: + return DXGI_FORMAT_R16_TYPELESS; + case Format::D32_FLOAT: + return DXGI_FORMAT_R32_TYPELESS; + default: + return D3DUtil::getMapFormat(format); + } +} + +BOOL isTypelessDepthFormat(DXGI_FORMAT format) +{ + switch (format) + { + case DXGI_FORMAT_R16_TYPELESS: + case DXGI_FORMAT_R32_TYPELESS: + return true; + default: + return false; + } +} + Result setupResourceDesc(D3D12_RESOURCE_DESC& resourceDesc, const ITextureResource::Desc& srcDesc) { const DXGI_FORMAT pixelFormat = D3DUtil::getMapFormat(srcDesc.format); @@ -6080,6 +6183,13 @@ Result setupResourceDesc(D3D12_RESOURCE_DESC& resourceDesc, const ITextureResour resourceDesc.Alignment = 0; + if (isDepthFormat(srcDesc.format) && + (srcDesc.allowedStates.contains(ResourceState::ShaderResource) || + srcDesc.allowedStates.contains(ResourceState::UnorderedAccess))) + { + resourceDesc.Format = getTypelessFormatFromDepthFormat(srcDesc.format); + } + return SLANG_OK; } @@ -6129,6 +6239,10 @@ Result D3D12Device::createTextureResource(const ITextureResource::Desc& descIn, { clearValuePtr = nullptr; } + if (isTypelessDepthFormat(resourceDesc.Format)) + { + clearValuePtr = nullptr; + } clearValue.Format = resourceDesc.Format; memcpy(clearValue.Color, &descIn.optimalClearValue.color, sizeof(clearValue.Color)); clearValue.DepthStencil.Depth = descIn.optimalClearValue.depthStencil.depth; @@ -6632,9 +6746,7 @@ Result D3D12Device::createTextureView(ITextureResource* texture, IResourceView:: d3d12desc.ViewDimension = D3D12_UAV_DIMENSION_TEXTURE3D; d3d12desc.Texture3D.MipSlice = desc.subresourceRange.mipLevel; d3d12desc.Texture3D.FirstWSlice = desc.subresourceRange.baseArrayLayer; - d3d12desc.Texture3D.WSize = desc.subresourceRange.layerCount == 0 - ? resourceDesc.size.depth - : desc.subresourceRange.layerCount; + d3d12desc.Texture3D.WSize = resourceDesc.size.depth; break; default: return SLANG_FAIL; @@ -6719,7 +6831,11 @@ Result D3D12Device::getFormatSupportedResourceStates(Format format, ResourceStat return SLANG_OK; } -Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Desc const& desc, IResourceView** outView) +Result D3D12Device::createBufferView( + IBufferResource* buffer, + IBufferResource* counterBuffer, + IResourceView::Desc const& desc, + IResourceView** outView) { auto resourceImpl = (BufferResourceImpl*) buffer; auto resourceDesc = *resourceImpl->getDesc(); @@ -6745,14 +6861,14 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des uavDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 ? UINT(resourceDesc.sizeInBytes / desc.bufferElementSize) - : desc.bufferRange.elementCount; + : (UINT)desc.bufferRange.elementCount; } else if(desc.format == Format::Unknown) { uavDesc.Format = DXGI_FORMAT_R32_TYPELESS; uavDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 ? UINT(resourceDesc.sizeInBytes / 4) - : desc.bufferRange.elementCount / 4; + : UINT(desc.bufferRange.elementCount / 4); uavDesc.Buffer.Flags |= D3D12_BUFFER_UAV_FLAG_RAW; } else @@ -6763,16 +6879,16 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des uavDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 ? UINT(resourceDesc.sizeInBytes / sizeInfo.blockSizeInBytes) - : desc.bufferRange.elementCount; + : (UINT)desc.bufferRange.elementCount; } - - - // TODO: need to support the separate "counter resource" for the case - // of append/consume buffers with attached counters. - + auto counterResourceImpl = static_cast(counterBuffer); SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor)); viewImpl->m_allocator = m_cpuViewHeap; - m_device->CreateUnorderedAccessView(resourceImpl->m_resource, nullptr, &uavDesc, viewImpl->m_descriptor.cpuHandle); + m_device->CreateUnorderedAccessView( + resourceImpl->m_resource, + counterResourceImpl ? counterResourceImpl->m_resource.getResource() : nullptr, + &uavDesc, + viewImpl->m_descriptor.cpuHandle); } break; @@ -6790,14 +6906,14 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des srvDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 ? UINT(resourceDesc.sizeInBytes / desc.bufferElementSize) - : desc.bufferRange.elementCount; + : (UINT)desc.bufferRange.elementCount; } else if (desc.format == Format::Unknown) { srvDesc.Format = DXGI_FORMAT_R32_TYPELESS; srvDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 ? UINT(resourceDesc.sizeInBytes / 4) - : desc.bufferRange.elementCount / 4; + : UINT(desc.bufferRange.elementCount / 4); srvDesc.Buffer.Flags |= D3D12_BUFFER_SRV_FLAG_RAW; } else @@ -6808,7 +6924,7 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des srvDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0 ? UINT(resourceDesc.sizeInBytes / sizeInfo.blockSizeInBytes) - : desc.bufferRange.elementCount; + : (UINT)desc.bufferRange.elementCount; } SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor)); @@ -7021,12 +7137,12 @@ Result D3D12Device::readBufferResource( Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnosticBlob) { RefPtr shaderProgram = new ShaderProgramImpl(); - shaderProgram->slangProgram = desc.slangProgram; + shaderProgram->init(desc); ComPtr d3dDiagnosticBlob; auto rootShaderLayoutResult = RootShaderObjectLayoutImpl::create( this, - desc.slangProgram, - desc.slangProgram->getLayout(), + shaderProgram->linkedProgram, + shaderProgram->linkedProgram->getLayout(), shaderProgram->m_rootObjectLayout.writeRef(), d3dDiagnosticBlob.writeRef()); if (!SLANG_SUCCEEDED(rootShaderLayoutResult)) @@ -7039,22 +7155,22 @@ Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr } return rootShaderLayoutResult; } - if (desc.slangProgram->getSpecializationParamCount() != 0) + if (shaderProgram->isSpecializable()) { // For a specializable program, we don't invoke any actual slang compilation yet. returnComPtr(outProgram, shaderProgram); return SLANG_OK; } // For a fully specialized program, read and store its kernel code in `shaderProgram`. - auto programReflection = desc.slangProgram->getLayout(); - for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++) + auto compileShader = [&](slang::EntryPointReflection* entryPointInfo, + slang::IComponentType* entryPointComponent, + SlangInt entryPointIndex) { - auto entryPointInfo = programReflection->getEntryPointByIndex(i); auto stage = entryPointInfo->getStage(); ComPtr kernelCode; ComPtr diagnostics; - auto compileResult = desc.slangProgram->getEntryPointCode( - (SlangInt)i, 0, kernelCode.writeRef(), diagnostics.writeRef()); + auto compileResult = entryPointComponent->getEntryPointCode( + entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef()); if (diagnostics) { getDebugCallback()->handleMessage( @@ -7067,10 +7183,35 @@ Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr SLANG_RETURN_ON_FAIL(compileResult); ShaderBinary shaderBin; shaderBin.stage = stage; + shaderBin.entryPointInfo = entryPointInfo; shaderBin.code.addRange( reinterpret_cast(kernelCode->getBufferPointer()), (Index)kernelCode->getBufferSize()); shaderProgram->m_shaders.add(_Move(shaderBin)); + return SLANG_OK; + }; + + if (shaderProgram->linkedEntryPoints.getCount() == 0) + { + // If the user does not explicitly specify entry point components, find them from `linkedEntryPoints`. + auto programReflection = shaderProgram->linkedProgram->getLayout(); + for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++) + { + SLANG_RETURN_ON_FAIL(compileShader( + programReflection->getEntryPointByIndex(i), + shaderProgram->linkedProgram, + (SlangInt)i)); + } + } + else + { + // If the user specifies entry point components via the separated entry point array, compile code + // from there. + for (auto& entryPoint : shaderProgram->linkedEntryPoints) + { + SLANG_RETURN_ON_FAIL( + compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0)); + } } returnComPtr(outProgram, shaderProgram); return SLANG_OK; @@ -7150,6 +7291,7 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& D3D12_GRAPHICS_PIPELINE_STATE_DESC psoDesc = {}; psoDesc.pRootSignature = programImpl->m_rootObjectLayout->m_rootSignature; + for (auto& shaderBin : programImpl->m_shaders) { switch (shaderBin.stage) @@ -7283,7 +7425,7 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc& if (m_pipelineCreationAPIDispatcher) { SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState( - this, programImpl->slangProgram.get(), &psoDesc, (void**)pipelineState.writeRef())); + this, programImpl->linkedProgram.get(), &psoDesc, (void**)pipelineState.writeRef())); } else { @@ -7312,7 +7454,7 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i // Only actually create a D3D12 pipeline state if the pipeline is fully specialized. ComPtr pipelineState; - if (!programImpl->slangProgram || programImpl->slangProgram->getSpecializationParamCount() == 0) + if (!programImpl->isSpecializable()) { // Describe and create the compute pipeline state object D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {}; @@ -7361,7 +7503,7 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i { SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState( this, - programImpl->slangProgram.get(), + programImpl->linkedProgram.get(), &computeDesc, (void**)pipelineState.writeRef())); } @@ -7835,8 +7977,8 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD pipelineStateImpl->init(inDesc); auto program = static_cast(inDesc.program); - auto slangProgram = program->slangProgram; - auto programLayout = slangProgram->getLayout(); + auto slangGlobalScope = program->linkedProgram; + auto programLayout = slangGlobalScope->getLayout(); if (!program->m_rootObjectLayout->m_rootSignature) { @@ -7847,13 +7989,24 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD ChunkedList dxilLibraries; ChunkedList hitGroups; ChunkedList> codeBlobs; + ChunkedList exports; + ChunkedList strPtrs; + ComPtr diagnostics; ChunkedList stringPool; - for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) + auto getWStr = [&](const char* name) + { + String str = String(name); + auto wstr = str.toWString(); + return stringPool.add(wstr)->begin(); + }; + auto compileShader = [&](slang::EntryPointLayout* entryPointInfo, + slang::IComponentType* component, + SlangInt entryPointIndex) { ComPtr codeBlob; - auto compileResult = - slangProgram->getEntryPointCode(i, 0, codeBlob.writeRef(), diagnostics.writeRef()); + auto compileResult = component->getEntryPointCode( + entryPointIndex, 0, codeBlob.writeRef(), diagnostics.writeRef()); if (diagnostics.get()) { getDebugCallback()->handleMessage( @@ -7864,20 +8017,38 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD SLANG_RETURN_ON_FAIL(compileResult); codeBlobs.add(codeBlob); D3D12_DXIL_LIBRARY_DESC library = {}; - library.DXILLibrary.BytecodeLength = codeBlob->getBufferSize();; + library.DXILLibrary.BytecodeLength = codeBlob->getBufferSize(); library.DXILLibrary.pShaderBytecode = codeBlob->getBufferPointer(); + library.NumExports = 1; + D3D12_EXPORT_DESC exportDesc = {}; + exportDesc.Name = getWStr(entryPointInfo->getNameOverride()); + exportDesc.ExportToRename = getWStr(entryPointInfo->getNameOverride()); + exportDesc.Flags = D3D12_EXPORT_FLAG_NONE; + library.pExports = exports.add(exportDesc); D3D12_STATE_SUBOBJECT dxilSubObject = {}; dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY; dxilSubObject.pDesc = dxilLibraries.add(library); subObjects.add(dxilSubObject); + return SLANG_OK; + }; + if (program->linkedEntryPoints.getCount() == 0) + { + for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++) + { + SLANG_RETURN_ON_FAIL(compileShader( + programLayout->getEntryPointByIndex(i), program->linkedProgram, (SlangInt)i)); + } } - auto getWStr = [&](const char* name) + else { - String str = String(name); - auto wstr = str.toWString(); - return stringPool.add(wstr)->begin(); - }; + for (auto& entryPoint : program->linkedEntryPoints) + { + SLANG_RETURN_ON_FAIL( + compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0)); + } + } + for (int i = 0; i < inDesc.hitGroupCount; i++) { auto hitGroup = inDesc.hitGroups[i]; @@ -7909,7 +8080,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD D3D12_RAYTRACING_SHADER_CONFIG shaderConfig = {}; // According to DXR spec, fixed function triangle intersections must use float2 as ray attributes // that defines the barycentric coordinates at intersection. - shaderConfig.MaxAttributeSizeInBytes = sizeof(float) * 2; + shaderConfig.MaxAttributeSizeInBytes = inDesc.maxAttributeSizeInBytes; shaderConfig.MaxPayloadSizeInBytes = inDesc.maxRayPayloadSize; D3D12_STATE_SUBOBJECT shaderConfigSubObject = {}; shaderConfigSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG; @@ -7932,7 +8103,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD if (m_pipelineCreationAPIDispatcher) { - m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangProgram); + m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangGlobalScope); } D3D12_STATE_OBJECT_DESC rtpsoDesc = {}; @@ -7943,7 +8114,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD if (m_pipelineCreationAPIDispatcher) { - m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangProgram); + m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangGlobalScope); } returnComPtr(outState, pipelineStateImpl); -- cgit v1.2.3