diff options
| author | jsmall-nvidia <jsmall@nvidia.com> | 2020-06-03 17:22:48 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-06-03 17:22:48 -0400 |
| commit | 1b8731c809761c4e2dbec81dcee207f8a4621903 (patch) | |
| tree | b8c67d97a71df2a8ba776b6d1a39bc13138aeaf0 /source/slang | |
| parent | 4e3e7f2a8f032c3f8fc4c530023aa80973598502 (diff) | |
Devirtualize AST types (#1368)
* Make getSup work with more general non-virtual 'virtual' mechanism.
* WIP: Non virtual AST types.
* Project change.
* Type doesn't implement equalsImpl
* Fix macro invocation
Make Overridden functions public to make simply accessible by base types.
* Use SLANG_UNEXPECTED.
* GetScalarType -> getScalarType
Use SLANG_UNEXPECTED instead on ASSERT in NamedExpressionType and TypeType
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ast-base.h | 91 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 27 | ||||
| -rw-r--r-- | source/slang/slang-ast-reflect.h | 24 | ||||
| -rw-r--r-- | source/slang/slang-ast-substitutions.cpp | 237 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 972 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 194 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 552 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 74 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 1598 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 11 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 4 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 12 |
13 files changed, 2026 insertions, 1791 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 5f40cba49..05d2ded69 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -98,15 +98,21 @@ class Val : public NodeBase // integer parameter that should be incremented when // returning a modified value (this can help the caller // decide whether they need to do anything). - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); - virtual bool equalsVal(Val* val) = 0; - virtual String toString() = 0; - virtual HashCode getHashCode() = 0; + bool equalsVal(Val* val); + String toString(); + HashCode getHashCode(); bool operator == (const Val & v) { return equalsVal(const_cast<Val*>(&v)); } + + // Overrides should be public so base classes can access + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + bool _equalsValOverride(Val* val); + String _toStringOverride(); + HashCode _getHashCodeOverride(); }; class Type; @@ -144,22 +150,22 @@ class Type: public Val /// Get the ASTBuilder that was used to construct this Type SLANG_FORCE_INLINE ASTBuilder* getASTBuilder() const { return m_astBuilder; } - //Session* getSession() { return this->session; } - bool equals(Type* type); Type* getCanonicalType(); - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; - - virtual bool equalsVal(Val* val) override; - ~Type(); + // Overrides should be public so base classes can access + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + bool _equalsValOverride(Val* val); + bool _equalsImplOverride(Type* type); + RefPtr<Type> _createCanonicalTypeOverride(); + protected: - virtual bool equalsImpl(Type* type) = 0; + bool equalsImpl(Type* type); + RefPtr<Type> createCanonicalType(); - virtual RefPtr<Type> createCanonicalType() = 0; Type* canonicalType = nullptr; SLANG_UNREFLECTED @@ -181,11 +187,16 @@ class Substitutions: public NodeBase RefPtr<Substitutions> outer; // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr<Substitutions> applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) = 0; + RefPtr<Substitutions> applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff); // Check if these are equivalent substitutions to another set - virtual bool equals(Substitutions* subst) = 0; - virtual HashCode getHashCode() const = 0; + bool equals(Substitutions* subst); + HashCode getHashCode() const; + + // Overrides should be public so base classes can access + RefPtr<Substitutions> _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff); + bool _equalsOverride(Substitutions* subst); + HashCode _getHashCodeOverride() const; }; class GenericSubstitution : public Substitutions @@ -199,22 +210,10 @@ class GenericSubstitution : public Substitutions // The actual values of the arguments List<RefPtr<Val> > args; - // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr<Substitutions> applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) override; - - // Check if these are equivalent substitutions to another set - virtual bool equals(Substitutions* subst) override; - - virtual HashCode getHashCode() const override - { - HashCode rs = 0; - for (auto && v : args) - { - rs ^= v->getHashCode(); - rs *= 16777619; - } - return rs; - } + // Overrides should be public so base classes can access + RefPtr<Substitutions> _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff); + bool _equalsOverride(Substitutions* subst); + HashCode _getHashCodeOverride() const; }; class ThisTypeSubstitution : public Substitutions @@ -228,14 +227,11 @@ class ThisTypeSubstitution : public Substitutions // specialize the interface conforms to the interface. RefPtr<SubtypeWitness> witness; + // Overrides should be public so base classes can access // The actual type that provides the lookup scope for an associated type - // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr<Substitutions> applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) override; - - // Check if these are equivalent substitutions to another set - virtual bool equals(Substitutions* subst) override; - - virtual HashCode getHashCode() const override; + RefPtr<Substitutions> _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff); + bool _equalsOverride(Substitutions* subst); + HashCode _getHashCodeOverride() const; }; class GlobalGenericParamSubstitution : public Substitutions @@ -256,21 +252,10 @@ class GlobalGenericParamSubstitution : public Substitutions // the values that satisfy any constraints on the type parameter List<ConstraintArg> constraintArgs; - // Apply a set of substitutions to the bindings in this substitution - virtual RefPtr<Substitutions> applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) override; - - // Check if these are equivalent substitutions to another set - virtual bool equals(Substitutions* subst) override; - - virtual HashCode getHashCode() const override - { - HashCode rs = actualType->getHashCode(); - for (auto && a : constraintArgs) - { - rs = combineHash(rs, a.val->getHashCode()); - } - return rs; - } + // Overrides should be public so base classes can access + RefPtr<Substitutions> _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff); + bool _equalsOverride(Substitutions* subst); + HashCode _getHashCodeOverride() const; }; class SyntaxNode : public SyntaxNodeBase diff --git a/source/slang/slang-ast-decl.cpp b/source/slang/slang-ast-decl.cpp new file mode 100644 index 000000000..a10411ebb --- /dev/null +++ b/source/slang/slang-ast-decl.cpp @@ -0,0 +1,21 @@ +// slang-ast-decl.cpp +#include "slang-ast-builder.h" +#include <assert.h> + +#include "slang-ast-generated-macro.h" + +namespace Slang { + +const TypeExp& TypeConstraintDecl::getSup() const +{ + SLANG_AST_NODE_CONST_VIRTUAL_CALL(TypeConstraintDecl, getSup, ()) +} + +const TypeExp& TypeConstraintDecl::_getSupOverride() const +{ + SLANG_UNEXPECTED("TypeConstraintDecl::_getSupOverride not overridden"); + //return TypeExp::empty; +} + + +} // namespace Slang diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 54e8ac18b..bdb9b8ad8 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -171,8 +171,11 @@ class InterfaceDecl : public AggTypeDecl class TypeConstraintDecl : public Decl { SLANG_ABSTRACT_CLASS(TypeConstraintDecl) - - SLANG_INLINE const TypeExp& getSup() const; + + const TypeExp& getSup() const; + // Overrides should be public so base classes can access + // Implement _getSupOverride on derived classes to change behavior of getSup, as if getSup is virtual + const TypeExp& _getSupOverride() const; }; // A kind of pseudo-member that represents an explicit @@ -190,6 +193,9 @@ class InheritanceDecl : public TypeConstraintDecl // implementations in the type that contains // this inheritance declaration. RefPtr<WitnessTable> witnessTable; + + // Overrides should be public so base classes can access + const TypeExp& _getSupOverride() const { return base; } }; // TODO: may eventually need sub-classes for explicit/direct vs. implicit/indirect inheritance @@ -403,6 +409,9 @@ class GenericTypeConstraintDecl : public TypeConstraintDecl // think of these fields as the sub-type and super-type, respectively. TypeExp sub; TypeExp sup; + + // Overrides should be public so base classes can access + const TypeExp& _getSupOverride() const { return sup; } }; class GenericValueParamDecl : public VarDeclBase @@ -450,18 +459,4 @@ class AttributeDecl : public ContainerDecl SyntaxClass<RefObject> syntaxClass; }; -// ------------------------------------------------------------------------ - -const TypeExp& TypeConstraintDecl::getSup() const -{ - ASTNodeType type = ASTNodeType(getClassInfo().m_classId); - switch (type) - { - case ASTNodeType::InheritanceDecl: return static_cast<const InheritanceDecl*>(this)->base; - case ASTNodeType::GenericTypeConstraintDecl: return static_cast<const GenericTypeConstraintDecl*>(this)->sup; - default: SLANG_ASSERT(!"getSup not implemented for this type!"); return TypeExp::empty; - } -} - - } // namespace Slang diff --git a/source/slang/slang-ast-reflect.h b/source/slang/slang-ast-reflect.h index 2c629e839..82fb80b22 100644 --- a/source/slang/slang-ast-reflect.h +++ b/source/slang/slang-ast-reflect.h @@ -16,7 +16,7 @@ static const ReflectClassInfo kReflectClassInfo; \ SLANG_FORCE_INLINE static bool isDerivedFrom(ASTNodeType type) { return int(type) >= int(kType) && int(type) <= int(ASTNodeType::LAST); } \ friend class ASTBuilder; \ - friend struct ASTConstructAccess; + friend struct ASTConstructAccess; // Macro definitions - use the SLANG_ASTNode_ definitions to invoke the IMPL to produce the code // injected into AST classes @@ -28,4 +28,26 @@ #define SLANG_REFLECTED #define SLANG_UNREFLECTED +// Macros for simulating virtual methods without virtual methods + +#define SLANG_AST_NODE_INVOKE(method, methodParams) _##method##Override methodParams + +#define SLANG_AST_NODE_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) case ASTNodeType::NAME: return static_cast<NAME*>(this)-> SLANG_AST_NODE_INVOKE param; + +#define SLANG_AST_NODE_VIRTUAL_CALL(base, methodName, methodParams) \ + switch (astNodeType) \ + { \ + SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CASE, (methodName, methodParams)) \ + default: return SLANG_AST_NODE_INVOKE (methodName, methodParams); \ + } + +// Same but for a method that's const +#define SLANG_AST_NODE_CONST_CASE(NAME, SUPER, ORIGIN, LAST, MARKER, TYPE, param) case ASTNodeType::NAME: return static_cast<const NAME*>(this)-> SLANG_AST_NODE_INVOKE param; +#define SLANG_AST_NODE_CONST_VIRTUAL_CALL(base, methodName, methodParams) \ + switch (astNodeType) \ + { \ + SLANG_ALL_ASTNode_##base(SLANG_AST_NODE_CONST_CASE, (methodName, methodParams)) \ + default: return SLANG_AST_NODE_INVOKE (methodName, methodParams); \ + } + #endif // SLANG_AST_REFLECT_H diff --git a/source/slang/slang-ast-substitutions.cpp b/source/slang/slang-ast-substitutions.cpp new file mode 100644 index 000000000..05865fe8f --- /dev/null +++ b/source/slang/slang-ast-substitutions.cpp @@ -0,0 +1,237 @@ +// slang-ast-substitutions.cpp +#include "slang-ast-builder.h" +#include <assert.h> + +#include "slang-ast-generated-macro.h" + +namespace Slang { + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Substitutions !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +RefPtr<Substitutions> Substitutions::applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) +{ + SLANG_AST_NODE_VIRTUAL_CALL(Substitutions, applySubstitutionsShallow, (astBuilder, substSet, substOuter, ioDiff)) +} + +bool Substitutions::equals(Substitutions* subst) +{ + SLANG_AST_NODE_VIRTUAL_CALL(Substitutions, equals, (subst)) +} + +HashCode Substitutions::getHashCode() const +{ + SLANG_AST_NODE_CONST_VIRTUAL_CALL(Substitutions, getHashCode, ()) +} + +RefPtr<Substitutions> Substitutions::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) +{ + SLANG_UNUSED(astBuilder); + SLANG_UNUSED(substSet); + SLANG_UNUSED(substOuter); + SLANG_UNUSED(ioDiff); + SLANG_UNEXPECTED("Substitutions::_applySubstitutionsShallowOverride not overridden"); + //return RefPtr<Substitutions>(); +} + +bool Substitutions::_equalsOverride(Substitutions* subst) +{ + SLANG_UNUSED(subst); + SLANG_UNEXPECTED("Substitutions::_equalsOverride not overridden"); + //return false; +} + +HashCode Substitutions::_getHashCodeOverride() const +{ + SLANG_UNEXPECTED("Substitutions::_getHashCodeOverride not overridden"); + //return HashCode(0); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericSubstitution !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +RefPtr<Substitutions> GenericSubstitution::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) +{ + int diff = 0; + + if (substOuter != outer) diff++; + + List<RefPtr<Val>> substArgs; + for (auto a : args) + { + substArgs.add(a->substituteImpl(astBuilder, substSet, &diff)); + } + + if (!diff) return this; + + (*ioDiff)++; + auto substSubst = astBuilder->create<GenericSubstitution>(); + substSubst->genericDecl = genericDecl; + substSubst->args = substArgs; + substSubst->outer = substOuter; + return substSubst; +} + +bool GenericSubstitution::_equalsOverride(Substitutions* subst) +{ + // both must be NULL, or non-NULL + if (subst == nullptr) + return false; + if (this == subst) + return true; + + auto genericSubst = as<GenericSubstitution>(subst); + if (!genericSubst) + return false; + if (genericDecl != genericSubst->genericDecl) + return false; + + Index argCount = args.getCount(); + SLANG_RELEASE_ASSERT(args.getCount() == genericSubst->args.getCount()); + for (Index aa = 0; aa < argCount; ++aa) + { + if (!args[aa]->equalsVal(genericSubst->args[aa].Ptr())) + return false; + } + + if (!outer) + return !genericSubst->outer; + + if (!outer->equals(genericSubst->outer.Ptr())) + return false; + + return true; +} + +HashCode GenericSubstitution::_getHashCodeOverride() const +{ + HashCode rs = 0; + for (auto && v : args) + { + rs ^= v->getHashCode(); + rs *= 16777619; + } + return rs; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisTypeSubstitution !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +RefPtr<Substitutions> ThisTypeSubstitution::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) +{ + int diff = 0; + + if (substOuter != outer) diff++; + + // NOTE: Must use .as because we must have a smart pointer here to keep in scope. + auto substWitness = witness->substituteImpl(astBuilder, substSet, &diff).as<SubtypeWitness>(); + + if (!diff) return this; + + (*ioDiff)++; + auto substSubst = astBuilder->create<ThisTypeSubstitution>(); + substSubst->interfaceDecl = interfaceDecl; + substSubst->witness = substWitness; + substSubst->outer = substOuter; + return substSubst; +} + +bool ThisTypeSubstitution::_equalsOverride(Substitutions* subst) +{ + if (!subst) + return false; + if (subst == this) + return true; + + if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst)) + { + // For our purposes, two this-type substitutions are + // equivalent if they have the same type as `This`, + // even if the specific witness values they use + // might differ. + // + if (this->interfaceDecl != thisTypeSubst->interfaceDecl) + return false; + + if (!this->witness->sub->equals(thisTypeSubst->witness->sub)) + return false; + + return true; + } + return false; +} + +HashCode ThisTypeSubstitution::_getHashCodeOverride() const +{ + return witness->getHashCode(); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GlobalGenericParamSubstitution !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +RefPtr<Substitutions> GlobalGenericParamSubstitution::_applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) +{ + // if we find a GlobalGenericParamSubstitution in subst that references the same type_param decl + // return a copy of that GlobalGenericParamSubstitution + int diff = 0; + + if (substOuter != outer) diff++; + + auto substActualType = actualType->substituteImpl(astBuilder, substSet, &diff).as<Type>(); + + List<ConstraintArg> substConstraintArgs; + for (auto constraintArg : constraintArgs) + { + ConstraintArg substConstraintArg; + substConstraintArg.decl = constraintArg.decl; + substConstraintArg.val = constraintArg.val->substituteImpl(astBuilder, substSet, &diff); + + substConstraintArgs.add(substConstraintArg); + } + + if (!diff) + return this; + + (*ioDiff)++; + + RefPtr<GlobalGenericParamSubstitution> substSubst = astBuilder->create<GlobalGenericParamSubstitution>(); + substSubst->paramDecl = paramDecl; + substSubst->actualType = substActualType; + substSubst->constraintArgs = substConstraintArgs; + substSubst->outer = substOuter; + return substSubst; +} + +bool GlobalGenericParamSubstitution::_equalsOverride(Substitutions* subst) +{ + if (!subst) + return false; + if (subst == this) + return true; + + if (auto genSubst = as<GlobalGenericParamSubstitution>(subst)) + { + if (paramDecl != genSubst->paramDecl) + return false; + if (!actualType->equalsVal(genSubst->actualType)) + return false; + if (constraintArgs.getCount() != genSubst->constraintArgs.getCount()) + return false; + for (Index i = 0; i < constraintArgs.getCount(); i++) + { + if (!constraintArgs[i].val->equalsVal(genSubst->constraintArgs[i].val)) + return false; + } + return true; + } + return false; +} + +HashCode GlobalGenericParamSubstitution::_getHashCodeOverride() const +{ + HashCode rs = actualType->getHashCode(); + for (auto && a : constraintArgs) + { + rs = combineHash(rs, a.val->getHashCode()); + } + return rs; +} + + +} // namespace Slang diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp new file mode 100644 index 000000000..48e95562e --- /dev/null +++ b/source/slang/slang-ast-type.cpp @@ -0,0 +1,972 @@ +// slang-ast-type.cpp +#include "slang-ast-builder.h" +#include <assert.h> +#include <typeinfo> + +#include "slang-syntax.h" + +#include "slang-ast-generated-macro.h" + +namespace Slang { + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Type !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +Type::~Type() +{ + // If the canonicalType !=nullptr AND it is not set to this (ie the canonicalType is another object) + // then it needs to be released because it's owned by this object. + if (canonicalType && canonicalType != this) + { + canonicalType->releaseReference(); + } +} + +RefPtr<Type> Type::createCanonicalType() +{ + SLANG_AST_NODE_VIRTUAL_CALL(Type, createCanonicalType, ()) +} + +bool Type::equals(Type* type) +{ + return getCanonicalType()->equalsImpl(type->getCanonicalType()); +} + +bool Type::equalsImpl(Type* type) +{ + SLANG_AST_NODE_VIRTUAL_CALL(Type, equalsImpl, (type)) +} + +bool Type::_equalsImplOverride(Type* type) +{ + SLANG_UNUSED(type) + SLANG_UNEXPECTED("Type::_equalsImplOverride not overridden"); + //return false; +} + +RefPtr<Type> Type::_createCanonicalTypeOverride() +{ + SLANG_UNEXPECTED("Type::_createCanonicalTypeOverride not overridden"); + //return RefPtr<Type>(); +} + +bool Type::_equalsValOverride(Val* val) +{ + if (auto type = dynamicCast<Type>(val)) + return const_cast<Type*>(this)->equals(type); + return false; +} + +RefPtr<Val> Type::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto canSubst = getCanonicalType()->substituteImpl(astBuilder, subst, &diff); + + // If nothing changed, then don't drop any sugar that is applied + if (!diff) + return this; + + // If the canonical type changed, then we return a canonical type, + // rather than try to re-construct any amount of sugar + (*ioDiff)++; + return canSubst; +} + +Type* Type::getCanonicalType() +{ + Type* et = const_cast<Type*>(this); + if (!et->canonicalType) + { + // TODO(tfoley): worry about thread safety here? + auto canType = et->createCanonicalType(); + et->canonicalType = canType; + + // TODO(js): That this detachs when canType == this is a little surprising. It would seem + // as if this would create a circular reference on the object, but in practice there are + // no leaks so appears correct. + // That the dtor only releases if != this, also makes it surprising. + canType.detach(); + + SLANG_ASSERT(et->canonicalType); + } + return et->canonicalType; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! OverloadGroupType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String OverloadGroupType::_toStringOverride() +{ + return "overload group"; +} + +bool OverloadGroupType::_equalsImplOverride(Type * /*type*/) +{ + return false; +} + +RefPtr<Type> OverloadGroupType::_createCanonicalTypeOverride() +{ + return this; +} + +HashCode OverloadGroupType::_getHashCodeOverride() +{ + return (HashCode)(size_t(this)); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! InitializerListType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String InitializerListType::_toStringOverride() +{ + return "initializer list"; +} + +bool InitializerListType::_equalsImplOverride(Type * /*type*/) +{ + return false; +} + +RefPtr<Type> InitializerListType::_createCanonicalTypeOverride() +{ + return this; +} + +HashCode InitializerListType::_getHashCodeOverride() +{ + return (HashCode)(size_t(this)); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String ErrorType::_toStringOverride() +{ + return "error"; +} + +bool ErrorType::_equalsImplOverride(Type* type) +{ + if (auto errorType = as<ErrorType>(type)) + return true; + return false; +} + +RefPtr<Type> ErrorType::_createCanonicalTypeOverride() +{ + return this; +} + +RefPtr<Val> ErrorType::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) +{ + return this; +} + +HashCode ErrorType::_getHashCodeOverride() +{ + return HashCode(size_t(this)); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String DeclRefType::_toStringOverride() +{ + return declRef.toString(); +} + +HashCode DeclRefType::_getHashCodeOverride() +{ + return (declRef.getHashCode() * 16777619) ^ (HashCode)(typeid(this).hash_code()); +} + +bool DeclRefType::_equalsImplOverride(Type * type) +{ + if (auto declRefType = as<DeclRefType>(type)) + { + return declRef.equals(declRefType->declRef); + } + return false; +} + +RefPtr<Type> DeclRefType::_createCanonicalTypeOverride() +{ + // A declaration reference is already canonical + return this; +} + +RefPtr<Val> DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + if (!subst) return this; + + // the case we especially care about is when this type references a declaration + // of a generic parameter, since that is what we might be substituting... + if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(declRef.getDecl())) + { + // search for a substitution that might apply to us + for (auto s = subst.substitutions; s; s = s->outer) + { + auto genericSubst = s.as<GenericSubstitution>(); + if (!genericSubst) + continue; + + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genericSubst->genericDecl; + if (genericDecl != genericTypeParamDecl->parentDecl) + continue; + + int index = 0; + for (auto m : genericDecl->members) + { + if (m.Ptr() == genericTypeParamDecl) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return genericSubst->args[index]; + } + else if (auto typeParam = as<GenericTypeParamDecl>(m)) + { + index++; + } + else if (auto valParam = as<GenericValueParamDecl>(m)) + { + index++; + } + else + { + } + } + } + } + else if (auto globalGenParam = as<GlobalGenericParamDecl>(declRef.getDecl())) + { + // search for a substitution that might apply to us + for (auto s = subst.substitutions; s; s = s->outer) + { + auto genericSubst = as<GlobalGenericParamSubstitution>(s); + if (!genericSubst) + continue; + + if (genericSubst->paramDecl == globalGenParam) + { + (*ioDiff)++; + return genericSubst->actualType; + } + } + } + int diff = 0; + DeclRef<Decl> substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + + if (!diff) + return this; + + // Make sure to record the difference! + *ioDiff += diff; + + // If this type is a reference to an associated type declaration, + // and the substitutions provide a "this type" substitution for + // the outer interface, then try to replace the type with the + // actual value of the associated type for the given implementation. + // + if (auto substAssocTypeDecl = as<AssocTypeDecl>(substDeclRef.decl)) + { + for (auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) + { + auto thisSubst = s.as<ThisTypeSubstitution>(); + if (!thisSubst) + continue; + + if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl)) + { + if (thisSubst->interfaceDecl == interfaceDecl) + { + // We need to look up the declaration that satisfies + // the requirement named by the associated type. + Decl* requirementKey = substAssocTypeDecl; + RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisSubst->witness, requirementKey); + switch (requirementWitness.getFlavor()) + { + default: + // No usable value was found, so there is nothing we can do. + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = requirementWitness.getVal(); + return satisfyingVal; + } + break; + } + } + } + } + } + + // Re-construct the type in case we are using a specialized sub-class + return DeclRefType::create(astBuilder, substDeclRef); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArithmeticExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + + +BasicExpressionType* ArithmeticExpressionType::getScalarType() +{ + SLANG_AST_NODE_VIRTUAL_CALL(ArithmeticExpressionType, getScalarType, ()) +} + +BasicExpressionType* ArithmeticExpressionType::_getScalarTypeOverride() +{ + SLANG_UNEXPECTED("ArithmeticExpressionType::_getScalarTypeOverride not overridden"); + //return nullptr; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool BasicExpressionType::_equalsImplOverride(Type * type) +{ + auto basicType = as<BasicExpressionType>(type); + return basicType && basicType->baseType == this->baseType; +} + +RefPtr<Type> BasicExpressionType::_createCanonicalTypeOverride() +{ + // A basic type is already canonical, in our setup + return this; +} + +BasicExpressionType* BasicExpressionType::_getScalarTypeOverride() +{ + return this; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! VectorExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String VectorExpressionType::_toStringOverride() +{ + StringBuilder sb; + sb << "vector<" << elementType->toString() << "," << elementCount->toString() << ">"; + return sb.ProduceString(); +} + +BasicExpressionType* VectorExpressionType::_getScalarTypeOverride() +{ + return as<BasicExpressionType>(elementType); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! MatrixExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String MatrixExpressionType::_toStringOverride() +{ + StringBuilder sb; + sb << "matrix<" << getElementType()->toString() << "," << getRowCount()->toString() << "," << getColumnCount()->toString() << ">"; + return sb.ProduceString(); +} + +BasicExpressionType* MatrixExpressionType::_getScalarTypeOverride() +{ + return as<BasicExpressionType>(getElementType()); +} + +Type* MatrixExpressionType::getElementType() +{ + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); +} + +IntVal* MatrixExpressionType::getRowCount() +{ + return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); +} + +IntVal* MatrixExpressionType::getColumnCount() +{ + return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[2]); +} + +RefPtr<Type> MatrixExpressionType::getRowType() +{ + if (!rowType) + { + rowType = m_astBuilder->getVectorType(getElementType(), getColumnCount()); + } + return rowType; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ArrayExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool ArrayExpressionType::_equalsImplOverride(Type* type) +{ + auto arrType = as<ArrayExpressionType>(type); + if (!arrType) + return false; + return (areValsEqual(arrayLength, arrType->arrayLength) && baseType->equals(arrType->baseType.Ptr())); +} + +RefPtr<Val> ArrayExpressionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto elementType = baseType->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto arrlen = arrayLength->substituteImpl(astBuilder, subst, &diff).as<IntVal>(); + SLANG_ASSERT(arrlen); + if (diff) + { + *ioDiff = 1; + auto rsType = getArrayType( + astBuilder, + elementType, + arrlen); + return rsType; + } + return this; +} + +RefPtr<Type> ArrayExpressionType::_createCanonicalTypeOverride() +{ + auto canonicalElementType = baseType->getCanonicalType(); + auto canonicalArrayType = getASTBuilder()->getArrayType( + canonicalElementType, + arrayLength); + return canonicalArrayType; +} + +HashCode ArrayExpressionType::_getHashCodeOverride() +{ + if (arrayLength) + return (baseType->getHashCode() * 16777619) ^ arrayLength->getHashCode(); + else + return baseType->getHashCode(); +} + +Slang::String ArrayExpressionType::_toStringOverride() +{ + if (arrayLength) + return baseType->toString() + "[" + arrayLength->toString() + "]"; + else + return baseType->toString() + "[]"; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String TypeType::_toStringOverride() +{ + StringBuilder sb; + sb << "typeof(" << type->toString() << ")"; + return sb.ProduceString(); +} + +bool TypeType::_equalsImplOverride(Type * t) +{ + if (auto typeType = as<TypeType>(t)) + { + return t->equals(typeType->type); + } + return false; +} + +RefPtr<Type> TypeType::_createCanonicalTypeOverride() +{ + return getASTBuilder()->getTypeType(type->getCanonicalType()); +} + +HashCode TypeType::_getHashCodeOverride() +{ + SLANG_UNEXPECTED("TypeType::_getHashCodeOverride should be unreachable"); + //return HashCode(0); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericDeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String GenericDeclRefType::_toStringOverride() +{ + // TODO: what is appropriate here? + return "<DeclRef<GenericDecl>>"; +} + +bool GenericDeclRefType::_equalsImplOverride(Type * type) +{ + if (auto genericDeclRefType = as<GenericDeclRefType>(type)) + { + return declRef.equals(genericDeclRefType->declRef); + } + return false; +} + +HashCode GenericDeclRefType::_getHashCodeOverride() +{ + return declRef.getHashCode(); +} + +RefPtr<Type> GenericDeclRefType::_createCanonicalTypeOverride() +{ + return this; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamespaceType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String NamespaceType::_toStringOverride() +{ + String result; + result.append("namespace "); + result.append(declRef.toString()); + return result; +} + +bool NamespaceType::_equalsImplOverride(Type * type) +{ + if (auto namespaceType = as<NamespaceType>(type)) + { + return declRef.equals(namespaceType->declRef); + } + return false; +} + +HashCode NamespaceType::_getHashCodeOverride() +{ + return declRef.getHashCode(); +} + +RefPtr<Type> NamespaceType::_createCanonicalTypeOverride() +{ + return this; +} + + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! PtrTypeBase !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +Type* PtrTypeBase::getValueType() +{ + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! NamedExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String NamedExpressionType::_toStringOverride() +{ + return getText(declRef.getName()); +} + +bool NamedExpressionType::_equalsImplOverride(Type * /*type*/) +{ + SLANG_UNEXPECTED("NamedExpressionType::_equalsImplOverride should be unreachable"); + //return false; +} + +RefPtr<Type> NamedExpressionType::_createCanonicalTypeOverride() +{ + if (!innerType) + innerType = getType(m_astBuilder, declRef); + return innerType->getCanonicalType(); +} + +HashCode NamedExpressionType::_getHashCodeOverride() +{ + // Type equality is based on comparing canonical types, + // so the hash code for a type needs to come from the + // canonical version of the type. This really means + // that `Type::getHashCode()` should dispatch out to + // something like `Type::getHashCodeImpl()` on the + // canonical version of a type, but it is less invasive + // for now (and hopefully equivalent) to just have any + // named types automaticlaly route hash-code requests + // to their canonical type. + return getCanonicalType()->getHashCode(); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! FuncType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String FuncType::_toStringOverride() +{ + StringBuilder sb; + sb << "("; + UInt paramCount = getParamCount(); + for (UInt pp = 0; pp < paramCount; ++pp) + { + if (pp != 0) sb << ", "; + sb << getParamType(pp)->toString(); + } + sb << ") -> "; + sb << getResultType()->toString(); + return sb.ProduceString(); +} + +bool FuncType::_equalsImplOverride(Type * type) +{ + if (auto funcType = as<FuncType>(type)) + { + auto paramCount = getParamCount(); + auto otherParamCount = funcType->getParamCount(); + if (paramCount != otherParamCount) + return false; + + for (UInt pp = 0; pp < paramCount; ++pp) + { + auto paramType = getParamType(pp); + auto otherParamType = funcType->getParamType(pp); + if (!paramType->equals(otherParamType)) + return false; + } + + if (!resultType->equals(funcType->resultType)) + return false; + + // TODO: if we ever introduce other kinds + // of qualification on function types, we'd + // want to consider it here. + return true; + } + return false; +} + +RefPtr<Val> FuncType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + // result type + RefPtr<Type> substResultType = resultType->substituteImpl(astBuilder, subst, &diff).as<Type>(); + + // parameter types + List<RefPtr<Type>> substParamTypes; + for (auto pp : paramTypes) + { + substParamTypes.add(pp->substituteImpl(astBuilder, subst, &diff).as<Type>()); + } + + // early exit for no change... + if (!diff) + return this; + + (*ioDiff)++; + RefPtr<FuncType> substType = astBuilder->create<FuncType>(); + substType->resultType = substResultType; + substType->paramTypes = substParamTypes; + return substType; +} + +RefPtr<Type> FuncType::_createCanonicalTypeOverride() +{ + // result type + RefPtr<Type> canResultType = resultType->getCanonicalType(); + + // parameter types + List<RefPtr<Type>> canParamTypes; + for (auto pp : paramTypes) + { + canParamTypes.add(pp->getCanonicalType()); + } + + RefPtr<FuncType> canType = getASTBuilder()->create<FuncType>(); + canType->resultType = resultType; + canType->paramTypes = canParamTypes; + + return canType; +} + +HashCode FuncType::_getHashCodeOverride() +{ + HashCode hashCode = getResultType()->getHashCode(); + UInt paramCount = getParamCount(); + hashCode = combineHash(hashCode, Slang::getHashCode(paramCount)); + for (UInt pp = 0; pp < paramCount; ++pp) + { + hashCode = combineHash( + hashCode, + getParamType(pp)->getHashCode()); + } + return hashCode; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String ExtractExistentialType::_toStringOverride() +{ + String result; + result.append(declRef.toString()); + result.append(".This"); + return result; +} + +bool ExtractExistentialType::_equalsImplOverride(Type* type) +{ + if (auto extractExistential = as<ExtractExistentialType>(type)) + { + return declRef.equals(extractExistential->declRef); + } + return false; +} + +HashCode ExtractExistentialType::_getHashCodeOverride() +{ + return declRef.getHashCode(); +} + +RefPtr<Type> ExtractExistentialType::_createCanonicalTypeOverride() +{ + return this; +} + +RefPtr<Val> ExtractExistentialType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + if (!diff) + return this; + + (*ioDiff)++; + + RefPtr<ExtractExistentialType> substValue = astBuilder->create<ExtractExistentialType>(); + substValue->declRef = declRef; + return substValue; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String TaggedUnionType::_toStringOverride() +{ + String result; + result.append("__TaggedUnion("); + bool first = true; + for (auto caseType : caseTypes) + { + if (!first) result.append(", "); + first = false; + + result.append(caseType->toString()); + } + result.append(")"); + return result; +} + +bool TaggedUnionType::_equalsImplOverride(Type* type) +{ + auto taggedUnion = as<TaggedUnionType>(type); + if (!taggedUnion) + return false; + + auto caseCount = caseTypes.getCount(); + if (caseCount != taggedUnion->caseTypes.getCount()) + return false; + + for (Index ii = 0; ii < caseCount; ++ii) + { + if (!caseTypes[ii]->equals(taggedUnion->caseTypes[ii])) + return false; + } + return true; +} + +HashCode TaggedUnionType::_getHashCodeOverride() +{ + HashCode hashCode = 0; + for (auto caseType : caseTypes) + { + hashCode = combineHash(hashCode, caseType->getHashCode()); + } + return hashCode; +} + +RefPtr<Type> TaggedUnionType::_createCanonicalTypeOverride() +{ + RefPtr<TaggedUnionType> canType = m_astBuilder->create<TaggedUnionType>(); + + for (auto caseType : caseTypes) + { + auto canCaseType = caseType->getCanonicalType(); + canType->caseTypes.add(canCaseType); + } + + return canType; +} + +RefPtr<Val> TaggedUnionType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + List<RefPtr<Type>> substCaseTypes; + for (auto caseType : caseTypes) + { + substCaseTypes.add(caseType->substituteImpl(astBuilder, subst, &diff).as<Type>()); + } + if (!diff) + return this; + + (*ioDiff)++; + + RefPtr<TaggedUnionType> substType = astBuilder->create<TaggedUnionType>(); + substType->caseTypes.swapWith(substCaseTypes); + return substType; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExistentialSpecializedType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String ExistentialSpecializedType::_toStringOverride() +{ + String result; + result.append("__ExistentialSpecializedType("); + result.append(baseType->toString()); + for (auto arg : args) + { + result.append(", "); + result.append(arg.val->toString()); + } + result.append(")"); + return result; +} + +bool ExistentialSpecializedType::_equalsImplOverride(Type * type) +{ + auto other = as<ExistentialSpecializedType>(type); + if (!other) + return false; + + if (!baseType->equals(other->baseType)) + return false; + + auto argCount = args.getCount(); + if (argCount != other->args.getCount()) + return false; + + for (Index ii = 0; ii < argCount; ++ii) + { + auto arg = args[ii]; + auto otherArg = other->args[ii]; + + if (!arg.val->equalsVal(otherArg.val)) + return false; + + if (!areValsEqual(arg.witness, otherArg.witness)) + return false; + } + return true; +} + +HashCode ExistentialSpecializedType::_getHashCodeOverride() +{ + Hasher hasher; + hasher.hashObject(baseType); + for (auto arg : args) + { + hasher.hashObject(arg.val); + if (auto witness = arg.witness) + hasher.hashObject(witness); + } + return hasher.getResult(); +} + +static RefPtr<Val> _getCanonicalValue(Val* val) +{ + if (!val) + return nullptr; + if (auto type = as<Type>(val)) + { + return type->getCanonicalType(); + } + // TODO: We may eventually need/want some sort of canonicalization + // for non-type values, but for now there is nothing to do. + return val; +} + +RefPtr<Type> ExistentialSpecializedType::_createCanonicalTypeOverride() +{ + RefPtr<ExistentialSpecializedType> canType = m_astBuilder->create<ExistentialSpecializedType>(); + + canType->baseType = baseType->getCanonicalType(); + for (auto arg : args) + { + ExpandedSpecializationArg canArg; + canArg.val = _getCanonicalValue(arg.val); + canArg.witness = _getCanonicalValue(arg.witness); + canType->args.add(canArg); + } + return canType; +} + +static RefPtr<Val> _substituteImpl(ASTBuilder* astBuilder, Val* val, SubstitutionSet subst, int* ioDiff) +{ + if (!val) return nullptr; + return val->substituteImpl(astBuilder, subst, ioDiff); +} + +RefPtr<Val> ExistentialSpecializedType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substBaseType = baseType->substituteImpl(astBuilder, subst, &diff).as<Type>(); + + ExpandedSpecializationArgs substArgs; + for (auto arg : args) + { + ExpandedSpecializationArg substArg; + substArg.val = _substituteImpl(astBuilder, arg.val, subst, &diff); + substArg.witness = _substituteImpl(astBuilder, arg.witness, subst, &diff); + substArgs.add(substArg); + } + + if (!diff) + return this; + + (*ioDiff)++; + + RefPtr<ExistentialSpecializedType> substType = astBuilder->create<ExistentialSpecializedType>(); + substType->baseType = substBaseType; + substType->args = substArgs; + return substType; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ThisType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +String ThisType::_toStringOverride() +{ + String result; + result.append(interfaceDeclRef.toString()); + result.append(".This"); + return result; +} + +bool ThisType::_equalsImplOverride(Type * type) +{ + auto other = as<ThisType>(type); + if (!other) + return false; + + if (!interfaceDeclRef.equals(other->interfaceDeclRef)) + return false; + + return true; +} + +HashCode ThisType::_getHashCodeOverride() +{ + return combineHash( + HashCode(typeid(*this).hash_code()), + interfaceDeclRef.getHashCode()); +} + +RefPtr<Type> ThisType::_createCanonicalTypeOverride() +{ + RefPtr<ThisType> canType = m_astBuilder->create<ThisType>(); + + // TODO: need to canonicalize the decl-ref + canType->interfaceDeclRef = interfaceDeclRef; + return canType; +} + +RefPtr<Val> ThisType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substInterfaceDeclRef = interfaceDeclRef.substituteImpl(astBuilder, subst, &diff); + + auto thisTypeSubst = findThisTypeSubstitution(subst.substitutions, substInterfaceDeclRef.getDecl()); + if (thisTypeSubst) + { + return thisTypeSubst->witness->sub; + } + + if (!diff) + return this; + + (*ioDiff)++; + + RefPtr<ThisType> substType = m_astBuilder->create<ThisType>(); + substType->interfaceDeclRef = substInterfaceDeclRef; + return substType; +} + + +} // namespace Slang diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 68935108c..14028eff6 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -13,12 +13,11 @@ class OverloadGroupType : public Type { SLANG_CLASS(OverloadGroupType) - virtual String toString() override; - -protected: - virtual RefPtr<Type> createCanonicalType() override; - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); }; // The type of an initializer-list expression (before it has @@ -27,12 +26,12 @@ class InitializerListType : public Type { SLANG_CLASS(InitializerListType) - virtual String toString() override; - -protected: - virtual RefPtr<Type> createCanonicalType() override; - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; + + // Overrides should be public so base classes can access + String _toStringOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); }; // The type of an expression that was erroneous @@ -40,13 +39,12 @@ class ErrorType : public Type { SLANG_CLASS(ErrorType) - virtual String toString() override; - -protected: - virtual RefPtr<Type> createCanonicalType() override; - virtual bool equalsImpl(Type* type) override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; - virtual HashCode getHashCode() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; // A type that takes the form of a reference to some declaration @@ -56,19 +54,20 @@ class DeclRefType : public Type DeclRef<Decl> declRef; - virtual String toString() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; - + static RefPtr<DeclRefType> create(ASTBuilder* astBuilder, DeclRef<Decl> declRef); + // Overrides should be public so base classes can access + String _toStringOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + protected: DeclRefType( DeclRef<Decl> declRef) : declRef(declRef) {} - - virtual HashCode getHashCode() override; - virtual RefPtr<Type> createCanonicalType() override; - virtual bool equalsImpl(Type* type) override; }; // Base class for types that can be used in arithmetic expressions @@ -76,8 +75,10 @@ class ArithmeticExpressionType : public DeclRefType { SLANG_ABSTRACT_CLASS(ArithmeticExpressionType) -public: - virtual BasicExpressionType* GetScalarType() = 0; + BasicExpressionType* getScalarType(); + + // Overrides should be public so base classes can access + BasicExpressionType* _getScalarTypeOverride(); }; class BasicExpressionType : public ArithmeticExpressionType @@ -86,16 +87,16 @@ class BasicExpressionType : public ArithmeticExpressionType BaseType baseType; + // Overrides should be public so base classes can access + RefPtr<Type> _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + BasicExpressionType* _getScalarTypeOverride(); + protected: BasicExpressionType( Slang::BaseType baseType) : baseType(baseType) {} - - virtual BasicExpressionType* GetScalarType() override; - virtual RefPtr<Type> createCanonicalType() override; - virtual bool equalsImpl(Type* type) override; - }; // Base type for things that are built in to the compiler, @@ -375,13 +376,12 @@ class ArrayExpressionType : public Type RefPtr<Type> baseType; RefPtr<IntVal> arrayLength; - virtual String toString() override; - -protected: - virtual RefPtr<Type> createCanonicalType() override; - virtual bool equalsImpl(Type* type) override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; - virtual HashCode getHashCode() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + HashCode _getHashCodeOverride(); }; // The "type" of an expression that resolves to a type. @@ -394,16 +394,18 @@ class TypeType : public Type // The type that this is the type of... RefPtr<Type> type; - virtual String toString() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); protected: TypeType(RefPtr<Type> type) : type(type) {} - virtual RefPtr<Type> createCanonicalType() override; - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; + }; // A vector type, e.g., `vector<T,N>` @@ -418,10 +420,9 @@ class VectorExpressionType : public ArithmeticExpressionType // The number of elements RefPtr<IntVal> elementCount; - virtual String toString() override; - -protected: - virtual BasicExpressionType* GetScalarType() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + BasicExpressionType* _getScalarTypeOverride(); }; // A matrix type, e.g., `matrix<T,R,C>` @@ -435,10 +436,9 @@ class MatrixExpressionType : public ArithmeticExpressionType RefPtr<Type> getRowType(); - virtual String toString() override; - -protected: - virtual BasicExpressionType* GetScalarType() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + BasicExpressionType* _getScalarTypeOverride(); private: RefPtr<Type> rowType; @@ -465,7 +465,7 @@ class PtrTypeBase : public BuiltinType SLANG_CLASS(PtrTypeBase) // Get the type of the pointed-to value. - Type* getValueType(); + Type* getValueType(); }; // A true (user-visible) pointer type, e.g., `T*` @@ -508,7 +508,11 @@ class NamedExpressionType : public Type DeclRef<TypeDefDecl> declRef; RefPtr<Type> innerType; - virtual String toString() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); protected: NamedExpressionType( @@ -516,9 +520,7 @@ protected: : declRef(declRef) {} - virtual RefPtr<Type> createCanonicalType() override; - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; + }; // A function type is defined by its parameter types @@ -540,13 +542,12 @@ class FuncType : public Type Type* getParamType(UInt index) { return paramTypes[index]; } Type* getResultType() { return resultType; } - virtual String toString() override; - -protected: - virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); }; // The "type" of an expression that names a generic declaration. @@ -558,17 +559,17 @@ class GenericDeclRefType : public Type DeclRef<GenericDecl> const& getDeclRef() const { return declRef; } - virtual String toString() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); protected: GenericDeclRefType( DeclRef<GenericDecl> declRef) : declRef(declRef) {} - - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; - virtual RefPtr<Type> createCanonicalType() override; }; // The "type" of a reference to a module or namespace @@ -580,12 +581,11 @@ class NamespaceType : public Type DeclRef<NamespaceDeclBase> const& getDeclRef() const { return declRef; } - virtual String toString() override; - -protected: - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; - virtual RefPtr<Type> createCanonicalType() override; + // Overrides should be public so base classes can access + String _toStringOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); }; // The concrete type for a value wrapped in an existential, accessible @@ -596,11 +596,12 @@ class ExtractExistentialType : public Type DeclRef<VarDeclBase> declRef; - virtual String toString() override; - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; - virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + // Overrides should be public so base classes can access + String _toStringOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; /// A tagged union of zero or more other types. @@ -615,11 +616,12 @@ class TaggedUnionType : public Type /// List<RefPtr<Type>> caseTypes; - virtual String toString() override; - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; - virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + // Overrides should be public so base classes can access + String _toStringOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; class ExistentialSpecializedType : public Type @@ -629,11 +631,12 @@ class ExistentialSpecializedType : public Type RefPtr<Type> baseType; ExpandedSpecializationArgs args; - virtual String toString() override; - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; - virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + // Overrides should be public so base classes can access + String _toStringOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; /// The type of `this` within a polymorphic declaration @@ -643,11 +646,12 @@ class ThisType : public Type DeclRef<InterfaceDecl> interfaceDeclRef; - virtual String toString() override; - virtual bool equalsImpl(Type* type) override; - virtual HashCode getHashCode() override; - virtual RefPtr<Type> createCanonicalType() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + // Overrides should be public so base classes can access + String _toStringOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + RefPtr<Type> _createCanonicalTypeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; } // namespace Slang diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp new file mode 100644 index 000000000..cb711a653 --- /dev/null +++ b/source/slang/slang-ast-val.cpp @@ -0,0 +1,552 @@ +// slang-ast-type.cpp +#include "slang-ast-builder.h" +#include <assert.h> +#include <typeinfo> + +#include "slang-ast-generated-macro.h" + +#include "slang-syntax.h" + +namespace Slang { + +RefPtr<Val> Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst) +{ + if (!subst) return this; + int diff = 0; + return substituteImpl(astBuilder, subst, &diff); +} + +RefPtr<Val> Val::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + SLANG_AST_NODE_VIRTUAL_CALL(Val, substituteImpl, (astBuilder, subst, ioDiff)) +} + +bool Val::equalsVal(Val* val) +{ + SLANG_AST_NODE_VIRTUAL_CALL(Val, equalsVal, (val)) +} + +String Val::toString() +{ + SLANG_AST_NODE_VIRTUAL_CALL(Val, toString, ()) +} + +HashCode Val::getHashCode() +{ + SLANG_AST_NODE_VIRTUAL_CALL(Val, getHashCode, ()) +} + +RefPtr<Val> Val::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + SLANG_UNUSED(astBuilder); + SLANG_UNUSED(subst); + SLANG_UNUSED(ioDiff); + // Default behavior is to not substitute at all + return this; +} + +bool Val::_equalsValOverride(Val* val) +{ + SLANG_UNUSED(val); + SLANG_UNEXPECTED("Val::_equalsValOverride not overridden"); + //return false; +} + +String Val::_toStringOverride() +{ + SLANG_UNEXPECTED("Val::_toStringOverride not overridden"); + //return String(); +} + +HashCode Val::_getHashCodeOverride() +{ + SLANG_UNEXPECTED("Val::_getHashCodeOverride not overridden"); + //return HashCode(0); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ConstantIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool ConstantIntVal::_equalsValOverride(Val* val) +{ + if (auto intVal = as<ConstantIntVal>(val)) + return value == intVal->value; + return false; +} + +String ConstantIntVal::_toStringOverride() +{ + return String(value); +} + +HashCode ConstantIntVal::_getHashCodeOverride() +{ + return (HashCode)value; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! GenericParamIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool GenericParamIntVal::_equalsValOverride(Val* val) +{ + if (auto genericParamVal = as<GenericParamIntVal>(val)) + { + return declRef.equals(genericParamVal->declRef); + } + return false; +} + +String GenericParamIntVal::_toStringOverride() +{ + return getText(declRef.getName()); +} + +HashCode GenericParamIntVal::_getHashCodeOverride() +{ + return declRef.getHashCode() ^ HashCode(0xFFFF); +} + +RefPtr<Val> GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff) +{ + // search for a substitution that might apply to us + for (auto s = subst.substitutions; s; s = s->outer) + { + auto genSubst = s.as<GenericSubstitution>(); + if (!genSubst) + continue; + + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genSubst->genericDecl; + if (genericDecl != declRef.getDecl()->parentDecl) + continue; + + int index = 0; + for (auto m : genericDecl->members) + { + if (m.Ptr() == declRef.getDecl()) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return genSubst->args[index]; + } + else if (auto typeParam = as<GenericTypeParamDecl>(m)) + { + index++; + } + else if (auto valParam = as<GenericValueParamDecl>(m)) + { + index++; + } + else + { + } + } + } + + // Nothing found: don't substitute. + return this; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool ErrorIntVal::_equalsValOverride(Val* val) +{ + if (auto errorIntVal = as<ErrorIntVal>(val)) + { + return true; + } + return false; +} + +String ErrorIntVal::_toStringOverride() +{ + return "<error>"; +} + +HashCode ErrorIntVal::_getHashCodeOverride() +{ + return HashCode(typeid(this).hash_code()); +} + +RefPtr<Val> ErrorIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + SLANG_UNUSED(astBuilder); + SLANG_UNUSED(subst); + SLANG_UNUSED(ioDiff); + return this; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +// TODO: should really have a `type.cpp` and a `witness.cpp` + +bool TypeEqualityWitness::_equalsValOverride(Val* val) +{ + auto otherWitness = as<TypeEqualityWitness>(val); + if (!otherWitness) + return false; + return sub->equals(otherWitness->sub); +} + +RefPtr<Val> TypeEqualityWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +{ + RefPtr<TypeEqualityWitness> rs = astBuilder->create<TypeEqualityWitness>(); + rs->sub = sub->substituteImpl(astBuilder, subst, ioDiff).as<Type>(); + rs->sup = sup->substituteImpl(astBuilder, subst, ioDiff).as<Type>(); + return rs; +} + +String TypeEqualityWitness::_toStringOverride() +{ + return "TypeEqualityWitness(" + sub->toString() + ")"; +} + +HashCode TypeEqualityWitness::_getHashCodeOverride() +{ + return sub->getHashCode(); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclaredSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool DeclaredSubtypeWitness::_equalsValOverride(Val* val) +{ + auto otherWitness = as<DeclaredSubtypeWitness>(val); + if (!otherWitness) + return false; + + return sub->equals(otherWitness->sub) + && sup->equals(otherWitness->sup) + && declRef.equals(otherWitness->declRef); +} + +RefPtr<Val> DeclaredSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +{ + if (auto genConstraintDeclRef = declRef.as<GenericTypeConstraintDecl>()) + { + auto genConstraintDecl = genConstraintDeclRef.getDecl(); + + // search for a substitution that might apply to us + for (auto s = subst.substitutions; s; s = s->outer) + { + if (auto genericSubst = as<GenericSubstitution>(s)) + { + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genericSubst->genericDecl; + if (genericDecl != genConstraintDecl->parentDecl) + continue; + + bool found = false; + Index index = 0; + for (auto m : genericDecl->members) + { + if (auto constraintParam = as<GenericTypeConstraintDecl>(m)) + { + if (constraintParam == declRef.getDecl()) + { + found = true; + break; + } + index++; + } + } + if (found) + { + (*ioDiff)++; + auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() + + genericDecl->getMembersOfType<GenericValueParamDecl>().getCount(); + SLANG_ASSERT(index + ordinaryParamCount < genericSubst->args.getCount()); + return genericSubst->args[index + ordinaryParamCount]; + } + } + else if (auto globalGenericSubst = s.as<GlobalGenericParamSubstitution>()) + { + // check if the substitution is really about this global generic type parameter + if (globalGenericSubst->paramDecl != genConstraintDecl->parentDecl) + continue; + + for (auto constraintArg : globalGenericSubst->constraintArgs) + { + if (constraintArg.decl.Ptr() != genConstraintDecl) + continue; + + (*ioDiff)++; + return constraintArg.val; + } + } + } + } + + // Perform substitution on the constituent elements. + int diff = 0; + auto substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + if (!diff) + return this; + + (*ioDiff)++; + + // If we have a reference to a type constraint for an + // associated type declaration, then we can replace it + // with the concrete conformance witness for a concrete + // type implementing the outer interface. + // + // TODO: It is a bit gross that we use `GenericTypeConstraintDecl` for + // associated types, when they aren't really generic type *parameters*, + // so we'll need to change this location in the code if we ever clean + // up the hierarchy. + // + if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(substDeclRef.decl)) + { + if (auto substAssocTypeDecl = as<AssocTypeDecl>(substTypeConstraintDecl->parentDecl)) + { + if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl)) + { + // At this point we have a constraint decl for an associated type, + // and we nee to see if we are dealing with a concrete substitution + // for the interface around that associated type. + if (auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.substitutions, interfaceDecl)) + { + // We need to look up the declaration that satisfies + // the requirement named by the associated type. + Decl* requirementKey = substTypeConstraintDecl; + RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeSubst->witness, requirementKey); + switch (requirementWitness.getFlavor()) + { + default: + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = requirementWitness.getVal(); + return satisfyingVal; + } + } + } + } + } + } + + RefPtr<DeclaredSubtypeWitness> rs = astBuilder->create<DeclaredSubtypeWitness>(); + rs->sub = substSub; + rs->sup = substSup; + rs->declRef = substDeclRef; + return rs; +} + +String DeclaredSubtypeWitness::_toStringOverride() +{ + StringBuilder sb; + sb << "DeclaredSubtypeWitness("; + sb << this->sub->toString(); + sb << ", "; + sb << this->sup->toString(); + sb << ", "; + sb << this->declRef.toString(); + sb << ")"; + return sb.ProduceString(); +} + +HashCode DeclaredSubtypeWitness::_getHashCodeOverride() +{ + return declRef.getHashCode(); +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TransitiveSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool TransitiveSubtypeWitness::_equalsValOverride(Val* val) +{ + auto otherWitness = as<TransitiveSubtypeWitness>(val); + if (!otherWitness) + return false; + + return sub->equals(otherWitness->sub) + && sup->equals(otherWitness->sup) + && subToMid->equalsVal(otherWitness->subToMid) + && midToSup.equals(otherWitness->midToSup); +} + +RefPtr<Val> TransitiveSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) +{ + int diff = 0; + + RefPtr<Type> substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); + RefPtr<Type> substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); + RefPtr<SubtypeWitness> substSubToMid = subToMid->substituteImpl(astBuilder, subst, &diff).as<SubtypeWitness>(); + DeclRef<Decl> substMidToSup = midToSup.substituteImpl(astBuilder, subst, &diff); + + // If nothing changed, then we can bail out early. + if (!diff) + return this; + + // Something changes, so let the caller know. + (*ioDiff)++; + + // TODO: are there cases where we can simplify? + // + // In principle, if either `subToMid` or `midToSub` turns into + // a reflexive subtype witness, then we could drop that side, + // and just return the other one (this would imply that `sub == mid` + // or `mid == sup` after substitutions). + // + // In the long run, is it also possible that if `sub` gets resolved + // to a concrete type *and* we decide to flatten out the inheritance + // graph into a linearized "class precedence list" stored in any + // aggregate type, then we could potentially just redirect to point + // to the appropriate inheritance decl in the original type. + // + // For now I'm going to ignore those possibilities and hope for the best. + + // In the simple case, we just construct a new transitive subtype + // witness, and we move on with life. + RefPtr<TransitiveSubtypeWitness> result = astBuilder->create<TransitiveSubtypeWitness>(); + result->sub = substSub; + result->sup = substSup; + result->subToMid = substSubToMid; + result->midToSup = substMidToSup; + return result; +} + +String TransitiveSubtypeWitness::_toStringOverride() +{ + // Note: we only print the constituent + // witnesses, and rely on them to print + // the starting and ending types. + StringBuilder sb; + sb << "TransitiveSubtypeWitness("; + sb << this->subToMid->toString(); + sb << ", "; + sb << this->midToSup.toString(); + sb << ")"; + return sb.ProduceString(); +} + +HashCode TransitiveSubtypeWitness::_getHashCodeOverride() +{ + auto hash = sub->getHashCode(); + hash = combineHash(hash, sup->getHashCode()); + hash = combineHash(hash, subToMid->getHashCode()); + hash = combineHash(hash, midToSup.getHashCode()); + return hash; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ExtractExistentialSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool ExtractExistentialSubtypeWitness::_equalsValOverride(Val* val) +{ + if (auto extractWitness = as<ExtractExistentialSubtypeWitness>(val)) + { + return declRef.equals(extractWitness->declRef); + } + return false; +} + +String ExtractExistentialSubtypeWitness::_toStringOverride() +{ + String result; + result.append("extractExistentialValue("); + result.append(declRef.toString()); + result.append(")"); + return result; +} + +HashCode ExtractExistentialSubtypeWitness::_getHashCodeOverride() +{ + return declRef.getHashCode(); +} + +RefPtr<Val> ExtractExistentialSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + auto substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); + + if (!diff) + return this; + + (*ioDiff)++; + + RefPtr<ExtractExistentialSubtypeWitness> substValue = astBuilder->create<ExtractExistentialSubtypeWitness>(); + substValue->declRef = declRef; + substValue->sub = substSub; + substValue->sup = substSup; + return substValue; +} + + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TaggedUnionSubtypeWitness !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool TaggedUnionSubtypeWitness::_equalsValOverride(Val* val) +{ + auto taggedUnionWitness = as<TaggedUnionSubtypeWitness>(val); + if (!taggedUnionWitness) + return false; + + auto caseCount = caseWitnesses.getCount(); + if (caseCount != taggedUnionWitness->caseWitnesses.getCount()) + return false; + + for (Index ii = 0; ii < caseCount; ++ii) + { + if (!caseWitnesses[ii]->equalsVal(taggedUnionWitness->caseWitnesses[ii])) + return false; + } + + return true; +} + +String TaggedUnionSubtypeWitness::_toStringOverride() +{ + String result; + result.append("TaggedUnionSubtypeWitness("); + bool first = true; + for (auto caseWitness : caseWitnesses) + { + if (!first) result.append(", "); + first = false; + + result.append(caseWitness->toString()); + } + return result; +} + +HashCode TaggedUnionSubtypeWitness::_getHashCodeOverride() +{ + HashCode hash = 0; + for (auto caseWitness : caseWitnesses) + { + hash = combineHash(hash, caseWitness->getHashCode()); + } + return hash; +} + +RefPtr<Val> TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); + auto substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); + + List<RefPtr<Val>> substCaseWitnesses; + for (auto caseWitness : caseWitnesses) + { + substCaseWitnesses.add(caseWitness->substituteImpl(astBuilder, subst, &diff)); + } + + if (!diff) + return this; + + (*ioDiff)++; + + RefPtr<TaggedUnionSubtypeWitness> substWitness = astBuilder->create<TaggedUnionSubtypeWitness>(); + substWitness->sub = substSub; + substWitness->sup = substSup; + substWitness->caseWitnesses.swapWith(substCaseWitnesses); + return substWitness; +} + + + +} // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 5345a389f..91c20e3b2 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -21,14 +21,16 @@ class ConstantIntVal : public IntVal IntegerLiteralValue value; + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + String _toStringOverride(); + HashCode _getHashCodeOverride(); + protected: ConstantIntVal(IntegerLiteralValue inValue) : value(inValue) {} - virtual bool equalsVal(Val* val) override; - virtual String toString() override; - virtual HashCode getHashCode() override; }; // The logical "value" of a reference to a generic value parameter @@ -38,15 +40,16 @@ class GenericParamIntVal : public IntVal DeclRef<VarDeclBase> declRef; + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + String _toStringOverride(); + HashCode _getHashCodeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + protected: GenericParamIntVal(DeclRef<VarDeclBase> inDeclRef) : declRef(inDeclRef) {} - - virtual bool equalsVal(Val* val) override; - virtual String toString() override; - virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; }; /// An unknown integer value indicating an erroneous sub-expression @@ -58,10 +61,11 @@ class ErrorIntVal : public IntVal // and have all `Val`s that represent ordinary values hold their // `Type` so that we can have an `ErrorVal` of any type. - virtual bool equalsVal(Val* val) override; - virtual String toString() override; - virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + String _toStringOverride(); + HashCode _getHashCodeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; // A witness to the fact that some proposition is true, encoded @@ -119,10 +123,11 @@ class TypeEqualityWitness : public SubtypeWitness { SLANG_CLASS(TypeEqualityWitness) - virtual bool equalsVal(Val* val) override; - virtual String toString() override; - virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + String _toStringOverride(); + HashCode _getHashCodeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; // A witness that one type is a subtype of another @@ -133,10 +138,11 @@ class DeclaredSubtypeWitness : public SubtypeWitness DeclRef<Decl> declRef; - virtual bool equalsVal(Val* val) override; - virtual String toString() override; - virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + String _toStringOverride(); + HashCode _getHashCodeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; // A witness that `sub : sup` because `sub : mid` and `mid : sup` @@ -150,10 +156,11 @@ class TransitiveSubtypeWitness : public SubtypeWitness // Witness that `mid : sup` DeclRef<Decl> midToSup; - virtual bool equalsVal(Val* val) override; - virtual String toString() override; - virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + String _toStringOverride(); + HashCode _getHashCodeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; // A witness taht `sub : sup` because `sub` was wrapped into @@ -165,10 +172,11 @@ class ExtractExistentialSubtypeWitness : public SubtypeWitness // The declaration of the existential value that has been opened DeclRef<VarDeclBase> declRef; - virtual bool equalsVal(Val* val) override; - virtual String toString() override; - virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + String _toStringOverride(); + HashCode _getHashCodeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; // A witness that `sub : sup`, because `sub` is a tagged union @@ -184,10 +192,12 @@ class TaggedUnionSubtypeWitness : public SubtypeWitness // List<RefPtr<Val>> caseWitnesses; - virtual bool equalsVal(Val* val) override; - virtual String toString() override; - virtual HashCode getHashCode() override; - virtual RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) override; + + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + String _toStringOverride(); + HashCode _getHashCodeOverride(); + RefPtr<Val> _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; } // namespace Slang diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 1b96aec98..c78de91a5 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -51,19 +51,6 @@ SourceLoc const& getDiagnosticPos(TypeExp const& typeExp) return typeExp.exp->loc; } -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!! BasicExpressionType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -bool BasicExpressionType::equalsImpl(Type * type) -{ - auto basicType = as<BasicExpressionType>(type); - return basicType && basicType->baseType == this->baseType; -} - -RefPtr<Type> BasicExpressionType::createCanonicalType() -{ - // A basic type is already canonical, in our setup - return this; -} // !!!!!!!!!!!!!!!!!!!!!!!!!!!!! Free functions !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -213,149 +200,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return type->equals(other); } - // BasicExpressionType - - BasicExpressionType* BasicExpressionType::GetScalarType() - { - return this; - } - - // - - Type::~Type() - { - // If the canonicalType !=nullptr AND it is not set to this (ie the canonicalType is another object) - // then it needs to be released because it's owned by this object. - if (canonicalType && canonicalType != this) - { - canonicalType->releaseReference(); - } - } - - bool Type::equals(Type* type) - { - return getCanonicalType()->equalsImpl(type->getCanonicalType()); - } - - bool Type::equalsVal(Val* val) - { - if (auto type = dynamicCast<Type>(val)) - return const_cast<Type*>(this)->equals(type); - return false; - } - - RefPtr<Val> Type::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - auto canSubst = getCanonicalType()->substituteImpl(astBuilder, subst, &diff); - - // If nothing changed, then don't drop any sugar that is applied - if (!diff) - return this; - - // If the canonical type changed, then we return a canonical type, - // rather than try to re-construct any amount of sugar - (*ioDiff)++; - return canSubst; - } - - Type* Type::getCanonicalType() - { - Type* et = const_cast<Type*>(this); - if (!et->canonicalType) - { - // TODO(tfoley): worry about thread safety here? - auto canType = et->createCanonicalType(); - et->canonicalType = canType; - - // TODO(js): That this detachs when canType == this is a little surprising. It would seem - // as if this would create a circular reference on the object, but in practice there are - // no leaks so appears correct. - // That the dtor only releases if != this, also makes it surprising. - canType.detach(); - - SLANG_ASSERT(et->canonicalType); - } - return et->canonicalType; - } - - bool ArrayExpressionType::equalsImpl(Type* type) - { - auto arrType = as<ArrayExpressionType>(type); - if (!arrType) - return false; - return (areValsEqual(arrayLength, arrType->arrayLength) && baseType->equals(arrType->baseType.Ptr())); - } - - RefPtr<Val> ArrayExpressionType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - auto elementType = baseType->substituteImpl(astBuilder, subst, &diff).as<Type>(); - auto arrlen = arrayLength->substituteImpl(astBuilder, subst, &diff).as<IntVal>(); - SLANG_ASSERT(arrlen); - if (diff) - { - *ioDiff = 1; - auto rsType = getArrayType( - astBuilder, - elementType, - arrlen); - return rsType; - } - return this; - } - - RefPtr<Type> ArrayExpressionType::createCanonicalType() - { - auto canonicalElementType = baseType->getCanonicalType(); - auto canonicalArrayType = getASTBuilder()->getArrayType( - canonicalElementType, - arrayLength); - return canonicalArrayType; - } - - HashCode ArrayExpressionType::getHashCode() - { - if (arrayLength) - return (baseType->getHashCode() * 16777619) ^ arrayLength->getHashCode(); - else - return baseType->getHashCode(); - } - Slang::String ArrayExpressionType::toString() - { - if (arrayLength) - return baseType->toString() + "[" + arrayLength->toString() + "]"; - else - return baseType->toString() + "[]"; - } - - // DeclRefType - - String DeclRefType::toString() - { - return declRef.toString(); - } - - HashCode DeclRefType::getHashCode() - { - return (declRef.getHashCode() * 16777619) ^ (HashCode)(typeid(this).hash_code()); - } - - bool DeclRefType::equalsImpl(Type * type) - { - if (auto declRefType = as<DeclRefType>(type)) - { - return declRef.equals(declRefType->declRef); - } - return false; - } - - RefPtr<Type> DeclRefType::createCanonicalType() - { - // A declaration reference is already canonical - return this; - } - // // RequirementWitness // @@ -478,117 +322,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return RequirementWitness(); } - RefPtr<Val> DeclRefType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) - { - if (!subst) return this; - - // the case we especially care about is when this type references a declaration - // of a generic parameter, since that is what we might be substituting... - if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(declRef.getDecl())) - { - // search for a substitution that might apply to us - for(auto s = subst.substitutions; s; s = s->outer) - { - auto genericSubst = s.as<GenericSubstitution>(); - if(!genericSubst) - continue; - - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genericSubst->genericDecl; - if (genericDecl != genericTypeParamDecl->parentDecl) - continue; - - int index = 0; - for (auto m : genericDecl->members) - { - if (m.Ptr() == genericTypeParamDecl) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return genericSubst->args[index]; - } - else if (auto typeParam = as<GenericTypeParamDecl>(m)) - { - index++; - } - else if (auto valParam = as<GenericValueParamDecl>(m)) - { - index++; - } - else - { - } - } - } - } - else if (auto globalGenParam = as<GlobalGenericParamDecl>(declRef.getDecl())) - { - // search for a substitution that might apply to us - for(auto s = subst.substitutions; s; s = s->outer) - { - auto genericSubst = as<GlobalGenericParamSubstitution>(s); - if(!genericSubst) - continue; - - if (genericSubst->paramDecl == globalGenParam) - { - (*ioDiff)++; - return genericSubst->actualType; - } - } - } - int diff = 0; - DeclRef<Decl> substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); - - if (!diff) - return this; - - // Make sure to record the difference! - *ioDiff += diff; - - // If this type is a reference to an associated type declaration, - // and the substitutions provide a "this type" substitution for - // the outer interface, then try to replace the type with the - // actual value of the associated type for the given implementation. - // - if(auto substAssocTypeDecl = as<AssocTypeDecl>(substDeclRef.decl)) - { - for(auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) - { - auto thisSubst = s.as<ThisTypeSubstitution>(); - if(!thisSubst) - continue; - - if(auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl)) - { - if(thisSubst->interfaceDecl == interfaceDecl) - { - // We need to look up the declaration that satisfies - // the requirement named by the associated type. - Decl* requirementKey = substAssocTypeDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisSubst->witness, requirementKey); - switch(requirementWitness.getFlavor()) - { - default: - // No usable value was found, so there is nothing we can do. - break; - - case RequirementWitness::Flavor::val: - { - auto satisfyingVal = requirementWitness.getVal(); - return satisfyingVal; - } - break; - } - } - } - } - } - - // Re-construct the type in case we are using a specialized sub-class - return DeclRefType::create(astBuilder, substDeclRef); - } + static RefPtr<Type> ExtractGenericArgType(RefPtr<Val> val) { @@ -848,318 +582,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } } - // OverloadGroupType - - String OverloadGroupType::toString() - { - return "overload group"; - } - - bool OverloadGroupType::equalsImpl(Type * /*type*/) - { - return false; - } - - RefPtr<Type> OverloadGroupType::createCanonicalType() - { - return this; - } - - HashCode OverloadGroupType::getHashCode() - { - return (HashCode)(size_t(this)); - } - - // InitializerListType - - String InitializerListType::toString() - { - return "initializer list"; - } - - bool InitializerListType::equalsImpl(Type * /*type*/) - { - return false; - } - - RefPtr<Type> InitializerListType::createCanonicalType() - { - return this; - } - - HashCode InitializerListType::getHashCode() - { - return (HashCode)(size_t(this)); - } - - // ErrorType - - String ErrorType::toString() - { - return "error"; - } - - bool ErrorType::equalsImpl(Type* type) - { - if (auto errorType = as<ErrorType>(type)) - return true; - return false; - } - - RefPtr<Type> ErrorType::createCanonicalType() - { - return this; - } - - RefPtr<Val> ErrorType::substituteImpl(ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) - { - return this; - } - - HashCode ErrorType::getHashCode() - { - return HashCode(size_t(this)); - } - - - // NamedExpressionType - - String NamedExpressionType::toString() - { - return getText(declRef.getName()); - } - - bool NamedExpressionType::equalsImpl(Type * /*type*/) - { - SLANG_UNEXPECTED("unreachable"); - UNREACHABLE_RETURN(false); - } - - RefPtr<Type> NamedExpressionType::createCanonicalType() - { - if (!innerType) - innerType = getType(m_astBuilder, declRef); - return innerType->getCanonicalType(); - } - - HashCode NamedExpressionType::getHashCode() - { - // Type equality is based on comparing canonical types, - // so the hash code for a type needs to come from the - // canonical version of the type. This really means - // that `Type::getHashCode()` should dispatch out to - // something like `Type::getHashCodeImpl()` on the - // canonical version of a type, but it is less invasive - // for now (and hopefully equivalent) to just have any - // named types automaticlaly route hash-code requests - // to their canonical type. - return getCanonicalType()->getHashCode(); - } - - // FuncType - - String FuncType::toString() - { - StringBuilder sb; - sb << "("; - UInt paramCount = getParamCount(); - for (UInt pp = 0; pp < paramCount; ++pp) - { - if (pp != 0) sb << ", "; - sb << getParamType(pp)->toString(); - } - sb << ") -> "; - sb << getResultType()->toString(); - return sb.ProduceString(); - } - - bool FuncType::equalsImpl(Type * type) - { - if (auto funcType = as<FuncType>(type)) - { - auto paramCount = getParamCount(); - auto otherParamCount = funcType->getParamCount(); - if (paramCount != otherParamCount) - return false; - - for (UInt pp = 0; pp < paramCount; ++pp) - { - auto paramType = getParamType(pp); - auto otherParamType = funcType->getParamType(pp); - if (!paramType->equals(otherParamType)) - return false; - } - - if(!resultType->equals(funcType->resultType)) - return false; - - // TODO: if we ever introduce other kinds - // of qualification on function types, we'd - // want to consider it here. - return true; - } - return false; - } - - RefPtr<Val> FuncType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - - // result type - RefPtr<Type> substResultType = resultType->substituteImpl(astBuilder, subst, &diff).as<Type>(); - - // parameter types - List<RefPtr<Type>> substParamTypes; - for( auto pp : paramTypes ) - { - substParamTypes.add(pp->substituteImpl(astBuilder, subst, &diff).as<Type>()); - } - - // early exit for no change... - if(!diff) - return this; - - (*ioDiff)++; - RefPtr<FuncType> substType = astBuilder->create<FuncType>(); - substType->resultType = substResultType; - substType->paramTypes = substParamTypes; - return substType; - } - - RefPtr<Type> FuncType::createCanonicalType() - { - // result type - RefPtr<Type> canResultType = resultType->getCanonicalType(); - - // parameter types - List<RefPtr<Type>> canParamTypes; - for( auto pp : paramTypes ) - { - canParamTypes.add(pp->getCanonicalType()); - } - - RefPtr<FuncType> canType = getASTBuilder()->create<FuncType>(); - canType->resultType = resultType; - canType->paramTypes = canParamTypes; - - return canType; - } - - HashCode FuncType::getHashCode() - { - HashCode hashCode = getResultType()->getHashCode(); - UInt paramCount = getParamCount(); - hashCode = combineHash(hashCode, Slang::getHashCode(paramCount)); - for (UInt pp = 0; pp < paramCount; ++pp) - { - hashCode = combineHash( - hashCode, - getParamType(pp)->getHashCode()); - } - return hashCode; - } - - // TypeType - - String TypeType::toString() - { - StringBuilder sb; - sb << "typeof(" << type->toString() << ")"; - return sb.ProduceString(); - } - - bool TypeType::equalsImpl(Type * t) - { - if (auto typeType = as<TypeType>(t)) - { - return t->equals(typeType->type); - } - return false; - } - - RefPtr<Type> TypeType::createCanonicalType() - { - return getASTBuilder()->getTypeType(type->getCanonicalType()); - } - - HashCode TypeType::getHashCode() - { - SLANG_UNEXPECTED("unreachable"); - UNREACHABLE_RETURN(0); - } - - // GenericDeclRefType - - String GenericDeclRefType::toString() - { - // TODO: what is appropriate here? - return "<DeclRef<GenericDecl>>"; - } - - bool GenericDeclRefType::equalsImpl(Type * type) - { - if (auto genericDeclRefType = as<GenericDeclRefType>(type)) - { - return declRef.equals(genericDeclRefType->declRef); - } - return false; - } - - HashCode GenericDeclRefType::getHashCode() - { - return declRef.getHashCode(); - } - - RefPtr<Type> GenericDeclRefType::createCanonicalType() - { - return this; - } - - // NamespaceType - - String NamespaceType::toString() - { - String result; - result.append("namespace "); - result.append(declRef.toString()); - return result; - } - - bool NamespaceType::equalsImpl(Type * type) - { - if (auto namespaceType = as<NamespaceType>(type)) - { - return declRef.equals(namespaceType->declRef); - } - return false; - } - - HashCode NamespaceType::getHashCode() - { - return declRef.getHashCode(); - } - - RefPtr<Type> NamespaceType::createCanonicalType() - { - return this; - } - - // ArithmeticExpressionType - - // VectorExpressionType - - String VectorExpressionType::toString() - { - StringBuilder sb; - sb << "vector<" << elementType->toString() << "," << elementCount->toString() << ">"; - return sb.ProduceString(); - } - - BasicExpressionType* VectorExpressionType::GetScalarType() - { - return as<BasicExpressionType>(elementType); - } - // RefPtr<GenericSubstitution> findInnerMostGenericSubstitution(Substitutions* subst) @@ -1172,309 +594,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - // MatrixExpressionType - - String MatrixExpressionType::toString() - { - StringBuilder sb; - sb << "matrix<" << getElementType()->toString() << "," << getRowCount()->toString() << "," << getColumnCount()->toString() << ">"; - return sb.ProduceString(); - } - - BasicExpressionType* MatrixExpressionType::GetScalarType() - { - return as<BasicExpressionType>(getElementType()); - } - - Type* MatrixExpressionType::getElementType() - { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); - } - - IntVal* MatrixExpressionType::getRowCount() - { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[1]); - } - - IntVal* MatrixExpressionType::getColumnCount() - { - return as<IntVal>(findInnerMostGenericSubstitution(declRef.substitutions)->args[2]); - } - - RefPtr<Type> MatrixExpressionType::getRowType() - { - if( !rowType ) - { - rowType = m_astBuilder->getVectorType(getElementType(), getColumnCount()); - } - return rowType; - } - - - - - // PtrTypeBase - - Type* PtrTypeBase::getValueType() - { - return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->args[0]); - } - - // GenericParamIntVal - - bool GenericParamIntVal::equalsVal(Val* val) - { - if (auto genericParamVal = as<GenericParamIntVal>(val)) - { - return declRef.equals(genericParamVal->declRef); - } - return false; - } - - String GenericParamIntVal::toString() - { - return getText(declRef.getName()); - } - - HashCode GenericParamIntVal::getHashCode() - { - return declRef.getHashCode() ^ HashCode(0xFFFF); - } - - RefPtr<Val> GenericParamIntVal::substituteImpl(ASTBuilder* /* astBuilder */, SubstitutionSet subst, int* ioDiff) - { - // search for a substitution that might apply to us - for(auto s = subst.substitutions; s; s = s->outer) - { - auto genSubst = s.as<GenericSubstitution>(); - if(!genSubst) - continue; - - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genSubst->genericDecl; - if (genericDecl != declRef.getDecl()->parentDecl) - continue; - - int index = 0; - for (auto m : genericDecl->members) - { - if (m.Ptr() == declRef.getDecl()) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return genSubst->args[index]; - } - else if (auto typeParam = as<GenericTypeParamDecl>(m)) - { - index++; - } - else if (auto valParam = as<GenericValueParamDecl>(m)) - { - index++; - } - else - { - } - } - } - - // Nothing found: don't substitute. - return this; - } - - // ErrorIntVal - - bool ErrorIntVal::equalsVal(Val* val) - { - if( auto errorIntVal = as<ErrorIntVal>(val) ) - { - return true; - } - return false; - } - - String ErrorIntVal::toString() - { - return "<error>"; - } - - HashCode ErrorIntVal::getHashCode() - { - return HashCode(typeid(this).hash_code()); - } - - RefPtr<Val> ErrorIntVal::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) - { - SLANG_UNUSED(astBuilder); - SLANG_UNUSED(subst); - SLANG_UNUSED(ioDiff); - return this; - } - - // Substitutions - - RefPtr<Substitutions> GenericSubstitution::applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) - { - int diff = 0; - - if(substOuter != outer) diff++; - - List<RefPtr<Val>> substArgs; - for (auto a : args) - { - substArgs.add(a->substituteImpl(astBuilder, substSet, &diff)); - } - - if (!diff) return this; - - (*ioDiff)++; - auto substSubst = astBuilder->create<GenericSubstitution>(); - substSubst->genericDecl = genericDecl; - substSubst->args = substArgs; - substSubst->outer = substOuter; - return substSubst; - } - - bool GenericSubstitution::equals(Substitutions* subst) - { - // both must be NULL, or non-NULL - if (subst == nullptr) - return false; - if (this == subst) - return true; - - auto genericSubst = as<GenericSubstitution>(subst); - if (!genericSubst) - return false; - if (genericDecl != genericSubst->genericDecl) - return false; - - Index argCount = args.getCount(); - SLANG_RELEASE_ASSERT(args.getCount() == genericSubst->args.getCount()); - for (Index aa = 0; aa < argCount; ++aa) - { - if (!args[aa]->equalsVal(genericSubst->args[aa].Ptr())) - return false; - } - - if (!outer) - return !genericSubst->outer; - - if (!outer->equals(genericSubst->outer.Ptr())) - return false; - - return true; - } - - RefPtr<Substitutions> ThisTypeSubstitution::applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) - { - int diff = 0; - - if(substOuter != outer) diff++; - - // NOTE: Must use .as because we must have a smart pointer here to keep in scope. - auto substWitness = witness->substituteImpl(astBuilder, substSet, &diff).as<SubtypeWitness>(); - - if (!diff) return this; - - (*ioDiff)++; - auto substSubst = astBuilder->create<ThisTypeSubstitution>(); - substSubst->interfaceDecl = interfaceDecl; - substSubst->witness = substWitness; - substSubst->outer = substOuter; - return substSubst; - } - - bool ThisTypeSubstitution::equals(Substitutions* subst) - { - if (!subst) - return false; - if (subst == this) - return true; - - if (auto thisTypeSubst = as<ThisTypeSubstitution>(subst)) - { - // For our purposes, two this-type substitutions are - // equivalent if they have the same type as `This`, - // even if the specific witness values they use - // might differ. - // - if(this->interfaceDecl != thisTypeSubst->interfaceDecl) - return false; - - if(!this->witness->sub->equals(thisTypeSubst->witness->sub)) - return false; - - return true; - } - return false; - } - - HashCode ThisTypeSubstitution::getHashCode() const - { - return witness->getHashCode(); - } - - RefPtr<Substitutions> GlobalGenericParamSubstitution::applySubstitutionsShallow(ASTBuilder* astBuilder, SubstitutionSet substSet, RefPtr<Substitutions> substOuter, int* ioDiff) - { - // if we find a GlobalGenericParamSubstitution in subst that references the same type_param decl - // return a copy of that GlobalGenericParamSubstitution - int diff = 0; - - if(substOuter != outer) diff++; - - auto substActualType = actualType->substituteImpl(astBuilder, substSet, &diff).as<Type>(); - - List<ConstraintArg> substConstraintArgs; - for(auto constraintArg : constraintArgs) - { - ConstraintArg substConstraintArg; - substConstraintArg.decl = constraintArg.decl; - substConstraintArg.val = constraintArg.val->substituteImpl(astBuilder, substSet, &diff); - - substConstraintArgs.add(substConstraintArg); - } - - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr<GlobalGenericParamSubstitution> substSubst = astBuilder->create<GlobalGenericParamSubstitution>(); - substSubst->paramDecl = paramDecl; - substSubst->actualType = substActualType; - substSubst->constraintArgs = substConstraintArgs; - substSubst->outer = substOuter; - return substSubst; - } - - bool GlobalGenericParamSubstitution::equals(Substitutions* subst) - { - if (!subst) - return false; - if (subst == this) - return true; - - if (auto genSubst = as<GlobalGenericParamSubstitution>(subst)) - { - if (paramDecl != genSubst->paramDecl) - return false; - if (!actualType->equalsVal(genSubst->actualType)) - return false; - if (constraintArgs.getCount() != genSubst->constraintArgs.getCount()) - return false; - for (Index i = 0; i < constraintArgs.getCount(); i++) - { - if (!constraintArgs[i].val->equalsVal(genSubst->constraintArgs[i].val)) - return false; - } - return true; - } - return false; - } - - + // DeclRefBase RefPtr<Type> DeclRefBase::substitute(ASTBuilder* astBuilder, RefPtr<Type> type) const @@ -1957,21 +1077,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return combineHash(PointerHash<1>::getHashCode(decl), substitutions.getHashCode()); } - // Val - - RefPtr<Val> Val::substitute(ASTBuilder* astBuilder, SubstitutionSet subst) - { - if (!subst) return this; - int diff = 0; - return substituteImpl(astBuilder, subst, &diff); - } - - RefPtr<Val> Val::substituteImpl(ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) - { - // Default behavior is to not substitute at all - return this; - } - // IntVal IntegerLiteralValue getIntVal(RefPtr<IntVal> val) @@ -1984,27 +1089,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return 0; } - // ConstantIntVal - - bool ConstantIntVal::equalsVal(Val* val) - { - if (auto intVal = as<ConstantIntVal>(val)) - return value == intVal->value; - return false; - } - - String ConstantIntVal::toString() - { - return String(value); - } - - HashCode ConstantIntVal::getHashCode() - { - return (HashCode) value; - } - - - // // HLSLPatchType @@ -2105,45 +1189,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return astBuilder->create<SamplerStateType>(); } - // TODO: should really have a `type.cpp` and a `witness.cpp` - - bool TypeEqualityWitness::equalsVal(Val* val) - { - auto otherWitness = as<TypeEqualityWitness>(val); - if (!otherWitness) - return false; - return sub->equals(otherWitness->sub); - } - - RefPtr<Val> TypeEqualityWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) - { - RefPtr<TypeEqualityWitness> rs = astBuilder->create<TypeEqualityWitness>(); - rs->sub = sub->substituteImpl(astBuilder, subst, ioDiff).as<Type>(); - rs->sup = sup->substituteImpl(astBuilder, subst, ioDiff).as<Type>(); - return rs; - } - - String TypeEqualityWitness::toString() - { - return "TypeEqualityWitness(" + sub->toString() + ")"; - } - - HashCode TypeEqualityWitness::getHashCode() - { - return sub->getHashCode(); - } - - bool DeclaredSubtypeWitness::equalsVal(Val* val) - { - auto otherWitness = as<DeclaredSubtypeWitness>(val); - if(!otherWitness) - return false; - - return sub->equals(otherWitness->sub) - && sup->equals(otherWitness->sup) - && declRef.equals(otherWitness->declRef); - } - RefPtr<ThisTypeSubstitution> findThisTypeSubstitution( Substitutions* substs, InterfaceDecl* interfaceDecl) @@ -2163,221 +1208,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - RefPtr<Val> DeclaredSubtypeWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) - { - if (auto genConstraintDeclRef = declRef.as<GenericTypeConstraintDecl>()) - { - auto genConstraintDecl = genConstraintDeclRef.getDecl(); - - // search for a substitution that might apply to us - for(auto s = subst.substitutions; s; s = s->outer) - { - if(auto genericSubst = as<GenericSubstitution>(s)) - { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = genericSubst->genericDecl; - if (genericDecl != genConstraintDecl->parentDecl) - continue; - - bool found = false; - Index index = 0; - for (auto m : genericDecl->members) - { - if (auto constraintParam = as<GenericTypeConstraintDecl>(m)) - { - if (constraintParam == declRef.getDecl()) - { - found = true; - break; - } - index++; - } - } - if (found) - { - (*ioDiff)++; - auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().getCount() + - genericDecl->getMembersOfType<GenericValueParamDecl>().getCount(); - SLANG_ASSERT(index + ordinaryParamCount < genericSubst->args.getCount()); - return genericSubst->args[index + ordinaryParamCount]; - } - } - else if(auto globalGenericSubst = s.as<GlobalGenericParamSubstitution>()) - { - // check if the substitution is really about this global generic type parameter - if (globalGenericSubst->paramDecl != genConstraintDecl->parentDecl) - continue; - - for(auto constraintArg : globalGenericSubst->constraintArgs) - { - if(constraintArg.decl.Ptr() != genConstraintDecl) - continue; - - (*ioDiff)++; - return constraintArg.val; - } - } - } - } - - // Perform substitution on the constituent elements. - int diff = 0; - auto substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); - auto substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); - if (!diff) - return this; - - (*ioDiff)++; - - // If we have a reference to a type constraint for an - // associated type declaration, then we can replace it - // with the concrete conformance witness for a concrete - // type implementing the outer interface. - // - // TODO: It is a bit gross that we use `GenericTypeConstraintDecl` for - // associated types, when they aren't really generic type *parameters*, - // so we'll need to change this location in the code if we ever clean - // up the hierarchy. - // - if (auto substTypeConstraintDecl = as<GenericTypeConstraintDecl>(substDeclRef.decl)) - { - if (auto substAssocTypeDecl = as<AssocTypeDecl>(substTypeConstraintDecl->parentDecl)) - { - if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl)) - { - // At this point we have a constraint decl for an associated type, - // and we nee to see if we are dealing with a concrete substitution - // for the interface around that associated type. - if(auto thisTypeSubst = findThisTypeSubstitution(substDeclRef.substitutions, interfaceDecl)) - { - // We need to look up the declaration that satisfies - // the requirement named by the associated type. - Decl* requirementKey = substTypeConstraintDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisTypeSubst->witness, requirementKey); - switch(requirementWitness.getFlavor()) - { - default: - break; - - case RequirementWitness::Flavor::val: - { - auto satisfyingVal = requirementWitness.getVal(); - return satisfyingVal; - } - } - } - } - } - } - - - - - RefPtr<DeclaredSubtypeWitness> rs = astBuilder->create<DeclaredSubtypeWitness>(); - rs->sub = substSub; - rs->sup = substSup; - rs->declRef = substDeclRef; - return rs; - } - - String DeclaredSubtypeWitness::toString() - { - StringBuilder sb; - sb << "DeclaredSubtypeWitness("; - sb << this->sub->toString(); - sb << ", "; - sb << this->sup->toString(); - sb << ", "; - sb << this->declRef.toString(); - sb << ")"; - return sb.ProduceString(); - } - - HashCode DeclaredSubtypeWitness::getHashCode() - { - return declRef.getHashCode(); - } - - // TransitiveSubtypeWitness - - bool TransitiveSubtypeWitness::equalsVal(Val* val) - { - auto otherWitness = as<TransitiveSubtypeWitness>(val); - if(!otherWitness) - return false; - - return sub->equals(otherWitness->sub) - && sup->equals(otherWitness->sup) - && subToMid->equalsVal(otherWitness->subToMid) - && midToSup.equals(otherWitness->midToSup); - } - - RefPtr<Val> TransitiveSubtypeWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int * ioDiff) - { - int diff = 0; - - RefPtr<Type> substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); - RefPtr<Type> substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); - RefPtr<SubtypeWitness> substSubToMid = subToMid->substituteImpl(astBuilder, subst, &diff).as<SubtypeWitness>(); - DeclRef<Decl> substMidToSup = midToSup.substituteImpl(astBuilder, subst, &diff); - - // If nothing changed, then we can bail out early. - if (!diff) - return this; - - // Something changes, so let the caller know. - (*ioDiff)++; - - // TODO: are there cases where we can simplify? - // - // In principle, if either `subToMid` or `midToSub` turns into - // a reflexive subtype witness, then we could drop that side, - // and just return the other one (this would imply that `sub == mid` - // or `mid == sup` after substitutions). - // - // In the long run, is it also possible that if `sub` gets resolved - // to a concrete type *and* we decide to flatten out the inheritance - // graph into a linearized "class precedence list" stored in any - // aggregate type, then we could potentially just redirect to point - // to the appropriate inheritance decl in the original type. - // - // For now I'm going to ignore those possibilities and hope for the best. - - // In the simple case, we just construct a new transitive subtype - // witness, and we move on with life. - RefPtr<TransitiveSubtypeWitness> result = astBuilder->create<TransitiveSubtypeWitness>(); - result->sub = substSub; - result->sup = substSup; - result->subToMid = substSubToMid; - result->midToSup = substMidToSup; - return result; - } - - String TransitiveSubtypeWitness::toString() - { - // Note: we only print the constituent - // witnesses, and rely on them to print - // the starting and ending types. - StringBuilder sb; - sb << "TransitiveSubtypeWitness("; - sb << this->subToMid->toString(); - sb << ", "; - sb << this->midToSup.toString(); - sb << ")"; - return sb.ProduceString(); - } - - HashCode TransitiveSubtypeWitness::getHashCode() - { - auto hash = sub->getHashCode(); - hash = combineHash(hash, sup->getHashCode()); - hash = combineHash(hash, subToMid->getHashCode()); - hash = combineHash(hash, midToSup.getHashCode()); - return hash; - } - // String DeclRefBase::toString() const @@ -2412,248 +1242,8 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return rs; } - // ExtractExistentialType - - String ExtractExistentialType::toString() - { - String result; - result.append(declRef.toString()); - result.append(".This"); - return result; - } - - bool ExtractExistentialType::equalsImpl(Type* type) - { - if( auto extractExistential = as<ExtractExistentialType>(type) ) - { - return declRef.equals(extractExistential->declRef); - } - return false; - } - - HashCode ExtractExistentialType::getHashCode() - { - return declRef.getHashCode(); - } - - RefPtr<Type> ExtractExistentialType::createCanonicalType() - { - return this; - } - - RefPtr<Val> ExtractExistentialType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr<ExtractExistentialType> substValue = astBuilder->create<ExtractExistentialType>(); - substValue->declRef = declRef; - return substValue; - } - - // ExtractExistentialSubtypeWitness - - bool ExtractExistentialSubtypeWitness::equalsVal(Val* val) - { - if( auto extractWitness = as<ExtractExistentialSubtypeWitness>(val) ) - { - return declRef.equals(extractWitness->declRef); - } - return false; - } - - String ExtractExistentialSubtypeWitness::toString() - { - String result; - result.append("extractExistentialValue("); - result.append(declRef.toString()); - result.append(")"); - return result; - } - - HashCode ExtractExistentialSubtypeWitness::getHashCode() - { - return declRef.getHashCode(); - } - - RefPtr<Val> ExtractExistentialSubtypeWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - - auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); - auto substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); - auto substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); - - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr<ExtractExistentialSubtypeWitness> substValue = astBuilder->create<ExtractExistentialSubtypeWitness>(); - substValue->declRef = declRef; - substValue->sub = substSub; - substValue->sup = substSup; - return substValue; - } - - // - // TaggedUnionType - // - - String TaggedUnionType::toString() - { - String result; - result.append("__TaggedUnion("); - bool first = true; - for( auto caseType : caseTypes ) - { - if(!first) result.append(", "); - first = false; - - result.append(caseType->toString()); - } - result.append(")"); - return result; - } - - bool TaggedUnionType::equalsImpl(Type* type) - { - auto taggedUnion = as<TaggedUnionType>(type); - if(!taggedUnion) - return false; - - auto caseCount = caseTypes.getCount(); - if(caseCount != taggedUnion->caseTypes.getCount()) - return false; - - for( Index ii = 0; ii < caseCount; ++ii ) - { - if(!caseTypes[ii]->equals(taggedUnion->caseTypes[ii])) - return false; - } - return true; - } - - HashCode TaggedUnionType::getHashCode() - { - HashCode hashCode = 0; - for( auto caseType : caseTypes ) - { - hashCode = combineHash(hashCode, caseType->getHashCode()); - } - return hashCode; - } - - RefPtr<Type> TaggedUnionType::createCanonicalType() - { - RefPtr<TaggedUnionType> canType = m_astBuilder->create<TaggedUnionType>(); - - for( auto caseType : caseTypes ) - { - auto canCaseType = caseType->getCanonicalType(); - canType->caseTypes.add(canCaseType); - } - - return canType; - } - - RefPtr<Val> TaggedUnionType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) - { - int diff = 0; - - List<RefPtr<Type>> substCaseTypes; - for( auto caseType : caseTypes ) - { - substCaseTypes.add(caseType->substituteImpl(astBuilder, subst, &diff).as<Type>()); - } - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr<TaggedUnionType> substType = astBuilder->create<TaggedUnionType>(); - substType->caseTypes.swapWith(substCaseTypes); - return substType; - } - -// -// TaggedUnionSubtypeWitness -// - - -bool TaggedUnionSubtypeWitness::equalsVal(Val* val) -{ - auto taggedUnionWitness = as<TaggedUnionSubtypeWitness>(val); - if(!taggedUnionWitness) - return false; - - auto caseCount = caseWitnesses.getCount(); - if(caseCount != taggedUnionWitness->caseWitnesses.getCount()) - return false; - - for(Index ii = 0; ii < caseCount; ++ii) - { - if(!caseWitnesses[ii]->equalsVal(taggedUnionWitness->caseWitnesses[ii])) - return false; - } - - return true; -} - -String TaggedUnionSubtypeWitness::toString() -{ - String result; - result.append("TaggedUnionSubtypeWitness("); - bool first = true; - for( auto caseWitness : caseWitnesses ) - { - if(!first) result.append(", "); - first = false; - - result.append(caseWitness->toString()); - } - return result; -} - -HashCode TaggedUnionSubtypeWitness::getHashCode() -{ - HashCode hash = 0; - for( auto caseWitness : caseWitnesses ) - { - hash = combineHash(hash, caseWitness->getHashCode()); - } - return hash; -} - -RefPtr<Val> TaggedUnionSubtypeWitness::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substSub = sub->substituteImpl(astBuilder, subst, &diff).as<Type>(); - auto substSup = sup->substituteImpl(astBuilder, subst, &diff).as<Type>(); - - List<RefPtr<Val>> substCaseWitnesses; - for( auto caseWitness : caseWitnesses ) - { - substCaseWitnesses.add(caseWitness->substituteImpl(astBuilder, subst, &diff)); - } - - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr<TaggedUnionSubtypeWitness> substWitness = astBuilder->create<TaggedUnionSubtypeWitness>(); - substWitness->sub = substSub; - substWitness->sup = substSup; - substWitness->caseWitnesses.swapWith(substCaseWitnesses); - return substWitness; -} - + + Module* getModule(Decl* decl) { for( auto dd = decl; dd; dd = dd->parentDecl ) @@ -2698,184 +1288,4 @@ char const* getGLSLNameForImageFormat(ImageFormat format) } } -// -// ExistentialSpecializedType -// - -String ExistentialSpecializedType::toString() -{ - String result; - result.append("__ExistentialSpecializedType("); - result.append(baseType->toString()); - for( auto arg : args ) - { - result.append(", "); - result.append(arg.val->toString()); - } - result.append(")"); - return result; -} - -bool ExistentialSpecializedType::equalsImpl(Type * type) -{ - auto other = as<ExistentialSpecializedType>(type); - if(!other) - return false; - - if(!baseType->equals(other->baseType)) - return false; - - auto argCount = args.getCount(); - if(argCount != other->args.getCount()) - return false; - - for( Index ii = 0; ii < argCount; ++ii ) - { - auto arg = args[ii]; - auto otherArg = other->args[ii]; - - if(!arg.val->equalsVal(otherArg.val)) - return false; - - if(!areValsEqual(arg.witness, otherArg.witness)) - return false; - } - return true; -} - -HashCode ExistentialSpecializedType::getHashCode() -{ - Hasher hasher; - hasher.hashObject(baseType); - for(auto arg : args) - { - hasher.hashObject(arg.val); - if(auto witness = arg.witness) - hasher.hashObject(witness); - } - return hasher.getResult(); -} - -RefPtr<Val> getCanonicalValue(Val* val) -{ - if(!val) - return nullptr; - if(auto type = as<Type>(val)) - { - return type->getCanonicalType(); - } - // TODO: We may eventually need/want some sort of canonicalization - // for non-type values, but for now there is nothing to do. - return val; -} - -RefPtr<Type> ExistentialSpecializedType::createCanonicalType() -{ - RefPtr<ExistentialSpecializedType> canType = m_astBuilder->create<ExistentialSpecializedType>(); - - canType->baseType = baseType->getCanonicalType(); - for( auto arg : args ) - { - ExpandedSpecializationArg canArg; - canArg.val = getCanonicalValue(arg.val); - canArg.witness = getCanonicalValue(arg.witness); - canType->args.add(canArg); - } - return canType; -} - -RefPtr<Val> substituteImpl(ASTBuilder* astBuilder, Val* val, SubstitutionSet subst, int* ioDiff) -{ - if(!val) return nullptr; - return val->substituteImpl(astBuilder, subst, ioDiff); -} - -RefPtr<Val> ExistentialSpecializedType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substBaseType = baseType->substituteImpl(astBuilder, subst, &diff).as<Type>(); - - ExpandedSpecializationArgs substArgs; - for( auto arg : args ) - { - ExpandedSpecializationArg substArg; - substArg.val = Slang::substituteImpl(astBuilder, arg.val, subst, &diff); - substArg.witness = Slang::substituteImpl(astBuilder, arg.witness, subst, &diff); - substArgs.add(substArg); - } - - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr<ExistentialSpecializedType> substType = astBuilder->create<ExistentialSpecializedType>(); - substType->baseType = substBaseType; - substType->args = substArgs; - return substType; -} - -// -// ThisType -// - -String ThisType::toString() -{ - String result; - result.append(interfaceDeclRef.toString()); - result.append(".This"); - return result; -} - -bool ThisType::equalsImpl(Type * type) -{ - auto other = as<ThisType>(type); - if(!other) - return false; - - if(!interfaceDeclRef.equals(other->interfaceDeclRef)) - return false; - - return true; -} - -HashCode ThisType::getHashCode() -{ - return combineHash( - HashCode(typeid(*this).hash_code()), - interfaceDeclRef.getHashCode()); -} - -RefPtr<Type> ThisType::createCanonicalType() -{ - RefPtr<ThisType> canType = m_astBuilder->create<ThisType>(); - - // TODO: need to canonicalize the decl-ref - canType->interfaceDeclRef = interfaceDeclRef; - return canType; -} - -RefPtr<Val> ThisType::substituteImpl(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) -{ - int diff = 0; - - auto substInterfaceDeclRef = interfaceDeclRef.substituteImpl(astBuilder, subst, &diff); - - auto thisTypeSubst = findThisTypeSubstitution(subst.substitutions, substInterfaceDeclRef.getDecl()); - if( thisTypeSubst ) - { - return thisTypeSubst->witness->sub; - } - - if(!diff) - return this; - - (*ioDiff)++; - - RefPtr<ThisType> substType = m_astBuilder->create<ThisType>(); - substType->interfaceDeclRef = substInterfaceDeclRef; - return substType; -} - } // namespace Slang diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index dff24435d..53ec57daa 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -210,6 +210,17 @@ namespace Slang } } + // + + RefPtr<ThisTypeSubstitution> findThisTypeSubstitution( + Substitutions* substs, + InterfaceDecl* interfaceDecl); + + RequirementWitness tryLookUpRequirementWitness( + ASTBuilder* astBuilder, + SubtypeWitness* subtypeWitness, + Decl* requirementKey); + // TODO: where should this live? SubstitutionSet createDefaultSubstitutions( ASTBuilder* astBuilder, diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index c3e936f8a..353da833d 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -275,8 +275,12 @@ </ItemGroup> <ItemGroup> <ClCompile Include="slang-ast-builder.cpp" /> + <ClCompile Include="slang-ast-decl.cpp" /> <ClCompile Include="slang-ast-dump.cpp" /> <ClCompile Include="slang-ast-reflect.cpp" /> + <ClCompile Include="slang-ast-substitutions.cpp" /> + <ClCompile Include="slang-ast-type.cpp" /> + <ClCompile Include="slang-ast-val.cpp" /> <ClCompile Include="slang-check-conformance.cpp" /> <ClCompile Include="slang-check-constraint.cpp" /> <ClCompile Include="slang-check-conversion.cpp" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index bec8547e1..589bd374c 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -272,12 +272,24 @@ <ClCompile Include="slang-ast-builder.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="slang-ast-decl.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="slang-ast-dump.cpp"> <Filter>Source Files</Filter> </ClCompile> <ClCompile Include="slang-ast-reflect.cpp"> <Filter>Source Files</Filter> </ClCompile> + <ClCompile Include="slang-ast-substitutions.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="slang-ast-type.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + <ClCompile Include="slang-ast-val.cpp"> + <Filter>Source Files</Filter> + </ClCompile> <ClCompile Include="slang-check-conformance.cpp"> <Filter>Source Files</Filter> </ClCompile> |
