summaryrefslogtreecommitdiffstats
path: root/source/slang/lower.cpp
diff options
context:
space:
mode:
authorTim Foley <tfoley@nvidia.com>2017-07-07 09:57:26 -0700
committerTim Foley <tfoley@nvidia.com>2017-07-07 09:57:26 -0700
commit56b44cbf582fac32e31601fd2a7ae1d6cb8f71b2 (patch)
tree2612e7862a4b57acff0e6223caf365a88c63e62a /source/slang/lower.cpp
parent975e4b326cd2ef3ef0341d1fb7509315b9dee555 (diff)
Fix up visitor approach.
The existing code used a catch-all `visit()` method, and then relied on overloading to find the right version (allowing fallback to a `visit()` method taking a base-class parameter). This approach works, but has some big down-sides: - When browsing the code, you have a bunch of identically-named methods, and it can be hard to find the one you want. - It is impossible to use inheritance to implement fallback for `visit()` methods, because *any* method in the derived class with that name hides *all* methods with the same name in a base class This change makes the `visit()` methods use the name of the corresponding syntax class, and then has visitors inherit the fallback methods they need from the base visitor template class.
Diffstat (limited to 'source/slang/lower.cpp')
-rw-r--r--source/slang/lower.cpp111
1 files changed, 56 insertions, 55 deletions
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index f9bf97107..8bc9619f3 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -76,6 +76,7 @@ struct StructuralTransformVisitorBase
}
};
+#if 0
template<typename V>
struct StructuralTransformStmtVisitor
: StructuralTransformVisitorBase<V>
@@ -107,6 +108,7 @@ struct StructuralTransformStmtVisitor
#include "object-meta-end.h"
};
+#endif
template<typename V>
RefPtr<StatementSyntaxNode> structuralTransform(
@@ -130,7 +132,7 @@ struct StructuralTransformExprVisitor
#define SYNTAX_CLASS(NAME, BASE, ...) \
- RefPtr<ExpressionSyntaxNode> visit(NAME* obj) { \
+ RefPtr<ExpressionSyntaxNode> visit##NAME(NAME* obj) { \
RefPtr<NAME> result = new NAME(*obj); \
transformFields(result, obj); \
return result; \
@@ -216,8 +218,7 @@ struct LoweringVisitor
: ExprVisitor<LoweringVisitor, RefPtr<ExpressionSyntaxNode>>
, StmtVisitor<LoweringVisitor, void>
, DeclVisitor<LoweringVisitor, RefPtr<Decl>>
- , TypeVisitor<LoweringVisitor, RefPtr<ExpressionType>>
- , ValVisitor<LoweringVisitor, RefPtr<Val>>
+ , ValVisitor<LoweringVisitor, RefPtr<Val>, RefPtr<ExpressionType>>
{
//
SharedLoweringContext* shared;
@@ -247,12 +248,12 @@ struct LoweringVisitor
return ValVisitor::dispatch(val);
}
- RefPtr<Val> visit(GenericParamIntVal* val)
+ RefPtr<Val> visitGenericParamIntVal(GenericParamIntVal* val)
{
return new GenericParamIntVal(translateDeclRef(DeclRef<Decl>(val->declRef)).As<VarDeclBase>());
}
- RefPtr<Val> visit(ConstantIntVal* val)
+ RefPtr<Val> visitConstantIntVal(ConstantIntVal* val)
{
return val;
}
@@ -276,40 +277,40 @@ struct LoweringVisitor
return result;
}
- RefPtr<ExpressionType> visit(ErrorType* type)
+ RefPtr<ExpressionType> visitErrorType(ErrorType* type)
{
return type;
}
- RefPtr<ExpressionType> visit(OverloadGroupType* type)
+ RefPtr<ExpressionType> visitOverloadGroupType(OverloadGroupType* type)
{
return type;
}
- RefPtr<ExpressionType> visit(InitializerListType* type)
+ RefPtr<ExpressionType> visitInitializerListType(InitializerListType* type)
{
return type;
}
- RefPtr<ExpressionType> visit(GenericDeclRefType* type)
+ RefPtr<ExpressionType> visitGenericDeclRefType(GenericDeclRefType* type)
{
return new GenericDeclRefType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<GenericDecl>());
}
- RefPtr<ExpressionType> visit(FuncType* type)
+ RefPtr<ExpressionType> visitFuncType(FuncType* type)
{
RefPtr<FuncType> loweredType = new FuncType();
loweredType->declRef = translateDeclRef(DeclRef<Decl>(type->declRef)).As<CallableDecl>();
return loweredType;
}
- RefPtr<ExpressionType> visit(DeclRefType* type)
+ RefPtr<ExpressionType> visitDeclRefType(DeclRefType* type)
{
auto loweredDeclRef = translateDeclRef(type->declRef);
return DeclRefType::Create(loweredDeclRef);
}
- RefPtr<ExpressionType> visit(NamedExpressionType* type)
+ RefPtr<ExpressionType> visitNamedExpressionType(NamedExpressionType* type)
{
if (shared->target == CodeGenTarget::GLSL)
{
@@ -320,12 +321,12 @@ struct LoweringVisitor
return new NamedExpressionType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>());
}
- RefPtr<ExpressionType> visit(TypeType* type)
+ RefPtr<ExpressionType> visitTypeType(TypeType* type)
{
return new TypeType(lowerType(type->type));
}
- RefPtr<ExpressionType> visit(ArrayExpressionType* type)
+ RefPtr<ExpressionType> visitArrayExpressionType(ArrayExpressionType* type)
{
RefPtr<ArrayExpressionType> loweredType = new ArrayExpressionType();
loweredType->BaseType = lowerType(type->BaseType);
@@ -350,7 +351,7 @@ struct LoweringVisitor
}
// catch-all
- RefPtr<ExpressionSyntaxNode> visit(
+ RefPtr<ExpressionSyntaxNode> visitExpressionSyntaxNode(
ExpressionSyntaxNode* expr)
{
return structuralTransform(expr, this);
@@ -404,7 +405,7 @@ struct LoweringVisitor
return result;
}
- RefPtr<ExpressionSyntaxNode> visit(
+ RefPtr<ExpressionSyntaxNode> visitVarExpressionSyntaxNode(
VarExpressionSyntaxNode* expr)
{
// If the expression didn't get resolved, we can leave it as-is
@@ -428,7 +429,7 @@ struct LoweringVisitor
return loweredExpr;
}
- RefPtr<ExpressionSyntaxNode> visit(
+ RefPtr<ExpressionSyntaxNode> visitMemberExpressionSyntaxNode(
MemberExpressionSyntaxNode* expr)
{
auto loweredBase = lowerExpr(expr->BaseExpression);
@@ -521,7 +522,7 @@ struct LoweringVisitor
StmtVisitor::dispatch(stmt);
}
- RefPtr<ScopeDecl> visit(ScopeDecl* decl)
+ RefPtr<ScopeDecl> visitScopeDecl(ScopeDecl* decl)
{
RefPtr<ScopeDecl> loweredDecl = new ScopeDecl();
lowerDeclCommon(loweredDecl, decl);
@@ -594,7 +595,7 @@ struct LoweringVisitor
addStmt(stmt);
}
- void visit(BlockStmt* stmt)
+ void visitBlockStmt(BlockStmt* stmt)
{
RefPtr<BlockStmt> loweredStmt = new BlockStmt();
lowerScopeStmtFields(loweredStmt, stmt);
@@ -606,7 +607,7 @@ struct LoweringVisitor
addStmt(loweredStmt);
}
- void visit(SeqStmt* stmt)
+ void visitSeqStmt(SeqStmt* stmt)
{
for( auto ss : stmt->stmts )
{
@@ -614,12 +615,12 @@ struct LoweringVisitor
}
}
- void visit(ExpressionStatementSyntaxNode* stmt)
+ void visitExpressionStatementSyntaxNode(ExpressionStatementSyntaxNode* stmt)
{
addExprStmt(lowerExpr(stmt->Expression));
}
- void visit(VarDeclrStatementSyntaxNode* stmt)
+ void visitVarDeclrStatementSyntaxNode(VarDeclrStatementSyntaxNode* stmt)
{
DeclVisitor::dispatch(stmt->decl);
}
@@ -651,42 +652,42 @@ struct LoweringVisitor
loweredStmt->parentStmt = translateStmtRef(originalStmt->parentStmt);
}
- void visit(ContinueStatementSyntaxNode* stmt)
+ void visitContinueStatementSyntaxNode(ContinueStatementSyntaxNode* stmt)
{
RefPtr<ContinueStatementSyntaxNode> loweredStmt = new ContinueStatementSyntaxNode();
lowerChildStmtFields(loweredStmt, stmt);
addStmt(loweredStmt);
}
- void visit(BreakStatementSyntaxNode* stmt)
+ void visitBreakStatementSyntaxNode(BreakStatementSyntaxNode* stmt)
{
RefPtr<BreakStatementSyntaxNode> loweredStmt = new BreakStatementSyntaxNode();
lowerChildStmtFields(loweredStmt, stmt);
addStmt(loweredStmt);
}
- void visit(DefaultStmt* stmt)
+ void visitDefaultStmt(DefaultStmt* stmt)
{
RefPtr<DefaultStmt> loweredStmt = new DefaultStmt();
lowerChildStmtFields(loweredStmt, stmt);
addStmt(loweredStmt);
}
- void visit(DiscardStatementSyntaxNode* stmt)
+ void visitDiscardStatementSyntaxNode(DiscardStatementSyntaxNode* stmt)
{
RefPtr<DiscardStatementSyntaxNode> loweredStmt = new DiscardStatementSyntaxNode();
lowerStmtFields(loweredStmt, stmt);
addStmt(loweredStmt);
}
- void visit(EmptyStatementSyntaxNode* stmt)
+ void visitEmptyStatementSyntaxNode(EmptyStatementSyntaxNode* stmt)
{
RefPtr<EmptyStatementSyntaxNode> loweredStmt = new EmptyStatementSyntaxNode();
lowerStmtFields(loweredStmt, stmt);
addStmt(loweredStmt);
}
- void visit(UnparsedStmt* stmt)
+ void visitUnparsedStmt(UnparsedStmt* stmt)
{
RefPtr<UnparsedStmt> loweredStmt = new UnparsedStmt();
lowerStmtFields(loweredStmt, stmt);
@@ -696,7 +697,7 @@ struct LoweringVisitor
addStmt(loweredStmt);
}
- void visit(CaseStmt* stmt)
+ void visitCaseStmt(CaseStmt* stmt)
{
RefPtr<CaseStmt> loweredStmt = new CaseStmt();
lowerChildStmtFields(loweredStmt, stmt);
@@ -706,7 +707,7 @@ struct LoweringVisitor
addStmt(loweredStmt);
}
- void visit(IfStatementSyntaxNode* stmt)
+ void visitIfStatementSyntaxNode(IfStatementSyntaxNode* stmt)
{
RefPtr<IfStatementSyntaxNode> loweredStmt = new IfStatementSyntaxNode();
lowerStmtFields(loweredStmt, stmt);
@@ -718,7 +719,7 @@ struct LoweringVisitor
addStmt(loweredStmt);
}
- void visit(SwitchStmt* stmt)
+ void visitSwitchStmt(SwitchStmt* stmt)
{
RefPtr<SwitchStmt> loweredStmt = new SwitchStmt();
lowerScopeStmtFields(loweredStmt, stmt);
@@ -732,7 +733,7 @@ struct LoweringVisitor
}
- void visit(ForStatementSyntaxNode* stmt)
+ void visitForStatementSyntaxNode(ForStatementSyntaxNode* stmt)
{
RefPtr<ForStatementSyntaxNode> loweredStmt = new ForStatementSyntaxNode();
lowerScopeStmtFields(loweredStmt, stmt);
@@ -747,7 +748,7 @@ struct LoweringVisitor
addStmt(loweredStmt);
}
- void visit(WhileStatementSyntaxNode* stmt)
+ void visitWhileStatementSyntaxNode(WhileStatementSyntaxNode* stmt)
{
RefPtr<WhileStatementSyntaxNode> loweredStmt = new WhileStatementSyntaxNode();
lowerScopeStmtFields(loweredStmt, stmt);
@@ -760,7 +761,7 @@ struct LoweringVisitor
addStmt(loweredStmt);
}
- void visit(DoWhileStatementSyntaxNode* stmt)
+ void visitDoWhileStatementSyntaxNode(DoWhileStatementSyntaxNode* stmt)
{
RefPtr<DoWhileStatementSyntaxNode> loweredStmt = new DoWhileStatementSyntaxNode();
lowerScopeStmtFields(loweredStmt, stmt);
@@ -805,7 +806,7 @@ struct LoweringVisitor
assign(expr, createVarRef(expr->Position, varDecl));
}
- void visit(ReturnStatementSyntaxNode* stmt)
+ void visitReturnStatementSyntaxNode(ReturnStatementSyntaxNode* stmt)
{
auto loweredStmt = new ReturnStatementSyntaxNode();
lowerStmtCommon(loweredStmt, stmt);
@@ -1004,59 +1005,59 @@ struct LoweringVisitor
// Catch-all
- RefPtr<Decl> visit(ModifierDecl*)
+ RefPtr<Decl> visitModifierDecl(ModifierDecl*)
{
// should not occur in user code
SLANG_UNEXPECTED("modifiers shouldn't occur in user code");
}
- RefPtr<Decl> visit(GenericValueParamDecl*)
+ RefPtr<Decl> visitGenericValueParamDecl(GenericValueParamDecl*)
{
SLANG_UNEXPECTED("generics should be lowered to specialized decls");
}
- RefPtr<Decl> visit(GenericTypeParamDecl*)
+ RefPtr<Decl> visitGenericTypeParamDecl(GenericTypeParamDecl*)
{
SLANG_UNEXPECTED("generics should be lowered to specialized decls");
}
- RefPtr<Decl> visit(GenericTypeConstraintDecl*)
+ RefPtr<Decl> visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*)
{
SLANG_UNEXPECTED("generics should be lowered to specialized decls");
}
- RefPtr<Decl> visit(GenericDecl*)
+ RefPtr<Decl> visitGenericDecl(GenericDecl*)
{
SLANG_UNEXPECTED("generics should be lowered to specialized decls");
}
- RefPtr<Decl> visit(ProgramSyntaxNode*)
+ RefPtr<Decl> visitProgramSyntaxNode(ProgramSyntaxNode*)
{
SLANG_UNEXPECTED("module decls should be lowered explicitly");
}
- RefPtr<Decl> visit(SubscriptDecl*)
+ RefPtr<Decl> visitSubscriptDecl(SubscriptDecl*)
{
// We don't expect to find direct references to a subscript
// declaration, but rather to the underlying accessors
return nullptr;
}
- RefPtr<Decl> visit(InheritanceDecl*)
+ RefPtr<Decl> visitInheritanceDecl(InheritanceDecl*)
{
// We should deal with these explicitly, as part of lowering
// the type that contains them.
return nullptr;
}
- RefPtr<Decl> visit(ExtensionDecl*)
+ RefPtr<Decl> visitExtensionDecl(ExtensionDecl*)
{
// Extensions won't exist in the lowered code: their members
// will turn into ordinary functions that get called explicitly
return nullptr;
}
- RefPtr<Decl> visit(TypeDefDecl* decl)
+ RefPtr<Decl> visitTypeDefDecl(TypeDefDecl* decl)
{
RefPtr<TypeDefDecl> loweredDecl = new TypeDefDecl();
lowerDeclCommon(loweredDecl, decl);
@@ -1067,7 +1068,7 @@ struct LoweringVisitor
return loweredDecl;
}
- RefPtr<ImportDecl> visit(ImportDecl* decl)
+ RefPtr<ImportDecl> visitImportDecl(ImportDecl* decl)
{
// No need to translate things here if we are
// in "full" mode, because we will selectively
@@ -1086,7 +1087,7 @@ struct LoweringVisitor
return nullptr;
}
- RefPtr<EmptyDecl> visit(EmptyDecl* decl)
+ RefPtr<EmptyDecl> visitEmptyDecl(EmptyDecl* decl)
{
// Empty declarations are really only useful in GLSL,
// where they are used to hold metadata that doesn't
@@ -1103,7 +1104,7 @@ struct LoweringVisitor
return loweredDecl;
}
- RefPtr<Decl> visit(AggTypeDecl* decl)
+ RefPtr<Decl> visitAggTypeDecl(AggTypeDecl* decl)
{
// We want to lower any aggregate type declaration
// to just a `struct` type that contains its fields.
@@ -1145,7 +1146,7 @@ struct LoweringVisitor
return loweredDecl;
}
- RefPtr<VarDeclBase> visit(
+ RefPtr<VarDeclBase> visitVariable(
Variable* decl)
{
auto loweredDecl = lowerVarDeclCommon(new Variable(), decl);
@@ -1173,13 +1174,13 @@ struct LoweringVisitor
return loweredDecl;
}
- RefPtr<VarDeclBase> visit(
+ RefPtr<VarDeclBase> visitStructField(
StructField* decl)
{
return lowerVarDeclCommon(new StructField(), decl);
}
- RefPtr<VarDeclBase> visit(
+ RefPtr<VarDeclBase> visitParameterSyntaxNode(
ParameterSyntaxNode* decl)
{
return lowerVarDeclCommon(new ParameterSyntaxNode(), decl);
@@ -1191,7 +1192,7 @@ struct LoweringVisitor
}
- RefPtr<Decl> visit(
+ RefPtr<Decl> visitDeclGroup(
DeclGroup* group)
{
for (auto decl : group->decls)
@@ -1201,7 +1202,7 @@ struct LoweringVisitor
return nullptr;
}
- RefPtr<FunctionSyntaxNode> visit(
+ RefPtr<FunctionSyntaxNode> visitFunctionDeclBase(
FunctionDeclBase* decl)
{
// TODO: need to generate a name
@@ -1474,7 +1475,7 @@ struct LoweringVisitor
RefPtr<EntryPointLayout> entryPointLayout)
{
// First, loer the entry-point function as an ordinary function:
- auto loweredEntryPointFunc = visit(entryPointDecl);
+ auto loweredEntryPointFunc = visitFunctionDeclBase(entryPointDecl);
// Now we will generate a `void main() { ... }` function to call the lowered code.
RefPtr<FunctionSyntaxNode> mainDecl = new FunctionSyntaxNode();
@@ -1672,7 +1673,7 @@ struct LoweringVisitor
{
// Default case: lower an entry point just like any other function
default:
- return visit(entryPointDecl);
+ return visitFunctionDeclBase(entryPointDecl);
// For Slang->GLSL translation, we need to lower things from HLSL-style
// declarations over to GLSL conventions