From f145e09a6dcbcf326f782b3e6a76dbf291c792cf Mon Sep 17 00:00:00 2001 From: Tim Foley Date: Wed, 28 Jun 2017 13:34:38 -0700 Subject: Start to support cross-compilation via "lowering" pass - The big change here is the introduction of a "lowering" pass that takes an input AST from the semantic checker, and produces an output AST suitable for emitting. The intention is that he lowering pass is responsible for: - Stripping out unused code (when we have enough information to do so), by only outputting declarations that are transitively references from an entry point - When cross-compiling to GLSL, generating a suitable `void main()` entry point to wrap the user-written entry-point function - (Eventually) legalizing types in the program, by scalarizing aggregate types that mix uniform and resource types - (Eventually) instantiating generic declarations so that the resulting code only deals with fully specialized declarations - (Eventually) de-sugaring OOP constructs into basic "structs and functions" form - (Eventually) instantiating code that depends on interface types at the concrete types chosen - It is clear that there is still a lot of work to be done there, to this change is really about getting infrastructure in place without breaking the existing test cases. - One cleanup here is that we get rid of the idea of whole-translation-unit output, since that was specific to HLSL output, and there is really no strong reason for keeping it. Users should now just ask for the output for each entry point that they wanted to generate. - The biggest source of complexity for the lowering process is that it needs to produce the same AST structure as the input, to deal with the complexity of the rewriter case. That is, we need the output to be able to reproduce the input exactly in the case where we are rewriting and nothing needs to change, so the output format needs at least the degrees of freedom of the input. - As a result, we end up having to distinguish "rewriter" and "full" modes in both lowering and code-emit steps, so that we can react appropriately. - Generating a GLSL `main()` also adds a lot of complexity. Right now I'm using the simplest approach, where we always output the Slang/HLSL entry point as an ordinary function (as written) and then emit a simple GLSL `main()` to call it. I generate globals for all the shader inputs/outputs (these need to be scalarized and have explicit `location`s attached), and then collect these into the `struct` types of the original parameters as needed. - This approach will start to have some major down-sides once we have to deal with "arrayed" input/output - A long-term question here is how to replace entry-point parameter types with scalarized and/or "transposed" versions, while still letting the original code work as written (including copying those inputs to temporary arrays) - Split `BlockStatementSyntaxNode` into: - `BlockStmt` which just provides a scope around a `body` statement - `SeqStmt` which just allows multiple statements to be treated as one - Change how we emit `for` loops, to deal with the case where the initialization part might expand into multiple statements - Basically `for(A;B;C) {D}` becomes `{A; for(;B;C) {D}}`, so we can handle arbitrary statements for `A` - As an additional wrinkle, when we are rewriting HLSL, we just generate `A; for(;B;C) {D}` to deal with the broken scoping there - This change is needed because the lowering pass was sometimes expanding the original initialization statement `A` into a block `{A}`. Certainly if it declared multiple variables we'd need to handle it, and this seemed the easiest way - A more significant challenge for lowering would come if/when we ever wanted to support true short-circuiting behavior for `&&` and `||` - For right now I'm not changing the behavior of the "rewriter" mode, so we still have `UnparsedStmt` instances being generated, but it is clear that eventually we need to parse *all* input, even if we can't type-check 100% of it. This is required so that we can rewrite user code that might refer to a shader input with interface type. --- source/slang/lower.cpp | 1601 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1601 insertions(+) create mode 100644 source/slang/lower.cpp (limited to 'source/slang/lower.cpp') diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp new file mode 100644 index 000000000..674614cd3 --- /dev/null +++ b/source/slang/lower.cpp @@ -0,0 +1,1601 @@ +// lower.cpp +#include "lower.h" + +#include "visitor.h" + +namespace Slang +{ + +// + +template +struct StructuralTransformVisitorBase +{ + V* visitor; + + RefPtr transformDeclField(StatementSyntaxNode* stmt) + { + return visitor->translateStmtRef(stmt); + } + + RefPtr transformDeclField(Decl* decl) + { + return visitor->translateDeclRef(decl); + } + + template + DeclRef transformDeclField(DeclRef const& decl) + { + return visitor->translateDeclRef(decl).As(); + } + + TypeExp transformSyntaxField(TypeExp const& typeExp) + { + TypeExp result; + result.type = visitor->transformSyntaxField(typeExp.type); + return result; + } + + QualType transformSyntaxField(QualType const& qualType) + { + QualType result = qualType; + result.type = visitor->transformSyntaxField(qualType.type); + return result; + } + + RefPtr transformSyntaxField(ExpressionSyntaxNode* expr) + { + return visitor->transformSyntaxField(expr); + } + + RefPtr transformSyntaxField(StatementSyntaxNode* stmt) + { + return visitor->transformSyntaxField(stmt); + } + + RefPtr transformSyntaxField(DeclBase* decl) + { + return visitor->transformSyntaxField(decl); + } + + RefPtr transformSyntaxField(ScopeDecl* decl) + { + return visitor->transformSyntaxField(decl).As(); + } + + template + List transformSyntaxField(List const& list) + { + List result; + for (auto item : list) + { + result.Add(transformSyntaxField(item)); + } + return result; + } +}; + +template +struct StructuralTransformStmtVisitor + : StructuralTransformVisitorBase + , StmtVisitor, RefPtr> +{ + void transformFields(StatementSyntaxNode* result, StatementSyntaxNode* obj) + { + } + +#define SYNTAX_CLASS(NAME, BASE, ...) \ + RefPtr visit(NAME* obj) { \ + RefPtr result = new NAME(*obj); \ + transformFields(result, obj); \ + return result; \ + } \ + void transformFields(NAME* result, NAME* obj) { \ + transformFields((BASE*) result, (BASE*) obj); \ + +#define SYNTAX_FIELD(TYPE, NAME) result->NAME = this->transformSyntaxField(obj->NAME); +#define DECL_FIELD(TYPE, NAME) result->NAME = this->transformDeclField(obj->NAME); + +#define FIELD(TYPE, NAME) /* empty */ + +#define END_SYNTAX_CLASS() \ + } + +#include "object-meta-begin.h" +#include "stmt-defs.h" +#include "object-meta-end.h" + +}; + +template +RefPtr structuralTransform( + StatementSyntaxNode* stmt, + V* visitor) +{ + StructuralTransformStmtVisitor transformer; + transformer.visitor = visitor; + return transformer.dispatch(stmt); +} + +template +struct StructuralTransformExprVisitor + : StructuralTransformVisitorBase + , ExprVisitor, RefPtr> +{ + void transformFields(ExpressionSyntaxNode* result, ExpressionSyntaxNode* obj) + { + result->Type = transformSyntaxField(obj->Type); + } + + +#define SYNTAX_CLASS(NAME, BASE, ...) \ + RefPtr visit(NAME* obj) { \ + RefPtr result = new NAME(*obj); \ + transformFields(result, obj); \ + return result; \ + } \ + void transformFields(NAME* result, NAME* obj) { \ + transformFields((BASE*) result, (BASE*) obj); \ + +#define SYNTAX_FIELD(TYPE, NAME) result->NAME = transformSyntaxField(obj->NAME); +#define DECL_FIELD(TYPE, NAME) result->NAME = transformDeclField(obj->NAME); + +#define FIELD(TYPE, NAME) /* empty */ + +#define END_SYNTAX_CLASS() \ + } + +#include "object-meta-begin.h" +#include "expr-defs.h" +#include "object-meta-end.h" +}; + + +template +RefPtr structuralTransform( + ExpressionSyntaxNode* expr, + V* visitor) +{ + StructuralTransformExprVisitor transformer; + transformer.visitor = visitor; + return transformer.dispatch(expr); +} + +// + +// Pseudo-syntax used during lowering +class TupleDecl : public VarDeclBase +{ +public: + virtual void accept(IDeclVisitor *, void *) override + { + throw "unexpected"; + } + + List> decls; +}; + +// Pseudo-syntax used during lowering: +// represents an ordered list of expressions as a single unit +class TupleExpr : public ExpressionSyntaxNode +{ +public: + virtual void accept(IExprVisitor *, void *) override + { + throw "unexpected"; + } + + List> exprs; +}; + +struct SharedLoweringContext +{ + ProgramLayout* programLayout; + CodeGenTarget target; + + RefPtr loweredProgram; + + Dictionary> loweredDecls; + Dictionary mapLoweredDeclToOriginal; + + bool isRewrite; +}; + +static void attachLayout( + ModifiableSyntaxNode* syntax, + Layout* layout) +{ + RefPtr modifier = new ComputedLayoutModifier(); + modifier->layout = layout; + + addModifier(syntax, modifier); +} + +struct LoweringVisitor + : ExprVisitor> + , StmtVisitor + , DeclVisitor> + , TypeVisitor> + , ValVisitor> +{ + // + SharedLoweringContext* shared; + RefPtr substitutions; + + bool isBuildingStmt = false; + RefPtr stmtBeingBuilt; + + // If we *aren't* building a statement, then this + // is the container we should be adding declarations to + RefPtr parentDecl; + + // If we are in a context where a `return` should be turned + // into assignment to a variable (followed by a `return`), + // then this will point to that variable. + RefPtr resultVariable; + + CodeGenTarget getTarget() { return shared->target; } + + // + // Values + // + + RefPtr lowerVal(Val* val) + { + if (!val) return nullptr; + return ValVisitor::dispatch(val); + } + + RefPtr visit(GenericParamIntVal* val) + { + return new GenericParamIntVal(translateDeclRef(DeclRef(val->declRef)).As()); + } + + RefPtr visit(ConstantIntVal* val) + { + return val; + } + + // + // Types + // + + RefPtr lowerType( + ExpressionType* type) + { + return TypeVisitor::dispatch(type); + } + + TypeExp lowerType( + TypeExp const& typeExp) + { + TypeExp result; + result.type = lowerType(typeExp.type); + return result; + } + + RefPtr visit(ErrorType* type) + { + return type; + } + + RefPtr visit(OverloadGroupType* type) + { + return type; + } + + RefPtr visit(InitializerListType* type) + { + return type; + } + + RefPtr visit(GenericDeclRefType* type) + { + return new GenericDeclRefType(translateDeclRef(DeclRef(type->declRef)).As()); + } + + RefPtr visit(FuncType* type) + { + RefPtr loweredType = new FuncType(); + loweredType->declRef = translateDeclRef(DeclRef(type->declRef)).As(); + return loweredType; + } + + RefPtr visit(DeclRefType* type) + { + auto loweredDeclRef = translateDeclRef(type->declRef); + return DeclRefType::Create(loweredDeclRef); + } + + RefPtr visit(NamedExpressionType* type) + { + if (shared->target == CodeGenTarget::GLSL) + { + // GLSL does not support `typedef`, so we will lower it out of existence here + return lowerType(GetType(type->declRef)); + } + + return new NamedExpressionType(translateDeclRef(DeclRef(type->declRef)).As()); + } + + RefPtr visit(TypeType* type) + { + return new TypeType(lowerType(type->type)); + } + + RefPtr visit(ArrayExpressionType* type) + { + RefPtr loweredType = new ArrayExpressionType(); + loweredType->BaseType = lowerType(type->BaseType); + loweredType->ArrayLength = lowerVal(type->ArrayLength).As(); + return loweredType; + } + + RefPtr transformSyntaxField(ExpressionType* type) + { + return lowerType(type); + } + + // + // Expressions + // + + RefPtr lowerExpr( + ExpressionSyntaxNode* expr) + { + if (!expr) return nullptr; + return ExprVisitor::dispatch(expr); + } + + // catch-all + RefPtr visit( + ExpressionSyntaxNode* expr) + { + return structuralTransform(expr, this); + } + + RefPtr transformSyntaxField(ExpressionSyntaxNode* expr) + { + return lowerExpr(expr); + } + + void lowerExprCommon( + RefPtr loweredExpr, + RefPtr expr) + { + loweredExpr->Position = expr->Position; + loweredExpr->Type.type = lowerType(expr->Type.type); + } + + RefPtr createVarRef( + CodePosition const& loc, + VarDeclBase* decl) + { + if (auto tupleDecl = dynamic_cast(decl)) + { + return createTupleRef(loc, tupleDecl); + } + else + { + RefPtr result = new VarExpressionSyntaxNode(); + result->Position = loc; + result->Type.type = decl->Type.type; + result->declRef = makeDeclRef(decl); + return result; + } + } + + RefPtr createTupleRef( + CodePosition const& loc, + TupleDecl* decl) + { + RefPtr result = new TupleExpr(); + result->Position = loc; + result->Type.type = decl->Type.type; + + for (auto dd : decl->decls) + { + auto expr = createVarRef(loc, dd); + result->exprs.Add(expr); + } + + return result; + } + + RefPtr visit( + VarExpressionSyntaxNode* expr) + { + // If the expression didn't get resolved, we can leave it as-is + if (!expr->declRef) + return expr; + + auto loweredDeclRef = translateDeclRef(expr->declRef); + auto loweredDecl = loweredDeclRef.getDecl(); + + if (auto tupleDecl = dynamic_cast(loweredDecl)) + { + // If we are referencing a declaration that got tuple-ified, + // then we need to produce a tuple expression as well. + + return createTupleRef(expr->Position, tupleDecl); + } + + RefPtr loweredExpr = new VarExpressionSyntaxNode(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->declRef = loweredDeclRef; + return loweredExpr; + } + + RefPtr visit( + MemberExpressionSyntaxNode* expr) + { + auto loweredBase = lowerExpr(expr->BaseExpression); + + // Are we extracting an element from a tuple? + if (auto baseTuple = loweredBase.As()) + { + // We need to find the correct member expression, + // based on the actual tuple type. + + throw "unimplemented"; + } + + // Default handling: + auto loweredDeclRef = translateDeclRef(expr->declRef); + assert(!dynamic_cast(loweredDeclRef.getDecl())); + + RefPtr loweredExpr = new MemberExpressionSyntaxNode(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->BaseExpression = loweredBase; + loweredExpr->declRef = loweredDeclRef; + + return loweredExpr; + } + + // + // Statements + // + + StatementSyntaxNode* translateStmtRef( + StatementSyntaxNode* stmt) + { + throw "unimplemented"; + } + + RefPtr lowerStmt( + StatementSyntaxNode* stmt) + { + if(!stmt) + return nullptr; + + LoweringVisitor subVisitor = *this; + subVisitor.stmtBeingBuilt = nullptr; + + subVisitor.lowerStmtImpl(stmt); + + if( !subVisitor.stmtBeingBuilt ) + { + return new EmptyStatementSyntaxNode(); + } + else + { + return subVisitor.stmtBeingBuilt; + } + } + + void lowerStmtImpl( + StatementSyntaxNode* stmt) + { + StmtVisitor::dispatch(stmt); + } + + RefPtr visit(ScopeDecl* decl) + { + RefPtr loweredDecl = new ScopeDecl(); + lowerDeclCommon(loweredDecl, decl); + return loweredDecl; + } + + LoweringVisitor pushScope( + RefPtr loweredStmt, + RefPtr stmt) + { + loweredStmt->scopeDecl = translateDeclRef(stmt->scopeDecl).As(); + + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + subVisitor.parentDecl = loweredStmt->scopeDecl; + return subVisitor; + } + + void addStmtImpl( + RefPtr& dest, + StatementSyntaxNode* stmt) + { + // add a statement to the code we are building... + if( !dest ) + { + dest = stmt; + return; + } + + if (auto blockStmt = dest.As()) + { + addStmtImpl(blockStmt->body, stmt); + return; + } + + if (auto seqStmt = dest.As()) + { + seqStmt->stmts.Add(stmt); + } + else + { + RefPtr newSeqStmt = new SeqStmt(); + + newSeqStmt->stmts.Add(dest); + newSeqStmt->stmts.Add(stmt); + + dest = newSeqStmt; + } + + } + + void addStmt( + StatementSyntaxNode* stmt) + { + addStmtImpl(stmtBeingBuilt, stmt); + } + + void addExprStmt( + RefPtr expr) + { + // TODO: handle cases where the `expr` cannot be directly + // represented as a single statement + + RefPtr stmt = new ExpressionStatementSyntaxNode(); + stmt->Expression = expr; + addStmt(stmt); + } + + void visit(BlockStmt* stmt) + { + RefPtr loweredStmt = new BlockStmt(); + + LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); + subVisitor.stmtBeingBuilt = loweredStmt; + + subVisitor.lowerStmtImpl(stmt->body); + + addStmt(loweredStmt); + } + + void visit(SeqStmt* stmt) + { + for( auto ss : stmt->stmts ) + { + lowerStmtImpl(ss); + } + } + + void visit(ExpressionStatementSyntaxNode* stmt) + { + addExprStmt(lowerExpr(stmt->Expression)); + } + + void visit(VarDeclrStatementSyntaxNode* stmt) + { + DeclVisitor::dispatch(stmt->decl); + } + + // catch-all + void visit(StatementSyntaxNode* stmt) + { + auto loweredStmt = structuralTransform(stmt, this); + addStmt(loweredStmt); + } + + RefPtr transformSyntaxField(StatementSyntaxNode* stmt) + { + return lowerStmt(stmt); + } + + void lowerStmtCommon(StatementSyntaxNode* loweredStmt, StatementSyntaxNode* stmt) + { + loweredStmt->modifiers = stmt->modifiers; + } + + void assign( + RefPtr destExpr, + RefPtr srcExpr) + { + RefPtr assignExpr = new AssignExpr(); + assignExpr->Position = destExpr->Position; + assignExpr->left = destExpr; + assignExpr->right = srcExpr; + + addExprStmt(assignExpr); + } + + void assign(VarDeclBase* varDecl, RefPtr expr) + { + assign(createVarRef(expr->Position, varDecl), expr); + } + + void assign(RefPtr expr, VarDeclBase* varDecl) + { + assign(expr, createVarRef(expr->Position, varDecl)); + } + + void visit(ReturnStatementSyntaxNode* stmt) + { + auto loweredStmt = new ReturnStatementSyntaxNode(); + lowerStmtCommon(loweredStmt, stmt); + + if (stmt->Expression) + { + if (resultVariable) + { + // Do it as an assignment + assign(resultVariable, lowerExpr(stmt->Expression)); + } + else + { + // Simple case + loweredStmt->Expression = lowerExpr(stmt->Expression); + } + } + + addStmt(loweredStmt); + } + + // + // Declarations + // + + RefPtr translateVal(Val* val) + { + if (auto type = dynamic_cast(val)) + return lowerType(type); + + if (auto litVal = dynamic_cast(val)) + return val; + + throw 99; + } + + RefPtr translateSubstitutions( + Substitutions* substitutions) + { + if (!substitutions) return nullptr; + + RefPtr result = new Substitutions(); + result->genericDecl = translateDeclRef(substitutions->genericDecl).As(); + for (auto arg : substitutions->args) + { + result->args.Add(translateVal(arg)); + } + return result; + } + + static Decl* getModifiedDecl(Decl* decl) + { + if (!decl) return nullptr; + if (auto genericDecl = dynamic_cast(decl->ParentDecl)) + return genericDecl; + return decl; + } + + DeclRef translateDeclRef( + DeclRef const& decl) + { + DeclRef result; + result.decl = translateDeclRef(decl.decl); + result.substitutions = translateSubstitutions(decl.substitutions); + return result; + } + + RefPtr translateDeclRef( + Decl* decl) + { + if (!decl) return nullptr; + + // We don't want to translate references to built-in declarations, + // since they won't be subtituted anyway. + if (getModifiedDecl(decl)->HasModifier()) + return decl; + + // If any parent of the declaration was in the stdlib, then + // we need to skip it. + for(auto pp = decl; pp; pp = pp->ParentDecl) + { + if (pp->HasModifier()) + return decl; + } + + if (getModifiedDecl(decl)->HasModifier()) + return decl; + + RefPtr loweredDecl; + if (shared->loweredDecls.TryGetValue(decl, loweredDecl)) + return loweredDecl; + + // Time to force it + return lowerDecl(decl); + } + + RefPtr translateDeclRef( + ContainerDecl* decl) + { + return translateDeclRef((Decl*)decl).As(); + } + + RefPtr lowerDeclBase( + DeclBase* declBase) + { + if (Decl* decl = dynamic_cast(declBase)) + { + return lowerDecl(decl); + } + else + { + DeclVisitor::dispatch(declBase); + } + + } + + RefPtr lowerDecl( + Decl* decl) + { + RefPtr loweredDecl = DeclVisitor::dispatch(decl).As(); + return loweredDecl; + } + + static void addMember( + RefPtr containerDecl, + RefPtr memberDecl) + { + containerDecl->Members.Add(memberDecl); + memberDecl->ParentDecl = containerDecl.Ptr(); + } + + void addDecl( + Decl* decl) + { + if(isBuildingStmt) + { + RefPtr declStmt = new VarDeclrStatementSyntaxNode(); + declStmt->Position = decl->Position; + declStmt->decl = decl; + addStmt(declStmt); + } + + + // We will add the declaration to the current container declaration being + // translated, which the user will maintain via pua/pop. + // + + assert(parentDecl); + addMember(parentDecl, decl); + } + + void registerLoweredDecl(Decl* loweredDecl, Decl* decl) + { + shared->loweredDecls.Add(decl, loweredDecl); + + shared->mapLoweredDeclToOriginal.Add(loweredDecl, decl); + } + + void lowerDeclCommon( + Decl* loweredDecl, + Decl* decl) + { + registerLoweredDecl(loweredDecl, decl); + + loweredDecl->Position = decl->Position; + loweredDecl->Name = decl->getNameToken(); + + // Lower modifiers as needed + + // HACK: just doing a shallow copy of modifiers, which will + // suffice for most of them, but we need to do something + // better soon. + loweredDecl->modifiers = decl->modifiers; + + // deal with layout stuff + + auto loweredParent = translateDeclRef(decl->ParentDecl); + if (loweredParent) + { + auto layoutMod = loweredParent->FindModifier(); + if (layoutMod) + { + auto parentLayout = layoutMod->layout; + if (auto structLayout = parentLayout.As()) + { + RefPtr fieldLayout; + if (structLayout->mapVarToLayout.TryGetValue(decl, fieldLayout)) + { + attachLayout(loweredDecl, fieldLayout); + } + } + + // TODO: are there other cases to handle here? + } + } + } + + // Catch-all + RefPtr visit( + Decl* decl) + { + assert(!"unimplemented"); + return decl; + } + + RefPtr visit(ImportDecl* decl) + { + // No need to translate things here if we are + // in "full" mode, because we will selectively + // translate the imported declarations at their + // use sites(s). + if (!shared->isRewrite) + return nullptr; + + for (auto dd : decl->importedModuleDecl->Members) + { + translateDeclRef(dd); + } + + // Don't actually include a representation of + // the import declaration in the output + return nullptr; + } + + RefPtr visit(EmptyDecl* decl) + { + // Empty declarations are really only useful in GLSL, + // where they are used to hold metadata that doesn't + // attach to any particular shader parameter. + // + // TODO: Only lower empty declarations if we are + // rewriting a GLSL file, and otherwise ignore them. + // + RefPtr loweredDecl = new EmptyDecl(); + lowerDeclCommon(loweredDecl, decl); + + addDecl(loweredDecl); + + return loweredDecl; + } + + RefPtr visit(AggTypeDecl* decl) + { + // We want to lower any aggregate type declaration + // to just a `struct` type that contains its fields. + // + // Any non-field members (e.g., methods) will be + // lowered separately. + + // TODO: also need to figure out how to handle fields + // with types that should not be allowed in a `struct` + // for the chosen target. + // (also: what to do if there are no fields left + // after removing invalid ones?) + + RefPtr loweredDecl = new StructSyntaxNode(); + lowerDeclCommon(loweredDecl, decl); + + for (auto field : decl->getMembersOfType()) + { + // TODO: anything more to do than this? + addMember(loweredDecl, translateDeclRef(field)); + } + + addMember( + shared->loweredProgram, + loweredDecl); + + return loweredDecl; + } + + RefPtr lowerVarDeclCommon( + RefPtr loweredDecl, + VarDeclBase* decl) + { + lowerDeclCommon(loweredDecl, decl); + + loweredDecl->Type = lowerType(decl->Type); + loweredDecl->Expr = lowerExpr(decl->Expr); + + return loweredDecl; + } + + RefPtr visit( + Variable* decl) + { + auto loweredDecl = lowerVarDeclCommon(new Variable(), decl); + + // We need to add things to an appropriate scope, based on what + // we are referencing. + // + // If this is a global variable (program scope), then add it + // to the global scope. + RefPtr parentDecl = decl->ParentDecl; + if (auto parentModuleDecl = parentDecl.As()) + { + addMember( + translateDeclRef(parentModuleDecl), + loweredDecl); + } + // TODO: handle `static` function-scope variables + else + { + // A local variable declaration will get added to the + // statement scope we are currently processing. + addDecl(loweredDecl); + } + + return loweredDecl; + } + + RefPtr visit( + StructField* decl) + { + return lowerVarDeclCommon(new StructField(), decl); + } + + RefPtr visit( + ParameterSyntaxNode* decl) + { + return lowerVarDeclCommon(new ParameterSyntaxNode(), decl); + } + + RefPtr transformSyntaxField(DeclBase* decl) + { + return lowerDeclBase(decl); + } + + + RefPtr visit( + DeclGroup* group) + { + for (auto decl : group->decls) + { + lowerDecl(decl); + } + return nullptr; + } + + RefPtr visit( + FunctionDeclBase* decl) + { + // TODO: need to generate a name + + RefPtr loweredDecl = new FunctionSyntaxNode(); + lowerDeclCommon(loweredDecl, decl); + + // TODO: push scope for parent decl here... + + // TODO: need to copy over relevant modifiers + + for (auto paramDecl : decl->GetParameters()) + { + addMember(loweredDecl, translateDeclRef(paramDecl)); + } + + auto loweredReturnType = lowerType(decl->ReturnType); + + loweredDecl->ReturnType = loweredReturnType; + + // If we are a being called recurisvely, then we need to + // be careful not to let the context get polluted + LoweringVisitor subVisitor = *this; + subVisitor.resultVariable = nullptr; + subVisitor.stmtBeingBuilt = nullptr; + + loweredDecl->Body = subVisitor.lowerStmt(decl->Body); + + // A lowered function always becomes a global-scope function, + // even if it had been a member function when declared. + addMember(shared->loweredProgram, loweredDecl); + + return loweredDecl; + } + + // + // Entry Points + // + + EntryPointLayout* findEntryPointLayout( + EntryPointRequest* entryPointRequest) + { + for( auto entryPointLayout : shared->programLayout->entryPoints ) + { + if(entryPointLayout->entryPoint->getName() != entryPointRequest->name) + continue; + + if(entryPointLayout->profile != entryPointRequest->profile) + continue; + + // TODO: can't easily filter on translation unit here... + // Ideally the `EntryPointRequest` should get filled in with a pointer + // the specific function declaration that represents the entry point. + + return entryPointLayout.Ptr(); + } + + return nullptr; + } + + enum class VaryingParameterDirection + { + Input, + Output, + }; + + struct VaryingParameterArraySpec + { + VaryingParameterArraySpec* next = nullptr; + IntVal* elementCount; + }; + + struct VaryingParameterInfo + { + String name; + VaryingParameterDirection direction; + VaryingParameterArraySpec* arraySpecs = nullptr; + }; + + + void lowerSimpleShaderParameterToGLSLGlobal( + VaryingParameterInfo const& info, + RefPtr varType, + RefPtr varLayout, + RefPtr varExpr) + { + RefPtr type = varType; + + for (auto aa = info.arraySpecs; aa; aa = aa->next) + { + RefPtr arrayType = new ArrayExpressionType(); + arrayType->BaseType = type; + arrayType->ArrayLength = aa->elementCount; + + type = arrayType; + } + + // TODO: if we are declaring an SOA-ized array, + // this is where those array dimensions would need + // to be tacked on. + + RefPtr globalVarDecl = new Variable(); + globalVarDecl->Name.Content = info.name; + globalVarDecl->Type.type = type; + + addMember(shared->loweredProgram, globalVarDecl); + + // Add the layout information + RefPtr modifier = new ComputedLayoutModifier(); + modifier->layout = varLayout; + addModifier(globalVarDecl, modifier); + + // Need to generate an assignment in the right direction. + // + // TODO: for now I am just dealing with input: + + switch (info.direction) + { + case VaryingParameterDirection::Input: + addModifier(globalVarDecl, new InModifier()); + assign(varExpr, globalVarDecl); + break; + + case VaryingParameterDirection::Output: + addModifier(globalVarDecl, new OutModifier()); + + assign(globalVarDecl, varExpr); + break; + } + } + + void lowerShaderParameterToGLSLGLobalsRec( + VaryingParameterInfo const& info, + RefPtr varType, + RefPtr varLayout, + RefPtr varExpr) + { + assert(varLayout); + + if (auto basicType = varType->As()) + { + // handled below + } + else if (auto vectorType = varType->As()) + { + // handled below + } + else if (auto matrixType = varType->As()) + { + // handled below + } + else if (auto arrayType = varType->As()) + { + // We will accumulate information on the array + // types that were encoutnered on our walk down + // to the leaves, and then apply these array dimensions + // to any leaf parameters. + + VaryingParameterArraySpec arraySpec; + arraySpec.next = info.arraySpecs; + arraySpec.elementCount = arrayType->ArrayLength; + + VaryingParameterInfo arrayInfo = info; + arrayInfo.arraySpecs = &arraySpec; + + RefPtr subscriptExpr = new IndexExpressionSyntaxNode(); + subscriptExpr->Position = varExpr->Position; + subscriptExpr->BaseExpression = varExpr; + + // TODO: we need to construct syntax for a loop to initialize + // the array here... + throw "unimplemented"; + + // Note that we use the original `varLayout` that was passed in, + // since that is the layout that will ultimately need to be + // used on the array elements. + // + // TODO: That won't actually work if we ever had an array of + // heterogeneous stuff... + lowerShaderParameterToGLSLGLobalsRec( + arrayInfo, + arrayType->BaseType, + varLayout, + subscriptExpr); + + } + else if (auto declRefType = varType->As()) + { + auto declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.As()) + { + // The shader parameter had a structured type, so we need + // to destructure it into its constituent fields + + for (auto fieldDeclRef : getMembersOfType(aggTypeDeclRef)) + { + // Don't emit storage for `static` fields here, of course + if (fieldDeclRef.getDecl()->HasModifier()) + continue; + + RefPtr fieldExpr = new MemberExpressionSyntaxNode(); + fieldExpr->Position = varExpr->Position; + fieldExpr->Type.type = GetType(fieldDeclRef); + fieldExpr->declRef = fieldDeclRef; + fieldExpr->BaseExpression = varExpr; + + VaryingParameterInfo fieldInfo = info; + fieldInfo.name = info.name + "_" + fieldDeclRef.GetName(); + + // Need to find the layout for the given field... + Decl* originalFieldDecl = nullptr; + shared->mapLoweredDeclToOriginal.TryGetValue(fieldDeclRef.getDecl(), originalFieldDecl); + assert(originalFieldDecl); + + auto structTypeLayout = varLayout->typeLayout.As(); + assert(structTypeLayout); + + RefPtr fieldLayout; + structTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout); + assert(fieldLayout); + + lowerShaderParameterToGLSLGLobalsRec( + fieldInfo, + GetType(fieldDeclRef), + fieldLayout, + fieldExpr); + } + + // Okay, we are done with this parameter + return; + } + } + + // Default case: just try to emit things as-is + lowerSimpleShaderParameterToGLSLGlobal(info, varType, varLayout, varExpr); + } + + void lowerShaderParameterToGLSLGLobals( + RefPtr localVarDecl, + RefPtr paramLayout, + VaryingParameterDirection direction) + { + auto name = localVarDecl->getName(); + auto declRef = makeDeclRef(localVarDecl.Ptr()); + + RefPtr expr = new VarExpressionSyntaxNode(); + expr->name = name; + expr->declRef = declRef; + expr->Type.type = GetType(declRef); + + VaryingParameterInfo info; + info.name = name; + info.direction = direction; + + lowerShaderParameterToGLSLGLobalsRec( + info, + localVarDecl->getType(), + paramLayout, + expr); + } + + struct EntryPointParamPair + { + RefPtr original; + RefPtr layout; + RefPtr lowered; + }; + + RefPtr lowerEntryPointToGLSL( + FunctionSyntaxNode* entryPointDecl, + RefPtr entryPointLayout) + { + // First, loer the entry-point function as an ordinary function: + auto loweredEntryPointFunc = visit(entryPointDecl); + + // Now we will generate a `void main() { ... }` function to call the lowered code. + RefPtr mainDecl = new FunctionSyntaxNode(); + mainDecl->ReturnType.type = ExpressionType::GetVoid(); + mainDecl->Name.Content = "main"; + + // If the user's entry point was called `main` then rename it here + if (loweredEntryPointFunc->getName() == "main") + loweredEntryPointFunc->Name.Content = "main_"; + + // We will want to generate declarations into the body of our new `main()` + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + + // The parameters of the entry-point function will be translated to + // both a local variable (for passing to/from the entry point func), + // and to global variables (used for parameter passing) + + List params; + + // First generate declarations for the locals + for (auto paramDecl : entryPointDecl->GetParameters()) + { + RefPtr paramLayout; + entryPointLayout->mapVarToLayout.TryGetValue(paramDecl.Ptr(), paramLayout); + assert(paramLayout); + + RefPtr localVarDecl = new Variable(); + localVarDecl->Position = paramDecl->Position; + localVarDecl->Name.Content = paramDecl->getName(); + localVarDecl->Type = lowerType(paramDecl->Type); + + subVisitor.addDecl(localVarDecl); + + EntryPointParamPair paramPair; + paramPair.original = paramDecl; + paramPair.layout = paramLayout; + paramPair.lowered = localVarDecl; + + params.Add(paramPair); + } + + // Next generate globals for the inputs, and initialize them + for (auto paramPair : params) + { + auto paramDecl = paramPair.original; + if (paramDecl->HasModifier() + || paramDecl->HasModifier() + || !paramDecl->HasModifier()) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.lowered, + paramPair.layout, + VaryingParameterDirection::Input); + } + } + + // Generate a local variable for the result, if any + RefPtr resultVarDecl; + if (!loweredEntryPointFunc->ReturnType->Equals(ExpressionType::GetVoid())) + { + resultVarDecl = new Variable(); + resultVarDecl->Position = loweredEntryPointFunc->Position; + resultVarDecl->Name.Content = "_main_result"; + resultVarDecl->Type = TypeExp(loweredEntryPointFunc->ReturnType); + + subVisitor.addDecl(resultVarDecl); + } + + // Now generate a call to the entry-point function, using the local variables + auto entryPointDeclRef = makeDeclRef(loweredEntryPointFunc.Ptr()); + + RefPtr entryPointType = new FuncType(); + entryPointType->declRef = entryPointDeclRef; + + RefPtr entryPointRef = new VarExpressionSyntaxNode(); + entryPointRef->name = loweredEntryPointFunc->getName(); + entryPointRef->declRef = entryPointDeclRef; + entryPointRef->Type = QualType(entryPointType); + + RefPtr callExpr = new InvokeExpressionSyntaxNode(); + callExpr->FunctionExpr = entryPointRef; + callExpr->Type = QualType(loweredEntryPointFunc->ReturnType); + + // + for (auto paramPair : params) + { + auto localVarDecl = paramPair.lowered; + + RefPtr varRef = new VarExpressionSyntaxNode(); + varRef->name = localVarDecl->getName(); + varRef->declRef = makeDeclRef(localVarDecl.Ptr()); + varRef->Type = QualType(localVarDecl->getType()); + + callExpr->Arguments.Add(varRef); + } + + if (resultVarDecl) + { + // Non-`void` return type, so we need to store it + subVisitor.assign(resultVarDecl, callExpr); + } + else + { + // `void` return type: just call it + subVisitor.addExprStmt(callExpr); + } + + + // Finally, generate logic to copy the outputs to global parameters + for (auto paramPair : params) + { + auto paramDecl = paramPair.original; + if (paramDecl->HasModifier() + || paramDecl->HasModifier()) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.lowered, + paramPair.layout, + VaryingParameterDirection::Output); + } + } + if (resultVarDecl) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + resultVarDecl, + entryPointLayout->resultLayout, + VaryingParameterDirection::Output); + } + + mainDecl->Body = subVisitor.stmtBeingBuilt; + + + // Once we are done building the body, we append our new declaration to the program. + addMember(shared->loweredProgram, mainDecl); + return mainDecl; + +#if 0 + RefPtr loweredDecl = new FunctionSyntaxNode(); + lowerDeclCommon(loweredDecl, entryPointDecl); + + // We create a sub-context appropriate for lowering the function body + + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + + // The parameters of the entry-point function must be translated + // to global-scope declarations + for (auto paramDecl : entryPointDecl->GetParameters()) + { + subVisitor.lowerShaderParameterToGLSLGLobals(paramDecl); + } + + // The output of the function must also be translated into a + // global-scope declaration. + auto loweredReturnType = lowerType(entryPointDecl->ReturnType); + RefPtr resultGlobal; + if (!loweredReturnType->Equals(ExpressionType::GetVoid())) + { + resultGlobal = new Variable(); + // TODO: need a scheme for generating unique names + resultGlobal->Name.Content = "_main_result"; + resultGlobal->Type = loweredReturnType; + + addMember(shared->loweredProgram, resultGlobal); + } + + loweredDecl->Name.Content = "main"; + loweredDecl->ReturnType.type = ExpressionType::GetVoid(); + + // We will emit the body statement in a context where + // a `return` statmenet will generate writes to the + // result global that we declared. + subVisitor.resultVariable = resultGlobal; + + auto loweredBody = subVisitor.lowerStmt(entryPointDecl->Body); + subVisitor.addStmt(loweredBody); + + loweredDecl->Body = subVisitor.stmtBeingBuilt; + + // TODO: need to append writes for `out` parameters here... + + addMember(shared->loweredProgram, loweredDecl); + return loweredDecl; +#endif + } + + RefPtr lowerEntryPoint( + FunctionSyntaxNode* entryPointDecl, + RefPtr entryPointLayout) + { + switch( getTarget() ) + { + // Default case: lower an entry point just like any other function + default: + return visit(entryPointDecl); + + // For Slang->GLSL translation, we need to lower things from HLSL-style + // declarations over to GLSL conventions + case CodeGenTarget::GLSL: + return lowerEntryPointToGLSL(entryPointDecl, entryPointLayout); + } + } + + RefPtr lowerEntryPoint( + EntryPointRequest* entryPointRequest) + { + auto entryPointLayout = findEntryPointLayout(entryPointRequest); + auto entryPointDecl = entryPointLayout->entryPoint; + + return lowerEntryPoint( + entryPointDecl, + entryPointLayout); + } + + +}; + +static RefPtr getGlobalStructLayout( + ProgramLayout* programLayout) +{ + // Layout information for the global scope is either an ordinary + // `struct` in the common case, or a constant buffer in the case + // where there were global-scope uniforms. + auto globalScopeLayout = programLayout->globalScopeLayout; + StructTypeLayout* globalStructLayout = globalScopeLayout.As(); + if(globalStructLayout) + { } + else if(auto globalConstantBufferLayout = globalScopeLayout.As()) + { + // TODO: the `cbuffer` case really needs to be emitted very + // carefully, but that is beyond the scope of what a simple rewriter + // can easily do (without semantic analysis, etc.). + // + // The crux of the problem is that we need to collect all the + // global-scope uniforms (but not declarations that don't involve + // uniform storage...) and put them in a single `cbuffer` declaration, + // so that we can give it an explicit location. The fields in that + // declaration might use various type declarations, so we'd really + // need to emit all the type declarations first, and that involves + // some large scale reorderings. + // + // For now we will punt and just emit the declarations normally, + // and hope that the global-scope block (`$Globals`) gets auto-assigned + // the same location that we manually asigned it. + + auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout; + auto elementTypeStructLayout = elementTypeLayout.As(); + + // We expect all constant buffers to contain `struct` types for now + assert(elementTypeStructLayout); + + globalStructLayout = elementTypeStructLayout.Ptr(); + } + else + { + assert(!"unexpected"); + } + return globalStructLayout; +} + + +// Determine if the user is just trying to "rewrite" their input file +// into an output file. This will affect the way we approach code +// generation, because we want to leave their code "as is" whenever +// possible. +bool isRewriteRequest( + SourceLanguage sourceLanguage, + CodeGenTarget target) +{ + // TODO: we might only consider things to be a rewrite request + // in the specific case where checking is turned off... + + switch( target ) + { + default: + return false; + + case CodeGenTarget::HLSL: + return sourceLanguage == SourceLanguage::HLSL; + + case CodeGenTarget::GLSL: + return sourceLanguage == SourceLanguage::GLSL; + } +} + + + +LoweredEntryPoint lowerEntryPoint( + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target) +{ + SharedLoweringContext sharedContext; + sharedContext.programLayout = programLayout; + sharedContext.target = target; + + auto translationUnit = entryPoint->getTranslationUnit(); + + // Create a single module/program to hold all the lowered code + // (with the exception of instrinsic/stdlib declarations, which + // will be remain where they are) + RefPtr loweredProgram = new ProgramSyntaxNode(); + sharedContext.loweredProgram = loweredProgram; + + LoweringVisitor visitor; + visitor.shared = &sharedContext; + visitor.parentDecl = loweredProgram; + + // We need to register the lowered program as the lowered version + // of the existing translation unit declaration. + + visitor.registerLoweredDecl( + loweredProgram, + translationUnit->SyntaxNode); + + // We also need to register the lowered program as the lowered version + // of any imported modules (since we will be collecting everything into + // a single module for code generation). + for (auto rr : entryPoint->compileRequest->loadedModulesList) + { + sharedContext.loweredDecls.Add( + rr, + loweredProgram); + } + + // We also want to remember the layout information for + // that declaration, so that we can apply it during emission + attachLayout(loweredProgram, + getGlobalStructLayout(programLayout)); + + + bool isRewrite = isRewriteRequest(translationUnit->sourceLanguage, target); + sharedContext.isRewrite = isRewrite; + + LoweredEntryPoint result; + if (isRewrite) + { + for (auto dd : translationUnit->SyntaxNode->Members) + { + visitor.translateDeclRef(dd); + } + } + else + { + auto loweredEntryPoint = visitor.lowerEntryPoint(entryPoint); + result.entryPoint = loweredEntryPoint; + } + + result.program = sharedContext.loweredProgram; + + return result; +} +} -- cgit v1.2.3