diff options
| author | Yong He <yonghe@google.com> | 2019-01-29 11:41:54 -0800 |
|---|---|---|
| committer | Yong He <yonghe@google.com> | 2019-01-29 11:41:54 -0800 |
| commit | b7f8f7abcc3cc1dfa820ebba47a772b78d6a4cfb (patch) | |
| tree | 26d81dec1162ee9d26b811f0b7621e74ade9e06f | |
| parent | f8b8ea0055ad877551198e1e295d33860b504672 (diff) | |
Add support for user defined attributes.
| -rw-r--r-- | .gitignore | 1 | ||||
| -rw-r--r-- | slang.h | 76 | ||||
| -rw-r--r-- | source/slang/check.cpp | 177 | ||||
| -rw-r--r-- | source/slang/compiler.h | 1 | ||||
| -rw-r--r-- | source/slang/core.meta.slang | 9 | ||||
| -rw-r--r-- | source/slang/core.meta.slang.h | 12 | ||||
| -rw-r--r-- | source/slang/diagnostic-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang.h | 1 | ||||
| -rw-r--r-- | source/slang/modifier-defs.h | 8 | ||||
| -rw-r--r-- | source/slang/name.cpp | 8 | ||||
| -rw-r--r-- | source/slang/name.h | 4 | ||||
| -rw-r--r-- | source/slang/reflection.cpp | 158 | ||||
| -rw-r--r-- | source/slang/syntax.h | 11 | ||||
| -rw-r--r-- | tests/diagnostics/attribute-error.slang | 34 | ||||
| -rw-r--r-- | tests/diagnostics/attribute-error.slang.expected | 8 | ||||
| -rw-r--r-- | tests/reflection/attribute.slang | 42 | ||||
| -rw-r--r-- | tests/reflection/attribute.slang.expected | 118 | ||||
| -rw-r--r-- | tools/slang-reflection-test/slang-reflection-test-main.cpp | 96 |
18 files changed, 747 insertions, 22 deletions
diff --git a/.gitignore b/.gitignore index 206dfca5e..b92a66e6e 100644 --- a/.gitignore +++ b/.gitignore @@ -24,3 +24,4 @@ build.*/ # Files generated by other shader compilers *.spv +/source/slang/core.meta.slang.temp.h @@ -1357,6 +1357,7 @@ extern "C" typedef struct SlangReflectionVariable SlangReflectionVariable; typedef struct SlangReflectionVariableLayout SlangReflectionVariableLayout; typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter; + typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute; // get reflection data from a compilation request SLANG_API SlangReflection* spGetReflection( @@ -1492,9 +1493,26 @@ extern "C" SLANG_MODIFIER_SHARED, }; + // User Attribute + SLANG_API char const* spReflectionUserAttribute_GetName(SlangReflectionUserAttribute* attrib); + SLANG_API unsigned int spReflectionUserAttribute_GetArgumentCount(SlangReflectionUserAttribute* attrib); + SLANG_API SlangReflectionType* spReflectionUserAttribute_GetArgumentType(SlangReflectionUserAttribute* attrib, unsigned int index); + SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueInt(SlangReflectionUserAttribute* attrib, unsigned int index, int * rs); + SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueFloat(SlangReflectionUserAttribute* attrib, unsigned int index, float * rs); + + /** Returns the string-typed value of a user attribute argument + The string returned is not null-terminated. The length of the string is returned via `outSize`. + If index of out of range, or if the specified argument is not a string, the function will return nullptr. + */ + SLANG_API const char* spReflectionUserAttribute_GetArgumentValueString(SlangReflectionUserAttribute* attrib, unsigned int index, size_t * outSize); + // Type Reflection SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* type); + SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* type); + SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* type, unsigned int index); + SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* type, char const* name); + SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* type); SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* type, unsigned index); @@ -1548,8 +1566,10 @@ extern "C" SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* var); SLANG_API SlangReflectionType* spReflectionVariable_GetType(SlangReflectionVariable* var); - SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflectionVariable* var, SlangModifierID modifierID); + SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* var); + SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* var, unsigned int index); + SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * session, char const* name); // Variable Layout Reflection @@ -1641,6 +1661,34 @@ namespace slang struct TypeReflection; struct VariableLayoutReflection; struct VariableReflection; + + struct UserAttribute + { + char const* getName() + { + return spReflectionUserAttribute_GetName((SlangReflectionUserAttribute*)this); + } + uint32_t getArgumentCount() + { + return (uint32_t)spReflectionUserAttribute_GetArgumentCount((SlangReflectionUserAttribute*)this); + } + TypeReflection* getArgumentType(uint32_t index) + { + return (TypeReflection*)spReflectionUserAttribute_GetArgumentType((SlangReflectionUserAttribute*)this, index); + } + SlangResult getArgumentValueInt(uint32_t index, int * value) + { + return spReflectionUserAttribute_GetArgumentValueInt((SlangReflectionUserAttribute*)this, index, value); + } + SlangResult getArgumentValueFloat(uint32_t index, float * value) + { + return spReflectionUserAttribute_GetArgumentValueFloat((SlangReflectionUserAttribute*)this, index, value); + } + const char* getArgumentValueString(uint32_t index, size_t * outSize) + { + return spReflectionUserAttribute_GetArgumentValueString((SlangReflectionUserAttribute*)this, index, outSize); + } + }; struct TypeReflection { @@ -1764,6 +1812,19 @@ namespace slang { return spReflectionType_GetName((SlangReflectionType*) this); } + + unsigned int getUserAttributeCount() + { + return spReflectionType_GetUserAttributeCount((SlangReflectionType*)this); + } + UserAttribute* getUserAttributeByIndex(unsigned int index) + { + return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index); + } + UserAttribute* findUserAttributeByName(char const* name) + { + return (UserAttribute*)spReflectionType_FindUserAttributeByName((SlangReflectionType*)this, name); + } }; enum ParameterCategory : SlangParameterCategory @@ -1944,6 +2005,19 @@ namespace slang { return (Modifier*) spReflectionVariable_FindModifier((SlangReflectionVariable*) this, (SlangModifierID) id); } + + unsigned int getUserAttributeCount() + { + return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this); + } + UserAttribute* getUserAttributeByIndex(unsigned int index) + { + return (UserAttribute*)spReflectionVariable_GetUserAttribute((SlangReflectionVariable*)this, index); + } + UserAttribute* findUserAttributeByName(SlangSession* session, char const* name) + { + return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, session, name); + } }; struct VariableLayoutReflection diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 42bf396fa..e100cbd1d 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -2389,20 +2389,75 @@ namespace Slang // rules to keep us from seeing shadowing variable declarations. auto lookupResult = lookUp(getSession(), this, attributeName, scope, LookupMask::Attribute); - // If we didn't find anything, or the result was overloaded, + // If the result was overloaded, // then we aren't going to be able to extract a single decl. - if(!lookupResult.isValid() || lookupResult.isOverloaded()) + if(lookupResult.isOverloaded()) return nullptr; - auto decl = lookupResult.item.declRef.getDecl(); - if( auto attributeDecl = dynamic_cast<AttributeDecl*>(decl) ) + if (lookupResult.isValid()) { - return attributeDecl; + auto decl = lookupResult.item.declRef.getDecl(); + if (auto attributeDecl = dynamic_cast<AttributeDecl*>(decl)) + { + return attributeDecl; + } + else + { + return nullptr; + } } - else + + // If we couldn't find a system attribute, try looking up as a user defined attribute + // A user defined attribute class is defined as a struct type with a "UserDefinedAttributeAttribute" modifier + lookupResult = lookUp(getSession(), this, getSession()->getNameObj(attributeName->text + "Attribute"), scope, LookupMask::type); + if (lookupResult.isOverloaded()) { - return nullptr; + // see if we have already created an AttributeDecl for this attribute struct + for (auto alt : lookupResult.items) + { + if (auto adecl = alt.declRef.As<AttributeDecl>()) + return adecl.getDecl(); + } } + // If we still cannot find any thing, quit + if (!lookupResult.isValid() || lookupResult.isOverloaded()) + return nullptr; + // Now construct an AttributeDecl for this user defined attribute class for future lookup + auto userDefAttribAttrib = lookupResult.item.declRef.decl->FindModifier<AttributeUsageAttribute>(); + if (!userDefAttribAttrib) + return nullptr; + // create an AttributeDecl for the user defined attribute + auto structAttribDef = lookupResult.item.declRef.As<StructDecl>().getDecl(); + RefPtr<AttributeDecl> attribDecl = new AttributeDecl(); + attribDecl->nameAndLoc = structAttribDef->nameAndLoc; + attribDecl->loc = structAttribDef->loc; + attribDecl->nextInContainerWithSameName = structAttribDef->nextInContainerWithSameName; + // create a __attributeTarget modifier for the attribute class definition + RefPtr<AttributeTargetModifier> targetModifier = new AttributeTargetModifier(); + targetModifier->syntaxClass = userDefAttribAttrib->targetSyntaxClass; + targetModifier->loc = structAttribDef->loc; + targetModifier->next = attribDecl->modifiers.first; + attribDecl->modifiers.first = targetModifier; + structAttribDef->nextInContainerWithSameName = attribDecl.Ptr(); + // we should create UserDefinedAttribute nodes for all user defined attribute instances + attribDecl->syntaxClass = getSession()->findSyntaxClass(getSession()->getNameObj("UserDefinedAttribute")); + for (auto member : structAttribDef->Members) + { + if (auto varMember = member.As<VarDecl>()) + { + RefPtr<ParamDecl> param = new ParamDecl(); + param->nameAndLoc = member->nameAndLoc; + param->type = varMember->type; + param->loc = member->loc; + attribDecl->Members.Add(param); + } + } + // add the attribute class definition to the syntax tree, so it can be found + structAttribDef->ParentDecl->Members.Add(attribDecl.Ptr()); + structAttribDef->ParentDecl->memberDictionaryIsValid = false; + // do necessary checks on this newly constructed node + checkDecl(attribDecl.Ptr()); + return attribDecl.Ptr(); } bool hasIntArgs(Attribute* attr, int numArgs) @@ -2437,7 +2492,27 @@ namespace Slang return true; } - bool validateAttribute(RefPtr<Attribute> attr) + bool getAttributeTargetSyntaxClasses(SyntaxClass<RefObject> & cls, uint32_t typeFlags) + { + if (typeFlags == (int)UserDefinedAttributeTargets::Struct) + { + cls = getSession()->findSyntaxClass(getSession()->getNameObj("StructDecl")); + return true; + } + if (typeFlags == (int)UserDefinedAttributeTargets::Var) + { + cls = getSession()->findSyntaxClass(getSession()->getNameObj("VarDecl")); + return true; + } + if (typeFlags == (int)UserDefinedAttributeTargets::Function) + { + cls = getSession()->findSyntaxClass(getSession()->getNameObj("FuncDecl")); + return true; + } + return false; + } + + bool validateAttribute(RefPtr<Attribute> attr, AttributeDecl* attribClassDecl) { if(auto numThreadsAttr = attr.As<NumThreadsAttribute>()) { @@ -2529,6 +2604,67 @@ namespace Slang // Has no args SLANG_ASSERT(attr->args.Count() == 0); } + else if (auto attrUsageAttr = attr.As<AttributeUsageAttribute>()) + { + uint32_t targetClassId = (uint32_t)UserDefinedAttributeTargets::None; + if (attr->args.Count() == 1) + { + RefPtr<IntVal> outIntVal; + if (auto cInt = checkConstantIntVal(attr->args[0])) + { + targetClassId = (uint32_t)(cInt->value); + } + else + { + getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->name); + return false; + } + } + if (!getAttributeTargetSyntaxClasses(attrUsageAttr->targetSyntaxClass, targetClassId)) + { + getSink()->diagnose(attr, Diagnostics::invalidAttributeTarget); + return false; + } + } + else if (auto userDefAttr = attr.As<UserDefinedAttribute>()) + { + // check arguments against attribute parameters defined in attribClassDecl + uint32_t paramIndex = 0; + auto params = attribClassDecl->getMembersOfType<ParamDecl>(); + for (auto paramDecl : params) + { + if (paramIndex < attr->args.Count()) + { + auto & arg = attr->args[paramIndex]; + bool typeChecked = false; + if (auto basicType = paramDecl->getType()->AsBasicType()) + { + if (basicType->baseType == BaseType::Int) + { + if (auto cint = checkConstantIntVal(arg)) + { + attr->intArgVals[(uint32_t)paramIndex] = cint; + } + typeChecked = true; + } + } + if (!typeChecked) + { + arg = CheckExpr(arg); + arg = Coerce(paramDecl->getType(), arg); + } + } + paramIndex++; + } + if (params.Count() < attr->args.Count()) + { + getSink()->diagnose(attr, Diagnostics::tooManyArguments, attr->args.Count(), params.Count()); + } + else if (params.Count() > attr->args.Count()) + { + getSink()->diagnose(attr, Diagnostics::notEnoughArguments, attr->args.Count(), params.Count()); + } + } else { if(attr->args.Count() == 0) @@ -2597,11 +2733,9 @@ namespace Slang UInt paramIndex = paramCounter++; if( paramIndex < argCount ) { - auto arg = attr->args[paramIndex]; - // TODO: support checking the argument against the declared // type for the parameter. - + } else { @@ -2653,7 +2787,7 @@ namespace Slang } // Now apply type-specific validation to the attribute. - if(!validateAttribute(attr)) + if(!validateAttribute(attr, attrDecl)) { return uncheckedAttr; } @@ -2778,6 +2912,24 @@ namespace Slang registerExtension(extDecl); } } + // check user defined attribute classes first + for (auto decl : programNode->Members) + { + if (auto typeMember = decl->As<StructDecl>()) + { + bool isTypeAttributeClass = false; + for (auto attrib : typeMember->GetModifiersOfType<UncheckedAttribute>()) + { + if (attrib->name == getSession()->getNameObj("AttributeUsageAttribute")) + { + isTypeAttributeClass = true; + break; + } + } + if (isTypeAttributeClass) + checkDecl(decl); + } + } // check types for (auto & s : programNode->getMembersOfType<TypeDefDecl>()) checkDecl(s.Ptr()); @@ -8588,6 +8740,7 @@ namespace Slang RefPtr<Expr> visitStaticMemberExpr(StaticMemberExpr* /*expr*/) { + // StaticMemberExpr means it is already checked SLANG_UNEXPECTED("should not occur in unchecked AST"); UNREACHABLE_RETURN(nullptr); } diff --git a/source/slang/compiler.h b/source/slang/compiler.h index ca82950be..9ca4bbc70 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -587,6 +587,7 @@ namespace Slang RootNamePool* getRootNamePool() { return &rootNamePool; } NamePool* getNamePool() { return &namePool; } Name* getNameObj(String name) { return namePool.getName(name); } + Name* tryGetNameObj(String name) { return namePool.tryGetName(name); } // // Generated code for stdlib, etc. diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index e6bad3f50..50a4cbf96 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -1286,3 +1286,12 @@ attribute_syntax [mutating] : MutatingAttribute; /// This is equivalent to the LLVM `readnone` function attribute. __attributeTarget(FunctionDeclBase) attribute_syntax [__readNone] : ReadNoneAttribute; + +enum AttributeTargets +{ + Struct = $((int) UserDefinedAttributeTargets::Struct), + Var = $((int) UserDefinedAttributeTargets::Var), + Function = $((int) UserDefinedAttributeTargets::Function), +}; +__attributeTarget(StructDecl) +attribute_syntax [AttributeUsage(target : AttributeTargets)] : AttributeUsageAttribute;
\ No newline at end of file diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h index a29fc24ce..746e75bd9 100644 --- a/source/slang/core.meta.slang.h +++ b/source/slang/core.meta.slang.h @@ -1304,3 +1304,15 @@ SLANG_RAW(" ///\n") SLANG_RAW(" /// This is equivalent to the LLVM `readnone` function attribute.\n") SLANG_RAW("__attributeTarget(FunctionDeclBase)\n") SLANG_RAW("attribute_syntax [__readNone] : ReadNoneAttribute;\n") +SLANG_RAW("\n") +SLANG_RAW("enum AttributeTargets\n") +SLANG_RAW("{\n") + + sb << "Struct = " << (int)UserDefinedAttributeTargets::Struct << ", "; + sb << "Var = " << (int)UserDefinedAttributeTargets::Var << ", "; + sb << "Function = " << (int)UserDefinedAttributeTargets::Function; + +SLANG_RAW("\n") +SLANG_RAW("};\n") +SLANG_RAW("__attributeTarget(StructDecl)\n") +SLANG_RAW("attribute_syntax [AttributeUsage(target : AttributeTargets)] : AttributeUsageAttribute;") diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 2c4f00dd6..68790d4ab 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -260,9 +260,8 @@ DIAGNOSTIC(31005, Error, expectedSingleStringArg, "attribute '$0' expects a sing DIAGNOSTIC(31006, Error, attributeFunctionNotFound, "Could not find function '$0' for attribute'$1'") -DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'") - - +DIAGNOSTIC(31100, Error, unknownStageName, "unknown stage name '$0'") +DIAGNOSTIC(31120, Error, invalidAttributeTarget, "invalid syntax target for user defined attribute") // Enums diff --git a/source/slang/hlsl.meta.slang.h b/source/slang/hlsl.meta.slang.h index 4dc75ab78..ef1a6d273 100644 --- a/source/slang/hlsl.meta.slang.h +++ b/source/slang/hlsl.meta.slang.h @@ -418,6 +418,7 @@ SLANG_RAW("\n") SLANG_RAW("double asdouble(uint lowbits, uint highbits);\n") SLANG_RAW("\n") SLANG_RAW("// Reinterpret bits as a float (HLSL SM 4.0)\n") +SLANG_RAW("\n") SLANG_RAW("__target_intrinsic(glsl, \"intBitsToFloat\")\n") SLANG_RAW("float asfloat(int x);\n") SLANG_RAW("__target_intrinsic(glsl, \"uintBitsToFloat\")\n") diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h index 2276f37f8..164621620 100644 --- a/source/slang/modifier-defs.h +++ b/source/slang/modifier-defs.h @@ -315,6 +315,14 @@ END_SYNTAX_CLASS() // A `[name(arg0, ...)]` style attribute that has been validated. SYNTAX_CLASS(Attribute, AttributeBase) + FIELD(AttributeArgumentValueDict, intArgVals) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(UserDefinedAttribute, Attribute) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(AttributeUsageAttribute, Attribute) + FIELD(SyntaxClass<RefObject>, targetSyntaxClass) END_SYNTAX_CLASS() // An `[unroll]` or `[unroll(count)]` attribute diff --git a/source/slang/name.cpp b/source/slang/name.cpp index a5cbb6541..21586e0b6 100644 --- a/source/slang/name.cpp +++ b/source/slang/name.cpp @@ -26,4 +26,12 @@ Name* NamePool::getName(String const& text) return name; } +Name* NamePool::tryGetName(String const& text) +{ + RefPtr<Name> name; + if (rootPool->names.TryGetValue(text, name)) + return name; + return nullptr; +} + } // namespace Slang diff --git a/source/slang/name.h b/source/slang/name.h index e1a055e60..a144fbb84 100644 --- a/source/slang/name.h +++ b/source/slang/name.h @@ -66,7 +66,9 @@ struct NamePool { // Find or create the `Name` that represents the given `text`. Name* getName(String const& text); - + // Try find the `Name` that represents the given `text`. + // If the name does not exist, return nullptr + Name* tryGetName(String const& text); // Set the parent name pool to use for lookup void setRootNamePool(RootNamePool* rootNamePool) { diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index 44920fb9f..9a5a5faf9 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -3,7 +3,7 @@ #include "compiler.h" #include "type-layout.h" - +#include "syntax.h" #include <assert.h> // Don't signal errors for stuff we don't implement here, @@ -18,7 +18,19 @@ using namespace Slang; // Conversion routines to help with strongly-typed reflection API +static inline Session* convert(SlangSession* session) +{ + return (Session*)session; +} +static inline UserDefinedAttribute* convert(SlangReflectionUserAttribute* attrib) +{ + return (UserDefinedAttribute*)attrib; +} +static inline SlangReflectionUserAttribute* convert(UserDefinedAttribute* attrib) +{ + return (SlangReflectionUserAttribute*)attrib; +} static inline Type* convert(SlangReflectionType* type) { return (Type*) type; @@ -85,6 +97,101 @@ static inline SlangReflection* convert(ProgramLayout* program) return (SlangReflection*) program; } +// user attaribute + +unsigned int getUserAttributeCount(Decl* decl) +{ + unsigned int count = 0; + for (auto x : decl->GetModifiersOfType<UserDefinedAttribute>()) + { + SLANG_UNUSED(x); + count++; + } + return count; +} + +SlangReflectionUserAttribute* findUserAttributeByName(Session* session, Decl* decl, const char* name) +{ + auto nameObj = session->tryGetNameObj(name); + for (auto x : decl->GetModifiersOfType<UserDefinedAttribute>()) + { + if (x->name == nameObj) + return (SlangReflectionUserAttribute*)(x); + } + return nullptr; +} + +SlangReflectionUserAttribute* getUserAttributeByIndex(Decl* decl, unsigned int index) +{ + unsigned int id = 0; + for (auto x : decl->GetModifiersOfType<UserDefinedAttribute>()) + { + if (id == index) + return convert(x); + id++; + } + return nullptr; +} + +SLANG_API char const* spReflectionUserAttribute_GetName(SlangReflectionUserAttribute* attrib) +{ + auto userAttr = convert(attrib); + if (!userAttr) return nullptr; + return userAttr->getName()->text.Buffer(); +} +SLANG_API unsigned int spReflectionUserAttribute_GetArgumentCount(SlangReflectionUserAttribute* attrib) +{ + auto userAttr = convert(attrib); + if (!userAttr) return 0; + return (unsigned int)userAttr->args.Count(); +} +SlangReflectionType* spReflectionUserAttribute_GetArgumentType(SlangReflectionUserAttribute* attrib, unsigned int index) +{ + auto userAttr = convert(attrib); + if (!userAttr) return nullptr; + return convert(userAttr->args[index]->type.type.Ptr()); +} +SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueInt(SlangReflectionUserAttribute* attrib, unsigned int index, int * rs) +{ + auto userAttr = convert(attrib); + if (!userAttr) return SLANG_ERROR_INVALID_PARAMETER; + if (index >= userAttr->args.Count()) return SLANG_ERROR_INVALID_PARAMETER; + RefPtr<RefObject> val; + if (userAttr->intArgVals.TryGetValue(index, val)) + { + *rs = (int)val.As<ConstantIntVal>()->value; + return 0; + } + return SLANG_ERROR_INVALID_PARAMETER; +} +SLANG_API SlangResult spReflectionUserAttribute_GetArgumentValueFloat(SlangReflectionUserAttribute* attrib, unsigned int index, float * rs) +{ + auto userAttr = convert(attrib); + if (!userAttr) return SLANG_ERROR_INVALID_PARAMETER; + if (index >= userAttr->args.Count()) return SLANG_ERROR_INVALID_PARAMETER; + if (auto cexpr = userAttr->args[index].As<FloatingPointLiteralExpr>()) + { + *rs = (float)cexpr->value; + return 0; + } + return SLANG_ERROR_INVALID_PARAMETER; +} +SLANG_API const char* spReflectionUserAttribute_GetArgumentValueString(SlangReflectionUserAttribute* attrib, unsigned int index, size_t* bufLen) +{ + auto userAttr = convert(attrib); + if (!userAttr) return nullptr; + if (index >= userAttr->args.Count()) return nullptr; + if (auto cexpr = userAttr->args[index].As<StringLiteralExpr>()) + { + if (bufLen) + *bufLen = cexpr->token.Content.size(); + return cexpr->token.Content.begin(); + } + return nullptr; +} + + + // type Reflection @@ -352,6 +459,37 @@ SLANG_API SlangScalarType spReflectionType_GetScalarType(SlangReflectionType* in return SLANG_SCALAR_TYPE_NONE; } +SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* inType) +{ + auto type = convert(inType); + if (!type) return 0; + if (auto declRefType = type->AsDeclRefType()) + { + return getUserAttributeCount(declRefType->declRef.getDecl()); + } + return 0; +} +SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* inType, unsigned int index) +{ + auto type = convert(inType); + if (!type) return 0; + if (auto declRefType = type->AsDeclRefType()) + { + return getUserAttributeByIndex(declRefType->declRef.getDecl(), index); + } + return 0; +} +SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* inType, char const* name) +{ + auto type = convert(inType); + if (!type) return 0; + if (auto declRefType = type->AsDeclRefType()) + { + return findUserAttributeByName(declRefType->getSession(), declRefType->declRef.getDecl(), name); + } + return 0; +} + SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionType* inType) { auto type = convert(inType); @@ -757,6 +895,24 @@ SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflec return (SlangReflectionModifier*) modifier; } +SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* inVar) +{ + auto varDecl = convert(inVar); + if (!varDecl) return 0; + return getUserAttributeCount(varDecl); +} +SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* inVar, unsigned int index) +{ + auto varDecl = convert(inVar); + if (!varDecl) return 0; + return getUserAttributeByIndex(varDecl, index); +} +SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* inVar, SlangSession* session, char const* name) +{ + auto varDecl = convert(inVar); + if (!varDecl) return 0; + return findUserAttributeByName(convert(session), varDecl, name); +} // Variable Layout Reflection diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 74eda66a5..5db762f11 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -1084,6 +1084,8 @@ namespace Slang RequirementDictionary requirementDictionary; }; + typedef Dictionary<unsigned int, RefPtr<RefObject>> AttributeArgumentValueDict; + // Generate class definition for all syntax classes #define SYNTAX_FIELD(TYPE, NAME) TYPE NAME; #define FIELD(TYPE, NAME) TYPE NAME; @@ -1343,6 +1345,15 @@ namespace Slang RefPtr<Substitutions> outerSubst); RefPtr<GenericSubstitution> findInnerMostGenericSubstitution(Substitutions* subst); + + enum class UserDefinedAttributeTargets + { + None = 0, + Struct = 1, + Var = 2, + Function = 4, + All = 7 + }; } // namespace Slang #endif diff --git a/tests/diagnostics/attribute-error.slang b/tests/diagnostics/attribute-error.slang new file mode 100644 index 000000000..472593f3b --- /dev/null +++ b/tests/diagnostics/attribute-error.slang @@ -0,0 +1,34 @@ +// attribute.slang + +// Tests reflection of user defined attributes. + +//TEST:REFLECTION:-stage compute -entry main -target hlsl + +[AttributeUsage(AttributeTargets.Struct)] +struct MyStructAttribute +{ + int iParam; + float fParam; +}; +[AttributeUsage(AttributeTargets.Var)] +struct DefaultValueAttribute +{ + int iParam; +}; + +[MyStruct(0, "stringVal")] // attribute arg type mismatch +struct A +{ + [MyStruct(0, 10.0)] // attribute does not apply to this construct + float x; + [DefaultValue(2.0)] // attribute arg type mismatch + float y; +}; + +ParameterBlock<A> param; + +[numthreads(1, 1, 1)] +void main( + uint3 dispatchThreadID : SV_DispatchThreadID) +{ +}
\ No newline at end of file diff --git a/tests/diagnostics/attribute-error.slang.expected b/tests/diagnostics/attribute-error.slang.expected new file mode 100644 index 000000000..e372eb957 --- /dev/null +++ b/tests/diagnostics/attribute-error.slang.expected @@ -0,0 +1,8 @@ +result code = 1 +standard error = { +tests/diagnostics/attribute-error.slang(19): error 30019: expected an expression of type 'float', got 'String' +tests/diagnostics/attribute-error.slang(22): error 31002: attribute 'MyStruct' is not valid here +tests/diagnostics/attribute-error.slang(24): error 39999: expression does not evaluate to a compile-time constant +} +standard output = { +} diff --git a/tests/reflection/attribute.slang b/tests/reflection/attribute.slang new file mode 100644 index 000000000..a3cda4f4b --- /dev/null +++ b/tests/reflection/attribute.slang @@ -0,0 +1,42 @@ +// attribute.slang + +// Tests reflection of user defined attributes. + +//TEST:REFLECTION:-stage compute -entry main -target hlsl + +[AttributeUsage(AttributeTargets.Struct)] +struct MyStructAttribute +{ + int iParam; + float fParam; +}; +[AttributeUsage(AttributeTargets.Var)] +struct DefaultValueAttribute +{ + int iParam; +}; + +[MyStruct(0, 1.0)] +struct A +{ + float x; + [DefaultValue(1)] + float y; +}; + +[MyStruct(0, 2.0)] +struct B +{ + float x; + [DefaultValue(1+1)] + float z; +}; + +ParameterBlock<A> param; +ParameterBlock<B> param2; + +[numthreads(1, 1, 1)] +void main( + uint3 dispatchThreadID : SV_DispatchThreadID) +{ +}
\ No newline at end of file diff --git a/tests/reflection/attribute.slang.expected b/tests/reflection/attribute.slang.expected new file mode 100644 index 000000000..4348b898f --- /dev/null +++ b/tests/reflection/attribute.slang.expected @@ -0,0 +1,118 @@ +result code = 0 +standard error = { +} +standard output = { +{ + "parameters": [ + { + "name": "param", + "binding": {"kind": "constantBuffer", "index": 0}, + "type": { + "kind": "parameterBlock", + "elementType": { + "kind": "struct", + "name": "A", + "fields": [ + { + "name": "x", + "type": { + "kind": "scalar", + "scalarType": "float32" + }, + "binding": {"kind": "uniform", "offset": 0, "size": 4} + }, + { + "name": "y", + "type": { + "kind": "scalar", + "scalarType": "float32" + }, + "binding": {"kind": "uniform", "offset": 4, "size": 4}, + "userAttribs": [{ + "name": "DefaultValue", + "arguments": [ + 1 + ] + } + ] + } + ], + "userAttribs": [{ + "name": "MyStruct", + "arguments": [ + 0, + 1.000000 + ] + } + ] + } + } + }, + { + "name": "param2", + "binding": {"kind": "constantBuffer", "index": 1}, + "type": { + "kind": "parameterBlock", + "elementType": { + "kind": "struct", + "name": "B", + "fields": [ + { + "name": "x", + "type": { + "kind": "scalar", + "scalarType": "float32" + }, + "binding": {"kind": "uniform", "offset": 0, "size": 4} + }, + { + "name": "z", + "type": { + "kind": "scalar", + "scalarType": "float32" + }, + "binding": {"kind": "uniform", "offset": 4, "size": 4}, + "userAttribs": [{ + "name": "DefaultValue", + "arguments": [ + 2 + ] + } + ] + } + ], + "userAttribs": [{ + "name": "MyStruct", + "arguments": [ + 0, + 2.000000 + ] + } + ] + } + } + } + ], + "entryPoints": [ + { + "name": "main", + "stage:": "compute", + "parameters": [ + { + "name": "dispatchThreadID", + "semanticName": "SV_DISPATCHTHREADID", + "type": { + "kind": "vector", + "elementCount": 3, + "elementType": { + "kind": "scalar", + "scalarType": "uint32" + } + } + } + ], + "threadGroupSize": [1, 1, 1] + } + ] +} +} diff --git a/tools/slang-reflection-test/slang-reflection-test-main.cpp b/tools/slang-reflection-test/slang-reflection-test-main.cpp index 209528927..636b16a1b 100644 --- a/tools/slang-reflection-test/slang-reflection-test-main.cpp +++ b/tools/slang-reflection-test/slang-reflection-test-main.cpp @@ -55,7 +55,7 @@ static void dedent(PrettyWriter& writer) writer.indent--; } -static void write(PrettyWriter& writer, char const* text) +static void write(PrettyWriter& writer, char const* text, size_t length = 0) { // TODO: can do this more efficiently... char const* cursor = text; @@ -63,7 +63,7 @@ static void write(PrettyWriter& writer, char const* text) { char c = *cursor++; if (!c) break; - + if (length && cursor - text == length) break; if (c == '\n') { writer.startOfLine = true; @@ -83,6 +83,18 @@ static void write(PrettyWriter& writer, SlangUInt val) Slang::StdWriters::getOut().print("%llu", (unsigned long long)val); } +static void write(PrettyWriter& writer, int val) +{ + adjust(writer); + Slang::StdWriters::getOut().print("%d", val); +} + +static void write(PrettyWriter& writer, float val) +{ + adjust(writer); + Slang::StdWriters::getOut().print("%f", val); +} + static void emitReflectionVarInfoJSON(PrettyWriter& writer, slang::VariableReflection* var); static void emitReflectionTypeLayoutJSON(PrettyWriter& writer, slang::TypeLayoutReflection* type); static void emitReflectionTypeJSON(PrettyWriter& writer, slang::TypeReflection* type); @@ -261,6 +273,76 @@ static void emitReflectionModifierInfoJSON( } } +static void emitUserAttributeJSON(PrettyWriter& writer, slang::UserAttribute* userAttribute) +{ + write(writer, "{\n"); + indent(writer); + write(writer, "\"name\": \""); + write(writer, userAttribute->getName()); + write(writer, "\",\n"); + write(writer, "\"arguments\": [\n"); + indent(writer); + for (unsigned int i = 0; i < userAttribute->getArgumentCount(); i++) + { + int intVal; + float floatVal; + size_t bufSize = 0; + if (i > 0) + write(writer, ",\n"); + if (SLANG_SUCCEEDED(userAttribute->getArgumentValueInt(i, &intVal))) + { + write(writer, intVal); + } + else if (SLANG_SUCCEEDED(userAttribute->getArgumentValueFloat(i, &floatVal))) + { + write(writer, floatVal); + } + else if (auto str = userAttribute->getArgumentValueString(i, &bufSize)) + { + write(writer, str, bufSize); + } + else + write(writer, "\"invalid value\""); + } + dedent(writer); + write(writer, "\n]\n"); + dedent(writer); + write(writer, "}\n"); +} + +static void emitUserAttributes(PrettyWriter& writer, slang::TypeReflection* type) +{ + auto attribCount = type->getUserAttributeCount(); + if (attribCount) + { + write(writer, ",\n\"userAttribs\": ["); + for (unsigned int i = 0; i < attribCount; i++) + { + if (i > 0) + write(writer, ",\n"); + auto attrib = type->getUserAttributeByIndex(i); + emitUserAttributeJSON(writer, attrib); + } + write(writer, "]"); + } +} +static void emitUserAttributes(PrettyWriter& writer, slang::VariableReflection* var) +{ + auto attribCount = var->getUserAttributeCount(); + if (attribCount) + { + write(writer, ",\n\"userAttribs\": ["); + for (unsigned int i = 0; i < attribCount; i++) + { + if (i > 0) + write(writer, ",\n"); + auto attrib = var->getUserAttributeByIndex(i); + emitUserAttributeJSON(writer, attrib); + } + write(writer, "]"); + } +} + static void emitReflectionVarLayoutJSON( PrettyWriter& writer, slang::VariableLayoutReflection* var) @@ -278,6 +360,7 @@ static void emitReflectionVarLayoutJSON( emitReflectionVarBindingInfoJSON(writer, var); + emitUserAttributes(writer, var->getVariable()); dedent(writer); write(writer, "\n}"); } @@ -368,6 +451,7 @@ static void emitReflectionResourceTypeBaseInfoJSON( } } + static void emitReflectionTypeInfoJSON( PrettyWriter& writer, slang::TypeReflection* type) @@ -467,10 +551,10 @@ static void emitReflectionTypeInfoJSON( write(writer, "\"kind\": \"matrix\""); write(writer, ",\n"); write(writer, "\"rowCount\": "); - write(writer, type->getRowCount()); + write(writer, (SlangUInt)type->getRowCount()); write(writer, ",\n"); write(writer, "\"columnCount\": "); - write(writer, type->getColumnCount()); + write(writer, (SlangUInt)type->getColumnCount()); write(writer, ",\n"); write(writer, "\"elementType\": "); emitReflectionTypeJSON( @@ -523,8 +607,10 @@ static void emitReflectionTypeInfoJSON( assert(!"unhandled case"); break; } + emitUserAttributes(writer, type); } + static void emitReflectionTypeLayoutInfoJSON( PrettyWriter& writer, slang::TypeLayoutReflection* typeLayout) @@ -580,6 +666,8 @@ static void emitReflectionTypeLayoutInfoJSON( } dedent(writer); write(writer, "\n]"); + emitUserAttributes(writer, structTypeLayout->getType()); + } break; |
