diff options
| -rw-r--r-- | slang-gfx.h | 7 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-14.slang | 3 | ||||
| -rw-r--r-- | tools/gfx/debug-layer.cpp | 30 | ||||
| -rw-r--r-- | tools/gfx/debug-layer.h | 7 | ||||
| -rw-r--r-- | tools/gfx/renderer-shared.h | 53 | ||||
| -rw-r--r-- | tools/render-test/render-test-main.cpp | 20 | ||||
| -rw-r--r-- | tools/render-test/shader-input-layout.cpp | 32 | ||||
| -rw-r--r-- | tools/render-test/shader-input-layout.h | 1 |
8 files changed, 147 insertions, 6 deletions
diff --git a/slang-gfx.h b/slang-gfx.h index 7428f4c56..f32aee1ce 100644 --- a/slang-gfx.h +++ b/slang-gfx.h @@ -544,6 +544,13 @@ public: setSampler(ShaderOffset const& offset, ISamplerState* sampler) = 0; virtual SLANG_NO_THROW Result SLANG_MCALL setCombinedTextureSampler( ShaderOffset const& offset, IResourceView* textureView, ISamplerState* sampler) = 0; + + /// Manually setting the specialization arguments for the shader object, overriding + /// the default arguments computed from the sub-objects. + /// Specialization arguments are passed to the shader compiler to specialize the type + /// of interface-typed shader parameters. + virtual SLANG_NO_THROW Result SLANG_MCALL + setSpecializationArgs(const slang::SpecializationArg* args, uint32_t count) = 0; }; #define SLANG_UUID_IShaderObject \ { \ diff --git a/tests/compute/dynamic-dispatch-14.slang b/tests/compute/dynamic-dispatch-14.slang index 8361cd317..9bbed215e 100644 --- a/tests/compute/dynamic-dispatch-14.slang +++ b/tests/compute/dynamic-dispatch-14.slang @@ -29,8 +29,7 @@ RWStructuredBuffer<int> gOutputBuffer; //TEST_INPUT: set gCb = new StructuredBuffer<IInterface>{new MyImpl{1}}; RWStructuredBuffer<IInterface> gCb; -// Add two elements into the structured buffer to prevent specialization. -//TEST_INPUT: set gCb1 = new StructuredBuffer<IInterface>{new MyImpl{1}, new MyImpl2{2}}; +//TEST_INPUT: set gCb1 = new StructuredBuffer<IInterface>{new MyImpl{1}} : specialization_args(__Dynamic); RWStructuredBuffer<IInterface> gCb1; [numthreads(4, 1, 1)] diff --git a/tools/gfx/debug-layer.cpp b/tools/gfx/debug-layer.cpp index 56ae4fdab..3fa2eee9d 100644 --- a/tools/gfx/debug-layer.cpp +++ b/tools/gfx/debug-layer.cpp @@ -385,6 +385,8 @@ Result DebugDevice::createShaderObject( auto result = baseObject->createShaderObject(type, containerType, outObject->baseObject.writeRef()); outObject->m_typeName = typeName; + outObject->m_device = this; + outObject->m_slangType = type; if (SLANG_FAILED(result)) return result; returnComPtr(outShaderObject, outObject); @@ -952,10 +954,38 @@ Result DebugShaderObject::setCombinedTextureSampler( offset, viewImpl->baseObject.get(), samplerImpl->baseObject.get()); } +Result DebugShaderObject::setSpecializationArgs( + const slang::SpecializationArg* args, + uint32_t count) +{ + ComPtr<slang::ISession> session; + m_device->getSlangSession(session.writeRef()); + auto expectedCount = (uint32_t)session->getTypeLayout(m_slangType) + ->getSize(SLANG_PARAMETER_CATEGORY_EXISTENTIAL_TYPE_PARAM); + if (expectedCount != count) + { + GFX_DIAGNOSE_ERROR_FORMAT( + "specialization argument count for shader object type %s mismatch: expecting %d but %d " + "provided.", + m_typeName.getBuffer(), + expectedCount, + count); + }; + return baseObject->setSpecializationArgs(args, count); +} + DebugObjectBase::DebugObjectBase() { static uint64_t uidCounter = 0; uid = ++uidCounter; } +Result DebugRootShaderObject::setSpecializationArgs( + const slang::SpecializationArg* args, + uint32_t count) +{ + GFX_DIAGNOSE_ERROR("`setSpecializationArgs` should not be called directly on root objects."); + return baseObject->setSpecializationArgs(args, count); +} + } // namespace gfx diff --git a/tools/gfx/debug-layer.h b/tools/gfx/debug-layer.h index a4e201e4f..f9ecc7dfe 100644 --- a/tools/gfx/debug-layer.h +++ b/tools/gfx/debug-layer.h @@ -167,6 +167,9 @@ public: ShaderOffset const& offset, IResourceView* textureView, ISamplerState* sampler) override; + virtual SLANG_NO_THROW Result SLANG_MCALL setSpecializationArgs( + const slang::SpecializationArg* args, + uint32_t count) override; public: struct ShaderOffsetKey @@ -188,6 +191,8 @@ public: } }; Slang::String m_typeName; + slang::TypeReflection* m_slangType = nullptr; + DebugDevice* m_device; Slang::List<Slang::RefPtr<DebugShaderObject>> m_entryPoints; Slang::Dictionary<ShaderOffsetKey, Slang::RefPtr<DebugShaderObject>> m_objects; Slang::Dictionary<ShaderOffsetKey, Slang::RefPtr<DebugResourceView>> m_resources; @@ -199,6 +204,8 @@ class DebugRootShaderObject : public DebugShaderObject public: virtual SLANG_NO_THROW uint32_t SLANG_MCALL addRef() override { return 1; } virtual SLANG_NO_THROW uint32_t SLANG_MCALL release() override { return 1; } + virtual SLANG_NO_THROW Result SLANG_MCALL + setSpecializationArgs(const slang::SpecializationArg* args, uint32_t count) override; }; class DebugCommandBuffer; diff --git a/tools/gfx/renderer-shared.h b/tools/gfx/renderer-shared.h index ba0327cf4..83799ac25 100644 --- a/tools/gfx/renderer-shared.h +++ b/tools/gfx/renderer-shared.h @@ -296,6 +296,11 @@ struct ExtendedShaderObjectTypeList } }; +struct ExtendedShaderObjectTypeListObject + : public ExtendedShaderObjectTypeList + , public Slang::RefObject +{}; + class ShaderObjectLayoutBase : public Slang::RefObject { protected: @@ -471,6 +476,7 @@ protected: // Specialization args for a StructuredBuffer object. ExtendedShaderObjectTypeList m_structuredBufferSpecializationArgs; + Slang::RefPtr<ExtendedShaderObjectTypeListObject> m_userProvidedSpecializationArgs; public: TShaderObjectLayoutImpl* getLayout() @@ -687,10 +693,51 @@ public: return SLANG_OK; } + virtual SLANG_NO_THROW Result SLANG_MCALL + setSpecializationArgs(const slang::SpecializationArg* args, uint32_t count) override + { + if (!m_userProvidedSpecializationArgs) + { + m_userProvidedSpecializationArgs = new ExtendedShaderObjectTypeListObject(); + } + else + { + m_userProvidedSpecializationArgs->clear(); + } + auto device = getRenderer(); + for (uint32_t i = 0; i < count; i++) + { + gfx::ExtendedShaderObjectType extendedType; + switch (args[i].kind) + { + case slang::SpecializationArg::Kind::Type: + extendedType.slangType = args[i].type; + extendedType.componentID = device->shaderCache.getComponentId(args[i].type); + break; + default: + SLANG_ASSERT(false && "Unexpected specialization argument kind."); + return SLANG_FAIL; + } + m_userProvidedSpecializationArgs->add(extendedType); + } + return SLANG_OK; + } + // Appends all types that are used to specialize the element type of this shader object in // `args` list. virtual Result collectSpecializationArgs(ExtendedShaderObjectTypeList& args) override { + if (m_userProvidedSpecializationArgs) + { + args.addRange(*m_userProvidedSpecializationArgs); + return SLANG_OK; + } + if (m_layout->getContainerType() != ShaderObjectContainerType::None) + { + args.addRange(m_structuredBufferSpecializationArgs); + return SLANG_OK; + } + auto device = getRenderer(); auto& subObjectRanges = getLayout()->getSubObjectRanges(); // The following logic is built on the assumption that all fields that involve @@ -740,6 +787,8 @@ public: } case slang::BindingType::ParameterBlock: case slang::BindingType::ConstantBuffer: + case slang::BindingType::RawBuffer: + case slang::BindingType::MutableRawBuffer: // Currently we only handle the case where the field's type is // `ParameterBlock<SomeStruct>` or `ConstantBuffer<SomeStruct>`, where // `SomeStruct` is a struct type (not directly an interface type). In this case, @@ -751,10 +800,6 @@ public: // `ExistentialValue` case here, but currently we lack a mechanism to // distinguish the two scenarios. break; - case slang::BindingType::RawBuffer: - case slang::BindingType::MutableRawBuffer: - typeArgs.addRange(subObject->m_structuredBufferSpecializationArgs); - break; } auto addedTypeArgCountForCurrentRange = args.getCount() - oldArgsCount; diff --git a/tools/render-test/render-test-main.cpp b/tools/render-test/render-test-main.cpp index c04d6db00..5c8011889 100644 --- a/tools/render-test/render-test-main.cpp +++ b/tools/render-test/render-test-main.cpp @@ -341,6 +341,26 @@ struct AssignValsFromLayoutContext SLANG_RETURN_ON_FAIL(assign(ShaderCursor(shaderObject), srcVal->contentVal)); + if (srcVal->specializationArgs.getCount()) + { + List<slang::SpecializationArg> args; + for (auto srcArg : srcVal->specializationArgs) + { + auto argType = slangReflection->findTypeByName(srcArg.getBuffer()); + if (argType) + { + slang::SpecializationArg arg = slang::SpecializationArg::fromType(argType); + args.add(arg); + } + else + { + StdWriters::getError().print( + "error: could not find shader type '%s'\n", srcArg.getBuffer()); + return SLANG_E_INVALID_ARG; + } + } + shaderObject->setSpecializationArgs(args.getBuffer(), args.getCount()); + } dstCursor.setObject(shaderObject); return SLANG_OK; } diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 9362efb45..af15d5f9f 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -346,6 +346,37 @@ namespace renderer_test return SLANG_OK; } + SlangResult parseObjectAttributes(ShaderInputLayout::ObjectVal* val, Misc::TokenReader& parser) + { + if (parser.AdvanceIf(":")) + { + while (!parser.IsEnd() && parser.NextToken().Type == Misc::TokenType::Identifier) + { + if (parser.AdvanceIf("specialization_args")) + { + parser.Read(Misc::TokenType::LParent); + while (!parser.IsEnd() && + parser.NextToken().Type != Misc::TokenType::RParent) + { + val->specializationArgs.add(parseTypeName(parser)); + if (!parser.AdvanceIf(",")) + break; + } + parser.Read(Misc::TokenType::RParent); + } + else + { + throw ShaderInputLayoutFormatException( + StringBuilder() << "Unknown attribute \'" << parser.NextToken().Content << "\' (" + << parser.NextToken().Position.Line << ")"); + + return SLANG_FAIL; + } + } + } + return SLANG_OK; + } + Format parseFormatOption(Misc::TokenReader& parser) { parser.Read("="); @@ -511,6 +542,7 @@ namespace renderer_test } val->contentVal = parseValExpr(parser); + parseObjectAttributes(val, parser); return val; } else if( parser.AdvanceIf("out") ) diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h index adb30c7ec..86e7641f0 100644 --- a/tools/render-test/shader-input-layout.h +++ b/tools/render-test/shader-input-layout.h @@ -248,6 +248,7 @@ public: Slang::String typeName; ValPtr contentVal; + Slang::List<Slang::String> specializationArgs; }; class ArrayVal : public ParentVal |
