summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-07-18 13:11:19 -0400
committerGitHub <noreply@github.com>2024-07-18 13:11:19 -0400
commit0d06ebcefb36a19710d87832fc1ea027e21281af (patch)
tree05ac5af6fa26a658348335eac3d63199b6949507
parent89e836d42822e69dcaa4eb0a366d8c66e5aaa7e4 (diff)
Initial implementation for decl-tree reflection API (#4666)
* Initial implementation for decl-tree reflection API This patch adds Slang API methods for walking all the declarations in the AST. We expose this functionality through an abstract `DeclReflection` class that can be a type, function or a variable declaration. We also provide ways to cast the decl to a `FunctionReflection`, `TypeReflection` or `VariableReflection` and traverse through the child nodes (for instance, a struct type will have component variable declarations) This patch also adds `ISlangInternal` as an internal COM interface to allow us to cast IGlobalSession to the internal Session pointer while bypassing any wrappers (such as the capture interface) * Update slang.h * Remove `ISlangInternal` (its causing a diamond pattern w.r.t `ISlangUnknown`) and use `ComPtr` for proper ref management. * Update unit-test-decl-tree-reflection.cpp * Change `FunctionDeclBase` to use `DeclRef` instead of directly using the decl. * Update slang-reflection-api.cpp
-rw-r--r--include/slang.h160
-rw-r--r--source/slang-capture-replay/slang-global-session.cpp20
-rw-r--r--source/slang-capture-replay/slang-global-session.h5
-rw-r--r--source/slang-capture-replay/slang-module.cpp8
-rw-r--r--source/slang-capture-replay/slang-module.h1
-rwxr-xr-xsource/slang/slang-compiler.h17
-rw-r--r--source/slang/slang-reflection-api.cpp152
-rw-r--r--source/slang/slang.cpp25
-rw-r--r--tools/slang-unit-test/unit-test-decl-tree-reflection.cpp167
9 files changed, 523 insertions, 32 deletions
diff --git a/include/slang.h b/include/slang.h
index 2f8f8d15d..86ca4649b 100644
--- a/include/slang.h
+++ b/include/slang.h
@@ -2107,6 +2107,7 @@ extern "C"
typedef struct SlangEntryPoint SlangEntryPoint;
typedef struct SlangEntryPointLayout SlangEntryPointLayout;
+ typedef struct SlangReflectionDecl SlangReflectionDecl;
typedef struct SlangReflectionModifier SlangReflectionModifier;
typedef struct SlangReflectionType SlangReflectionType;
typedef struct SlangReflectionTypeLayout SlangReflectionTypeLayout;
@@ -2174,6 +2175,18 @@ extern "C"
SLANG_SCALAR_TYPE_UINTPTR
};
+ // abstract decl reflection
+ typedef unsigned int SlangDeclKindIntegral;
+ enum SlangDeclKind : SlangDeclKindIntegral
+ {
+ SLANG_DECL_KIND_UNSUPPORTED_FOR_REFLECTION,
+ SLANG_DECL_KIND_STRUCT,
+ SLANG_DECL_KIND_FUNC,
+ SLANG_DECL_KIND_MODULE,
+ SLANG_DECL_KIND_GENERIC,
+ SLANG_DECL_KIND_VARIABLE
+ };
+
#ifndef SLANG_RESOURCE_SHAPE
# define SLANG_RESOURCE_SHAPE
typedef unsigned int SlangResourceShapeIntegral;
@@ -2394,6 +2407,7 @@ extern "C"
SLANG_MODIFIER_EXPORT,
SLANG_MODIFIER_EXTERN,
SLANG_MODIFIER_DIFFERENTIABLE,
+ SLANG_MODIFIER_MUTATING
};
// User Attribute
@@ -2528,6 +2542,7 @@ extern "C"
SLANG_API unsigned int spReflectionVariable_GetUserAttributeCount(SlangReflectionVariable* var);
SLANG_API SlangReflectionUserAttribute* spReflectionVariable_GetUserAttribute(SlangReflectionVariable* var, unsigned int index);
SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * globalSession, char const* name);
+ SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar);
// Variable Layout Reflection
@@ -2544,7 +2559,9 @@ extern "C"
// Function Reflection
+ SLANG_API SlangReflectionDecl* spReflectionFunction_asDecl(SlangReflectionFunction* func);
SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* func);
+ SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* var, SlangModifierID modifierID);
SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* func);
SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* func, unsigned int index);
SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* func, SlangSession* globalSession, char const* name);
@@ -2552,6 +2569,16 @@ extern "C"
SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* func, unsigned index);
SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func);
+ // Abstract Decl Reflection
+
+ SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl);
+ SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* parentDecl, unsigned int index);
+ SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl);
+ SLANG_API SlangReflectionFunction* spReflectionDecl_castToFunction(SlangReflectionDecl* decl);
+ SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflectionDecl* decl);
+ SLANG_API SlangReflectionType* spReflection_getTypeFromDecl(SlangSession* session, SlangReflectionDecl* decl);
+
+
/** Get the stage that a variable belongs to (if any).
A variable "belongs" to a specific stage when it is a varying input/output
@@ -2693,6 +2720,7 @@ SLANG_API slang::ISession* spReflection_GetSession(SlangReflection* reflection);
namespace slang
{
struct BufferReflection;
+ struct DeclReflection;
struct TypeLayoutReflection;
struct TypeReflection;
struct VariableLayoutReflection;
@@ -3293,6 +3321,7 @@ namespace slang
Export = SLANG_MODIFIER_EXPORT,
Extern = SLANG_MODIFIER_EXTERN,
Differentiable = SLANG_MODIFIER_DIFFERENTIABLE,
+ Mutating = SLANG_MODIFIER_MUTATING
};
};
@@ -3325,6 +3354,11 @@ namespace slang
{
return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, globalSession, name);
}
+
+ bool hasDefaultValue()
+ {
+ return spReflectionVariable_HasDefaultValue((SlangReflectionVariable*)this);
+ }
};
struct VariableLayoutReflection
@@ -3435,20 +3469,20 @@ namespace slang
unsigned int getUserAttributeCount()
{
- return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this);
+ return spReflectionFunction_GetUserAttributeCount((SlangReflectionFunction*)this);
}
UserAttribute* getUserAttributeByIndex(unsigned int index)
{
- return (UserAttribute*)spReflectionVariable_GetUserAttribute((SlangReflectionVariable*)this, index);
+ return (UserAttribute*)spReflectionFunction_GetUserAttribute((SlangReflectionFunction*)this, index);
}
UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name)
{
- return (UserAttribute*)spReflectionVariable_FindUserAttributeByName((SlangReflectionVariable*)this, globalSession, name);
+ return (UserAttribute*)spReflectionFunction_FindUserAttributeByName((SlangReflectionFunction*)this, globalSession, name);
}
Modifier* findModifier(Modifier::ID id)
{
- return (Modifier*)spReflectionVariable_FindModifier((SlangReflectionVariable*)this, (SlangModifierID)id);
+ return (Modifier*)spReflectionFunction_FindModifier((SlangReflectionFunction*)this, (SlangModifierID)id);
}
};
@@ -3672,6 +3706,122 @@ namespace slang
}
};
+
+ struct DeclReflection
+ {
+ enum class Kind
+ {
+ Unsupported = SLANG_DECL_KIND_UNSUPPORTED_FOR_REFLECTION,
+ Struct = SLANG_DECL_KIND_STRUCT,
+ Func = SLANG_DECL_KIND_FUNC,
+ Module = SLANG_DECL_KIND_MODULE,
+ Generic = SLANG_DECL_KIND_GENERIC,
+ Variable = SLANG_DECL_KIND_VARIABLE,
+ };
+
+ Kind getKind()
+ {
+ return (Kind)spReflectionDecl_getKind((SlangReflectionDecl*)this);
+ }
+
+ unsigned int getChildrenCount()
+ {
+ return spReflectionDecl_getChildrenCount((SlangReflectionDecl*)this);
+ }
+
+ DeclReflection* getChild(unsigned int index)
+ {
+ return (DeclReflection*)spReflectionDecl_getChild((SlangReflectionDecl*)this, index);
+ }
+
+ TypeReflection* getType(SlangSession* session)
+ {
+ return (TypeReflection*)spReflection_getTypeFromDecl(session, (SlangReflectionDecl*)this);
+ }
+
+ VariableReflection* asVariable()
+ {
+ return (VariableReflection*)spReflectionDecl_castToVariable((SlangReflectionDecl*)this);
+ }
+
+ FunctionReflection* asFunction()
+ {
+ return (FunctionReflection*)spReflectionDecl_castToFunction((SlangReflectionDecl*)this);
+ }
+
+ template <Kind K>
+ struct FilteredList
+ {
+ unsigned int count;
+ DeclReflection* parent;
+
+ struct FilteredIterator
+ {
+ DeclReflection* parent;
+ unsigned int count;
+ unsigned int index;
+
+ DeclReflection* operator*() { return parent->getChild(index); }
+ void operator++()
+ {
+ index++;
+ while (index < count && !(parent->getChild(index)->getKind() == K))
+ {
+ index++;
+ }
+ }
+ bool operator!=(FilteredIterator const& other) { return index != other.index; }
+ };
+
+ // begin/end for range-based for that checks the kind
+ FilteredIterator begin()
+ {
+ // Find the first child of the right kind
+ unsigned int index = 0;
+ while (index < count && !(parent->getChild(index)->getKind() == K))
+ {
+ index++;
+ }
+ return FilteredIterator{parent, count, index};
+ }
+
+ FilteredIterator end() { return FilteredIterator{parent, count, count}; }
+ };
+
+ template <Kind K>
+ FilteredList<K> getChildrenOfKind()
+ {
+ return FilteredList<K>{ getChildrenCount(), (DeclReflection*)this };
+ }
+
+ struct IteratedList
+ {
+ unsigned int count;
+ DeclReflection* parent;
+
+ struct Iterator
+ {
+ DeclReflection* parent;
+ unsigned int count;
+ unsigned int index;
+
+ DeclReflection* operator*() { return parent->getChild(index); }
+ void operator++() { index++; }
+ bool operator!=(Iterator const& other) { return index != other.index; }
+ };
+
+ // begin/end for range-based for that checks the kind
+ IteratedList::Iterator begin() { return IteratedList::Iterator{ parent, count, 0 }; }
+ IteratedList::Iterator end() { return IteratedList::Iterator{ parent, count, count }; }
+ };
+
+ IteratedList getChildren()
+ {
+ return IteratedList{ getChildrenCount(), (DeclReflection*)this };
+ }
+
+ };
+
typedef uint32_t CompileStdLibFlags;
struct CompileStdLibFlag
{
@@ -5132,6 +5282,8 @@ namespace slang
/// Get the path to a file this module depends on.
virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(
SlangInt32 index) = 0;
+
+ virtual SLANG_NO_THROW DeclReflection* SLANG_MCALL getModuleReflection() = 0;
};
#define SLANG_UUID_IModule IModule::getTypeGuid()
diff --git a/source/slang-capture-replay/slang-global-session.cpp b/source/slang-capture-replay/slang-global-session.cpp
index b8dfd7d40..d8f1e361d 100644
--- a/source/slang-capture-replay/slang-global-session.cpp
+++ b/source/slang-capture-replay/slang-global-session.cpp
@@ -28,11 +28,23 @@ namespace SlangCapture
m_actualGlobalSession->release();
}
- ISlangUnknown* GlobalSessionCapture::getInterface(const Guid& guid)
+ SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::queryInterface(SlangUUID const& uuid, void** outObject)
{
- if(guid == ISlangUnknown::getTypeGuid() || guid == IGlobalSession::getTypeGuid())
- return asExternal(this);
- return nullptr;
+ if (uuid == Session::getTypeGuid())
+ {
+ // no add-ref here, the query will cause the inner session to handle the add-ref.
+ this->m_actualGlobalSession->queryInterface(uuid, outObject);
+ return SLANG_OK;
+ }
+
+ if (uuid == ISlangUnknown::getTypeGuid() && uuid == IGlobalSession::getTypeGuid())
+ {
+ addReference();
+ *outObject = static_cast<slang::IGlobalSession*>(this);
+ return SLANG_OK;
+ }
+
+ return SLANG_E_NO_INTERFACE;
}
SLANG_NO_THROW SlangResult SLANG_MCALL GlobalSessionCapture::createSession(slang::SessionDesc const& desc, slang::ISession** outSession)
diff --git a/source/slang-capture-replay/slang-global-session.h b/source/slang-capture-replay/slang-global-session.h
index ae451505f..f478e60db 100644
--- a/source/slang-capture-replay/slang-global-session.h
+++ b/source/slang-capture-replay/slang-global-session.h
@@ -17,9 +17,10 @@ namespace SlangCapture
explicit GlobalSessionCapture(slang::IGlobalSession* session);
virtual ~GlobalSessionCapture();
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_REF_OBJECT_IUNKNOWN_ADD_REF
+ SLANG_REF_OBJECT_IUNKNOWN_RELEASE
- ISlangUnknown* getInterface(const Guid& guid);
+ SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE;
// slang::IGlobalSession
SLANG_NO_THROW SlangResult SLANG_MCALL createSession(slang::SessionDesc const& desc, slang::ISession** outSession) override;
diff --git a/source/slang-capture-replay/slang-module.cpp b/source/slang-capture-replay/slang-module.cpp
index 8a1a80126..273faa59d 100644
--- a/source/slang-capture-replay/slang-module.cpp
+++ b/source/slang-capture-replay/slang-module.cpp
@@ -27,6 +27,14 @@ namespace SlangCapture
return nullptr;
}
+ SLANG_NO_THROW slang::DeclReflection* ModuleCapture::getModuleReflection()
+ {
+ // No need to capture this call as it is just a query.
+ slangCaptureLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
+ slang::DeclReflection* res = (slang::DeclReflection*)m_actualModule->getModuleReflection();
+ return res;
+ }
+
SLANG_NO_THROW SlangResult ModuleCapture::findEntryPointByName(
char const* name,
slang::IEntryPoint** outEntryPoint)
diff --git a/source/slang-capture-replay/slang-module.h b/source/slang-capture-replay/slang-module.h
index de63b7967..94539532c 100644
--- a/source/slang-capture-replay/slang-module.h
+++ b/source/slang-capture-replay/slang-module.h
@@ -83,6 +83,7 @@ namespace SlangCapture
uint32_t compilerOptionEntryCount,
slang::CompilerOptionEntry* compilerOptionEntries,
ISlangBlob** outDiagnostics = nullptr) override;
+ virtual SLANG_NO_THROW slang::DeclReflection* getModuleReflection() override;
slang::IModule* getActualModule() const { return m_actualModule; }
private:
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 35b2fcf44..2409cedfb 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -282,6 +282,7 @@ namespace Slang
HashSet<SourceFile*> m_fileSet;
};
+
class EntryPoint;
class ComponentType;
@@ -1051,7 +1052,7 @@ namespace Slang
SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() SLANG_OVERRIDE
{
- return (slang::FunctionReflection*)m_funcDeclRef.getDecl();
+ return (slang::FunctionReflection*)m_funcDeclRef.declRefBase;
}
protected:
void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE;
@@ -1476,6 +1477,8 @@ namespace Slang
virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE;
+ virtual slang::DeclReflection* getModuleReflection() SLANG_OVERRIDE;
+
void setDigest(SHA1::Digest const& digest) { m_digest = digest; }
SHA1::Digest computeDigest();
@@ -3079,9 +3082,11 @@ namespace Slang
class Session : public RefObject, public slang::IGlobalSession
{
public:
- SLANG_REF_OBJECT_IUNKNOWN_ALL
+ SLANG_COM_INTERFACE(0xd6b767eb, 0xd786, 0x4343, { 0x2a, 0x8c, 0x6d, 0xa0, 0x3d, 0x5a, 0xb4, 0x4a })
- ISlangUnknown* getInterface(const Guid& guid);
+ SLANG_NO_THROW SlangResult SLANG_MCALL queryInterface(SlangUUID const& uuid, void** outObject) SLANG_OVERRIDE;
+ SLANG_REF_OBJECT_IUNKNOWN_ADD_REF
+ SLANG_REF_OBJECT_IUNKNOWN_RELEASE
// slang::IGlobalSession
SLANG_NO_THROW SlangResult SLANG_MCALL createSession(slang::SessionDesc const& desc, slang::ISession** outSession) override;
@@ -3287,9 +3292,11 @@ SLANG_FORCE_INLINE slang::IGlobalSession* asExternal(Session* session)
return static_cast<slang::IGlobalSession*>(session);
}
-SLANG_FORCE_INLINE Session* asInternal(slang::IGlobalSession* session)
+SLANG_FORCE_INLINE ComPtr<Session> asInternal(slang::IGlobalSession* session)
{
- return static_cast<Session*>(session);
+ Slang::Session* internalSession = nullptr;
+ session->queryInterface(SLANG_IID_PPV_ARGS(&internalSession));
+ return ComPtr<Session>(INIT_ATTACH, static_cast<Session*>(internalSession));
}
SLANG_FORCE_INLINE slang::ISession* asExternal(Linkage* linkage)
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index d0d09f814..9e08330e4 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -65,14 +65,15 @@ static inline SlangReflectionVariable* convert(Decl* var)
return (SlangReflectionVariable*) var;
}
-static inline FunctionDeclBase* convert(SlangReflectionFunction* func)
+static inline DeclRef<FunctionDeclBase> convert(SlangReflectionFunction* func)
{
- return (FunctionDeclBase*)func;
+ DeclRefBase* declBase = (DeclRefBase*)func;
+ return DeclRef<FunctionDeclBase>(declBase);
}
-static inline SlangReflectionFunction* convert(FunctionDeclBase* func)
+static inline SlangReflectionFunction* convert(DeclRef<FunctionDeclBase> func)
{
- return (SlangReflectionFunction*)func;
+ return (SlangReflectionFunction*)func.declRefBase;
}
static inline VarLayout* convert(SlangReflectionVariableLayout* var)
@@ -773,7 +774,7 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti
{
auto result = program->findDeclFromString(name, &sink);
if (auto funcDeclRef = result.as<FunctionDeclBase>())
- return (SlangReflectionFunction*)funcDeclRef.getDecl();
+ return convert(funcDeclRef);
}
catch (...)
{
@@ -2621,6 +2622,9 @@ SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflec
case SLANG_MODIFIER_DIFFERENTIABLE:
modifier = var->findModifier<DifferentiableAttribute>();
break;
+ case SLANG_MODIFIER_MUTATING:
+ modifier = var->findModifier<MutatingAttribute>();
+ break;
default:
return nullptr;
}
@@ -2647,6 +2651,17 @@ SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeBy
return findUserAttributeByName(asInternal(session), varDecl, name);
}
+SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar)
+{
+ auto decl = convert(inVar);
+ if (auto varDecl = as<VarDeclBase>(decl))
+ {
+ return varDecl->initExpr != nullptr;
+ }
+
+ return false;
+}
+
// Variable Layout Reflection
SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangReflectionVariableLayout* inVarLayout)
@@ -2793,11 +2808,18 @@ SLANG_API SlangStage spReflectionVariableLayout_getStage(
// Function Reflection
+SLANG_API SlangReflectionDecl* spReflectionFunction_asDecl(SlangReflectionFunction* inFunc)
+{
+ auto func = convert(inFunc);
+ if (!func) return nullptr;
+ return (SlangReflectionDecl*)func.getDecl();
+}
+
SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* inFunc)
{
auto func = convert(inFunc);
if (!func) return nullptr;
- return getText(func->getName()).getBuffer();
+ return getText(func.getDecl()->getName()).getBuffer();
}
SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* inFunc)
@@ -2805,42 +2827,146 @@ SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectio
auto func = convert(inFunc);
if (!func) return nullptr;
- return convert(func->returnType.type);
+ auto rawType = func.getDecl()->returnType.type;
+ auto astBuilder = rawType->getASTBuilderForReflection();
+
+ return convert((Type*)rawType->substitute(astBuilder, SubstitutionSet(func.declRefBase)));
+}
+
+SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* inFunc, SlangModifierID modifierID)
+{
+ auto funcDeclRef = convert(inFunc);
+ auto varRefl = convert(funcDeclRef.getDecl());
+ if (!varRefl) return nullptr;
+
+ return spReflectionVariable_FindModifier(varRefl, modifierID);
}
SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* inFunc)
{
auto func = convert(inFunc);
if (!func) return 0;
- return getUserAttributeCount(func);
+ return getUserAttributeCount(func.getDecl());
}
SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* inFunc, unsigned int index)
{
auto func = convert(inFunc);
if (!func) return nullptr;
- return getUserAttributeByIndex(func, index);
+ return getUserAttributeByIndex(func.getDecl(), index);
}
SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* inFunc, SlangSession* session, char const* name)
{
auto func = convert(inFunc);
if (!func) return nullptr;
- return findUserAttributeByName(asInternal(session), func, name);
+ return findUserAttributeByName(asInternal(session), func.getDecl(), name);
}
SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* inFunc)
{
auto func = convert(inFunc);
if (!func) return 0;
- return (unsigned int)func->getParameters().getCount();
+ return (unsigned int)func.getDecl()->getParameters().getCount();
}
SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* inFunc, unsigned int index)
{
auto func = convert(inFunc);
if (!func) return nullptr;
- return convert(as<Decl>(func->getParameters()[index]));
+ return convert(as<Decl>(func.getDecl()->getParameters()[index]));
+}
+
+// Abstract decl reflection
+
+SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl)
+{
+ Decl* decl = (Decl*)parentDecl;
+ if (as<ContainerDecl>(decl))
+ {
+ return (unsigned int)as<ContainerDecl>(decl)->members.getCount();
+ }
+
+ return 0;
+}
+
+SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* parentDecl, unsigned int index)
+{
+ Decl* decl = (Decl*)parentDecl;
+ if (auto containerDecl = as<ContainerDecl>(decl))
+ {
+ if (containerDecl->members.getCount() > index)
+ return (SlangReflectionDecl*)containerDecl->members[index];
+ }
+
+ return nullptr;
+}
+
+SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl)
+{
+ Decl* slangDecl = (Decl*)decl;
+ if (as<StructDecl>(slangDecl))
+ {
+ return SLANG_DECL_KIND_STRUCT;
+ }
+ else if (as<VarDeclBase>(slangDecl))
+ {
+ return SLANG_DECL_KIND_VARIABLE;
+ }
+ else if (as<GenericDecl>(slangDecl))
+ {
+ return SLANG_DECL_KIND_GENERIC;
+ }
+ else if (as<FunctionDeclBase>(slangDecl))
+ {
+ return SLANG_DECL_KIND_FUNC;
+ }
+ else if (as<ModuleDecl>(slangDecl))
+ {
+ return SLANG_DECL_KIND_MODULE;
+ }
+ else
+ return SLANG_DECL_KIND_UNSUPPORTED_FOR_REFLECTION;
+}
+
+SLANG_API SlangReflectionFunction* spReflectionDecl_castToFunction(SlangReflectionDecl* decl)
+{
+ Decl* slangDecl = (Decl*) decl;
+ if (auto funcDecl = as<FunctionDeclBase>(slangDecl))
+ {
+ return convert(DeclRef<FunctionDeclBase>(funcDecl->getDefaultDeclRef()));
+ }
+
+ // Improper cast
+ return nullptr;
+}
+
+SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflectionDecl* decl)
+{
+ Decl* slangDecl = (Decl*) decl;
+ if (auto varDecl = as<VarDeclBase>(slangDecl))
+ {
+ return (SlangReflectionVariable*) varDecl;
+ }
+
+ // Improper cast
+ return nullptr;
+
+}
+
+SLANG_API SlangReflectionType* spReflection_getTypeFromDecl(SlangSession* session, SlangReflectionDecl* decl)
+{
+ Decl* slangDecl = (Decl*)decl;
+ auto slangSession = asInternal(session);
+
+ ASTBuilder* builder = slangSession->getGlobalASTBuilder();
+ if (auto type = DeclRefType::create(builder, slangDecl->getDefaultDeclRef()))
+ {
+ return convert(type);
+ }
+
+ // Couldn't create a type from the decl
+ return nullptr;
}
// Shader Parameter Reflection
@@ -2906,7 +3032,7 @@ SLANG_API SlangReflectionFunction* spReflectionEntryPoint_GetFunction(SlangRefle
auto entryPointLayout = convert(inEntryPoint);
if (entryPointLayout)
{
- return (SlangReflectionFunction*)entryPointLayout->entryPoint.getDecl();
+ return convert(entryPointLayout->entryPoint);
}
return nullptr;
}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index f28d3cdeb..f440ce3cb 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -546,11 +546,23 @@ SlangResult Session::_readBuiltinModule(ISlangFileSystem* fileSystem, Scope* sco
return SLANG_OK;
}
-ISlangUnknown* Session::getInterface(const Guid& guid)
+SLANG_NO_THROW SlangResult SLANG_MCALL Session::queryInterface(SlangUUID const& uuid, void** outObject)
{
- if(guid == ISlangUnknown::getTypeGuid() || guid == IGlobalSession::getTypeGuid())
- return asExternal(this);
- return nullptr;
+ if (uuid == Session::getTypeGuid())
+ {
+ addReference();
+ *outObject = static_cast<Session*>(this);
+ return SLANG_OK;
+ }
+
+ if (uuid == ISlangUnknown::getTypeGuid() && uuid == IGlobalSession::getTypeGuid())
+ {
+ addReference();
+ *outObject = static_cast<slang::IGlobalSession*>(this);
+ return SLANG_OK;
+ }
+
+ return SLANG_E_NO_INTERFACE;
}
static size_t _getStructureSize(const uint8_t* src)
@@ -4070,6 +4082,11 @@ void Module::buildHash(DigestBuilder<SHA1>& builder)
builder.append(computeDigest());
}
+slang::DeclReflection* Module::getModuleReflection()
+{
+ return (slang::DeclReflection*)m_moduleDecl;
+}
+
SHA1::Digest Module::computeDigest()
{
if (m_digest == SHA1::Digest())
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
new file mode 100644
index 000000000..e443358fb
--- /dev/null
+++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
@@ -0,0 +1,167 @@
+// unit-test-translation-unit-import.cpp
+
+#include "slang.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "tools/unit-test/slang-unit-test.h"
+#include "slang-com-ptr.h"
+#include "../../source/core/slang-io.h"
+#include "../../source/core/slang-process.h"
+
+using namespace Slang;
+
+static String getTypeFullName(slang::TypeReflection* type)
+{
+ ComPtr<ISlangBlob> blob;
+ type->getFullName(blob.writeRef());
+ return String((const char*)blob->getBufferPointer());
+}
+
+static void printRefl(slang::DeclReflection* refl, unsigned int level = 0)
+{
+ // Mapping of kind ids to names
+ std::string names[] = {"Unsupported", "Struct", "Function", "Module", "Generic", "Variable"};
+ for (unsigned int i = 0; i < level; i++)
+ {
+ std::cout << " ";
+ }
+ std::cout<< "[" << names[(unsigned int)refl->getKind()] << "] (" << refl->getChildrenCount() << ")" << std::endl;
+
+ for (auto* child : refl->getChildren())
+ {
+ printRefl(child, level + 1);
+ }
+}
+
+// Test that the reflection API provides correct info about entry point and ordinary functions.
+
+SLANG_UNIT_TEST(declTreeReflection)
+{
+ // Source for a module that contains an undecorated entrypoint.
+ const char* userSourceBody = R"(
+ [__AttributeUsage(_AttributeTargets.Function)]
+ struct MyFuncPropertyAttribute {int v;}
+
+ [MyFuncProperty(1024)]
+ [Differentiable]
+ float ordinaryFunc(no_diff float x, int y) { return x + y; }
+
+ float4 fragMain(float4 pos:SV_Position) : SV_Position
+ {
+ return pos;
+ }
+ )";
+
+ auto moduleName = "moduleG" + String(Process::getId());
+ String userSource = "import " + moduleName + ";\n" + userSourceBody;
+ ComPtr<slang::IGlobalSession> globalSession;
+ SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
+ slang::TargetDesc targetDesc = {};
+ targetDesc.format = SLANG_HLSL;
+ targetDesc.profile = globalSession->findProfile("sm_5_0");
+ slang::SessionDesc sessionDesc = {};
+ sessionDesc.targetCount = 1;
+ sessionDesc.targets = &targetDesc;
+ ComPtr<slang::ISession> session;
+ SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
+
+ ComPtr<slang::IBlob> diagnosticBlob;
+ auto module = session->loadModuleFromSourceString("m", "m.slang", userSourceBody, diagnosticBlob.writeRef());
+ SLANG_CHECK(module != nullptr);
+
+ ComPtr<slang::IEntryPoint> entryPoint;
+ module->findAndCheckEntryPoint("fragMain", SLANG_STAGE_FRAGMENT, entryPoint.writeRef(), diagnosticBlob.writeRef());
+ SLANG_CHECK(entryPoint != nullptr);
+
+ auto moduleDeclReflection = module->getModuleReflection();
+ SLANG_CHECK(moduleDeclReflection != nullptr);
+ SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module);
+ SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 4);
+
+ // First declaration should be a struct with 1 variable
+ auto firstDecl = moduleDeclReflection->getChild(0);
+ SLANG_CHECK(firstDecl->getKind() == slang::DeclReflection::Kind::Struct);
+ SLANG_CHECK(firstDecl->getChildrenCount() == 1);
+
+ {
+ slang::TypeReflection* type = firstDecl->getType(globalSession);
+ SLANG_CHECK(getTypeFullName(type) == "MyFuncPropertyAttribute");
+
+ // Check the field of the struct.
+ SLANG_CHECK(type->getFieldCount() == 1);
+ auto field = type->getFieldByIndex(0);
+ SLANG_CHECK(UnownedStringSlice(field->getName()) == "v");
+ SLANG_CHECK(getTypeFullName(field->getType()) == "int");
+ }
+
+ // Second declaration should be a function
+ auto secondDecl = moduleDeclReflection->getChild(1);
+ SLANG_CHECK(secondDecl->getKind() == slang::DeclReflection::Kind::Func);
+ SLANG_CHECK(secondDecl->getChildrenCount() == 2); // Parameter declarations are children (return type is not)
+
+ {
+ auto funcReflection = secondDecl->asFunction();
+ SLANG_CHECK(funcReflection->findModifier(slang::Modifier::Differentiable) != nullptr);
+ SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "float");
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "ordinaryFunc");
+ SLANG_CHECK(funcReflection->getParameterCount() == 2);
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "x");
+ SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "float");
+ SLANG_CHECK(funcReflection->getParameterByIndex(0)->findModifier(slang::Modifier::NoDiff) != nullptr);
+
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(1)->getName()) == "y");
+ SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(1)->getType()) == "int");
+
+ SLANG_CHECK(funcReflection->getUserAttributeCount() == 1);
+ auto userAttribute = funcReflection->getUserAttributeByIndex(0);
+ SLANG_CHECK(UnownedStringSlice(userAttribute->getName()) == "MyFuncProperty");
+ SLANG_CHECK(userAttribute->getArgumentCount() == 1);
+ SLANG_CHECK(getTypeFullName(userAttribute->getArgumentType(0)) == "int");
+ int val = 0;
+ auto result = userAttribute->getArgumentValueInt(0, &val);
+ SLANG_CHECK(result == SLANG_OK);
+ SLANG_CHECK(val == 1024);
+ SLANG_CHECK(funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") == userAttribute);
+ }
+
+ // Third declaration should also be a function
+ auto thirdDecl = moduleDeclReflection->getChild(2);
+ SLANG_CHECK(thirdDecl->getKind() == slang::DeclReflection::Kind::Func);
+ SLANG_CHECK(thirdDecl->getChildrenCount() == 1);
+
+ {
+ auto funcReflection = thirdDecl->asFunction();
+ SLANG_CHECK(getTypeFullName(funcReflection->getReturnType()) == "vector<float,4>");
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getName()) == "fragMain");
+ SLANG_CHECK(funcReflection->getParameterCount() == 1);
+ SLANG_CHECK(UnownedStringSlice(funcReflection->getParameterByIndex(0)->getName()) == "pos");
+ SLANG_CHECK(getTypeFullName(funcReflection->getParameterByIndex(0)->getType()) == "vector<float,4>");
+ }
+
+ // Check iterators
+ {
+ unsigned int count = 0;
+ for (auto* child : moduleDeclReflection->getChildren())
+ {
+ count++;
+ }
+ SLANG_CHECK(count == 4);
+
+ count = 0;
+ for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>())
+ {
+ count++;
+ }
+ SLANG_CHECK(count == 2);
+
+ count = 0;
+ for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Struct>())
+ {
+ count++;
+ }
+ SLANG_CHECK(count == 1);
+ }
+}
+