diff options
| author | Tim Foley <tfoley@nvidia.com> | 2017-07-18 07:49:33 -0700 |
|---|---|---|
| committer | Tim Foley <tfoley@nvidia.com> | 2017-07-18 12:58:48 -0700 |
| commit | 1c022e2c3654de868c45658683f9e04cf4d68cc0 (patch) | |
| tree | d4a5f0cefd50c96aaf22921f9fef715b6359c0c5 | |
| parent | 361e29572ff8e2cdd1e4ffe2cb62599e9ef06461 (diff) | |
Support scalarization of varying input/output for GLSL
GLSL technically supports varying (`in`, `out`) parameters of `struct` type, but there are some annoying constraints (not allowed for VS input), and it doesn't work with how an HLSL user would usually put "system-value" inputs/outputs into a `struct` together with ordinary inputs/outputs.
To work around this, this change adds support for using an imported Slang `struct` type for an `in` or `out` parameter, in which case it will (1) be scalarized and (2) will have system-value semantics mapped appropriately, just as for an entry-point parameter when cross-compiling an HLSL-style `main()`.
Changes:
- Add a notion of a `VaryingTupleExpr` and `VaryingTupleVarDecl`, similar to those for the resources-in-structs case
- Trigger use of these when we have a global-scope varying in/out using an imported `struct` type
- Also use these in the cross-compilation case for ordinary varying input/output (since this approach seems like it should be more general, and can hopefully handle stuff like GS input/output some day)
- When generating parameter binding information, special case global-scope input/output, and treat it the same as entry-point-parameter input/output
- Revamp how used resource ranges are computed so that we can eventually make this specific to an entry point
- Actually implement first signs of life for `maybeMoveTemp` so that assignments to the tuple-ified outputs will work better
- Add first test case that actually seems to work
- Add diagnostics for conflicting explicit bindings on a parameter
- Add diagnostic for different parameters with overlapping bindings
- Make global-scope varying input/output use a tracking data structure specific to the translation unit for computing locations (so that they are independent of other TUs)
| -rw-r--r-- | source/slang/check.cpp | 6 | ||||
| -rw-r--r-- | source/slang/diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 55 | ||||
| -rw-r--r-- | source/slang/expr-defs.h | 14 | ||||
| -rw-r--r-- | source/slang/lower.cpp | 633 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 417 | ||||
| -rw-r--r-- | tests/reflection/thread-group-size.hlsl.expected | 4 | ||||
| -rw-r--r-- | tests/rewriter/varying-struct.slang | 21 | ||||
| -rw-r--r-- | tests/rewriter/varying-struct.vert | 54 |
9 files changed, 1021 insertions, 186 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index b175b7f86..1c1c6d5d3 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -4801,6 +4801,12 @@ namespace Slang return expr; } + RefPtr<ExpressionSyntaxNode> visitAggTypeCtorExpr(AggTypeCtorExpr* expr) + { + assert(!"unexpected"); + return expr; + } + // // // diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index c85270e08..17291d1ad 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -287,6 +287,9 @@ DIAGNOSTIC(39999, Error, tooManyArguments, "too many arguments to call (got $0, DIAGNOSTIC(39999, Error, invalidIntegerLiteralSuffix, "invalid suffix '$0' on integer literal") DIAGNOSTIC(39999, Error, invalidFloatingPOintLiteralSuffix, "invalid suffix '$0' on floating-point literal") +DIAGNOSTIC(39999, Error, conflictingExplicitBindingsForParameter, "conflicting explicit bindings for parameter '$0'") +DIAGNOSTIC(39999, Error, parameterBindingsOverlap, "explicit parameter bindings overlap for parameters '$0' and '$1'") + // // 4xxxx - IL code generation. // diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index ef922c418..76dc9c75f 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -1991,7 +1991,25 @@ struct EmitVisitor emitSimpleCallExpr(callExpr, outerPrec); } + void visitAggTypeCtorExpr(AggTypeCtorExpr* expr, ExprEmitArg const& arg) + { + auto prec = kEOp_Postfix; + auto outerPrec = arg.outerPrec; + bool needClose = MaybeEmitParens(outerPrec, prec); + emitTypeExp(expr->base); + Emit("("); + bool first = true; + for (auto aa : expr->Arguments) + { + if (!first) Emit(", "); + EmitExpr(aa); + first = false; + } + Emit(")"); + + if(needClose) Emit(")"); + } void visitMemberExpressionSyntaxNode(MemberExpressionSyntaxNode* memberExpr, ExprEmitArg const& arg) { @@ -3245,17 +3263,23 @@ struct EmitVisitor auto declRefType = dataType->As<DeclRefType>(); assert(declRefType); - // We expect to always have layout information - assert(layout); + // We expect the layout, if present, to be for a structured type... + RefPtr<StructTypeLayout> structTypeLayout; + if (layout) + { - // We expect the layout to be for a structured type... - RefPtr<ParameterBlockTypeLayout> bufferLayout = layout->typeLayout.As<ParameterBlockTypeLayout>(); - assert(bufferLayout); + auto typeLayout = layout->typeLayout; + if (auto bufferLayout = typeLayout.As<ParameterBlockTypeLayout>()) + { + typeLayout = bufferLayout->elementTypeLayout; + } - RefPtr<StructTypeLayout> structTypeLayout = bufferLayout->elementTypeLayout.As<StructTypeLayout>(); - assert(structTypeLayout); + structTypeLayout = typeLayout.As<StructTypeLayout>(); + assert(structTypeLayout); + + emitGLSLLayoutQualifiers(layout); + } - emitGLSLLayoutQualifiers(layout); EmitModifiers(varDecl); @@ -3293,13 +3317,16 @@ struct EmitVisitor { for (auto field : getMembersOfType<StructField>(structRef)) { - RefPtr<VarLayout> fieldLayout; - structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout); - // assert(fieldLayout); + if (structTypeLayout) + { + RefPtr<VarLayout> fieldLayout; + structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout); + // assert(fieldLayout); - // TODO(tfoley): We may want to emit *some* of these, - // some of the time... - // emitGLSLLayoutQualifiers(fieldLayout); + // TODO(tfoley): We may want to emit *some* of these, + // some of the time... + // emitGLSLLayoutQualifiers(fieldLayout); + } EmitVarDeclCommon(field); diff --git a/source/slang/expr-defs.h b/source/slang/expr-defs.h index 0dac324b9..dc93407e0 100644 --- a/source/slang/expr-defs.h +++ b/source/slang/expr-defs.h @@ -56,11 +56,21 @@ SYNTAX_CLASS(InitializerListExpr, ExpressionSyntaxNode) SYNTAX_FIELD(List<RefPtr<ExpressionSyntaxNode>>, args) END_SYNTAX_CLASS() +// A base class for expressions with arguments +ABSTRACT_SYNTAX_CLASS(ExprWithArgsBase, ExpressionSyntaxNode) + SYNTAX_FIELD(List<RefPtr<ExpressionSyntaxNode>>, Arguments) +END_SYNTAX_CLASS() + +// An aggregate type constructor +SYNTAX_CLASS(AggTypeCtorExpr, ExprWithArgsBase) + SYNTAX_FIELD(TypeExp, base); +END_SYNTAX_CLASS() + + // A base expression being applied to arguments: covers // both ordinary `()` function calls and `<>` generic application -ABSTRACT_SYNTAX_CLASS(AppExprBase, ExpressionSyntaxNode) +ABSTRACT_SYNTAX_CLASS(AppExprBase, ExprWithArgsBase) SYNTAX_FIELD(RefPtr<ExpressionSyntaxNode>, FunctionExpr) - SYNTAX_FIELD(List<RefPtr<ExpressionSyntaxNode>>, Arguments) END_SYNTAX_CLASS() SIMPLE_SYNTAX_CLASS(InvokeExpressionSyntaxNode, AppExprBase) diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp index b6a19ab34..bbbfe724b 100644 --- a/source/slang/lower.cpp +++ b/source/slang/lower.cpp @@ -219,6 +219,35 @@ public: List<Element> tupleElements; }; +// Pseudo-syntax used during lowering +class VaryingTupleVarDecl : public VarDeclBase +{ +public: + virtual void accept(IDeclVisitor *, void *) override + { + throw "unexpected"; + } +}; + +// Pseudo-syntax used during lowering: +// represents an ordered list of expressions as a single unit +class VaryingTupleExpr : public ExpressionSyntaxNode +{ +public: + virtual void accept(IExprVisitor *, void *) override + { + throw "unexpected"; + } + + struct Element + { + DeclRef<VarDeclBase> originalFieldDeclRef; + RefPtr<ExpressionSyntaxNode> expr; + }; + + List<Element> elements; +}; + struct SharedLoweringContext { CompileRequest* compileRequest; @@ -242,6 +271,13 @@ struct SharedLoweringContext Dictionary<Decl*, RefPtr<Decl>> loweredDecls; Dictionary<Decl*, Decl*> mapLoweredDeclToOriginal; + // Work to be done at the very start and end of the entry point + RefPtr<StatementSyntaxNode> entryPointInitializeStmt; + RefPtr<StatementSyntaxNode> entryPointFinalizeStmt; + + // Counter used for generating unique temporary names + int nameCounter = 0; + bool isRewrite; }; @@ -544,13 +580,22 @@ struct LoweringVisitor // Expressions // - RefPtr<ExpressionSyntaxNode> lowerExpr( + RefPtr<ExpressionSyntaxNode> lowerExprOrTuple( ExpressionSyntaxNode* expr) { if (!expr) return nullptr; return ExprVisitor::dispatch(expr); } + RefPtr<ExpressionSyntaxNode> lowerExpr( + ExpressionSyntaxNode* expr) + { + if (!expr) return nullptr; + + auto result = lowerExprOrTuple(expr); + return maybeReifyTuple(result); + } + // catch-all RefPtr<ExpressionSyntaxNode> visitExpressionSyntaxNode( ExpressionSyntaxNode* expr) @@ -571,6 +616,14 @@ struct LoweringVisitor loweredExpr->Type.type = lowerType(expr->Type.type); } + RefPtr<ExpressionSyntaxNode> createUncheckedVarRef( + char const* name) + { + RefPtr<VarExpressionSyntaxNode> result = new VarExpressionSyntaxNode(); + result->name = name; + return result; + } + RefPtr<ExpressionSyntaxNode> createVarRef( CodePosition const& loc, VarDeclBase* decl) @@ -579,12 +632,17 @@ struct LoweringVisitor { return createTupleRef(loc, tupleDecl); } + else if (auto varyingTupleDecl = dynamic_cast<VaryingTupleVarDecl*>(decl)) + { + return createVaryingTupleRef(loc, varyingTupleDecl); + } else { RefPtr<VarExpressionSyntaxNode> result = new VarExpressionSyntaxNode(); result->Position = loc; result->Type.type = decl->Type.type; result->declRef = makeDeclRef(decl); + result->name = decl->getName(); return result; } } @@ -616,7 +674,14 @@ struct LoweringVisitor result->tupleElements.Add(elem); } -return result; + return result; + } + + RefPtr<ExpressionSyntaxNode> createVaryingTupleRef( + CodePosition const& loc, + VaryingTupleVarDecl* decl) + { + return decl->Expr; } RefPtr<ExpressionSyntaxNode> visitVarExpressionSyntaxNode( @@ -638,13 +703,85 @@ return result; return createTupleRef(expr->Position, tupleVarDecl); } + else if (auto varyingTupleVarDecl = dynamic_cast<VaryingTupleVarDecl*>(loweredDecl)) + { + return createVaryingTupleRef(expr->Position, varyingTupleVarDecl); + } RefPtr<VarExpressionSyntaxNode> loweredExpr = new VarExpressionSyntaxNode(); lowerExprCommon(loweredExpr, expr); loweredExpr->declRef = loweredDeclRef; + loweredExpr->name = expr->name; return loweredExpr; } + String generateName() + { + int id = shared->nameCounter++; + + String result; + result.append("SLANG_tmp_"); + result.append(id); + return result; + } + + // The idea of this function is to take an expression that we plan to + // use/evaluate more than once, and if needed replace it with a + // reference to a temporary (initialized with the expr) so that it + // can safely be re-evaluated. + RefPtr<ExpressionSyntaxNode> maybeMoveTemp( + RefPtr<ExpressionSyntaxNode> expr) + { + // TODO: actually implement this properly! + + // Certain expressions are already in a form we can directly re-use, + // so there is no reason to move them. + if (expr.As<VarExpressionSyntaxNode>()) + return expr; + if (expr.As<ConstantExpressionSyntaxNode>()) + return expr; + + if (auto varyingTupleExpr = expr.As<VaryingTupleExpr>()) + { + RefPtr<VaryingTupleExpr> resultExpr = new VaryingTupleExpr(); + resultExpr->Position = expr->Position; + resultExpr->Type = expr->Type; + for (auto ee : varyingTupleExpr->elements) + { + VaryingTupleExpr::Element elem; + elem.originalFieldDeclRef = ee.originalFieldDeclRef; + elem.expr = maybeMoveTemp(ee.expr); + + resultExpr->elements.Add(elem); + } + + return resultExpr; + } + + // TODO: handle the tuple cases here... + + // In the general case, though, we need to introduce a temporary + RefPtr<Variable> varDecl = new Variable(); + varDecl->Name.Content = generateName(); + varDecl->Type.type = expr->Type.type; + varDecl->Expr = expr; + + addDecl(varDecl); + + return createVarRef(expr->Position, varDecl); + } + + // Similar to the above, this ensures that an l-value expression + // is safe to re-evaluate, by recursively moving things off + // to temporaries where needed. + RefPtr<ExpressionSyntaxNode> ensureSimpleLValue( + RefPtr<ExpressionSyntaxNode> expr) + { + // TODO: actually implement this properly! + + return expr; + } + RefPtr<ExpressionSyntaxNode> createAssignExpr( RefPtr<ExpressionSyntaxNode> leftExpr, RefPtr<ExpressionSyntaxNode> rightExpr) @@ -689,6 +826,133 @@ return result; assert(!leftTuple && !rightTuple); } + auto leftVaryingTuple = leftExpr.As<VaryingTupleExpr>(); + auto rightVaryingTuple = rightExpr.As<VaryingTupleExpr>(); + if (leftVaryingTuple && rightVaryingTuple) + { + RefPtr<VaryingTupleExpr> resultTuple = new VaryingTupleExpr(); + resultTuple->Type.type = lowerType(leftExpr->Type.type); + resultTuple->Position = leftExpr->Position; + + assert(resultTuple->Type.type); + + UInt elementCount = leftVaryingTuple->elements.Count(); + assert(elementCount == rightVaryingTuple->elements.Count()); + + for (UInt ee = 0; ee < elementCount; ++ee) + { + auto leftElem = leftVaryingTuple->elements[ee]; + auto rightElem = rightVaryingTuple->elements[ee]; + + VaryingTupleExpr::Element elem; + elem.originalFieldDeclRef = leftElem.originalFieldDeclRef; + elem.expr = createAssignExpr( + leftElem.expr, + rightElem.expr); + } + } + else if (leftVaryingTuple) + { + // Assigning from ordinary expression on RHS to tuple. + // This will naturally yield a tuple expression. + // + // TODO: need to be careful about side-effects, or + // about dropping sub-expressions after the assignment. + // For now this will really only work directly in + // a statement context. + + UInt elementCount = leftVaryingTuple->elements.Count(); + + // Move everything into temps if we can + + rightExpr = maybeMoveTemp(rightExpr); + for (UInt ee = 0; ee < elementCount; ++ee) + { + auto& leftElem = leftVaryingTuple->elements[ee]; + leftElem.expr = ensureSimpleLValue(leftElem.expr); + } + + // We need to combine the sub-expressions into a giant sequence expression. + // + // We will procede through thigns from last to first, to build a bunch + // of "operator comma" expressions bottom-up. + RefPtr<ExpressionSyntaxNode> resultExpr = leftExpr; + + for (UInt ee = 0; ee < elementCount; ++ee) + { + auto leftElem = leftVaryingTuple->elements[elementCount - ee - 1]; + + RefPtr<MemberExpressionSyntaxNode> rightElemExpr = new MemberExpressionSyntaxNode(); + rightElemExpr->Position = rightExpr->Position; + rightElemExpr->Type = leftElem.expr->Type; + rightElemExpr->declRef = leftElem.originalFieldDeclRef; + rightElemExpr->name = leftElem.originalFieldDeclRef.GetName(); + rightElemExpr->BaseExpression = rightExpr; + + auto subExpr = createAssignExpr( + leftElem.expr, + rightElemExpr); + + RefPtr<InfixExpr> seqExpr = new InfixExpr(); + seqExpr->FunctionExpr = createUncheckedVarRef(","); + seqExpr->Arguments.Add(subExpr); + seqExpr->Arguments.Add(resultExpr); + + resultExpr = seqExpr; + } + + return resultExpr; + } + else if (rightVaryingTuple) + { + // Pretty much the same as the above case, and we should + // probably try to share code eventually. + + UInt elementCount = rightVaryingTuple->elements.Count(); + + // Move everything into temps if we can + + leftExpr = ensureSimpleLValue(leftExpr); + for (UInt ee = 0; ee < elementCount; ++ee) + { + auto& rightElem = rightVaryingTuple->elements[ee]; + rightElem.expr = maybeMoveTemp(rightElem.expr); + } + + // We need to combine the sub-expressions into a giant sequence expression. + // + // We will procede through thigns from last to first, to build a bunch + // of "operator comma" expressions bottom-up. + RefPtr<ExpressionSyntaxNode> resultExpr = leftExpr; + + for (UInt ee = 0; ee < elementCount; ++ee) + { + auto rightElem = rightVaryingTuple->elements[elementCount - ee - 1]; + + RefPtr<MemberExpressionSyntaxNode> leftElemExpr = new MemberExpressionSyntaxNode(); + leftElemExpr->Position = leftExpr->Position; + leftElemExpr->Type = rightElem.expr->Type; + leftElemExpr->declRef = rightElem.originalFieldDeclRef; + leftElemExpr->name = rightElem.originalFieldDeclRef.GetName(); + leftElemExpr->BaseExpression = leftExpr; + + auto subExpr = createAssignExpr( + leftElemExpr, + rightElem.expr); + + RefPtr<InfixExpr> seqExpr = new InfixExpr(); + seqExpr->FunctionExpr = createUncheckedVarRef(","); + seqExpr->Arguments.Add(subExpr); + seqExpr->Arguments.Add(resultExpr); + + resultExpr = seqExpr; + } + + return resultExpr; + } + + // Default case: no tuples of any kind... + RefPtr<AssignExpr> loweredExpr = new AssignExpr(); loweredExpr->Type = leftExpr->Type; @@ -700,8 +964,8 @@ return result; RefPtr<ExpressionSyntaxNode> visitAssignExpr( AssignExpr* expr) { - auto leftExpr = lowerExpr(expr->left); - auto rightExpr = lowerExpr(expr->right); + auto leftExpr = lowerExprOrTuple(expr->left); + auto rightExpr = lowerExprOrTuple(expr->right); auto loweredExpr = createAssignExpr(leftExpr, rightExpr); lowerExprCommon(loweredExpr, expr); @@ -728,6 +992,8 @@ return result; if (auto baseTuple = baseExpr.As<TupleExpr>()) { + indexExpr = maybeMoveTemp(indexExpr); + auto loweredExpr = new TupleExpr(); loweredExpr->Type.type = getSubscripResultType(baseExpr->Type.type); @@ -750,6 +1016,26 @@ return result; return loweredExpr; } + else if (auto baseVaryingTuple = baseExpr.As<VaryingTupleExpr>()) + { + indexExpr = maybeMoveTemp(indexExpr); + + auto loweredExpr = new VaryingTupleExpr(); + loweredExpr->Type.type = getSubscripResultType(baseExpr->Type.type); + + assert(loweredExpr->Type.type); + + for (auto elem : baseVaryingTuple->elements) + { + VaryingTupleExpr::Element loweredElem; + loweredElem.originalFieldDeclRef = elem.originalFieldDeclRef; + loweredElem.expr = createSubscriptExpr( + elem.expr, + indexExpr); + } + + return loweredExpr; + } else { // Default case: just reconstrut a subscript expr @@ -766,7 +1052,7 @@ return result; RefPtr<ExpressionSyntaxNode> visitIndexExpressionSyntaxNode( IndexExpressionSyntaxNode* subscriptExpr) { - auto baseExpr = lowerExpr(subscriptExpr->BaseExpression); + auto baseExpr = lowerExprOrTuple(subscriptExpr->BaseExpression); auto indexExpr = lowerExpr(subscriptExpr->IndexExpression); // An attempt to subscript a tuple must be turned into a @@ -775,6 +1061,10 @@ return result; { return createSubscriptExpr(baseExpr, indexExpr); } + else if (auto baseVaryingTuple = baseExpr.As<VaryingTupleExpr>()) + { + return createSubscriptExpr(baseExpr, indexExpr); + } else { // Default case: just reconstrut a subscript expr @@ -786,8 +1076,42 @@ return result; } } + RefPtr<ExpressionSyntaxNode> maybeReifyTuple( + RefPtr<ExpressionSyntaxNode> expr) + { + if (auto tupleExpr = expr.As<TupleExpr>()) + { + // TODO: need to diagnose + return expr; + } + else if (auto varyingTupleExpr = expr.As<VaryingTupleExpr>()) + { + // Need to pass an ordinary (non-tuple) expression of + // the corresponding type here. + + // TODO(tfoley): This won't work at all for an `out` or `inout` + // function argument, so we'll need to figure out a plan + // to handle that case... + + RefPtr<AggTypeCtorExpr> resultExpr = new AggTypeCtorExpr(); + resultExpr->Type = varyingTupleExpr->Type; + resultExpr->base.type = varyingTupleExpr->Type.type; + assert(resultExpr->Type.type); + + for (auto elem : varyingTupleExpr->elements) + { + addArgs(resultExpr, elem.expr); + } + + return resultExpr; + } + + // Default case: nothing special to this expression + return expr; + } + void addArgs( - InvokeExpressionSyntaxNode* callExpr, + ExprWithArgsBase* callExpr, RefPtr<ExpressionSyntaxNode> argExpr) { if (auto argTuple = argExpr.As<TupleExpr>()) @@ -801,6 +1125,13 @@ return result; addArgs(callExpr, elem.expr); } } + else if (auto varyingArgTuple = argExpr.As<VaryingTupleExpr>()) + { + // Need to pass an ordinary (non-tuple) expression of + // the corresponding type here. + + callExpr->Arguments.Add(maybeReifyTuple(argExpr)); + } else { callExpr->Arguments.Add(argExpr); @@ -817,7 +1148,7 @@ return result; for (auto arg : expr->Arguments) { - auto loweredArg = lowerExpr(arg); + auto loweredArg = lowerExprOrTuple(arg); addArgs(loweredExpr, loweredArg); } @@ -859,7 +1190,7 @@ return result; RefPtr<ExpressionSyntaxNode> visitDerefExpr( DerefExpr* expr) { - auto loweredBase = lowerExpr(expr->base); + auto loweredBase = lowerExprOrTuple(expr->base); if (auto baseTuple = loweredBase.As<TupleExpr>()) { @@ -882,7 +1213,7 @@ return result; RefPtr<ExpressionSyntaxNode> visitMemberExpressionSyntaxNode( MemberExpressionSyntaxNode* expr) { - auto loweredBase = lowerExpr(expr->BaseExpression); + auto loweredBase = lowerExprOrTuple(expr->BaseExpression); auto loweredDeclRef = translateDeclRef(expr->declRef); @@ -927,6 +1258,20 @@ return result; // simply fall through to the ordinary case below. loweredBase = baseTuple->primaryExpr; } + else if (auto baseVaryingTuple = loweredBase.As<VaryingTupleExpr>()) + { + // Search for the element corresponding to this field + for(auto elem : baseVaryingTuple->elements) + { + if (expr->declRef.getDecl() == elem.originalFieldDeclRef.getDecl()) + { + // We found the field! + return elem.expr; + } + } + + assert(!"unexpected"); + } // Default handling: assert(!dynamic_cast<TupleVarDecl*>(loweredDeclRef.getDecl())); @@ -1085,8 +1430,29 @@ return result; } return; } - - // TODO: could also desugar "operator comma" here + else if (auto varyingTupleExpr = expr.As<VaryingTupleExpr>()) + { + for (auto ee : varyingTupleExpr->elements) + { + addExprStmt(ee.expr); + } + return; + } + else if (auto infixExpr = expr.As<InfixExpr>()) + { + if (auto varExpr = infixExpr->FunctionExpr.As<VarExpressionSyntaxNode>()) + { + if (varExpr->name == ",") + { + // Call to "operator comma" + for (auto aa : infixExpr->Arguments) + { + addExprStmt(aa); + } + return; + } + } + } RefPtr<ExpressionStatementSyntaxNode> stmt = new ExpressionStatementSyntaxNode(); stmt->Expression = expr; @@ -1115,7 +1481,7 @@ return result; void visitExpressionStatementSyntaxNode(ExpressionStatementSyntaxNode* stmt) { - addExprStmt(lowerExpr(stmt->Expression)); + addExprStmt(lowerExprOrTuple(stmt->Expression)); } void visitVarDeclrStatementSyntaxNode(VarDeclrStatementSyntaxNode* stmt) @@ -1302,11 +1668,7 @@ return result; RefPtr<ExpressionSyntaxNode> destExpr, RefPtr<ExpressionSyntaxNode> srcExpr) { - RefPtr<AssignExpr> assignExpr = new AssignExpr(); - assignExpr->Position = destExpr->Position; - assignExpr->left = destExpr; - assignExpr->right = srcExpr; - + auto assignExpr = createAssignExpr(destExpr, srcExpr); addExprStmt(assignExpr); } @@ -1330,7 +1692,7 @@ return result; if (resultVariable) { // Do it as an assignment - assign(resultVariable, lowerExpr(stmt->Expression)); + assign(resultVariable, lowerExprOrTuple(stmt->Expression)); } else { @@ -2293,6 +2655,32 @@ return result; } } + bool isImportedStructType(RefPtr<ExpressionType> type) + { + if (type->As<BasicExpressionType>()) return false; + else if (type->As<VectorExpressionType>()) return false; + else if (type->As<MatrixExpressionType>()) return false; + else if (type->As<ResourceType>()) return false; + else if (type->As<BuiltinGenericType>()) return false; + else if (auto declRefType = type->As<DeclRefType>()) + { + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDecl>()) + { + Decl* pp = aggTypeDeclRef.getDecl(); + while (pp->ParentDecl) + pp = pp->ParentDecl; + + // Did the declaration come from this translation unit? + if (pp == shared->entryPointRequest->getTranslationUnit()->SyntaxNode.Ptr()) + return false; + + return true; + } + } + + return false; + } + RefPtr<VarDeclBase> visitVariable( Variable* decl) { @@ -2303,6 +2691,49 @@ return result; { doSampleRateInputCheck(decl); } + + auto varLayout = tryToFindLayout(decl); + if (varLayout) + { + auto inRes = varLayout->FindResourceInfo(LayoutResourceKind::VertexInput); + auto outRes = varLayout->FindResourceInfo(LayoutResourceKind::FragmentOutput); + + if( (inRes || outRes) && isImportedStructType(decl->Type.type)) + { + // We are seemingly looking at a GLSL global-scope varying + // of an aggregate type which was imported from library + // code. We should destructure that into individual + // declarations. + + // We can't easily support `in out` declarations with this approach + assert(!(inRes && outRes)); + + RefPtr<ExpressionSyntaxNode> loweredExpr; + if (inRes) + { + loweredExpr = lowerShaderParameterToGLSLGLobals( + decl, + varLayout, + VaryingParameterDirection::Input); + } + + if (outRes) + { + loweredExpr = lowerShaderParameterToGLSLGLobals( + decl, + varLayout, + VaryingParameterDirection::Output); + } + + assert(loweredExpr); + auto loweredDecl = createVaryingTupleVarDecl( + decl, + loweredExpr); + + registerLoweredDecl(loweredDecl, decl); + return loweredDecl; + } + } } auto loweredDecl = lowerVarDeclCommon(decl, getClass<Variable>()); @@ -2511,11 +2942,10 @@ return result; } } - void lowerSimpleShaderParameterToGLSLGlobal( + RefPtr<ExpressionSyntaxNode> lowerSimpleShaderParameterToGLSLGlobal( VaryingParameterInfo const& info, RefPtr<ExpressionType> varType, - RefPtr<VarLayout> varLayout, - RefPtr<ExpressionSyntaxNode> varExpr) + RefPtr<VarLayout> varLayout) { RefPtr<ExpressionType> type = varType; @@ -2720,8 +3150,16 @@ return result; // Otherwise, check if we need to add one: else if (isIntegralType(varType)) { - auto mod = new HLSLNoInterpolationModifier(); - addModifier(globalVarDecl, mod); + if (info.direction == VaryingParameterDirection::Input + && shared->entryPointRequest->profile.GetStage() == Stage::Vertex) + { + // Don't add extra qualification to VS inputs + } + else + { + auto mod = new HLSLNoInterpolationModifier(); + addModifier(globalVarDecl, mod); + } } @@ -2733,33 +3171,13 @@ return result; globalVarExpr = globalVarRef; } - // TODO: if we are declaring an SOA-ized array, - // this is where those array dimensions would need - // to be tacked on. - // - // That is, this logic should be getting collected into a loop, - // and so we need to have a loop variable we can use to - // index into the two different expressions. - - - // Need to generate an assignment in the right direction. - switch (info.direction) - { - case VaryingParameterDirection::Input: - assign(varExpr, globalVarExpr); - break; - - case VaryingParameterDirection::Output: - assign(globalVarExpr, varExpr); - break; - } + return globalVarExpr; } - void lowerShaderParameterToGLSLGLobalsRec( + RefPtr<ExpressionSyntaxNode> lowerShaderParameterToGLSLGLobalsRec( VaryingParameterInfo const& info, RefPtr<ExpressionType> varType, - RefPtr<VarLayout> varLayout, - RefPtr<ExpressionSyntaxNode> varExpr) + RefPtr<VarLayout> varLayout) { assert(varLayout); @@ -2789,25 +3207,16 @@ return result; VaryingParameterInfo arrayInfo = info; arrayInfo.arraySpecs = &arraySpec; - RefPtr<IndexExpressionSyntaxNode> subscriptExpr = new IndexExpressionSyntaxNode(); - subscriptExpr->Position = varExpr->Position; - subscriptExpr->BaseExpression = varExpr; - // Note that we use the original `varLayout` that was passed in, // since that is the layout that will ultimately need to be // used on the array elements. // // TODO: That won't actually work if we ever had an array of // heterogeneous stuff... - lowerShaderParameterToGLSLGLobalsRec( + return lowerShaderParameterToGLSLGLobalsRec( arrayInfo, arrayType->BaseType, - varLayout, - subscriptExpr); - - // TODO: we need to construct syntax for a loop to initialize - // the array here... - throw "unimplemented"; + varLayout); } else if (auto declRefType = varType->As<DeclRefType>()) { @@ -2817,18 +3226,17 @@ return result; // The shader parameter had a structured type, so we need // to destructure it into its constituent fields + RefPtr<VaryingTupleExpr> tupleExpr = new VaryingTupleExpr(); + tupleExpr->Type.type = varType; + + assert(tupleExpr->Type.type); + for (auto fieldDeclRef : getMembersOfType<VarDeclBase>(aggTypeDeclRef)) { // Don't emit storage for `static` fields here, of course if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) continue; - RefPtr<MemberExpressionSyntaxNode> fieldExpr = new MemberExpressionSyntaxNode(); - fieldExpr->Position = varExpr->Position; - fieldExpr->Type.type = GetType(fieldDeclRef); - fieldExpr->declRef = fieldDeclRef; - fieldExpr->BaseExpression = varExpr; - VaryingParameterVarChain fieldVarChain; fieldVarChain.next = info.varChain; fieldVarChain.varDecl = fieldDeclRef.getDecl(); @@ -2849,38 +3257,38 @@ return result; structTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout); assert(fieldLayout); - lowerShaderParameterToGLSLGLobalsRec( + auto loweredFieldExpr = lowerShaderParameterToGLSLGLobalsRec( fieldInfo, GetType(fieldDeclRef), - fieldLayout, - fieldExpr); + fieldLayout); + + VaryingTupleExpr::Element elem; + elem.originalFieldDeclRef = makeDeclRef(originalFieldDecl).As<VarDeclBase>(); + elem.expr = loweredFieldExpr; + + tupleExpr->elements.Add(elem); } // Okay, we are done with this parameter - return; + return tupleExpr; } } // Default case: just try to emit things as-is - lowerSimpleShaderParameterToGLSLGlobal(info, varType, varLayout, varExpr); + return lowerSimpleShaderParameterToGLSLGlobal(info, varType, varLayout); } - void lowerShaderParameterToGLSLGLobals( - RefPtr<Variable> localVarDecl, + RefPtr<ExpressionSyntaxNode> lowerShaderParameterToGLSLGLobals( + RefPtr<VarDeclBase> originalVarDecl, RefPtr<VarLayout> paramLayout, VaryingParameterDirection direction) { - auto name = localVarDecl->getName(); - auto declRef = makeDeclRef(localVarDecl.Ptr()); - - RefPtr<VarExpressionSyntaxNode> expr = new VarExpressionSyntaxNode(); - expr->name = name; - expr->declRef = declRef; - expr->Type.type = GetType(declRef); + auto name = originalVarDecl->getName(); + auto declRef = makeDeclRef(originalVarDecl.Ptr()); VaryingParameterVarChain varChain; varChain.next = nullptr; - varChain.varDecl = localVarDecl; + varChain.varDecl = originalVarDecl; VaryingParameterInfo info; info.name = name; @@ -2899,11 +3307,45 @@ return result; break; } - lowerShaderParameterToGLSLGLobalsRec( + auto loweredType = lowerType(originalVarDecl->Type); + + auto loweredExpr = lowerShaderParameterToGLSLGLobalsRec( info, - localVarDecl->getType(), - paramLayout, - expr); + loweredType.type, + paramLayout); + +#if 0 + RefPtr<VaryingTupleVarDecl> loweredDecl = createVaryingTupleVarDecl( + originalVarDecl, + loweredType, + loweredExpr); + + registerLoweredDecl(loweredDecl, originalVarDecl); + addDecl(loweredDecl); +#endif + + return loweredExpr; + } + + RefPtr<VaryingTupleVarDecl> createVaryingTupleVarDecl( + RefPtr<VarDeclBase> originalVarDecl, + TypeExp const& loweredType, + RefPtr<ExpressionSyntaxNode> loweredExpr) + { + RefPtr<VaryingTupleVarDecl> loweredDecl = new VaryingTupleVarDecl(); + loweredDecl->Name = originalVarDecl->Name; + loweredDecl->Type = loweredType; + loweredDecl->Expr = loweredExpr; + + return loweredDecl; + } + + RefPtr<VaryingTupleVarDecl> createVaryingTupleVarDecl( + RefPtr<VarDeclBase> originalVarDecl, + RefPtr<ExpressionSyntaxNode> loweredExpr) + { + auto loweredType = lowerType(originalVarDecl->Type); + return createVaryingTupleVarDecl(originalVarDecl, loweredType, loweredExpr); } struct EntryPointParamPair @@ -2976,10 +3418,12 @@ return result; || paramDecl->HasModifier<InOutModifier>() || !paramDecl->HasModifier<OutModifier>()) { - subVisitor.lowerShaderParameterToGLSLGLobals( - paramPair.lowered, + auto loweredExpr = subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.original, paramPair.layout, VaryingParameterDirection::Input); + + subVisitor.assign(paramPair.lowered, loweredExpr); } } @@ -3044,18 +3488,27 @@ return result; if (paramDecl->HasModifier<OutModifier>() || paramDecl->HasModifier<InOutModifier>()) { - subVisitor.lowerShaderParameterToGLSLGLobals( - paramPair.lowered, + auto loweredExpr = subVisitor.lowerShaderParameterToGLSLGLobals( + paramPair.original, paramPair.layout, VaryingParameterDirection::Output); + + subVisitor.assign(loweredExpr, paramPair.lowered); } } if (resultVarDecl) { - subVisitor.lowerShaderParameterToGLSLGLobals( - resultVarDecl, - entryPointLayout->resultLayout, - VaryingParameterDirection::Output); + VaryingParameterInfo info; + info.name = "SLANG_out_" + resultVarDecl->getName(); + info.direction = VaryingParameterDirection::Output; + info.varChain = nullptr; + + auto loweredExpr = lowerShaderParameterToGLSLGLobalsRec( + info, + resultVarDecl->Type.type, + entryPointLayout->resultLayout); + + subVisitor.assign(loweredExpr, resultVarDecl); } bodyStmt->body = subVisitor.stmtBeingBuilt; diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index ca135f1ba..15025f1e4 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -11,9 +11,15 @@ namespace Slang { +struct ParameterInfo; + // Information on ranges of registers already claimed/used struct UsedRange { + // What parameter has claimed this range? + ParameterInfo* parameter = nullptr; + + // Begin/end of the range (half-open interval) UInt begin; UInt end; }; @@ -26,42 +32,78 @@ bool operator<(UsedRange left, UsedRange right) return false; } +static bool rangesOverlap(UsedRange const& x, UsedRange const& y) +{ + assert(x.begin <= x.end); + assert(y.begin <= y.end); + + // If they don't overlap, then one must be earlier than the other, + // and that one must therefore *end* before the other *begins* + + if (x.end <= y.begin) return false; + if (y.end <= x.begin) return false; + + // Otherwise they must overlap + return true; +} + + struct UsedRanges { List<UsedRange> ranges; // Add a range to the set, either by extending // an existing range, or by adding a new one... - void Add(UsedRange const& range) + // + // If we find that the new range overlaps with + // an existing range for a *different* parameter + // then we return that parameter so that the + // caller can issue an error. + ParameterInfo* Add(UsedRange const& range) { + ParameterInfo* newParam = range.parameter; + ParameterInfo* existingParam = nullptr; + for (auto& rr : ranges) + { + if (rangesOverlap(rr, range) + && rr.parameter + && rr.parameter != newParam) + { + // there was an overlap! + existingParam = rr.parameter; + } + } + for (auto& rr : ranges) { if (rr.begin == range.end) { rr.begin = range.begin; - return; + return existingParam; } else if (rr.end == range.begin) { rr.end = range.end; - return; + return existingParam; } } ranges.Add(range); ranges.Sort(); + return existingParam; } - void Add(UInt begin, UInt end) + ParameterInfo* Add(ParameterInfo* param, UInt begin, UInt end) { UsedRange range; + range.parameter = param; range.begin = begin; range.end = end; - Add(range); + return Add(range); } // Try to find space for `count` entries - UInt Allocate(UInt count) + UInt Allocate(ParameterInfo* param, UInt count) { UInt begin = 0; @@ -76,7 +118,7 @@ struct UsedRanges if (end >= begin + count) { // ... then claim it and be done - Add(begin, begin + count); + Add(param, begin, begin + count); return begin; } @@ -87,7 +129,7 @@ struct UsedRanges // We've run out of ranges to check, so we // can safely go after the last one! - Add(begin, begin + count); + Add(param, begin, begin + count); return begin; } }; @@ -104,6 +146,13 @@ enum kLayoutResourceKindCount = SLANG_PARAMETER_CATEGORY_COUNT, }; +struct UsedRangeSet : RefObject +{ + // Information on what ranges of "registers" have already + // been claimed, for each resource type + UsedRanges usedResourceRanges[kLayoutResourceKindCount]; +}; + // Information on a single parameter struct ParameterInfo : RefObject { @@ -115,6 +164,9 @@ struct ParameterInfo : RefObject // The next parameter that has the same name... ParameterInfo* nextOfSameName; + // The translation unit this parameter is specific to, if any + TranslationUnitRequest* translationUnit = nullptr; + ParameterInfo() { // Make sure we aren't claiming any resources yet @@ -125,10 +177,19 @@ struct ParameterInfo : RefObject } }; +struct EntryPointParameterBindingContext +{ + // What ranges of resources bindings are already claimed for this translation unit + UsedRangeSet usedRangeSet; +}; + // State that is shared during parameter binding, // across all translation units struct SharedParameterBindingContext { + // The base compile request + CompileRequest* compileRequest; + LayoutRulesFamilyImpl* defaultLayoutRules; // All shader parameters we've discovered so far, and started to lay out... @@ -137,15 +198,29 @@ struct SharedParameterBindingContext // The program layout we are trying to construct RefPtr<ProgramLayout> programLayout; - // Information on what ranges of "registers" have already - // been claimed, for each resource type - UsedRanges usedResourceRanges[kLayoutResourceKindCount]; + // What ranges of resources bindings are already claimed at the global scope? + // We store one of these for each declared binding space/set. + // + Dictionary<UInt, RefPtr<UsedRangeSet>> globalSpaceUsedRangeSets; + + // What ranges of resource bindings are claimed for particular translation unit? + // This is only used for varying input/output. + // + Dictionary<TranslationUnitRequest*, RefPtr<UsedRangeSet>> translationUnitUsedRangeSets; }; +static DiagnosticSink* getSink(SharedParameterBindingContext* shared) +{ + return &shared->compileRequest->mSink; +} + // State that might be specific to a single translation unit // or event to an entry point. struct ParameterBindingContext { + // The translation unit we are processing right now + TranslationUnitRequest* translationUnit; + // All the shared state needs to be available SharedParameterBindingContext* shared; @@ -162,6 +237,12 @@ struct ParameterBindingContext SourceLanguage sourceLanguage; }; +static DiagnosticSink* getSink(ParameterBindingContext* context) +{ + return getSink(context->shared); +} + + struct LayoutSemanticInfo { LayoutResourceKind kind; // the register kind @@ -298,37 +379,19 @@ static bool findLayoutArg( // -RefPtr<TypeLayout> -getTypeLayoutForGlobalShaderParameter_GLSL( +static bool isGLSLBuiltinName(VarDeclBase* varDecl) +{ + return varDecl->getName().StartsWith("gl_"); +} + +RefPtr<ExpressionType> tryGetEffectiveTypeForGLSLVaryingInput( ParameterBindingContext* context, VarDeclBase* varDecl) { - auto rules = context->layoutRules; - auto type = varDecl->getType(); - - // A GLSL shader parameter will be marked with - // a qualifier to match the boundary it uses - // - // In the case of a parameter block, we will have - // consumed this qualifier as part of parsing, - // so that it won't be present on the declaration - // any more. As such we also inspect the type - // of the variable. - - // We want to check for a constant-buffer type with a `push_constant` layout - // qualifier before we move on to anything else. - if (varDecl->HasModifier<GLSLPushConstantLayoutModifier>() && type->As<ConstantBufferType>()) - return CreateTypeLayout(type, rules->getPushConstantBufferRules()); - - // TODO(tfoley): We have multiple variations of - // the `uniform` modifier right now, and that - // needs to get fixed... - if(varDecl->HasModifier<HLSLUniformModifier>() || type->As<ConstantBufferType>()) - return CreateTypeLayout(type, rules->getConstantBufferRules()); - - if(varDecl->HasModifier<GLSLBufferModifier>() || type->As<GLSLShaderStorageBufferType>()) - return CreateTypeLayout(type, rules->getShaderStorageBufferRules()); + if (isGLSLBuiltinName(varDecl)) + return nullptr; + auto type = varDecl->getType(); if( varDecl->HasModifier<InModifier>() || type->As<GLSLInputParameterBlockType>()) { // Special case to handle "arrayed" shader inputs, as used @@ -353,9 +416,20 @@ getTypeLayoutForGlobalShaderParameter_GLSL( break; } - return CreateTypeLayout(type, rules->getVaryingInputRules()); + return type; } + return nullptr; +} + +RefPtr<ExpressionType> tryGetEffectiveTypeForGLSLVaryingOutput( + ParameterBindingContext* context, + VarDeclBase* varDecl) +{ + if (isGLSLBuiltinName(varDecl)) + return nullptr; + + auto type = varDecl->getType(); if( varDecl->HasModifier<OutModifier>() || type->As<GLSLOutputParameterBlockType>()) { // Special case to handle "arrayed" shader outputs, as used @@ -381,7 +455,55 @@ getTypeLayoutForGlobalShaderParameter_GLSL( break; } - return CreateTypeLayout(type, rules->getVaryingOutputRules()); + return type; + } + + return nullptr; +} + +RefPtr<TypeLayout> +getTypeLayoutForGlobalShaderParameter_GLSL( + ParameterBindingContext* context, + VarDeclBase* varDecl) +{ + auto rules = context->layoutRules; + auto type = varDecl->getType(); + + // A GLSL shader parameter will be marked with + // a qualifier to match the boundary it uses + // + // In the case of a parameter block, we will have + // consumed this qualifier as part of parsing, + // so that it won't be present on the declaration + // any more. As such we also inspect the type + // of the variable. + + // We want to check for a constant-buffer type with a `push_constant` layout + // qualifier before we move on to anything else. + if (varDecl->HasModifier<GLSLPushConstantLayoutModifier>() && type->As<ConstantBufferType>()) + return CreateTypeLayout(type, rules->getPushConstantBufferRules()); + + // TODO(tfoley): We have multiple variations of + // the `uniform` modifier right now, and that + // needs to get fixed... + if(varDecl->HasModifier<HLSLUniformModifier>() || type->As<ConstantBufferType>()) + return CreateTypeLayout(type, rules->getConstantBufferRules()); + + if(varDecl->HasModifier<GLSLBufferModifier>() || type->As<GLSLShaderStorageBufferType>()) + return CreateTypeLayout(type, rules->getShaderStorageBufferRules()); + + if (auto effectiveVaryingInputType = tryGetEffectiveTypeForGLSLVaryingInput(context, varDecl)) + { + // We expect to handle these elsewhere + assert(!"unexpected"); + return CreateTypeLayout(effectiveVaryingInputType, rules->getVaryingInputRules()); + } + + if (auto effectiveVaryingOutputType = tryGetEffectiveTypeForGLSLVaryingOutput(context, varDecl)) + { + // We expect to handle these elsewhere + assert(!"unexpected"); + return CreateTypeLayout(effectiveVaryingOutputType, rules->getVaryingOutputRules()); } // A `const` global with a `layout(constant_id = ...)` modifier @@ -446,13 +568,83 @@ getTypeLayoutForGlobalShaderParameter( // +enum EntryPointParameterDirection +{ + kEntryPointParameterDirection_Input = 0x1, + kEntryPointParameterDirection_Output = 0x2, +}; +typedef unsigned int EntryPointParameterDirectionMask; + +struct EntryPointParameterState +{ + String* optSemanticName = nullptr; + int* ioSemanticIndex = nullptr; + EntryPointParameterDirectionMask directionMask; + int semanticSlotCount; +}; + + +static RefPtr<TypeLayout> processEntryPointParameter( + ParameterBindingContext* context, + RefPtr<ExpressionType> type, + EntryPointParameterState const& state, + RefPtr<VarLayout> varLayout); + +static void collectGlobalScopeGLSLVaryingParameter( + ParameterBindingContext* context, + RefPtr<VarDeclBase> varDecl, + RefPtr<ExpressionType> effectiveType, + EntryPointParameterDirection direction) +{ + int defaultSemanticIndex = 0; + + EntryPointParameterState state; + state.directionMask = direction; + state.ioSemanticIndex = &defaultSemanticIndex; + + RefPtr<VarLayout> varLayout = new VarLayout(); + varLayout->varDecl = makeDeclRef(varDecl.Ptr()); + + varLayout->typeLayout = processEntryPointParameter( + context, + effectiveType, + state, + varLayout); + + // Now add it to our list of reflection parameters, so + // that it can get a location assigned later... + auto parameterName = varDecl->Name.Content; + ParameterInfo* parameterInfo = new ParameterInfo(); + parameterInfo->translationUnit = context->translationUnit; + context->shared->parameters.Add(parameterInfo); + parameterInfo->varLayouts.Add(varLayout); +} // Collect a single declaration into our set of parameters static void collectGlobalScopeParameter( ParameterBindingContext* context, RefPtr<VarDeclBase> varDecl) { + // HACK: We need to intercept GLSL varying `in` and `out` here, way earlier + // in the process, so that we can avoid all kinds of nastiness that would + // otherwise be applied to them. + if (context->sourceLanguage == SourceLanguage::GLSL) + { + if (auto effectiveVaryingInputType = tryGetEffectiveTypeForGLSLVaryingInput(context, varDecl)) + { + collectGlobalScopeGLSLVaryingParameter(context, varDecl, effectiveVaryingInputType, kEntryPointParameterDirection_Input); + return; + } + + if (auto effectiveVaryingOutputType = tryGetEffectiveTypeForGLSLVaryingOutput(context, varDecl)) + { + collectGlobalScopeGLSLVaryingParameter(context, varDecl, effectiveVaryingOutputType, kEntryPointParameterDirection_Output); + return; + } + } + + // We use a single operation to both check whether the // variable represents a shader parameter, and to compute // the layout for that parameter's type. @@ -506,11 +698,41 @@ static void collectGlobalScopeParameter( parameterInfo->varLayouts.Add(varLayout); } +static RefPtr<UsedRangeSet> findUsedRangeSetForSpace( + ParameterBindingContext* context, + UInt space) +{ + RefPtr<UsedRangeSet> usedRangeSet; + if (context->shared->globalSpaceUsedRangeSets.TryGetValue(space, usedRangeSet)) + return usedRangeSet; + + usedRangeSet = new UsedRangeSet(); + context->shared->globalSpaceUsedRangeSets.Add(space, usedRangeSet); + return usedRangeSet; +} + +static RefPtr<UsedRangeSet> findUsedRangeSetForTranslationUnit( + ParameterBindingContext* context, + TranslationUnitRequest* translationUnit) +{ + if (!translationUnit) + return findUsedRangeSetForSpace(context, 0); + + RefPtr<UsedRangeSet> usedRangeSet; + if (context->shared->translationUnitUsedRangeSets.TryGetValue(translationUnit, usedRangeSet)) + return usedRangeSet; + + usedRangeSet = new UsedRangeSet(); + context->shared->translationUnitUsedRangeSets.Add(translationUnit, usedRangeSet); + return usedRangeSet; +} + static void addExplicitParameterBinding( ParameterBindingContext* context, RefPtr<ParameterInfo> parameterInfo, LayoutSemanticInfo const& semanticInfo, - UInt count) + UInt count, + RefPtr<UsedRangeSet> usedRangeSet = nullptr) { auto kind = semanticInfo.kind; @@ -524,7 +746,8 @@ static void addExplicitParameterBinding( || bindingInfo.index != semanticInfo.index || bindingInfo.space != semanticInfo.space ) { - // TODO: diagnose! + auto firstVarDecl = parameterInfo->varLayouts[0]->varDecl.getDecl(); + getSink(context)->diagnose(firstVarDecl, Diagnostics::conflictingExplicitBindingsForParameter, firstVarDecl->getName()); } // TODO(tfoley): `register` semantics can technically be @@ -536,14 +759,21 @@ static void addExplicitParameterBinding( bindingInfo.index = semanticInfo.index; bindingInfo.space = semanticInfo.space; - // If things are bound in `space0` (the default), then we need - // to lay claim to the register range used, so that automatic - // assignment doesn't go and use the same registers. - if (semanticInfo.space == 0) + if (!usedRangeSet) + { + usedRangeSet = findUsedRangeSetForSpace(context, semanticInfo.space); + } + auto overlappedParameterInfo = usedRangeSet->usedResourceRanges[(int)semanticInfo.kind].Add( + parameterInfo, + semanticInfo.index, + semanticInfo.index + count); + + if (overlappedParameterInfo) { - context->shared->usedResourceRanges[(int)semanticInfo.kind].Add( - semanticInfo.index, - semanticInfo.index + count); + auto paramA = parameterInfo->varLayouts[0]->varDecl.getDecl(); + auto paramB = overlappedParameterInfo->varLayouts[0]->varDecl.getDecl(); + + getSink(context)->diagnose(paramA, Diagnostics::parameterBindingsOverlap, paramA->getName(), paramB->getName()); } } } @@ -603,6 +833,11 @@ static void addExplicitParameterBindings_GLSL( // the index/offset/etc. // + // We also may need to store explicit binding info in a different place, + // in the case of varying input/output, since we don't want to collect + // things globally; + RefPtr<UsedRangeSet> usedRangeSet; + TypeLayout::ResourceInfo* resInfo = nullptr; LayoutSemanticInfo semanticInfo; semanticInfo.index = 0; @@ -620,12 +855,16 @@ static void addExplicitParameterBindings_GLSL( // Try to find `location` binding if(!findLayoutArg<GLSLLocationLayoutModifier>(varDecl, &semanticInfo.index)) return; + + usedRangeSet = findUsedRangeSetForTranslationUnit(context, parameterInfo->translationUnit); } else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::FragmentOutput)) != nullptr ) { // Try to find `location` binding if(!findLayoutArg<GLSLLocationLayoutModifier>(varDecl, &semanticInfo.index)) return; + + usedRangeSet = findUsedRangeSetForTranslationUnit(context, parameterInfo->translationUnit); } else if( (resInfo = typeLayout->FindResourceInfo(LayoutResourceKind::SpecializationConstant)) != nullptr ) { @@ -642,7 +881,7 @@ static void addExplicitParameterBindings_GLSL( auto count = resInfo->count; semanticInfo.kind = kind; - addExplicitParameterBinding(context, parameterInfo, semanticInfo, int(count)); + addExplicitParameterBinding(context, parameterInfo, semanticInfo, int(count), usedRangeSet); } // Given a single parameter, collect whatever information we have on @@ -696,12 +935,31 @@ static void completeBindingsForParameter( continue; } + // For now we only auto-generate bindings in space zero + // + // TODO: we may want to support searching for a space with + // capacity for our resource, just in case somebody has + // claimed the entire range... + UInt space = 0; + + RefPtr<UsedRangeSet> usedRangeSet; + switch (kind) + { + default: + usedRangeSet = findUsedRangeSetForSpace(context, space); + break; + + case LayoutResourceKind::VertexInput: + case LayoutResourceKind::FragmentOutput: + usedRangeSet = findUsedRangeSetForTranslationUnit(context, parameterInfo->translationUnit); + break; + } + auto count = typeRes.count; bindingInfo.count = count; - bindingInfo.index = context->shared->usedResourceRanges[(int)kind].Allocate((int) count); + bindingInfo.index = usedRangeSet->usedResourceRanges[(int)kind].Allocate(parameterInfo, (int) count); - // For now we only auto-generate bindings in space zero - bindingInfo.space = 0; + bindingInfo.space = space; } // At this point we should have explicit binding locations chosen for @@ -799,21 +1057,6 @@ SimpleSemanticInfo decomposeSimpleSemantic( return info; } -enum EntryPointParameterDirection -{ - kEntryPointParameterDirection_Input = 0x1, - kEntryPointParameterDirection_Output = 0x2, -}; -typedef unsigned int EntryPointParameterDirectionMask; - -struct EntryPointParameterState -{ - String* optSemanticName; - int* ioSemanticIndex; - EntryPointParameterDirectionMask directionMask; - int semanticSlotCount; -}; - static RefPtr<TypeLayout> processSimpleEntryPointParameter( ParameterBindingContext* context, RefPtr<ExpressionType> type, @@ -844,7 +1087,18 @@ static RefPtr<TypeLayout> processSimpleEntryPointParameter( // once we've gone to the trouble of looking it all up... if( sn == "sv_target" ) { - context->shared->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(semanticIndex, semanticIndex + semanticSlotCount); + // TODO: construct a `ParameterInfo` we can use here so that + // overlapped layout errors get reported nicely. + + auto usedResourceSet = findUsedRangeSetForSpace(context, 0); + usedResourceSet->usedResourceRanges[int(LayoutResourceKind::UnorderedAccess)].Add(nullptr, semanticIndex, semanticIndex + semanticSlotCount); + + + // We also need to track this as an ordinary varying output from the stage, + // since that is how GLSL will want to see it. + auto rules = context->layoutRules->getVaryingOutputRules(); + SimpleLayoutInfo layout = GetLayout(type, rules); + typeLayout->addResourceUsage(layout.kind, layout.size); } } @@ -882,12 +1136,6 @@ static RefPtr<TypeLayout> processSimpleEntryPointParameter( return typeLayout; } -static RefPtr<TypeLayout> processEntryPointParameter( - ParameterBindingContext* context, - RefPtr<ExpressionType> type, - EntryPointParameterState const& state, - RefPtr<VarLayout> varLayout); - static RefPtr<TypeLayout> processEntryPointParameterWithPossibleSemantic( ParameterBindingContext* context, Decl* declForSemantic, @@ -908,7 +1156,7 @@ static RefPtr<TypeLayout> processEntryPointParameterWithPossibleSemantic( subState.optSemanticName = &semanticInfo.name; subState.ioSemanticIndex = &semanticIndex; - processEntryPointParameter(context, type, subState, varLayout); + return processEntryPointParameter(context, type, subState, varLayout); } } @@ -1200,6 +1448,8 @@ static void collectModuleParameters( ParameterBindingContext contextData = *inContext; auto context = &contextData; + context->translationUnit = nullptr; + context->stage = Stage::Unknown; // All imported modules are implicitly Slang code @@ -1225,6 +1475,7 @@ static void collectParameters( for( auto& translationUnit : request->translationUnits ) { + context->translationUnit = translationUnit; context->stage = inferStageForTranslationUnit(translationUnit.Ptr()); context->sourceLanguage = translationUnit->sourceLanguage; @@ -1262,6 +1513,7 @@ void generateParameterBindings( // Create a context to hold shared state during the process // of generating parameter bindings SharedParameterBindingContext sharedContext; + sharedContext.compileRequest = request; sharedContext.defaultLayoutRules = rules; sharedContext.programLayout = programLayout; @@ -1269,6 +1521,7 @@ void generateParameterBindings( // declared into the global scope ParameterBindingContext context; context.shared = &sharedContext; + context.translationUnit = nullptr; context.layoutRules = sharedContext.defaultLayoutRules; // Walk through AST to discover all the parameters @@ -1299,12 +1552,18 @@ void generateParameterBindings( ParameterBindingInfo globalConstantBufferBinding; if( anyGlobalUniforms ) { + // TODO: this logic is only correct for D3D targets, where + // global-scope uniforms get wrapped into a constant buffer. + + UInt space = 0; + auto usedRangeSet = findUsedRangeSetForSpace(&context, space); + globalConstantBufferBinding.index = - context.shared->usedResourceRanges[ - (int)LayoutResourceKind::ConstantBuffer].Allocate(1); + usedRangeSet->usedResourceRanges[ + (int)LayoutResourceKind::ConstantBuffer].Allocate(nullptr, 1); // For now we only auto-generate bindings in space zero - globalConstantBufferBinding.space = 0; + globalConstantBufferBinding.space = space; } diff --git a/tests/reflection/thread-group-size.hlsl.expected b/tests/reflection/thread-group-size.hlsl.expected index 60d5e822c..d139c5b64 100644 --- a/tests/reflection/thread-group-size.hlsl.expected +++ b/tests/reflection/thread-group-size.hlsl.expected @@ -21,7 +21,9 @@ standard output = { "parameters": [ { "name": "tid", - "binding": {"kind": "vertexInput", "index": 0}, + "bindings": [ + + ], "type": { "kind": "vector", "elementCount": 3, diff --git a/tests/rewriter/varying-struct.slang b/tests/rewriter/varying-struct.slang new file mode 100644 index 000000000..92e9dda2e --- /dev/null +++ b/tests/rewriter/varying-struct.slang @@ -0,0 +1,21 @@ +//TEST_IGNORE_FILE: + +struct VS_IN +{ + float4 x : X; + float4 y : Y; +}; + +struct VS_OUT +{ + float4 color : COLOR; + float4 posH : SV_Position; +}; + +VS_OUT doIt(VS_IN i) +{ + VS_OUT o; + o.color = i.x; + o.posH = i.y; + return o; +} diff --git a/tests/rewriter/varying-struct.vert b/tests/rewriter/varying-struct.vert new file mode 100644 index 000000000..74ca8be37 --- /dev/null +++ b/tests/rewriter/varying-struct.vert @@ -0,0 +1,54 @@ +#version 450 core +//TEST:COMPARE_GLSL: + +#if defined(__SLANG__) + +__import varying_struct; + +in VS_IN foo; +out VS_OUT bar; + +void main() +{ + bar = doIt(foo); +} + +#else + +struct VS_IN +{ + vec4 x; + vec4 y; +}; + +struct VS_OUT +{ + vec4 color; + vec4 posH; +}; + +VS_OUT doIt(VS_IN i) +{ + VS_OUT o; + o.color = i.x; + o.posH = i.y; + return o; +} + +layout(location = 0) +out vec4 SLANG_out_bar_color; + +layout(location = 0) +in vec4 SLANG_in_foo_x; + +layout(location = 1) +in vec4 SLANG_in_foo_y; + +void main() +{ + VS_OUT SLANG_tmp_0 = doIt(VS_IN(SLANG_in_foo_x, SLANG_in_foo_y)); + SLANG_out_bar_color = SLANG_tmp_0.color; + gl_Position = SLANG_tmp_0.posH; +} + +#endif |
