summaryrefslogtreecommitdiffstats
path: root/tools
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2021-05-25 15:22:39 -0700
committerGitHub <noreply@github.com>2021-05-25 15:22:39 -0700
commit89f67d9c626fa193dba4adafcb54e46b13aa5e98 (patch)
tree769e11debb4194595a99e484d69af7b3704389c3 /tools
parentba24264275c640e0ac3732f0f5720e1f5816cded (diff)
Rework shader object specialization control interface. (#1857)
Diffstat (limited to 'tools')
-rw-r--r--tools/gfx-util/shader-cursor.h5
-rw-r--r--tools/gfx/debug-layer.cpp21
-rw-r--r--tools/gfx/debug-layer.h7
-rw-r--r--tools/gfx/renderer-shared.h140
-rw-r--r--tools/render-test/render-test-main.cpp38
-rw-r--r--tools/render-test/shader-input-layout.cpp57
-rw-r--r--tools/render-test/shader-input-layout.h12
7 files changed, 162 insertions, 118 deletions
diff --git a/tools/gfx-util/shader-cursor.h b/tools/gfx-util/shader-cursor.h
index 7512c62ee..41d0b3945 100644
--- a/tools/gfx-util/shader-cursor.h
+++ b/tools/gfx-util/shader-cursor.h
@@ -92,6 +92,11 @@ struct ShaderCursor
return m_baseObject->setObject(m_offset, object);
}
+ SlangResult setSpecializationArgs(const slang::SpecializationArg* args, uint32_t count) const
+ {
+ return m_baseObject->setSpecializationArgs(m_offset, args, count);
+ }
+
SlangResult setResource(IResourceView* resourceView) const
{
return m_baseObject->setResource(m_offset, resourceView);
diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp
index 3fa2eee9d..26e55ca7e 100644
--- a/tools/gfx/debug-layer.cpp
+++ b/tools/gfx/debug-layer.cpp
@@ -955,23 +955,12 @@ Result DebugShaderObject::setCombinedTextureSampler(
}
Result DebugShaderObject::setSpecializationArgs(
+ ShaderOffset const& offset,
const slang::SpecializationArg* args,
uint32_t count)
{
- ComPtr<slang::ISession> session;
- m_device->getSlangSession(session.writeRef());
- auto expectedCount = (uint32_t)session->getTypeLayout(m_slangType)
- ->getSize(SLANG_PARAMETER_CATEGORY_EXISTENTIAL_TYPE_PARAM);
- if (expectedCount != count)
- {
- GFX_DIAGNOSE_ERROR_FORMAT(
- "specialization argument count for shader object type %s mismatch: expecting %d but %d "
- "provided.",
- m_typeName.getBuffer(),
- expectedCount,
- count);
- };
- return baseObject->setSpecializationArgs(args, count);
+
+ return baseObject->setSpecializationArgs(offset, args, count);
}
DebugObjectBase::DebugObjectBase()
@@ -981,11 +970,11 @@ DebugObjectBase::DebugObjectBase()
}
Result DebugRootShaderObject::setSpecializationArgs(
+ ShaderOffset const& offset,
const slang::SpecializationArg* args,
uint32_t count)
{
- GFX_DIAGNOSE_ERROR("`setSpecializationArgs` should not be called directly on root objects.");
- return baseObject->setSpecializationArgs(args, count);
+ return baseObject->setSpecializationArgs(offset, args, count);
}
} // namespace gfx
diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h
index f9ecc7dfe..89ee9d837 100644
--- a/tools/gfx/debug-layer.h
+++ b/tools/gfx/debug-layer.h
@@ -168,6 +168,7 @@ public:
IResourceView* textureView,
ISamplerState* sampler) override;
virtual SLANG_NO_THROW Result SLANG_MCALL setSpecializationArgs(
+ ShaderOffset const& offset,
const slang::SpecializationArg* args,
uint32_t count) override;
@@ -204,8 +205,10 @@ class DebugRootShaderObject : public DebugShaderObject
public:
virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; }
virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; }
- virtual SLANG_NO_THROW Result SLANG_MCALL
- setSpecializationArgs(const slang::SpecializationArg* args, uint32_t count) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL setSpecializationArgs(
+ ShaderOffset const& offset,
+ const slang::SpecializationArg* args,
+ uint32_t count) override;
};
class DebugCommandBuffer;
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index 83799ac25..0dc0f75ae 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -473,10 +473,10 @@ class ShaderObjectBaseImpl : public ShaderObjectBase
protected:
TShaderObjectData m_data;
Slang::List<Slang::RefPtr<TShaderObjectImpl>> m_objects;
+ Slang::List<Slang::RefPtr<ExtendedShaderObjectTypeListObject>> m_userProvidedSpecializationArgs;
// Specialization args for a StructuredBuffer object.
ExtendedShaderObjectTypeList m_structuredBufferSpecializationArgs;
- Slang::RefPtr<ExtendedShaderObjectTypeListObject> m_userProvidedSpecializationArgs;
public:
TShaderObjectLayoutImpl* getLayout()
@@ -502,6 +502,39 @@ public:
return SLANG_OK;
}
+ void setSpecializationArgsForContainerElement(ExtendedShaderObjectTypeList& specializationArgs)
+ {
+ // Compute specialization args for the structured buffer object.
+ // If we haven't filled anything to `m_structuredBufferSpecializationArgs` yet,
+ // use `specializationArgs` directly.
+ if (m_structuredBufferSpecializationArgs.getCount() == 0)
+ {
+ m_structuredBufferSpecializationArgs = Slang::_Move(specializationArgs);
+ }
+ else
+ {
+ // If `m_structuredBufferSpecializationArgs` already contains some arguments, we
+ // need to check if they are the same as `specializationArgs`, and replace
+ // anything that is different with `__Dynamic` because we cannot specialize the
+ // buffer type if the element types are not the same.
+ SLANG_ASSERT(
+ m_structuredBufferSpecializationArgs.getCount() == specializationArgs.getCount());
+ auto device = getRenderer();
+ for (Slang::Index i = 0; i < m_structuredBufferSpecializationArgs.getCount(); i++)
+ {
+ if (m_structuredBufferSpecializationArgs[i].componentID !=
+ specializationArgs[i].componentID)
+ {
+ auto dynamicType = device->slangContext.session->getDynamicType();
+ m_structuredBufferSpecializationArgs.componentIDs[i] =
+ device->shaderCache.getComponentId(dynamicType);
+ m_structuredBufferSpecializationArgs.components[i] =
+ slang::SpecializationArg::fromType(dynamicType);
+ }
+ }
+ }
+ }
+
virtual SLANG_NO_THROW Result SLANG_MCALL
setObject(ShaderOffset const& offset, IShaderObject* object) SLANG_OVERRIDE
{
@@ -563,36 +596,7 @@ public:
subObject->m_data.getBuffer(),
(size_t)subObject->m_data.getCount()));
- // Compute specialization args for the structured buffer object.
- // If we haven't filled anything to `m_structuredBufferSpecializationArgs` yet,
- // use `specializationArgs` directly.
- if (m_structuredBufferSpecializationArgs.getCount() == 0)
- {
- m_structuredBufferSpecializationArgs = Slang::_Move(specializationArgs);
- }
- else
- {
- // If `m_structuredBufferSpecializationArgs` already contains some arguments, we
- // need to check if they are the same as `specializationArgs`, and replace
- // anything that is different with `__Dynamic` because we cannot specialize the
- // buffer type if the element types are not the same.
- SLANG_ASSERT(
- m_structuredBufferSpecializationArgs.getCount() ==
- specializationArgs.getCount());
- auto device = getRenderer();
- for (Slang::Index i = 0; i < m_structuredBufferSpecializationArgs.getCount(); i++)
- {
- if (m_structuredBufferSpecializationArgs[i].componentID !=
- specializationArgs[i].componentID)
- {
- auto dynamicType = device->slangContext.session->getDynamicType();
- m_structuredBufferSpecializationArgs.componentIDs[i] =
- device->shaderCache.getComponentId(dynamicType);
- m_structuredBufferSpecializationArgs.components[i] =
- slang::SpecializationArg::fromType(dynamicType);
- }
- }
- }
+ setSpecializationArgsForContainerElement(specializationArgs);
return SLANG_OK;
}
@@ -693,17 +697,11 @@ public:
return SLANG_OK;
}
- virtual SLANG_NO_THROW Result SLANG_MCALL
- setSpecializationArgs(const slang::SpecializationArg* args, uint32_t count) override
+ Result getExtendedShaderTypeListFromSpecializationArgs(
+ ExtendedShaderObjectTypeList& list,
+ const slang::SpecializationArg* args,
+ uint32_t count)
{
- if (!m_userProvidedSpecializationArgs)
- {
- m_userProvidedSpecializationArgs = new ExtendedShaderObjectTypeListObject();
- }
- else
- {
- m_userProvidedSpecializationArgs->clear();
- }
auto device = getRenderer();
for (uint32_t i = 0; i < count; i++)
{
@@ -718,20 +716,57 @@ public:
SLANG_ASSERT(false && "Unexpected specialization argument kind.");
return SLANG_FAIL;
}
- m_userProvidedSpecializationArgs->add(extendedType);
+ list.add(extendedType);
}
return SLANG_OK;
}
- // Appends all types that are used to specialize the element type of this shader object in
- // `args` list.
- virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override
+ virtual SLANG_NO_THROW Result SLANG_MCALL setSpecializationArgs(
+ ShaderOffset const& offset,
+ const slang::SpecializationArg* args,
+ uint32_t count) override
{
- if (m_userProvidedSpecializationArgs)
+ auto layout = getLayout();
+
+ // If the shader object is a container, delegate the processing to
+ // `setSpecializationArgsForContainerElements`.
+ if (layout->getContainerType() != ShaderObjectContainerType::None)
{
- args.addRange(*m_userProvidedSpecializationArgs);
+ ExtendedShaderObjectTypeList argList;
+ SLANG_RETURN_ON_FAIL(
+ getExtendedShaderTypeListFromSpecializationArgs(argList, args, count));
+ setSpecializationArgsForContainerElement(argList);
return SLANG_OK;
}
+
+ if (offset.bindingRangeIndex < 0)
+ return SLANG_E_INVALID_ARG;
+ if (offset.bindingRangeIndex >= layout->getBindingRangeCount())
+ return SLANG_E_INVALID_ARG;
+
+ auto bindingRangeIndex = offset.bindingRangeIndex;
+ auto bindingRange = layout->getBindingRange(bindingRangeIndex);
+ auto objectIndex = bindingRange.subObjectIndex + offset.bindingArrayIndex;
+ if (objectIndex >= m_userProvidedSpecializationArgs.getCount())
+ m_userProvidedSpecializationArgs.setCount(objectIndex + 1);
+ if (!m_userProvidedSpecializationArgs[objectIndex])
+ {
+ m_userProvidedSpecializationArgs[objectIndex] =
+ new ExtendedShaderObjectTypeListObject();
+ }
+ else
+ {
+ m_userProvidedSpecializationArgs[objectIndex]->clear();
+ }
+ SLANG_RETURN_ON_FAIL(getExtendedShaderTypeListFromSpecializationArgs(
+ *m_userProvidedSpecializationArgs[objectIndex], args, count));
+ return SLANG_OK;
+ }
+
+ // Appends all types that are used to specialize the element type of this shader object in
+ // `args` list.
+ virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override
+ {
if (m_layout->getContainerType() != ShaderObjectContainerType::None)
{
args.addRange(m_structuredBufferSpecializationArgs);
@@ -761,12 +796,19 @@ public:
subObjectIndexInRange++)
{
ExtendedShaderObjectTypeList typeArgs;
-
- auto subObject = m_objects[bindingRange.subObjectIndex + subObjectIndexInRange];
+ auto objectIndex = bindingRange.subObjectIndex + subObjectIndexInRange;
+ auto subObject = m_objects[objectIndex];
if (!subObject)
continue;
+ if (objectIndex < m_userProvidedSpecializationArgs.getCount() &&
+ m_userProvidedSpecializationArgs[objectIndex])
+ {
+ args.addRange(*m_userProvidedSpecializationArgs[objectIndex]);
+ continue;
+ }
+
switch (bindingRange.bindingType)
{
case slang::BindingType::ExistentialValue:
diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp
index 5c8011889..ab73af5b5 100644
--- a/tools/render-test/render-test-main.cpp
+++ b/tools/render-test/render-test-main.cpp
@@ -340,29 +340,27 @@ struct AssignValsFromLayoutContext
ComPtr<IShaderObject> shaderObject = device->createShaderObject(slangType);
SLANG_RETURN_ON_FAIL(assign(ShaderCursor(shaderObject), srcVal->contentVal));
+ dstCursor.setObject(shaderObject);
+ return SLANG_OK;
+ }
- if (srcVal->specializationArgs.getCount())
+ SlangResult assignValWithSpecializationArg(
+ ShaderCursor const& dstCursor,
+ ShaderInputLayout::SpecializeVal* srcVal)
+ {
+ assign(dstCursor, srcVal->contentVal);
+ List<slang::SpecializationArg> args;
+ for (auto& typeName : srcVal->typeArgs)
{
- List<slang::SpecializationArg> args;
- for (auto srcArg : srcVal->specializationArgs)
+ auto slangType = slangReflection->findTypeByName(typeName.getBuffer());
+ if (!slangType)
{
- auto argType = slangReflection->findTypeByName(srcArg.getBuffer());
- if (argType)
- {
- slang::SpecializationArg arg = slang::SpecializationArg::fromType(argType);
- args.add(arg);
- }
- else
- {
- StdWriters::getError().print(
- "error: could not find shader type '%s'\n", srcArg.getBuffer());
- return SLANG_E_INVALID_ARG;
- }
+ StdWriters::getError().print("error: could not find shader type '%s'\n", typeName.getBuffer());
+ return SLANG_E_INVALID_ARG;
}
- shaderObject->setSpecializationArgs(args.getBuffer(), args.getCount());
+ args.add(slang::SpecializationArg::fromType(slangType));
}
- dstCursor.setObject(shaderObject);
- return SLANG_OK;
+ return dstCursor.setSpecializationArgs(args.getBuffer(), (uint32_t)args.getCount());
}
SlangResult assignArray(ShaderCursor const& dstCursor, ShaderInputLayout::ArrayVal* srcVal)
@@ -399,6 +397,10 @@ struct AssignValsFromLayoutContext
case ShaderInputType::Object:
return assignObject(dstCursor, (ShaderInputLayout::ObjectVal*) srcVal.Ptr());
+ case ShaderInputType::Specialize:
+ return assignValWithSpecializationArg(
+ dstCursor, (ShaderInputLayout::SpecializeVal*)srcVal.Ptr());
+
case ShaderInputType::Aggregate:
return assignAggregate(dstCursor, (ShaderInputLayout::AggVal*) srcVal.Ptr());
diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp
index af15d5f9f..3ab0366a5 100644
--- a/tools/render-test/shader-input-layout.cpp
+++ b/tools/render-test/shader-input-layout.cpp
@@ -346,37 +346,6 @@ namespace renderer_test
return SLANG_OK;
}
- SlangResult parseObjectAttributes(ShaderInputLayout::ObjectVal* val, Misc::TokenReader& parser)
- {
- if (parser.AdvanceIf(":"))
- {
- while (!parser.IsEnd() && parser.NextToken().Type == Misc::TokenType::Identifier)
- {
- if (parser.AdvanceIf("specialization_args"))
- {
- parser.Read(Misc::TokenType::LParent);
- while (!parser.IsEnd() &&
- parser.NextToken().Type != Misc::TokenType::RParent)
- {
- val->specializationArgs.add(parseTypeName(parser));
- if (!parser.AdvanceIf(","))
- break;
- }
- parser.Read(Misc::TokenType::RParent);
- }
- else
- {
- throw ShaderInputLayoutFormatException(
- StringBuilder() << "Unknown attribute \'" << parser.NextToken().Content << "\' ("
- << parser.NextToken().Position.Line << ")");
-
- return SLANG_FAIL;
- }
- }
- }
- return SLANG_OK;
- }
-
Format parseFormatOption(Misc::TokenReader& parser)
{
parser.Read("=");
@@ -542,7 +511,6 @@ namespace renderer_test
}
val->contentVal = parseValExpr(parser);
- parseObjectAttributes(val, parser);
return val;
}
else if( parser.AdvanceIf("out") )
@@ -551,6 +519,31 @@ namespace renderer_test
val->isOutput = true;
return val;
}
+ else if (parser.AdvanceIf("specialize"))
+ {
+ RefPtr<ShaderInputLayout::SpecializeVal> val =
+ new ShaderInputLayout::SpecializeVal();
+
+ parser.Read(Misc::TokenType::LParent);
+ while (!parser.IsEnd() &&
+ parser.NextToken().Type != Misc::TokenType::RParent)
+ {
+ val->typeArgs.add(parseTypeName(parser));
+ if (!parser.AdvanceIf(","))
+ break;
+ }
+ parser.Read(Misc::TokenType::RParent);
+ val->contentVal = parseValExpr(parser);
+ return val;
+ }
+ else if (parser.AdvanceIf("dynamic"))
+ {
+ RefPtr<ShaderInputLayout::SpecializeVal> val =
+ new ShaderInputLayout::SpecializeVal();
+ val->typeArgs.add("__Dynamic");
+ val->contentVal = parseValExpr(parser);
+ return val;
+ }
else
{
// We assume that any other word is introducing one of the other
diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h
index 86e7641f0..ed2f57370 100644
--- a/tools/render-test/shader-input-layout.h
+++ b/tools/render-test/shader-input-layout.h
@@ -22,6 +22,7 @@ enum class ShaderInputType
UniformData,
Object,
Aggregate,
+ Specialize,
};
enum class InputTextureContent
@@ -248,7 +249,16 @@ public:
Slang::String typeName;
ValPtr contentVal;
- Slang::List<Slang::String> specializationArgs;
+ };
+
+ class SpecializeVal : public Val
+ {
+ public:
+ ValPtr contentVal;
+ Slang::List<Slang::String> typeArgs;
+ SpecializeVal()
+ : Val(ShaderInputType::Specialize)
+ {}
};
class ArrayVal : public ParentVal