summaryrefslogtreecommitdiffstats
path: root/source/slang/lower.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/lower.cpp')
-rw-r--r--source/slang/lower.cpp292
1 files changed, 262 insertions, 30 deletions
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index 674614cd3..f9bf97107 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -60,6 +60,7 @@ struct StructuralTransformVisitorBase
RefPtr<ScopeDecl> transformSyntaxField(ScopeDecl* decl)
{
+ if(!decl) return nullptr;
return visitor->transformSyntaxField(decl).As<ScopeDecl>();
}
@@ -80,7 +81,7 @@ struct StructuralTransformStmtVisitor
: StructuralTransformVisitorBase<V>
, StmtVisitor<StructuralTransformStmtVisitor<V>, RefPtr<StatementSyntaxNode>>
{
- void transformFields(StatementSyntaxNode* result, StatementSyntaxNode* obj)
+ void transformFields(StatementSyntaxNode*, StatementSyntaxNode*)
{
}
@@ -263,6 +264,7 @@ struct LoweringVisitor
RefPtr<ExpressionType> lowerType(
ExpressionType* type)
{
+ if(!type) return nullptr;
return TypeVisitor::dispatch(type);
}
@@ -456,12 +458,11 @@ struct LoweringVisitor
// Statements
//
- StatementSyntaxNode* translateStmtRef(
- StatementSyntaxNode* stmt)
- {
- throw "unimplemented";
- }
-
+ // Lowering one statement to another.
+ // The source statement might desugar into multiple statements,
+ // (or event to none), and in such a case this function wraps
+ // the result up as a `SeqStmt` or `EmptyStmt` as appropriate.
+ //
RefPtr<StatementSyntaxNode> lowerStmt(
StatementSyntaxNode* stmt)
{
@@ -483,6 +484,37 @@ struct LoweringVisitor
}
}
+
+ // Structure to track "outer" statements during lowering
+ struct StmtLoweringState
+ {
+ // The next "outer" statement entry
+ StmtLoweringState* parent = nullptr;
+
+ // The outer statement (both lowered and original)
+ StatementSyntaxNode* loweredStmt = nullptr;
+ StatementSyntaxNode* originalStmt = nullptr;
+ };
+ StmtLoweringState stmtLoweringState;
+
+ // Translate a reference from one statement to an outer statement
+ StatementSyntaxNode* translateStmtRef(
+ StatementSyntaxNode* originalStmt)
+ {
+ if(!originalStmt) return nullptr;
+
+ for( auto state = &stmtLoweringState; state; state = state->parent )
+ {
+ if(state->originalStmt == originalStmt)
+ return state->loweredStmt;
+ }
+
+ assert(!"unexepcted");
+
+ return nullptr;
+ }
+
+ // Expand a statement to be lowered into one or more statements
void lowerStmtImpl(
StatementSyntaxNode* stmt)
{
@@ -498,14 +530,17 @@ struct LoweringVisitor
LoweringVisitor pushScope(
RefPtr<ScopeStmt> loweredStmt,
- RefPtr<ScopeStmt> stmt)
+ RefPtr<ScopeStmt> originalStmt)
{
- loweredStmt->scopeDecl = translateDeclRef(stmt->scopeDecl).As<ScopeDecl>();
+ loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).As<ScopeDecl>();
LoweringVisitor subVisitor = *this;
subVisitor.isBuildingStmt = true;
subVisitor.stmtBeingBuilt = nullptr;
subVisitor.parentDecl = loweredStmt->scopeDecl;
+ subVisitor.stmtLoweringState.parent = &stmtLoweringState;
+ subVisitor.stmtLoweringState.originalStmt = originalStmt;
+ subVisitor.stmtLoweringState.loweredStmt = loweredStmt;
return subVisitor;
}
@@ -562,11 +597,11 @@ struct LoweringVisitor
void visit(BlockStmt* stmt)
{
RefPtr<BlockStmt> loweredStmt = new BlockStmt();
+ lowerScopeStmtFields(loweredStmt, stmt);
LoweringVisitor subVisitor = pushScope(loweredStmt, stmt);
- subVisitor.stmtBeingBuilt = loweredStmt;
- subVisitor.lowerStmtImpl(stmt->body);
+ loweredStmt->body = subVisitor.lowerStmt(stmt->body);
addStmt(loweredStmt);
}
@@ -589,10 +624,152 @@ struct LoweringVisitor
DeclVisitor::dispatch(stmt->decl);
}
- // catch-all
- void visit(StatementSyntaxNode* stmt)
+ void lowerStmtFields(
+ StatementSyntaxNode* loweredStmt,
+ StatementSyntaxNode* originalStmt)
+ {
+ loweredStmt->Position = originalStmt->Position;
+ loweredStmt->modifiers = originalStmt->modifiers;
+ }
+
+ void lowerScopeStmtFields(
+ ScopeStmt* loweredStmt,
+ ScopeStmt* originalStmt)
+ {
+ lowerStmtFields(loweredStmt, originalStmt);
+ loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).As<ScopeDecl>();
+ }
+
+ // Child statements reference their parent statement,
+ // so we need to translate that cross-reference
+ void lowerChildStmtFields(
+ ChildStmt* loweredStmt,
+ ChildStmt* originalStmt)
+ {
+ lowerStmtFields(loweredStmt, originalStmt);
+
+ loweredStmt->parentStmt = translateStmtRef(originalStmt->parentStmt);
+ }
+
+ void visit(ContinueStatementSyntaxNode* stmt)
+ {
+ RefPtr<ContinueStatementSyntaxNode> loweredStmt = new ContinueStatementSyntaxNode();
+ lowerChildStmtFields(loweredStmt, stmt);
+ addStmt(loweredStmt);
+ }
+
+ void visit(BreakStatementSyntaxNode* stmt)
+ {
+ RefPtr<BreakStatementSyntaxNode> loweredStmt = new BreakStatementSyntaxNode();
+ lowerChildStmtFields(loweredStmt, stmt);
+ addStmt(loweredStmt);
+ }
+
+ void visit(DefaultStmt* stmt)
+ {
+ RefPtr<DefaultStmt> loweredStmt = new DefaultStmt();
+ lowerChildStmtFields(loweredStmt, stmt);
+ addStmt(loweredStmt);
+ }
+
+ void visit(DiscardStatementSyntaxNode* stmt)
+ {
+ RefPtr<DiscardStatementSyntaxNode> loweredStmt = new DiscardStatementSyntaxNode();
+ lowerStmtFields(loweredStmt, stmt);
+ addStmt(loweredStmt);
+ }
+
+ void visit(EmptyStatementSyntaxNode* stmt)
+ {
+ RefPtr<EmptyStatementSyntaxNode> loweredStmt = new EmptyStatementSyntaxNode();
+ lowerStmtFields(loweredStmt, stmt);
+ addStmt(loweredStmt);
+ }
+
+ void visit(UnparsedStmt* stmt)
+ {
+ RefPtr<UnparsedStmt> loweredStmt = new UnparsedStmt();
+ lowerStmtFields(loweredStmt, stmt);
+
+ loweredStmt->tokens = stmt->tokens;
+
+ addStmt(loweredStmt);
+ }
+
+ void visit(CaseStmt* stmt)
+ {
+ RefPtr<CaseStmt> loweredStmt = new CaseStmt();
+ lowerChildStmtFields(loweredStmt, stmt);
+
+ loweredStmt->expr = lowerExpr(stmt->expr);
+
+ addStmt(loweredStmt);
+ }
+
+ void visit(IfStatementSyntaxNode* stmt)
+ {
+ RefPtr<IfStatementSyntaxNode> loweredStmt = new IfStatementSyntaxNode();
+ lowerStmtFields(loweredStmt, stmt);
+
+ loweredStmt->Predicate = lowerExpr(stmt->Predicate);
+ loweredStmt->PositiveStatement = lowerStmt(stmt->PositiveStatement);
+ loweredStmt->NegativeStatement = lowerStmt(stmt->NegativeStatement);
+
+ addStmt(loweredStmt);
+ }
+
+ void visit(SwitchStmt* stmt)
+ {
+ RefPtr<SwitchStmt> loweredStmt = new SwitchStmt();
+ lowerScopeStmtFields(loweredStmt, stmt);
+
+ LoweringVisitor subVisitor = pushScope(loweredStmt, stmt);
+
+ loweredStmt->condition = subVisitor.lowerExpr(stmt->condition);
+ loweredStmt->body = subVisitor.lowerStmt(stmt->body);
+
+ addStmt(loweredStmt);
+ }
+
+
+ void visit(ForStatementSyntaxNode* stmt)
{
- auto loweredStmt = structuralTransform(stmt, this);
+ RefPtr<ForStatementSyntaxNode> loweredStmt = new ForStatementSyntaxNode();
+ lowerScopeStmtFields(loweredStmt, stmt);
+
+ LoweringVisitor subVisitor = pushScope(loweredStmt, stmt);
+
+ loweredStmt->InitialStatement = subVisitor.lowerStmt(stmt->InitialStatement);
+ loweredStmt->SideEffectExpression = subVisitor.lowerExpr(stmt->SideEffectExpression);
+ loweredStmt->PredicateExpression = subVisitor.lowerExpr(stmt->PredicateExpression);
+ loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement);
+
+ addStmt(loweredStmt);
+ }
+
+ void visit(WhileStatementSyntaxNode* stmt)
+ {
+ RefPtr<WhileStatementSyntaxNode> loweredStmt = new WhileStatementSyntaxNode();
+ lowerScopeStmtFields(loweredStmt, stmt);
+
+ LoweringVisitor subVisitor = pushScope(loweredStmt, stmt);
+
+ loweredStmt->Predicate = subVisitor.lowerExpr(stmt->Predicate);
+ loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement);
+
+ addStmt(loweredStmt);
+ }
+
+ void visit(DoWhileStatementSyntaxNode* stmt)
+ {
+ RefPtr<DoWhileStatementSyntaxNode> loweredStmt = new DoWhileStatementSyntaxNode();
+ lowerScopeStmtFields(loweredStmt, stmt);
+
+ LoweringVisitor subVisitor = pushScope(loweredStmt, stmt);
+
+ loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement);
+ loweredStmt->Predicate = subVisitor.lowerExpr(stmt->Predicate);
+
addStmt(loweredStmt);
}
@@ -666,13 +843,13 @@ struct LoweringVisitor
}
RefPtr<Substitutions> translateSubstitutions(
- Substitutions* substitutions)
+ Substitutions* inSubstitutions)
{
- if (!substitutions) return nullptr;
+ if (!inSubstitutions) return nullptr;
RefPtr<Substitutions> result = new Substitutions();
- result->genericDecl = translateDeclRef(substitutions->genericDecl).As<GenericDecl>();
- for (auto arg : substitutions->args)
+ result->genericDecl = translateDeclRef(inSubstitutions->genericDecl).As<GenericDecl>();
+ for (auto arg : inSubstitutions->args)
{
result->args.Add(translateVal(arg));
}
@@ -740,9 +917,8 @@ struct LoweringVisitor
}
else
{
- DeclVisitor::dispatch(declBase);
+ return DeclVisitor::dispatch(declBase);
}
-
}
RefPtr<Decl> lowerDecl(
@@ -827,11 +1003,68 @@ struct LoweringVisitor
}
// Catch-all
- RefPtr<Decl> visit(
- Decl* decl)
+
+ RefPtr<Decl> visit(ModifierDecl*)
{
- assert(!"unimplemented");
- return decl;
+ // should not occur in user code
+ SLANG_UNEXPECTED("modifiers shouldn't occur in user code");
+ }
+
+ RefPtr<Decl> visit(GenericValueParamDecl*)
+ {
+ SLANG_UNEXPECTED("generics should be lowered to specialized decls");
+ }
+
+ RefPtr<Decl> visit(GenericTypeParamDecl*)
+ {
+ SLANG_UNEXPECTED("generics should be lowered to specialized decls");
+ }
+
+ RefPtr<Decl> visit(GenericTypeConstraintDecl*)
+ {
+ SLANG_UNEXPECTED("generics should be lowered to specialized decls");
+ }
+
+ RefPtr<Decl> visit(GenericDecl*)
+ {
+ SLANG_UNEXPECTED("generics should be lowered to specialized decls");
+ }
+
+ RefPtr<Decl> visit(ProgramSyntaxNode*)
+ {
+ SLANG_UNEXPECTED("module decls should be lowered explicitly");
+ }
+
+ RefPtr<Decl> visit(SubscriptDecl*)
+ {
+ // We don't expect to find direct references to a subscript
+ // declaration, but rather to the underlying accessors
+ return nullptr;
+ }
+
+ RefPtr<Decl> visit(InheritanceDecl*)
+ {
+ // We should deal with these explicitly, as part of lowering
+ // the type that contains them.
+ return nullptr;
+ }
+
+ RefPtr<Decl> visit(ExtensionDecl*)
+ {
+ // Extensions won't exist in the lowered code: their members
+ // will turn into ordinary functions that get called explicitly
+ return nullptr;
+ }
+
+ RefPtr<Decl> visit(TypeDefDecl* decl)
+ {
+ RefPtr<TypeDefDecl> loweredDecl = new TypeDefDecl();
+ lowerDeclCommon(loweredDecl, decl);
+
+ loweredDecl->Type = lowerType(decl->Type);
+
+ addMember(shared->loweredProgram, loweredDecl);
+ return loweredDecl;
}
RefPtr<ImportDecl> visit(ImportDecl* decl)
@@ -922,8 +1155,8 @@ struct LoweringVisitor
//
// If this is a global variable (program scope), then add it
// to the global scope.
- RefPtr<ContainerDecl> parentDecl = decl->ParentDecl;
- if (auto parentModuleDecl = parentDecl.As<ProgramSyntaxNode>())
+ RefPtr<ContainerDecl> pp = decl->ParentDecl;
+ if (auto parentModuleDecl = pp.As<ProgramSyntaxNode>())
{
addMember(
translateDeclRef(parentModuleDecl),
@@ -1138,10 +1371,6 @@ struct LoweringVisitor
subscriptExpr->Position = varExpr->Position;
subscriptExpr->BaseExpression = varExpr;
- // TODO: we need to construct syntax for a loop to initialize
- // the array here...
- throw "unimplemented";
-
// Note that we use the original `varLayout` that was passed in,
// since that is the layout that will ultimately need to be
// used on the array elements.
@@ -1154,6 +1383,9 @@ struct LoweringVisitor
varLayout,
subscriptExpr);
+ // TODO: we need to construct syntax for a loop to initialize
+ // the array here...
+ throw "unimplemented";
}
else if (auto declRefType = varType->As<DeclRefType>())
{