diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-01 08:46:57 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-01 08:46:57 -0700 |
| commit | cbc1eff56057f199183bb7c17d8a360326512367 (patch) | |
| tree | 487865e928cd2ceecbb509f0bfd06aa8d9584411 /source | |
| parent | b707a07b1de3535cb0a8ccb6fe2ed4afa4a016d1 (diff) | |
Make `DifferentialPair` able to nest. (#2477)
Diffstat (limited to 'source')
29 files changed, 1176 insertions, 295 deletions
diff --git a/source/core/slang-list.h b/source/core/slang-list.h index 25687d129..187ff3109 100644 --- a/source/core/slang-list.h +++ b/source/core/slang-list.h @@ -269,7 +269,7 @@ namespace Slang void insertRange(Index id, const List<T>& list) { insertRange(id, list.m_buffer, list.m_count); } - void addRange(ArrayView<T> list) { insertRange(m_count, list.getBuffer(), list.Count()); } + void addRange(ArrayView<T> list) { insertRange(m_count, list.getBuffer(), list.getCount()); } void addRange(const T* vals, Index n) { insertRange(m_count, vals, n); } diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 75bc65562..5df9d01fe 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2434,7 +2434,7 @@ __generic<T : IComparable> [__unsafeForceInlineEarly] bool operator >=(T v0, T v1) { - return v1.lessThanOrEquals(v1); + return v1.lessThan(v1); } __generic<T : IComparable> [__unsafeForceInlineEarly] diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ad3dfe77c..674531048 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -128,13 +128,37 @@ extension vector<float, 4> : IDifferentiable } } +__magic_type(DifferentialBottomType) +__intrinsic_type($(kIROp_DifferentialBottomType)) +struct __DifferentialBottom : IDifferentiable +{ + typedef __DifferentialBottom Differential; + + __intrinsic_op($(kIROp_DifferentialBottomValue)) + static __DifferentialBottom dzero(); + + [__unsafeForceInlineEarly] + static __DifferentialBottom dadd(Differential a, Differential b) + { + return dzero(); + } + + [__unsafeForceInlineEarly] + static __DifferentialBottom dmul(This a, Differential b) + { + return dzero(); + } +} + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic<T : IDifferentiable> __magic_type(DifferentialPairType) __intrinsic_type($(kIROp_DifferentialPairType)) -struct __DifferentialPair +struct DifferentialPair : IDifferentiable { + typedef DifferentialPair<T.Differential> Differential; + typedef T.Differential DifferentialElementType; __intrinsic_op($(kIROp_MakeDifferentialPair)) __init(T _primal, T.Differential _differential); @@ -154,6 +178,31 @@ struct __DifferentialPair { return p(); } + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Differential(T.dzero(), Differential.DifferentialElementType.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return Differential( + T.dadd( + a.p(), + b.p() + ), + Differential.DifferentialElementType.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return Differential( + T.dmul(a.p(), b.p()), + Differential.DifferentialElementType.dzero()); + } }; typealias IDFloat = IFloat & IDifferentiable; @@ -171,9 +220,9 @@ namespace dstd T exp(T x); __generic<T : IDFloat> - __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx) + DifferentialPair<T> d_exp(DifferentialPair<T> dpx) { - return __DifferentialPair<T>( + return DifferentialPair<T>( exp(dpx.p()), T.dmul(exp(dpx.p()), dpx.d())); } @@ -189,9 +238,9 @@ namespace dstd T sin(T x); __generic<T : IDFloat> - __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx) + DifferentialPair<T> d_sin(DifferentialPair<T> dpx) { - return __DifferentialPair<T>( + return DifferentialPair<T>( sin(dpx.p()), T.dmul(cos(dpx.p()), dpx.d())); } @@ -207,9 +256,9 @@ namespace dstd T cos(T x); __generic<T : IDFloat> - __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx) + DifferentialPair<T> d_cos(DifferentialPair<T> dpx) { - return __DifferentialPair<T>( + return DifferentialPair<T>( cos(dpx.p()), T.dmul(-sin(dpx.p()), dpx.d())); } diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index dea02afbb..627a56152 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -254,7 +254,9 @@ private: // The actual values of the arguments List<Val* > args; public: + List<Val*>& getArgs() { return args; } const List<Val*>& getArgs() const { return args; } + // Overrides should be public so base classes can access Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff); bool _equalsOverride(Substitutions* subst); @@ -265,6 +267,12 @@ public: genericDecl = decl; } + GenericSubstitution(GenericDecl* decl, ArrayView<Val*> argVals) + { + genericDecl = decl; + args.addRange(argVals); + } + template<typename... TArgs> GenericSubstitution(GenericDecl* decl, TArgs... inArgs) { diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index f6c550d69..beee16f9c 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -141,6 +141,16 @@ Type* SharedASTBuilder::getNoneType() return m_noneType; } +Type* SharedASTBuilder::getDifferentialBottomType() +{ + if (!m_diffBottomType) + { + auto diffBottomTypeDecl = findMagicDecl("DifferentialBottomType"); + m_diffBottomType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(diffBottomTypeDecl)); + } + return m_diffBottomType; +} + SharedASTBuilder::~SharedASTBuilder() { // Release built in types.. @@ -299,13 +309,18 @@ VectorExpressionType* ASTBuilder::getVectorType( return result; } -DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* conformanceWitness) +DifferentialPairType* ASTBuilder::getDifferentialPairType( + Type* valueType, + Witness* primalIsDifferentialWitness) { auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl("DifferentialPairType")); auto typeDecl = genericDecl->inner; - auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, valueType, conformanceWitness); + auto substitutions = getOrCreate<GenericSubstitution>( + genericDecl, + valueType, + primalIsDifferentialWitness); auto declRef = DeclRef<Decl>(typeDecl, substitutions); auto rsType = DeclRefType::create(this, declRef); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index e4ea872a0..235bebfaa 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -35,6 +35,8 @@ public: Type* getNullPtrType(); /// Get the NullPtr type Type* getNoneType(); + /// Get the DifferentialBottom type. + Type* getDifferentialBottomType(); const ReflectClassInfo* findClassInfo(Name* name); SyntaxClass<NodeBase> findSyntaxClass(Name* name); @@ -79,7 +81,7 @@ protected: Type* m_dynamicType = nullptr; Type* m_nullPtrType = nullptr; Type* m_noneType = nullptr; - + Type* m_diffBottomType = nullptr; Type* m_builtinTypes[Index(BaseType::CountOf)]; Dictionary<String, Decl*> m_magicDecls; @@ -297,6 +299,7 @@ public: Type* getOverloadedType() { return m_sharedASTBuilder->m_overloadedType; } Type* getErrorType() { return m_sharedASTBuilder->m_errorType; } Type* getBottomType() { return m_sharedASTBuilder->m_bottomType; } + Type* getDifferentialBottomType() { return m_sharedASTBuilder->getDifferentialBottomType(); } Type* getStringType() { return m_sharedASTBuilder->getStringType(); } Type* getNullPtrType() { return m_sharedASTBuilder->getNullPtrType(); } Type* getNoneType() { return m_sharedASTBuilder->getNoneType(); } @@ -326,7 +329,9 @@ public: VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount); - DifferentialPairType* getDifferentialPairType(Type* valueType, Witness* conformanceWitness); + DifferentialPairType* getDifferentialPairType( + Type* valueType, + Witness* primalIsDifferentialWitness); DeclRef<InterfaceDecl> getDifferentiableInterface(); diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 87d696927..90175dd9d 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -29,6 +29,9 @@ class ContainerDecl: public Decl List<Decl*> members; SourceLoc closingSourceLoc; + // The associated scope owned by this decl. + Scope* ownedScope = nullptr; + template<typename T> FilteredMemberList<T> getMembersOfType() { diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index f2a72703e..ef6a05c71 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -126,6 +126,12 @@ class InitializerListExpr : public Expr List<Expr*> args; }; +class GetArrayLengthExpr : public Expr +{ + SLANG_AST_CLASS(GetArrayLengthExpr) + Expr* arrayExpr = nullptr; +}; + // A base class for expressions with arguments class ExprWithArgsBase : public Expr { diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 76106074f..0c1eb8d49 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1016,14 +1016,14 @@ class RequiresNVAPIAttribute : public Attribute }; /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. -class ForwardDifferentiableAttribute : public Attribute +class ForwardDifferentiableAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(ForwardDifferentiableAttribute) }; /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. -class ForwardDerivativeAttribute : public Attribute +class ForwardDerivativeAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(ForwardDerivativeAttribute) diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 4c92810d9..d4a781846 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -384,6 +384,12 @@ namespace Slang /// ReadyForConformances, + /// Any DeclRefTypes with substitutions have been fully resolved + /// to concrete type. E.g. `T.X` with `T=A` should resolve to `A.X`. + /// We need a separate pass to resolve these types because `A.X` + /// maybe synthesized and made available only after conformance checking. + TypesFullyResolved, + /// The declaration is fully checked. /// /// This step includes any validation of the declaration that is @@ -779,6 +785,12 @@ namespace Slang void toText(StringBuilder& out) const; }; + // If this is a declref to an associatedtype with a ThisTypeSubsitution, + // try to find the concrete decl that satisfies the associatedtype requirement from the + // concrete type supplied by ThisTypeSubstittution. + Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef); + + template<typename T> struct DeclRef : DeclRefBase { diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp new file mode 100644 index 000000000..e5a9ff75f --- /dev/null +++ b/source/slang/slang-ast-synthesis.cpp @@ -0,0 +1,175 @@ +#include "slang-ast-synthesis.h" + +namespace Slang +{ +Expr* ASTSynthesizer::emitBinaryExpr(UnownedStringSlice operatorToken, Expr* left, Expr* right) +{ + auto infixExpr = m_builder->create<InfixExpr>(); + infixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));; + infixExpr->arguments.add(left); + infixExpr->arguments.add(right); + return infixExpr; +} + +Expr* ASTSynthesizer::emitPrefixExpr(UnownedStringSlice operatorToken, Expr* base) +{ + auto prefixExpr = m_builder->create<PrefixExpr>(); + prefixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));; + prefixExpr->arguments.add(base); + return prefixExpr; +} + +Expr* ASTSynthesizer::emitPostfixExpr(UnownedStringSlice operatorToken, Expr* base) +{ + auto postfixExpr = m_builder->create<PostfixExpr>(); + postfixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));; + postfixExpr->arguments.add(base); + return postfixExpr; +} + +ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl* &outIndexVar) +{ + auto scopeDecl = pushVarScope()->containerDecl; + auto stmt = m_builder->create<ForStmt>(); + stmt->scopeDecl = (ScopeDecl*)scopeDecl; + auto declStmt = emitVarDeclStmt(nullptr, m_namePool->getName("S_synth_loop_index"), initVal); + stmt->initialStatement = declStmt; + outIndexVar = (VarDecl*)declStmt->decl; + auto predicateExpr = emitBinaryExpr(UnownedStringSlice("<"), emitVarExpr(outIndexVar), finalVal); + stmt->predicateExpression = predicateExpr; + stmt->sideEffectExpression = emitPrefixExpr(UnownedStringSlice("++"), emitVarExpr(outIndexVar)); + getCurrentScope().m_parentSeqStmt->stmts.add(stmt); + return stmt; +} + +Expr* ASTSynthesizer::emitVarExpr(Name* name) +{ + auto scope = getCurrentScope(); + SLANG_RELEASE_ASSERT(scope.m_scope); + auto varExpr = m_builder->create<VarExpr>(); + varExpr->name = name; + varExpr->scope = scope.m_scope; + return varExpr; +} + +Expr* ASTSynthesizer::emitVarExpr(VarDecl* varDecl) +{ + auto varExpr = m_builder->create<VarExpr>(); + varExpr->declRef = makeDeclRef(varDecl); + varExpr->type = varDecl->type.type; + return varExpr; +} + +Expr* ASTSynthesizer::emitVarExpr(VarDecl* var, Type* type) +{ + auto expr = m_builder->create<VarExpr>(); + expr->declRef = DeclRef<Decl>(var, nullptr); + expr->type.type = type; + expr->type.isLeftValue = true; + return expr; +} + +Expr* ASTSynthesizer::emitVarExpr(DeclStmt* varStmt, Type* type) +{ + auto expr = m_builder->create<VarExpr>(); + expr->declRef = DeclRef<Decl>(as<Decl>(varStmt->decl), nullptr); + expr->type.type = type; + expr->type.isLeftValue = true; + return expr; +} + +Expr* ASTSynthesizer::emitIntConst(int value) +{ + auto expr = m_builder->create<IntegerLiteralExpr>(); + expr->type.type = m_builder->getIntType(); + expr->value = value; + return expr; +} + +Expr* ASTSynthesizer::emitGetArrayLengthExpr(Expr* arrayExpr) +{ + auto expr = m_builder->create<GetArrayLengthExpr>(); + expr->arrayExpr = arrayExpr; + expr->type = m_builder->getIntType(); + return expr; +} + +Expr* ASTSynthesizer::emitMemberExpr(Expr* arrayExpr, Name* name) +{ + auto rs = m_builder->create<MemberExpr>(); + rs->baseExpression = arrayExpr; + rs->name = name; + return rs; +} + +Expr* ASTSynthesizer::emitAssignExpr(Expr* left, Expr* right) +{ + auto rs = m_builder->create<AssignExpr>(); + rs->left = left; + rs->right = right; + return rs; +} + +Expr* ASTSynthesizer::emitInvokeExpr(Expr* callee, List<Expr*>&& args) +{ + auto rs = m_builder->create<InvokeExpr>(); + rs->functionExpr = callee; + rs->arguments = _Move(args); + return rs; +} + +Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name) +{ + auto rs = m_builder->create<StaticMemberExpr>(); + auto typeExpr = m_builder->create<SharedTypeExpr>(); + auto typetype = m_builder->create<TypeType>(); + typetype->type = type; + typeExpr->type = typetype; + rs->baseExpression = typeExpr; + rs->name = name; + return rs; +} + +Expr* ASTSynthesizer::emitIndexExpr(Expr* base, Expr* index) +{ + auto rs = m_builder->create<IndexExpr>(); + rs->baseExpression = base; + rs->indexExprs.add(index); + return rs; +} + +ExpressionStmt* ASTSynthesizer::emitExprStmt(Expr* expr) +{ + auto rs = m_builder->create<ExpressionStmt>(); + _addStmtToScope(rs); + rs->expression = expr; + return rs; +} + +ReturnStmt* ASTSynthesizer::emitReturnStmt(Expr* expr) +{ + auto rs = m_builder->create<ReturnStmt>(); + rs->expression = expr; + _addStmtToScope(rs); + return rs; +} + +DeclStmt* ASTSynthesizer::emitVarDeclStmt(Type* type, Name* name, Expr* initVal) +{ + auto scope = getCurrentScope(); + SLANG_RELEASE_ASSERT(scope.m_parentSeqStmt); + SLANG_RELEASE_ASSERT(scope.m_scope); + SLANG_RELEASE_ASSERT(scope.m_scope->containerDecl); + auto varDecl = m_builder->create<VarDecl>(); + varDecl->type.type = type; + varDecl->nameAndLoc.name = name; + varDecl->initExpr = initVal; + varDecl->parentDecl = scope.m_scope->containerDecl; + varDecl->parentDecl->members.add(varDecl); + auto stmt = m_builder->create<DeclStmt>(); + stmt->decl = varDecl; + _addStmtToScope(stmt); + return stmt; +} + +} diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h new file mode 100644 index 000000000..2af890d34 --- /dev/null +++ b/source/slang/slang-ast-synthesis.h @@ -0,0 +1,147 @@ +// slang-ast-synthesis.h + +#pragma once + +#include "slang-syntax.h" + +namespace Slang { + +struct ASTEmitScope +{ + ContainerDecl* m_parent = nullptr; + SeqStmt* m_parentSeqStmt = nullptr; + Scope* m_scope = nullptr; +}; +class ASTSynthesizer +{ +private: + ASTBuilder* m_builder; + NamePool* m_namePool; + List<ASTEmitScope> m_scopeStack; +public: + ASTSynthesizer(ASTBuilder* builder, NamePool* namePool) + : m_builder(builder) + , m_namePool(namePool) + { + } + + Scope* getScope(ContainerDecl* decl) + { + for (auto container = decl; container; container = container->parentDecl) + { + if (container->ownedScope) + { + return container->ownedScope; + } + } + return nullptr; + } + + // Create a scope for `decl` and push it to scope stack + void pushScopeForContainer(ContainerDecl* decl) + { + if (decl->ownedScope) + { + // if decl already owns a scope, don't create a new one. + pushContainerScope(decl); + return; + } + + auto parentScope = getScope(decl); + decl->ownedScope = m_builder->create<Scope>(); + decl->ownedScope->parent = parentScope; + pushContainerScope(decl); + } + + // Push `decl` and its associated scope to scope stack + void pushContainerScope(ContainerDecl* decl) + { + ASTEmitScope scope = getCurrentScope(); + scope.m_parent = decl; + scope.m_scope = getScope(decl); + m_scopeStack.add(scope); + } + + Scope* pushVarScope() + { + ASTEmitScope scope = getCurrentScope(); + auto scopeDecl = m_builder->create<ScopeDecl>(); + auto newScope = m_builder->create<Scope>(); + scopeDecl->parentDecl = scope.m_parent; + if (scope.m_parent) + scope.m_parent->members.add(scopeDecl); + newScope->parent = scope.m_scope; + newScope->containerDecl = scopeDecl; + scope.m_scope = newScope; + m_scopeStack.add(scope); + return newScope; + } + + void _addStmtToScope(Stmt* stmt) + { + auto scope = getCurrentScope(); + if (scope.m_parentSeqStmt) + { + scope.m_parentSeqStmt->stmts.add(stmt); + } + } + + SeqStmt* pushSeqStmtScope() + { + ASTEmitScope scope = getCurrentScope(); + scope.m_parentSeqStmt = m_builder->create<SeqStmt>(); + m_scopeStack.add(scope); + return scope.m_parentSeqStmt; + } + + void popScope() + { + m_scopeStack.removeLast(); + } + + ASTEmitScope getCurrentScope() + { + if (m_scopeStack.getCount()) + return m_scopeStack.getLast(); + return ASTEmitScope(); + } + + ForStmt* emitFor(Expr* initVal, Expr* finalVal, VarDecl*& outIndexVar); + + Expr* emitBinaryExpr(UnownedStringSlice operatorToken, Expr* left, Expr* right); + + Expr* emitPrefixExpr(UnownedStringSlice operatorToken, Expr* base); + + Expr* emitPostfixExpr(UnownedStringSlice operatorToken, Expr* base); + + Expr* emitVarExpr(Name* name); + Expr* emitVarExpr(VarDecl* var); + Expr* emitVarExpr(VarDecl* var, Type* type); + Expr* emitVarExpr(DeclStmt* varStmt, Type* type); + + Expr* emitIntConst(int value); + + Expr* emitGetArrayLengthExpr(Expr* arrayExpr); + + Expr* emitMemberExpr(Expr* base, Name* name); + Expr* emitMemberExpr(Type* base, Name* name); + + Expr* emitIndexExpr(Expr* base, Expr* index); + + Expr* emitAssignExpr(Expr* left, Expr* right); + ExpressionStmt* emitAssignStmt(Expr* left, Expr* right) + { + return emitExprStmt(emitAssignExpr(left, right)); + } + + Expr* emitInvokeExpr(Expr* callee, List<Expr*>&& args); + + DeclStmt* emitVarDeclStmt(Type* type, Name* name = nullptr, Expr* initVal = nullptr); + + ExpressionStmt* emitExprStmt(Expr* expr); + + ReturnStmt* emitReturnStmt(Expr* expr); + +}; + +} // namespace Slang diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 480589af4..a869c95a7 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -169,6 +169,27 @@ Val* BottomType::_substituteImplOverride( HashCode BottomType::_getHashCodeOverride() { return HashCode(size_t(this)); } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DifferentialBottomType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +void DifferentialBottomType::_toTextOverride(StringBuilder& out) { out << toSlice("diff_bottom"); } + +bool DifferentialBottomType::_equalsImplOverride(Type* type) +{ + if (auto bottomType = as<DifferentialBottomType>(type)) + return true; + return false; +} + +Type* DifferentialBottomType::_createCanonicalTypeOverride() { return this; } + +Val* DifferentialBottomType::_substituteImplOverride( + ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) +{ + return this; +} + +HashCode DifferentialBottomType::_getHashCodeOverride() { return HashCode(size_t(this)); } + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void DeclRefType::_toTextOverride(StringBuilder& out) @@ -193,6 +214,7 @@ bool DeclRefType::_equalsImplOverride(Type * type) Type* DeclRefType::_createCanonicalTypeOverride() { // A declaration reference is already canonical + declRef.substitute(this->getASTBuilder(), this); return this; } @@ -223,39 +245,8 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe // the outer interface, then try to replace the type with the // actual value of the associated type for the given implementation. // - if (auto substAssocTypeDecl = as<AssocTypeDecl>(substDeclRef.decl)) - { - for (auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) - { - auto thisSubst = as<ThisTypeSubstitution>(s); - if (!thisSubst) - continue; - - if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl)) - { - if (thisSubst->interfaceDecl == interfaceDecl) - { - // We need to look up the declaration that satisfies - // the requirement named by the associated type. - Decl* requirementKey = substAssocTypeDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisSubst->witness, requirementKey); - switch (requirementWitness.getFlavor()) - { - default: - // No usable value was found, so there is nothing we can do. - break; - - case RequirementWitness::Flavor::val: - { - auto satisfyingVal = requirementWitness.getVal(); - return satisfyingVal; - } - break; - } - } - } - } - } + if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, substDeclRef)) + return satisfyingVal; // Re-construct the type in case we are using a specialized sub-class return DeclRefType::create(astBuilder, substDeclRef); diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 895b64f35..c7ce21cb0 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -86,6 +86,20 @@ protected: {} }; +// The bottom/empty type as a result of Differentiating a Differential. +class DifferentialBottomType : public DeclRefType +{ + SLANG_AST_CLASS(DifferentialBottomType) + + // Overrides should be public so base classes can access + void _toTextOverride(StringBuilder& out); + Type* _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); +}; + + // Base class for types that can be used in arithmetic expressions class ArithmeticExpressionType : public DeclRefType { diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index a8ceaa716..87e89ef18 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -618,6 +618,41 @@ Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, return substWitness; } +bool DifferentialBottomSubtypeWitness::_equalsValOverride(Val* val) +{ + auto otherDiffBottomWitness = as<DifferentialBottomSubtypeWitness>(val); + if (!otherDiffBottomWitness) + return false; + + return otherDiffBottomWitness->sub && otherDiffBottomWitness->sub->equals(sub); +} + +void DifferentialBottomSubtypeWitness::_toTextOverride(StringBuilder& out) +{ + out << "DifferentialBottomSubtypeWitness(" << sub << ")"; +} + +HashCode DifferentialBottomSubtypeWitness::_getHashCodeOverride() +{ + return combineHash(3892, sub->getHashCode()); +} + +Val* DifferentialBottomSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff)); + if (!diff) + return this; + + *ioDiff += diff; + + DifferentialBottomSubtypeWitness* substWitness = + astBuilder->create<DifferentialBottomSubtypeWitness>(substSub, substSup); + return substWitness; +} + bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val) { if (auto other = as<ConjunctionSubtypeWitness>(val)) diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 5f72e58c8..b52984f8b 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -399,6 +399,24 @@ class DynamicSubtypeWitness : public SubtypeWitness SLANG_AST_CLASS(DynamicSubtypeWitness) }; + /// A witness of the fact that any type can be viewed as a subtype of DifferentialBottom. +class DifferentialBottomSubtypeWitness : public SubtypeWitness +{ + SLANG_AST_CLASS(DifferentialBottomSubtypeWitness) + + DifferentialBottomSubtypeWitness(Type* inSub, Type* inSup) + { + sub = inSub; + sup = inSup; + } + + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + void _toTextOverride(StringBuilder& out); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); +}; + /// A witness that `T : L & R` because `T : L` and `T : R` class ConjunctionSubtypeWitness : public SubtypeWitness { diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 2c9977082..eb072e9dd 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -33,7 +33,6 @@ namespace Slang else return conjunction->rightWitness; } - ExtractFromConjunctionSubtypeWitness* simplExtractFromConjunction = builder->create<ExtractFromConjunctionSubtypeWitness>(); simplExtractFromConjunction->sub = extractFromConjunction->sub; simplExtractFromConjunction->sup = extractFromConjunction->sup; @@ -145,7 +144,7 @@ namespace Slang m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions); - TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, bb->sup, declaredWitness); + TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(); transitiveWitness->sub = subType; transitiveWitness->sup = bb->sup; transitiveWitness->midToSup = declaredWitness; @@ -379,6 +378,45 @@ namespace Slang } } } + + // If a generic type parameter does not declare itself to conform to `IDifferentiable`, + // we treat it as a subtype of `DifferentialBottom` to make it conform to `IDifferentiable`. + // Note: we only consider this option for `originalSubType` so a type that implements `IDifferential` but + // inherits from some other non differentiable types don't get to inherit `DifferentialBottom`. + if (m_astBuilder->isDifferentiableInterfaceAvailable() && + subType == originalSubType && + superTypeDeclRef.getDecl() == m_astBuilder->getDifferentiableInterface()) + { + if (as<GenericTypeParamDecl>(declRefType->declRef.getDecl()) || + as<AssocTypeDecl>(declRefType->declRef.getDecl())) + { + auto sup = DeclRefType::create(m_astBuilder, superTypeDeclRef); + auto differentialBottomType = as<DeclRefType>(m_astBuilder->getDifferentialBottomType()); + auto container = differentialBottomType->declRef.as<ContainerDecl>().getDecl(); + SLANG_RELEASE_ASSERT(container); + auto inheritanceDecl = container->getMembersOfType<InheritanceDecl>().getFirst(); + auto witnessDifferentialBottomIsIDifferentiable = + m_astBuilder->getOrCreate<DeclaredSubtypeWitness>( + m_astBuilder->getDifferentialBottomType(), + sup, + inheritanceDecl, + nullptr); + + auto witnessSubIsDifferentialBottom = + m_astBuilder->getOrCreate<DifferentialBottomSubtypeWitness>( + subType, differentialBottomType); + + TransitiveSubtypeWitness* transitiveWitness = + m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>( + witnessSubIsDifferentialBottom, witnessDifferentialBottomIsIDifferentiable); + transitiveWitness->sub = subType; + transitiveWitness->sup = sup; + transitiveWitness->midToSup = witnessDifferentialBottomIsIDifferentiable; + transitiveWitness->subToMid = witnessSubIsDifferentialBottom; + *outWitness = transitiveWitness; + return true; + } + } } else if (auto extractExistentialType = as<ExtractExistentialType>(subType)) { diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index e56d63f91..5e84e170b 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1169,9 +1169,6 @@ namespace Slang fromExpr); } - // If we coerced to a differentiable type, log it. - maybeRegisterDifferentiableType(m_astBuilder, expr->type); - return expr; } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 457ae229b..f60fbcc2c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -11,9 +11,8 @@ // and when things get checked. #include "slang-lookup.h" - #include "slang-syntax.h" - +#include "slang-ast-synthesis.h" #include <limits> namespace Slang @@ -166,6 +165,65 @@ namespace Slang void visitExtensionDecl(ExtensionDecl* decl); }; + struct SemanticsDeclTypeResolutionVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor<SemanticsDeclTypeResolutionVisitor> + { + SemanticsDeclTypeResolutionVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + Val* resolveVal(Val* val); + Type* resolveType(Type* type) + { + return (Type*)resolveVal(type); + } + + void visitTypeExp(TypeExp& exp) + { + exp.type = resolveType(exp.type); + } + + void visitVarDeclBase(VarDeclBase* varDecl) + { + visitTypeExp(varDecl->type); + } + + void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) + { + visitTypeExp(decl->sup); + } + + void visitTypeDefDecl(TypeDefDecl* decl) + { + visitTypeExp(decl->type); + } + + void visitGenericTypeParamDecl(GenericTypeParamDecl* paramDecl) + { + visitTypeExp(paramDecl->initType); + } + + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + visitTypeExp(inheritanceDecl->base); + } + + void visitCallableDecl(CallableDecl* decl) + { + visitTypeExp(decl->returnType); + visitTypeExp(decl->errorType); + } + + void visitPropertyDecl(PropertyDecl* decl) + { + visitTypeExp(decl->type); + } + }; + struct SemanticsDeclBodyVisitor : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclBodyVisitor> @@ -1363,27 +1421,30 @@ namespace Slang bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness( ConformanceCheckingContext* context, - DeclRef<Decl> requirementDeclRef, + DeclRef<AssocTypeDecl> requirementDeclRef, RefPtr<WitnessTable> witnessTable) { - // We currently can't handle generic types. - if (GetOuterGeneric(context->parentDecl) != nullptr) - { - return false; - } - + ASTSynthesizer synth(m_astBuilder, getNamePool()); Decl* existingDecl = nullptr; AggTypeDecl* aggTypeDecl = nullptr; if (context->parentDecl->getMemberDictionary().TryGetValue(requirementDeclRef.getName(), existingDecl)) { - aggTypeDecl = as<AggTypeDecl>(existingDecl); - SLANG_RELEASE_ASSERT(aggTypeDecl); - // Remove the `ToBeSynthesizedModifier`. - if (as<ToBeSynthesizedModifier>(aggTypeDecl->modifiers.first)) + if (as<ToBeSynthesizedModifier>(existingDecl->modifiers.first)) { - aggTypeDecl->modifiers.first = aggTypeDecl->modifiers.first->next; + existingDecl->modifiers.first = existingDecl->modifiers.first->next; } + else + { + // The user has defined an associatedtype explicitly but that we reach here because + // that type failed to satisfy the `IDifferential` requirement. + // We stop the synthesis and let the follow-up logic to report a diagnostic. + return false; + } + + aggTypeDecl = as<AggTypeDecl>(existingDecl); + SLANG_RELEASE_ASSERT(aggTypeDecl); + synth.pushContainerScope(aggTypeDecl); } else { @@ -1393,15 +1454,12 @@ namespace Slang aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; context->parentDecl->invalidateMemberDictionary(); + synth.pushScopeForContainer(aggTypeDecl); } - // TODO: if we want to make the synthesized type itself to be differentiable, - // add an inheritance decl here. Need to be careful to avoid infinite recursion - // trying to synthesize the higher order differential types. - // Helper function to add a `diffType` field into the synthesized type for the original // `member`. - auto differentialType = GetTypeForDeclRef(makeDeclRef(aggTypeDecl), context->parentDecl->loc); + auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl)); auto addDiffMember = [&](Decl* member, Type* diffMemberType) { // If the field is differentiable, add a corresponding field in the associated Differential type. @@ -1452,12 +1510,35 @@ namespace Slang addDiffMember(member, diffType); } - // In the future when the Differential type itself needs to conform to some interface, - // this is the place to synthesize requirements for them. addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>()); - auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, nullptr); - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); - return true; + + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution + // from requirementDeclRef to get the generic substitution for outer generic parameters, and + // apply it to the newly synthesized decl. + SubstitutionSet substSet; + if (auto thisTypeSusbt = findThisTypeSubstitution( + requirementDeclRef.substitutions, + as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) + { + if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) + { + substSet = declRefType->declRef.substitutions; + } + } + + auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet); + + if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) + { + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); + return true; + } + + // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed. + // If not, there is something wrong with the code synthesis logic. For now we just return false + // instead of crashing so the user can work around the issues. + return false; } void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*) @@ -2242,22 +2323,8 @@ namespace Slang witnessTable); } - bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement( - Type* satisfyingType, - DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, - RefPtr<WitnessTable> witnessTable) + bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement(Type* satisfyingType, DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, RefPtr<WitnessTable> witnessTable) { - if (auto declRefType = as<DeclRefType>(satisfyingType)) - { - // If we are seeing a placeholder that awaits synthesis, return false now to trigger - // auto synthesis. - if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>()) - return false; - } - // We need to confirm that the chosen type `satisfyingType`, - // meets all the constraints placed on the associated type - // requirement `requiredAssociatedTypeDeclRef`. - // // We will enumerate the type constraints placed on the // associated type and see if they can be satisfied. // @@ -2269,7 +2336,7 @@ namespace Slang // Perform a search for a witness to the subtype relationship. auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); - if(witness) + if (witness) { // If a subtype witness was found, then the conformance // appears to hold, and we can satisfy that requirement. @@ -2282,6 +2349,30 @@ namespace Slang conformance = false; } } + return conformance; + } + + bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement( + Type* satisfyingType, + DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, + RefPtr<WitnessTable> witnessTable) + { + if (auto declRefType = as<DeclRefType>(satisfyingType)) + { + // If we are seeing a placeholder that awaits synthesis, return false now to trigger + // auto synthesis. + if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>()) + return false; + } + // We need to confirm that the chosen type `satisfyingType`, + // meets all the constraints placed on the associated type + // requirement `requiredAssociatedTypeDeclRef`. + // + // We will enumerate the type constraints placed on the + // associated type and see if they can be satisfied. + // + bool conformance = doesTypeSatisfyAssociatedTypeConstraintRequirement( + satisfyingType, requiredAssociatedTypeDeclRef, witnessTable); // TODO: if any conformance check failed, we should probably include // that in an error message produced about not satisfying the requirement. @@ -3122,12 +3213,43 @@ namespace Slang return false; } + Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List<Expr*>&& args, int nestingLevel = 0) + { + if (nestingLevel > 16) + return nullptr; + + // If field type is an array, assign each element individually. + if (auto arrayType = as<ArrayExpressionType>(leftType)) + { + VarDecl* indexVar = nullptr; + auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar); + auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar)); + for (auto& arg : args) + { + arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); + } + auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->baseType, innerLeft, _Move(args), nestingLevel + 1); + synth.popScope(); + if (!assignStmt) + return nullptr; + forStmt->statement = assignStmt; + return forStmt; + } + + auto callee = synth.emitMemberExpr(leftType, funcName); + return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args))); + } + bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( ConformanceCheckingContext* context, DeclRef<Decl> requirementDeclRef, RefPtr<WitnessTable> witnessTable) { - // This method implements a general code synthesis pattern. + // We support two cases of synthesis here. + // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`. + // In this case we just trivially return `DifferentialBottom` in all synthesized methods. + // Case 2 is that the `Differential` type contains members corresponding to each primal member. + // We will apply a general code synthesis pattern to reflect that structure. // For requirement of the form: // ``` // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) @@ -3145,104 +3267,123 @@ namespace Slang // return result; // } // ``` + + // First we need to make sure the associated `Differential` type requirement is satisfied. + bool hasDifferentialAssocType = false; + for (auto existingEntry : witnessTable->requirementList) + { + if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementAttribute>()) + { + if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType && + existingEntry.Value.getFlavor() != RequirementWitness::Flavor::none) + { + hasDifferentialAssocType = true; + } + } + } + if (!hasDifferentialAssocType) + return false; + + ASTSynthesizer synth(m_astBuilder, getNamePool()); List<Expr*> synArgs; ThisExpr* synThis = nullptr; auto synFunc = synthesizeMethodSignatureForRequirementWitness( context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis); - + synFunc->parentDecl = context->parentDecl; + synth.pushContainerScope(synFunc); auto blockStmt = m_astBuilder->create<BlockStmt>(); synFunc->body = blockStmt; - auto seqStmt = m_astBuilder->create<SeqStmt>(); + auto seqStmt = synth.pushSeqStmtScope(); blockStmt->body = seqStmt; - // Create a variable for return value. - auto scopeDecl = m_astBuilder->create<ScopeDecl>(); - synFunc->members.add(scopeDecl); - scopeDecl->parentDecl = synFunc; - auto varStmt = m_astBuilder->create<DeclStmt>(); - seqStmt->stmts.add(varStmt); - - auto returnVar = m_astBuilder->create<VarDecl>(); - returnVar->parentDecl = scopeDecl; - scopeDecl->members.add(returnVar); - - returnVar->type.type = synFunc->returnType.type; - returnVar->nameAndLoc.name = getName("result"); - varStmt->decl = returnVar; - auto resultVarExpr = m_astBuilder->create<VarExpr>(); - resultVarExpr->declRef = makeDeclRef(returnVar); - resultVarExpr->type.type = synFunc->returnType.type; - resultVarExpr->type.isLeftValue = true; - - for (auto member : context->parentDecl->members) - { - auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); - if (!derivativeAttr) - continue; - auto varMember = as<VarDeclBase>(member); - if (!varMember) - continue; - ensureDecl(varMember, DeclCheckState::ReadyForReference); - auto memberType = varMember->getType(); - auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); - if (!diffMemberType) - continue; + if (synFunc->returnType.type->equals(m_astBuilder->getDifferentialBottomType())) + { + // Trivial case, the `Differential` type is `DifferentialBottom`. + // We will just return `DifferentialBottom.dzero()`. + auto resultExpr = m_astBuilder->create<InvokeExpr>(); + auto dzeroMember = m_astBuilder->create<StaticMemberExpr>(); + auto base = m_astBuilder->create<SharedTypeExpr>(); + auto typetype = m_astBuilder->create<TypeType>(); + typetype->type = m_astBuilder->getDifferentialBottomType(); + base->type.type = typetype; + dzeroMember->baseExpression = base; + dzeroMember->name = getName("dzero"); + resultExpr->functionExpr = dzeroMember; + auto synReturn = m_astBuilder->create<ReturnStmt>(); + synReturn->expression = resultExpr; + seqStmt->stmts.add(synReturn); + } + else + { + // The general case. + // Create a variable for return value. + synth.pushVarScope(); + auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result")); + auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type); - // Construct reference exprs to the member's corresponding fields in each parameter. - List<Expr*> paramFields; - int paramIndex = 0; - for (auto arg : synArgs) - { - auto memberExpr = m_astBuilder->create<MemberExpr>(); - memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); - paramFields.add(memberExpr); - paramIndex++; - } - - // Invoke the method for the field. - auto callee = m_astBuilder->create<StaticMemberExpr>(); - auto baseSharedType = m_astBuilder->create<SharedTypeExpr>(); - auto baseSharedTypeType = m_astBuilder->create<TypeType>(); - baseSharedTypeType->type = memberType; - baseSharedType->type = baseSharedTypeType; - baseSharedType->base.type = memberType; - callee->baseExpression = baseSharedType; - callee->name = requirementDeclRef.getName(); - callee->loc = synFunc->loc; - auto invokeExpr = m_astBuilder->create<InvokeExpr>(); - invokeExpr->functionExpr = callee; - invokeExpr->arguments = _Move(paramFields); - - // Assign the value to resultVar. - auto leftVal = m_astBuilder->create<MemberExpr>(); - leftVal->baseExpression = resultVarExpr; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` - // is Differential type. - leftVal->name = varMember->getName(); - - auto assignExpr = m_astBuilder->create<AssignExpr>(); - assignExpr->left = leftVal; - assignExpr->right = invokeExpr; - auto assignStmt = m_astBuilder->create<ExpressionStmt>(); - assignStmt->expression = assignExpr; - seqStmt->stmts.add(assignStmt); - } - - // TODO: synthesize assignments for inherited members here. - - auto synReturn = m_astBuilder->create<ReturnStmt>(); - synReturn->expression = resultVarExpr; - seqStmt->stmts.add(synReturn); + for (auto member : context->parentDecl->members) + { + auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); + if (!derivativeAttr) + continue; + auto varMember = as<VarDeclBase>(member); + if (!varMember) + continue; + ensureDecl(varMember, DeclCheckState::ReadyForReference); + auto memberType = varMember->getType(); + auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); + if (!diffMemberType) + continue; - synFunc->parentDecl = context->parentDecl; + // Construct reference exprs to the member's corresponding fields in each parameter. + List<Expr*> paramFields; + int paramIndex = 0; + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + paramIndex++; + } + + // Invoke the method for the field and assign the value to resultVar. + // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` + // is Differential type. + auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); + if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields))) + return false; + } + + // TODO: synthesize assignments for inherited members here. + + auto synReturn = m_astBuilder->create<ReturnStmt>(); + synReturn->expression = resultVarExpr; + seqStmt->stmts.add(synReturn); + } + context->parentDecl->members.add(synFunc); context->parentDecl->invalidateMemberDictionary(); addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>()); - witnessTable->add(requirementDeclRef, RequirementWitness(makeDeclRef(synFunc))); + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized method here in order to fill into the witness table. + // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the + // generic substitution for outer generic parameters, and apply it here. + SubstitutionSet substSet; + if (auto thisTypeSusbt = findThisTypeSubstitution( + requirementDeclRef.substitutions, + as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) + { + if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) + { + substSet = declRefType->declRef.substitutions; + } + } + + witnessTable->add(requirementDeclRef, RequirementWitness(DeclRef<Decl>(synFunc, substSet))); return true; } @@ -3801,7 +3942,10 @@ namespace Slang // be required to implement all interface requirements, // just with `abstract` methods that replicate things? // (That's what C# does). - for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>()) + + // Make a copy of inhertanceDecls firstsince `checkConformance` may modify decl->members. + auto inheritanceDecls = decl->getMembersOfType<InheritanceDecl>().toList(); + for (auto inheritanceDecl : inheritanceDecls) { checkConformance(type, inheritanceDecl, decl); } @@ -5230,7 +5374,7 @@ namespace Slang void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { - if (decl->findModifier<ForwardDifferentiableAttribute>()) + if (decl->findModifier<DifferentiableAttribute>()) { this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); } @@ -6274,6 +6418,10 @@ namespace Slang SemanticsDeclConformancesVisitor(shared).dispatch(decl); break; + case DeclCheckState::TypesFullyResolved: + SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl); + break; + case DeclCheckState::Checked: SemanticsDeclBodyVisitor(shared).dispatch(decl); break; @@ -6325,4 +6473,40 @@ namespace Slang return result; } + Val* SemanticsDeclTypeResolutionVisitor::resolveVal(Val* val) + { + if (auto declRefType = as<DeclRefType>(val)) + { + if (auto concreteType = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(m_astBuilder, declRefType->declRef)) + return as<Type>(concreteType); + for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst=subst->outer) + { + if (auto genericSubst = as<GenericSubstitution>(subst)) + { + ShortList<Val*> newArgs; + for (auto& arg : genericSubst->getArgs()) + { + arg = resolveVal(arg); + SLANG_RELEASE_ASSERT(arg); + } + } + } + } + else if (auto subtypeWitness = as<SubtypeWitness>(val)) + { + auto sub = as<Type>(resolveVal(subtypeWitness->sub)); + auto sup = as<Type>(resolveVal(subtypeWitness->sup)); + if (sub && sup) + { + if (sub != subtypeWitness->sub || sup != subtypeWitness->sup) + { + auto newVal = tryGetSubtypeWitness(as<Type>(sub), as<Type>(sup)); + if (newVal) + val = newVal; + } + } + } + return val; + } + } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index fe37f5099..251849ede 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1410,6 +1410,19 @@ namespace Slang return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value); } + if (auto arrayLengthExpr = expr.as<GetArrayLengthExpr>()) + { + if (arrayLengthExpr.getExpr()->arrayExpr && arrayLengthExpr.getExpr()->arrayExpr->type) + { + auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts()); + if (auto arrayType = as<ArrayExpressionType>(type)) + { + if (auto val = as<IntVal>(arrayType->arrayLength)) + return val; + } + } + } + // it is possible that we are referring to a generic value param if (auto declRefExpr = expr.as<DeclRefExpr>()) { @@ -1871,14 +1884,42 @@ namespace Slang arg = CheckTerm(arg); } - return CheckInvokeExprWithCheckedOperands(expr); + // If we are in a differentiable function, register differential witness tables involved in + // this call. + if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>()) + { + for (auto& arg : expr->arguments) + { + maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); + } + } + + auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr); + + if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>()) + { + if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr)) + { + // Register types for final resolved invoke arguments again. + for (auto& arg : expr->arguments) + { + maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); + } + } + maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type); + } + return checkedExpr; } Expr* SemanticsExprVisitor::visitVarExpr(VarExpr *expr) { // If we've already resolved this expression, don't try again. if (expr->declRef) + { + if (!expr->type) + expr->type = GetTypeForDeclRef(expr->declRef, expr->loc); return expr; + } expr->type = QualType(m_astBuilder->getErrorType()); auto lookupResult = lookUp( @@ -1908,63 +1949,56 @@ namespace Slang return expr; } - Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) + Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) { // Check for type modifiers like 'out' and 'inout'. We need to differentiate the // nested type. // if (auto primalOutType = as<OutType>(primalType)) { - return builder->getOutType(_toDifferentialParamType(builder, primalOutType->getValueType())); + return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType())); } else if (auto primalInOutType = as<InOutType>(primalType)) { - return builder->getInOutType(_toDifferentialParamType(builder, primalInOutType->getValueType())); + return m_astBuilder->getInOutType(_toDifferentialParamType(primalInOutType->getValueType())); } + return getDifferentialPairType(primalType); + } + Type* SemanticsVisitor::getDifferentialPairType(Type* primalType) + { // Get a reference to the builtin 'IDifferentiable' interface - auto differentiableInterface = builder->getDifferentiableInterface(); + auto differentiableInterface = m_astBuilder->getDifferentiableInterface(); + auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)); // Check if the provided type inherits from IDifferentiable. // If not, return the original type. - if (auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface))) - return builder->getDifferentialPairType(primalType, conformanceWitness); + if (conformanceWitness) + return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); else return primalType; - } - Type* SemanticsVisitor::_toJVPReturnType(ASTBuilder* builder, Type* primalType) - { - if (auto conformanceWitness = - as<Witness>(tryGetInterfaceConformanceWitness( - primalType, - builder->getDifferentiableInterface()))) - return builder->getDifferentialPairType(primalType, conformanceWitness); - else - return primalType; - } - - Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType) + Type* SemanticsVisitor::processJVPFuncType(FuncType* originalType) { // Resolve JVP type here. // Note that this type checking needs to be in sync with // the auto-generation logic in slang-ir-jvp-diff.cpp - FuncType* jvpType = builder->create<FuncType>(); + FuncType* jvpType = m_astBuilder->create<FuncType>(); // The JVP return type is float if primal return type is float // void otherwise. // - jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType()); + jvpType->resultType = getDifferentialPairType(originalType->getResultType()); // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType())); + SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); jvpType->errorType = originalType->errorType; for (UInt i = 0; i < originalType->getParamCount(); i++) { - if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i))) + if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) jvpType->paramTypes.add(jvpParamType); } @@ -1978,6 +2012,15 @@ namespace Slang // Check/Resolve inner function declaration. expr->baseFunction = CheckTerm(expr->baseFunction); + // Register parameter types. + if (auto funcType = as<FuncType>(expr->baseFunction->type.type)) + { + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + maybeRegisterDifferentiableType(m_astBuilder, funcType->getParamType(i)); + } + } + // For now we only support using higher order expr as callee in an invoke expr. // The actual type of the higher order function will be derived during resolve invoke. expr->type = m_astBuilder->getBottomType(); @@ -1985,6 +2028,29 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) + { + expr->arrayExpr = CheckTerm(expr->arrayExpr); + if (auto arrType = as<ArrayExpressionType>(expr->arrayExpr->type)) + { + expr->type = m_astBuilder->getIntType(); + if (!arrType->arrayLength) + { + getSink()->diagnose(expr, Diagnostics::invalidArraySize); + } + } + else + { + if (!as<ErrorType>(expr->arrayExpr->type)) + { + getSink()->diagnose( + expr, Diagnostics::typeMismatch, "array", expr->arrayExpr->type); + } + expr->type = m_astBuilder->getErrorType(); + } + return expr; + } + Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr) { // Check the term we are applying first diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 821f785f5..33455e42d 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -794,15 +794,12 @@ namespace Slang bool shouldSkipChecking(Decl* decl, DeclCheckState state); // Auto-diff convenience functions for translating primal types to differential types. - Type* _toDifferentialParamType(ASTBuilder* builder, Type* primalType); + Type* _toDifferentialParamType(Type* primalType); + + Type* getDifferentialPairType(Type* primalType); - // Translate a return type to the return type of a forward-mode differentiated - // function. - // - Type* _toJVPReturnType(ASTBuilder* builder, Type* primalType); - // Convert a function's original type to it's JVP type. - Type* processJVPFuncType(ASTBuilder* builder, FuncType* originalType); + Type* processJVPFuncType(FuncType* originalType); // Check and register a type if it is differentiable. void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); @@ -1038,6 +1035,11 @@ namespace Slang DeclRef<GenericDecl> requirementGenDecl, RefPtr<WitnessTable> witnessTable); + bool doesTypeSatisfyAssociatedTypeConstraintRequirement( + Type* satisfyingType, + DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, + RefPtr<WitnessTable> witnessTable); + bool doesTypeSatisfyAssociatedTypeRequirement( Type* satisfyingType, DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, @@ -1124,7 +1126,7 @@ namespace Slang /// Otherwise, returns `false`. bool trySynthesizeDifferentialAssociatedTypeRequirementWitness( ConformanceCheckingContext* context, - DeclRef<Decl> requirementDeclRef, + DeclRef<AssocTypeDecl> requirementDeclRef, RefPtr<WitnessTable> witnessTable); /// Registers a type as differentiable in the currrent semantic context, if the declaration represents @@ -1989,6 +1991,8 @@ namespace Slang Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); + Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr); + /// Perform semantic checking on a `modifier` that is being applied to the given `type` Val* checkTypeModifier(Modifier* modifier, Type* type); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index ef067c06c..42fab94a6 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1559,7 +1559,7 @@ namespace Slang OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), origFuncType)); + candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType)); candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); @@ -1576,7 +1576,6 @@ namespace Slang OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; candidate.funcType = as<FuncType>(processJVPFuncType( - this->getASTBuilder(), as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)))); candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(item.declRef); @@ -1606,7 +1605,7 @@ namespace Slang auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>()); // Process func type to generate JVP func type. - auto jvpFuncType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), funcType)); + auto jvpFuncType = as<FuncType>(processJVPFuncType(funcType)); // Extract parameter list from processed type. List<Type*> paramTypes; @@ -1631,7 +1630,6 @@ namespace Slang // This could potentially be a declRef.substitute(jvpFuncType) // candidate.funcType = as<FuncType>(processJVPFuncType( - this->getASTBuilder(), getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); candidate.resultType = candidate.funcType->getResultType(); diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 73818dbb1..3d02d4fc0 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -205,7 +205,7 @@ struct DifferentiableTypeConformanceContext { if (as<IRModuleInst>(inst) && differentiableInterfaceType) { - // Assume for now that IDifferentiable has exactly three fields. + // Assume for now that IDifferentiable has exactly four fields. SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4); if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) return as<IRStructKey>(entry->getRequirementKey()); @@ -462,45 +462,6 @@ struct DifferentialPairTypeBuilder } } - void _createGenericDiffPairType(IRBuilder* builder) - { - // Insert directly at top level (skip any generic scopes etc.) - auto insertLoc = builder->getInsertLoc(); - builder->setInsertInto(builder->getModule()->getModuleInst()); - - // Make a generic version of the pair struct. - auto irGeneric = builder->emitGeneric(); - irGeneric->setFullType(builder->getTypeKind()); - builder->setInsertInto(irGeneric); - - generatedTypeList.add(irGeneric); - - auto irBlock = builder->emitBlock(); - builder->setInsertInto(irBlock); - - auto pTypeParam = builder->emitParam(builder->getTypeType()); - builder->addNameHintDecoration(pTypeParam, UnownedTerminatedStringSlice("pT")); - - auto dTypeParam = builder->emitParam(builder->getTypeType()); - builder->addNameHintDecoration(dTypeParam, UnownedTerminatedStringSlice("dT")); - - auto irStructType = builder->createStructType(); - builder->emitReturn(irStructType); - - auto primalKey = _getOrCreatePrimalStructKey(builder); - builder->addNameHintDecoration(primalKey, UnownedTerminatedStringSlice("primal")); - builder->createStructField(irStructType, primalKey, (IRType*) pTypeParam); - - auto diffKey = _getOrCreateDiffStructKey(builder); - builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential")); - builder->createStructField(irStructType, diffKey, (IRType*) dTypeParam); - - // Reset cursor when done. - builder->setInsertLoc(insertLoc); - - this->genericDiffPairType = irGeneric; - } - IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder) { if (!this->globalDiffKey) @@ -535,17 +496,6 @@ struct DifferentialPairTypeBuilder return this->globalPrimalKey; } - IRInst* _getOrCreateGenericDiffPairType(IRBuilder* builder) - { - if (!this->genericDiffPairType) - { - _createGenericDiffPairType(builder); - } - - SLANG_ASSERT(this->genericDiffPairType); - return this->genericDiffPairType; - } - IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) { if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType)) @@ -1383,22 +1333,10 @@ struct JVPTranscriber } else { - // We special case a few non-differentiable types that sometimes appear in places - // where we're forced to provide a differential zero value. For instance, - // float3(float, float, int) is accepted by the compiler, but is tricky in the context - // of differentiation since int is non-differentiable, and should be cast to float first. - // In the absence of such casts, this piece of code generates appropriate zero values. - // - switch (primalType->getOp()) - { - case kIROp_IntType: - return builder->getIntValue(primalType, 0); - default: - getSink()->diagnose(primalType->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; - } + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; } } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 1d1db14f9..61aa28bbe 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -60,6 +60,7 @@ INST(Nop, nop, 0, 0) INST(OptionalType, Optional, 1, 0) INST(DifferentialPairType, DiffPair, 1, 0) + INST(DifferentialBottomType, DiffBottomType, 0, 0) /* BindExistentialsTypeBase */ @@ -277,7 +278,6 @@ INST(lookup_interface_method, lookup_interface_method, 2, 0) INST(GetSequentialID, GetSequentialID, 1, 0) INST(lookup_witness_table, lookup_witness_table, 2, 0) INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) - INST(Construct, construct, 0, 0) INST(AllocObj, allocObj, 0, 0) @@ -297,6 +297,7 @@ INST(GetOptionalValue, getOptionalValue, 1, 0) INST(OptionalHasValue, optionalHasValue, 1, 0) INST(MakeOptionalValue, makeOptionalValue, 1, 0) INST(MakeOptionalNone, makeOptionalNone, 1, 0) +INST(DifferentialBottomValue, differentialBottomVal, 0, 0) INST(Call, call, 1, 0) INST(RTTIObject, rtti_object, 0, 0) @@ -759,6 +760,7 @@ INST(Reinterpret, reinterpret, 1, 0) INST(CastPtrToBool, CastPtrToBool, 1, 0) INST(IsType, IsType, 3, 0) INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0) +INST(DifferentialEqualityTypeCast, DifferentialEqualityTypeCast, 1, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 3a59eb6c9..382f7be5e 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3629,6 +3629,7 @@ namespace Slang IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType, IRInst* scope) { + IRInst* foundResult = nullptr; for (auto child = scope->getFirstChild(); child; child = child->getNextInst()) { if (child->getOp() == kIROp_DifferentiableTypeDictionary) @@ -3640,13 +3641,20 @@ namespace Slang if (irType == entryType) { - return entryConformanceWitness; + foundResult = entryConformanceWitness; + // If the found witness table is not a trivial one (i.e. DifferentialBottom:IDifferential), + // return immediately. Otherwise, continue the search to see if we can find a better one. + if (auto witness = as<IRWitnessTable>(foundResult)) + { + if (witness->getConcreteType()->getOp() != kIROp_DifferentialBottomType) + return foundResult; + } } } } } - return nullptr; + return foundResult; } IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType) @@ -6282,6 +6290,8 @@ namespace Slang case kIROp_MakeOptionalNone: case kIROp_OptionalHasValue: case kIROp_GetOptionalValue: + case kIROp_MakeTuple: + case kIROp_GetTupleElement: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads case kIROp_ImageSubscript: case kIROp_FieldExtract: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ae0590105..8f00253f5 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1366,12 +1366,29 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // produce transitive witnesses in shapes that will cuase us // problems here. // - IRInst* requirementKey = lowerSimpleVal(context, val->midToSup); + IRInst* midToSup = lowerSimpleVal(context, val->midToSup); + + if (!baseWitnessTable) + { + // If we don't have a valid baseWitnessTable, + // we are in a situation that `subToMid` is a `DifferentialBottomSubtypeWitness` + // that applies for all non-differentiable types. + // In this case `midToSup` will give us the `DifferentialBottom:IDifferentiable` + // witness table and we can just use that as the final result of + // this transitive witness. + SLANG_RELEASE_ASSERT(midToSup && as<IRWitnessTableType>(midToSup->getDataType())); + return LoweredValInfo::simple(midToSup); + } return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( getBuilder()->getWitnessTableType(lowerType(context, val->sup)), baseWitnessTable, - requirementKey)); + midToSup)); + } + + LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*) + { + return LoweredValInfo(); } LoweredValInfo visitTaggedUnionSubtypeWitness( @@ -3053,6 +3070,15 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> baseVal.val)); } + LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr) + { + auto baseVal = lowerSubExpr(expr->arrayExpr); + auto type = lowerType(context, expr->arrayExpr->type); + auto arrayType = as<IRArrayType>(type); + SLANG_ASSERT(arrayType); + return LoweredValInfo::simple(arrayType->getElementCount()); + } + LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/) { SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); @@ -5857,7 +5883,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // add an entry to the context. // if (irWitness && !getBuilder()->findDifferentiableTypeEntry(irType)) + { getBuilder()->addDifferentiableTypeEntry(irType, irWitness); + } } else if (auto importEntry = as<DifferentiableTypeDictionaryImportItem>(member)) { @@ -6777,7 +6805,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr); // Add `irInterface` to decl mapping now to prevent cyclic lowering. - setValue(subContext, decl, LoweredValInfo::simple(irInterface)); + setValue(context, decl, LoweredValInfo::simple(irInterface)); // Setup subContext for proper lowering `ThisType`, associated types and // the interface decl's self reference. @@ -7084,6 +7112,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember) { + ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl); auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; SLANG_RELEASE_ASSERT(as<IRStructKey>(key)); auto builder = getBuilder(); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 93f2fcdcb..980a1d0bc 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -132,8 +132,8 @@ namespace Slang Scope* newScope = astBuilder->create<Scope>(); newScope->containerDecl = containerDecl; newScope->parent = currentScope; - currentScope = newScope; + containerDecl->ownedScope = newScope; } void pushScopeAndSetParent(ContainerDecl* containerDecl) diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index c779b4510..8cd443438 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -319,7 +319,31 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } } } - + else if (auto transitiveTypeWitness = as<TransitiveSubtypeWitness>(subtypeWitness)) + { + // Hard code witness entry that `T.Differential = DifferentialBottom` for `T` that + // coerce to `DifferentialBottom`. + if (astBuilder->getDifferentialBottomType()->equals(transitiveTypeWitness->subToMid->sup)) + { + if (auto builtinAttr = requirementKey->findModifier<BuiltinRequirementAttribute>()) + { + if (builtinAttr->kind == BuiltinRequirementKind::DifferentialType) + { + return RequirementWitness(astBuilder->getDifferentialBottomType()); + } + } + } + } + else if (auto extractFromConjunctionTypeWitness = as<ExtractFromConjunctionSubtypeWitness>(subtypeWitness)) + { + if (auto conjunctionTypeWitness = as<ConjunctionSubtypeWitness>(extractFromConjunctionTypeWitness->conjunctionWitness)) + { + if (extractFromConjunctionTypeWitness->indexInConjunction == 0) + return tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(conjunctionTypeWitness->leftWitness), requirementKey); + else + return tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(conjunctionTypeWitness->rightWitness), requirementKey); + } + } // TODO: should handle the transitive case here too return RequirementWitness(); @@ -1140,7 +1164,46 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - // + Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef) + { + auto substDeclRef = declRef.as<AssocTypeDecl>(); + if (!substDeclRef) + return nullptr; + + auto substAssocTypeDecl = substDeclRef.getDecl(); + + for (auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) + { + auto thisSubst = as<ThisTypeSubstitution>(s); + if (!thisSubst) + continue; + + if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl)) + { + if (thisSubst->interfaceDecl == interfaceDecl) + { + // We need to look up the declaration that satisfies + // the requirement named by the associated type. + Decl* requirementKey = substAssocTypeDecl; + RequirementWitness requirementWitness = tryLookUpRequirementWitness(builder, thisSubst->witness, requirementKey); + switch (requirementWitness.getFlavor()) + { + default: + // No usable value was found, so there is nothing we can do. + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = requirementWitness.getVal(); + return satisfyingVal; + } + break; + } + } + } + } + return nullptr; + } String DeclRefBase::toString() const { diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index ee868be62..b38bd358f 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -233,7 +233,7 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::LetExpr">(Slang::LetExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialValueExpr">(Slang::ExtractExistentialValueExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OpenRefExpr">(Slang::OpenRefExpr*)&astNodeType</ExpandedItem> - <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::JVPDifferentiateExpr">(Slang::JVPDifferentiateExpr*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ForwardDifferentiateExpr">(Slang::ForwardDifferentiateExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionTypeExpr">(Slang::TaggedUnionTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisTypeExpr">(Slang::ThisTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndTypeExpr">(Slang::AndTypeExpr*)&astNodeType</ExpandedItem> @@ -384,6 +384,7 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ErrorType">(Slang::ErrorType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BottomType">(Slang::BottomType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclRefType">(Slang::DeclRefType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">(Slang::DeclRefType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArithmeticExpressionType">(Slang::ArithmeticExpressionType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BasicExpressionType">(Slang::BasicExpressionType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::VectorExpressionType">(Slang::VectorExpressionType*)&astNodeType</ExpandedItem> @@ -456,7 +457,7 @@ </Expand> </Type> - <Type Name="Slang::Substitutions"> + <Type Name="Slang::Substitutions" Inheritable="false"> <DisplayString>{astNodeType}</DisplayString> <Expand> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericSubstitution">(Slang::GenericSubstitution*)&astNodeType</ExpandedItem> @@ -469,7 +470,7 @@ <LinkedListItems> <HeadPointer>substitutions</HeadPointer> <NextPointer>outer</NextPointer> - <ValueNode>this</ValueNode> + <ValueNode>(Slang::Substitutions*)this</ValueNode> </LinkedListItems> </Expand> </Type> @@ -487,6 +488,79 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericParamIntVal">(Slang::GenericParamIntVal*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclaredSubtypeWitness">(Slang::DeclaredSubtypeWitness*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TransitiveSubtypeWitness">(Slang::TransitiveSubtypeWitness*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OverloadGroupType">(Slang::OverloadGroupType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::InitializerListType">(Slang::InitializerListType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ErrorType">(Slang::ErrorType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BottomType">(Slang::BottomType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclRefType">(Slang::DeclRefType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">(Slang::DeclRefType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArithmeticExpressionType">(Slang::ArithmeticExpressionType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BasicExpressionType">(Slang::BasicExpressionType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::VectorExpressionType">(Slang::VectorExpressionType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::MatrixExpressionType">(Slang::MatrixExpressionType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BuiltinType">(Slang::BuiltinType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::FeedbackType">(Slang::FeedbackType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ResourceType">(Slang::ResourceType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TextureTypeBase">(Slang::TextureTypeBase*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TextureType">(Slang::TextureType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TextureSamplerType">(Slang::TextureSamplerType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLImageType">(Slang::GLSLImageType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::SamplerStateType">(Slang::SamplerStateType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BuiltinGenericType">(Slang::BuiltinGenericType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PointerLikeType">(Slang::PointerLikeType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ParameterGroupType">(Slang::ParameterGroupType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::UniformParameterGroupType">(Slang::UniformParameterGroupType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ConstantBufferType">(Slang::ConstantBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TextureBufferType">(Slang::TextureBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLShaderStorageBufferType">(Slang::GLSLShaderStorageBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ParameterBlockType">(Slang::ParameterBlockType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::VaryingParameterGroupType">(Slang::VaryingParameterGroupType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLInputParameterGroupType">(Slang::GLSLInputParameterGroupType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLOutputParameterGroupType">(Slang::GLSLOutputParameterGroupType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLStructuredBufferTypeBase">(Slang::HLSLStructuredBufferTypeBase*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLStructuredBufferType">(Slang::HLSLStructuredBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLRWStructuredBufferType">(Slang::HLSLRWStructuredBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLRasterizerOrderedStructuredBufferType">(Slang::HLSLRasterizerOrderedStructuredBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLAppendStructuredBufferType">(Slang::HLSLAppendStructuredBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLConsumeStructuredBufferType">(Slang::HLSLConsumeStructuredBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLStreamOutputType">(Slang::HLSLStreamOutputType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLPointStreamType">(Slang::HLSLPointStreamType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLLineStreamType">(Slang::HLSLLineStreamType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLTriangleStreamType">(Slang::HLSLTriangleStreamType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::UntypedBufferResourceType">(Slang::UntypedBufferResourceType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLByteAddressBufferType">(Slang::HLSLByteAddressBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLRWByteAddressBufferType">(Slang::HLSLRWByteAddressBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLRasterizerOrderedByteAddressBufferType">(Slang::HLSLRasterizerOrderedByteAddressBufferType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::RaytracingAccelerationStructureType">(Slang::RaytracingAccelerationStructureType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLPatchType">(Slang::HLSLPatchType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLInputPatchType">(Slang::HLSLInputPatchType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLOutputPatchType">(Slang::HLSLOutputPatchType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLInputAttachmentType">(Slang::GLSLInputAttachmentType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::StringTypeBase">(Slang::StringTypeBase*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::StringType">(Slang::StringType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NativeStringType">(Slang::NativeStringType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DynamicType">(Slang::DynamicType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::EnumTypeType">(Slang::EnumTypeType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PtrTypeBase">(Slang::PtrTypeBase*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PtrType">(Slang::PtrType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ParamDirectionType">(Slang::ParamDirectionType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OutTypeBase">(Slang::OutTypeBase*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OutType">(Slang::OutType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::InOutType">(Slang::InOutType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::RefType">(Slang::RefType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NullPtrType">(Slang::NullPtrType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArrayExpressionType">(Slang::ArrayExpressionType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TypeType">(Slang::TypeType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamedExpressionType">(Slang::NamedExpressionType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::FuncType">(Slang::FuncType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericDeclRefType">(Slang::GenericDeclRefType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamespaceType">(Slang::NamespaceType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialType">(Slang::ExtractExistentialType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionType">(Slang::TaggedUnionType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExistentialSpecializedType">(Slang::ExistentialSpecializedType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisType">(Slang::ThisType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndType">(Slang::AndType*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedType">(Slang::ModifiedType*)&astNodeType</ExpandedItem> </Expand> </Type> </AutoVisualizer>
\ No newline at end of file |
