summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-02-10 12:39:55 -0800
committerGitHub <noreply@github.com>2022-02-10 12:39:55 -0800
commit120f97fb8d4e22b057cea43b503611f8292ade37 (patch)
treee1f7bae615b499425702f7e82bc556a312c7515c
parent0c04885da9edc3df7a1ef5cb520be1bd29eb13e4 (diff)
gfx: support shader record overwrite and fix QueryPool. (#2123)
* Various fixes to gfx. * Fix. * Fixes. * Fix. * gfx: support root parameter via user-defined attribute. * Fix. * Fix. * Skip d3d12 tests on win x86. * Fixes. * gfx: support shader record overwrite. * Fix QueyPool implementation. * Rename to `getBindingRangeLeafVariable` Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--slang-gfx.h11
-rw-r--r--slang.h6
-rw-r--r--source/slang/slang-reflection-api.cpp10
-rw-r--r--source/slang/slang-type-layout.h2
-rw-r--r--tools/gfx/d3d12/render-d3d12.cpp29
-rw-r--r--tools/gfx/renderer-shared.cpp25
-rw-r--r--tools/gfx/renderer-shared.h2
7 files changed, 68 insertions, 17 deletions
diff --git a/slang-gfx.h b/slang-gfx.h
index a9dfe7b28..4b40eab30 100644
--- a/slang-gfx.h
+++ b/slang-gfx.h
@@ -1318,16 +1318,27 @@ struct RayTracingPipelineStateDesc
class IShaderTable : public ISlangUnknown
{
public:
+ // Specifies the bytes to overwrite into a record in the shader table.
+ struct ShaderRecordOverwrite
+ {
+ uint32_t offset; // Offset within the shader record.
+ uint32_t size; // Number of bytes to overwrite.
+ uint8_t data[8]; // Content to overwrite.
+ };
+
struct Desc
{
uint32_t rayGenShaderCount;
const char** rayGenShaderEntryPointNames;
+ const ShaderRecordOverwrite* rayGenShaderRecordOverwrites;
uint32_t missShaderCount;
const char** missShaderEntryPointNames;
+ const ShaderRecordOverwrite* missShaderRecordOverwrites;
uint32_t hitGroupCount;
const char** hitGroupNames;
+ const ShaderRecordOverwrite* hitGroupRecordOverwrites;
IShaderProgram* program;
};
diff --git a/slang.h b/slang.h
index b3169e4b0..9872180fb 100644
--- a/slang.h
+++ b/slang.h
@@ -2043,7 +2043,7 @@ extern "C"
SLANG_API SlangBindingType spReflectionTypeLayout_getBindingRangeType(SlangReflectionTypeLayout* typeLayout, SlangInt index);
SLANG_API SlangInt spReflectionTypeLayout_getBindingRangeBindingCount(SlangReflectionTypeLayout* typeLayout, SlangInt index);
SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_getBindingRangeLeafTypeLayout(SlangReflectionTypeLayout* typeLayout, SlangInt index);
- SLANG_API SlangReflectionVariable* spReflectionTypeLayout_getBindingRangeVariable(SlangReflectionTypeLayout* typeLayout, SlangInt index);
+ SLANG_API SlangReflectionVariable* spReflectionTypeLayout_getBindingRangeLeafVariable(SlangReflectionTypeLayout* typeLayout, SlangInt index);
SLANG_API SlangInt spReflectionTypeLayout_getFieldBindingRangeOffset(SlangReflectionTypeLayout* typeLayout, SlangInt fieldIndex);
SLANG_API SlangInt spReflectionTypeLayout_getBindingRangeDescriptorSetIndex(SlangReflectionTypeLayout* typeLayout, SlangInt index);
@@ -2660,9 +2660,9 @@ namespace slang
index);
}
- VariableReflection* getBindingRangeVariable(SlangInt index)
+ VariableReflection* getBindingRangeLeafVariable(SlangInt index)
{
- return (VariableReflection*)spReflectionTypeLayout_getBindingRangeVariable(
+ return (VariableReflection*)spReflectionTypeLayout_getBindingRangeLeafVariable(
(SlangReflectionTypeLayout*)this, index);
}
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 8bcab8ada..8919e2ba5 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -1559,7 +1559,7 @@ namespace Slang
TypeLayout::ExtendedInfo::BindingRangeInfo bindingRange;
bindingRange.leafTypeLayout = typeLayout;
- bindingRange.variable = path.primary ? path.primary->var->getVariable() : nullptr;
+ bindingRange.leafVariable = path.primary ? path.primary->var->getVariable() : nullptr;
bindingRange.bindingType = bindingType;
bindingRange.count = multiplier;
bindingRange.descriptorSetIndex = -1;
@@ -1746,7 +1746,7 @@ namespace Slang
//
TypeLayout::ExtendedInfo::BindingRangeInfo bindingRange;
bindingRange.leafTypeLayout = typeLayout;
- bindingRange.variable = path.primary ? path.primary->var->getVariable() : nullptr;
+ bindingRange.leafVariable = path.primary ? path.primary->var->getVariable() : nullptr;
bindingRange.bindingType = SLANG_BINDING_TYPE_EXISTENTIAL_VALUE;
bindingRange.count = multiplier;
bindingRange.descriptorSetIndex = 0;
@@ -1819,7 +1819,7 @@ namespace Slang
//
TypeLayout::ExtendedInfo::BindingRangeInfo bindingRange;
bindingRange.leafTypeLayout = typeLayout;
- bindingRange.variable = path.primary ? path.primary->var->getVariable() : nullptr;
+ bindingRange.leafVariable = path.primary ? path.primary->var->getVariable() : nullptr;
bindingRange.bindingType = bindingType;
bindingRange.count = multiplier;
bindingRange.descriptorSetIndex = 0;
@@ -2019,7 +2019,7 @@ SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_getBindingRangeLeafT
return convert(bindingRange.leafTypeLayout);
}
-SLANG_API SlangReflectionVariable* spReflectionTypeLayout_getBindingRangeVariable(
+SLANG_API SlangReflectionVariable* spReflectionTypeLayout_getBindingRangeLeafVariable(
SlangReflectionTypeLayout* inTypeLayout, SlangInt index)
{
auto typeLayout = convert(inTypeLayout);
@@ -2033,7 +2033,7 @@ SLANG_API SlangReflectionVariable* spReflectionTypeLayout_getBindingRangeVariabl
return 0;
auto& bindingRange = extTypeLayout->m_bindingRanges[index];
- return convert(bindingRange.variable);
+ return convert(bindingRange.leafVariable);
}
diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h
index d66a77365..6e28b6c9d 100644
--- a/source/slang/slang-type-layout.h
+++ b/source/slang/slang-type-layout.h
@@ -431,7 +431,7 @@ public:
struct BindingRangeInfo
{
- VarDeclBase* variable;
+ VarDeclBase* leafVariable;
TypeLayout* leafTypeLayout;
SlangBindingType bindingType;
LayoutSize count;
diff --git a/tools/gfx/d3d12/render-d3d12.cpp b/tools/gfx/d3d12/render-d3d12.cpp
index 72474f549..0539b8111 100644
--- a/tools/gfx/d3d12/render-d3d12.cpp
+++ b/tools/gfx/d3d12/render-d3d12.cpp
@@ -519,7 +519,13 @@ public:
virtual SLANG_NO_THROW Result SLANG_MCALL getResult(SlangInt queryIndex, SlangInt count, uint64_t* data) override
{
m_commandList->Reset(m_commandAllocator, nullptr);
- m_commandList->ResolveQueryData(m_queryHeap, m_queryType, (UINT)queryIndex, (UINT)count, m_readBackBuffer, 0);
+ m_commandList->ResolveQueryData(
+ m_queryHeap,
+ m_queryType,
+ (UINT)queryIndex,
+ (UINT)count,
+ m_readBackBuffer,
+ sizeof(uint64_t) * queryIndex);
m_commandList->Close();
ID3D12CommandList* cmdList = m_commandList;
m_commandQueue->ExecuteCommandLists(1, &cmdList);
@@ -529,7 +535,7 @@ public:
WaitForSingleObject(m_waitEvent, INFINITE);
int8_t* mappedData = nullptr;
- D3D12_RANGE readRange = { sizeof(uint64_t) * queryIndex,sizeof(uint64_t) * (queryIndex + count) };
+ D3D12_RANGE readRange = { sizeof(uint64_t) * queryIndex, sizeof(uint64_t) * (queryIndex + count) };
m_readBackBuffer.getResource()->Map(0, &readRange, (void**)&mappedData);
memcpy(data, mappedData + sizeof(uint64_t) * queryIndex, sizeof(uint64_t) * count);
m_readBackBuffer.getResource()->Unmap(0, nullptr);
@@ -1228,9 +1234,9 @@ public:
bool isRootParameter = false;
if (rootParameterAttributeName)
{
- if (auto variable = typeLayout->getBindingRangeVariable(bindingRangeIndex))
+ if (auto leafVariable = typeLayout->getBindingRangeLeafVariable(bindingRangeIndex))
{
- if (variable->findUserAttributeByName(
+ if (leafVariable->findUserAttributeByName(
globalSession, rootParameterAttributeName))
{
isRootParameter = true;
@@ -3470,7 +3476,7 @@ public:
void* stagingPtr = nullptr;
stagingBuffer->map(nullptr, &stagingPtr);
- auto copyShaderIdInto = [&](void* dest, String& name)
+ auto copyShaderIdInto = [&](void* dest, String& name, const ShaderRecordOverwrite& overwrite)
{
if (name.getLength())
{
@@ -3481,6 +3487,10 @@ public:
{
memset(dest, 0, D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES);
}
+ if (overwrite.size)
+ {
+ memcpy((uint8_t*)dest + overwrite.offset, overwrite.data, overwrite.size);
+ }
};
uint8_t* stagingBufferPtr = (uint8_t*)stagingPtr;
@@ -3489,21 +3499,24 @@ public:
copyShaderIdInto(
stagingBufferPtr + m_rayGenTableOffset +
D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
- m_entryPointNames[i]);
+ m_entryPointNames[i],
+ m_recordOverwrites[i]);
}
for (uint32_t i = 0; i < m_missShaderCount; i++)
{
copyShaderIdInto(
stagingBufferPtr + m_missTableOffset +
D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
- m_entryPointNames[m_rayGenShaderCount + i]);
+ m_entryPointNames[m_rayGenShaderCount + i],
+ m_recordOverwrites[m_rayGenShaderCount + i]);
}
for (uint32_t i = 0; i < m_hitGroupCount; i++)
{
copyShaderIdInto(
stagingBufferPtr + m_hitGroupTableOffset +
D3D12_SHADER_IDENTIFIER_SIZE_IN_BYTES * i,
- m_entryPointNames[m_rayGenShaderCount + m_missShaderCount + i]);
+ m_entryPointNames[m_rayGenShaderCount + m_missShaderCount + i],
+ m_recordOverwrites[m_rayGenShaderCount + m_missShaderCount + i]);
}
stagingBuffer->unmap(nullptr);
diff --git a/tools/gfx/renderer-shared.cpp b/tools/gfx/renderer-shared.cpp
index f9ffe8dfd..a97462c73 100644
--- a/tools/gfx/renderer-shared.cpp
+++ b/tools/gfx/renderer-shared.cpp
@@ -957,17 +957,42 @@ Result ShaderTableBase::init(const IShaderTable::Desc& desc)
m_missShaderCount = desc.missShaderCount;
m_hitGroupCount = desc.hitGroupCount;
m_entryPointNames.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
+ m_recordOverwrites.reserve(desc.hitGroupCount + desc.missShaderCount + desc.rayGenShaderCount);
for (uint32_t i = 0; i < desc.rayGenShaderCount; i++)
{
m_entryPointNames.add(desc.rayGenShaderEntryPointNames[i]);
+ if (desc.rayGenShaderRecordOverwrites)
+ {
+ m_recordOverwrites.add(desc.rayGenShaderRecordOverwrites[i]);
+ }
+ else
+ {
+ m_recordOverwrites.add(ShaderRecordOverwrite{});
+ }
}
for (uint32_t i = 0; i < desc.missShaderCount; i++)
{
m_entryPointNames.add(desc.missShaderEntryPointNames[i]);
+ if (desc.missShaderRecordOverwrites)
+ {
+ m_recordOverwrites.add(desc.missShaderRecordOverwrites[i]);
+ }
+ else
+ {
+ m_recordOverwrites.add(ShaderRecordOverwrite{});
+ }
}
for (uint32_t i = 0; i < desc.hitGroupCount; i++)
{
m_entryPointNames.add(desc.hitGroupNames[i]);
+ if (desc.hitGroupRecordOverwrites)
+ {
+ m_recordOverwrites.add(desc.hitGroupRecordOverwrites[i]);
+ }
+ else
+ {
+ m_recordOverwrites.add(ShaderRecordOverwrite{});
+ }
}
return SLANG_OK;
}
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index e6f54ae10..5ba6bbc50 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -1234,6 +1234,8 @@ class ShaderTableBase
{
public:
Slang::List<Slang::String> m_entryPointNames;
+ Slang::List<ShaderRecordOverwrite> m_recordOverwrites;
+
uint32_t m_rayGenShaderCount;
uint32_t m_missShaderCount;
uint32_t m_hitGroupCount;