summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-16 16:04:45 -0400
committerGitHub <noreply@github.com>2024-09-16 16:04:45 -0400
commitd866c0b9dfc0fdc8ad8cede4d7a8593f7ddf4716 (patch)
tree77cd8713987e575aaf8c7436cd9d2fda8ddc9e63 /source
parentc46ca4cfeff2c78078aa3c4014cd6b0341ee01fc (diff)
Add API method to specialize function reference with argument types (#4966)
* Add `FunctionReflection::specializeWithArgTypes()` * Update slang.cpp * Use a shared semantics context on linkage Improve performance on reflection queries * Try to fix linux/mac compile errors
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-impl.h2
-rwxr-xr-xsource/slang/slang-compiler.h14
-rw-r--r--source/slang/slang-reflection-api.cpp69
-rw-r--r--source/slang/slang.cpp100
4 files changed, 136 insertions, 49 deletions
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index dc4568f8a..ad3539a21 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -598,7 +598,7 @@ namespace Slang
};
/// Shared state for a semantics-checking session.
- struct SharedSemanticsContext
+ struct SharedSemanticsContext : public RefObject
{
Linkage* m_linkage = nullptr;
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 62e4c5f4a..0c788ae18 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -40,7 +40,9 @@
namespace Slang
{
struct PathInfo;
- struct IncludeHandler;
+ struct IncludeHandler;
+ struct SharedSemanticsContext;
+
class ProgramLayout;
class PtrType;
class TargetProgram;
@@ -2170,6 +2172,11 @@ namespace Slang
DeclRef<Decl> declRef,
List<Expr*> argExprs,
DiagnosticSink* sink);
+
+ DeclRef<Decl> specializeWithArgTypes(
+ DeclRef<Decl> funcDeclRef,
+ List<Type*> argTypes,
+ DiagnosticSink* sink);
DiagnosticSink::Flags diagnosticSinkFlags = 0;
@@ -2183,6 +2190,9 @@ namespace Slang
m_retainedSession = nullptr;
}
+ // Get shared semantics information for reflection purposes.
+ SharedSemanticsContext* getSemanticsForReflection();
+
private:
/// The global Slang library session that this linkage is a child of
Session* m_session = nullptr;
@@ -2236,6 +2246,8 @@ namespace Slang
List<Type*> m_specializedTypes;
+ RefPtr<SharedSemanticsContext> m_semanticsForReflection;
+
};
/// Shared functionality between front- and back-end compile requests.
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index efa9a20a9..38129babf 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -797,9 +797,18 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti
programLayout->getTargetReq()->getLinkage()->getSourceManager(),
Lexer::sourceLocationLexer);
+ auto astBuilder = program->getLinkage()->getASTBuilder();
try
{
auto result = program->findDeclFromString(name, &sink);
+
+ if (auto genericDeclRef = result.as<GenericDecl>())
+ {
+ auto innerDeclRef = substituteDeclRef(
+ SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
+ result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
+ }
+
if (auto funcDeclRef = result.as<FunctionDeclBase>())
return convert(funcDeclRef);
}
@@ -924,7 +933,7 @@ SLANG_API bool spReflection_isSubType(
}
}
-SlangReflectionGeneric* getInnermostGenericParent(DeclRef<Decl> declRef)
+DeclRef<Decl> getInnermostGenericParent(DeclRef<Decl> declRef)
{
auto decl = declRef.getDecl();
auto astBuilder = getModule(decl)->getLinkage()->getASTBuilder();
@@ -932,15 +941,14 @@ SlangReflectionGeneric* getInnermostGenericParent(DeclRef<Decl> declRef)
while(parentDecl)
{
if(parentDecl->parentDecl && as<GenericDecl>(parentDecl->parentDecl))
- return convertDeclToGeneric(
- substituteDeclRef(
+ return substituteDeclRef(
SubstitutionSet(declRef),
astBuilder,
- createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl))));
+ createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(parentDecl)));
parentDecl = parentDecl->parentDecl;
}
- return nullptr;
+ return DeclRef<Decl>();
}
SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type)
@@ -948,11 +956,13 @@ SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangRefl
auto slangType = convert(type);
if (auto declRefType = as<DeclRefType>(slangType))
{
- return getInnermostGenericParent(declRefType->getDeclRef());
+ return convertDeclToGeneric(
+ getInnermostGenericParent(declRefType->getDeclRef()));
}
else if (auto genericDeclRefType = as<GenericDeclRefType>(slangType))
{
- return getInnermostGenericParent(genericDeclRefType->getDeclRef());
+ return convertDeclToGeneric(
+ getInnermostGenericParent(genericDeclRefType->getDeclRef()));
}
return nullptr;
@@ -2835,7 +2845,7 @@ SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inV
SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var)
{
auto declRef = convert(var);
- return getInnermostGenericParent(declRef);
+ return convertDeclToGeneric(getInnermostGenericParent(declRef));
}
SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic)
@@ -3072,7 +3082,7 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func)
{
auto declRef = convert(func);
- return getInnermostGenericParent(declRef);
+ return convertDeclToGeneric(getInnermostGenericParent(declRef));
}
SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic)
@@ -3088,6 +3098,36 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(Sla
return convert(substDeclRef.as<FunctionDeclBase>());
}
+SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
+ SlangReflectionFunction* func,
+ SlangInt argTypeCount,
+ SlangReflectionType* const* argTypes)
+{
+ auto declRef = convert(func);
+ if (!declRef)
+ return nullptr;
+
+
+ auto linkage = getModule(declRef.getDecl())->getLinkage();
+
+ List<Type*> argTypeList;
+ for (SlangInt ii = 0; ii < argTypeCount; ++ii)
+ {
+ auto argType = convert(argTypes[ii]);
+ argTypeList.add(argType);
+ }
+
+ try
+ {
+ DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
+ return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as<FunctionDeclBase>());
+ }
+ catch (...)
+ {
+ return nullptr;
+ }
+}
+
// Abstract decl reflection
SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl)
@@ -3329,11 +3369,12 @@ SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(S
auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder();
- return getInnermostGenericParent(
- substituteDeclRef(
- SubstitutionSet(declRef),
- astBuilder,
- createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl))));
+ return convertDeclToGeneric(
+ getInnermostGenericParent(
+ substituteDeclRef(
+ SubstitutionSet(declRef),
+ astBuilder,
+ createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, DeclRef(declRef.getDecl()->parentDecl)))));
}
SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam)
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index c78348a86..6c152cddd 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -28,7 +28,6 @@
#include "slang-type-layout.h"
#include "slang-lookup.h"
-#
#include "slang-options.h"
#include "slang-repro.h"
@@ -1069,8 +1068,12 @@ Linkage::Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinka
for (const auto& nameToMod : builtinLinkage->mapNameToLoadedModules)
mapNameToLoadedModules.add(nameToMod);
}
+
+ m_semanticsForReflection = new SharedSemanticsContext(this, nullptr, nullptr);
}
+SharedSemanticsContext* Linkage::getSemanticsForReflection() { return m_semanticsForReflection.get(); }
+
ISlangUnknown* Linkage::getInterface(const Guid& guid)
{
if(guid == ISlangUnknown::getTypeGuid() || guid == ISession::getTypeGuid())
@@ -1348,18 +1351,11 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType(
return asExternal(specializedType);
}
-
-DeclRef<Decl> Linkage::specializeGeneric(
- DeclRef<Decl> declRef,
- List<Expr*> argExprs,
- DiagnosticSink* sink)
+DeclRef<GenericDecl> getGenericParentDeclRef(
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* visitor,
+ DeclRef<Decl> declRef)
{
- SLANG_AST_BUILDER_RAII(getASTBuilder());
- SLANG_ASSERT(declRef);
-
- SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink);
- SemanticsVisitor visitor(&sharedSemanticsContext);
-
// Create substituted parent decl ref.
auto decl = declRef.getDecl();
@@ -1369,9 +1365,58 @@ DeclRef<Decl> Linkage::specializeGeneric(
}
auto genericDecl = as<GenericDecl>(decl);
- auto genericDeclRef = createDefaultSubstitutionsIfNeeded(getASTBuilder(), &visitor, DeclRef(genericDecl)).as<GenericDecl>();
- genericDeclRef = substituteDeclRef(SubstitutionSet(declRef), getASTBuilder(), genericDeclRef).as<GenericDecl>();
+ auto genericDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, DeclRef(genericDecl)).as<GenericDecl>();
+ return substituteDeclRef(SubstitutionSet(declRef), astBuilder, genericDeclRef).as<GenericDecl>();
+}
+
+DeclRef<Decl> Linkage::specializeWithArgTypes(
+ DeclRef<Decl> funcDeclRef,
+ List<Type*> argTypes,
+ DiagnosticSink* sink)
+{
+ SemanticsVisitor visitor(getSemanticsForReflection());
+ visitor = visitor.withSink(sink);
+
+ ASTBuilder* astBuilder = getASTBuilder();
+ List<Expr*> argExprs;
+ for (SlangInt aa = 0; aa < argTypes.getCount(); ++aa)
+ {
+ auto argType = argTypes[aa];
+
+ // Create an 'empty' expr with the given type. Ideally, the expression itself should not matter
+ // only its checked type.
+ //
+ auto argExpr = astBuilder->create<VarExpr>();
+ argExpr->type = argType;
+ argExprs.add(argExpr);
+ }
+
+ // Construct invoke expr.
+ auto invokeExpr = astBuilder->create<InvokeExpr>();
+ auto declRefExpr = astBuilder->create<DeclRefExpr>();
+
+ declRefExpr->declRef = getGenericParentDeclRef(getASTBuilder(), &visitor, funcDeclRef);
+ invokeExpr->functionExpr = declRefExpr;
+ invokeExpr->arguments = argExprs;
+
+ auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr);
+ return as<DeclRefExpr>(as<InvokeExpr>(checkedInvokeExpr)->functionExpr)->declRef;
+}
+
+
+DeclRef<Decl> Linkage::specializeGeneric(
+ DeclRef<Decl> declRef,
+ List<Expr*> argExprs,
+ DiagnosticSink* sink)
+{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+ SLANG_ASSERT(declRef);
+
+ SemanticsVisitor visitor(getSemanticsForReflection());
+ visitor = visitor.withSink(sink);
+
+ auto genericDeclRef = getGenericParentDeclRef(getASTBuilder(), &visitor, declRef);
DeclRefExpr* declRefExpr = getASTBuilder()->create<DeclRefExpr>();
declRefExpr->declRef = genericDeclRef;
@@ -1561,8 +1606,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL Linkage::createTypeConformanceComponentTy
try
{
- SharedSemanticsContext sharedSemanticsContext(this, nullptr, &sink);
- SemanticsVisitor visitor(&sharedSemanticsContext);
+ SemanticsVisitor visitor(getSemanticsForReflection());
+ visitor = visitor.withSink(&sink);
+
auto witness =
visitor.isSubtype((Slang::Type*)type, (Slang::Type*)interfaceType, IsSubTypeOptions::None);
if (auto subtypeWitness = as<SubtypeWitness>(witness))
@@ -2318,12 +2364,8 @@ DeclRef<Decl> ComponentType::findDeclFromString(
Expr* expr = linkage->parseTermString(name, scope);
- SharedSemanticsContext sharedSemanticsContext(
- linkage,
- nullptr,
- sink);
- SemanticsContext context(&sharedSemanticsContext);
- context = context.allowStaticReferenceToNonStaticMember();
+ SemanticsContext context(linkage->getSemanticsForReflection());
+ context = context.allowStaticReferenceToNonStaticMember().withSink(sink);
SemanticsVisitor visitor(context);
@@ -2377,12 +2419,8 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType(
Expr* expr = linkage->parseTermString(name, scope);
- SharedSemanticsContext sharedSemanticsContext(
- linkage,
- nullptr,
- sink);
- SemanticsContext context(&sharedSemanticsContext);
- context = context.allowStaticReferenceToNonStaticMember();
+ SemanticsContext context(linkage->getSemanticsForReflection());
+ context = context.allowStaticReferenceToNonStaticMember().withSink(sink);
SemanticsVisitor visitor(context);
@@ -2433,11 +2471,7 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType(
bool ComponentType::isSubType(Type* subType, Type* superType)
{
- SharedSemanticsContext sharedSemanticsContext(
- getLinkage(),
- nullptr,
- nullptr);
- SemanticsContext context(&sharedSemanticsContext);
+ SemanticsContext context(getLinkage()->getSemanticsForReflection());
SemanticsVisitor visitor(context);
return (visitor.isSubtype(subType, superType, IsSubTypeOptions::None) != nullptr);