diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-08-27 16:47:05 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-27 16:47:05 -0400 |
| commit | 4aac22da6ae902eca1e7750f4e5b83ba238b5874 (patch) | |
| tree | f266e3c7c3a646473ac4af80ddbcd72702ced917 | |
| parent | d40c143eb4f19f1dfd0d0dcf9b718be6e495ca27 (diff) | |
Add ability to specialize generic references to functions, types and more (#4909)
* More reflection API features.
+ Lookup methods and members (by string) on types
+ Fix issue with looking up non-static members through the scope operator '::'
+ `GenericReflection`: Cast a decl to generic to access unspecialized generic parameter names and constraints
+ `GenericReflection`: Use `getGenericContainer()` from function, variable or type to access the 'nearest' generic parent along with specialization info
+ `GenericReflection::getConcreteType` and `GenericReflection::getConcreteIntVal`: to get the concrete type of a param in the context of the reflection object
+ `GenericReflection::getOuterGenericContainer` to go up one level and get the outer generic declarations (if there are more than one enclosing generic scopes)
+ `DeclReflection::getParent`: go to parent declaration.
+ Change `VariableReflection` to be a `DeclRef` rather than a decl (allows us to return properly substituted types for methods, members, and more)
* Fix Falcor issue
* Initial namespace reflection support
* FIx issue with specializing witness tables
* Add API method for specializing parameters of a generic decl
* Add ability to specialize generic references to functions, types and more
This PR adds the following end-points:
- `specializeGeneric()` method that can be called on a generic reflection to substitute arguments for generic type and value parameters. It returns another generic reflection, but this time with the appropriate substitution.
- `applySpecializations()` method to then copy these specializations onto an existing type or function reflection.
- `isSubType()` to check if a type is a subtype of another type (useful to check if a type is differentiable by checking `IDifferentiable`)
This PR also:
- Adds `DeclReflection::Kind::Namespace` so that namespace containers are correctly reflected when walking the decl-tree. the name can be obtained through `getName()` but there's no need to cast to a namespace (since there's nothing else we can do with a namespace decl)
- Fixes an issue with name-based lookups that fail if a type or function is referenced without specializations. Its helpful to be able to form a reference to a function with default substitutions, so that we can we can specialize it later (either directly, or via argument types).
* Update slang.h
* Fix up naming
* Update slang-compiler.h
* Update slang-reflection-api.cpp
* Update slang.cpp
* Update slang.cpp
* Update slang.cpp
* Use `checkGenericAppWithCheckedArgs` to do specialization
* Update slang-reflection-api.cpp
* Update slang-check-decl.cpp
| -rw-r--r-- | include/slang.h | 102 | ||||
| -rwxr-xr-x | source/slang/slang-compiler.h | 17 | ||||
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 188 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 7 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 58 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-decl-tree-reflection.cpp | 126 |
6 files changed, 486 insertions, 12 deletions
diff --git a/include/slang.h b/include/slang.h index 8f4d53a0d..7c417e74e 100644 --- a/include/slang.h +++ b/include/slang.h @@ -2110,6 +2110,20 @@ extern "C" typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute; typedef struct SlangReflectionFunction SlangReflectionFunction; typedef struct SlangReflectionGeneric SlangReflectionGeneric; + + union SlangReflectionGenericArg + { + SlangReflectionType* typeVal; + int64_t intVal; + bool boolVal; + }; + + enum SlangReflectionGenericArgType + { + SLANG_GENERIC_ARG_TYPE = 0, + SLANG_GENERIC_ARG_INT = 1, + SLANG_GENERIC_ARG_BOOL = 2 + }; /* Type aliases to maintain backward compatibility. @@ -2179,7 +2193,8 @@ extern "C" SLANG_DECL_KIND_FUNC, SLANG_DECL_KIND_MODULE, SLANG_DECL_KIND_GENERIC, - SLANG_DECL_KIND_VARIABLE + SLANG_DECL_KIND_VARIABLE, + SLANG_DECL_KIND_NAMESPACE }; #ifndef SLANG_RESOURCE_SHAPE @@ -2428,6 +2443,7 @@ extern "C" SLANG_API unsigned int spReflectionType_GetUserAttributeCount(SlangReflectionType* type); SLANG_API SlangReflectionUserAttribute* spReflectionType_GetUserAttribute(SlangReflectionType* type, unsigned int index); SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName(SlangReflectionType* type, char const* name); + SLANG_API SlangReflectionType* spReflectionType_applySpecializations(SlangReflectionType* type, SlangReflectionGeneric* generic); SLANG_API unsigned int spReflectionType_GetFieldCount(SlangReflectionType* type); SLANG_API SlangReflectionVariable* spReflectionType_GetFieldByIndex(SlangReflectionType* type, unsigned index); @@ -2544,6 +2560,7 @@ extern "C" SLANG_API SlangReflectionUserAttribute* spReflectionVariable_FindUserAttributeByName(SlangReflectionVariable* var, SlangSession * globalSession, char const* name); SLANG_API bool spReflectionVariable_HasDefaultValue(SlangReflectionVariable* inVar); SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(SlangReflectionVariable* var); + SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic); // Variable Layout Reflection @@ -2570,11 +2587,13 @@ extern "C" SLANG_API SlangReflectionVariable* spReflectionFunction_GetParameter(SlangReflectionFunction* func, unsigned index); SLANG_API SlangReflectionType* spReflectionFunction_GetResultType(SlangReflectionFunction* func); SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(SlangReflectionFunction* func); + SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic); // Abstract Decl Reflection SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl); SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* parentDecl, unsigned int index); + SLANG_API char const* spReflectionDecl_getName(SlangReflectionDecl* decl); SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl); SLANG_API SlangReflectionFunction* spReflectionDecl_castToFunction(SlangReflectionDecl* decl); SLANG_API SlangReflectionVariable* spReflectionDecl_castToVariable(SlangReflectionDecl* decl); @@ -2597,6 +2616,7 @@ extern "C" SLANG_API SlangReflectionGeneric* spReflectionGeneric_GetOuterGenericContainer(SlangReflectionGeneric* generic); SLANG_API SlangReflectionType* spReflectionGeneric_GetConcreteType(SlangReflectionGeneric* generic, SlangReflectionVariable* typeParam); SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(SlangReflectionGeneric* generic, SlangReflectionVariable* valueParam); + SLANG_API SlangReflectionGeneric* spReflectionGeneric_applySpecializations(SlangReflectionGeneric* currGeneric, SlangReflectionGeneric* generic); /** Get the stage that a variable belongs to (if any). @@ -2698,12 +2718,25 @@ extern "C" SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection); SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection); - SLANG_API SlangReflectionType* spReflection_specializeType( + SLANG_API SlangReflectionType* spReflection_specializeType( SlangReflection* reflection, SlangReflectionType* type, SlangInt specializationArgCount, SlangReflectionType* const* specializationArgs, ISlangBlob** outDiagnostics); + + SLANG_API SlangReflectionGeneric* spReflection_specializeGeneric( + SlangReflection* inProgramLayout, + SlangReflectionGeneric* generic, + SlangInt argCount, + SlangReflectionGenericArgType const* argTypes, + SlangReflectionGenericArg const* args, + ISlangBlob** outDiagnostics); + + SLANG_API bool spReflection_isSubType( + SlangReflection * reflection, + SlangReflectionType* subType, + SlangReflectionType* superType); /// Get the number of hashed strings SLANG_API SlangUInt spReflection_getHashedStringCount( @@ -2750,6 +2783,13 @@ namespace slang struct FunctionReflection; struct GenericReflection; + union GenericArgReflection + { + TypeReflection* typeVal; + int64_t intVal; + bool boolVal; + }; + struct UserAttribute { char const* getName() @@ -2919,18 +2959,25 @@ namespace slang { return spReflectionType_GetUserAttributeCount((SlangReflectionType*)this); } + UserAttribute* getUserAttributeByIndex(unsigned int index) { return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index); } + UserAttribute* findUserAttributeByName(char const* name) { return (UserAttribute*)spReflectionType_FindUserAttributeByName((SlangReflectionType*)this, name); } - SlangReflectionGeneric* getGenericContainer() + TypeReflection* applySpecializations(GenericReflection* generic) + { + return (TypeReflection*)spReflectionType_applySpecializations((SlangReflectionType*)this, (SlangReflectionGeneric*)generic); + } + + GenericReflection* getGenericContainer() { - return (SlangReflectionGeneric*) spReflectionType_GetGenericContainer((SlangReflectionType*) this); + return (GenericReflection*) spReflectionType_GetGenericContainer((SlangReflectionType*) this); } }; @@ -3399,6 +3446,11 @@ namespace slang { return (GenericReflection*)spReflectionVariable_GetGenericContainer((SlangReflectionVariable*)this); } + + VariableReflection* applySpecializations(GenericReflection* generic) + { + return (VariableReflection*)spReflectionVariable_applySpecializations((SlangReflectionVariable*)this, (SlangReflectionGeneric*)generic); + } }; struct VariableLayoutReflection @@ -3529,6 +3581,11 @@ namespace slang { return (GenericReflection*)spReflectionFunction_GetGenericContainer((SlangReflectionFunction*)this); } + + FunctionReflection* applySpecializations(GenericReflection* generic) + { + return (FunctionReflection*)spReflectionFunction_applySpecializations((SlangReflectionFunction*)this, (SlangReflectionGeneric*)generic); + } }; struct GenericReflection @@ -3599,6 +3656,10 @@ namespace slang return spReflectionGeneric_GetConcreteIntVal((SlangReflectionGeneric*)this, (SlangReflectionVariable*)valueParam); } + GenericReflection* applySpecializations(GenericReflection* generic) + { + return (GenericReflection*)spReflectionGeneric_applySpecializations((SlangReflectionGeneric*)this, (SlangReflectionGeneric*)generic); + } }; struct EntryPointReflection @@ -3701,6 +3762,7 @@ namespace slang }; typedef struct ShaderReflection ProgramLayout; + typedef enum SlangReflectionGenericArgType GenericArgType; struct ShaderReflection { @@ -3820,6 +3882,32 @@ namespace slang outDiagnostics); } + GenericReflection* specializeGeneric( + GenericReflection* generic, + SlangInt specializationArgCount, + GenericArgType const* specializationArgTypes, + GenericArgReflection const* specializationArgVals, + ISlangBlob** outDiagnostics) + { + return (GenericReflection*) spReflection_specializeGeneric( + (SlangReflection*) this, + (SlangReflectionGeneric*) generic, + specializationArgCount, + (SlangReflectionGenericArgType const*) specializationArgTypes, + (SlangReflectionGenericArg const*) specializationArgVals, + outDiagnostics); + } + + bool isSubType( + TypeReflection* subType, + TypeReflection* superType) + { + return spReflection_isSubType( + (SlangReflection*) this, + (SlangReflectionType*) subType, + (SlangReflectionType*) superType); + } + SlangUInt getHashedStringCount() const { return spReflection_getHashedStringCount((SlangReflection*)this); } const char* getHashedString(SlangUInt index, size_t* outCount) const @@ -3849,8 +3937,14 @@ namespace slang Module = SLANG_DECL_KIND_MODULE, Generic = SLANG_DECL_KIND_GENERIC, Variable = SLANG_DECL_KIND_VARIABLE, + Namespace = SLANG_DECL_KIND_NAMESPACE, }; + char const* getName() + { + return spReflectionDecl_getName((SlangReflectionDecl*) this); + } + Kind getKind() { return (Kind)spReflectionDecl_getKind((SlangReflectionDecl*)this); diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 5ebefb888..9d796f48d 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -426,6 +426,8 @@ namespace Slang String const& name, LookupMask mask, DiagnosticSink* sink); + + bool isSubType(Type* subType, Type* superType); Dictionary<String, IntVal*>& getMangledNameToIntValMap(); ConstantIntVal* tryFoldIntVal(IntVal* intVal); @@ -2162,6 +2164,11 @@ namespace Slang void setFileSystem(ISlangFileSystem* fileSystem); + DeclRef<Decl> specializeGeneric( + DeclRef<Decl> declRef, + List<Expr*> argExprs, + DiagnosticSink* sink); + DiagnosticSink::Flags diagnosticSinkFlags = 0; bool m_requireCacheFileSystem = false; @@ -3374,6 +3381,16 @@ SLANG_FORCE_INLINE slang::TypeReflection* asExternal(Type* type) return reinterpret_cast<slang::TypeReflection*>(type); } +SLANG_FORCE_INLINE DeclRef<Decl> asInternal(slang::GenericReflection* generic) +{ + return DeclRef<Decl>(reinterpret_cast<DeclRefBase*>(generic)); +} + +SLANG_FORCE_INLINE slang::GenericReflection* asExternal(DeclRef<Decl> generic) +{ + return reinterpret_cast<slang::GenericReflection*>(generic.declRefBase); +} + SLANG_FORCE_INLINE TypeLayout* asInternal(slang::TypeLayoutReflection* type) { return reinterpret_cast<TypeLayout*>(type); diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index edfd27489..efa9a20a9 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -670,6 +670,17 @@ SLANG_API SlangReflectionUserAttribute* spReflectionType_FindUserAttributeByName return 0; } +SLANG_API SlangReflectionType* spReflectionType_applySpecializations(SlangReflectionType* inType, SlangReflectionGeneric* generic) +{ + auto type = convert(inType); + auto genericDeclRef = convertGenericToDeclRef(generic); + + if (!type || !genericDeclRef) + return nullptr; + + return convert(substituteType(SubstitutionSet(genericDeclRef), type->getASTBuilderForReflection(), type)); +} + SLANG_API SlangResourceShape spReflectionType_GetResourceShape(SlangReflectionType* inType) { auto type = convert(inType); @@ -859,6 +870,21 @@ SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * re try { Type* result = program->getTypeFromString(name, &sink); + + ASTBuilder* astBuilder = program->getLinkage()->getASTBuilder(); + + if (auto genericType = as<GenericDeclRefType>(result)) + { + auto genericDeclRef = genericType->getDeclRef(); + auto innerDeclRef = substituteDeclRef( + SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); + return convert( + DeclRefType::create( + astBuilder, + createDefaultSubstitutionsIfNeeded( + astBuilder, nullptr, innerDeclRef))); + } + if (as<ErrorType>(result)) return nullptr; return (SlangReflectionType*)result; @@ -869,6 +895,35 @@ SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * re } } + +SLANG_API bool spReflection_isSubType( + SlangReflection * reflection, + SlangReflectionType* subType, + SlangReflectionType* superType) +{ + auto programLayout = convert(reflection); + auto program = programLayout->getProgram(); + + // TODO: We should extend this API to support getting error messages + // when type lookup fails. + // + Slang::DiagnosticSink sink( + programLayout->getTargetReq()->getLinkage()->getSourceManager(), + Lexer::sourceLocationLexer); + + try + { + auto sub = convert(subType); + auto super = convert(superType); + + return program->isSubType(sub, super); + } + catch( ... ) + { + return false; + } +} + SlangReflectionGeneric* getInnermostGenericParent(DeclRef<Decl> declRef) { auto decl = declRef.getDecl(); @@ -890,12 +945,17 @@ SlangReflectionGeneric* getInnermostGenericParent(DeclRef<Decl> declRef) SLANG_API SlangReflectionGeneric* spReflectionType_GetGenericContainer(SlangReflectionType* type) { - auto declRefType = as<DeclRefType>(convert(type)); - if (!declRefType) - return nullptr; + auto slangType = convert(type); + if (auto declRefType = as<DeclRefType>(slangType)) + { + return getInnermostGenericParent(declRefType->getDeclRef()); + } + else if (auto genericDeclRefType = as<GenericDeclRefType>(slangType)) + { + return getInnermostGenericParent(genericDeclRefType->getDeclRef()); + } - auto declRef = declRefType->getDeclRef(); - return getInnermostGenericParent(declRef); + return nullptr; } SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout( @@ -2778,6 +2838,19 @@ SLANG_API SlangReflectionGeneric* spReflectionVariable_GetGenericContainer(Slang return getInnermostGenericParent(declRef); } +SLANG_API SlangReflectionVariable* spReflectionVariable_applySpecializations(SlangReflectionVariable* var, SlangReflectionGeneric* generic) +{ + auto declRef = convert(var); + auto genericDeclRef = convertGenericToDeclRef(generic); + if (!declRef || !genericDeclRef) + return nullptr; + + auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder(); + + auto substDeclRef = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, declRef); + return convert(substDeclRef); +} + // Variable Layout Reflection SLANG_API SlangReflectionVariable* spReflectionVariableLayout_GetVariable(SlangReflectionVariableLayout* inVarLayout) @@ -3002,6 +3075,19 @@ SLANG_API SlangReflectionGeneric* spReflectionFunction_GetGenericContainer(Slang return getInnermostGenericParent(declRef); } +SLANG_API SlangReflectionFunction* spReflectionFunction_applySpecializations(SlangReflectionFunction* func, SlangReflectionGeneric* generic) +{ + auto declRef = convert(func); + auto genericDeclRef = convertGenericToDeclRef(generic); + if (!declRef || !genericDeclRef) + return nullptr; + + auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder(); + + auto substDeclRef = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, declRef); + return convert(substDeclRef.as<FunctionDeclBase>()); +} + // Abstract decl reflection SLANG_API unsigned int spReflectionDecl_getChildrenCount(SlangReflectionDecl* parentDecl) @@ -3027,6 +3113,16 @@ SLANG_API SlangReflectionDecl* spReflectionDecl_getChild(SlangReflectionDecl* pa return nullptr; } +SLANG_API char const* spReflectionDecl_getName(SlangReflectionDecl* decl) +{ + Decl* slangDecl = (Decl*)decl; + + if (auto name = slangDecl->getName()) + return getText(name).getBuffer(); + + return nullptr; +} + SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl) { Decl* slangDecl = (Decl*)decl; @@ -3050,6 +3146,10 @@ SLANG_API SlangDeclKind spReflectionDecl_getKind(SlangReflectionDecl* decl) { return SLANG_DECL_KIND_MODULE; } + else if (as<NamespaceDecl>(slangDecl)) + { + return SLANG_DECL_KIND_NAMESPACE; + } else return SLANG_DECL_KIND_UNSUPPORTED_FOR_REFLECTION; } @@ -3276,6 +3376,19 @@ SLANG_API int64_t spReflectionGeneric_GetConcreteIntVal(SlangReflectionGeneric* return 0; } +SLANG_API SlangReflectionGeneric* spReflectionGeneric_applySpecializations(SlangReflectionGeneric* currGeneric, SlangReflectionGeneric* generic) +{ + auto declRef = convertGenericToDeclRef(currGeneric); + auto genericDeclRef = convertGenericToDeclRef(generic); + if (!declRef || !genericDeclRef) + return nullptr; + + auto astBuilder = getModule(declRef.getDecl())->getLinkage()->getASTBuilder(); + + auto substDeclRef = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, declRef); + return convertDeclToGeneric(substDeclRef); +} + // Shader Parameter Reflection @@ -3681,6 +3794,71 @@ SLANG_API SlangReflectionType* spReflection_specializeType( return convert(specializedType); } + +SLANG_API SlangReflectionGeneric* spReflection_specializeGeneric( + SlangReflection* inProgramLayout, + SlangReflectionGeneric* generic, + SlangInt argCount, + SlangReflectionGenericArgType const* argTypes, + SlangReflectionGenericArg const* args, + ISlangBlob** outDiagnostics) +{ + auto programLayout = convert(inProgramLayout); + auto slangGeneric = convertGenericToDeclRef(generic); + if (!slangGeneric) return nullptr; + auto astBuilder = getModule(slangGeneric.getDecl())->getLinkage()->getASTBuilder(); + + auto linkage = programLayout->getProgram()->getLinkage(); + + DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); + + List<Expr*> argExprs; + for (SlangInt i = 0; i < argCount; ++i) + { + auto argType = argTypes[i]; + auto arg = args[i]; + + switch (argType) + { + case SLANG_GENERIC_ARG_TYPE: + { + auto type = convert(arg.typeVal); + auto declRefType = as<DeclRefType>(type); + auto declRefExpr = astBuilder->create<DeclRefExpr>(); + declRefExpr->declRef = declRefType->getDeclRef(); + declRefExpr->type.type = astBuilder->getOrCreate<TypeType>(type); + argExprs.add(declRefExpr); + break; + } + case SLANG_GENERIC_ARG_INT: + { + auto literalExpr = astBuilder->create<IntegerLiteralExpr>(); + literalExpr->value = args[i].intVal; + literalExpr->type = astBuilder->getIntType(); + argExprs.add(literalExpr); + break; + } + case SLANG_GENERIC_ARG_BOOL: + { + auto literalExpr = astBuilder->create<BoolLiteralExpr>(); + literalExpr->value = args[i].boolVal; + literalExpr->type = astBuilder->getBoolType(); + argExprs.add(literalExpr); + break; + } + default: + // abort (TODO: throw a proper error) + return nullptr; + } + } + + auto specialized = linkage->specializeGeneric(slangGeneric, argExprs, &sink); + sink.getBlobIfNeeded(outDiagnostics); + + return convertDeclToGeneric(specialized); +} + + SLANG_API SlangUInt spReflection_getHashedStringCount( SlangReflection* reflection) { diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index cd8f426ea..ba7909aae 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -180,6 +180,13 @@ namespace Slang List<Val*> getDefaultSubstitutionArgs(ASTBuilder* astBuilder, SemanticsVisitor* semantics, GenericDecl* genericDecl); + SubstitutionSet makeSubstitutionFromIncompleteSet( + ASTBuilder* astBuilder, + SemanticsVisitor* semantics, + DeclRef<GenericDecl> genericDeclRef, + Dictionary<Decl*, Val*> paramArgMap, + DiagnosticSink* sink); + Val::OperandView<Val> findInnerMostGenericArgs(SubstitutionSet subst); ParameterDirection getParameterDirection(VarDeclBase* varDecl); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 3072ef0a7..aa114e44d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1348,6 +1348,43 @@ SLANG_NO_THROW slang::TypeReflection* SLANG_MCALL Linkage::specializeType( return asExternal(specializedType); } + +DeclRef<Decl> Linkage::specializeGeneric( + DeclRef<Decl> declRef, + List<Expr*> argExprs, + DiagnosticSink* sink) +{ + SLANG_AST_BUILDER_RAII(getASTBuilder()); + SLANG_ASSERT(declRef); + + SharedSemanticsContext sharedSemanticsContext(this, nullptr, sink); + SemanticsVisitor visitor(&sharedSemanticsContext); + + // Create substituted parent decl ref. + auto decl = declRef.getDecl(); + + while (!as<GenericDecl>(decl)) + { + decl = decl->parentDecl; + } + + auto genericDecl = as<GenericDecl>(decl); + auto genericDeclRef = createDefaultSubstitutionsIfNeeded(getASTBuilder(), &visitor, DeclRef(genericDecl)).as<GenericDecl>(); + genericDeclRef = substituteDeclRef(SubstitutionSet(declRef), getASTBuilder(), genericDeclRef).as<GenericDecl>(); + + + DeclRefExpr* declRefExpr = getASTBuilder()->create<DeclRefExpr>(); + declRefExpr->declRef = genericDeclRef; + + GenericAppExpr* genericAppExpr = getASTBuilder()->create<GenericAppExpr>(); + genericAppExpr->functionExpr = declRefExpr; + genericAppExpr->arguments = argExprs; + + auto specializedDeclRef = as<DeclRefExpr>(visitor.checkGenericAppWithCheckedArgs(genericAppExpr))->declRef; + + return specializedDeclRef; +} + SLANG_NO_THROW slang::TypeLayoutReflection* SLANG_MCALL Linkage::getTypeLayout( slang::TypeReflection* inType, SlangInt targetIndex, @@ -2373,9 +2410,28 @@ DeclRef<Decl> ComponentType::findDeclFromStringInType( result = declRefExpr->declRef; } + if (auto genericDeclRef = result.as<GenericDecl>()) + { + result = createDefaultSubstitutionsIfNeeded( + astBuilder, &visitor, DeclRef(genericDeclRef.getDecl()->inner)); + result = substituteDeclRef(SubstitutionSet(genericDeclRef), astBuilder, result); + } + return result; } +bool ComponentType::isSubType(Type* subType, Type* superType) +{ + SharedSemanticsContext sharedSemanticsContext( + getLinkage(), + nullptr, + nullptr); + SemanticsContext context(&sharedSemanticsContext); + SemanticsVisitor visitor(context); + + return (visitor.isSubtype(subType, superType, IsSubTypeOptions::None) != nullptr); +} + static void collectExportedConstantInContainer( Dictionary<String, IntVal*>& dict, ASTBuilder* builder, @@ -6805,4 +6861,4 @@ SlangResult EndToEndCompileRequest::isParameterLocationUsed(Int entryPointIndex, return SLANG_OK; } -} // namespace Slang +} // namespace Slang
\ No newline at end of file diff --git a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp index fb35f323c..d98ea0423 100644 --- a/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp +++ b/tools/slang-unit-test/unit-test-decl-tree-reflection.cpp @@ -70,6 +70,15 @@ SLANG_UNIT_TEST(declTreeReflection) T j<let N : int>(T x, out int o) { o = N; return x; } } + + namespace MyNamespace + { + struct MyStruct + { + int x; + } + } + )"; auto moduleName = "moduleG" + String(Process::getId()); @@ -101,7 +110,7 @@ SLANG_UNIT_TEST(declTreeReflection) auto moduleDeclReflection = module->getModuleReflection(); SLANG_CHECK(moduleDeclReflection != nullptr); SLANG_CHECK(moduleDeclReflection->getKind() == slang::DeclReflection::Kind::Module); - SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 7); + SLANG_CHECK(moduleDeclReflection->getChildrenCount() == 8); // First declaration should be a struct with 1 variable auto firstDecl = moduleDeclReflection->getChild(0); @@ -180,6 +189,11 @@ SLANG_UNIT_TEST(declTreeReflection) auto innerStruct = genericReflection->getInnerDecl(); SLANG_CHECK(innerStruct->getKind() == slang::DeclReflection::Kind::Struct); + // Check that the seventh declaration is a namespace + auto seventhDecl = moduleDeclReflection->getChild(6); + SLANG_CHECK(seventhDecl->getKind() == slang::DeclReflection::Kind::Namespace); + SLANG_CHECK(UnownedStringSlice(seventhDecl->getName()) == "MyNamespace"); + // Check type-lookup-by-name { @@ -262,7 +276,108 @@ SLANG_UNIT_TEST(declTreeReflection) SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10); } + + // Check specializeGeneric() and applySpecializations() + { + auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType"); + SLANG_CHECK(unspecializedType != nullptr); + auto halfType = compositeProgram->getLayout()->findTypeByName("half"); + SLANG_CHECK(halfType != nullptr); + + slang::GenericReflection* genericContainer = unspecializedType->getGenericContainer(); + SLANG_CHECK(genericContainer != nullptr); + //auto typeParamT = genericContainer->getTypeParameter(0); + + List<slang::GenericArgType> argTypes; + List<slang::GenericArgReflection> args; + argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE); + args.add({halfType}); + auto specializedContainer = compositeProgram->getLayout()->specializeGeneric( + genericContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr); + + SLANG_CHECK(specializedContainer != nullptr); + + auto specializedType = unspecializedType->applySpecializations(specializedContainer); + SLANG_CHECK(specializedType != nullptr); + SLANG_CHECK(getTypeFullName(specializedType) == "MyGenericType<half>"); + + } + + // Check specializeGeneric() and applySpecializations() on multiple levels (generic function nested in a generic struct) + { + auto unspecializedType = compositeProgram->getLayout()->findTypeByName("MyGenericType"); + auto unspecializedFunc = compositeProgram->getLayout()->findFunctionByNameInType(unspecializedType, "j"); + + SLANG_CHECK(unspecializedFunc != nullptr); + auto halfType = compositeProgram->getLayout()->findTypeByName("half"); + SLANG_CHECK(halfType != nullptr); + + slang::GenericReflection* genericFuncContainer = unspecializedFunc->getGenericContainer(); + SLANG_CHECK(genericFuncContainer != nullptr); + slang::GenericReflection* genericStructContainer = genericFuncContainer->getOuterGenericContainer(); + SLANG_CHECK(genericStructContainer != nullptr); + + // Specialize the outer container with half + List<slang::GenericArgType> argTypes; + List<slang::GenericArgReflection> args; + argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_TYPE); + args.add({halfType}); + auto specializedStructContainer = compositeProgram->getLayout()->specializeGeneric( + genericStructContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr); + SLANG_CHECK(specializedStructContainer != nullptr); + + // apply T=half. N is still left unspecialized. + genericFuncContainer = genericFuncContainer->applySpecializations(specializedStructContainer); + + // Specialize the inner container with 10 separately.. + argTypes.clear(); + args.clear(); + + slang::GenericArgReflection argN; + argN.intVal = 10; + argTypes.add(slang::GenericArgType::SLANG_GENERIC_ARG_INT); + args.add(argN); + + auto specializedFuncContainer = compositeProgram->getLayout()->specializeGeneric( + genericFuncContainer, argTypes.getCount(), argTypes.getBuffer(), args.getBuffer(), nullptr); + + auto specializedFunc = unspecializedFunc->applySpecializations(specializedFuncContainer); + SLANG_CHECK(specializedFunc != nullptr); + + // ------ check the specialized function + auto specializationInfo = specializedFunc->getGenericContainer(); + SLANG_CHECK(specializationInfo != nullptr); + SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "j"); + SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + SLANG_CHECK(specializationInfo->getValueParameterCount() == 1); + auto valueParam = specializationInfo->getValueParameter(0); + SLANG_CHECK(UnownedStringSlice(valueParam->getName()) == "N"); // generic name + SLANG_CHECK(specializationInfo->getConcreteIntVal(valueParam) == 10); + + // check outer container + specializationInfo = specializationInfo->getOuterGenericContainer(); + SLANG_CHECK(specializationInfo != nullptr); + SLANG_CHECK(UnownedStringSlice(specializationInfo->getName()) == "MyGenericType"); + SLANG_CHECK(specializationInfo->asDecl()->getKind() == slang::DeclReflection::Kind::Generic); + // Check type parameters + SLANG_CHECK(specializationInfo->getTypeParameterCount() == 1); + auto typeParam = specializationInfo->getTypeParameter(0); + SLANG_CHECK(UnownedStringSlice(typeParam->getName()) == "T"); // generic name + SLANG_CHECK(getTypeFullName(specializationInfo->getConcreteType(typeParam)) == "half"); + } + + // Check sub-type relations + { + auto floatType = compositeProgram->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + auto diffType = compositeProgram->getLayout()->findTypeByName("IDifferentiable"); + SLANG_CHECK(diffType != nullptr); + SLANG_CHECK(compositeProgram->getLayout()->isSubType(floatType, diffType) == true); + + auto uintType = compositeProgram->getLayout()->findTypeByName("uint"); + SLANG_CHECK(compositeProgram->getLayout()->isSubType(uintType, diffType) == false); + } // Check iterators { @@ -271,7 +386,7 @@ SLANG_UNIT_TEST(declTreeReflection) { count++; } - SLANG_CHECK(count == 7); + SLANG_CHECK(count == 8); count = 0; for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Func>()) @@ -293,6 +408,13 @@ SLANG_UNIT_TEST(declTreeReflection) count++; } SLANG_CHECK(count == 1); + + count = 0; + for (auto* child : moduleDeclReflection->getChildrenOfKind<slang::DeclReflection::Kind::Namespace>()) + { + count++; + } + SLANG_CHECK(count == 1); } } |
