diff options
| author | Tim Foley <tfoley@nvidia.com> | 2017-06-28 13:34:38 -0700 |
|---|---|---|
| committer | Tim Foley <tfoley@nvidia.com> | 2017-07-06 09:17:04 -0700 |
| commit | f145e09a6dcbcf326f782b3e6a76dbf291c792cf (patch) | |
| tree | 88a04619ceaaa37b87199dd82334cc9d102c156d /source/slang/lower.cpp | |
| parent | c0d2c17bc73bc2a8863e086af3ea395ad09465ee (diff) | |
Start to support cross-compilation via "lowering" pass
- The big change here is the introduction of a "lowering" pass that takes an input AST from the semantic checker, and produces an output AST suitable for emitting. The intention is that he lowering pass is responsible for:
- Stripping out unused code (when we have enough information to do so), by only outputting declarations that are transitively references from an entry point
- When cross-compiling to GLSL, generating a suitable `void main()` entry point to wrap the user-written entry-point function
- (Eventually) legalizing types in the program, by scalarizing aggregate types that mix uniform and resource types
- (Eventually) instantiating generic declarations so that the resulting code only deals with fully specialized declarations
- (Eventually) de-sugaring OOP constructs into basic "structs and functions" form
- (Eventually) instantiating code that depends on interface types at the concrete types chosen
- It is clear that there is still a lot of work to be done there, to this change is really about getting infrastructure in place without breaking the existing test cases.
- One cleanup here is that we get rid of the idea of whole-translation-unit output, since that was specific to HLSL output, and there is really no strong reason for keeping it. Users should now just ask for the output for each entry point that they wanted to generate.
- The biggest source of complexity for the lowering process is that it needs to produce the same AST structure as the input, to deal with the complexity of the rewriter case. That is, we need the output to be able to reproduce the input exactly in the case where we are rewriting and nothing needs to change, so the output format needs at least the degrees of freedom of the input.
- As a result, we end up having to distinguish "rewriter" and "full" modes in both lowering and code-emit steps, so that we can react appropriately.
- Generating a GLSL `main()` also adds a lot of complexity. Right now I'm using the simplest approach, where we always output the Slang/HLSL entry point as an ordinary function (as written) and then emit a simple GLSL `main()` to call it. I generate globals for all the shader inputs/outputs (these need to be scalarized and have explicit `location`s attached), and then collect these into the `struct` types of the original parameters as needed.
- This approach will start to have some major down-sides once we have to deal with "arrayed" input/output
- A long-term question here is how to replace entry-point parameter types with scalarized and/or "transposed" versions, while still letting the original code work as written (including copying those inputs to temporary arrays)
- Split `BlockStatementSyntaxNode` into:
- `BlockStmt` which just provides a scope around a `body` statement
- `SeqStmt` which just allows multiple statements to be treated as one
- Change how we emit `for` loops, to deal with the case where the initialization part might expand into multiple statements
- Basically `for(A;B;C) {D}` becomes `{A; for(;B;C) {D}}`, so we can handle arbitrary statements for `A`
- As an additional wrinkle, when we are rewriting HLSL, we just generate `A; for(;B;C) {D}` to deal with the broken scoping there
- This change is needed because the lowering pass was sometimes expanding the original initialization statement `A` into a block `{A}`. Certainly if it declared multiple variables we'd need to handle it, and this seemed the easiest way
- A more significant challenge for lowering would come if/when we ever wanted to support true short-circuiting behavior for `&&` and `||`
- For right now I'm not changing the behavior of the "rewriter" mode, so we still have `UnparsedStmt` instances being generated, but it is clear that eventually we need to parse *all* input, even if we can't type-check 100% of it. This is required so that we can rewrite user code that might refer to a shader input with interface type.
Diffstat (limited to 'source/slang/lower.cpp')
| -rw-r--r-- | source/slang/lower.cpp | 1601 |
1 files changed, 1601 insertions, 0 deletions
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp new file mode 100644 index 000000000..674614cd3 --- /dev/null +++ b/source/slang/lower.cpp @@ -0,0 +1,1601 @@ +// lower.cpp +#include "lower.h" + +#include "visitor.h" + +namespace Slang +{ + +// + +template<typename V> +struct StructuralTransformVisitorBase +{ + V* visitor; + + RefPtr<StatementSyntaxNode> transformDeclField(StatementSyntaxNode* stmt) + { + return visitor->translateStmtRef(stmt); + } + + RefPtr<Decl> transformDeclField(Decl* decl) + { + return visitor->translateDeclRef(decl); + } + + template<typename T> + DeclRef<T> transformDeclField(DeclRef<T> const& decl) + { + return visitor->translateDeclRef(decl).As<T>(); + } + + TypeExp transformSyntaxField(TypeExp const& typeExp) + { + TypeExp result; + result.type = visitor->transformSyntaxField(typeExp.type); + return result; + } + + QualType transformSyntaxField(QualType const& qualType) + { + QualType result = qualType; + result.type = visitor->transformSyntaxField(qualType.type); + return result; + } + + RefPtr<ExpressionSyntaxNode> transformSyntaxField(ExpressionSyntaxNode* expr) + { + return visitor->transformSyntaxField(expr); + } + + RefPtr<StatementSyntaxNode> transformSyntaxField(StatementSyntaxNode* stmt) + { + return visitor->transformSyntaxField(stmt); + } + + RefPtr<DeclBase> transformSyntaxField(DeclBase* decl) + { + return visitor->transformSyntaxField(decl); + } + + RefPtr<ScopeDecl> transformSyntaxField(ScopeDecl* decl) + { + return visitor->transformSyntaxField(decl).As<ScopeDecl>(); + } + + template<typename T> + List<T> transformSyntaxField(List<T> const& list) + { + List<T> result; + for (auto item : list) + { + result.Add(transformSyntaxField(item)); + } + return result; + } +}; + +template<typename V> +struct StructuralTransformStmtVisitor + : StructuralTransformVisitorBase<V> + , StmtVisitor<StructuralTransformStmtVisitor<V>, RefPtr<StatementSyntaxNode>> +{ + void transformFields(StatementSyntaxNode* result, StatementSyntaxNode* obj) + { + } + +#define SYNTAX_CLASS(NAME, BASE, ...) \ + RefPtr<StatementSyntaxNode> visit(NAME* obj) { \ + RefPtr<NAME> result = new NAME(*obj); \ + transformFields(result, obj); \ + return result; \ + } \ + void transformFields(NAME* result, NAME* obj) { \ + transformFields((BASE*) result, (BASE*) obj); \ + +#define SYNTAX_FIELD(TYPE, NAME) result->NAME = this->transformSyntaxField(obj->NAME); +#define DECL_FIELD(TYPE, NAME) result->NAME = this->transformDeclField(obj->NAME); + +#define FIELD(TYPE, NAME) /* empty */ + +#define END_SYNTAX_CLASS() \ + } + +#include "object-meta-begin.h" +#include "stmt-defs.h" +#include "object-meta-end.h" + +}; + +template<typename V> +RefPtr<StatementSyntaxNode> structuralTransform( + StatementSyntaxNode* stmt, + V* visitor) +{ + StructuralTransformStmtVisitor<V> transformer; + transformer.visitor = visitor; + return transformer.dispatch(stmt); +} + +template<typename V> +struct StructuralTransformExprVisitor + : StructuralTransformVisitorBase<V> + , ExprVisitor<StructuralTransformExprVisitor<V>, RefPtr<ExpressionSyntaxNode>> +{ + void transformFields(ExpressionSyntaxNode* result, ExpressionSyntaxNode* obj) + { + result->Type = transformSyntaxField(obj->Type); + } + + +#define SYNTAX_CLASS(NAME, BASE, ...) \ + RefPtr<ExpressionSyntaxNode> visit(NAME* obj) { \ + RefPtr<NAME> result = new NAME(*obj); \ + transformFields(result, obj); \ + return result; \ + } \ + void transformFields(NAME* result, NAME* obj) { \ + transformFields((BASE*) result, (BASE*) obj); \ + +#define SYNTAX_FIELD(TYPE, NAME) result->NAME = transformSyntaxField(obj->NAME); +#define DECL_FIELD(TYPE, NAME) result->NAME = transformDeclField(obj->NAME); + +#define FIELD(TYPE, NAME) /* empty */ + +#define END_SYNTAX_CLASS() \ + } + +#include "object-meta-begin.h" +#include "expr-defs.h" +#include "object-meta-end.h" +}; + + +template<typename V> +RefPtr<ExpressionSyntaxNode> structuralTransform( + ExpressionSyntaxNode* expr, + V* visitor) +{ + StructuralTransformExprVisitor<V> transformer; + transformer.visitor = visitor; + return transformer.dispatch(expr); +} + +// + +// Pseudo-syntax used during lowering +class TupleDecl : public VarDeclBase +{ +public: + virtual void accept(IDeclVisitor *, void *) override + { + throw "unexpected"; + } + + List<RefPtr<VarDeclBase>> decls; +}; + +// Pseudo-syntax used during lowering: +// represents an ordered list of expressions as a single unit +class TupleExpr : public ExpressionSyntaxNode +{ +public: + virtual void accept(IExprVisitor *, void *) override + { + throw "unexpected"; + } + + List<RefPtr<ExpressionSyntaxNode>> exprs; +}; + +struct SharedLoweringContext +{ + ProgramLayout* programLayout; + CodeGenTarget target; + + RefPtr<ProgramSyntaxNode> loweredProgram; + + Dictionary<Decl*, RefPtr<Decl>> loweredDecls; + Dictionary<Decl*, Decl*> mapLoweredDeclToOriginal; + + bool isRewrite; +}; + +static void attachLayout( + ModifiableSyntaxNode* syntax, + Layout* layout) +{ + RefPtr<ComputedLayoutModifier> modifier = new ComputedLayoutModifier(); + modifier->layout = layout; + + addModifier(syntax, modifier); +} + +struct LoweringVisitor + : ExprVisitor<LoweringVisitor, RefPtr<ExpressionSyntaxNode>> + , StmtVisitor<LoweringVisitor, void> + , DeclVisitor<LoweringVisitor, RefPtr<Decl>> + , TypeVisitor<LoweringVisitor, RefPtr<ExpressionType>> + , ValVisitor<LoweringVisitor, RefPtr<Val>> +{ + // + SharedLoweringContext* shared; + RefPtr<Substitutions> substitutions; + + bool isBuildingStmt = false; + RefPtr<StatementSyntaxNode> stmtBeingBuilt; + + // If we *aren't* building a statement, then this + // is the container we should be adding declarations to + RefPtr<ContainerDecl> parentDecl; + + // If we are in a context where a `return` should be turned + // into assignment to a variable (followed by a `return`), + // then this will point to that variable. + RefPtr<Variable> resultVariable; + + CodeGenTarget getTarget() { return shared->target; } + + // + // Values + // + + RefPtr<Val> lowerVal(Val* val) + { + if (!val) return nullptr; + return ValVisitor::dispatch(val); + } + + RefPtr<Val> visit(GenericParamIntVal* val) + { + return new GenericParamIntVal(translateDeclRef(DeclRef<Decl>(val->declRef)).As<VarDeclBase>()); + } + + RefPtr<Val> visit(ConstantIntVal* val) + { + return val; + } + + // + // Types + // + + RefPtr<ExpressionType> lowerType( + ExpressionType* type) + { + return TypeVisitor::dispatch(type); + } + + TypeExp lowerType( + TypeExp const& typeExp) + { + TypeExp result; + result.type = lowerType(typeExp.type); + return result; + } + + RefPtr<ExpressionType> visit(ErrorType* type) + { + return type; + } + + RefPtr<ExpressionType> visit(OverloadGroupType* type) + { + return type; + } + + RefPtr<ExpressionType> visit(InitializerListType* type) + { + return type; + } + + RefPtr<ExpressionType> visit(GenericDeclRefType* type) + { + return new GenericDeclRefType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<GenericDecl>()); + } + + RefPtr<ExpressionType> visit(FuncType* type) + { + RefPtr<FuncType> loweredType = new FuncType(); + loweredType->declRef = translateDeclRef(DeclRef<Decl>(type->declRef)).As<CallableDecl>(); + return loweredType; + } + + RefPtr<ExpressionType> visit(DeclRefType* type) + { + auto loweredDeclRef = translateDeclRef(type->declRef); + return DeclRefType::Create(loweredDeclRef); + } + + RefPtr<ExpressionType> visit(NamedExpressionType* type) + { + if (shared->target == CodeGenTarget::GLSL) + { + // GLSL does not support `typedef`, so we will lower it out of existence here + return lowerType(GetType(type->declRef)); + } + + return new NamedExpressionType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>()); + } + + RefPtr<ExpressionType> visit(TypeType* type) + { + return new TypeType(lowerType(type->type)); + } + + RefPtr<ExpressionType> visit(ArrayExpressionType* type) + { + RefPtr<ArrayExpressionType> loweredType = new ArrayExpressionType(); + loweredType->BaseType = lowerType(type->BaseType); + loweredType->ArrayLength = lowerVal(type->ArrayLength).As<IntVal>(); + return loweredType; + } + + RefPtr<ExpressionType> transformSyntaxField(ExpressionType* type) + { + return lowerType(type); + } + + // + // Expressions + // + + RefPtr<ExpressionSyntaxNode> lowerExpr( + ExpressionSyntaxNode* expr) + { + if (!expr) return nullptr; + return ExprVisitor::dispatch(expr); + } + + // catch-all + RefPtr<ExpressionSyntaxNode> visit( + ExpressionSyntaxNode* expr) + { + return structuralTransform(expr, this); + } + + RefPtr<ExpressionSyntaxNode> transformSyntaxField(ExpressionSyntaxNode* expr) + { + return lowerExpr(expr); + } + + void lowerExprCommon( + RefPtr<ExpressionSyntaxNode> loweredExpr, + RefPtr<ExpressionSyntaxNode> expr) + { + loweredExpr->Position = expr->Position; + loweredExpr->Type.type = lowerType(expr->Type.type); + } + + RefPtr<ExpressionSyntaxNode> createVarRef( + CodePosition const& loc, + VarDeclBase* decl) + { + if (auto tupleDecl = dynamic_cast<TupleDecl*>(decl)) + { + return createTupleRef(loc, tupleDecl); + } + else + { + RefPtr<VarExpressionSyntaxNode> result = new VarExpressionSyntaxNode(); + result->Position = loc; + result->Type.type = decl->Type.type; + result->declRef = makeDeclRef(decl); + return result; + } + } + + RefPtr<ExpressionSyntaxNode> createTupleRef( + CodePosition const& loc, + TupleDecl* decl) + { + RefPtr<TupleExpr> result = new TupleExpr(); + result->Position = loc; + result->Type.type = decl->Type.type; + + for (auto dd : decl->decls) + { + auto expr = createVarRef(loc, dd); + result->exprs.Add(expr); + } + + return result; + } + + RefPtr<ExpressionSyntaxNode> visit( + VarExpressionSyntaxNode* expr) + { + // If the expression didn't get resolved, we can leave it as-is + if (!expr->declRef) + return expr; + + auto loweredDeclRef = translateDeclRef(expr->declRef); + auto loweredDecl = loweredDeclRef.getDecl(); + + if (auto tupleDecl = dynamic_cast<TupleDecl*>(loweredDecl)) + { + // If we are referencing a declaration that got tuple-ified, + // then we need to produce a tuple expression as well. + + return createTupleRef(expr->Position, tupleDecl); + } + + RefPtr<VarExpressionSyntaxNode> loweredExpr = new VarExpressionSyntaxNode(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->declRef = loweredDeclRef; + return loweredExpr; + } + + RefPtr<ExpressionSyntaxNode> visit( + MemberExpressionSyntaxNode* expr) + { + auto loweredBase = lowerExpr(expr->BaseExpression); + + // Are we extracting an element from a tuple? + if (auto baseTuple = loweredBase.As<TupleExpr>()) + { + // We need to find the correct member expression, + // based on the actual tuple type. + + throw "unimplemented"; + } + + // Default handling: + auto loweredDeclRef = translateDeclRef(expr->declRef); + assert(!dynamic_cast<TupleDecl*>(loweredDeclRef.getDecl())); + + RefPtr<MemberExpressionSyntaxNode> loweredExpr = new MemberExpressionSyntaxNode(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->BaseExpression = loweredBase; + loweredExpr->declRef = loweredDeclRef; + + return loweredExpr; + } + + // + // Statements + // + + StatementSyntaxNode* translateStmtRef( + StatementSyntaxNode* stmt) + { + throw "unimplemented"; + } + + RefPtr<StatementSyntaxNode> lowerStmt( + StatementSyntaxNode* stmt) + { + if(!stmt) + return nullptr; + + LoweringVisitor subVisitor = *this; + subVisitor.stmtBeingBuilt = nullptr; + + subVisitor.lowerStmtImpl(stmt); + + if( !subVisitor.stmtBeingBuilt ) + { + return new EmptyStatementSyntaxNode(); + } + else + { + return subVisitor.stmtBeingBuilt; + } + } + + void lowerStmtImpl( + StatementSyntaxNode* stmt) + { + StmtVisitor::dispatch(stmt); + } + + RefPtr<ScopeDecl> visit(ScopeDecl* decl) + { + RefPtr<ScopeDecl> loweredDecl = new ScopeDecl(); + lowerDeclCommon(loweredDecl, decl); + return loweredDecl; + } + + LoweringVisitor pushScope( + RefPtr<ScopeStmt> loweredStmt, + RefPtr<ScopeStmt> stmt) + { + loweredStmt->scopeDecl = translateDeclRef(stmt->scopeDecl).As<ScopeDecl>(); + + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + subVisitor.parentDecl = loweredStmt->scopeDecl; + return subVisitor; + } + + void addStmtImpl( + RefPtr<StatementSyntaxNode>& dest, + StatementSyntaxNode* stmt) + { + // add a statement to the code we are building... + if( !dest ) + { + dest = stmt; + return; + } + + if (auto blockStmt = dest.As<BlockStmt>()) + { + addStmtImpl(blockStmt->body, stmt); + return; + } + + if (auto seqStmt = dest.As<SeqStmt>()) + { + seqStmt->stmts.Add(stmt); + } + else + { + RefPtr<SeqStmt> newSeqStmt = new SeqStmt(); + + newSeqStmt->stmts.Add(dest); + newSeqStmt->stmts.Add(stmt); + + dest = newSeqStmt; + } + + } + + void addStmt( + StatementSyntaxNode* stmt) + { + addStmtImpl(stmtBeingBuilt, stmt); + } + + void addExprStmt( + RefPtr<ExpressionSyntaxNode> expr) + { + // TODO: handle cases where the `expr` cannot be directly + // represented as a single statement + + RefPtr<ExpressionStatementSyntaxNode> stmt = new ExpressionStatementSyntaxNode(); + stmt->Expression = expr; + addStmt(stmt); + } + + void visit(BlockStmt* stmt) + { + RefPtr<BlockStmt> loweredStmt = new BlockStmt(); + + LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); + subVisitor.stmtBeingBuilt = loweredStmt; + + subVisitor.lowerStmtImpl(stmt->body); + + addStmt(loweredStmt); + } + + void visit(SeqStmt* stmt) + { + for( auto ss : stmt->stmts ) + { + lowerStmtImpl(ss); + } + } + + void visit(ExpressionStatementSyntaxNode* stmt) + { + addExprStmt(lowerExpr(stmt->Expression)); + } + + void visit(VarDeclrStatementSyntaxNode* stmt) + { + DeclVisitor::dispatch(stmt->decl); + } + + // catch-all + void visit(StatementSyntaxNode* stmt) + { + auto loweredStmt = structuralTransform(stmt, this); + addStmt(loweredStmt); + } + + RefPtr<StatementSyntaxNode> transformSyntaxField(StatementSyntaxNode* stmt) + { + return lowerStmt(stmt); + } + + void lowerStmtCommon(StatementSyntaxNode* loweredStmt, StatementSyntaxNode* stmt) + { + loweredStmt->modifiers = stmt->modifiers; + } + + void assign( + RefPtr<ExpressionSyntaxNode> destExpr, + RefPtr<ExpressionSyntaxNode> srcExpr) + { + RefPtr<AssignExpr> assignExpr = new AssignExpr(); + assignExpr->Position = destExpr->Position; + assignExpr->left = destExpr; + assignExpr->right = srcExpr; + + addExprStmt(assignExpr); + } + + void assign(VarDeclBase* varDecl, RefPtr<ExpressionSyntaxNode> expr) + { + assign(createVarRef(expr->Position, varDecl), expr); + } + + void assign(RefPtr<ExpressionSyntaxNode> expr, VarDeclBase* varDecl) + { + assign(expr, createVarRef(expr->Position, varDecl)); + } + + void visit(ReturnStatementSyntaxNode* stmt) + { + auto loweredStmt = new ReturnStatementSyntaxNode(); + lowerStmtCommon(loweredStmt, stmt); + + if (stmt->Expression) + { + if (resultVariable) + { + // Do it as an assignment + assign(resultVariable, lowerExpr(stmt->Expression)); + } + else + { + // Simple case + loweredStmt->Expression = lowerExpr(stmt->Expression); + } + } + + addStmt(loweredStmt); + } + + // + // Declarations + // + + RefPtr<Val> translateVal(Val* val) + { + if (auto type = dynamic_cast<ExpressionType*>(val)) + return lowerType(type); + + if (auto litVal = dynamic_cast<ConstantIntVal*>(val)) + return val; + + throw 99; + } + + RefPtr<Substitutions> translateSubstitutions( + Substitutions* substitutions) + { + if (!substitutions) return nullptr; + + RefPtr<Substitutions> result = new Substitutions(); + result->genericDecl = translateDeclRef(substitutions->genericDecl).As<GenericDecl>(); + for (auto arg : substitutions->args) + { + result->args.Add(translateVal(arg)); + } + return result; + } + + static Decl* getModifiedDecl(Decl* decl) + { + if (!decl) return nullptr; + if (auto genericDecl = dynamic_cast<GenericDecl*>(decl->ParentDecl)) + return genericDecl; + return decl; + } + + DeclRef<Decl> translateDeclRef( + DeclRef<Decl> const& decl) + { + DeclRef<Decl> result; + result.decl = translateDeclRef(decl.decl); + result.substitutions = translateSubstitutions(decl.substitutions); + return result; + } + + RefPtr<Decl> translateDeclRef( + Decl* decl) + { + if (!decl) return nullptr; + + // We don't want to translate references to built-in declarations, + // since they won't be subtituted anyway. + if (getModifiedDecl(decl)->HasModifier<FromStdLibModifier>()) + return decl; + + // If any parent of the declaration was in the stdlib, then + // we need to skip it. + for(auto pp = decl; pp; pp = pp->ParentDecl) + { + if (pp->HasModifier<FromStdLibModifier>()) + return decl; + } + + if (getModifiedDecl(decl)->HasModifier<BuiltinModifier>()) + return decl; + + RefPtr<Decl> loweredDecl; + if (shared->loweredDecls.TryGetValue(decl, loweredDecl)) + return loweredDecl; + + // Time to force it + return lowerDecl(decl); + } + + RefPtr<ContainerDecl> translateDeclRef( + ContainerDecl* decl) + { + return translateDeclRef((Decl*)decl).As<ContainerDecl>(); + } + + RefPtr<DeclBase> lowerDeclBase( + DeclBase* declBase) + { + if (Decl* decl = dynamic_cast<Decl*>(declBase)) + { + return lowerDecl(decl); + } + else + { + DeclVisitor::dispatch(declBase); + } + + } + + RefPtr<Decl> lowerDecl( + Decl* decl) + { + RefPtr<Decl> loweredDecl = DeclVisitor::dispatch(decl).As<Decl>(); + return loweredDecl; + } + + static void addMember( + RefPtr<ContainerDecl> containerDecl, + RefPtr<Decl> memberDecl) + { + containerDecl->Members.Add(memberDecl); + memberDecl->ParentDecl = containerDecl.Ptr(); + } + + void addDecl( + Decl* decl) + { + if(isBuildingStmt) + { + RefPtr<VarDeclrStatementSyntaxNode> declStmt = new VarDeclrStatementSyntaxNode(); + declStmt->Position = decl->Position; + declStmt->decl = decl; + addStmt(declStmt); + } + + + // We will add the declaration to the current container declaration being + // translated, which the user will maintain via pua/pop. + // + + assert(parentDecl); + addMember(parentDecl, decl); + } + + void registerLoweredDecl(Decl* loweredDecl, Decl* decl) + { + shared->loweredDecls.Add(decl, loweredDecl); + + shared->mapLoweredDeclToOriginal.Add(loweredDecl, decl); + } + + void lowerDeclCommon( + Decl* loweredDecl, + Decl* decl) + { + registerLoweredDecl(loweredDecl, decl); + + loweredDecl->Position = decl->Position; + loweredDecl->Name = decl->getNameToken(); + + // Lower modifiers as needed + + // HACK: just doing a shallow copy of modifiers, which will + // suffice for most of them, but we need to do something + // better soon. + loweredDecl->modifiers = decl->modifiers; + + // deal with layout stuff + + auto loweredParent = translateDeclRef(decl->ParentDecl); + if (loweredParent) + { + auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>(); + if (layoutMod) + { + auto parentLayout = layoutMod->layout; + if (auto structLayout = parentLayout.As<StructTypeLayout>()) + { + RefPtr<VarLayout> fieldLayout; + if (structLayout->mapVarToLayout.TryGetValue(decl, fieldLayout)) + { + attachLayout(loweredDecl, fieldLayout); + } + } + + // TODO: are there other cases to handle here? + } + } + } + + // Catch-all + RefPtr<Decl> visit( + Decl* decl) + { + assert(!"unimplemented"); + return decl; + } + + RefPtr<ImportDecl> visit(ImportDecl* decl) + { + // No need to translate things here if we are + // in "full" mode, because we will selectively + // translate the imported declarations at their + // use sites(s). + if (!shared->isRewrite) + return nullptr; + + for (auto dd : decl->importedModuleDecl->Members) + { + translateDeclRef(dd); + } + + // Don't actually include a representation of + // the import declaration in the output + return nullptr; + } + + RefPtr<EmptyDecl> visit(EmptyDecl* decl) + { + // Empty declarations are really only useful in GLSL, + // where they are used to hold metadata that doesn't + // attach to any particular shader parameter. + // + // TODO: Only lower empty declarations if we are + // rewriting a GLSL file, and otherwise ignore them. + // + RefPtr<EmptyDecl> loweredDecl = new EmptyDecl(); + lowerDeclCommon(loweredDecl, decl); + + addDecl(loweredDecl); + + return loweredDecl; + } + + RefPtr<Decl> visit(AggTypeDecl* decl) + { + // We want to lower any aggregate type declaration + // to just a `struct` type that contains its fields. + // + // Any non-field members (e.g., methods) will be + // lowered separately. + + // TODO: also need to figure out how to handle fields + // with types that should not be allowed in a `struct` + // for the chosen target. + // (also: what to do if there are no fields left + // after removing invalid ones?) + + RefPtr<StructSyntaxNode> loweredDecl = new StructSyntaxNode(); + lowerDeclCommon(loweredDecl, decl); + + for (auto field : decl->getMembersOfType<VarDeclBase>()) + { + // TODO: anything more to do than this? + addMember(loweredDecl, translateDeclRef(field)); + } + + addMember( + shared->loweredProgram, + loweredDecl); + + return loweredDecl; + } + + RefPtr<VarDeclBase> lowerVarDeclCommon( + RefPtr<VarDeclBase> loweredDecl, + VarDeclBase* decl) + { + lowerDeclCommon(loweredDecl, decl); + + loweredDecl->Type = lowerType(decl->Type); + loweredDecl->Expr = lowerExpr(decl->Expr); + + return loweredDecl; + } + + RefPtr<VarDeclBase> visit( + Variable* decl) + { + auto loweredDecl = lowerVarDeclCommon(new Variable(), decl); + + // We need to add things to an appropriate scope, based on what + // we are referencing. + // + // If this is a global variable (program scope), then add it + // to the global scope. + RefPtr<ContainerDecl> parentDecl = decl->ParentDecl; + if (auto parentModuleDecl = parentDecl.As<ProgramSyntaxNode>()) + { + addMember( + translateDeclRef(parentModuleDecl), + loweredDecl); + } + // TODO: handle `static` function-scope variables + else + { + // A local variable declaration will get added to the + // statement scope we are currently processing. + addDecl(loweredDecl); + } + + return loweredDecl; + } + + RefPtr<VarDeclBase> visit( + StructField* decl) + { + return lowerVarDeclCommon(new StructField(), decl); + } + + RefPtr<VarDeclBase> visit( + ParameterSyntaxNode* decl) + { + return lowerVarDeclCommon(new ParameterSyntaxNode(), decl); + } + + RefPtr<DeclBase> transformSyntaxField(DeclBase* decl) + { + return lowerDeclBase(decl); + } + + + RefPtr<Decl> visit( + DeclGroup* group) + { + for (auto decl : group->decls) + { + lowerDecl(decl); + } + return nullptr; + } + + RefPtr<FunctionSyntaxNode> visit( + FunctionDeclBase* decl) + { + // TODO: need to generate a name + + RefPtr<FunctionSyntaxNode> loweredDecl = new FunctionSyntaxNode(); + lowerDeclCommon(loweredDecl, decl); + + // TODO: push scope for parent decl here... + + // TODO: need to copy over relevant modifiers + + for (auto paramDecl : decl->GetParameters()) + { + addMember(loweredDecl, translateDeclRef(paramDecl)); + } + + auto loweredReturnType = lowerType(decl->ReturnType); + + loweredDecl->ReturnType = loweredReturnType; + + // If we are a being called recurisvely, then we need to + // be careful not to let the context get polluted + LoweringVisitor subVisitor = *this; + subVisitor.resultVariable = nullptr; + subVisitor.stmtBeingBuilt = nullptr; + + loweredDecl->Body = subVisitor.lowerStmt(decl->Body); + + // A lowered function always becomes a global-scope function, + // even if it had been a member function when declared. + addMember(shared->loweredProgram, loweredDecl); + + return loweredDecl; + } + + // + // Entry Points + // + + EntryPointLayout* findEntryPointLayout( + EntryPointRequest* entryPointRequest) + { + for( auto entryPointLayout : shared->programLayout->entryPoints ) + { + if(entryPointLayout->entryPoint->getName() != entryPointRequest->name) + continue; + + if(entryPointLayout->profile != entryPointRequest->profile) + continue; + + // TODO: can't easily filter on translation unit here... + // Ideally the `EntryPointRequest` should get filled in with a pointer + // the specific function declaration that represents the entry point. + + return entryPointLayout.Ptr(); + } + + return nullptr; + } + + enum class VaryingParameterDirection + { + Input, + Output, + }; + + struct VaryingParameterArraySpec + { + VaryingParameterArraySpec* next = nullptr; + IntVal* elementCount; + }; + + struct VaryingParameterInfo + { + String name; + VaryingParameterDirection direction; + VaryingParameterArraySpec* arraySpecs = nullptr; + }; + + + void lowerSimpleShaderParameterToGLSLGlobal( + VaryingParameterInfo const& info, + RefPtr<ExpressionType> varType, + RefPtr<VarLayout> varLayout, + RefPtr<ExpressionSyntaxNode> varExpr) + { + RefPtr<ExpressionType> type = varType; + + for (auto aa = info.arraySpecs; aa; aa = aa->next) + { + RefPtr<ArrayExpressionType> arrayType = new ArrayExpressionType(); + arrayType->BaseType = type; + arrayType->ArrayLength = aa->elementCount; + + type = arrayType; + } + + // TODO: if we are declaring an SOA-ized array, + // this is where those array dimensions would need + // to be tacked on. + + RefPtr<Variable> globalVarDecl = new Variable(); + globalVarDecl->Name.Content = info.name; + globalVarDecl->Type.type = type; + + addMember(shared->loweredProgram, globalVarDecl); + + // Add the layout information + RefPtr<ComputedLayoutModifier> modifier = new ComputedLayoutModifier(); + modifier->layout = varLayout; + addModifier(globalVarDecl, modifier); + + // Need to generate an assignment in the right direction. + // + // TODO: for now I am just dealing with input: + + switch (info.direction) + { + case VaryingParameterDirection::Input: + addModifier(globalVarDecl, new InModifier()); + assign(varExpr, globalVarDecl); + break; + + case VaryingParameterDirection::Output: + addModifier(globalVarDecl, new OutModifier()); + + assign(globalVarDecl, varExpr); + break; + } + } + + void lowerShaderParameterToGLSLGLobalsRec( + VaryingParameterInfo const& info, + RefPtr<ExpressionType> varType, + RefPtr<VarLayout> varLayout, + RefPtr<ExpressionSyntaxNode> varExpr) + { + assert(varLayout); + + if (auto basicType = varType->As<BasicExpressionType>()) + { + // handled below + } + else if (auto vectorType = varType->As<VectorExpressionType>()) + { + // handled below + } + else if (auto matrixType = varType->As<MatrixExpressionType>()) + { + // handled below + } + else if (auto arrayType = varType->As<ArrayExpressionType>()) + { + // We will accumulate information on the array + // types that were encoutnered on our walk down + // to the leaves, and then apply these array dimensions + // to any leaf parameters. + + VaryingParameterArraySpec arraySpec; + arraySpec.next = info.arraySpecs; + arraySpec.elementCount = arrayType->ArrayLength; + + VaryingParameterInfo arrayInfo = info; + arrayInfo.arraySpecs = &arraySpec; + + RefPtr<IndexExpressionSyntaxNode> subscriptExpr = new IndexExpressionSyntaxNode(); + subscriptExpr->Position = varExpr->Position; + subscriptExpr->BaseExpression = varExpr; + + // TODO: we need to construct syntax for a loop to initialize + // the array here... + throw "unimplemented"; + + // Note that we use the original `varLayout` that was passed in, + // since that is the layout that will ultimately need to be + // used on the array elements. + // + // TODO: That won't actually work if we ever had an array of + // heterogeneous stuff... + lowerShaderParameterToGLSLGLobalsRec( + arrayInfo, + arrayType->BaseType, + varLayout, + subscriptExpr); + + } + else if (auto declRefType = varType->As<DeclRefType>()) + { + auto declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + { + // The shader parameter had a structured type, so we need + // to destructure it into its constituent fields + + for (auto fieldDeclRef : getMembersOfType<VarDeclBase>(aggTypeDeclRef)) + { + // Don't emit storage for `static` fields here, of course + if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + continue; + + RefPtr<MemberExpressionSyntaxNode> fieldExpr = new MemberExpressionSyntaxNode(); + fieldExpr->Position = varExpr->Position; + fieldExpr->Type.type = GetType(fieldDeclRef); + fieldExpr->declRef = fieldDeclRef; + fieldExpr->BaseExpression = varExpr; + + VaryingParameterInfo fieldInfo = info; + fieldInfo.name = info.name + "_" + fieldDeclRef.GetName(); + + // Need to find the layout for the given field... + Decl* originalFieldDecl = nullptr; + shared->mapLoweredDeclToOriginal.TryGetValue(fieldDeclRef.getDecl(), originalFieldDecl); + assert(originalFieldDecl); + + auto structTypeLayout = varLayout->typeLayout.As<StructTypeLayout>(); + assert(structTypeLayout); + + RefPtr<VarLayout> fieldLayout; + structTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout); + assert(fieldLayout); + + lowerShaderParameterToGLSLGLobalsRec( + fieldInfo, + GetType(fieldDeclRef), + fieldLayout, + fieldExpr); + } + + // Okay, we are done with this parameter + return; + } + } + + // Default case: just try to emit things as-is + lowerSimpleShaderParameterToGLSLGlobal(info, varType, varLayout, varExpr); + } + + void lowerShaderParameterToGLSLGLobals( + RefPtr<Variable> localVarDecl, + RefPtr<VarLayout> paramLayout, + VaryingParameterDirection direction) + { + auto name = localVarDecl->getName(); + auto declRef = makeDeclRef(localVarDecl.Ptr()); + + RefPtr<VarExpressionSyntaxNode> expr = new VarExpressionSyntaxNode(); + expr->name = name; + expr->declRef = declRef; + expr->Type.type = GetType(declRef); + + VaryingParameterInfo info; + info.name = name; + info.direction = direction; + + lowerShaderParameterToGLSLGLobalsRec( + info, + localVarDecl->getType(), + paramLayout, + expr); + } + + struct EntryPointParamPair + { + RefPtr<ParameterSyntaxNode> original; + RefPtr<VarLayout> layout; + RefPtr<Variable> lowered; + }; + + RefPtr<FunctionSyntaxNode> lowerEntryPointToGLSL( + FunctionSyntaxNode* entryPointDecl, + RefPtr<EntryPointLayout> entryPointLayout) + { + // First, loer the entry-point function as an ordinary function: + auto loweredEntryPointFunc = visit(entryPointDecl); + + // Now we will generate a `void main() { ... }` function to call the lowered code. + RefPtr<FunctionSyntaxNode> mainDecl = new FunctionSyntaxNode(); + mainDecl->ReturnType.type = ExpressionType::GetVoid(); + mainDecl->Name.Content = "main"; + + // If the user's entry point was called `main` then rename it here + if (loweredEntryPointFunc->getName() == "main") + loweredEntryPointFunc->Name.Content = "main_"; + + // We will want to generate declarations into the body of our new `main()` + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + + // The parameters of the entry-point function will be translated to + // both a local variable (for passing to/from the entry point func), + // and to global variables (used for parameter passing) + + List<EntryPointParamPair> params; + + // First generate declarations for the locals + for (auto paramDecl : entryPointDecl->GetParameters()) + { + RefPtr<VarLayout> paramLayout; + entryPointLayout->mapVarToLayout.TryGetValue(paramDecl.Ptr(), paramLayout); + assert(paramLayout); + + RefPtr<Variable> localVarDecl = new Variable(); + localVarDecl->Position = paramDecl->Position; + localVarDecl->Name.Content = paramDecl->getName(); + localVarDecl->Type = lowerType(paramDecl->Type); + + subVisitor.addDecl(localVarDecl); + + EntryPointParamPair paramPair; + paramPair.original = paramDecl; + paramPair.layout = paramLayout; + paramPair.lowered = localVarDecl; + + params.Add(paramPair); + } + + // Next generate globals for the inputs, and initialize them + for (auto paramPair : params) + { + auto paramDecl = paramPair.original; + if (paramDecl->HasModifier<InModifier>() + || paramDecl->HasModifier<InOutModifier>() + || !paramDecl->HasModifier<OutModifier>()) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.lowered, + paramPair.layout, + VaryingParameterDirection::Input); + } + } + + // Generate a local variable for the result, if any + RefPtr<Variable> resultVarDecl; + if (!loweredEntryPointFunc->ReturnType->Equals(ExpressionType::GetVoid())) + { + resultVarDecl = new Variable(); + resultVarDecl->Position = loweredEntryPointFunc->Position; + resultVarDecl->Name.Content = "_main_result"; + resultVarDecl->Type = TypeExp(loweredEntryPointFunc->ReturnType); + + subVisitor.addDecl(resultVarDecl); + } + + // Now generate a call to the entry-point function, using the local variables + auto entryPointDeclRef = makeDeclRef(loweredEntryPointFunc.Ptr()); + + RefPtr<FuncType> entryPointType = new FuncType(); + entryPointType->declRef = entryPointDeclRef; + + RefPtr<VarExpressionSyntaxNode> entryPointRef = new VarExpressionSyntaxNode(); + entryPointRef->name = loweredEntryPointFunc->getName(); + entryPointRef->declRef = entryPointDeclRef; + entryPointRef->Type = QualType(entryPointType); + + RefPtr<InvokeExpressionSyntaxNode> callExpr = new InvokeExpressionSyntaxNode(); + callExpr->FunctionExpr = entryPointRef; + callExpr->Type = QualType(loweredEntryPointFunc->ReturnType); + + // + for (auto paramPair : params) + { + auto localVarDecl = paramPair.lowered; + + RefPtr<VarExpressionSyntaxNode> varRef = new VarExpressionSyntaxNode(); + varRef->name = localVarDecl->getName(); + varRef->declRef = makeDeclRef(localVarDecl.Ptr()); + varRef->Type = QualType(localVarDecl->getType()); + + callExpr->Arguments.Add(varRef); + } + + if (resultVarDecl) + { + // Non-`void` return type, so we need to store it + subVisitor.assign(resultVarDecl, callExpr); + } + else + { + // `void` return type: just call it + subVisitor.addExprStmt(callExpr); + } + + + // Finally, generate logic to copy the outputs to global parameters + for (auto paramPair : params) + { + auto paramDecl = paramPair.original; + if (paramDecl->HasModifier<OutModifier>() + || paramDecl->HasModifier<InOutModifier>()) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.lowered, + paramPair.layout, + VaryingParameterDirection::Output); + } + } + if (resultVarDecl) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + resultVarDecl, + entryPointLayout->resultLayout, + VaryingParameterDirection::Output); + } + + mainDecl->Body = subVisitor.stmtBeingBuilt; + + + // Once we are done building the body, we append our new declaration to the program. + addMember(shared->loweredProgram, mainDecl); + return mainDecl; + +#if 0 + RefPtr<FunctionSyntaxNode> loweredDecl = new FunctionSyntaxNode(); + lowerDeclCommon(loweredDecl, entryPointDecl); + + // We create a sub-context appropriate for lowering the function body + + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + + // The parameters of the entry-point function must be translated + // to global-scope declarations + for (auto paramDecl : entryPointDecl->GetParameters()) + { + subVisitor.lowerShaderParameterToGLSLGLobals(paramDecl); + } + + // The output of the function must also be translated into a + // global-scope declaration. + auto loweredReturnType = lowerType(entryPointDecl->ReturnType); + RefPtr<Variable> resultGlobal; + if (!loweredReturnType->Equals(ExpressionType::GetVoid())) + { + resultGlobal = new Variable(); + // TODO: need a scheme for generating unique names + resultGlobal->Name.Content = "_main_result"; + resultGlobal->Type = loweredReturnType; + + addMember(shared->loweredProgram, resultGlobal); + } + + loweredDecl->Name.Content = "main"; + loweredDecl->ReturnType.type = ExpressionType::GetVoid(); + + // We will emit the body statement in a context where + // a `return` statmenet will generate writes to the + // result global that we declared. + subVisitor.resultVariable = resultGlobal; + + auto loweredBody = subVisitor.lowerStmt(entryPointDecl->Body); + subVisitor.addStmt(loweredBody); + + loweredDecl->Body = subVisitor.stmtBeingBuilt; + + // TODO: need to append writes for `out` parameters here... + + addMember(shared->loweredProgram, loweredDecl); + return loweredDecl; +#endif + } + + RefPtr<FunctionSyntaxNode> lowerEntryPoint( + FunctionSyntaxNode* entryPointDecl, + RefPtr<EntryPointLayout> entryPointLayout) + { + switch( getTarget() ) + { + // Default case: lower an entry point just like any other function + default: + return visit(entryPointDecl); + + // For Slang->GLSL translation, we need to lower things from HLSL-style + // declarations over to GLSL conventions + case CodeGenTarget::GLSL: + return lowerEntryPointToGLSL(entryPointDecl, entryPointLayout); + } + } + + RefPtr<FunctionSyntaxNode> lowerEntryPoint( + EntryPointRequest* entryPointRequest) + { + auto entryPointLayout = findEntryPointLayout(entryPointRequest); + auto entryPointDecl = entryPointLayout->entryPoint; + + return lowerEntryPoint( + entryPointDecl, + entryPointLayout); + } + + +}; + +static RefPtr<StructTypeLayout> getGlobalStructLayout( + ProgramLayout* programLayout) +{ + // Layout information for the global scope is either an ordinary + // `struct` in the common case, or a constant buffer in the case + // where there were global-scope uniforms. + auto globalScopeLayout = programLayout->globalScopeLayout; + StructTypeLayout* globalStructLayout = globalScopeLayout.As<StructTypeLayout>(); + if(globalStructLayout) + { } + else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>()) + { + // TODO: the `cbuffer` case really needs to be emitted very + // carefully, but that is beyond the scope of what a simple rewriter + // can easily do (without semantic analysis, etc.). + // + // The crux of the problem is that we need to collect all the + // global-scope uniforms (but not declarations that don't involve + // uniform storage...) and put them in a single `cbuffer` declaration, + // so that we can give it an explicit location. The fields in that + // declaration might use various type declarations, so we'd really + // need to emit all the type declarations first, and that involves + // some large scale reorderings. + // + // For now we will punt and just emit the declarations normally, + // and hope that the global-scope block (`$Globals`) gets auto-assigned + // the same location that we manually asigned it. + + auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout; + auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>(); + + // We expect all constant buffers to contain `struct` types for now + assert(elementTypeStructLayout); + + globalStructLayout = elementTypeStructLayout.Ptr(); + } + else + { + assert(!"unexpected"); + } + return globalStructLayout; +} + + +// Determine if the user is just trying to "rewrite" their input file +// into an output file. This will affect the way we approach code +// generation, because we want to leave their code "as is" whenever +// possible. +bool isRewriteRequest( + SourceLanguage sourceLanguage, + CodeGenTarget target) +{ + // TODO: we might only consider things to be a rewrite request + // in the specific case where checking is turned off... + + switch( target ) + { + default: + return false; + + case CodeGenTarget::HLSL: + return sourceLanguage == SourceLanguage::HLSL; + + case CodeGenTarget::GLSL: + return sourceLanguage == SourceLanguage::GLSL; + } +} + + + +LoweredEntryPoint lowerEntryPoint( + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target) +{ + SharedLoweringContext sharedContext; + sharedContext.programLayout = programLayout; + sharedContext.target = target; + + auto translationUnit = entryPoint->getTranslationUnit(); + + // Create a single module/program to hold all the lowered code + // (with the exception of instrinsic/stdlib declarations, which + // will be remain where they are) + RefPtr<ProgramSyntaxNode> loweredProgram = new ProgramSyntaxNode(); + sharedContext.loweredProgram = loweredProgram; + + LoweringVisitor visitor; + visitor.shared = &sharedContext; + visitor.parentDecl = loweredProgram; + + // We need to register the lowered program as the lowered version + // of the existing translation unit declaration. + + visitor.registerLoweredDecl( + loweredProgram, + translationUnit->SyntaxNode); + + // We also need to register the lowered program as the lowered version + // of any imported modules (since we will be collecting everything into + // a single module for code generation). + for (auto rr : entryPoint->compileRequest->loadedModulesList) + { + sharedContext.loweredDecls.Add( + rr, + loweredProgram); + } + + // We also want to remember the layout information for + // that declaration, so that we can apply it during emission + attachLayout(loweredProgram, + getGlobalStructLayout(programLayout)); + + + bool isRewrite = isRewriteRequest(translationUnit->sourceLanguage, target); + sharedContext.isRewrite = isRewrite; + + LoweredEntryPoint result; + if (isRewrite) + { + for (auto dd : translationUnit->SyntaxNode->Members) + { + visitor.translateDeclRef(dd); + } + } + else + { + auto loweredEntryPoint = visitor.lowerEntryPoint(entryPoint); + result.entryPoint = loweredEntryPoint; + } + + result.program = sharedContext.loweredProgram; + + return result; +} +} |
