From dd3d80e61b316390a468a142de2be2fb85b73d0d Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 19 Sep 2024 16:27:50 -0400 Subject: Allow lookups of overloaded methods. (#5110) * Allow lookups of overloaded methods. * Update slang-reflection-api.cpp * Update slang.cpp --------- Co-authored-by: Yong He --- source/slang/slang-reflection-api.cpp | 167 ++++++++++++++++++++++++++-------- 1 file changed, 131 insertions(+), 36 deletions(-) (limited to 'source/slang/slang-reflection-api.cpp') diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 38129babf..b6fc05986 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -66,10 +66,21 @@ static inline SlangReflectionVariable* convert(DeclRef var) return (SlangReflectionVariable*) var.declRefBase; } -static inline DeclRef convert(SlangReflectionFunction* func) +static inline DeclRef convertToFunc(SlangReflectionFunction* func) { - DeclRefBase* declBase = (DeclRefBase*)func; - return DeclRef(declBase); + NodeBase* nodeBase = (NodeBase*)func; + if (DeclRefBase* declRefBase = as(nodeBase)) + { + return DeclRef(declRefBase); + } + + return DeclRef(); +} + +static inline OverloadedExpr* convertToOverloadedFunc(SlangReflectionFunction* func) +{ + NodeBase* nodeBase = (NodeBase*)func; + return as(nodeBase); } static inline SlangReflectionFunction* convert(DeclRef func) @@ -77,6 +88,11 @@ static inline SlangReflectionFunction* convert(DeclRef func) return (SlangReflectionFunction*)func.declRefBase; } +static inline SlangReflectionFunction* convert(OverloadedExpr* overloadedFunc) +{ + return (SlangReflectionFunction*)overloadedFunc; +} + static inline DeclRef convertGenericToDeclRef(SlangReflectionGeneric* func) { DeclRefBase* declBase = (DeclRefBase*)func; @@ -785,6 +801,27 @@ SLANG_API SlangResult spReflectionType_GetFullName(SlangReflectionType* inType, return SLANG_OK; } +SlangReflectionFunction* tryConvertExprToFunctionReflection(ASTBuilder* astBuilder, Expr* expr) +{ + if (auto declRefExpr = as(expr)) + { + auto declRef = declRefExpr->declRef; + if (auto genericDeclRef = declRef.as()) + { + auto innerDeclRef = substituteDeclRef( + SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); + declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef); + } + + if (auto funcDeclRef = declRef.as()) + return convert(funcDeclRef); + } + else if (auto overloadedExpr = as(expr)) + return convert(overloadedExpr); + + return nullptr; +} + SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflection* reflection, char const* name) { auto programLayout = convert(reflection); @@ -800,17 +837,9 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti auto astBuilder = program->getLinkage()->getASTBuilder(); try { - auto result = program->findDeclFromString(name, &sink); - - if (auto genericDeclRef = result.as()) - { - auto innerDeclRef = substituteDeclRef( - SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); - result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef); - } - - if (auto funcDeclRef = result.as()) - return convert(funcDeclRef); + return tryConvertExprToFunctionReflection( + astBuilder, + program->findDeclFromString(name, &sink)); } catch (...) { @@ -828,12 +857,13 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByNameInType(SlangRe Slang::DiagnosticSink sink( programLayout->getTargetReq()->getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); - + + auto astBuilder = program->getLinkage()->getASTBuilder(); + try { auto result = program->findDeclFromStringInType(type, name, LookupMask::Function, &sink); - if (auto funcDeclRef = result.as()) - return convert(funcDeclRef); + return tryConvertExprToFunctionReflection(astBuilder, result); } catch (...) { @@ -855,8 +885,11 @@ SLANG_API SlangReflectionVariable* spReflection_FindVarByNameInType(SlangReflect try { auto result = program->findDeclFromStringInType(type, name, LookupMask::Value, &sink); - if (auto varDeclRef = result.as()) - return convert(varDeclRef.as()); + if (auto declRefExpr = as(result)) + { + if (auto varDeclRef = declRefExpr->declRef.as()) + return convert(varDeclRef.as()); + } } catch (...) { @@ -3009,21 +3042,23 @@ SLANG_API SlangStage spReflectionVariableLayout_getStage( SLANG_API SlangReflectionDecl* spReflectionFunction_asDecl(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; + return (SlangReflectionDecl*)func.getDecl(); } SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; + return getText(func.getDecl()->getName()).getBuffer(); } SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; auto rawType = func.getDecl()->returnType.type; @@ -3034,7 +3069,9 @@ SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectio SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* inFunc, SlangModifierID modifierID) { - auto funcDeclRef = convert(inFunc); + auto funcDeclRef = convertToFunc(inFunc); + if (!funcDeclRef) return nullptr; + auto varRefl = convert(funcDeclRef.as()); if (!varRefl) return nullptr; @@ -3043,35 +3080,38 @@ SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflec SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return 0; + return getUserAttributeCount(func.getDecl()); } SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* inFunc, unsigned int index) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; return getUserAttributeByIndex(func.getDecl(), index); } SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* inFunc, SlangSession* session, char const* name) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; + return findUserAttributeByName(asInternal(session), func.getDecl(), name); } SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* inFunc) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return 0; + return (unsigned int)func.getDecl()->getParameters().getCount(); } SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* inFunc, unsigned int index) { - auto func = convert(inFunc); + auto func = convertToFunc(inFunc); if (!func) return nullptr; auto astBuilder = getModule(func.getDecl())->getLinkage()->getASTBuilder(); @@ -3081,13 +3121,16 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func) { - auto declRef = convert(func); + auto declRef = convertToFunc(func); + if (!declRef) + return nullptr; + return convertDeclToGeneric(getInnermostGenericParent(declRef)); } SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic) { - auto declRef = convert(func); + auto declRef = convertToFunc(func); auto genericDeclRef = convertGenericToDeclRef(generic); if (!declRef || !genericDeclRef) return nullptr; @@ -3103,12 +3146,25 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( SlangInt argTypeCount, SlangReflectionType* const* argTypes) { - auto declRef = convert(func); - if (!declRef) + Linkage* linkage = nullptr; + Expr* funcExpr = nullptr; + + if (auto funcDeclRef = convertToFunc(func)) + { + linkage = getModule(funcDeclRef.getDecl())->getLinkage(); + auto declRefExpr = linkage->getASTBuilder()->create(); + declRefExpr->declRef = funcDeclRef; + funcExpr = declRefExpr; + } + else if (auto overloadedExpr = convertToOverloadedFunc(func)) + { + linkage = getModule(overloadedExpr->lookupResult2.items[0].declRef.getDecl())->getLinkage(); + funcExpr = overloadedExpr; + } + else + { return nullptr; - - - auto linkage = getModule(declRef.getDecl())->getLinkage(); + } List argTypeList; for (SlangInt ii = 0; ii < argTypeCount; ++ii) @@ -3120,7 +3176,7 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( try { DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as()); + return convert(linkage->specializeWithArgTypes(funcExpr, argTypeList, &sink).as()); } catch (...) { @@ -3128,6 +3184,45 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( } } +SLANG_API bool spReflectionFunction_isOverloaded( + SlangReflectionFunction* func) +{ + return (convertToOverloadedFunc(func) != nullptr); +} + +SLANG_API unsigned int spReflectionFunction_getOverloadCount( + SlangReflectionFunction* func) +{ + auto overloadedFunc = convertToOverloadedFunc(func); + if (!overloadedFunc) return 1; + + return (unsigned int) overloadedFunc->lookupResult2.items.getCount(); +} + +SLANG_API SlangReflectionFunction* spReflectionFunction_getOverload( + SlangReflectionFunction* func, + unsigned int index) +{ + auto overloadedFunc = convertToOverloadedFunc(func); + if (!overloadedFunc) return nullptr; + + auto declRef = overloadedFunc->lookupResult2.items[index].declRef; + if (auto funcDeclRef = declRef.as()) + { + return convert(declRef.as()); + } + else if (auto genericDeclRef = declRef.as()) + { + auto astBuilder = getModule(genericDeclRef.getDecl())->getLinkage()->getASTBuilder(); + auto innerDeclRef = substituteDeclRef( + SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); + return convert( + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef).as()); + } + + return nullptr; +} + // Abstract decl reflection SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl) -- cgit v1.2.3