summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-12-11 13:34:54 -0800
committerGitHub <noreply@github.com>2024-12-11 13:34:54 -0800
commit941f07040a505f1f673c96da959bde839c629aba (patch)
treefe5cd3cd0a63919ad8971d32cd18e8161f9cbd99
parente50aac13e2c161d672b137a62f6d66820d0f9ff1 (diff)
Fix attribute reflection. (#5823)
* Fix attribute reflection. * Fix. * Fix.
-rw-r--r--include/slang.h46
-rw-r--r--source/slang/slang-check-modifier.cpp15
-rw-r--r--source/slang/slang-reflection-api.cpp10
-rw-r--r--tools/gfx/d3d12/d3d12-shader-object-layout.cpp2
-rw-r--r--tools/slang-unit-test/unit-test-attribute-reflection.cpp79
-rw-r--r--tools/slang-unit-test/unit-test-decl-tree-reflection.cpp2
-rw-r--r--tools/slang-unit-test/unit-test-function-reflection.cpp2
7 files changed, 123 insertions, 33 deletions
diff --git a/include/slang.h b/include/slang.h
index 2dfee6b28..2ba7150ec 100644
--- a/include/slang.h
+++ b/include/slang.h
@@ -1776,6 +1776,7 @@ public: \
typedef struct SlangReflectionVariableLayout SlangReflectionVariableLayout;
typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter;
typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute;
+ typedef SlangReflectionUserAttribute SlangReflectionAttribute;
typedef struct SlangReflectionFunction SlangReflectionFunction;
typedef struct SlangReflectionGeneric SlangReflectionGeneric;
@@ -2140,46 +2141,48 @@ union GenericArgReflection
bool boolVal;
};
-struct UserAttribute
+struct Attribute
{
char const* getName()
{
- return spReflectionUserAttribute_GetName((SlangReflectionUserAttribute*)this);
+ return spReflectionUserAttribute_GetName((SlangReflectionAttribute*)this);
}
uint32_t getArgumentCount()
{
return (uint32_t)spReflectionUserAttribute_GetArgumentCount(
- (SlangReflectionUserAttribute*)this);
+ (SlangReflectionAttribute*)this);
}
TypeReflection* getArgumentType(uint32_t index)
{
return (TypeReflection*)spReflectionUserAttribute_GetArgumentType(
- (SlangReflectionUserAttribute*)this,
+ (SlangReflectionAttribute*)this,
index);
}
SlangResult getArgumentValueInt(uint32_t index, int* value)
{
return spReflectionUserAttribute_GetArgumentValueInt(
- (SlangReflectionUserAttribute*)this,
+ (SlangReflectionAttribute*)this,
index,
value);
}
SlangResult getArgumentValueFloat(uint32_t index, float* value)
{
return spReflectionUserAttribute_GetArgumentValueFloat(
- (SlangReflectionUserAttribute*)this,
+ (SlangReflectionAttribute*)this,
index,
value);
}
const char* getArgumentValueString(uint32_t index, size_t* outSize)
{
return spReflectionUserAttribute_GetArgumentValueString(
- (SlangReflectionUserAttribute*)this,
+ (SlangReflectionAttribute*)this,
index,
outSize);
}
};
+typedef Attribute UserAttribute;
+
struct TypeReflection
{
enum class Kind
@@ -2320,13 +2323,15 @@ struct TypeReflection
return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index);
}
- UserAttribute* findUserAttributeByName(char const* name)
+ UserAttribute* findAttributeByName(char const* name)
{
return (UserAttribute*)spReflectionType_FindUserAttributeByName(
(SlangReflectionType*)this,
name);
}
+ UserAttribute* findUserAttributeByName(char const* name) { return findAttributeByName(name); }
+
TypeReflection* applySpecializations(GenericReflection* generic)
{
return (TypeReflection*)spReflectionType_applySpecializations(
@@ -2777,14 +2782,14 @@ struct VariableReflection
return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this);
}
- UserAttribute* getUserAttributeByIndex(unsigned int index)
+ Attribute* getUserAttributeByIndex(unsigned int index)
{
return (UserAttribute*)spReflectionVariable_GetUserAttribute(
(SlangReflectionVariable*)this,
index);
}
- UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name)
+ Attribute* findAttributeByName(SlangSession* globalSession, char const* name)
{
return (UserAttribute*)spReflectionVariable_FindUserAttributeByName(
(SlangReflectionVariable*)this,
@@ -2792,6 +2797,11 @@ struct VariableReflection
name);
}
+ Attribute* findUserAttributeByName(SlangSession* globalSession, char const* name)
+ {
+ return findAttributeByName(globalSession, name);
+ }
+
bool hasDefaultValue()
{
return spReflectionVariable_HasDefaultValue((SlangReflectionVariable*)this);
@@ -2908,20 +2918,22 @@ struct FunctionReflection
{
return spReflectionFunction_GetUserAttributeCount((SlangReflectionFunction*)this);
}
- UserAttribute* getUserAttributeByIndex(unsigned int index)
+ Attribute* getUserAttributeByIndex(unsigned int index)
{
- return (UserAttribute*)spReflectionFunction_GetUserAttribute(
- (SlangReflectionFunction*)this,
- index);
+ return (
+ Attribute*)spReflectionFunction_GetUserAttribute((SlangReflectionFunction*)this, index);
}
- UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name)
+ Attribute* findAttributeByName(SlangSession* globalSession, char const* name)
{
- return (UserAttribute*)spReflectionFunction_FindUserAttributeByName(
+ return (Attribute*)spReflectionFunction_FindUserAttributeByName(
(SlangReflectionFunction*)this,
globalSession,
name);
}
-
+ Attribute* findUserAttributeByName(SlangSession* globalSession, char const* name)
+ {
+ return findAttributeByName(globalSession, name);
+ }
Modifier* findModifier(Modifier::ID id)
{
return (Modifier*)spReflectionFunction_FindModifier(
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index aebfe3b96..05eb978bc 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -752,18 +752,15 @@ Modifier* SemanticsVisitor::validateAttribute(
{
auto& arg = attr->args[paramIndex];
bool typeChecked = false;
- if (auto basicType = as<BasicExpressionType>(paramDecl->getType()))
+ if (isValidCompileTimeConstantType(paramDecl->getType()))
{
- if (basicType->getBaseType() == BaseType::Int)
+ if (auto cint = checkConstantIntVal(arg))
{
- if (auto cint = checkConstantIntVal(arg))
- {
- for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++)
- attr->intArgVals.add(nullptr);
- attr->intArgVals[(uint32_t)paramIndex] = cint;
- }
- typeChecked = true;
+ for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++)
+ attr->intArgVals.add(nullptr);
+ attr->intArgVals[(uint32_t)paramIndex] = cint;
}
+ typeChecked = true;
}
if (!typeChecked)
{
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 93e495194..d7f793d05 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -23,11 +23,11 @@ namespace Slang
// Conversion routines to help with strongly-typed reflection API
-static inline UserDefinedAttribute* convert(SlangReflectionUserAttribute* attrib)
+static inline Attribute* convert(SlangReflectionUserAttribute* attrib)
{
- return (UserDefinedAttribute*)attrib;
+ return (Attribute*)attrib;
}
-static inline SlangReflectionUserAttribute* convert(UserDefinedAttribute* attrib)
+static inline SlangReflectionUserAttribute* convert(Attribute* attrib)
{
return (SlangReflectionUserAttribute*)attrib;
}
@@ -154,7 +154,9 @@ static SlangReflectionUserAttribute* findUserAttributeByName(
const char* name)
{
auto nameObj = session->tryGetNameObj(name);
- for (auto x : decl->getModifiersOfType<UserDefinedAttribute>())
+ if (!nameObj)
+ return nullptr;
+ for (auto x : decl->getModifiersOfType<Attribute>())
{
if (x->keywordName == nameObj)
return (SlangReflectionUserAttribute*)(x);
diff --git a/tools/gfx/d3d12/d3d12-shader-object-layout.cpp b/tools/gfx/d3d12/d3d12-shader-object-layout.cpp
index 6cf51ee2b..8e2d24ad6 100644
--- a/tools/gfx/d3d12/d3d12-shader-object-layout.cpp
+++ b/tools/gfx/d3d12/d3d12-shader-object-layout.cpp
@@ -39,7 +39,7 @@ bool ShaderObjectLayoutImpl::isBindingRangeRootParameter(
{
if (auto leafVariable = typeLayout->getBindingRangeLeafVariable(bindingRangeIndex))
{
- if (leafVariable->findUserAttributeByName(globalSession, rootParameterAttributeName))
+ if (leafVariable->findAttributeByName(globalSession, rootParameterAttributeName))
{
isRootParameter = true;
}
diff --git a/tools/slang-unit-test/unit-test-attribute-reflection.cpp b/tools/slang-unit-test/unit-test-attribute-reflection.cpp
new file mode 100644
index 000000000..e60eeb2d4
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-attribute-reflection.cpp
@@ -0,0 +1,79 @@
+// unit-test-translation-unit-import.cpp
+
+#include "../../source/core/slang-io.h"
+#include "../../source/core/slang-process.h"
+#include "slang-com-ptr.h"
+#include "slang.h"
+#include "unit-test/slang-unit-test.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+using namespace Slang;
+
+// Test that the reflection API provides correct info about attributes.
+
+SLANG_UNIT_TEST(attributeReflection)
+{
+ const char* userSourceBody = R"(
+ public enum E
+ {
+ V0,
+ V1,
+ };
+
+ [__AttributeUsage(_AttributeTargets.Struct)]
+ public struct NormalTextureAttribute
+ {
+ public E Type;
+ };
+
+ [COM("042BE50B-CB01-4DBB-8367-3A9CDCBE2F49")]
+ interface IInterface { void f(); }
+
+ [NormalTexture(E.V1)]
+ struct TS {};
+ )";
+ String userSource = userSourceBody;
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_HLSL;
+ targetDesc.profile = globalSession->findProfile("sm_5_0");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString(
+ "m",
+ "m.slang",
+ userSourceBody,
+ diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ auto reflection = module->getLayout();
+
+ auto interfaceType = reflection->findTypeByName("IInterface");
+ SLANG_CHECK(interfaceType != nullptr);
+
+ auto comAttribute = interfaceType->findAttributeByName("COM");
+ SLANG_CHECK(comAttribute != nullptr);
+
+ size_t size = 0;
+ auto guid = comAttribute->getArgumentValueString(0, &size);
+ UnownedStringSlice stringSlice = UnownedStringSlice(guid, size);
+ SLANG_CHECK(stringSlice == "\"042BE50B-CB01-4DBB-8367-3A9CDCBE2F49\"");
+
+ auto testType = reflection->findTypeByName("TS");
+ SLANG_CHECK(testType != nullptr);
+
+ auto normalTextureAttribute = testType->findAttributeByName("NormalTexture");
+ SLANG_CHECK(normalTextureAttribute != nullptr);
+
+ int value = 0;
+ normalTextureAttribute->getArgumentValueInt(0, &value);
+ SLANG_CHECK(value == 1);
+}
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
index 2ceb9981b..512be9be5 100644
--- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
@@ -178,7 +178,7 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(result == SLANG_OK);
SLANG_CHECK(val == 1024);
SLANG_CHECK(
- funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") ==
+ funcReflection->findAttributeByName(globalSession.get(), "MyFuncProperty") ==
userAttribute);
}
diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp
index 52c2e795a..3ce6ab7a5 100644
--- a/tools/slang-unit-test/unit-test-function-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-function-reflection.cpp
@@ -108,7 +108,7 @@ SLANG_UNIT_TEST(functionReflection)
SLANG_CHECK(result == SLANG_OK);
SLANG_CHECK(val == 1024);
SLANG_CHECK(
- funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") ==
+ funcReflection->findAttributeByName(globalSession.get(), "MyFuncProperty") ==
userAttribute);
// Check overloaded method resolution