diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2017-07-06 09:52:53 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-07-06 09:52:53 -0700 |
| commit | 21a14cb4e0d578bc4f8a460016269a1199cac0da (patch) | |
| tree | 88a04619ceaaa37b87199dd82334cc9d102c156d /source/slang | |
| parent | f313df379dd9e0d4395f072ffb87016a6f20d5a1 (diff) | |
| parent | f145e09a6dcbcf326f782b3e6a76dbf291c792cf (diff) | |
Merge pull request #53 from tfoleyNV/cross-compilation
Initial work on cross-compilation
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/check.cpp | 29 | ||||
| -rw-r--r-- | source/slang/compiler.cpp | 73 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 435 | ||||
| -rw-r--r-- | source/slang/emit.h | 9 | ||||
| -rw-r--r-- | source/slang/expr-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/lower.cpp | 1601 | ||||
| -rw-r--r-- | source/slang/lower.h | 38 | ||||
| -rw-r--r-- | source/slang/modifier-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 220 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 39 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 2 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 2 | ||||
| -rw-r--r-- | source/slang/stmt-defs.h | 10 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 24 | ||||
| -rw-r--r-- | source/slang/syntax.h | 9 | ||||
| -rw-r--r-- | source/slang/type-layout.h | 31 | ||||
| -rw-r--r-- | source/slang/val-defs.h | 3 |
17 files changed, 2182 insertions, 352 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index a3e061a3b..a79af3f37 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -1581,11 +1581,16 @@ namespace Slang DeclVisitor::dispatch(stmt->decl); } - void visit(BlockStatementSyntaxNode *stmt) + void visit(BlockStmt* stmt) { - for (auto & node : stmt->Statements) + checkStmt(stmt->body); + } + + void visit(SeqStmt* stmt) + { + for(auto ss : stmt->stmts) { - checkStmt(node); + checkStmt(ss); } } @@ -2414,6 +2419,24 @@ namespace Slang return appExpr; } + // + + RefPtr<ExpressionSyntaxNode> visit(AssignExpr* expr) + { + expr->left = CheckExpr(expr->left); + + auto type = expr->left->Type; + + expr->right = Coerce(type, CheckTerm(expr->right)); + + if (!type.IsLeftValue) + { + getSink()->diagnose(expr, Diagnostics::assignNonLValue); + } + expr->Type = type; + return expr; + } + // diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp index 9132624f9..19623a7aa 100644 --- a/source/slang/compiler.cpp +++ b/source/slang/compiler.cpp @@ -55,10 +55,11 @@ namespace Slang // - String emitHLSLForTranslationUnit( - TranslationUnitRequest* translationUnit) + String emitHLSLForEntryPoint( + EntryPointRequest* entryPoint) { - auto compileRequest = translationUnit->compileRequest; + auto compileRequest = entryPoint->compileRequest; + auto translationUnit = entryPoint->getTranslationUnit(); if (compileRequest->passThrough != PassThroughMode::None) { // Generate a string that includes the content of @@ -92,8 +93,8 @@ namespace Slang } else { - return emitProgram( - translationUnit->SyntaxNode.Ptr(), + return emitEntryPoint( + entryPoint, compileRequest->layout.Ptr(), CodeGenTarget::HLSL); } @@ -137,8 +138,8 @@ namespace Slang { // TODO(tfoley): need to pass along the entry point // so that we properly emit it as the `main` function. - return emitProgram( - translationUnit->SyntaxNode.Ptr(), + return emitEntryPoint( + entryPoint, compileRequest->layout.Ptr(), CodeGenTarget::GLSL); } @@ -180,15 +181,7 @@ namespace Slang assert(D3DCompile_); } - // The HLSL compiler will try to "canonicalize" our input file path, - // and we don't want it to do that, because they it won't report - // the same locations on error messages that we would. - // - // To work around that, we prepend a custom `#line` directive. - - auto translationUnit = entryPoint->getTranslationUnit(); - - auto hlslCode = emitHLSLForTranslationUnit(translationUnit); + auto hlslCode = emitHLSLForEntryPoint(entryPoint); ID3DBlob* codeBlob; ID3DBlob* diagnosticsBlob; @@ -223,6 +216,7 @@ namespace Slang if (FAILED(hr)) { // TODO(tfoley): What to do on failure? + exit(1); } return data; } @@ -400,6 +394,13 @@ namespace Slang switch (compileRequest->Target) { + case CodeGenTarget::HLSL: + { + String code = emitHLSLForEntryPoint(entryPoint); + result.outputSource = code; + } + break; + case CodeGenTarget::GLSL: { String code = emitGLSLForEntryPoint(entryPoint); @@ -505,45 +506,7 @@ namespace Slang TranslationUnitResult emitTranslationUnit( TranslationUnitRequest* translationUnit) { - auto compileRequest = translationUnit->compileRequest; - - // Most of our code generation targets will require us - // to proceed through one entry point at a time, but - // in some cases we can emit an entire translation unit - // in one go. - - switch (compileRequest->Target) - { - default: - // The default behavior is going to loop over all the entry - // points, and then collect an aggregate result. - return emitTranslationUnitEntryPoints(translationUnit); - - case CodeGenTarget::HLSL: - // When targetting HLSL, we can emit the entire translation unit - // as a single HLSL program, and include all the entry points. - { - - String hlsl = emitHLSLForTranslationUnit(translationUnit); - - TranslationUnitResult result; - result.outputSource = hlsl; - - // Because the user might ask for per-entry-point source, - // we will just attach the same string as the result for - // each entry point. - for( auto& entryPoint : translationUnit->entryPoints ) - { - EntryPointResult entryPointResult; - entryPointResult.outputSource = hlsl; - - entryPoint->result = entryPointResult; - } - - return result; - } - break; - } + return emitTranslationUnitEntryPoints(translationUnit); } #if 0 diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index dc7c68aa4..71b39aabb 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -1,6 +1,7 @@ // emit.cpp #include "emit.h" +#include "lower.h" #include "syntax.h" #include "type-layout.h" @@ -13,8 +14,16 @@ namespace Slang { -struct EmitContext +// Shared state for an entire emit session +struct SharedEmitContext { + // The target language we want to generate code for + CodeGenTarget target; + + // A set of words reserved by the target + Dictionary<String, String> reservedWords; + + // The string of code we've built so far StringBuilder sb; // Current source position for tracking purposes... @@ -22,12 +31,6 @@ struct EmitContext CodePosition nextSourceLocation; bool needToUpdateSourceLocation; - // The target language we want to generate code for - CodeGenTarget target; - - // A set of words reserved by the target - Dictionary<String, String> reservedWords; - // For GLSL output, we can't emit traidtional `#line` directives // with a file path in them, so we maintain a map that associates // each path with a unique integer, and then we output those @@ -45,6 +48,18 @@ struct EmitContext // TODO: This will probably change if we represent imports // explicitly in the layout data. StructTypeLayout* globalStructLayout; + + ProgramLayout* programLayout; +}; + +struct EmitContext +{ + // The shared context that is in effect + SharedEmitContext* shared; + + // Are we in "rewrite" mode, where we are trying to reproduce the input + // code as closely as posible? + bool isRewrite; }; // @@ -79,7 +94,7 @@ static void emitRawTextSpan(EmitContext* context, char const* textBegin, char co // TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things... auto len = int(textEnd - textBegin); - context->sb.Append(textBegin, len); + context->shared->sb.Append(textBegin, len); } static void emitRawText(EmitContext* context, char const* text) @@ -99,7 +114,7 @@ static void emitTextSpan(EmitContext* context, char const* textBegin, char const // Update our logical position // TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things... auto len = int(textEnd - textBegin); - context->loc.Col += len; + context->shared->loc.Col += len; } static void Emit(EmitContext* context, char const* textBegin, char const* textEnd) @@ -123,8 +138,8 @@ static void Emit(EmitContext* context, char const* textBegin, char const* textEn // At the end of a line, we need to update our tracking // information on code positions emitTextSpan(context, spanBegin, spanEnd); - context->loc.Line++; - context->loc.Col = 1; + context->shared->loc.Line++; + context->shared->loc.Col = 1; // Start a new span for emit purposes spanBegin = spanEnd; @@ -144,7 +159,7 @@ static void emit(EmitContext* context, String const& text) static bool isReservedWord(EmitContext* context, String const& name) { - return context->reservedWords.TryGetValue(name) != nullptr; + return context->shared->reservedWords.TryGetValue(name) != nullptr; } static void emitName( @@ -449,7 +464,7 @@ static bool isTargetIntrinsicModifierApplicable( // we expect. auto const& targetName = targetToken.Content; - switch(context->target) + switch(context->shared->target) { default: assert(!"unexpected"); @@ -658,7 +673,7 @@ static void emitCallExpr( case IntrinsicOp::InnerProduct_Vector_Vector: // HLSL allows `mul()` to be used as a synonym for `dot()`, // so we need to translate to `dot` for GLSL - if (context->target == CodeGenTarget::GLSL) + if (context->shared->target == CodeGenTarget::GLSL) { Emit(context, "dot("); EmitExpr(context, callExpr->Arguments[0]); @@ -677,7 +692,7 @@ static void emitCallExpr( // // The other critical detail here is that the way we handle matrix // conventions requires that the operands to the product be swapped. - if (context->target == CodeGenTarget::GLSL) + if (context->shared->target == CodeGenTarget::GLSL) { Emit(context, "(("); EmitExpr(context, callExpr->Arguments[1]); @@ -825,6 +840,13 @@ static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntax Emit(context, " : "); EmitExprWithPrecedence(context, selectExpr->Arguments[2], kPrecedence_Conditional); } + else if (auto assignExpr = expr.As<AssignExpr>()) + { + needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Assign); + EmitExprWithPrecedence(context, assignExpr->left, kPrecedence_Assign); + Emit(context, " = "); + EmitExprWithPrecedence(context, assignExpr->right, kPrecedence_Assign); + } else if (auto callExpr = expr.As<InvokeExpressionSyntaxNode>()) { emitCallExpr(context, callExpr, outerPrec); @@ -970,7 +992,7 @@ static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntax } else if (auto castExpr = expr.As<TypeCastExpressionSyntaxNode>()) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::GLSL: // GLSL requires constructor syntax for all conversions @@ -1227,7 +1249,7 @@ static void emitTextureType( EmitContext* context, RefPtr<TextureType> texType) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::HLSL: emitHLSLTextureType(context, texType); @@ -1247,7 +1269,7 @@ static void emitTextureSamplerType( EmitContext* context, RefPtr<TextureSamplerType> type) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::GLSL: emitGLSLTextureSamplerType(context, type); @@ -1263,7 +1285,7 @@ static void emitImageType( EmitContext* context, RefPtr<GLSLImageType> type) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::HLSL: emitHLSLTextureType(context, type); @@ -1300,7 +1322,7 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclara } else if (auto vecType = type->As<VectorExpressionType>()) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::GLSL: case CodeGenTarget::GLSL_Vulkan: @@ -1331,7 +1353,7 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclara } else if (auto matType = type->As<MatrixExpressionType>()) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::GLSL: case CodeGenTarget::GLSL_Vulkan: @@ -1386,7 +1408,7 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclara } else if (auto samplerStateType = type->As<SamplerStateType>()) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::HLSL: default: @@ -1474,7 +1496,8 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type) static void EmitType(EmitContext* context, TypeExp const& typeExp, Token const& nameToken) { - EmitType(context, typeExp.type, typeExp.exp->Position, nameToken.Content, nameToken.Position); + EmitType(context, typeExp.type, + typeExp.exp ? typeExp.exp->Position : CodePosition(), nameToken.Content, nameToken.Position); } static void EmitType(EmitContext* context, TypeExp const& typeExp, String const& name) @@ -1497,12 +1520,9 @@ static void EmitBlockStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt { // TODO(tfoley): support indenting Emit(context, "{\n"); - if( auto blockStmt = stmt.As<BlockStatementSyntaxNode>() ) + if( auto blockStmt = stmt.As<BlockStmt>() ) { - for (auto s : blockStmt->Statements) - { - EmitStmt(context, s); - } + EmitStmt(context, blockStmt->body); } else { @@ -1544,7 +1564,7 @@ static void emitLineDirective( emitRawText(context, " "); - if(context->target == CodeGenTarget::GLSL) + if(context->shared->target == CodeGenTarget::GLSL) { auto path = sourceLocation.FileName; @@ -1556,10 +1576,10 @@ static void emitLineDirective( // extension and then emit a traditional line directive. int id = 0; - if(!context->mapGLSLSourcePathToID.TryGetValue(path, id)) + if(!context->shared->mapGLSLSourcePathToID.TryGetValue(path, id)) { - id = context->glslSourceIDCount++; - context->mapGLSLSourcePathToID.Add(path, id); + id = context->shared->glslSourceIDCount++; + context->shared->mapGLSLSourcePathToID.Add(path, id); } sprintf(buffer, "%d", id); @@ -1611,9 +1631,9 @@ static void emitLineDirectiveAndUpdateSourceLocation( { emitLineDirective(context, sourceLocation); - context->loc.FileName = sourceLocation.FileName; - context->loc.Line = sourceLocation.Line; - context->loc.Col = 1; + context->shared->loc.FileName = sourceLocation.FileName; + context->shared->loc.Line = sourceLocation.Line; + context->shared->loc.Col = 1; } static void emitLineDirectiveIfNeeded( @@ -1628,24 +1648,24 @@ static void emitLineDirectiveIfNeeded( // a differnet file or line, *or* if the source location is // somehow later on the line than what we want to emit, // then we need to emit a new `#line` directive. - if(sourceLocation.FileName != context->loc.FileName - || sourceLocation.Line != context->loc.Line - || sourceLocation.Col < context->loc.Col) + if(sourceLocation.FileName != context->shared->loc.FileName + || sourceLocation.Line != context->shared->loc.Line + || sourceLocation.Col < context->shared->loc.Col) { // Special case: if we are in the same file, and within a small number // of lines of the target location, then go ahead and output newlines // to get us caught up. enum { kSmallLineCount = 3 }; - auto lineDiff = sourceLocation.Line - context->loc.Line; - if(sourceLocation.FileName == context->loc.FileName - && sourceLocation.Line > context->loc.Line + auto lineDiff = sourceLocation.Line - context->shared->loc.Line; + if(sourceLocation.FileName == context->shared->loc.FileName + && sourceLocation.Line > context->shared->loc.Line && lineDiff <= kSmallLineCount) { for(int ii = 0; ii < lineDiff; ++ii ) { Emit(context, "\n"); } - assert(sourceLocation.Line == context->loc.Line); + assert(sourceLocation.Line == context->shared->loc.Line); } else { @@ -1661,14 +1681,14 @@ static void emitLineDirectiveIfNeeded( // came in as spaces or tabs, so there is necessarily going to be // coupling between how the downstream compiler counts columns, // and how we do. - if(sourceLocation.Col > context->loc.Col) + if(sourceLocation.Col > context->shared->loc.Col) { - int delta = sourceLocation.Col - context->loc.Col; + int delta = sourceLocation.Col - context->shared->loc.Col; for( int ii = 0; ii < delta; ++ii ) { emitRawText(context, " "); } - context->loc.Col = sourceLocation.Col; + context->shared->loc.Col = sourceLocation.Col; } } @@ -1680,22 +1700,22 @@ static void advanceToSourceLocation( if(sourceLocation.Line <= 0) return; - context->needToUpdateSourceLocation = true; - context->nextSourceLocation = sourceLocation; + context->shared->needToUpdateSourceLocation = true; + context->shared->nextSourceLocation = sourceLocation; } static void flushSourceLocationChange( EmitContext* context) { - if(!context->needToUpdateSourceLocation) + if(!context->shared->needToUpdateSourceLocation) return; // Note: the order matters here, because trying to update // the source location may involve outputting text that // advances the location, and outputting text is what // triggers this flush operation. - context->needToUpdateSourceLocation = false; - emitLineDirectiveIfNeeded(context, context->nextSourceLocation); + context->shared->needToUpdateSourceLocation = false; + emitLineDirectiveIfNeeded(context, context->shared->nextSourceLocation); } static void emitTokenWithLocation(EmitContext* context, Token const& token) @@ -1737,11 +1757,19 @@ static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt) // Try to ensure that debugging can find the right location advanceToSourceLocation(context, stmt->Position); - if (auto blockStmt = stmt.As<BlockStatementSyntaxNode>()) + if (auto blockStmt = stmt.As<BlockStmt>()) { EmitBlockStmt(context, blockStmt); return; } + else if (auto seqStmt = stmt.As<SeqStmt>()) + { + for (auto ss : seqStmt->stmts) + { + EmitStmt(context, ss); + } + return; + } else if( auto unparsedStmt = stmt.As<UnparsedStmt>() ) { EmitUnparsedStmt(context, unparsedStmt); @@ -1784,17 +1812,43 @@ static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt) } else if (auto forStmt = stmt.As<ForStatementSyntaxNode>()) { - EmitLoopAttributes(context, forStmt); + // We are going to always take a `for` loop like: + // + // for(A; B; C) { D } + // + // and emit it as: + // + // { A; for(; B; C) { D } } + // + // This ensures that we are robust against any kind + // of statement appearing in `A`, including things + // that might occur due to lowering steps. + // - Emit(context, "for("); - if (auto initStmt = forStmt->InitialStatement) + // The one wrinkle is that HLSL implements the + // bad approach to scoping a `for` loop variable, + // so we need to avoid those outer `{...}` when + // we are generating HLSL via "rewrite" (that is, + // without our semantic checks). + // + bool brokenScoping = false; + if (context->shared->target == CodeGenTarget::HLSL + && context->isRewrite) { - EmitStmt(context, initStmt); + brokenScoping = true; } - else + + auto initStmt = forStmt->InitialStatement; + if(initStmt) { - Emit(context, ";"); + if(!brokenScoping) + Emit(context, "{\n"); + EmitStmt(context, initStmt); } + + EmitLoopAttributes(context, forStmt); + + Emit(context, "for(;"); if (auto testExp = forStmt->PredicateExpression) { EmitExpr(context, testExp); @@ -1806,6 +1860,13 @@ static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt) } Emit(context, ")\n"); EmitBlockStmt(context, forStmt->Statement); + + if (initStmt) + { + if(!brokenScoping) + Emit(context, "}\n"); + } + return; } else if (auto whileStmt = stmt.As<WhileStatementSyntaxNode>()) @@ -2085,7 +2146,7 @@ static void EmitSemantic(EmitContext* context, RefPtr<HLSLSemantic> semantic, ES static void EmitSemantics(EmitContext* context, RefPtr<Decl> decl, ESemanticMask mask = kESemanticMask_Default ) { // Don't emit semantics if we aren't translating down to HLSL - switch (context->target) + switch (context->shared->target) { case CodeGenTarget::HLSL: break; @@ -2263,7 +2324,7 @@ static void emitHLSLRegisterSemantics( { if (!layout) return; - switch( context->target ) + switch( context->shared->target ) { default: return; @@ -2278,6 +2339,25 @@ static void emitHLSLRegisterSemantics( } } +static RefPtr<VarLayout> maybeFetchLayout( + RefPtr<Decl> decl, + RefPtr<VarLayout> layout) +{ + // If we have already found layout info, don't go searching + if (layout) return layout; + + // Otherwise, we need to look and see if computed layout + // information has been attached to the declaration. + auto modifier = decl->FindModifier<ComputedLayoutModifier>(); + if (!modifier) return nullptr; + + auto computedLayout = modifier->layout; + assert(computedLayout); + + auto varLayout = computedLayout.As<VarLayout>(); + return varLayout; +} + static void emitHLSLParameterBlockDecl( EmitContext* context, RefPtr<VarDeclBase> varDecl, @@ -2292,6 +2372,7 @@ static void emitHLSLParameterBlockDecl( assert(declRefType); // We expect to always have layout information + layout = maybeFetchLayout(varDecl, layout); assert(layout); // We expect the layout to be for a structured type... @@ -2325,13 +2406,16 @@ static void emitHLSLParameterBlockDecl( Emit(context, "\n{\n"); if (auto structRef = declRefType->declRef.As<StructSyntaxNode>()) { + int fieldCounter = 0; + for (auto field : getMembersOfType<StructField>(structRef)) { + int fieldIndex = fieldCounter++; + EmitVarDeclCommon(context, field); - RefPtr<VarLayout> fieldLayout; - structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout); - assert(fieldLayout); + RefPtr<VarLayout> fieldLayout = structTypeLayout->fields[fieldIndex]; + assert(fieldLayout->varDecl.GetName() == field.GetName()); // Emit explicit layout annotations for every field for( auto rr : fieldLayout->resourceInfos ) @@ -2415,7 +2499,7 @@ emitGLSLLayoutQualifiers( { if(!layout) return; - switch( context->target ) + switch( context->shared->target ) { default: return; @@ -2493,7 +2577,7 @@ static void emitGLSLParameterBlockDecl( { RefPtr<VarLayout> fieldLayout; structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout); - assert(fieldLayout); +// assert(fieldLayout); // TODO(tfoley): We may want to emit *some* of these, // some of the time... @@ -2521,7 +2605,7 @@ static void emitParameterBlockDecl( RefPtr<ParameterBlockType> parameterBlockType, RefPtr<VarLayout> layout) { - switch(context->target) + switch(context->shared->target) { case CodeGenTarget::HLSL: emitHLSLParameterBlockDecl(context, varDecl, parameterBlockType, layout); @@ -2539,6 +2623,8 @@ static void emitParameterBlockDecl( static void EmitVarDecl(EmitContext* context, RefPtr<VarDeclBase> decl, RefPtr<VarLayout> layout) { + layout = maybeFetchLayout(decl, layout); + // As a special case, a variable using a parameter block type // will be translated into a declaration using the more primitive // language syntax. @@ -2606,7 +2692,7 @@ static void emitGLSLPreprocessorDirectives( EmitContext* context, RefPtr<ProgramSyntaxNode> program) { - switch(context->target) + switch(context->shared->target) { // Don't emit this stuff unless we are targetting GLSL default: @@ -2658,78 +2744,6 @@ static void emitGLSLPreprocessorDirectives( // TODO: handle other cases... } -static void EmitProgram( - EmitContext* context, - RefPtr<ProgramSyntaxNode> program, - RefPtr<ProgramLayout> programLayout) -{ - // There may be global-scope modifiers that we should emit now - emitGLSLPreprocessorDirectives(context, program); - - switch(context->target) - { - case CodeGenTarget::GLSL: - { - // TODO(tfoley): Need a plan for how to enable/disable these as needed... -// Emit(context, "#extension GL_GOOGLE_cpp_style_line_directive : require\n"); - } - break; - - default: - break; - } - - - // Layout information for the global scope is either an ordinary - // `struct` in the common case, or a constant buffer in the case - // where there were global-scope uniforms. - auto globalScopeLayout = programLayout->globalScopeLayout; - if( auto globalStructLayout = globalScopeLayout.As<StructTypeLayout>() ) - { - context->globalStructLayout = globalStructLayout.Ptr(); - - // The `struct` case is easy enough to handle: we just - // emit all the declarations directly, using their layout - // information as a guideline. - EmitDeclsInContainerUsingLayout(context, program, globalStructLayout); - } - else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>()) - { - // TODO: the `cbuffer` case really needs to be emitted very - // carefully, but that is beyond the scope of what a simple rewriter - // can easily do (without semantic analysis, etc.). - // - // The crux of the problem is that we need to collect all the - // global-scope uniforms (but not declarations that don't involve - // uniform storage...) and put them in a single `cbuffer` declaration, - // so that we can give it an explicit location. The fields in that - // declaration might use various type declarations, so we'd really - // need to emit all the type declarations first, and that involves - // some large scale reorderings. - // - // For now we will punt and just emit the declarations normally, - // and hope that the global-scope block (`$Globals`) gets auto-assigned - // the same location that we manually asigned it. - - auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout; - auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>(); - - // We expect all constant buffers to contain `struct` types for now - assert(elementTypeStructLayout); - - context->globalStructLayout = elementTypeStructLayout.Ptr(); - - EmitDeclsInContainerUsingLayout( - context, - program, - elementTypeStructLayout); - } - else - { - assert(!"unexpected"); - } -} - static void EmitDeclImpl(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayout> layout) { // Don't emit code for declarations that came from the stdlib. @@ -2783,18 +2797,18 @@ static void EmitDeclImpl(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayo // We might import the same module along two different paths, // so we need to be careful to only emit each module once // per output. - if(!context->modulesAlreadyEmitted.Contains(moduleDecl)) + if(!context->shared->modulesAlreadyEmitted.Contains(moduleDecl)) { // Add the module to our set before emitting it, just // in case a circular reference would lead us to // infinite recursion (but that shouldn't be allowed // in the first place). - context->modulesAlreadyEmitted.Add(moduleDecl); + context->shared->modulesAlreadyEmitted.Add(moduleDecl); // TODO: do we need to modify the code generation environment at // all when doing this recursive emit? - EmitDeclsInContainerUsingLayout(context, moduleDecl, context->globalStructLayout); + EmitDeclsInContainerUsingLayout(context, moduleDecl, context->shared->globalStructLayout); } return; @@ -2839,7 +2853,7 @@ static void registerReservedWord( EmitContext* context, String const& name) { - context->reservedWords.Add(name, name); + context->shared->reservedWords.Add(name, name); } static void registerReservedWords( @@ -2847,7 +2861,7 @@ static void registerReservedWords( { #define WORD(NAME) registerReservedWord(context, #NAME) - switch (context->target) + switch (context->shared->target) { case CodeGenTarget::GLSL: WORD(attribute); @@ -2990,68 +3004,115 @@ static void registerReservedWords( } } -String emitProgram( - ProgramSyntaxNode* program, +bool isRewriteRequest( + SourceLanguage sourceLanguage, + CodeGenTarget target); + +String emitEntryPoint( + EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target) { - // TODO(tfoley): only emit symbols on-demand, as needed by a particular entry point + auto translationUnit = entryPoint->getTranslationUnit(); + + SharedEmitContext sharedContext; + sharedContext.target = target; + + sharedContext.programLayout = programLayout; + + // Layout information for the global scope is either an ordinary + // `struct` in the common case, or a constant buffer in the case + // where there were global-scope uniforms. + auto globalScopeLayout = programLayout->globalScopeLayout; + StructTypeLayout* globalStructLayout = nullptr; + if( auto globalStructLayout = globalScopeLayout.As<StructTypeLayout>() ) + { + globalStructLayout = globalStructLayout.Ptr(); + } + else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>()) + { + // TODO: the `cbuffer` case really needs to be emitted very + // carefully, but that is beyond the scope of what a simple rewriter + // can easily do (without semantic analysis, etc.). + // + // The crux of the problem is that we need to collect all the + // global-scope uniforms (but not declarations that don't involve + // uniform storage...) and put them in a single `cbuffer` declaration, + // so that we can give it an explicit location. The fields in that + // declaration might use various type declarations, so we'd really + // need to emit all the type declarations first, and that involves + // some large scale reorderings. + // + // For now we will punt and just emit the declarations normally, + // and hope that the global-scope block (`$Globals`) gets auto-assigned + // the same location that we manually asigned it. + + auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout; + auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>(); + + // We expect all constant buffers to contain `struct` types for now + assert(elementTypeStructLayout); + + globalStructLayout = elementTypeStructLayout.Ptr(); + } + else + { + assert(!"unexpected"); + } + sharedContext.globalStructLayout = globalStructLayout; EmitContext context; - context.target = target; + context.shared = &sharedContext; + context.isRewrite = isRewriteRequest( + translationUnit->sourceLanguage, + target); + // TODO: this should only need to take the shared context registerReservedWords(&context); - EmitProgram(&context, program, programLayout); + auto translationUnitSyntax = translationUnit->SyntaxNode.Ptr(); - String code = context.sb.ProduceString(); - return code; - -#if 0 - // HACK(tfoley): Invoke the D3D HLSL compiler on the result, to validate it + // There may be global-scope modifiers that we should emit now + emitGLSLPreprocessorDirectives(&context, translationUnitSyntax); -#ifdef _WIN32 + switch(target) { - HMODULE d3dCompiler = LoadLibraryA("d3dcompiler_47"); - assert(d3dCompiler); - - pD3DCompile D3DCompile_ = (pD3DCompile)GetProcAddress(d3dCompiler, "D3DCompile"); - assert(D3DCompile_); - - ID3DBlob* codeBlob; - ID3DBlob* diagnosticsBlob; - HRESULT hr = D3DCompile_( - code.begin(), - code.Length(), - "slang", - nullptr, - nullptr, - "main", - "ps_5_0", - 0, - 0, - &codeBlob, - &diagnosticsBlob); - if (codeBlob) codeBlob->Release(); - if (diagnosticsBlob) - { - String diagnostics = (char const*) diagnosticsBlob->GetBufferPointer(); - fprintf(stderr, "%s", diagnostics.begin()); - OutputDebugStringA(diagnostics.begin()); - diagnosticsBlob->Release(); - } - if (FAILED(hr)) + case CodeGenTarget::GLSL: { - int f = 9; + // TODO(tfoley): Need a plan for how to enable/disable these as needed... +// Emit(context, "#extension GL_GOOGLE_cpp_style_line_directive : require\n"); } + break; + + default: + break; } - #include <d3dcompiler.h> -#endif + auto lowered = lowerEntryPoint(entryPoint, programLayout, target); + + EmitDeclsInContainer(&context, lowered.program.Ptr()); + +#if 0 + if( isRewrite ) + { + // In rewrite mode, we will just emit the text of the translation unit as given, + // and not pay attention to the specific entry point that was requested. + // + // It is a user error to request GLSL output and have an entry point name + // other than `main`. + EmitDeclsInContainerUsingLayout(&context, translationUnitSyntax, globalStructLayout); + } + else + { + // We are being asked to emit a single entry point in "full" mode. + emitEntryPoint(&context, entryPoint); + } #endif -} + String code = sharedContext.sb.ProduceString(); + return code; +} } // namespace Slang diff --git a/source/slang/emit.h b/source/slang/emit.h index 4f3e98c8d..da1ac9f08 100644 --- a/source/slang/emit.h +++ b/source/slang/emit.h @@ -8,11 +8,14 @@ namespace Slang { - class ProgramSyntaxNode; + class EntryPointRequest; class ProgramLayout; + class TranslationUnitRequest; - String emitProgram( - ProgramSyntaxNode* program, + // Emit code for a single entry point, based on + // the input translation unit. + String emitEntryPoint( + EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target); } diff --git a/source/slang/expr-defs.h b/source/slang/expr-defs.h index 59bad37e7..ca5bfacb8 100644 --- a/source/slang/expr-defs.h +++ b/source/slang/expr-defs.h @@ -107,3 +107,8 @@ SYNTAX_CLASS(SharedTypeExpr, ExpressionSyntaxNode) SYNTAX_FIELD(TypeExp, base) END_SYNTAX_CLASS() +SYNTAX_CLASS(AssignExpr, ExpressionSyntaxNode) + SYNTAX_FIELD(RefPtr<ExpressionSyntaxNode>, left); + SYNTAX_FIELD(RefPtr<ExpressionSyntaxNode>, right); +END_SYNTAX_CLASS() + diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp new file mode 100644 index 000000000..674614cd3 --- /dev/null +++ b/source/slang/lower.cpp @@ -0,0 +1,1601 @@ +// lower.cpp +#include "lower.h" + +#include "visitor.h" + +namespace Slang +{ + +// + +template<typename V> +struct StructuralTransformVisitorBase +{ + V* visitor; + + RefPtr<StatementSyntaxNode> transformDeclField(StatementSyntaxNode* stmt) + { + return visitor->translateStmtRef(stmt); + } + + RefPtr<Decl> transformDeclField(Decl* decl) + { + return visitor->translateDeclRef(decl); + } + + template<typename T> + DeclRef<T> transformDeclField(DeclRef<T> const& decl) + { + return visitor->translateDeclRef(decl).As<T>(); + } + + TypeExp transformSyntaxField(TypeExp const& typeExp) + { + TypeExp result; + result.type = visitor->transformSyntaxField(typeExp.type); + return result; + } + + QualType transformSyntaxField(QualType const& qualType) + { + QualType result = qualType; + result.type = visitor->transformSyntaxField(qualType.type); + return result; + } + + RefPtr<ExpressionSyntaxNode> transformSyntaxField(ExpressionSyntaxNode* expr) + { + return visitor->transformSyntaxField(expr); + } + + RefPtr<StatementSyntaxNode> transformSyntaxField(StatementSyntaxNode* stmt) + { + return visitor->transformSyntaxField(stmt); + } + + RefPtr<DeclBase> transformSyntaxField(DeclBase* decl) + { + return visitor->transformSyntaxField(decl); + } + + RefPtr<ScopeDecl> transformSyntaxField(ScopeDecl* decl) + { + return visitor->transformSyntaxField(decl).As<ScopeDecl>(); + } + + template<typename T> + List<T> transformSyntaxField(List<T> const& list) + { + List<T> result; + for (auto item : list) + { + result.Add(transformSyntaxField(item)); + } + return result; + } +}; + +template<typename V> +struct StructuralTransformStmtVisitor + : StructuralTransformVisitorBase<V> + , StmtVisitor<StructuralTransformStmtVisitor<V>, RefPtr<StatementSyntaxNode>> +{ + void transformFields(StatementSyntaxNode* result, StatementSyntaxNode* obj) + { + } + +#define SYNTAX_CLASS(NAME, BASE, ...) \ + RefPtr<StatementSyntaxNode> visit(NAME* obj) { \ + RefPtr<NAME> result = new NAME(*obj); \ + transformFields(result, obj); \ + return result; \ + } \ + void transformFields(NAME* result, NAME* obj) { \ + transformFields((BASE*) result, (BASE*) obj); \ + +#define SYNTAX_FIELD(TYPE, NAME) result->NAME = this->transformSyntaxField(obj->NAME); +#define DECL_FIELD(TYPE, NAME) result->NAME = this->transformDeclField(obj->NAME); + +#define FIELD(TYPE, NAME) /* empty */ + +#define END_SYNTAX_CLASS() \ + } + +#include "object-meta-begin.h" +#include "stmt-defs.h" +#include "object-meta-end.h" + +}; + +template<typename V> +RefPtr<StatementSyntaxNode> structuralTransform( + StatementSyntaxNode* stmt, + V* visitor) +{ + StructuralTransformStmtVisitor<V> transformer; + transformer.visitor = visitor; + return transformer.dispatch(stmt); +} + +template<typename V> +struct StructuralTransformExprVisitor + : StructuralTransformVisitorBase<V> + , ExprVisitor<StructuralTransformExprVisitor<V>, RefPtr<ExpressionSyntaxNode>> +{ + void transformFields(ExpressionSyntaxNode* result, ExpressionSyntaxNode* obj) + { + result->Type = transformSyntaxField(obj->Type); + } + + +#define SYNTAX_CLASS(NAME, BASE, ...) \ + RefPtr<ExpressionSyntaxNode> visit(NAME* obj) { \ + RefPtr<NAME> result = new NAME(*obj); \ + transformFields(result, obj); \ + return result; \ + } \ + void transformFields(NAME* result, NAME* obj) { \ + transformFields((BASE*) result, (BASE*) obj); \ + +#define SYNTAX_FIELD(TYPE, NAME) result->NAME = transformSyntaxField(obj->NAME); +#define DECL_FIELD(TYPE, NAME) result->NAME = transformDeclField(obj->NAME); + +#define FIELD(TYPE, NAME) /* empty */ + +#define END_SYNTAX_CLASS() \ + } + +#include "object-meta-begin.h" +#include "expr-defs.h" +#include "object-meta-end.h" +}; + + +template<typename V> +RefPtr<ExpressionSyntaxNode> structuralTransform( + ExpressionSyntaxNode* expr, + V* visitor) +{ + StructuralTransformExprVisitor<V> transformer; + transformer.visitor = visitor; + return transformer.dispatch(expr); +} + +// + +// Pseudo-syntax used during lowering +class TupleDecl : public VarDeclBase +{ +public: + virtual void accept(IDeclVisitor *, void *) override + { + throw "unexpected"; + } + + List<RefPtr<VarDeclBase>> decls; +}; + +// Pseudo-syntax used during lowering: +// represents an ordered list of expressions as a single unit +class TupleExpr : public ExpressionSyntaxNode +{ +public: + virtual void accept(IExprVisitor *, void *) override + { + throw "unexpected"; + } + + List<RefPtr<ExpressionSyntaxNode>> exprs; +}; + +struct SharedLoweringContext +{ + ProgramLayout* programLayout; + CodeGenTarget target; + + RefPtr<ProgramSyntaxNode> loweredProgram; + + Dictionary<Decl*, RefPtr<Decl>> loweredDecls; + Dictionary<Decl*, Decl*> mapLoweredDeclToOriginal; + + bool isRewrite; +}; + +static void attachLayout( + ModifiableSyntaxNode* syntax, + Layout* layout) +{ + RefPtr<ComputedLayoutModifier> modifier = new ComputedLayoutModifier(); + modifier->layout = layout; + + addModifier(syntax, modifier); +} + +struct LoweringVisitor + : ExprVisitor<LoweringVisitor, RefPtr<ExpressionSyntaxNode>> + , StmtVisitor<LoweringVisitor, void> + , DeclVisitor<LoweringVisitor, RefPtr<Decl>> + , TypeVisitor<LoweringVisitor, RefPtr<ExpressionType>> + , ValVisitor<LoweringVisitor, RefPtr<Val>> +{ + // + SharedLoweringContext* shared; + RefPtr<Substitutions> substitutions; + + bool isBuildingStmt = false; + RefPtr<StatementSyntaxNode> stmtBeingBuilt; + + // If we *aren't* building a statement, then this + // is the container we should be adding declarations to + RefPtr<ContainerDecl> parentDecl; + + // If we are in a context where a `return` should be turned + // into assignment to a variable (followed by a `return`), + // then this will point to that variable. + RefPtr<Variable> resultVariable; + + CodeGenTarget getTarget() { return shared->target; } + + // + // Values + // + + RefPtr<Val> lowerVal(Val* val) + { + if (!val) return nullptr; + return ValVisitor::dispatch(val); + } + + RefPtr<Val> visit(GenericParamIntVal* val) + { + return new GenericParamIntVal(translateDeclRef(DeclRef<Decl>(val->declRef)).As<VarDeclBase>()); + } + + RefPtr<Val> visit(ConstantIntVal* val) + { + return val; + } + + // + // Types + // + + RefPtr<ExpressionType> lowerType( + ExpressionType* type) + { + return TypeVisitor::dispatch(type); + } + + TypeExp lowerType( + TypeExp const& typeExp) + { + TypeExp result; + result.type = lowerType(typeExp.type); + return result; + } + + RefPtr<ExpressionType> visit(ErrorType* type) + { + return type; + } + + RefPtr<ExpressionType> visit(OverloadGroupType* type) + { + return type; + } + + RefPtr<ExpressionType> visit(InitializerListType* type) + { + return type; + } + + RefPtr<ExpressionType> visit(GenericDeclRefType* type) + { + return new GenericDeclRefType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<GenericDecl>()); + } + + RefPtr<ExpressionType> visit(FuncType* type) + { + RefPtr<FuncType> loweredType = new FuncType(); + loweredType->declRef = translateDeclRef(DeclRef<Decl>(type->declRef)).As<CallableDecl>(); + return loweredType; + } + + RefPtr<ExpressionType> visit(DeclRefType* type) + { + auto loweredDeclRef = translateDeclRef(type->declRef); + return DeclRefType::Create(loweredDeclRef); + } + + RefPtr<ExpressionType> visit(NamedExpressionType* type) + { + if (shared->target == CodeGenTarget::GLSL) + { + // GLSL does not support `typedef`, so we will lower it out of existence here + return lowerType(GetType(type->declRef)); + } + + return new NamedExpressionType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>()); + } + + RefPtr<ExpressionType> visit(TypeType* type) + { + return new TypeType(lowerType(type->type)); + } + + RefPtr<ExpressionType> visit(ArrayExpressionType* type) + { + RefPtr<ArrayExpressionType> loweredType = new ArrayExpressionType(); + loweredType->BaseType = lowerType(type->BaseType); + loweredType->ArrayLength = lowerVal(type->ArrayLength).As<IntVal>(); + return loweredType; + } + + RefPtr<ExpressionType> transformSyntaxField(ExpressionType* type) + { + return lowerType(type); + } + + // + // Expressions + // + + RefPtr<ExpressionSyntaxNode> lowerExpr( + ExpressionSyntaxNode* expr) + { + if (!expr) return nullptr; + return ExprVisitor::dispatch(expr); + } + + // catch-all + RefPtr<ExpressionSyntaxNode> visit( + ExpressionSyntaxNode* expr) + { + return structuralTransform(expr, this); + } + + RefPtr<ExpressionSyntaxNode> transformSyntaxField(ExpressionSyntaxNode* expr) + { + return lowerExpr(expr); + } + + void lowerExprCommon( + RefPtr<ExpressionSyntaxNode> loweredExpr, + RefPtr<ExpressionSyntaxNode> expr) + { + loweredExpr->Position = expr->Position; + loweredExpr->Type.type = lowerType(expr->Type.type); + } + + RefPtr<ExpressionSyntaxNode> createVarRef( + CodePosition const& loc, + VarDeclBase* decl) + { + if (auto tupleDecl = dynamic_cast<TupleDecl*>(decl)) + { + return createTupleRef(loc, tupleDecl); + } + else + { + RefPtr<VarExpressionSyntaxNode> result = new VarExpressionSyntaxNode(); + result->Position = loc; + result->Type.type = decl->Type.type; + result->declRef = makeDeclRef(decl); + return result; + } + } + + RefPtr<ExpressionSyntaxNode> createTupleRef( + CodePosition const& loc, + TupleDecl* decl) + { + RefPtr<TupleExpr> result = new TupleExpr(); + result->Position = loc; + result->Type.type = decl->Type.type; + + for (auto dd : decl->decls) + { + auto expr = createVarRef(loc, dd); + result->exprs.Add(expr); + } + + return result; + } + + RefPtr<ExpressionSyntaxNode> visit( + VarExpressionSyntaxNode* expr) + { + // If the expression didn't get resolved, we can leave it as-is + if (!expr->declRef) + return expr; + + auto loweredDeclRef = translateDeclRef(expr->declRef); + auto loweredDecl = loweredDeclRef.getDecl(); + + if (auto tupleDecl = dynamic_cast<TupleDecl*>(loweredDecl)) + { + // If we are referencing a declaration that got tuple-ified, + // then we need to produce a tuple expression as well. + + return createTupleRef(expr->Position, tupleDecl); + } + + RefPtr<VarExpressionSyntaxNode> loweredExpr = new VarExpressionSyntaxNode(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->declRef = loweredDeclRef; + return loweredExpr; + } + + RefPtr<ExpressionSyntaxNode> visit( + MemberExpressionSyntaxNode* expr) + { + auto loweredBase = lowerExpr(expr->BaseExpression); + + // Are we extracting an element from a tuple? + if (auto baseTuple = loweredBase.As<TupleExpr>()) + { + // We need to find the correct member expression, + // based on the actual tuple type. + + throw "unimplemented"; + } + + // Default handling: + auto loweredDeclRef = translateDeclRef(expr->declRef); + assert(!dynamic_cast<TupleDecl*>(loweredDeclRef.getDecl())); + + RefPtr<MemberExpressionSyntaxNode> loweredExpr = new MemberExpressionSyntaxNode(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->BaseExpression = loweredBase; + loweredExpr->declRef = loweredDeclRef; + + return loweredExpr; + } + + // + // Statements + // + + StatementSyntaxNode* translateStmtRef( + StatementSyntaxNode* stmt) + { + throw "unimplemented"; + } + + RefPtr<StatementSyntaxNode> lowerStmt( + StatementSyntaxNode* stmt) + { + if(!stmt) + return nullptr; + + LoweringVisitor subVisitor = *this; + subVisitor.stmtBeingBuilt = nullptr; + + subVisitor.lowerStmtImpl(stmt); + + if( !subVisitor.stmtBeingBuilt ) + { + return new EmptyStatementSyntaxNode(); + } + else + { + return subVisitor.stmtBeingBuilt; + } + } + + void lowerStmtImpl( + StatementSyntaxNode* stmt) + { + StmtVisitor::dispatch(stmt); + } + + RefPtr<ScopeDecl> visit(ScopeDecl* decl) + { + RefPtr<ScopeDecl> loweredDecl = new ScopeDecl(); + lowerDeclCommon(loweredDecl, decl); + return loweredDecl; + } + + LoweringVisitor pushScope( + RefPtr<ScopeStmt> loweredStmt, + RefPtr<ScopeStmt> stmt) + { + loweredStmt->scopeDecl = translateDeclRef(stmt->scopeDecl).As<ScopeDecl>(); + + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + subVisitor.parentDecl = loweredStmt->scopeDecl; + return subVisitor; + } + + void addStmtImpl( + RefPtr<StatementSyntaxNode>& dest, + StatementSyntaxNode* stmt) + { + // add a statement to the code we are building... + if( !dest ) + { + dest = stmt; + return; + } + + if (auto blockStmt = dest.As<BlockStmt>()) + { + addStmtImpl(blockStmt->body, stmt); + return; + } + + if (auto seqStmt = dest.As<SeqStmt>()) + { + seqStmt->stmts.Add(stmt); + } + else + { + RefPtr<SeqStmt> newSeqStmt = new SeqStmt(); + + newSeqStmt->stmts.Add(dest); + newSeqStmt->stmts.Add(stmt); + + dest = newSeqStmt; + } + + } + + void addStmt( + StatementSyntaxNode* stmt) + { + addStmtImpl(stmtBeingBuilt, stmt); + } + + void addExprStmt( + RefPtr<ExpressionSyntaxNode> expr) + { + // TODO: handle cases where the `expr` cannot be directly + // represented as a single statement + + RefPtr<ExpressionStatementSyntaxNode> stmt = new ExpressionStatementSyntaxNode(); + stmt->Expression = expr; + addStmt(stmt); + } + + void visit(BlockStmt* stmt) + { + RefPtr<BlockStmt> loweredStmt = new BlockStmt(); + + LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); + subVisitor.stmtBeingBuilt = loweredStmt; + + subVisitor.lowerStmtImpl(stmt->body); + + addStmt(loweredStmt); + } + + void visit(SeqStmt* stmt) + { + for( auto ss : stmt->stmts ) + { + lowerStmtImpl(ss); + } + } + + void visit(ExpressionStatementSyntaxNode* stmt) + { + addExprStmt(lowerExpr(stmt->Expression)); + } + + void visit(VarDeclrStatementSyntaxNode* stmt) + { + DeclVisitor::dispatch(stmt->decl); + } + + // catch-all + void visit(StatementSyntaxNode* stmt) + { + auto loweredStmt = structuralTransform(stmt, this); + addStmt(loweredStmt); + } + + RefPtr<StatementSyntaxNode> transformSyntaxField(StatementSyntaxNode* stmt) + { + return lowerStmt(stmt); + } + + void lowerStmtCommon(StatementSyntaxNode* loweredStmt, StatementSyntaxNode* stmt) + { + loweredStmt->modifiers = stmt->modifiers; + } + + void assign( + RefPtr<ExpressionSyntaxNode> destExpr, + RefPtr<ExpressionSyntaxNode> srcExpr) + { + RefPtr<AssignExpr> assignExpr = new AssignExpr(); + assignExpr->Position = destExpr->Position; + assignExpr->left = destExpr; + assignExpr->right = srcExpr; + + addExprStmt(assignExpr); + } + + void assign(VarDeclBase* varDecl, RefPtr<ExpressionSyntaxNode> expr) + { + assign(createVarRef(expr->Position, varDecl), expr); + } + + void assign(RefPtr<ExpressionSyntaxNode> expr, VarDeclBase* varDecl) + { + assign(expr, createVarRef(expr->Position, varDecl)); + } + + void visit(ReturnStatementSyntaxNode* stmt) + { + auto loweredStmt = new ReturnStatementSyntaxNode(); + lowerStmtCommon(loweredStmt, stmt); + + if (stmt->Expression) + { + if (resultVariable) + { + // Do it as an assignment + assign(resultVariable, lowerExpr(stmt->Expression)); + } + else + { + // Simple case + loweredStmt->Expression = lowerExpr(stmt->Expression); + } + } + + addStmt(loweredStmt); + } + + // + // Declarations + // + + RefPtr<Val> translateVal(Val* val) + { + if (auto type = dynamic_cast<ExpressionType*>(val)) + return lowerType(type); + + if (auto litVal = dynamic_cast<ConstantIntVal*>(val)) + return val; + + throw 99; + } + + RefPtr<Substitutions> translateSubstitutions( + Substitutions* substitutions) + { + if (!substitutions) return nullptr; + + RefPtr<Substitutions> result = new Substitutions(); + result->genericDecl = translateDeclRef(substitutions->genericDecl).As<GenericDecl>(); + for (auto arg : substitutions->args) + { + result->args.Add(translateVal(arg)); + } + return result; + } + + static Decl* getModifiedDecl(Decl* decl) + { + if (!decl) return nullptr; + if (auto genericDecl = dynamic_cast<GenericDecl*>(decl->ParentDecl)) + return genericDecl; + return decl; + } + + DeclRef<Decl> translateDeclRef( + DeclRef<Decl> const& decl) + { + DeclRef<Decl> result; + result.decl = translateDeclRef(decl.decl); + result.substitutions = translateSubstitutions(decl.substitutions); + return result; + } + + RefPtr<Decl> translateDeclRef( + Decl* decl) + { + if (!decl) return nullptr; + + // We don't want to translate references to built-in declarations, + // since they won't be subtituted anyway. + if (getModifiedDecl(decl)->HasModifier<FromStdLibModifier>()) + return decl; + + // If any parent of the declaration was in the stdlib, then + // we need to skip it. + for(auto pp = decl; pp; pp = pp->ParentDecl) + { + if (pp->HasModifier<FromStdLibModifier>()) + return decl; + } + + if (getModifiedDecl(decl)->HasModifier<BuiltinModifier>()) + return decl; + + RefPtr<Decl> loweredDecl; + if (shared->loweredDecls.TryGetValue(decl, loweredDecl)) + return loweredDecl; + + // Time to force it + return lowerDecl(decl); + } + + RefPtr<ContainerDecl> translateDeclRef( + ContainerDecl* decl) + { + return translateDeclRef((Decl*)decl).As<ContainerDecl>(); + } + + RefPtr<DeclBase> lowerDeclBase( + DeclBase* declBase) + { + if (Decl* decl = dynamic_cast<Decl*>(declBase)) + { + return lowerDecl(decl); + } + else + { + DeclVisitor::dispatch(declBase); + } + + } + + RefPtr<Decl> lowerDecl( + Decl* decl) + { + RefPtr<Decl> loweredDecl = DeclVisitor::dispatch(decl).As<Decl>(); + return loweredDecl; + } + + static void addMember( + RefPtr<ContainerDecl> containerDecl, + RefPtr<Decl> memberDecl) + { + containerDecl->Members.Add(memberDecl); + memberDecl->ParentDecl = containerDecl.Ptr(); + } + + void addDecl( + Decl* decl) + { + if(isBuildingStmt) + { + RefPtr<VarDeclrStatementSyntaxNode> declStmt = new VarDeclrStatementSyntaxNode(); + declStmt->Position = decl->Position; + declStmt->decl = decl; + addStmt(declStmt); + } + + + // We will add the declaration to the current container declaration being + // translated, which the user will maintain via pua/pop. + // + + assert(parentDecl); + addMember(parentDecl, decl); + } + + void registerLoweredDecl(Decl* loweredDecl, Decl* decl) + { + shared->loweredDecls.Add(decl, loweredDecl); + + shared->mapLoweredDeclToOriginal.Add(loweredDecl, decl); + } + + void lowerDeclCommon( + Decl* loweredDecl, + Decl* decl) + { + registerLoweredDecl(loweredDecl, decl); + + loweredDecl->Position = decl->Position; + loweredDecl->Name = decl->getNameToken(); + + // Lower modifiers as needed + + // HACK: just doing a shallow copy of modifiers, which will + // suffice for most of them, but we need to do something + // better soon. + loweredDecl->modifiers = decl->modifiers; + + // deal with layout stuff + + auto loweredParent = translateDeclRef(decl->ParentDecl); + if (loweredParent) + { + auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>(); + if (layoutMod) + { + auto parentLayout = layoutMod->layout; + if (auto structLayout = parentLayout.As<StructTypeLayout>()) + { + RefPtr<VarLayout> fieldLayout; + if (structLayout->mapVarToLayout.TryGetValue(decl, fieldLayout)) + { + attachLayout(loweredDecl, fieldLayout); + } + } + + // TODO: are there other cases to handle here? + } + } + } + + // Catch-all + RefPtr<Decl> visit( + Decl* decl) + { + assert(!"unimplemented"); + return decl; + } + + RefPtr<ImportDecl> visit(ImportDecl* decl) + { + // No need to translate things here if we are + // in "full" mode, because we will selectively + // translate the imported declarations at their + // use sites(s). + if (!shared->isRewrite) + return nullptr; + + for (auto dd : decl->importedModuleDecl->Members) + { + translateDeclRef(dd); + } + + // Don't actually include a representation of + // the import declaration in the output + return nullptr; + } + + RefPtr<EmptyDecl> visit(EmptyDecl* decl) + { + // Empty declarations are really only useful in GLSL, + // where they are used to hold metadata that doesn't + // attach to any particular shader parameter. + // + // TODO: Only lower empty declarations if we are + // rewriting a GLSL file, and otherwise ignore them. + // + RefPtr<EmptyDecl> loweredDecl = new EmptyDecl(); + lowerDeclCommon(loweredDecl, decl); + + addDecl(loweredDecl); + + return loweredDecl; + } + + RefPtr<Decl> visit(AggTypeDecl* decl) + { + // We want to lower any aggregate type declaration + // to just a `struct` type that contains its fields. + // + // Any non-field members (e.g., methods) will be + // lowered separately. + + // TODO: also need to figure out how to handle fields + // with types that should not be allowed in a `struct` + // for the chosen target. + // (also: what to do if there are no fields left + // after removing invalid ones?) + + RefPtr<StructSyntaxNode> loweredDecl = new StructSyntaxNode(); + lowerDeclCommon(loweredDecl, decl); + + for (auto field : decl->getMembersOfType<VarDeclBase>()) + { + // TODO: anything more to do than this? + addMember(loweredDecl, translateDeclRef(field)); + } + + addMember( + shared->loweredProgram, + loweredDecl); + + return loweredDecl; + } + + RefPtr<VarDeclBase> lowerVarDeclCommon( + RefPtr<VarDeclBase> loweredDecl, + VarDeclBase* decl) + { + lowerDeclCommon(loweredDecl, decl); + + loweredDecl->Type = lowerType(decl->Type); + loweredDecl->Expr = lowerExpr(decl->Expr); + + return loweredDecl; + } + + RefPtr<VarDeclBase> visit( + Variable* decl) + { + auto loweredDecl = lowerVarDeclCommon(new Variable(), decl); + + // We need to add things to an appropriate scope, based on what + // we are referencing. + // + // If this is a global variable (program scope), then add it + // to the global scope. + RefPtr<ContainerDecl> parentDecl = decl->ParentDecl; + if (auto parentModuleDecl = parentDecl.As<ProgramSyntaxNode>()) + { + addMember( + translateDeclRef(parentModuleDecl), + loweredDecl); + } + // TODO: handle `static` function-scope variables + else + { + // A local variable declaration will get added to the + // statement scope we are currently processing. + addDecl(loweredDecl); + } + + return loweredDecl; + } + + RefPtr<VarDeclBase> visit( + StructField* decl) + { + return lowerVarDeclCommon(new StructField(), decl); + } + + RefPtr<VarDeclBase> visit( + ParameterSyntaxNode* decl) + { + return lowerVarDeclCommon(new ParameterSyntaxNode(), decl); + } + + RefPtr<DeclBase> transformSyntaxField(DeclBase* decl) + { + return lowerDeclBase(decl); + } + + + RefPtr<Decl> visit( + DeclGroup* group) + { + for (auto decl : group->decls) + { + lowerDecl(decl); + } + return nullptr; + } + + RefPtr<FunctionSyntaxNode> visit( + FunctionDeclBase* decl) + { + // TODO: need to generate a name + + RefPtr<FunctionSyntaxNode> loweredDecl = new FunctionSyntaxNode(); + lowerDeclCommon(loweredDecl, decl); + + // TODO: push scope for parent decl here... + + // TODO: need to copy over relevant modifiers + + for (auto paramDecl : decl->GetParameters()) + { + addMember(loweredDecl, translateDeclRef(paramDecl)); + } + + auto loweredReturnType = lowerType(decl->ReturnType); + + loweredDecl->ReturnType = loweredReturnType; + + // If we are a being called recurisvely, then we need to + // be careful not to let the context get polluted + LoweringVisitor subVisitor = *this; + subVisitor.resultVariable = nullptr; + subVisitor.stmtBeingBuilt = nullptr; + + loweredDecl->Body = subVisitor.lowerStmt(decl->Body); + + // A lowered function always becomes a global-scope function, + // even if it had been a member function when declared. + addMember(shared->loweredProgram, loweredDecl); + + return loweredDecl; + } + + // + // Entry Points + // + + EntryPointLayout* findEntryPointLayout( + EntryPointRequest* entryPointRequest) + { + for( auto entryPointLayout : shared->programLayout->entryPoints ) + { + if(entryPointLayout->entryPoint->getName() != entryPointRequest->name) + continue; + + if(entryPointLayout->profile != entryPointRequest->profile) + continue; + + // TODO: can't easily filter on translation unit here... + // Ideally the `EntryPointRequest` should get filled in with a pointer + // the specific function declaration that represents the entry point. + + return entryPointLayout.Ptr(); + } + + return nullptr; + } + + enum class VaryingParameterDirection + { + Input, + Output, + }; + + struct VaryingParameterArraySpec + { + VaryingParameterArraySpec* next = nullptr; + IntVal* elementCount; + }; + + struct VaryingParameterInfo + { + String name; + VaryingParameterDirection direction; + VaryingParameterArraySpec* arraySpecs = nullptr; + }; + + + void lowerSimpleShaderParameterToGLSLGlobal( + VaryingParameterInfo const& info, + RefPtr<ExpressionType> varType, + RefPtr<VarLayout> varLayout, + RefPtr<ExpressionSyntaxNode> varExpr) + { + RefPtr<ExpressionType> type = varType; + + for (auto aa = info.arraySpecs; aa; aa = aa->next) + { + RefPtr<ArrayExpressionType> arrayType = new ArrayExpressionType(); + arrayType->BaseType = type; + arrayType->ArrayLength = aa->elementCount; + + type = arrayType; + } + + // TODO: if we are declaring an SOA-ized array, + // this is where those array dimensions would need + // to be tacked on. + + RefPtr<Variable> globalVarDecl = new Variable(); + globalVarDecl->Name.Content = info.name; + globalVarDecl->Type.type = type; + + addMember(shared->loweredProgram, globalVarDecl); + + // Add the layout information + RefPtr<ComputedLayoutModifier> modifier = new ComputedLayoutModifier(); + modifier->layout = varLayout; + addModifier(globalVarDecl, modifier); + + // Need to generate an assignment in the right direction. + // + // TODO: for now I am just dealing with input: + + switch (info.direction) + { + case VaryingParameterDirection::Input: + addModifier(globalVarDecl, new InModifier()); + assign(varExpr, globalVarDecl); + break; + + case VaryingParameterDirection::Output: + addModifier(globalVarDecl, new OutModifier()); + + assign(globalVarDecl, varExpr); + break; + } + } + + void lowerShaderParameterToGLSLGLobalsRec( + VaryingParameterInfo const& info, + RefPtr<ExpressionType> varType, + RefPtr<VarLayout> varLayout, + RefPtr<ExpressionSyntaxNode> varExpr) + { + assert(varLayout); + + if (auto basicType = varType->As<BasicExpressionType>()) + { + // handled below + } + else if (auto vectorType = varType->As<VectorExpressionType>()) + { + // handled below + } + else if (auto matrixType = varType->As<MatrixExpressionType>()) + { + // handled below + } + else if (auto arrayType = varType->As<ArrayExpressionType>()) + { + // We will accumulate information on the array + // types that were encoutnered on our walk down + // to the leaves, and then apply these array dimensions + // to any leaf parameters. + + VaryingParameterArraySpec arraySpec; + arraySpec.next = info.arraySpecs; + arraySpec.elementCount = arrayType->ArrayLength; + + VaryingParameterInfo arrayInfo = info; + arrayInfo.arraySpecs = &arraySpec; + + RefPtr<IndexExpressionSyntaxNode> subscriptExpr = new IndexExpressionSyntaxNode(); + subscriptExpr->Position = varExpr->Position; + subscriptExpr->BaseExpression = varExpr; + + // TODO: we need to construct syntax for a loop to initialize + // the array here... + throw "unimplemented"; + + // Note that we use the original `varLayout` that was passed in, + // since that is the layout that will ultimately need to be + // used on the array elements. + // + // TODO: That won't actually work if we ever had an array of + // heterogeneous stuff... + lowerShaderParameterToGLSLGLobalsRec( + arrayInfo, + arrayType->BaseType, + varLayout, + subscriptExpr); + + } + else if (auto declRefType = varType->As<DeclRefType>()) + { + auto declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + { + // The shader parameter had a structured type, so we need + // to destructure it into its constituent fields + + for (auto fieldDeclRef : getMembersOfType<VarDeclBase>(aggTypeDeclRef)) + { + // Don't emit storage for `static` fields here, of course + if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + continue; + + RefPtr<MemberExpressionSyntaxNode> fieldExpr = new MemberExpressionSyntaxNode(); + fieldExpr->Position = varExpr->Position; + fieldExpr->Type.type = GetType(fieldDeclRef); + fieldExpr->declRef = fieldDeclRef; + fieldExpr->BaseExpression = varExpr; + + VaryingParameterInfo fieldInfo = info; + fieldInfo.name = info.name + "_" + fieldDeclRef.GetName(); + + // Need to find the layout for the given field... + Decl* originalFieldDecl = nullptr; + shared->mapLoweredDeclToOriginal.TryGetValue(fieldDeclRef.getDecl(), originalFieldDecl); + assert(originalFieldDecl); + + auto structTypeLayout = varLayout->typeLayout.As<StructTypeLayout>(); + assert(structTypeLayout); + + RefPtr<VarLayout> fieldLayout; + structTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout); + assert(fieldLayout); + + lowerShaderParameterToGLSLGLobalsRec( + fieldInfo, + GetType(fieldDeclRef), + fieldLayout, + fieldExpr); + } + + // Okay, we are done with this parameter + return; + } + } + + // Default case: just try to emit things as-is + lowerSimpleShaderParameterToGLSLGlobal(info, varType, varLayout, varExpr); + } + + void lowerShaderParameterToGLSLGLobals( + RefPtr<Variable> localVarDecl, + RefPtr<VarLayout> paramLayout, + VaryingParameterDirection direction) + { + auto name = localVarDecl->getName(); + auto declRef = makeDeclRef(localVarDecl.Ptr()); + + RefPtr<VarExpressionSyntaxNode> expr = new VarExpressionSyntaxNode(); + expr->name = name; + expr->declRef = declRef; + expr->Type.type = GetType(declRef); + + VaryingParameterInfo info; + info.name = name; + info.direction = direction; + + lowerShaderParameterToGLSLGLobalsRec( + info, + localVarDecl->getType(), + paramLayout, + expr); + } + + struct EntryPointParamPair + { + RefPtr<ParameterSyntaxNode> original; + RefPtr<VarLayout> layout; + RefPtr<Variable> lowered; + }; + + RefPtr<FunctionSyntaxNode> lowerEntryPointToGLSL( + FunctionSyntaxNode* entryPointDecl, + RefPtr<EntryPointLayout> entryPointLayout) + { + // First, loer the entry-point function as an ordinary function: + auto loweredEntryPointFunc = visit(entryPointDecl); + + // Now we will generate a `void main() { ... }` function to call the lowered code. + RefPtr<FunctionSyntaxNode> mainDecl = new FunctionSyntaxNode(); + mainDecl->ReturnType.type = ExpressionType::GetVoid(); + mainDecl->Name.Content = "main"; + + // If the user's entry point was called `main` then rename it here + if (loweredEntryPointFunc->getName() == "main") + loweredEntryPointFunc->Name.Content = "main_"; + + // We will want to generate declarations into the body of our new `main()` + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + + // The parameters of the entry-point function will be translated to + // both a local variable (for passing to/from the entry point func), + // and to global variables (used for parameter passing) + + List<EntryPointParamPair> params; + + // First generate declarations for the locals + for (auto paramDecl : entryPointDecl->GetParameters()) + { + RefPtr<VarLayout> paramLayout; + entryPointLayout->mapVarToLayout.TryGetValue(paramDecl.Ptr(), paramLayout); + assert(paramLayout); + + RefPtr<Variable> localVarDecl = new Variable(); + localVarDecl->Position = paramDecl->Position; + localVarDecl->Name.Content = paramDecl->getName(); + localVarDecl->Type = lowerType(paramDecl->Type); + + subVisitor.addDecl(localVarDecl); + + EntryPointParamPair paramPair; + paramPair.original = paramDecl; + paramPair.layout = paramLayout; + paramPair.lowered = localVarDecl; + + params.Add(paramPair); + } + + // Next generate globals for the inputs, and initialize them + for (auto paramPair : params) + { + auto paramDecl = paramPair.original; + if (paramDecl->HasModifier<InModifier>() + || paramDecl->HasModifier<InOutModifier>() + || !paramDecl->HasModifier<OutModifier>()) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.lowered, + paramPair.layout, + VaryingParameterDirection::Input); + } + } + + // Generate a local variable for the result, if any + RefPtr<Variable> resultVarDecl; + if (!loweredEntryPointFunc->ReturnType->Equals(ExpressionType::GetVoid())) + { + resultVarDecl = new Variable(); + resultVarDecl->Position = loweredEntryPointFunc->Position; + resultVarDecl->Name.Content = "_main_result"; + resultVarDecl->Type = TypeExp(loweredEntryPointFunc->ReturnType); + + subVisitor.addDecl(resultVarDecl); + } + + // Now generate a call to the entry-point function, using the local variables + auto entryPointDeclRef = makeDeclRef(loweredEntryPointFunc.Ptr()); + + RefPtr<FuncType> entryPointType = new FuncType(); + entryPointType->declRef = entryPointDeclRef; + + RefPtr<VarExpressionSyntaxNode> entryPointRef = new VarExpressionSyntaxNode(); + entryPointRef->name = loweredEntryPointFunc->getName(); + entryPointRef->declRef = entryPointDeclRef; + entryPointRef->Type = QualType(entryPointType); + + RefPtr<InvokeExpressionSyntaxNode> callExpr = new InvokeExpressionSyntaxNode(); + callExpr->FunctionExpr = entryPointRef; + callExpr->Type = QualType(loweredEntryPointFunc->ReturnType); + + // + for (auto paramPair : params) + { + auto localVarDecl = paramPair.lowered; + + RefPtr<VarExpressionSyntaxNode> varRef = new VarExpressionSyntaxNode(); + varRef->name = localVarDecl->getName(); + varRef->declRef = makeDeclRef(localVarDecl.Ptr()); + varRef->Type = QualType(localVarDecl->getType()); + + callExpr->Arguments.Add(varRef); + } + + if (resultVarDecl) + { + // Non-`void` return type, so we need to store it + subVisitor.assign(resultVarDecl, callExpr); + } + else + { + // `void` return type: just call it + subVisitor.addExprStmt(callExpr); + } + + + // Finally, generate logic to copy the outputs to global parameters + for (auto paramPair : params) + { + auto paramDecl = paramPair.original; + if (paramDecl->HasModifier<OutModifier>() + || paramDecl->HasModifier<InOutModifier>()) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.lowered, + paramPair.layout, + VaryingParameterDirection::Output); + } + } + if (resultVarDecl) + { + subVisitor.lowerShaderParameterToGLSLGLobals( + resultVarDecl, + entryPointLayout->resultLayout, + VaryingParameterDirection::Output); + } + + mainDecl->Body = subVisitor.stmtBeingBuilt; + + + // Once we are done building the body, we append our new declaration to the program. + addMember(shared->loweredProgram, mainDecl); + return mainDecl; + +#if 0 + RefPtr<FunctionSyntaxNode> loweredDecl = new FunctionSyntaxNode(); + lowerDeclCommon(loweredDecl, entryPointDecl); + + // We create a sub-context appropriate for lowering the function body + + LoweringVisitor subVisitor = *this; + subVisitor.isBuildingStmt = true; + subVisitor.stmtBeingBuilt = nullptr; + + // The parameters of the entry-point function must be translated + // to global-scope declarations + for (auto paramDecl : entryPointDecl->GetParameters()) + { + subVisitor.lowerShaderParameterToGLSLGLobals(paramDecl); + } + + // The output of the function must also be translated into a + // global-scope declaration. + auto loweredReturnType = lowerType(entryPointDecl->ReturnType); + RefPtr<Variable> resultGlobal; + if (!loweredReturnType->Equals(ExpressionType::GetVoid())) + { + resultGlobal = new Variable(); + // TODO: need a scheme for generating unique names + resultGlobal->Name.Content = "_main_result"; + resultGlobal->Type = loweredReturnType; + + addMember(shared->loweredProgram, resultGlobal); + } + + loweredDecl->Name.Content = "main"; + loweredDecl->ReturnType.type = ExpressionType::GetVoid(); + + // We will emit the body statement in a context where + // a `return` statmenet will generate writes to the + // result global that we declared. + subVisitor.resultVariable = resultGlobal; + + auto loweredBody = subVisitor.lowerStmt(entryPointDecl->Body); + subVisitor.addStmt(loweredBody); + + loweredDecl->Body = subVisitor.stmtBeingBuilt; + + // TODO: need to append writes for `out` parameters here... + + addMember(shared->loweredProgram, loweredDecl); + return loweredDecl; +#endif + } + + RefPtr<FunctionSyntaxNode> lowerEntryPoint( + FunctionSyntaxNode* entryPointDecl, + RefPtr<EntryPointLayout> entryPointLayout) + { + switch( getTarget() ) + { + // Default case: lower an entry point just like any other function + default: + return visit(entryPointDecl); + + // For Slang->GLSL translation, we need to lower things from HLSL-style + // declarations over to GLSL conventions + case CodeGenTarget::GLSL: + return lowerEntryPointToGLSL(entryPointDecl, entryPointLayout); + } + } + + RefPtr<FunctionSyntaxNode> lowerEntryPoint( + EntryPointRequest* entryPointRequest) + { + auto entryPointLayout = findEntryPointLayout(entryPointRequest); + auto entryPointDecl = entryPointLayout->entryPoint; + + return lowerEntryPoint( + entryPointDecl, + entryPointLayout); + } + + +}; + +static RefPtr<StructTypeLayout> getGlobalStructLayout( + ProgramLayout* programLayout) +{ + // Layout information for the global scope is either an ordinary + // `struct` in the common case, or a constant buffer in the case + // where there were global-scope uniforms. + auto globalScopeLayout = programLayout->globalScopeLayout; + StructTypeLayout* globalStructLayout = globalScopeLayout.As<StructTypeLayout>(); + if(globalStructLayout) + { } + else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>()) + { + // TODO: the `cbuffer` case really needs to be emitted very + // carefully, but that is beyond the scope of what a simple rewriter + // can easily do (without semantic analysis, etc.). + // + // The crux of the problem is that we need to collect all the + // global-scope uniforms (but not declarations that don't involve + // uniform storage...) and put them in a single `cbuffer` declaration, + // so that we can give it an explicit location. The fields in that + // declaration might use various type declarations, so we'd really + // need to emit all the type declarations first, and that involves + // some large scale reorderings. + // + // For now we will punt and just emit the declarations normally, + // and hope that the global-scope block (`$Globals`) gets auto-assigned + // the same location that we manually asigned it. + + auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout; + auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>(); + + // We expect all constant buffers to contain `struct` types for now + assert(elementTypeStructLayout); + + globalStructLayout = elementTypeStructLayout.Ptr(); + } + else + { + assert(!"unexpected"); + } + return globalStructLayout; +} + + +// Determine if the user is just trying to "rewrite" their input file +// into an output file. This will affect the way we approach code +// generation, because we want to leave their code "as is" whenever +// possible. +bool isRewriteRequest( + SourceLanguage sourceLanguage, + CodeGenTarget target) +{ + // TODO: we might only consider things to be a rewrite request + // in the specific case where checking is turned off... + + switch( target ) + { + default: + return false; + + case CodeGenTarget::HLSL: + return sourceLanguage == SourceLanguage::HLSL; + + case CodeGenTarget::GLSL: + return sourceLanguage == SourceLanguage::GLSL; + } +} + + + +LoweredEntryPoint lowerEntryPoint( + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target) +{ + SharedLoweringContext sharedContext; + sharedContext.programLayout = programLayout; + sharedContext.target = target; + + auto translationUnit = entryPoint->getTranslationUnit(); + + // Create a single module/program to hold all the lowered code + // (with the exception of instrinsic/stdlib declarations, which + // will be remain where they are) + RefPtr<ProgramSyntaxNode> loweredProgram = new ProgramSyntaxNode(); + sharedContext.loweredProgram = loweredProgram; + + LoweringVisitor visitor; + visitor.shared = &sharedContext; + visitor.parentDecl = loweredProgram; + + // We need to register the lowered program as the lowered version + // of the existing translation unit declaration. + + visitor.registerLoweredDecl( + loweredProgram, + translationUnit->SyntaxNode); + + // We also need to register the lowered program as the lowered version + // of any imported modules (since we will be collecting everything into + // a single module for code generation). + for (auto rr : entryPoint->compileRequest->loadedModulesList) + { + sharedContext.loweredDecls.Add( + rr, + loweredProgram); + } + + // We also want to remember the layout information for + // that declaration, so that we can apply it during emission + attachLayout(loweredProgram, + getGlobalStructLayout(programLayout)); + + + bool isRewrite = isRewriteRequest(translationUnit->sourceLanguage, target); + sharedContext.isRewrite = isRewrite; + + LoweredEntryPoint result; + if (isRewrite) + { + for (auto dd : translationUnit->SyntaxNode->Members) + { + visitor.translateDeclRef(dd); + } + } + else + { + auto loweredEntryPoint = visitor.lowerEntryPoint(entryPoint); + result.entryPoint = loweredEntryPoint; + } + + result.program = sharedContext.loweredProgram; + + return result; +} +} diff --git a/source/slang/lower.h b/source/slang/lower.h new file mode 100644 index 000000000..c690ea025 --- /dev/null +++ b/source/slang/lower.h @@ -0,0 +1,38 @@ +// lower.h +#ifndef SLANG_LOWER_H_INCLUDED +#define SLANG_LOWER_H_INCLUDED + +// The "lowering" step takes an input AST written in the complete Slang +// language and turns it into a more minimal format (still using the +// same AST) suitable for emission into lower-level languages. + +#include "../core/basic.h" + +#include "compiler.h" +#include "syntax.h" + +namespace Slang +{ + class EntryPointRequest; + class ProgramLayout; + class TranslationUnitRequest; + + struct LoweredEntryPoint + { + // The actual lowered entry point + RefPtr<FunctionSyntaxNode> entryPoint; + + // The generated program AST that + // contains the entry point and any + // other declarations it uses + RefPtr<ProgramSyntaxNode> program; + }; + + // Emit code for a single entry point, based on + // the input translation unit. + LoweredEntryPoint lowerEntryPoint( + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target); +} +#endif diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h index a1da47fb7..e50288600 100644 --- a/source/slang/modifier-defs.h +++ b/source/slang/modifier-defs.h @@ -269,3 +269,7 @@ SIMPLE_SYNTAX_CLASS(HLSLTriangleModifier , HLSLGeometryShaderInputPrimitiveT SIMPLE_SYNTAX_CLASS(HLSLLineAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) SIMPLE_SYNTAX_CLASS(HLSLTriangleAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier) +// A modifier to be attached to syntax after we've computed layout +SYNTAX_CLASS(ComputedLayoutModifier, Modifier) + FIELD(RefPtr<Layout>, layout) +END_SYNTAX_CLASS() diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 2a0c64892..8d008618f 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -767,86 +767,87 @@ SimpleSemanticInfo decomposeSimpleSemantic( return info; } -enum class EntryPointParameterDirection +enum EntryPointParameterDirection { - Input, - Output, + kEntryPointParameterDirection_Input = 0x1, + kEntryPointParameterDirection_Output = 0x2, }; +typedef unsigned int EntryPointParameterDirectionMask; struct EntryPointParameterState { - String* optSemanticName; - int* ioSemanticIndex; - EntryPointParameterDirection direction; - int semanticSlotCount; + String* optSemanticName; + int* ioSemanticIndex; + EntryPointParameterDirectionMask directionMask; + int semanticSlotCount; }; -static void processSimpleEntryPointInput( +static RefPtr<TypeLayout> processSimpleEntryPointParameter( ParameterBindingContext* context, RefPtr<ExpressionType> type, - EntryPointParameterState const& state) + EntryPointParameterState const& inState, + int semanticSlotCount = 1) { - auto optSemanticName = state.optSemanticName; - auto semanticIndex = *state.ioSemanticIndex; - auto semanticSlotCount = state.semanticSlotCount; -} + EntryPointParameterState state = inState; + state.semanticSlotCount = semanticSlotCount; -static void processSimpleEntryPointOutput( - ParameterBindingContext* context, - RefPtr<ExpressionType> type, - EntryPointParameterState const& state) -{ auto optSemanticName = state.optSemanticName; auto semanticIndex = *state.ioSemanticIndex; - auto semanticSlotCount = state.semanticSlotCount; - - if(!optSemanticName) - return; - auto semanticName = *optSemanticName; + String semanticName = optSemanticName ? *optSemanticName : ""; + String sn = semanticName.ToLower(); - // Note: I'm just doing something expedient here and detecting `SV_Target` - // outputs and claiming the appropriate register range right away. - // - // TODO: we should really be building up some representation of all of this, - // once we've gone to the trouble of looking it all up... - if( semanticName.ToLower() == "sv_target" ) + RefPtr<TypeLayout> typeLayout = new TypeLayout(); + if (sn.StartsWith("sv_")) { - context->shared->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(semanticIndex, semanticIndex + semanticSlotCount); - } -} + // System-value semantic. -static void processSimpleEntryPointParameter( - ParameterBindingContext* context, - RefPtr<ExpressionType> type, - EntryPointParameterState const& inState, - int semanticSlotCount = 1) -{ - EntryPointParameterState state = inState; - state.semanticSlotCount = semanticSlotCount; + if (state.directionMask & kEntryPointParameterDirection_Output) + { + // Note: I'm just doing something expedient here and detecting `SV_Target` + // outputs and claiming the appropriate register range right away. + // + // TODO: we should really be building up some representation of all of this, + // once we've gone to the trouble of looking it all up... + if( sn == "sv_target" ) + { + context->shared->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(semanticIndex, semanticIndex + semanticSlotCount); + } + } - switch( state.direction ) + // TODO: add some kind of usage information for system input/output + } + else { - case EntryPointParameterDirection::Input: - processSimpleEntryPointInput(context, type, state); - break; + // user-defined semantic - case EntryPointParameterDirection::Output: - processSimpleEntryPointOutput(context, type, state); - break; + if (state.directionMask & kEntryPointParameterDirection_Input) + { + auto rules = context->layoutRules->getVaryingInputRules(); + SimpleLayoutInfo layout = GetLayout(type, rules); + typeLayout->addResourceUsage(layout.kind, layout.size); + } - SLANG_EXHAUSTIVE_SWITCH() + if (state.directionMask & kEntryPointParameterDirection_Output) + { + auto rules = context->layoutRules->getVaryingOutputRules(); + SimpleLayoutInfo layout = GetLayout(type, rules); + typeLayout->addResourceUsage(layout.kind, layout.size); + } } *state.ioSemanticIndex += state.semanticSlotCount; + typeLayout->type = type; + + return typeLayout; } -static void processEntryPointParameter( +static RefPtr<TypeLayout> processEntryPointParameter( ParameterBindingContext* context, RefPtr<ExpressionType> type, EntryPointParameterState const& state); -static void processEntryPointParameterWithPossibleSemantic( +static RefPtr<TypeLayout> processEntryPointParameterWithPossibleSemantic( ParameterBindingContext* context, Decl* declForSemantic, RefPtr<ExpressionType> type, @@ -873,11 +874,11 @@ static void processEntryPointParameterWithPossibleSemantic( // *or* we couldn't find an explicit semantic to apply on the given // declaration, so we will just recursive with whatever we have at // the moment. - processEntryPointParameter(context, type, state); + return processEntryPointParameter(context, type, state); } -static void processEntryPointParameter( +static RefPtr<TypeLayout> processEntryPointParameter( ParameterBindingContext* context, RefPtr<ExpressionType> type, EntryPointParameterState const& state) @@ -885,31 +886,50 @@ static void processEntryPointParameter( // Scalar and vector types are treated as outputs directly if(auto basicType = type->As<BasicExpressionType>()) { - processSimpleEntryPointParameter(context, basicType, state); + return processSimpleEntryPointParameter(context, basicType, state); } else if(auto basicType = type->As<VectorExpressionType>()) { - processSimpleEntryPointParameter(context, basicType, state); + return processSimpleEntryPointParameter(context, basicType, state); } // A matrix is processed as if it was an array of rows else if( auto matrixType = type->As<MatrixExpressionType>() ) { auto rowCount = GetIntVal(matrixType->getRowCount()); - processSimpleEntryPointParameter(context, basicType, state, (int) rowCount); + return processSimpleEntryPointParameter(context, basicType, state, (int) rowCount); } else if( auto arrayType = type->As<ArrayExpressionType>() ) { + // Note: Bad Things will happen if we have an array input + // without a semantic already being enforced. + auto elementCount = GetIntVal(arrayType->ArrayLength); - for( int ii = 0; ii < elementCount; ++ii ) + // We use the first element to derive the layout for the element type + auto elementTypeLayout = processEntryPointParameter(context, arrayType->BaseType, state); + + // We still walk over subsequent elements to make sure they consume resources + // as needed + for( int ii = 1; ii < elementCount; ++ii ) { processEntryPointParameter(context, arrayType->BaseType, state); } + + RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); + arrayTypeLayout->elementTypeLayout = elementTypeLayout; + arrayTypeLayout->type = arrayType; + + for (auto rr : elementTypeLayout->resourceInfos) + { + arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count * elementCount; + } + + return arrayTypeLayout; } // Ignore a bunch of types that don't make sense here... - else if(auto textureType = type->As<TextureType>()) {} - else if(auto samplerStateType = type->As<SamplerStateType>()) {} - else if(auto constantBufferType = type->As<ConstantBufferType>()) {} + else if (auto textureType = type->As<TextureType>()) { return nullptr; } + else if(auto samplerStateType = type->As<SamplerStateType>()) { return nullptr; } + else if(auto constantBufferType = type->As<ConstantBufferType>()) { return nullptr; } // Catch declaration-reference types late in the sequence, since // otherwise they will include all of the above cases... else if( auto declRefType = type->As<DeclRefType>() ) @@ -918,15 +938,36 @@ static void processEntryPointParameter( if (auto structDeclRef = declRef.As<StructSyntaxNode>()) { + RefPtr<StructTypeLayout> structLayout = new StructTypeLayout(); + structLayout->type = type; + // Need to recursively walk the fields of the structure now... for( auto field : GetFields(structDeclRef) ) { - processEntryPointParameterWithPossibleSemantic( + auto fieldTypeLayout = processEntryPointParameterWithPossibleSemantic( context, field.getDecl(), GetType(field), state); + + RefPtr<VarLayout> fieldVarLayout = new VarLayout(); + fieldVarLayout->varDecl = field; + fieldVarLayout->typeLayout = fieldTypeLayout; + + for (auto rr : fieldTypeLayout->resourceInfos) + { + assert(rr.count != 0); + + auto structRes = structLayout->findOrAddResourceInfo(rr.kind); + fieldVarLayout->findOrAddResourceInfo(rr.kind)->index = structRes->count; + structRes->count += rr.count; + } + + structLayout->fields.Add(fieldVarLayout); + structLayout->mapVarToLayout.Add(field.getDecl(), fieldVarLayout); } + + return structLayout; } else { @@ -937,6 +978,9 @@ static void processEntryPointParameter( { assert(!"unimplemented"); } + + assert(!"unexpected"); + return nullptr; } static void collectEntryPointParameters( @@ -998,42 +1042,68 @@ static void collectEntryPointParameters( // We have an entry-point parameter, and need to figure out what to do with it. + // TODO: need to handle `uniform`-qualified parameters here + if (paramDecl->HasModifier<UniformModifier>()) + continue; + + state.directionMask = 0; + // If it appears to be an input, process it as such. if( paramDecl->HasModifier<InModifier>() || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>() ) { - state.direction = EntryPointParameterDirection::Input; - - processEntryPointParameterWithPossibleSemantic( - context, - paramDecl.Ptr(), - paramDecl->Type.type, - state); + state.directionMask |= kEntryPointParameterDirection_Input; } // If it appears to be an output, process it as such. if(paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>()) { - state.direction = EntryPointParameterDirection::Output; + state.directionMask |= kEntryPointParameterDirection_Output; + } + + auto paramTypeLayout = processEntryPointParameterWithPossibleSemantic( + context, + paramDecl.Ptr(), + paramDecl->Type.type, + state); - processEntryPointParameterWithPossibleSemantic( - context, - paramDecl.Ptr(), - paramDecl->Type.type, - state); + RefPtr<VarLayout> paramVarLayout = new VarLayout(); + paramVarLayout->varDecl = makeDeclRef(paramDecl.Ptr()); + paramVarLayout->typeLayout = paramTypeLayout; + + for (auto rr : paramTypeLayout->resourceInfos) + { + auto entryPointRes = entryPointLayout->findOrAddResourceInfo(rr.kind); + paramVarLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count; + entryPointRes->count += rr.count; } + + entryPointLayout->fields.Add(paramVarLayout); + entryPointLayout->mapVarToLayout.Add(paramDecl, paramVarLayout); } // If we can find an output type for the entry point, then process it as // an output parameter. if( auto resultType = entryPointFuncDecl->ReturnType.type ) { - state.direction = EntryPointParameterDirection::Output; + state.directionMask = kEntryPointParameterDirection_Output; - processEntryPointParameterWithPossibleSemantic( + auto resultTypeLayout = processEntryPointParameterWithPossibleSemantic( context, entryPointFuncDecl, resultType, state); + + RefPtr<VarLayout> resultLayout = new VarLayout(); + resultLayout->typeLayout = resultTypeLayout; + + for (auto rr : resultTypeLayout->resourceInfos) + { + auto entryPointRes = entryPointLayout->findOrAddResourceInfo(rr.kind); + resultLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count; + entryPointRes->count += rr.count; + } + + entryPointLayout->resultLayout = resultLayout; } } diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 4f622ada6..b134f9645 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -2617,10 +2617,13 @@ namespace Slang RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); - RefPtr<BlockStatementSyntaxNode> blockStatement = new BlockStatementSyntaxNode(); + RefPtr<BlockStmt> blockStatement = new BlockStmt(); blockStatement->scopeDecl = scopeDecl; PushScope(scopeDecl.Ptr()); ReadToken(TokenType::LBrace); + + RefPtr<StatementSyntaxNode> body; + if(!tokenReader.IsAtEnd()) { FillPosition(blockStatement.Ptr()); @@ -2630,11 +2633,29 @@ namespace Slang auto stmt = ParseStatement(); if(stmt) { - blockStatement->Statements.Add(stmt); + if (!body) + { + body = stmt; + } + else if (auto seqStmt = body.As<SeqStmt>()) + { + seqStmt->stmts.Add(stmt); + } + else + { + RefPtr<SeqStmt> newBody = new SeqStmt(); + newBody->Position = blockStatement->Position; + newBody->stmts.Add(body); + newBody->stmts.Add(stmt); + + body = newBody; + } } TryRecover(this); } PopScope(); + + blockStatement->body = body; return blockStatement; } @@ -3054,7 +3075,19 @@ namespace Slang right = parseInfixExprWithPrecedence(parser, right, nextOpPrec); } - expr = createInfixExpr(parser, expr, op, right); + if (opTokenType == TokenType::OpAssign) + { + RefPtr<AssignExpr> assignExpr = new AssignExpr(); + assignExpr->Position = op->Position; + assignExpr->left = expr; + assignExpr->right = right; + + expr = assignExpr; + } + else + { + expr = createInfixExpr(parser, expr, op, right); + } } return expr; } diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index 3fde10e4b..441bc5e45 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -178,6 +178,7 @@ <ClInclude Include="modifier-defs.h" /> <ClInclude Include="object-meta-begin.h" /> <ClInclude Include="object-meta-end.h" /> + <ClInclude Include="lower.h" /> <ClInclude Include="parameter-binding.h" /> <ClInclude Include="parser.h" /> <ClInclude Include="preprocessor.h" /> @@ -205,6 +206,7 @@ <ClCompile Include="emit.cpp" /> <ClCompile Include="lexer.cpp" /> <ClCompile Include="lookup.cpp" /> + <ClCompile Include="lower.cpp" /> <ClCompile Include="options.cpp" /> <ClCompile Include="parameter-binding.cpp" /> <ClCompile Include="parser.cpp" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 0ed4457dc..21b533a10 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -37,6 +37,7 @@ <ClInclude Include="type-defs.h" /> <ClInclude Include="val-defs.h" /> <ClInclude Include="visitor.h" /> + <ClInclude Include="lower.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="check.cpp" /> @@ -56,5 +57,6 @@ <ClCompile Include="token.cpp" /> <ClCompile Include="type-layout.cpp" /> <ClCompile Include="options.cpp" /> + <ClCompile Include="lower.cpp" /> </ItemGroup> </Project>
\ No newline at end of file diff --git a/source/slang/stmt-defs.h b/source/slang/stmt-defs.h index 9cc1978bd..c5bcda63b 100644 --- a/source/slang/stmt-defs.h +++ b/source/slang/stmt-defs.h @@ -6,8 +6,14 @@ ABSTRACT_SYNTAX_CLASS(ScopeStmt, StatementSyntaxNode) SYNTAX_FIELD(RefPtr<ScopeDecl>, scopeDecl) END_SYNTAX_CLASS() -SYNTAX_CLASS(BlockStatementSyntaxNode, ScopeStmt) - SYNTAX_FIELD(List<RefPtr<StatementSyntaxNode>>, Statements) +// A sequence of statements, treated as a single statement +SYNTAX_CLASS(SeqStmt, StatementSyntaxNode) + SYNTAX_FIELD(List<RefPtr<StatementSyntaxNode>>, stmts) +END_SYNTAX_CLASS() + +// The simplest kind of scope statement: just a `{...}` block +SYNTAX_CLASS(BlockStmt, ScopeStmt) + SYNTAX_FIELD(RefPtr<StatementSyntaxNode>, body); END_SYNTAX_CLASS() SYNTAX_CLASS(UnparsedStmt, StatementSyntaxNode) diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 983f18bb7..12d8c90bb 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -1,8 +1,6 @@ #include "syntax.h" -#pragma warning AAA #include "visitor.h" -#pragma warning BBB #include <typeinfo> #include <assert.h> @@ -55,9 +53,6 @@ namespace Slang return res.ProduceString(); } -#pragma warning CCC - - // Generate dispatch logic and other definitions for all syntax classes #define SYNTAX_CLASS(NAME, BASE) /* empty */ #include "object-meta-begin.h" @@ -79,8 +74,6 @@ namespace Slang #include "object-meta-end.h" -#pragma warning DDD - void ExpressionType::accept(IValVisitor* visitor, void* extra) { accept((ITypeVisitor*)visitor, extra); @@ -881,9 +874,20 @@ void ExpressionType::accept(IValVisitor* visitor, void* extra) auto parentDecl = decl->ParentDecl; if (auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl)) { - // We need to strip away one layer of specialization - assert(substitutions); - return DeclRefBase(parentGeneric, substitutions->outer); + if (substitutions && substitutions->genericDecl == parentDecl) + { + // We strip away the specializations that were applied to + // the parent, since we were asked for a reference *to* the parent. + return DeclRefBase(parentGeneric, substitutions->outer); + } + else + { + // Either we don't have specializations, or the inner-most + // specializations didn't apply to the parent decl. This + // can happen if we are looking at an unspecialized + // declaration that is a child of a generic. + return DeclRefBase(parentGeneric, substitutions); + } } else { diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 47427b130..6587e18c3 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -14,6 +14,7 @@ namespace Slang class Substitutions; class SyntaxVisitor; class FunctionSyntaxNode; + class Layout; struct IExprVisitor; struct IDeclVisitor; @@ -347,9 +348,15 @@ namespace Slang { return DeclRef<ContainerDecl>::unsafeInit(DeclRefBase::GetParent()); } - }; + + template<typename T> + inline DeclRef<T> makeDeclRef(T* decl) + { + return DeclRef<T>(decl, nullptr); + } + template<typename T> struct FilteredMemberList { diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index 07867fddb..ce1f8864d 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -143,8 +143,13 @@ struct SimpleArrayLayoutInfo : SimpleLayoutInfo struct LayoutRulesImpl; +// Base class for things that store layout info +class Layout : public RefObject +{ +}; + // A reified reprsentation of a particular laid-out type -class TypeLayout : public RefObject +class TypeLayout : public Layout { public: // The type that was laid out @@ -215,7 +220,7 @@ enum VarLayoutFlag : VarLayoutFlags }; // A reified layout for a particular variable, field, etc. -class VarLayout : public RefObject +class VarLayout : public Layout { public: // The variable we are laying out @@ -324,7 +329,14 @@ public: // Layout information for a single shader entry point // within a program -class EntryPointLayout : public RefObject +// +// Treated as a subclass of `StructTypeLayout` becase +// it needs to include computed layout information +// for the parameters of the entry point. +// +// TODO: where to store layout info for the return +// type of the function? +class EntryPointLayout : public StructTypeLayout { public: // The corresponding function declaration @@ -332,10 +344,13 @@ public: // The shader profile that was used to compile the entry point Profile profile; + + // Layout for any results of the entry point + RefPtr<VarLayout> resultLayout; }; // Layout information for the global scope of a program -class ProgramLayout : public RefObject +class ProgramLayout : public Layout { public: // We store a layout for the declarations at the global @@ -359,14 +374,6 @@ public: List<RefPtr<EntryPointLayout>> entryPoints; }; -// A modifier to be attached to syntax after we've computed layout -class ComputedLayoutModifier : public Modifier -{ -public: - RefPtr<TypeLayout> typeLayout; -}; - - struct LayoutRulesFamilyImpl; // A delineation of shader parameter types into fine-grained diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h index 316981842..9ebfc9872 100644 --- a/source/slang/val-defs.h +++ b/source/slang/val-defs.h @@ -3,7 +3,8 @@ // Syntax class definitions for compile-time values. // A compile-time integer (may not have a specific concrete value) -SIMPLE_SYNTAX_CLASS(IntVal, Val) +ABSTRACT_SYNTAX_CLASS(IntVal, Val) +END_SYNTAX_CLASS() // Trivial case of a value that is just a constant integer SYNTAX_CLASS(ConstantIntVal, IntVal) |
