diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-08-07 02:04:37 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-07 02:04:37 -0400 |
| commit | 2f2ae8c31490ab01ce0d0cc76d5d7fcf1d21efe7 (patch) | |
| tree | 135cb2109f2717dba3a11929c4cdf163b5ad5c50 | |
| parent | 366c9b4526b4b940c8aafce459d6784211e862bc (diff) | |
More reflection API features. (#4740)
* More reflection API features.
+ Lookup methods and members (by string) on types
+ Fix issue with looking up non-static members through the scope operator '::'
+ `GenericReflection`: Cast a decl to generic to access unspecialized generic parameter names and constraints
+ `GenericReflection`: Use `getGenericContainer()` from function, variable or type to access the 'nearest' generic parent along with specialization info
+ `GenericReflection::getConcreteType` and `GenericReflection::getConcreteIntVal`: to get the concrete type of a param in the context of the reflection object
+ `GenericReflection::getOuterGenericContainer` to go up one level and get the outer generic declarations (if there are more than one enclosing generic scopes)
+ `DeclReflection::getParent`: go to parent declaration.
+ Change `VariableReflection` to be a `DeclRef` rather than a decl (allows us to return properly substituted types for methods, members, and more)
* Fix Falcor issue
| -rw-r--r-- | include/slang.h | 147 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 342 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 87 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-decl-tree-reflection.cpp | 139 |
6 files changed, 689 insertions, 33 deletions
diff --git a/include/slang.h b/include/slang.h index c68157759..05ab1a0ce 100644 --- a/include/slang.h +++ b/include/slang.h @@ -2121,6 +2121,7 @@ extern "C" typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter; typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute; typedef struct SlangReflectionFunction SlangReflectionFunction; + typedef struct SlangReflectionGeneric SlangReflectionGeneric; /* Type aliases to maintain backward compatibility. @@ -2469,6 +2470,7 @@ extern "C" SLANG_API char const* spReflectionType_GetName(SlangReflectionType* type); SLANG_API SlangResult spReflectionType_GetFullName(SlangReflectionType* type, ISlangBlob** outNameBlob); + SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type); // Type Layout Reflection @@ -2553,6 +2555,7 @@ extern "C" SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* var, unsigned int index); SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * globalSession, char const* name); SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar); + SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var); // Variable Layout Reflection @@ -2578,6 +2581,7 @@ extern "C" SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* func); SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* func, unsigned index); SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func); + SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func); // Abstract Decl Reflection @@ -2586,8 +2590,26 @@ extern "C" SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl); SLANG_API SlangReflectionFunction* spReflectionDecl_castToFunction(SlangReflectionDecl* decl); SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflectionDecl* decl); - SLANG_API SlangReflectionType* spReflection_getTypeFromDecl(SlangSession* session, SlangReflectionDecl* decl); - + SLANG_API SlangReflectionGeneric* spReflectionDecl_castToGeneric(SlangReflectionDecl* decl); + SLANG_API SlangReflectionType* spReflection_getTypeFromDecl(SlangReflectionDecl* decl); + SLANG_API SlangReflectionDecl* spReflectionDecl_getParent(SlangReflectionDecl* decl); + + // Generic Reflection + + SLANG_API SlangReflectionDecl* spReflectionGeneric_asDecl(SlangReflectionGeneric* generic); + SLANG_API char const* spReflectionGeneric_GetName(SlangReflectionGeneric* generic); + SLANG_API unsigned int spReflectionGeneric_GetTypeParameterCount(SlangReflectionGeneric* generic); + SLANG_API SlangReflectionVariable* spReflectionGeneric_GetTypeParameter(SlangReflectionGeneric* generic, unsigned index); + SLANG_API unsigned int spReflectionGeneric_GetValueParameterCount(SlangReflectionGeneric* generic); + SLANG_API SlangReflectionVariable* spReflectionGeneric_GetValueParameter(SlangReflectionGeneric* generic, unsigned index); + SLANG_API unsigned int spReflectionGeneric_GetTypeParameterConstraintCount(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam); + SLANG_API SlangReflectionType* spReflectionGeneric_GetTypeParameterConstraintType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam, unsigned index); + SLANG_API SlangDeclKind spReflectionGeneric_GetInnerKind(SlangReflectionGeneric* generic); + SLANG_API SlangReflectionDecl* spReflectionGeneric_GetInnerDecl(SlangReflectionGeneric* generic); + SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(SlangReflectionGeneric* generic); + SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam); + SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(SlangReflectionGeneric* generic, SlangReflectionVariable* valueParam); + /** Get the stage that a variable belongs to (if any). @@ -2678,6 +2700,8 @@ extern "C" SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout(SlangReflection* reflection, SlangReflectionType* reflectionType, SlangLayoutRules rules); SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflection* reflection, char const* name); + SLANG_API SlangReflectionFunction* spReflection_FindFunctionByNameInType(SlangReflection* reflection, SlangReflectionType* reflType, char const* name); + SLANG_API SlangReflectionVariable* spReflection_FindVarByNameInType(SlangReflection* reflection, SlangReflectionType* reflType, char const* name); SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* reflection); SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* reflection, SlangUInt index); @@ -2735,6 +2759,8 @@ namespace slang struct TypeReflection; struct VariableLayoutReflection; struct VariableReflection; + struct FunctionReflection; + struct GenericReflection; struct UserAttribute { @@ -2913,6 +2939,11 @@ namespace slang { return (UserAttribute*)spReflectionType_FindUserAttributeByName((SlangReflectionType*)this, name); } + + SlangReflectionGeneric* getGenericContainer() + { + return (SlangReflectionGeneric*) spReflectionType_GetGenericContainer((SlangReflectionType*) this); + } }; enum ParameterCategory : SlangParameterCategoryIntegral @@ -3360,10 +3391,12 @@ namespace slang { return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this); } + UserAttribute* getUserAttributeByIndex(unsigned int index) { return (UserAttribute*)spReflectionVariable_GetUserAttribute((SlangReflectionVariable*)this, index); } + UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name) { return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, globalSession, name); @@ -3373,6 +3406,11 @@ namespace slang { return spReflectionVariable_HasDefaultValue((SlangReflectionVariable*)this); } + + GenericReflection* getGenericContainer() + { + return (GenericReflection*)spReflectionVariable_GetGenericContainer((SlangReflectionVariable*)this); + } }; struct VariableLayoutReflection @@ -3498,6 +3536,81 @@ namespace slang { return (Modifier*)spReflectionFunction_FindModifier((SlangReflectionFunction*)this, (SlangModifierID)id); } + + GenericReflection* getGenericContainer() + { + return (GenericReflection*)spReflectionFunction_GetGenericContainer((SlangReflectionFunction*)this); + } + }; + + struct GenericReflection + { + + DeclReflection* asDecl() + { + return (DeclReflection*)spReflectionGeneric_asDecl((SlangReflectionGeneric*)this); + } + + char const* getName() + { + return spReflectionGeneric_GetName((SlangReflectionGeneric*)this); + } + + unsigned int getTypeParameterCount() + { + return spReflectionGeneric_GetTypeParameterCount((SlangReflectionGeneric*)this); + } + + VariableReflection* getTypeParameter(unsigned index) + { + return (VariableReflection*)spReflectionGeneric_GetTypeParameter((SlangReflectionGeneric*)this, index); + } + + unsigned int getValueParameterCount() + { + return spReflectionGeneric_GetValueParameterCount((SlangReflectionGeneric*)this); + } + + VariableReflection* getValueParameter(unsigned index) + { + return (VariableReflection*)spReflectionGeneric_GetValueParameter((SlangReflectionGeneric*)this, index); + } + + unsigned int getTypeParameterConstraintCount(VariableReflection* typeParam) + { + return spReflectionGeneric_GetTypeParameterConstraintCount((SlangReflectionGeneric*)this, (SlangReflectionVariable*)typeParam); + } + + TypeReflection* getTypeParameterConstraintType(VariableReflection* typeParam, unsigned index) + { + return (TypeReflection*)spReflectionGeneric_GetTypeParameterConstraintType((SlangReflectionGeneric*)this, (SlangReflectionVariable*)typeParam, index); + } + + DeclReflection* getInnerDecl() + { + return (DeclReflection*)spReflectionGeneric_GetInnerDecl((SlangReflectionGeneric*)this); + } + + SlangDeclKind getInnerKind() + { + return spReflectionGeneric_GetInnerKind((SlangReflectionGeneric*)this); + } + + GenericReflection* getOuterGenericContainer() + { + return (GenericReflection*)spReflectionGeneric_GetOuterGenericContainer((SlangReflectionGeneric*)this); + } + + TypeReflection* getConcreteType(VariableReflection* typeParam) + { + return (TypeReflection*)spReflectionGeneric_GetConcreteType((SlangReflectionGeneric*)this, (SlangReflectionVariable*)typeParam); + } + + int64_t getConcreteIntVal(VariableReflection* valueParam) + { + return spReflectionGeneric_GetConcreteIntVal((SlangReflectionGeneric*)this, (SlangReflectionVariable*)valueParam); + } + }; struct EntryPointReflection @@ -3672,6 +3785,22 @@ namespace slang name); } + FunctionReflection* findFunctionByNameInType(TypeReflection* type, const char* name) + { + return (FunctionReflection*)spReflection_FindFunctionByNameInType( + (SlangReflection*) this, + (SlangReflectionType*) type, + name); + } + + VariableReflection* findVarByNameInType(TypeReflection* type, const char* name) + { + return (VariableReflection*)spReflection_FindVarByNameInType( + (SlangReflection*) this, + (SlangReflectionType*) type, + name); + } + TypeLayoutReflection* getTypeLayout( TypeReflection* type, LayoutRules rules = LayoutRules::Default) @@ -3749,9 +3878,9 @@ namespace slang return (DeclReflection*)spReflectionDecl_getChild((SlangReflectionDecl*)this, index); } - TypeReflection* getType(SlangSession* session) + TypeReflection* getType() { - return (TypeReflection*)spReflection_getTypeFromDecl(session, (SlangReflectionDecl*)this); + return (TypeReflection*)spReflection_getTypeFromDecl((SlangReflectionDecl*)this); } VariableReflection* asVariable() @@ -3764,6 +3893,16 @@ namespace slang return (FunctionReflection*)spReflectionDecl_castToFunction((SlangReflectionDecl*)this); } + GenericReflection* asGeneric() + { + return (GenericReflection*)spReflectionDecl_castToGeneric((SlangReflectionDecl*)this); + } + + DeclReflection* getParent() + { + return (DeclReflection*)spReflectionDecl_getParent((SlangReflectionDecl*)this); + } + template <Kind K> struct FilteredList { diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 8e1cb0a88..1180d4be2 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -420,6 +420,12 @@ namespace Slang DeclRef<Decl> findDeclFromString( String const& name, DiagnosticSink* sink); + + DeclRef<Decl> findDeclFromStringInType( + Type* type, + String const& name, + LookupMask mask, + DiagnosticSink* sink); Dictionary<String, IntVal*>& getMangledNameToIntValMap(); ConstantIntVal* tryFoldIntVal(IntVal* intVal); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 6510793f3..10b52611f 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2203,6 +2203,7 @@ namespace Slang case TokenType::OpEql: case TokenType::OpNeq: case TokenType::OpGreater: + case TokenType::EndOfFile: { return parseGenericApp(parser, base); } diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index a212f0d1b..edfd27489 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -8,6 +8,7 @@ #include "slang-compiler.h" #include "slang-type-layout.h" #include "slang-syntax.h" +#include "slang-check.h" #include <assert.h> // Don't signal errors for stuff we don't implement here, @@ -55,14 +56,14 @@ static inline SpecializationParamLayout* convert(SlangReflectionTypeParameter * return (SpecializationParamLayout*) typeParam; } -static inline Decl* convert(SlangReflectionVariable* var) +static inline DeclRef<Decl> convert(SlangReflectionVariable* var) { - return (Decl*) var; + return DeclRef<Decl>((DeclRefBase*) var); } -static inline SlangReflectionVariable* convert(Decl* var) +static inline SlangReflectionVariable* convert(DeclRef<Decl> var) { - return (SlangReflectionVariable*) var; + return (SlangReflectionVariable*) var.declRefBase; } static inline DeclRef<FunctionDeclBase> convert(SlangReflectionFunction* func) @@ -76,6 +77,17 @@ static inline SlangReflectionFunction* convert(DeclRef<FunctionDeclBase> func) return (SlangReflectionFunction*)func.declRefBase; } +static inline DeclRef<Decl> convertGenericToDeclRef(SlangReflectionGeneric* func) +{ + DeclRefBase* declBase = (DeclRefBase*)func; + return DeclRef<Decl>(declBase); +} + +static inline SlangReflectionGeneric* convertDeclToGeneric(DeclRef<Decl> func) +{ + return (SlangReflectionGeneric*)func.declRefBase; +} + static inline VarLayout* convert(SlangReflectionVariableLayout* var) { return (VarLayout*) var; @@ -482,7 +494,7 @@ SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflect auto fields = getFields( getModule(declRef.getDecl())->getLinkage()->getASTBuilder(), structDeclRef, MemberFilterStyle::Instance); auto fieldDeclRef = fields[index]; - return (SlangReflectionVariable*) fieldDeclRef.getDecl(); + return convert(fieldDeclRef); } } @@ -786,6 +798,52 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti return nullptr; } +SLANG_API SlangReflectionFunction* spReflection_FindFunctionByNameInType(SlangReflection* reflection, SlangReflectionType* reflType, char const* name) +{ + auto programLayout = convert(reflection); + auto program = programLayout->getProgram(); + + auto type = convert(reflType); + + Slang::DiagnosticSink sink( + programLayout->getTargetReq()->getLinkage()->getSourceManager(), + Lexer::sourceLocationLexer); + + try + { + auto result = program->findDeclFromStringInType(type, name, LookupMask::Function, &sink); + if (auto funcDeclRef = result.as<FunctionDeclBase>()) + return convert(funcDeclRef); + } + catch (...) + { + } + return nullptr; +} + +SLANG_API SlangReflectionVariable* spReflection_FindVarByNameInType(SlangReflection* reflection, SlangReflectionType* reflType, char const* name) +{ + auto programLayout = convert(reflection); + auto program = programLayout->getProgram(); + + auto type = convert(reflType); + + Slang::DiagnosticSink sink( + programLayout->getTargetReq()->getLinkage()->getSourceManager(), + Lexer::sourceLocationLexer); + + try + { + auto result = program->findDeclFromStringInType(type, name, LookupMask::Value, &sink); + if (auto varDeclRef = result.as<VarDeclBase>()) + return convert(varDeclRef.as<Decl>()); + } + catch (...) + { + } + return nullptr; +} + SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * reflection, char const * name) { auto programLayout = convert(reflection); @@ -811,6 +869,35 @@ SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * re } } +SlangReflectionGeneric* getInnermostGenericParent(DeclRef<Decl> declRef) +{ + auto decl = declRef.getDecl(); + auto astBuilder = getModule(decl)->getLinkage()->getASTBuilder(); + auto parentDecl = decl; + while(parentDecl) + { + if(parentDecl->parentDecl && as<GenericDecl>(parentDecl->parentDecl)) + return convertDeclToGeneric( + substituteDeclRef( + SubstitutionSet(declRef), + astBuilder, + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl)))); + parentDecl = parentDecl->parentDecl; + } + + return nullptr; +} + +SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type) +{ + auto declRefType = as<DeclRefType>(convert(type)); + if (!declRefType) + return nullptr; + + auto declRef = declRefType->getDeclRef(); + return getInnermostGenericParent(declRef); +} + SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout( SlangReflection* reflection, SlangReflectionType* inType, @@ -2228,7 +2315,7 @@ SLANG_API SlangReflectionVariable* spReflectionTypeLayout_getBindingRangeLeafVar return 0; auto& bindingRange = extTypeLayout->m_bindingRanges[index]; - return convert(bindingRange.leafVariable); + return convert(DeclRef<Decl>(bindingRange.leafVariable)); } @@ -2571,7 +2658,7 @@ SLANG_API SlangInt spReflectionTypeLayout_getSubObjectRangeDescriptorRangeSpaceO SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* inVar) { - auto var = convert(inVar); + auto var = convert(inVar).getDecl(); if (as<InheritanceDecl>(var)) return "$base"; @@ -2589,17 +2676,26 @@ SLANG_API SlangReflectionType* spReflectionVariable_GetType(SlangReflectionVaria { auto var = convert(inVar); - if (auto inheritanceDecl = as<InheritanceDecl>(var)) - return convert(inheritanceDecl->base.type); - if(!var) return nullptr; - return convert(as<VarDeclBase>(var)->getType()); + auto astBuilder = getModule(var.getDecl())->getLinkage()->getASTBuilder(); + + if (auto inheritanceDecl = as<InheritanceDecl>(var.getDecl())) + return convert(inheritanceDecl->base.type); + + if (auto varDecl = as<VarDeclBase>(var.getDecl())) + return convert( + substituteType( + SubstitutionSet(var), + astBuilder, + varDecl->getType())); + + return nullptr; } SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflectionVariable* inVar, SlangModifierID modifierID) { - auto var = convert(inVar); + auto var = convert(inVar).getDecl(); if(!var) return nullptr; @@ -2648,26 +2744,26 @@ SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflec SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* inVar) { - auto varDecl = convert(inVar); + auto varDecl = convert(inVar).getDecl(); if (!varDecl) return 0; return getUserAttributeCount(varDecl); } SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* inVar, unsigned int index) { - auto varDecl = convert(inVar); + auto varDecl = convert(inVar).getDecl(); if (!varDecl) return 0; return getUserAttributeByIndex(varDecl, index); } SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* inVar, SlangSession* session, char const* name) { - auto varDecl = convert(inVar); + auto varDecl = convert(inVar).getDecl(); if (!varDecl) return 0; return findUserAttributeByName(asInternal(session), varDecl, name); } SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar) { - auto decl = convert(inVar); + auto decl = convert(inVar).getDecl(); if (auto varDecl = as<VarDeclBase>(decl)) { return varDecl->initExpr != nullptr; @@ -2676,6 +2772,12 @@ SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inV return false; } +SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var) +{ + auto declRef = convert(var); + return getInnermostGenericParent(declRef); +} + // Variable Layout Reflection SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangReflectionVariableLayout* inVarLayout) @@ -2683,7 +2785,7 @@ SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangR auto varLayout = convert(inVarLayout); if(!varLayout) return nullptr; - return (SlangReflectionVariable*)(varLayout->varDecl.getDecl()); + return convert(varLayout->varDecl); } SLANG_API SlangReflectionTypeLayout* spReflectionVariableLayout_GetTypeLayout(SlangReflectionVariableLayout* inVarLayout) @@ -2850,7 +2952,7 @@ SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectio SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* inFunc, SlangModifierID modifierID) { auto funcDeclRef = convert(inFunc); - auto varRefl = convert(funcDeclRef.getDecl()); + auto varRefl = convert(funcDeclRef.as<Decl>()); if (!varRefl) return nullptr; return spReflectionVariable_FindModifier(varRefl, modifierID); @@ -2888,7 +2990,16 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec { auto func = convert(inFunc); if (!func) return nullptr; - return convert(as<Decl>(func.getDecl()->getParameters()[index])); + + auto astBuilder = getModule(func.getDecl())->getLinkage()->getASTBuilder(); + + return convert(getParameters(astBuilder, func)[index]); +} + +SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func) +{ + auto declRef = convert(func); + return getInnermostGenericParent(declRef); } // Abstract decl reflection @@ -2960,20 +3071,31 @@ SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflecti Decl* slangDecl = (Decl*) decl; if (auto varDecl = as<VarDeclBase>(slangDecl)) { - return (SlangReflectionVariable*) varDecl; + return convert(DeclRef(varDecl)); } // Improper cast return nullptr; +} + +SLANG_API SlangReflectionGeneric* spReflectionDecl_castToGeneric(SlangReflectionDecl* decl) +{ + Decl* slangDecl = (Decl*) decl; + if (auto genericInnerDecl = as<GenericDecl>(slangDecl)->inner) + { + return convertDeclToGeneric(genericInnerDecl); + } + // Improper cast + return nullptr; } -SLANG_API SlangReflectionType* spReflection_getTypeFromDecl(SlangSession* session, SlangReflectionDecl* decl) +SLANG_API SlangReflectionType* spReflection_getTypeFromDecl(SlangReflectionDecl* decl) { Decl* slangDecl = (Decl*)decl; - auto slangSession = asInternal(session); - ASTBuilder* builder = slangSession->getGlobalASTBuilder(); + ASTBuilder* builder = getModule(slangDecl)->getLinkage()->getASTBuilder(); + // TODO: create default substitutions if (auto type = DeclRefType::create(builder, slangDecl->getDefaultDeclRef())) { return convert(type); @@ -2983,6 +3105,178 @@ SLANG_API SlangReflectionType* spReflection_getTypeFromDecl(SlangSession* sessio return nullptr; } +SLANG_API SlangReflectionDecl* spReflectionDecl_getParent(SlangReflectionDecl* decl) +{ + Decl* slangDecl = (Decl*)decl; + if (auto parentDecl = slangDecl->parentDecl) + { + return (SlangReflectionDecl*)parentDecl; + } + + return nullptr; +} + +// Generic Reflection + +SLANG_API SlangReflectionDecl* spReflectionGeneric_asDecl(SlangReflectionGeneric* generic) +{ + return (SlangReflectionDecl*) convertGenericToDeclRef(generic).getDecl()->parentDecl; +} + +SLANG_API char const* spReflectionGeneric_GetName(SlangReflectionGeneric* generic) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return nullptr; + return getText(slangGeneric.getDecl()->getName()).getBuffer(); +} + +SLANG_API unsigned int spReflectionGeneric_GetTypeParameterCount(SlangReflectionGeneric* generic) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return 0; + auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder(); + + return (unsigned int) getMembersOfType<GenericTypeParamDecl>(astBuilder, slangGeneric.getDecl()->parentDecl).getCount(); +} + +SLANG_API SlangReflectionVariable* spReflectionGeneric_GetTypeParameter(SlangReflectionGeneric* generic, unsigned index) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return nullptr; + auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder(); + + auto paramDeclRef = getMembersOfType<GenericTypeParamDecl>(astBuilder, slangGeneric.getDecl()->parentDecl)[index]; + + return convert(substituteDeclRef(SubstitutionSet(slangGeneric), astBuilder, paramDeclRef)); +} + +SLANG_API unsigned int spReflectionGeneric_GetValueParameterCount(SlangReflectionGeneric* generic) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return 0; + auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder(); + + return (unsigned int) getMembersOfType<GenericValueParamDecl>(astBuilder, slangGeneric.getDecl()->parentDecl).getCount(); +} + +SLANG_API SlangReflectionVariable* spReflectionGeneric_GetValueParameter(SlangReflectionGeneric* generic, unsigned index) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return nullptr; + auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder(); + + auto paramDeclRef = getMembersOfType<GenericValueParamDecl>(astBuilder, slangGeneric.getDecl()->parentDecl)[index]; + + return convert(substituteDeclRef(SubstitutionSet(slangGeneric), astBuilder, paramDeclRef)); +} + +SLANG_API unsigned int spReflectionGeneric_GetTypeParameterConstraintCount(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return 0; + auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder(); + + if (auto typeParamDecl = as<GenericTypeParamDecl>(convert(typeParam).getDecl())) + { + auto constraints = getCanonicalGenericConstraints( + astBuilder, + DeclRef<GenericDecl>(slangGeneric.getDecl()->parentDecl)); + return (unsigned int)(constraints[typeParamDecl]).getValue().getCount(); + } + + return 0; +} + +SLANG_API SlangReflectionType* spReflectionGeneric_GetTypeParameterConstraintType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam, unsigned index) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return nullptr; + auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder(); + + if (auto typeParamDecl = as<GenericTypeParamDecl>(convert(typeParam).getDecl())) + { + auto constraints = getCanonicalGenericConstraints( + astBuilder, + DeclRef<GenericDecl>(slangGeneric.getDecl()->parentDecl)); + if (auto constraint = (constraints[typeParamDecl]).getValue()[index]) + { + return convert(substituteType(SubstitutionSet(slangGeneric), astBuilder, constraint)); + } + } + + return nullptr; +} + +SLANG_API SlangDeclKind spReflectionGeneric_GetInnerKind(SlangReflectionGeneric* generic) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return SLANG_DECL_KIND_UNSUPPORTED_FOR_REFLECTION; + + return spReflectionDecl_getKind((SlangReflectionDecl*)slangGeneric.getDecl()); +} + +SLANG_API SlangReflectionDecl* spReflectionGeneric_GetInnerDecl(SlangReflectionGeneric* generic) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return nullptr; + + return (SlangReflectionDecl*)slangGeneric.getDecl(); +} + +SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(SlangReflectionGeneric* generic) +{ + auto declRef = convertGenericToDeclRef(generic); + + auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder(); + + return getInnermostGenericParent( + substituteDeclRef( + SubstitutionSet(declRef), + astBuilder, + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl)))); +} + +SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return nullptr; + auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder(); + + auto genericType = DeclRefType::create(astBuilder, convert(typeParam)); + + auto substType = substituteType(SubstitutionSet(slangGeneric), astBuilder, genericType); + + if (genericType != substType) + { + return convert(substType); + } + + return nullptr; +} + +SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(SlangReflectionGeneric* generic, SlangReflectionVariable* valueParam) +{ + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return 0; + auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder(); + + auto valueParamDeclRef = convert(valueParam); + + Val* valResult = astBuilder->getOrCreate<GenericParamIntVal>( + valueParamDeclRef.substitute(astBuilder, as<GenericValueParamDecl>(valueParamDeclRef.getDecl())->getType()), + valueParamDeclRef); + valResult = valResult->substitute(astBuilder, SubstitutionSet(slangGeneric)); + + auto intVal = as<ConstantIntVal>(valResult); + if (intVal) + { + return intVal->getValue(); + } + + return 0; +} + + // Shader Parameter Reflection SLANG_API unsigned spReflectionParameter_GetBindingIndex(SlangReflectionParameter* inVarLayout) @@ -3046,7 +3340,7 @@ SLANG_API SlangReflectionFunction* spReflectionEntryPoint_GetFunction(SlangRefle auto entryPointLayout = convert(inEntryPoint); if (entryPointLayout) { - return convert(entryPointLayout->entryPoint); + return convert(entryPointLayout->entryPoint.as<FunctionDeclBase>()); } return nullptr; } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index cf8c19033..9ba02ee50 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -26,6 +26,8 @@ #include "slang-parser.h" #include "slang-preprocessor.h" #include "slang-type-layout.h" +#include "slang-lookup.h" + # #include "slang-options.h" @@ -2279,7 +2281,10 @@ DeclRef<Decl> ComponentType::findDeclFromString( linkage, nullptr, sink); - SemanticsVisitor visitor(&sharedSemanticsContext); + SemanticsContext context(&sharedSemanticsContext); + context = context.allowStaticReferenceToNonStaticMember(); + + SemanticsVisitor visitor(context); auto checkedExpr = visitor.CheckExpr(expr); if (auto declRefExpr = as<DeclRefExpr>(checkedExpr)) @@ -2291,6 +2296,86 @@ DeclRef<Decl> ComponentType::findDeclFromString( return result; } +DeclRef<Decl> ComponentType::findDeclFromStringInType( + Type* type, + String const& name, + LookupMask mask, + DiagnosticSink* sink) +{ + DeclRef<Decl> result; + + // Only look up in the type if it is a DeclRefType + if (!as<DeclRefType>(type)) + return DeclRef<Decl>(); + + // TODO(JS): For now just used the linkages ASTBuilder to keep on scope + // + // The parseTermString uses the linkage ASTBuilder for it's parsing. + // + // It might be possible to just create a temporary ASTBuilder - the worry though is + // that the parsing sets a member variable in AST node to one of these scopes, and then + // it become a dangling pointer. So for now we go with the linkages. + auto astBuilder = getLinkage()->getASTBuilder(); + + // Otherwise, we need to start looking in + // the modules that were directly or + // indirectly referenced. + // + Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder); + + auto linkage = getLinkage(); + + SLANG_AST_BUILDER_RAII(linkage->getASTBuilder()); + + Expr* expr = linkage->parseTermString(name, scope); + + SharedSemanticsContext sharedSemanticsContext( + linkage, + nullptr, + sink); + SemanticsContext context(&sharedSemanticsContext); + context = context.allowStaticReferenceToNonStaticMember(); + + SemanticsVisitor visitor(context); + + GenericAppExpr* genericOuterExpr = nullptr; + if (as<GenericAppExpr>(expr)) + { + // Unwrap the generic application, and re-wrap it around the static-member expr + genericOuterExpr = as<GenericAppExpr>(expr); + expr = genericOuterExpr->functionExpr; + } + + if (!as<VarExpr>(expr)) + return result; + + auto rs = astBuilder->create<StaticMemberExpr>(); + auto typeExpr = astBuilder->create<SharedTypeExpr>(); + auto typetype = astBuilder->getOrCreate<TypeType>(type); + typeExpr->type = typetype; + rs->baseExpression = typeExpr; + rs->name = as<VarExpr>(expr)->name; + + expr = rs; + + // If we have a generic-app expression, re-wrap the static-member expr + if (genericOuterExpr) + { + genericOuterExpr->functionExpr = expr; + expr = genericOuterExpr; + } + + auto checkedTerm = visitor.CheckTerm(expr); + auto resolvedTerm = visitor.maybeResolveOverloadedExpr(checkedTerm, mask, sink); + + if (auto declRefExpr = as<DeclRefExpr>(resolvedTerm)) + { + result = declRefExpr->declRef; + } + + return result; +} + static void collectExportedConstantInContainer( Dictionary<String, IntVal*>& dict, ASTBuilder* builder, diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp index e443358fb..fb35f323c 100644 --- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -52,6 +52,24 @@ SLANG_UNIT_TEST(declTreeReflection) { return pos; } + + uint f(uint y) { return y; } + + struct MyType + { + int x; + float f(float x) { return x; } + } + + struct MyGenericType<T : IArithmetic & IFloat> + { + T z; + T g() { return z; } + U h<U>(U x, out T y) { y = z; return x; } + + T j<let N : int>(T x, out int o) { o = N; return x; } + } + )"; auto moduleName = "moduleG" + String(Process::getId()); @@ -75,10 +93,15 @@ SLANG_UNIT_TEST(declTreeReflection) module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef()); SLANG_CHECK(entryPoint != nullptr); + ComPtr<slang::IComponentType> compositeProgram; + slang::IComponentType* components[] = { module, entryPoint.get() }; + session->createCompositeComponentType(components, 2, compositeProgram.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(compositeProgram != nullptr); + auto moduleDeclReflection = module->getModuleReflection(); SLANG_CHECK(moduleDeclReflection != nullptr); SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module); - SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 4); + SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 7); // First declaration should be a struct with 1 variable auto firstDecl = moduleDeclReflection->getChild(0); @@ -86,7 +109,7 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(firstDecl->getChildrenCount() == 1); { - slang::TypeReflection* type = firstDecl->getType(globalSession); + slang::TypeReflection* type = firstDecl->getType(); SLANG_CHECK(getTypeFullName(type) == "MyFuncPropertyAttribute"); // Check the field of the struct. @@ -140,6 +163,107 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "vector<float,4>"); } + // Sixth declaration should be a generic struct + auto sixthDecl = moduleDeclReflection->getChild(5); + SLANG_CHECK(sixthDecl->getKind() == slang::DeclReflection::Kind::Generic); + auto genericReflection = sixthDecl->asGeneric(); + SLANG_CHECK(genericReflection->getTypeParameterCount() == 1); + auto typeParamT = genericReflection->getTypeParameter(0); + SLANG_CHECK(UnownedStringSlice(typeParamT->getName()) == "T"); + auto typeParamTConstraintCount = genericReflection->getTypeParameterConstraintCount(typeParamT); + SLANG_CHECK(typeParamTConstraintCount == 2); + auto typeParamTConstraintType1 = genericReflection->getTypeParameterConstraintType(typeParamT, 0); + SLANG_CHECK(getTypeFullName(typeParamTConstraintType1) == "IArithmetic"); + auto typeParamTConstraintType2 = genericReflection->getTypeParameterConstraintType(typeParamT, 1); + SLANG_CHECK(getTypeFullName(typeParamTConstraintType2) == "IFloat"); + + auto innerStruct = genericReflection->getInnerDecl(); + SLANG_CHECK(innerStruct->getKind() == slang::DeclReflection::Kind::Struct); + + + // Check type-lookup-by-name + { + auto type = compositeProgram->getLayout()->findTypeByName("MyType"); + SLANG_CHECK(type != nullptr); + //SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct); + SLANG_CHECK(UnownedStringSlice(type->getName()) == "MyType"); + auto funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "f"); + SLANG_CHECK(funcReflection != nullptr); + SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "f"); + SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "float"); + SLANG_CHECK(funcReflection->getParameterCount() == 1); + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "x"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float"); + } + + // Check type-lookup-by-name for generic type + { + auto type = compositeProgram->getLayout()->findTypeByName("MyGenericType<half>"); + SLANG_CHECK(type != nullptr); + //SLANG_CHECK(type->getKind() == slang::DeclReflection::Kind::Struct); + SLANG_CHECK(getTypeFullName(type) == "MyGenericType<half>"); + auto funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "g"); + SLANG_CHECK(funcReflection != nullptr); + SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "g"); + SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "half"); + SLANG_CHECK(funcReflection->getParameterCount() == 0); + + auto varReflection = compositeProgram->getLayout()->findVarByNameInType(type, "z"); + SLANG_CHECK(varReflection != nullptr); + SLANG_CHECK(UnownedStringSlice(varReflection->getName()) == "z"); + SLANG_CHECK(getTypeFullName(varReflection->getType()) == "half"); + + funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "h<float>"); + SLANG_CHECK(funcReflection != nullptr); + SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "h"); + SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "float"); + SLANG_CHECK(funcReflection->getParameterCount() == 2); + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "x"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float"); + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "half"); + + // Access parent generic container from a specialized method. + auto specializationInfo = funcReflection->getGenericContainer(); + SLANG_CHECK(specializationInfo != nullptr); + SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "h"); + SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + // Check type parameters + SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1); + auto typeParam = specializationInfo->getTypeParameter(0); + SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "U"); // generic name + SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "float"); // specialized type name under the context in which the generic is obtained + SLANG_CHECK(specializationInfo->getTypeParameterConstraintCount(typeParam) == 0); + + // Go up another level to the generic struct + specializationInfo = specializationInfo->getOuterGenericContainer(); + SLANG_CHECK(specializationInfo != nullptr); + SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "MyGenericType"); + SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + // Check type parameters + SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1); + typeParam = specializationInfo->getTypeParameter(0); + SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "T"); // generic name + SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "half"); // specialized type name under the context in which the generic is obtained + SLANG_CHECK(specializationInfo->getTypeParameterConstraintCount(typeParam) == 2); + + // Query 'j' on the type 'half' + funcReflection = compositeProgram->getLayout()->findFunctionByNameInType(type, "j<10>"); + SLANG_CHECK(funcReflection != nullptr); + SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "j"); + + // Check the generic parameters + specializationInfo = funcReflection->getGenericContainer(); + SLANG_CHECK(specializationInfo != nullptr); + SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "j"); + SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + SLANG_CHECK(specializationInfo->getValueParameterCount() == 1); + auto valueParam = specializationInfo->getValueParameter(0); + SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name + SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10); + } + + // Check iterators { unsigned int count = 0; @@ -147,20 +271,27 @@ SLANG_UNIT_TEST(declTreeReflection) { count++; } - SLANG_CHECK(count == 4); + SLANG_CHECK(count == 7); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>()) { count++; } - SLANG_CHECK(count == 2); + SLANG_CHECK(count == 3); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Struct>()) { count++; } + SLANG_CHECK(count == 2); + + count = 0; + for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Generic>()) + { + count++; + } SLANG_CHECK(count == 1); } } |
