summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-08-27 16:47:05 -0400
committerGitHub <noreply@github.com>2024-08-27 16:47:05 -0400
commit4aac22da6ae902eca1e7750f4e5b83ba238b5874 (patch)
treef266e3c7c3a646473ac4af80ddbcd72702ced917
parentd40c143eb4f19f1dfd0d0dcf9b718be6e495ca27 (diff)
Add ability to specialize generic references to functions, types and more (#4909)
* More reflection API features. + Lookup methods and members (by string) on types + Fix issue with looking up non-static members through the scope operator '::' + `GenericReflection`: Cast a decl to generic to access unspecialized generic parameter names and constraints + `GenericReflection`: Use `getGenericContainer()` from function, variable or type to access the 'nearest' generic parent along with specialization info + `GenericReflection::getConcreteType` and `GenericReflection::getConcreteIntVal`: to get the concrete type of a param in the context of the reflection object + `GenericReflection::getOuterGenericContainer` to go up one level and get the outer generic declarations (if there are more than one enclosing generic scopes) + `DeclReflection::getParent`: go to parent declaration. + Change `VariableReflection` to be a `DeclRef` rather than a decl (allows us to return properly substituted types for methods, members, and more) * Fix Falcor issue * Initial namespace reflection support * FIx issue with specializing witness tables * Add API method for specializing parameters of a generic decl * Add ability to specialize generic references to functions, types and more This PR adds the following end-points: - `specializeGeneric()` method that can be called on a generic reflection to substitute arguments for generic type and value parameters. It returns another generic reflection, but this time with the appropriate substitution. - `applySpecializations()` method to then copy these specializations onto an existing type or function reflection. - `isSubType()` to check if a type is a subtype of another type (useful to check if a type is differentiable by checking `IDifferentiable`) This PR also: - Adds `DeclReflection::Kind::Namespace` so that namespace containers are correctly reflected when walking the decl-tree. the name can be obtained through `getName()` but there's no need to cast to a namespace (since there's nothing else we can do with a namespace decl) - Fixes an issue with name-based lookups that fail if a type or function is referenced without specializations. Its helpful to be able to form a reference to a function with default substitutions, so that we can we can specialize it later (either directly, or via argument types). * Update slang.h * Fix up naming * Update slang-compiler.h * Update slang-reflection-api.cpp * Update slang.cpp * Update slang.cpp * Update slang.cpp * Use `checkGenericAppWithCheckedArgs` to do specialization * Update slang-reflection-api.cpp * Update slang-check-decl.cpp
-rw-r--r--include/slang.h102
-rwxr-xr-xsource/slang/slang-compiler.h17
-rw-r--r--source/slang/slang-reflection-api.cpp188
-rw-r--r--source/slang/slang-syntax.h7
-rw-r--r--source/slang/slang.cpp58
-rw-r--r--tools/slang-unit-test/unit-test-decl-tree-reflection.cpp126
6 files changed, 486 insertions, 12 deletions
diff --git a/include/slang.h b/include/slang.h
index 8f4d53a0d..7c417e74e 100644
--- a/include/slang.h
+++ b/include/slang.h
@@ -2110,6 +2110,20 @@ extern "C"
typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute;
typedef struct SlangReflectionFunction SlangReflectionFunction;
typedef struct SlangReflectionGeneric SlangReflectionGeneric;
+
+ union SlangReflectionGenericArg
+ {
+ SlangReflectionType* typeVal;
+ int64_t intVal;
+ bool boolVal;
+ };
+
+ enum SlangReflectionGenericArgType
+ {
+ SLANG_GENERIC_ARG_TYPE = 0,
+ SLANG_GENERIC_ARG_INT = 1,
+ SLANG_GENERIC_ARG_BOOL = 2
+ };
/*
Type aliases to maintain backward compatibility.
@@ -2179,7 +2193,8 @@ extern "C"
SLANG_DECL_KIND_FUNC,
SLANG_DECL_KIND_MODULE,
SLANG_DECL_KIND_GENERIC,
- SLANG_DECL_KIND_VARIABLE
+ SLANG_DECL_KIND_VARIABLE,
+ SLANG_DECL_KIND_NAMESPACE
};
#ifndef SLANG_RESOURCE_SHAPE
@@ -2428,6 +2443,7 @@ extern "C"
SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* type);
SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* type, unsigned int index);
SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* type, char const* name);
+ SLANG_API SlangReflectionType* spReflectionType_applySpecializations(SlangReflectionType* type, SlangReflectionGeneric* generic);
SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* type);
SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* type, unsigned index);
@@ -2544,6 +2560,7 @@ extern "C"
SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * globalSession, char const* name);
SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar);
SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var);
+ SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic);
// Variable Layout Reflection
@@ -2570,11 +2587,13 @@ extern "C"
SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* func, unsigned index);
SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func);
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func);
+ SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic);
// Abstract Decl Reflection
SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl);
SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* parentDecl, unsigned int index);
+ SLANG_API char const* spReflectionDecl_getName(SlangReflectionDecl* decl);
SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl);
SLANG_API SlangReflectionFunction* spReflectionDecl_castToFunction(SlangReflectionDecl* decl);
SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflectionDecl* decl);
@@ -2597,6 +2616,7 @@ extern "C"
SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(SlangReflectionGeneric* generic);
SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam);
SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(SlangReflectionGeneric* generic, SlangReflectionVariable* valueParam);
+ SLANG_API SlangReflectionGeneric* spReflectionGeneric_applySpecializations(SlangReflectionGeneric* currGeneric, SlangReflectionGeneric* generic);
/** Get the stage that a variable belongs to (if any).
@@ -2698,12 +2718,25 @@ extern "C"
SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection);
SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection);
- SLANG_API SlangReflectionType* spReflection_specializeType(
+ SLANG_API SlangReflectionType* spReflection_specializeType(
SlangReflection* reflection,
SlangReflectionType* type,
SlangInt specializationArgCount,
SlangReflectionType* const* specializationArgs,
ISlangBlob** outDiagnostics);
+
+ SLANG_API SlangReflectionGeneric* spReflection_specializeGeneric(
+ SlangReflection* inProgramLayout,
+ SlangReflectionGeneric* generic,
+ SlangInt argCount,
+ SlangReflectionGenericArgType const* argTypes,
+ SlangReflectionGenericArg const* args,
+ ISlangBlob** outDiagnostics);
+
+ SLANG_API bool spReflection_isSubType(
+ SlangReflection * reflection,
+ SlangReflectionType* subType,
+ SlangReflectionType* superType);
/// Get the number of hashed strings
SLANG_API SlangUInt spReflection_getHashedStringCount(
@@ -2750,6 +2783,13 @@ namespace slang
struct FunctionReflection;
struct GenericReflection;
+ union GenericArgReflection
+ {
+ TypeReflection* typeVal;
+ int64_t intVal;
+ bool boolVal;
+ };
+
struct UserAttribute
{
char const* getName()
@@ -2919,18 +2959,25 @@ namespace slang
{
return spReflectionType_GetUserAttributeCount((SlangReflectionType*)this);
}
+
UserAttribute* getUserAttributeByIndex(unsigned int index)
{
return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index);
}
+
UserAttribute* findUserAttributeByName(char const* name)
{
return (UserAttribute*)spReflectionType_FindUserAttributeByName((SlangReflectionType*)this, name);
}
- SlangReflectionGeneric* getGenericContainer()
+ TypeReflection* applySpecializations(GenericReflection* generic)
+ {
+ return (TypeReflection*)spReflectionType_applySpecializations((SlangReflectionType*)this, (SlangReflectionGeneric*)generic);
+ }
+
+ GenericReflection* getGenericContainer()
{
- return (SlangReflectionGeneric*) spReflectionType_GetGenericContainer((SlangReflectionType*) this);
+ return (GenericReflection*) spReflectionType_GetGenericContainer((SlangReflectionType*) this);
}
};
@@ -3399,6 +3446,11 @@ namespace slang
{
return (GenericReflection*)spReflectionVariable_GetGenericContainer((SlangReflectionVariable*)this);
}
+
+ VariableReflection* applySpecializations(GenericReflection* generic)
+ {
+ return (VariableReflection*)spReflectionVariable_applySpecializations((SlangReflectionVariable*)this, (SlangReflectionGeneric*)generic);
+ }
};
struct VariableLayoutReflection
@@ -3529,6 +3581,11 @@ namespace slang
{
return (GenericReflection*)spReflectionFunction_GetGenericContainer((SlangReflectionFunction*)this);
}
+
+ FunctionReflection* applySpecializations(GenericReflection* generic)
+ {
+ return (FunctionReflection*)spReflectionFunction_applySpecializations((SlangReflectionFunction*)this, (SlangReflectionGeneric*)generic);
+ }
};
struct GenericReflection
@@ -3599,6 +3656,10 @@ namespace slang
return spReflectionGeneric_GetConcreteIntVal((SlangReflectionGeneric*)this, (SlangReflectionVariable*)valueParam);
}
+ GenericReflection* applySpecializations(GenericReflection* generic)
+ {
+ return (GenericReflection*)spReflectionGeneric_applySpecializations((SlangReflectionGeneric*)this, (SlangReflectionGeneric*)generic);
+ }
};
struct EntryPointReflection
@@ -3701,6 +3762,7 @@ namespace slang
};
typedef struct ShaderReflection ProgramLayout;
+ typedef enum SlangReflectionGenericArgType GenericArgType;
struct ShaderReflection
{
@@ -3820,6 +3882,32 @@ namespace slang
outDiagnostics);
}
+ GenericReflection* specializeGeneric(
+ GenericReflection* generic,
+ SlangInt specializationArgCount,
+ GenericArgType const* specializationArgTypes,
+ GenericArgReflection const* specializationArgVals,
+ ISlangBlob** outDiagnostics)
+ {
+ return (GenericReflection*) spReflection_specializeGeneric(
+ (SlangReflection*) this,
+ (SlangReflectionGeneric*) generic,
+ specializationArgCount,
+ (SlangReflectionGenericArgType const*) specializationArgTypes,
+ (SlangReflectionGenericArg const*) specializationArgVals,
+ outDiagnostics);
+ }
+
+ bool isSubType(
+ TypeReflection* subType,
+ TypeReflection* superType)
+ {
+ return spReflection_isSubType(
+ (SlangReflection*) this,
+ (SlangReflectionType*) subType,
+ (SlangReflectionType*) superType);
+ }
+
SlangUInt getHashedStringCount() const { return spReflection_getHashedStringCount((SlangReflection*)this); }
const char* getHashedString(SlangUInt index, size_t* outCount) const
@@ -3849,8 +3937,14 @@ namespace slang
Module = SLANG_DECL_KIND_MODULE,
Generic = SLANG_DECL_KIND_GENERIC,
Variable = SLANG_DECL_KIND_VARIABLE,
+ Namespace = SLANG_DECL_KIND_NAMESPACE,
};
+ char const* getName()
+ {
+ return spReflectionDecl_getName((SlangReflectionDecl*) this);
+ }
+
Kind getKind()
{
return (Kind)spReflectionDecl_getKind((SlangReflectionDecl*)this);
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 5ebefb888..9d796f48d 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -426,6 +426,8 @@ namespace Slang
String const& name,
LookupMask mask,
DiagnosticSink* sink);
+
+ bool isSubType(Type* subType, Type* superType);
Dictionary<String, IntVal*>& getMangledNameToIntValMap();
ConstantIntVal* tryFoldIntVal(IntVal* intVal);
@@ -2162,6 +2164,11 @@ namespace Slang
void setFileSystem(ISlangFileSystem* fileSystem);
+ DeclRef<Decl> specializeGeneric(
+ DeclRef<Decl> declRef,
+ List<Expr*> argExprs,
+ DiagnosticSink* sink);
+
DiagnosticSink::Flags diagnosticSinkFlags = 0;
bool m_requireCacheFileSystem = false;
@@ -3374,6 +3381,16 @@ SLANG_FORCE_INLINE slang::TypeReflection* asExternal(Type* type)
return reinterpret_cast<slang::TypeReflection*>(type);
}
+SLANG_FORCE_INLINE DeclRef<Decl> asInternal(slang::GenericReflection* generic)
+{
+ return DeclRef<Decl>(reinterpret_cast<DeclRefBase*>(generic));
+}
+
+SLANG_FORCE_INLINE slang::GenericReflection* asExternal(DeclRef<Decl> generic)
+{
+ return reinterpret_cast<slang::GenericReflection*>(generic.declRefBase);
+}
+
SLANG_FORCE_INLINE TypeLayout* asInternal(slang::TypeLayoutReflection* type)
{
return reinterpret_cast<TypeLayout*>(type);
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index edfd27489..efa9a20a9 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -670,6 +670,17 @@ SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName
return 0;
}
+SLANG_API SlangReflectionType* spReflectionType_applySpecializations(SlangReflectionType* inType, SlangReflectionGeneric* generic)
+{
+ auto type = convert(inType);
+ auto genericDeclRef = convertGenericToDeclRef(generic);
+
+ if (!type || !genericDeclRef)
+ return nullptr;
+
+ return convert(substituteType(SubstitutionSet(genericDeclRef), type->getASTBuilderForReflection(), type));
+}
+
SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionType* inType)
{
auto type = convert(inType);
@@ -859,6 +870,21 @@ SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * re
try
{
Type* result = program->getTypeFromString(name, &sink);
+
+ ASTBuilder* astBuilder = program->getLinkage()->getASTBuilder();
+
+ if (auto genericType = as<GenericDeclRefType>(result))
+ {
+ auto genericDeclRef = genericType->getDeclRef();
+ auto innerDeclRef = substituteDeclRef(
+ SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
+ return convert(
+ DeclRefType::create(
+ astBuilder,
+ createDefaultSubstitutionsIfNeeded(
+ astBuilder, nullptr, innerDeclRef)));
+ }
+
if (as<ErrorType>(result))
return nullptr;
return (SlangReflectionType*)result;
@@ -869,6 +895,35 @@ SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * re
}
}
+
+SLANG_API bool spReflection_isSubType(
+ SlangReflection * reflection,
+ SlangReflectionType* subType,
+ SlangReflectionType* superType)
+{
+ auto programLayout = convert(reflection);
+ auto program = programLayout->getProgram();
+
+ // TODO: We should extend this API to support getting error messages
+ // when type lookup fails.
+ //
+ Slang::DiagnosticSink sink(
+ programLayout->getTargetReq()->getLinkage()->getSourceManager(),
+ Lexer::sourceLocationLexer);
+
+ try
+ {
+ auto sub = convert(subType);
+ auto super = convert(superType);
+
+ return program->isSubType(sub, super);
+ }
+ catch( ... )
+ {
+ return false;
+ }
+}
+
SlangReflectionGeneric* getInnermostGenericParent(DeclRef<Decl> declRef)
{
auto decl = declRef.getDecl();
@@ -890,12 +945,17 @@ SlangReflectionGeneric* getInnermostGenericParent(DeclRef<Decl> declRef)
SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type)
{
- auto declRefType = as<DeclRefType>(convert(type));
- if (!declRefType)
- return nullptr;
+ auto slangType = convert(type);
+ if (auto declRefType = as<DeclRefType>(slangType))
+ {
+ return getInnermostGenericParent(declRefType->getDeclRef());
+ }
+ else if (auto genericDeclRefType = as<GenericDeclRefType>(slangType))
+ {
+ return getInnermostGenericParent(genericDeclRefType->getDeclRef());
+ }
- auto declRef = declRefType->getDeclRef();
- return getInnermostGenericParent(declRef);
+ return nullptr;
}
SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout(
@@ -2778,6 +2838,19 @@ SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(Slang
return getInnermostGenericParent(declRef);
}
+SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic)
+{
+ auto declRef = convert(var);
+ auto genericDeclRef = convertGenericToDeclRef(generic);
+ if (!declRef || !genericDeclRef)
+ return nullptr;
+
+ auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder();
+
+ auto substDeclRef = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, declRef);
+ return convert(substDeclRef);
+}
+
// Variable Layout Reflection
SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangReflectionVariableLayout* inVarLayout)
@@ -3002,6 +3075,19 @@ SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(Slang
return getInnermostGenericParent(declRef);
}
+SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic)
+{
+ auto declRef = convert(func);
+ auto genericDeclRef = convertGenericToDeclRef(generic);
+ if (!declRef || !genericDeclRef)
+ return nullptr;
+
+ auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder();
+
+ auto substDeclRef = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, declRef);
+ return convert(substDeclRef.as<FunctionDeclBase>());
+}
+
// Abstract decl reflection
SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl)
@@ -3027,6 +3113,16 @@ SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* pa
return nullptr;
}
+SLANG_API char const* spReflectionDecl_getName(SlangReflectionDecl* decl)
+{
+ Decl* slangDecl = (Decl*)decl;
+
+ if (auto name = slangDecl->getName())
+ return getText(name).getBuffer();
+
+ return nullptr;
+}
+
SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl)
{
Decl* slangDecl = (Decl*)decl;
@@ -3050,6 +3146,10 @@ SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl)
{
return SLANG_DECL_KIND_MODULE;
}
+ else if (as<NamespaceDecl>(slangDecl))
+ {
+ return SLANG_DECL_KIND_NAMESPACE;
+ }
else
return SLANG_DECL_KIND_UNSUPPORTED_FOR_REFLECTION;
}
@@ -3276,6 +3376,19 @@ SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(SlangReflectionGeneric*
return 0;
}
+SLANG_API SlangReflectionGeneric* spReflectionGeneric_applySpecializations(SlangReflectionGeneric* currGeneric, SlangReflectionGeneric* generic)
+{
+ auto declRef = convertGenericToDeclRef(currGeneric);
+ auto genericDeclRef = convertGenericToDeclRef(generic);
+ if (!declRef || !genericDeclRef)
+ return nullptr;
+
+ auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder();
+
+ auto substDeclRef = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, declRef);
+ return convertDeclToGeneric(substDeclRef);
+}
+
// Shader Parameter Reflection
@@ -3681,6 +3794,71 @@ SLANG_API SlangReflectionType* spReflection_specializeType(
return convert(specializedType);
}
+
+SLANG_API SlangReflectionGeneric* spReflection_specializeGeneric(
+ SlangReflection* inProgramLayout,
+ SlangReflectionGeneric* generic,
+ SlangInt argCount,
+ SlangReflectionGenericArgType const* argTypes,
+ SlangReflectionGenericArg const* args,
+ ISlangBlob** outDiagnostics)
+{
+ auto programLayout = convert(inProgramLayout);
+ auto slangGeneric = convertGenericToDeclRef(generic);
+ if (!slangGeneric) return nullptr;
+ auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder();
+
+ auto linkage = programLayout->getProgram()->getLinkage();
+
+ DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
+
+ List<Expr*> argExprs;
+ for (SlangInt i = 0; i < argCount; ++i)
+ {
+ auto argType = argTypes[i];
+ auto arg = args[i];
+
+ switch (argType)
+ {
+ case SLANG_GENERIC_ARG_TYPE:
+ {
+ auto type = convert(arg.typeVal);
+ auto declRefType = as<DeclRefType>(type);
+ auto declRefExpr = astBuilder->create<DeclRefExpr>();
+ declRefExpr->declRef = declRefType->getDeclRef();
+ declRefExpr->type.type = astBuilder->getOrCreate<TypeType>(type);
+ argExprs.add(declRefExpr);
+ break;
+ }
+ case SLANG_GENERIC_ARG_INT:
+ {
+ auto literalExpr = astBuilder->create<IntegerLiteralExpr>();
+ literalExpr->value = args[i].intVal;
+ literalExpr->type = astBuilder->getIntType();
+ argExprs.add(literalExpr);
+ break;
+ }
+ case SLANG_GENERIC_ARG_BOOL:
+ {
+ auto literalExpr = astBuilder->create<BoolLiteralExpr>();
+ literalExpr->value = args[i].boolVal;
+ literalExpr->type = astBuilder->getBoolType();
+ argExprs.add(literalExpr);
+ break;
+ }
+ default:
+ // abort (TODO: throw a proper error)
+ return nullptr;
+ }
+ }
+
+ auto specialized = linkage->specializeGeneric(slangGeneric, argExprs, &sink);
+ sink.getBlobIfNeeded(outDiagnostics);
+
+ return convertDeclToGeneric(specialized);
+}
+
+
SLANG_API SlangUInt spReflection_getHashedStringCount(
SlangReflection* reflection)
{
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index cd8f426ea..ba7909aae 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -180,6 +180,13 @@ namespace Slang
List<Val*> getDefaultSubstitutionArgs(ASTBuilder* astBuilder, SemanticsVisitor* semantics, GenericDecl* genericDecl);
+ SubstitutionSet makeSubstitutionFromIncompleteSet(
+ ASTBuilder* astBuilder,
+ SemanticsVisitor* semantics,
+ DeclRef<GenericDecl> genericDeclRef,
+ Dictionary<Decl*, Val*> paramArgMap,
+ DiagnosticSink* sink);
+
Val::OperandView<Val> findInnerMostGenericArgs(SubstitutionSet subst);
ParameterDirection getParameterDirection(VarDeclBase* varDecl);
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 3072ef0a7..aa114e44d 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1348,6 +1348,43 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType(
return asExternal(specializedType);
}
+
+DeclRef<Decl> Linkage::specializeGeneric(
+ DeclRef<Decl> declRef,
+ List<Expr*> argExprs,
+ DiagnosticSink* sink)
+{
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
+ SLANG_ASSERT(declRef);
+
+ SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink);
+ SemanticsVisitor visitor(&sharedSemanticsContext);
+
+ // Create substituted parent decl ref.
+ auto decl = declRef.getDecl();
+
+ while (!as<GenericDecl>(decl))
+ {
+ decl = decl->parentDecl;
+ }
+
+ auto genericDecl = as<GenericDecl>(decl);
+ auto genericDeclRef = createDefaultSubstitutionsIfNeeded(getASTBuilder(), &visitor, DeclRef(genericDecl)).as<GenericDecl>();
+ genericDeclRef = substituteDeclRef(SubstitutionSet(declRef), getASTBuilder(), genericDeclRef).as<GenericDecl>();
+
+
+ DeclRefExpr* declRefExpr = getASTBuilder()->create<DeclRefExpr>();
+ declRefExpr->declRef = genericDeclRef;
+
+ GenericAppExpr* genericAppExpr = getASTBuilder()->create<GenericAppExpr>();
+ genericAppExpr->functionExpr = declRefExpr;
+ genericAppExpr->arguments = argExprs;
+
+ auto specializedDeclRef = as<DeclRefExpr>(visitor.checkGenericAppWithCheckedArgs(genericAppExpr))->declRef;
+
+ return specializedDeclRef;
+}
+
SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL Linkage::getTypeLayout(
slang::TypeReflection* inType,
SlangInt targetIndex,
@@ -2373,9 +2410,28 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType(
result = declRefExpr->declRef;
}
+ if (auto genericDeclRef = result.as<GenericDecl>())
+ {
+ result = createDefaultSubstitutionsIfNeeded(
+ astBuilder, &visitor, DeclRef(genericDeclRef.getDecl()->inner));
+ result = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, result);
+ }
+
return result;
}
+bool ComponentType::isSubType(Type* subType, Type* superType)
+{
+ SharedSemanticsContext sharedSemanticsContext(
+ getLinkage(),
+ nullptr,
+ nullptr);
+ SemanticsContext context(&sharedSemanticsContext);
+ SemanticsVisitor visitor(context);
+
+ return (visitor.isSubtype(subType, superType, IsSubTypeOptions::None) != nullptr);
+}
+
static void collectExportedConstantInContainer(
Dictionary<String, IntVal*>& dict,
ASTBuilder* builder,
@@ -6805,4 +6861,4 @@ SlangResult EndToEndCompileRequest::isParameterLocationUsed(Int entryPointIndex,
return SLANG_OK;
}
-} // namespace Slang
+} // namespace Slang \ No newline at end of file
diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
index fb35f323c..d98ea0423 100644
--- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp
@@ -70,6 +70,15 @@ SLANG_UNIT_TEST(declTreeReflection)
T j<let N : int>(T x, out int o) { o = N; return x; }
}
+
+ namespace MyNamespace
+ {
+ struct MyStruct
+ {
+ int x;
+ }
+ }
+
)";
auto moduleName = "moduleG" + String(Process::getId());
@@ -101,7 +110,7 @@ SLANG_UNIT_TEST(declTreeReflection)
auto moduleDeclReflection = module->getModuleReflection();
SLANG_CHECK(moduleDeclReflection != nullptr);
SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module);
- SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 7);
+ SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 8);
// First declaration should be a struct with 1 variable
auto firstDecl = moduleDeclReflection->getChild(0);
@@ -180,6 +189,11 @@ SLANG_UNIT_TEST(declTreeReflection)
auto innerStruct = genericReflection->getInnerDecl();
SLANG_CHECK(innerStruct->getKind() == slang::DeclReflection::Kind::Struct);
+ // Check that the seventh declaration is a namespace
+ auto seventhDecl = moduleDeclReflection->getChild(6);
+ SLANG_CHECK(seventhDecl->getKind() == slang::DeclReflection::Kind::Namespace);
+ SLANG_CHECK(UnownedStringSlice(seventhDecl->getName()) == "MyNamespace");
+
// Check type-lookup-by-name
{
@@ -262,7 +276,108 @@ SLANG_UNIT_TEST(declTreeReflection)
SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name
SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10);
}
+
+ // Check specializeGeneric() and applySpecializations()
+ {
+ auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType");
+ SLANG_CHECK(unspecializedType != nullptr);
+ auto halfType = compositeProgram->getLayout()->findTypeByName("half");
+ SLANG_CHECK(halfType != nullptr);
+
+ slang::GenericReflection* genericContainer = unspecializedType->getGenericContainer();
+ SLANG_CHECK(genericContainer != nullptr);
+ //auto typeParamT = genericContainer->getTypeParameter(0);
+
+ List<slang::GenericArgType> argTypes;
+ List<slang::GenericArgReflection> args;
+ argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE);
+ args.add({halfType});
+ auto specializedContainer = compositeProgram->getLayout()->specializeGeneric(
+ genericContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr);
+
+ SLANG_CHECK(specializedContainer != nullptr);
+
+ auto specializedType = unspecializedType->applySpecializations(specializedContainer);
+ SLANG_CHECK(specializedType != nullptr);
+ SLANG_CHECK(getTypeFullName(specializedType) == "MyGenericType<half>");
+
+ }
+
+ // Check specializeGeneric() and applySpecializations() on multiple levels (generic function nested in a generic struct)
+ {
+ auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType");
+ auto unspecializedFunc = compositeProgram->getLayout()->findFunctionByNameInType(unspecializedType, "j");
+
+ SLANG_CHECK(unspecializedFunc != nullptr);
+ auto halfType = compositeProgram->getLayout()->findTypeByName("half");
+ SLANG_CHECK(halfType != nullptr);
+
+ slang::GenericReflection* genericFuncContainer = unspecializedFunc->getGenericContainer();
+ SLANG_CHECK(genericFuncContainer != nullptr);
+ slang::GenericReflection* genericStructContainer = genericFuncContainer->getOuterGenericContainer();
+ SLANG_CHECK(genericStructContainer != nullptr);
+
+ // Specialize the outer container with half
+ List<slang::GenericArgType> argTypes;
+ List<slang::GenericArgReflection> args;
+ argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE);
+ args.add({halfType});
+ auto specializedStructContainer = compositeProgram->getLayout()->specializeGeneric(
+ genericStructContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr);
+ SLANG_CHECK(specializedStructContainer != nullptr);
+
+ // apply T=half. N is still left unspecialized.
+ genericFuncContainer = genericFuncContainer->applySpecializations(specializedStructContainer);
+
+ // Specialize the inner container with 10 separately..
+ argTypes.clear();
+ args.clear();
+
+ slang::GenericArgReflection argN;
+ argN.intVal = 10;
+ argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_INT);
+ args.add(argN);
+
+ auto specializedFuncContainer = compositeProgram->getLayout()->specializeGeneric(
+ genericFuncContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr);
+
+ auto specializedFunc = unspecializedFunc->applySpecializations(specializedFuncContainer);
+ SLANG_CHECK(specializedFunc != nullptr);
+
+ // ------ check the specialized function
+ auto specializationInfo = specializedFunc->getGenericContainer();
+ SLANG_CHECK(specializationInfo != nullptr);
+ SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "j");
+ SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
+ SLANG_CHECK(specializationInfo->getValueParameterCount() == 1);
+ auto valueParam = specializationInfo->getValueParameter(0);
+ SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name
+ SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10);
+
+ // check outer container
+ specializationInfo = specializationInfo->getOuterGenericContainer();
+ SLANG_CHECK(specializationInfo != nullptr);
+ SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "MyGenericType");
+ SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic);
+ // Check type parameters
+ SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1);
+ auto typeParam = specializationInfo->getTypeParameter(0);
+ SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "T"); // generic name
+ SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "half");
+ }
+
+ // Check sub-type relations
+ {
+ auto floatType = compositeProgram->getLayout()->findTypeByName("float");
+ SLANG_CHECK(floatType != nullptr);
+ auto diffType = compositeProgram->getLayout()->findTypeByName("IDifferentiable");
+ SLANG_CHECK(diffType != nullptr);
+ SLANG_CHECK(compositeProgram->getLayout()->isSubType(floatType, diffType) == true);
+
+ auto uintType = compositeProgram->getLayout()->findTypeByName("uint");
+ SLANG_CHECK(compositeProgram->getLayout()->isSubType(uintType, diffType) == false);
+ }
// Check iterators
{
@@ -271,7 +386,7 @@ SLANG_UNIT_TEST(declTreeReflection)
{
count++;
}
- SLANG_CHECK(count == 7);
+ SLANG_CHECK(count == 8);
count = 0;
for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>())
@@ -293,6 +408,13 @@ SLANG_UNIT_TEST(declTreeReflection)
count++;
}
SLANG_CHECK(count == 1);
+
+ count = 0;
+ for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>())
+ {
+ count++;
+ }
+ SLANG_CHECK(count == 1);
}
}