summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--slang.h76
-rw-r--r--source/slang/check.cpp177
-rw-r--r--source/slang/compiler.h1
-rw-r--r--source/slang/core.meta.slang9
-rw-r--r--source/slang/core.meta.slang.h12
-rw-r--r--source/slang/diagnostic-defs.h5
-rw-r--r--source/slang/hlsl.meta.slang.h1
-rw-r--r--source/slang/modifier-defs.h8
-rw-r--r--source/slang/name.cpp8
-rw-r--r--source/slang/name.h4
-rw-r--r--source/slang/reflection.cpp158
-rw-r--r--source/slang/syntax.h11
-rw-r--r--tests/diagnostics/attribute-error.slang34
-rw-r--r--tests/diagnostics/attribute-error.slang.expected8
-rw-r--r--tests/reflection/attribute.slang42
-rw-r--r--tests/reflection/attribute.slang.expected118
-rw-r--r--tools/slang-reflection-test/slang-reflection-test-main.cpp96
18 files changed, 747 insertions, 22 deletions
diff --git a/.gitignore b/.gitignore
index 206dfca5e..b92a66e6e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -24,3 +24,4 @@ build.*/
# Files generated by other shader compilers
*.spv
+/source/slang/core.meta.slang.temp.h
diff --git a/slang.h b/slang.h
index 12beedb5b..9a1880536 100644
--- a/slang.h
+++ b/slang.h
@@ -1357,6 +1357,7 @@ extern "C"
typedef struct SlangReflectionVariable SlangReflectionVariable;
typedef struct SlangReflectionVariableLayout SlangReflectionVariableLayout;
typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter;
+ typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute;
// get reflection data from a compilation request
SLANG_API SlangReflection* spGetReflection(
@@ -1492,9 +1493,26 @@ extern "C"
SLANG_MODIFIER_SHARED,
};
+ // User Attribute
+ SLANG_API char const* spReflectionUserAttribute_GetName(SlangReflectionUserAttribute* attrib);
+ SLANG_API unsigned int spReflectionUserAttribute_GetArgumentCount(SlangReflectionUserAttribute* attrib);
+ SLANG_API SlangReflectionType* spReflectionUserAttribute_GetArgumentType(SlangReflectionUserAttribute* attrib, unsigned int index);
+ SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueInt(SlangReflectionUserAttribute* attrib, unsigned int index, int * rs);
+ SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueFloat(SlangReflectionUserAttribute* attrib, unsigned int index, float * rs);
+
+ /** Returns the string-typed value of a user attribute argument
+ The string returned is not null-terminated. The length of the string is returned via `outSize`.
+ If index of out of range, or if the specified argument is not a string, the function will return nullptr.
+ */
+ SLANG_API const char* spReflectionUserAttribute_GetArgumentValueString(SlangReflectionUserAttribute* attrib, unsigned int index, size_t * outSize);
+
// Type Reflection
SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* type);
+ SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* type);
+ SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* type, unsigned int index);
+ SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* type, char const* name);
+
SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* type);
SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* type, unsigned index);
@@ -1548,8 +1566,10 @@ extern "C"
SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* var);
SLANG_API SlangReflectionType* spReflectionVariable_GetType(SlangReflectionVariable* var);
-
SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflectionVariable* var, SlangModifierID modifierID);
+ SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* var);
+ SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* var, unsigned int index);
+ SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * session, char const* name);
// Variable Layout Reflection
@@ -1641,6 +1661,34 @@ namespace slang
struct TypeReflection;
struct VariableLayoutReflection;
struct VariableReflection;
+
+ struct UserAttribute
+ {
+ char const* getName()
+ {
+ return spReflectionUserAttribute_GetName((SlangReflectionUserAttribute*)this);
+ }
+ uint32_t getArgumentCount()
+ {
+ return (uint32_t)spReflectionUserAttribute_GetArgumentCount((SlangReflectionUserAttribute*)this);
+ }
+ TypeReflection* getArgumentType(uint32_t index)
+ {
+ return (TypeReflection*)spReflectionUserAttribute_GetArgumentType((SlangReflectionUserAttribute*)this, index);
+ }
+ SlangResult getArgumentValueInt(uint32_t index, int * value)
+ {
+ return spReflectionUserAttribute_GetArgumentValueInt((SlangReflectionUserAttribute*)this, index, value);
+ }
+ SlangResult getArgumentValueFloat(uint32_t index, float * value)
+ {
+ return spReflectionUserAttribute_GetArgumentValueFloat((SlangReflectionUserAttribute*)this, index, value);
+ }
+ const char* getArgumentValueString(uint32_t index, size_t * outSize)
+ {
+ return spReflectionUserAttribute_GetArgumentValueString((SlangReflectionUserAttribute*)this, index, outSize);
+ }
+ };
struct TypeReflection
{
@@ -1764,6 +1812,19 @@ namespace slang
{
return spReflectionType_GetName((SlangReflectionType*) this);
}
+
+ unsigned int getUserAttributeCount()
+ {
+ return spReflectionType_GetUserAttributeCount((SlangReflectionType*)this);
+ }
+ UserAttribute* getUserAttributeByIndex(unsigned int index)
+ {
+ return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index);
+ }
+ UserAttribute* findUserAttributeByName(char const* name)
+ {
+ return (UserAttribute*)spReflectionType_FindUserAttributeByName((SlangReflectionType*)this, name);
+ }
};
enum ParameterCategory : SlangParameterCategory
@@ -1944,6 +2005,19 @@ namespace slang
{
return (Modifier*) spReflectionVariable_FindModifier((SlangReflectionVariable*) this, (SlangModifierID) id);
}
+
+ unsigned int getUserAttributeCount()
+ {
+ return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this);
+ }
+ UserAttribute* getUserAttributeByIndex(unsigned int index)
+ {
+ return (UserAttribute*)spReflectionVariable_GetUserAttribute((SlangReflectionVariable*)this, index);
+ }
+ UserAttribute* findUserAttributeByName(SlangSession* session, char const* name)
+ {
+ return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, session, name);
+ }
};
struct VariableLayoutReflection
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 42bf396fa..e100cbd1d 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -2389,20 +2389,75 @@ namespace Slang
// rules to keep us from seeing shadowing variable declarations.
auto lookupResult = lookUp(getSession(), this, attributeName, scope, LookupMask::Attribute);
- // If we didn't find anything, or the result was overloaded,
+ // If the result was overloaded,
// then we aren't going to be able to extract a single decl.
- if(!lookupResult.isValid() || lookupResult.isOverloaded())
+ if(lookupResult.isOverloaded())
return nullptr;
- auto decl = lookupResult.item.declRef.getDecl();
- if( auto attributeDecl = dynamic_cast<AttributeDecl*>(decl) )
+ if (lookupResult.isValid())
{
- return attributeDecl;
+ auto decl = lookupResult.item.declRef.getDecl();
+ if (auto attributeDecl = dynamic_cast<AttributeDecl*>(decl))
+ {
+ return attributeDecl;
+ }
+ else
+ {
+ return nullptr;
+ }
}
- else
+
+ // If we couldn't find a system attribute, try looking up as a user defined attribute
+ // A user defined attribute class is defined as a struct type with a "UserDefinedAttributeAttribute" modifier
+ lookupResult = lookUp(getSession(), this, getSession()->getNameObj(attributeName->text + "Attribute"), scope, LookupMask::type);
+ if (lookupResult.isOverloaded())
{
- return nullptr;
+ // see if we have already created an AttributeDecl for this attribute struct
+ for (auto alt : lookupResult.items)
+ {
+ if (auto adecl = alt.declRef.As<AttributeDecl>())
+ return adecl.getDecl();
+ }
}
+ // If we still cannot find any thing, quit
+ if (!lookupResult.isValid() || lookupResult.isOverloaded())
+ return nullptr;
+ // Now construct an AttributeDecl for this user defined attribute class for future lookup
+ auto userDefAttribAttrib = lookupResult.item.declRef.decl->FindModifier<AttributeUsageAttribute>();
+ if (!userDefAttribAttrib)
+ return nullptr;
+ // create an AttributeDecl for the user defined attribute
+ auto structAttribDef = lookupResult.item.declRef.As<StructDecl>().getDecl();
+ RefPtr<AttributeDecl> attribDecl = new AttributeDecl();
+ attribDecl->nameAndLoc = structAttribDef->nameAndLoc;
+ attribDecl->loc = structAttribDef->loc;
+ attribDecl->nextInContainerWithSameName = structAttribDef->nextInContainerWithSameName;
+ // create a __attributeTarget modifier for the attribute class definition
+ RefPtr<AttributeTargetModifier> targetModifier = new AttributeTargetModifier();
+ targetModifier->syntaxClass = userDefAttribAttrib->targetSyntaxClass;
+ targetModifier->loc = structAttribDef->loc;
+ targetModifier->next = attribDecl->modifiers.first;
+ attribDecl->modifiers.first = targetModifier;
+ structAttribDef->nextInContainerWithSameName = attribDecl.Ptr();
+ // we should create UserDefinedAttribute nodes for all user defined attribute instances
+ attribDecl->syntaxClass = getSession()->findSyntaxClass(getSession()->getNameObj("UserDefinedAttribute"));
+ for (auto member : structAttribDef->Members)
+ {
+ if (auto varMember = member.As<VarDecl>())
+ {
+ RefPtr<ParamDecl> param = new ParamDecl();
+ param->nameAndLoc = member->nameAndLoc;
+ param->type = varMember->type;
+ param->loc = member->loc;
+ attribDecl->Members.Add(param);
+ }
+ }
+ // add the attribute class definition to the syntax tree, so it can be found
+ structAttribDef->ParentDecl->Members.Add(attribDecl.Ptr());
+ structAttribDef->ParentDecl->memberDictionaryIsValid = false;
+ // do necessary checks on this newly constructed node
+ checkDecl(attribDecl.Ptr());
+ return attribDecl.Ptr();
}
bool hasIntArgs(Attribute* attr, int numArgs)
@@ -2437,7 +2492,27 @@ namespace Slang
return true;
}
- bool validateAttribute(RefPtr<Attribute> attr)
+ bool getAttributeTargetSyntaxClasses(SyntaxClass<RefObject> & cls, uint32_t typeFlags)
+ {
+ if (typeFlags == (int)UserDefinedAttributeTargets::Struct)
+ {
+ cls = getSession()->findSyntaxClass(getSession()->getNameObj("StructDecl"));
+ return true;
+ }
+ if (typeFlags == (int)UserDefinedAttributeTargets::Var)
+ {
+ cls = getSession()->findSyntaxClass(getSession()->getNameObj("VarDecl"));
+ return true;
+ }
+ if (typeFlags == (int)UserDefinedAttributeTargets::Function)
+ {
+ cls = getSession()->findSyntaxClass(getSession()->getNameObj("FuncDecl"));
+ return true;
+ }
+ return false;
+ }
+
+ bool validateAttribute(RefPtr<Attribute> attr, AttributeDecl* attribClassDecl)
{
if(auto numThreadsAttr = attr.As<NumThreadsAttribute>())
{
@@ -2529,6 +2604,67 @@ namespace Slang
// Has no args
SLANG_ASSERT(attr->args.Count() == 0);
}
+ else if (auto attrUsageAttr = attr.As<AttributeUsageAttribute>())
+ {
+ uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None;
+ if (attr->args.Count() == 1)
+ {
+ RefPtr<IntVal> outIntVal;
+ if (auto cInt = checkConstantIntVal(attr->args[0]))
+ {
+ targetClassId = (uint32_t)(cInt->value);
+ }
+ else
+ {
+ getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name);
+ return false;
+ }
+ }
+ if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId))
+ {
+ getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget);
+ return false;
+ }
+ }
+ else if (auto userDefAttr = attr.As<UserDefinedAttribute>())
+ {
+ // check arguments against attribute parameters defined in attribClassDecl
+ uint32_t paramIndex = 0;
+ auto params = attribClassDecl->getMembersOfType<ParamDecl>();
+ for (auto paramDecl : params)
+ {
+ if (paramIndex < attr->args.Count())
+ {
+ auto & arg = attr->args[paramIndex];
+ bool typeChecked = false;
+ if (auto basicType = paramDecl->getType()->AsBasicType())
+ {
+ if (basicType->baseType == BaseType::Int)
+ {
+ if (auto cint = checkConstantIntVal(arg))
+ {
+ attr->intArgVals[(uint32_t)paramIndex] = cint;
+ }
+ typeChecked = true;
+ }
+ }
+ if (!typeChecked)
+ {
+ arg = CheckExpr(arg);
+ arg = Coerce(paramDecl->getType(), arg);
+ }
+ }
+ paramIndex++;
+ }
+ if (params.Count() < attr->args.Count())
+ {
+ getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.Count(), params.Count());
+ }
+ else if (params.Count() > attr->args.Count())
+ {
+ getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.Count(), params.Count());
+ }
+ }
else
{
if(attr->args.Count() == 0)
@@ -2597,11 +2733,9 @@ namespace Slang
UInt paramIndex = paramCounter++;
if( paramIndex < argCount )
{
- auto arg = attr->args[paramIndex];
-
// TODO: support checking the argument against the declared
// type for the parameter.
-
+
}
else
{
@@ -2653,7 +2787,7 @@ namespace Slang
}
// Now apply type-specific validation to the attribute.
- if(!validateAttribute(attr))
+ if(!validateAttribute(attr, attrDecl))
{
return uncheckedAttr;
}
@@ -2778,6 +2912,24 @@ namespace Slang
registerExtension(extDecl);
}
}
+ // check user defined attribute classes first
+ for (auto decl : programNode->Members)
+ {
+ if (auto typeMember = decl->As<StructDecl>())
+ {
+ bool isTypeAttributeClass = false;
+ for (auto attrib : typeMember->GetModifiersOfType<UncheckedAttribute>())
+ {
+ if (attrib->name == getSession()->getNameObj("AttributeUsageAttribute"))
+ {
+ isTypeAttributeClass = true;
+ break;
+ }
+ }
+ if (isTypeAttributeClass)
+ checkDecl(decl);
+ }
+ }
// check types
for (auto & s : programNode->getMembersOfType<TypeDefDecl>())
checkDecl(s.Ptr());
@@ -8588,6 +8740,7 @@ namespace Slang
RefPtr<Expr> visitStaticMemberExpr(StaticMemberExpr* /*expr*/)
{
+ // StaticMemberExpr means it is already checked
SLANG_UNEXPECTED("should not occur in unchecked AST");
UNREACHABLE_RETURN(nullptr);
}
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index ca82950be..9ca4bbc70 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -587,6 +587,7 @@ namespace Slang
RootNamePool* getRootNamePool() { return &rootNamePool; }
NamePool* getNamePool() { return &namePool; }
Name* getNameObj(String name) { return namePool.getName(name); }
+ Name* tryGetNameObj(String name) { return namePool.tryGetName(name); }
//
// Generated code for stdlib, etc.
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index e6bad3f50..50a4cbf96 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -1286,3 +1286,12 @@ attribute_syntax [mutating] : MutatingAttribute;
/// This is equivalent to the LLVM `readnone` function attribute.
__attributeTarget(FunctionDeclBase)
attribute_syntax [__readNone] : ReadNoneAttribute;
+
+enum AttributeTargets
+{
+ Struct = $((int) UserDefinedAttributeTargets::Struct),
+ Var = $((int) UserDefinedAttributeTargets::Var),
+ Function = $((int) UserDefinedAttributeTargets::Function),
+};
+__attributeTarget(StructDecl)
+attribute_syntax [AttributeUsage(target : AttributeTargets)] : AttributeUsageAttribute; \ No newline at end of file
diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h
index a29fc24ce..746e75bd9 100644
--- a/source/slang/core.meta.slang.h
+++ b/source/slang/core.meta.slang.h
@@ -1304,3 +1304,15 @@ SLANG_RAW(" ///\n")
SLANG_RAW(" /// This is equivalent to the LLVM `readnone` function attribute.\n")
SLANG_RAW("__attributeTarget(FunctionDeclBase)\n")
SLANG_RAW("attribute_syntax [__readNone] : ReadNoneAttribute;\n")
+SLANG_RAW("\n")
+SLANG_RAW("enum AttributeTargets\n")
+SLANG_RAW("{\n")
+
+ sb << "Struct = " << (int)UserDefinedAttributeTargets::Struct << ", ";
+ sb << "Var = " << (int)UserDefinedAttributeTargets::Var << ", ";
+ sb << "Function = " << (int)UserDefinedAttributeTargets::Function;
+
+SLANG_RAW("\n")
+SLANG_RAW("};\n")
+SLANG_RAW("__attributeTarget(StructDecl)\n")
+SLANG_RAW("attribute_syntax [AttributeUsage(target : AttributeTargets)] : AttributeUsageAttribute;")
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index 2c4f00dd6..68790d4ab 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -260,9 +260,8 @@ DIAGNOSTIC(31005, Error, expectedSingleStringArg, "attribute '$0' expects a sing
DIAGNOSTIC(31006, Error, attributeFunctionNotFound, "Could not find function '$0' for attribute'$1'")
-DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'")
-
-
+DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'")
+DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute")
// Enums
diff --git a/source/slang/hlsl.meta.slang.h b/source/slang/hlsl.meta.slang.h
index 4dc75ab78..ef1a6d273 100644
--- a/source/slang/hlsl.meta.slang.h
+++ b/source/slang/hlsl.meta.slang.h
@@ -418,6 +418,7 @@ SLANG_RAW("\n")
SLANG_RAW("double asdouble(uint lowbits, uint highbits);\n")
SLANG_RAW("\n")
SLANG_RAW("// Reinterpret bits as a float (HLSL SM 4.0)\n")
+SLANG_RAW("\n")
SLANG_RAW("__target_intrinsic(glsl, \"intBitsToFloat\")\n")
SLANG_RAW("float asfloat(int x);\n")
SLANG_RAW("__target_intrinsic(glsl, \"uintBitsToFloat\")\n")
diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h
index 2276f37f8..164621620 100644
--- a/source/slang/modifier-defs.h
+++ b/source/slang/modifier-defs.h
@@ -315,6 +315,14 @@ END_SYNTAX_CLASS()
// A `[name(arg0, ...)]` style attribute that has been validated.
SYNTAX_CLASS(Attribute, AttributeBase)
+ FIELD(AttributeArgumentValueDict, intArgVals)
+END_SYNTAX_CLASS()
+
+SYNTAX_CLASS(UserDefinedAttribute, Attribute)
+END_SYNTAX_CLASS()
+
+SYNTAX_CLASS(AttributeUsageAttribute, Attribute)
+ FIELD(SyntaxClass<RefObject>, targetSyntaxClass)
END_SYNTAX_CLASS()
// An `[unroll]` or `[unroll(count)]` attribute
diff --git a/source/slang/name.cpp b/source/slang/name.cpp
index a5cbb6541..21586e0b6 100644
--- a/source/slang/name.cpp
+++ b/source/slang/name.cpp
@@ -26,4 +26,12 @@ Name* NamePool::getName(String const& text)
return name;
}
+Name* NamePool::tryGetName(String const& text)
+{
+ RefPtr<Name> name;
+ if (rootPool->names.TryGetValue(text, name))
+ return name;
+ return nullptr;
+}
+
} // namespace Slang
diff --git a/source/slang/name.h b/source/slang/name.h
index e1a055e60..a144fbb84 100644
--- a/source/slang/name.h
+++ b/source/slang/name.h
@@ -66,7 +66,9 @@ struct NamePool
{
// Find or create the `Name` that represents the given `text`.
Name* getName(String const& text);
-
+ // Try find the `Name` that represents the given `text`.
+ // If the name does not exist, return nullptr
+ Name* tryGetName(String const& text);
// Set the parent name pool to use for lookup
void setRootNamePool(RootNamePool* rootNamePool)
{
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp
index 44920fb9f..9a5a5faf9 100644
--- a/source/slang/reflection.cpp
+++ b/source/slang/reflection.cpp
@@ -3,7 +3,7 @@
#include "compiler.h"
#include "type-layout.h"
-
+#include "syntax.h"
#include <assert.h>
// Don't signal errors for stuff we don't implement here,
@@ -18,7 +18,19 @@ using namespace Slang;
// Conversion routines to help with strongly-typed reflection API
+static inline Session* convert(SlangSession* session)
+{
+ return (Session*)session;
+}
+static inline UserDefinedAttribute* convert(SlangReflectionUserAttribute* attrib)
+{
+ return (UserDefinedAttribute*)attrib;
+}
+static inline SlangReflectionUserAttribute* convert(UserDefinedAttribute* attrib)
+{
+ return (SlangReflectionUserAttribute*)attrib;
+}
static inline Type* convert(SlangReflectionType* type)
{
return (Type*) type;
@@ -85,6 +97,101 @@ static inline SlangReflection* convert(ProgramLayout* program)
return (SlangReflection*) program;
}
+// user attaribute
+
+unsigned int getUserAttributeCount(Decl* decl)
+{
+ unsigned int count = 0;
+ for (auto x : decl->GetModifiersOfType<UserDefinedAttribute>())
+ {
+ SLANG_UNUSED(x);
+ count++;
+ }
+ return count;
+}
+
+SlangReflectionUserAttribute* findUserAttributeByName(Session* session, Decl* decl, const char* name)
+{
+ auto nameObj = session->tryGetNameObj(name);
+ for (auto x : decl->GetModifiersOfType<UserDefinedAttribute>())
+ {
+ if (x->name == nameObj)
+ return (SlangReflectionUserAttribute*)(x);
+ }
+ return nullptr;
+}
+
+SlangReflectionUserAttribute* getUserAttributeByIndex(Decl* decl, unsigned int index)
+{
+ unsigned int id = 0;
+ for (auto x : decl->GetModifiersOfType<UserDefinedAttribute>())
+ {
+ if (id == index)
+ return convert(x);
+ id++;
+ }
+ return nullptr;
+}
+
+SLANG_API char const* spReflectionUserAttribute_GetName(SlangReflectionUserAttribute* attrib)
+{
+ auto userAttr = convert(attrib);
+ if (!userAttr) return nullptr;
+ return userAttr->getName()->text.Buffer();
+}
+SLANG_API unsigned int spReflectionUserAttribute_GetArgumentCount(SlangReflectionUserAttribute* attrib)
+{
+ auto userAttr = convert(attrib);
+ if (!userAttr) return 0;
+ return (unsigned int)userAttr->args.Count();
+}
+SlangReflectionType* spReflectionUserAttribute_GetArgumentType(SlangReflectionUserAttribute* attrib, unsigned int index)
+{
+ auto userAttr = convert(attrib);
+ if (!userAttr) return nullptr;
+ return convert(userAttr->args[index]->type.type.Ptr());
+}
+SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueInt(SlangReflectionUserAttribute* attrib, unsigned int index, int * rs)
+{
+ auto userAttr = convert(attrib);
+ if (!userAttr) return SLANG_ERROR_INVALID_PARAMETER;
+ if (index >= userAttr->args.Count()) return SLANG_ERROR_INVALID_PARAMETER;
+ RefPtr<RefObject> val;
+ if (userAttr->intArgVals.TryGetValue(index, val))
+ {
+ *rs = (int)val.As<ConstantIntVal>()->value;
+ return 0;
+ }
+ return SLANG_ERROR_INVALID_PARAMETER;
+}
+SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueFloat(SlangReflectionUserAttribute* attrib, unsigned int index, float * rs)
+{
+ auto userAttr = convert(attrib);
+ if (!userAttr) return SLANG_ERROR_INVALID_PARAMETER;
+ if (index >= userAttr->args.Count()) return SLANG_ERROR_INVALID_PARAMETER;
+ if (auto cexpr = userAttr->args[index].As<FloatingPointLiteralExpr>())
+ {
+ *rs = (float)cexpr->value;
+ return 0;
+ }
+ return SLANG_ERROR_INVALID_PARAMETER;
+}
+SLANG_API const char* spReflectionUserAttribute_GetArgumentValueString(SlangReflectionUserAttribute* attrib, unsigned int index, size_t* bufLen)
+{
+ auto userAttr = convert(attrib);
+ if (!userAttr) return nullptr;
+ if (index >= userAttr->args.Count()) return nullptr;
+ if (auto cexpr = userAttr->args[index].As<StringLiteralExpr>())
+ {
+ if (bufLen)
+ *bufLen = cexpr->token.Content.size();
+ return cexpr->token.Content.begin();
+ }
+ return nullptr;
+}
+
+
+
// type Reflection
@@ -352,6 +459,37 @@ SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* in
return SLANG_SCALAR_TYPE_NONE;
}
+SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* inType)
+{
+ auto type = convert(inType);
+ if (!type) return 0;
+ if (auto declRefType = type->AsDeclRefType())
+ {
+ return getUserAttributeCount(declRefType->declRef.getDecl());
+ }
+ return 0;
+}
+SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* inType, unsigned int index)
+{
+ auto type = convert(inType);
+ if (!type) return 0;
+ if (auto declRefType = type->AsDeclRefType())
+ {
+ return getUserAttributeByIndex(declRefType->declRef.getDecl(), index);
+ }
+ return 0;
+}
+SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* inType, char const* name)
+{
+ auto type = convert(inType);
+ if (!type) return 0;
+ if (auto declRefType = type->AsDeclRefType())
+ {
+ return findUserAttributeByName(declRefType->getSession(), declRefType->declRef.getDecl(), name);
+ }
+ return 0;
+}
+
SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionType* inType)
{
auto type = convert(inType);
@@ -757,6 +895,24 @@ SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflec
return (SlangReflectionModifier*) modifier;
}
+SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* inVar)
+{
+ auto varDecl = convert(inVar);
+ if (!varDecl) return 0;
+ return getUserAttributeCount(varDecl);
+}
+SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* inVar, unsigned int index)
+{
+ auto varDecl = convert(inVar);
+ if (!varDecl) return 0;
+ return getUserAttributeByIndex(varDecl, index);
+}
+SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* inVar, SlangSession* session, char const* name)
+{
+ auto varDecl = convert(inVar);
+ if (!varDecl) return 0;
+ return findUserAttributeByName(convert(session), varDecl, name);
+}
// Variable Layout Reflection
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index 74eda66a5..5db762f11 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -1084,6 +1084,8 @@ namespace Slang
RequirementDictionary requirementDictionary;
};
+ typedef Dictionary<unsigned int, RefPtr<RefObject>> AttributeArgumentValueDict;
+
// Generate class definition for all syntax classes
#define SYNTAX_FIELD(TYPE, NAME) TYPE NAME;
#define FIELD(TYPE, NAME) TYPE NAME;
@@ -1343,6 +1345,15 @@ namespace Slang
RefPtr<Substitutions> outerSubst);
RefPtr<GenericSubstitution> findInnerMostGenericSubstitution(Substitutions* subst);
+
+ enum class UserDefinedAttributeTargets
+ {
+ None = 0,
+ Struct = 1,
+ Var = 2,
+ Function = 4,
+ All = 7
+ };
} // namespace Slang
#endif
diff --git a/tests/diagnostics/attribute-error.slang b/tests/diagnostics/attribute-error.slang
new file mode 100644
index 000000000..472593f3b
--- /dev/null
+++ b/tests/diagnostics/attribute-error.slang
@@ -0,0 +1,34 @@
+// attribute.slang
+
+// Tests reflection of user defined attributes.
+
+//TEST:REFLECTION:-stage compute -entry main -target hlsl
+
+[AttributeUsage(AttributeTargets.Struct)]
+struct MyStructAttribute
+{
+ int iParam;
+ float fParam;
+};
+[AttributeUsage(AttributeTargets.Var)]
+struct DefaultValueAttribute
+{
+ int iParam;
+};
+
+[MyStruct(0, "stringVal")] // attribute arg type mismatch
+struct A
+{
+ [MyStruct(0, 10.0)] // attribute does not apply to this construct
+ float x;
+ [DefaultValue(2.0)] // attribute arg type mismatch
+ float y;
+};
+
+ParameterBlock<A> param;
+
+[numthreads(1, 1, 1)]
+void main(
+ uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+} \ No newline at end of file
diff --git a/tests/diagnostics/attribute-error.slang.expected b/tests/diagnostics/attribute-error.slang.expected
new file mode 100644
index 000000000..e372eb957
--- /dev/null
+++ b/tests/diagnostics/attribute-error.slang.expected
@@ -0,0 +1,8 @@
+result code = 1
+standard error = {
+tests/diagnostics/attribute-error.slang(19): error 30019: expected an expression of type 'float', got 'String'
+tests/diagnostics/attribute-error.slang(22): error 31002: attribute 'MyStruct' is not valid here
+tests/diagnostics/attribute-error.slang(24): error 39999: expression does not evaluate to a compile-time constant
+}
+standard output = {
+}
diff --git a/tests/reflection/attribute.slang b/tests/reflection/attribute.slang
new file mode 100644
index 000000000..a3cda4f4b
--- /dev/null
+++ b/tests/reflection/attribute.slang
@@ -0,0 +1,42 @@
+// attribute.slang
+
+// Tests reflection of user defined attributes.
+
+//TEST:REFLECTION:-stage compute -entry main -target hlsl
+
+[AttributeUsage(AttributeTargets.Struct)]
+struct MyStructAttribute
+{
+ int iParam;
+ float fParam;
+};
+[AttributeUsage(AttributeTargets.Var)]
+struct DefaultValueAttribute
+{
+ int iParam;
+};
+
+[MyStruct(0, 1.0)]
+struct A
+{
+ float x;
+ [DefaultValue(1)]
+ float y;
+};
+
+[MyStruct(0, 2.0)]
+struct B
+{
+ float x;
+ [DefaultValue(1+1)]
+ float z;
+};
+
+ParameterBlock<A> param;
+ParameterBlock<B> param2;
+
+[numthreads(1, 1, 1)]
+void main(
+ uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+} \ No newline at end of file
diff --git a/tests/reflection/attribute.slang.expected b/tests/reflection/attribute.slang.expected
new file mode 100644
index 000000000..4348b898f
--- /dev/null
+++ b/tests/reflection/attribute.slang.expected
@@ -0,0 +1,118 @@
+result code = 0
+standard error = {
+}
+standard output = {
+{
+ "parameters": [
+ {
+ "name": "param",
+ "binding": {"kind": "constantBuffer", "index": 0},
+ "type": {
+ "kind": "parameterBlock",
+ "elementType": {
+ "kind": "struct",
+ "name": "A",
+ "fields": [
+ {
+ "name": "x",
+ "type": {
+ "kind": "scalar",
+ "scalarType": "float32"
+ },
+ "binding": {"kind": "uniform", "offset": 0, "size": 4}
+ },
+ {
+ "name": "y",
+ "type": {
+ "kind": "scalar",
+ "scalarType": "float32"
+ },
+ "binding": {"kind": "uniform", "offset": 4, "size": 4},
+ "userAttribs": [{
+ "name": "DefaultValue",
+ "arguments": [
+ 1
+ ]
+ }
+ ]
+ }
+ ],
+ "userAttribs": [{
+ "name": "MyStruct",
+ "arguments": [
+ 0,
+ 1.000000
+ ]
+ }
+ ]
+ }
+ }
+ },
+ {
+ "name": "param2",
+ "binding": {"kind": "constantBuffer", "index": 1},
+ "type": {
+ "kind": "parameterBlock",
+ "elementType": {
+ "kind": "struct",
+ "name": "B",
+ "fields": [
+ {
+ "name": "x",
+ "type": {
+ "kind": "scalar",
+ "scalarType": "float32"
+ },
+ "binding": {"kind": "uniform", "offset": 0, "size": 4}
+ },
+ {
+ "name": "z",
+ "type": {
+ "kind": "scalar",
+ "scalarType": "float32"
+ },
+ "binding": {"kind": "uniform", "offset": 4, "size": 4},
+ "userAttribs": [{
+ "name": "DefaultValue",
+ "arguments": [
+ 2
+ ]
+ }
+ ]
+ }
+ ],
+ "userAttribs": [{
+ "name": "MyStruct",
+ "arguments": [
+ 0,
+ 2.000000
+ ]
+ }
+ ]
+ }
+ }
+ }
+ ],
+ "entryPoints": [
+ {
+ "name": "main",
+ "stage:": "compute",
+ "parameters": [
+ {
+ "name": "dispatchThreadID",
+ "semanticName": "SV_DISPATCHTHREADID",
+ "type": {
+ "kind": "vector",
+ "elementCount": 3,
+ "elementType": {
+ "kind": "scalar",
+ "scalarType": "uint32"
+ }
+ }
+ }
+ ],
+ "threadGroupSize": [1, 1, 1]
+ }
+ ]
+}
+}
diff --git a/tools/slang-reflection-test/slang-reflection-test-main.cpp b/tools/slang-reflection-test/slang-reflection-test-main.cpp
index 209528927..636b16a1b 100644
--- a/tools/slang-reflection-test/slang-reflection-test-main.cpp
+++ b/tools/slang-reflection-test/slang-reflection-test-main.cpp
@@ -55,7 +55,7 @@ static void dedent(PrettyWriter& writer)
writer.indent--;
}
-static void write(PrettyWriter& writer, char const* text)
+static void write(PrettyWriter& writer, char const* text, size_t length = 0)
{
// TODO: can do this more efficiently...
char const* cursor = text;
@@ -63,7 +63,7 @@ static void write(PrettyWriter& writer, char const* text)
{
char c = *cursor++;
if (!c) break;
-
+ if (length && cursor - text == length) break;
if (c == '\n')
{
writer.startOfLine = true;
@@ -83,6 +83,18 @@ static void write(PrettyWriter& writer, SlangUInt val)
Slang::StdWriters::getOut().print("%llu", (unsigned long long)val);
}
+static void write(PrettyWriter& writer, int val)
+{
+ adjust(writer);
+ Slang::StdWriters::getOut().print("%d", val);
+}
+
+static void write(PrettyWriter& writer, float val)
+{
+ adjust(writer);
+ Slang::StdWriters::getOut().print("%f", val);
+}
+
static void emitReflectionVarInfoJSON(PrettyWriter& writer, slang::VariableReflection* var);
static void emitReflectionTypeLayoutJSON(PrettyWriter& writer, slang::TypeLayoutReflection* type);
static void emitReflectionTypeJSON(PrettyWriter& writer, slang::TypeReflection* type);
@@ -261,6 +273,76 @@ static void emitReflectionModifierInfoJSON(
}
}
+static void emitUserAttributeJSON(PrettyWriter& writer, slang::UserAttribute* userAttribute)
+{
+ write(writer, "{\n");
+ indent(writer);
+ write(writer, "\"name\": \"");
+ write(writer, userAttribute->getName());
+ write(writer, "\",\n");
+ write(writer, "\"arguments\": [\n");
+ indent(writer);
+ for (unsigned int i = 0; i < userAttribute->getArgumentCount(); i++)
+ {
+ int intVal;
+ float floatVal;
+ size_t bufSize = 0;
+ if (i > 0)
+ write(writer, ",\n");
+ if (SLANG_SUCCEEDED(userAttribute->getArgumentValueInt(i, &intVal)))
+ {
+ write(writer, intVal);
+ }
+ else if (SLANG_SUCCEEDED(userAttribute->getArgumentValueFloat(i, &floatVal)))
+ {
+ write(writer, floatVal);
+ }
+ else if (auto str = userAttribute->getArgumentValueString(i, &bufSize))
+ {
+ write(writer, str, bufSize);
+ }
+ else
+ write(writer, "\"invalid value\"");
+ }
+ dedent(writer);
+ write(writer, "\n]\n");
+ dedent(writer);
+ write(writer, "}\n");
+}
+
+static void emitUserAttributes(PrettyWriter& writer, slang::TypeReflection* type)
+{
+ auto attribCount = type->getUserAttributeCount();
+ if (attribCount)
+ {
+ write(writer, ",\n\"userAttribs\": [");
+ for (unsigned int i = 0; i < attribCount; i++)
+ {
+ if (i > 0)
+ write(writer, ",\n");
+ auto attrib = type->getUserAttributeByIndex(i);
+ emitUserAttributeJSON(writer, attrib);
+ }
+ write(writer, "]");
+ }
+}
+static void emitUserAttributes(PrettyWriter& writer, slang::VariableReflection* var)
+{
+ auto attribCount = var->getUserAttributeCount();
+ if (attribCount)
+ {
+ write(writer, ",\n\"userAttribs\": [");
+ for (unsigned int i = 0; i < attribCount; i++)
+ {
+ if (i > 0)
+ write(writer, ",\n");
+ auto attrib = var->getUserAttributeByIndex(i);
+ emitUserAttributeJSON(writer, attrib);
+ }
+ write(writer, "]");
+ }
+}
+
static void emitReflectionVarLayoutJSON(
PrettyWriter& writer,
slang::VariableLayoutReflection* var)
@@ -278,6 +360,7 @@ static void emitReflectionVarLayoutJSON(
emitReflectionVarBindingInfoJSON(writer, var);
+ emitUserAttributes(writer, var->getVariable());
dedent(writer);
write(writer, "\n}");
}
@@ -368,6 +451,7 @@ static void emitReflectionResourceTypeBaseInfoJSON(
}
}
+
static void emitReflectionTypeInfoJSON(
PrettyWriter& writer,
slang::TypeReflection* type)
@@ -467,10 +551,10 @@ static void emitReflectionTypeInfoJSON(
write(writer, "\"kind\": \"matrix\"");
write(writer, ",\n");
write(writer, "\"rowCount\": ");
- write(writer, type->getRowCount());
+ write(writer, (SlangUInt)type->getRowCount());
write(writer, ",\n");
write(writer, "\"columnCount\": ");
- write(writer, type->getColumnCount());
+ write(writer, (SlangUInt)type->getColumnCount());
write(writer, ",\n");
write(writer, "\"elementType\": ");
emitReflectionTypeJSON(
@@ -523,8 +607,10 @@ static void emitReflectionTypeInfoJSON(
assert(!"unhandled case");
break;
}
+ emitUserAttributes(writer, type);
}
+
static void emitReflectionTypeLayoutInfoJSON(
PrettyWriter& writer,
slang::TypeLayoutReflection* typeLayout)
@@ -580,6 +666,8 @@ static void emitReflectionTypeLayoutInfoJSON(
}
dedent(writer);
write(writer, "\n]");
+ emitUserAttributes(writer, structTypeLayout->getType());
+
}
break;