diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-08-27 16:47:05 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-27 16:47:05 -0400 |
| commit | 4aac22da6ae902eca1e7750f4e5b83ba238b5874 (patch) | |
| tree | f266e3c7c3a646473ac4af80ddbcd72702ced917 /include/slang.h | |
| parent | d40c143eb4f19f1dfd0d0dcf9b718be6e495ca27 (diff) | |
Add ability to specialize generic references to functions, types and more (#4909)
* More reflection API features.
+ Lookup methods and members (by string) on types
+ Fix issue with looking up non-static members through the scope operator '::'
+ `GenericReflection`: Cast a decl to generic to access unspecialized generic parameter names and constraints
+ `GenericReflection`: Use `getGenericContainer()` from function, variable or type to access the 'nearest' generic parent along with specialization info
+ `GenericReflection::getConcreteType` and `GenericReflection::getConcreteIntVal`: to get the concrete type of a param in the context of the reflection object
+ `GenericReflection::getOuterGenericContainer` to go up one level and get the outer generic declarations (if there are more than one enclosing generic scopes)
+ `DeclReflection::getParent`: go to parent declaration.
+ Change `VariableReflection` to be a `DeclRef` rather than a decl (allows us to return properly substituted types for methods, members, and more)
* Fix Falcor issue
* Initial namespace reflection support
* FIx issue with specializing witness tables
* Add API method for specializing parameters of a generic decl
* Add ability to specialize generic references to functions, types and more
This PR adds the following end-points:
- `specializeGeneric()` method that can be called on a generic reflection to substitute arguments for generic type and value parameters. It returns another generic reflection, but this time with the appropriate substitution.
- `applySpecializations()` method to then copy these specializations onto an existing type or function reflection.
- `isSubType()` to check if a type is a subtype of another type (useful to check if a type is differentiable by checking `IDifferentiable`)
This PR also:
- Adds `DeclReflection::Kind::Namespace` so that namespace containers are correctly reflected when walking the decl-tree. the name can be obtained through `getName()` but there's no need to cast to a namespace (since there's nothing else we can do with a namespace decl)
- Fixes an issue with name-based lookups that fail if a type or function is referenced without specializations. Its helpful to be able to form a reference to a function with default substitutions, so that we can we can specialize it later (either directly, or via argument types).
* Update slang.h
* Fix up naming
* Update slang-compiler.h
* Update slang-reflection-api.cpp
* Update slang.cpp
* Update slang.cpp
* Update slang.cpp
* Use `checkGenericAppWithCheckedArgs` to do specialization
* Update slang-reflection-api.cpp
* Update slang-check-decl.cpp
Diffstat (limited to 'include/slang.h')
| -rw-r--r-- | include/slang.h | 102 |
1 files changed, 98 insertions, 4 deletions
diff --git a/include/slang.h b/include/slang.h index 8f4d53a0d..7c417e74e 100644 --- a/include/slang.h +++ b/include/slang.h @@ -2110,6 +2110,20 @@ extern "C" typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute; typedef struct SlangReflectionFunction SlangReflectionFunction; typedef struct SlangReflectionGeneric SlangReflectionGeneric; + + union SlangReflectionGenericArg + { + SlangReflectionType* typeVal; + int64_t intVal; + bool boolVal; + }; + + enum SlangReflectionGenericArgType + { + SLANG_GENERIC_ARG_TYPE = 0, + SLANG_GENERIC_ARG_INT = 1, + SLANG_GENERIC_ARG_BOOL = 2 + }; /* Type aliases to maintain backward compatibility. @@ -2179,7 +2193,8 @@ extern "C" SLANG_DECL_KIND_FUNC, SLANG_DECL_KIND_MODULE, SLANG_DECL_KIND_GENERIC, - SLANG_DECL_KIND_VARIABLE + SLANG_DECL_KIND_VARIABLE, + SLANG_DECL_KIND_NAMESPACE }; #ifndef SLANG_RESOURCE_SHAPE @@ -2428,6 +2443,7 @@ extern "C" SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* type); SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* type, unsigned int index); SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* type, char const* name); + SLANG_API SlangReflectionType* spReflectionType_applySpecializations(SlangReflectionType* type, SlangReflectionGeneric* generic); SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* type); SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* type, unsigned index); @@ -2544,6 +2560,7 @@ extern "C" SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * globalSession, char const* name); SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar); SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var); + SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic); // Variable Layout Reflection @@ -2570,11 +2587,13 @@ extern "C" SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* func, unsigned index); SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func); SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func); + SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic); // Abstract Decl Reflection SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl); SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* parentDecl, unsigned int index); + SLANG_API char const* spReflectionDecl_getName(SlangReflectionDecl* decl); SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl); SLANG_API SlangReflectionFunction* spReflectionDecl_castToFunction(SlangReflectionDecl* decl); SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflectionDecl* decl); @@ -2597,6 +2616,7 @@ extern "C" SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(SlangReflectionGeneric* generic); SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam); SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(SlangReflectionGeneric* generic, SlangReflectionVariable* valueParam); + SLANG_API SlangReflectionGeneric* spReflectionGeneric_applySpecializations(SlangReflectionGeneric* currGeneric, SlangReflectionGeneric* generic); /** Get the stage that a variable belongs to (if any). @@ -2698,12 +2718,25 @@ extern "C" SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection); SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection); - SLANG_API SlangReflectionType* spReflection_specializeType( + SLANG_API SlangReflectionType* spReflection_specializeType( SlangReflection* reflection, SlangReflectionType* type, SlangInt specializationArgCount, SlangReflectionType* const* specializationArgs, ISlangBlob** outDiagnostics); + + SLANG_API SlangReflectionGeneric* spReflection_specializeGeneric( + SlangReflection* inProgramLayout, + SlangReflectionGeneric* generic, + SlangInt argCount, + SlangReflectionGenericArgType const* argTypes, + SlangReflectionGenericArg const* args, + ISlangBlob** outDiagnostics); + + SLANG_API bool spReflection_isSubType( + SlangReflection * reflection, + SlangReflectionType* subType, + SlangReflectionType* superType); /// Get the number of hashed strings SLANG_API SlangUInt spReflection_getHashedStringCount( @@ -2750,6 +2783,13 @@ namespace slang struct FunctionReflection; struct GenericReflection; + union GenericArgReflection + { + TypeReflection* typeVal; + int64_t intVal; + bool boolVal; + }; + struct UserAttribute { char const* getName() @@ -2919,18 +2959,25 @@ namespace slang { return spReflectionType_GetUserAttributeCount((SlangReflectionType*)this); } + UserAttribute* getUserAttributeByIndex(unsigned int index) { return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index); } + UserAttribute* findUserAttributeByName(char const* name) { return (UserAttribute*)spReflectionType_FindUserAttributeByName((SlangReflectionType*)this, name); } - SlangReflectionGeneric* getGenericContainer() + TypeReflection* applySpecializations(GenericReflection* generic) + { + return (TypeReflection*)spReflectionType_applySpecializations((SlangReflectionType*)this, (SlangReflectionGeneric*)generic); + } + + GenericReflection* getGenericContainer() { - return (SlangReflectionGeneric*) spReflectionType_GetGenericContainer((SlangReflectionType*) this); + return (GenericReflection*) spReflectionType_GetGenericContainer((SlangReflectionType*) this); } }; @@ -3399,6 +3446,11 @@ namespace slang { return (GenericReflection*)spReflectionVariable_GetGenericContainer((SlangReflectionVariable*)this); } + + VariableReflection* applySpecializations(GenericReflection* generic) + { + return (VariableReflection*)spReflectionVariable_applySpecializations((SlangReflectionVariable*)this, (SlangReflectionGeneric*)generic); + } }; struct VariableLayoutReflection @@ -3529,6 +3581,11 @@ namespace slang { return (GenericReflection*)spReflectionFunction_GetGenericContainer((SlangReflectionFunction*)this); } + + FunctionReflection* applySpecializations(GenericReflection* generic) + { + return (FunctionReflection*)spReflectionFunction_applySpecializations((SlangReflectionFunction*)this, (SlangReflectionGeneric*)generic); + } }; struct GenericReflection @@ -3599,6 +3656,10 @@ namespace slang return spReflectionGeneric_GetConcreteIntVal((SlangReflectionGeneric*)this, (SlangReflectionVariable*)valueParam); } + GenericReflection* applySpecializations(GenericReflection* generic) + { + return (GenericReflection*)spReflectionGeneric_applySpecializations((SlangReflectionGeneric*)this, (SlangReflectionGeneric*)generic); + } }; struct EntryPointReflection @@ -3701,6 +3762,7 @@ namespace slang }; typedef struct ShaderReflection ProgramLayout; + typedef enum SlangReflectionGenericArgType GenericArgType; struct ShaderReflection { @@ -3820,6 +3882,32 @@ namespace slang outDiagnostics); } + GenericReflection* specializeGeneric( + GenericReflection* generic, + SlangInt specializationArgCount, + GenericArgType const* specializationArgTypes, + GenericArgReflection const* specializationArgVals, + ISlangBlob** outDiagnostics) + { + return (GenericReflection*) spReflection_specializeGeneric( + (SlangReflection*) this, + (SlangReflectionGeneric*) generic, + specializationArgCount, + (SlangReflectionGenericArgType const*) specializationArgTypes, + (SlangReflectionGenericArg const*) specializationArgVals, + outDiagnostics); + } + + bool isSubType( + TypeReflection* subType, + TypeReflection* superType) + { + return spReflection_isSubType( + (SlangReflection*) this, + (SlangReflectionType*) subType, + (SlangReflectionType*) superType); + } + SlangUInt getHashedStringCount() const { return spReflection_getHashedStringCount((SlangReflection*)this); } const char* getHashedString(SlangUInt index, size_t* outCount) const @@ -3849,8 +3937,14 @@ namespace slang Module = SLANG_DECL_KIND_MODULE, Generic = SLANG_DECL_KIND_GENERIC, Variable = SLANG_DECL_KIND_VARIABLE, + Namespace = SLANG_DECL_KIND_NAMESPACE, }; + char const* getName() + { + return spReflectionDecl_getName((SlangReflectionDecl*) this); + } + Kind getKind() { return (Kind)spReflectionDecl_getKind((SlangReflectionDecl*)this); |
