diff options
| author | Theresa Foley <10618364+tangent-vector@users.noreply.github.com> | 2022-05-25 08:14:28 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-05-25 08:14:28 -0700 |
| commit | 24e60fd14bd957a69fb054d20e43e2c0580d57f2 (patch) | |
| tree | f58200bb82d5fc5fa45f92652ab264ce82b856c4 | |
| parent | 5c2e3a841fe7fc98cfa7c135596b4eef278f3a56 (diff) | |
Allow [mutating] methods on existential values (#2245)
The problematic case is when an `interface` has a `[mutating]` method:
interface ICounter
{
[mutating] void increment();
}
and code tries to invoke that method on a value of existential type:
ICounter c = ...;
c.increment();
We know that the existential value `c` is conceptually a tuple of:
* A concrete type `X`
* A witness that `X : ICounter`
* A value `v` of type `X`
We simply want to invoke `increment()` on the `v` part, using the `X : ICounter` witness table.
The catch that the compiler faces is that the variable `c` is mutable, so we need to be careful that we "snapshot" its value (the tuple `X, X:ICounter, v`) at a single point.
The snapshotting behavior is important when invoking a method that involves `This` or associated types in its signature, so we cannot get rid of it.
The snapshotting we do relies on the idea of a `LetExpr` AST node, which cannot be written in the input syntax.
A `LetExpr` introduces a variable binding (with an initial-value expression) and then evaluates a body expression in the context of that binding.
For a call site like `c.increment()` the front-end makes an intermediate copy of `c` and then "opens" that immutable value to get at the elements of the tuple `X`, `X : ICounter`, `v`.
The resulting AST after checking looks something like:
ICounter c = ...;
(let tmp = c in extractExistentialValue(tmp)).increment();
In that form it is more clear why the attempt to call `increment()` fails:
1. The binding `tmp` sure looks immutable
2. There is no logic in the compiler to make `extractExistentialValue(x)` be an l-value if `x` is
3. There is seemingly no logic to write back from `tmp` to `c` when the operation completes
Let us walk through those problems in order.
Item (1) turns out to be a bit of a non-issue.
Despite the way that I've written out `let` expressions above, the logic in `moveTemp()` in the compiler actually introduces a *mutable* binding.
Item (2) can be fixed for the purposes of semantic checking by modifying `openExistential()`.
Simplistically, we make the overall expression be an l-value if the operand is.
Item (3) is handled at the level of AST->IR lowering. Each kind of expression that can form an l-value needs to have a way to represent the "location" of that l-value in the `LoweredValInfo` type.
This change adds a case to handle the `extractExistentialVal` operation, by tracking both the extract value (of concrete type) and the underlying l-value (of existential type).
Where all of this comes crashing against reality a bit is that the scoping I've drawn for the `let` expressions above kind of doesn't work once we look at types.
The basic problem is that the *type* of the `(let tmp = c in ...)` expression is the concrete type `X` that was extracted from the existential.
That type can conceptually be written as `ExtractExistentialType(tmp)` which, notably, references `tmp`.
That means that we end up with AST expression nodes that reference the variable `tmp` *outside* of its scope.
Furthermore, those references to `tmp` can end up being lowered to IR *before* we have lowered the `let ...` expression itself.
Fixing the scoping issue turns out to be a major undertaking.
The first (and more obvious) issue is needing to address the scoping problem.
The solution I implemented includes a bit of refactoring to make all the `SemanticsVisitor` types better able to pass around the contextual scope-dependent state that might be needed during semantic checking, but really only adds a single piece of state.
The semantic-checking state used for checking expressions is bottlenecked so that there will (or at least *should*) always be an explicit representation of a "scope" that surrounds a complete expression (as opposed to a sub-expression).
When a `LetExpr` needs to be introduced, it is added to a pending list on the active scope, rather than being added locally.
Once the complete expression is checked, the resulting expression is wrapped up in the pending `LetExpr`s so that their scope is as broad as possible.
Technically this solution doesn't cover all cases. For example:
interface ICell { associatedtype Content; Content getContent(); }
...
ICell cell = ...;
let content = cell.getContent();
In this case the type of `content` refers to the binding introduced by a `LetExpr` in the initial-value expression.
I am leaving such issues as a piece of future work, in the hopes that we can get at least a partial fix for the problem in place.
A future fix probably nees to extend the scoping even wider (e.g., by unwrapping the `LetExpr`s from the initial-value expression and turning them into distinct temporaries).
The second piece of the fix is that we need a way for the modified value of the extracted existential to be "written back" to the original location.
Well...
We are actually being a little slippery here, based on some logic in the compiler codebase that I guess Just Works.
When AST->IR lowering encounters a `LetExpr` that binds an l-value to a name, it actually ends up binding that name more or less as a *reference* to that l-value.
At this point the `let`-ness of `LetExpr` is very much in doubt: the binding can be mutable, and it can even be an *alias* of some location?!?
In any case, the result is that the AST->IR codegen logic implicitly handles the "write-back" because the `let`-bound temporary is actually an alias for the original location.
A more complete future fix might need to introduce a distinct case in `LoweredValInfo` to handle the case of copy of a mutable temporary.
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 74 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 118 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 146 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-check.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 160 |
6 files changed, 430 insertions, 93 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 6c414b292..6d579e1fb 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -21,8 +21,8 @@ namespace Slang : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclModifiersVisitor> { - SemanticsDeclModifiersVisitor(SharedSemanticsContext* shared) - : SemanticsDeclVisitorBase(shared) + SemanticsDeclModifiersVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) {} void visitDeclGroup(DeclGroup*) {} @@ -37,8 +37,8 @@ namespace Slang : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclHeaderVisitor> { - SemanticsDeclHeaderVisitor(SharedSemanticsContext* shared) - : SemanticsDeclVisitorBase(shared) + SemanticsDeclHeaderVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) {} void visitDecl(Decl*) {} @@ -106,8 +106,8 @@ namespace Slang : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclRedeclarationVisitor> { - SemanticsDeclRedeclarationVisitor(SharedSemanticsContext* shared) - : SemanticsDeclVisitorBase(shared) + SemanticsDeclRedeclarationVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) {} void visitDecl(Decl*) {} @@ -127,8 +127,8 @@ namespace Slang : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclBasesVisitor> { - SemanticsDeclBasesVisitor(SharedSemanticsContext* shared) - : SemanticsDeclVisitorBase(shared) + SemanticsDeclBasesVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) {} void visitDecl(Decl*) {} @@ -161,8 +161,8 @@ namespace Slang : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclBodyVisitor> { - SemanticsDeclBodyVisitor(SharedSemanticsContext* shared) - : SemanticsDeclVisitorBase(shared) + SemanticsDeclBodyVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) {} void visitDecl(Decl*) {} @@ -668,7 +668,7 @@ namespace Slang /// This call does *not* handle updating the state of `decl`; the /// caller takes responsibility for doing so. /// - static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, SharedSemanticsContext* shared); + static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, SemanticsContext const& shared); // Make sure a declaration has been checked, so we can refer to it. // Note that this may lead to us recursively invoking checking, @@ -729,7 +729,12 @@ namespace Slang // We now dispatch an appropriate visitor based on `nextState`. // - _dispatchDeclCheckingVisitor(decl, nextState, getShared()); + // Note that we always dispatch the visitor in a "fresh" semantic-checking + // context, so that the state at the point where a declaration is *referenced* + // cannot affect the state in which the declaration is *checked*. + // + SemanticsContext subContext(getShared()); + _dispatchDeclCheckingVisitor(decl, nextState, subContext); // In the common case, the visitor will have done the necessary // checking, but will *not* have updated the `checkState` on @@ -1177,8 +1182,8 @@ namespace Slang : public SemanticsDeclVisitorBase , public DeclVisitor<SemanticsDeclConformancesVisitor> { - SemanticsDeclConformancesVisitor(SharedSemanticsContext* shared) - : SemanticsDeclVisitorBase(shared) + SemanticsDeclConformancesVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) {} void visitDecl(Decl*) {} @@ -2088,31 +2093,24 @@ namespace Slang // to the user as some kind of overload-resolution failure. // // In order to protect the user from whatever errors might - // occur, we will swap out the current diagnostic sink for - // a temporary one. + // occur, we will perform the checking in the context of + // a temporary diagnostic sink. // - DiagnosticSink* savedSink = m_shared->m_sink; - DiagnosticSink tempSink(savedSink->getSourceManager(), nullptr); - m_shared->m_sink = &tempSink; + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); // With our temporary diagnostic sink soaking up any messages // from overload resolution, we can now try to resolve // the call to see what happens. // - auto checkedCall = ResolveInvoke(synCall); + auto checkedCall = subVisitor.ResolveInvoke(synCall); // Of course, it is possible that the call went through fine, // but the result isn't of the type we expect/require, // so we also need to coerce the result of the call to // the expected type. // - auto coercedCall = coerce(resultType, checkedCall); - - // Once we are done making our semantic checks, we can - // restore the original sink, so that subsequent operations - // report diagnostics as usual. - // - m_shared->m_sink = savedSink; + auto coercedCall = subVisitor.coerce(resultType, checkedCall); // If our overload resolution or type coercion failed, // then we have not been able to synthesize a witness @@ -2380,9 +2378,8 @@ namespace Slang // `SemanticsVisitor` so that code can push/pop the emission // of diagnostics more easily. // - DiagnosticSink* savedSink = m_shared->m_sink; - DiagnosticSink tempSink(savedSink->getSourceManager(), nullptr); - m_shared->m_sink = &tempSink; + DiagnosticSink tempSink(getSourceManager(), nullptr); + SemanticsVisitor subVisitor(withSink(&tempSink)); // We start by constructing an expression that represents // `this.name` where `name` is the name of the required @@ -2415,7 +2412,7 @@ namespace Slang // general-purpose language features is unlikely to be as efficient // as special-case logic. // - auto synMemberRef = createLookupResultExpr( + auto synMemberRef = subVisitor.createLookupResultExpr( requiredMemberDeclRef.getName(), lookupResult, synThis, @@ -2434,7 +2431,7 @@ namespace Slang // which involves coercing the member access `this.name` to // the expected type of the property. // - auto coercedMemberRef = coerce(propertyType, synMemberRef); + auto coercedMemberRef = subVisitor.coerce(propertyType, synMemberRef); auto synReturn = m_astBuilder->create<ReturnStmt>(); synReturn->expression = coercedMemberRef; @@ -2461,7 +2458,7 @@ namespace Slang synAssign->left = synMemberRef; synAssign->right = synArgs[0]; - auto synCheckedAssign = checkAssignWithCheckedOperands(synAssign); + auto synCheckedAssign = subVisitor.checkAssignWithCheckedOperands(synAssign); auto synExprStmt = m_astBuilder->create<ExpressionStmt>(); synExprStmt->expression = synCheckedAssign; @@ -2477,10 +2474,8 @@ namespace Slang return false; } - // We restore the semantic checking state that was in place before - // we checked the synthesized accessor body, and then bail out - // if we ran into any errors (meaning that the synthesized accessor - // is not usable). + // We bail out if we ran into any errors (meaning that the synthesized + // accessor is not usable). // // TODO: If there were *warnings* emitted to the sink, it would probably // be good to show those warnings to the user, since they might indicate @@ -2488,7 +2483,6 @@ namespace Slang // satisfying an `int` property requirement, but the user would probably // want to be warned when they do such a thing. // - m_shared->m_sink = savedSink; if(tempSink.getErrorCount() != 0) return false; @@ -5046,7 +5040,7 @@ namespace Slang name, decl->moduleNameAndLoc.loc, getSink(), - m_shared->m_environmentModules); + getShared()->m_environmentModules); // If we didn't find a matching module, then bail out if (!importedModule) @@ -5324,7 +5318,7 @@ namespace Slang } - static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, SharedSemanticsContext* shared) + static void _dispatchDeclCheckingVisitor(Decl* decl, DeclCheckState state, SemanticsContext const& shared) { switch(state) { diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 2290936d8..317ab6a1a 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -25,6 +25,31 @@ namespace Slang return as<DeclRefType>(expr->type); } + void SemanticsContext::ExprLocalScope::addBinding(LetExpr* binding) + { + if (!m_innerMostBinding) + { + SLANG_ASSERT(!m_outerMostBinding); + + // If we haven't added any bindings, then `binding` + // becomes both the inner-most and outer most. + // + m_innerMostBinding = binding; + m_outerMostBinding = binding; + } + else + { + SLANG_ASSERT(m_outerMostBinding); + + // If we already have bindings, then `binding` + // will become the new inner-most binding. + // + m_innerMostBinding->body = binding; + m_innerMostBinding = binding; + } + } + + /// Move `expr` into a temporary variable and execute `func` on that variable. /// /// Returns an expression that wraps both the creation and initialization of @@ -46,11 +71,30 @@ namespace Slang letExpr->decl = varDecl; auto body = func(varDeclRef); + Expr* result = body; + if (auto exprLocalScope = getExprLocalScope()) + { + // We want to add the `LetExpr` to the set of such expressions + // in the local scope, so that it can be emitted properly. + // + exprLocalScope->addBinding(letExpr); + } + else + { + // If we somehow got in here and there wasn't an expression-local + // scope established yet, it almost certainly represents an error. + // + SLANG_ASSERT(exprLocalScope); - letExpr->body = body; - letExpr->type = body->type; + // As a fallback, though, we will try to wire up the `letExpr` + // to surround the body directly and return that. + // + letExpr->body = body; + letExpr->type = body->type; - return letExpr; + result = letExpr; + } + return result; } /// Execute `func` on a variable with the value of `expr`. @@ -62,6 +106,11 @@ namespace Slang template<typename F> Expr* SemanticsVisitor::maybeMoveTemp(Expr* const& expr, F const& func) { + // TODO: Eventually this operation could consider any case where the + // input `expr` names an immutable "path": one that starts at an + // immutable binding and follows a (possibly empty) chain of accesses + // to immutable members. + if(auto varExpr = as<VarExpr>(expr)) { auto declRef = varExpr->declRef; @@ -114,6 +163,26 @@ namespace Slang openedValue->declRef = varDeclRef; openedValue->type = QualType(openedType); + // The result of opening an existential is an l-value + // if the original existential is an l-value. + // + if(expr->type.isLeftValue) + { + // Marking the opened value as an l-value is the easy part. + // + openedValue->type.isLeftValue = true; + + // The more challenging bit is that in this case the `maybeMoveTemp()` + // operation will have copied the original existential value into + // a temporary. + // + // If this expression is used in an l-value context, then we need + // to be able to generate code to "write back" the modified value + // (which will be of `openedType`) to the original location named + // by `expr` (an existential for `interfaceDeclRef`). + // + } + return openedValue; }); } @@ -606,8 +675,47 @@ namespace Slang { if (!term) return nullptr; - SemanticsExprVisitor exprVisitor(getShared()); - return exprVisitor.dispatch(term); + // The process of checking a term/expression can end up introducing + // temporaries that need to be added to an outer scope. When jumping + // into expression checking, we want to check if we already have such + // a scope in place. If we do, we will re-use it for any sub-expressions. + // If not, we need to create one. + // + if(getExprLocalScope()) + { + return dispatchExpr(term, *this); + } + + ExprLocalScope exprLocalScope; + + Expr* checkedTerm = dispatchExpr(term, withExprLocalScope(&exprLocalScope)); + + LetExpr* outerMostBinding = exprLocalScope.getOuterMostBinding(); + if(!outerMostBinding) + { + return checkedTerm; + } + + LetExpr* binding = outerMostBinding; + auto type = checkedTerm->type; + while (binding) + { + binding->type = type; + + if (auto body = binding->body) + { + binding = as<LetExpr>(binding->body); + SLANG_ASSERT(binding); + continue; + } + else + { + binding->body = checkedTerm; + break; + } + } + + return outerMostBinding; } Expr* SemanticsVisitor::CreateErrorExpr(Expr* expr) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index f93db894f..e6088ccca 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -272,26 +272,51 @@ namespace Slang void _addCandidateExtensionsFromModule(ModuleDecl* moduleDecl); }; - struct SemanticsVisitor + /// Local/scoped state of the semantic-checking system + /// + /// This type is kept distinct from `SharedSemanticsContext` so that we + /// can avoid unncessary mutable state being propagated through the + /// checking process. + /// + /// Semantic-checking code should make a new local `SemanticsContext` + /// in cases where it want to check a sub-entity (expression, statement, + /// declaration, etc.) in a modified or extended context. + /// + struct SemanticsContext { - SemanticsVisitor( + public: + explicit SemanticsContext( SharedSemanticsContext* shared) - : m_shared(shared), - m_astBuilder(shared->getLinkage()->getASTBuilder()) + : m_shared(shared) + , m_sink(shared->getSink()) + , m_astBuilder(shared->getLinkage()->getASTBuilder()) {} - SharedSemanticsContext* m_shared = nullptr; - ASTBuilder* m_astBuilder = nullptr; - SharedSemanticsContext* getShared() { return m_shared; } - ASTBuilder* getASTBuilder() { return m_astBuilder;} + ASTBuilder* getASTBuilder() { return m_astBuilder; } - DiagnosticSink* getSink() { return m_shared->getSink(); } + DiagnosticSink* getSink() { return m_sink; } Session* getSession() { return m_shared->getSession(); } Linkage* getLinkage() { return m_shared->m_linkage; } NamePool* getNamePool() { return getLinkage()->getNamePool(); } + SourceManager* getSourceManager() { return getLinkage()->getSourceManager(); } + + SemanticsContext withSink(DiagnosticSink* sink) + { + SemanticsContext result(*this); + result.m_sink = sink; + return result; + } + + SemanticsContext withParentFunc(FunctionDeclBase* parentFunc) + { + SemanticsContext result(*this); + result.m_parentFunc = parentFunc; + result.m_outerStmts = nullptr; + return result; + } /// Information for tracking one or more outer statements. /// @@ -307,10 +332,81 @@ namespace Slang /// struct OuterStmtInfo { - Stmt* stmt = nullptr; - OuterStmtInfo* next; + Stmt* stmt = nullptr; + OuterStmtInfo* next; + }; + + OuterStmtInfo* getOuterStmts() { return m_outerStmts; } + + SemanticsContext withOuterStmts(OuterStmtInfo* outerStmts) + { + SemanticsContext result(*this); + result.m_outerStmts = outerStmts; + return result; + } + + /// A scope that is local to a particular expression, and + /// that can be used to allocate temporary bindings that + /// might be needed by that expression or its sub-expressions. + /// + /// The scope is represented as a sequence of nested `LetExpr`s + /// that introduce the bindings needed in the scope. + /// + struct ExprLocalScope + { + public: + void addBinding(LetExpr* binding); + + LetExpr* getOuterMostBinding() const { return m_outerMostBinding; } + + private: + LetExpr* m_outerMostBinding = nullptr; + LetExpr* m_innerMostBinding = nullptr; }; + ExprLocalScope* getExprLocalScope() { return m_exprLocalScope; } + + SemanticsContext withExprLocalScope(ExprLocalScope* exprLocalScope) + { + SemanticsContext result(*this); + result.m_exprLocalScope = exprLocalScope; + return result; + } + + private: + SharedSemanticsContext* m_shared = nullptr; + + DiagnosticSink* m_sink = nullptr; + + ExprLocalScope* m_exprLocalScope = nullptr; + + protected: + // TODO: consider making more of this state `private`... + + /// The parent function (if any) that surrounds the statement being checked. + FunctionDeclBase* m_parentFunc = nullptr; + + /// The linked list of lexically surrounding statements. + OuterStmtInfo* m_outerStmts = nullptr; + + ASTBuilder* m_astBuilder = nullptr; + }; + + struct SemanticsVisitor : public SemanticsContext + { + typedef SemanticsContext Super; + + explicit SemanticsVisitor( + SharedSemanticsContext* shared) + : Super(shared) + {} + + SemanticsVisitor( + SemanticsContext const& context) + : Super(context) + {} + + public: // Translate Types @@ -449,8 +545,8 @@ namespace Slang // so that we can add some quality-of-life features for users // in cases where the compiler crashes // - void dispatchStmt(Stmt* stmt, FunctionDeclBase* parentFunc, OuterStmtInfo* outerStmts); - void dispatchExpr(Expr* expr); + void dispatchStmt(Stmt* stmt, SemanticsContext const& context); + Expr* dispatchExpr(Expr* expr, SemanticsContext const& context); /// Ensure that a declaration has been checked up to some state /// (aka, a phase of semantic checking) so that we can safely @@ -899,7 +995,7 @@ namespace Slang // as the tag type for an `enum` void validateEnumTagType(Type* type, SourceLoc const& loc); - void checkStmt(Stmt* stmt, FunctionDeclBase* outerFunction, OuterStmtInfo* outerStmts); + void checkStmt(Stmt* stmt, SemanticsContext const& context); void getGenericParams( GenericDecl* decl, @@ -1536,8 +1632,8 @@ namespace Slang , ExprVisitor<SemanticsExprVisitor, Expr*> { public: - SemanticsExprVisitor(SharedSemanticsContext* shared) - : SemanticsVisitor(shared) + SemanticsExprVisitor(SemanticsContext const& outer) + : SemanticsVisitor(outer) {} Expr* visitBoolLiteralExpr(BoolLiteralExpr* expr); @@ -1610,18 +1706,10 @@ namespace Slang : public SemanticsVisitor , StmtVisitor<SemanticsStmtVisitor> { - SemanticsStmtVisitor(SharedSemanticsContext* shared, FunctionDeclBase* parentFunc, OuterStmtInfo* outerStmts) - : SemanticsVisitor(shared) - , m_parentFunc(parentFunc) - , m_outerStmts(outerStmts) + SemanticsStmtVisitor(SemanticsContext const& outer) + : SemanticsVisitor(outer) {} - /// The parent function (if any) that surrounds the statement being checked. - FunctionDeclBase* m_parentFunc = nullptr; - - /// The linked list of lexically surrounding statements. - OuterStmtInfo* m_outerStmts = nullptr; - FunctionDeclBase* getParentFunc() { return m_parentFunc; } void checkStmt(Stmt* stmt); @@ -1671,13 +1759,13 @@ namespace Slang struct SemanticsDeclVisitorBase : public SemanticsVisitor { - SemanticsDeclVisitorBase(SharedSemanticsContext* shared) - : SemanticsVisitor(shared) + SemanticsDeclVisitorBase(SemanticsContext const& outer) + : SemanticsVisitor(outer) {} void checkBodyStmt(Stmt* stmt, FunctionDeclBase* parentDecl) { - checkStmt(stmt, parentDecl, nullptr); + checkStmt(stmt, withParentFunc(parentDecl)); } void checkModule(ModuleDecl* programNode); diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 6f8ed9ff5..45be1d662 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -21,13 +21,10 @@ namespace Slang { public: WithOuterStmt(SemanticsStmtVisitor* visitor, Stmt* outerStmt) - : SemanticsStmtVisitor(*visitor) + : SemanticsStmtVisitor(visitor->withOuterStmts(&m_outerStmt)) { - m_parentFunc = visitor->m_parentFunc; - - m_outerStmt.next = visitor->m_outerStmts; + m_outerStmt.next = visitor->getOuterStmts(); m_outerStmt.stmt = outerStmt; - m_outerStmts = &m_outerStmt; } private: @@ -35,10 +32,10 @@ namespace Slang }; } - void SemanticsVisitor::checkStmt(Stmt* stmt, FunctionDeclBase* parentDecl, OuterStmtInfo* outerStmts) + void SemanticsVisitor::checkStmt(Stmt* stmt, SemanticsContext const& context) { if (!stmt) return; - dispatchStmt(stmt, parentDecl, outerStmts); + dispatchStmt(stmt, context); checkModifiers(stmt); } @@ -72,7 +69,7 @@ namespace Slang void SemanticsStmtVisitor::checkStmt(Stmt* stmt) { - SemanticsVisitor::checkStmt(stmt, m_parentFunc, m_outerStmts); + SemanticsVisitor::checkStmt(stmt, *this); } template<typename T> diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index 56e7bd379..633dec215 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -170,7 +170,7 @@ namespace Slang translationUnit->compileRequest->getSink(), &loadedModules); - SemanticsDeclVisitorBase visitor(&sharedSemanticsContext); + SemanticsDeclVisitorBase visitor( (SemanticsContext(&sharedSemanticsContext)) ); // Apply the visitor to do the main semantic // checking that is required on all declarations @@ -181,9 +181,9 @@ namespace Slang translationUnit->getModule()->_collectShaderParams(); } - void SemanticsVisitor::dispatchStmt(Stmt* stmt, FunctionDeclBase* parentFunc, OuterStmtInfo* outerStmts) + void SemanticsVisitor::dispatchStmt(Stmt* stmt, SemanticsContext const& context) { - SemanticsStmtVisitor visitor(getShared(), parentFunc, outerStmts); + SemanticsStmtVisitor visitor(context); try { visitor.dispatch(stmt); @@ -196,12 +196,12 @@ namespace Slang } } - void SemanticsVisitor::dispatchExpr(Expr* expr) + Expr* SemanticsVisitor::dispatchExpr(Expr* expr, SemanticsContext const& context) { - SemanticsExprVisitor visitor(getShared()); + SemanticsExprVisitor visitor(context); try { - visitor.dispatch(expr); + return visitor.dispatch(expr); } catch(const AbortCompilationException&) { throw; } catch(...) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 4ff43b697..5cfb07c1c 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -80,7 +80,8 @@ struct SubscriptInfo : ExtendedValueInfo struct BoundStorageInfo; struct BoundMemberInfo; struct SwizzledLValueInfo; - +struct CopiedValInfo; +struct ExtractedExistentialValInfo; // This type is our core representation of lowered values. // In the simple case, it just wraps an `IRInst*`. @@ -113,6 +114,9 @@ struct LoweredValInfo // The result of applying swizzling to an l-value SwizzledLValue, + + // The value extracted from an opened existential + ExtractedExistential, }; union @@ -185,6 +189,15 @@ struct LoweredValInfo SLANG_ASSERT(flavor == Flavor::SwizzledLValue); return (SwizzledLValueInfo*)ext; } + + static LoweredValInfo extractedExistential( + ExtractedExistentialValInfo* extInfo); + + ExtractedExistentialValInfo* getExtractedExistentialValInfo() + { + SLANG_ASSERT(flavor == Flavor::ExtractedExistential); + return (ExtractedExistentialValInfo*)ext; + } }; // This case is used to indicate a reference to an AST-level @@ -272,6 +285,27 @@ struct SwizzledLValueInfo : ExtendedValueInfo UInt elementIndices[4]; }; +// Represents the results of extractng a value of +// some (statically unknown) concrete type from +// an existential, in an l-value context. +// +struct ExtractedExistentialValInfo : ExtendedValueInfo +{ + // The extracted value + IRInst* extractedVal; + + // The original existential value + LoweredValInfo existentialVal; + + // The type of `existentialVal` + IRType* existentialType; + + // The IR witness table for the conformance of + // the type of `extractedVal` to `existentialType` + // + IRInst* witnessTable; +}; + LoweredValInfo LoweredValInfo::boundMember( BoundMemberInfo* boundMemberInfo) { @@ -308,6 +342,16 @@ LoweredValInfo LoweredValInfo::swizzledLValue( return info; } +LoweredValInfo LoweredValInfo::extractedExistential( + ExtractedExistentialValInfo* extInfo) +{ + LoweredValInfo info; + info.flavor = Flavor::ExtractedExistential; + info.ext = extInfo; + return info; +} + + // An "environment" for mapping AST declarations to IR values. // // This is required because in some cases we might lower the @@ -935,6 +979,13 @@ top: swizzleInfo->elementIndices)); } + case LoweredValInfo::Flavor::ExtractedExistential: + { + auto info = lowered.getExtractedExistentialValInfo(); + + return LoweredValInfo::simple(info->extractedVal); + } + default: SLANG_UNEXPECTED("unhandled value flavor"); UNREACHABLE_RETURN(LoweredValInfo()); @@ -2108,6 +2159,7 @@ void addInArg( case LoweredValInfo::Flavor::SwizzledLValue: case LoweredValInfo::Flavor::BoundStorage: case LoweredValInfo::Flavor::BoundMember: + case LoweredValInfo::Flavor::ExtractedExistential: args.add(getSimpleVal(context, argVal)); break; @@ -2784,6 +2836,8 @@ static LoweredValInfo _emitCallToAccessor( template<typename Derived> struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> { + static bool isLValueContext() { return Derived::_isLValueContext(); } + IRGenContext* context; IRBuilder* getBuilder() { return context->irBuilder; } @@ -2797,6 +2851,13 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return this->dispatch(expr); } + LoweredValInfo lowerSubExpr(Expr* expr, IRGenContext* subContext) + { + IRBuilderSourceLocRAII sourceLocInfo(getBuilder(), expr->loc); + Derived d; + d.context = subContext; + return d.dispatch(expr); + } LoweredValInfo visitVarExpr(VarExpr* expr) { @@ -3802,8 +3863,20 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo visitLetExpr(LetExpr* expr) { - // TODO: deal with the case where we might want to capture - // a reference to the bound value... + // Note: The semantics here are annoyingly subtle. + // + // If `expr->decl->initExpr` is an l-value, then we will set things up + // so that `expr->decl` is bound as an *alias* for that l-value. + // + // Otherwise, `expr->decl` will simply be bound to the r-value. + // + // The first case is necessary to make `maybeMoveTemp` operations that + // produce l-value results work correctly, but seems slippery. + // + // TODO: We should probably have two AST node types to cover the two + // different use cases of `LetExpr`: the definitely-immutable case that + // actually behaves like a `let`, and this other mutable-alias case that + // feels kind of messy and gross. auto initVal = lowerLValueExpr(context, expr->decl->initExpr); setGlobalValue(context, expr->decl, initVal); @@ -3813,17 +3886,66 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo visitExtractExistentialValueExpr(ExtractExistentialValueExpr* expr) { + // We are being asked to extract the value from an existential, which + // is itself a single IR op. However, we also need to handle the case + // where `expr` might be used as an l-value, in which case we need + // additional information to allow any mutations through the extracted + // value to be written back. + auto existentialType = lowerType(context, getType(getASTBuilder(), expr->declRef)); - auto existentialVal = getSimpleVal(context, emitDeclRef(context, expr->declRef, existentialType)); + auto existentialVal = emitDeclRef(context, expr->declRef, existentialType); + + // Note that we make a *copy* of the existential value that is definitely + // a simple r-value. This ensures that all the `extractExistential*()` operations + // below work on the same consistent IR value. + // + auto existentialValCopy = getSimpleVal(context, existentialVal); auto openedType = lowerType(context, expr->type); - return LoweredValInfo::simple(getBuilder()->emitExtractExistentialValue(openedType, existentialVal)); + auto extractedVal = getBuilder()->emitExtractExistentialValue( + openedType, existentialValCopy); + + if(!isLValueContext()) + { + // If we are in an r-value context, we can directly use the `extractExistentialValue` + // instruction as the result, and life is simple. + // + return LoweredValInfo::simple(extractedVal); + } + + // In an l-value context, we need to track the information necessary so that + // if a new/modified value of `openedType` was produced, we could write it + // back into the original `existentialVal`'s location. + // + // The write-back is actually pretty simple: it is just a `makeExisential` op. + // In order to be able to emit that op later, we need to track the operands + // that it would use. The first operand would be the new concrete value (which + // would implicitly encode the concrete type via its IR type) while the second + // is the witness table for the conformance to the existential. + // + // Note: We are assuming/requiring here that any value "written back" must have + // the exact same concrete type as `extractedVal`, so taht it can use the same + // IR witness table. The front-end should be enforcing that constraint, and we + // have no way to check or enforce it at this point. + + auto witnessTable = getBuilder()->emitExtractExistentialWitnessTable(existentialValCopy); + + RefPtr<ExtractedExistentialValInfo> info = new ExtractedExistentialValInfo(); + info->extractedVal = extractedVal; + info->existentialVal = existentialVal; + info->existentialType = existentialType; + info->witnessTable = witnessTable; + + context->shared->extValues.add(info); + return LoweredValInfo::extractedExistential(info); } }; struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVisitor> { + static bool _isLValueContext() { return true; } + // When visiting a swizzle expression in an l-value context, // we need to construct a "swizzled l-value." LoweredValInfo visitMatrixSwizzleExpr(MatrixSwizzleExpr*) @@ -3901,6 +4023,8 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVisitor> { + static bool _isLValueContext() { return false; } + // A matrix swizzle in an r-value context can save time by just // emitting the matrix swizzle instructions directly. LoweredValInfo visitMatrixSwizzleExpr(MatrixSwizzleExpr* expr) @@ -5289,6 +5413,32 @@ top: } break; + case LoweredValInfo::Flavor::ExtractedExistential: + { + // The `left` value is the result of opening an existential. + // + auto leftInfo = left.getExtractedExistentialValInfo(); + auto existentialVal = leftInfo->existentialVal; + + // The actual desitnation we need to store into is the + // existential value itself. + // + left = existentialVal; + + // The `right` value must be of the same concrete type as + // the opened value, but the new destination is of the + // original existential type, so we need to wrap it up + // appropriately. + // + right = LoweredValInfo::simple(builder->emitMakeExistential( + leftInfo->existentialType, + getSimpleVal(context, right), + leftInfo->witnessTable)); + + goto top; + } + break; + default: SLANG_UNIMPLEMENTED_X("assignment"); break; |
