summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-07-10 14:09:18 -0700
committerGitHub <noreply@github.com>2024-07-10 14:09:18 -0700
commitb89421cb3b803165455020f5b70d582b6aec6e76 (patch)
tree461327ffbb55e7cea0ca73ae11cfa18425c904a8 /source
parent45ef0ce906c93c16495755fec2e597573e8631c4 (diff)
Add reflection API for functions. (#4587)
* Add reflection API for functions. This change adds `SlangFunctionReflection` type in the reflection API that provides methods for querying function result type, parameters and user-defined attributes. `ProgramLayout::findFunctionByName` can now find a function with the given name and returns a `FunctionReflection`. `IEntryPoint` now has a `getFunctionReflection` method that returns an `FunctionReflection` for the entrypoint. * More modifiers; make reflection API consistent.
Diffstat (limited to 'source')
-rw-r--r--source/slang-capture-replay/slang-entrypoint.cpp6
-rw-r--r--source/slang-capture-replay/slang-entrypoint.h2
-rwxr-xr-xsource/slang/slang-compiler.h11
-rw-r--r--source/slang/slang-reflection-api.cpp127
-rw-r--r--source/slang/slang.cpp50
5 files changed, 192 insertions, 4 deletions
diff --git a/source/slang-capture-replay/slang-entrypoint.cpp b/source/slang-capture-replay/slang-entrypoint.cpp
index c8ce7a663..46d701f98 100644
--- a/source/slang-capture-replay/slang-entrypoint.cpp
+++ b/source/slang-capture-replay/slang-entrypoint.cpp
@@ -304,4 +304,10 @@ namespace SlangCapture
return res;
}
+
+ SLANG_NO_THROW slang::FunctionReflection* EntryPointCapture::getFunctionReflection()
+ {
+ return m_actualEntryPoint->getFunctionReflection();
+ }
+
}
diff --git a/source/slang-capture-replay/slang-entrypoint.h b/source/slang-capture-replay/slang-entrypoint.h
index 87b3cc430..06a577069 100644
--- a/source/slang-capture-replay/slang-entrypoint.h
+++ b/source/slang-capture-replay/slang-entrypoint.h
@@ -66,7 +66,7 @@ namespace SlangCapture
uint32_t compilerOptionEntryCount,
slang::CompilerOptionEntry* compilerOptionEntries,
ISlangBlob** outDiagnostics = nullptr) override;
-
+ virtual SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() override;
slang::IEntryPoint* getActualEntryPoint() const { return m_actualEntryPoint; }
private:
Slang::ComPtr<slang::IEntryPoint> m_actualEntryPoint;
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 46724119b..e9660f7ed 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -416,6 +416,10 @@ namespace Slang
String const& typeStr,
DiagnosticSink* sink);
+ DeclRef<Decl> findDeclFromString(
+ String const& name,
+ DiagnosticSink* sink);
+
Dictionary<String, IntVal*>& getMangledNameToIntValMap();
ConstantIntVal* tryFoldIntVal(IntVal* intVal);
@@ -560,6 +564,9 @@ namespace Slang
//
Dictionary<String, Type*> m_types;
+ // Any decls looked up dynamically using `findDeclFromString`.
+ Dictionary<String, DeclRef<Decl>> m_decls;
+
Scope* m_lookupScope = nullptr;
std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal;
};
@@ -1042,6 +1049,10 @@ namespace Slang
List<ExpandedSpecializationArg> existentialSpecializationArgs;
};
+ SLANG_NO_THROW slang::FunctionReflection* SLANG_MCALL getFunctionReflection() SLANG_OVERRIDE
+ {
+ return (slang::FunctionReflection*)m_funcDeclRef.getDecl();
+ }
protected:
void acceptVisitor(ComponentTypeVisitor* visitor, SpecializationInfo* specializationInfo) SLANG_OVERRIDE;
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index e7f4b9bf3..0642caeb6 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -65,6 +65,16 @@ static inline SlangReflectionVariable* convert(Decl* var)
return (SlangReflectionVariable*) var;
}
+static inline FunctionDeclBase* convert(SlangReflectionFunction* func)
+{
+ return (FunctionDeclBase*)func;
+}
+
+static inline SlangReflectionFunction* convert(FunctionDeclBase* func)
+{
+ return (SlangReflectionFunction*)func;
+}
+
static inline VarLayout* convert(SlangReflectionVariableLayout* var)
{
return (VarLayout*) var;
@@ -729,13 +739,48 @@ SLANG_API char const* spReflectionType_GetName(SlangReflectionType* inType)
auto decl = declRef.getDecl();
if(decl->hasModifier<ImplicitParameterGroupElementTypeModifier>())
return nullptr;
-
return getText(declRef.getName()).begin();
}
return nullptr;
}
+SLANG_API SlangResult spReflectionType_GetFullName(SlangReflectionType* inType, ISlangBlob** outNameBlob)
+{
+ auto type = convert(inType);
+
+ if (!type) return SLANG_FAIL;
+
+ StringBuilder sb;
+ type->toText(sb);
+ *outNameBlob = StringUtil::createStringBlob(sb.produceString()).detach();
+ return SLANG_OK;
+}
+
+SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflection* reflection, char const* name)
+{
+ auto programLayout = convert(reflection);
+ auto program = programLayout->getProgram();
+
+ // TODO: We should extend this API to support getting error messages
+ // when type lookup fails.
+ //
+ Slang::DiagnosticSink sink(
+ programLayout->getTargetReq()->getLinkage()->getSourceManager(),
+ Lexer::sourceLocationLexer);
+
+ try
+ {
+ auto result = program->findDeclFromString(name, &sink);
+ if (auto funcDeclRef = result.as<FunctionDeclBase>())
+ return (SlangReflectionFunction*)funcDeclRef.getDecl();
+ }
+ catch (...)
+ {
+ }
+ return nullptr;
+}
+
SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * reflection, char const * name)
{
auto programLayout = convert(reflection);
@@ -2558,7 +2603,24 @@ SLANG_API SlangReflectionModifier* spReflectionVariable_FindModifier(SlangReflec
case SLANG_MODIFIER_SHARED:
modifier = var->findModifier<HLSLEffectSharedModifier>();
break;
-
+ case SLANG_MODIFIER_CONST:
+ modifier = var->findModifier<ConstModifier>();
+ break;
+ case SLANG_MODIFIER_NO_DIFF:
+ modifier = var->findModifier<NoDiffModifier>();
+ break;
+ case SLANG_MODIFIER_STATIC:
+ modifier = var->findModifier<HLSLStaticModifier>();
+ break;
+ case SLANG_MODIFIER_EXPORT:
+ modifier = var->findModifier<HLSLExportModifier>();
+ break;
+ case SLANG_MODIFIER_EXTERN:
+ modifier = var->findModifier<ExternModifier>();
+ break;
+ case SLANG_MODIFIER_DIFFERENTIABLE:
+ modifier = var->findModifier<DifferentiableAttribute>();
+ break;
default:
return nullptr;
}
@@ -2729,6 +2791,57 @@ SLANG_API SlangStage spReflectionVariableLayout_getStage(
return (SlangStage) varLayout->stage;
}
+// Function Reflection
+
+SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* inFunc)
+{
+ auto func = convert(inFunc);
+ if (!func) return nullptr;
+ return getText(func->getName()).getBuffer();
+}
+
+SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* inFunc)
+{
+ auto func = convert(inFunc);
+ if (!func) return nullptr;
+
+ return convert(func->returnType.type);
+}
+
+SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* inFunc)
+{
+ auto func = convert(inFunc);
+ if (!func) return 0;
+ return getUserAttributeCount(func);
+}
+
+SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* inFunc, unsigned int index)
+{
+ auto func = convert(inFunc);
+ if (!func) return nullptr;
+ return getUserAttributeByIndex(func, index);
+}
+
+SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* inFunc, SlangSession* session, char const* name)
+{
+ auto func = convert(inFunc);
+ if (!func) return nullptr;
+ return findUserAttributeByName(asInternal(session), func, name);
+}
+
+SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* inFunc)
+{
+ auto func = convert(inFunc);
+ if (!func) return 0;
+ return (unsigned int)func->getParameters().getCount();
+}
+
+SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* inFunc, unsigned int index)
+{
+ auto func = convert(inFunc);
+ if (!func) return nullptr;
+ return convert(as<Decl>(func->getParameters()[index]));
+}
// Shader Parameter Reflection
@@ -2788,6 +2901,16 @@ SLANG_API char const* spReflectionEntryPoint_getNameOverride(SlangReflectionEntr
return nullptr;
}
+SLANG_API SlangReflectionFunction* spReflectionEntryPoint_GetFunction(SlangReflectionEntryPoint* inEntryPoint)
+{
+ auto entryPointLayout = convert(inEntryPoint);
+ if (entryPointLayout)
+ {
+ return (SlangReflectionFunction*)entryPointLayout->entryPoint.getDecl();
+ }
+ return nullptr;
+}
+
SLANG_API unsigned spReflectionEntryPoint_getParameterCount(
SlangReflectionEntryPoint* inEntryPoint)
{
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index adfa031f5..d601dcd0e 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -2229,6 +2229,55 @@ Type* ComponentType::getTypeFromString(
return type;
}
+DeclRef<Decl> ComponentType::findDeclFromString(
+ String const& name,
+ DiagnosticSink* sink)
+{
+ // If we've looked up this type name before,
+ // then we can re-use it.
+ //
+ DeclRef<Decl> result;
+ if (m_decls.tryGetValue(name, result))
+ return result;
+
+
+ // TODO(JS): For now just used the linkages ASTBuilder to keep on scope
+ //
+ // The parseTermString uses the linkage ASTBuilder for it's parsing.
+ //
+ // It might be possible to just create a temporary ASTBuilder - the worry though is
+ // that the parsing sets a member variable in AST node to one of these scopes, and then
+ // it become a dangling pointer. So for now we go with the linkages.
+ auto astBuilder = getLinkage()->getASTBuilder();
+
+ // Otherwise, we need to start looking in
+ // the modules that were directly or
+ // indirectly referenced.
+ //
+ Scope* scope = _getOrCreateScopeForLegacyLookup(astBuilder);
+
+ auto linkage = getLinkage();
+
+ SLANG_AST_BUILDER_RAII(linkage->getASTBuilder());
+
+ Expr* expr = linkage->parseTermString(name, scope);
+
+ SharedSemanticsContext sharedSemanticsContext(
+ linkage,
+ nullptr,
+ sink);
+ SemanticsVisitor visitor(&sharedSemanticsContext);
+
+ auto checkedExpr = visitor.CheckExpr(expr);
+ if (auto declRefExpr = as<DeclRefExpr>(checkedExpr))
+ {
+ result = declRefExpr->declRef;
+ }
+
+ m_decls[name] = result;
+ return result;
+}
+
static void collectExportedConstantInContainer(
Dictionary<String, IntVal*>& dict,
ASTBuilder* builder,
@@ -4075,7 +4124,6 @@ RefPtr<EntryPoint> Module::findEntryPointByName(UnownedStringSlice const& name)
return nullptr;
}
-
RefPtr<EntryPoint> Module::findAndCheckEntryPoint(
UnownedStringSlice const& name,
SlangStage stage,