summaryrefslogtreecommitdiffstats
path: root/include
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-08-27 16:47:05 -0400
committerGitHub <noreply@github.com>2024-08-27 16:47:05 -0400
commit4aac22da6ae902eca1e7750f4e5b83ba238b5874 (patch)
treef266e3c7c3a646473ac4af80ddbcd72702ced917 /include
parentd40c143eb4f19f1dfd0d0dcf9b718be6e495ca27 (diff)
Add ability to specialize generic references to functions, types and more (#4909)
* More reflection API features. + Lookup methods and members (by string) on types + Fix issue with looking up non-static members through the scope operator '::' + `GenericReflection`: Cast a decl to generic to access unspecialized generic parameter names and constraints + `GenericReflection`: Use `getGenericContainer()` from function, variable or type to access the 'nearest' generic parent along with specialization info + `GenericReflection::getConcreteType` and `GenericReflection::getConcreteIntVal`: to get the concrete type of a param in the context of the reflection object + `GenericReflection::getOuterGenericContainer` to go up one level and get the outer generic declarations (if there are more than one enclosing generic scopes) + `DeclReflection::getParent`: go to parent declaration. + Change `VariableReflection` to be a `DeclRef` rather than a decl (allows us to return properly substituted types for methods, members, and more) * Fix Falcor issue * Initial namespace reflection support * FIx issue with specializing witness tables * Add API method for specializing parameters of a generic decl * Add ability to specialize generic references to functions, types and more This PR adds the following end-points: - `specializeGeneric()` method that can be called on a generic reflection to substitute arguments for generic type and value parameters. It returns another generic reflection, but this time with the appropriate substitution. - `applySpecializations()` method to then copy these specializations onto an existing type or function reflection. - `isSubType()` to check if a type is a subtype of another type (useful to check if a type is differentiable by checking `IDifferentiable`) This PR also: - Adds `DeclReflection::Kind::Namespace` so that namespace containers are correctly reflected when walking the decl-tree. the name can be obtained through `getName()` but there's no need to cast to a namespace (since there's nothing else we can do with a namespace decl) - Fixes an issue with name-based lookups that fail if a type or function is referenced without specializations. Its helpful to be able to form a reference to a function with default substitutions, so that we can we can specialize it later (either directly, or via argument types). * Update slang.h * Fix up naming * Update slang-compiler.h * Update slang-reflection-api.cpp * Update slang.cpp * Update slang.cpp * Update slang.cpp * Use `checkGenericAppWithCheckedArgs` to do specialization * Update slang-reflection-api.cpp * Update slang-check-decl.cpp
Diffstat (limited to 'include')
-rw-r--r--include/slang.h102
1 files changed, 98 insertions, 4 deletions
diff --git a/include/slang.h b/include/slang.h
index 8f4d53a0d..7c417e74e 100644
--- a/include/slang.h
+++ b/include/slang.h
@@ -2110,6 +2110,20 @@ extern "C"
typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute;
typedef struct SlangReflectionFunction SlangReflectionFunction;
typedef struct SlangReflectionGeneric SlangReflectionGeneric;
+
+ union SlangReflectionGenericArg
+ {
+ SlangReflectionType* typeVal;
+ int64_t intVal;
+ bool boolVal;
+ };
+
+ enum SlangReflectionGenericArgType
+ {
+ SLANG_GENERIC_ARG_TYPE = 0,
+ SLANG_GENERIC_ARG_INT = 1,
+ SLANG_GENERIC_ARG_BOOL = 2
+ };
/*
Type aliases to maintain backward compatibility.
@@ -2179,7 +2193,8 @@ extern "C"
SLANG_DECL_KIND_FUNC,
SLANG_DECL_KIND_MODULE,
SLANG_DECL_KIND_GENERIC,
- SLANG_DECL_KIND_VARIABLE
+ SLANG_DECL_KIND_VARIABLE,
+ SLANG_DECL_KIND_NAMESPACE
};
#ifndef SLANG_RESOURCE_SHAPE
@@ -2428,6 +2443,7 @@ extern "C"
SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* type);
SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* type, unsigned int index);
SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* type, char const* name);
+ SLANG_API SlangReflectionType* spReflectionType_applySpecializations(SlangReflectionType* type, SlangReflectionGeneric* generic);
SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* type);
SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* type, unsigned index);
@@ -2544,6 +2560,7 @@ extern "C"
SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * globalSession, char const* name);
SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar);
SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var);
+ SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic);
// Variable Layout Reflection
@@ -2570,11 +2587,13 @@ extern "C"
SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* func, unsigned index);
SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func);
SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func);
+ SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic);
// Abstract Decl Reflection
SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl);
SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* parentDecl, unsigned int index);
+ SLANG_API char const* spReflectionDecl_getName(SlangReflectionDecl* decl);
SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl);
SLANG_API SlangReflectionFunction* spReflectionDecl_castToFunction(SlangReflectionDecl* decl);
SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflectionDecl* decl);
@@ -2597,6 +2616,7 @@ extern "C"
SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(SlangReflectionGeneric* generic);
SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam);
SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(SlangReflectionGeneric* generic, SlangReflectionVariable* valueParam);
+ SLANG_API SlangReflectionGeneric* spReflectionGeneric_applySpecializations(SlangReflectionGeneric* currGeneric, SlangReflectionGeneric* generic);
/** Get the stage that a variable belongs to (if any).
@@ -2698,12 +2718,25 @@ extern "C"
SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection);
SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection);
- SLANG_API SlangReflectionType* spReflection_specializeType(
+ SLANG_API SlangReflectionType* spReflection_specializeType(
SlangReflection* reflection,
SlangReflectionType* type,
SlangInt specializationArgCount,
SlangReflectionType* const* specializationArgs,
ISlangBlob** outDiagnostics);
+
+ SLANG_API SlangReflectionGeneric* spReflection_specializeGeneric(
+ SlangReflection* inProgramLayout,
+ SlangReflectionGeneric* generic,
+ SlangInt argCount,
+ SlangReflectionGenericArgType const* argTypes,
+ SlangReflectionGenericArg const* args,
+ ISlangBlob** outDiagnostics);
+
+ SLANG_API bool spReflection_isSubType(
+ SlangReflection * reflection,
+ SlangReflectionType* subType,
+ SlangReflectionType* superType);
/// Get the number of hashed strings
SLANG_API SlangUInt spReflection_getHashedStringCount(
@@ -2750,6 +2783,13 @@ namespace slang
struct FunctionReflection;
struct GenericReflection;
+ union GenericArgReflection
+ {
+ TypeReflection* typeVal;
+ int64_t intVal;
+ bool boolVal;
+ };
+
struct UserAttribute
{
char const* getName()
@@ -2919,18 +2959,25 @@ namespace slang
{
return spReflectionType_GetUserAttributeCount((SlangReflectionType*)this);
}
+
UserAttribute* getUserAttributeByIndex(unsigned int index)
{
return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index);
}
+
UserAttribute* findUserAttributeByName(char const* name)
{
return (UserAttribute*)spReflectionType_FindUserAttributeByName((SlangReflectionType*)this, name);
}
- SlangReflectionGeneric* getGenericContainer()
+ TypeReflection* applySpecializations(GenericReflection* generic)
+ {
+ return (TypeReflection*)spReflectionType_applySpecializations((SlangReflectionType*)this, (SlangReflectionGeneric*)generic);
+ }
+
+ GenericReflection* getGenericContainer()
{
- return (SlangReflectionGeneric*) spReflectionType_GetGenericContainer((SlangReflectionType*) this);
+ return (GenericReflection*) spReflectionType_GetGenericContainer((SlangReflectionType*) this);
}
};
@@ -3399,6 +3446,11 @@ namespace slang
{
return (GenericReflection*)spReflectionVariable_GetGenericContainer((SlangReflectionVariable*)this);
}
+
+ VariableReflection* applySpecializations(GenericReflection* generic)
+ {
+ return (VariableReflection*)spReflectionVariable_applySpecializations((SlangReflectionVariable*)this, (SlangReflectionGeneric*)generic);
+ }
};
struct VariableLayoutReflection
@@ -3529,6 +3581,11 @@ namespace slang
{
return (GenericReflection*)spReflectionFunction_GetGenericContainer((SlangReflectionFunction*)this);
}
+
+ FunctionReflection* applySpecializations(GenericReflection* generic)
+ {
+ return (FunctionReflection*)spReflectionFunction_applySpecializations((SlangReflectionFunction*)this, (SlangReflectionGeneric*)generic);
+ }
};
struct GenericReflection
@@ -3599,6 +3656,10 @@ namespace slang
return spReflectionGeneric_GetConcreteIntVal((SlangReflectionGeneric*)this, (SlangReflectionVariable*)valueParam);
}
+ GenericReflection* applySpecializations(GenericReflection* generic)
+ {
+ return (GenericReflection*)spReflectionGeneric_applySpecializations((SlangReflectionGeneric*)this, (SlangReflectionGeneric*)generic);
+ }
};
struct EntryPointReflection
@@ -3701,6 +3762,7 @@ namespace slang
};
typedef struct ShaderReflection ProgramLayout;
+ typedef enum SlangReflectionGenericArgType GenericArgType;
struct ShaderReflection
{
@@ -3820,6 +3882,32 @@ namespace slang
outDiagnostics);
}
+ GenericReflection* specializeGeneric(
+ GenericReflection* generic,
+ SlangInt specializationArgCount,
+ GenericArgType const* specializationArgTypes,
+ GenericArgReflection const* specializationArgVals,
+ ISlangBlob** outDiagnostics)
+ {
+ return (GenericReflection*) spReflection_specializeGeneric(
+ (SlangReflection*) this,
+ (SlangReflectionGeneric*) generic,
+ specializationArgCount,
+ (SlangReflectionGenericArgType const*) specializationArgTypes,
+ (SlangReflectionGenericArg const*) specializationArgVals,
+ outDiagnostics);
+ }
+
+ bool isSubType(
+ TypeReflection* subType,
+ TypeReflection* superType)
+ {
+ return spReflection_isSubType(
+ (SlangReflection*) this,
+ (SlangReflectionType*) subType,
+ (SlangReflectionType*) superType);
+ }
+
SlangUInt getHashedStringCount() const { return spReflection_getHashedStringCount((SlangReflection*)this); }
const char* getHashedString(SlangUInt index, size_t* outCount) const
@@ -3849,8 +3937,14 @@ namespace slang
Module = SLANG_DECL_KIND_MODULE,
Generic = SLANG_DECL_KIND_GENERIC,
Variable = SLANG_DECL_KIND_VARIABLE,
+ Namespace = SLANG_DECL_KIND_NAMESPACE,
};
+ char const* getName()
+ {
+ return spReflectionDecl_getName((SlangReflectionDecl*) this);
+ }
+
Kind getKind()
{
return (Kind)spReflectionDecl_getKind((SlangReflectionDecl*)this);