summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2017-07-06 09:52:53 -0700
committerGitHub <noreply@github.com>2017-07-06 09:52:53 -0700
commit21a14cb4e0d578bc4f8a460016269a1199cac0da (patch)
tree88a04619ceaaa37b87199dd82334cc9d102c156d
parentf313df379dd9e0d4395f072ffb87016a6f20d5a1 (diff)
parentf145e09a6dcbcf326f782b3e6a76dbf291c792cf (diff)
Merge pull request #53 from tfoleyNV/cross-compilation
Initial work on cross-compilation
-rw-r--r--slang.h1
-rw-r--r--source/slang/check.cpp29
-rw-r--r--source/slang/compiler.cpp73
-rw-r--r--source/slang/emit.cpp435
-rw-r--r--source/slang/emit.h9
-rw-r--r--source/slang/expr-defs.h5
-rw-r--r--source/slang/lower.cpp1601
-rw-r--r--source/slang/lower.h38
-rw-r--r--source/slang/modifier-defs.h4
-rw-r--r--source/slang/parameter-binding.cpp220
-rw-r--r--source/slang/parser.cpp39
-rw-r--r--source/slang/slang.vcxproj2
-rw-r--r--source/slang/slang.vcxproj.filters2
-rw-r--r--source/slang/stmt-defs.h10
-rw-r--r--source/slang/syntax.cpp24
-rw-r--r--source/slang/syntax.h9
-rw-r--r--source/slang/type-layout.h31
-rw-r--r--source/slang/val-defs.h3
-rw-r--r--tests/render/cross-compile-entry-point.slang89
-rw-r--r--tools/render-test/render-d3d11.cpp4
20 files changed, 2274 insertions, 354 deletions
diff --git a/slang.h b/slang.h
index 91b9f5931..93fdefab7 100644
--- a/slang.h
+++ b/slang.h
@@ -901,6 +901,7 @@ namespace slang
#include "source/slang/parser.cpp"
#include "source/slang/preprocessor.cpp"
#include "source/slang/lookup.cpp"
+#include "source/slang/lower.cpp"
#include "source/slang/check.cpp"
#include "source/slang/compiler.cpp"
#include "source/slang/slang-stdlib.cpp"
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index a3e061a3b..a79af3f37 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -1581,11 +1581,16 @@ namespace Slang
DeclVisitor::dispatch(stmt->decl);
}
- void visit(BlockStatementSyntaxNode *stmt)
+ void visit(BlockStmt* stmt)
{
- for (auto & node : stmt->Statements)
+ checkStmt(stmt->body);
+ }
+
+ void visit(SeqStmt* stmt)
+ {
+ for(auto ss : stmt->stmts)
{
- checkStmt(node);
+ checkStmt(ss);
}
}
@@ -2414,6 +2419,24 @@ namespace Slang
return appExpr;
}
+ //
+
+ RefPtr<ExpressionSyntaxNode> visit(AssignExpr* expr)
+ {
+ expr->left = CheckExpr(expr->left);
+
+ auto type = expr->left->Type;
+
+ expr->right = Coerce(type, CheckTerm(expr->right));
+
+ if (!type.IsLeftValue)
+ {
+ getSink()->diagnose(expr, Diagnostics::assignNonLValue);
+ }
+ expr->Type = type;
+ return expr;
+ }
+
//
diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp
index 9132624f9..19623a7aa 100644
--- a/source/slang/compiler.cpp
+++ b/source/slang/compiler.cpp
@@ -55,10 +55,11 @@ namespace Slang
//
- String emitHLSLForTranslationUnit(
- TranslationUnitRequest* translationUnit)
+ String emitHLSLForEntryPoint(
+ EntryPointRequest* entryPoint)
{
- auto compileRequest = translationUnit->compileRequest;
+ auto compileRequest = entryPoint->compileRequest;
+ auto translationUnit = entryPoint->getTranslationUnit();
if (compileRequest->passThrough != PassThroughMode::None)
{
// Generate a string that includes the content of
@@ -92,8 +93,8 @@ namespace Slang
}
else
{
- return emitProgram(
- translationUnit->SyntaxNode.Ptr(),
+ return emitEntryPoint(
+ entryPoint,
compileRequest->layout.Ptr(),
CodeGenTarget::HLSL);
}
@@ -137,8 +138,8 @@ namespace Slang
{
// TODO(tfoley): need to pass along the entry point
// so that we properly emit it as the `main` function.
- return emitProgram(
- translationUnit->SyntaxNode.Ptr(),
+ return emitEntryPoint(
+ entryPoint,
compileRequest->layout.Ptr(),
CodeGenTarget::GLSL);
}
@@ -180,15 +181,7 @@ namespace Slang
assert(D3DCompile_);
}
- // The HLSL compiler will try to "canonicalize" our input file path,
- // and we don't want it to do that, because they it won't report
- // the same locations on error messages that we would.
- //
- // To work around that, we prepend a custom `#line` directive.
-
- auto translationUnit = entryPoint->getTranslationUnit();
-
- auto hlslCode = emitHLSLForTranslationUnit(translationUnit);
+ auto hlslCode = emitHLSLForEntryPoint(entryPoint);
ID3DBlob* codeBlob;
ID3DBlob* diagnosticsBlob;
@@ -223,6 +216,7 @@ namespace Slang
if (FAILED(hr))
{
// TODO(tfoley): What to do on failure?
+ exit(1);
}
return data;
}
@@ -400,6 +394,13 @@ namespace Slang
switch (compileRequest->Target)
{
+ case CodeGenTarget::HLSL:
+ {
+ String code = emitHLSLForEntryPoint(entryPoint);
+ result.outputSource = code;
+ }
+ break;
+
case CodeGenTarget::GLSL:
{
String code = emitGLSLForEntryPoint(entryPoint);
@@ -505,45 +506,7 @@ namespace Slang
TranslationUnitResult emitTranslationUnit(
TranslationUnitRequest* translationUnit)
{
- auto compileRequest = translationUnit->compileRequest;
-
- // Most of our code generation targets will require us
- // to proceed through one entry point at a time, but
- // in some cases we can emit an entire translation unit
- // in one go.
-
- switch (compileRequest->Target)
- {
- default:
- // The default behavior is going to loop over all the entry
- // points, and then collect an aggregate result.
- return emitTranslationUnitEntryPoints(translationUnit);
-
- case CodeGenTarget::HLSL:
- // When targetting HLSL, we can emit the entire translation unit
- // as a single HLSL program, and include all the entry points.
- {
-
- String hlsl = emitHLSLForTranslationUnit(translationUnit);
-
- TranslationUnitResult result;
- result.outputSource = hlsl;
-
- // Because the user might ask for per-entry-point source,
- // we will just attach the same string as the result for
- // each entry point.
- for( auto& entryPoint : translationUnit->entryPoints )
- {
- EntryPointResult entryPointResult;
- entryPointResult.outputSource = hlsl;
-
- entryPoint->result = entryPointResult;
- }
-
- return result;
- }
- break;
- }
+ return emitTranslationUnitEntryPoints(translationUnit);
}
#if 0
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index dc7c68aa4..71b39aabb 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -1,6 +1,7 @@
// emit.cpp
#include "emit.h"
+#include "lower.h"
#include "syntax.h"
#include "type-layout.h"
@@ -13,8 +14,16 @@
namespace Slang {
-struct EmitContext
+// Shared state for an entire emit session
+struct SharedEmitContext
{
+ // The target language we want to generate code for
+ CodeGenTarget target;
+
+ // A set of words reserved by the target
+ Dictionary<String, String> reservedWords;
+
+ // The string of code we've built so far
StringBuilder sb;
// Current source position for tracking purposes...
@@ -22,12 +31,6 @@ struct EmitContext
CodePosition nextSourceLocation;
bool needToUpdateSourceLocation;
- // The target language we want to generate code for
- CodeGenTarget target;
-
- // A set of words reserved by the target
- Dictionary<String, String> reservedWords;
-
// For GLSL output, we can't emit traidtional `#line` directives
// with a file path in them, so we maintain a map that associates
// each path with a unique integer, and then we output those
@@ -45,6 +48,18 @@ struct EmitContext
// TODO: This will probably change if we represent imports
// explicitly in the layout data.
StructTypeLayout* globalStructLayout;
+
+ ProgramLayout* programLayout;
+};
+
+struct EmitContext
+{
+ // The shared context that is in effect
+ SharedEmitContext* shared;
+
+ // Are we in "rewrite" mode, where we are trying to reproduce the input
+ // code as closely as posible?
+ bool isRewrite;
};
//
@@ -79,7 +94,7 @@ static void emitRawTextSpan(EmitContext* context, char const* textBegin, char co
// TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things...
auto len = int(textEnd - textBegin);
- context->sb.Append(textBegin, len);
+ context->shared->sb.Append(textBegin, len);
}
static void emitRawText(EmitContext* context, char const* text)
@@ -99,7 +114,7 @@ static void emitTextSpan(EmitContext* context, char const* textBegin, char const
// Update our logical position
// TODO(tfoley): Need to make "corelib" not use `int` for pointer-sized things...
auto len = int(textEnd - textBegin);
- context->loc.Col += len;
+ context->shared->loc.Col += len;
}
static void Emit(EmitContext* context, char const* textBegin, char const* textEnd)
@@ -123,8 +138,8 @@ static void Emit(EmitContext* context, char const* textBegin, char const* textEn
// At the end of a line, we need to update our tracking
// information on code positions
emitTextSpan(context, spanBegin, spanEnd);
- context->loc.Line++;
- context->loc.Col = 1;
+ context->shared->loc.Line++;
+ context->shared->loc.Col = 1;
// Start a new span for emit purposes
spanBegin = spanEnd;
@@ -144,7 +159,7 @@ static void emit(EmitContext* context, String const& text)
static bool isReservedWord(EmitContext* context, String const& name)
{
- return context->reservedWords.TryGetValue(name) != nullptr;
+ return context->shared->reservedWords.TryGetValue(name) != nullptr;
}
static void emitName(
@@ -449,7 +464,7 @@ static bool isTargetIntrinsicModifierApplicable(
// we expect.
auto const& targetName = targetToken.Content;
- switch(context->target)
+ switch(context->shared->target)
{
default:
assert(!"unexpected");
@@ -658,7 +673,7 @@ static void emitCallExpr(
case IntrinsicOp::InnerProduct_Vector_Vector:
// HLSL allows `mul()` to be used as a synonym for `dot()`,
// so we need to translate to `dot` for GLSL
- if (context->target == CodeGenTarget::GLSL)
+ if (context->shared->target == CodeGenTarget::GLSL)
{
Emit(context, "dot(");
EmitExpr(context, callExpr->Arguments[0]);
@@ -677,7 +692,7 @@ static void emitCallExpr(
//
// The other critical detail here is that the way we handle matrix
// conventions requires that the operands to the product be swapped.
- if (context->target == CodeGenTarget::GLSL)
+ if (context->shared->target == CodeGenTarget::GLSL)
{
Emit(context, "((");
EmitExpr(context, callExpr->Arguments[1]);
@@ -825,6 +840,13 @@ static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntax
Emit(context, " : ");
EmitExprWithPrecedence(context, selectExpr->Arguments[2], kPrecedence_Conditional);
}
+ else if (auto assignExpr = expr.As<AssignExpr>())
+ {
+ needClose = MaybeEmitParens(context, outerPrec, kPrecedence_Assign);
+ EmitExprWithPrecedence(context, assignExpr->left, kPrecedence_Assign);
+ Emit(context, " = ");
+ EmitExprWithPrecedence(context, assignExpr->right, kPrecedence_Assign);
+ }
else if (auto callExpr = expr.As<InvokeExpressionSyntaxNode>())
{
emitCallExpr(context, callExpr, outerPrec);
@@ -970,7 +992,7 @@ static void EmitExprWithPrecedence(EmitContext* context, RefPtr<ExpressionSyntax
}
else if (auto castExpr = expr.As<TypeCastExpressionSyntaxNode>())
{
- switch(context->target)
+ switch(context->shared->target)
{
case CodeGenTarget::GLSL:
// GLSL requires constructor syntax for all conversions
@@ -1227,7 +1249,7 @@ static void emitTextureType(
EmitContext* context,
RefPtr<TextureType> texType)
{
- switch(context->target)
+ switch(context->shared->target)
{
case CodeGenTarget::HLSL:
emitHLSLTextureType(context, texType);
@@ -1247,7 +1269,7 @@ static void emitTextureSamplerType(
EmitContext* context,
RefPtr<TextureSamplerType> type)
{
- switch(context->target)
+ switch(context->shared->target)
{
case CodeGenTarget::GLSL:
emitGLSLTextureSamplerType(context, type);
@@ -1263,7 +1285,7 @@ static void emitImageType(
EmitContext* context,
RefPtr<GLSLImageType> type)
{
- switch(context->target)
+ switch(context->shared->target)
{
case CodeGenTarget::HLSL:
emitHLSLTextureType(context, type);
@@ -1300,7 +1322,7 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclara
}
else if (auto vecType = type->As<VectorExpressionType>())
{
- switch(context->target)
+ switch(context->shared->target)
{
case CodeGenTarget::GLSL:
case CodeGenTarget::GLSL_Vulkan:
@@ -1331,7 +1353,7 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclara
}
else if (auto matType = type->As<MatrixExpressionType>())
{
- switch(context->target)
+ switch(context->shared->target)
{
case CodeGenTarget::GLSL:
case CodeGenTarget::GLSL_Vulkan:
@@ -1386,7 +1408,7 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type, EDeclara
}
else if (auto samplerStateType = type->As<SamplerStateType>())
{
- switch(context->target)
+ switch(context->shared->target)
{
case CodeGenTarget::HLSL:
default:
@@ -1474,7 +1496,8 @@ static void EmitType(EmitContext* context, RefPtr<ExpressionType> type)
static void EmitType(EmitContext* context, TypeExp const& typeExp, Token const& nameToken)
{
- EmitType(context, typeExp.type, typeExp.exp->Position, nameToken.Content, nameToken.Position);
+ EmitType(context, typeExp.type,
+ typeExp.exp ? typeExp.exp->Position : CodePosition(), nameToken.Content, nameToken.Position);
}
static void EmitType(EmitContext* context, TypeExp const& typeExp, String const& name)
@@ -1497,12 +1520,9 @@ static void EmitBlockStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt
{
// TODO(tfoley): support indenting
Emit(context, "{\n");
- if( auto blockStmt = stmt.As<BlockStatementSyntaxNode>() )
+ if( auto blockStmt = stmt.As<BlockStmt>() )
{
- for (auto s : blockStmt->Statements)
- {
- EmitStmt(context, s);
- }
+ EmitStmt(context, blockStmt->body);
}
else
{
@@ -1544,7 +1564,7 @@ static void emitLineDirective(
emitRawText(context, " ");
- if(context->target == CodeGenTarget::GLSL)
+ if(context->shared->target == CodeGenTarget::GLSL)
{
auto path = sourceLocation.FileName;
@@ -1556,10 +1576,10 @@ static void emitLineDirective(
// extension and then emit a traditional line directive.
int id = 0;
- if(!context->mapGLSLSourcePathToID.TryGetValue(path, id))
+ if(!context->shared->mapGLSLSourcePathToID.TryGetValue(path, id))
{
- id = context->glslSourceIDCount++;
- context->mapGLSLSourcePathToID.Add(path, id);
+ id = context->shared->glslSourceIDCount++;
+ context->shared->mapGLSLSourcePathToID.Add(path, id);
}
sprintf(buffer, "%d", id);
@@ -1611,9 +1631,9 @@ static void emitLineDirectiveAndUpdateSourceLocation(
{
emitLineDirective(context, sourceLocation);
- context->loc.FileName = sourceLocation.FileName;
- context->loc.Line = sourceLocation.Line;
- context->loc.Col = 1;
+ context->shared->loc.FileName = sourceLocation.FileName;
+ context->shared->loc.Line = sourceLocation.Line;
+ context->shared->loc.Col = 1;
}
static void emitLineDirectiveIfNeeded(
@@ -1628,24 +1648,24 @@ static void emitLineDirectiveIfNeeded(
// a differnet file or line, *or* if the source location is
// somehow later on the line than what we want to emit,
// then we need to emit a new `#line` directive.
- if(sourceLocation.FileName != context->loc.FileName
- || sourceLocation.Line != context->loc.Line
- || sourceLocation.Col < context->loc.Col)
+ if(sourceLocation.FileName != context->shared->loc.FileName
+ || sourceLocation.Line != context->shared->loc.Line
+ || sourceLocation.Col < context->shared->loc.Col)
{
// Special case: if we are in the same file, and within a small number
// of lines of the target location, then go ahead and output newlines
// to get us caught up.
enum { kSmallLineCount = 3 };
- auto lineDiff = sourceLocation.Line - context->loc.Line;
- if(sourceLocation.FileName == context->loc.FileName
- && sourceLocation.Line > context->loc.Line
+ auto lineDiff = sourceLocation.Line - context->shared->loc.Line;
+ if(sourceLocation.FileName == context->shared->loc.FileName
+ && sourceLocation.Line > context->shared->loc.Line
&& lineDiff <= kSmallLineCount)
{
for(int ii = 0; ii < lineDiff; ++ii )
{
Emit(context, "\n");
}
- assert(sourceLocation.Line == context->loc.Line);
+ assert(sourceLocation.Line == context->shared->loc.Line);
}
else
{
@@ -1661,14 +1681,14 @@ static void emitLineDirectiveIfNeeded(
// came in as spaces or tabs, so there is necessarily going to be
// coupling between how the downstream compiler counts columns,
// and how we do.
- if(sourceLocation.Col > context->loc.Col)
+ if(sourceLocation.Col > context->shared->loc.Col)
{
- int delta = sourceLocation.Col - context->loc.Col;
+ int delta = sourceLocation.Col - context->shared->loc.Col;
for( int ii = 0; ii < delta; ++ii )
{
emitRawText(context, " ");
}
- context->loc.Col = sourceLocation.Col;
+ context->shared->loc.Col = sourceLocation.Col;
}
}
@@ -1680,22 +1700,22 @@ static void advanceToSourceLocation(
if(sourceLocation.Line <= 0)
return;
- context->needToUpdateSourceLocation = true;
- context->nextSourceLocation = sourceLocation;
+ context->shared->needToUpdateSourceLocation = true;
+ context->shared->nextSourceLocation = sourceLocation;
}
static void flushSourceLocationChange(
EmitContext* context)
{
- if(!context->needToUpdateSourceLocation)
+ if(!context->shared->needToUpdateSourceLocation)
return;
// Note: the order matters here, because trying to update
// the source location may involve outputting text that
// advances the location, and outputting text is what
// triggers this flush operation.
- context->needToUpdateSourceLocation = false;
- emitLineDirectiveIfNeeded(context, context->nextSourceLocation);
+ context->shared->needToUpdateSourceLocation = false;
+ emitLineDirectiveIfNeeded(context, context->shared->nextSourceLocation);
}
static void emitTokenWithLocation(EmitContext* context, Token const& token)
@@ -1737,11 +1757,19 @@ static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt)
// Try to ensure that debugging can find the right location
advanceToSourceLocation(context, stmt->Position);
- if (auto blockStmt = stmt.As<BlockStatementSyntaxNode>())
+ if (auto blockStmt = stmt.As<BlockStmt>())
{
EmitBlockStmt(context, blockStmt);
return;
}
+ else if (auto seqStmt = stmt.As<SeqStmt>())
+ {
+ for (auto ss : seqStmt->stmts)
+ {
+ EmitStmt(context, ss);
+ }
+ return;
+ }
else if( auto unparsedStmt = stmt.As<UnparsedStmt>() )
{
EmitUnparsedStmt(context, unparsedStmt);
@@ -1784,17 +1812,43 @@ static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt)
}
else if (auto forStmt = stmt.As<ForStatementSyntaxNode>())
{
- EmitLoopAttributes(context, forStmt);
+ // We are going to always take a `for` loop like:
+ //
+ // for(A; B; C) { D }
+ //
+ // and emit it as:
+ //
+ // { A; for(; B; C) { D } }
+ //
+ // This ensures that we are robust against any kind
+ // of statement appearing in `A`, including things
+ // that might occur due to lowering steps.
+ //
- Emit(context, "for(");
- if (auto initStmt = forStmt->InitialStatement)
+ // The one wrinkle is that HLSL implements the
+ // bad approach to scoping a `for` loop variable,
+ // so we need to avoid those outer `{...}` when
+ // we are generating HLSL via "rewrite" (that is,
+ // without our semantic checks).
+ //
+ bool brokenScoping = false;
+ if (context->shared->target == CodeGenTarget::HLSL
+ && context->isRewrite)
{
- EmitStmt(context, initStmt);
+ brokenScoping = true;
}
- else
+
+ auto initStmt = forStmt->InitialStatement;
+ if(initStmt)
{
- Emit(context, ";");
+ if(!brokenScoping)
+ Emit(context, "{\n");
+ EmitStmt(context, initStmt);
}
+
+ EmitLoopAttributes(context, forStmt);
+
+ Emit(context, "for(;");
if (auto testExp = forStmt->PredicateExpression)
{
EmitExpr(context, testExp);
@@ -1806,6 +1860,13 @@ static void EmitStmt(EmitContext* context, RefPtr<StatementSyntaxNode> stmt)
}
Emit(context, ")\n");
EmitBlockStmt(context, forStmt->Statement);
+
+ if (initStmt)
+ {
+ if(!brokenScoping)
+ Emit(context, "}\n");
+ }
+
return;
}
else if (auto whileStmt = stmt.As<WhileStatementSyntaxNode>())
@@ -2085,7 +2146,7 @@ static void EmitSemantic(EmitContext* context, RefPtr<HLSLSemantic> semantic, ES
static void EmitSemantics(EmitContext* context, RefPtr<Decl> decl, ESemanticMask mask = kESemanticMask_Default )
{
// Don't emit semantics if we aren't translating down to HLSL
- switch (context->target)
+ switch (context->shared->target)
{
case CodeGenTarget::HLSL:
break;
@@ -2263,7 +2324,7 @@ static void emitHLSLRegisterSemantics(
{
if (!layout) return;
- switch( context->target )
+ switch( context->shared->target )
{
default:
return;
@@ -2278,6 +2339,25 @@ static void emitHLSLRegisterSemantics(
}
}
+static RefPtr<VarLayout> maybeFetchLayout(
+ RefPtr<Decl> decl,
+ RefPtr<VarLayout> layout)
+{
+ // If we have already found layout info, don't go searching
+ if (layout) return layout;
+
+ // Otherwise, we need to look and see if computed layout
+ // information has been attached to the declaration.
+ auto modifier = decl->FindModifier<ComputedLayoutModifier>();
+ if (!modifier) return nullptr;
+
+ auto computedLayout = modifier->layout;
+ assert(computedLayout);
+
+ auto varLayout = computedLayout.As<VarLayout>();
+ return varLayout;
+}
+
static void emitHLSLParameterBlockDecl(
EmitContext* context,
RefPtr<VarDeclBase> varDecl,
@@ -2292,6 +2372,7 @@ static void emitHLSLParameterBlockDecl(
assert(declRefType);
// We expect to always have layout information
+ layout = maybeFetchLayout(varDecl, layout);
assert(layout);
// We expect the layout to be for a structured type...
@@ -2325,13 +2406,16 @@ static void emitHLSLParameterBlockDecl(
Emit(context, "\n{\n");
if (auto structRef = declRefType->declRef.As<StructSyntaxNode>())
{
+ int fieldCounter = 0;
+
for (auto field : getMembersOfType<StructField>(structRef))
{
+ int fieldIndex = fieldCounter++;
+
EmitVarDeclCommon(context, field);
- RefPtr<VarLayout> fieldLayout;
- structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout);
- assert(fieldLayout);
+ RefPtr<VarLayout> fieldLayout = structTypeLayout->fields[fieldIndex];
+ assert(fieldLayout->varDecl.GetName() == field.GetName());
// Emit explicit layout annotations for every field
for( auto rr : fieldLayout->resourceInfos )
@@ -2415,7 +2499,7 @@ emitGLSLLayoutQualifiers(
{
if(!layout) return;
- switch( context->target )
+ switch( context->shared->target )
{
default:
return;
@@ -2493,7 +2577,7 @@ static void emitGLSLParameterBlockDecl(
{
RefPtr<VarLayout> fieldLayout;
structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout);
- assert(fieldLayout);
+// assert(fieldLayout);
// TODO(tfoley): We may want to emit *some* of these,
// some of the time...
@@ -2521,7 +2605,7 @@ static void emitParameterBlockDecl(
RefPtr<ParameterBlockType> parameterBlockType,
RefPtr<VarLayout> layout)
{
- switch(context->target)
+ switch(context->shared->target)
{
case CodeGenTarget::HLSL:
emitHLSLParameterBlockDecl(context, varDecl, parameterBlockType, layout);
@@ -2539,6 +2623,8 @@ static void emitParameterBlockDecl(
static void EmitVarDecl(EmitContext* context, RefPtr<VarDeclBase> decl, RefPtr<VarLayout> layout)
{
+ layout = maybeFetchLayout(decl, layout);
+
// As a special case, a variable using a parameter block type
// will be translated into a declaration using the more primitive
// language syntax.
@@ -2606,7 +2692,7 @@ static void emitGLSLPreprocessorDirectives(
EmitContext* context,
RefPtr<ProgramSyntaxNode> program)
{
- switch(context->target)
+ switch(context->shared->target)
{
// Don't emit this stuff unless we are targetting GLSL
default:
@@ -2658,78 +2744,6 @@ static void emitGLSLPreprocessorDirectives(
// TODO: handle other cases...
}
-static void EmitProgram(
- EmitContext* context,
- RefPtr<ProgramSyntaxNode> program,
- RefPtr<ProgramLayout> programLayout)
-{
- // There may be global-scope modifiers that we should emit now
- emitGLSLPreprocessorDirectives(context, program);
-
- switch(context->target)
- {
- case CodeGenTarget::GLSL:
- {
- // TODO(tfoley): Need a plan for how to enable/disable these as needed...
-// Emit(context, "#extension GL_GOOGLE_cpp_style_line_directive : require\n");
- }
- break;
-
- default:
- break;
- }
-
-
- // Layout information for the global scope is either an ordinary
- // `struct` in the common case, or a constant buffer in the case
- // where there were global-scope uniforms.
- auto globalScopeLayout = programLayout->globalScopeLayout;
- if( auto globalStructLayout = globalScopeLayout.As<StructTypeLayout>() )
- {
- context->globalStructLayout = globalStructLayout.Ptr();
-
- // The `struct` case is easy enough to handle: we just
- // emit all the declarations directly, using their layout
- // information as a guideline.
- EmitDeclsInContainerUsingLayout(context, program, globalStructLayout);
- }
- else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>())
- {
- // TODO: the `cbuffer` case really needs to be emitted very
- // carefully, but that is beyond the scope of what a simple rewriter
- // can easily do (without semantic analysis, etc.).
- //
- // The crux of the problem is that we need to collect all the
- // global-scope uniforms (but not declarations that don't involve
- // uniform storage...) and put them in a single `cbuffer` declaration,
- // so that we can give it an explicit location. The fields in that
- // declaration might use various type declarations, so we'd really
- // need to emit all the type declarations first, and that involves
- // some large scale reorderings.
- //
- // For now we will punt and just emit the declarations normally,
- // and hope that the global-scope block (`$Globals`) gets auto-assigned
- // the same location that we manually asigned it.
-
- auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout;
- auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>();
-
- // We expect all constant buffers to contain `struct` types for now
- assert(elementTypeStructLayout);
-
- context->globalStructLayout = elementTypeStructLayout.Ptr();
-
- EmitDeclsInContainerUsingLayout(
- context,
- program,
- elementTypeStructLayout);
- }
- else
- {
- assert(!"unexpected");
- }
-}
-
static void EmitDeclImpl(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayout> layout)
{
// Don't emit code for declarations that came from the stdlib.
@@ -2783,18 +2797,18 @@ static void EmitDeclImpl(EmitContext* context, RefPtr<Decl> decl, RefPtr<VarLayo
// We might import the same module along two different paths,
// so we need to be careful to only emit each module once
// per output.
- if(!context->modulesAlreadyEmitted.Contains(moduleDecl))
+ if(!context->shared->modulesAlreadyEmitted.Contains(moduleDecl))
{
// Add the module to our set before emitting it, just
// in case a circular reference would lead us to
// infinite recursion (but that shouldn't be allowed
// in the first place).
- context->modulesAlreadyEmitted.Add(moduleDecl);
+ context->shared->modulesAlreadyEmitted.Add(moduleDecl);
// TODO: do we need to modify the code generation environment at
// all when doing this recursive emit?
- EmitDeclsInContainerUsingLayout(context, moduleDecl, context->globalStructLayout);
+ EmitDeclsInContainerUsingLayout(context, moduleDecl, context->shared->globalStructLayout);
}
return;
@@ -2839,7 +2853,7 @@ static void registerReservedWord(
EmitContext* context,
String const& name)
{
- context->reservedWords.Add(name, name);
+ context->shared->reservedWords.Add(name, name);
}
static void registerReservedWords(
@@ -2847,7 +2861,7 @@ static void registerReservedWords(
{
#define WORD(NAME) registerReservedWord(context, #NAME)
- switch (context->target)
+ switch (context->shared->target)
{
case CodeGenTarget::GLSL:
WORD(attribute);
@@ -2990,68 +3004,115 @@ static void registerReservedWords(
}
}
-String emitProgram(
- ProgramSyntaxNode* program,
+bool isRewriteRequest(
+ SourceLanguage sourceLanguage,
+ CodeGenTarget target);
+
+String emitEntryPoint(
+ EntryPointRequest* entryPoint,
ProgramLayout* programLayout,
CodeGenTarget target)
{
- // TODO(tfoley): only emit symbols on-demand, as needed by a particular entry point
+ auto translationUnit = entryPoint->getTranslationUnit();
+
+ SharedEmitContext sharedContext;
+ sharedContext.target = target;
+
+ sharedContext.programLayout = programLayout;
+
+ // Layout information for the global scope is either an ordinary
+ // `struct` in the common case, or a constant buffer in the case
+ // where there were global-scope uniforms.
+ auto globalScopeLayout = programLayout->globalScopeLayout;
+ StructTypeLayout* globalStructLayout = nullptr;
+ if( auto globalStructLayout = globalScopeLayout.As<StructTypeLayout>() )
+ {
+ globalStructLayout = globalStructLayout.Ptr();
+ }
+ else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>())
+ {
+ // TODO: the `cbuffer` case really needs to be emitted very
+ // carefully, but that is beyond the scope of what a simple rewriter
+ // can easily do (without semantic analysis, etc.).
+ //
+ // The crux of the problem is that we need to collect all the
+ // global-scope uniforms (but not declarations that don't involve
+ // uniform storage...) and put them in a single `cbuffer` declaration,
+ // so that we can give it an explicit location. The fields in that
+ // declaration might use various type declarations, so we'd really
+ // need to emit all the type declarations first, and that involves
+ // some large scale reorderings.
+ //
+ // For now we will punt and just emit the declarations normally,
+ // and hope that the global-scope block (`$Globals`) gets auto-assigned
+ // the same location that we manually asigned it.
+
+ auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout;
+ auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>();
+
+ // We expect all constant buffers to contain `struct` types for now
+ assert(elementTypeStructLayout);
+
+ globalStructLayout = elementTypeStructLayout.Ptr();
+ }
+ else
+ {
+ assert(!"unexpected");
+ }
+ sharedContext.globalStructLayout = globalStructLayout;
EmitContext context;
- context.target = target;
+ context.shared = &sharedContext;
+ context.isRewrite = isRewriteRequest(
+ translationUnit->sourceLanguage,
+ target);
+ // TODO: this should only need to take the shared context
registerReservedWords(&context);
- EmitProgram(&context, program, programLayout);
+ auto translationUnitSyntax = translationUnit->SyntaxNode.Ptr();
- String code = context.sb.ProduceString();
- return code;
-
-#if 0
- // HACK(tfoley): Invoke the D3D HLSL compiler on the result, to validate it
+ // There may be global-scope modifiers that we should emit now
+ emitGLSLPreprocessorDirectives(&context, translationUnitSyntax);
-#ifdef _WIN32
+ switch(target)
{
- HMODULE d3dCompiler = LoadLibraryA("d3dcompiler_47");
- assert(d3dCompiler);
-
- pD3DCompile D3DCompile_ = (pD3DCompile)GetProcAddress(d3dCompiler, "D3DCompile");
- assert(D3DCompile_);
-
- ID3DBlob* codeBlob;
- ID3DBlob* diagnosticsBlob;
- HRESULT hr = D3DCompile_(
- code.begin(),
- code.Length(),
- "slang",
- nullptr,
- nullptr,
- "main",
- "ps_5_0",
- 0,
- 0,
- &codeBlob,
- &diagnosticsBlob);
- if (codeBlob) codeBlob->Release();
- if (diagnosticsBlob)
- {
- String diagnostics = (char const*) diagnosticsBlob->GetBufferPointer();
- fprintf(stderr, "%s", diagnostics.begin());
- OutputDebugStringA(diagnostics.begin());
- diagnosticsBlob->Release();
- }
- if (FAILED(hr))
+ case CodeGenTarget::GLSL:
{
- int f = 9;
+ // TODO(tfoley): Need a plan for how to enable/disable these as needed...
+// Emit(context, "#extension GL_GOOGLE_cpp_style_line_directive : require\n");
}
+ break;
+
+ default:
+ break;
}
- #include <d3dcompiler.h>
-#endif
+ auto lowered = lowerEntryPoint(entryPoint, programLayout, target);
+
+ EmitDeclsInContainer(&context, lowered.program.Ptr());
+
+#if 0
+ if( isRewrite )
+ {
+ // In rewrite mode, we will just emit the text of the translation unit as given,
+ // and not pay attention to the specific entry point that was requested.
+ //
+ // It is a user error to request GLSL output and have an entry point name
+ // other than `main`.
+ EmitDeclsInContainerUsingLayout(&context, translationUnitSyntax, globalStructLayout);
+ }
+ else
+ {
+ // We are being asked to emit a single entry point in "full" mode.
+ emitEntryPoint(&context, entryPoint);
+ }
#endif
-}
+ String code = sharedContext.sb.ProduceString();
+ return code;
+}
} // namespace Slang
diff --git a/source/slang/emit.h b/source/slang/emit.h
index 4f3e98c8d..da1ac9f08 100644
--- a/source/slang/emit.h
+++ b/source/slang/emit.h
@@ -8,11 +8,14 @@
namespace Slang
{
- class ProgramSyntaxNode;
+ class EntryPointRequest;
class ProgramLayout;
+ class TranslationUnitRequest;
- String emitProgram(
- ProgramSyntaxNode* program,
+ // Emit code for a single entry point, based on
+ // the input translation unit.
+ String emitEntryPoint(
+ EntryPointRequest* entryPoint,
ProgramLayout* programLayout,
CodeGenTarget target);
}
diff --git a/source/slang/expr-defs.h b/source/slang/expr-defs.h
index 59bad37e7..ca5bfacb8 100644
--- a/source/slang/expr-defs.h
+++ b/source/slang/expr-defs.h
@@ -107,3 +107,8 @@ SYNTAX_CLASS(SharedTypeExpr, ExpressionSyntaxNode)
SYNTAX_FIELD(TypeExp, base)
END_SYNTAX_CLASS()
+SYNTAX_CLASS(AssignExpr, ExpressionSyntaxNode)
+ SYNTAX_FIELD(RefPtr<ExpressionSyntaxNode>, left);
+ SYNTAX_FIELD(RefPtr<ExpressionSyntaxNode>, right);
+END_SYNTAX_CLASS()
+
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
new file mode 100644
index 000000000..674614cd3
--- /dev/null
+++ b/source/slang/lower.cpp
@@ -0,0 +1,1601 @@
+// lower.cpp
+#include "lower.h"
+
+#include "visitor.h"
+
+namespace Slang
+{
+
+//
+
+template<typename V>
+struct StructuralTransformVisitorBase
+{
+ V* visitor;
+
+ RefPtr<StatementSyntaxNode> transformDeclField(StatementSyntaxNode* stmt)
+ {
+ return visitor->translateStmtRef(stmt);
+ }
+
+ RefPtr<Decl> transformDeclField(Decl* decl)
+ {
+ return visitor->translateDeclRef(decl);
+ }
+
+ template<typename T>
+ DeclRef<T> transformDeclField(DeclRef<T> const& decl)
+ {
+ return visitor->translateDeclRef(decl).As<T>();
+ }
+
+ TypeExp transformSyntaxField(TypeExp const& typeExp)
+ {
+ TypeExp result;
+ result.type = visitor->transformSyntaxField(typeExp.type);
+ return result;
+ }
+
+ QualType transformSyntaxField(QualType const& qualType)
+ {
+ QualType result = qualType;
+ result.type = visitor->transformSyntaxField(qualType.type);
+ return result;
+ }
+
+ RefPtr<ExpressionSyntaxNode> transformSyntaxField(ExpressionSyntaxNode* expr)
+ {
+ return visitor->transformSyntaxField(expr);
+ }
+
+ RefPtr<StatementSyntaxNode> transformSyntaxField(StatementSyntaxNode* stmt)
+ {
+ return visitor->transformSyntaxField(stmt);
+ }
+
+ RefPtr<DeclBase> transformSyntaxField(DeclBase* decl)
+ {
+ return visitor->transformSyntaxField(decl);
+ }
+
+ RefPtr<ScopeDecl> transformSyntaxField(ScopeDecl* decl)
+ {
+ return visitor->transformSyntaxField(decl).As<ScopeDecl>();
+ }
+
+ template<typename T>
+ List<T> transformSyntaxField(List<T> const& list)
+ {
+ List<T> result;
+ for (auto item : list)
+ {
+ result.Add(transformSyntaxField(item));
+ }
+ return result;
+ }
+};
+
+template<typename V>
+struct StructuralTransformStmtVisitor
+ : StructuralTransformVisitorBase<V>
+ , StmtVisitor<StructuralTransformStmtVisitor<V>, RefPtr<StatementSyntaxNode>>
+{
+ void transformFields(StatementSyntaxNode* result, StatementSyntaxNode* obj)
+ {
+ }
+
+#define SYNTAX_CLASS(NAME, BASE, ...) \
+ RefPtr<StatementSyntaxNode> visit(NAME* obj) { \
+ RefPtr<NAME> result = new NAME(*obj); \
+ transformFields(result, obj); \
+ return result; \
+ } \
+ void transformFields(NAME* result, NAME* obj) { \
+ transformFields((BASE*) result, (BASE*) obj); \
+
+#define SYNTAX_FIELD(TYPE, NAME) result->NAME = this->transformSyntaxField(obj->NAME);
+#define DECL_FIELD(TYPE, NAME) result->NAME = this->transformDeclField(obj->NAME);
+
+#define FIELD(TYPE, NAME) /* empty */
+
+#define END_SYNTAX_CLASS() \
+ }
+
+#include "object-meta-begin.h"
+#include "stmt-defs.h"
+#include "object-meta-end.h"
+
+};
+
+template<typename V>
+RefPtr<StatementSyntaxNode> structuralTransform(
+ StatementSyntaxNode* stmt,
+ V* visitor)
+{
+ StructuralTransformStmtVisitor<V> transformer;
+ transformer.visitor = visitor;
+ return transformer.dispatch(stmt);
+}
+
+template<typename V>
+struct StructuralTransformExprVisitor
+ : StructuralTransformVisitorBase<V>
+ , ExprVisitor<StructuralTransformExprVisitor<V>, RefPtr<ExpressionSyntaxNode>>
+{
+ void transformFields(ExpressionSyntaxNode* result, ExpressionSyntaxNode* obj)
+ {
+ result->Type = transformSyntaxField(obj->Type);
+ }
+
+
+#define SYNTAX_CLASS(NAME, BASE, ...) \
+ RefPtr<ExpressionSyntaxNode> visit(NAME* obj) { \
+ RefPtr<NAME> result = new NAME(*obj); \
+ transformFields(result, obj); \
+ return result; \
+ } \
+ void transformFields(NAME* result, NAME* obj) { \
+ transformFields((BASE*) result, (BASE*) obj); \
+
+#define SYNTAX_FIELD(TYPE, NAME) result->NAME = transformSyntaxField(obj->NAME);
+#define DECL_FIELD(TYPE, NAME) result->NAME = transformDeclField(obj->NAME);
+
+#define FIELD(TYPE, NAME) /* empty */
+
+#define END_SYNTAX_CLASS() \
+ }
+
+#include "object-meta-begin.h"
+#include "expr-defs.h"
+#include "object-meta-end.h"
+};
+
+
+template<typename V>
+RefPtr<ExpressionSyntaxNode> structuralTransform(
+ ExpressionSyntaxNode* expr,
+ V* visitor)
+{
+ StructuralTransformExprVisitor<V> transformer;
+ transformer.visitor = visitor;
+ return transformer.dispatch(expr);
+}
+
+//
+
+// Pseudo-syntax used during lowering
+class TupleDecl : public VarDeclBase
+{
+public:
+ virtual void accept(IDeclVisitor *, void *) override
+ {
+ throw "unexpected";
+ }
+
+ List<RefPtr<VarDeclBase>> decls;
+};
+
+// Pseudo-syntax used during lowering:
+// represents an ordered list of expressions as a single unit
+class TupleExpr : public ExpressionSyntaxNode
+{
+public:
+ virtual void accept(IExprVisitor *, void *) override
+ {
+ throw "unexpected";
+ }
+
+ List<RefPtr<ExpressionSyntaxNode>> exprs;
+};
+
+struct SharedLoweringContext
+{
+ ProgramLayout* programLayout;
+ CodeGenTarget target;
+
+ RefPtr<ProgramSyntaxNode> loweredProgram;
+
+ Dictionary<Decl*, RefPtr<Decl>> loweredDecls;
+ Dictionary<Decl*, Decl*> mapLoweredDeclToOriginal;
+
+ bool isRewrite;
+};
+
+static void attachLayout(
+ ModifiableSyntaxNode* syntax,
+ Layout* layout)
+{
+ RefPtr<ComputedLayoutModifier> modifier = new ComputedLayoutModifier();
+ modifier->layout = layout;
+
+ addModifier(syntax, modifier);
+}
+
+struct LoweringVisitor
+ : ExprVisitor<LoweringVisitor, RefPtr<ExpressionSyntaxNode>>
+ , StmtVisitor<LoweringVisitor, void>
+ , DeclVisitor<LoweringVisitor, RefPtr<Decl>>
+ , TypeVisitor<LoweringVisitor, RefPtr<ExpressionType>>
+ , ValVisitor<LoweringVisitor, RefPtr<Val>>
+{
+ //
+ SharedLoweringContext* shared;
+ RefPtr<Substitutions> substitutions;
+
+ bool isBuildingStmt = false;
+ RefPtr<StatementSyntaxNode> stmtBeingBuilt;
+
+ // If we *aren't* building a statement, then this
+ // is the container we should be adding declarations to
+ RefPtr<ContainerDecl> parentDecl;
+
+ // If we are in a context where a `return` should be turned
+ // into assignment to a variable (followed by a `return`),
+ // then this will point to that variable.
+ RefPtr<Variable> resultVariable;
+
+ CodeGenTarget getTarget() { return shared->target; }
+
+ //
+ // Values
+ //
+
+ RefPtr<Val> lowerVal(Val* val)
+ {
+ if (!val) return nullptr;
+ return ValVisitor::dispatch(val);
+ }
+
+ RefPtr<Val> visit(GenericParamIntVal* val)
+ {
+ return new GenericParamIntVal(translateDeclRef(DeclRef<Decl>(val->declRef)).As<VarDeclBase>());
+ }
+
+ RefPtr<Val> visit(ConstantIntVal* val)
+ {
+ return val;
+ }
+
+ //
+ // Types
+ //
+
+ RefPtr<ExpressionType> lowerType(
+ ExpressionType* type)
+ {
+ return TypeVisitor::dispatch(type);
+ }
+
+ TypeExp lowerType(
+ TypeExp const& typeExp)
+ {
+ TypeExp result;
+ result.type = lowerType(typeExp.type);
+ return result;
+ }
+
+ RefPtr<ExpressionType> visit(ErrorType* type)
+ {
+ return type;
+ }
+
+ RefPtr<ExpressionType> visit(OverloadGroupType* type)
+ {
+ return type;
+ }
+
+ RefPtr<ExpressionType> visit(InitializerListType* type)
+ {
+ return type;
+ }
+
+ RefPtr<ExpressionType> visit(GenericDeclRefType* type)
+ {
+ return new GenericDeclRefType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<GenericDecl>());
+ }
+
+ RefPtr<ExpressionType> visit(FuncType* type)
+ {
+ RefPtr<FuncType> loweredType = new FuncType();
+ loweredType->declRef = translateDeclRef(DeclRef<Decl>(type->declRef)).As<CallableDecl>();
+ return loweredType;
+ }
+
+ RefPtr<ExpressionType> visit(DeclRefType* type)
+ {
+ auto loweredDeclRef = translateDeclRef(type->declRef);
+ return DeclRefType::Create(loweredDeclRef);
+ }
+
+ RefPtr<ExpressionType> visit(NamedExpressionType* type)
+ {
+ if (shared->target == CodeGenTarget::GLSL)
+ {
+ // GLSL does not support `typedef`, so we will lower it out of existence here
+ return lowerType(GetType(type->declRef));
+ }
+
+ return new NamedExpressionType(translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>());
+ }
+
+ RefPtr<ExpressionType> visit(TypeType* type)
+ {
+ return new TypeType(lowerType(type->type));
+ }
+
+ RefPtr<ExpressionType> visit(ArrayExpressionType* type)
+ {
+ RefPtr<ArrayExpressionType> loweredType = new ArrayExpressionType();
+ loweredType->BaseType = lowerType(type->BaseType);
+ loweredType->ArrayLength = lowerVal(type->ArrayLength).As<IntVal>();
+ return loweredType;
+ }
+
+ RefPtr<ExpressionType> transformSyntaxField(ExpressionType* type)
+ {
+ return lowerType(type);
+ }
+
+ //
+ // Expressions
+ //
+
+ RefPtr<ExpressionSyntaxNode> lowerExpr(
+ ExpressionSyntaxNode* expr)
+ {
+ if (!expr) return nullptr;
+ return ExprVisitor::dispatch(expr);
+ }
+
+ // catch-all
+ RefPtr<ExpressionSyntaxNode> visit(
+ ExpressionSyntaxNode* expr)
+ {
+ return structuralTransform(expr, this);
+ }
+
+ RefPtr<ExpressionSyntaxNode> transformSyntaxField(ExpressionSyntaxNode* expr)
+ {
+ return lowerExpr(expr);
+ }
+
+ void lowerExprCommon(
+ RefPtr<ExpressionSyntaxNode> loweredExpr,
+ RefPtr<ExpressionSyntaxNode> expr)
+ {
+ loweredExpr->Position = expr->Position;
+ loweredExpr->Type.type = lowerType(expr->Type.type);
+ }
+
+ RefPtr<ExpressionSyntaxNode> createVarRef(
+ CodePosition const& loc,
+ VarDeclBase* decl)
+ {
+ if (auto tupleDecl = dynamic_cast<TupleDecl*>(decl))
+ {
+ return createTupleRef(loc, tupleDecl);
+ }
+ else
+ {
+ RefPtr<VarExpressionSyntaxNode> result = new VarExpressionSyntaxNode();
+ result->Position = loc;
+ result->Type.type = decl->Type.type;
+ result->declRef = makeDeclRef(decl);
+ return result;
+ }
+ }
+
+ RefPtr<ExpressionSyntaxNode> createTupleRef(
+ CodePosition const& loc,
+ TupleDecl* decl)
+ {
+ RefPtr<TupleExpr> result = new TupleExpr();
+ result->Position = loc;
+ result->Type.type = decl->Type.type;
+
+ for (auto dd : decl->decls)
+ {
+ auto expr = createVarRef(loc, dd);
+ result->exprs.Add(expr);
+ }
+
+ return result;
+ }
+
+ RefPtr<ExpressionSyntaxNode> visit(
+ VarExpressionSyntaxNode* expr)
+ {
+ // If the expression didn't get resolved, we can leave it as-is
+ if (!expr->declRef)
+ return expr;
+
+ auto loweredDeclRef = translateDeclRef(expr->declRef);
+ auto loweredDecl = loweredDeclRef.getDecl();
+
+ if (auto tupleDecl = dynamic_cast<TupleDecl*>(loweredDecl))
+ {
+ // If we are referencing a declaration that got tuple-ified,
+ // then we need to produce a tuple expression as well.
+
+ return createTupleRef(expr->Position, tupleDecl);
+ }
+
+ RefPtr<VarExpressionSyntaxNode> loweredExpr = new VarExpressionSyntaxNode();
+ lowerExprCommon(loweredExpr, expr);
+ loweredExpr->declRef = loweredDeclRef;
+ return loweredExpr;
+ }
+
+ RefPtr<ExpressionSyntaxNode> visit(
+ MemberExpressionSyntaxNode* expr)
+ {
+ auto loweredBase = lowerExpr(expr->BaseExpression);
+
+ // Are we extracting an element from a tuple?
+ if (auto baseTuple = loweredBase.As<TupleExpr>())
+ {
+ // We need to find the correct member expression,
+ // based on the actual tuple type.
+
+ throw "unimplemented";
+ }
+
+ // Default handling:
+ auto loweredDeclRef = translateDeclRef(expr->declRef);
+ assert(!dynamic_cast<TupleDecl*>(loweredDeclRef.getDecl()));
+
+ RefPtr<MemberExpressionSyntaxNode> loweredExpr = new MemberExpressionSyntaxNode();
+ lowerExprCommon(loweredExpr, expr);
+ loweredExpr->BaseExpression = loweredBase;
+ loweredExpr->declRef = loweredDeclRef;
+
+ return loweredExpr;
+ }
+
+ //
+ // Statements
+ //
+
+ StatementSyntaxNode* translateStmtRef(
+ StatementSyntaxNode* stmt)
+ {
+ throw "unimplemented";
+ }
+
+ RefPtr<StatementSyntaxNode> lowerStmt(
+ StatementSyntaxNode* stmt)
+ {
+ if(!stmt)
+ return nullptr;
+
+ LoweringVisitor subVisitor = *this;
+ subVisitor.stmtBeingBuilt = nullptr;
+
+ subVisitor.lowerStmtImpl(stmt);
+
+ if( !subVisitor.stmtBeingBuilt )
+ {
+ return new EmptyStatementSyntaxNode();
+ }
+ else
+ {
+ return subVisitor.stmtBeingBuilt;
+ }
+ }
+
+ void lowerStmtImpl(
+ StatementSyntaxNode* stmt)
+ {
+ StmtVisitor::dispatch(stmt);
+ }
+
+ RefPtr<ScopeDecl> visit(ScopeDecl* decl)
+ {
+ RefPtr<ScopeDecl> loweredDecl = new ScopeDecl();
+ lowerDeclCommon(loweredDecl, decl);
+ return loweredDecl;
+ }
+
+ LoweringVisitor pushScope(
+ RefPtr<ScopeStmt> loweredStmt,
+ RefPtr<ScopeStmt> stmt)
+ {
+ loweredStmt->scopeDecl = translateDeclRef(stmt->scopeDecl).As<ScopeDecl>();
+
+ LoweringVisitor subVisitor = *this;
+ subVisitor.isBuildingStmt = true;
+ subVisitor.stmtBeingBuilt = nullptr;
+ subVisitor.parentDecl = loweredStmt->scopeDecl;
+ return subVisitor;
+ }
+
+ void addStmtImpl(
+ RefPtr<StatementSyntaxNode>& dest,
+ StatementSyntaxNode* stmt)
+ {
+ // add a statement to the code we are building...
+ if( !dest )
+ {
+ dest = stmt;
+ return;
+ }
+
+ if (auto blockStmt = dest.As<BlockStmt>())
+ {
+ addStmtImpl(blockStmt->body, stmt);
+ return;
+ }
+
+ if (auto seqStmt = dest.As<SeqStmt>())
+ {
+ seqStmt->stmts.Add(stmt);
+ }
+ else
+ {
+ RefPtr<SeqStmt> newSeqStmt = new SeqStmt();
+
+ newSeqStmt->stmts.Add(dest);
+ newSeqStmt->stmts.Add(stmt);
+
+ dest = newSeqStmt;
+ }
+
+ }
+
+ void addStmt(
+ StatementSyntaxNode* stmt)
+ {
+ addStmtImpl(stmtBeingBuilt, stmt);
+ }
+
+ void addExprStmt(
+ RefPtr<ExpressionSyntaxNode> expr)
+ {
+ // TODO: handle cases where the `expr` cannot be directly
+ // represented as a single statement
+
+ RefPtr<ExpressionStatementSyntaxNode> stmt = new ExpressionStatementSyntaxNode();
+ stmt->Expression = expr;
+ addStmt(stmt);
+ }
+
+ void visit(BlockStmt* stmt)
+ {
+ RefPtr<BlockStmt> loweredStmt = new BlockStmt();
+
+ LoweringVisitor subVisitor = pushScope(loweredStmt, stmt);
+ subVisitor.stmtBeingBuilt = loweredStmt;
+
+ subVisitor.lowerStmtImpl(stmt->body);
+
+ addStmt(loweredStmt);
+ }
+
+ void visit(SeqStmt* stmt)
+ {
+ for( auto ss : stmt->stmts )
+ {
+ lowerStmtImpl(ss);
+ }
+ }
+
+ void visit(ExpressionStatementSyntaxNode* stmt)
+ {
+ addExprStmt(lowerExpr(stmt->Expression));
+ }
+
+ void visit(VarDeclrStatementSyntaxNode* stmt)
+ {
+ DeclVisitor::dispatch(stmt->decl);
+ }
+
+ // catch-all
+ void visit(StatementSyntaxNode* stmt)
+ {
+ auto loweredStmt = structuralTransform(stmt, this);
+ addStmt(loweredStmt);
+ }
+
+ RefPtr<StatementSyntaxNode> transformSyntaxField(StatementSyntaxNode* stmt)
+ {
+ return lowerStmt(stmt);
+ }
+
+ void lowerStmtCommon(StatementSyntaxNode* loweredStmt, StatementSyntaxNode* stmt)
+ {
+ loweredStmt->modifiers = stmt->modifiers;
+ }
+
+ void assign(
+ RefPtr<ExpressionSyntaxNode> destExpr,
+ RefPtr<ExpressionSyntaxNode> srcExpr)
+ {
+ RefPtr<AssignExpr> assignExpr = new AssignExpr();
+ assignExpr->Position = destExpr->Position;
+ assignExpr->left = destExpr;
+ assignExpr->right = srcExpr;
+
+ addExprStmt(assignExpr);
+ }
+
+ void assign(VarDeclBase* varDecl, RefPtr<ExpressionSyntaxNode> expr)
+ {
+ assign(createVarRef(expr->Position, varDecl), expr);
+ }
+
+ void assign(RefPtr<ExpressionSyntaxNode> expr, VarDeclBase* varDecl)
+ {
+ assign(expr, createVarRef(expr->Position, varDecl));
+ }
+
+ void visit(ReturnStatementSyntaxNode* stmt)
+ {
+ auto loweredStmt = new ReturnStatementSyntaxNode();
+ lowerStmtCommon(loweredStmt, stmt);
+
+ if (stmt->Expression)
+ {
+ if (resultVariable)
+ {
+ // Do it as an assignment
+ assign(resultVariable, lowerExpr(stmt->Expression));
+ }
+ else
+ {
+ // Simple case
+ loweredStmt->Expression = lowerExpr(stmt->Expression);
+ }
+ }
+
+ addStmt(loweredStmt);
+ }
+
+ //
+ // Declarations
+ //
+
+ RefPtr<Val> translateVal(Val* val)
+ {
+ if (auto type = dynamic_cast<ExpressionType*>(val))
+ return lowerType(type);
+
+ if (auto litVal = dynamic_cast<ConstantIntVal*>(val))
+ return val;
+
+ throw 99;
+ }
+
+ RefPtr<Substitutions> translateSubstitutions(
+ Substitutions* substitutions)
+ {
+ if (!substitutions) return nullptr;
+
+ RefPtr<Substitutions> result = new Substitutions();
+ result->genericDecl = translateDeclRef(substitutions->genericDecl).As<GenericDecl>();
+ for (auto arg : substitutions->args)
+ {
+ result->args.Add(translateVal(arg));
+ }
+ return result;
+ }
+
+ static Decl* getModifiedDecl(Decl* decl)
+ {
+ if (!decl) return nullptr;
+ if (auto genericDecl = dynamic_cast<GenericDecl*>(decl->ParentDecl))
+ return genericDecl;
+ return decl;
+ }
+
+ DeclRef<Decl> translateDeclRef(
+ DeclRef<Decl> const& decl)
+ {
+ DeclRef<Decl> result;
+ result.decl = translateDeclRef(decl.decl);
+ result.substitutions = translateSubstitutions(decl.substitutions);
+ return result;
+ }
+
+ RefPtr<Decl> translateDeclRef(
+ Decl* decl)
+ {
+ if (!decl) return nullptr;
+
+ // We don't want to translate references to built-in declarations,
+ // since they won't be subtituted anyway.
+ if (getModifiedDecl(decl)->HasModifier<FromStdLibModifier>())
+ return decl;
+
+ // If any parent of the declaration was in the stdlib, then
+ // we need to skip it.
+ for(auto pp = decl; pp; pp = pp->ParentDecl)
+ {
+ if (pp->HasModifier<FromStdLibModifier>())
+ return decl;
+ }
+
+ if (getModifiedDecl(decl)->HasModifier<BuiltinModifier>())
+ return decl;
+
+ RefPtr<Decl> loweredDecl;
+ if (shared->loweredDecls.TryGetValue(decl, loweredDecl))
+ return loweredDecl;
+
+ // Time to force it
+ return lowerDecl(decl);
+ }
+
+ RefPtr<ContainerDecl> translateDeclRef(
+ ContainerDecl* decl)
+ {
+ return translateDeclRef((Decl*)decl).As<ContainerDecl>();
+ }
+
+ RefPtr<DeclBase> lowerDeclBase(
+ DeclBase* declBase)
+ {
+ if (Decl* decl = dynamic_cast<Decl*>(declBase))
+ {
+ return lowerDecl(decl);
+ }
+ else
+ {
+ DeclVisitor::dispatch(declBase);
+ }
+
+ }
+
+ RefPtr<Decl> lowerDecl(
+ Decl* decl)
+ {
+ RefPtr<Decl> loweredDecl = DeclVisitor::dispatch(decl).As<Decl>();
+ return loweredDecl;
+ }
+
+ static void addMember(
+ RefPtr<ContainerDecl> containerDecl,
+ RefPtr<Decl> memberDecl)
+ {
+ containerDecl->Members.Add(memberDecl);
+ memberDecl->ParentDecl = containerDecl.Ptr();
+ }
+
+ void addDecl(
+ Decl* decl)
+ {
+ if(isBuildingStmt)
+ {
+ RefPtr<VarDeclrStatementSyntaxNode> declStmt = new VarDeclrStatementSyntaxNode();
+ declStmt->Position = decl->Position;
+ declStmt->decl = decl;
+ addStmt(declStmt);
+ }
+
+
+ // We will add the declaration to the current container declaration being
+ // translated, which the user will maintain via pua/pop.
+ //
+
+ assert(parentDecl);
+ addMember(parentDecl, decl);
+ }
+
+ void registerLoweredDecl(Decl* loweredDecl, Decl* decl)
+ {
+ shared->loweredDecls.Add(decl, loweredDecl);
+
+ shared->mapLoweredDeclToOriginal.Add(loweredDecl, decl);
+ }
+
+ void lowerDeclCommon(
+ Decl* loweredDecl,
+ Decl* decl)
+ {
+ registerLoweredDecl(loweredDecl, decl);
+
+ loweredDecl->Position = decl->Position;
+ loweredDecl->Name = decl->getNameToken();
+
+ // Lower modifiers as needed
+
+ // HACK: just doing a shallow copy of modifiers, which will
+ // suffice for most of them, but we need to do something
+ // better soon.
+ loweredDecl->modifiers = decl->modifiers;
+
+ // deal with layout stuff
+
+ auto loweredParent = translateDeclRef(decl->ParentDecl);
+ if (loweredParent)
+ {
+ auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>();
+ if (layoutMod)
+ {
+ auto parentLayout = layoutMod->layout;
+ if (auto structLayout = parentLayout.As<StructTypeLayout>())
+ {
+ RefPtr<VarLayout> fieldLayout;
+ if (structLayout->mapVarToLayout.TryGetValue(decl, fieldLayout))
+ {
+ attachLayout(loweredDecl, fieldLayout);
+ }
+ }
+
+ // TODO: are there other cases to handle here?
+ }
+ }
+ }
+
+ // Catch-all
+ RefPtr<Decl> visit(
+ Decl* decl)
+ {
+ assert(!"unimplemented");
+ return decl;
+ }
+
+ RefPtr<ImportDecl> visit(ImportDecl* decl)
+ {
+ // No need to translate things here if we are
+ // in "full" mode, because we will selectively
+ // translate the imported declarations at their
+ // use sites(s).
+ if (!shared->isRewrite)
+ return nullptr;
+
+ for (auto dd : decl->importedModuleDecl->Members)
+ {
+ translateDeclRef(dd);
+ }
+
+ // Don't actually include a representation of
+ // the import declaration in the output
+ return nullptr;
+ }
+
+ RefPtr<EmptyDecl> visit(EmptyDecl* decl)
+ {
+ // Empty declarations are really only useful in GLSL,
+ // where they are used to hold metadata that doesn't
+ // attach to any particular shader parameter.
+ //
+ // TODO: Only lower empty declarations if we are
+ // rewriting a GLSL file, and otherwise ignore them.
+ //
+ RefPtr<EmptyDecl> loweredDecl = new EmptyDecl();
+ lowerDeclCommon(loweredDecl, decl);
+
+ addDecl(loweredDecl);
+
+ return loweredDecl;
+ }
+
+ RefPtr<Decl> visit(AggTypeDecl* decl)
+ {
+ // We want to lower any aggregate type declaration
+ // to just a `struct` type that contains its fields.
+ //
+ // Any non-field members (e.g., methods) will be
+ // lowered separately.
+
+ // TODO: also need to figure out how to handle fields
+ // with types that should not be allowed in a `struct`
+ // for the chosen target.
+ // (also: what to do if there are no fields left
+ // after removing invalid ones?)
+
+ RefPtr<StructSyntaxNode> loweredDecl = new StructSyntaxNode();
+ lowerDeclCommon(loweredDecl, decl);
+
+ for (auto field : decl->getMembersOfType<VarDeclBase>())
+ {
+ // TODO: anything more to do than this?
+ addMember(loweredDecl, translateDeclRef(field));
+ }
+
+ addMember(
+ shared->loweredProgram,
+ loweredDecl);
+
+ return loweredDecl;
+ }
+
+ RefPtr<VarDeclBase> lowerVarDeclCommon(
+ RefPtr<VarDeclBase> loweredDecl,
+ VarDeclBase* decl)
+ {
+ lowerDeclCommon(loweredDecl, decl);
+
+ loweredDecl->Type = lowerType(decl->Type);
+ loweredDecl->Expr = lowerExpr(decl->Expr);
+
+ return loweredDecl;
+ }
+
+ RefPtr<VarDeclBase> visit(
+ Variable* decl)
+ {
+ auto loweredDecl = lowerVarDeclCommon(new Variable(), decl);
+
+ // We need to add things to an appropriate scope, based on what
+ // we are referencing.
+ //
+ // If this is a global variable (program scope), then add it
+ // to the global scope.
+ RefPtr<ContainerDecl> parentDecl = decl->ParentDecl;
+ if (auto parentModuleDecl = parentDecl.As<ProgramSyntaxNode>())
+ {
+ addMember(
+ translateDeclRef(parentModuleDecl),
+ loweredDecl);
+ }
+ // TODO: handle `static` function-scope variables
+ else
+ {
+ // A local variable declaration will get added to the
+ // statement scope we are currently processing.
+ addDecl(loweredDecl);
+ }
+
+ return loweredDecl;
+ }
+
+ RefPtr<VarDeclBase> visit(
+ StructField* decl)
+ {
+ return lowerVarDeclCommon(new StructField(), decl);
+ }
+
+ RefPtr<VarDeclBase> visit(
+ ParameterSyntaxNode* decl)
+ {
+ return lowerVarDeclCommon(new ParameterSyntaxNode(), decl);
+ }
+
+ RefPtr<DeclBase> transformSyntaxField(DeclBase* decl)
+ {
+ return lowerDeclBase(decl);
+ }
+
+
+ RefPtr<Decl> visit(
+ DeclGroup* group)
+ {
+ for (auto decl : group->decls)
+ {
+ lowerDecl(decl);
+ }
+ return nullptr;
+ }
+
+ RefPtr<FunctionSyntaxNode> visit(
+ FunctionDeclBase* decl)
+ {
+ // TODO: need to generate a name
+
+ RefPtr<FunctionSyntaxNode> loweredDecl = new FunctionSyntaxNode();
+ lowerDeclCommon(loweredDecl, decl);
+
+ // TODO: push scope for parent decl here...
+
+ // TODO: need to copy over relevant modifiers
+
+ for (auto paramDecl : decl->GetParameters())
+ {
+ addMember(loweredDecl, translateDeclRef(paramDecl));
+ }
+
+ auto loweredReturnType = lowerType(decl->ReturnType);
+
+ loweredDecl->ReturnType = loweredReturnType;
+
+ // If we are a being called recurisvely, then we need to
+ // be careful not to let the context get polluted
+ LoweringVisitor subVisitor = *this;
+ subVisitor.resultVariable = nullptr;
+ subVisitor.stmtBeingBuilt = nullptr;
+
+ loweredDecl->Body = subVisitor.lowerStmt(decl->Body);
+
+ // A lowered function always becomes a global-scope function,
+ // even if it had been a member function when declared.
+ addMember(shared->loweredProgram, loweredDecl);
+
+ return loweredDecl;
+ }
+
+ //
+ // Entry Points
+ //
+
+ EntryPointLayout* findEntryPointLayout(
+ EntryPointRequest* entryPointRequest)
+ {
+ for( auto entryPointLayout : shared->programLayout->entryPoints )
+ {
+ if(entryPointLayout->entryPoint->getName() != entryPointRequest->name)
+ continue;
+
+ if(entryPointLayout->profile != entryPointRequest->profile)
+ continue;
+
+ // TODO: can't easily filter on translation unit here...
+ // Ideally the `EntryPointRequest` should get filled in with a pointer
+ // the specific function declaration that represents the entry point.
+
+ return entryPointLayout.Ptr();
+ }
+
+ return nullptr;
+ }
+
+ enum class VaryingParameterDirection
+ {
+ Input,
+ Output,
+ };
+
+ struct VaryingParameterArraySpec
+ {
+ VaryingParameterArraySpec* next = nullptr;
+ IntVal* elementCount;
+ };
+
+ struct VaryingParameterInfo
+ {
+ String name;
+ VaryingParameterDirection direction;
+ VaryingParameterArraySpec* arraySpecs = nullptr;
+ };
+
+
+ void lowerSimpleShaderParameterToGLSLGlobal(
+ VaryingParameterInfo const& info,
+ RefPtr<ExpressionType> varType,
+ RefPtr<VarLayout> varLayout,
+ RefPtr<ExpressionSyntaxNode> varExpr)
+ {
+ RefPtr<ExpressionType> type = varType;
+
+ for (auto aa = info.arraySpecs; aa; aa = aa->next)
+ {
+ RefPtr<ArrayExpressionType> arrayType = new ArrayExpressionType();
+ arrayType->BaseType = type;
+ arrayType->ArrayLength = aa->elementCount;
+
+ type = arrayType;
+ }
+
+ // TODO: if we are declaring an SOA-ized array,
+ // this is where those array dimensions would need
+ // to be tacked on.
+
+ RefPtr<Variable> globalVarDecl = new Variable();
+ globalVarDecl->Name.Content = info.name;
+ globalVarDecl->Type.type = type;
+
+ addMember(shared->loweredProgram, globalVarDecl);
+
+ // Add the layout information
+ RefPtr<ComputedLayoutModifier> modifier = new ComputedLayoutModifier();
+ modifier->layout = varLayout;
+ addModifier(globalVarDecl, modifier);
+
+ // Need to generate an assignment in the right direction.
+ //
+ // TODO: for now I am just dealing with input:
+
+ switch (info.direction)
+ {
+ case VaryingParameterDirection::Input:
+ addModifier(globalVarDecl, new InModifier());
+ assign(varExpr, globalVarDecl);
+ break;
+
+ case VaryingParameterDirection::Output:
+ addModifier(globalVarDecl, new OutModifier());
+
+ assign(globalVarDecl, varExpr);
+ break;
+ }
+ }
+
+ void lowerShaderParameterToGLSLGLobalsRec(
+ VaryingParameterInfo const& info,
+ RefPtr<ExpressionType> varType,
+ RefPtr<VarLayout> varLayout,
+ RefPtr<ExpressionSyntaxNode> varExpr)
+ {
+ assert(varLayout);
+
+ if (auto basicType = varType->As<BasicExpressionType>())
+ {
+ // handled below
+ }
+ else if (auto vectorType = varType->As<VectorExpressionType>())
+ {
+ // handled below
+ }
+ else if (auto matrixType = varType->As<MatrixExpressionType>())
+ {
+ // handled below
+ }
+ else if (auto arrayType = varType->As<ArrayExpressionType>())
+ {
+ // We will accumulate information on the array
+ // types that were encoutnered on our walk down
+ // to the leaves, and then apply these array dimensions
+ // to any leaf parameters.
+
+ VaryingParameterArraySpec arraySpec;
+ arraySpec.next = info.arraySpecs;
+ arraySpec.elementCount = arrayType->ArrayLength;
+
+ VaryingParameterInfo arrayInfo = info;
+ arrayInfo.arraySpecs = &arraySpec;
+
+ RefPtr<IndexExpressionSyntaxNode> subscriptExpr = new IndexExpressionSyntaxNode();
+ subscriptExpr->Position = varExpr->Position;
+ subscriptExpr->BaseExpression = varExpr;
+
+ // TODO: we need to construct syntax for a loop to initialize
+ // the array here...
+ throw "unimplemented";
+
+ // Note that we use the original `varLayout` that was passed in,
+ // since that is the layout that will ultimately need to be
+ // used on the array elements.
+ //
+ // TODO: That won't actually work if we ever had an array of
+ // heterogeneous stuff...
+ lowerShaderParameterToGLSLGLobalsRec(
+ arrayInfo,
+ arrayType->BaseType,
+ varLayout,
+ subscriptExpr);
+
+ }
+ else if (auto declRefType = varType->As<DeclRefType>())
+ {
+ auto declRef = declRefType->declRef;
+ if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>())
+ {
+ // The shader parameter had a structured type, so we need
+ // to destructure it into its constituent fields
+
+ for (auto fieldDeclRef : getMembersOfType<VarDeclBase>(aggTypeDeclRef))
+ {
+ // Don't emit storage for `static` fields here, of course
+ if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>())
+ continue;
+
+ RefPtr<MemberExpressionSyntaxNode> fieldExpr = new MemberExpressionSyntaxNode();
+ fieldExpr->Position = varExpr->Position;
+ fieldExpr->Type.type = GetType(fieldDeclRef);
+ fieldExpr->declRef = fieldDeclRef;
+ fieldExpr->BaseExpression = varExpr;
+
+ VaryingParameterInfo fieldInfo = info;
+ fieldInfo.name = info.name + "_" + fieldDeclRef.GetName();
+
+ // Need to find the layout for the given field...
+ Decl* originalFieldDecl = nullptr;
+ shared->mapLoweredDeclToOriginal.TryGetValue(fieldDeclRef.getDecl(), originalFieldDecl);
+ assert(originalFieldDecl);
+
+ auto structTypeLayout = varLayout->typeLayout.As<StructTypeLayout>();
+ assert(structTypeLayout);
+
+ RefPtr<VarLayout> fieldLayout;
+ structTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout);
+ assert(fieldLayout);
+
+ lowerShaderParameterToGLSLGLobalsRec(
+ fieldInfo,
+ GetType(fieldDeclRef),
+ fieldLayout,
+ fieldExpr);
+ }
+
+ // Okay, we are done with this parameter
+ return;
+ }
+ }
+
+ // Default case: just try to emit things as-is
+ lowerSimpleShaderParameterToGLSLGlobal(info, varType, varLayout, varExpr);
+ }
+
+ void lowerShaderParameterToGLSLGLobals(
+ RefPtr<Variable> localVarDecl,
+ RefPtr<VarLayout> paramLayout,
+ VaryingParameterDirection direction)
+ {
+ auto name = localVarDecl->getName();
+ auto declRef = makeDeclRef(localVarDecl.Ptr());
+
+ RefPtr<VarExpressionSyntaxNode> expr = new VarExpressionSyntaxNode();
+ expr->name = name;
+ expr->declRef = declRef;
+ expr->Type.type = GetType(declRef);
+
+ VaryingParameterInfo info;
+ info.name = name;
+ info.direction = direction;
+
+ lowerShaderParameterToGLSLGLobalsRec(
+ info,
+ localVarDecl->getType(),
+ paramLayout,
+ expr);
+ }
+
+ struct EntryPointParamPair
+ {
+ RefPtr<ParameterSyntaxNode> original;
+ RefPtr<VarLayout> layout;
+ RefPtr<Variable> lowered;
+ };
+
+ RefPtr<FunctionSyntaxNode> lowerEntryPointToGLSL(
+ FunctionSyntaxNode* entryPointDecl,
+ RefPtr<EntryPointLayout> entryPointLayout)
+ {
+ // First, loer the entry-point function as an ordinary function:
+ auto loweredEntryPointFunc = visit(entryPointDecl);
+
+ // Now we will generate a `void main() { ... }` function to call the lowered code.
+ RefPtr<FunctionSyntaxNode> mainDecl = new FunctionSyntaxNode();
+ mainDecl->ReturnType.type = ExpressionType::GetVoid();
+ mainDecl->Name.Content = "main";
+
+ // If the user's entry point was called `main` then rename it here
+ if (loweredEntryPointFunc->getName() == "main")
+ loweredEntryPointFunc->Name.Content = "main_";
+
+ // We will want to generate declarations into the body of our new `main()`
+ LoweringVisitor subVisitor = *this;
+ subVisitor.isBuildingStmt = true;
+ subVisitor.stmtBeingBuilt = nullptr;
+
+ // The parameters of the entry-point function will be translated to
+ // both a local variable (for passing to/from the entry point func),
+ // and to global variables (used for parameter passing)
+
+ List<EntryPointParamPair> params;
+
+ // First generate declarations for the locals
+ for (auto paramDecl : entryPointDecl->GetParameters())
+ {
+ RefPtr<VarLayout> paramLayout;
+ entryPointLayout->mapVarToLayout.TryGetValue(paramDecl.Ptr(), paramLayout);
+ assert(paramLayout);
+
+ RefPtr<Variable> localVarDecl = new Variable();
+ localVarDecl->Position = paramDecl->Position;
+ localVarDecl->Name.Content = paramDecl->getName();
+ localVarDecl->Type = lowerType(paramDecl->Type);
+
+ subVisitor.addDecl(localVarDecl);
+
+ EntryPointParamPair paramPair;
+ paramPair.original = paramDecl;
+ paramPair.layout = paramLayout;
+ paramPair.lowered = localVarDecl;
+
+ params.Add(paramPair);
+ }
+
+ // Next generate globals for the inputs, and initialize them
+ for (auto paramPair : params)
+ {
+ auto paramDecl = paramPair.original;
+ if (paramDecl->HasModifier<InModifier>()
+ || paramDecl->HasModifier<InOutModifier>()
+ || !paramDecl->HasModifier<OutModifier>())
+ {
+ subVisitor.lowerShaderParameterToGLSLGLobals(
+ paramPair.lowered,
+ paramPair.layout,
+ VaryingParameterDirection::Input);
+ }
+ }
+
+ // Generate a local variable for the result, if any
+ RefPtr<Variable> resultVarDecl;
+ if (!loweredEntryPointFunc->ReturnType->Equals(ExpressionType::GetVoid()))
+ {
+ resultVarDecl = new Variable();
+ resultVarDecl->Position = loweredEntryPointFunc->Position;
+ resultVarDecl->Name.Content = "_main_result";
+ resultVarDecl->Type = TypeExp(loweredEntryPointFunc->ReturnType);
+
+ subVisitor.addDecl(resultVarDecl);
+ }
+
+ // Now generate a call to the entry-point function, using the local variables
+ auto entryPointDeclRef = makeDeclRef(loweredEntryPointFunc.Ptr());
+
+ RefPtr<FuncType> entryPointType = new FuncType();
+ entryPointType->declRef = entryPointDeclRef;
+
+ RefPtr<VarExpressionSyntaxNode> entryPointRef = new VarExpressionSyntaxNode();
+ entryPointRef->name = loweredEntryPointFunc->getName();
+ entryPointRef->declRef = entryPointDeclRef;
+ entryPointRef->Type = QualType(entryPointType);
+
+ RefPtr<InvokeExpressionSyntaxNode> callExpr = new InvokeExpressionSyntaxNode();
+ callExpr->FunctionExpr = entryPointRef;
+ callExpr->Type = QualType(loweredEntryPointFunc->ReturnType);
+
+ //
+ for (auto paramPair : params)
+ {
+ auto localVarDecl = paramPair.lowered;
+
+ RefPtr<VarExpressionSyntaxNode> varRef = new VarExpressionSyntaxNode();
+ varRef->name = localVarDecl->getName();
+ varRef->declRef = makeDeclRef(localVarDecl.Ptr());
+ varRef->Type = QualType(localVarDecl->getType());
+
+ callExpr->Arguments.Add(varRef);
+ }
+
+ if (resultVarDecl)
+ {
+ // Non-`void` return type, so we need to store it
+ subVisitor.assign(resultVarDecl, callExpr);
+ }
+ else
+ {
+ // `void` return type: just call it
+ subVisitor.addExprStmt(callExpr);
+ }
+
+
+ // Finally, generate logic to copy the outputs to global parameters
+ for (auto paramPair : params)
+ {
+ auto paramDecl = paramPair.original;
+ if (paramDecl->HasModifier<OutModifier>()
+ || paramDecl->HasModifier<InOutModifier>())
+ {
+ subVisitor.lowerShaderParameterToGLSLGLobals(
+ paramPair.lowered,
+ paramPair.layout,
+ VaryingParameterDirection::Output);
+ }
+ }
+ if (resultVarDecl)
+ {
+ subVisitor.lowerShaderParameterToGLSLGLobals(
+ resultVarDecl,
+ entryPointLayout->resultLayout,
+ VaryingParameterDirection::Output);
+ }
+
+ mainDecl->Body = subVisitor.stmtBeingBuilt;
+
+
+ // Once we are done building the body, we append our new declaration to the program.
+ addMember(shared->loweredProgram, mainDecl);
+ return mainDecl;
+
+#if 0
+ RefPtr<FunctionSyntaxNode> loweredDecl = new FunctionSyntaxNode();
+ lowerDeclCommon(loweredDecl, entryPointDecl);
+
+ // We create a sub-context appropriate for lowering the function body
+
+ LoweringVisitor subVisitor = *this;
+ subVisitor.isBuildingStmt = true;
+ subVisitor.stmtBeingBuilt = nullptr;
+
+ // The parameters of the entry-point function must be translated
+ // to global-scope declarations
+ for (auto paramDecl : entryPointDecl->GetParameters())
+ {
+ subVisitor.lowerShaderParameterToGLSLGLobals(paramDecl);
+ }
+
+ // The output of the function must also be translated into a
+ // global-scope declaration.
+ auto loweredReturnType = lowerType(entryPointDecl->ReturnType);
+ RefPtr<Variable> resultGlobal;
+ if (!loweredReturnType->Equals(ExpressionType::GetVoid()))
+ {
+ resultGlobal = new Variable();
+ // TODO: need a scheme for generating unique names
+ resultGlobal->Name.Content = "_main_result";
+ resultGlobal->Type = loweredReturnType;
+
+ addMember(shared->loweredProgram, resultGlobal);
+ }
+
+ loweredDecl->Name.Content = "main";
+ loweredDecl->ReturnType.type = ExpressionType::GetVoid();
+
+ // We will emit the body statement in a context where
+ // a `return` statmenet will generate writes to the
+ // result global that we declared.
+ subVisitor.resultVariable = resultGlobal;
+
+ auto loweredBody = subVisitor.lowerStmt(entryPointDecl->Body);
+ subVisitor.addStmt(loweredBody);
+
+ loweredDecl->Body = subVisitor.stmtBeingBuilt;
+
+ // TODO: need to append writes for `out` parameters here...
+
+ addMember(shared->loweredProgram, loweredDecl);
+ return loweredDecl;
+#endif
+ }
+
+ RefPtr<FunctionSyntaxNode> lowerEntryPoint(
+ FunctionSyntaxNode* entryPointDecl,
+ RefPtr<EntryPointLayout> entryPointLayout)
+ {
+ switch( getTarget() )
+ {
+ // Default case: lower an entry point just like any other function
+ default:
+ return visit(entryPointDecl);
+
+ // For Slang->GLSL translation, we need to lower things from HLSL-style
+ // declarations over to GLSL conventions
+ case CodeGenTarget::GLSL:
+ return lowerEntryPointToGLSL(entryPointDecl, entryPointLayout);
+ }
+ }
+
+ RefPtr<FunctionSyntaxNode> lowerEntryPoint(
+ EntryPointRequest* entryPointRequest)
+ {
+ auto entryPointLayout = findEntryPointLayout(entryPointRequest);
+ auto entryPointDecl = entryPointLayout->entryPoint;
+
+ return lowerEntryPoint(
+ entryPointDecl,
+ entryPointLayout);
+ }
+
+
+};
+
+static RefPtr<StructTypeLayout> getGlobalStructLayout(
+ ProgramLayout* programLayout)
+{
+ // Layout information for the global scope is either an ordinary
+ // `struct` in the common case, or a constant buffer in the case
+ // where there were global-scope uniforms.
+ auto globalScopeLayout = programLayout->globalScopeLayout;
+ StructTypeLayout* globalStructLayout = globalScopeLayout.As<StructTypeLayout>();
+ if(globalStructLayout)
+ { }
+ else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>())
+ {
+ // TODO: the `cbuffer` case really needs to be emitted very
+ // carefully, but that is beyond the scope of what a simple rewriter
+ // can easily do (without semantic analysis, etc.).
+ //
+ // The crux of the problem is that we need to collect all the
+ // global-scope uniforms (but not declarations that don't involve
+ // uniform storage...) and put them in a single `cbuffer` declaration,
+ // so that we can give it an explicit location. The fields in that
+ // declaration might use various type declarations, so we'd really
+ // need to emit all the type declarations first, and that involves
+ // some large scale reorderings.
+ //
+ // For now we will punt and just emit the declarations normally,
+ // and hope that the global-scope block (`$Globals`) gets auto-assigned
+ // the same location that we manually asigned it.
+
+ auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout;
+ auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>();
+
+ // We expect all constant buffers to contain `struct` types for now
+ assert(elementTypeStructLayout);
+
+ globalStructLayout = elementTypeStructLayout.Ptr();
+ }
+ else
+ {
+ assert(!"unexpected");
+ }
+ return globalStructLayout;
+}
+
+
+// Determine if the user is just trying to "rewrite" their input file
+// into an output file. This will affect the way we approach code
+// generation, because we want to leave their code "as is" whenever
+// possible.
+bool isRewriteRequest(
+ SourceLanguage sourceLanguage,
+ CodeGenTarget target)
+{
+ // TODO: we might only consider things to be a rewrite request
+ // in the specific case where checking is turned off...
+
+ switch( target )
+ {
+ default:
+ return false;
+
+ case CodeGenTarget::HLSL:
+ return sourceLanguage == SourceLanguage::HLSL;
+
+ case CodeGenTarget::GLSL:
+ return sourceLanguage == SourceLanguage::GLSL;
+ }
+}
+
+
+
+LoweredEntryPoint lowerEntryPoint(
+ EntryPointRequest* entryPoint,
+ ProgramLayout* programLayout,
+ CodeGenTarget target)
+{
+ SharedLoweringContext sharedContext;
+ sharedContext.programLayout = programLayout;
+ sharedContext.target = target;
+
+ auto translationUnit = entryPoint->getTranslationUnit();
+
+ // Create a single module/program to hold all the lowered code
+ // (with the exception of instrinsic/stdlib declarations, which
+ // will be remain where they are)
+ RefPtr<ProgramSyntaxNode> loweredProgram = new ProgramSyntaxNode();
+ sharedContext.loweredProgram = loweredProgram;
+
+ LoweringVisitor visitor;
+ visitor.shared = &sharedContext;
+ visitor.parentDecl = loweredProgram;
+
+ // We need to register the lowered program as the lowered version
+ // of the existing translation unit declaration.
+
+ visitor.registerLoweredDecl(
+ loweredProgram,
+ translationUnit->SyntaxNode);
+
+ // We also need to register the lowered program as the lowered version
+ // of any imported modules (since we will be collecting everything into
+ // a single module for code generation).
+ for (auto rr : entryPoint->compileRequest->loadedModulesList)
+ {
+ sharedContext.loweredDecls.Add(
+ rr,
+ loweredProgram);
+ }
+
+ // We also want to remember the layout information for
+ // that declaration, so that we can apply it during emission
+ attachLayout(loweredProgram,
+ getGlobalStructLayout(programLayout));
+
+
+ bool isRewrite = isRewriteRequest(translationUnit->sourceLanguage, target);
+ sharedContext.isRewrite = isRewrite;
+
+ LoweredEntryPoint result;
+ if (isRewrite)
+ {
+ for (auto dd : translationUnit->SyntaxNode->Members)
+ {
+ visitor.translateDeclRef(dd);
+ }
+ }
+ else
+ {
+ auto loweredEntryPoint = visitor.lowerEntryPoint(entryPoint);
+ result.entryPoint = loweredEntryPoint;
+ }
+
+ result.program = sharedContext.loweredProgram;
+
+ return result;
+}
+}
diff --git a/source/slang/lower.h b/source/slang/lower.h
new file mode 100644
index 000000000..c690ea025
--- /dev/null
+++ b/source/slang/lower.h
@@ -0,0 +1,38 @@
+// lower.h
+#ifndef SLANG_LOWER_H_INCLUDED
+#define SLANG_LOWER_H_INCLUDED
+
+// The "lowering" step takes an input AST written in the complete Slang
+// language and turns it into a more minimal format (still using the
+// same AST) suitable for emission into lower-level languages.
+
+#include "../core/basic.h"
+
+#include "compiler.h"
+#include "syntax.h"
+
+namespace Slang
+{
+ class EntryPointRequest;
+ class ProgramLayout;
+ class TranslationUnitRequest;
+
+ struct LoweredEntryPoint
+ {
+ // The actual lowered entry point
+ RefPtr<FunctionSyntaxNode> entryPoint;
+
+ // The generated program AST that
+ // contains the entry point and any
+ // other declarations it uses
+ RefPtr<ProgramSyntaxNode> program;
+ };
+
+ // Emit code for a single entry point, based on
+ // the input translation unit.
+ LoweredEntryPoint lowerEntryPoint(
+ EntryPointRequest* entryPoint,
+ ProgramLayout* programLayout,
+ CodeGenTarget target);
+}
+#endif
diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h
index a1da47fb7..e50288600 100644
--- a/source/slang/modifier-defs.h
+++ b/source/slang/modifier-defs.h
@@ -269,3 +269,7 @@ SIMPLE_SYNTAX_CLASS(HLSLTriangleModifier , HLSLGeometryShaderInputPrimitiveT
SIMPLE_SYNTAX_CLASS(HLSLLineAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier)
SIMPLE_SYNTAX_CLASS(HLSLTriangleAdjModifier , HLSLGeometryShaderInputPrimitiveTypeModifier)
+// A modifier to be attached to syntax after we've computed layout
+SYNTAX_CLASS(ComputedLayoutModifier, Modifier)
+ FIELD(RefPtr<Layout>, layout)
+END_SYNTAX_CLASS()
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index 2a0c64892..8d008618f 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -767,86 +767,87 @@ SimpleSemanticInfo decomposeSimpleSemantic(
return info;
}
-enum class EntryPointParameterDirection
+enum EntryPointParameterDirection
{
- Input,
- Output,
+ kEntryPointParameterDirection_Input = 0x1,
+ kEntryPointParameterDirection_Output = 0x2,
};
+typedef unsigned int EntryPointParameterDirectionMask;
struct EntryPointParameterState
{
- String* optSemanticName;
- int* ioSemanticIndex;
- EntryPointParameterDirection direction;
- int semanticSlotCount;
+ String* optSemanticName;
+ int* ioSemanticIndex;
+ EntryPointParameterDirectionMask directionMask;
+ int semanticSlotCount;
};
-static void processSimpleEntryPointInput(
+static RefPtr<TypeLayout> processSimpleEntryPointParameter(
ParameterBindingContext* context,
RefPtr<ExpressionType> type,
- EntryPointParameterState const& state)
+ EntryPointParameterState const& inState,
+ int semanticSlotCount = 1)
{
- auto optSemanticName = state.optSemanticName;
- auto semanticIndex = *state.ioSemanticIndex;
- auto semanticSlotCount = state.semanticSlotCount;
-}
+ EntryPointParameterState state = inState;
+ state.semanticSlotCount = semanticSlotCount;
-static void processSimpleEntryPointOutput(
- ParameterBindingContext* context,
- RefPtr<ExpressionType> type,
- EntryPointParameterState const& state)
-{
auto optSemanticName = state.optSemanticName;
auto semanticIndex = *state.ioSemanticIndex;
- auto semanticSlotCount = state.semanticSlotCount;
-
- if(!optSemanticName)
- return;
- auto semanticName = *optSemanticName;
+ String semanticName = optSemanticName ? *optSemanticName : "";
+ String sn = semanticName.ToLower();
- // Note: I'm just doing something expedient here and detecting `SV_Target`
- // outputs and claiming the appropriate register range right away.
- //
- // TODO: we should really be building up some representation of all of this,
- // once we've gone to the trouble of looking it all up...
- if( semanticName.ToLower() == "sv_target" )
+ RefPtr<TypeLayout> typeLayout = new TypeLayout();
+ if (sn.StartsWith("sv_"))
{
- context->shared->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(semanticIndex, semanticIndex + semanticSlotCount);
- }
-}
+ // System-value semantic.
-static void processSimpleEntryPointParameter(
- ParameterBindingContext* context,
- RefPtr<ExpressionType> type,
- EntryPointParameterState const& inState,
- int semanticSlotCount = 1)
-{
- EntryPointParameterState state = inState;
- state.semanticSlotCount = semanticSlotCount;
+ if (state.directionMask & kEntryPointParameterDirection_Output)
+ {
+ // Note: I'm just doing something expedient here and detecting `SV_Target`
+ // outputs and claiming the appropriate register range right away.
+ //
+ // TODO: we should really be building up some representation of all of this,
+ // once we've gone to the trouble of looking it all up...
+ if( sn == "sv_target" )
+ {
+ context->shared->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(semanticIndex, semanticIndex + semanticSlotCount);
+ }
+ }
- switch( state.direction )
+ // TODO: add some kind of usage information for system input/output
+ }
+ else
{
- case EntryPointParameterDirection::Input:
- processSimpleEntryPointInput(context, type, state);
- break;
+ // user-defined semantic
- case EntryPointParameterDirection::Output:
- processSimpleEntryPointOutput(context, type, state);
- break;
+ if (state.directionMask & kEntryPointParameterDirection_Input)
+ {
+ auto rules = context->layoutRules->getVaryingInputRules();
+ SimpleLayoutInfo layout = GetLayout(type, rules);
+ typeLayout->addResourceUsage(layout.kind, layout.size);
+ }
- SLANG_EXHAUSTIVE_SWITCH()
+ if (state.directionMask & kEntryPointParameterDirection_Output)
+ {
+ auto rules = context->layoutRules->getVaryingOutputRules();
+ SimpleLayoutInfo layout = GetLayout(type, rules);
+ typeLayout->addResourceUsage(layout.kind, layout.size);
+ }
}
*state.ioSemanticIndex += state.semanticSlotCount;
+ typeLayout->type = type;
+
+ return typeLayout;
}
-static void processEntryPointParameter(
+static RefPtr<TypeLayout> processEntryPointParameter(
ParameterBindingContext* context,
RefPtr<ExpressionType> type,
EntryPointParameterState const& state);
-static void processEntryPointParameterWithPossibleSemantic(
+static RefPtr<TypeLayout> processEntryPointParameterWithPossibleSemantic(
ParameterBindingContext* context,
Decl* declForSemantic,
RefPtr<ExpressionType> type,
@@ -873,11 +874,11 @@ static void processEntryPointParameterWithPossibleSemantic(
// *or* we couldn't find an explicit semantic to apply on the given
// declaration, so we will just recursive with whatever we have at
// the moment.
- processEntryPointParameter(context, type, state);
+ return processEntryPointParameter(context, type, state);
}
-static void processEntryPointParameter(
+static RefPtr<TypeLayout> processEntryPointParameter(
ParameterBindingContext* context,
RefPtr<ExpressionType> type,
EntryPointParameterState const& state)
@@ -885,31 +886,50 @@ static void processEntryPointParameter(
// Scalar and vector types are treated as outputs directly
if(auto basicType = type->As<BasicExpressionType>())
{
- processSimpleEntryPointParameter(context, basicType, state);
+ return processSimpleEntryPointParameter(context, basicType, state);
}
else if(auto basicType = type->As<VectorExpressionType>())
{
- processSimpleEntryPointParameter(context, basicType, state);
+ return processSimpleEntryPointParameter(context, basicType, state);
}
// A matrix is processed as if it was an array of rows
else if( auto matrixType = type->As<MatrixExpressionType>() )
{
auto rowCount = GetIntVal(matrixType->getRowCount());
- processSimpleEntryPointParameter(context, basicType, state, (int) rowCount);
+ return processSimpleEntryPointParameter(context, basicType, state, (int) rowCount);
}
else if( auto arrayType = type->As<ArrayExpressionType>() )
{
+ // Note: Bad Things will happen if we have an array input
+ // without a semantic already being enforced.
+
auto elementCount = GetIntVal(arrayType->ArrayLength);
- for( int ii = 0; ii < elementCount; ++ii )
+ // We use the first element to derive the layout for the element type
+ auto elementTypeLayout = processEntryPointParameter(context, arrayType->BaseType, state);
+
+ // We still walk over subsequent elements to make sure they consume resources
+ // as needed
+ for( int ii = 1; ii < elementCount; ++ii )
{
processEntryPointParameter(context, arrayType->BaseType, state);
}
+
+ RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout();
+ arrayTypeLayout->elementTypeLayout = elementTypeLayout;
+ arrayTypeLayout->type = arrayType;
+
+ for (auto rr : elementTypeLayout->resourceInfos)
+ {
+ arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count * elementCount;
+ }
+
+ return arrayTypeLayout;
}
// Ignore a bunch of types that don't make sense here...
- else if(auto textureType = type->As<TextureType>()) {}
- else if(auto samplerStateType = type->As<SamplerStateType>()) {}
- else if(auto constantBufferType = type->As<ConstantBufferType>()) {}
+ else if (auto textureType = type->As<TextureType>()) { return nullptr; }
+ else if(auto samplerStateType = type->As<SamplerStateType>()) { return nullptr; }
+ else if(auto constantBufferType = type->As<ConstantBufferType>()) { return nullptr; }
// Catch declaration-reference types late in the sequence, since
// otherwise they will include all of the above cases...
else if( auto declRefType = type->As<DeclRefType>() )
@@ -918,15 +938,36 @@ static void processEntryPointParameter(
if (auto structDeclRef = declRef.As<StructSyntaxNode>())
{
+ RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
+ structLayout->type = type;
+
// Need to recursively walk the fields of the structure now...
for( auto field : GetFields(structDeclRef) )
{
- processEntryPointParameterWithPossibleSemantic(
+ auto fieldTypeLayout = processEntryPointParameterWithPossibleSemantic(
context,
field.getDecl(),
GetType(field),
state);
+
+ RefPtr<VarLayout> fieldVarLayout = new VarLayout();
+ fieldVarLayout->varDecl = field;
+ fieldVarLayout->typeLayout = fieldTypeLayout;
+
+ for (auto rr : fieldTypeLayout->resourceInfos)
+ {
+ assert(rr.count != 0);
+
+ auto structRes = structLayout->findOrAddResourceInfo(rr.kind);
+ fieldVarLayout->findOrAddResourceInfo(rr.kind)->index = structRes->count;
+ structRes->count += rr.count;
+ }
+
+ structLayout->fields.Add(fieldVarLayout);
+ structLayout->mapVarToLayout.Add(field.getDecl(), fieldVarLayout);
}
+
+ return structLayout;
}
else
{
@@ -937,6 +978,9 @@ static void processEntryPointParameter(
{
assert(!"unimplemented");
}
+
+ assert(!"unexpected");
+ return nullptr;
}
static void collectEntryPointParameters(
@@ -998,42 +1042,68 @@ static void collectEntryPointParameters(
// We have an entry-point parameter, and need to figure out what to do with it.
+ // TODO: need to handle `uniform`-qualified parameters here
+ if (paramDecl->HasModifier<UniformModifier>())
+ continue;
+
+ state.directionMask = 0;
+
// If it appears to be an input, process it as such.
if( paramDecl->HasModifier<InModifier>() || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>() )
{
- state.direction = EntryPointParameterDirection::Input;
-
- processEntryPointParameterWithPossibleSemantic(
- context,
- paramDecl.Ptr(),
- paramDecl->Type.type,
- state);
+ state.directionMask |= kEntryPointParameterDirection_Input;
}
// If it appears to be an output, process it as such.
if(paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>())
{
- state.direction = EntryPointParameterDirection::Output;
+ state.directionMask |= kEntryPointParameterDirection_Output;
+ }
+
+ auto paramTypeLayout = processEntryPointParameterWithPossibleSemantic(
+ context,
+ paramDecl.Ptr(),
+ paramDecl->Type.type,
+ state);
- processEntryPointParameterWithPossibleSemantic(
- context,
- paramDecl.Ptr(),
- paramDecl->Type.type,
- state);
+ RefPtr<VarLayout> paramVarLayout = new VarLayout();
+ paramVarLayout->varDecl = makeDeclRef(paramDecl.Ptr());
+ paramVarLayout->typeLayout = paramTypeLayout;
+
+ for (auto rr : paramTypeLayout->resourceInfos)
+ {
+ auto entryPointRes = entryPointLayout->findOrAddResourceInfo(rr.kind);
+ paramVarLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count;
+ entryPointRes->count += rr.count;
}
+
+ entryPointLayout->fields.Add(paramVarLayout);
+ entryPointLayout->mapVarToLayout.Add(paramDecl, paramVarLayout);
}
// If we can find an output type for the entry point, then process it as
// an output parameter.
if( auto resultType = entryPointFuncDecl->ReturnType.type )
{
- state.direction = EntryPointParameterDirection::Output;
+ state.directionMask = kEntryPointParameterDirection_Output;
- processEntryPointParameterWithPossibleSemantic(
+ auto resultTypeLayout = processEntryPointParameterWithPossibleSemantic(
context,
entryPointFuncDecl,
resultType,
state);
+
+ RefPtr<VarLayout> resultLayout = new VarLayout();
+ resultLayout->typeLayout = resultTypeLayout;
+
+ for (auto rr : resultTypeLayout->resourceInfos)
+ {
+ auto entryPointRes = entryPointLayout->findOrAddResourceInfo(rr.kind);
+ resultLayout->findOrAddResourceInfo(rr.kind)->index = entryPointRes->count;
+ entryPointRes->count += rr.count;
+ }
+
+ entryPointLayout->resultLayout = resultLayout;
}
}
diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp
index 4f622ada6..b134f9645 100644
--- a/source/slang/parser.cpp
+++ b/source/slang/parser.cpp
@@ -2617,10 +2617,13 @@ namespace Slang
RefPtr<ScopeDecl> scopeDecl = new ScopeDecl();
- RefPtr<BlockStatementSyntaxNode> blockStatement = new BlockStatementSyntaxNode();
+ RefPtr<BlockStmt> blockStatement = new BlockStmt();
blockStatement->scopeDecl = scopeDecl;
PushScope(scopeDecl.Ptr());
ReadToken(TokenType::LBrace);
+
+ RefPtr<StatementSyntaxNode> body;
+
if(!tokenReader.IsAtEnd())
{
FillPosition(blockStatement.Ptr());
@@ -2630,11 +2633,29 @@ namespace Slang
auto stmt = ParseStatement();
if(stmt)
{
- blockStatement->Statements.Add(stmt);
+ if (!body)
+ {
+ body = stmt;
+ }
+ else if (auto seqStmt = body.As<SeqStmt>())
+ {
+ seqStmt->stmts.Add(stmt);
+ }
+ else
+ {
+ RefPtr<SeqStmt> newBody = new SeqStmt();
+ newBody->Position = blockStatement->Position;
+ newBody->stmts.Add(body);
+ newBody->stmts.Add(stmt);
+
+ body = newBody;
+ }
}
TryRecover(this);
}
PopScope();
+
+ blockStatement->body = body;
return blockStatement;
}
@@ -3054,7 +3075,19 @@ namespace Slang
right = parseInfixExprWithPrecedence(parser, right, nextOpPrec);
}
- expr = createInfixExpr(parser, expr, op, right);
+ if (opTokenType == TokenType::OpAssign)
+ {
+ RefPtr<AssignExpr> assignExpr = new AssignExpr();
+ assignExpr->Position = op->Position;
+ assignExpr->left = expr;
+ assignExpr->right = right;
+
+ expr = assignExpr;
+ }
+ else
+ {
+ expr = createInfixExpr(parser, expr, op, right);
+ }
}
return expr;
}
diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj
index 3fde10e4b..441bc5e45 100644
--- a/source/slang/slang.vcxproj
+++ b/source/slang/slang.vcxproj
@@ -178,6 +178,7 @@
<ClInclude Include="modifier-defs.h" />
<ClInclude Include="object-meta-begin.h" />
<ClInclude Include="object-meta-end.h" />
+ <ClInclude Include="lower.h" />
<ClInclude Include="parameter-binding.h" />
<ClInclude Include="parser.h" />
<ClInclude Include="preprocessor.h" />
@@ -205,6 +206,7 @@
<ClCompile Include="emit.cpp" />
<ClCompile Include="lexer.cpp" />
<ClCompile Include="lookup.cpp" />
+ <ClCompile Include="lower.cpp" />
<ClCompile Include="options.cpp" />
<ClCompile Include="parameter-binding.cpp" />
<ClCompile Include="parser.cpp" />
diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters
index 0ed4457dc..21b533a10 100644
--- a/source/slang/slang.vcxproj.filters
+++ b/source/slang/slang.vcxproj.filters
@@ -37,6 +37,7 @@
<ClInclude Include="type-defs.h" />
<ClInclude Include="val-defs.h" />
<ClInclude Include="visitor.h" />
+ <ClInclude Include="lower.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="check.cpp" />
@@ -56,5 +57,6 @@
<ClCompile Include="token.cpp" />
<ClCompile Include="type-layout.cpp" />
<ClCompile Include="options.cpp" />
+ <ClCompile Include="lower.cpp" />
</ItemGroup>
</Project> \ No newline at end of file
diff --git a/source/slang/stmt-defs.h b/source/slang/stmt-defs.h
index 9cc1978bd..c5bcda63b 100644
--- a/source/slang/stmt-defs.h
+++ b/source/slang/stmt-defs.h
@@ -6,8 +6,14 @@ ABSTRACT_SYNTAX_CLASS(ScopeStmt, StatementSyntaxNode)
SYNTAX_FIELD(RefPtr<ScopeDecl>, scopeDecl)
END_SYNTAX_CLASS()
-SYNTAX_CLASS(BlockStatementSyntaxNode, ScopeStmt)
- SYNTAX_FIELD(List<RefPtr<StatementSyntaxNode>>, Statements)
+// A sequence of statements, treated as a single statement
+SYNTAX_CLASS(SeqStmt, StatementSyntaxNode)
+ SYNTAX_FIELD(List<RefPtr<StatementSyntaxNode>>, stmts)
+END_SYNTAX_CLASS()
+
+// The simplest kind of scope statement: just a `{...}` block
+SYNTAX_CLASS(BlockStmt, ScopeStmt)
+ SYNTAX_FIELD(RefPtr<StatementSyntaxNode>, body);
END_SYNTAX_CLASS()
SYNTAX_CLASS(UnparsedStmt, StatementSyntaxNode)
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index 983f18bb7..12d8c90bb 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -1,8 +1,6 @@
#include "syntax.h"
-#pragma warning AAA
#include "visitor.h"
-#pragma warning BBB
#include <typeinfo>
#include <assert.h>
@@ -55,9 +53,6 @@ namespace Slang
return res.ProduceString();
}
-#pragma warning CCC
-
-
// Generate dispatch logic and other definitions for all syntax classes
#define SYNTAX_CLASS(NAME, BASE) /* empty */
#include "object-meta-begin.h"
@@ -79,8 +74,6 @@ namespace Slang
#include "object-meta-end.h"
-#pragma warning DDD
-
void ExpressionType::accept(IValVisitor* visitor, void* extra)
{
accept((ITypeVisitor*)visitor, extra);
@@ -881,9 +874,20 @@ void ExpressionType::accept(IValVisitor* visitor, void* extra)
auto parentDecl = decl->ParentDecl;
if (auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl))
{
- // We need to strip away one layer of specialization
- assert(substitutions);
- return DeclRefBase(parentGeneric, substitutions->outer);
+ if (substitutions && substitutions->genericDecl == parentDecl)
+ {
+ // We strip away the specializations that were applied to
+ // the parent, since we were asked for a reference *to* the parent.
+ return DeclRefBase(parentGeneric, substitutions->outer);
+ }
+ else
+ {
+ // Either we don't have specializations, or the inner-most
+ // specializations didn't apply to the parent decl. This
+ // can happen if we are looking at an unspecialized
+ // declaration that is a child of a generic.
+ return DeclRefBase(parentGeneric, substitutions);
+ }
}
else
{
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index 47427b130..6587e18c3 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -14,6 +14,7 @@ namespace Slang
class Substitutions;
class SyntaxVisitor;
class FunctionSyntaxNode;
+ class Layout;
struct IExprVisitor;
struct IDeclVisitor;
@@ -347,9 +348,15 @@ namespace Slang
{
return DeclRef<ContainerDecl>::unsafeInit(DeclRefBase::GetParent());
}
-
};
+
+ template<typename T>
+ inline DeclRef<T> makeDeclRef(T* decl)
+ {
+ return DeclRef<T>(decl, nullptr);
+ }
+
template<typename T>
struct FilteredMemberList
{
diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h
index 07867fddb..ce1f8864d 100644
--- a/source/slang/type-layout.h
+++ b/source/slang/type-layout.h
@@ -143,8 +143,13 @@ struct SimpleArrayLayoutInfo : SimpleLayoutInfo
struct LayoutRulesImpl;
+// Base class for things that store layout info
+class Layout : public RefObject
+{
+};
+
// A reified reprsentation of a particular laid-out type
-class TypeLayout : public RefObject
+class TypeLayout : public Layout
{
public:
// The type that was laid out
@@ -215,7 +220,7 @@ enum VarLayoutFlag : VarLayoutFlags
};
// A reified layout for a particular variable, field, etc.
-class VarLayout : public RefObject
+class VarLayout : public Layout
{
public:
// The variable we are laying out
@@ -324,7 +329,14 @@ public:
// Layout information for a single shader entry point
// within a program
-class EntryPointLayout : public RefObject
+//
+// Treated as a subclass of `StructTypeLayout` becase
+// it needs to include computed layout information
+// for the parameters of the entry point.
+//
+// TODO: where to store layout info for the return
+// type of the function?
+class EntryPointLayout : public StructTypeLayout
{
public:
// The corresponding function declaration
@@ -332,10 +344,13 @@ public:
// The shader profile that was used to compile the entry point
Profile profile;
+
+ // Layout for any results of the entry point
+ RefPtr<VarLayout> resultLayout;
};
// Layout information for the global scope of a program
-class ProgramLayout : public RefObject
+class ProgramLayout : public Layout
{
public:
// We store a layout for the declarations at the global
@@ -359,14 +374,6 @@ public:
List<RefPtr<EntryPointLayout>> entryPoints;
};
-// A modifier to be attached to syntax after we've computed layout
-class ComputedLayoutModifier : public Modifier
-{
-public:
- RefPtr<TypeLayout> typeLayout;
-};
-
-
struct LayoutRulesFamilyImpl;
// A delineation of shader parameter types into fine-grained
diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h
index 316981842..9ebfc9872 100644
--- a/source/slang/val-defs.h
+++ b/source/slang/val-defs.h
@@ -3,7 +3,8 @@
// Syntax class definitions for compile-time values.
// A compile-time integer (may not have a specific concrete value)
-SIMPLE_SYNTAX_CLASS(IntVal, Val)
+ABSTRACT_SYNTAX_CLASS(IntVal, Val)
+END_SYNTAX_CLASS()
// Trivial case of a value that is just a constant integer
SYNTAX_CLASS(ConstantIntVal, IntVal)
diff --git a/tests/render/cross-compile-entry-point.slang b/tests/render/cross-compile-entry-point.slang
new file mode 100644
index 000000000..018947228
--- /dev/null
+++ b/tests/render/cross-compile-entry-point.slang
@@ -0,0 +1,89 @@
+//TEST(render):COMPARE_HLSL_CROSS_COMPILE_RENDER:
+
+// This is a test to ensure that we can cross-compile a complete entry point.
+
+float3 transformColor(float3 color)
+{
+ float3 result;
+
+ result.x = sin(20.0 * (color.x + color.y));
+ result.y = saturate(cos(color.z * 30.0));
+ result.z = sin(color.x * color.y * color.z * 100.0);
+
+ result = 0.5 * (result + 1);
+
+ return result;
+}
+
+cbuffer Uniforms
+{
+ float4x4 modelViewProjection;
+};
+
+struct AssembledVertex
+{
+ float3 position;
+ float3 color;
+};
+
+struct CoarseVertex
+{
+ float3 color;
+};
+
+struct Fragment
+{
+ float4 color;
+};
+
+// Vertex Shader
+
+struct VertexStageInput
+{
+ AssembledVertex assembledVertex : A;
+};
+
+struct VertexStageOutput
+{
+ CoarseVertex coarseVertex : CoarseVertex;
+ float4 sv_position : SV_Position;
+};
+
+VertexStageOutput vertexMain(VertexStageInput input)
+{
+ VertexStageOutput output;
+
+ float3 position = input.assembledVertex.position;
+ float3 color = input.assembledVertex.color;
+
+ output.coarseVertex.color = color;
+ output.sv_position = mul(modelViewProjection, float4(position, 1.0));
+
+ return output;
+
+}
+
+// Fragment Shader
+
+struct FragmentStageInput
+{
+ CoarseVertex coarseVertex : CoarseVertex;
+};
+
+struct FragmentStageOutput
+{
+ Fragment fragment : SV_Target;
+};
+
+FragmentStageOutput fragmentMain(FragmentStageInput input)
+{
+ FragmentStageOutput output;
+
+ float3 color = input.coarseVertex.color;
+
+ color = transformColor(color);
+
+ output.fragment.color = float4(color, 1.0);
+
+ return output;
+}
diff --git a/tools/render-test/render-d3d11.cpp b/tools/render-test/render-d3d11.cpp
index 19795d685..5b25714c0 100644
--- a/tools/render-test/render-d3d11.cpp
+++ b/tools/render-test/render-d3d11.cpp
@@ -889,10 +889,10 @@ public:
virtual ShaderProgram* compileProgram(ShaderCompileRequest const& request) override
{
- auto dxVertexShaderBlob = compileHLSLShader(request.source.path, request.source.text, request.vertexShader .name, request.vertexShader .profile);
+ auto dxVertexShaderBlob = compileHLSLShader(request.vertexShader.source.path, request.vertexShader.source.text, request.vertexShader .name, request.vertexShader .profile);
if(!dxVertexShaderBlob) return nullptr;
- auto dxFragmentShaderBlob = compileHLSLShader(request.source.path, request.source.text, request.fragmentShader .name, request.fragmentShader .profile);
+ auto dxFragmentShaderBlob = compileHLSLShader(request.fragmentShader.source.path, request.fragmentShader.source.text, request.fragmentShader .name, request.fragmentShader .profile);
if(!dxFragmentShaderBlob) return nullptr;
ID3D11VertexShader* dxVertexShader;