diff options
| author | Yong He <yonghe@outlook.com> | 2025-08-06 01:07:41 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-06 08:07:41 +0000 |
| commit | 68b0125226464cb3c9e9b7f50bfb53cda97723b4 (patch) | |
| tree | 5f0833c6d9aa759b2769f7f6ac9b3ca6ed9a10f0 | |
| parent | 83675103a1a4fefde11b314aed26f4d37860efe7 (diff) | |
Add reflection api for overload candidate filtering. (#8066)
* Add reflection api for overload candidate filtering.
* Fix API.
* Fix.
* Update build.
* Update test.
* Update formatting.
| -rw-r--r-- | include/slang-deprecated.h | 4 | ||||
| -rw-r--r-- | include/slang.h | 16 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-iterator.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-language-server-ast-lookup.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-language-server.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-linkable.cpp | 59 | ||||
| -rw-r--r-- | source/slang/slang-linkable.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 40 | ||||
| -rw-r--r-- | tools/CMakeLists.txt | 4 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-function-lookup-resolution.cpp | 91 |
13 files changed, 207 insertions, 42 deletions
diff --git a/include/slang-deprecated.h b/include/slang-deprecated.h index f210e8c48..32db65007 100644 --- a/include/slang-deprecated.h +++ b/include/slang-deprecated.h @@ -904,6 +904,10 @@ extern "C" SlangReflection* reflection, SlangReflectionType* reflType, char const* name); + SLANG_API SlangReflectionFunction* spReflection_TryResolveOverloadedFunction( + SlangReflection* reflection, + uint32_t candidateCount, + SlangReflectionFunction** candidates); SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* reflection); SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex( diff --git a/include/slang.h b/include/slang.h index 65449a1ff..7462644a2 100644 --- a/include/slang.h +++ b/include/slang.h @@ -502,6 +502,12 @@ convention for interface methods. #include <stddef.h> #endif // ! SLANG_NO_STDDEF +#ifdef SLANG_NO_DEPRECATION + #define SLANG_DEPRECATED +#else + #define SLANG_DEPRECATED [[deprecated]] +#endif + #ifdef __cplusplus extern "C" { @@ -3345,6 +3351,16 @@ struct ShaderReflection name); } + SLANG_DEPRECATED FunctionReflection* tryResolveOverloadedFunction( + uint32_t candidateCount, + FunctionReflection** candidates) + { + return (FunctionReflection*)spReflection_TryResolveOverloadedFunction( + (SlangReflection*)this, + candidateCount, + (SlangReflectionFunction**)candidates); + } + VariableReflection* findVarByNameInType(TypeReflection* type, const char* name) { return (VariableReflection*)spReflection_FindVarByNameInType( diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 1100cdee4..32c724e10 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -79,7 +79,7 @@ class OverloadedExpr2 : public Expr Expr* base = nullptr; // The lookup result that was ambiguous - List<Expr*> candidiateExprs; + List<Expr*> candidateExprs; }; FIDDLE(abstract) diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 3cce8df59..c29a42665 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -181,7 +181,7 @@ struct ASTIterator { iterator->maybeDispatchCallback(expr); dispatchIfNotNull(expr->base); - for (auto candidate : expr->candidiateExprs) + for (auto candidate : expr->candidateExprs) { dispatchIfNotNull(candidate); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f70760c5d..d2f338e65 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -354,7 +354,8 @@ DeclRefExpr* SemanticsVisitor::ConstructDeclRefExpr( // This is the bottleneck for using declarations which might be // deprecated, diagnose here. - diagnoseDeprecatedDeclRefUsage(declRef, loc, originalExpr); + if (getSink()) + diagnoseDeprecatedDeclRefUsage(declRef, loc, originalExpr); // Construct an appropriate expression based on the structured of // the declaration reference. @@ -389,7 +390,7 @@ DeclRefExpr* SemanticsVisitor::ConstructDeclRefExpr( auto expr = m_astBuilder->create<StaticMemberExpr>(); expr->loc = loc; expr->type = type; - if (!isDeclUsableAsStaticMember(declRef.getDecl())) + if (getSink() && !isDeclUsableAsStaticMember(declRef.getDecl())) { getSink()->diagnose( loc, @@ -1234,7 +1235,8 @@ Expr* SemanticsVisitor::_resolveOverloadedExprImpl( DiagnosticSink* diagSink) { auto lookupResult = overloadedExpr->lookupResult2; - SLANG_RELEASE_ASSERT(lookupResult.isValid() && lookupResult.isOverloaded()); + if (!lookupResult.isValid() || !lookupResult.isOverloaded()) + return overloadedExpr; // Take the lookup result we had, and refine it based on what is expected in context. // @@ -3827,7 +3829,7 @@ static Expr* _checkHigherOrderInvokeExpr( auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, lookupResultExpr); candidateExpr->loc = expr->loc; - result->candidiateExprs.add(candidateExpr); + result->candidateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); result->loc = expr->loc; @@ -3836,12 +3838,12 @@ static Expr* _checkHigherOrderInvokeExpr( else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr->baseFunction)) { OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>(); - for (auto item : overloadedExpr2->candidiateExprs) + for (auto item : overloadedExpr2->candidateExprs) { auto candidateExpr = actions->createHigherOrderInvokeExpr(semantics); actions->fillHigherOrderInvokeExpr(candidateExpr, semantics, item); candidateExpr->loc = expr->loc; - result->candidiateExprs.add(candidateExpr); + result->candidateExprs.add(candidateExpr); } result->type.type = astBuilder->getOverloadedType(); result->loc = expr->loc; @@ -5181,7 +5183,7 @@ Expr* SemanticsVisitor::_lookupStaticMember(DeclRefExpr* expr, Expr* baseExpress } else if (auto overloaded2 = as<OverloadedExpr2>(baseExpression)) { - for (auto candidate : overloaded2->candidiateExprs) + for (auto candidate : overloaded2->candidateExprs) { handleLeafExpr(candidate); } diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 72de3fd61..2c17a6380 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -2492,7 +2492,7 @@ void SemanticsVisitor::AddOverloadCandidates(Expr* funcExpr, OverloadResolveCont } else if (auto overloadedExpr2 = as<OverloadedExpr2>(funcExpr)) { - for (auto item : overloadedExpr2->candidiateExprs) + for (auto item : overloadedExpr2->candidateExprs) { AddOverloadCandidates(item, context); } @@ -3142,7 +3142,7 @@ Expr* SemanticsVisitor::checkGenericAppWithCheckedArgs(GenericAppExpr* genericAp for (auto candidate : context.bestCandidates) { auto candidateExpr = CompleteOverloadCandidate(context, candidate); - overloadedExpr->candidiateExprs.add(candidateExpr); + overloadedExpr->candidateExprs.add(candidateExpr); } return overloadedExpr; } diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index aa3040f08..cfb8e4d61 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -311,7 +311,7 @@ public: return true; bool result = false; PushNode pushNode(context, expr); - for (auto candidate : expr->candidiateExprs) + for (auto candidate : expr->candidateExprs) { result |= dispatchIfNotNull(candidate); } diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index d65dba71c..aa7f5e495 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -663,7 +663,7 @@ LanguageServerResult<LanguageServerProtocol::Hover> LanguageServerCore::hover( } else if (auto overloadedExpr2 = as<OverloadedExpr2>(node)) { - numOverloads = overloadedExpr2->candidiateExprs.getCount(); + numOverloads = overloadedExpr2->candidateExprs.getCount(); } } if (numOverloads > 1) @@ -872,9 +872,9 @@ LanguageServerResult<LanguageServerProtocol::Hover> LanguageServerCore::hover( } else if (auto overloadedExpr2 = as<OverloadedExpr2>(leafNode)) { - if (overloadedExpr2->candidiateExprs.getCount() > 0) + if (overloadedExpr2->candidateExprs.getCount() > 0) { - auto candidateExpr = overloadedExpr2->candidiateExprs[0]; + auto candidateExpr = overloadedExpr2->candidateExprs[0]; fillExprHoverInfo(candidateExpr); } } @@ -1896,7 +1896,7 @@ LanguageServerResult<LanguageServerProtocol::SignatureHelp> LanguageServerCore:: } else if (auto overloadedExpr2 = as<OverloadedExpr2>(funcExpr)) { - for (auto item : overloadedExpr2->candidiateExprs) + for (auto item : overloadedExpr2->candidateExprs) { addExpr(item); } diff --git a/source/slang/slang-linkable.cpp b/source/slang/slang-linkable.cpp index da4cec823..f7dc28171 100644 --- a/source/slang/slang-linkable.cpp +++ b/source/slang/slang-linkable.cpp @@ -8,6 +8,7 @@ #include "core/slang-memory-file-system.h" #include "slang-check-impl.h" #include "slang-compiler.h" +#include "slang-lookup.h" #include "slang-mangle.h" namespace Slang @@ -772,6 +773,14 @@ Type* ComponentType::getTypeFromString(String const& typeStr, DiagnosticSink* si return type; } +Expr* ComponentType::tryResolveOverloadedExpr(Expr* exprIn) +{ + auto linkage = getLinkage(); + SemanticsContext context(linkage->getSemanticsForReflection()); + SemanticsVisitor visitor(context); + return visitor.maybeResolveOverloadedExpr(exprIn, LookupMask::Function, nullptr); +} + Expr* ComponentType::findDeclFromString(String const& name, DiagnosticSink* sink) { // If we've looked up this type name before, @@ -905,39 +914,39 @@ Expr* ComponentType::findDeclFromStringInType( auto checkedTerm = visitor.CheckTerm(expr); - // Check if checkedTerm is overloaded functions and avoid resolving if so - // to preserve all function overloads with different signatures - Expr* resolvedTerm = checkedTerm; if (auto overloadedExpr = as<OverloadedExpr>(checkedTerm)) { - // Check if all candidates are function references - bool allAreFunctions = true; - for (auto item : overloadedExpr->lookupResult2.items) + // For functions, since we don't know the argument list yet, we will have to defer + // non-parameter-related candidate comparison logic into its separate step. + if (mask != LookupMask::Function) + return visitor.maybeResolveOverloadedExpr(checkedTerm, mask, nullptr); + overloadedExpr->lookupResult2 = refineLookup(overloadedExpr->lookupResult2, mask); + + // Filter out abstract base interface method implementations for reflection. + if (!isInterfaceType(type)) { - if (!as<FunctionDeclBase>(item.declRef.getDecl())) + LookupResult filteredResult; + for (auto candidate : overloadedExpr->lookupResult2) { - allAreFunctions = false; - break; + if (as<InterfaceDecl>(getParentDecl(candidate.declRef.getDecl()))) + { + if (!candidate.declRef.getDecl() + ->hasModifier<HasInterfaceDefaultImplModifier>()) + continue; + } + AddToLookupResult(filteredResult, candidate); } + if (filteredResult.isValid() && !filteredResult.isOverloaded()) + { + // If there are exactly one candidate after filtering, we can + // safely return resolved expr. + return visitor.maybeResolveOverloadedExpr(checkedTerm, mask, nullptr); + } + overloadedExpr->lookupResult2 = filteredResult; } - - // If not all are functions, resolve the overload as usual - if (!allAreFunctions) - { - resolvedTerm = visitor.maybeResolveOverloadedExpr(checkedTerm, mask, sink); - } - } - else - { - // Not overloaded, resolve as usual - resolvedTerm = visitor.maybeResolveOverloadedExpr(checkedTerm, mask, sink); - } - - if (auto overloadedExpr = as<OverloadedExpr>(resolvedTerm)) - { return overloadedExpr; } - if (auto declRefExpr = as<DeclRefExpr>(resolvedTerm)) + if (auto declRefExpr = as<DeclRefExpr>(checkedTerm)) { return declRefExpr; } diff --git a/source/slang/slang-linkable.h b/source/slang/slang-linkable.h index e900fd275..04a5c3c88 100644 --- a/source/slang/slang-linkable.h +++ b/source/slang/slang-linkable.h @@ -264,6 +264,7 @@ public: Type* getTypeFromString(String const& typeStr, DiagnosticSink* sink); Expr* findDeclFromString(String const& name, DiagnosticSink* sink); + Expr* tryResolveOverloadedExpr(Expr* exprIn); Expr* findDeclFromStringInType( Type* type, diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 56a82e17e..95dcc6249 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -5,6 +5,7 @@ #include "slang-check.h" #include "slang-compiler.h" #include "slang-deprecated.h" +#include "slang-lookup.h" #include "slang-syntax.h" #include "slang-type-layout.h" #include "slang.h" @@ -994,6 +995,45 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByNameInType( return nullptr; } + +SLANG_API SlangReflectionFunction* spReflection_TryResolveOverloadedFunction( + SlangReflection* reflection, + uint32_t candidateCount, + SlangReflectionFunction** candidates) +{ + auto programLayout = convert(reflection); + auto program = programLayout->getProgram(); + auto astBuilder = program->getLinkage()->getASTBuilder(); + SLANG_AST_BUILDER_RAII(astBuilder); + OverloadedExpr* overloadedFunc = nullptr; + if (candidateCount == 1) + { + overloadedFunc = convertToOverloadedFunc(candidates[0]); + if (!overloadedFunc) + return candidates[0]; + } + else + { + overloadedFunc = astBuilder->create<OverloadedExpr>(); + overloadedFunc->type = astBuilder->getOrCreate<OverloadGroupType>(); + for (uint32_t i = 0; i < candidateCount; i++) + { + auto func = convertToFunc(candidates[i]); + AddToLookupResult(overloadedFunc->lookupResult2, LookupResultItem(func)); + } + } + + try + { + auto result = program->tryResolveOverloadedExpr(overloadedFunc); + return tryConvertExprToFunctionReflection(astBuilder, result); + } + catch (...) + { + } + return nullptr; +} + SLANG_API SlangReflectionVariable* spReflection_FindVarByNameInType( SlangReflection* reflection, SlangReflectionType* reflType, diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 00bd93c49..7b28e960f 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -377,7 +377,9 @@ if(SLANG_ENABLE_TESTS) slang-unit-test MODULE EXCLUDE_FROM_ALL - EXTRA_COMPILE_DEFINITIONS_PRIVATE SLANG_SHARED_LIBRARY_TOOL + EXTRA_COMPILE_DEFINITIONS_PRIVATE + SLANG_SHARED_LIBRARY_TOOL + SLANG_NO_DEPRECATION USE_FEWER_WARNINGS LINK_WITH_PRIVATE core compiler-core unit-test slang Threads::Threads OUTPUT_NAME slang-unit-test-tool diff --git a/tools/slang-unit-test/unit-test-function-lookup-resolution.cpp b/tools/slang-unit-test/unit-test-function-lookup-resolution.cpp new file mode 100644 index 000000000..539c9ac48 --- /dev/null +++ b/tools/slang-unit-test/unit-test-function-lookup-resolution.cpp @@ -0,0 +1,91 @@ +// unit-test-function-lookup-resolution.cpp + +#include "../../source/core/slang-io.h" +#include "../../source/core/slang-process.h" +#include "slang-com-ptr.h" +#include "slang.h" +#include "unit-test/slang-unit-test.h" + +#include <stdio.h> +#include <stdlib.h> + +using namespace Slang; + +static String getTypeFullName(slang::TypeReflection* type) +{ + ComPtr<ISlangBlob> blob; + type->getFullName(blob.writeRef()); + return String((const char*)blob->getBufferPointer()); +} + +// Test that the reflection API provides correctly resolved lookup results. + +SLANG_UNIT_TEST(functionLookupResolution) +{ + // Source for a module that contains an undecorated entrypoint. + const char* userSourceBody = R"( + public interface IBase + { + public void step(inout float f); + public void method(int x) {} + } + + public struct Impl : IBase + { + public void step(inout float f) + { + f += 1.0f; + } + public override void method(int x) {} + } + public extension<T : IBase> T { + public void method(int x) {} + } + )"; + + auto moduleName = "moduleG" + String(Process::getId()); + String userSource = "import " + moduleName + ";\n" + userSourceBody; + ComPtr<slang::IGlobalSession> globalSession; + SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK); + slang::TargetDesc targetDesc = {}; + targetDesc.format = SLANG_HLSL; + targetDesc.profile = globalSession->findProfile("sm_5_0"); + slang::SessionDesc sessionDesc = {}; + sessionDesc.targetCount = 1; + sessionDesc.targets = &targetDesc; + ComPtr<slang::ISession> session; + SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK); + + ComPtr<slang::IBlob> diagnosticBlob; + auto module = session->loadModuleFromSourceString( + "m", + "m.slang", + userSourceBody, + diagnosticBlob.writeRef()); + SLANG_CHECK(module != nullptr); + + auto layout = module->getLayout(); + auto type = layout->findTypeByName("Impl"); + SLANG_CHECK_ABORT(type != nullptr); + + auto func = layout->findFunctionByNameInType(type, "step"); + SLANG_CHECK_ABORT(func && !func->isOverloaded()); + + + auto func1 = layout->findFunctionByNameInType(type, "method"); + SLANG_CHECK_ABORT(func1->isOverloaded()); + SLANG_CHECK(func1->getOverloadCount() == 3); + if (func1->isOverloaded()) + { + List<slang::FunctionReflection*> candidates; + for (uint32_t i = 0; i < func1->getOverloadCount(); i++) + { + candidates.add(func1->getOverload(i)); + } + func1 = layout->tryResolveOverloadedFunction( + (uint32_t)candidates.getCount(), + candidates.getBuffer()); + } + SLANG_CHECK(!func1->isOverloaded()); + SLANG_CHECK(String(func1->getName()) == "method"); +} |
