From 0d06ebcefb36a19710d87832fc1ea027e21281af Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 18 Jul 2024 13:11:19 -0400 Subject: Initial implementation for decl-tree reflection API (#4666) * Initial implementation for decl-tree reflection API This patch adds Slang API methods for walking all the declarations in the AST. We expose this functionality through an abstract `DeclReflection` class that can be a type, function or a variable declaration. We also provide ways to cast the decl to a `FunctionReflection`, `TypeReflection` or `VariableReflection` and traverse through the child nodes (for instance, a struct type will have component variable declarations) This patch also adds `ISlangInternal` as an internal COM interface to allow us to cast IGlobalSession to the internal Session pointer while bypassing any wrappers (such as the capture interface) * Update slang.h * Remove `ISlangInternal` (its causing a diamond pattern w.r.t `ISlangUnknown`) and use `ComPtr` for proper ref management. * Update unit-test-decl-tree-reflection.cpp * Change `FunctionDeclBase` to use `DeclRef` instead of directly using the decl. * Update slang-reflection-api.cpp --- include/slang.h | 160 +++++++++++++++++++- .../slang-capture-replay/slang-global-session.cpp | 20 ++- source/slang-capture-replay/slang-global-session.h | 5 +- source/slang-capture-replay/slang-module.cpp | 8 + source/slang-capture-replay/slang-module.h | 1 + source/slang/slang-compiler.h | 17 ++- source/slang/slang-reflection-api.cpp | 152 +++++++++++++++++-- source/slang/slang.cpp | 25 ++- .../unit-test-decl-tree-reflection.cpp | 167 +++++++++++++++++++++ 9 files changed, 523 insertions(+), 32 deletions(-) create mode 100644 tools/slang-unit-test/unit-test-decl-tree-reflection.cpp diff --git a/include/slang.h b/include/slang.h index 2f8f8d15d..86ca4649b 100644 --- a/include/slang.h +++ b/include/slang.h @@ -2107,6 +2107,7 @@ extern "C" typedef struct SlangEntryPoint SlangEntryPoint; typedef struct SlangEntryPointLayout SlangEntryPointLayout; + typedef struct SlangReflectionDecl SlangReflectionDecl; typedef struct SlangReflectionModifier SlangReflectionModifier; typedef struct SlangReflectionType SlangReflectionType; typedef struct SlangReflectionTypeLayout SlangReflectionTypeLayout; @@ -2174,6 +2175,18 @@ extern "C" SLANG_SCALAR_TYPE_UINTPTR }; + // abstract decl reflection + typedef unsigned int SlangDeclKindIntegral; + enum SlangDeclKind : SlangDeclKindIntegral + { + SLANG_DECL_KIND_UNSUPPORTED_FOR_REFLECTION, + SLANG_DECL_KIND_STRUCT, + SLANG_DECL_KIND_FUNC, + SLANG_DECL_KIND_MODULE, + SLANG_DECL_KIND_GENERIC, + SLANG_DECL_KIND_VARIABLE + }; + #ifndef SLANG_RESOURCE_SHAPE # define SLANG_RESOURCE_SHAPE typedef unsigned int SlangResourceShapeIntegral; @@ -2394,6 +2407,7 @@ extern "C" SLANG_MODIFIER_EXPORT, SLANG_MODIFIER_EXTERN, SLANG_MODIFIER_DIFFERENTIABLE, + SLANG_MODIFIER_MUTATING }; // User Attribute @@ -2528,6 +2542,7 @@ extern "C" SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* var); SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* var, unsigned int index); SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * globalSession, char const* name); + SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar); // Variable Layout Reflection @@ -2544,7 +2559,9 @@ extern "C" // Function Reflection + SLANG_API SlangReflectionDecl* spReflectionFunction_asDecl(SlangReflectionFunction* func); SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* func); + SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* var, SlangModifierID modifierID); SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* func); SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* func, unsigned int index); SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* func, SlangSession* globalSession, char const* name); @@ -2552,6 +2569,16 @@ extern "C" SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* func, unsigned index); SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func); + // Abstract Decl Reflection + + SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl); + SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* parentDecl, unsigned int index); + SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl); + SLANG_API SlangReflectionFunction* spReflectionDecl_castToFunction(SlangReflectionDecl* decl); + SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflectionDecl* decl); + SLANG_API SlangReflectionType* spReflection_getTypeFromDecl(SlangSession* session, SlangReflectionDecl* decl); + + /** Get the stage that a variable belongs to (if any). A variable "belongs" to a specific stage when it is a varying input/output @@ -2693,6 +2720,7 @@ SLANG_API slang::ISession* spReflection_GetSession(SlangReflection* reflection); namespace slang { struct BufferReflection; + struct DeclReflection; struct TypeLayoutReflection; struct TypeReflection; struct VariableLayoutReflection; @@ -3293,6 +3321,7 @@ namespace slang Export = SLANG_MODIFIER_EXPORT, Extern = SLANG_MODIFIER_EXTERN, Differentiable = SLANG_MODIFIER_DIFFERENTIABLE, + Mutating = SLANG_MODIFIER_MUTATING }; }; @@ -3325,6 +3354,11 @@ namespace slang { return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, globalSession, name); } + + bool hasDefaultValue() + { + return spReflectionVariable_HasDefaultValue((SlangReflectionVariable*)this); + } }; struct VariableLayoutReflection @@ -3435,20 +3469,20 @@ namespace slang unsigned int getUserAttributeCount() { - return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this); + return spReflectionFunction_GetUserAttributeCount((SlangReflectionFunction*)this); } UserAttribute* getUserAttributeByIndex(unsigned int index) { - return (UserAttribute*)spReflectionVariable_GetUserAttribute((SlangReflectionVariable*)this, index); + return (UserAttribute*)spReflectionFunction_GetUserAttribute((SlangReflectionFunction*)this, index); } UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name) { - return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, globalSession, name); + return (UserAttribute*)spReflectionFunction_FindUserAttributeByName((SlangReflectionFunction*)this, globalSession, name); } Modifier* findModifier(Modifier::ID id) { - return (Modifier*)spReflectionVariable_FindModifier((SlangReflectionVariable*)this, (SlangModifierID)id); + return (Modifier*)spReflectionFunction_FindModifier((SlangReflectionFunction*)this, (SlangModifierID)id); } }; @@ -3672,6 +3706,122 @@ namespace slang } }; + + struct DeclReflection + { + enum class Kind + { + Unsupported = SLANG_DECL_KIND_UNSUPPORTED_FOR_REFLECTION, + Struct = SLANG_DECL_KIND_STRUCT, + Func = SLANG_DECL_KIND_FUNC, + Module = SLANG_DECL_KIND_MODULE, + Generic = SLANG_DECL_KIND_GENERIC, + Variable = SLANG_DECL_KIND_VARIABLE, + }; + + Kind getKind() + { + return (Kind)spReflectionDecl_getKind((SlangReflectionDecl*)this); + } + + unsigned int getChildrenCount() + { + return spReflectionDecl_getChildrenCount((SlangReflectionDecl*)this); + } + + DeclReflection* getChild(unsigned int index) + { + return (DeclReflection*)spReflectionDecl_getChild((SlangReflectionDecl*)this, index); + } + + TypeReflection* getType(SlangSession* session) + { + return (TypeReflection*)spReflection_getTypeFromDecl(session, (SlangReflectionDecl*)this); + } + + VariableReflection* asVariable() + { + return (VariableReflection*)spReflectionDecl_castToVariable((SlangReflectionDecl*)this); + } + + FunctionReflection* asFunction() + { + return (FunctionReflection*)spReflectionDecl_castToFunction((SlangReflectionDecl*)this); + } + + template + struct FilteredList + { + unsigned int count; + DeclReflection* parent; + + struct FilteredIterator + { + DeclReflection* parent; + unsigned int count; + unsigned int index; + + DeclReflection* operator*() { return parent->getChild(index); } + void operator++() + { + index++; + while (index < count && !(parent->getChild(index)->getKind() == K)) + { + index++; + } + } + bool operator!=(FilteredIterator const& other) { return index != other.index; } + }; + + // begin/end for range-based for that checks the kind + FilteredIterator begin() + { + // Find the first child of the right kind + unsigned int index = 0; + while (index < count && !(parent->getChild(index)->getKind() == K)) + { + index++; + } + return FilteredIterator{parent, count, index}; + } + + FilteredIterator end() { return FilteredIterator{parent, count, count}; } + }; + + template + FilteredList getChildrenOfKind() + { + return FilteredList{ getChildrenCount(), (DeclReflection*)this }; + } + + struct IteratedList + { + unsigned int count; + DeclReflection* parent; + + struct Iterator + { + DeclReflection* parent; + unsigned int count; + unsigned int index; + + DeclReflection* operator*() { return parent->getChild(index); } + void operator++() { index++; } + bool operator!=(Iterator const& other) { return index != other.index; } + }; + + // begin/end for range-based for that checks the kind + IteratedList::Iterator begin() { return IteratedList::Iterator{ parent, count, 0 }; } + IteratedList::Iterator end() { return IteratedList::Iterator{ parent, count, count }; } + }; + + IteratedList getChildren() + { + return IteratedList{ getChildrenCount(), (DeclReflection*)this }; + } + + }; + typedef uint32_t CompileStdLibFlags; struct CompileStdLibFlag { @@ -5132,6 +5282,8 @@ namespace slang /// Get the path to a file this module depends on. virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath( SlangInt32 index) = 0; + + virtual SLANG_NO_THROW DeclReflection* SLANG_MCALL getModuleReflection() = 0; }; #define SLANG_UUID_IModule IModule::getTypeGuid() diff --git a/source/slang-capture-replay/slang-global-session.cpp b/source/slang-capture-replay/slang-global-session.cpp index b8dfd7d40..d8f1e361d 100644 --- a/source/slang-capture-replay/slang-global-session.cpp +++ b/source/slang-capture-replay/slang-global-session.cpp @@ -28,11 +28,23 @@ namespace SlangCapture m_actualGlobalSession->release(); } - ISlangUnknown* GlobalSessionCapture::getInterface(const Guid& guid) + SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::queryInterface(SlangUUID const& uuid, void** outObject) { - if(guid == ISlangUnknown::getTypeGuid() || guid == IGlobalSession::getTypeGuid()) - return asExternal(this); - return nullptr; + if (uuid == Session::getTypeGuid()) + { + // no add-ref here, the query will cause the inner session to handle the add-ref. + this->m_actualGlobalSession->queryInterface(uuid, outObject); + return SLANG_OK; + } + + if (uuid == ISlangUnknown::getTypeGuid() && uuid == IGlobalSession::getTypeGuid()) + { + addReference(); + *outObject = static_cast(this); + return SLANG_OK; + } + + return SLANG_E_NO_INTERFACE; } SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::createSession(slang::SessionDesc const& desc, slang::ISession** outSession) diff --git a/source/slang-capture-replay/slang-global-session.h b/source/slang-capture-replay/slang-global-session.h index ae451505f..f478e60db 100644 --- a/source/slang-capture-replay/slang-global-session.h +++ b/source/slang-capture-replay/slang-global-session.h @@ -17,9 +17,10 @@ namespace SlangCapture explicit GlobalSessionCapture(slang::IGlobalSession* session); virtual ~GlobalSessionCapture(); - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_REF_OBJECT_IUNKNOWN_ADD_REF + SLANG_REF_OBJECT_IUNKNOWN_RELEASE - ISlangUnknown* getInterface(const Guid& guid); + SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE; // slang::IGlobalSession SLANG_NO_THROW SlangResult SLANG_MCALL createSession(slang::SessionDesc const& desc, slang::ISession** outSession) override; diff --git a/source/slang-capture-replay/slang-module.cpp b/source/slang-capture-replay/slang-module.cpp index 8a1a80126..273faa59d 100644 --- a/source/slang-capture-replay/slang-module.cpp +++ b/source/slang-capture-replay/slang-module.cpp @@ -27,6 +27,14 @@ namespace SlangCapture return nullptr; } + SLANG_NO_THROW slang::DeclReflection* ModuleCapture::getModuleReflection() + { + // No need to capture this call as it is just a query. + slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__); + slang::DeclReflection* res = (slang::DeclReflection*)m_actualModule->getModuleReflection(); + return res; + } + SLANG_NO_THROW SlangResult ModuleCapture::findEntryPointByName( char const* name, slang::IEntryPoint** outEntryPoint) diff --git a/source/slang-capture-replay/slang-module.h b/source/slang-capture-replay/slang-module.h index de63b7967..94539532c 100644 --- a/source/slang-capture-replay/slang-module.h +++ b/source/slang-capture-replay/slang-module.h @@ -83,6 +83,7 @@ namespace SlangCapture uint32_t compilerOptionEntryCount, slang::CompilerOptionEntry* compilerOptionEntries, ISlangBlob** outDiagnostics = nullptr) override; + virtual SLANG_NO_THROW slang::DeclReflection* getModuleReflection() override; slang::IModule* getActualModule() const { return m_actualModule; } private: diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 35b2fcf44..2409cedfb 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -282,6 +282,7 @@ namespace Slang HashSet m_fileSet; }; + class EntryPoint; class ComponentType; @@ -1051,7 +1052,7 @@ namespace Slang SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() SLANG_OVERRIDE { - return (slang::FunctionReflection*)m_funcDeclRef.getDecl(); + return (slang::FunctionReflection*)m_funcDeclRef.declRefBase; } protected: void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE; @@ -1476,6 +1477,8 @@ namespace Slang virtual void buildHash(DigestBuilder& builder) SLANG_OVERRIDE; + virtual slang::DeclReflection* getModuleReflection() SLANG_OVERRIDE; + void setDigest(SHA1::Digest const& digest) { m_digest = digest; } SHA1::Digest computeDigest(); @@ -3079,9 +3082,11 @@ namespace Slang class Session : public RefObject, public slang::IGlobalSession { public: - SLANG_REF_OBJECT_IUNKNOWN_ALL + SLANG_COM_INTERFACE(0xd6b767eb, 0xd786, 0x4343, { 0x2a, 0x8c, 0x6d, 0xa0, 0x3d, 0x5a, 0xb4, 0x4a }) - ISlangUnknown* getInterface(const Guid& guid); + SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE; + SLANG_REF_OBJECT_IUNKNOWN_ADD_REF + SLANG_REF_OBJECT_IUNKNOWN_RELEASE // slang::IGlobalSession SLANG_NO_THROW SlangResult SLANG_MCALL createSession(slang::SessionDesc const& desc, slang::ISession** outSession) override; @@ -3287,9 +3292,11 @@ SLANG_FORCE_INLINE slang::IGlobalSession* asExternal(Session* session) return static_cast(session); } -SLANG_FORCE_INLINE Session* asInternal(slang::IGlobalSession* session) +SLANG_FORCE_INLINE ComPtr asInternal(slang::IGlobalSession* session) { - return static_cast(session); + Slang::Session* internalSession = nullptr; + session->queryInterface(SLANG_IID_PPV_ARGS(&internalSession)); + return ComPtr(INIT_ATTACH, static_cast(internalSession)); } SLANG_FORCE_INLINE slang::ISession* asExternal(Linkage* linkage) diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index d0d09f814..9e08330e4 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -65,14 +65,15 @@ static inline SlangReflectionVariable* convert(Decl* var) return (SlangReflectionVariable*) var; } -static inline FunctionDeclBase* convert(SlangReflectionFunction* func) +static inline DeclRef convert(SlangReflectionFunction* func) { - return (FunctionDeclBase*)func; + DeclRefBase* declBase = (DeclRefBase*)func; + return DeclRef(declBase); } -static inline SlangReflectionFunction* convert(FunctionDeclBase* func) +static inline SlangReflectionFunction* convert(DeclRef func) { - return (SlangReflectionFunction*)func; + return (SlangReflectionFunction*)func.declRefBase; } static inline VarLayout* convert(SlangReflectionVariableLayout* var) @@ -773,7 +774,7 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti { auto result = program->findDeclFromString(name, &sink); if (auto funcDeclRef = result.as()) - return (SlangReflectionFunction*)funcDeclRef.getDecl(); + return convert(funcDeclRef); } catch (...) { @@ -2621,6 +2622,9 @@ SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflec case SLANG_MODIFIER_DIFFERENTIABLE: modifier = var->findModifier(); break; + case SLANG_MODIFIER_MUTATING: + modifier = var->findModifier(); + break; default: return nullptr; } @@ -2647,6 +2651,17 @@ SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeBy return findUserAttributeByName(asInternal(session), varDecl, name); } +SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar) +{ + auto decl = convert(inVar); + if (auto varDecl = as(decl)) + { + return varDecl->initExpr != nullptr; + } + + return false; +} + // Variable Layout Reflection SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangReflectionVariableLayout* inVarLayout) @@ -2793,11 +2808,18 @@ SLANG_API SlangStage spReflectionVariableLayout_getStage( // Function Reflection +SLANG_API SlangReflectionDecl* spReflectionFunction_asDecl(SlangReflectionFunction* inFunc) +{ + auto func = convert(inFunc); + if (!func) return nullptr; + return (SlangReflectionDecl*)func.getDecl(); +} + SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* inFunc) { auto func = convert(inFunc); if (!func) return nullptr; - return getText(func->getName()).getBuffer(); + return getText(func.getDecl()->getName()).getBuffer(); } SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* inFunc) @@ -2805,42 +2827,146 @@ SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectio auto func = convert(inFunc); if (!func) return nullptr; - return convert(func->returnType.type); + auto rawType = func.getDecl()->returnType.type; + auto astBuilder = rawType->getASTBuilderForReflection(); + + return convert((Type*)rawType->substitute(astBuilder, SubstitutionSet(func.declRefBase))); +} + +SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* inFunc, SlangModifierID modifierID) +{ + auto funcDeclRef = convert(inFunc); + auto varRefl = convert(funcDeclRef.getDecl()); + if (!varRefl) return nullptr; + + return spReflectionVariable_FindModifier(varRefl, modifierID); } SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* inFunc) { auto func = convert(inFunc); if (!func) return 0; - return getUserAttributeCount(func); + return getUserAttributeCount(func.getDecl()); } SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* inFunc, unsigned int index) { auto func = convert(inFunc); if (!func) return nullptr; - return getUserAttributeByIndex(func, index); + return getUserAttributeByIndex(func.getDecl(), index); } SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* inFunc, SlangSession* session, char const* name) { auto func = convert(inFunc); if (!func) return nullptr; - return findUserAttributeByName(asInternal(session), func, name); + return findUserAttributeByName(asInternal(session), func.getDecl(), name); } SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* inFunc) { auto func = convert(inFunc); if (!func) return 0; - return (unsigned int)func->getParameters().getCount(); + return (unsigned int)func.getDecl()->getParameters().getCount(); } SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* inFunc, unsigned int index) { auto func = convert(inFunc); if (!func) return nullptr; - return convert(as(func->getParameters()[index])); + return convert(as(func.getDecl()->getParameters()[index])); +} + +// Abstract decl reflection + +SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl) +{ + Decl* decl = (Decl*)parentDecl; + if (as(decl)) + { + return (unsigned int)as(decl)->members.getCount(); + } + + return 0; +} + +SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* parentDecl, unsigned int index) +{ + Decl* decl = (Decl*)parentDecl; + if (auto containerDecl = as(decl)) + { + if (containerDecl->members.getCount() > index) + return (SlangReflectionDecl*)containerDecl->members[index]; + } + + return nullptr; +} + +SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl) +{ + Decl* slangDecl = (Decl*)decl; + if (as(slangDecl)) + { + return SLANG_DECL_KIND_STRUCT; + } + else if (as(slangDecl)) + { + return SLANG_DECL_KIND_VARIABLE; + } + else if (as(slangDecl)) + { + return SLANG_DECL_KIND_GENERIC; + } + else if (as(slangDecl)) + { + return SLANG_DECL_KIND_FUNC; + } + else if (as(slangDecl)) + { + return SLANG_DECL_KIND_MODULE; + } + else + return SLANG_DECL_KIND_UNSUPPORTED_FOR_REFLECTION; +} + +SLANG_API SlangReflectionFunction* spReflectionDecl_castToFunction(SlangReflectionDecl* decl) +{ + Decl* slangDecl = (Decl*) decl; + if (auto funcDecl = as(slangDecl)) + { + return convert(DeclRef(funcDecl->getDefaultDeclRef())); + } + + // Improper cast + return nullptr; +} + +SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflectionDecl* decl) +{ + Decl* slangDecl = (Decl*) decl; + if (auto varDecl = as(slangDecl)) + { + return (SlangReflectionVariable*) varDecl; + } + + // Improper cast + return nullptr; + +} + +SLANG_API SlangReflectionType* spReflection_getTypeFromDecl(SlangSession* session, SlangReflectionDecl* decl) +{ + Decl* slangDecl = (Decl*)decl; + auto slangSession = asInternal(session); + + ASTBuilder* builder = slangSession->getGlobalASTBuilder(); + if (auto type = DeclRefType::create(builder, slangDecl->getDefaultDeclRef())) + { + return convert(type); + } + + // Couldn't create a type from the decl + return nullptr; } // Shader Parameter Reflection @@ -2906,7 +3032,7 @@ SLANG_API SlangReflectionFunction* spReflectionEntryPoint_GetFunction(SlangRefle auto entryPointLayout = convert(inEntryPoint); if (entryPointLayout) { - return (SlangReflectionFunction*)entryPointLayout->entryPoint.getDecl(); + return convert(entryPointLayout->entryPoint); } return nullptr; } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index f28d3cdeb..f440ce3cb 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -546,11 +546,23 @@ SlangResult Session::_readBuiltinModule(ISlangFileSystem* fileSystem, Scope* sco return SLANG_OK; } -ISlangUnknown* Session::getInterface(const Guid& guid) +SLANG_NO_THROW SlangResult SLANG_MCALL Session::queryInterface(SlangUUID const& uuid, void** outObject) { - if(guid == ISlangUnknown::getTypeGuid() || guid == IGlobalSession::getTypeGuid()) - return asExternal(this); - return nullptr; + if (uuid == Session::getTypeGuid()) + { + addReference(); + *outObject = static_cast(this); + return SLANG_OK; + } + + if (uuid == ISlangUnknown::getTypeGuid() && uuid == IGlobalSession::getTypeGuid()) + { + addReference(); + *outObject = static_cast(this); + return SLANG_OK; + } + + return SLANG_E_NO_INTERFACE; } static size_t _getStructureSize(const uint8_t* src) @@ -4070,6 +4082,11 @@ void Module::buildHash(DigestBuilder& builder) builder.append(computeDigest()); } +slang::DeclReflection* Module::getModuleReflection() +{ + return (slang::DeclReflection*)m_moduleDecl; +} + SHA1::Digest Module::computeDigest() { if (m_digest == SHA1::Digest()) diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp new file mode 100644 index 000000000..e443358fb --- /dev/null +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -0,0 +1,167 @@ +// unit-test-translation-unit-import.cpp + +#include "slang.h" + +#include +#include + +#include "tools/unit-test/slang-unit-test.h" +#include "slang-com-ptr.h" +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" + +using namespace Slang; + +static String getTypeFullName(slang::TypeReflection* type) +{ + ComPtr blob; + type->getFullName(blob.writeRef()); + return String((const char*)blob->getBufferPointer()); +} + +static void printRefl(slang::DeclReflection* refl, unsigned int level = 0) +{ + // Mapping of kind ids to names + std::string names[] = {"Unsupported", "Struct", "Function", "Module", "Generic", "Variable"}; + for (unsigned int i = 0; i < level; i++) + { + std::cout << " "; + } + std::cout<< "[" << names[(unsigned int)refl->getKind()] << "] (" << refl->getChildrenCount() << ")" << std::endl; + + for (auto* child : refl->getChildren()) + { + printRefl(child, level + 1); + } +} + +// Test that the reflection API provides correct info about entry point and ordinary functions. + +SLANG_UNIT_TEST(declTreeReflection) +{ + // Source for a module that contains an undecorated entrypoint. + const char* userSourceBody = R"( + [__AttributeUsage(_AttributeTargets.Function)] + struct MyFuncPropertyAttribute {int v;} + + [MyFuncProperty(1024)] + [Differentiable] + float ordinaryFunc(no_diff float x, int y) { return x + y; } + + float4 fragMain(float4 pos:SV_Position) : SV_Position + { + return pos; + } + )"; + + auto moduleName = "moduleG" + String(Process::getId()); + String userSource = "import " + moduleName + ";\n" + userSourceBody; + ComPtr globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr diagnosticBlob; + auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + ComPtr entryPoint; + module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef()); + SLANG_CHECK(entryPoint != nullptr); + + auto moduleDeclReflection = module->getModuleReflection(); + SLANG_CHECK(moduleDeclReflection != nullptr); + SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module); + SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 4); + + // First declaration should be a struct with 1 variable + auto firstDecl = moduleDeclReflection->getChild(0); + SLANG_CHECK(firstDecl->getKind() == slang::DeclReflection::Kind::Struct); + SLANG_CHECK(firstDecl->getChildrenCount() == 1); + + { + slang::TypeReflection* type = firstDecl->getType(globalSession); + SLANG_CHECK(getTypeFullName(type) == "MyFuncPropertyAttribute"); + + // Check the field of the struct. + SLANG_CHECK(type->getFieldCount() == 1); + auto field = type->getFieldByIndex(0); + SLANG_CHECK(UnownedStringSlice(field->getName()) == "v"); + SLANG_CHECK(getTypeFullName(field->getType()) == "int"); + } + + // Second declaration should be a function + auto secondDecl = moduleDeclReflection->getChild(1); + SLANG_CHECK(secondDecl->getKind() == slang::DeclReflection::Kind::Func); + SLANG_CHECK(secondDecl->getChildrenCount() == 2); // Parameter declarations are children (return type is not) + + { + auto funcReflection = secondDecl->asFunction(); + SLANG_CHECK(funcReflection->findModifier(slang::Modifier::Differentiable) != nullptr); + SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "float"); + SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "ordinaryFunc"); + SLANG_CHECK(funcReflection->getParameterCount() == 2); + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "x"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float"); + SLANG_CHECK(funcReflection->getParameterByIndex(0)->findModifier(slang::Modifier::NoDiff) != nullptr); + + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "int"); + + SLANG_CHECK(funcReflection->getUserAttributeCount() == 1); + auto userAttribute = funcReflection->getUserAttributeByIndex(0); + SLANG_CHECK(UnownedStringSlice(userAttribute->getName()) == "MyFuncProperty"); + SLANG_CHECK(userAttribute->getArgumentCount() == 1); + SLANG_CHECK(getTypeFullName(userAttribute->getArgumentType(0)) == "int"); + int val = 0; + auto result = userAttribute->getArgumentValueInt(0, &val); + SLANG_CHECK(result == SLANG_OK); + SLANG_CHECK(val == 1024); + SLANG_CHECK(funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute); + } + + // Third declaration should also be a function + auto thirdDecl = moduleDeclReflection->getChild(2); + SLANG_CHECK(thirdDecl->getKind() == slang::DeclReflection::Kind::Func); + SLANG_CHECK(thirdDecl->getChildrenCount() == 1); + + { + auto funcReflection = thirdDecl->asFunction(); + SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "vector"); + SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "fragMain"); + SLANG_CHECK(funcReflection->getParameterCount() == 1); + SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "pos"); + SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "vector"); + } + + // Check iterators + { + unsigned int count = 0; + for (auto* child : moduleDeclReflection->getChildren()) + { + count++; + } + SLANG_CHECK(count == 4); + + count = 0; + for (auto* child : moduleDeclReflection->getChildrenOfKind()) + { + count++; + } + SLANG_CHECK(count == 2); + + count = 0; + for (auto* child : moduleDeclReflection->getChildrenOfKind()) + { + count++; + } + SLANG_CHECK(count == 1); + } +} + -- cgit v1.2.3