diff options
| author | Yong He <yonghe@outlook.com> | 2024-12-11 13:34:54 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-12-11 13:34:54 -0800 |
| commit | 941f07040a505f1f673c96da959bde839c629aba (patch) | |
| tree | fe5cd3cd0a63919ad8971d32cd18e8161f9cbd99 | |
| parent | e50aac13e2c161d672b137a62f6d66820d0f9ff1 (diff) | |
Fix attribute reflection. (#5823)
* Fix attribute reflection.
* Fix.
* Fix.
| -rw-r--r-- | include/slang.h | 46 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 15 | ||||
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 10 | ||||
| -rw-r--r-- | tools/gfx/d3d12/d3d12-shader-object-layout.cpp | 2 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-attribute-reflection.cpp | 79 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-decl-tree-reflection.cpp | 2 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-function-reflection.cpp | 2 |
7 files changed, 123 insertions, 33 deletions
diff --git a/include/slang.h b/include/slang.h index 2dfee6b28..2ba7150ec 100644 --- a/include/slang.h +++ b/include/slang.h @@ -1776,6 +1776,7 @@ public: \ typedef struct SlangReflectionVariableLayout SlangReflectionVariableLayout; typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter; typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute; + typedef SlangReflectionUserAttribute SlangReflectionAttribute; typedef struct SlangReflectionFunction SlangReflectionFunction; typedef struct SlangReflectionGeneric SlangReflectionGeneric; @@ -2140,46 +2141,48 @@ union GenericArgReflection bool boolVal; }; -struct UserAttribute +struct Attribute { char const* getName() { - return spReflectionUserAttribute_GetName((SlangReflectionUserAttribute*)this); + return spReflectionUserAttribute_GetName((SlangReflectionAttribute*)this); } uint32_t getArgumentCount() { return (uint32_t)spReflectionUserAttribute_GetArgumentCount( - (SlangReflectionUserAttribute*)this); + (SlangReflectionAttribute*)this); } TypeReflection* getArgumentType(uint32_t index) { return (TypeReflection*)spReflectionUserAttribute_GetArgumentType( - (SlangReflectionUserAttribute*)this, + (SlangReflectionAttribute*)this, index); } SlangResult getArgumentValueInt(uint32_t index, int* value) { return spReflectionUserAttribute_GetArgumentValueInt( - (SlangReflectionUserAttribute*)this, + (SlangReflectionAttribute*)this, index, value); } SlangResult getArgumentValueFloat(uint32_t index, float* value) { return spReflectionUserAttribute_GetArgumentValueFloat( - (SlangReflectionUserAttribute*)this, + (SlangReflectionAttribute*)this, index, value); } const char* getArgumentValueString(uint32_t index, size_t* outSize) { return spReflectionUserAttribute_GetArgumentValueString( - (SlangReflectionUserAttribute*)this, + (SlangReflectionAttribute*)this, index, outSize); } }; +typedef Attribute UserAttribute; + struct TypeReflection { enum class Kind @@ -2320,13 +2323,15 @@ struct TypeReflection return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index); } - UserAttribute* findUserAttributeByName(char const* name) + UserAttribute* findAttributeByName(char const* name) { return (UserAttribute*)spReflectionType_FindUserAttributeByName( (SlangReflectionType*)this, name); } + UserAttribute* findUserAttributeByName(char const* name) { return findAttributeByName(name); } + TypeReflection* applySpecializations(GenericReflection* generic) { return (TypeReflection*)spReflectionType_applySpecializations( @@ -2777,14 +2782,14 @@ struct VariableReflection return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this); } - UserAttribute* getUserAttributeByIndex(unsigned int index) + Attribute* getUserAttributeByIndex(unsigned int index) { return (UserAttribute*)spReflectionVariable_GetUserAttribute( (SlangReflectionVariable*)this, index); } - UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name) + Attribute* findAttributeByName(SlangSession* globalSession, char const* name) { return (UserAttribute*)spReflectionVariable_FindUserAttributeByName( (SlangReflectionVariable*)this, @@ -2792,6 +2797,11 @@ struct VariableReflection name); } + Attribute* findUserAttributeByName(SlangSession* globalSession, char const* name) + { + return findAttributeByName(globalSession, name); + } + bool hasDefaultValue() { return spReflectionVariable_HasDefaultValue((SlangReflectionVariable*)this); @@ -2908,20 +2918,22 @@ struct FunctionReflection { return spReflectionFunction_GetUserAttributeCount((SlangReflectionFunction*)this); } - UserAttribute* getUserAttributeByIndex(unsigned int index) + Attribute* getUserAttributeByIndex(unsigned int index) { - return (UserAttribute*)spReflectionFunction_GetUserAttribute( - (SlangReflectionFunction*)this, - index); + return ( + Attribute*)spReflectionFunction_GetUserAttribute((SlangReflectionFunction*)this, index); } - UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name) + Attribute* findAttributeByName(SlangSession* globalSession, char const* name) { - return (UserAttribute*)spReflectionFunction_FindUserAttributeByName( + return (Attribute*)spReflectionFunction_FindUserAttributeByName( (SlangReflectionFunction*)this, globalSession, name); } - + Attribute* findUserAttributeByName(SlangSession* globalSession, char const* name) + { + return findAttributeByName(globalSession, name); + } Modifier* findModifier(Modifier::ID id) { return (Modifier*)spReflectionFunction_FindModifier( diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index aebfe3b96..05eb978bc 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -752,18 +752,15 @@ Modifier* SemanticsVisitor::validateAttribute( { auto& arg = attr->args[paramIndex]; bool typeChecked = false; - if (auto basicType = as<BasicExpressionType>(paramDecl->getType())) + if (isValidCompileTimeConstantType(paramDecl->getType())) { - if (basicType->getBaseType() == BaseType::Int) + if (auto cint = checkConstantIntVal(arg)) { - if (auto cint = checkConstantIntVal(arg)) - { - for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++) - attr->intArgVals.add(nullptr); - attr->intArgVals[(uint32_t)paramIndex] = cint; - } - typeChecked = true; + for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++) + attr->intArgVals.add(nullptr); + attr->intArgVals[(uint32_t)paramIndex] = cint; } + typeChecked = true; } if (!typeChecked) { diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 93e495194..d7f793d05 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -23,11 +23,11 @@ namespace Slang // Conversion routines to help with strongly-typed reflection API -static inline UserDefinedAttribute* convert(SlangReflectionUserAttribute* attrib) +static inline Attribute* convert(SlangReflectionUserAttribute* attrib) { - return (UserDefinedAttribute*)attrib; + return (Attribute*)attrib; } -static inline SlangReflectionUserAttribute* convert(UserDefinedAttribute* attrib) +static inline SlangReflectionUserAttribute* convert(Attribute* attrib) { return (SlangReflectionUserAttribute*)attrib; } @@ -154,7 +154,9 @@ static SlangReflectionUserAttribute* findUserAttributeByName( const char* name) { auto nameObj = session->tryGetNameObj(name); - for (auto x : decl->getModifiersOfType<UserDefinedAttribute>()) + if (!nameObj) + return nullptr; + for (auto x : decl->getModifiersOfType<Attribute>()) { if (x->keywordName == nameObj) return (SlangReflectionUserAttribute*)(x); diff --git a/tools/gfx/d3d12/d3d12-shader-object-layout.cpp b/tools/gfx/d3d12/d3d12-shader-object-layout.cpp index 6cf51ee2b..8e2d24ad6 100644 --- a/tools/gfx/d3d12/d3d12-shader-object-layout.cpp +++ b/tools/gfx/d3d12/d3d12-shader-object-layout.cpp @@ -39,7 +39,7 @@ bool ShaderObjectLayoutImpl::isBindingRangeRootParameter( { if (auto leafVariable = typeLayout->getBindingRangeLeafVariable(bindingRangeIndex)) { - if (leafVariable->findUserAttributeByName(globalSession, rootParameterAttributeName)) + if (leafVariable->findAttributeByName(globalSession, rootParameterAttributeName)) { isRootParameter = true; } diff --git a/tools/slang-unit-test/unit-test-attribute-reflection.cpp b/tools/slang-unit-test/unit-test-attribute-reflection.cpp new file mode 100644 index 000000000..e60eeb2d4 --- /dev/null +++ b/tools/slang-unit-test/unit-test-attribute-reflection.cpp @@ -0,0 +1,79 @@ +// unit-test-translation-unit-import.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include <stdio.h> +#include <stdlib.h> + +using namespace Slang; + +// Test that the reflection API provides correct info about attributes. + +SLANG_UNIT_TEST(attributeReflection) +{ + const char* userSourceBody = R"( + public enum E + { + V0, + V1, + }; + + [__AttributeUsage(_AttributeTargets.Struct)] + public struct NormalTextureAttribute + { + public E Type; + }; + + [COM("042BE50B-CB01-4DBB-8367-3A9CDCBE2F49")] + interface IInterface { void f(); } + + [NormalTexture(E.V1)] + struct TS {}; + )"; + String userSource = userSourceBody; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + auto reflection = module->getLayout(); + + auto interfaceType = reflection->findTypeByName("IInterface"); + SLANG_CHECK(interfaceType != nullptr); + + auto comAttribute = interfaceType->findAttributeByName("COM"); + SLANG_CHECK(comAttribute != nullptr); + + size_t size = 0; + auto guid = comAttribute->getArgumentValueString(0, &size); + UnownedStringSlice stringSlice = UnownedStringSlice(guid, size); + SLANG_CHECK(stringSlice == "\"042BE50B-CB01-4DBB-8367-3A9CDCBE2F49\""); + + auto testType = reflection->findTypeByName("TS"); + SLANG_CHECK(testType != nullptr); + + auto normalTextureAttribute = testType->findAttributeByName("NormalTexture"); + SLANG_CHECK(normalTextureAttribute != nullptr); + + int value = 0; + normalTextureAttribute->getArgumentValueInt(0, &value); + SLANG_CHECK(value == 1); +} diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp index 2ceb9981b..512be9be5 100644 --- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -178,7 +178,7 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(result == SLANG_OK); SLANG_CHECK(val == 1024); SLANG_CHECK( - funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == + funcReflection->findAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute); } diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp index 52c2e795a..3ce6ab7a5 100644 --- a/tools/slang-unit-test/unit-test-function-reflection.cpp +++ b/tools/slang-unit-test/unit-test-function-reflection.cpp @@ -108,7 +108,7 @@ SLANG_UNIT_TEST(functionReflection) SLANG_CHECK(result == SLANG_OK); SLANG_CHECK(val == 1024); SLANG_CHECK( - funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == + funcReflection->findAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute); // Check overloaded method resolution |
