diff options
| author | Tim Foley <tfoley@nvidia.com> | 2017-06-28 13:34:38 -0700 |
|---|---|---|
| committer | Tim Foley <tfoley@nvidia.com> | 2017-07-06 09:17:04 -0700 |
| commit | f145e09a6dcbcf326f782b3e6a76dbf291c792cf (patch) | |
| tree | 88a04619ceaaa37b87199dd82334cc9d102c156d /source/slang | |
| parent | c0d2c17bc73bc2a8863e086af3ea395ad09465ee (diff) | |
Start to support cross-compilation via "lowering" pass
- The big change here is the introduction of a "lowering" pass that takes an input AST from the semantic checker, and produces an output AST suitable for emitting. The intention is that he lowering pass is responsible for:
- Stripping out unused code (when we have enough information to do so), by only outputting declarations that are transitively references from an entry point
- When cross-compiling to GLSL, generating a suitable `void main()` entry point to wrap the user-written entry-point function
- (Eventually) legalizing types in the program, by scalarizing aggregate types that mix uniform and resource types
- (Eventually) instantiating generic declarations so that the resulting code only deals with fully specialized declarations
- (Eventually) de-sugaring OOP constructs into basic "structs and functions" form
- (Eventually) instantiating code that depends on interface types at the concrete types chosen
- It is clear that there is still a lot of work to be done there, to this change is really about getting infrastructure in place without breaking the existing test cases.
- One cleanup here is that we get rid of the idea of whole-translation-unit output, since that was specific to HLSL output, and there is really no strong reason for keeping it. Users should now just ask for the output for each entry point that they wanted to generate.
- The biggest source of complexity for the lowering process is that it needs to produce the same AST structure as the input, to deal with the complexity of the rewriter case. That is, we need the output to be able to reproduce the input exactly in the case where we are rewriting and nothing needs to change, so the output format needs at least the degrees of freedom of the input.
- As a result, we end up having to distinguish "rewriter" and "full" modes in both lowering and code-emit steps, so that we can react appropriately.
- Generating a GLSL `main()` also adds a lot of complexity. Right now I'm using the simplest approach, where we always output the Slang/HLSL entry point as an ordinary function (as written) and then emit a simple GLSL `main()` to call it. I generate globals for all the shader inputs/outputs (these need to be scalarized and have explicit `location`s attached), and then collect these into the `struct` types of the original parameters as needed.
- This approach will start to have some major down-sides once we have to deal with "arrayed" input/output
- A long-term question here is how to replace entry-point parameter types with scalarized and/or "transposed" versions, while still letting the original code work as written (including copying those inputs to temporary arrays)
- Split `BlockStatementSyntaxNode` into:
- `BlockStmt` which just provides a scope around a `body` statement
- `SeqStmt` which just allows multiple statements to be treated as one
- Change how we emit `for` loops, to deal with the case where the initialization part might expand into multiple statements
- Basically `for(A;B;C) {D}` becomes `{A; for(;B;C) {D}}`, so we can handle arbitrary statements for `A`
- As an additional wrinkle, when we are rewriting HLSL, we just generate `A; for(;B;C) {D}` to deal with the broken scoping there
- This change is needed because the lowering pass was sometimes expanding the original initialization statement `A` into a block `{A}`. Certainly if it declared multiple variables we'd need to handle it, and this seemed the easiest way
- A more significant challenge for lowering would come if/when we ever wanted to support true short-circuiting behavior for `&&` and `||`
- For right now I'm not changing the behavior of the "rewriter" mode, so we still have `UnparsedStmt` instances being generated, but it is clear that eventually we need to parse *all* input, even if we can't type-check 100% of it. This is required so that we can rewrite user code that might refer to a shader input with interface type.
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/check.cpp | 29 | ||||
| -rw-r--r-- | source/slang/compiler.cpp | 73 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 435 | ||||
| -rw-r--r-- | source/slang/emit.h | 9 | ||||
| -rw-r--r-- | source/slang/expr-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/lower.cpp | 1601 | ||||
| -rw-r--r-- | source/slang/lower.h | 38 | ||||
| -rw-r--r-- | source/slang/modifier-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 220 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 39 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 2 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 2 | ||||
| -rw-r--r-- | source/slang/stmt-defs.h | 10 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 24 | ||||
| -rw-r--r-- | source/slang/syntax.h | 9 | ||||
| -rw-r--r-- | source/slang/type-layout.h | 31 | ||||
| -rw-r--r-- | source/slang/val-defs.h | 3 |
17 files changed, 2182 insertions, 352 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index a3e061a3b..a79af3f37 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -1581,11 +1581,16 @@ namespace Slang DeclVisitor::dispatch(stmt->decl); } - void visit(BlockStatementSyntaxNode *stmt) + void visit(BlockStmt* stmt) { - for (auto & node : stmt->Statements) + checkStmt(stmt->body); + } + + void visit(SeqStmt* stmt) + { + for(auto ss : stmt->stmts) { - checkStmt(node); + checkStmt(ss); } } @@ -2414,6 +2419,24 @@ namespace Slang return appExpr; } + // + + RefPtr<ExpressionSyntaxNode> visit(AssignExpr* expr) + { + expr->left = CheckExpr(expr->left); + + auto type = expr->left->Type; + + expr->right = Coerce(type, CheckTerm(expr->right)); + + if (!type.IsLeftValue) + { + getSink()->diagnose(expr, Diagnostics::assignNonLValue); + } + expr->Type = type; + return expr; + } + // diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp index 9132624f9..19623a7aa 100644 --- a/source/slang/compiler.cpp +++ b/source/slang/compiler.cpp @@ -55,10 +55,11 @@ namespace Slang // - String emitHLSLForTranslationUnit( - TranslationUnitRequest* translationUnit) + String emitHLSLForEntryPoint( + EntryPointRequest* entryPoint) { - auto compileRequest = translationUnit->compileRequest; + auto compileRequest = entryPoint->compileRequest; + auto translationUnit = entryPoint->getTranslationUnit(); if (compileRequest->passThrough != PassThroughMode::None) { // Generate a string that includes the content of @@ -92,8 +93,8 @@ namespace Slang } else { - return emitProgram( - translationUnit->SyntaxNode.Ptr(), + return emitEntryPoint( + entryPoint, compileRequest->layout.Ptr(), CodeGenTarget::HLSL); } @@ -137,8 +138,8 @@ namespace Slang { // TODO(tfoley): need to pass along the entry point // so that we properly emit it as the `main` function. - return emitProgram( - translationUnit->SyntaxNode.Ptr(), + return emitEntryPoint( + entryPoint, compileRequest->layout.Ptr(), CodeGenTarget::GLSL); } @@ -180,15 +181,7 @@ namespace Slang assert(D3DCompile_); } - // The HLSL compiler will try to "canonicalize" our input file path, - // and we don't want it to do that, because they it won't report - // the same locations on error messages that we would. - // - // To work around that, we prepend a custom `#line` directive. - - auto translationUnit = entryPoint->getTranslationUnit(); - - auto hlslCode = emitHLSLForTranslationUnit(translationUnit); + auto hlslCode = emitHLSLForEntryPoint(entryPoint); ID3DBlob* codeBlob; ID3DBlob* diagnosticsBlob; @@ -223,6 +216,7 @@ namespace Slang if (FAILED(hr)) { // TODO(tfoley): What to do on failure? + exit(1); } return data; } @@ -400,6 +394,13 @@ namespace Slang switch (compileRequest->Target) { + case CodeGenTarget::HLSL: + { + String code = emitHLSLForEntryPoint(entryPoint); + result.outputSource = code; + } + break; + case CodeGenTarget::GLSL: { String code = emitGLSLForEntryPoint(entryPoint); @@ -505,45 +506,7 @@ namespace Slang TranslationUnitResult emitTranslationUnit( TranslationUnitRequest* translationUnit) { - auto compileRequest = translationUnit->compileRequest; - - // Most of our code generation targets will require us - // to proceed through one entry point at a time, but - // in some cases we can emit an entire translation unit - // in one go. - - switch (compileRequest->Target) - { - default: - // The default behavior is going to loop over all the entry - // points, and then collect an aggregate result. - return emitTranslationUnitEntryPoints(translationUnit); - - case CodeGenTarget::HLSL: - // When targetting HLSL, we can emit the entire translation unit - // as a single HLSL program, and include all the entry points. - { - - String hlsl = emitHLSLForTranslationUnit(translationUnit); - - TranslationUnitResult result; - result.outputSource = hlsl; - - // Because the user might ask for per-entry-point source, - // we will just attach the same string as the result for - // each entry point. - for( auto& entryPoint : translationUnit->entryPoints ) - { - EntryPointResult entryPointResult; - entryPointResult.outputSource = hlsl; - - entryPoint->result = entryPointResult; - } - - return result; - } - break; - } + return emitTranslationUnitEntryPoints(translationUnit); } #if 0 diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index dc7c68aa4..71b39aabb 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -1,6 +1,7 @@ // emit.cpp #include "emit.h" +#include "lower.h" #include "syntax.h" #include "type-layout.h" @@ -13,8 +14,16 @@ namespace Slang { -struct EmitContext +// Shared state for an entire emit session +struct SharedEmitContext { + // The target language we want to generate code for + CodeGenTarget target; + + // A set of words reserved by the target + Dictionary<String, String> reservedWords; + + // The string of code we've built so far StringBuilder sb; // Current source position for tracking purposes... @@ -22,12 +31,6 @@ struct EmitContext CodePosition nextSourceLocation; bool needToUpdateSourceLocation; - // The target language we want to generate code for - CodeGenTarget target; - - // A set of words reserved by the target - Dictionary<String, String> reservedWords; - // For GLSL output, we can't emit traidtional `#line` directives // with a file path in them, so we maintain a map that associates // each path with a unique integer, and then we output those @@ -45,6 +48,18 @@ struct EmitContext // TODO: This will probably change if we represent imports // explicitly in the layout data. StructTypeLayout* globalStructLayout; + + ProgramLayout* programLayout; +}; + +struct EmitContext +{ + // The shared context that is in effect + SharedEmitContext* shared; + + // Are we in "rewrite" mode, where we are trying to reproduce the input + // code as closely as posible? + bool isRewrite; }; // @@ -79,7 +94,7 @@ static void emitRawTextSpan(EmitContext* context, char const* textBegin, char co // TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things... auto len = int(textEnd - textBegin); - context->sb.Append(textBegin, len); + context->shared->sb.Append(textBegin, len); } static void emitRawText(EmitContext* context, char const* text) @@ -99,7 +114,7 @@ static void emitTextSpan(EmitContext* context, char const* textBegin, char const // Update our logical position // TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things... auto len = int(textEnd - textBegin); - context->loc.Col += len; + context->shared->loc.Col += len; } static void Emit(EmitContext* context, char const* textBegin, char const* textEnd) @@ -123,8 +138,8 @@ static void Emit(EmitContext* context, char const* textBegin, char const* textEn // At the end of a line, we need to update our tracking // information on code positions emitTextSpan(context, spanBegin, spanEnd); - context->loc.Line++; - context->loc.Col = 1; + context->shared->loc.Line++; + context->shared->loc.Col = 1; // Start a new span for emit purposes spanBegin = spanEnd; @@ -144,7 +159,7 @@ static void emit(EmitContext* context, String const& text) static bool isReservedWord(EmitContext* context, String const& name) { - return context->reservedWords.TryGetValue(name) != nullptr; + return context->shared->reservedWords.TryGetValue(name) != nullptr; } static void emitName( @@ -449,7 +464,7 @@ static bool isTargetIntrinsicModifierApplicable( // we expect. auto const& targetName = targetToken.Content; - switch(context->target) + switch(context->shared->target) { default: assert(!"unexpected"); @@ -658,7 +673,7 @@ static void emitCallExpr( case IntrinsicOp::InnerProduct_Vector_Vector: // HLSL allows `mul()` to be used as a synonym for `dot()`, // so we need to translate to `dot` for GLSL - if (context->target == CodeGenTarget::GLSL) + if (context->shared->target == CodeGenTarget::GLSL) { Emit(context, "dot("); EmitExpr(context, callExpr->Arguments[0]); @@ -677,7 +692,7 @@ static void emitCallExpr( // // The other critical detail here is that the way we handle matrix // conventions requires that the operands to the product be swapped. - if (context->target == CodeGenTarget::GLSL) + if (context->shared->target == CodeGenTarget::GLSL) { Emit(context, "(("); EmitExpr(context, callExpr->Arguments[1]); @@ -825,6 +840,13 @@ static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntax Emit(context, " : "); EmitExprWithPrecedence(context, selectExpr->Arguments[2], kPrecedence_Conditional); } + else if (auto assignExpr = expr.As<AssignExpr>()) + { + needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Assign); + EmitExprWithPrecedence(context, assignExpr->left, kPrecedence_Assign); + Emit(context, " = "); + EmitExprWithPrecedence(context, assignExpr->right, kPrecedence_Assign); + } else if (auto callExpr = expr.As<InvokeExpressionSyntaxNode>()) { emitCallExpr(context, callExpr, outerPrec); @@ -970,7 +992,7 @@ static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntax } else if (auto castExpr = expr.As<TypeCastExpressionSyntaxNode>()) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::GLSL: // GLSL requires constructor syntax for all conversions @@ -1227,7 +1249,7 @@ static void emitTextureType( EmitContext* context, RefPtr<TextureType> texType) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::HLSL: emitHLSLTextureType(context, texType); @@ -1247,7 +1269,7 @@ static void emitTextureSamplerType( EmitContext* context, RefPtr<TextureSamplerType> type) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::GLSL: emitGLSLTextureSamplerType(context, type); @@ -1263,7 +1285,7 @@ static void emitImageType( EmitContext* context, RefPtr<GLSLImageType> type) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::HLSL: emitHLSLTextureType(context, type); @@ -1300,7 +1322,7 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclara } else if (auto vecType = type->As<VectorExpressionType>()) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::GLSL: case CodeGenTarget::GLSL_Vulkan: @@ -1331,7 +1353,7 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclara } else if (auto matType = type->As<MatrixExpressionType>()) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::GLSL: case CodeGenTarget::GLSL_Vulkan: @@ -1386,7 +1408,7 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclara } else if (auto samplerStateType = type->As<SamplerStateType>()) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::HLSL: default: @@ -1474,7 +1496,8 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type) static void EmitType(EmitContext* context, TypeExp const& typeExp, Token const& nameToken) { - EmitType(context, typeExp.type, typeExp.exp->Position, nameToken.Content, nameToken.Position); + EmitType(context, typeExp.type, + typeExp.exp ? typeExp.exp->Position : CodePosition(), nameToken.Content, nameToken.Position); } static void EmitType(EmitContext* context, TypeExp const& typeExp, String const& name) @@ -1497,12 +1520,9 @@ static void EmitBlockStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt { // TODO(tfoley): support indenting Emit(context, "{\n"); - if( auto blockStmt = stmt.As<BlockStatementSyntaxNode>() ) + if( auto blockStmt = stmt.As<BlockStmt>() ) { - for (auto s : blockStmt->Statements) - { - EmitStmt(context, s); - } + EmitStmt(context, blockStmt->body); } else { @@ -1544,7 +1564,7 @@ static void emitLineDirective( emitRawText(context, " "); - if(context->target == CodeGenTarget::GLSL) + if(context->shared->target == CodeGenTarget::GLSL) { auto path = sourceLocation.FileName; @@ -1556,10 +1576,10 @@ static void emitLineDirective( // extension and then emit a traditional line directive. int id = 0; - if(!context->mapGLSLSourcePathToID.TryGetValue(path, id)) + if(!context->shared->mapGLSLSourcePathToID.TryGetValue(path, id)) { - id = context->glslSourceIDCount++; - context->mapGLSLSourcePathToID.Add(path, id); + id = context->shared->glslSourceIDCount++; + context->shared->mapGLSLSourcePathToID.Add(path, id); } sprintf(buffer, "%d", id); @@ -1611,9 +1631,9 @@ static void emitLineDirectiveAndUpdateSourceLocation( { emitLineDirective(context, sourceLocation); - context->loc.FileName = sourceLocation.FileName; - context->loc.Line = sourceLocation.Line; - context->loc.Col = 1; + context->shared->loc.FileName = sourceLocation.FileName; + context->shared->loc.Line = sourceLocation.Line; + context->shared->loc.Col = 1; } static void emitLineDirectiveIfNeeded( @@ -1628,24 +1648,24 @@ static void emitLineDirectiveIfNeeded( // a differnet file or line, *or* if the source location is // somehow later on the line than what we want to emit, // then we need to emit a new `#line` directive. - if(sourceLocation.FileName != context->loc.FileName - || sourceLocation.Line != context->loc.Line - || sourceLocation.Col < context->loc.Col) + if(sourceLocation.FileName != context->shared->loc.FileName + || sourceLocation.Line != context->shared->loc.Line + || sourceLocation.Col < context->shared->loc.Col) { // Special case: if we are in the same file, and within a small number // of lines of the target location, then go ahead and output newlines // to get us caught up. enum { kSmallLineCount = 3 }; - auto lineDiff = sourceLocation.Line - context->loc.Line; - if(sourceLocation.FileName == context->loc.FileName - && sourceLocation.Line > context->loc.Line + auto lineDiff = sourceLocation.Line - context->shared->loc.Line; + if(sourceLocation.FileName == context->shared->loc.FileName + && sourceLocation.Line > context->shared->loc.Line && lineDiff <= kSmallLineCount) { for(int ii = 0; ii < lineDiff; ++ii ) { Emit(context, "\n"); } - assert(sourceLocation.Line == context->loc.Line); + assert(sourceLocation.Line == context->shared->loc.Line); } else { @@ -1661,14 +1681,14 @@ static void emitLineDirectiveIfNeeded( // came in as spaces or tabs, so there is necessarily going to be // coupling between how the downstream compiler counts columns, // and how we do. - if(sourceLocation.Col > context->loc.Col) + if(sourceLocation.Col > context->shared->loc.Col) { - int delta = sourceLocation.Col - context->loc.Col; + int delta = sourceLocation.Col - context->shared->loc.Col; for( int ii = 0; ii < delta; ++ii ) { emitRawText(context, " "); } - context->loc.Col = sourceLocation.Col; + context->shared->loc.Col = sourceLocation.Col; } } @@ -1680,22 +1700,22 @@ static void advanceToSourceLocation( if(sourceLocation.Line <= 0) return; - context->needToUpdateSourceLocation = true; - context->nextSourceLocation = sourceLocation; + context->shared->needToUpdateSourceLocation = true; + context->shared->nextSourceLocation = sourceLocation; } static void flushSourceLocationChange( EmitContext* context) { - if(!context->needToUpdateSourceLocation) + if(!context->shared->needToUpdateSourceLocation) return; // Note: the order matters here, because trying to update // the source location may involve outputting text that // advances the location, and outputting text is what // triggers this flush operation. - context->needToUpdateSourceLocation = false; - emitLineDirectiveIfNeeded(context, context->nextSourceLocation); + context->shared->needToUpdateSourceLocation = false; + emitLineDirectiveIfNeeded(context, context->shared->nextSourceLocation); } static void emitTokenWithLocation(EmitContext* context, Token const& token) @@ -1737,11 +1757,19 @@ static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt) // Try to ensure that debugging can find the right location advanceToSourceLocation(context, stmt->Position); - if (auto blockStmt = stmt.As<BlockStatementSyntaxNode>()) + if (auto blockStmt = stmt.As<BlockStmt>()) { EmitBlockStmt(context, blockStmt); return; } + else if (auto seqStmt = stmt.As<SeqStmt>()) + { + for (auto ss : seqStmt->stmts) + { + EmitStmt(context, ss); + } + return; + } else if( auto unparsedStmt = stmt.As<UnparsedStmt>() ) { EmitUnparsedStmt(context, unparsedStmt); @@ -1784,17 +1812,43 @@ static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt) } else if (auto forStmt = stmt.As<ForStatementSyntaxNode>()) { - EmitLoopAttributes(context, forStmt); + // We are going to always take a `for` loop like: + // + // for(A; B; C) { D } + // + // and emit it as: + // + // { A; for(; B; C) { D } } + // + // This ensures that we are robust against any kind + // of statement appearing in `A`, including things + // that might occur due to lowering steps. + // - Emit(context, "for("); - if (auto initStmt = forStmt->InitialStatement) + // The one wrinkle is that HLSL implements the + // bad approach to scoping a `for` loop variable, + // so we need to avoid those outer `{...}` when + // we are generating HLSL via "rewrite" (that is, + // without our semantic checks). + // + bool brokenScoping = false; + if (context->shared->target == CodeGenTarget::HLSL + && context->isRewrite) { - EmitStmt(context, initStmt); + brokenScoping = true; } - else + + auto initStmt = forStmt->InitialStatement; + if(initStmt) { - Emit(context, ";"); + if(!brokenScoping) + Emit(context, "{\n"); + EmitStmt(context, initStmt); } + + EmitLoopAttributes(context, forStmt); + + Emit(context, "for(;"); if (auto testExp = forStmt->PredicateExpression) { EmitExpr(context, testExp); @@ -1806,6 +1860,13 @@ static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt) } Emit(context, ")\n"); EmitBlockStmt(context, forStmt->Statement); + + if (initStmt) + { + if(!brokenScoping) + Emit(context, "}\n"); + } + return; } else if (auto whileStmt = stmt.As<WhileStatementSyntaxNode>()) @@ -2085,7 +2146,7 @@ static void EmitSemantic(EmitContext* context, RefPtr<HLSLSemantic> semantic, ES static void EmitSemantics(EmitContext* context, RefPtr<Decl> decl, ESemanticMask mask = kESemanticMask_Default ) { // Don't emit semantics if we aren't translating down to HLSL - switch (context->target) + switch (context->shared->target) { case CodeGenTarget::HLSL: break; @@ -2263,7 +2324,7 @@ static void emitHLSLRegisterSemantics( { if (!layout) return; - switch( context->target ) + switch( context->shared->target ) { default: return; @@ -2278,6 +2339,25 @@ static void emitHLSLRegisterSemantics( } } +static RefPtr<VarLayout> maybeFetchLayout( + RefPtr<Decl> decl, + RefPtr<VarLayout> layout) +{ + // If we have already found layout info, don't go searching + if (layout) return layout; + + // Otherwise, we need to look and see if computed layout + // information has been attached to the declaration. + auto modifier = decl->FindModifier<ComputedLayoutModifier>(); + if (!modifier) return nullptr; + + auto computedLayout = modifier->layout; + assert(computedLayout); + + auto varLayout = computedLayout.As<VarLayout>(); + return varLayout; +} + static void emitHLSLParameterBlockDecl( EmitContext* context, RefPtr<VarDeclBase> varDecl, @@ -2292,6 +2372,7 @@ static void emitHLSLParameterBlockDecl( assert(declRefType); // We expect to always have layout information + layout = maybeFetchLayout(varDecl, layout); assert(layout); // We expect the layout to be for a structured type... @@ -2325,13 +2406,16 @@ static void emitHLSLParameterBlockDecl( Emit(context, "\n{\n"); if (auto structRef = declRefType->declRef.As<StructSyntaxNode>()) { + int fieldCounter = 0; + for (auto field : getMembersOfType<StructField>(structRef)) { + int fieldIndex = fieldCounter++; + EmitVarDeclCommon(context, field); - RefPtr<VarLayout> fieldLayout; - structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout); - assert(fieldLayout); + RefPtr<VarLayout> fieldLayout = structTypeLayout->fields[fieldIndex]; + assert(fieldLayout->varDecl.GetName() == field.GetName()); // Emit explicit layout annotations for every field for( auto rr : fieldLayout->resourceInfos ) @@ -2415,7 +2499,7 @@ emitGLSLLayoutQualifiers( { if(!layout) return; - switch( context->target ) + switch( context->shared->target ) { default: return; @@ -2493,7 +2577,7 @@ static void emitGLSLParameterBlockDecl( { RefPtr<VarLayout> fieldLayout; structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout); - assert(fieldLayout); +// assert(fieldLayout); // TODO(tfoley): We may want to emit *some* of these, // some of the time... @@ -2521,7 +2605,7 @@ static void emitParameterBlockDecl( RefPtr<ParameterBlockType> parameterBlockType, RefPtr<VarLayout> layout) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::HLSL: emitHLSLParameterBlockDecl(context, varDecl, parameterBlockType, layout); @@ -2539,6 +2623,8 @@ static void emitParameterBlockDecl( static void EmitVarDecl(EmitContext* context, RefPtr<VarDeclBase> decl, RefPtr<VarLayout> layout) { + layout = maybeFetchLayout(decl, layout); + // As a special case, a variable using a parameter block type // will be translated into a declaration using the more primitive // language syntax. @@ -2606,7 +2692,7 @@ static void emitGLSLPreprocessorDirectives( EmitContext* context, RefPtr<ProgramSyntaxNode> program) { - switch(context->target) + switch(context->shared->target) { // Don't emit this stuff unless we are targetting GLSL default: @@ -2658,78 +2744,6 @@ static void emitGLSLPreprocessorDirectives( // TODO: handle other cases... } -static void EmitProgram( - EmitContext* context, - RefPtr<ProgramSyntaxNode> program, - RefPtr<ProgramLayout> programLayout) -{ - // There may be global-scope modifiers that we should emit now - emitGLSLPreprocessorDirectives(context, program); - - switch(context->target) - { - case CodeGenTarget::GLSL: - { - // TODO(tfoley): Need a plan for how to enable/disable these as needed... -// Emit(context, "#extension GL_GOOGLE_cpp_style_line_directive : require\n"); - } - break; - - default: - break; - } - - - // Layout information for the global scope is either an ordinary - // `struct` in the common case, or a constant buffer in the case - // where there were global-scope uniforms. - auto globalScopeLayout = programLayout->globalScopeLayout; - if( auto globalStructLayout = globalScopeLayout.As<StructTypeLayout>() ) - { - context->globalStructLayout = globalStructLayout.Ptr(); - - // The `struct` case is easy enough to handle: we just - // emit all the declarations directly, using their layout - // information as a guideline. - EmitDeclsInContainerUsingLayout(context, program, globalStructLayout); - } - else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>()) - { - // TODO: the `cbuffer` case really needs to be emitted very - // carefully, but that is beyond the scope of what a simple rewriter - // can easily do (without semantic analysis, etc.). - // - // The crux of the problem is that we need to collect all the - // global-scope uniforms (but not declarations that don't involve - // uniform storage...) and put them in a single `cbuffer` declaration, - // so that we can give it an explicit location. The fields in that - // declaration might use various type declarations, so we'd really - // need to emit all the type declarations first, and that involves - // some large scale reorderings. - // - // For now we will punt and just emit the declarations normally, - // and hope that the global-scope block (`$Globals`) gets auto-assigned - // the same location that we manually asigned it. - - auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout; - auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>(); - - // We expect all constant buffers to contain `struct` types for now - assert(elementTypeStructLayout); - - context->globalStructLayout = elementTypeStructLayout.Ptr(); - - EmitDeclsInContainerUsingLayout( - context, - program, - elementTypeStructLayout); - } - else - { - assert(!"unexpected"); - } -} - static void EmitDeclImpl(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayout> layout) { // Don't emit code for declarations that came from the stdlib. @@ -2783,18 +2797,18 @@ static void EmitDeclImpl(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayo // We might import the same module along two different paths, // so we need to be careful to only emit each module once // per output. - if(!context->modulesAlreadyEmitted.Contains(moduleDecl)) + if(!context->shared->modulesAlreadyEmitted.Contains(moduleDecl)) { // Add the module to our set before emitting it, just // in case a circular reference would lead us to // infinite recursion (but that shouldn't be allowed // in the first place). - context->modulesAlreadyEmitted.Add(moduleDecl); + context->shared->modulesAlreadyEmitted.Add(moduleDecl); // TODO: do we need to modify the code generation environment at // all when doing this recursive emit? - EmitDeclsInContainerUsingLayout(context, moduleDecl, context->globalStructLayout); + EmitDeclsInContainerUsingLayout(context, moduleDecl, context->shared->globalStructLayout); } return; @@ -2839,7 +2853,7 @@ static void registerReservedWord( EmitContext* context, String const& name) { - context->reservedWords.Add(name, name); + context->shared->reservedWords.Add(name, name); } static void registerReservedWords( @@ -2847,7 +2861,7 @@ static void registerReservedWords( { #define WORD(NAME) registerReservedWord(context, #NAME) - switch (context->target) + switch (context->shared->target) { case CodeGenTarget::GLSL: WORD(attribute); @@ -2990,68 +3004,115 @@ static void registerReservedWords( } } -String emitProgram( - ProgramSyntaxNode* program, +bool isRewriteRequest( + SourceLanguage sourceLanguage, + CodeGenTarget target); + +String emitEntryPoint( + EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target) { - // TODO(tfoley): only emit symbols on-demand, as needed by a particular entry point + auto translationUnit = entryPoint->getTranslationUnit(); + + SharedEmitContext sharedContext; + sharedContext.target = target; + + sharedContext.programLayout = programLayout; + + // Layout information for the global scope is either an ordinary + // `struct` in the common case, or a constant buffer in the case + // where there were global-scope uniforms. + auto globalScopeLayout = programLayout->globalScopeLayout; + StructTypeLayout* globalStructLayout = nullptr; + if( auto globalStructLayout = globalScopeLayout.As<StructTypeLayout>() ) + { + globalStructLayout = globalStructLayout.Ptr(); + } + else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>()) + { + // TODO: the `cbuffer` case really needs to be emitted very + // carefully, but that is beyond the scope of what a simple rewriter + // can easily do (without semantic analysis, etc.). + // + // The crux of the problem is that we need to collect all the + // global-scope uniforms (but not declarations that don't involve + // uniform storage...) and put them in a single `cbuffer` declaration, + // so that we can give it an explicit location. The fields in that + // declaration might use various type declarations, so we'd really + // need to emit all the type declarations first, and that involves + // some large scale reorderings. + // + // For now we will punt and just emit the declarations normally, + // and hope that the global-scope block (`$Globals`) gets auto-assigned + // the same location that we manually asigned it. + + auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout; + auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>(); + + // We expect all constant buffers to contain `struct` types for now + assert(elementTypeStructLayout); + + globalStructLayout = elementTypeStructLayout.Ptr(); + } + else + { + assert(!"unexpected"); + } + sharedContext.globalStructLayout = globalStructLayout; EmitContext context; - context.target = target; + context.shared = &sharedContext; + context.isRewrite = isRewriteRequest( + translationUnit->sourceLanguage, + target); + // TODO: this should only need to take the shared context registerReservedWords(&context); - EmitProgram(&context, program, programLayout); + auto translationUnitSyntax = translationUnit->SyntaxNode.Ptr(); - String code = context.sb.ProduceString(); - return code; - -#if 0 - // HACK(tfoley): Invoke the D3D HLSL compiler on the result, to validate it + // There may be global-scope modifiers that we should emit now + emitGLSLPreprocessorDirectives(&context, translationUnitSyntax); -#ifdef _WIN32 + switch(target) { - HMODULE d3dCompiler = LoadLibraryA("d3dcompiler_47"); - assert(d3dCompiler); - - pD3DCompile D3DCompile_ = (pD3DCompile)GetProcAddress(d3dCompiler, "D3DCompile"); - assert(D3DCompile_); - - ID3DBlob* codeBlob; - ID3DBlob* diagnosticsBlob; - HRESULT hr = D3DCompile_( - code.begin(), - code.Length(), - "slang", - nullptr, - nullptr, - "main", - "ps_5_0", - 0, - 0, - &codeBlob, - &diagnosticsBlob); - if (codeBlob) codeBlob->Release(); - if (diagnosticsBlob) - { - String diagnostics = (char const*) diagnosticsBlob->GetBufferPointer(); - fprintf(stderr, "%s", diagnostics.begin()); - OutputDebugStringA(diagnostics.begin()); - diagnosticsBlob->Release(); - } - if (FAILED(hr)) + case CodeGenTarget::GLSL: { - int f = 9; + // TODO(tfoley): Need a plan for how to enable/disable these as needed... +// Emit(context, "#extension GL_GOOGLE_cpp_style_line_directive : require\n"); } + break; + + default: + break; } - #include <d3dcompiler.h> -#endif + auto lowered = lowerEntryPoint(entryPoint, programLayout, target); + + EmitDeclsInContainer(&context, lowered.program.Ptr()); + +#if 0 + if( isRewrite ) + { + // In rewrite mode, we will just emit the text of the translation unit as given, + // and not pay attention to the specific entry point that was requested. + // + // It is a user error to request GLSL output and have an entry point name + // other than `main`. + EmitDeclsInContainerUsingLayout(&context, translationUnitSyntax, globalStructLayout); + } + else + { + // We are being asked to emit a single entry point in "full" mode. + emitEntryPoint(&context, entryPoint); + } #endif -} + String code = sharedContext.sb.ProduceString(); + return code; +} } // namespace Slang diff --git a/source/slang/emit.h b/source/slang/emit.h index 4f3e98c8d..da1ac9f08 100644 --- a/source/slang/emit.h +++ b/source/slang/emit.h @@ -8,11 +8,14 @@ namespace Slang { - class ProgramSyntaxNode; + class EntryPointRequest; class ProgramLayout; + class TranslationUnitRequest; - String emitProgram( - ProgramSyntaxNode* program, + // Emit code for a single entry point, based on + // the input translation unit. + String emitEntryPoint( + EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target); } diff --git a/source/slang/expr-defs.h b/source/slang/expr-defs.h index 59bad37e7..ca5bfacb8 100644 --- a/source/slang/expr-defs.h +++ b/source/slang/expr-defs.h @@ -107,3 +107,8 @@ SYNTAX_CLASS(SharedTypeExpr, ExpressionSyntaxNode) SYNTAX_FIELD(TypeExp, base) END_SYNTAX_CLASS() +SYNTAX_CLASS(AssignExpr, ExpressionSyntaxNode) + SYNTAX_FIELD(RefPtr<ExpressionSyntaxNode>, left); + SYNTAX_FIELD(RefPtr<ExpressionSyntaxNode>, right); +END_SYNTAX_CLASS() + diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp new file mode 100644 index 000000000..674614cd3 --- /dev/null +++ b/source/slang/lower.cpp @@ -0,0 +1,1601 @@ +// lower.cpp +#include "lower.h" + +#include "visitor.h" + +namespace Slang +{ + +// + +template<typename V> +struct StructuralTransformVisitorBase +{ + V* visitor; + + RefPtr<StatementSyntaxNode> transformDeclField(StatementSyntaxNode* stmt) + { + return visitor->translateStmtRef(stmt); + } + + RefPtr<Decl> transformDeclField(Decl* decl) + { + return visitor->translateDeclRef(decl); + } + + template<typename T> + DeclRef<T> transformDeclField(DeclRef<T> const& decl) + { + return visitor->translateDeclRef(decl).As<T>(); + } + + TypeExp transformSyntaxField(TypeExp const& typeExp) + { + TypeExp result; + result.type = visitor->transformSyntaxField(typeExp.type); + return result; + } + + QualType transformSyntaxField(QualType const& qualType) + { + QualType result = qualType; + result.type = visitor->transformSyntaxField(qualType.type); + return result; + } + + RefPtr<ExpressionSyntaxNode> transformSyntaxField(ExpressionSyntaxNode* expr) + { + return visitor->transformSyntaxField(expr); + } + + RefPtr<StatementSyntaxNode> transformSyntaxField(StatementSyntaxNode* stmt) + { + return visitor->transformSyntaxField(stmt); + } + + RefPtr<DeclBase> transformSyntaxField(DeclBase* decl) + { + return visitor->transformSyntaxField(decl); + } + + RefPtr<ScopeDecl> transformSyntaxField(ScopeDecl* decl) + { + return visitor->transformSyntaxField(decl).As<ScopeDecl>(); + } + + template<typename T> + List<T> transformSyntaxField(List<T> const& list) + { + List<T> result; + for (auto item : list) + { + result.Add(transformSyntaxField(item)); + } + return result; + } +}; + +template<typename V> +struct StructuralTransformStmtVisitor + : StructuralTransformVisitorBase<V> + , StmtVisitor<StructuralTransformStmtVisitor<V>, RefPtr<StatementSyntaxNode>> +{ + void transformFields(StatementSyntaxNode* result, StatementSyntaxNode* obj) + { + } + +#define SYNTAX_CLASS(NAME, BASE, ...) \ + RefPtr<StatementSyntaxNode> visit(NAME* obj) { \ + RefPtr<NAME> result = new NAME(*obj); \ + transformFields(result, obj); \ + return result; \ + } \ + void transformFields(NAME* result, NAME* obj) { \ + transformFields((BASE*) result, (BASE*) obj); \ + +#define SYNTAX_FIELD(TYPE, NAME) result->NAME = this->transformSyntaxField(obj->NAME); +#define DECL_FIELD(TYPE, NAME) result->NAME = this->transformDeclField(obj->NAME); + +#define FIELD(TYPE, NAME) /* empty */ + +#define END_SYNTAX_CLASS() \ + } + +#include "object-meta-begin.h" +#include "stmt-defs.h" +#include "object-meta-end.h" + +}; + +template<typename V> +RefPtr<StatementSyntaxNode> structuralTransform( + StatementSyntaxNode* stmt, + V* visitor) +{ + StructuralTransformStmtVisitor<V> transformer; + transformer.visitor = visitor; + return transformer.dispatch(stmt); +} + +template<typename V> +struct StructuralTransformExprVisitor + : StructuralTransformVisitorBase<V> + , ExprVisitor<StructuralTransformExprVisitor<V>, RefPtr<ExpressionSyntaxNode>> +{ + void transformFields(ExpressionSyntaxNode* result, ExpressionSyntaxNode* obj) + { + result->Type = transformSyntaxField(obj->Type); + } + + +#define SYNTAX_CLASS(NAME, BASE, ...) \ + RefPtr<ExpressionSyntaxNode> visit(NAME* obj) { \ + RefPtr<NAME> result = new NAME(*obj); \ + transformFields(result, obj); \ + return result; \ + } \ + void transformFields(NAME* result, NAME* obj) { \ + transformFields((BASE*) result, (BASE*) obj); \ + +#define SYNTAX_FIELD(TYPE, NAME) result->NAME = transformSyntaxField(obj->NAME); +#define DECL_FIELD(TYPE, NAME) result->NAME = transformDeclField(obj->NAME); + +#define FIELD(TYPE, NAME) /* empty */ + +#define END_SYNTAX_CLASS() \ + } + +#include "object-meta-begin.h" +#include "expr-defs.h" +#include "object-meta-end.h" +}; + + +template<typename V> +RefPtr<ExpressionSyntaxNode> structuralTransform( + ExpressionSyntaxNode* expr, + V* visitor) +{ + StructuralTransformExprVisitor<V> transformer; + transformer.visitor = visitor; + return transformer.dispatch(expr); +} + +// + +// Pseudo-syntax used during lowering +class TupleDecl : public VarDeclBase +{ +public: + virtual void accept(IDeclVisitor *, void *) override + { + throw "unexpected"; + } + + List<RefPtr<VarDeclBase>> decls; +}; + +// Pseudo-syntax used during lowering: +// represents an ordered list of expressions as a single unit +class TupleExpr : public ExpressionSyntaxNode +{ +public: + virtual void accept(IExprVisitor *, void *) override + { + throw "unexpected"; + } + + List<RefPtr<ExpressionSyntaxNode>> exprs; +}; + +struct SharedLoweringContext +{ + ProgramLayout* programLayout; + CodeGenTarget target; + + RefPtr<ProgramSyntaxNode> loweredProgram; + + Dictionary<Decl*, RefPtr<Decl>> loweredDecls; + Dictionary<Decl*, Decl*> mapLoweredDeclToOriginal; + + bool isRewrite; +}; + +static void attachLayout( + ModifiableSyntaxNode* syntax, + Layout* layout) +{ + RefPtr<ComputedLayoutModifier> modifier = new ComputedLayoutModifier(); + modifier->layout = layout; + + addModifier(syntax, modifier); +} + +struct LoweringVisitor + : ExprVisitor<LoweringVisitor, RefPtr<ExpressionSyntaxNode>> + , StmtVisitor<LoweringVisitor, void> + , DeclVisitor<LoweringVisitor, RefPtr<Decl>> + , TypeVisitor<LoweringVisitor, RefPtr<ExpressionType>> + , ValVisitor<LoweringVisitor, RefPtr<Val>> +{ + // + SharedLoweringContext* shared; + RefPtr<Substitutions> substitutions; + + bool isBuildingStmt = false; + RefPtr<StatementSyntaxNode> stmtBeingBuilt; + + // If we *aren't* building a statement, then this + // is the container we should be adding declarations to + RefPtr<ContainerDecl> parentDecl; + + // If we are in a context where a `return` should be turned + // into assignment to a variable (followed by a `return`), + // then this will point to that variable. + RefPtr<Variable> resultVariable; + + CodeGenTarget getTarget() { return shared->target; } + + // + // Values + // + + RefPtr<Val> lowerVal(Val* val) + { + if (!val) return nullptr; + return ValVisitor::dispatch(val); + } + + RefPtr<Val> visit(GenericParamIntVal* val) + { + return new GenericParamIntVal(translateDeclRef(DeclRef<Decl>(val->declRef)).As<VarDeclBase>()); + } + + RefPtr<Val> visit(ConstantIntVal* val) + { + return val; + } + + // + // Types + // + + RefPtr<ExpressionType> lowerType( + ExpressionType* type) + { + return TypeVisitor::dispatch(type); + } + + TypeExp lowerType( + TypeExp const& typeExp) + { + TypeExp result; + result.type = lowerType(typeExp.type); + return result; + } + + RefPtr<ExpressionType> visit(ErrorType* type) + { + return type; + } + + RefPtr<ExpressionType> visit(OverloadGroupType* type) + { + return type; + } + + RefPtr<ExpressionType> visit(InitializerListType* type) + { + return type; + } + + RefPtr<ExpressionType> visit(GenericDeclRefType* type) + { + return new GenericDeclRefType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<GenericDecl>()); + } + + RefPtr<ExpressionType> visit(FuncType* type) + { + RefPtr<FuncType> loweredType = new FuncType(); + loweredType->declRef = translateDeclRef(DeclRef<Decl>(type->declRef)).As<CallableDecl>(); + return loweredType; + } + + RefPtr<ExpressionType> visit(DeclRefType* type) + { + auto loweredDeclRef = translateDeclRef(type->declRef); + return DeclRefType::Create(loweredDeclRef); + } + + RefPtr<ExpressionType> visit(NamedExpressionType* type) + { + if (shared->target == CodeGenTarget::GLSL) + { + // GLSL does not support `typedef`, so we will lower it out of existence here + return lowerType(GetType(type->declRef)); + } + + return new NamedExpressionType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>()); + } + + RefPtr<ExpressionType> visit(TypeType* type) + { + return new TypeType(lowerType(type->type)); + } + + RefPtr<ExpressionType> visit(ArrayExpressionType* type) + { + RefPtr<ArrayExpressionType> loweredType = new ArrayExpressionType(); + loweredType->BaseType = lowerType(type->BaseType); + loweredType->ArrayLength = lowerVal(type->ArrayLength).As<IntVal>(); + return loweredType; + } + + RefPtr<ExpressionType> transformSyntaxField(ExpressionType* type) + { + return lowerType(type); + } + + // + // Expressions + // + + RefPtr<ExpressionSyntaxNode> lowerExpr( + ExpressionSyntaxNode* expr) + { + if (!expr) return nullptr; + return ExprVisitor::dispatch(expr); + } + + // catch-all + RefPtr<ExpressionSyntaxNode> visit( + ExpressionSyntaxNode* expr) + { + return structuralTransform(expr, this); + } + + RefPtr<ExpressionSyntaxNode> transformSyntaxField(ExpressionSyntaxNode* expr) + { + return lowerExpr(expr); + } + + void lowerExprCommon( + RefPtr<ExpressionSyntaxNode> loweredExpr, + RefPtr<ExpressionSyntaxNode> expr) + { + loweredExpr->Position = expr->Position; + loweredExpr->Type.type = lowerType(expr->Type.type); + } + + RefPtr<ExpressionSyntaxNode> createVarRef( + CodePosition const& loc, + VarDeclBase* decl) + { + if (auto tupleDecl = dynamic_cast<TupleDecl*>(decl)) + { + return createTupleRef(loc, tupleDecl); + } + else + { + RefPtr<VarExpressionSyntaxNode> result = new VarExpressionSyntaxNode(); + result->Position = loc; + result->Type.type = decl->Type.type; + result->declRef = makeDeclRef(decl); + return result; + } + } + + RefPtr<ExpressionSyntaxNode> createTupleRef( + CodePosition const& loc, + TupleDecl* decl) + { + RefPtr<TupleExpr> result = new TupleExpr(); + result->Position = loc; + result->Type.type = decl->Type.type; + + for (auto dd : decl->decls) + { + auto expr = createVarRef(loc, dd); + result->exprs.Add(expr); + } + + return result; + } + + RefPtr<ExpressionSyntaxNode> visit( + VarExpressionSyntaxNode* expr) + { + // If the expression didn't get resolved, we can leave it as-is + if (!expr->declRef) + return expr; + + auto loweredDeclRef = translateDeclRef(expr->declRef); + auto loweredDecl = loweredDeclRef.getDecl(); + + if (auto tupleDecl = dynamic_cast<TupleDecl*>(loweredDecl)) + { + // If we are referencing a declaration that got tuple-ified, + // then we need to produce a tuple expression as well. + + return createTupleRef(expr->Position, tupleDecl); + } + + RefPtr<VarExpressionSyntaxNode> loweredExpr = new VarExpressionSyntaxNode(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->declRef = loweredDeclRef; + return loweredExpr; + } + + RefPtr<ExpressionSyntaxNode> visit( + MemberExpressionSyntaxNode* expr) + { + auto loweredBase = lowerExpr(expr->BaseExpression); + + // Are we extracting an element from a tuple? + if (auto baseTuple = loweredBase.As<TupleExpr>()) + { + // We need to find the correct member expression, + // based on the actual tuple type. + + throw "unimplemented"; + } + + // Default handling: + auto loweredDeclRef = translateDeclRef(expr->declRef); + assert(!dynamic_cast<TupleDecl*>(loweredDeclRef.getDecl())); + + RefPtr<MemberExpressionSyntaxNode> loweredExpr = new MemberExpressionSyntaxNode(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->BaseExpression = loweredBase; + loweredExpr->declRef = loweredDeclRef; + + return loweredExpr; + } + + // + // Statements + // + + StatementSyntaxNode* translateStmtRef( + StatementSyntaxNode* stmt) + { + throw "unimplemented"; + } + + RefPtr<StatementSyntaxNode> lowerStmt( + StatementSyntaxNode* stmt) + { + if(!stmt) + return nullptr; + + LoweringVisitor subVisitor = *this; + subVisitor.stmtBeingBuilt = nullptr; + + subVisitor.lowerStmtImpl(stmt); + + if( !subVisitor.stmtBeingBuilt ) + { + return new EmptyStatementSyntaxNode(); + } + else + { + return subVisitor.stmtBeingBuilt; + } + } + + void lowerStmtImpl( + StatementSyntaxNode* stmt) + { + StmtVisitor::dispatch(stmt); + } + + RefPtr<ScopeDecl> visit(ScopeDecl* decl) + { + RefPtr<ScopeDecl> loweredDecl = new ScopeDecl(); + lowerDeclCommon(loweredDecl, decl); + return loweredDecl; + } + + LoweringVisitor pushScope( + RefPtr<ScopeStmt> loweredStmt, + RefPtr<ScopeStmt> stmt) + { + loweredStmt->scopeDecl = translateDeclRef(stmt->scopeDecl).As<ScopeDecl>(); + + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + subVisitor.parentDecl = loweredStmt->scopeDecl; + return subVisitor; + } + + void addStmtImpl( + RefPtr<StatementSyntaxNode>& dest, + StatementSyntaxNode* stmt) + { + // add a statement to the code we are building... + if( !dest ) + { + dest = stmt; + return; + } + + if (auto blockStmt = dest.As<BlockStmt>()) + { + addStmtImpl(blockStmt->body, stmt); + return; + } + + if (auto seqStmt = dest.As<SeqStmt>()) + { + seqStmt->stmts.Add(stmt); + } + else + { + RefPtr<SeqStmt> newSeqStmt = new SeqStmt(); + + newSeqStmt->stmts.Add(dest); + newSeqStmt->stmts.Add(stmt); + + dest = newSeqStmt; + } + + } + + void addStmt( + StatementSyntaxNode* stmt) + { + addStmtImpl(stmtBeingBuilt, stmt); + } + + void addExprStmt( + RefPtr<ExpressionSyntaxNode> expr) + { + // TODO: handle cases where the `expr` cannot be directly + // represented as a single statement + + RefPtr<ExpressionStatementSyntaxNode> stmt = new ExpressionStatementSyntaxNode(); + stmt->Expression = expr; + addStmt(stmt); + } + + void visit(BlockStmt* stmt) + { + RefPtr<BlockStmt> loweredStmt = new BlockStmt(); + + LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); + subVisitor.stmtBeingBuilt = loweredStmt; + + subVisitor.lowerStmtImpl(stmt->body); + + addStmt(loweredStmt); + } + + void visit(SeqStmt* stmt) + { + for( auto ss : stmt->stmts ) + { + lowerStmtImpl(ss); + } + } + + void visit(ExpressionStatementSyntaxNode* stmt) + { + addExprStmt(lowerExpr(stmt->Expression)); + } + + void visit(VarDeclrStatementSyntaxNode* stmt) + { + DeclVisitor::dispatch(stmt->decl); + } + + // catch-all + void visit(StatementSyntaxNode* stmt) + { + auto loweredStmt = structuralTransform(stmt, this); + addStmt(loweredStmt); + } + + RefPtr<StatementSyntaxNode> transformSyntaxField(StatementSyntaxNode* stmt) + { + return lowerStmt(stmt); + } + + void lowerStmtCommon(StatementSyntaxNode* loweredStmt, StatementSyntaxNode* stmt) + { + loweredStmt->modifiers = stmt->modifiers; + } + + void assign( + RefPtr<ExpressionSyntaxNode> destExpr, + RefPtr<ExpressionSyntaxNode> srcExpr) + { + RefPtr<AssignExpr> assignExpr = new AssignExpr(); + assignExpr->Position = destExpr->Position; + assignExpr->left = destExpr; + assignExpr->right = srcExpr; + + addExprStmt(assignExpr); + } + + void assign(VarDeclBase* varDecl, RefPtr<ExpressionSyntaxNode> expr) + { + assign(createVarRef(expr->Position, varDecl), expr); + } + + void assign(RefPtr<ExpressionSyntaxNode> expr, VarDeclBase* varDecl) + { + assign(expr, createVarRef(expr->Position, varDecl)); + } + + void visit(ReturnStatementSyntaxNode* stmt) + { + auto loweredStmt = new ReturnStatementSyntaxNode(); + lowerStmtCommon(loweredStmt, stmt); + + if (stmt->Expression) + { + if (resultVariable) + { + // Do it as an assignment + assign(resultVariable, lowerExpr(stmt->Expression)); + } + else + { + // Simple case + loweredStmt->Expression = lowerExpr(stmt->Expression); + } + } + + addStmt(loweredStmt); + } + + // + // Declarations + // + + RefPtr<Val> translateVal(Val* val) + { + if (auto type = dynamic_cast<ExpressionType*>(val)) + return lowerType(type); + + if (auto litVal = dynamic_cast<ConstantIntVal*>(val)) + return val; + + throw 99; + } + + RefPtr<Substitutions> translateSubstitutions( + Substitutions* substitutions) + { + if (!substitutions) return nullptr; + + RefPtr<Substitutions> result = new Substitutions(); + result->genericDecl = translateDeclRef(substitutions->genericDecl).As<GenericDecl>(); + for (auto arg : substitutions->args) + { + result->args.Add(translateVal(arg)); + } + return result; + } + + static Decl* getModifiedDecl(Decl* decl) + { + if (!decl) return nullptr; + if (auto genericDecl = dynamic_cast<GenericDecl*>(decl->ParentDecl)) + return genericDecl; + return decl; + } + + DeclRef<Decl> translateDeclRef( + DeclRef<Decl> const& decl) + { + DeclRef<Decl> result; + result.decl = translateDeclRef(decl.decl); + result.substitutions = translateSubstitutions(decl.substitutions); + return result; + } + + RefPtr<Decl> translateDeclRef( + Decl* decl) + { + if (!decl) return nullptr; + + // We don't want to translate references to built-in declarations, + // since they won't be subtituted anyway. + if (getModifiedDecl(decl)->HasModifier<FromStdLibModifier>()) + return decl; + + // If any parent of the declaration was in the stdlib, then + // we need to skip it. + for(auto pp = decl; pp; pp = pp->ParentDecl) + { + if (pp->HasModifier<FromStdLibModifier>()) + return decl; + } + + if (getModifiedDecl(decl)->HasModifier<BuiltinModifier>()) + return decl; + + RefPtr<Decl> loweredDecl; + if (shared->loweredDecls.TryGetValue(decl, loweredDecl)) + return loweredDecl; + + // Time to force it + return lowerDecl(decl); + } + + RefPtr<ContainerDecl> translateDeclRef( + ContainerDecl* decl) + { + return translateDeclRef((Decl*)decl).As<ContainerDecl>(); + } + + RefPtr<DeclBase> lowerDeclBase( + DeclBase* declBase) + { + if (Decl* decl = dynamic_cast<Decl*>(declBase)) + { + return lowerDecl(decl); + } + else + { + DeclVisitor::dispatch(declBase); + } + + } + + RefPtr<Decl> lowerDecl( + Decl* decl) + { + RefPtr<Decl> loweredDecl = DeclVisitor::dispatch(decl).As<Decl>(); + return loweredDecl; + } + + static void addMember( + RefPtr<ContainerDecl> containerDecl, + RefPtr<Decl> memberDecl) + { + containerDecl->Members.Add(memberDecl); + memberDecl->ParentDecl = containerDecl.Ptr(); + } + + void addDecl( + Decl* decl) + { + if(isBuildingStmt) + { + RefPtr<VarDeclrStatementSyntaxNode> declStmt = new VarDeclrStatementSyntaxNode(); + declStmt->Position = decl->Position; + declStmt->decl = decl; + addStmt(declStmt); + } + + + // We will add the declaration to the current container declaration being + // translated, which the user will maintain via pua/pop. + // + + assert(parentDecl); + addMember(parentDecl, decl); + } + + void registerLoweredDecl(Decl* loweredDecl, Decl* decl) + { + shared->loweredDecls.Add(decl, loweredDecl); + + shared->mapLoweredDeclToOriginal.Add(loweredDecl, decl); + } + + void lowerDeclCommon( + Decl* loweredDecl, + Decl* decl) + { + registerLoweredDecl(loweredDecl, decl); + + loweredDecl->Position = decl->Position; + loweredDecl->Name = decl->getNameToken(); + + // Lower modifiers as needed + + // HACK: just doing a shallow copy of modifiers, which will + // suffice for most of them, but we need to do something + // better soon. + loweredDecl->modifiers = decl->modifiers; + + // deal with layout stuff + + auto loweredParent = translateDeclRef(decl->ParentDecl); + if (loweredParent) + { + auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>(); + if (layoutMod) + { + auto parentLayout = layoutMod->layout; + if (auto structLayout = parentLayout.As<StructTypeLayout>()) + { + RefPtr<VarLayout> fieldLayout; + if (structLayout->mapVarToLayout.TryGetValue(decl, fieldLayout)) + { + attachLayout(loweredDecl, fieldLayout); + } + } + + // TODO: are there other cases to handle here? + } + } + } + + // Catch-all + RefPtr<Decl> visit( + Decl* decl) + { + assert(!"unimplemented"); + return decl; + } + + RefPtr<ImportDecl> visit(ImportDecl* decl) + { + // No need to translate things here if we are + // in "full" mode, because we will selectively + // translate the imported declarations at their + // use sites(s). + if (!shared->isRewrite) + return nullptr; + + for (auto dd : decl->importedModuleDecl->Members) + { + translateDeclRef(dd); + } + + // Don't actually include a representation of + // the import declaration in the output + return nullptr; + } + + RefPtr<EmptyDecl> visit(EmptyDecl* decl) + { + // Empty declarations are really only useful in GLSL, + // where they are used to hold metadata that doesn't + // attach to any particular shader parameter. + // + // TODO: Only lower empty declarations if we are + // rewriting a GLSL file, and otherwise ignore them. + // + RefPtr<EmptyDecl> loweredDecl = new EmptyDecl(); + lowerDeclCommon(loweredDecl, decl); + + addDecl(loweredDecl); + + return loweredDecl; + } + + RefPtr<Decl> visit(AggTypeDecl* decl) + { + // We want to lower any aggregate type declaration + // to just a `struct` type that contains its fields. + // + // Any non-field members (e.g., methods) will be + // lowered separately. + + // TODO: also need to figure out how to handle fields + // with types that should not be allowed in a `struct` + // for the chosen target. + // (also: what to do if there are no fields left + // after removing invalid ones?) + + RefPtr<StructSyntaxNode> loweredDecl = new StructSyntaxNode(); + lowerDeclCommon(loweredDecl, decl); + + for (auto field : decl->getMembersOfType<VarDeclBase>()) + { + // TODO: anything more to do than this? + addMember(loweredDecl, translateDeclRef(field)); + } + + addMember( + shared->loweredProgram, + loweredDecl); + + return loweredDecl; + } + + RefPtr<VarDeclBase> lowerVarDeclCommon( + RefPtr<VarDeclBase> loweredDecl, + VarDeclBase* decl) + { + lowerDeclCommon(loweredDecl, decl); + + loweredDecl->Type = lowerType(decl->Type); + loweredDecl->Expr = lowerExpr(decl->Expr); + + return loweredDecl; + } + + RefPtr<VarDeclBase> visit( + Variable* decl) + { + auto loweredDecl = lowerVarDeclCommon(new Variable(), decl); + + // We need to add things to an appropriate scope, based on what + // we are referencing. + // + // If this is a global variable (program scope), then add it + // to the global scope. + RefPtr<ContainerDecl> parentDecl = decl->ParentDecl; + if (auto parentModuleDecl = parentDecl.As<ProgramSyntaxNode>()) + { + addMember( + translateDeclRef(parentModuleDecl), + loweredDecl); + } + // TODO: handle `static` function-scope variables + else + { + // A local variable declaration will get added to the + // statement scope we are currently processing. + addDecl(loweredDecl); + } + + return loweredDecl; + } + + RefPtr<VarDeclBase> visit( + StructField* decl) + { + return lowerVarDeclCommon(new StructField(), decl); + } + + RefPtr<VarDeclBase> visit( + ParameterSyntaxNode* decl) + { + return lowerVarDeclCommon(new ParameterSyntaxNode(), decl); + } + + RefPtr<DeclBase> transformSyntaxField(DeclBase* decl) + { + return lowerDeclBase(decl); + } + + + RefPtr<Decl> visit( + DeclGroup* group) + { + for (auto decl : group->decls) + { + lowerDecl(decl); + } + return nullptr; + } + + RefPtr<FunctionSyntaxNode> visit( + FunctionDeclBase* decl) + { + // TODO: need to generate a name + + RefPtr<FunctionSyntaxNode> loweredDecl = new FunctionSyntaxNode(); + lowerDeclCommon(loweredDecl, decl); + + // TODO: push scope for parent decl here... + + // TODO: need to copy over relevant modifiers + + for (auto paramDecl : decl->GetParameters()) + { + addMember(loweredDecl, translateDeclRef(paramDecl)); + } + + auto loweredReturnType = lowerType(decl->ReturnType); + + loweredDecl->ReturnType = loweredReturnType; + + // If we are a being called recurisvely, then we need to + // be careful not to let the context get polluted + LoweringVisitor subVisitor = *this; + subVisitor.resultVariable = nullptr; + subVisitor.stmtBeingBuilt = nullptr; + + loweredDecl->Body = subVisitor.lowerStmt(decl->Body); + + // A lowered function always becomes a global-scope function, + // even if it had been a member function when declared. + addMember(shared->loweredProgram, loweredDecl); + + return loweredDecl; + } + + // + // Entry Points + // + + EntryPointLayout* findEntryPointLayout( + EntryPointRequest* entryPointRequest) + { + for( auto entryPointLayout : shared->programLayout->entryPoints ) + { + if(entryPointLayout->entryPoint->getName() != entryPointRequest->name) + continue; + + if(entryPointLayout->profile != entryPointRequest->profile) + continue; + + // TODO: can't easily filter on translation unit here... + // Ideally the `EntryPointRequest` should get filled in with a pointer + // the specific function declaration that represents the entry point. + + return entryPointLayout.Ptr(); + } + + return nullptr; + } + + enum class VaryingParameterDirection + { + Input, + Output, + }; + + struct VaryingParameterArraySpec + { + VaryingParameterArraySpec* next = nullptr; + IntVal* elementCount; + }; + + struct VaryingParameterInfo + { + String name; + VaryingParameterDirection direction; + VaryingParameterArraySpec* arraySpecs = nullptr; + }; + + + void lowerSimpleShaderParameterToGLSLGlobal( + VaryingParameterInfo const& info, + RefPtr<ExpressionType> varType, + RefPtr<VarLayout> varLayout, + RefPtr<ExpressionSyntaxNode> varExpr) + { + RefPtr<ExpressionType> type = varType; + + for (auto aa = info.arraySpecs; aa; aa = aa->next) + { + RefPtr<ArrayExpressionType> arrayType = new ArrayExpressionType(); + arrayType->BaseType = type; + arrayType->ArrayLength = aa->elementCount; + + type = arrayType; + } + + // TODO: if we are declaring an SOA-ized array, + // this is where those array dimensions would need + // to be tacked on. + + RefPtr<Variable> globalVarDecl = new Variable(); + globalVarDecl->Name.Content = info.name; + globalVarDecl->Type.type = type; + + addMember(shared->loweredProgram, globalVarDecl); + + // Add the layout information + RefPtr<ComputedLayoutModifier> modifier = new ComputedLayoutModifier(); + modifier->layout = varLayout; + addModifier(globalVarDecl, modifier); + + // Need to generate an assignment in the right direction. + // + // TODO: for now I am just dealing with input: + + switch (info.direction) + { + case VaryingParameterDirection::Input: + addModifier(globalVarDecl, new InModifier()); + assign(varExpr, globalVarDecl); + break; + + case VaryingParameterDirection::Output: + addModifier(globalVarDecl, new OutModifier()); + + assign(globalVarDecl, varExpr); + break; + } + } + + void lowerShaderParameterToGLSLGLobalsRec( + VaryingParameterInfo const& info, + RefPtr<ExpressionType> varType, + RefPtr<VarLayout> varLayout, + RefPtr<ExpressionSyntaxNode> varExpr) + { + assert(varLayout); + + if (auto basicType = varType->As<BasicExpressionType>()) + { + // handled below + } + else if (auto vectorType = varType->As<VectorExpressionType>()) + { + // handled below + } + else if (auto matrixType = varType->As<MatrixExpressionType>()) + { + // handled below + } + else if (auto arrayType = varType->As<ArrayExpressionType>()) + { + // We will accumulate information on the array + // types that were encoutnered on our walk down + // to the leaves, and then apply these array dimensions + // to any leaf parameters. + + VaryingParameterArraySpec arraySpec; + arraySpec.next = info.arraySpecs; + arraySpec.elementCount = arrayType->ArrayLength; + + VaryingParameterInfo arrayInfo = info; + arrayInfo.arraySpecs = &arraySpec; + + RefPtr<IndexExpressionSyntaxNode> subscriptExpr = new IndexExpressionSyntaxNode(); + subscriptExpr->Position = varExpr->Position; + subscriptExpr->BaseExpression = varExpr; + + // TODO: we need to construct syntax for a loop to initialize + // the array here... + throw "unimplemented"; + + // Note that we use the original `varLayout` that was passed in, + // since that is the layout that will ultimately need to be + // used on the array elements. + // + // TODO: That won't actually work if we ever had an array of + // heterogeneous stuff... + lowerShaderParameterToGLSLGLobalsRec( + arrayInfo, + arrayType->BaseType, + varLayout, + subscriptExpr); + + } + else if (auto declRefType = varType->As<DeclRefType>()) + { + auto declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + { + // The shader parameter had a structured type, so we need + // to destructure it into its constituent fields + + for (auto fieldDeclRef : getMembersOfType<VarDeclBase>(aggTypeDeclRef)) + { + // Don't emit storage for `static` fields here, of course + if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + continue; + + RefPtr<MemberExpressionSyntaxNode> fieldExpr = new MemberExpressionSyntaxNode(); + fieldExpr->Position = varExpr->Position; + fieldExpr->Type.type = GetType(fieldDeclRef); + fieldExpr->declRef = fieldDeclRef; + fieldExpr->BaseExpression = varExpr; + + VaryingParameterInfo fieldInfo = info; + fieldInfo.name = info.name + "_" + fieldDeclRef.GetName(); + + // Need to find the layout for the given field... + Decl* originalFieldDecl = nullptr; + shared->mapLoweredDeclToOriginal.TryGetValue(fieldDeclRef.getDecl(), originalFieldDecl); + assert(originalFieldDecl); + + auto structTypeLayout = varLayout->typeLayout.As<StructTypeLayout>(); + assert(structTypeLayout); + + RefPtr<VarLayout> fieldLayout; + structTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout); + assert(fieldLayout); + + lowerShaderParameterToGLSLGLobalsRec( + fieldInfo, + GetType(fieldDeclRef), + fieldLayout, + fieldExpr); + } + + // Okay, we are done with this parameter + return; + } + } + + // Default case: just try to emit things as-is + lowerSimpleShaderParameterToGLSLGlobal(info, varType, varLayout, varExpr); + } + + void lowerShaderParameterToGLSLGLobals( + RefPtr<Variable> localVarDecl, + RefPtr<VarLayout> paramLayout, + VaryingParameterDirection direction) + { + auto name = localVarDecl->getName(); + auto declRef = makeDeclRef(localVarDecl.Ptr()); + + RefPtr<VarExpressionSyntaxNode> expr = new VarExpressionSyntaxNode(); + expr->name = name; + expr->declRef = declRef; + expr->Type.type = GetType(declRef); + + VaryingParameterInfo info; + info.name = name; + info.direction = direction; + + lowerShaderParameterToGLSLGLobalsRec( + info, + localVarDecl->getType(), + paramLayout, + expr); + } + + struct EntryPointParamPair + { + RefPtr<ParameterSyntaxNode> original; + RefPtr<VarLayout> layout; + RefPtr<Variable> lowered; + }; + + RefPtr<FunctionSyntaxNode> lowerEntryPointToGLSL( + FunctionSyntaxNode* entryPointDecl, + RefPtr<EntryPointLayout> entryPointLayout) + { + // First, loer the entry-point function as an ordinary function: + auto loweredEntryPointFunc = visit(entryPointDecl); + + // Now we will generate a `void main() { ... }` function to call the lowered code. + RefPtr<FunctionSyntaxNode> mainDecl = new FunctionSyntaxNode(); + mainDecl->ReturnType.type = ExpressionType::GetVoid(); + mainDecl->Name.Content = "main"; + + // If the user's entry point was called `main` then rename it here + if (loweredEntryPointFunc->getName() == "main") + loweredEntryPointFunc->Name.Content = "main_"; + + // We will want to generate declarations into the body of our new `main()` + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + + // The parameters of the entry-point function will be translated to + // both a local variable (for passing to/from the entry point func), + // and to global variables (used for parameter passing) + + List<EntryPointParamPair> params; + + // First generate declarations for the locals + for (auto paramDecl : entryPointDecl->GetParameters()) + { + RefPtr<VarLayout> paramLayout; + entryPointLayout->mapVarToLayout.TryGetValue(paramDecl.Ptr(), paramLayout); + assert(paramLayout); + + RefPtr<Variable> localVarDecl = new Variable(); + localVarDecl->Position = paramDecl->Position; + localVarDecl->Name.Content = paramDecl->getName(); + localVarDecl->Type = lowerType(paramDecl->Type); + + subVisitor.addDecl(localVarDecl); + + EntryPointParamPair paramPair; + paramPair.original = paramDecl; + paramPair.layout = paramLayout; + paramPair.lowered = localVarDecl; + + params.Add(paramPair); + } + + // Next generate globals for the inputs, and initialize them + for (auto paramPair : params) + { + auto paramDecl = paramPair.original; + if (paramDecl->HasModifier<InModifier>() + || paramDecl->HasModifier<InOutModifier>() + || !paramDecl->HasModifier<OutModifier>()) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.lowered, + paramPair.layout, + VaryingParameterDirection::Input); + } + } + + // Generate a local variable for the result, if any + RefPtr<Variable> resultVarDecl; + if (!loweredEntryPointFunc->ReturnType->Equals(ExpressionType::GetVoid())) + { + resultVarDecl = new Variable(); + resultVarDecl->Position = loweredEntryPointFunc->Position; + resultVarDecl->Name.Content = "_main_result"; + resultVarDecl->Type = TypeExp(loweredEntryPointFunc->ReturnType); + + subVisitor.addDecl(resultVarDecl); + } + + // Now generate a call to the entry-point function, using the local variables + auto entryPointDeclRef = makeDeclRef(loweredEntryPointFunc.Ptr()); + + RefPtr<FuncType> entryPointType = new FuncType(); + entryPointType->declRef = entryPointDeclRef; + + RefPtr<VarExpressionSyntaxNode> entryPointRef = new VarExpressionSyntaxNode(); + entryPointRef->name = loweredEntryPointFunc->getName(); + entryPointRef->declRef = entryPointDeclRef; + entryPointRef->Type = QualType(entryPointType); + + RefPtr<InvokeExpressionSyntaxNode> callExpr = new InvokeExpressionSyntaxNode(); + callExpr->FunctionExpr = entryPointRef; + callExpr->Type = QualType(loweredEntryPointFunc->ReturnType); + + // + for (auto paramPair : params) + { + auto localVarDecl = paramPair.lowered; + + RefPtr<VarExpressionSyntaxNode> varRef = new VarExpressionSyntaxNode(); + varRef->name = localVarDecl->getName(); + varRef->declRef = makeDeclRef(localVarDecl.Ptr()); + varRef->Type = QualType(localVarDecl->getType()); + + callExpr->Arguments.Add(varRef); + } + + if (resultVarDecl) + { + // Non-`void` return type, so we need to store it + subVisitor.assign(resultVarDecl, callExpr); + } + else + { + // `void` return type: just call it + subVisitor.addExprStmt(callExpr); + } + + + // Finally, generate logic to copy the outputs to global parameters + for (auto paramPair : params) + { + auto paramDecl = paramPair.original; + if (paramDecl->HasModifier<OutModifier>() + || paramDecl->HasModifier<InOutModifier>()) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.lowered, + paramPair.layout, + VaryingParameterDirection::Output); + } + } + if (resultVarDecl) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + resultVarDecl, + entryPointLayout->resultLayout, + VaryingParameterDirection::Output); + } + + mainDecl->Body = subVisitor.stmtBeingBuilt; + + + // Once we are done building the body, we append our new declaration to the program. + addMember(shared->loweredProgram, mainDecl); + return mainDecl; + +#if 0 + RefPtr<FunctionSyntaxNode> loweredDecl = new FunctionSyntaxNode(); + lowerDeclCommon(loweredDecl, entryPointDecl); + + // We create a sub-context appropriate for lowering the function body + + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + + // The parameters of the entry-point function must be translated + // to global-scope declarations + for (auto paramDecl : entryPointDecl->GetParameters()) + { + subVisitor.lowerShaderParameterToGLSLGLobals(paramDecl); + } + + // The output of the function must also be translated into a + // global-scope declaration. + auto loweredReturnType = lowerType(entryPointDecl->ReturnType); + RefPtr<Variable> resultGlobal; + if (!loweredReturnType->Equals(ExpressionType::GetVoid())) + { + resultGlobal = new Variable(); + // TODO: need a scheme for generating unique names + resultGlobal->Name.Content = "_main_result"; + resultGlobal->Type = loweredReturnType; + + addMember(shared->loweredProgram, resultGlobal); + } + + loweredDecl->Name.Content = "main"; + loweredDecl->ReturnType.type = ExpressionType::GetVoid(); + + // We will emit the body statement in a context where + // a `return` statmenet will generate writes to the + // result global that we declared. + subVisitor.resultVariable = resultGlobal; + + auto loweredBody = subVisitor.lowerStmt(entryPointDecl->Body); + subVisitor.addStmt(loweredBody); + + loweredDecl->Body = subVisitor.stmtBeingBuilt; + + // TODO: need to append writes for `out` parameters here... + + addMember(shared->loweredProgram, loweredDecl); + return loweredDecl; +#endif + } + + RefPtr<FunctionSyntaxNode> lowerEntryPoint( + FunctionSyntaxNode* entryPointDecl, + RefPtr<EntryPointLayout> entryPointLayout) + { + switch( getTarget() ) + { + // Default case: lower an entry point just like any other function + default: + return visit(entryPointDecl); + + // For Slang->GLSL translation, we need to lower things from HLSL-style + // declarations over to GLSL conventions + case CodeGenTarget::GLSL: + return lowerEntryPointToGLSL(entryPointDecl, entryPointLayout); + } + } + + RefPtr<FunctionSyntaxNode> lowerEntryPoint( + EntryPointRequest* entryPointRequest) + { + auto entryPointLayout = findEntryPointLayout(entryPointRequest); + auto entryPointDecl = entryPointLayout->entryPoint; + + return lowerEntryPoint( + entryPointDecl, + entryPointLayout); + } + + +}; + +static RefPtr<StructTypeLayout> getGlobalStructLayout( + ProgramLayout* programLayout) +{ + // Layout information for the global scope is either an ordinary + // `struct` in the common case, or a constant buffer in the case + // where there were global-scope uniforms. + auto globalScopeLayout = programLayout->globalScopeLayout; + StructTypeLayout* globalStructLayout = globalScopeLayout.As<StructTypeLayout>(); + if(globalStructLayout) + { } + else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>()) + { + // TODO: the `cbuffer` case really needs to be emitted very + // carefully, but that is beyond the scope of what a simple rewriter + // can easily do (without semantic analysis, etc.). + // + // The crux of the problem is that we need to collect all the + // global-scope uniforms (but not declarations that don't involve + // uniform storage...) and put them in a single `cbuffer` declaration, + // so that we can give it an explicit location. The fields in that + // declaration might use various type declarations, so we'd really + // need to emit all the type declarations first, and that involves + // some large scale reorderings. + // + // For now we will punt and just emit the declarations normally, + // and hope that the global-scope block (`$Globals`) gets auto-assigned + // the same location that we manually asigned it. + + auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout; + auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>(); + + // We expect all constant buffers to contain `struct` types for now + assert(elementTypeStructLayout); + + globalStructLayout = elementTypeStructLayout.Ptr(); + } + else + { + assert(!"unexpected"); + } + return globalStructLayout; +} + + +// Determine if the user is just trying to "rewrite" their input file +// into an output file. This will affect the way we approach code +// generation, because we want to leave their code "as is" whenever +// possible. +bool isRewriteRequest( + SourceLanguage sourceLanguage, + CodeGenTarget target) +{ + // TODO: we might only consider things to be a rewrite request + // in the specific case where checking is turned off... + + switch( target ) + { + default: + return false; + + case CodeGenTarget::HLSL: + return sourceLanguage == SourceLanguage::HLSL; + + case CodeGenTarget::GLSL: + return sourceLanguage == SourceLanguage::GLSL; + } +} + + + +LoweredEntryPoint lowerEntryPoint( + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target) +{ + SharedLoweringContext sharedContext; + sharedContext.programLayout = programLayout; + sharedContext.target = target; + + auto translationUnit = entryPoint->getTranslationUnit(); + + // Create a single module/program to hold all the lowered code + // (with the exception of instrinsic/stdlib declarations, which + // will be remain where they are) + RefPtr<ProgramSyntaxNode> loweredProgram = new ProgramSyntaxNode(); + sharedContext.loweredProgram = loweredProgram; + + LoweringVisitor visitor; + visitor.shared = &sharedContext; + visitor.parentDecl = loweredProgram; + + // We need to register the lowered program as the lowered version + // of the existing translation unit declaration. + + visitor.registerLoweredDecl( + loweredProgram, + translationUnit->SyntaxNode); + + // We also need to register the lowered program as the lowered version + // of any imported modules (since we will be collecting everything into + // a single module for code generation). + for (auto rr : entryPoint->compileRequest->loadedModulesList) + { + sharedContext.loweredDecls.Add( + rr, + loweredProgram); + } + + // We also want to remember the layout information for + // that declaration, so that we can apply it during emission + attachLayout(loweredProgram, + getGlobalStructLayout(programLayout)); + + + bool isRewrite = isRewriteRequest(translationUnit->sourceLanguage, target); + sharedContext.isRewrite = isRewrite; + + LoweredEntryPoint result; + if (isRewrite) + { + for (auto dd : translationUnit->SyntaxNode->Members) + { + visitor.translateDeclRef(dd); + } + } + else + { + auto loweredEntryPoint = visitor.lowerEntryPoint(entryPoint); + result.entryPoint = loweredEntryPoint; + } + + result.program = sharedContext.loweredProgram; + + return result; +} +} diff --git a/source/slang/lower.h b/source/slang/lower.h new file mode 100644 index 000000000..c690ea025 --- /dev/null +++ b/source/slang/lower.h @@ -0,0 +1,38 @@ +// lower.h +#ifndef SLANG_LOWER_H_INCLUDED +#define SLANG_LOWER_H_INCLUDED + +// The "lowering" step takes an input AST written in the complete Slang +// language and turns it into a more minimal format (still using the +// same AST) suitable for emission into lower-level languages. + +#include "../core/basic.h" + +#include "compiler.h" +#include "syntax.h" + +namespace Slang +{ + class EntryPointRequest; + class ProgramLayout; + class TranslationUnitRequest; + + struct LoweredEntryPoint + { + // The actual lowered entry point + RefPtr<FunctionSyntaxNode> entryPoint; + + // The generated program AST that + // contains the entry point and any + // other declarations it uses + RefPtr<ProgramSyntaxNode> program; + }; + + // Emit code for a single entry point, based on + // the input translation unit. + LoweredEntryPoint lowerEntryPoint( + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target); +} +#endif diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h index a1da47fb7..e50288600 100644 --- a/source/slang/modifier-defs.h +++ b/source/slang/modifier-defs.h @@ -269,3 +269,7 @@ SIMPLE_SYNTAX_CLASS(HLSLTriangleModifier , HLSLGeometryShaderInputPrimitiveT SIMPLE_SYNTAX_CLASS(HLSLLineAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) SIMPLE_SYNTAX_CLASS(HLSLTriangleAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) +// A modifier to be attached to syntax after we've computed layout +SYNTAX_CLASS(ComputedLayoutModifier, Modifier) + FIELD(RefPtr<Layout>, layout) +END_SYNTAX_CLASS() diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 2a0c64892..8d008618f 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -767,86 +767,87 @@ SimpleSemanticInfo decomposeSimpleSemantic( return info; } -enum class EntryPointParameterDirection +enum EntryPointParameterDirection { - Input, - Output, + kEntryPointParameterDirection_Input = 0x1, + kEntryPointParameterDirection_Output = 0x2, }; +typedef unsigned int EntryPointParameterDirectionMask; struct EntryPointParameterState { - String* optSemanticName; - int* ioSemanticIndex; - EntryPointParameterDirection direction; - int semanticSlotCount; + String* optSemanticName; + int* ioSemanticIndex; + EntryPointParameterDirectionMask directionMask; + int semanticSlotCount; }; -static void processSimpleEntryPointInput( +static RefPtr<TypeLayout> processSimpleEntryPointParameter( ParameterBindingContext* context, RefPtr<ExpressionType> type, - EntryPointParameterState const& state) + EntryPointParameterState const& inState, + int semanticSlotCount = 1) { - auto optSemanticName = state.optSemanticName; - auto semanticIndex = *state.ioSemanticIndex; - auto semanticSlotCount = state.semanticSlotCount; -} + EntryPointParameterState state = inState; + state.semanticSlotCount = semanticSlotCount; -static void processSimpleEntryPointOutput( - ParameterBindingContext* context, - RefPtr<ExpressionType> type, - EntryPointParameterState const& state) -{ auto optSemanticName = state.optSemanticName; auto semanticIndex = *state.ioSemanticIndex; - auto semanticSlotCount = state.semanticSlotCount; - - if(!optSemanticName) - return; - auto semanticName = *optSemanticName; + String semanticName = optSemanticName ? *optSemanticName : ""; + String sn = semanticName.ToLower(); - // Note: I'm just doing something expedient here and detecting `SV_Target` - // outputs and claiming the appropriate register range right away. - // - // TODO: we should really be building up some representation of all of this, - // once we've gone to the trouble of looking it all up... - if( semanticName.ToLower() == "sv_target" ) + RefPtr<TypeLayout> typeLayout = new TypeLayout(); + if (sn.StartsWith("sv_")) { - context->shared->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(semanticIndex, semanticIndex + semanticSlotCount); - } -} + // System-value semantic. -static void processSimpleEntryPointParameter( - ParameterBindingContext* context, - RefPtr<ExpressionType> type, - EntryPointParameterState const& inState, - int semanticSlotCount = 1) -{ - EntryPointParameterState state = inState; - state.semanticSlotCount = semanticSlotCount; + if (state.directionMask & kEntryPointParameterDirection_Output) + { + // Note: I'm just doing something expedient here and detecting `SV_Target` + // outputs and claiming the appropriate register range right away. + // + // TODO: we should really be building up some representation of all of this, + // once we've gone to the trouble of looking it all up... + if( sn == "sv_target" ) + { + context->shared->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(semanticIndex, semanticIndex + semanticSlotCount); + } + } - switch( state.direction ) + // TODO: add some kind of usage information for system input/output + } + else { - case EntryPointParameterDirection::Input: - processSimpleEntryPointInput(context, type, state); - break; + // user-defined semantic - case EntryPointParameterDirection::Output: - processSimpleEntryPointOutput(context, type, state); - break; + if (state.directionMask & kEntryPointParameterDirection_Input) + { + auto rules = context->layoutRules->getVaryingInputRules(); + SimpleLayoutInfo layout = GetLayout(type, rules); + typeLayout->addResourceUsage(layout.kind, layout.size); + } - SLANG_EXHAUSTIVE_SWITCH() + if (state.directionMask & kEntryPointParameterDirection_Output) + { + auto rules = context->layoutRules->getVaryingOutputRules(); + SimpleLayoutInfo layout = GetLayout(type, rules); + typeLayout->addResourceUsage(layout.kind, layout.size); + } } *state.ioSemanticIndex += state.semanticSlotCount; + typeLayout->type = type; + + return typeLayout; } -static void processEntryPointParameter( +static RefPtr<TypeLayout> processEntryPointParameter( ParameterBindingContext* context, RefPtr<ExpressionType> type, EntryPointParameterState const& state); -static void processEntryPointParameterWithPossibleSemantic( +static RefPtr<TypeLayout> processEntryPointParameterWithPossibleSemantic( ParameterBindingContext* context, Decl* declForSemantic, RefPtr<ExpressionType> type, @@ -873,11 +874,11 @@ static void processEntryPointParameterWithPossibleSemantic( // *or* we couldn't find an explicit semantic to apply on the given // declaration, so we will just recursive with whatever we have at // the moment. - processEntryPointParameter(context, type, state); + return processEntryPointParameter(context, type, state); } -static void processEntryPointParameter( +static RefPtr<TypeLayout> processEntryPointParameter( ParameterBindingContext* context, RefPtr<ExpressionType> type, EntryPointParameterState const& state) @@ -885,31 +886,50 @@ static void processEntryPointParameter( // Scalar and vector types are treated as outputs directly if(auto basicType = type->As<BasicExpressionType>()) { - processSimpleEntryPointParameter(context, basicType, state); + return processSimpleEntryPointParameter(context, basicType, state); } else if(auto basicType = type->As<VectorExpressionType>()) { - processSimpleEntryPointParameter(context, basicType, state); + return processSimpleEntryPointParameter(context, basicType, state); } // A matrix is processed as if it was an array of rows else if( auto matrixType = type->As<MatrixExpressionType>() ) { auto rowCount = GetIntVal(matrixType->getRowCount()); - processSimpleEntryPointParameter(context, basicType, state, (int) rowCount); + return processSimpleEntryPointParameter(context, basicType, state, (int) rowCount); } else if( auto arrayType = type->As<ArrayExpressionType>() ) { + // Note: Bad Things will happen if we have an array input + // without a semantic already being enforced. + auto elementCount = GetIntVal(arrayType->ArrayLength); - for( int ii = 0; ii < elementCount; ++ii ) + // We use the first element to derive the layout for the element type + auto elementTypeLayout = processEntryPointParameter(context, arrayType->BaseType, state); + + // We still walk over subsequent elements to make sure they consume resources + // as needed + for( int ii = 1; ii < elementCount; ++ii ) { processEntryPointParameter(context, arrayType->BaseType, state); } + + RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); + arrayTypeLayout->elementTypeLayout = elementTypeLayout; + arrayTypeLayout->type = arrayType; + + for (auto rr : elementTypeLayout->resourceInfos) + { + arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count * elementCount; + } + + return arrayTypeLayout; } // Ignore a bunch of types that don't make sense here... - else if(auto textureType = type->As<TextureType>()) {} - else if(auto samplerStateType = type->As<SamplerStateType>()) {} - else if(auto constantBufferType = type->As<ConstantBufferType>()) {} + else if (auto textureType = type->As<TextureType>()) { return nullptr; } + else if(auto samplerStateType = type->As<SamplerStateType>()) { return nullptr; } + else if(auto constantBufferType = type->As<ConstantBufferType>()) { return nullptr; } // Catch declaration-reference types late in the sequence, since // otherwise they will include all of the above cases... else if( auto declRefType = type->As<DeclRefType>() ) @@ -918,15 +938,36 @@ static void processEntryPointParameter( if (auto structDeclRef = declRef.As<StructSyntaxNode>()) { + RefPtr<StructTypeLayout> structLayout = new StructTypeLayout(); + structLayout->type = type; + // Need to recursively walk the fields of the structure now... for( auto field : GetFields(structDeclRef) ) { - processEntryPointParameterWithPossibleSemantic( + auto fieldTypeLayout = processEntryPointParameterWithPossibleSemantic( context, field.getDecl(), GetType(field), state); + + RefPtr<VarLayout> fieldVarLayout = new VarLayout(); + fieldVarLayout->varDecl = field; + fieldVarLayout->typeLayout = fieldTypeLayout; + + for (auto rr : fieldTypeLayout->resourceInfos) + { + assert(rr.count != 0); + + auto structRes = structLayout->findOrAddResourceInfo(rr.kind); + fieldVarLayout->findOrAddResourceInfo(rr.kind)->index = structRes->count; + structRes->count += rr.count; + } + + structLayout->fields.Add(fieldVarLayout); + structLayout->mapVarToLayout.Add(field.getDecl(), fieldVarLayout); } + + return structLayout; } else { @@ -937,6 +978,9 @@ static void processEntryPointParameter( { assert(!"unimplemented"); } + + assert(!"unexpected"); + return nullptr; } static void collectEntryPointParameters( @@ -998,42 +1042,68 @@ static void collectEntryPointParameters( // We have an entry-point parameter, and need to figure out what to do with it. + // TODO: need to handle `uniform`-qualified parameters here + if (paramDecl->HasModifier<UniformModifier>()) + continue; + + state.directionMask = 0; + // If it appears to be an input, process it as such. if( paramDecl->HasModifier<InModifier>() || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>() ) { - state.direction = EntryPointParameterDirection::Input; - - processEntryPointParameterWithPossibleSemantic( - context, - paramDecl.Ptr(), - paramDecl->Type.type, - state); + state.directionMask |= kEntryPointParameterDirection_Input; } // If it appears to be an output, process it as such. if(paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>()) { - state.direction = EntryPointParameterDirection::Output; + state.directionMask |= kEntryPointParameterDirection_Output; + } + + auto paramTypeLayout = processEntryPointParameterWithPossibleSemantic( + context, + paramDecl.Ptr(), + paramDecl->Type.type, + state); - processEntryPointParameterWithPossibleSemantic( - context, - paramDecl.Ptr(), - paramDecl->Type.type, - state); + RefPtr<VarLayout> paramVarLayout = new VarLayout(); + paramVarLayout->varDecl = makeDeclRef(paramDecl.Ptr()); + paramVarLayout->typeLayout = paramTypeLayout; + + for (auto rr : paramTypeLayout->resourceInfos) + { + auto entryPointRes = entryPointLayout->findOrAddResourceInfo(rr.kind); + paramVarLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count; + entryPointRes->count += rr.count; } + + entryPointLayout->fields.Add(paramVarLayout); + entryPointLayout->mapVarToLayout.Add(paramDecl, paramVarLayout); } // If we can find an output type for the entry point, then process it as // an output parameter. if( auto resultType = entryPointFuncDecl->ReturnType.type ) { - state.direction = EntryPointParameterDirection::Output; + state.directionMask = kEntryPointParameterDirection_Output; - processEntryPointParameterWithPossibleSemantic( + auto resultTypeLayout = processEntryPointParameterWithPossibleSemantic( context, entryPointFuncDecl, resultType, state); + + RefPtr<VarLayout> resultLayout = new VarLayout(); + resultLayout->typeLayout = resultTypeLayout; + + for (auto rr : resultTypeLayout->resourceInfos) + { + auto entryPointRes = entryPointLayout->findOrAddResourceInfo(rr.kind); + resultLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count; + entryPointRes->count += rr.count; + } + + entryPointLayout->resultLayout = resultLayout; } } diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 4f622ada6..b134f9645 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -2617,10 +2617,13 @@ namespace Slang RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); - RefPtr<BlockStatementSyntaxNode> blockStatement = new BlockStatementSyntaxNode(); + RefPtr<BlockStmt> blockStatement = new BlockStmt(); blockStatement->scopeDecl = scopeDecl; PushScope(scopeDecl.Ptr()); ReadToken(TokenType::LBrace); + + RefPtr<StatementSyntaxNode> body; + if(!tokenReader.IsAtEnd()) { FillPosition(blockStatement.Ptr()); @@ -2630,11 +2633,29 @@ namespace Slang auto stmt = ParseStatement(); if(stmt) { - blockStatement->Statements.Add(stmt); + if (!body) + { + body = stmt; + } + else if (auto seqStmt = body.As<SeqStmt>()) + { + seqStmt->stmts.Add(stmt); + } + else + { + RefPtr<SeqStmt> newBody = new SeqStmt(); + newBody->Position = blockStatement->Position; + newBody->stmts.Add(body); + newBody->stmts.Add(stmt); + + body = newBody; + } } TryRecover(this); } PopScope(); + + blockStatement->body = body; return blockStatement; } @@ -3054,7 +3075,19 @@ namespace Slang right = parseInfixExprWithPrecedence(parser, right, nextOpPrec); } - expr = createInfixExpr(parser, expr, op, right); + if (opTokenType == TokenType::OpAssign) + { + RefPtr<AssignExpr> assignExpr = new AssignExpr(); + assignExpr->Position = op->Position; + assignExpr->left = expr; + assignExpr->right = right; + + expr = assignExpr; + } + else + { + expr = createInfixExpr(parser, expr, op, right); + } } return expr; } diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index 3fde10e4b..441bc5e45 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -178,6 +178,7 @@ <ClInclude Include="modifier-defs.h" /> <ClInclude Include="object-meta-begin.h" /> <ClInclude Include="object-meta-end.h" /> + <ClInclude Include="lower.h" /> <ClInclude Include="parameter-binding.h" /> <ClInclude Include="parser.h" /> <ClInclude Include="preprocessor.h" /> @@ -205,6 +206,7 @@ <ClCompile Include="emit.cpp" /> <ClCompile Include="lexer.cpp" /> <ClCompile Include="lookup.cpp" /> + <ClCompile Include="lower.cpp" /> <ClCompile Include="options.cpp" /> <ClCompile Include="parameter-binding.cpp" /> <ClCompile Include="parser.cpp" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 0ed4457dc..21b533a10 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -37,6 +37,7 @@ <ClInclude Include="type-defs.h" /> <ClInclude Include="val-defs.h" /> <ClInclude Include="visitor.h" /> + <ClInclude Include="lower.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="check.cpp" /> @@ -56,5 +57,6 @@ <ClCompile Include="token.cpp" /> <ClCompile Include="type-layout.cpp" /> <ClCompile Include="options.cpp" /> + <ClCompile Include="lower.cpp" /> </ItemGroup> </Project>
\ No newline at end of file diff --git a/source/slang/stmt-defs.h b/source/slang/stmt-defs.h index 9cc1978bd..c5bcda63b 100644 --- a/source/slang/stmt-defs.h +++ b/source/slang/stmt-defs.h @@ -6,8 +6,14 @@ ABSTRACT_SYNTAX_CLASS(ScopeStmt, StatementSyntaxNode) SYNTAX_FIELD(RefPtr<ScopeDecl>, scopeDecl) END_SYNTAX_CLASS() -SYNTAX_CLASS(BlockStatementSyntaxNode, ScopeStmt) - SYNTAX_FIELD(List<RefPtr<StatementSyntaxNode>>, Statements) +// A sequence of statements, treated as a single statement +SYNTAX_CLASS(SeqStmt, StatementSyntaxNode) + SYNTAX_FIELD(List<RefPtr<StatementSyntaxNode>>, stmts) +END_SYNTAX_CLASS() + +// The simplest kind of scope statement: just a `{...}` block +SYNTAX_CLASS(BlockStmt, ScopeStmt) + SYNTAX_FIELD(RefPtr<StatementSyntaxNode>, body); END_SYNTAX_CLASS() SYNTAX_CLASS(UnparsedStmt, StatementSyntaxNode) diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 983f18bb7..12d8c90bb 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -1,8 +1,6 @@ #include "syntax.h" -#pragma warning AAA #include "visitor.h" -#pragma warning BBB #include <typeinfo> #include <assert.h> @@ -55,9 +53,6 @@ namespace Slang return res.ProduceString(); } -#pragma warning CCC - - // Generate dispatch logic and other definitions for all syntax classes #define SYNTAX_CLASS(NAME, BASE) /* empty */ #include "object-meta-begin.h" @@ -79,8 +74,6 @@ namespace Slang #include "object-meta-end.h" -#pragma warning DDD - void ExpressionType::accept(IValVisitor* visitor, void* extra) { accept((ITypeVisitor*)visitor, extra); @@ -881,9 +874,20 @@ void ExpressionType::accept(IValVisitor* visitor, void* extra) auto parentDecl = decl->ParentDecl; if (auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl)) { - // We need to strip away one layer of specialization - assert(substitutions); - return DeclRefBase(parentGeneric, substitutions->outer); + if (substitutions && substitutions->genericDecl == parentDecl) + { + // We strip away the specializations that were applied to + // the parent, since we were asked for a reference *to* the parent. + return DeclRefBase(parentGeneric, substitutions->outer); + } + else + { + // Either we don't have specializations, or the inner-most + // specializations didn't apply to the parent decl. This + // can happen if we are looking at an unspecialized + // declaration that is a child of a generic. + return DeclRefBase(parentGeneric, substitutions); + } } else { diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 47427b130..6587e18c3 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -14,6 +14,7 @@ namespace Slang class Substitutions; class SyntaxVisitor; class FunctionSyntaxNode; + class Layout; struct IExprVisitor; struct IDeclVisitor; @@ -347,9 +348,15 @@ namespace Slang { return DeclRef<ContainerDecl>::unsafeInit(DeclRefBase::GetParent()); } - }; + + template<typename T> + inline DeclRef<T> makeDeclRef(T* decl) + { + return DeclRef<T>(decl, nullptr); + } + template<typename T> struct FilteredMemberList { diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index 07867fddb..ce1f8864d 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -143,8 +143,13 @@ struct SimpleArrayLayoutInfo : SimpleLayoutInfo struct LayoutRulesImpl; +// Base class for things that store layout info +class Layout : public RefObject +{ +}; + // A reified reprsentation of a particular laid-out type -class TypeLayout : public RefObject +class TypeLayout : public Layout { public: // The type that was laid out @@ -215,7 +220,7 @@ enum VarLayoutFlag : VarLayoutFlags }; // A reified layout for a particular variable, field, etc. -class VarLayout : public RefObject +class VarLayout : public Layout { public: // The variable we are laying out @@ -324,7 +329,14 @@ public: // Layout information for a single shader entry point // within a program -class EntryPointLayout : public RefObject +// +// Treated as a subclass of `StructTypeLayout` becase +// it needs to include computed layout information +// for the parameters of the entry point. +// +// TODO: where to store layout info for the return +// type of the function? +class EntryPointLayout : public StructTypeLayout { public: // The corresponding function declaration @@ -332,10 +344,13 @@ public: // The shader profile that was used to compile the entry point Profile profile; + + // Layout for any results of the entry point + RefPtr<VarLayout> resultLayout; }; // Layout information for the global scope of a program -class ProgramLayout : public RefObject +class ProgramLayout : public Layout { public: // We store a layout for the declarations at the global @@ -359,14 +374,6 @@ public: List<RefPtr<EntryPointLayout>> entryPoints; }; -// A modifier to be attached to syntax after we've computed layout -class ComputedLayoutModifier : public Modifier -{ -public: - RefPtr<TypeLayout> typeLayout; -}; - - struct LayoutRulesFamilyImpl; // A delineation of shader parameter types into fine-grained diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h index 316981842..9ebfc9872 100644 --- a/source/slang/val-defs.h +++ b/source/slang/val-defs.h @@ -3,7 +3,8 @@ // Syntax class definitions for compile-time values. // A compile-time integer (may not have a specific concrete value) -SIMPLE_SYNTAX_CLASS(IntVal, Val) +ABSTRACT_SYNTAX_CLASS(IntVal, Val) +END_SYNTAX_CLASS() // Trivial case of a value that is just a constant integer SYNTAX_CLASS(ConstantIntVal, IntVal) |
