summaryrefslogtreecommitdiffstats
path: root/tools/gfx/d3d12/render-d3d12.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-02-09 15:30:38 -0800
committerGitHub <noreply@github.com>2022-02-09 15:30:38 -0800
commitb8982fcf43b86c1e39dcc3dd19bff2821633eda6 (patch)
tree0d66dbf46b50e760cce4aee232bd6a020976e6fb /tools/gfx/d3d12/render-d3d12.cpp
parent59f3fdc0a372d19ce4e989514ee3e9ecbcbf234c (diff)
Various fixes to gfx. (#2120)
* Various fixes to gfx. * Fix. * Fixes. * Fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'tools/gfx/d3d12/render-d3d12.cpp')
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp341
1 files changed, 256 insertions, 85 deletions
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index a6d02cdc3..e5a1d1876 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -110,7 +110,10 @@ public:
IResourceView::Desc const& desc,
IResourceView** outView) override;
virtual SLANG_NO_THROW Result SLANG_MCALL createBufferView(
- IBufferResource* buffer, IResourceView::Desc const& desc, IResourceView** outView) override;
+ IBufferResource* buffer,
+ IBufferResource* counterBuffer,
+ IResourceView::Desc const& desc,
+ IResourceView** outView) override;
virtual SLANG_NO_THROW Result SLANG_MCALL
createFramebuffer(IFramebuffer::Desc const& desc, IFramebuffer** outFrameBuffer) override;
@@ -377,7 +380,7 @@ public:
{
public:
D3D12Descriptor m_descriptor;
- Slang::RefPtr<D3D12GeneralDescriptorHeap> m_allocator;
+ Slang::RefPtr<D3D12GeneralExpandingDescriptorHeap> m_allocator;
~SamplerStateImpl()
{
m_allocator->free(m_descriptor);
@@ -394,7 +397,7 @@ public:
{
public:
D3D12Descriptor m_descriptor;
- RefPtr<D3D12GeneralDescriptorHeap> m_allocator;
+ RefPtr<D3D12GeneralExpandingDescriptorHeap> m_allocator;
~ResourceViewInternalImpl() { m_allocator->free(m_descriptor); }
};
@@ -2279,6 +2282,7 @@ public:
struct ShaderBinary
{
SlangStage stage;
+ slang::EntryPointReflection* entryPointInfo;
List<uint8_t> code;
};
@@ -3934,7 +3938,24 @@ public:
virtual SLANG_NO_THROW void SLANG_MCALL
dispatchComputeIndirect(IBufferResource* argBuffer, uint64_t offset) override
{
- SLANG_UNIMPLEMENTED_X("dispatchComputeIndirect");
+ // Submit binding for compute
+ {
+ ComputeSubmitter submitter(m_d3dCmdList);
+ RefPtr<PipelineStateBase> newPipeline;
+ if (SLANG_FAILED(_bindRenderState(&submitter, newPipeline)))
+ {
+ assert(!"Failed to bind render state");
+ }
+ }
+ auto argBufferImpl = static_cast<BufferResourceImpl*>(argBuffer);
+
+ m_d3dCmdList->ExecuteIndirect(
+ m_renderer->dispatchIndirectCmdSignature,
+ 1,
+ argBufferImpl->m_resource,
+ offset,
+ nullptr,
+ 0);
}
};
@@ -4233,11 +4254,13 @@ public:
}
auto rowSize = (footprint.Footprint.Width + formatInfo.blockWidth - 1) /
formatInfo.blockWidth * formatInfo.blockSizeInBytes;
+ auto rowCount = (footprint.Footprint.Height + formatInfo.blockHeight - 1) /
+ formatInfo.blockHeight;
footprint.Footprint.RowPitch = (UINT)D3DUtil::calcAligned(
rowSize, (uint32_t)D3D12_TEXTURE_DATA_PITCH_ALIGNMENT);
- auto bufferSize = footprint.Footprint.RowPitch * footprint.Footprint.Height *
- footprint.Footprint.Depth;
+ auto bufferSize =
+ footprint.Footprint.RowPitch * rowCount * footprint.Footprint.Depth;
IBufferResource* stagingBuffer;
m_commandBuffer->m_transientHeap->allocateStagingBuffer(
@@ -4249,11 +4272,10 @@ public:
bufferImpl->m_resource.getResource()->Map(0, &mapRange, (void**)&bufferData);
for (uint32_t z = 0; z < footprint.Footprint.Depth; z++)
{
- auto imageStart = bufferData + footprint.Footprint.RowPitch *
- footprint.Footprint.Height * (size_t)z;
+ auto imageStart = bufferData + footprint.Footprint.RowPitch * rowCount * (size_t)z;
auto srcData =
(uint8_t*)subResourceData->data + subResourceData->strideZ * z;
- for (uint32_t row = 0; row < footprint.Footprint.Height; row++)
+ for (uint32_t row = 0; row < rowCount; row++)
{
memcpy(
imageStart + row * (size_t)footprint.Footprint.RowPitch,
@@ -4979,14 +5001,14 @@ public:
RefPtr<CommandQueueImpl> m_resourceCommandQueue;
RefPtr<TransientResourceHeapImpl> m_resourceCommandTransientHeap;
- RefPtr<D3D12GeneralDescriptorHeap> m_rtvAllocator;
- RefPtr<D3D12GeneralDescriptorHeap> m_dsvAllocator;
+ RefPtr<D3D12GeneralExpandingDescriptorHeap> m_rtvAllocator;
+ RefPtr<D3D12GeneralExpandingDescriptorHeap> m_dsvAllocator;
// Space in the GPU-visible heaps is precious, so we will also keep
// around CPU-visible heaps for storing descriptors in a format
// that is ready for copying into the GPU-visible heaps as needed.
//
- RefPtr<D3D12GeneralDescriptorHeap> m_cpuViewHeap; ///< Cbv, Srv, Uav
- RefPtr<D3D12GeneralDescriptorHeap> m_cpuSamplerHeap; ///< Heap for samplers
+ RefPtr<D3D12GeneralExpandingDescriptorHeap> m_cpuViewHeap; ///< Cbv, Srv, Uav
+ RefPtr<D3D12GeneralExpandingDescriptorHeap> m_cpuSamplerHeap; ///< Heap for samplers
// Dll entry points
PFN_D3D12_GET_DEBUG_INTERFACE m_D3D12GetDebugInterface = nullptr;
@@ -4999,6 +5021,7 @@ public:
// as well as the command type to be used (DrawInstanced and DrawIndexedInstanced, in this case).
ComPtr<ID3D12CommandSignature> drawIndirectCmdSignature;
ComPtr<ID3D12CommandSignature> drawIndexedIndirectCmdSignature;
+ ComPtr<ID3D12CommandSignature> dispatchIndirectCmdSignature;
};
SLANG_NO_THROW Result SLANG_MCALL D3D12Device::TransientResourceHeapImpl::synchronizeAndReset()
@@ -5189,18 +5212,34 @@ static void _initSrvDesc(
{
switch (desc.Dimension)
{
- case D3D12_RESOURCE_DIMENSION_TEXTURE1D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE1D; break;
- case D3D12_RESOURCE_DIMENSION_TEXTURE2D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2D; break;
- case D3D12_RESOURCE_DIMENSION_TEXTURE3D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE3D; break;
- default: assert(!"Unknown dimension");
+ case D3D12_RESOURCE_DIMENSION_TEXTURE1D:
+ descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE1D;
+ descOut.Texture1D.MipLevels = subresourceRange.mipLevelCount == 0
+ ? desc.MipLevels - subresourceRange.mipLevel
+ : subresourceRange.mipLevelCount;
+ descOut.Texture1D.MostDetailedMip = subresourceRange.mipLevel;
+ break;
+ case D3D12_RESOURCE_DIMENSION_TEXTURE2D:
+ descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2D;
+ descOut.Texture2D.PlaneSlice =
+ D3DUtil::getPlaneSlice(descOut.Format, subresourceRange.aspectMask);
+ descOut.Texture2D.ResourceMinLODClamp = 0.0f;
+ descOut.Texture2D.MipLevels = subresourceRange.mipLevelCount == 0
+ ? desc.MipLevels - subresourceRange.mipLevel
+ : subresourceRange.mipLevelCount;
+ descOut.Texture2D.MostDetailedMip = subresourceRange.mipLevel;
+ break;
+ case D3D12_RESOURCE_DIMENSION_TEXTURE3D:
+ descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE3D;
+ descOut.Texture3D.MipLevels = subresourceRange.mipLevelCount == 0
+ ? desc.MipLevels - subresourceRange.mipLevel
+ : subresourceRange.mipLevelCount;
+ descOut.Texture3D.MostDetailedMip = subresourceRange.mipLevel;
+ break;
+ default:
+ assert(!"Unknown dimension");
}
- descOut.Texture2D.MipLevels =
- subresourceRange.mipLevelCount == 0 ? desc.MipLevels : subresourceRange.mipLevelCount;
- descOut.Texture2D.MostDetailedMip = subresourceRange.mipLevel;
- descOut.Texture2D.PlaneSlice =
- D3DUtil::getPlaneSlice(descOut.Format, subresourceRange.aspectMask);
- descOut.Texture2D.ResourceMinLODClamp = 0.0f;
}
else if (resourceType == IResource::Type::TextureCube)
{
@@ -5213,7 +5252,7 @@ static void _initSrvDesc(
: subresourceRange.layerCount / 6;
descOut.TextureCubeArray.First2DArrayFace = subresourceRange.baseArrayLayer;
descOut.TextureCubeArray.MipLevels = subresourceRange.mipLevelCount == 0
- ? desc.MipLevels
+ ? desc.MipLevels - subresourceRange.mipLevel
: subresourceRange.mipLevelCount;
descOut.TextureCubeArray.MostDetailedMip = subresourceRange.mipLevel;
descOut.TextureCubeArray.ResourceMinLODClamp = 0;
@@ -5223,7 +5262,7 @@ static void _initSrvDesc(
descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURECUBE;
descOut.TextureCube.MipLevels = subresourceRange.mipLevelCount == 0
- ? desc.MipLevels
+ ? desc.MipLevels - subresourceRange.mipLevel
: subresourceRange.mipLevelCount;
descOut.TextureCube.MostDetailedMip = subresourceRange.mipLevel;
descOut.TextureCube.ResourceMinLODClamp = 0;
@@ -5235,22 +5274,46 @@ static void _initSrvDesc(
switch (desc.Dimension)
{
- case D3D12_RESOURCE_DIMENSION_TEXTURE1D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE1DARRAY; break;
- case D3D12_RESOURCE_DIMENSION_TEXTURE2D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2DARRAY; break;
- case D3D12_RESOURCE_DIMENSION_TEXTURE3D: descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE3D; break;
+ case D3D12_RESOURCE_DIMENSION_TEXTURE1D:
+ descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE1DARRAY;
+ descOut.Texture1D.MostDetailedMip = subresourceRange.mipLevel;
+ descOut.Texture1D.MipLevels = subresourceRange.mipLevelCount == 0
+ ? desc.MipLevels
+ : subresourceRange.mipLevelCount;
+ descOut.Texture1DArray.ArraySize = subresourceRange.layerCount == 0
+ ? desc.DepthOrArraySize
+ : subresourceRange.layerCount;
+ descOut.Texture1DArray.FirstArraySlice = subresourceRange.baseArrayLayer;
+ descOut.Texture1DArray.ResourceMinLODClamp = 0;
+ descOut.Texture1DArray.MostDetailedMip = subresourceRange.mipLevel;
+ descOut.Texture1DArray.MipLevels = subresourceRange.mipLevelCount == 0
+ ? desc.MipLevels - subresourceRange.mipLevel
+ : subresourceRange.mipLevelCount;
+ break;
+ case D3D12_RESOURCE_DIMENSION_TEXTURE2D:
+ descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE2DARRAY;
+ descOut.Texture2DArray.ArraySize =
+ subresourceRange.layerCount == 0 ? desc.DepthOrArraySize : subresourceRange.layerCount;
+ descOut.Texture2DArray.FirstArraySlice = subresourceRange.baseArrayLayer;
+ descOut.Texture2DArray.PlaneSlice =
+ D3DUtil::getPlaneSlice(descOut.Format, subresourceRange.aspectMask);
+ descOut.Texture2DArray.ResourceMinLODClamp = 0;
+ descOut.Texture2DArray.MostDetailedMip = subresourceRange.mipLevel;
+ descOut.Texture2DArray.MipLevels = subresourceRange.mipLevelCount == 0
+ ? desc.MipLevels - subresourceRange.mipLevel
+ : subresourceRange.mipLevelCount;
+ break;
+ case D3D12_RESOURCE_DIMENSION_TEXTURE3D:
+ descOut.ViewDimension = D3D12_SRV_DIMENSION_TEXTURE3D;
+ descOut.Texture3D.MostDetailedMip = subresourceRange.mipLevel;
+ descOut.Texture3D.MipLevels = subresourceRange.mipLevelCount == 0
+ ? desc.MipLevels
+ : subresourceRange.mipLevelCount;
+ break;
- default: assert(!"Unknown dimension");
+ default:
+ assert(!"Unknown dimension");
}
-
- descOut.Texture2DArray.ArraySize =
- subresourceRange.layerCount == 0 ? desc.DepthOrArraySize : subresourceRange.layerCount;
- descOut.Texture2DArray.MostDetailedMip = subresourceRange.mipLevel;
- descOut.Texture2DArray.MipLevels =
- subresourceRange.mipLevelCount == 0 ? desc.MipLevels : subresourceRange.mipLevelCount;
- descOut.Texture2DArray.FirstArraySlice = subresourceRange.baseArrayLayer;
- descOut.Texture2DArray.PlaneSlice =
- D3DUtil::getPlaneSlice(descOut.Format, subresourceRange.aspectMask);
- descOut.Texture2DArray.ResourceMinLODClamp = 0;
}
}
@@ -5822,20 +5885,20 @@ Result D3D12Device::initialize(const Desc& desc)
// since this object is already owned by `D3D12Device`.
m_resourceCommandTransientHeap->breakStrongReferenceToDevice();
- m_cpuViewHeap = new D3D12GeneralDescriptorHeap();
+ m_cpuViewHeap = new D3D12GeneralExpandingDescriptorHeap();
SLANG_RETURN_ON_FAIL(m_cpuViewHeap->init(
m_device,
1024 * 1024,
D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV,
D3D12_DESCRIPTOR_HEAP_FLAG_NONE));
- m_cpuSamplerHeap = new D3D12GeneralDescriptorHeap();
+ m_cpuSamplerHeap = new D3D12GeneralExpandingDescriptorHeap();
SLANG_RETURN_ON_FAIL(m_cpuSamplerHeap->init(
m_device, 2048, D3D12_DESCRIPTOR_HEAP_TYPE_SAMPLER, D3D12_DESCRIPTOR_HEAP_FLAG_NONE));
- m_rtvAllocator = new D3D12GeneralDescriptorHeap();
+ m_rtvAllocator = new D3D12GeneralExpandingDescriptorHeap();
SLANG_RETURN_ON_FAIL(m_rtvAllocator->init(
m_device, 16 * 1024, D3D12_DESCRIPTOR_HEAP_TYPE_RTV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE));
- m_dsvAllocator = new D3D12GeneralDescriptorHeap();
+ m_dsvAllocator = new D3D12GeneralExpandingDescriptorHeap();
SLANG_RETURN_ON_FAIL(m_dsvAllocator->init(
m_device, 1024, D3D12_DESCRIPTOR_HEAP_TYPE_DSV, D3D12_DESCRIPTOR_HEAP_FLAG_NONE));
@@ -5936,6 +5999,21 @@ Result D3D12Device::initialize(const Desc& desc)
SLANG_RETURN_ON_FAIL(m_device->CreateCommandSignature(&desc, nullptr, IID_PPV_ARGS(drawIndexedIndirectCmdSignature.writeRef())));
}
+ // Allocate a D3D12 "command signature" object that matches the behavior
+ // of a D3D11-style `Dispatch` operation.
+ {
+ D3D12_INDIRECT_ARGUMENT_DESC args;
+ args.Type = D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH;
+
+ D3D12_COMMAND_SIGNATURE_DESC desc;
+ desc.ByteStride = sizeof(D3D12_DISPATCH_ARGUMENTS);
+ desc.NumArgumentDescs = 1;
+ desc.pArgumentDescs = &args;
+ desc.NodeMask = 0;
+
+ SLANG_RETURN_ON_FAIL(m_device->CreateCommandSignature(
+ &desc, nullptr, IID_PPV_ARGS(dispatchIndirectCmdSignature.writeRef())));
+ }
m_isInitialized = true;
return SLANG_OK;
}
@@ -6046,6 +6124,31 @@ static D3D12_RESOURCE_DIMENSION _calcResourceDimension(IResource::Type type)
}
}
+DXGI_FORMAT getTypelessFormatFromDepthFormat(Format format)
+{
+ switch (format)
+ {
+ case Format::D16_UNORM:
+ return DXGI_FORMAT_R16_TYPELESS;
+ case Format::D32_FLOAT:
+ return DXGI_FORMAT_R32_TYPELESS;
+ default:
+ return D3DUtil::getMapFormat(format);
+ }
+}
+
+BOOL isTypelessDepthFormat(DXGI_FORMAT format)
+{
+ switch (format)
+ {
+ case DXGI_FORMAT_R16_TYPELESS:
+ case DXGI_FORMAT_R32_TYPELESS:
+ return true;
+ default:
+ return false;
+ }
+}
+
Result setupResourceDesc(D3D12_RESOURCE_DESC& resourceDesc, const ITextureResource::Desc& srcDesc)
{
const DXGI_FORMAT pixelFormat = D3DUtil::getMapFormat(srcDesc.format);
@@ -6080,6 +6183,13 @@ Result setupResourceDesc(D3D12_RESOURCE_DESC& resourceDesc, const ITextureResour
resourceDesc.Alignment = 0;
+ if (isDepthFormat(srcDesc.format) &&
+ (srcDesc.allowedStates.contains(ResourceState::ShaderResource) ||
+ srcDesc.allowedStates.contains(ResourceState::UnorderedAccess)))
+ {
+ resourceDesc.Format = getTypelessFormatFromDepthFormat(srcDesc.format);
+ }
+
return SLANG_OK;
}
@@ -6129,6 +6239,10 @@ Result D3D12Device::createTextureResource(const ITextureResource::Desc& descIn,
{
clearValuePtr = nullptr;
}
+ if (isTypelessDepthFormat(resourceDesc.Format))
+ {
+ clearValuePtr = nullptr;
+ }
clearValue.Format = resourceDesc.Format;
memcpy(clearValue.Color, &descIn.optimalClearValue.color, sizeof(clearValue.Color));
clearValue.DepthStencil.Depth = descIn.optimalClearValue.depthStencil.depth;
@@ -6632,9 +6746,7 @@ Result D3D12Device::createTextureView(ITextureResource* texture, IResourceView::
d3d12desc.ViewDimension = D3D12_UAV_DIMENSION_TEXTURE3D;
d3d12desc.Texture3D.MipSlice = desc.subresourceRange.mipLevel;
d3d12desc.Texture3D.FirstWSlice = desc.subresourceRange.baseArrayLayer;
- d3d12desc.Texture3D.WSize = desc.subresourceRange.layerCount == 0
- ? resourceDesc.size.depth
- : desc.subresourceRange.layerCount;
+ d3d12desc.Texture3D.WSize = resourceDesc.size.depth;
break;
default:
return SLANG_FAIL;
@@ -6719,7 +6831,11 @@ Result D3D12Device::getFormatSupportedResourceStates(Format format, ResourceStat
return SLANG_OK;
}
-Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Desc const& desc, IResourceView** outView)
+Result D3D12Device::createBufferView(
+ IBufferResource* buffer,
+ IBufferResource* counterBuffer,
+ IResourceView::Desc const& desc,
+ IResourceView** outView)
{
auto resourceImpl = (BufferResourceImpl*) buffer;
auto resourceDesc = *resourceImpl->getDesc();
@@ -6745,14 +6861,14 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des
uavDesc.Buffer.NumElements =
desc.bufferRange.elementCount == 0
? UINT(resourceDesc.sizeInBytes / desc.bufferElementSize)
- : desc.bufferRange.elementCount;
+ : (UINT)desc.bufferRange.elementCount;
}
else if(desc.format == Format::Unknown)
{
uavDesc.Format = DXGI_FORMAT_R32_TYPELESS;
uavDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0
? UINT(resourceDesc.sizeInBytes / 4)
- : desc.bufferRange.elementCount / 4;
+ : UINT(desc.bufferRange.elementCount / 4);
uavDesc.Buffer.Flags |= D3D12_BUFFER_UAV_FLAG_RAW;
}
else
@@ -6763,16 +6879,16 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des
uavDesc.Buffer.NumElements =
desc.bufferRange.elementCount == 0
? UINT(resourceDesc.sizeInBytes / sizeInfo.blockSizeInBytes)
- : desc.bufferRange.elementCount;
+ : (UINT)desc.bufferRange.elementCount;
}
-
-
- // TODO: need to support the separate "counter resource" for the case
- // of append/consume buffers with attached counters.
-
+ auto counterResourceImpl = static_cast<BufferResourceImpl*>(counterBuffer);
SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor));
viewImpl->m_allocator = m_cpuViewHeap;
- m_device->CreateUnorderedAccessView(resourceImpl->m_resource, nullptr, &uavDesc, viewImpl->m_descriptor.cpuHandle);
+ m_device->CreateUnorderedAccessView(
+ resourceImpl->m_resource,
+ counterResourceImpl ? counterResourceImpl->m_resource.getResource() : nullptr,
+ &uavDesc,
+ viewImpl->m_descriptor.cpuHandle);
}
break;
@@ -6790,14 +6906,14 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des
srvDesc.Buffer.NumElements =
desc.bufferRange.elementCount == 0
? UINT(resourceDesc.sizeInBytes / desc.bufferElementSize)
- : desc.bufferRange.elementCount;
+ : (UINT)desc.bufferRange.elementCount;
}
else if (desc.format == Format::Unknown)
{
srvDesc.Format = DXGI_FORMAT_R32_TYPELESS;
srvDesc.Buffer.NumElements = desc.bufferRange.elementCount == 0
? UINT(resourceDesc.sizeInBytes / 4)
- : desc.bufferRange.elementCount / 4;
+ : UINT(desc.bufferRange.elementCount / 4);
srvDesc.Buffer.Flags |= D3D12_BUFFER_SRV_FLAG_RAW;
}
else
@@ -6808,7 +6924,7 @@ Result D3D12Device::createBufferView(IBufferResource* buffer, IResourceView::Des
srvDesc.Buffer.NumElements =
desc.bufferRange.elementCount == 0
? UINT(resourceDesc.sizeInBytes / sizeInfo.blockSizeInBytes)
- : desc.bufferRange.elementCount;
+ : (UINT)desc.bufferRange.elementCount;
}
SLANG_RETURN_ON_FAIL(m_cpuViewHeap->allocate(&viewImpl->m_descriptor));
@@ -7021,12 +7137,12 @@ Result D3D12Device::readBufferResource(
Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgram** outProgram, ISlangBlob** outDiagnosticBlob)
{
RefPtr<ShaderProgramImpl> shaderProgram = new ShaderProgramImpl();
- shaderProgram->slangProgram = desc.slangProgram;
+ shaderProgram->init(desc);
ComPtr<ID3DBlob> d3dDiagnosticBlob;
auto rootShaderLayoutResult = RootShaderObjectLayoutImpl::create(
this,
- desc.slangProgram,
- desc.slangProgram->getLayout(),
+ shaderProgram->linkedProgram,
+ shaderProgram->linkedProgram->getLayout(),
shaderProgram->m_rootObjectLayout.writeRef(),
d3dDiagnosticBlob.writeRef());
if (!SLANG_SUCCEEDED(rootShaderLayoutResult))
@@ -7039,22 +7155,22 @@ Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr
}
return rootShaderLayoutResult;
}
- if (desc.slangProgram->getSpecializationParamCount() != 0)
+ if (shaderProgram->isSpecializable())
{
// For a specializable program, we don't invoke any actual slang compilation yet.
returnComPtr(outProgram, shaderProgram);
return SLANG_OK;
}
// For a fully specialized program, read and store its kernel code in `shaderProgram`.
- auto programReflection = desc.slangProgram->getLayout();
- for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++)
+ auto compileShader = [&](slang::EntryPointReflection* entryPointInfo,
+ slang::IComponentType* entryPointComponent,
+ SlangInt entryPointIndex)
{
- auto entryPointInfo = programReflection->getEntryPointByIndex(i);
auto stage = entryPointInfo->getStage();
ComPtr<ISlangBlob> kernelCode;
ComPtr<ISlangBlob> diagnostics;
- auto compileResult = desc.slangProgram->getEntryPointCode(
- (SlangInt)i, 0, kernelCode.writeRef(), diagnostics.writeRef());
+ auto compileResult = entryPointComponent->getEntryPointCode(
+ entryPointIndex, 0, kernelCode.writeRef(), diagnostics.writeRef());
if (diagnostics)
{
getDebugCallback()->handleMessage(
@@ -7067,10 +7183,35 @@ Result D3D12Device::createProgram(const IShaderProgram::Desc& desc, IShaderProgr
SLANG_RETURN_ON_FAIL(compileResult);
ShaderBinary shaderBin;
shaderBin.stage = stage;
+ shaderBin.entryPointInfo = entryPointInfo;
shaderBin.code.addRange(
reinterpret_cast<const uint8_t*>(kernelCode->getBufferPointer()),
(Index)kernelCode->getBufferSize());
shaderProgram->m_shaders.add(_Move(shaderBin));
+ return SLANG_OK;
+ };
+
+ if (shaderProgram->linkedEntryPoints.getCount() == 0)
+ {
+ // If the user does not explicitly specify entry point components, find them from `linkedEntryPoints`.
+ auto programReflection = shaderProgram->linkedProgram->getLayout();
+ for (SlangUInt i = 0; i < programReflection->getEntryPointCount(); i++)
+ {
+ SLANG_RETURN_ON_FAIL(compileShader(
+ programReflection->getEntryPointByIndex(i),
+ shaderProgram->linkedProgram,
+ (SlangInt)i));
+ }
+ }
+ else
+ {
+ // If the user specifies entry point components via the separated entry point array, compile code
+ // from there.
+ for (auto& entryPoint : shaderProgram->linkedEntryPoints)
+ {
+ SLANG_RETURN_ON_FAIL(
+ compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0));
+ }
}
returnComPtr(outProgram, shaderProgram);
return SLANG_OK;
@@ -7150,6 +7291,7 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc&
D3D12_GRAPHICS_PIPELINE_STATE_DESC psoDesc = {};
psoDesc.pRootSignature = programImpl->m_rootObjectLayout->m_rootSignature;
+
for (auto& shaderBin : programImpl->m_shaders)
{
switch (shaderBin.stage)
@@ -7283,7 +7425,7 @@ Result D3D12Device::createGraphicsPipelineState(const GraphicsPipelineStateDesc&
if (m_pipelineCreationAPIDispatcher)
{
SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createGraphicsPipelineState(
- this, programImpl->slangProgram.get(), &psoDesc, (void**)pipelineState.writeRef()));
+ this, programImpl->linkedProgram.get(), &psoDesc, (void**)pipelineState.writeRef()));
}
else
{
@@ -7312,7 +7454,7 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
// Only actually create a D3D12 pipeline state if the pipeline is fully specialized.
ComPtr<ID3D12PipelineState> pipelineState;
- if (!programImpl->slangProgram || programImpl->slangProgram->getSpecializationParamCount() == 0)
+ if (!programImpl->isSpecializable())
{
// Describe and create the compute pipeline state object
D3D12_COMPUTE_PIPELINE_STATE_DESC computeDesc = {};
@@ -7361,7 +7503,7 @@ Result D3D12Device::createComputePipelineState(const ComputePipelineStateDesc& i
{
SLANG_RETURN_ON_FAIL(m_pipelineCreationAPIDispatcher->createComputePipelineState(
this,
- programImpl->slangProgram.get(),
+ programImpl->linkedProgram.get(),
&computeDesc,
(void**)pipelineState.writeRef()));
}
@@ -7835,8 +7977,8 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
pipelineStateImpl->init(inDesc);
auto program = static_cast<ShaderProgramImpl*>(inDesc.program);
- auto slangProgram = program->slangProgram;
- auto programLayout = slangProgram->getLayout();
+ auto slangGlobalScope = program->linkedProgram;
+ auto programLayout = slangGlobalScope->getLayout();
if (!program->m_rootObjectLayout->m_rootSignature)
{
@@ -7847,13 +7989,24 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
ChunkedList<D3D12_DXIL_LIBRARY_DESC> dxilLibraries;
ChunkedList<D3D12_HIT_GROUP_DESC> hitGroups;
ChunkedList<ComPtr<ISlangBlob>> codeBlobs;
+ ChunkedList<D3D12_EXPORT_DESC> exports;
+ ChunkedList<const wchar_t*> strPtrs;
+
ComPtr<ISlangBlob> diagnostics;
ChunkedList<OSString> stringPool;
- for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
+ auto getWStr = [&](const char* name)
+ {
+ String str = String(name);
+ auto wstr = str.toWString();
+ return stringPool.add(wstr)->begin();
+ };
+ auto compileShader = [&](slang::EntryPointLayout* entryPointInfo,
+ slang::IComponentType* component,
+ SlangInt entryPointIndex)
{
ComPtr<ISlangBlob> codeBlob;
- auto compileResult =
- slangProgram->getEntryPointCode(i, 0, codeBlob.writeRef(), diagnostics.writeRef());
+ auto compileResult = component->getEntryPointCode(
+ entryPointIndex, 0, codeBlob.writeRef(), diagnostics.writeRef());
if (diagnostics.get())
{
getDebugCallback()->handleMessage(
@@ -7864,20 +8017,38 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
SLANG_RETURN_ON_FAIL(compileResult);
codeBlobs.add(codeBlob);
D3D12_DXIL_LIBRARY_DESC library = {};
- library.DXILLibrary.BytecodeLength = codeBlob->getBufferSize();;
+ library.DXILLibrary.BytecodeLength = codeBlob->getBufferSize();
library.DXILLibrary.pShaderBytecode = codeBlob->getBufferPointer();
+ library.NumExports = 1;
+ D3D12_EXPORT_DESC exportDesc = {};
+ exportDesc.Name = getWStr(entryPointInfo->getNameOverride());
+ exportDesc.ExportToRename = getWStr(entryPointInfo->getNameOverride());
+ exportDesc.Flags = D3D12_EXPORT_FLAG_NONE;
+ library.pExports = exports.add(exportDesc);
D3D12_STATE_SUBOBJECT dxilSubObject = {};
dxilSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_DXIL_LIBRARY;
dxilSubObject.pDesc = dxilLibraries.add(library);
subObjects.add(dxilSubObject);
+ return SLANG_OK;
+ };
+ if (program->linkedEntryPoints.getCount() == 0)
+ {
+ for (SlangUInt i = 0; i < programLayout->getEntryPointCount(); i++)
+ {
+ SLANG_RETURN_ON_FAIL(compileShader(
+ programLayout->getEntryPointByIndex(i), program->linkedProgram, (SlangInt)i));
+ }
}
- auto getWStr = [&](const char* name)
+ else
{
- String str = String(name);
- auto wstr = str.toWString();
- return stringPool.add(wstr)->begin();
- };
+ for (auto& entryPoint : program->linkedEntryPoints)
+ {
+ SLANG_RETURN_ON_FAIL(
+ compileShader(entryPoint->getLayout()->getEntryPointByIndex(0), entryPoint, 0));
+ }
+ }
+
for (int i = 0; i < inDesc.hitGroupCount; i++)
{
auto hitGroup = inDesc.hitGroups[i];
@@ -7909,7 +8080,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
D3D12_RAYTRACING_SHADER_CONFIG shaderConfig = {};
// According to DXR spec, fixed function triangle intersections must use float2 as ray attributes
// that defines the barycentric coordinates at intersection.
- shaderConfig.MaxAttributeSizeInBytes = sizeof(float) * 2;
+ shaderConfig.MaxAttributeSizeInBytes = inDesc.maxAttributeSizeInBytes;
shaderConfig.MaxPayloadSizeInBytes = inDesc.maxRayPayloadSize;
D3D12_STATE_SUBOBJECT shaderConfigSubObject = {};
shaderConfigSubObject.Type = D3D12_STATE_SUBOBJECT_TYPE_RAYTRACING_SHADER_CONFIG;
@@ -7932,7 +8103,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
if (m_pipelineCreationAPIDispatcher)
{
- m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangProgram);
+ m_pipelineCreationAPIDispatcher->beforeCreateRayTracingState(this, slangGlobalScope);
}
D3D12_STATE_OBJECT_DESC rtpsoDesc = {};
@@ -7943,7 +8114,7 @@ Result D3D12Device::createRayTracingPipelineState(const RayTracingPipelineStateD
if (m_pipelineCreationAPIDispatcher)
{
- m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangProgram);
+ m_pipelineCreationAPIDispatcher->afterCreateRayTracingState(this, slangGlobalScope);
}
returnComPtr(outState, pipelineStateImpl);