From d866c0b9dfc0fdc8ad8cede4d7a8593f7ddf4716 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 16 Sep 2024 16:04:45 -0400 Subject: Add API method to specialize function reference with argument types (#4966) * Add `FunctionReflection::specializeWithArgTypes()` * Update slang.cpp * Use a shared semantics context on linkage Improve performance on reflection queries * Try to fix linux/mac compile errors --- source/slang/slang-reflection-api.cpp | 69 ++++++++++++++++++++++++++++------- 1 file changed, 55 insertions(+), 14 deletions(-) (limited to 'source/slang/slang-reflection-api.cpp') diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index efa9a20a9..38129babf 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -797,9 +797,18 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti programLayout->getTargetReq()->getLinkage()->getSourceManager(), Lexer::sourceLocationLexer); + auto astBuilder = program->getLinkage()->getASTBuilder(); try { auto result = program->findDeclFromString(name, &sink); + + if (auto genericDeclRef = result.as()) + { + auto innerDeclRef = substituteDeclRef( + SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); + result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef); + } + if (auto funcDeclRef = result.as()) return convert(funcDeclRef); } @@ -924,7 +933,7 @@ SLANG_API bool spReflection_isSubType( } } -SlangReflectionGeneric* getInnermostGenericParent(DeclRef declRef) +DeclRef getInnermostGenericParent(DeclRef declRef) { auto decl = declRef.getDecl(); auto astBuilder = getModule(decl)->getLinkage()->getASTBuilder(); @@ -932,15 +941,14 @@ SlangReflectionGeneric* getInnermostGenericParent(DeclRef declRef) while(parentDecl) { if(parentDecl->parentDecl && as(parentDecl->parentDecl)) - return convertDeclToGeneric( - substituteDeclRef( + return substituteDeclRef( SubstitutionSet(declRef), astBuilder, - createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl)))); + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl))); parentDecl = parentDecl->parentDecl; } - return nullptr; + return DeclRef(); } SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type) @@ -948,11 +956,13 @@ SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangRefl auto slangType = convert(type); if (auto declRefType = as(slangType)) { - return getInnermostGenericParent(declRefType->getDeclRef()); + return convertDeclToGeneric( + getInnermostGenericParent(declRefType->getDeclRef())); } else if (auto genericDeclRefType = as(slangType)) { - return getInnermostGenericParent(genericDeclRefType->getDeclRef()); + return convertDeclToGeneric( + getInnermostGenericParent(genericDeclRefType->getDeclRef())); } return nullptr; @@ -2835,7 +2845,7 @@ SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inV SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var) { auto declRef = convert(var); - return getInnermostGenericParent(declRef); + return convertDeclToGeneric(getInnermostGenericParent(declRef)); } SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic) @@ -3072,7 +3082,7 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func) { auto declRef = convert(func); - return getInnermostGenericParent(declRef); + return convertDeclToGeneric(getInnermostGenericParent(declRef)); } SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic) @@ -3088,6 +3098,36 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(Sla return convert(substDeclRef.as()); } +SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( + SlangReflectionFunction* func, + SlangInt argTypeCount, + SlangReflectionType* const* argTypes) +{ + auto declRef = convert(func); + if (!declRef) + return nullptr; + + + auto linkage = getModule(declRef.getDecl())->getLinkage(); + + List argTypeList; + for (SlangInt ii = 0; ii < argTypeCount; ++ii) + { + auto argType = convert(argTypes[ii]); + argTypeList.add(argType); + } + + try + { + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as()); + } + catch (...) + { + return nullptr; + } +} + // Abstract decl reflection SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl) @@ -3329,11 +3369,12 @@ SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(S auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder(); - return getInnermostGenericParent( - substituteDeclRef( - SubstitutionSet(declRef), - astBuilder, - createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl)))); + return convertDeclToGeneric( + getInnermostGenericParent( + substituteDeclRef( + SubstitutionSet(declRef), + astBuilder, + createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl))))); } SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam) -- cgit v1.2.3