summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-reflection-api.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-19 16:27:50 -0400
committerGitHub <noreply@github.com>2024-09-19 13:27:50 -0700
commitdd3d80e61b316390a468a142de2be2fb85b73d0d (patch)
treecbfd3ddcbaed84de335818e9e618d7c3ebff6ecd /source/slang/slang-reflection-api.cpp
parent9d40ce4e8921ef468281c91f052dbd443ecf56e2 (diff)
Allow lookups of overloaded methods. (#5110)
* Allow lookups of overloaded methods. * Update slang-reflection-api.cpp * Update slang.cpp --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-reflection-api.cpp')
-rw-r--r--source/slang/slang-reflection-api.cpp167
1 files changed, 131 insertions, 36 deletions
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 38129babf..b6fc05986 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -66,10 +66,21 @@ static inline SlangReflectionVariable* convert(DeclRef<Decl> var)
return (SlangReflectionVariable*) var.declRefBase;
}
-static inline DeclRef<FunctionDeclBase> convert(SlangReflectionFunction* func)
+static inline DeclRef<FunctionDeclBase> convertToFunc(SlangReflectionFunction* func)
{
- DeclRefBase* declBase = (DeclRefBase*)func;
- return DeclRef<FunctionDeclBase>(declBase);
+ NodeBase* nodeBase = (NodeBase*)func;
+ if (DeclRefBase* declRefBase = as<DeclRefBase>(nodeBase))
+ {
+ return DeclRef<FunctionDeclBase>(declRefBase);
+ }
+
+ return DeclRef<FunctionDeclBase>();
+}
+
+static inline OverloadedExpr* convertToOverloadedFunc(SlangReflectionFunction* func)
+{
+ NodeBase* nodeBase = (NodeBase*)func;
+ return as<OverloadedExpr>(nodeBase);
}
static inline SlangReflectionFunction* convert(DeclRef<FunctionDeclBase> func)
@@ -77,6 +88,11 @@ static inline SlangReflectionFunction* convert(DeclRef<FunctionDeclBase> func)
return (SlangReflectionFunction*)func.declRefBase;
}
+static inline SlangReflectionFunction* convert(OverloadedExpr* overloadedFunc)
+{
+ return (SlangReflectionFunction*)overloadedFunc;
+}
+
static inline DeclRef<Decl> convertGenericToDeclRef(SlangReflectionGeneric* func)
{
DeclRefBase* declBase = (DeclRefBase*)func;
@@ -785,6 +801,27 @@ SLANG_API SlangResult spReflectionType_GetFullName(SlangReflectionType* inType,
return SLANG_OK;
}
+SlangReflectionFunction* tryConvertExprToFunctionReflection(ASTBuilder* astBuilder, Expr* expr)
+{
+ if (auto declRefExpr = as<DeclRefExpr>(expr))
+ {
+ auto declRef = declRefExpr->declRef;
+ if (auto genericDeclRef = declRef.as<GenericDecl>())
+ {
+ auto innerDeclRef = substituteDeclRef(
+ SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
+ declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
+ }
+
+ if (auto funcDeclRef = declRef.as<FunctionDeclBase>())
+ return convert(funcDeclRef);
+ }
+ else if (auto overloadedExpr = as<OverloadedExpr>(expr))
+ return convert(overloadedExpr);
+
+ return nullptr;
+}
+
SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflection* reflection, char const* name)
{
auto programLayout = convert(reflection);
@@ -800,17 +837,9 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti
auto astBuilder = program->getLinkage()->getASTBuilder();
try
{
- auto result = program->findDeclFromString(name, &sink);
-
- if (auto genericDeclRef = result.as<GenericDecl>())
- {
- auto innerDeclRef = substituteDeclRef(
- SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
- result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
- }
-
- if (auto funcDeclRef = result.as<FunctionDeclBase>())
- return convert(funcDeclRef);
+ return tryConvertExprToFunctionReflection(
+ astBuilder,
+ program->findDeclFromString(name, &sink));
}
catch (...)
{
@@ -828,12 +857,13 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByNameInType(SlangRe
Slang::DiagnosticSink sink(
programLayout->getTargetReq()->getLinkage()->getSourceManager(),
Lexer::sourceLocationLexer);
-
+
+ auto astBuilder = program->getLinkage()->getASTBuilder();
+
try
{
auto result = program->findDeclFromStringInType(type, name, LookupMask::Function, &sink);
- if (auto funcDeclRef = result.as<FunctionDeclBase>())
- return convert(funcDeclRef);
+ return tryConvertExprToFunctionReflection(astBuilder, result);
}
catch (...)
{
@@ -855,8 +885,11 @@ SLANG_API SlangReflectionVariable* spReflection_FindVarByNameInType(SlangReflect
try
{
auto result = program->findDeclFromStringInType(type, name, LookupMask::Value, &sink);
- if (auto varDeclRef = result.as<VarDeclBase>())
- return convert(varDeclRef.as<Decl>());
+ if (auto declRefExpr = as<DeclRefExpr>(result))
+ {
+ if (auto varDeclRef = declRefExpr->declRef.as<VarDeclBase>())
+ return convert(varDeclRef.as<Decl>());
+ }
}
catch (...)
{
@@ -3009,21 +3042,23 @@ SLANG_API SlangStage spReflectionVariableLayout_getStage(
SLANG_API SlangReflectionDecl* spReflectionFunction_asDecl(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
+
return (SlangReflectionDecl*)func.getDecl();
}
SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
+
return getText(func.getDecl()->getName()).getBuffer();
}
SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
auto rawType = func.getDecl()->returnType.type;
@@ -3034,7 +3069,9 @@ SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectio
SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* inFunc, SlangModifierID modifierID)
{
- auto funcDeclRef = convert(inFunc);
+ auto funcDeclRef = convertToFunc(inFunc);
+ if (!funcDeclRef) return nullptr;
+
auto varRefl = convert(funcDeclRef.as<Decl>());
if (!varRefl) return nullptr;
@@ -3043,35 +3080,38 @@ SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflec
SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return 0;
+
return getUserAttributeCount(func.getDecl());
}
SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* inFunc, unsigned int index)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
return getUserAttributeByIndex(func.getDecl(), index);
}
SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* inFunc, SlangSession* session, char const* name)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
+
return findUserAttributeByName(asInternal(session), func.getDecl(), name);
}
SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return 0;
+
return (unsigned int)func.getDecl()->getParameters().getCount();
}
SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* inFunc, unsigned int index)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
auto astBuilder = getModule(func.getDecl())->getLinkage()->getASTBuilder();
@@ -3081,13 +3121,16 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func)
{
- auto declRef = convert(func);
+ auto declRef = convertToFunc(func);
+ if (!declRef)
+ return nullptr;
+
return convertDeclToGeneric(getInnermostGenericParent(declRef));
}
SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic)
{
- auto declRef = convert(func);
+ auto declRef = convertToFunc(func);
auto genericDeclRef = convertGenericToDeclRef(generic);
if (!declRef || !genericDeclRef)
return nullptr;
@@ -3103,12 +3146,25 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
SlangInt argTypeCount,
SlangReflectionType* const* argTypes)
{
- auto declRef = convert(func);
- if (!declRef)
+ Linkage* linkage = nullptr;
+ Expr* funcExpr = nullptr;
+
+ if (auto funcDeclRef = convertToFunc(func))
+ {
+ linkage = getModule(funcDeclRef.getDecl())->getLinkage();
+ auto declRefExpr = linkage->getASTBuilder()->create<DeclRefExpr>();
+ declRefExpr->declRef = funcDeclRef;
+ funcExpr = declRefExpr;
+ }
+ else if (auto overloadedExpr = convertToOverloadedFunc(func))
+ {
+ linkage = getModule(overloadedExpr->lookupResult2.items[0].declRef.getDecl())->getLinkage();
+ funcExpr = overloadedExpr;
+ }
+ else
+ {
return nullptr;
-
-
- auto linkage = getModule(declRef.getDecl())->getLinkage();
+ }
List<Type*> argTypeList;
for (SlangInt ii = 0; ii < argTypeCount; ++ii)
@@ -3120,7 +3176,7 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
try
{
DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
- return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as<FunctionDeclBase>());
+ return convert(linkage->specializeWithArgTypes(funcExpr, argTypeList, &sink).as<FunctionDeclBase>());
}
catch (...)
{
@@ -3128,6 +3184,45 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
}
}
+SLANG_API bool spReflectionFunction_isOverloaded(
+ SlangReflectionFunction* func)
+{
+ return (convertToOverloadedFunc(func) != nullptr);
+}
+
+SLANG_API unsigned int spReflectionFunction_getOverloadCount(
+ SlangReflectionFunction* func)
+{
+ auto overloadedFunc = convertToOverloadedFunc(func);
+ if (!overloadedFunc) return 1;
+
+ return (unsigned int) overloadedFunc->lookupResult2.items.getCount();
+}
+
+SLANG_API SlangReflectionFunction* spReflectionFunction_getOverload(
+ SlangReflectionFunction* func,
+ unsigned int index)
+{
+ auto overloadedFunc = convertToOverloadedFunc(func);
+ if (!overloadedFunc) return nullptr;
+
+ auto declRef = overloadedFunc->lookupResult2.items[index].declRef;
+ if (auto funcDeclRef = declRef.as<FunctionDeclBase>())
+ {
+ return convert(declRef.as<FunctionDeclBase>());
+ }
+ else if (auto genericDeclRef = declRef.as<GenericDecl>())
+ {
+ auto astBuilder = getModule(genericDeclRef.getDecl())->getLinkage()->getASTBuilder();
+ auto innerDeclRef = substituteDeclRef(
+ SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
+ return convert(
+ createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef).as<FunctionDeclBase>());
+ }
+
+ return nullptr;
+}
+
// Abstract decl reflection
SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl)