summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-shader.cpp13
-rwxr-xr-xsource/slang/slang-compiler.h8
-rw-r--r--source/slang/slang-reflection-api.cpp167
-rw-r--r--source/slang/slang.cpp61
4 files changed, 175 insertions, 74 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp
index 1718c3afd..08bac1f78 100644
--- a/source/slang/slang-check-shader.cpp
+++ b/source/slang/slang-check-shader.cpp
@@ -235,8 +235,17 @@ namespace Slang
Name* name,
DiagnosticSink* sink)
{
- auto declRef = translationUnit->findDeclFromString(getText(name), sink);
- FuncDecl* entryPointFuncDecl = declRef.as<FuncDecl>().getDecl();
+ FuncDecl* entryPointFuncDecl = nullptr;
+
+ auto expr = translationUnit->findDeclFromString(getText(name), sink);
+ if (auto declRefExpr = as<DeclRefExpr>(expr))
+ {
+ auto declRef = declRefExpr->declRef;
+ entryPointFuncDecl = declRef.as<FuncDecl>().getDecl();
+
+ if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit)
+ entryPointFuncDecl = nullptr;
+ }
if (entryPointFuncDecl && getModule(entryPointFuncDecl) != translationUnit)
entryPointFuncDecl = nullptr;
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 4b20d1f76..820cab03f 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -420,11 +420,11 @@ namespace Slang
String const& typeStr,
DiagnosticSink* sink);
- DeclRef<Decl> findDeclFromString(
+ Expr* findDeclFromString(
String const& name,
DiagnosticSink* sink);
- DeclRef<Decl> findDeclFromStringInType(
+ Expr* findDeclFromStringInType(
Type* type,
String const& name,
LookupMask mask,
@@ -576,7 +576,7 @@ namespace Slang
Dictionary<String, Type*> m_types;
// Any decls looked up dynamically using `findDeclFromString`.
- Dictionary<String, DeclRef<Decl>> m_decls;
+ Dictionary<String, Expr*> m_decls;
Scope* m_lookupScope = nullptr;
std::unique_ptr<Dictionary<String, IntVal*>> m_mapMangledNameToIntVal;
@@ -2174,7 +2174,7 @@ namespace Slang
DiagnosticSink* sink);
DeclRef<Decl> specializeWithArgTypes(
- DeclRef<Decl> funcDeclRef,
+ Expr* funcExpr,
List<Type*> argTypes,
DiagnosticSink* sink);
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 38129babf..b6fc05986 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -66,10 +66,21 @@ static inline SlangReflectionVariable* convert(DeclRef<Decl> var)
return (SlangReflectionVariable*) var.declRefBase;
}
-static inline DeclRef<FunctionDeclBase> convert(SlangReflectionFunction* func)
+static inline DeclRef<FunctionDeclBase> convertToFunc(SlangReflectionFunction* func)
{
- DeclRefBase* declBase = (DeclRefBase*)func;
- return DeclRef<FunctionDeclBase>(declBase);
+ NodeBase* nodeBase = (NodeBase*)func;
+ if (DeclRefBase* declRefBase = as<DeclRefBase>(nodeBase))
+ {
+ return DeclRef<FunctionDeclBase>(declRefBase);
+ }
+
+ return DeclRef<FunctionDeclBase>();
+}
+
+static inline OverloadedExpr* convertToOverloadedFunc(SlangReflectionFunction* func)
+{
+ NodeBase* nodeBase = (NodeBase*)func;
+ return as<OverloadedExpr>(nodeBase);
}
static inline SlangReflectionFunction* convert(DeclRef<FunctionDeclBase> func)
@@ -77,6 +88,11 @@ static inline SlangReflectionFunction* convert(DeclRef<FunctionDeclBase> func)
return (SlangReflectionFunction*)func.declRefBase;
}
+static inline SlangReflectionFunction* convert(OverloadedExpr* overloadedFunc)
+{
+ return (SlangReflectionFunction*)overloadedFunc;
+}
+
static inline DeclRef<Decl> convertGenericToDeclRef(SlangReflectionGeneric* func)
{
DeclRefBase* declBase = (DeclRefBase*)func;
@@ -785,6 +801,27 @@ SLANG_API SlangResult spReflectionType_GetFullName(SlangReflectionType* inType,
return SLANG_OK;
}
+SlangReflectionFunction* tryConvertExprToFunctionReflection(ASTBuilder* astBuilder, Expr* expr)
+{
+ if (auto declRefExpr = as<DeclRefExpr>(expr))
+ {
+ auto declRef = declRefExpr->declRef;
+ if (auto genericDeclRef = declRef.as<GenericDecl>())
+ {
+ auto innerDeclRef = substituteDeclRef(
+ SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
+ declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
+ }
+
+ if (auto funcDeclRef = declRef.as<FunctionDeclBase>())
+ return convert(funcDeclRef);
+ }
+ else if (auto overloadedExpr = as<OverloadedExpr>(expr))
+ return convert(overloadedExpr);
+
+ return nullptr;
+}
+
SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflection* reflection, char const* name)
{
auto programLayout = convert(reflection);
@@ -800,17 +837,9 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByName(SlangReflecti
auto astBuilder = program->getLinkage()->getASTBuilder();
try
{
- auto result = program->findDeclFromString(name, &sink);
-
- if (auto genericDeclRef = result.as<GenericDecl>())
- {
- auto innerDeclRef = substituteDeclRef(
- SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
- result = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
- }
-
- if (auto funcDeclRef = result.as<FunctionDeclBase>())
- return convert(funcDeclRef);
+ return tryConvertExprToFunctionReflection(
+ astBuilder,
+ program->findDeclFromString(name, &sink));
}
catch (...)
{
@@ -828,12 +857,13 @@ SLANG_API SlangReflectionFunction* spReflection_FindFunctionByNameInType(SlangRe
Slang::DiagnosticSink sink(
programLayout->getTargetReq()->getLinkage()->getSourceManager(),
Lexer::sourceLocationLexer);
-
+
+ auto astBuilder = program->getLinkage()->getASTBuilder();
+
try
{
auto result = program->findDeclFromStringInType(type, name, LookupMask::Function, &sink);
- if (auto funcDeclRef = result.as<FunctionDeclBase>())
- return convert(funcDeclRef);
+ return tryConvertExprToFunctionReflection(astBuilder, result);
}
catch (...)
{
@@ -855,8 +885,11 @@ SLANG_API SlangReflectionVariable* spReflection_FindVarByNameInType(SlangReflect
try
{
auto result = program->findDeclFromStringInType(type, name, LookupMask::Value, &sink);
- if (auto varDeclRef = result.as<VarDeclBase>())
- return convert(varDeclRef.as<Decl>());
+ if (auto declRefExpr = as<DeclRefExpr>(result))
+ {
+ if (auto varDeclRef = declRefExpr->declRef.as<VarDeclBase>())
+ return convert(varDeclRef.as<Decl>());
+ }
}
catch (...)
{
@@ -3009,21 +3042,23 @@ SLANG_API SlangStage spReflectionVariableLayout_getStage(
SLANG_API SlangReflectionDecl* spReflectionFunction_asDecl(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
+
return (SlangReflectionDecl*)func.getDecl();
}
SLANG_API char const* spReflectionFunction_GetName(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
+
return getText(func.getDecl()->getName()).getBuffer();
}
SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
auto rawType = func.getDecl()->returnType.type;
@@ -3034,7 +3069,9 @@ SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectio
SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflectionFunction* inFunc, SlangModifierID modifierID)
{
- auto funcDeclRef = convert(inFunc);
+ auto funcDeclRef = convertToFunc(inFunc);
+ if (!funcDeclRef) return nullptr;
+
auto varRefl = convert(funcDeclRef.as<Decl>());
if (!varRefl) return nullptr;
@@ -3043,35 +3080,38 @@ SLANG_API SlangReflectionModifier* spReflectionFunction_FindModifier(SlangReflec
SLANG_API unsigned int spReflectionFunction_GetUserAttributeCount(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return 0;
+
return getUserAttributeCount(func.getDecl());
}
SLANG_API SlangReflectionUserAttribute* spReflectionFunction_GetUserAttribute(SlangReflectionFunction* inFunc, unsigned int index)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
return getUserAttributeByIndex(func.getDecl(), index);
}
SLANG_API SlangReflectionUserAttribute* spReflectionFunction_FindUserAttributeByName(SlangReflectionFunction* inFunc, SlangSession* session, char const* name)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
+
return findUserAttributeByName(asInternal(session), func.getDecl(), name);
}
SLANG_API unsigned int spReflectionFunction_GetParameterCount(SlangReflectionFunction* inFunc)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return 0;
+
return (unsigned int)func.getDecl()->getParameters().getCount();
}
SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* inFunc, unsigned int index)
{
- auto func = convert(inFunc);
+ auto func = convertToFunc(inFunc);
if (!func) return nullptr;
auto astBuilder = getModule(func.getDecl())->getLinkage()->getASTBuilder();
@@ -3081,13 +3121,16 @@ SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflec
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func)
{
- auto declRef = convert(func);
+ auto declRef = convertToFunc(func);
+ if (!declRef)
+ return nullptr;
+
return convertDeclToGeneric(getInnermostGenericParent(declRef));
}
SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic)
{
- auto declRef = convert(func);
+ auto declRef = convertToFunc(func);
auto genericDeclRef = convertGenericToDeclRef(generic);
if (!declRef || !genericDeclRef)
return nullptr;
@@ -3103,12 +3146,25 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
SlangInt argTypeCount,
SlangReflectionType* const* argTypes)
{
- auto declRef = convert(func);
- if (!declRef)
+ Linkage* linkage = nullptr;
+ Expr* funcExpr = nullptr;
+
+ if (auto funcDeclRef = convertToFunc(func))
+ {
+ linkage = getModule(funcDeclRef.getDecl())->getLinkage();
+ auto declRefExpr = linkage->getASTBuilder()->create<DeclRefExpr>();
+ declRefExpr->declRef = funcDeclRef;
+ funcExpr = declRefExpr;
+ }
+ else if (auto overloadedExpr = convertToOverloadedFunc(func))
+ {
+ linkage = getModule(overloadedExpr->lookupResult2.items[0].declRef.getDecl())->getLinkage();
+ funcExpr = overloadedExpr;
+ }
+ else
+ {
return nullptr;
-
-
- auto linkage = getModule(declRef.getDecl())->getLinkage();
+ }
List<Type*> argTypeList;
for (SlangInt ii = 0; ii < argTypeCount; ++ii)
@@ -3120,7 +3176,7 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
try
{
DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
- return convert(linkage->specializeWithArgTypes(declRef, argTypeList, &sink).as<FunctionDeclBase>());
+ return convert(linkage->specializeWithArgTypes(funcExpr, argTypeList, &sink).as<FunctionDeclBase>());
}
catch (...)
{
@@ -3128,6 +3184,45 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
}
}
+SLANG_API bool spReflectionFunction_isOverloaded(
+ SlangReflectionFunction* func)
+{
+ return (convertToOverloadedFunc(func) != nullptr);
+}
+
+SLANG_API unsigned int spReflectionFunction_getOverloadCount(
+ SlangReflectionFunction* func)
+{
+ auto overloadedFunc = convertToOverloadedFunc(func);
+ if (!overloadedFunc) return 1;
+
+ return (unsigned int) overloadedFunc->lookupResult2.items.getCount();
+}
+
+SLANG_API SlangReflectionFunction* spReflectionFunction_getOverload(
+ SlangReflectionFunction* func,
+ unsigned int index)
+{
+ auto overloadedFunc = convertToOverloadedFunc(func);
+ if (!overloadedFunc) return nullptr;
+
+ auto declRef = overloadedFunc->lookupResult2.items[index].declRef;
+ if (auto funcDeclRef = declRef.as<FunctionDeclBase>())
+ {
+ return convert(declRef.as<FunctionDeclBase>());
+ }
+ else if (auto genericDeclRef = declRef.as<GenericDecl>())
+ {
+ auto astBuilder = getModule(genericDeclRef.getDecl())->getLinkage()->getASTBuilder();
+ auto innerDeclRef = substituteDeclRef(
+ SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
+ return convert(
+ createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef).as<FunctionDeclBase>());
+ }
+
+ return nullptr;
+}
+
// Abstract decl reflection
SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl)
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 6c152cddd..dc5f9a755 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1370,7 +1370,7 @@ DeclRef<GenericDecl> getGenericParentDeclRef(
}
DeclRef<Decl> Linkage::specializeWithArgTypes(
- DeclRef<Decl> funcDeclRef,
+ Expr* funcExpr,
List<Type*> argTypes,
DiagnosticSink* sink)
{
@@ -1378,6 +1378,16 @@ DeclRef<Decl> Linkage::specializeWithArgTypes(
visitor = visitor.withSink(sink);
ASTBuilder* astBuilder = getASTBuilder();
+
+ if (auto declRefFuncExpr = as<DeclRefExpr>(funcExpr))
+ {
+ auto genericDeclRefExpr = astBuilder->create<DeclRefExpr>();
+ genericDeclRefExpr->declRef = getGenericParentDeclRef(
+ getASTBuilder(),
+ &visitor,
+ declRefFuncExpr->declRef);
+ funcExpr = genericDeclRefExpr;
+ }
List<Expr*> argExprs;
for (SlangInt aa = 0; aa < argTypes.getCount(); ++aa)
@@ -1394,10 +1404,7 @@ DeclRef<Decl> Linkage::specializeWithArgTypes(
// Construct invoke expr.
auto invokeExpr = astBuilder->create<InvokeExpr>();
- auto declRefExpr = astBuilder->create<DeclRefExpr>();
-
- declRefExpr->declRef = getGenericParentDeclRef(getASTBuilder(), &visitor, funcDeclRef);
- invokeExpr->functionExpr = declRefExpr;
+ invokeExpr->functionExpr = funcExpr;
invokeExpr->arguments = argExprs;
auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr);
@@ -2331,14 +2338,14 @@ Type* ComponentType::getTypeFromString(
return type;
}
-DeclRef<Decl> ComponentType::findDeclFromString(
+Expr* ComponentType::findDeclFromString(
String const& name,
DiagnosticSink* sink)
{
// If we've looked up this type name before,
// then we can re-use it.
//
- DeclRef<Decl> result;
+ Expr* result = nullptr;
if (m_decls.tryGetValue(name, result))
return result;
@@ -2369,34 +2376,26 @@ DeclRef<Decl> ComponentType::findDeclFromString(
SemanticsVisitor visitor(context);
- auto checkedExpr = visitor.CheckExpr(expr);
- if (auto declRefExpr = as<DeclRefExpr>(checkedExpr))
- {
- result = declRefExpr->declRef;
- }
- else if (auto overloadedExpr = as<OverloadedExpr>(checkedExpr))
+ auto checkedExpr = visitor.CheckTerm(expr);
+
+ if (as<DeclRefExpr>(checkedExpr) || as<OverloadedExpr>(checkedExpr))
{
- sink->diagnose(SourceLoc(), Diagnostics::ambiguousReference, name);
- for (auto candidate : overloadedExpr->lookupResult2)
- {
- sink->diagnose(candidate.declRef.getDecl(), Diagnostics::overloadCandidate, candidate.declRef);
- }
+ result = checkedExpr;
}
+
m_decls[name] = result;
return result;
}
-DeclRef<Decl> ComponentType::findDeclFromStringInType(
+Expr* ComponentType::findDeclFromStringInType(
Type* type,
String const& name,
LookupMask mask,
DiagnosticSink* sink)
{
- DeclRef<Decl> result;
-
// Only look up in the type if it is a DeclRefType
if (!as<DeclRefType>(type))
- return DeclRef<Decl>();
+ return nullptr;
// TODO(JS): For now just used the linkages ASTBuilder to keep on scope
//
@@ -2433,7 +2432,7 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType(
}
if (!as<VarExpr>(expr))
- return result;
+ return nullptr;
auto rs = astBuilder->create<StaticMemberExpr>();
auto typeExpr = astBuilder->create<SharedTypeExpr>();
@@ -2453,20 +2452,18 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType(
auto checkedTerm = visitor.CheckTerm(expr);
auto resolvedTerm = visitor.maybeResolveOverloadedExpr(checkedTerm, mask, sink);
+
- if (auto declRefExpr = as<DeclRefExpr>(resolvedTerm))
+ if (auto overloadedExpr = as<OverloadedExpr>(resolvedTerm))
{
- result = declRefExpr->declRef;
+ return overloadedExpr;
}
-
- if (auto genericDeclRef = result.as<GenericDecl>())
- {
- result = createDefaultSubstitutionsIfNeeded(
- astBuilder, &visitor, DeclRef(genericDeclRef.getDecl()->inner));
- result = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, result);
+ if (auto declRefExpr = as<DeclRefExpr>(resolvedTerm))
+ {
+ return declRefExpr;
}
- return result;
+ return nullptr;
}
bool ComponentType::isSubType(Type* subType, Type* superType)