diff options
Diffstat (limited to 'source/slang/lower.cpp')
| -rw-r--r-- | source/slang/lower.cpp | 292 |
1 files changed, 262 insertions, 30 deletions
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp index 674614cd3..f9bf97107 100644 --- a/source/slang/lower.cpp +++ b/source/slang/lower.cpp @@ -60,6 +60,7 @@ struct StructuralTransformVisitorBase RefPtr<ScopeDecl> transformSyntaxField(ScopeDecl* decl) { + if(!decl) return nullptr; return visitor->transformSyntaxField(decl).As<ScopeDecl>(); } @@ -80,7 +81,7 @@ struct StructuralTransformStmtVisitor : StructuralTransformVisitorBase<V> , StmtVisitor<StructuralTransformStmtVisitor<V>, RefPtr<StatementSyntaxNode>> { - void transformFields(StatementSyntaxNode* result, StatementSyntaxNode* obj) + void transformFields(StatementSyntaxNode*, StatementSyntaxNode*) { } @@ -263,6 +264,7 @@ struct LoweringVisitor RefPtr<ExpressionType> lowerType( ExpressionType* type) { + if(!type) return nullptr; return TypeVisitor::dispatch(type); } @@ -456,12 +458,11 @@ struct LoweringVisitor // Statements // - StatementSyntaxNode* translateStmtRef( - StatementSyntaxNode* stmt) - { - throw "unimplemented"; - } - + // Lowering one statement to another. + // The source statement might desugar into multiple statements, + // (or event to none), and in such a case this function wraps + // the result up as a `SeqStmt` or `EmptyStmt` as appropriate. + // RefPtr<StatementSyntaxNode> lowerStmt( StatementSyntaxNode* stmt) { @@ -483,6 +484,37 @@ struct LoweringVisitor } } + + // Structure to track "outer" statements during lowering + struct StmtLoweringState + { + // The next "outer" statement entry + StmtLoweringState* parent = nullptr; + + // The outer statement (both lowered and original) + StatementSyntaxNode* loweredStmt = nullptr; + StatementSyntaxNode* originalStmt = nullptr; + }; + StmtLoweringState stmtLoweringState; + + // Translate a reference from one statement to an outer statement + StatementSyntaxNode* translateStmtRef( + StatementSyntaxNode* originalStmt) + { + if(!originalStmt) return nullptr; + + for( auto state = &stmtLoweringState; state; state = state->parent ) + { + if(state->originalStmt == originalStmt) + return state->loweredStmt; + } + + assert(!"unexepcted"); + + return nullptr; + } + + // Expand a statement to be lowered into one or more statements void lowerStmtImpl( StatementSyntaxNode* stmt) { @@ -498,14 +530,17 @@ struct LoweringVisitor LoweringVisitor pushScope( RefPtr<ScopeStmt> loweredStmt, - RefPtr<ScopeStmt> stmt) + RefPtr<ScopeStmt> originalStmt) { - loweredStmt->scopeDecl = translateDeclRef(stmt->scopeDecl).As<ScopeDecl>(); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).As<ScopeDecl>(); LoweringVisitor subVisitor = *this; subVisitor.isBuildingStmt = true; subVisitor.stmtBeingBuilt = nullptr; subVisitor.parentDecl = loweredStmt->scopeDecl; + subVisitor.stmtLoweringState.parent = &stmtLoweringState; + subVisitor.stmtLoweringState.originalStmt = originalStmt; + subVisitor.stmtLoweringState.loweredStmt = loweredStmt; return subVisitor; } @@ -562,11 +597,11 @@ struct LoweringVisitor void visit(BlockStmt* stmt) { RefPtr<BlockStmt> loweredStmt = new BlockStmt(); + lowerScopeStmtFields(loweredStmt, stmt); LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); - subVisitor.stmtBeingBuilt = loweredStmt; - subVisitor.lowerStmtImpl(stmt->body); + loweredStmt->body = subVisitor.lowerStmt(stmt->body); addStmt(loweredStmt); } @@ -589,10 +624,152 @@ struct LoweringVisitor DeclVisitor::dispatch(stmt->decl); } - // catch-all - void visit(StatementSyntaxNode* stmt) + void lowerStmtFields( + StatementSyntaxNode* loweredStmt, + StatementSyntaxNode* originalStmt) + { + loweredStmt->Position = originalStmt->Position; + loweredStmt->modifiers = originalStmt->modifiers; + } + + void lowerScopeStmtFields( + ScopeStmt* loweredStmt, + ScopeStmt* originalStmt) + { + lowerStmtFields(loweredStmt, originalStmt); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).As<ScopeDecl>(); + } + + // Child statements reference their parent statement, + // so we need to translate that cross-reference + void lowerChildStmtFields( + ChildStmt* loweredStmt, + ChildStmt* originalStmt) + { + lowerStmtFields(loweredStmt, originalStmt); + + loweredStmt->parentStmt = translateStmtRef(originalStmt->parentStmt); + } + + void visit(ContinueStatementSyntaxNode* stmt) + { + RefPtr<ContinueStatementSyntaxNode> loweredStmt = new ContinueStatementSyntaxNode(); + lowerChildStmtFields(loweredStmt, stmt); + addStmt(loweredStmt); + } + + void visit(BreakStatementSyntaxNode* stmt) + { + RefPtr<BreakStatementSyntaxNode> loweredStmt = new BreakStatementSyntaxNode(); + lowerChildStmtFields(loweredStmt, stmt); + addStmt(loweredStmt); + } + + void visit(DefaultStmt* stmt) + { + RefPtr<DefaultStmt> loweredStmt = new DefaultStmt(); + lowerChildStmtFields(loweredStmt, stmt); + addStmt(loweredStmt); + } + + void visit(DiscardStatementSyntaxNode* stmt) + { + RefPtr<DiscardStatementSyntaxNode> loweredStmt = new DiscardStatementSyntaxNode(); + lowerStmtFields(loweredStmt, stmt); + addStmt(loweredStmt); + } + + void visit(EmptyStatementSyntaxNode* stmt) + { + RefPtr<EmptyStatementSyntaxNode> loweredStmt = new EmptyStatementSyntaxNode(); + lowerStmtFields(loweredStmt, stmt); + addStmt(loweredStmt); + } + + void visit(UnparsedStmt* stmt) + { + RefPtr<UnparsedStmt> loweredStmt = new UnparsedStmt(); + lowerStmtFields(loweredStmt, stmt); + + loweredStmt->tokens = stmt->tokens; + + addStmt(loweredStmt); + } + + void visit(CaseStmt* stmt) + { + RefPtr<CaseStmt> loweredStmt = new CaseStmt(); + lowerChildStmtFields(loweredStmt, stmt); + + loweredStmt->expr = lowerExpr(stmt->expr); + + addStmt(loweredStmt); + } + + void visit(IfStatementSyntaxNode* stmt) + { + RefPtr<IfStatementSyntaxNode> loweredStmt = new IfStatementSyntaxNode(); + lowerStmtFields(loweredStmt, stmt); + + loweredStmt->Predicate = lowerExpr(stmt->Predicate); + loweredStmt->PositiveStatement = lowerStmt(stmt->PositiveStatement); + loweredStmt->NegativeStatement = lowerStmt(stmt->NegativeStatement); + + addStmt(loweredStmt); + } + + void visit(SwitchStmt* stmt) + { + RefPtr<SwitchStmt> loweredStmt = new SwitchStmt(); + lowerScopeStmtFields(loweredStmt, stmt); + + LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); + + loweredStmt->condition = subVisitor.lowerExpr(stmt->condition); + loweredStmt->body = subVisitor.lowerStmt(stmt->body); + + addStmt(loweredStmt); + } + + + void visit(ForStatementSyntaxNode* stmt) { - auto loweredStmt = structuralTransform(stmt, this); + RefPtr<ForStatementSyntaxNode> loweredStmt = new ForStatementSyntaxNode(); + lowerScopeStmtFields(loweredStmt, stmt); + + LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); + + loweredStmt->InitialStatement = subVisitor.lowerStmt(stmt->InitialStatement); + loweredStmt->SideEffectExpression = subVisitor.lowerExpr(stmt->SideEffectExpression); + loweredStmt->PredicateExpression = subVisitor.lowerExpr(stmt->PredicateExpression); + loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement); + + addStmt(loweredStmt); + } + + void visit(WhileStatementSyntaxNode* stmt) + { + RefPtr<WhileStatementSyntaxNode> loweredStmt = new WhileStatementSyntaxNode(); + lowerScopeStmtFields(loweredStmt, stmt); + + LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); + + loweredStmt->Predicate = subVisitor.lowerExpr(stmt->Predicate); + loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement); + + addStmt(loweredStmt); + } + + void visit(DoWhileStatementSyntaxNode* stmt) + { + RefPtr<DoWhileStatementSyntaxNode> loweredStmt = new DoWhileStatementSyntaxNode(); + lowerScopeStmtFields(loweredStmt, stmt); + + LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); + + loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement); + loweredStmt->Predicate = subVisitor.lowerExpr(stmt->Predicate); + addStmt(loweredStmt); } @@ -666,13 +843,13 @@ struct LoweringVisitor } RefPtr<Substitutions> translateSubstitutions( - Substitutions* substitutions) + Substitutions* inSubstitutions) { - if (!substitutions) return nullptr; + if (!inSubstitutions) return nullptr; RefPtr<Substitutions> result = new Substitutions(); - result->genericDecl = translateDeclRef(substitutions->genericDecl).As<GenericDecl>(); - for (auto arg : substitutions->args) + result->genericDecl = translateDeclRef(inSubstitutions->genericDecl).As<GenericDecl>(); + for (auto arg : inSubstitutions->args) { result->args.Add(translateVal(arg)); } @@ -740,9 +917,8 @@ struct LoweringVisitor } else { - DeclVisitor::dispatch(declBase); + return DeclVisitor::dispatch(declBase); } - } RefPtr<Decl> lowerDecl( @@ -827,11 +1003,68 @@ struct LoweringVisitor } // Catch-all - RefPtr<Decl> visit( - Decl* decl) + + RefPtr<Decl> visit(ModifierDecl*) { - assert(!"unimplemented"); - return decl; + // should not occur in user code + SLANG_UNEXPECTED("modifiers shouldn't occur in user code"); + } + + RefPtr<Decl> visit(GenericValueParamDecl*) + { + SLANG_UNEXPECTED("generics should be lowered to specialized decls"); + } + + RefPtr<Decl> visit(GenericTypeParamDecl*) + { + SLANG_UNEXPECTED("generics should be lowered to specialized decls"); + } + + RefPtr<Decl> visit(GenericTypeConstraintDecl*) + { + SLANG_UNEXPECTED("generics should be lowered to specialized decls"); + } + + RefPtr<Decl> visit(GenericDecl*) + { + SLANG_UNEXPECTED("generics should be lowered to specialized decls"); + } + + RefPtr<Decl> visit(ProgramSyntaxNode*) + { + SLANG_UNEXPECTED("module decls should be lowered explicitly"); + } + + RefPtr<Decl> visit(SubscriptDecl*) + { + // We don't expect to find direct references to a subscript + // declaration, but rather to the underlying accessors + return nullptr; + } + + RefPtr<Decl> visit(InheritanceDecl*) + { + // We should deal with these explicitly, as part of lowering + // the type that contains them. + return nullptr; + } + + RefPtr<Decl> visit(ExtensionDecl*) + { + // Extensions won't exist in the lowered code: their members + // will turn into ordinary functions that get called explicitly + return nullptr; + } + + RefPtr<Decl> visit(TypeDefDecl* decl) + { + RefPtr<TypeDefDecl> loweredDecl = new TypeDefDecl(); + lowerDeclCommon(loweredDecl, decl); + + loweredDecl->Type = lowerType(decl->Type); + + addMember(shared->loweredProgram, loweredDecl); + return loweredDecl; } RefPtr<ImportDecl> visit(ImportDecl* decl) @@ -922,8 +1155,8 @@ struct LoweringVisitor // // If this is a global variable (program scope), then add it // to the global scope. - RefPtr<ContainerDecl> parentDecl = decl->ParentDecl; - if (auto parentModuleDecl = parentDecl.As<ProgramSyntaxNode>()) + RefPtr<ContainerDecl> pp = decl->ParentDecl; + if (auto parentModuleDecl = pp.As<ProgramSyntaxNode>()) { addMember( translateDeclRef(parentModuleDecl), @@ -1138,10 +1371,6 @@ struct LoweringVisitor subscriptExpr->Position = varExpr->Position; subscriptExpr->BaseExpression = varExpr; - // TODO: we need to construct syntax for a loop to initialize - // the array here... - throw "unimplemented"; - // Note that we use the original `varLayout` that was passed in, // since that is the layout that will ultimately need to be // used on the array elements. @@ -1154,6 +1383,9 @@ struct LoweringVisitor varLayout, subscriptExpr); + // TODO: we need to construct syntax for a loop to initialize + // the array here... + throw "unimplemented"; } else if (auto declRefType = varType->As<DeclRefType>()) { |
