summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slang-gfx.h7
-rw-r--r--tests/compute/dynamic-dispatch-14.slang3
-rw-r--r--tools/gfx/debug-layer.cpp30
-rw-r--r--tools/gfx/debug-layer.h7
-rw-r--r--tools/gfx/renderer-shared.h53
-rw-r--r--tools/render-test/render-test-main.cpp20
-rw-r--r--tools/render-test/shader-input-layout.cpp32
-rw-r--r--tools/render-test/shader-input-layout.h1
8 files changed, 147 insertions, 6 deletions
diff --git a/slang-gfx.h b/slang-gfx.h
index 7428f4c56..f32aee1ce 100644
--- a/slang-gfx.h
+++ b/slang-gfx.h
@@ -544,6 +544,13 @@ public:
setSampler(ShaderOffset const& offset, ISamplerState* sampler) = 0;
virtual SLANG_NO_THROW Result SLANG_MCALL setCombinedTextureSampler(
ShaderOffset const& offset, IResourceView* textureView, ISamplerState* sampler) = 0;
+
+ /// Manually setting the specialization arguments for the shader object, overriding
+ /// the default arguments computed from the sub-objects.
+ /// Specialization arguments are passed to the shader compiler to specialize the type
+ /// of interface-typed shader parameters.
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ setSpecializationArgs(const slang::SpecializationArg* args, uint32_t count) = 0;
};
#define SLANG_UUID_IShaderObject \
{ \
diff --git a/tests/compute/dynamic-dispatch-14.slang b/tests/compute/dynamic-dispatch-14.slang
index 8361cd317..9bbed215e 100644
--- a/tests/compute/dynamic-dispatch-14.slang
+++ b/tests/compute/dynamic-dispatch-14.slang
@@ -29,8 +29,7 @@ RWStructuredBuffer<int> gOutputBuffer;
//TEST_INPUT: set gCb = new StructuredBuffer<IInterface>{new MyImpl{1}};
RWStructuredBuffer<IInterface> gCb;
-// Add two elements into the structured buffer to prevent specialization.
-//TEST_INPUT: set gCb1 = new StructuredBuffer<IInterface>{new MyImpl{1}, new MyImpl2{2}};
+//TEST_INPUT: set gCb1 = new StructuredBuffer<IInterface>{new MyImpl{1}} : specialization_args(__Dynamic);
RWStructuredBuffer<IInterface> gCb1;
[numthreads(4, 1, 1)]
diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp
index 56ae4fdab..3fa2eee9d 100644
--- a/tools/gfx/debug-layer.cpp
+++ b/tools/gfx/debug-layer.cpp
@@ -385,6 +385,8 @@ Result DebugDevice::createShaderObject(
auto result =
baseObject->createShaderObject(type, containerType, outObject->baseObject.writeRef());
outObject->m_typeName = typeName;
+ outObject->m_device = this;
+ outObject->m_slangType = type;
if (SLANG_FAILED(result))
return result;
returnComPtr(outShaderObject, outObject);
@@ -952,10 +954,38 @@ Result DebugShaderObject::setCombinedTextureSampler(
offset, viewImpl->baseObject.get(), samplerImpl->baseObject.get());
}
+Result DebugShaderObject::setSpecializationArgs(
+ const slang::SpecializationArg* args,
+ uint32_t count)
+{
+ ComPtr<slang::ISession> session;
+ m_device->getSlangSession(session.writeRef());
+ auto expectedCount = (uint32_t)session->getTypeLayout(m_slangType)
+ ->getSize(SLANG_PARAMETER_CATEGORY_EXISTENTIAL_TYPE_PARAM);
+ if (expectedCount != count)
+ {
+ GFX_DIAGNOSE_ERROR_FORMAT(
+ "specialization argument count for shader object type %s mismatch: expecting %d but %d "
+ "provided.",
+ m_typeName.getBuffer(),
+ expectedCount,
+ count);
+ };
+ return baseObject->setSpecializationArgs(args, count);
+}
+
DebugObjectBase::DebugObjectBase()
{
static uint64_t uidCounter = 0;
uid = ++uidCounter;
}
+Result DebugRootShaderObject::setSpecializationArgs(
+ const slang::SpecializationArg* args,
+ uint32_t count)
+{
+ GFX_DIAGNOSE_ERROR("`setSpecializationArgs` should not be called directly on root objects.");
+ return baseObject->setSpecializationArgs(args, count);
+}
+
} // namespace gfx
diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h
index a4e201e4f..f9ecc7dfe 100644
--- a/tools/gfx/debug-layer.h
+++ b/tools/gfx/debug-layer.h
@@ -167,6 +167,9 @@ public:
ShaderOffset const& offset,
IResourceView* textureView,
ISamplerState* sampler) override;
+ virtual SLANG_NO_THROW Result SLANG_MCALL setSpecializationArgs(
+ const slang::SpecializationArg* args,
+ uint32_t count) override;
public:
struct ShaderOffsetKey
@@ -188,6 +191,8 @@ public:
}
};
Slang::String m_typeName;
+ slang::TypeReflection* m_slangType = nullptr;
+ DebugDevice* m_device;
Slang::List<Slang::RefPtr<DebugShaderObject>> m_entryPoints;
Slang::Dictionary<ShaderOffsetKey, Slang::RefPtr<DebugShaderObject>> m_objects;
Slang::Dictionary<ShaderOffsetKey, Slang::RefPtr<DebugResourceView>> m_resources;
@@ -199,6 +204,8 @@ class DebugRootShaderObject : public DebugShaderObject
public:
virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; }
virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; }
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ setSpecializationArgs(const slang::SpecializationArg* args, uint32_t count) override;
};
class DebugCommandBuffer;
diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h
index ba0327cf4..83799ac25 100644
--- a/tools/gfx/renderer-shared.h
+++ b/tools/gfx/renderer-shared.h
@@ -296,6 +296,11 @@ struct ExtendedShaderObjectTypeList
}
};
+struct ExtendedShaderObjectTypeListObject
+ : public ExtendedShaderObjectTypeList
+ , public Slang::RefObject
+{};
+
class ShaderObjectLayoutBase : public Slang::RefObject
{
protected:
@@ -471,6 +476,7 @@ protected:
// Specialization args for a StructuredBuffer object.
ExtendedShaderObjectTypeList m_structuredBufferSpecializationArgs;
+ Slang::RefPtr<ExtendedShaderObjectTypeListObject> m_userProvidedSpecializationArgs;
public:
TShaderObjectLayoutImpl* getLayout()
@@ -687,10 +693,51 @@ public:
return SLANG_OK;
}
+ virtual SLANG_NO_THROW Result SLANG_MCALL
+ setSpecializationArgs(const slang::SpecializationArg* args, uint32_t count) override
+ {
+ if (!m_userProvidedSpecializationArgs)
+ {
+ m_userProvidedSpecializationArgs = new ExtendedShaderObjectTypeListObject();
+ }
+ else
+ {
+ m_userProvidedSpecializationArgs->clear();
+ }
+ auto device = getRenderer();
+ for (uint32_t i = 0; i < count; i++)
+ {
+ gfx::ExtendedShaderObjectType extendedType;
+ switch (args[i].kind)
+ {
+ case slang::SpecializationArg::Kind::Type:
+ extendedType.slangType = args[i].type;
+ extendedType.componentID = device->shaderCache.getComponentId(args[i].type);
+ break;
+ default:
+ SLANG_ASSERT(false && "Unexpected specialization argument kind.");
+ return SLANG_FAIL;
+ }
+ m_userProvidedSpecializationArgs->add(extendedType);
+ }
+ return SLANG_OK;
+ }
+
// Appends all types that are used to specialize the element type of this shader object in
// `args` list.
virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override
{
+ if (m_userProvidedSpecializationArgs)
+ {
+ args.addRange(*m_userProvidedSpecializationArgs);
+ return SLANG_OK;
+ }
+ if (m_layout->getContainerType() != ShaderObjectContainerType::None)
+ {
+ args.addRange(m_structuredBufferSpecializationArgs);
+ return SLANG_OK;
+ }
+
auto device = getRenderer();
auto& subObjectRanges = getLayout()->getSubObjectRanges();
// The following logic is built on the assumption that all fields that involve
@@ -740,6 +787,8 @@ public:
}
case slang::BindingType::ParameterBlock:
case slang::BindingType::ConstantBuffer:
+ case slang::BindingType::RawBuffer:
+ case slang::BindingType::MutableRawBuffer:
// Currently we only handle the case where the field's type is
// `ParameterBlock<SomeStruct>` or `ConstantBuffer<SomeStruct>`, where
// `SomeStruct` is a struct type (not directly an interface type). In this case,
@@ -751,10 +800,6 @@ public:
// `ExistentialValue` case here, but currently we lack a mechanism to
// distinguish the two scenarios.
break;
- case slang::BindingType::RawBuffer:
- case slang::BindingType::MutableRawBuffer:
- typeArgs.addRange(subObject->m_structuredBufferSpecializationArgs);
- break;
}
auto addedTypeArgCountForCurrentRange = args.getCount() - oldArgsCount;
diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp
index c04d6db00..5c8011889 100644
--- a/tools/render-test/render-test-main.cpp
+++ b/tools/render-test/render-test-main.cpp
@@ -341,6 +341,26 @@ struct AssignValsFromLayoutContext
SLANG_RETURN_ON_FAIL(assign(ShaderCursor(shaderObject), srcVal->contentVal));
+ if (srcVal->specializationArgs.getCount())
+ {
+ List<slang::SpecializationArg> args;
+ for (auto srcArg : srcVal->specializationArgs)
+ {
+ auto argType = slangReflection->findTypeByName(srcArg.getBuffer());
+ if (argType)
+ {
+ slang::SpecializationArg arg = slang::SpecializationArg::fromType(argType);
+ args.add(arg);
+ }
+ else
+ {
+ StdWriters::getError().print(
+ "error: could not find shader type '%s'\n", srcArg.getBuffer());
+ return SLANG_E_INVALID_ARG;
+ }
+ }
+ shaderObject->setSpecializationArgs(args.getBuffer(), args.getCount());
+ }
dstCursor.setObject(shaderObject);
return SLANG_OK;
}
diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp
index 9362efb45..af15d5f9f 100644
--- a/tools/render-test/shader-input-layout.cpp
+++ b/tools/render-test/shader-input-layout.cpp
@@ -346,6 +346,37 @@ namespace renderer_test
return SLANG_OK;
}
+ SlangResult parseObjectAttributes(ShaderInputLayout::ObjectVal* val, Misc::TokenReader& parser)
+ {
+ if (parser.AdvanceIf(":"))
+ {
+ while (!parser.IsEnd() && parser.NextToken().Type == Misc::TokenType::Identifier)
+ {
+ if (parser.AdvanceIf("specialization_args"))
+ {
+ parser.Read(Misc::TokenType::LParent);
+ while (!parser.IsEnd() &&
+ parser.NextToken().Type != Misc::TokenType::RParent)
+ {
+ val->specializationArgs.add(parseTypeName(parser));
+ if (!parser.AdvanceIf(","))
+ break;
+ }
+ parser.Read(Misc::TokenType::RParent);
+ }
+ else
+ {
+ throw ShaderInputLayoutFormatException(
+ StringBuilder() << "Unknown attribute \'" << parser.NextToken().Content << "\' ("
+ << parser.NextToken().Position.Line << ")");
+
+ return SLANG_FAIL;
+ }
+ }
+ }
+ return SLANG_OK;
+ }
+
Format parseFormatOption(Misc::TokenReader& parser)
{
parser.Read("=");
@@ -511,6 +542,7 @@ namespace renderer_test
}
val->contentVal = parseValExpr(parser);
+ parseObjectAttributes(val, parser);
return val;
}
else if( parser.AdvanceIf("out") )
diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h
index adb30c7ec..86e7641f0 100644
--- a/tools/render-test/shader-input-layout.h
+++ b/tools/render-test/shader-input-layout.h
@@ -248,6 +248,7 @@ public:
Slang::String typeName;
ValPtr contentVal;
+ Slang::List<Slang::String> specializationArgs;
};
class ArrayVal : public ParentVal