From cbc1eff56057f199183bb7c17d8a360326512367 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 1 Nov 2022 08:46:57 -0700 Subject: Make `DifferentialPair` able to nest. (#2477) --- source/core/slang-list.h | 2 +- source/slang/core.meta.slang | 2 +- source/slang/diff.meta.slang | 63 ++++- source/slang/slang-ast-base.h | 8 + source/slang/slang-ast-builder.cpp | 19 +- source/slang/slang-ast-builder.h | 9 +- source/slang/slang-ast-decl.h | 3 + source/slang/slang-ast-expr.h | 6 + source/slang/slang-ast-modifier.h | 4 +- source/slang/slang-ast-support-types.h | 12 + source/slang/slang-ast-synthesis.cpp | 175 ++++++++++++ source/slang/slang-ast-synthesis.h | 147 +++++++++++ source/slang/slang-ast-type.cpp | 57 ++-- source/slang/slang-ast-type.h | 14 + source/slang/slang-ast-val.cpp | 35 +++ source/slang/slang-ast-val.h | 18 ++ source/slang/slang-check-conformance.cpp | 42 ++- source/slang/slang-check-conversion.cpp | 3 - source/slang/slang-check-decl.cpp | 438 ++++++++++++++++++++++--------- source/slang/slang-check-expr.cpp | 114 ++++++-- source/slang/slang-check-impl.h | 20 +- source/slang/slang-check-overload.cpp | 6 +- source/slang/slang-ir-diff-jvp.cpp | 72 +---- source/slang/slang-ir-inst-defs.h | 4 +- source/slang/slang-ir.cpp | 14 +- source/slang/slang-lower-to-ir.cpp | 35 ++- source/slang/slang-parser.cpp | 2 +- source/slang/slang-syntax.cpp | 67 ++++- source/slang/slang.natvis | 80 +++++- 29 files changed, 1176 insertions(+), 295 deletions(-) create mode 100644 source/slang/slang-ast-synthesis.cpp create mode 100644 source/slang/slang-ast-synthesis.h (limited to 'source') diff --git a/source/core/slang-list.h b/source/core/slang-list.h index 25687d129..187ff3109 100644 --- a/source/core/slang-list.h +++ b/source/core/slang-list.h @@ -269,7 +269,7 @@ namespace Slang void insertRange(Index id, const List& list) { insertRange(id, list.m_buffer, list.m_count); } - void addRange(ArrayView list) { insertRange(m_count, list.getBuffer(), list.Count()); } + void addRange(ArrayView list) { insertRange(m_count, list.getBuffer(), list.getCount()); } void addRange(const T* vals, Index n) { insertRange(m_count, vals, n); } diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 75bc65562..5df9d01fe 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2434,7 +2434,7 @@ __generic [__unsafeForceInlineEarly] bool operator >=(T v0, T v1) { - return v1.lessThanOrEquals(v1); + return v1.lessThan(v1); } __generic [__unsafeForceInlineEarly] diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ad3dfe77c..674531048 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -128,13 +128,37 @@ extension vector : IDifferentiable } } +__magic_type(DifferentialBottomType) +__intrinsic_type($(kIROp_DifferentialBottomType)) +struct __DifferentialBottom : IDifferentiable +{ + typedef __DifferentialBottom Differential; + + __intrinsic_op($(kIROp_DifferentialBottomValue)) + static __DifferentialBottom dzero(); + + [__unsafeForceInlineEarly] + static __DifferentialBottom dadd(Differential a, Differential b) + { + return dzero(); + } + + [__unsafeForceInlineEarly] + static __DifferentialBottom dmul(This a, Differential b) + { + return dzero(); + } +} + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic __magic_type(DifferentialPairType) __intrinsic_type($(kIROp_DifferentialPairType)) -struct __DifferentialPair +struct DifferentialPair : IDifferentiable { + typedef DifferentialPair Differential; + typedef T.Differential DifferentialElementType; __intrinsic_op($(kIROp_MakeDifferentialPair)) __init(T _primal, T.Differential _differential); @@ -154,6 +178,31 @@ struct __DifferentialPair { return p(); } + + [__unsafeForceInlineEarly] + static Differential dzero() + { + return Differential(T.dzero(), Differential.DifferentialElementType.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dadd(Differential a, Differential b) + { + return Differential( + T.dadd( + a.p(), + b.p() + ), + Differential.DifferentialElementType.dzero()); + } + + [__unsafeForceInlineEarly] + static Differential dmul(This a, Differential b) + { + return Differential( + T.dmul(a.p(), b.p()), + Differential.DifferentialElementType.dzero()); + } }; typealias IDFloat = IFloat & IDifferentiable; @@ -171,9 +220,9 @@ namespace dstd T exp(T x); __generic - __DifferentialPair d_exp(__DifferentialPair dpx) + DifferentialPair d_exp(DifferentialPair dpx) { - return __DifferentialPair( + return DifferentialPair( exp(dpx.p()), T.dmul(exp(dpx.p()), dpx.d())); } @@ -189,9 +238,9 @@ namespace dstd T sin(T x); __generic - __DifferentialPair d_sin(__DifferentialPair dpx) + DifferentialPair d_sin(DifferentialPair dpx) { - return __DifferentialPair( + return DifferentialPair( sin(dpx.p()), T.dmul(cos(dpx.p()), dpx.d())); } @@ -207,9 +256,9 @@ namespace dstd T cos(T x); __generic - __DifferentialPair d_cos(__DifferentialPair dpx) + DifferentialPair d_cos(DifferentialPair dpx) { - return __DifferentialPair( + return DifferentialPair( cos(dpx.p()), T.dmul(-sin(dpx.p()), dpx.d())); } diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index dea02afbb..627a56152 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -254,7 +254,9 @@ private: // The actual values of the arguments List args; public: + List& getArgs() { return args; } const List& getArgs() const { return args; } + // Overrides should be public so base classes can access Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff); bool _equalsOverride(Substitutions* subst); @@ -265,6 +267,12 @@ public: genericDecl = decl; } + GenericSubstitution(GenericDecl* decl, ArrayView argVals) + { + genericDecl = decl; + args.addRange(argVals); + } + template GenericSubstitution(GenericDecl* decl, TArgs... inArgs) { diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index f6c550d69..beee16f9c 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -141,6 +141,16 @@ Type* SharedASTBuilder::getNoneType() return m_noneType; } +Type* SharedASTBuilder::getDifferentialBottomType() +{ + if (!m_diffBottomType) + { + auto diffBottomTypeDecl = findMagicDecl("DifferentialBottomType"); + m_diffBottomType = DeclRefType::create(m_astBuilder, makeDeclRef(diffBottomTypeDecl)); + } + return m_diffBottomType; +} + SharedASTBuilder::~SharedASTBuilder() { // Release built in types.. @@ -299,13 +309,18 @@ VectorExpressionType* ASTBuilder::getVectorType( return result; } -DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* conformanceWitness) +DifferentialPairType* ASTBuilder::getDifferentialPairType( + Type* valueType, + Witness* primalIsDifferentialWitness) { auto genericDecl = dynamicCast(m_sharedASTBuilder->findMagicDecl("DifferentialPairType")); auto typeDecl = genericDecl->inner; - auto substitutions = getOrCreate(genericDecl, valueType, conformanceWitness); + auto substitutions = getOrCreate( + genericDecl, + valueType, + primalIsDifferentialWitness); auto declRef = DeclRef(typeDecl, substitutions); auto rsType = DeclRefType::create(this, declRef); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index e4ea872a0..235bebfaa 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -35,6 +35,8 @@ public: Type* getNullPtrType(); /// Get the NullPtr type Type* getNoneType(); + /// Get the DifferentialBottom type. + Type* getDifferentialBottomType(); const ReflectClassInfo* findClassInfo(Name* name); SyntaxClass findSyntaxClass(Name* name); @@ -79,7 +81,7 @@ protected: Type* m_dynamicType = nullptr; Type* m_nullPtrType = nullptr; Type* m_noneType = nullptr; - + Type* m_diffBottomType = nullptr; Type* m_builtinTypes[Index(BaseType::CountOf)]; Dictionary m_magicDecls; @@ -297,6 +299,7 @@ public: Type* getOverloadedType() { return m_sharedASTBuilder->m_overloadedType; } Type* getErrorType() { return m_sharedASTBuilder->m_errorType; } Type* getBottomType() { return m_sharedASTBuilder->m_bottomType; } + Type* getDifferentialBottomType() { return m_sharedASTBuilder->getDifferentialBottomType(); } Type* getStringType() { return m_sharedASTBuilder->getStringType(); } Type* getNullPtrType() { return m_sharedASTBuilder->getNullPtrType(); } Type* getNoneType() { return m_sharedASTBuilder->getNoneType(); } @@ -326,7 +329,9 @@ public: VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount); - DifferentialPairType* getDifferentialPairType(Type* valueType, Witness* conformanceWitness); + DifferentialPairType* getDifferentialPairType( + Type* valueType, + Witness* primalIsDifferentialWitness); DeclRef getDifferentiableInterface(); diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 87d696927..90175dd9d 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -29,6 +29,9 @@ class ContainerDecl: public Decl List members; SourceLoc closingSourceLoc; + // The associated scope owned by this decl. + Scope* ownedScope = nullptr; + template FilteredMemberList getMembersOfType() { diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index f2a72703e..ef6a05c71 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -126,6 +126,12 @@ class InitializerListExpr : public Expr List args; }; +class GetArrayLengthExpr : public Expr +{ + SLANG_AST_CLASS(GetArrayLengthExpr) + Expr* arrayExpr = nullptr; +}; + // A base class for expressions with arguments class ExprWithArgsBase : public Expr { diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 76106074f..0c1eb8d49 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1016,14 +1016,14 @@ class RequiresNVAPIAttribute : public Attribute }; /// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated. -class ForwardDifferentiableAttribute : public Attribute +class ForwardDifferentiableAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(ForwardDifferentiableAttribute) }; /// The `[ForwardDerivative(function)]` attribute specifies a custom function that should /// be used as the derivative for the decorated function. -class ForwardDerivativeAttribute : public Attribute +class ForwardDerivativeAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(ForwardDerivativeAttribute) diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 4c92810d9..d4a781846 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -384,6 +384,12 @@ namespace Slang /// ReadyForConformances, + /// Any DeclRefTypes with substitutions have been fully resolved + /// to concrete type. E.g. `T.X` with `T=A` should resolve to `A.X`. + /// We need a separate pass to resolve these types because `A.X` + /// maybe synthesized and made available only after conformance checking. + TypesFullyResolved, + /// The declaration is fully checked. /// /// This step includes any validation of the declaration that is @@ -779,6 +785,12 @@ namespace Slang void toText(StringBuilder& out) const; }; + // If this is a declref to an associatedtype with a ThisTypeSubsitution, + // try to find the concrete decl that satisfies the associatedtype requirement from the + // concrete type supplied by ThisTypeSubstittution. + Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef declRef); + + template struct DeclRef : DeclRefBase { diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp new file mode 100644 index 000000000..e5a9ff75f --- /dev/null +++ b/source/slang/slang-ast-synthesis.cpp @@ -0,0 +1,175 @@ +#include "slang-ast-synthesis.h" + +namespace Slang +{ +Expr* ASTSynthesizer::emitBinaryExpr(UnownedStringSlice operatorToken, Expr* left, Expr* right) +{ + auto infixExpr = m_builder->create(); + infixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));; + infixExpr->arguments.add(left); + infixExpr->arguments.add(right); + return infixExpr; +} + +Expr* ASTSynthesizer::emitPrefixExpr(UnownedStringSlice operatorToken, Expr* base) +{ + auto prefixExpr = m_builder->create(); + prefixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));; + prefixExpr->arguments.add(base); + return prefixExpr; +} + +Expr* ASTSynthesizer::emitPostfixExpr(UnownedStringSlice operatorToken, Expr* base) +{ + auto postfixExpr = m_builder->create(); + postfixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));; + postfixExpr->arguments.add(base); + return postfixExpr; +} + +ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl* &outIndexVar) +{ + auto scopeDecl = pushVarScope()->containerDecl; + auto stmt = m_builder->create(); + stmt->scopeDecl = (ScopeDecl*)scopeDecl; + auto declStmt = emitVarDeclStmt(nullptr, m_namePool->getName("S_synth_loop_index"), initVal); + stmt->initialStatement = declStmt; + outIndexVar = (VarDecl*)declStmt->decl; + auto predicateExpr = emitBinaryExpr(UnownedStringSlice("<"), emitVarExpr(outIndexVar), finalVal); + stmt->predicateExpression = predicateExpr; + stmt->sideEffectExpression = emitPrefixExpr(UnownedStringSlice("++"), emitVarExpr(outIndexVar)); + getCurrentScope().m_parentSeqStmt->stmts.add(stmt); + return stmt; +} + +Expr* ASTSynthesizer::emitVarExpr(Name* name) +{ + auto scope = getCurrentScope(); + SLANG_RELEASE_ASSERT(scope.m_scope); + auto varExpr = m_builder->create(); + varExpr->name = name; + varExpr->scope = scope.m_scope; + return varExpr; +} + +Expr* ASTSynthesizer::emitVarExpr(VarDecl* varDecl) +{ + auto varExpr = m_builder->create(); + varExpr->declRef = makeDeclRef(varDecl); + varExpr->type = varDecl->type.type; + return varExpr; +} + +Expr* ASTSynthesizer::emitVarExpr(VarDecl* var, Type* type) +{ + auto expr = m_builder->create(); + expr->declRef = DeclRef(var, nullptr); + expr->type.type = type; + expr->type.isLeftValue = true; + return expr; +} + +Expr* ASTSynthesizer::emitVarExpr(DeclStmt* varStmt, Type* type) +{ + auto expr = m_builder->create(); + expr->declRef = DeclRef(as(varStmt->decl), nullptr); + expr->type.type = type; + expr->type.isLeftValue = true; + return expr; +} + +Expr* ASTSynthesizer::emitIntConst(int value) +{ + auto expr = m_builder->create(); + expr->type.type = m_builder->getIntType(); + expr->value = value; + return expr; +} + +Expr* ASTSynthesizer::emitGetArrayLengthExpr(Expr* arrayExpr) +{ + auto expr = m_builder->create(); + expr->arrayExpr = arrayExpr; + expr->type = m_builder->getIntType(); + return expr; +} + +Expr* ASTSynthesizer::emitMemberExpr(Expr* arrayExpr, Name* name) +{ + auto rs = m_builder->create(); + rs->baseExpression = arrayExpr; + rs->name = name; + return rs; +} + +Expr* ASTSynthesizer::emitAssignExpr(Expr* left, Expr* right) +{ + auto rs = m_builder->create(); + rs->left = left; + rs->right = right; + return rs; +} + +Expr* ASTSynthesizer::emitInvokeExpr(Expr* callee, List&& args) +{ + auto rs = m_builder->create(); + rs->functionExpr = callee; + rs->arguments = _Move(args); + return rs; +} + +Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name) +{ + auto rs = m_builder->create(); + auto typeExpr = m_builder->create(); + auto typetype = m_builder->create(); + typetype->type = type; + typeExpr->type = typetype; + rs->baseExpression = typeExpr; + rs->name = name; + return rs; +} + +Expr* ASTSynthesizer::emitIndexExpr(Expr* base, Expr* index) +{ + auto rs = m_builder->create(); + rs->baseExpression = base; + rs->indexExprs.add(index); + return rs; +} + +ExpressionStmt* ASTSynthesizer::emitExprStmt(Expr* expr) +{ + auto rs = m_builder->create(); + _addStmtToScope(rs); + rs->expression = expr; + return rs; +} + +ReturnStmt* ASTSynthesizer::emitReturnStmt(Expr* expr) +{ + auto rs = m_builder->create(); + rs->expression = expr; + _addStmtToScope(rs); + return rs; +} + +DeclStmt* ASTSynthesizer::emitVarDeclStmt(Type* type, Name* name, Expr* initVal) +{ + auto scope = getCurrentScope(); + SLANG_RELEASE_ASSERT(scope.m_parentSeqStmt); + SLANG_RELEASE_ASSERT(scope.m_scope); + SLANG_RELEASE_ASSERT(scope.m_scope->containerDecl); + auto varDecl = m_builder->create(); + varDecl->type.type = type; + varDecl->nameAndLoc.name = name; + varDecl->initExpr = initVal; + varDecl->parentDecl = scope.m_scope->containerDecl; + varDecl->parentDecl->members.add(varDecl); + auto stmt = m_builder->create(); + stmt->decl = varDecl; + _addStmtToScope(stmt); + return stmt; +} + +} diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h new file mode 100644 index 000000000..2af890d34 --- /dev/null +++ b/source/slang/slang-ast-synthesis.h @@ -0,0 +1,147 @@ +// slang-ast-synthesis.h + +#pragma once + +#include "slang-syntax.h" + +namespace Slang { + +struct ASTEmitScope +{ + ContainerDecl* m_parent = nullptr; + SeqStmt* m_parentSeqStmt = nullptr; + Scope* m_scope = nullptr; +}; +class ASTSynthesizer +{ +private: + ASTBuilder* m_builder; + NamePool* m_namePool; + List m_scopeStack; +public: + ASTSynthesizer(ASTBuilder* builder, NamePool* namePool) + : m_builder(builder) + , m_namePool(namePool) + { + } + + Scope* getScope(ContainerDecl* decl) + { + for (auto container = decl; container; container = container->parentDecl) + { + if (container->ownedScope) + { + return container->ownedScope; + } + } + return nullptr; + } + + // Create a scope for `decl` and push it to scope stack + void pushScopeForContainer(ContainerDecl* decl) + { + if (decl->ownedScope) + { + // if decl already owns a scope, don't create a new one. + pushContainerScope(decl); + return; + } + + auto parentScope = getScope(decl); + decl->ownedScope = m_builder->create(); + decl->ownedScope->parent = parentScope; + pushContainerScope(decl); + } + + // Push `decl` and its associated scope to scope stack + void pushContainerScope(ContainerDecl* decl) + { + ASTEmitScope scope = getCurrentScope(); + scope.m_parent = decl; + scope.m_scope = getScope(decl); + m_scopeStack.add(scope); + } + + Scope* pushVarScope() + { + ASTEmitScope scope = getCurrentScope(); + auto scopeDecl = m_builder->create(); + auto newScope = m_builder->create(); + scopeDecl->parentDecl = scope.m_parent; + if (scope.m_parent) + scope.m_parent->members.add(scopeDecl); + newScope->parent = scope.m_scope; + newScope->containerDecl = scopeDecl; + scope.m_scope = newScope; + m_scopeStack.add(scope); + return newScope; + } + + void _addStmtToScope(Stmt* stmt) + { + auto scope = getCurrentScope(); + if (scope.m_parentSeqStmt) + { + scope.m_parentSeqStmt->stmts.add(stmt); + } + } + + SeqStmt* pushSeqStmtScope() + { + ASTEmitScope scope = getCurrentScope(); + scope.m_parentSeqStmt = m_builder->create(); + m_scopeStack.add(scope); + return scope.m_parentSeqStmt; + } + + void popScope() + { + m_scopeStack.removeLast(); + } + + ASTEmitScope getCurrentScope() + { + if (m_scopeStack.getCount()) + return m_scopeStack.getLast(); + return ASTEmitScope(); + } + + ForStmt* emitFor(Expr* initVal, Expr* finalVal, VarDecl*& outIndexVar); + + Expr* emitBinaryExpr(UnownedStringSlice operatorToken, Expr* left, Expr* right); + + Expr* emitPrefixExpr(UnownedStringSlice operatorToken, Expr* base); + + Expr* emitPostfixExpr(UnownedStringSlice operatorToken, Expr* base); + + Expr* emitVarExpr(Name* name); + Expr* emitVarExpr(VarDecl* var); + Expr* emitVarExpr(VarDecl* var, Type* type); + Expr* emitVarExpr(DeclStmt* varStmt, Type* type); + + Expr* emitIntConst(int value); + + Expr* emitGetArrayLengthExpr(Expr* arrayExpr); + + Expr* emitMemberExpr(Expr* base, Name* name); + Expr* emitMemberExpr(Type* base, Name* name); + + Expr* emitIndexExpr(Expr* base, Expr* index); + + Expr* emitAssignExpr(Expr* left, Expr* right); + ExpressionStmt* emitAssignStmt(Expr* left, Expr* right) + { + return emitExprStmt(emitAssignExpr(left, right)); + } + + Expr* emitInvokeExpr(Expr* callee, List&& args); + + DeclStmt* emitVarDeclStmt(Type* type, Name* name = nullptr, Expr* initVal = nullptr); + + ExpressionStmt* emitExprStmt(Expr* expr); + + ReturnStmt* emitReturnStmt(Expr* expr); + +}; + +} // namespace Slang diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 480589af4..a869c95a7 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -169,6 +169,27 @@ Val* BottomType::_substituteImplOverride( HashCode BottomType::_getHashCodeOverride() { return HashCode(size_t(this)); } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DifferentialBottomType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +void DifferentialBottomType::_toTextOverride(StringBuilder& out) { out << toSlice("diff_bottom"); } + +bool DifferentialBottomType::_equalsImplOverride(Type* type) +{ + if (auto bottomType = as(type)) + return true; + return false; +} + +Type* DifferentialBottomType::_createCanonicalTypeOverride() { return this; } + +Val* DifferentialBottomType::_substituteImplOverride( + ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/) +{ + return this; +} + +HashCode DifferentialBottomType::_getHashCodeOverride() { return HashCode(size_t(this)); } + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void DeclRefType::_toTextOverride(StringBuilder& out) @@ -193,6 +214,7 @@ bool DeclRefType::_equalsImplOverride(Type * type) Type* DeclRefType::_createCanonicalTypeOverride() { // A declaration reference is already canonical + declRef.substitute(this->getASTBuilder(), this); return this; } @@ -223,39 +245,8 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe // the outer interface, then try to replace the type with the // actual value of the associated type for the given implementation. // - if (auto substAssocTypeDecl = as(substDeclRef.decl)) - { - for (auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) - { - auto thisSubst = as(s); - if (!thisSubst) - continue; - - if (auto interfaceDecl = as(substAssocTypeDecl->parentDecl)) - { - if (thisSubst->interfaceDecl == interfaceDecl) - { - // We need to look up the declaration that satisfies - // the requirement named by the associated type. - Decl* requirementKey = substAssocTypeDecl; - RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisSubst->witness, requirementKey); - switch (requirementWitness.getFlavor()) - { - default: - // No usable value was found, so there is nothing we can do. - break; - - case RequirementWitness::Flavor::val: - { - auto satisfyingVal = requirementWitness.getVal(); - return satisfyingVal; - } - break; - } - } - } - } - } + if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, substDeclRef)) + return satisfyingVal; // Re-construct the type in case we are using a specialized sub-class return DeclRefType::create(astBuilder, substDeclRef); diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 895b64f35..c7ce21cb0 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -86,6 +86,20 @@ protected: {} }; +// The bottom/empty type as a result of Differentiating a Differential. +class DifferentialBottomType : public DeclRefType +{ + SLANG_AST_CLASS(DifferentialBottomType) + + // Overrides should be public so base classes can access + void _toTextOverride(StringBuilder& out); + Type* _createCanonicalTypeOverride(); + bool _equalsImplOverride(Type* type); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); +}; + + // Base class for types that can be used in arithmetic expressions class ArithmeticExpressionType : public DeclRefType { diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index a8ceaa716..87e89ef18 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -618,6 +618,41 @@ Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, return substWitness; } +bool DifferentialBottomSubtypeWitness::_equalsValOverride(Val* val) +{ + auto otherDiffBottomWitness = as(val); + if (!otherDiffBottomWitness) + return false; + + return otherDiffBottomWitness->sub && otherDiffBottomWitness->sub->equals(sub); +} + +void DifferentialBottomSubtypeWitness::_toTextOverride(StringBuilder& out) +{ + out << "DifferentialBottomSubtypeWitness(" << sub << ")"; +} + +HashCode DifferentialBottomSubtypeWitness::_getHashCodeOverride() +{ + return combineHash(3892, sub->getHashCode()); +} + +Val* DifferentialBottomSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + + auto substSub = as(sub->substituteImpl(astBuilder, subst, &diff)); + auto substSup = as(sup->substituteImpl(astBuilder, subst, &diff)); + if (!diff) + return this; + + *ioDiff += diff; + + DifferentialBottomSubtypeWitness* substWitness = + astBuilder->create(substSub, substSup); + return substWitness; +} + bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val) { if (auto other = as(val)) diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 5f72e58c8..b52984f8b 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -399,6 +399,24 @@ class DynamicSubtypeWitness : public SubtypeWitness SLANG_AST_CLASS(DynamicSubtypeWitness) }; + /// A witness of the fact that any type can be viewed as a subtype of DifferentialBottom. +class DifferentialBottomSubtypeWitness : public SubtypeWitness +{ + SLANG_AST_CLASS(DifferentialBottomSubtypeWitness) + + DifferentialBottomSubtypeWitness(Type* inSub, Type* inSup) + { + sub = inSub; + sup = inSup; + } + + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + void _toTextOverride(StringBuilder& out); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); +}; + /// A witness that `T : L & R` because `T : L` and `T : R` class ConjunctionSubtypeWitness : public SubtypeWitness { diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 2c9977082..eb072e9dd 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -33,7 +33,6 @@ namespace Slang else return conjunction->rightWitness; } - ExtractFromConjunctionSubtypeWitness* simplExtractFromConjunction = builder->create(); simplExtractFromConjunction->sub = extractFromConjunction->sub; simplExtractFromConjunction->sup = extractFromConjunction->sup; @@ -145,7 +144,7 @@ namespace Slang m_astBuilder->getOrCreate( bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions); - TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor(subType, bb->sup, declaredWitness); + TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor(); transitiveWitness->sub = subType; transitiveWitness->sup = bb->sup; transitiveWitness->midToSup = declaredWitness; @@ -379,6 +378,45 @@ namespace Slang } } } + + // If a generic type parameter does not declare itself to conform to `IDifferentiable`, + // we treat it as a subtype of `DifferentialBottom` to make it conform to `IDifferentiable`. + // Note: we only consider this option for `originalSubType` so a type that implements `IDifferential` but + // inherits from some other non differentiable types don't get to inherit `DifferentialBottom`. + if (m_astBuilder->isDifferentiableInterfaceAvailable() && + subType == originalSubType && + superTypeDeclRef.getDecl() == m_astBuilder->getDifferentiableInterface()) + { + if (as(declRefType->declRef.getDecl()) || + as(declRefType->declRef.getDecl())) + { + auto sup = DeclRefType::create(m_astBuilder, superTypeDeclRef); + auto differentialBottomType = as(m_astBuilder->getDifferentialBottomType()); + auto container = differentialBottomType->declRef.as().getDecl(); + SLANG_RELEASE_ASSERT(container); + auto inheritanceDecl = container->getMembersOfType().getFirst(); + auto witnessDifferentialBottomIsIDifferentiable = + m_astBuilder->getOrCreate( + m_astBuilder->getDifferentialBottomType(), + sup, + inheritanceDecl, + nullptr); + + auto witnessSubIsDifferentialBottom = + m_astBuilder->getOrCreate( + subType, differentialBottomType); + + TransitiveSubtypeWitness* transitiveWitness = + m_astBuilder->getOrCreateWithDefaultCtor( + witnessSubIsDifferentialBottom, witnessDifferentialBottomIsIDifferentiable); + transitiveWitness->sub = subType; + transitiveWitness->sup = sup; + transitiveWitness->midToSup = witnessDifferentialBottomIsIDifferentiable; + transitiveWitness->subToMid = witnessSubIsDifferentialBottom; + *outWitness = transitiveWitness; + return true; + } + } } else if (auto extractExistentialType = as(subType)) { diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index e56d63f91..5e84e170b 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1169,9 +1169,6 @@ namespace Slang fromExpr); } - // If we coerced to a differentiable type, log it. - maybeRegisterDifferentiableType(m_astBuilder, expr->type); - return expr; } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 457ae229b..f60fbcc2c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -11,9 +11,8 @@ // and when things get checked. #include "slang-lookup.h" - #include "slang-syntax.h" - +#include "slang-ast-synthesis.h" #include namespace Slang @@ -166,6 +165,65 @@ namespace Slang void visitExtensionDecl(ExtensionDecl* decl); }; + struct SemanticsDeclTypeResolutionVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor + { + SemanticsDeclTypeResolutionVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + Val* resolveVal(Val* val); + Type* resolveType(Type* type) + { + return (Type*)resolveVal(type); + } + + void visitTypeExp(TypeExp& exp) + { + exp.type = resolveType(exp.type); + } + + void visitVarDeclBase(VarDeclBase* varDecl) + { + visitTypeExp(varDecl->type); + } + + void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl) + { + visitTypeExp(decl->sup); + } + + void visitTypeDefDecl(TypeDefDecl* decl) + { + visitTypeExp(decl->type); + } + + void visitGenericTypeParamDecl(GenericTypeParamDecl* paramDecl) + { + visitTypeExp(paramDecl->initType); + } + + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + visitTypeExp(inheritanceDecl->base); + } + + void visitCallableDecl(CallableDecl* decl) + { + visitTypeExp(decl->returnType); + visitTypeExp(decl->errorType); + } + + void visitPropertyDecl(PropertyDecl* decl) + { + visitTypeExp(decl->type); + } + }; + struct SemanticsDeclBodyVisitor : public SemanticsDeclVisitorBase , public DeclVisitor @@ -1363,27 +1421,30 @@ namespace Slang bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness( ConformanceCheckingContext* context, - DeclRef requirementDeclRef, + DeclRef requirementDeclRef, RefPtr witnessTable) { - // We currently can't handle generic types. - if (GetOuterGeneric(context->parentDecl) != nullptr) - { - return false; - } - + ASTSynthesizer synth(m_astBuilder, getNamePool()); Decl* existingDecl = nullptr; AggTypeDecl* aggTypeDecl = nullptr; if (context->parentDecl->getMemberDictionary().TryGetValue(requirementDeclRef.getName(), existingDecl)) { - aggTypeDecl = as(existingDecl); - SLANG_RELEASE_ASSERT(aggTypeDecl); - // Remove the `ToBeSynthesizedModifier`. - if (as(aggTypeDecl->modifiers.first)) + if (as(existingDecl->modifiers.first)) { - aggTypeDecl->modifiers.first = aggTypeDecl->modifiers.first->next; + existingDecl->modifiers.first = existingDecl->modifiers.first->next; } + else + { + // The user has defined an associatedtype explicitly but that we reach here because + // that type failed to satisfy the `IDifferential` requirement. + // We stop the synthesis and let the follow-up logic to report a diagnostic. + return false; + } + + aggTypeDecl = as(existingDecl); + SLANG_RELEASE_ASSERT(aggTypeDecl); + synth.pushContainerScope(aggTypeDecl); } else { @@ -1393,15 +1454,12 @@ namespace Slang aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName(); aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc; context->parentDecl->invalidateMemberDictionary(); + synth.pushScopeForContainer(aggTypeDecl); } - // TODO: if we want to make the synthesized type itself to be differentiable, - // add an inheritance decl here. Need to be careful to avoid infinite recursion - // trying to synthesize the higher order differential types. - // Helper function to add a `diffType` field into the synthesized type for the original // `member`. - auto differentialType = GetTypeForDeclRef(makeDeclRef(aggTypeDecl), context->parentDecl->loc); + auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl)); auto addDiffMember = [&](Decl* member, Type* diffMemberType) { // If the field is differentiable, add a corresponding field in the associated Differential type. @@ -1452,12 +1510,35 @@ namespace Slang addDiffMember(member, diffType); } - // In the future when the Differential type itself needs to conform to some interface, - // this is the place to synthesize requirements for them. addModifier(aggTypeDecl, m_astBuilder->create()); - auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, nullptr); - witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); - return true; + + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution + // from requirementDeclRef to get the generic substitution for outer generic parameters, and + // apply it to the newly synthesized decl. + SubstitutionSet substSet; + if (auto thisTypeSusbt = findThisTypeSubstitution( + requirementDeclRef.substitutions, + as(requirementDeclRef.getDecl()->parentDecl))) + { + if (auto declRefType = as(thisTypeSusbt->witness->sub)) + { + substSet = declRefType->declRef.substitutions; + } + } + + auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet); + + if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) + { + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType)); + return true; + } + + // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed. + // If not, there is something wrong with the code synthesis logic. For now we just return false + // instead of crashing so the user can work around the issues. + return false; } void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*) @@ -2242,22 +2323,8 @@ namespace Slang witnessTable); } - bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement( - Type* satisfyingType, - DeclRef requiredAssociatedTypeDeclRef, - RefPtr witnessTable) + bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement(Type* satisfyingType, DeclRef requiredAssociatedTypeDeclRef, RefPtr witnessTable) { - if (auto declRefType = as(satisfyingType)) - { - // If we are seeing a placeholder that awaits synthesis, return false now to trigger - // auto synthesis. - if (declRefType->declRef.getDecl()->hasModifier()) - return false; - } - // We need to confirm that the chosen type `satisfyingType`, - // meets all the constraints placed on the associated type - // requirement `requiredAssociatedTypeDeclRef`. - // // We will enumerate the type constraints placed on the // associated type and see if they can be satisfied. // @@ -2269,7 +2336,7 @@ namespace Slang // Perform a search for a witness to the subtype relationship. auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); - if(witness) + if (witness) { // If a subtype witness was found, then the conformance // appears to hold, and we can satisfy that requirement. @@ -2282,6 +2349,30 @@ namespace Slang conformance = false; } } + return conformance; + } + + bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement( + Type* satisfyingType, + DeclRef requiredAssociatedTypeDeclRef, + RefPtr witnessTable) + { + if (auto declRefType = as(satisfyingType)) + { + // If we are seeing a placeholder that awaits synthesis, return false now to trigger + // auto synthesis. + if (declRefType->declRef.getDecl()->hasModifier()) + return false; + } + // We need to confirm that the chosen type `satisfyingType`, + // meets all the constraints placed on the associated type + // requirement `requiredAssociatedTypeDeclRef`. + // + // We will enumerate the type constraints placed on the + // associated type and see if they can be satisfied. + // + bool conformance = doesTypeSatisfyAssociatedTypeConstraintRequirement( + satisfyingType, requiredAssociatedTypeDeclRef, witnessTable); // TODO: if any conformance check failed, we should probably include // that in an error message produced about not satisfying the requirement. @@ -3122,12 +3213,43 @@ namespace Slang return false; } + Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List&& args, int nestingLevel = 0) + { + if (nestingLevel > 16) + return nullptr; + + // If field type is an array, assign each element individually. + if (auto arrayType = as(leftType)) + { + VarDecl* indexVar = nullptr; + auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar); + auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar)); + for (auto& arg : args) + { + arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar)); + } + auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->baseType, innerLeft, _Move(args), nestingLevel + 1); + synth.popScope(); + if (!assignStmt) + return nullptr; + forStmt->statement = assignStmt; + return forStmt; + } + + auto callee = synth.emitMemberExpr(leftType, funcName); + return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args))); + } + bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( ConformanceCheckingContext* context, DeclRef requirementDeclRef, RefPtr witnessTable) { - // This method implements a general code synthesis pattern. + // We support two cases of synthesis here. + // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`. + // In this case we just trivially return `DifferentialBottom` in all synthesized methods. + // Case 2 is that the `Differential` type contains members corresponding to each primal member. + // We will apply a general code synthesis pattern to reflect that structure. // For requirement of the form: // ``` // static TResult requiredMethod(TParam1 p0, TParam2 p1, ...) @@ -3145,104 +3267,123 @@ namespace Slang // return result; // } // ``` + + // First we need to make sure the associated `Differential` type requirement is satisfied. + bool hasDifferentialAssocType = false; + for (auto existingEntry : witnessTable->requirementList) + { + if (auto builtinReqAttr = existingEntry.Key->findModifier()) + { + if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType && + existingEntry.Value.getFlavor() != RequirementWitness::Flavor::none) + { + hasDifferentialAssocType = true; + } + } + } + if (!hasDifferentialAssocType) + return false; + + ASTSynthesizer synth(m_astBuilder, getNamePool()); List synArgs; ThisExpr* synThis = nullptr; auto synFunc = synthesizeMethodSignatureForRequirementWitness( context, requirementDeclRef.as(), synArgs, synThis); - + synFunc->parentDecl = context->parentDecl; + synth.pushContainerScope(synFunc); auto blockStmt = m_astBuilder->create(); synFunc->body = blockStmt; - auto seqStmt = m_astBuilder->create(); + auto seqStmt = synth.pushSeqStmtScope(); blockStmt->body = seqStmt; - // Create a variable for return value. - auto scopeDecl = m_astBuilder->create(); - synFunc->members.add(scopeDecl); - scopeDecl->parentDecl = synFunc; - auto varStmt = m_astBuilder->create(); - seqStmt->stmts.add(varStmt); - - auto returnVar = m_astBuilder->create(); - returnVar->parentDecl = scopeDecl; - scopeDecl->members.add(returnVar); - - returnVar->type.type = synFunc->returnType.type; - returnVar->nameAndLoc.name = getName("result"); - varStmt->decl = returnVar; - auto resultVarExpr = m_astBuilder->create(); - resultVarExpr->declRef = makeDeclRef(returnVar); - resultVarExpr->type.type = synFunc->returnType.type; - resultVarExpr->type.isLeftValue = true; - - for (auto member : context->parentDecl->members) - { - auto derivativeAttr = member->findModifier(); - if (!derivativeAttr) - continue; - auto varMember = as(member); - if (!varMember) - continue; - ensureDecl(varMember, DeclCheckState::ReadyForReference); - auto memberType = varMember->getType(); - auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); - if (!diffMemberType) - continue; + if (synFunc->returnType.type->equals(m_astBuilder->getDifferentialBottomType())) + { + // Trivial case, the `Differential` type is `DifferentialBottom`. + // We will just return `DifferentialBottom.dzero()`. + auto resultExpr = m_astBuilder->create(); + auto dzeroMember = m_astBuilder->create(); + auto base = m_astBuilder->create(); + auto typetype = m_astBuilder->create(); + typetype->type = m_astBuilder->getDifferentialBottomType(); + base->type.type = typetype; + dzeroMember->baseExpression = base; + dzeroMember->name = getName("dzero"); + resultExpr->functionExpr = dzeroMember; + auto synReturn = m_astBuilder->create(); + synReturn->expression = resultExpr; + seqStmt->stmts.add(synReturn); + } + else + { + // The general case. + // Create a variable for return value. + synth.pushVarScope(); + auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result")); + auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type); - // Construct reference exprs to the member's corresponding fields in each parameter. - List paramFields; - int paramIndex = 0; - for (auto arg : synArgs) - { - auto memberExpr = m_astBuilder->create(); - memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); - paramFields.add(memberExpr); - paramIndex++; - } - - // Invoke the method for the field. - auto callee = m_astBuilder->create(); - auto baseSharedType = m_astBuilder->create(); - auto baseSharedTypeType = m_astBuilder->create(); - baseSharedTypeType->type = memberType; - baseSharedType->type = baseSharedTypeType; - baseSharedType->base.type = memberType; - callee->baseExpression = baseSharedType; - callee->name = requirementDeclRef.getName(); - callee->loc = synFunc->loc; - auto invokeExpr = m_astBuilder->create(); - invokeExpr->functionExpr = callee; - invokeExpr->arguments = _Move(paramFields); - - // Assign the value to resultVar. - auto leftVal = m_astBuilder->create(); - leftVal->baseExpression = resultVarExpr; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` - // is Differential type. - leftVal->name = varMember->getName(); - - auto assignExpr = m_astBuilder->create(); - assignExpr->left = leftVal; - assignExpr->right = invokeExpr; - auto assignStmt = m_astBuilder->create(); - assignStmt->expression = assignExpr; - seqStmt->stmts.add(assignStmt); - } - - // TODO: synthesize assignments for inherited members here. - - auto synReturn = m_astBuilder->create(); - synReturn->expression = resultVarExpr; - seqStmt->stmts.add(synReturn); + for (auto member : context->parentDecl->members) + { + auto derivativeAttr = member->findModifier(); + if (!derivativeAttr) + continue; + auto varMember = as(member); + if (!varMember) + continue; + ensureDecl(varMember, DeclCheckState::ReadyForReference); + auto memberType = varMember->getType(); + auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); + if (!diffMemberType) + continue; - synFunc->parentDecl = context->parentDecl; + // Construct reference exprs to the member's corresponding fields in each parameter. + List paramFields; + int paramIndex = 0; + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + paramIndex++; + } + + // Invoke the method for the field and assign the value to resultVar. + // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` + // is Differential type. + auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); + if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields))) + return false; + } + + // TODO: synthesize assignments for inherited members here. + + auto synReturn = m_astBuilder->create(); + synReturn->expression = resultVarExpr; + seqStmt->stmts.add(synReturn); + } + context->parentDecl->members.add(synFunc); context->parentDecl->invalidateMemberDictionary(); addModifier(synFunc, m_astBuilder->create()); - witnessTable->add(requirementDeclRef, RequirementWitness(makeDeclRef(synFunc))); + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized method here in order to fill into the witness table. + // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the + // generic substitution for outer generic parameters, and apply it here. + SubstitutionSet substSet; + if (auto thisTypeSusbt = findThisTypeSubstitution( + requirementDeclRef.substitutions, + as(requirementDeclRef.getDecl()->parentDecl))) + { + if (auto declRefType = as(thisTypeSusbt->witness->sub)) + { + substSet = declRefType->declRef.substitutions; + } + } + + witnessTable->add(requirementDeclRef, RequirementWitness(DeclRef(synFunc, substSet))); return true; } @@ -3801,7 +3942,10 @@ namespace Slang // be required to implement all interface requirements, // just with `abstract` methods that replicate things? // (That's what C# does). - for (auto inheritanceDecl : decl->getMembersOfType()) + + // Make a copy of inhertanceDecls firstsince `checkConformance` may modify decl->members. + auto inheritanceDecls = decl->getMembersOfType().toList(); + for (auto inheritanceDecl : inheritanceDecls) { checkConformance(type, inheritanceDecl, decl); } @@ -5230,7 +5374,7 @@ namespace Slang void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { - if (decl->findModifier()) + if (decl->findModifier()) { this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); } @@ -6274,6 +6418,10 @@ namespace Slang SemanticsDeclConformancesVisitor(shared).dispatch(decl); break; + case DeclCheckState::TypesFullyResolved: + SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl); + break; + case DeclCheckState::Checked: SemanticsDeclBodyVisitor(shared).dispatch(decl); break; @@ -6325,4 +6473,40 @@ namespace Slang return result; } + Val* SemanticsDeclTypeResolutionVisitor::resolveVal(Val* val) + { + if (auto declRefType = as(val)) + { + if (auto concreteType = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(m_astBuilder, declRefType->declRef)) + return as(concreteType); + for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst=subst->outer) + { + if (auto genericSubst = as(subst)) + { + ShortList newArgs; + for (auto& arg : genericSubst->getArgs()) + { + arg = resolveVal(arg); + SLANG_RELEASE_ASSERT(arg); + } + } + } + } + else if (auto subtypeWitness = as(val)) + { + auto sub = as(resolveVal(subtypeWitness->sub)); + auto sup = as(resolveVal(subtypeWitness->sup)); + if (sub && sup) + { + if (sub != subtypeWitness->sub || sup != subtypeWitness->sup) + { + auto newVal = tryGetSubtypeWitness(as(sub), as(sup)); + if (newVal) + val = newVal; + } + } + } + return val; + } + } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index fe37f5099..251849ede 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1410,6 +1410,19 @@ namespace Slang return m_astBuilder->getOrCreate(m_astBuilder->getBoolType(), value); } + if (auto arrayLengthExpr = expr.as()) + { + if (arrayLengthExpr.getExpr()->arrayExpr && arrayLengthExpr.getExpr()->arrayExpr->type) + { + auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts()); + if (auto arrayType = as(type)) + { + if (auto val = as(arrayType->arrayLength)) + return val; + } + } + } + // it is possible that we are referring to a generic value param if (auto declRefExpr = expr.as()) { @@ -1871,14 +1884,42 @@ namespace Slang arg = CheckTerm(arg); } - return CheckInvokeExprWithCheckedOperands(expr); + // If we are in a differentiable function, register differential witness tables involved in + // this call. + if (m_parentFunc && m_parentFunc->hasModifier()) + { + for (auto& arg : expr->arguments) + { + maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); + } + } + + auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr); + + if (m_parentFunc && m_parentFunc->hasModifier()) + { + if (auto checkedInvokeExpr = as(checkedExpr)) + { + // Register types for final resolved invoke arguments again. + for (auto& arg : expr->arguments) + { + maybeRegisterDifferentiableType(m_astBuilder, arg->type.type); + } + } + maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type); + } + return checkedExpr; } Expr* SemanticsExprVisitor::visitVarExpr(VarExpr *expr) { // If we've already resolved this expression, don't try again. if (expr->declRef) + { + if (!expr->type) + expr->type = GetTypeForDeclRef(expr->declRef, expr->loc); return expr; + } expr->type = QualType(m_astBuilder->getErrorType()); auto lookupResult = lookUp( @@ -1908,63 +1949,56 @@ namespace Slang return expr; } - Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) + Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType) { // Check for type modifiers like 'out' and 'inout'. We need to differentiate the // nested type. // if (auto primalOutType = as(primalType)) { - return builder->getOutType(_toDifferentialParamType(builder, primalOutType->getValueType())); + return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType())); } else if (auto primalInOutType = as(primalType)) { - return builder->getInOutType(_toDifferentialParamType(builder, primalInOutType->getValueType())); + return m_astBuilder->getInOutType(_toDifferentialParamType(primalInOutType->getValueType())); } + return getDifferentialPairType(primalType); + } + Type* SemanticsVisitor::getDifferentialPairType(Type* primalType) + { // Get a reference to the builtin 'IDifferentiable' interface - auto differentiableInterface = builder->getDifferentiableInterface(); + auto differentiableInterface = m_astBuilder->getDifferentiableInterface(); + auto conformanceWitness = as(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)); // Check if the provided type inherits from IDifferentiable. // If not, return the original type. - if (auto conformanceWitness = as(tryGetInterfaceConformanceWitness(primalType, differentiableInterface))) - return builder->getDifferentialPairType(primalType, conformanceWitness); + if (conformanceWitness) + return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); else return primalType; - } - Type* SemanticsVisitor::_toJVPReturnType(ASTBuilder* builder, Type* primalType) - { - if (auto conformanceWitness = - as(tryGetInterfaceConformanceWitness( - primalType, - builder->getDifferentiableInterface()))) - return builder->getDifferentialPairType(primalType, conformanceWitness); - else - return primalType; - } - - Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType) + Type* SemanticsVisitor::processJVPFuncType(FuncType* originalType) { // Resolve JVP type here. // Note that this type checking needs to be in sync with // the auto-generation logic in slang-ir-jvp-diff.cpp - FuncType* jvpType = builder->create(); + FuncType* jvpType = m_astBuilder->create(); // The JVP return type is float if primal return type is float // void otherwise. // - jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType()); + jvpType->resultType = getDifferentialPairType(originalType->getResultType()); // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType())); + SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); jvpType->errorType = originalType->errorType; for (UInt i = 0; i < originalType->getParamCount(); i++) { - if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i))) + if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i))) jvpType->paramTypes.add(jvpParamType); } @@ -1978,6 +2012,15 @@ namespace Slang // Check/Resolve inner function declaration. expr->baseFunction = CheckTerm(expr->baseFunction); + // Register parameter types. + if (auto funcType = as(expr->baseFunction->type.type)) + { + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + maybeRegisterDifferentiableType(m_astBuilder, funcType->getParamType(i)); + } + } + // For now we only support using higher order expr as callee in an invoke expr. // The actual type of the higher order function will be derived during resolve invoke. expr->type = m_astBuilder->getBottomType(); @@ -1985,6 +2028,29 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) + { + expr->arrayExpr = CheckTerm(expr->arrayExpr); + if (auto arrType = as(expr->arrayExpr->type)) + { + expr->type = m_astBuilder->getIntType(); + if (!arrType->arrayLength) + { + getSink()->diagnose(expr, Diagnostics::invalidArraySize); + } + } + else + { + if (!as(expr->arrayExpr->type)) + { + getSink()->diagnose( + expr, Diagnostics::typeMismatch, "array", expr->arrayExpr->type); + } + expr->type = m_astBuilder->getErrorType(); + } + return expr; + } + Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr) { // Check the term we are applying first diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 821f785f5..33455e42d 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -794,15 +794,12 @@ namespace Slang bool shouldSkipChecking(Decl* decl, DeclCheckState state); // Auto-diff convenience functions for translating primal types to differential types. - Type* _toDifferentialParamType(ASTBuilder* builder, Type* primalType); + Type* _toDifferentialParamType(Type* primalType); + + Type* getDifferentialPairType(Type* primalType); - // Translate a return type to the return type of a forward-mode differentiated - // function. - // - Type* _toJVPReturnType(ASTBuilder* builder, Type* primalType); - // Convert a function's original type to it's JVP type. - Type* processJVPFuncType(ASTBuilder* builder, FuncType* originalType); + Type* processJVPFuncType(FuncType* originalType); // Check and register a type if it is differentiable. void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); @@ -1038,6 +1035,11 @@ namespace Slang DeclRef requirementGenDecl, RefPtr witnessTable); + bool doesTypeSatisfyAssociatedTypeConstraintRequirement( + Type* satisfyingType, + DeclRef requiredAssociatedTypeDeclRef, + RefPtr witnessTable); + bool doesTypeSatisfyAssociatedTypeRequirement( Type* satisfyingType, DeclRef requiredAssociatedTypeDeclRef, @@ -1124,7 +1126,7 @@ namespace Slang /// Otherwise, returns `false`. bool trySynthesizeDifferentialAssociatedTypeRequirementWitness( ConformanceCheckingContext* context, - DeclRef requirementDeclRef, + DeclRef requirementDeclRef, RefPtr witnessTable); /// Registers a type as differentiable in the currrent semantic context, if the declaration represents @@ -1989,6 +1991,8 @@ namespace Slang Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); + Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr); + /// Perform semantic checking on a `modifier` that is being applied to the given `type` Val* checkTypeModifier(Modifier* modifier, Type* type); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index ef067c06c..42fab94a6 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1559,7 +1559,7 @@ namespace Slang OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = as(processJVPFuncType(this->getASTBuilder(), origFuncType)); + candidate.funcType = as(processJVPFuncType(origFuncType)); candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); @@ -1576,7 +1576,6 @@ namespace Slang OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; candidate.funcType = as(processJVPFuncType( - this->getASTBuilder(), as(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)))); candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(item.declRef); @@ -1606,7 +1605,7 @@ namespace Slang auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as()); // Process func type to generate JVP func type. - auto jvpFuncType = as(processJVPFuncType(this->getASTBuilder(), funcType)); + auto jvpFuncType = as(processJVPFuncType(funcType)); // Extract parameter list from processed type. List paramTypes; @@ -1631,7 +1630,6 @@ namespace Slang // This could potentially be a declRef.substitute(jvpFuncType) // candidate.funcType = as(processJVPFuncType( - this->getASTBuilder(), getFuncType(this->getASTBuilder(), innerRef.as()))); candidate.resultType = candidate.funcType->getResultType(); diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 73818dbb1..3d02d4fc0 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -205,7 +205,7 @@ struct DifferentiableTypeConformanceContext { if (as(inst) && differentiableInterfaceType) { - // Assume for now that IDifferentiable has exactly three fields. + // Assume for now that IDifferentiable has exactly four fields. SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4); if (auto entry = as(differentiableInterfaceType->getOperand(index))) return as(entry->getRequirementKey()); @@ -462,45 +462,6 @@ struct DifferentialPairTypeBuilder } } - void _createGenericDiffPairType(IRBuilder* builder) - { - // Insert directly at top level (skip any generic scopes etc.) - auto insertLoc = builder->getInsertLoc(); - builder->setInsertInto(builder->getModule()->getModuleInst()); - - // Make a generic version of the pair struct. - auto irGeneric = builder->emitGeneric(); - irGeneric->setFullType(builder->getTypeKind()); - builder->setInsertInto(irGeneric); - - generatedTypeList.add(irGeneric); - - auto irBlock = builder->emitBlock(); - builder->setInsertInto(irBlock); - - auto pTypeParam = builder->emitParam(builder->getTypeType()); - builder->addNameHintDecoration(pTypeParam, UnownedTerminatedStringSlice("pT")); - - auto dTypeParam = builder->emitParam(builder->getTypeType()); - builder->addNameHintDecoration(dTypeParam, UnownedTerminatedStringSlice("dT")); - - auto irStructType = builder->createStructType(); - builder->emitReturn(irStructType); - - auto primalKey = _getOrCreatePrimalStructKey(builder); - builder->addNameHintDecoration(primalKey, UnownedTerminatedStringSlice("primal")); - builder->createStructField(irStructType, primalKey, (IRType*) pTypeParam); - - auto diffKey = _getOrCreateDiffStructKey(builder); - builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential")); - builder->createStructField(irStructType, diffKey, (IRType*) dTypeParam); - - // Reset cursor when done. - builder->setInsertLoc(insertLoc); - - this->genericDiffPairType = irGeneric; - } - IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder) { if (!this->globalDiffKey) @@ -535,17 +496,6 @@ struct DifferentialPairTypeBuilder return this->globalPrimalKey; } - IRInst* _getOrCreateGenericDiffPairType(IRBuilder* builder) - { - if (!this->genericDiffPairType) - { - _createGenericDiffPairType(builder); - } - - SLANG_ASSERT(this->genericDiffPairType); - return this->genericDiffPairType; - } - IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) { if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType)) @@ -1383,22 +1333,10 @@ struct JVPTranscriber } else { - // We special case a few non-differentiable types that sometimes appear in places - // where we're forced to provide a differential zero value. For instance, - // float3(float, float, int) is accepted by the compiler, but is tricky in the context - // of differentiation since int is non-differentiable, and should be cast to float first. - // In the absence of such casts, this piece of code generates appropriate zero values. - // - switch (primalType->getOp()) - { - case kIROp_IntType: - return builder->getIntValue(primalType, 0); - default: - getSink()->diagnose(primalType->sourceLoc, - Diagnostics::internalCompilerError, - "could not generate zero value for given type"); - return nullptr; - } + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; } } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 1d1db14f9..61aa28bbe 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -60,6 +60,7 @@ INST(Nop, nop, 0, 0) INST(OptionalType, Optional, 1, 0) INST(DifferentialPairType, DiffPair, 1, 0) + INST(DifferentialBottomType, DiffBottomType, 0, 0) /* BindExistentialsTypeBase */ @@ -277,7 +278,6 @@ INST(lookup_interface_method, lookup_interface_method, 2, 0) INST(GetSequentialID, GetSequentialID, 1, 0) INST(lookup_witness_table, lookup_witness_table, 2, 0) INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) - INST(Construct, construct, 0, 0) INST(AllocObj, allocObj, 0, 0) @@ -297,6 +297,7 @@ INST(GetOptionalValue, getOptionalValue, 1, 0) INST(OptionalHasValue, optionalHasValue, 1, 0) INST(MakeOptionalValue, makeOptionalValue, 1, 0) INST(MakeOptionalNone, makeOptionalNone, 1, 0) +INST(DifferentialBottomValue, differentialBottomVal, 0, 0) INST(Call, call, 1, 0) INST(RTTIObject, rtti_object, 0, 0) @@ -759,6 +760,7 @@ INST(Reinterpret, reinterpret, 1, 0) INST(CastPtrToBool, CastPtrToBool, 1, 0) INST(IsType, IsType, 3, 0) INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0) +INST(DifferentialEqualityTypeCast, DifferentialEqualityTypeCast, 1, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 3a59eb6c9..382f7be5e 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3629,6 +3629,7 @@ namespace Slang IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType, IRInst* scope) { + IRInst* foundResult = nullptr; for (auto child = scope->getFirstChild(); child; child = child->getNextInst()) { if (child->getOp() == kIROp_DifferentiableTypeDictionary) @@ -3640,13 +3641,20 @@ namespace Slang if (irType == entryType) { - return entryConformanceWitness; + foundResult = entryConformanceWitness; + // If the found witness table is not a trivial one (i.e. DifferentialBottom:IDifferential), + // return immediately. Otherwise, continue the search to see if we can find a better one. + if (auto witness = as(foundResult)) + { + if (witness->getConcreteType()->getOp() != kIROp_DifferentialBottomType) + return foundResult; + } } } } } - return nullptr; + return foundResult; } IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType) @@ -6282,6 +6290,8 @@ namespace Slang case kIROp_MakeOptionalNone: case kIROp_OptionalHasValue: case kIROp_GetOptionalValue: + case kIROp_MakeTuple: + case kIROp_GetTupleElement: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads case kIROp_ImageSubscript: case kIROp_FieldExtract: diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ae0590105..8f00253f5 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1366,12 +1366,29 @@ struct ValLoweringVisitor : ValVisitormidToSup); + IRInst* midToSup = lowerSimpleVal(context, val->midToSup); + + if (!baseWitnessTable) + { + // If we don't have a valid baseWitnessTable, + // we are in a situation that `subToMid` is a `DifferentialBottomSubtypeWitness` + // that applies for all non-differentiable types. + // In this case `midToSup` will give us the `DifferentialBottom:IDifferentiable` + // witness table and we can just use that as the final result of + // this transitive witness. + SLANG_RELEASE_ASSERT(midToSup && as(midToSup->getDataType())); + return LoweredValInfo::simple(midToSup); + } return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst( getBuilder()->getWitnessTableType(lowerType(context, val->sup)), baseWitnessTable, - requirementKey)); + midToSup)); + } + + LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*) + { + return LoweredValInfo(); } LoweredValInfo visitTaggedUnionSubtypeWitness( @@ -3053,6 +3070,15 @@ struct ExprLoweringVisitorBase : ExprVisitor baseVal.val)); } + LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr) + { + auto baseVal = lowerSubExpr(expr->arrayExpr); + auto type = lowerType(context, expr->arrayExpr->type); + auto arrayType = as(type); + SLANG_ASSERT(arrayType); + return LoweredValInfo::simple(arrayType->getElementCount()); + } + LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/) { SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); @@ -5857,7 +5883,9 @@ struct DeclLoweringVisitor : DeclVisitor // add an entry to the context. // if (irWitness && !getBuilder()->findDifferentiableTypeEntry(irType)) + { getBuilder()->addDifferentiableTypeEntry(irType, irWitness); + } } else if (auto importEntry = as(member)) { @@ -6777,7 +6805,7 @@ struct DeclLoweringVisitor : DeclVisitor IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr); // Add `irInterface` to decl mapping now to prevent cyclic lowering. - setValue(subContext, decl, LoweredValInfo::simple(irInterface)); + setValue(context, decl, LoweredValInfo::simple(irInterface)); // Setup subContext for proper lowering `ThisType`, associated types and // the interface decl's self reference. @@ -7084,6 +7112,7 @@ struct DeclLoweringVisitor : DeclVisitor void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember) { + ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl); auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; SLANG_RELEASE_ASSERT(as(key)); auto builder = getBuilder(); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 93f2fcdcb..980a1d0bc 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -132,8 +132,8 @@ namespace Slang Scope* newScope = astBuilder->create(); newScope->containerDecl = containerDecl; newScope->parent = currentScope; - currentScope = newScope; + containerDecl->ownedScope = newScope; } void pushScopeAndSetParent(ContainerDecl* containerDecl) diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index c779b4510..8cd443438 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -319,7 +319,31 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } } } - + else if (auto transitiveTypeWitness = as(subtypeWitness)) + { + // Hard code witness entry that `T.Differential = DifferentialBottom` for `T` that + // coerce to `DifferentialBottom`. + if (astBuilder->getDifferentialBottomType()->equals(transitiveTypeWitness->subToMid->sup)) + { + if (auto builtinAttr = requirementKey->findModifier()) + { + if (builtinAttr->kind == BuiltinRequirementKind::DifferentialType) + { + return RequirementWitness(astBuilder->getDifferentialBottomType()); + } + } + } + } + else if (auto extractFromConjunctionTypeWitness = as(subtypeWitness)) + { + if (auto conjunctionTypeWitness = as(extractFromConjunctionTypeWitness->conjunctionWitness)) + { + if (extractFromConjunctionTypeWitness->indexInConjunction == 0) + return tryLookUpRequirementWitness(astBuilder, as(conjunctionTypeWitness->leftWitness), requirementKey); + else + return tryLookUpRequirementWitness(astBuilder, as(conjunctionTypeWitness->rightWitness), requirementKey); + } + } // TODO: should handle the transitive case here too return RequirementWitness(); @@ -1140,7 +1164,46 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt return nullptr; } - // + Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef declRef) + { + auto substDeclRef = declRef.as(); + if (!substDeclRef) + return nullptr; + + auto substAssocTypeDecl = substDeclRef.getDecl(); + + for (auto s = substDeclRef.substitutions.substitutions; s; s = s->outer) + { + auto thisSubst = as(s); + if (!thisSubst) + continue; + + if (auto interfaceDecl = as(substAssocTypeDecl->parentDecl)) + { + if (thisSubst->interfaceDecl == interfaceDecl) + { + // We need to look up the declaration that satisfies + // the requirement named by the associated type. + Decl* requirementKey = substAssocTypeDecl; + RequirementWitness requirementWitness = tryLookUpRequirementWitness(builder, thisSubst->witness, requirementKey); + switch (requirementWitness.getFlavor()) + { + default: + // No usable value was found, so there is nothing we can do. + break; + + case RequirementWitness::Flavor::val: + { + auto satisfyingVal = requirementWitness.getVal(); + return satisfyingVal; + } + break; + } + } + } + } + return nullptr; + } String DeclRefBase::toString() const { diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index ee868be62..b38bd358f 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -233,7 +233,7 @@ (Slang::LetExpr*)&astNodeType (Slang::ExtractExistentialValueExpr*)&astNodeType (Slang::OpenRefExpr*)&astNodeType - (Slang::JVPDifferentiateExpr*)&astNodeType + (Slang::ForwardDifferentiateExpr*)&astNodeType (Slang::TaggedUnionTypeExpr*)&astNodeType (Slang::ThisTypeExpr*)&astNodeType (Slang::AndTypeExpr*)&astNodeType @@ -384,6 +384,7 @@ (Slang::ErrorType*)&astNodeType (Slang::BottomType*)&astNodeType (Slang::DeclRefType*)&astNodeType + (Slang::DeclRefType*)&astNodeType (Slang::ArithmeticExpressionType*)&astNodeType (Slang::BasicExpressionType*)&astNodeType (Slang::VectorExpressionType*)&astNodeType @@ -456,7 +457,7 @@ - + {astNodeType} (Slang::GenericSubstitution*)&astNodeType @@ -469,7 +470,7 @@ substitutions outer - this + (Slang::Substitutions*)this @@ -487,6 +488,79 @@ (Slang::GenericParamIntVal*)&astNodeType (Slang::DeclaredSubtypeWitness*)&astNodeType (Slang::TransitiveSubtypeWitness*)&astNodeType + (Slang::OverloadGroupType*)&astNodeType + (Slang::InitializerListType*)&astNodeType + (Slang::ErrorType*)&astNodeType + (Slang::BottomType*)&astNodeType + (Slang::DeclRefType*)&astNodeType + (Slang::DeclRefType*)&astNodeType + (Slang::ArithmeticExpressionType*)&astNodeType + (Slang::BasicExpressionType*)&astNodeType + (Slang::VectorExpressionType*)&astNodeType + (Slang::MatrixExpressionType*)&astNodeType + (Slang::BuiltinType*)&astNodeType + (Slang::FeedbackType*)&astNodeType + (Slang::ResourceType*)&astNodeType + (Slang::TextureTypeBase*)&astNodeType + (Slang::TextureType*)&astNodeType + (Slang::TextureSamplerType*)&astNodeType + (Slang::GLSLImageType*)&astNodeType + (Slang::SamplerStateType*)&astNodeType + (Slang::BuiltinGenericType*)&astNodeType + (Slang::PointerLikeType*)&astNodeType + (Slang::ParameterGroupType*)&astNodeType + (Slang::UniformParameterGroupType*)&astNodeType + (Slang::ConstantBufferType*)&astNodeType + (Slang::TextureBufferType*)&astNodeType + (Slang::GLSLShaderStorageBufferType*)&astNodeType + (Slang::ParameterBlockType*)&astNodeType + (Slang::VaryingParameterGroupType*)&astNodeType + (Slang::GLSLInputParameterGroupType*)&astNodeType + (Slang::GLSLOutputParameterGroupType*)&astNodeType + (Slang::HLSLStructuredBufferTypeBase*)&astNodeType + (Slang::HLSLStructuredBufferType*)&astNodeType + (Slang::HLSLRWStructuredBufferType*)&astNodeType + (Slang::HLSLRasterizerOrderedStructuredBufferType*)&astNodeType + (Slang::HLSLAppendStructuredBufferType*)&astNodeType + (Slang::HLSLConsumeStructuredBufferType*)&astNodeType + (Slang::HLSLStreamOutputType*)&astNodeType + (Slang::HLSLPointStreamType*)&astNodeType + (Slang::HLSLLineStreamType*)&astNodeType + (Slang::HLSLTriangleStreamType*)&astNodeType + (Slang::UntypedBufferResourceType*)&astNodeType + (Slang::HLSLByteAddressBufferType*)&astNodeType + (Slang::HLSLRWByteAddressBufferType*)&astNodeType + (Slang::HLSLRasterizerOrderedByteAddressBufferType*)&astNodeType + (Slang::RaytracingAccelerationStructureType*)&astNodeType + (Slang::HLSLPatchType*)&astNodeType + (Slang::HLSLInputPatchType*)&astNodeType + (Slang::HLSLOutputPatchType*)&astNodeType + (Slang::GLSLInputAttachmentType*)&astNodeType + (Slang::StringTypeBase*)&astNodeType + (Slang::StringType*)&astNodeType + (Slang::NativeStringType*)&astNodeType + (Slang::DynamicType*)&astNodeType + (Slang::EnumTypeType*)&astNodeType + (Slang::PtrTypeBase*)&astNodeType + (Slang::PtrType*)&astNodeType + (Slang::ParamDirectionType*)&astNodeType + (Slang::OutTypeBase*)&astNodeType + (Slang::OutType*)&astNodeType + (Slang::InOutType*)&astNodeType + (Slang::RefType*)&astNodeType + (Slang::NullPtrType*)&astNodeType + (Slang::ArrayExpressionType*)&astNodeType + (Slang::TypeType*)&astNodeType + (Slang::NamedExpressionType*)&astNodeType + (Slang::FuncType*)&astNodeType + (Slang::GenericDeclRefType*)&astNodeType + (Slang::NamespaceType*)&astNodeType + (Slang::ExtractExistentialType*)&astNodeType + (Slang::TaggedUnionType*)&astNodeType + (Slang::ExistentialSpecializedType*)&astNodeType + (Slang::ThisType*)&astNodeType + (Slang::AndType*)&astNodeType + (Slang::ModifiedType*)&astNodeType \ No newline at end of file -- cgit v1.2.3