diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2017-12-06 13:55:31 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-12-06 13:55:31 -0800 |
| commit | 301cdf5ef42797b1073d9e6c741ef0ba98a38792 (patch) | |
| tree | f3b1af0ab973bddd4c9138f7e482aef59b7acbd0 | |
| parent | b487516880f56fd69ff76bf7cb3f0f1711bc356d (diff) | |
Make AST and IR share type legalization code (#303)
* Make AST and IR share type legalization code
A previous change already made it so that the AST-to-AST lowering/legalization pass could work together with IR-based lowering of `import`ed code, but that change didn't take into account the case where a function written in the AST needed to call an IR function and pass in a type that required legalization.
Both the IR-based and AST-based passes had their own approaches to type legalization, that mostly agreed on the desired output, but they ended up creating their own representations for legalized types which would mean that for a function call the caller and callee might end up legalizing the parameter list to use different types.
This change tries to fix this issue (and adds a new test case that relies on the fix) by massively overhauling the AST-based legalization pass so that it uses the same type legalization code as the IR. The shared code has been moved out into `legalize-types.{h,cpp}`.
Notes:
- I eliminated the `FilteredTupleType` type, since it was starting to cause code duplication in a lot of places. Instead, type legalization just creates new `struct` types to represent the result of filtering.
- One big consequence of this is that the `LegalType::pair` case needs to remember for each field in the original type which field (if any) in the new `struct` type it maps to
- A big source of complexity (and probably bugs) in this code is trying to figure out how to parent these new `struct` definitions effectively. A good follow-on change would be something that outputs declarations on-demand during the AST emit logic (as we do for the IR), just to avoid some of this song and dance.
- The old AST type legalization had a notion of both a "tuple" type and a "varying tuple" type. The "tuple" case was quite complex, and combined behavior currently handled by `LegalType::pair` (for splitting into ordinary and special sides) and `LegalType::tuple` (for holding multiple distinct elements to represent the fields of an aggregate). The "varying tuple" case was closer to `LegalType::tuple`, so I tried to just re-use the existing logic for that too. The one place this potentially gets messy is in `reifyTuple()`.
- The messiest bit of handling the "varying tuple" concept (which is used for GLSL shader inputs/outputs since they have to be scalarized) is that when passing them as function arguments we need to reify the tuple back into a structured value. Because the `LegalExpr` hierarchy doesn't have type information, but constructing a value of the "original" type requires such information, things get a little messy.
- I did *not* try to deal with any of the logic related to handling system inputs/outputs for cross-compilation purposes. Of course, the long-term goal is that any actual cross-compilation is handled via the IR, but this change can't afford to break the AST-based path just yet. As a result, there is still quite a bit of complexity in the handling of assignment, to deal with cases where "fixups" are required.
* fixup: bad code in macro, not caught by Visual Studio compiler
* fixup: more stuff missed by VS compiler
* fixup: VS continutes to miss stuff in UNREACHABLE_RETURN
| -rw-r--r-- | slang.h | 1 | ||||
| -rw-r--r-- | source/slang/ast-legalize.cpp | 2495 | ||||
| -rw-r--r-- | source/slang/ast-legalize.h | 12 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 256 | ||||
| -rw-r--r-- | source/slang/ir-legalize-types.cpp | 1047 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 19 | ||||
| -rw-r--r-- | source/slang/legalize-types.cpp | 1086 | ||||
| -rw-r--r-- | source/slang/legalize-types.h | 296 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 15 | ||||
| -rw-r--r-- | source/slang/mangle.cpp | 8 | ||||
| -rw-r--r-- | source/slang/modifier-defs.h | 19 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 2 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 2 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 127 | ||||
| -rw-r--r-- | source/slang/type-defs.h | 34 | ||||
| -rw-r--r-- | tests/compute/rewriter-array-type.hlsl | 35 | ||||
| -rw-r--r-- | tests/compute/rewriter-array-type.hlsl.expected.txt | 4 | ||||
| -rw-r--r-- | tests/compute/rewriter-types.hlsl | 34 | ||||
| -rw-r--r-- | tests/compute/rewriter-types.hlsl.expected.txt | 4 | ||||
| -rw-r--r-- | tests/compute/rewriter-types.slang | 17 |
20 files changed, 2933 insertions, 2580 deletions
@@ -1174,6 +1174,7 @@ namespace slang #include "source/slang/emit.cpp" #include "source/slang/ir.cpp" #include "source/slang/ir-legalize-types.cpp" +#include "source/slang/legalize-types.cpp" #include "source/slang/lexer.cpp" #include "source/slang/mangle.cpp" #include "source/slang/name.cpp" diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp index 1b045cbc9..eaceecc00 100644 --- a/source/slang/ast-legalize.cpp +++ b/source/slang/ast-legalize.cpp @@ -3,6 +3,7 @@ #include "emit.h" #include "ir-insts.h" +#include "legalize-types.h" #include "type-layout.h" #include "visitor.h" @@ -37,100 +38,10 @@ struct CloneVisitor // -// - -class TupleExpr; -class TupleVarDecl; -class VaryingTupleExpr; -class VaryingTupleVarDecl; - - -// The result of lowering a declaration will usually be a declaration, -// but it might also be a "tuple" declaration, in cases where we needed -// to sclarize (or partially scalarize) things to guarantee validity. -struct LoweredDecl -{ - enum class Flavor - { - Decl, // A single declaration (the default case) - Tuple, // A `TupleVarDecl` representing multiple decls - VaryingTuple, // A `VaryingTupleVarDecl` representing multiple decls - }; - - LoweredDecl() - : flavor(Flavor::Decl) - {} - - LoweredDecl(Decl* decl) - : value(decl) - , flavor(Flavor::Decl) - {} - - LoweredDecl(TupleVarDecl* decl) - : value((RefObject*) decl) - , flavor(Flavor::Tuple) - {} - - LoweredDecl(VaryingTupleVarDecl* decl) - : value((RefObject*) decl) - , flavor(Flavor::VaryingTuple) - {} - - Flavor getFlavor() const { return flavor; } - RefObject* getValue() const { return value; } - - Decl* getDecl() const - { - SLANG_ASSERT(getFlavor() == Flavor::Decl); - return (Decl*) value.Ptr(); - } - - TupleVarDecl* getTupleDecl() const - { - SLANG_ASSERT(getFlavor() == Flavor::Tuple); - return (TupleVarDecl*) value.Ptr(); - } - - VaryingTupleVarDecl* getVaryingTupleDecl() const - { - SLANG_ASSERT(getFlavor() == Flavor::VaryingTuple); - return (VaryingTupleVarDecl*) value.Ptr(); - } - - Decl* asDecl() const - { - return (getFlavor() == Flavor::Decl) ? getDecl() : nullptr; - } - - TupleVarDecl* asTupleDecl() const - { - return (getFlavor() == Flavor::Tuple) ? getTupleDecl() : nullptr; - } - - VaryingTupleVarDecl* asVaryingTupleDecl() const - { - return (getFlavor() == Flavor::VaryingTuple) ? getVaryingTupleDecl() : nullptr; - } - -private: - RefPtr<RefObject> value; - Flavor flavor; -}; - -struct LoweredDeclRef -{ -public: - LoweredDecl decl; - RefPtr<Substitutions> substitutions; - - LoweredDecl getDecl() { return decl; } - - template<typename T> - DeclRef<T> As() - { - return DeclRef<Decl>(decl.getDecl(), substitutions).As<T>(); - } -}; +// Forward-declare types used by `LegalExpr` +class ImplicitDerefPseudoExpr; +class TuplePseudoExpr; +class PairPseudoExpr; // @@ -152,7 +63,7 @@ struct StructuralTransformVisitorBase template<typename T> DeclRef<T> transformDeclField(DeclRef<T> const& decl) { - LoweredDeclRef declRef = visitor->translateDeclRef(decl); + DeclRef<Decl> declRef = visitor->translateDeclRef(decl); return declRef.As<T>(); } @@ -204,18 +115,6 @@ struct StructuralTransformVisitorBase } }; -#if 0 -template<typename V> -RefPtr<Stmt> structuralTransform( - Stmt* stmt, - V* visitor) -{ - StructuralTransformStmtVisitor<V> transformer; - transformer.visitor = visitor; - return transformer.dispatch(stmt); -} -#endif - template<typename V> struct StructuralTransformExprVisitor : StructuralTransformVisitorBase<V> @@ -267,71 +166,71 @@ RefPtr<Expr> structuralTransform( -// The result of lowering an exrpession will usually be just a single +// The result of legalizing an exrpession will usually be just a single // expression, but it might also be a "tuple" expression that encodes // multiple expressions. -struct LoweredExpr +struct LegalExpr { - enum class Flavor - { - Expr, - Tuple, - VaryingTuple, - }; + typedef LegalType::Flavor Flavor; - LoweredExpr() - : flavor(Flavor::Expr) + LegalExpr() + : flavor(Flavor::none) {} - LoweredExpr(Expr* expr) + LegalExpr(Expr* expr) : value(expr) - , flavor(Flavor::Expr) + , flavor(Flavor::simple) {} - LoweredExpr(TupleExpr* expr) + LegalExpr(TuplePseudoExpr* expr) : value((RefObject*) expr) - , flavor(Flavor::Tuple) + , flavor(Flavor::tuple) {} - LoweredExpr(VaryingTupleExpr* expr) + LegalExpr(PairPseudoExpr* expr) : value((RefObject*) expr) - , flavor(Flavor::VaryingTuple) + , flavor(Flavor::pair) {} - Flavor getFlavor() const { return flavor; } + LegalExpr(ImplicitDerefPseudoExpr* expr) + : value((RefObject*) expr) + , flavor(Flavor::implicitDeref) + {} - Expr* getExpr() const - { - assert(getFlavor() == Flavor::Expr); - return (Expr*)value.Ptr(); - } + Flavor getFlavor() const { return flavor; } - TupleExpr* getTupleExpr() const + Expr* getSimple() const { - assert(getFlavor() == Flavor::Tuple); - return (TupleExpr*)value.Ptr(); - } + switch (getFlavor()) + { + case Flavor::none: + return nullptr; + case Flavor::simple: + return (Expr*)value.Ptr(); - VaryingTupleExpr* getVaryingTupleExpr() const - { - assert(getFlavor() == Flavor::VaryingTuple); - return (VaryingTupleExpr*)value.Ptr(); + default: + assert(getFlavor() == Flavor::simple); + return nullptr; + } } - Expr* asExpr() const + TuplePseudoExpr* getTuple() const { - return (getFlavor() == Flavor::Expr) ? getExpr() : nullptr; + assert(getFlavor() == Flavor::tuple); + return (TuplePseudoExpr*)value.Ptr(); } - TupleExpr* asTuple() const + PairPseudoExpr* getPair() const { - return (getFlavor() == Flavor::Tuple) ? getTupleExpr() : nullptr; + assert(getFlavor() == Flavor::pair); + return (PairPseudoExpr*)value.Ptr(); } - VaryingTupleExpr* asVaryingTuple() const + ImplicitDerefPseudoExpr* getImplicitDeref() const { - return (getFlavor() == Flavor::VaryingTuple) ? getVaryingTupleExpr() : nullptr; + assert(getFlavor() == Flavor::implicitDeref); + return (ImplicitDerefPseudoExpr*)value.Ptr(); } // Allow use in boolean contexts @@ -345,84 +244,77 @@ private: Flavor flavor; }; -// Pseudo-syntax used during lowering -class PseudoVarDecl : public RefObject +struct LegalTypeExpr { -public: - NameLoc nameAndLoc; - SourceLoc loc; - TypeExp type; -}; + LegalType type; + RefPtr<Expr> expr; -class TupleVarDecl : public PseudoVarDecl -{ -public: - struct Element + LegalTypeExpr() + {} + + LegalTypeExpr(LegalType const& type) + : type(type) { - RefPtr<TupleVarModifier> tupleVarMod; - LoweredDecl decl; - }; + } - TupleTypeModifier* tupleType; - RefPtr<VarDeclBase> primaryDecl; - List<Element> tupleDecls; + LegalTypeExpr(TypeExp const& typeExpr) + { + type = LegalType::simple(typeExpr.type); + expr = typeExpr.exp; + } + + TypeExp getSimple() const + { + TypeExp result; + result.type = type.getSimple(); + result.exp = expr; + return result; + } }; class PseudoExpr : public RefObject { public: - SourceLoc loc; - QualType type; + SourceLoc loc; }; -// Pseudo-syntax used during lowering: -// represents an ordered list of expressions as a single unit -class TupleExpr : public PseudoExpr +class ImplicitDerefPseudoExpr : public PseudoExpr { public: - struct Element - { - DeclRef<VarDeclBase> tupleFieldDeclRef; - LoweredExpr expr; - }; - - // Optional reference to the "primary" value of the tuple, - // in the case of a tuple type with "orinary" fields - RefPtr<Expr> primaryExpr; - - // Additional fields to store values for any non-ordinary fields - // (or fields that aren't exclusively orginary) - List<Element> tupleElements; + LegalExpr valueExpr; }; -// Pseudo-syntax used during lowering -class VaryingTupleVarDecl : public PseudoVarDecl +class TuplePseudoExpr : public PseudoExpr { public: - LoweredExpr expr; + struct Element + { + LegalExpr expr; + DeclRef<VarDeclBase> fieldDeclRef; + }; + + List<Element> elements; }; -// Pseudo-syntax used during lowering: -// represents an ordered list of expressions as a single unit -class VaryingTupleExpr : public PseudoExpr +class PairPseudoExpr : public PseudoExpr { public: - struct Element - { - DeclRef<VarDeclBase> originalFieldDeclRef; - LoweredExpr expr; - }; + LegalExpr ordinary; + LegalExpr special; - List<Element> elements; + RefPtr<PairInfo> pairInfo; }; -static SourceLoc getPosition(LoweredExpr const& expr) +static SourceLoc getPosition(LegalExpr const& expr) { switch (expr.getFlavor()) { - case LoweredExpr::Flavor::Expr: return expr.getExpr() ->loc; - case LoweredExpr::Flavor::Tuple: return expr.getTupleExpr() ->loc; - case LoweredExpr::Flavor::VaryingTuple: return expr.getVaryingTupleExpr()->loc; + case LegalExpr::Flavor::none: return SourceLoc(); + case LegalExpr::Flavor::simple: return expr.getSimple() ->loc; + case LegalExpr::Flavor::tuple: return expr.getTuple() ->loc; + case LegalExpr::Flavor::pair: return expr.getPair() ->loc; + case LegalExpr::Flavor::implicitDeref: return expr.getImplicitDeref() ->loc; + default: SLANG_UNREACHABLE("all cases handled"); UNREACHABLE_RETURN(SourceLoc()); @@ -456,7 +348,8 @@ struct SharedLoweringContext RefPtr<ModuleDecl> loweredProgram; - Dictionary<Decl*, LoweredDecl> loweredDecls; + Dictionary<Decl*, RefPtr<Decl>> mapOriginalDeclToLowered; + Dictionary<Decl*, LegalExpr> mapOriginalDeclToExpr; Dictionary<RefObject*, Decl*> mapLoweredDeclToOriginal; // Work to be done at the very start and end of the entry point @@ -474,6 +367,9 @@ struct SharedLoweringContext // The actual result we want to return LoweredEntryPoint result; + + /// State to use when legalizing types. + TypeLegalizationContext* typeLegalizationContext; }; static void attachLayout( @@ -491,9 +387,9 @@ void requireGLSLVersion( ProfileVersion version); struct LoweringVisitor - : ExprVisitor<LoweringVisitor, LoweredExpr> + : ExprVisitor<LoweringVisitor, LegalExpr> , StmtVisitor<LoweringVisitor, void> - , DeclVisitor<LoweringVisitor, LoweredDecl> + , DeclVisitor<LoweringVisitor, RefPtr<Decl>> , ValVisitor<LoweringVisitor, RefPtr<Val>, RefPtr<Type>> { // @@ -512,6 +408,11 @@ struct LoweringVisitor // then this will point to that variable. RefPtr<Variable> resultVariable; + TypeLegalizationContext* getTypeLegalizationContext() + { + return shared->typeLegalizationContext; + } + Session* getSession() { return shared->compileRequest->mSession; @@ -708,22 +609,113 @@ struct LoweringVisitor // Types // - RefPtr<Type> lowerType( + RefPtr<Type> lowerTypeEx( Type* type) { if (!type) return nullptr; - return TypeVisitor::dispatch(type); + RefPtr<Type> loweredType = dispatchType(type); + return loweredType; + } + + LegalType lowerLegalType( + LegalType legalType) + { + switch(legalType.flavor) + { + case LegalType::Flavor::none: + return LegalType(); + + case LegalType::Flavor::simple: + return LegalType::simple( + lowerTypeEx(legalType.getSimple())); + + case LegalType::Flavor::tuple: + { + auto inputTuple = legalType.getTuple(); + RefPtr<TuplePseudoType> resultTuple = new TuplePseudoType(); + for(auto ee : inputTuple->elements) + { + TuplePseudoType::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.type = lowerLegalType(ee.type); + resultTuple->elements.Add(element); + } + return LegalType::tuple(resultTuple); + } + break; + + case LegalType::Flavor::pair: + { + auto inputPair = legalType.getPair(); + RefPtr<PairPseudoType> resultPair = new PairPseudoType(); + return LegalType::pair( + lowerLegalType(inputPair->ordinaryType), + lowerLegalType(inputPair->specialType), + inputPair->pairInfo); + } + break; + + case LegalType::Flavor::implicitDeref: + { + return LegalType::implicitDeref( + lowerLegalType(legalType.getImplicitDeref()->valueType)); + } + break; + + default: + SLANG_UNEXPECTED("uhandled type flavor"); + UNREACHABLE_RETURN(LegalType()); + } } - TypeExp lowerType( + LegalType lowerAndLegalizeType( + Type* type) + { + if (!type) return LegalType(); + + // We will first attempt to legalize the type, so that any parts of + // it that won't be allowed on the target get excised. Once we are + // done with that, we will do the "lowering" process of copying + // any needed bits of AST over to the new module. + LegalType legalType = legalizeType( + getTypeLegalizationContext(), + type); + + LegalType loweredType = lowerLegalType(legalType); + + return loweredType; + } + + TypeExp lowerTypeExprEx( TypeExp const& typeExp) { TypeExp result; - result.type = lowerType(typeExp.type); - result.exp = lowerExpr(typeExp.exp); + result.type = lowerTypeEx(typeExp.type); + result.exp = legalizeSimpleExpr(typeExp.exp); + return result; + } + + LegalTypeExpr lowerAndLegalizeTypeExpr( + TypeExp const& typeExp) + { + LegalTypeExpr result; + result.type = lowerAndLegalizeType(typeExp.type); + result.expr = legalizeSimpleExpr(typeExp.exp); return result; } + RefPtr<Type> lowerAndLegalizeSimpleType( + Type* type) + { + return lowerAndLegalizeType(type).getSimple(); + } + + TypeExp lowerAndlegalizeSimpleTypeExpr( + TypeExp const& typeExp) + { + return lowerAndLegalizeTypeExpr(typeExp).getSimple(); + } + RefPtr<Type> visitIRBasicBlockType(IRBasicBlockType* type) { return type; @@ -756,13 +748,10 @@ struct LoweringVisitor { RefPtr<FuncType> loweredType = new FuncType(); loweredType->setSession(getSession()); - loweredType->resultType = lowerType(type->resultType); + loweredType->resultType = lowerTypeEx(type->resultType); for (auto paramType : type->paramTypes) { - auto loweredParamType = lowerType(paramType); - - // TODO: it seems like this step needs to scalarize - // in the case where a parameter type is a tuple... + auto loweredParamType = lowerTypeEx(paramType); loweredType->paramTypes.Add(loweredParamType); } return loweredType; @@ -781,7 +770,7 @@ struct LoweringVisitor if (shared->target == CodeGenTarget::GLSL) { // GLSL does not support `typedef`, so we will lower it out of existence here - return lowerType(GetType(type->declRef)); + return lowerTypeEx(GetType(type->declRef)); } return getNamedType( @@ -789,30 +778,15 @@ struct LoweringVisitor translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>()); } - RefPtr<Type> visitFilteredTupleType(FilteredTupleType* type) - { - RefPtr<FilteredTupleType> loweredType = new FilteredTupleType(); - loweredType->setSession(type->getSession()); - loweredType->originalType = lowerType(type->originalType); - for (auto ee : type->elements) - { - FilteredTupleType::Element element; - element.fieldDeclRef = ee.fieldDeclRef; - element.type = lowerType(ee.type); - loweredType->elements.Add(element); - } - return loweredType; - } - RefPtr<Type> visitTypeType(TypeType* type) { - return getTypeType(lowerType(type->type)); + return getTypeType(lowerTypeEx(type->type)); } RefPtr<Type> visitArrayExpressionType(ArrayExpressionType* type) { RefPtr<ArrayExpressionType> loweredType = Slang::getArrayType( - lowerType(type->baseType), + lowerTypeEx(type->baseType), lowerVal(type->ArrayLength).As<IntVal>()); return loweredType; } @@ -820,7 +794,7 @@ struct LoweringVisitor RefPtr<Type> visitGroupSharedType(GroupSharedType* type) { return getSession()->getGroupSharedType( - lowerType(type->valueType)); + lowerTypeEx(type->valueType)); } RefPtr<Type> visitParameterBlockType(ParameterBlockType* type) @@ -832,14 +806,15 @@ struct LoweringVisitor // directly to its stated element type, and see how // that works. - return lowerType(type->getElementType()); + return lowerTypeEx(type->getElementType()); // return getSession()->getConstantBufferType( // lowerType(type->getElementType()); } RefPtr<Type> transformSyntaxField(Type* type) { - return lowerType(type); + // TODO: how to handle this... + return type; } RefPtr<Val> visitIRProxyVal(IRProxyVal* val) @@ -851,32 +826,33 @@ struct LoweringVisitor // Expressions // - LoweredExpr lowerExprOrTuple( + LegalExpr legalizeExpr( Expr* expr) { - if (!expr) return LoweredExpr(); + if (!expr) return LegalExpr(); return ExprVisitor::dispatch(expr); } - RefPtr<Expr> lowerExpr( + RefPtr<Expr> legalizeSimpleExpr( Expr* expr) { if (!expr) return nullptr; - auto result = lowerExprOrTuple(expr); - return maybeReifyTuple(result); + auto type = lowerAndLegalizeType(expr->type.type); + auto result = legalizeExpr(expr); + return maybeReifyTuple(result, type).getSimple(); } // catch-all - LoweredExpr visitExpr( + LegalExpr visitExpr( Expr* expr) { - return LoweredExpr(structuralTransform(expr, this)); + return LegalExpr(structuralTransform(expr, this)); } RefPtr<Expr> transformSyntaxField(Expr* expr) { - return lowerExpr(expr); + return legalizeSimpleExpr(expr); } void lowerExprCommon( @@ -884,16 +860,16 @@ struct LoweringVisitor Expr* expr) { loweredExpr->loc = expr->loc; - loweredExpr->type.type = lowerType(expr->type.type); + loweredExpr->type.type = lowerTypeEx(expr->type.type); } void lowerExprCommon( - LoweredExpr const& loweredExpr, - Expr* expr) + LegalExpr const& legalExpr, + Expr* expr) { - if (auto simpleExpr = loweredExpr.asExpr()) + if (legalExpr.getFlavor() == LegalExpr::Flavor::simple) { - lowerExprCommon(simpleExpr, expr); + lowerExprCommon(legalExpr.getSimple(), expr); } } @@ -925,95 +901,49 @@ struct LoweringVisitor return result; } - LoweredExpr createVarRef( - SourceLoc const& loc, - LoweredDecl const& decl) + LegalExpr createVarRef( + SourceLoc const& loc, + VarDeclBase* decl) { - switch (decl.getFlavor()) - { - case LoweredDecl::Flavor::Decl: - return LoweredExpr(createSimpleVarRef(loc, decl.getDecl()->As<VarDeclBase>())); - - case LoweredDecl::Flavor::Tuple: - return createTupleRef(loc, decl.getTupleDecl()); - - case LoweredDecl::Flavor::VaryingTuple: - return createVaryingTupleRef(loc, decl.getVaryingTupleDecl()); - - default: - SLANG_UNREACHABLE("all cases handled"); - UNREACHABLE_RETURN(LoweredExpr()); - } + return LegalExpr(createSimpleVarRef(loc, decl)); } - - LoweredExpr createTupleRef( - SourceLoc const& loc, - TupleVarDecl* decl) + RefPtr<Expr> createSimpleVarExpr( + VarExpr* expr, + DeclRef<Decl> const& declRef) { - RefPtr<TupleExpr> result = new TupleExpr(); - result->loc = loc; - result->type.type = decl->type.type; - - if (auto primaryDecl = decl->primaryDecl) - { - result->primaryExpr = createSimpleVarRef(loc, primaryDecl); - } - - for (auto declElem : decl->tupleDecls) + RefPtr<VarExpr> loweredExpr = new VarExpr(); + if (expr) { - auto tupleVarMod = declElem.tupleVarMod; - SLANG_RELEASE_ASSERT(tupleVarMod); - auto tupleFieldMod = tupleVarMod->tupleField; - SLANG_RELEASE_ASSERT(tupleFieldMod); - SLANG_RELEASE_ASSERT(tupleFieldMod->decl); - - TupleExpr::Element elem; - elem.tupleFieldDeclRef = makeDeclRef(tupleFieldMod->decl); - elem.expr = createVarRef(loc, declElem.decl); - result->tupleElements.Add(elem); + lowerExprCommon(loweredExpr, expr); } - - return LoweredExpr(result); - } - - LoweredExpr createVaryingTupleRef( - SourceLoc const& /*loc*/, - VaryingTupleVarDecl* decl) - { - return decl->expr; + loweredExpr->declRef = declRef; + loweredExpr->name = expr->name; + return loweredExpr; } - LoweredExpr visitVarExpr( + LegalExpr visitVarExpr( VarExpr* expr) { // If the expression didn't get resolved, we can leave it as-is if (!expr->declRef) return expr; + // Ensure that lowering has been applied to the declaration auto loweredDeclRef = translateDeclRef(expr->declRef); - auto loweredDecl = loweredDeclRef.getDecl(); - if (auto tupleVarDecl = loweredDecl.asTupleDecl()) - { - // If we are referencing a declaration that got tuple-ified, - // then we need to produce a tuple expression as well. - - return createTupleRef(expr->loc, tupleVarDecl); - } - else if (auto varyingTupleVarDecl = loweredDecl.asVaryingTupleDecl()) - { - return createVaryingTupleRef(expr->loc, varyingTupleVarDecl); - } + // Is there a value already registered for use when looking + // up this variable? + LegalExpr legalExpr; + if (this->shared->mapOriginalDeclToExpr.TryGetValue(expr->declRef.getDecl(), legalExpr)) + return legalExpr; - RefPtr<VarExpr> loweredExpr = new VarExpr(); - lowerExprCommon(loweredExpr, expr); - loweredExpr->declRef = loweredDeclRef.As<Decl>(); - loweredExpr->name = expr->name; - return LoweredExpr(loweredExpr); + return LegalExpr(createSimpleVarExpr( + expr, + loweredDeclRef)); } - LoweredExpr visitOverloadedExpr( + LegalExpr visitOverloadedExpr( OverloadedExpr* expr) { // The presence of an overloaded expression in the output @@ -1077,48 +1007,64 @@ struct LoweringVisitor return moveTemp(expr); } - LoweredExpr maybeMoveTemp( - LoweredExpr expr) + LegalExpr maybeMoveTemp( + LegalExpr expr) { - if (auto tupleExpr = expr.asTuple()) + switch (expr.getFlavor()) { - RefPtr<TupleExpr> resultExpr = new TupleExpr(); - resultExpr->loc = tupleExpr->loc; - resultExpr->type = tupleExpr->type; - if (tupleExpr->primaryExpr) + case LegalExpr::Flavor::none: + return LegalExpr(); + + case LegalExpr::Flavor::simple: + return LegalExpr(maybeMoveTemp(expr.getSimple())); + + case LegalExpr::Flavor::tuple: { - resultExpr->primaryExpr = maybeMoveTemp(tupleExpr->primaryExpr); + auto tupleExpr = expr.getTuple(); + RefPtr<TuplePseudoExpr> resultExpr = new TuplePseudoExpr(); + resultExpr->loc = tupleExpr->loc; + + for (auto ee : tupleExpr->elements) + { + TuplePseudoExpr::Element element; + element.expr = maybeMoveTemp(ee.expr); + element.fieldDeclRef = ee.fieldDeclRef; + resultExpr->elements.Add(element); + } + + return LegalExpr(resultExpr); } - for (auto ee : tupleExpr->tupleElements) + break; + + case LegalExpr::Flavor::pair: { - TupleExpr::Element elem; - elem.tupleFieldDeclRef = ee.tupleFieldDeclRef; - elem.expr = maybeMoveTemp(ee.expr); + auto pairExpr = expr.getPair(); + RefPtr<PairPseudoExpr> resultExpr = new PairPseudoExpr(); + resultExpr->loc = pairExpr->loc; + resultExpr->pairInfo = pairExpr->pairInfo; - resultExpr->tupleElements.Add(elem); + resultExpr->ordinary = maybeMoveTemp(pairExpr->ordinary); + resultExpr->special = maybeMoveTemp(pairExpr->special); + + return LegalExpr(resultExpr); } + break; - return LoweredExpr(resultExpr); - } - else if (auto varyingTupleExpr = expr.asVaryingTuple()) - { - RefPtr<VaryingTupleExpr> resultExpr = new VaryingTupleExpr(); - resultExpr->loc = varyingTupleExpr->loc; - resultExpr->type = varyingTupleExpr->type; - for (auto ee : varyingTupleExpr->elements) + case LegalExpr::Flavor::implicitDeref: { - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = ee.originalFieldDeclRef; - elem.expr = maybeMoveTemp(ee.expr); + auto implicitDerefExpr = expr.getImplicitDeref(); + RefPtr<ImplicitDerefPseudoExpr> resultExpr = new ImplicitDerefPseudoExpr(); + resultExpr->loc = implicitDerefExpr->loc; + + resultExpr->valueExpr = maybeMoveTemp(implicitDerefExpr->valueExpr); - resultExpr->elements.Add(elem); + return LegalExpr(resultExpr); } + break; - return LoweredExpr(resultExpr); - } - else - { - return LoweredExpr(maybeMoveTemp(expr.getExpr())); + default: + SLANG_UNEXPECTED("unhandled case"); + UNREACHABLE_RETURN(LegalExpr()); } } @@ -1133,8 +1079,8 @@ struct LoweringVisitor return expr; } - LoweredExpr ensureSimpleLValue( - LoweredExpr expr) + LegalExpr ensureSimpleLValue( + LegalExpr expr) { // TODO: actually implement this properly! @@ -1393,197 +1339,139 @@ struct LoweringVisitor } } - LoweredExpr createAssignExpr( - LoweredExpr leftExpr, - LoweredExpr rightExpr, + LegalExpr createAssignExpr( + LegalExpr leftExpr, + LegalExpr rightExpr, AssignMode mode = AssignMode::Default) { - auto leftTuple = leftExpr.asTuple(); - auto rightTuple = rightExpr.asTuple(); - if (leftTuple && rightTuple) + switch (leftExpr.getFlavor()) { - RefPtr<TupleExpr> resultTuple = new TupleExpr(); - resultTuple->type = leftTuple->type; - - if (leftTuple->primaryExpr) - { - SLANG_RELEASE_ASSERT(rightTuple->primaryExpr); - - resultTuple->primaryExpr = createSimpleAssignExpr( - leftTuple->primaryExpr, - rightTuple->primaryExpr, - mode); - } + case LegalExpr::Flavor::none: + return LegalExpr(); - auto elementCount = leftTuple->tupleElements.Count(); - SLANG_RELEASE_ASSERT(elementCount == rightTuple->tupleElements.Count()); - for (UInt ee = 0; ee < elementCount; ++ee) + case LegalExpr::Flavor::simple: + switch (rightExpr.getFlavor()) { - auto leftElement = leftTuple->tupleElements[ee]; - auto rightElement = rightTuple->tupleElements[ee]; + case LegalExpr::Flavor::simple: + return LegalExpr(createSimpleAssignExpr( + leftExpr.getSimple(), + rightExpr.getSimple(), + mode)); - TupleExpr::Element resultElement; - - resultElement.tupleFieldDeclRef = leftElement.tupleFieldDeclRef; - resultElement.expr = createAssignExpr( - leftElement.expr, - rightElement.expr, - mode); + case LegalExpr::Flavor::tuple: + { + auto rightTuple = rightExpr.getTuple(); + RefPtr<TuplePseudoExpr> resultTuple = new TuplePseudoExpr(); + for (auto ee : rightTuple->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = createAssignExpr( + extractField(leftExpr, ee.fieldDeclRef), + ee.expr, + mode); + + resultTuple->elements.Add(element); + } + return LegalExpr(resultTuple); + } + break; - resultTuple->tupleElements.Add(resultElement); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); } + break; - return LoweredExpr(resultTuple); - } - else - { - SLANG_RELEASE_ASSERT(!leftTuple && !rightTuple); - } - - auto leftVaryingTuple = leftExpr.asVaryingTuple(); - auto rightVaryingTuple = rightExpr.asVaryingTuple(); - - RefPtr<Expr> leftSimpleExpr = leftExpr.asExpr(); - RefPtr<Expr> rightSimpleExpr = rightExpr.asExpr(); - - if (leftVaryingTuple && rightVaryingTuple) - { - RefPtr<VaryingTupleExpr> resultTuple = new VaryingTupleExpr(); - resultTuple->type.type = leftVaryingTuple->type.type; - resultTuple->loc = leftVaryingTuple->loc; - - SLANG_RELEASE_ASSERT(resultTuple->type.type); - - UInt elementCount = leftVaryingTuple->elements.Count(); - SLANG_RELEASE_ASSERT(elementCount == rightVaryingTuple->elements.Count()); - - for (UInt ee = 0; ee < elementCount; ++ee) + case LegalExpr::Flavor::tuple: { - auto leftElem = leftVaryingTuple->elements[ee]; - auto rightElem = rightVaryingTuple->elements[ee]; - - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = leftElem.originalFieldDeclRef; - elem.expr = createAssignExpr( - leftElem.expr, - rightElem.expr, - mode); - } + rightExpr = maybeMoveTemp(rightExpr); - return LoweredExpr(resultTuple); - } - else if (leftVaryingTuple && rightSimpleExpr) - { - // Assigning from ordinary expression on RHS to tuple. - // This will naturally yield a tuple expression. - - RefPtr<VaryingTupleExpr> resultTuple = new VaryingTupleExpr(); - resultTuple->type.type = leftVaryingTuple->type.type; - resultTuple->loc = leftVaryingTuple->loc; - - SLANG_RELEASE_ASSERT(resultTuple->type.type); - - UInt elementCount = leftVaryingTuple->elements.Count(); - - // Move everything into temps if we can - rightSimpleExpr = maybeMoveTemp(rightSimpleExpr); - for (UInt ee = 0; ee < elementCount; ++ee) - { - auto& leftElem = leftVaryingTuple->elements[ee]; - leftElem.expr = ensureSimpleLValue(leftElem.expr); + auto leftTuple = leftExpr.getTuple(); + RefPtr<TuplePseudoExpr> resultTuple = new TuplePseudoExpr(); + resultTuple->loc = leftTuple->loc; + for (auto ee : leftTuple->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = createAssignExpr( + ee.expr, + extractField(rightExpr, ee.fieldDeclRef), + mode); + + resultTuple->elements.Add(element); + } + return LegalExpr(resultTuple); } + break; - // - - for (UInt ee = 0; ee < elementCount; ++ee) + case LegalExpr::Flavor::pair: { - auto leftElem = leftVaryingTuple->elements[ee]; - - - RefPtr<MemberExpr> rightElemExpr = new MemberExpr(); - rightElemExpr->loc = rightSimpleExpr->loc; - rightElemExpr->type.type = GetType(leftElem.originalFieldDeclRef); - rightElemExpr->declRef = leftElem.originalFieldDeclRef; - rightElemExpr->name = leftElem.originalFieldDeclRef.GetName(); - rightElemExpr->BaseExpression = rightSimpleExpr; - - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = leftElem.originalFieldDeclRef; - elem.expr = createAssignExpr( - leftElem.expr, - LoweredExpr(rightElemExpr), - mode); - - resultTuple->elements.Add(elem); - } - - return LoweredExpr(resultTuple); - } - else if (leftSimpleExpr && rightVaryingTuple) - { - // Pretty much the same as the above case, and we should - // probably try to share code eventually. - - - RefPtr<VaryingTupleExpr> resultTuple = new VaryingTupleExpr(); - resultTuple->type.type = leftSimpleExpr->type.type; - resultTuple->loc = leftSimpleExpr->loc; - - SLANG_RELEASE_ASSERT(resultTuple->type.type); - - UInt elementCount = rightVaryingTuple->elements.Count(); + auto leftPair = leftExpr.getPair(); + switch( rightExpr.getFlavor() ) + { + case LegalExpr::Flavor::pair: + { + auto rightPair = rightExpr.getPair(); + RefPtr<PairPseudoExpr> resultPair = new PairPseudoExpr(); + resultPair->loc = leftPair->loc; + resultPair->pairInfo = leftPair->pairInfo; + + resultPair->ordinary = createAssignExpr( + leftPair->ordinary, + rightPair->ordinary, + mode); + resultPair->special = createAssignExpr( + leftPair->special, + rightPair->special, + mode); + + return LegalExpr(resultPair); + } + break; - // Move everything into temps if we can - leftSimpleExpr = ensureSimpleLValue(leftSimpleExpr); - for (UInt ee = 0; ee < elementCount; ++ee) - { - auto& rightElem = rightVaryingTuple->elements[ee]; - rightElem.expr = maybeMoveTemp(rightElem.expr); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); + } } + break; - - for (UInt ee = 0; ee < elementCount; ++ee) + case LegalExpr::Flavor::implicitDeref: { - auto rightElem = rightVaryingTuple->elements[ee]; - - - RefPtr<MemberExpr> leftElemExpr = new MemberExpr(); - leftElemExpr->loc = leftSimpleExpr->loc; - leftElemExpr->type.type = GetType(rightElem.originalFieldDeclRef); - leftElemExpr->declRef = rightElem.originalFieldDeclRef; - leftElemExpr->name = rightElem.originalFieldDeclRef.GetName(); - leftElemExpr->BaseExpression = leftSimpleExpr; - - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = rightElem.originalFieldDeclRef; - elem.expr = createAssignExpr( - LoweredExpr(leftElemExpr), - rightElem.expr, - mode); + auto leftImplicitDeref = leftExpr.getImplicitDeref(); + switch(rightExpr.getFlavor()) + { + case LegalExpr::Flavor::implicitDeref: + { + auto rightImplicitDeref = rightExpr.getImplicitDeref(); + RefPtr<ImplicitDerefPseudoExpr> resultImplicitDeref = new ImplicitDerefPseudoExpr(); + resultImplicitDeref->loc = leftImplicitDeref->loc; + resultImplicitDeref->valueExpr = createAssignExpr( + leftImplicitDeref->valueExpr, + rightImplicitDeref->valueExpr, + mode); + + return LegalExpr(resultImplicitDeref); + } - resultTuple->elements.Add(elem); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); + } } - return LoweredExpr(resultTuple); - } - else if (leftSimpleExpr && rightSimpleExpr) - { - // Default case: no tuples of any kind... - - return LoweredExpr(createSimpleAssignExpr(leftSimpleExpr, rightSimpleExpr, mode)); - } - else - { - // Some case wasn't handled: diagnose! - SLANG_UNEXPECTED("bad combination of tuple types"); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); } } - LoweredExpr visitAssignExpr( + LegalExpr visitAssignExpr( AssignExpr* expr) { - auto leftExpr = lowerExprOrTuple(expr->left); - auto rightExpr = lowerExprOrTuple(expr->right); + auto leftExpr = legalizeExpr(expr->left); + auto rightExpr = legalizeExpr(expr->right); auto loweredExpr = createAssignExpr(leftExpr, rightExpr); lowerExprCommon(loweredExpr, expr); @@ -1614,10 +1502,77 @@ struct LoweringVisitor return loweredExpr; } - LoweredExpr createSubscriptExpr( - LoweredExpr baseExpr, + LegalExpr createSubscriptExpr( + LegalExpr baseExpr, RefPtr<Expr> indexExpr) { + switch (baseExpr.getFlavor()) + { + case LegalExpr::Flavor::none: + return LegalExpr(); + + case LegalExpr::Flavor::simple: + return LegalExpr(createSimpleSubscriptExpr( + baseExpr.getSimple(), + indexExpr)); + + case LegalExpr::Flavor::tuple: + { + indexExpr = maybeMoveTemp(indexExpr); + + auto baseTuple = baseExpr.getTuple(); + + auto resultTuple = new TuplePseudoExpr(); + resultTuple->loc = baseTuple->loc; + + for (auto ee : baseTuple->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = createSubscriptExpr( + ee.expr, + indexExpr); + + resultTuple->elements.Add(element); + } + + return LegalExpr(resultTuple); + } + break; + + case LegalExpr::Flavor::pair: + { + indexExpr = maybeMoveTemp(indexExpr); + + auto basePair = baseExpr.getPair(); + + RefPtr<PairPseudoExpr> resultPair = new PairPseudoExpr(); + resultPair->loc = basePair->loc; + + resultPair->ordinary = createSubscriptExpr(basePair->ordinary, indexExpr); + resultPair->special = createSubscriptExpr(basePair->special, indexExpr); + + return LegalExpr(resultPair); + } + + case LegalExpr::Flavor::implicitDeref: + { + auto baseImplicitDeref = baseExpr.getImplicitDeref(); + + RefPtr<ImplicitDerefPseudoExpr> resultImplicitDeref = new ImplicitDerefPseudoExpr(); + resultImplicitDeref->loc = baseImplicitDeref->loc; + + resultImplicitDeref->valueExpr = createSubscriptExpr(baseImplicitDeref->valueExpr, indexExpr); + + return LegalExpr(resultImplicitDeref); + } + + default: + SLANG_UNEXPECTED("unhandled case"); + UNREACHABLE_RETURN(LegalExpr()); + } + +#if 0 // TODO: This logic ends up duplicating the `indexExpr` // that was given, without worrying about any side // effects it might contain. That needs to be fixed. @@ -1670,71 +1625,30 @@ struct LoweringVisitor } else { - return LoweredExpr(createSimpleSubscriptExpr( + return LegalExpr(createSimpleSubscriptExpr( baseExpr.getExpr(), indexExpr)); } +#endif } - LoweredExpr visitIndexExpr( + LegalExpr visitIndexExpr( IndexExpr* subscriptExpr) { - auto baseExpr = lowerExprOrTuple(subscriptExpr->BaseExpression); - auto indexExpr = lowerExpr(subscriptExpr->IndexExpression); + auto baseExpr = legalizeExpr(subscriptExpr->BaseExpression); + auto indexExpr = legalizeSimpleExpr(subscriptExpr->IndexExpression); - // An attempt to subscript a tuple must be turned into a - // tuple of subscript expressions. - if (auto baseTuple = baseExpr.asTuple()) - { - return createSubscriptExpr(baseExpr, indexExpr); - } - else if (auto baseVaryingTuple = baseExpr.asVaryingTuple()) - { - return createSubscriptExpr(baseExpr, indexExpr); - } - else + if(baseExpr.getFlavor() == LegalExpr::Flavor::simple) { // Default case: just reconstrut a subscript expr RefPtr<IndexExpr> loweredExpr = new IndexExpr(); lowerExprCommon(loweredExpr, subscriptExpr); - loweredExpr->BaseExpression = baseExpr.getExpr(); + loweredExpr->BaseExpression = baseExpr.getSimple(); loweredExpr->IndexExpression = indexExpr; - return LoweredExpr(loweredExpr); - } - } - - RefPtr<Expr> maybeReifyTuple( - LoweredExpr expr) - { - if (auto tupleExpr = expr.asTuple()) - { - // TODO: need to diagnose - return tupleExpr->primaryExpr; - } - else if (auto varyingTupleExpr = expr.asVaryingTuple()) - { - // Need to pass an ordinary (non-tuple) expression of - // the corresponding type here. - - // TODO(tfoley): This won't work at all for an `out` or `inout` - // function argument, so we'll need to figure out a plan - // to handle that case... - - RefPtr<AggTypeCtorExpr> resultExpr = new AggTypeCtorExpr(); - resultExpr->type = varyingTupleExpr->type; - resultExpr->base.type = varyingTupleExpr->type.type; - SLANG_RELEASE_ASSERT(resultExpr->type.type); - - for (auto elem : varyingTupleExpr->elements) - { - addArgs(resultExpr, elem.expr); - } - - return resultExpr; + return LegalExpr(loweredExpr); } - // Default case: nothing special to this expression - return expr.getExpr(); + return createSubscriptExpr(baseExpr, indexExpr); } bool needGlslangBug988Workaround( @@ -1833,31 +1747,130 @@ struct LoweringVisitor callExpr->Arguments.Add(argExpr); } + // Take a legalized expression that might be represented as a tuple, + // and turn it back into a single ordinary expression of the given type. + // + // This is used in the case where we tuple-ified a value that has + // a legal type, but just isn't legal to use in a particular context. + static RefPtr<Expr> reifyTuple( + LegalExpr legalExpr, + RefPtr<Type> type) + { + if (legalExpr.getFlavor() == LegalExpr::Flavor::simple) + return legalExpr.getSimple(); + + if (auto declRefType = type->As<DeclRefType>()) + { + auto declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + { + // We want a single value of an aggregate type, which + // means we need to extract each of its fields from + // the expression. + + switch (legalExpr.getFlavor()) + { + case LegalExpr::Flavor::tuple: + { + auto tupleExpr = legalExpr.getTuple(); + + RefPtr<AggTypeCtorExpr> resultExpr = new AggTypeCtorExpr(); + resultExpr->type.type = type; + resultExpr->base.type = type; + SLANG_RELEASE_ASSERT(resultExpr->type.type); + + UInt fieldCounter = 0; + for (auto fieldDeclRef : getMembersOfType<StructField>(aggTypeDeclRef)) + { + if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + continue; + + UInt fieldIndex = fieldCounter++; + + resultExpr->Arguments.Add(reifyTuple( + tupleExpr->elements[fieldIndex].expr, + GetType(fieldDeclRef))); + } + + return resultExpr; + } + break; + } + + } + } + // TODO: need to handle array types here... + + SLANG_UNEXPECTED("unhandled case"); + UNREACHABLE_RETURN(legalExpr.getSimple()); + } + + static LegalExpr maybeReifyTuple( + LegalExpr legalExpr, + LegalType expectedType) + { + if (expectedType.flavor != LegalType::Flavor::simple) + return legalExpr; + + if (legalExpr.getFlavor() == LegalExpr::Flavor::simple) + return legalExpr; + + return LegalExpr(reifyTuple(legalExpr, expectedType.getSimple())); + } + void addArgs( ExprWithArgsBase* callExpr, - LoweredExpr argExpr) + LegalType argType, + LegalExpr argExpr) { - if (auto argTuple = argExpr.asTuple()) + argExpr = maybeReifyTuple(argExpr, argType); + + if (argExpr.getFlavor() != argType.flavor) + { + SLANG_UNEXPECTED("expression and type do not match"); + } + + switch (argExpr.getFlavor()) { - if (argTuple->primaryExpr) + case LegalExpr::Flavor::none: + break; + + case LegalExpr::Flavor::simple: + addArg(callExpr, argExpr.getSimple()); + break; + + case LegalExpr::Flavor::tuple: { - addArg(callExpr, argTuple->primaryExpr); + auto aa = argExpr.getTuple(); + auto at = argType.getTuple(); + auto elementCount = aa->elements.Count(); + for (UInt ee = 0; ee < elementCount; ++ee) + { + addArgs(callExpr, at->elements[ee].type, aa->elements[ee].expr); + } } - for (auto elem : argTuple->tupleElements) + break; + + case LegalExpr::Flavor::pair: { - addArgs(callExpr, elem.expr); + auto aa = argExpr.getPair(); + auto at = argType.getPair(); + addArgs(callExpr, at->ordinaryType, aa->ordinary); + addArgs(callExpr, at->specialType, aa->special); } - } - else if (auto varyingArgTuple = argExpr.asVaryingTuple()) - { - // Need to pass an ordinary (non-tuple) expression of - // the corresponding type here. + break; - callExpr->Arguments.Add(maybeReifyTuple(argExpr)); - } - else - { - addArg(callExpr, argExpr.getExpr()); + case LegalExpr::Flavor::implicitDeref: + { + auto aa = argExpr.getImplicitDeref(); + auto at = argType.getImplicitDeref(); + addArgs(callExpr, at->valueType, aa->valueExpr); + } + break; + + default: + SLANG_UNEXPECTED("unhandled case"); + break; } } @@ -1867,38 +1880,51 @@ struct LoweringVisitor { lowerExprCommon(loweredExpr, expr); - loweredExpr->FunctionExpr = lowerExpr(expr->FunctionExpr); + loweredExpr->FunctionExpr = legalizeSimpleExpr(expr->FunctionExpr); for (auto arg : expr->Arguments) { - auto loweredArg = lowerExprOrTuple(arg); - addArgs(loweredExpr, loweredArg); + auto argType = lowerAndLegalizeType(arg->type.type); + auto loweredArg = legalizeExpr(arg); + addArgs(loweredExpr, argType, loweredArg); } return loweredExpr; } - LoweredExpr visitInvokeExpr( + LegalExpr visitInvokeExpr( InvokeExpr* expr) { // Create a clone with the same class InvokeExpr* loweredExpr = (InvokeExpr*) expr->getClass().createInstance(); - return LoweredExpr(lowerCallExpr(loweredExpr, expr)); + return LegalExpr(lowerCallExpr(loweredExpr, expr)); } - LoweredExpr visitSelectExpr( + LegalExpr visitSelectExpr( SelectExpr* expr) { // TODO: A tuple needs to be special-cased here - return LoweredExpr(lowerCallExpr(new SelectExpr(), expr)); + return LegalExpr(lowerCallExpr(new SelectExpr(), expr)); } - LoweredExpr visitDerefExpr( + LegalExpr visitDerefExpr( DerefExpr* expr) { - auto loweredBase = lowerExprOrTuple(expr->base); + auto legalBase = legalizeExpr(expr->base); + if (legalBase.getFlavor() == LegalExpr::Flavor::simple) + { + // Default case is just to lower a dereference opertion + // into another dereference. + RefPtr<DerefExpr> loweredExpr = new DerefExpr(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->base = legalBase.getSimple(); + return LegalExpr(loweredExpr); + } + + return deref(legalBase); +#if 0 if (auto baseTuple = loweredBase.asTuple()) { // In the case of a tuple created for "resources in structs" reasons, @@ -1938,13 +1964,7 @@ struct LoweringVisitor // // TODO: implement this. } - - // Default case is just to lower a dereference opertion - // into another dereference. - RefPtr<DerefExpr> loweredExpr = new DerefExpr(); - lowerExprCommon(loweredExpr, expr); - loweredExpr->base = loweredBase.getExpr(); - return LoweredExpr(loweredExpr); +#endif } DiagnosticSink* getSink() @@ -1952,108 +1972,43 @@ struct LoweringVisitor return &shared->compileRequest->mSink; } - LoweredExpr visitStaticMemberExpr( + LegalExpr visitStaticMemberExpr( StaticMemberExpr* expr) { - auto loweredBase = lowerExprOrTuple(expr->BaseExpression); + auto loweredBase = legalizeExpr(expr->BaseExpression); auto loweredDeclRef = translateDeclRef(expr->declRef); // TODO: we should probably support type-type members here. RefPtr<StaticMemberExpr> loweredExpr = new StaticMemberExpr(); lowerExprCommon(loweredExpr, expr); - loweredExpr->BaseExpression = loweredBase.getExpr(); + loweredExpr->BaseExpression = loweredBase.getSimple(); loweredExpr->declRef = loweredDeclRef.As<Decl>(); loweredExpr->name = expr->name; - return LoweredExpr(loweredExpr); + return LegalExpr(loweredExpr); } - LoweredExpr visitMemberExpr( + LegalExpr visitMemberExpr( MemberExpr* expr) { assert(expr->BaseExpression); - auto loweredBase = lowerExprOrTuple(expr->BaseExpression); - assert(loweredBase); - - auto loweredDeclRef = translateDeclRef(expr->declRef); - - - // Are we extracting an element from a tuple? - if (auto baseTuple = loweredBase.asTuple()) - { - auto loweredFieldDecl = loweredDeclRef.As<Decl>().getDecl(); - auto tupleFieldMod = loweredFieldDecl->FindModifier<TupleFieldModifier>(); - if (tupleFieldMod) - { - // This field has a tuple part to it, so we need to search for it - - LoweredExpr tupleFieldExpr; - for (auto elem : baseTuple->tupleElements) - { - if (loweredFieldDecl == elem.tupleFieldDeclRef.getDecl()) - { - tupleFieldExpr = elem.expr; - break; - } - } - - if (!tupleFieldMod->hasAnyNonTupleFields) - { - // We need to have found something! - assert(tupleFieldExpr); - return tupleFieldExpr; - } - - auto tupleFieldTupleExpr = tupleFieldExpr.asTuple(); - SLANG_RELEASE_ASSERT(tupleFieldTupleExpr); - SLANG_RELEASE_ASSERT(!tupleFieldTupleExpr->primaryExpr); - - - RefPtr<MemberExpr> loweredPrimaryExpr = new MemberExpr(); - lowerExprCommon(loweredPrimaryExpr, expr); - loweredPrimaryExpr->BaseExpression = baseTuple->primaryExpr; - loweredPrimaryExpr->declRef = loweredDeclRef.As<Decl>(); - loweredPrimaryExpr->name = expr->name; - - assert(loweredPrimaryExpr->BaseExpression); + auto legalBase = legalizeExpr(expr->BaseExpression); + assert(legalBase); - tupleFieldTupleExpr->primaryExpr = loweredPrimaryExpr; - return tupleFieldTupleExpr; - } - - // If the field was a non-tuple field, then we can - // simply fall through to the ordinary case below. - loweredBase = LoweredExpr(baseTuple->primaryExpr); - assert(baseTuple->primaryExpr); - } - else if (auto baseVaryingTuple = loweredBase.asVaryingTuple()) + if (legalBase.getFlavor() == LegalExpr::Flavor::simple) { - // Search for the element corresponding to this field - for(auto elem : baseVaryingTuple->elements) - { - if (expr->declRef.getDecl() == elem.originalFieldDeclRef.getDecl()) - { - // We found the field! - assert(elem.expr); - return elem.expr; - } - } - - SLANG_DIAGNOSE_UNEXPECTED(getSink(), expr, "failed to find tuple field during lowering"); + // Default handling: + RefPtr<MemberExpr> loweredExpr = new MemberExpr(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->BaseExpression = legalBase.getSimple(); + loweredExpr->declRef = translateDeclRef(expr->declRef); + loweredExpr->name = expr->name; + assert(loweredExpr->BaseExpression); + return LegalExpr(loweredExpr); } - // Default handling: - - RefPtr<MemberExpr> loweredExpr = new MemberExpr(); - lowerExprCommon(loweredExpr, expr); - loweredExpr->BaseExpression = loweredBase.getExpr(); - loweredExpr->declRef = loweredDeclRef.As<Decl>(); - loweredExpr->name = expr->name; - - assert(loweredExpr->BaseExpression); - - return LoweredExpr(loweredExpr); + return extractField(legalBase, expr->declRef.As<VarDeclBase>()); } // @@ -2123,18 +2078,18 @@ struct LoweringVisitor StmtVisitor::dispatch(stmt); } - LoweredDecl visitScopeDecl(ScopeDecl* decl) + RefPtr<Decl> visitScopeDecl(ScopeDecl* decl) { RefPtr<ScopeDecl> loweredDecl = new ScopeDecl(); lowerDeclCommon(loweredDecl, decl); - return LoweredDecl(loweredDecl); + return loweredDecl; } LoweringVisitor pushScope( RefPtr<ScopeStmt> loweredStmt, RefPtr<ScopeStmt> originalStmt) { - loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).getDecl()->As<ScopeDecl>(); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl)->As<ScopeDecl>(); LoweringVisitor subVisitor = *this; subVisitor.isBuildingStmt = true; @@ -2215,32 +2170,46 @@ struct LoweringVisitor } void addExprStmt( - LoweredExpr expr) + LegalExpr expr) { // Desugar tuples in statement position - if (auto tupleExpr = expr.asTuple()) + switch (expr.getFlavor()) { - if (tupleExpr->primaryExpr) + case LegalExpr::Flavor::none: + break; + + case LegalExpr::Flavor::simple: + addSimpleExprStmt(expr.getSimple()); + break; + + case LegalExpr::Flavor::tuple: { - addSimpleExprStmt(tupleExpr->primaryExpr); + auto tupleExpr = expr.getTuple(); + for (auto ee : tupleExpr->elements) + { + addExprStmt(ee.expr); + } } - for (auto ee : tupleExpr->tupleElements) + break; + + case LegalExpr::Flavor::pair: { - addExprStmt(ee.expr); + auto pairExpr = expr.getPair(); + addExprStmt(pairExpr->ordinary); + addExprStmt(pairExpr->special); } - return; - } - else if (auto varyingTupleExpr = expr.asVaryingTuple()) - { - for (auto ee : varyingTupleExpr->elements) + break; + + case LegalExpr::Flavor::implicitDeref: { - addExprStmt(ee.expr); + auto implicitDerefExpr = expr.getImplicitDeref(); + addExprStmt(implicitDerefExpr->valueExpr); } - return; - } - else - { - addSimpleExprStmt(expr.getExpr()); + break; + + default: + SLANG_UNEXPECTED("unhandled case"); + break; } } @@ -2266,7 +2235,7 @@ struct LoweringVisitor void visitExpressionStmt(ExpressionStmt* stmt) { - addExprStmt(lowerExprOrTuple(stmt->Expression)); + addExprStmt(legalizeExpr(stmt->Expression)); } void visitDeclStmt(DeclStmt* stmt) @@ -2297,7 +2266,7 @@ struct LoweringVisitor ScopeStmt* originalStmt) { lowerStmtFields(loweredStmt, originalStmt); - loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).getDecl()->As<ScopeDecl>(); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl)->As<ScopeDecl>(); } // Child statements reference their parent statement, @@ -2361,7 +2330,7 @@ struct LoweringVisitor RefPtr<CaseStmt> loweredStmt = new CaseStmt(); lowerChildStmtFields(loweredStmt, stmt); - loweredStmt->expr = lowerExpr(stmt->expr); + loweredStmt->expr = legalizeSimpleExpr(stmt->expr); addStmt(loweredStmt); } @@ -2371,7 +2340,7 @@ struct LoweringVisitor RefPtr<IfStmt> loweredStmt = new IfStmt(); lowerStmtFields(loweredStmt, stmt); - loweredStmt->Predicate = lowerExpr(stmt->Predicate); + loweredStmt->Predicate = legalizeSimpleExpr(stmt->Predicate); loweredStmt->PositiveStatement = lowerStmt(stmt->PositiveStatement); loweredStmt->NegativeStatement = lowerStmt(stmt->NegativeStatement); @@ -2385,7 +2354,7 @@ struct LoweringVisitor LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); - loweredStmt->condition = subVisitor.lowerExpr(stmt->condition); + loweredStmt->condition = subVisitor.legalizeSimpleExpr(stmt->condition); loweredStmt->body = subVisitor.lowerStmt(stmt->body); addStmt(loweredStmt); @@ -2400,8 +2369,8 @@ struct LoweringVisitor LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); loweredStmt->InitialStatement = subVisitor.lowerStmt(stmt->InitialStatement); - loweredStmt->SideEffectExpression = subVisitor.lowerExpr(stmt->SideEffectExpression); - loweredStmt->PredicateExpression = subVisitor.lowerExpr(stmt->PredicateExpression); + loweredStmt->SideEffectExpression = subVisitor.legalizeSimpleExpr(stmt->SideEffectExpression); + loweredStmt->PredicateExpression = subVisitor.legalizeSimpleExpr(stmt->PredicateExpression); loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement); addStmt(loweredStmt); @@ -2431,8 +2400,9 @@ struct LoweringVisitor return; auto varDecl = stmt->varDecl; + shared->mapOriginalDeclToLowered[varDecl] = nullptr; - auto varType = lowerType(varDecl->type); + auto varType = lowerTypeExprEx(varDecl->type); for (IntegerLiteralValue ii = rangeBeginVal; ii < rangeEndVal; ++ii) { @@ -2441,12 +2411,7 @@ struct LoweringVisitor constExpr->ConstType = ConstantExpr::ConstantType::Int; constExpr->integerValue = ii; - RefPtr<VaryingTupleVarDecl> loweredVarDecl = new VaryingTupleVarDecl(); - loweredVarDecl->loc = varDecl->loc; - loweredVarDecl->type = varType; - loweredVarDecl->expr = LoweredExpr(constExpr); - - shared->loweredDecls[varDecl] = LoweredDecl(loweredVarDecl); + shared->mapOriginalDeclToExpr[varDecl] = LegalExpr(constExpr); lowerStmtImpl(stmt->body); } @@ -2459,7 +2424,7 @@ struct LoweringVisitor LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); - loweredStmt->Predicate = subVisitor.lowerExpr(stmt->Predicate); + loweredStmt->Predicate = subVisitor.legalizeSimpleExpr(stmt->Predicate); loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement); addStmt(loweredStmt); @@ -2473,7 +2438,7 @@ struct LoweringVisitor LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement); - loweredStmt->Predicate = subVisitor.lowerExpr(stmt->Predicate); + loweredStmt->Predicate = subVisitor.legalizeSimpleExpr(stmt->Predicate); addStmt(loweredStmt); } @@ -2489,22 +2454,22 @@ struct LoweringVisitor } void assign( - LoweredExpr destExpr, - LoweredExpr srcExpr, + LegalExpr destExpr, + LegalExpr srcExpr, AssignMode mode = AssignMode::Default) { auto assignExpr = createAssignExpr(destExpr, srcExpr, mode); addExprStmt(assignExpr); } - void assign(VarDeclBase* varDecl, LoweredExpr expr) + void assign(VarDeclBase* varDecl, LegalExpr expr) { - assign(LoweredExpr(createVarRef(getPosition(expr), varDecl)), expr); + assign(LegalExpr(createVarRef(getPosition(expr), varDecl)), expr); } - void assign(LoweredExpr expr, VarDeclBase* varDecl) + void assign(LegalExpr expr, VarDeclBase* varDecl) { - assign(expr, LoweredExpr(createVarRef(getPosition(expr), varDecl))); + assign(expr, LegalExpr(createVarRef(getPosition(expr), varDecl))); } RefPtr<Expr> createTypeExpr( @@ -2537,20 +2502,20 @@ struct LoweringVisitor // where the types don't actually line up, because of // differences in how something is declared in HLSL vs. GLSL void assignWithFixups( - LoweredExpr destExpr, - LoweredExpr srcExpr) + LegalExpr destExpr, + LegalExpr srcExpr) { assign(destExpr, srcExpr, AssignMode::WithFixups); } - void assignWithFixups(VarDeclBase* varDecl, LoweredExpr expr) + void assignWithFixups(VarDeclBase* varDecl, LegalExpr expr) { - assignWithFixups(LoweredExpr(createVarRef(getPosition(expr), varDecl)), expr); + assignWithFixups(LegalExpr(createVarRef(getPosition(expr), varDecl)), expr); } - void assignWithFixups(LoweredExpr expr, VarDeclBase* varDecl) + void assignWithFixups(LegalExpr expr, VarDeclBase* varDecl) { - assignWithFixups(expr, LoweredExpr(createVarRef(getPosition(expr), varDecl))); + assignWithFixups(expr, LegalExpr(createVarRef(getPosition(expr), varDecl))); } void visitReturnStmt(ReturnStmt* stmt) @@ -2563,12 +2528,12 @@ struct LoweringVisitor if (resultVariable) { // Do it as an assignment - assign(resultVariable, lowerExprOrTuple(stmt->Expression)); + assign(resultVariable, legalizeExpr(stmt->Expression)); } else { // Simple case - loweredStmt->Expression = lowerExpr(stmt->Expression); + loweredStmt->Expression = legalizeSimpleExpr(stmt->Expression); } } @@ -2582,7 +2547,7 @@ struct LoweringVisitor RefPtr<Val> translateVal(Val* val) { if (auto type = dynamic_cast<Type*>(val)) - return lowerType(type); + return lowerTypeEx(type); if (auto litVal = dynamic_cast<ConstantIntVal*>(val)) return val; @@ -2597,7 +2562,7 @@ struct LoweringVisitor if (auto genSubst = dynamic_cast<GenericSubstitution*>(inSubstitutions)) { RefPtr<GenericSubstitution> result = new GenericSubstitution(); - result->genericDecl = translateDeclRef(genSubst->genericDecl).getDecl()->As<GenericDecl>(); + result->genericDecl = translateDeclRef(genSubst->genericDecl)->As<GenericDecl>(); for (auto arg : genSubst->args) { result->args.Add(translateVal(arg)); @@ -2622,38 +2587,26 @@ struct LoweringVisitor return decl; } - LoweredDeclRef translateDeclRef( + DeclRef<Decl> translateDeclRef( DeclRef<Decl> const& declRef) { - LoweredDeclRef result; + DeclRef<Decl> result; result.decl = translateDeclRefImpl(declRef); result.substitutions = translateSubstitutions(declRef.substitutions); return result; } - LoweredDecl translateDeclRef( + RefPtr<Decl> translateDeclRef( Decl* decl) { return translateDeclRefImpl(DeclRef<Decl>(decl, nullptr)); } - // Try to find the module that (recursively) contains a given declaration. - ModuleDecl* findModuleForDecl( - Decl* decl) - { - for (auto dd = decl; dd; dd = dd->ParentDecl) - { - if (auto moduleDecl = dynamic_cast<ModuleDecl*>(dd)) - return moduleDecl; - } - return nullptr; - } - - LoweredDecl translateDeclRefImpl( + RefPtr<Decl> translateDeclRefImpl( DeclRef<Decl> declRef) { Decl* decl = declRef.getDecl(); - if (!decl) return LoweredDecl(); + if (!decl) return nullptr; // We don't want to translate references to built-in declarations, // since they won't be subtituted anyway. @@ -2703,8 +2656,17 @@ struct LoweringVisitor } } - LoweredDecl loweredDecl; - if (shared->loweredDecls.TryGetValue(decl, loweredDecl)) + if (getModifiedDecl(decl)->HasModifier<LegalizedModifier>()) + { + // We are trying to translate a reference to a declaration + // that was created by the type legalization process. The + // target declaration should already be placed inside of + // the output module. + return decl; + } + + RefPtr<Decl> loweredDecl; + if (shared->mapOriginalDeclToLowered.TryGetValue(decl, loweredDecl)) return loweredDecl; // Time to force it @@ -2717,7 +2679,7 @@ struct LoweringVisitor return translateDeclRef(declRef).As<ContainerDecl>(); } - LoweredDecl lowerDeclBase( + RefPtr<Decl> lowerDeclBase( DeclBase* declBase) { if (Decl* decl = dynamic_cast<Decl*>(declBase)) @@ -2730,10 +2692,10 @@ struct LoweringVisitor } } - LoweredDecl lowerDecl( + RefPtr<Decl> lowerDecl( Decl* decl) { - LoweredDecl loweredDecl = DeclVisitor::dispatch(decl); + RefPtr<Decl> loweredDecl = DeclVisitor::dispatch(decl); return loweredDecl; } @@ -2768,11 +2730,10 @@ struct LoweringVisitor addMember(parentDecl, decl); } - void registerLoweredDecl(LoweredDecl loweredDecl, Decl* decl) + void registerLoweredDecl(RefPtr<Decl> loweredDecl, Decl* decl) { - shared->loweredDecls.Add(decl, loweredDecl); - - shared->mapLoweredDeclToOriginal.Add(loweredDecl.getValue(), decl); + shared->mapOriginalDeclToLowered.Add(decl, loweredDecl); + shared->mapLoweredDeclToOriginal.Add(loweredDecl.Ptr(), decl); } // If the name of the declarations collides with a reserved word @@ -2821,9 +2782,9 @@ struct LoweringVisitor { RefPtr<Decl> loweredParent; if (auto genericParentDecl = decl->ParentDecl->As<GenericDecl>()) - loweredParent = translateDeclRef(genericParentDecl->ParentDecl).getDecl(); + loweredParent = translateDeclRef(genericParentDecl->ParentDecl); else - loweredParent = translateDeclRef(decl->ParentDecl).getDecl(); + loweredParent = translateDeclRef(decl->ParentDecl); if (loweredParent) { auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>(); @@ -2874,89 +2835,92 @@ struct LoweringVisitor // Catch-all - LoweredDecl visitSyntaxDecl(SyntaxDecl*) + RefPtr<Decl> visitSyntaxDecl(SyntaxDecl*) { - return LoweredDecl(); + return nullptr; } - LoweredDecl visitGenericValueParamDecl(GenericValueParamDecl*) + RefPtr<Decl> visitGenericValueParamDecl(GenericValueParamDecl*) { SLANG_UNEXPECTED("generics should be lowered to specialized decls"); } - LoweredDecl visitGenericTypeParamDecl(GenericTypeParamDecl*) + RefPtr<Decl> visitGenericTypeParamDecl(GenericTypeParamDecl*) { SLANG_UNEXPECTED("generics should be lowered to specialized decls"); } - LoweredDecl visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*) + RefPtr<Decl> visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*) { SLANG_UNEXPECTED("generics should be lowered to specialized decls"); } - LoweredDecl visitGenericDecl(GenericDecl*) + RefPtr<Decl> visitGenericDecl(GenericDecl*) { SLANG_UNEXPECTED("generics should be lowered to specialized decls"); } - LoweredDecl visitModuleDecl(ModuleDecl*) + RefPtr<Decl> visitModuleDecl(ModuleDecl*) { SLANG_UNEXPECTED("module decls should be lowered explicitly"); } - LoweredDecl visitSubscriptDecl(SubscriptDecl*) + RefPtr<Decl> visitSubscriptDecl(SubscriptDecl*) { // We don't expect to find direct references to a subscript // declaration, but rather to the underlying accessors - return LoweredDecl(); + return nullptr; } - LoweredDecl visitInheritanceDecl(InheritanceDecl*) + RefPtr<Decl> visitInheritanceDecl(InheritanceDecl*) { // We should deal with these explicitly, as part of lowering // the type that contains them. - return LoweredDecl(); + return nullptr; } - LoweredDecl visitExtensionDecl(ExtensionDecl*) + RefPtr<Decl> visitExtensionDecl(ExtensionDecl*) { // Extensions won't exist in the lowered code: their members // will turn into ordinary functions that get called explicitly - return LoweredDecl(); + return nullptr; } - LoweredDecl visitAssocTypeDecl(AssocTypeDecl * /*assocType*/) + RefPtr<Decl> visitAssocTypeDecl(AssocTypeDecl * /*assocType*/) { // not supported SLANG_UNREACHABLE("visitAssocTypeDecl in LowerVisitor"); - UNREACHABLE_RETURN(LoweredDecl()); + UNREACHABLE_RETURN(nullptr); } - LoweredDecl visitGlobalGenericParamDecl(GlobalGenericParamDecl * /*decl*/) + RefPtr<Decl> visitGlobalGenericParamDecl(GlobalGenericParamDecl * /*decl*/) { // not supported SLANG_UNREACHABLE("visitGlobalGenericParamDecl in LowerVisitor"); - UNREACHABLE_RETURN(LoweredDecl()); + UNREACHABLE_RETURN(nullptr); } - LoweredDecl visitTypeDefDecl(TypeDefDecl* decl) + RefPtr<Decl> visitTypeDefDecl(TypeDefDecl* decl) { if (shared->target == CodeGenTarget::GLSL) { // GLSL does not support `typedef`, so we will lower it out of existence here - return LoweredDecl(); + return nullptr; } RefPtr<TypeDefDecl> loweredDecl = new TypeDefDecl(); lowerDeclCommon(loweredDecl, decl); - loweredDecl->type = lowerType(decl->type); + // TODO: Need to handle the case where we `typedef` an aggregate + // type that needs to be legalized; in that case we should desugar + // the `typedef` out of existence. + loweredDecl->type = lowerTypeExprEx(decl->type); addMember(shared->loweredProgram, loweredDecl); - return LoweredDecl(loweredDecl); + return loweredDecl; } - LoweredDecl visitImportDecl(ImportDecl*) + RefPtr<Decl> visitImportDecl(ImportDecl*) { // We could unconditionally output the declarations in the // imported code, but this could cause problems if any @@ -2975,10 +2939,10 @@ struct LoweringVisitor // Don't actually include a representation of // the import declaration in the output - return LoweredDecl(); + return nullptr; } - LoweredDecl visitEmptyDecl(EmptyDecl* decl) + RefPtr<Decl> visitEmptyDecl(EmptyDecl* decl) { // Empty declarations are really only useful in GLSL, // where they are used to hold metadata that doesn't @@ -2992,20 +2956,7 @@ struct LoweringVisitor addDecl(loweredDecl); - return LoweredDecl(loweredDecl); - } - - TupleTypeModifier* isTupleType(Type* type) - { - if (auto declRefType = type->As<DeclRefType>()) - { - if (auto tupleTypeMod = declRefType->declRef.getDecl()->FindModifier<TupleTypeModifier>()) - { - return tupleTypeMod; - } - } - - return nullptr; + return loweredDecl; } Type* unwrapArray(Type* inType) @@ -3018,164 +2969,66 @@ struct LoweringVisitor return type; } - TupleTypeModifier* isTupleTypeOrArrayOfTupleType(Type* type) + RefPtr<Decl> visitAggTypeDecl(AggTypeDecl* decl) { - return isTupleType(unwrapArray(type)); - } - - bool isResourceType(Type* type) - { - while (auto arrayType = type->As<ArrayExpressionType>()) - { - type = arrayType->baseType; - } - - if (auto textureTypeBase = type->As<TextureTypeBase>()) - { - return true; - } - else if (auto samplerType = type->As<SamplerStateType>()) - { - return true; - } - - // TODO: need more comprehensive coverage here - - return false; - } - - LoweredDecl visitAggTypeDecl(AggTypeDecl* decl) - { - // We want to lower any aggregate type declaration - // to just a `struct` type that contains its fields. + // An aggregate type declaration might get "legalized away" + // and result in a new type declaration created by the + // type legalization logic. If that happens, we don't want + // the original type declaration to appear in the output. // - // Any non-field members (e.g., methods) will be - // lowered separately. + // If the result *doesn't* get legalized away, though, we + // need to try to reproduce this declaration as it originally + // appeared. - RefPtr<StructDecl> loweredDecl = new StructDecl(); - lowerDeclCommon(loweredDecl, decl); - - // We need to be ready to turn this type into a "tuple" type, - // if it has any members that can't normally be kept in a `struct` + // We start by creating a type to reference this declaration, + // and then we will try to legalize that. // - // We don't want to do this unconditionally, though, because - // then we'll end up changing the meaning of user code in - // languages like HLSL that support such nesting. + // Note: This logic shouldn't need to defend against generic + // types, since it won't get applied to Slang code that might + // include generics (just HLSL/GLSL code). + RefPtr<DeclRefType> declRefType = DeclRefType::Create( + getSession(), + makeDeclRef(decl)); + DeclRef<Decl> declRef = declRefType->declRef; - bool shouldDesugarTupleTypes = false; - if (getTarget() == CodeGenTarget::GLSL) - { - // Always desugar this stuff for GLSL, since it doesn't - // support nesting of resources in structs. - // - // TODO: Need a way to make this more fine-grained to - // handle cases where a nested member might be allowed - // due to, e.g., bindless textures. - shouldDesugarTupleTypes = true; - } - else if( shared->compileRequest->compileFlags & SLANG_COMPILE_FLAG_SPLIT_MIXED_TYPES ) + LegalType legalType = legalizeType(getTypeLegalizationContext(), declRefType); + if(legalType.flavor != LegalType::Flavor::simple) { - // If the user is directly asking us to do this transformation, - // then obviously we need to do it. + // Something happened to this type during legalization, so + // we don't want to let its declaration appear in the output. // - // TODO: The way this is defined here means it will even apply to user - // HLSL code (not just code written in Slang). We may want to - // reconsider that choice, and only split things that originated in Slang. + // However, we need to ensure that when declaration references + // that might reference this declaration get constructed (e.g., + // this might be the `T` in a `ConstantBuffer<T>`, we have something + // to stick in there. // - shouldDesugarTupleTypes = true; + // For now we'll use the original declaration and hope for the best. + return decl; } - bool isResultATupleType = false; - bool hasAnyNonTupleFields = false; + // if we get this far, then we want to produce an "equivalent" + // aggregate type declaration to what the user wrote. + + RefPtr<StructDecl> loweredDecl = new StructDecl(); + lowerDeclCommon(loweredDecl, decl); for (auto field : decl->getMembersOfType<VarDeclBase>()) { // We lower the field, which will involve lowering the field type - auto loweredField = translateDeclRef(field).getDecl()->As<VarDeclBase>(); + auto loweredField = translateDeclRef(field)->As<VarDeclBase>(); // Add the field to the result declaration addMember(loweredDecl, loweredField); - - // Don't consider any of the following desugaring logic, - // if we aren't supposed to be desugaring this type - if (!shouldDesugarTupleTypes) - { - hasAnyNonTupleFields = true; - continue; - } - - - // If the field is of a type that requires special handling, - // we need to make a note of it. - auto loweredFieldType = loweredField->type.type; - bool isTupleField = false; - bool fieldHasAnyNonTupleFields = false; - bool fieldHasTupleType = false; - if (auto fieldTupleTypeMod = isTupleTypeOrArrayOfTupleType(loweredFieldType)) - { - isTupleField = true; - fieldHasTupleType = true; - if (fieldTupleTypeMod->hasAnyNonTupleFields) - { - fieldHasAnyNonTupleFields = true; - hasAnyNonTupleFields = true; - } - } - else if (isResourceType(loweredFieldType)) - { - isTupleField = true; - } - else - { - hasAnyNonTupleFields = true; - } - - if (isTupleField) - { - isResultATupleType = true; - - RefPtr<TupleFieldModifier> tupleFieldMod = new TupleFieldModifier(); - tupleFieldMod->decl = loweredField; - tupleFieldMod->hasAnyNonTupleFields = fieldHasAnyNonTupleFields; - tupleFieldMod->isNestedTuple = fieldHasTupleType; - - addModifier(loweredField, tupleFieldMod); - } } - // An empty `struct` must be treated as a tuple type, - // in order to ensure that we don't mess up layout logic - // - // (Also, GLSL doesn't allow empty structs IIRC) - // - // Note: in this one case we are desugaring things even - // when targetting HLSL, just to keep things manageable. - if (!hasAnyNonTupleFields) - { - isResultATupleType = true; - } + // TODO: we should really be copying over *all* the members, + // in the case where this is a user-authored type. - if (isResultATupleType) - { - RefPtr<TupleTypeModifier> tupleTypeMod = new TupleTypeModifier(); - tupleTypeMod->decl = loweredDecl; - tupleTypeMod->hasAnyNonTupleFields = hasAnyNonTupleFields; - addModifier(loweredDecl, tupleTypeMod); - } + addMember( + shared->loweredProgram, + loweredDecl); - if (isResultATupleType && !hasAnyNonTupleFields) - { - // We don't want any pure-tuple types showing up in - // the output program, so we skip that here. - } - else - { - addMember( - shared->loweredProgram, - loweredDecl); - } - - return LoweredDecl(loweredDecl); + return loweredDecl; } RefPtr<VarDeclBase> lowerSimpleVarDeclCommon( @@ -3186,7 +3039,7 @@ struct LoweringVisitor lowerDeclCommon(loweredDecl, decl); loweredDecl->type = loweredType; - loweredDecl->initExpr = lowerExpr(decl->initExpr); + loweredDecl->initExpr = legalizeSimpleExpr(decl->initExpr); return loweredDecl; } @@ -3195,380 +3048,435 @@ struct LoweringVisitor RefPtr<VarDeclBase> loweredDecl, VarDeclBase* decl) { - auto loweredType = lowerType(decl->type); + auto loweredType = lowerTypeExprEx(decl->type); return lowerSimpleVarDeclCommon(loweredDecl, decl, loweredType); } - struct TupleTypeSecondaryVarArraySpec + RefPtr<StructTypeLayout> getBodyStructTypeLayout(RefPtr<TypeLayout> typeLayout) { - TupleTypeSecondaryVarArraySpec* next; - RefPtr<IntVal> elementCount; - }; + if (!typeLayout) + return nullptr; - struct TupleSecondaryVarInfo - { - // Parent tuple decl to add the secondary decl into - RefPtr<TupleVarDecl> tupleDecl; + while (auto parameterGroupTypeLayout = typeLayout.As<ParameterGroupTypeLayout>()) + { + typeLayout = parameterGroupTypeLayout->elementTypeLayout; + } - // Syntax class for declarations to create - SyntaxClass<VarDeclBase> varDeclClass; + while (auto arrayTypeLayout = typeLayout.As<ArrayTypeLayout>()) + { + typeLayout = arrayTypeLayout->elementTypeLayout; + } - // name "stem" to use for any actual variables we create - String name; + if (auto structTypeLayout = typeLayout.As<StructTypeLayout>()) + { + return structTypeLayout; + } - // The parent tuple type (or array thereof) we are scalarizing - RefPtr<Type> tupleType; + return nullptr; + } - // The actual declaration of the tuple type (which will give us the fields) - DeclRef<AggTypeDecl> tupleTypeDecl; + LegalExpr deref( + LegalExpr base) + { + switch (base.getFlavor()) + { + case LegalExpr::Flavor::none: + return LegalExpr(); - // An initializer expression to use for the tuple members - RefPtr<Expr> initExpr; + case LegalExpr::Flavor::simple: + { + auto simpleBase = base.getSimple(); - // The original layout given to the top-level variable - RefPtr<VarLayout> primaryVarLayout; + RefPtr<DerefExpr> resultExpr = new DerefExpr(); + // TODO: need to fill in a type here? + resultExpr->base = simpleBase; + return LegalExpr(resultExpr); + } + break; - // The computed layout of the tuple type itself - RefPtr<StructTypeLayout> tupleTypeLayout; + case LegalExpr::Flavor::tuple: + { + auto tupleExpr = base.getTuple(); + RefPtr<TuplePseudoExpr> resultExpr = new TuplePseudoExpr(); - TupleTypeSecondaryVarArraySpec* arraySpecs = nullptr; - }; + for (auto ee : tupleExpr->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = deref(ee.expr); - void createTupleTypeSecondaryVarDecls( - TupleSecondaryVarInfo const& info) - { - if (auto arrayType = info.tupleType->As<ArrayExpressionType>()) - { - TupleTypeSecondaryVarArraySpec arraySpec; - arraySpec.next = info.arraySpecs; - arraySpec.elementCount = arrayType->ArrayLength; + resultExpr->elements.Add(element); + } - TupleSecondaryVarInfo subInfo = info; - subInfo.tupleType = arrayType->baseType; - subInfo.arraySpecs = &arraySpec; - createTupleTypeSecondaryVarDecls(subInfo); - return; - } + return LegalExpr(resultExpr); + } + break; - // Next, we need to go through the declarations in the aggregate - // type, and deal with all of those that should be tuple-ified. - for (auto dd : getMembersOfType<VarDeclBase>(info.tupleTypeDecl)) - { - if (dd.getDecl()->HasModifier<HLSLStaticModifier>()) - continue; + case LegalExpr::Flavor::pair: + { + auto basePair = base.getPair(); + RefPtr<PairPseudoExpr> resultPair = new PairPseudoExpr(); + resultPair->pairInfo = basePair->pairInfo; - auto tupleFieldMod = dd.getDecl()->FindModifier<TupleFieldModifier>(); - if (!tupleFieldMod) - continue; + resultPair->ordinary = deref(basePair->ordinary); + resultPair->special = deref(basePair->special); - // TODO: need to extract the initializer for this field - SLANG_RELEASE_ASSERT(!info.initExpr); - RefPtr<Expr> fieldInitExpr; + return LegalExpr(resultPair); + } - String fieldName = info.name + "_" + getText(dd.GetName()); + case LegalExpr::Flavor::implicitDeref: + { + auto implicitDerefExpr = base.getImplicitDeref(); + return implicitDerefExpr->valueExpr; + } + break; - auto fieldType = GetType(dd); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); + break; + } + } - Decl* originalFieldDecl; - shared->mapLoweredDeclToOriginal.TryGetValue(dd, originalFieldDecl); - SLANG_RELEASE_ASSERT(originalFieldDecl); + LegalExpr extractField( + LegalExpr base, + DeclRef<VarDeclBase> fieldDeclRef) + { + switch (base.getFlavor()) + { + case LegalExpr::Flavor::none: + return LegalExpr(); - RefPtr<VarLayout> fieldLayout; - if(info.tupleTypeLayout) + case LegalExpr::Flavor::simple: { - info.tupleTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout); + auto simpleBase = base.getSimple(); + + RefPtr<MemberExpr> resultExpr = new MemberExpr(); + resultExpr->BaseExpression = simpleBase; + resultExpr->type.type = GetType(fieldDeclRef); + resultExpr->declRef = translateDeclRef(fieldDeclRef.As<Decl>()); + resultExpr->name = fieldDeclRef.GetName(); + return LegalExpr(resultExpr); } - if (fieldLayout && info.primaryVarLayout) + break; + + case LegalExpr::Flavor::tuple: { - // The layout for a field may need to be adjusted - // based on a base offset stored in the primary - // variable. - // - // For example, if the primary variable was recoreded - // to start at descriptor-table slot N, then the - // field layout might say it uses slot k, but that - // needs to be understood relative to the parent, - // so we want slot N + k... and actuall N + k + 1, - // in the case where the parent itself took up - // space of that type... - - bool needsOffset = false; - for (auto rr : fieldLayout->resourceInfos) - { - if (auto parentInfo = info.primaryVarLayout->FindResourceInfo(rr.kind)) - { - if (parentInfo->index != 0 || parentInfo->space != 0) - { - needsOffset = true; - break; - } - } - } - if (needsOffset) + auto baseTuple = base.getTuple(); + for (auto ee : baseTuple->elements) { - RefPtr<VarLayout> newFieldLayout = new VarLayout(); - newFieldLayout->typeLayout = fieldLayout->typeLayout; - newFieldLayout->flags = fieldLayout->flags; - newFieldLayout->stage = fieldLayout->stage; - newFieldLayout->varDecl = fieldLayout->varDecl; - newFieldLayout->systemValueSemantic = fieldLayout->systemValueSemantic; - newFieldLayout->systemValueSemanticIndex = fieldLayout->systemValueSemanticIndex; - newFieldLayout->semanticName = fieldLayout->semanticName; - newFieldLayout->semanticIndex = fieldLayout->semanticIndex; - - for (auto resInfo : fieldLayout->resourceInfos) + if (ee.fieldDeclRef.Equals(fieldDeclRef)) { - auto newResInfo = newFieldLayout->findOrAddResourceInfo(resInfo.kind); - newResInfo->index = resInfo.index; - newResInfo->space = resInfo.space; - if (auto parentInfo = info.primaryVarLayout->FindResourceInfo(resInfo.kind)) - { - newResInfo->index += parentInfo->index; - newResInfo->space += parentInfo->space; - } + return ee.expr; } - - fieldLayout = newFieldLayout; } + SLANG_UNEXPECTED("failed to find tuple element"); } + break; - LoweredDecl fieldVarOrTupleDecl; - if (auto fieldTupleTypeMod = isTupleTypeOrArrayOfTupleType(fieldType)) + case LegalExpr::Flavor::pair: { - // If the field is itself a tuple, then recurse - RefPtr<TupleVarDecl> fieldTupleDecl = new TupleVarDecl(); - - TupleSecondaryVarInfo fieldInfo; - fieldInfo.tupleDecl = fieldTupleDecl; - fieldInfo.varDeclClass = info.varDeclClass; - fieldInfo.name = fieldName; - fieldInfo.tupleType = fieldType; - fieldInfo.tupleTypeDecl = makeDeclRef(fieldTupleTypeMod->decl); - fieldInfo.initExpr = fieldInitExpr; - fieldInfo.primaryVarLayout = fieldLayout; - fieldInfo.tupleTypeLayout = getBodyStructTypeLayout(fieldLayout ? fieldLayout->typeLayout : nullptr); - fieldInfo.arraySpecs = info.arraySpecs; - - fieldTupleDecl->tupleType = fieldTupleTypeMod; - createTupleTypeSecondaryVarDecls(fieldInfo); - - fieldVarOrTupleDecl = LoweredDecl(fieldTupleDecl); - } - else - { - // Otherwise the field has a simple type, and we just need to declare the variable here + auto basePair = base.getPair(); - RefPtr<Type> fieldVarType = fieldType; - for (auto aa = info.arraySpecs; aa; aa = aa->next) - { - RefPtr<ArrayExpressionType> arrayType = Slang::getArrayType( - fieldVarType, - aa->elementCount); + // Need to determine if this field is on the + // ordinary side, the special side, or both. - fieldVarType = arrayType; + auto pairInfo = basePair->pairInfo; + auto pairElement = pairInfo->findElement(fieldDeclRef); + if (!pairElement) + { + SLANG_UNEXPECTED("failed to find tuple element"); + UNREACHABLE_RETURN(LegalExpr()); } - RefPtr<VarDeclBase> fieldVarDecl = info.varDeclClass.createInstance(); - fieldVarDecl->nameAndLoc = NameLoc(getName(fieldName)); - fieldVarDecl->type.type = fieldVarType; - - addDecl(fieldVarDecl); - - if (fieldLayout) + if ((pairElement->flags & PairInfo::kFlag_hasOrdinaryAndSpecial) == PairInfo::kFlag_hasOrdinaryAndSpecial) { - RefPtr<ComputedLayoutModifier> layoutMod = new ComputedLayoutModifier(); - layoutMod->layout = fieldLayout; - addModifier(fieldVarDecl, layoutMod); + // we have both flags + LegalExpr ordinaryResult = extractField(basePair->ordinary, + pairElement->ordinaryFieldDeclRef.As<VarDeclBase>()); + LegalExpr specialResult = extractField(basePair->special, fieldDeclRef); + + RefPtr<PairPseudoExpr> resultPair = new PairPseudoExpr(); + resultPair->ordinary = ordinaryResult; + resultPair->special = specialResult; + resultPair->pairInfo = pairElement->type.getPair()->pairInfo; + return LegalExpr(resultPair); + } + else if(pairElement->flags & PairInfo::kFlag_hasOrdinary) + { + return extractField(basePair->ordinary, + pairElement->ordinaryFieldDeclRef.As<VarDeclBase>()); + } + else + { + SLANG_ASSERT(pairElement->flags & PairInfo::kFlag_hasSpecial); + return extractField(basePair->special, fieldDeclRef); } - - fieldVarOrTupleDecl = LoweredDecl(fieldVarDecl); } + break; - RefPtr<TupleVarModifier> fieldTupleVarMod = new TupleVarModifier(); - fieldTupleVarMod->tupleField = tupleFieldMod; + case LegalExpr::Flavor::implicitDeref: + { + auto baseImplicitDeref = base.getImplicitDeref(); - TupleVarDecl::Element elem; - elem.decl = fieldVarOrTupleDecl; - elem.tupleVarMod = fieldTupleVarMod; + RefPtr<ImplicitDerefPseudoExpr> resultImplicitDeref = new ImplicitDerefPseudoExpr(); + resultImplicitDeref->valueExpr = extractField( + baseImplicitDeref->valueExpr, + fieldDeclRef); + return LegalExpr(resultImplicitDeref); + } - info.tupleDecl->tupleDecls.Add(elem); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); + break; } } - LoweredDecl createTupleTypeVarDecls( - SyntaxClass<VarDeclBase> varDeclClass, - RefPtr<VarDeclBase> originalVarDecl, - String const& name, - RefPtr<Type> tupleType, - DeclRef<AggTypeDecl> tupleTypeDecl, - TupleTypeModifier* tupleTypeMod, - RefPtr<Expr> initExpr, - RefPtr<VarLayout> primaryVarLayout, - RefPtr<StructTypeLayout> tupleTypeLayout) + void attachLayoutModifier( + VarDeclBase* decl, + VarLayout* layout) { - // Not handling initializers just yet... - SLANG_RELEASE_ASSERT(!initExpr); + if (!layout) + return; - // We'll need a placeholder declaration to wrap the whole thing up: - RefPtr<TupleVarDecl> tupleDecl = new TupleVarDecl(); - tupleDecl->nameAndLoc = NameLoc(getName(name)); + RefPtr<ComputedLayoutModifier> mod = new ComputedLayoutModifier(); + mod->layout = layout; + addModifier(decl, mod); + } - // First, if the tuple type had any "ordinary" data, - // then we go ahead and create a declaration for that stuff - if (tupleTypeMod->hasAnyNonTupleFields) + RefPtr<VarDeclBase> declareSimpleVar( + VarDeclBase* decl, + SourceLoc const& loc, + Name* name, + SyntaxClass<VarDeclBase> loweredDeclClass, + VarLayout* varLayout, + RefPtr<Expr> initExpr, + TypeExp const& typeExpr) + { + RefPtr<VarDeclBase> loweredDecl = loweredDeclClass.createInstance(); + if (decl) { - RefPtr<VarDeclBase> primaryVarDecl = varDeclClass.createInstance(); - primaryVarDecl->nameAndLoc.name = getName(name); - primaryVarDecl->type.type = tupleType; + lowerDeclCommon(loweredDecl, decl); + } + loweredDecl->nameAndLoc.name = name; + loweredDecl->nameAndLoc.loc = loc; - primaryVarDecl->modifiers = shallowCloneModifiers(originalVarDecl->modifiers); + loweredDecl->type = typeExpr; + loweredDecl->initExpr = initExpr; - tupleDecl->primaryDecl = primaryVarDecl; + attachLayoutModifier(loweredDecl, varLayout); - if (primaryVarLayout) + addDecl(loweredDecl); + return loweredDecl; + } + + LegalExpr declareSimpleVar( + VarDeclBase* originalDecl, + LegalVarChain* varChain, + SourceLoc const& loc, + String const& name, + SyntaxClass<VarDeclBase> loweredDeclClass, + TypeLayout* typeLayout, + LegalExpr legalInit, + LegalTypeExpr const& legalTypeExpr) + { + RefPtr<VarLayout> varLayout = createVarLayout(varChain, typeLayout); + + RefPtr<VarDeclBase> varDecl = declareSimpleVar( + originalDecl, + loc, + getName(name), + loweredDeclClass, + varLayout, + legalInit.getSimple(), + legalTypeExpr.getSimple()); + + return createVarRef(loc, varDecl); + } + + LegalExpr declareVars( + VarDeclBase* originalDecl, + LegalVarChain* varChain, + SourceLoc const& loc, + String const& name, + SyntaxClass<VarDeclBase> loweredDeclClass, + TypeLayout* typeLayout, + LegalExpr legalInit, + LegalTypeExpr const& legalTypeExpr) + { + auto& legalType = legalTypeExpr.type; + switch (legalType.flavor) + { + case LegalType::Flavor::simple: { - RefPtr<ComputedLayoutModifier> layoutMod = new ComputedLayoutModifier(); - layoutMod->layout = primaryVarLayout; - addModifier(primaryVarDecl, layoutMod); - } + return declareSimpleVar( + originalDecl, + varChain, + loc, + name, + loweredDeclClass, + typeLayout, + legalInit, + legalTypeExpr); - addDecl(primaryVarDecl); - } + } + break; - TupleSecondaryVarInfo info; - info.tupleDecl = tupleDecl; - info.varDeclClass = varDeclClass; - info.name = name; - info.tupleType = tupleType; - info.tupleTypeDecl = tupleTypeDecl; - info.initExpr = initExpr; - info.primaryVarLayout = primaryVarLayout; - info.tupleTypeLayout = tupleTypeLayout; + case LegalType::Flavor::implicitDeref: + { + auto implicitDerefType = legalType.getImplicitDeref(); + + auto valueType = implicitDerefType->valueType; + auto valueTypeLayout = getDerefTypeLayout(typeLayout); + SLANG_ASSERT(valueTypeLayout || !typeLayout); + auto valueInit = deref(legalInit); + + LegalExpr valueExpr = declareVars( + originalDecl, + varChain, + loc, + name, + loweredDeclClass, + valueTypeLayout, + valueInit, + valueType); - createTupleTypeSecondaryVarDecls(info); + RefPtr<ImplicitDerefPseudoExpr> implicitDerefExpr = new ImplicitDerefPseudoExpr(); + implicitDerefExpr->valueExpr = valueExpr; + return LegalExpr(implicitDerefExpr); + } + break; - return LoweredDecl(tupleDecl); - } + case LegalType::Flavor::tuple: + { + auto tupleType = legalType.getTuple(); - RefPtr<StructTypeLayout> getBodyStructTypeLayout(RefPtr<TypeLayout> typeLayout) - { - if (!typeLayout) - return nullptr; + RefPtr<TuplePseudoExpr> tupleExpr = new TuplePseudoExpr(); - while (auto parameterGroupTypeLayout = typeLayout.As<ParameterGroupTypeLayout>()) - { - typeLayout = parameterGroupTypeLayout->elementTypeLayout; - } + for (auto ff : tupleType->elements) + { + RefPtr<VarLayout> fieldLayout = getFieldLayout( + typeLayout, + ff.fieldDeclRef); + RefPtr<TypeLayout> fieldTypeLayout = fieldLayout ? fieldLayout->typeLayout : nullptr; + SLANG_ASSERT(fieldLayout || !typeLayout); + LegalExpr fieldInit = extractField(legalInit, ff.fieldDeclRef); + + String fieldName = name + "_" + getText(ff.fieldDeclRef.GetName()); + + LegalVarChain fieldVarChain; + fieldVarChain.next = varChain; + fieldVarChain.varLayout = fieldLayout; + + LegalExpr fieldExpr = declareVars( + nullptr, + &fieldVarChain, + loc, + fieldName, + loweredDeclClass, + fieldTypeLayout, + fieldInit, + ff.type); + + TuplePseudoExpr::Element element; + element.expr = fieldExpr; + element.fieldDeclRef = ff.fieldDeclRef; + + tupleExpr->elements.Add(element); + } - while (auto arrayTypeLayout = typeLayout.As<ArrayTypeLayout>()) - { - typeLayout = arrayTypeLayout->elementTypeLayout; - } + return LegalExpr(tupleExpr); + } + break; - if (auto structTypeLayout = typeLayout.As<StructTypeLayout>()) - { - return structTypeLayout; - } + case LegalType::Flavor::pair: + { + auto pairType = legalType.getPair(); + RefPtr<PairPseudoExpr> pairExpr = new PairPseudoExpr(); + pairExpr->pairInfo = pairType->pairInfo; + pairExpr->loc = loc; + + pairExpr->ordinary = declareVars( + originalDecl, + varChain, + loc, + name, + loweredDeclClass, + typeLayout, + legalInit, + pairType->ordinaryType); + + pairExpr->special = declareVars( + originalDecl, + varChain, + loc, + name, + loweredDeclClass, + typeLayout, + legalInit, + pairType->specialType); - return nullptr; - } + return LegalExpr(pairExpr); + } + break; - LoweredDecl createTupleTypeVarDecls( - SyntaxClass<VarDeclBase> varDeclClass, - RefPtr<VarDeclBase> originalVarDecl, - String const& name, - RefPtr<Type> tupleType, - TupleTypeModifier* tupleTypeMod, - RefPtr<Expr> initExpr, - RefPtr<VarLayout> primaryVarLayout) - { - RefPtr<StructTypeLayout> tupleTypeLayout; - if (primaryVarLayout) - { - auto primaryTypeLayout = primaryVarLayout->typeLayout; - tupleTypeLayout = getBodyStructTypeLayout(primaryTypeLayout); + default: + SLANG_UNEXPECTED("unhandled legalized type flavor"); + UNREACHABLE_RETURN(LegalExpr()); + break; } - - return createTupleTypeVarDecls( - varDeclClass, - originalVarDecl, - name, - tupleType, - makeDeclRef(tupleTypeMod->decl), - tupleTypeMod, - initExpr, - primaryVarLayout, - tupleTypeLayout); } - LoweredDecl lowerVarDeclCommonInner( + void lowerVarDeclCommonInner( VarDeclBase* decl, SyntaxClass<VarDeclBase> loweredDeclClass) { - auto loweredType = lowerType(decl->type); - - if (auto tupleTypeMod = isTupleTypeOrArrayOfTupleType(loweredType)) - { - auto varLayout = tryToFindLayout(decl).As<VarLayout>(); + auto legalTypeExpr = lowerAndLegalizeTypeExpr(decl->type); - // The type for the variable is a "tuple type" - // so we need to go ahead and create multiple variables - // to represent it. + auto varLayout = tryToFindLayout(decl).As<VarLayout>(); - // If the variable had an initializer, we expect it - // to resolve to a tuple *value* - auto loweredInit = lowerExpr(decl->initExpr); - - // TODO: need to extract layout here and propagate it down! + // Note: we lower the initialization expression, if any, + // *before* we add the declaration to the current context (e.g., a statement being + // built), so that any operations inside the initialization expression that + // might need to inject statements/temporaries/whatever happen *before* + // the declaration of this variable. + auto legalInit = legalizeExpr(decl->initExpr); - auto tupleDecl = createTupleTypeVarDecls( - loweredDeclClass, + if (legalTypeExpr.type.flavor == LegalType::Flavor::simple) + { + declareSimpleVar( decl, - getText(decl->getName()), - loweredType.type, - tupleTypeMod, - loweredInit, - varLayout); - - shared->loweredDecls.Add(decl, tupleDecl); - return tupleDecl; + decl->nameAndLoc.loc, + decl->getName(), + loweredDeclClass, + varLayout, + legalInit.getSimple(), + legalTypeExpr.getSimple()); } - if (auto bufferType = loweredType->As<UniformParameterGroupType>()) + else { - auto varLayout = tryToFindLayout(decl).As<VarLayout>(); - - auto elementType = bufferType->elementType; - if (auto elementTupleTypeMod = isTupleTypeOrArrayOfTupleType(elementType)) - { - auto tupleDecl = createTupleTypeVarDecls( - loweredDeclClass, - decl, - getText(decl->getName()), - loweredType.type, - elementTupleTypeMod, - nullptr, - varLayout); - - shared->loweredDecls.Add(decl, tupleDecl); - return tupleDecl; - } - } + LegalVarChain varChain; + varChain.next = nullptr; + varChain.varLayout = varLayout; - RefPtr<VarDeclBase> loweredDecl = loweredDeclClass.createInstance(); - - // Note: we lower the declaration (including its initialization expression, if any) - // *before* we add the declaration to the current context (e.g., a statement being - // built), so that any operations inside the initialization expression that - // might need to inject statements/temporaries/whatever happen *before* - // the declaration of this variable. - auto result = lowerSimpleVarDeclCommon(loweredDecl, decl, loweredType); - addDecl(loweredDecl); + LegalExpr legalExpr = declareVars( + decl, + &varChain, + decl->nameAndLoc.loc, + getText(decl->getName()), + loweredDeclClass, + varLayout ? varLayout->typeLayout : nullptr, + legalInit, + legalTypeExpr); - return LoweredDecl(result); + shared->mapOriginalDeclToExpr.Add(decl, legalExpr); + shared->mapOriginalDeclToLowered.AddIfNotExists(decl, nullptr); + } } - LoweredDecl lowerVarDeclCommon( + void lowerVarDeclCommon( VarDeclBase* decl, SyntaxClass<VarDeclBase> loweredDeclClass) { @@ -3581,17 +3489,17 @@ struct LoweringVisitor if (auto parentModuleDecl = pp.As<ModuleDecl>()) { LoweringVisitor subVisitor = *this; - subVisitor.parentDecl = translateDeclRef(parentModuleDecl).getDecl()->As<ContainerDecl>(); + subVisitor.parentDecl = translateDeclRef(parentModuleDecl)->As<ContainerDecl>(); subVisitor.isBuildingStmt = false; - return subVisitor.lowerVarDeclCommonInner(decl, loweredDeclClass); + subVisitor.lowerVarDeclCommonInner(decl, loweredDeclClass); } // TODO: handle `static` function-scope variables else { // The default behavior is to lower into whatever // scope was already in places - return lowerVarDeclCommonInner(decl, loweredDeclClass); + lowerVarDeclCommonInner(decl, loweredDeclClass); } } @@ -3658,7 +3566,7 @@ struct LoweringVisitor return false; } - LoweredDecl visitVariable( + RefPtr<Decl> visitVariable( Variable* decl) { if (dynamic_cast<ModuleDecl*>(decl->ParentDecl)) @@ -3679,7 +3587,7 @@ struct LoweringVisitor // We can't easily support `in out` declarations with this approach SLANG_RELEASE_ASSERT(!(inRes && outRes)); - LoweredExpr loweredExpr; + LegalExpr loweredExpr; if (inRes) { loweredExpr = lowerShaderParameterToGLSLGLobals( @@ -3696,54 +3604,49 @@ struct LoweringVisitor VaryingParameterDirection::Output); } -// SLANG_RELEASE_ASSERT(loweredExpr); - auto loweredDecl = createVaryingTupleVarDecl( - decl, - loweredExpr); - - registerLoweredDecl(LoweredDecl(loweredDecl), decl); - return LoweredDecl(loweredDecl); + shared->mapOriginalDeclToExpr.Add(decl, loweredExpr); + shared->mapOriginalDeclToLowered.Add(decl, nullptr); + return nullptr; } } } - auto loweredDecl = lowerVarDeclCommon(decl, getClass<Variable>()); - if(!loweredDecl.getValue()) - return LoweredDecl(); + lowerVarDeclCommon(decl, getClass<Variable>()); - return loweredDecl; + return nullptr; } - LoweredDecl visitStructField( + RefPtr<Decl> visitStructField( StructField* decl) { - return LoweredDecl(lowerSimpleVarDeclCommon(new StructField(), decl)); + return lowerSimpleVarDeclCommon(new StructField(), decl); } - LoweredDecl visitParamDecl( + RefPtr<Decl> visitParamDecl( ParamDecl* decl) { - auto loweredDecl = lowerVarDeclCommon(decl, getClass<ParamDecl>()); - return loweredDecl; + lowerVarDeclCommon(decl, getClass<ParamDecl>()); + + return nullptr; } - LoweredDecl transformSyntaxField(DeclBase* decl) + RefPtr<Decl> transformSyntaxField(DeclBase* decl) { return lowerDeclBase(decl); } - LoweredDecl visitDeclGroup( + RefPtr<Decl> visitDeclGroup( DeclGroup* group) { for (auto decl : group->decls) { lowerDecl(decl); } - return LoweredDecl(); + return nullptr; } - LoweredDecl visitFunctionDeclBase( + RefPtr<Decl> visitFunctionDeclBase( FunctionDeclBase* decl) { // TODO: need to generate a name @@ -3766,7 +3669,7 @@ struct LoweringVisitor subVisitor.translateDeclRef(paramDecl); } - auto loweredReturnType = subVisitor.lowerType(decl->ReturnType); + auto loweredReturnType = subVisitor.lowerAndlegalizeSimpleTypeExpr(decl->ReturnType); loweredDecl->ReturnType = loweredReturnType; @@ -3776,7 +3679,7 @@ struct LoweringVisitor // even if it had been a member function when declared. addMember(shared->loweredProgram, loweredDecl); - return LoweredDecl(loweredDecl); + return loweredDecl; } // @@ -3969,7 +3872,7 @@ struct LoweringVisitor return Slang::getArrayType(elementType, getConstantIntVal(elementCount)); } - LoweredExpr lowerSimpleShaderParameterToGLSLGlobal( + LegalExpr lowerSimpleShaderParameterToGLSLGlobal( VaryingParameterInfo const& info, RefPtr<Type> varType, RefPtr<VarLayout> varLayout) @@ -4246,10 +4149,10 @@ struct LoweringVisitor globalVarExpr = globalVarRef; } - return LoweredExpr(globalVarExpr); + return LegalExpr(globalVarExpr); } - LoweredExpr lowerShaderParameterToGLSLGLobalsRec( + LegalExpr lowerShaderParameterToGLSLGLobalsRec( VaryingParameterInfo const& info, RefPtr<Type> varType, RefPtr<VarLayout> varLayout) @@ -4301,10 +4204,7 @@ struct LoweringVisitor // The shader parameter had a structured type, so we need // to destructure it into its constituent fields - RefPtr<VaryingTupleExpr> tupleExpr = new VaryingTupleExpr(); - tupleExpr->type.type = varType; - - SLANG_RELEASE_ASSERT(tupleExpr->type.type); + RefPtr<TuplePseudoExpr> tupleExpr = new TuplePseudoExpr(); for (auto fieldDeclRef : getMembersOfType<VarDeclBase>(aggTypeDeclRef)) { @@ -4337,15 +4237,15 @@ struct LoweringVisitor GetType(fieldDeclRef), fieldLayout); - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = makeDeclRef(originalFieldDecl).As<VarDeclBase>(); + TuplePseudoExpr::Element elem; + elem.fieldDeclRef = makeDeclRef(originalFieldDecl).As<VarDeclBase>(); elem.expr = loweredFieldExpr; tupleExpr->elements.Add(elem); } // Okay, we are done with this parameter - return LoweredExpr(tupleExpr); + return LegalExpr(tupleExpr); } } @@ -4353,7 +4253,7 @@ struct LoweringVisitor return lowerSimpleShaderParameterToGLSLGlobal(info, varType, varLayout); } - LoweredExpr lowerShaderParameterToGLSLGLobals( + LegalExpr lowerShaderParameterToGLSLGLobals( RefPtr<VarDeclBase> originalVarDecl, RefPtr<VarLayout> paramLayout, VaryingParameterDirection direction) @@ -4383,47 +4283,16 @@ struct LoweringVisitor break; } - auto loweredType = lowerType(originalVarDecl->type); + auto loweredType = lowerAndLegalizeTypeExpr(originalVarDecl->type); auto loweredExpr = lowerShaderParameterToGLSLGLobalsRec( info, - loweredType.type, + loweredType.type.getSimple(), // TODO: handle non-simple? paramLayout); -#if 0 - RefPtr<VaryingTupleVarDecl> loweredDecl = createVaryingTupleVarDecl( - originalVarDecl, - loweredType, - loweredExpr); - - registerLoweredDecl(loweredDecl, originalVarDecl); - addDecl(loweredDecl); -#endif - return loweredExpr; } - RefPtr<VaryingTupleVarDecl> createVaryingTupleVarDecl( - RefPtr<VarDeclBase> originalVarDecl, - TypeExp const& loweredType, - LoweredExpr loweredExpr) - { - RefPtr<VaryingTupleVarDecl> loweredDecl = new VaryingTupleVarDecl(); - loweredDecl->nameAndLoc = originalVarDecl->nameAndLoc; - loweredDecl->type = loweredType; - loweredDecl->expr = loweredExpr; - - return loweredDecl; - } - - RefPtr<VaryingTupleVarDecl> createVaryingTupleVarDecl( - RefPtr<VarDeclBase> originalVarDecl, - LoweredExpr loweredExpr) - { - auto loweredType = lowerType(originalVarDecl->type); - return createVaryingTupleVarDecl(originalVarDecl, loweredType, loweredExpr); - } - struct EntryPointParamPair { RefPtr<ParamDecl> original; @@ -4436,7 +4305,7 @@ struct LoweringVisitor RefPtr<EntryPointLayout> entryPointLayout) { // First, loer the entry-point function as an ordinary function: - auto loweredEntryPointFunc = visitFunctionDeclBase(entryPointDecl).getDecl()->As<FunctionDeclBase>(); + auto loweredEntryPointFunc = visitFunctionDeclBase(entryPointDecl)->As<FunctionDeclBase>(); auto mainName = getName("main"); @@ -4476,7 +4345,7 @@ struct LoweringVisitor RefPtr<Variable> localVarDecl = new Variable(); localVarDecl->loc = paramDecl->loc; localVarDecl->nameAndLoc = paramDecl->getNameAndLoc(); - localVarDecl->type = lowerType(paramDecl->type); + localVarDecl->type = lowerAndlegalizeSimpleTypeExpr(paramDecl->type); ensureDeclHasAValidName(localVarDecl); @@ -4553,12 +4422,12 @@ struct LoweringVisitor if (resultVarDecl) { // Non-`void` return type, so we need to store it - subVisitor.assign(resultVarDecl, LoweredExpr(callExpr)); + subVisitor.assign(resultVarDecl, LegalExpr(callExpr)); } else { // `void` return type: just call it - subVisitor.addExprStmt(LoweredExpr(callExpr)); + subVisitor.addExprStmt(LegalExpr(callExpr)); } @@ -4594,8 +4463,8 @@ struct LoweringVisitor if (shared->requiresCopyGLPositionToPositionPerView) { subVisitor.assign( - LoweredExpr(createSimpleVarExpr("gl_PositionPerViewNV[0]")), - LoweredExpr(createSimpleVarExpr("gl_Position"))); + LegalExpr(createSimpleVarExpr("gl_PositionPerViewNV[0]")), + LegalExpr(createSimpleVarExpr("gl_Position"))); } bodyStmt->body = subVisitor.stmtBeingBuilt; @@ -4666,7 +4535,7 @@ struct LoweringVisitor { // Default case: lower an entry point just like any other function default: - return visitFunctionDeclBase(entryPointDecl).getDecl()->As<FuncDecl>(); + return visitFunctionDeclBase(entryPointDecl)->As<FuncDecl>(); // For Slang->GLSL translation, we need to lower things from HLSL-style // declarations over to GLSL conventions @@ -4719,11 +4588,12 @@ bool isRewriteRequest( LoweredEntryPoint lowerEntryPoint( - EntryPointRequest* entryPoint, - ProgramLayout* programLayout, - CodeGenTarget target, - ExtensionUsageTracker* extensionUsageTracker, - IRSpecializationState* irSpecializationState) + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + ExtensionUsageTracker* extensionUsageTracker, + IRSpecializationState* irSpecializationState, + TypeLegalizationContext* typeLegalizationContext) { SharedLoweringContext sharedContext; sharedContext.compileRequest = entryPoint->compileRequest; @@ -4732,6 +4602,7 @@ LoweredEntryPoint lowerEntryPoint( sharedContext.target = target; sharedContext.extensionUsageTracker = extensionUsageTracker; sharedContext.irSpecializationState = irSpecializationState; + sharedContext.typeLegalizationContext = typeLegalizationContext; auto translationUnit = entryPoint->getTranslationUnit(); sharedContext.mainModuleDecl = translationUnit->SyntaxNode; @@ -4742,6 +4613,10 @@ LoweredEntryPoint lowerEntryPoint( RefPtr<ModuleDecl> loweredProgram = new ModuleDecl(); sharedContext.loweredProgram = loweredProgram; + typeLegalizationContext->mainModuleDecl = sharedContext.mainModuleDecl; + typeLegalizationContext->outputModuleDecl = loweredProgram; + + LoweringVisitor visitor; visitor.shared = &sharedContext; visitor.parentDecl = loweredProgram; @@ -4753,7 +4628,7 @@ LoweredEntryPoint lowerEntryPoint( // of the existing translation unit declaration. visitor.registerLoweredDecl( - LoweredDecl(loweredProgram), + loweredProgram, translationUnit->SyntaxNode); // We also need to register the lowered program as the lowered version @@ -4761,9 +4636,9 @@ LoweredEntryPoint lowerEntryPoint( // a single module for code generation). for (auto rr : entryPoint->compileRequest->loadedModulesList) { - sharedContext.loweredDecls.Add( + sharedContext.mapOriginalDeclToLowered.Add( rr->moduleDecl, - LoweredDecl(loweredProgram)); + loweredProgram); } // We also want to remember the layout information for diff --git a/source/slang/ast-legalize.h b/source/slang/ast-legalize.h index 9046e8df0..23a150002 100644 --- a/source/slang/ast-legalize.h +++ b/source/slang/ast-legalize.h @@ -42,6 +42,7 @@ namespace Slang struct IRSpecializationState; class ProgramLayout; class TranslationUnitRequest; + struct TypeLegalizationContext; struct LoweredEntryPoint @@ -63,10 +64,11 @@ namespace Slang // Emit code for a single entry point, based on // the input translation unit. LoweredEntryPoint lowerEntryPoint( - EntryPointRequest* entryPoint, - ProgramLayout* programLayout, - CodeGenTarget target, - ExtensionUsageTracker* extensionUsageTracker, - IRSpecializationState* irSpecializationState); + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + ExtensionUsageTracker* extensionUsageTracker, + IRSpecializationState* irSpecializationState, + TypeLegalizationContext* typeLegalizationContext); } #endif diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 4a084c714..b9aa3c027 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -3,6 +3,7 @@ #include "ast-legalize.h" #include "ir-insts.h" +#include "legalize-types.h" #include "lower-to-ir.h" #include "mangle.h" #include "name.h" @@ -1236,13 +1237,6 @@ struct EmitVisitor emitTypeImpl(type->valueType, arg.declarator); } - void visitFilteredTupleType(FilteredTupleType* type, TypeEmitArg const& arg) - { - auto declarator = arg.declarator; - emit(getMangledTypeName(type)); - EmitDeclarator(declarator); - } - void EmitType( RefPtr<Type> type, SourceLoc const& typeLoc, @@ -2952,7 +2946,7 @@ struct EmitVisitor Decl* decl = declRef.getDecl(); if(irDeclSet->Contains(decl)) { - emit(getMangledName(declRef)); + emit(getIRName(declRef)); return; } } @@ -3605,10 +3599,6 @@ struct EmitVisitor // The data type that describes where stuff in the constant buffer should go RefPtr<Type> dataType = parameterGroupType->elementType; - // We expect/require the data type to be a user-defined `struct` type - auto declRefType = dataType->As<DeclRefType>(); - SLANG_RELEASE_ASSERT(declRefType); - // We expect to always have layout information layout = maybeFetchLayout(varDecl, layout); SLANG_RELEASE_ASSERT(layout); @@ -3642,27 +3632,48 @@ struct EmitVisitor emitHLSLRegisterSemantic(*info); Emit("\n{\n"); - if (auto structRef = declRefType->declRef.As<StructDecl>()) - { - int fieldCounter = 0; - for (auto field : getMembersOfType<StructField>(structRef)) + // We expect the data type to be a user-defined `struct` type, + // but it might also be a "filtered" type that represents the + // case where only some fields of the original type are valid + // to appear inside of a `struct`. + if (auto declRefType = dataType->As<DeclRefType>()) + { + if (auto structRef = declRefType->declRef.As<StructDecl>()) { - int fieldIndex = fieldCounter++; + int fieldCounter = 0; - emitVarDeclHead(field); + for (auto field : getMembersOfType<StructField>(structRef)) + { + int fieldIndex = fieldCounter++; + + // Skip fields that have `void` type, since these represent + // declarations that got legalized out of existence. + if(GetType(field)->Equals(getSession()->getVoidType())) + continue; - RefPtr<VarLayout> fieldLayout = structTypeLayout->fields[fieldIndex]; - SLANG_RELEASE_ASSERT(fieldLayout->varDecl.GetName() == field.GetName()); + emitVarDeclHead(field); - // Emit explicit layout annotations for every field - emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); + RefPtr<VarLayout> fieldLayout = structTypeLayout->fields[fieldIndex]; + SLANG_RELEASE_ASSERT(fieldLayout->varDecl.GetName() == field.GetName()); - emitVarDeclInit(field); + // Emit explicit layout annotations for every field + emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); - Emit(";\n"); + emitVarDeclInit(field); + + Emit(";\n"); + } + } + else + { + SLANG_UNEXPECTED("unexpected element type for parameter group"); } } + else + { + SLANG_UNEXPECTED("unexpected element type for parameter group"); + } Emit("}\n"); } @@ -3773,10 +3784,6 @@ struct EmitVisitor // The data type that describes where stuff in the constant buffer should go RefPtr<Type> dataType = parameterGroupType->elementType; - // We expect/require the data type to be a user-defined `struct` type - auto declRefType = dataType->As<DeclRefType>(); - SLANG_RELEASE_ASSERT(declRefType); - // We expect the layout, if present, to be for a structured type... RefPtr<StructTypeLayout> structTypeLayout; if (layout) @@ -3827,26 +3834,54 @@ struct EmitVisitor } Emit("\n{\n"); - if (auto structRef = declRefType->declRef.As<StructDecl>()) + + // We expect the data type to be a user-defined `struct` type, + // but it might also be a "filtered" type that represents the + // case where only some fields of the original type are valid + // to appear inside of a `struct`. + if (auto declRefType = dataType->As<DeclRefType>()) { - for (auto field : getMembersOfType<StructField>(structRef)) + + if (auto structRef = declRefType->declRef.As<StructDecl>()) { - if (structTypeLayout) + int fieldCounter = 0; + for (auto field : getMembersOfType<StructField>(structRef)) { - RefPtr<VarLayout> fieldLayout; - structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout); - // assert(fieldLayout); + int fieldIndex = fieldCounter++; + + // Skip fields that have `void` type, since these represent + // declarations that got legalized out of existence. + if(GetType(field)->Equals(getSession()->getVoidType())) + continue; + + if (structTypeLayout) + { + RefPtr<VarLayout> fieldLayout = structTypeLayout->fields[fieldIndex]; + // assert(fieldLayout); + + // TODO(tfoley): We may want to emit *some* of these, + // some of the time... + // emitGLSLLayoutQualifiers(fieldLayout); + } - // TODO(tfoley): We may want to emit *some* of these, - // some of the time... - // emitGLSLLayoutQualifiers(fieldLayout); - } - EmitVarDeclCommon(field); + EmitVarDeclCommon(field); - Emit(";\n"); + Emit(";\n"); + } + } + else + { + SLANG_UNEXPECTED("unexpected element type for parameter group"); } } + else + { + SLANG_UNEXPECTED("unexpected element type for parameter group"); + } + + + Emit("}"); if( varDecl->getNameLoc().isValid() ) @@ -3890,13 +3925,10 @@ struct EmitVisitor } } - // Skip fields that have been tuple-ified and don't contribute - // any fields of "ordinary" type. - if (auto tupleFieldMod = decl->FindModifier<TupleFieldModifier>()) - { - if (!tupleFieldMod->hasAnyNonTupleFields) - return; - } + // Skip fields that have `void` type, since these may be introduced + // as part of type leglaization. + if(decl->getType()->Equals(getSession()->getVoidType())) + return; RefPtr<VarLayout> layout = arg.layout; layout = maybeFetchLayout(decl, layout); @@ -4129,16 +4161,6 @@ emitDeclImpl(decl, nullptr); String getIRName(Decl* decl) { - // It is a bit ugly, but we need a deterministic way - // to get a name for things when emitting from the IR - // that won't conflict with any keywords, builtins, etc. - // in the target language. - // - // Eventually we should accomplish this by using - // mangled names everywhere, but that complicates things - // when we are also using direct comparison to fxc/glslang - // output for some of our tests. - // // TODO: need a flag to get rid of the step that adds // a prefix here, so that we can get "clean" output // when needed. @@ -4155,7 +4177,27 @@ emitDeclImpl(decl, nullptr); String getIRName(DeclRefBase const& declRef) { - return getIRName(declRef.decl); + // It is a bit ugly, but we need a deterministic way + // to get a name for things when emitting from the IR + // that won't conflict with any keywords, builtins, etc. + // in the target language. + // + // Eventually we should accomplish this by using + // mangled names everywhere, but that complicates things + // when we are also using direct comparison to fxc/glslang + // output for some of our tests. + // + + String name; + if (context->shared->entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_MANGLING) + { + name.append(getText(declRef.GetName())); + } + else + { + name.append(getMangledName(declRef)); + } + return name; } String getGLSLSystemValueName( @@ -6279,9 +6321,12 @@ emitDeclImpl(decl, nullptr); auto fieldLayout = structTypeLayout->fields[fieldIndex++]; + auto fieldType = GetType(ff); + if(fieldType->Equals(getSession()->getVoidType())) + continue; + emitIRVarModifiers(ctx, fieldLayout); - auto fieldType = GetType(ff); emitIRType(ctx, fieldType, getIRName(ff)); emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); @@ -6290,26 +6335,6 @@ emitDeclImpl(decl, nullptr); } } } - else if (auto filteredTupleType = elementType->As<FilteredTupleType>()) - { - auto structTypeLayout = typeLayout.As<StructTypeLayout>(); - assert(structTypeLayout); - - for (auto ee : filteredTupleType->elements) - { - RefPtr<VarLayout> fieldLayout; - structTypeLayout->mapVarToLayout.TryGetValue(ee.fieldDeclRef, fieldLayout); - - emitIRVarModifiers(ctx, fieldLayout); - - auto fieldType = ee.type; - emitIRType(ctx, fieldType, getIRName(ee.fieldDeclRef)); - - emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); - - emit(";\n"); - } - } else { emit("/* unexpected */"); @@ -6376,9 +6401,12 @@ emitDeclImpl(decl, nullptr); auto fieldLayout = structTypeLayout->fields[fieldIndex++]; + auto fieldType = GetType(ff); + if(fieldType->Equals(getSession()->getVoidType())) + continue; + emitIRVarModifiers(ctx, fieldLayout); - auto fieldType = GetType(ff); emitIRType(ctx, fieldType, getIRName(ff)); // emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); @@ -6589,7 +6617,7 @@ emitDeclImpl(decl, nullptr); } Emit("struct "); - emit(declRef.GetName()); + EmitDeclRef(declRef); Emit("\n{\n"); for( auto ff : GetFields(declRef) ) { @@ -6597,6 +6625,11 @@ emitDeclImpl(decl, nullptr); continue; auto fieldType = GetType(ff); + + // Skip `void` fields that might have been created by legalization. + if(fieldType->Equals(getSession()->getVoidType())) + continue; + emitIRType(ctx, fieldType, getIRName(ff)); EmitSemantics(ff.getDecl()); @@ -6655,43 +6688,6 @@ emitDeclImpl(decl, nullptr); ensureStructDecl(ctx, structDeclRef); } } - else if (auto filteredTupleType = type->As<FilteredTupleType>()) - { - // First, ensure that the element types are ready: - for (auto ee : filteredTupleType->elements) - { - if (ee.type) - { - emitIRUsedType(ctx, ee.type); - } - } - - // Now, we want to ensure we've emitted a - // matching `struct` type declaration. - - String mangledName = getMangledTypeName(filteredTupleType); - if (!ctx->shared->irTupleTypes.Contains(mangledName)) - { - ctx->shared->irTupleTypes.Add(mangledName); - - // Emit the damn `struct` decl... - - Emit("struct "); - emit(mangledName); - Emit("\n{\n"); - for( auto ee : filteredTupleType->elements ) - { - if (!ee.type) - continue; - - emitIRType(ctx, ee.type, getIRName(ee.fieldDeclRef)); - - emit(";\n"); - } - Emit("};\n"); - - } - } else {} } @@ -6846,7 +6842,8 @@ StructTypeLayout* getGlobalStructLayout( } void legalizeTypes( - IRModule* module); + TypeLegalizationContext* context, + IRModule* module); String emitEntryPoint( EntryPointRequest* entryPoint, @@ -6937,6 +6934,9 @@ String emitEntryPoint( // Next we will check for case (2a): else if (!(translationUnit->compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR)) { + TypeLegalizationContext typeLegalizationContext; + typeLegalizationContext.session = entryPoint->compileRequest->mSession; + // This case means the user has opted out of using the IR (so we can't use the // cases below), but they either turned on semantic checking *or* imported some // Slang code, so they can't use the case above. @@ -6958,7 +6958,8 @@ String emitEntryPoint( programLayout, target, &sharedContext.extensionUsageTracker, - nullptr); + nullptr, + &typeLegalizationContext); sharedContext.program = lowered.program; // Note that we emit the main body code of the program *before* @@ -6976,6 +6977,9 @@ String emitEntryPoint( // are certain steps that need to be shared. else { + TypeLegalizationContext typeLegalizationContext; + typeLegalizationContext.session = entryPoint->compileRequest->mSession; + // We are going to create a fresh IR module that we will use to // clone any code needed by the user's entry point. IRSpecializationState* irSpecializationState = createIRSpecializationState( @@ -6985,6 +6989,8 @@ String emitEntryPoint( targetRequest); IRModule* irModule = getIRModule(irSpecializationState); + typeLegalizationContext.irModule = irModule; + LoweredEntryPoint lowered; if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) { @@ -7002,7 +7008,8 @@ String emitEntryPoint( programLayout, target, &sharedContext.extensionUsageTracker, - irSpecializationState); + irSpecializationState, + &typeLegalizationContext); } else { @@ -7044,7 +7051,9 @@ String emitEntryPoint( // we need to ensure that the code only uses types // that are legal on the chosen target. // - legalizeTypes(irModule); + legalizeTypes( + &typeLegalizationContext, + irModule); // Debugging output of legalization #if 0 @@ -7053,6 +7062,8 @@ String emitEntryPoint( fprintf(stderr, "###\n"); #endif + sharedContext.irDeclSetForAST = &lowered.irDecls; + // After all of the required optimization and legalization // passes have been performed, we can emit target code from // the IR module. @@ -7065,7 +7076,6 @@ String emitEntryPoint( // that we need to output, we'll do it now. if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) { - sharedContext.irDeclSetForAST = &lowered.irDecls; visitor.EmitDeclsInContainer(lowered.program); } diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 7fad2ccc5..6e909106d 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -1,254 +1,22 @@ // ir-legalize-types.cpp -// This file implements a pass that takes IR -// that has been fully specialized (no more -// generics/interfaces needing to be specialized -// away) and replaces any types that can't actually -// be used as-is on the target. +// This file implements type legalization for the IR. +// It uses the core legalization logic in +// `legalize-types.{h,cpp}` to decide what to do with +// the types, while this file handles the actual +// rewriting of the IR to use the new types. // -// The particular case we are focused on is -// aggregate types (e.g., `struct` types) that -// contain resources (textures, samplers, etc.) -// or that mix resources and ordinary "uniform" -// data. +// This pass should only be applied to IR that has been +// fully specialized (no more generics/interfaces), so +// that the concrete type of everything is known. #include "ir.h" #include "ir-insts.h" +#include "legalize-types.h" namespace Slang { -struct LegalTypeImpl : RefObject -{ -}; -struct ImplicitDerefType; -struct TuplePseudoType; -struct PairPseudoType; -struct PairInfo; - -struct LegalType -{ - enum class Flavor - { - // Nothing: a NULL type - none, - - // A simple type that can be represented directly as a `Type` - simple, - - // Logically, we have a pointer-like type, but we are - // going to represnet it as the pointed-to type - implicitDeref, - - // A compound type was broken apart into its constituent fields, - // so a tuple "pseduo-type" is being used to collect - // those fields together. - tuple, - - // A type has to get split into "ordinary" and "special" parts, - // each of which will be represented with its own `LegalType`. - pair, - }; - - Flavor flavor = Flavor::none; - RefPtr<RefObject> obj; - - static LegalType simple(Type* type) - { - LegalType result; - result.flavor = Flavor::simple; - result.obj = type; - return result; - } - - RefPtr<Type> getSimple() - { - assert(flavor == Flavor::simple); - return obj.As<Type>(); - } - - static LegalType implicitDeref( - LegalType const& valueType); - - RefPtr<ImplicitDerefType> getImplicitDeref() - { - assert(flavor == Flavor::implicitDeref); - return obj.As<ImplicitDerefType>(); - } - - static LegalType tuple( - RefPtr<TuplePseudoType> tupleType); - - RefPtr<TuplePseudoType> getTuple() - { - assert(flavor == Flavor::tuple); - return obj.As<TuplePseudoType>(); - } - - static LegalType pair( - RefPtr<PairPseudoType> pairType); - - static LegalType pair( - RefPtr<Type> ordinaryType, - LegalType const& specialType, - RefPtr<PairInfo> pairInfo); - - RefPtr<PairPseudoType> getPair() - { - assert(flavor == Flavor::pair); - return obj.As<PairPseudoType>(); - } -}; - -struct ImplicitDerefType : LegalTypeImpl -{ - LegalType valueType; -}; - -LegalType LegalType::implicitDeref( - LegalType const& valueType) -{ - RefPtr<ImplicitDerefType> obj = new ImplicitDerefType(); - obj->valueType = valueType; - - LegalType result; - result.flavor = Flavor::implicitDeref; - result.obj = obj; - return result; -} - -// Represents the pseudo-type for a compound type -// that had to be broken apart because it contained -// one or more fields of types that shouldn't be -// allowed in aggregates. -// -// A tuple pseduo-type will have an element for -// each field of the original type, that represents -// the legalization of that field's type. -// -// It optionally also contains an "ordinary" type -// that packs together any per-field data that -// itself has (or contains) an ordinary type. -struct TuplePseudoType : LegalTypeImpl -{ - // Represents one element of the tuple pseudo-type - struct Element - { - // The field that this element replaces - DeclRef<VarDeclBase> fieldDeclRef; - - // The legalized type of the element - LegalType type; - }; - - // All of the elements of the tuple pseduo-type. - List<Element> elements; -}; - -LegalType LegalType::tuple( - RefPtr<TuplePseudoType> tupleType) -{ - LegalType result; - result.flavor = Flavor::tuple; - result.obj = tupleType; - return result; -} - -struct PairInfo : RefObject -{ - typedef unsigned int Flags; - enum - { - kFlag_hasOrdinary = 0x1, - kFlag_hasSpecial = 0x2, - }; - - struct Element - { - // The field the element represents - DeclRef<Decl> fieldDeclRef; - - // The conceptual type of the field. - // If both the `hasOrdinary` and - // `hasSpecial` bits are set, then - // this is expected to be a - // `LegalType::Flavor::pair` - LegalType type; - - // Is the value represented on - // the ordinary side, the special - // side, or both? - Flags flags; - }; - - // For a pair type or value, we need to track - // which fields are on which side(s). - List<Element> elements; - - Element* findElement(DeclRef<Decl> const& fieldDeclRef) - { - for (auto& ee : elements) - { - if(ee.fieldDeclRef.Equals(fieldDeclRef)) - return ⅇ - } - return nullptr; - } -}; - -struct PairPseudoType : LegalTypeImpl -{ - // Any field(s) with ordinary types will - // get captured here, as a completely - // standard AST-level type. - RefPtr<Type> ordinaryType; - - // Any fields with "special" (not ordinary) - // types will get captured here (usually - // with a tuple). - LegalType specialType; - - RefPtr<PairInfo> pairInfo; -}; - -LegalType LegalType::pair( - RefPtr<PairPseudoType> pairType) -{ - LegalType result; - result.flavor = Flavor::pair; - result.obj = pairType; - return result; -} - -LegalType LegalType::pair( - RefPtr<Type> ordinaryType, - LegalType const& specialType, - RefPtr<PairInfo> pairInfo) -{ - // Handle some special cases for when - // one or the other of the types isn't - // actually used. - - if (!ordinaryType) - { - // There was nothing ordinary. - return specialType; - } - - if (specialType.flavor == LegalType::Flavor::none) - { - return LegalType::simple(ordinaryType); - } - - // There were both ordinary and special fields, - // and so we need to handle them here. - - RefPtr<PairPseudoType> obj = new PairPseudoType(); - obj->ordinaryType = ordinaryType; - obj->specialType = specialType; - obj->pairInfo = pairInfo; - return LegalType::pair(obj); -} struct LegalValImpl : RefObject @@ -390,12 +158,15 @@ LegalVal LegalVal::getImplicitDeref() } -struct TypeLegalizationContext +struct IRTypeLegalizationContext { Session* session; IRModule* module; IRBuilder* builder; + /// Context to use for underlying (non-IR) type legalization. + TypeLegalizationContext* typeLegalizationContext; + // When inserting new globals, put them before this one. IRGlobalValue* insertBeforeGlobal = nullptr; @@ -411,633 +182,31 @@ struct TypeLegalizationContext }; static void registerLegalizedValue( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRValue* irValue, LegalVal const& legalVal) { context->mapValToLegalVal.Add(irValue, legalVal); } - -static bool isResourceType(Type* type) -{ - while (auto arrayType = type->As<ArrayExpressionType>()) - { - type = arrayType->baseType; - } - - if (auto textureTypeBase = type->As<TextureTypeBase>()) - { - return true; - } - else if (auto samplerType = type->As<SamplerStateType>()) - { - return true; - } - - // TODO: need more comprehensive coverage here - - return false; -} - -static LegalType legalizeType( - TypeLegalizationContext* context, - Type* type); - -// Helper type for legalization of aggregate types -// that might need to be turned into tuple pseudo-types. -struct TupleTypeBuilder -{ - TypeLegalizationContext* context; - RefPtr<Type> type; - - List<FilteredTupleType::Element> ordinaryElements; - List<TuplePseudoType::Element> specialElements; - - List<PairInfo::Element> pairElements; - - // Did we have any fields that forced us to change - // the actual type away from the declared type? - bool anyComplex = false; - - // Did we have any fields that actually required - // storage in the "special" part of things? - bool anySpecial = false; - - // Did we have any fields that actually used ordinary storage? - bool anyOrdinary = false; - - // Add a field to the (pseudo-)type we are building - void addField( - DeclRef<VarDeclBase> fieldDeclRef, - LegalType legalFieldType, - LegalType legalLeafType, - bool isResource) - { - RefPtr<Type> ordinaryType; - LegalType specialType; - RefPtr<PairInfo> elementPairInfo; - switch (legalLeafType.flavor) - { - case LegalType::Flavor::simple: - { - // We need to add an actual field, but we need - // to check if it is a resource type to know - // whether it should go in the "ordinary" list or not. - if (!isResource) - { - ordinaryType = legalLeafType.getSimple(); - } - else - { - specialType = legalFieldType; - } - } - break; - - case LegalType::Flavor::implicitDeref: - { - // TODO: we may want to say that any use - // of `implicitDeref` puts the entire thing - // into the "special" category, rather than - // try to look under the hood... - - anyComplex = true; - - // We want to recursively add data - // based on the unwrapped type. - // - // Note: this assumes we can't have a tuple - // or a pair "under" an `implicitDeref`, so - // we'll need to ensure that elsewhere. - addField( - fieldDeclRef, - legalFieldType, - legalLeafType.getImplicitDeref()->valueType, - isResource); - return; - } - break; - - case LegalType::Flavor::pair: - { - // The field's type had both special and non-special parts - auto pairType = legalLeafType.getPair(); - ordinaryType = pairType->ordinaryType; - specialType = pairType->specialType; - elementPairInfo = pairType->pairInfo; - } - break; - - case LegalType::Flavor::tuple: - { - // A tuple always represents "special" data - specialType = legalFieldType; - } - break; - - default: - SLANG_UNEXPECTED("unknown legal type flavor"); - break; - } - - - PairInfo::Element pairElement; - pairElement.flags = 0; - pairElement.fieldDeclRef = fieldDeclRef; - - if (ordinaryType) - { - anyOrdinary = true; - pairElement.flags |= PairInfo::kFlag_hasOrdinary; - - FilteredTupleType::Element ordinaryElement; - ordinaryElement.fieldDeclRef = fieldDeclRef; - ordinaryElement.type = ordinaryType; - ordinaryElements.Add(ordinaryElement); - } - - if (specialType.flavor != LegalType::Flavor::none) - { - anySpecial = true; - anyComplex = true; - pairElement.flags |= PairInfo::kFlag_hasSpecial; - - TuplePseudoType::Element specialElement; - specialElement.fieldDeclRef = fieldDeclRef; - specialElement.type = specialType; - specialElements.Add(specialElement); - } - - pairElement.type = LegalType::pair(ordinaryType, specialType, elementPairInfo); - pairElements.Add(pairElement); - } - - // Add a field to the (pseudo-)type we are building - void addField( - DeclRef<VarDeclBase> fieldDeclRef) - { - // Skip `static` fields. - if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) - return; - - auto fieldType = GetType(fieldDeclRef); - - bool isResourceField = isResourceType(fieldType); - - auto legalFieldType = legalizeType(context, fieldType); - addField( - fieldDeclRef, - legalFieldType, - legalFieldType, - isResourceField); - } - - LegalType getResult() - { - // If we didn't see anything "special" - // then we can use the type as-is. - // we can conceivably just use the type as-is - // - // TODO: this might be a good place to turn - // a reference to a generic `struct` type into - // a concrete non-generic type so that downstream - // codegen doesn't have to deal with generics... - // - // TODO: In fact, why not just fully replace - // all aggregate types here with some structural - // types defined in the IR? - if (!anyComplex) - { - return LegalType::simple(type); - } - - // If there were any "ordinary" fields along the way, - // then we need to collect them into a type to - // represent the ordinary part of things. - // - RefPtr<Type> ordinaryType; - if (anyOrdinary) - { - RefPtr<FilteredTupleType> ordinaryTypeImpl = new FilteredTupleType(); - ordinaryTypeImpl->setSession(context->session); - ordinaryTypeImpl->originalType = type; - ordinaryTypeImpl->elements = ordinaryElements; - ordinaryType = ordinaryTypeImpl; - } - - LegalType specialType; - if (anySpecial) - { - RefPtr<TuplePseudoType> specialTuple = new TuplePseudoType(); - specialTuple->elements = specialElements; - specialType = LegalType::tuple(specialTuple); - } - - RefPtr<PairInfo> pairInfo; - if (anyOrdinary && anySpecial) - { - pairInfo = new PairInfo(); - pairInfo->elements = pairElements; - } - - return LegalType::pair(ordinaryType, specialType, pairInfo); - } - -}; - -static RefPtr<Type> createBuiltinGenericType( - TypeLegalizationContext* context, - DeclRef<Decl> const& typeDeclRef, - RefPtr<Type> elementType) -{ - // We are going to take the type for the original - // decl-ref and construct a new one that uses - // our new element type as its parameter. - // - // TODO: we should have library code to make - // manipulations like this way easier. - - RefPtr<GenericSubstitution> oldGenericSubst = getGenericSubstitution( - typeDeclRef.substitutions); - SLANG_ASSERT(oldGenericSubst); - - RefPtr<GenericSubstitution> newGenericSubst = new GenericSubstitution(); - - newGenericSubst->outer = oldGenericSubst->outer; - newGenericSubst->genericDecl = oldGenericSubst->genericDecl; - newGenericSubst->args = oldGenericSubst->args; - newGenericSubst->args[0] = elementType; - - auto newDeclRef = DeclRef<Decl>( - typeDeclRef.getDecl(), - newGenericSubst); - - auto newType = DeclRefType::Create( - context->session, - newDeclRef); - - return newType; -} - -// Create a uniform buffer type with a given legalized -// element type. -static LegalType createLegalUniformBufferType( - TypeLegalizationContext* context, - DeclRef<Decl> const& typeDeclRef, - LegalType legalElementType) -{ - switch (legalElementType.flavor) - { - case LegalType::Flavor::simple: - { - // Easy case: we just have a simple element type, - // so we want to create a uniform buffer that wraps it. - return LegalType::simple(createBuiltinGenericType( - context, - typeDeclRef, - legalElementType.getSimple())); - } - break; - - case LegalType::Flavor::implicitDeref: - { - // This is actually an annoying case, because - // we are being asked to convert, e.g.,: - // - // cbuffer Foo { ParameterBlock<Bar> bar; } - // - // into the equivalent of: - // - // cbuffer Foo { Bar bar; } - // - // Which would really require a new `LegalType` that - // would reprerent a resource type with a modified - // element type. - // - // I'm going to attempt to hack this for now. - return LegalType::implicitDeref(createLegalUniformBufferType( - context, - typeDeclRef, - legalElementType.getImplicitDeref()->valueType)); - } - break; - - case LegalType::Flavor::pair: - { - // We assume that the "ordinary" part of things - // will get wrapped in a constant-buffer type, - // and the "special" part needs to be wrapped - // with an `implicitDeref`. - auto pairType = legalElementType.getPair(); - - auto ordinaryType = createBuiltinGenericType( - context, - typeDeclRef, - pairType->ordinaryType); - auto specialType = LegalType::implicitDeref(pairType->specialType); - - return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - // if we have a tuple type, then it must be representing - // the fields that can't be stored in a buffer anyway, - // so we just need to wrap each of them in an `implicitDeref` - - auto elementPseudoTupleType = legalElementType.getTuple(); - - RefPtr<TuplePseudoType> bufferPseudoTupleType = new TuplePseudoType(); - - // Wrap all the pseudo-tuple elements with `implicitDeref`, - // since they used to be inside a tuple, but aren't any more. - for (auto ee : elementPseudoTupleType->elements) - { - TuplePseudoType::Element newElement; - - newElement.fieldDeclRef = ee.fieldDeclRef; - newElement.type = LegalType::implicitDeref(ee.type); - - bufferPseudoTupleType->elements.Add(newElement); - } - - return LegalType::tuple(bufferPseudoTupleType); - } - break; - - default: - SLANG_UNEXPECTED("unknown legal type flavor"); - UNREACHABLE_RETURN(LegalType()); - break; - } -} - -static LegalType createLegalUniformBufferType( - TypeLegalizationContext* context, - UniformParameterGroupType* uniformBufferType, - LegalType legalElementType) -{ - return createLegalUniformBufferType( - context, - uniformBufferType->declRef, - legalElementType); -} - -// Create a pointer type with a given legalized value type. -static LegalType createLegalPtrType( - TypeLegalizationContext* context, - DeclRef<Decl> const& typeDeclRef, - LegalType legalValueType) -{ - switch (legalValueType.flavor) - { - case LegalType::Flavor::simple: - { - // Easy case: we just have a simple element type, - // so we want to create a uniform buffer that wraps it. - return LegalType::simple(createBuiltinGenericType( - context, - typeDeclRef, - legalValueType.getSimple())); - } - break; - - case LegalType::Flavor::implicitDeref: - { - // We are being asked to create a pointer type to something - // that is implicitly dereferenced, meaning we had: - // - // Ptr(PtrLink(T)) - // - // and now are being asked to make: - // - // Ptr(implicitDeref(LegalT)) - // - // So it seems like we can just create: - // - // implicitDeref(Ptr(LegalT)) - // - // and nobody should really be able to tell the difference, right? - return LegalType::implicitDeref(createLegalPtrType( - context, - typeDeclRef, - legalValueType.getImplicitDeref()->valueType)); - } - break; - - case LegalType::Flavor::pair: - { - // We just need to pointer-ify both sides of the pair. - auto pairType = legalValueType.getPair(); - - auto ordinaryType = createBuiltinGenericType( - context, - typeDeclRef, - pairType->ordinaryType); - auto specialType = createLegalPtrType( - context, - typeDeclRef, - pairType->specialType); - - return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - // Wrap each of the tuple elements up as a pointer. - auto valuePseudoTupleType = legalValueType.getTuple(); - - RefPtr<TuplePseudoType> ptrPseudoTupleType = new TuplePseudoType(); - - // Wrap all the pseudo-tuple elements with `implicitDeref`, - // since they used to be inside a tuple, but aren't any more. - for (auto ee : valuePseudoTupleType->elements) - { - TuplePseudoType::Element newElement; - - newElement.fieldDeclRef = ee.fieldDeclRef; - newElement.type = createLegalPtrType( - context, - typeDeclRef, - ee.type); - - ptrPseudoTupleType->elements.Add(newElement); - } - - return LegalType::tuple(ptrPseudoTupleType); - } - break; - - default: - SLANG_UNEXPECTED("unknown legal type flavor"); - UNREACHABLE_RETURN(LegalType()); - break; - } -} - -// Legalize a type, including any nested types -// that it transitively contains. -static LegalType legalizeType( - TypeLegalizationContext* context, - Type* type) -{ - if (auto parameterBlockType = type->As<ParameterBlockType>()) - { - // We basically legalize the `ParameterBlock<T>` type - // over to `T`. In order to represent this preoperly, - // we need to be careful to wrap it up in a way that - // tells us to eliminate downstream deferences... - - auto legalElementType = legalizeType(context, - parameterBlockType->getElementType()); - return LegalType::implicitDeref(legalElementType); - } - else if (auto uniformBufferType = type->As<UniformParameterGroupType>()) - { - // We have a `ConstantBuffer<T>` or `TextureBuffer<T>` or - // other pointer-like type that represents uniform parameters. - // We need to pull any resource-type fields out of it, but - // leave the non-resource fields where they are. - - // Legalize the element type to see what we are working with. - auto legalElementType = legalizeType(context, - uniformBufferType->getElementType()); - - switch (legalElementType.flavor) - { - case LegalType::Flavor::simple: - return LegalType::simple(type); - - default: - return createLegalUniformBufferType( - context, - uniformBufferType, - legalElementType); - } - - } - else if (isResourceType(type)) - { - // We assume that any resource types not handled above - // are legal as-is. - return LegalType::simple(type); - } - else if (type->As<BasicExpressionType>()) - { - return LegalType::simple(type); - } - else if (type->As<VectorExpressionType>()) - { - return LegalType::simple(type); - } - else if (type->As<MatrixExpressionType>()) - { - return LegalType::simple(type); - } - else if (auto ptrType = type->As<PtrTypeBase>()) - { - auto legalValueType = legalizeType(context, ptrType->getValueType()); - return createLegalPtrType(context, ptrType->declRef, legalValueType); - } - else if (auto declRefType = type->As<DeclRefType>()) - { - auto declRef = declRefType->declRef; - if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) - { - // Look at the (non-static) fields, and - // see if anything needs to be cleaned up. - // The things that need to be "cleaned up" for - // our purposes are: - // - // - Fields of resource type, or any other future - // type we run into that isn't allowed in - // aggregates for at least some targets - // - // - Fields with types that themselves had to - // get legalized. - // - // If we don't run into any of these, we - // can just use the type as-is. Hooray! - // - // Otherwise, we are effectively going to split - // the type apart and create a `TuplePseudoType`. - // Every field of the original type will be - // represented as an element of this pseudo-type. - // Each element will record its `LegalType`, - // and the original field that it was created from. - // An element will also track whether it contains - // any "ordinary" data, and if so, it will remember - // an element index in a real (AST-level, non-pseudo) - // `TupleType` that is used to bundle together - // such fields. - // - // Storing all the simple fields together like this - // obviously adds complexity to the legalization - // pass, but it has important benefits: - // - // - It avoids creating functions with a very large - // number of parameters (when passing a structure - // with many fields), which might confuse downstream - // compilers. - // - // - It avoids applying AOS->SOA conversion to fields - // that don't actually need it, which is basically - // required if we want type layout to work. - // - // - It ensures that we can actually construct a - // constant-buffer type that wraps a legalized - // aggregate type; the ordinary fields will get - // placed inside a new constant-buffer type, - // while the special ones will get left outside. - // - - TupleTypeBuilder builder; - builder.context = context; - builder.type = type; - - - for (auto ff : getMembersOfType<StructField>(aggTypeDeclRef)) - { - builder.addField(ff); - } - - return builder.getResult(); - } - - // TODO: for other declaration-reference types, we really - // need to legalize the types used in substitutions, and - // signal an error if any of them turn out to be non-simple. - // - // The limited cases of types that can handle having non-simple - // types as generic arguments all need to be special-cased here. - // (For example, we can't handle `Texture2D<SomeStructWithTexturesInIt>`. - // - } - - return LegalType::simple(type); -} - -// Represents the "chain" of declarations that -// were followed to get to a variable that we -// are now declaring as a leaf variable. -struct LegalVarChain -{ - LegalVarChain* next; - VarLayout* varLayout; -}; - static LegalVal declareVars( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, LegalVarChain* varChain); +static LegalType legalizeType( + IRTypeLegalizationContext* context, + Type* type) +{ + return legalizeType(context->typeLegalizationContext, type); +} + // Legalize a type, and then expect it to // result in a simple type. static RefPtr<Type> legalizeSimpleType( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, Type* type) { auto legalType = legalizeType(context, type); @@ -1056,7 +225,7 @@ static RefPtr<Type> legalizeSimpleType( // Take a value that is being used as an operand, // and turn it into the equivalent legalized value. static LegalVal legalizeOperand( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRValue* irValue) { LegalVal legalVal; @@ -1111,7 +280,7 @@ static void getArgumentValues( } static LegalVal legalizeCall( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRCall* callInst) { // TODO: implement legalization of non-simple return types @@ -1130,7 +299,7 @@ static LegalVal legalizeCall( } static LegalVal legalizeLoad( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, LegalVal legalPtrVal) { switch (legalPtrVal.flavor) @@ -1186,7 +355,7 @@ static LegalVal legalizeLoad( } static LegalVal legalizeStore( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, LegalVal legalPtrVal, LegalVal legalVal) { @@ -1241,19 +410,13 @@ static LegalVal legalizeStore( } static LegalVal legalizeFieldAddress( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, - LegalVal legalFieldOperand) + DeclRef<Decl> fieldDeclRef) { auto builder = context->builder; - // We don't expect any legalization to affect - // the "field" argument. - auto fieldOperand = legalFieldOperand.getSimple(); - assert(fieldOperand->op == kIROp_decl_ref); - auto fieldDeclRef = ((IRDeclRef*)fieldOperand)->declRef; - switch (legalPtrOperand.flavor) { case LegalVal::Flavor::simple: @@ -1261,7 +424,7 @@ static LegalVal legalizeFieldAddress( builder->emitFieldAddress( type.getSimple(), legalPtrOperand.getSimple(), - fieldOperand)); + builder->getDeclRefVal(fieldDeclRef))); case LegalVal::Flavor::pair: { @@ -1286,7 +449,7 @@ static LegalVal legalizeFieldAddress( { auto fieldPairType = type.getPair(); fieldPairInfo = fieldPairType->pairInfo; - ordinaryType = LegalType::simple(fieldPairType->ordinaryType); + ordinaryType = fieldPairType->ordinaryType; specialType = fieldPairType->specialType; } @@ -1295,12 +458,27 @@ static LegalVal legalizeFieldAddress( if (pairElement->flags & PairInfo::kFlag_hasOrdinary) { - ordinaryVal = legalizeFieldAddress(context, ordinaryType, pairVal->ordinaryVal, legalFieldOperand); + // Note: the ordinary side of the pair is expected + // to be a filtered `struct` type, and so it will + // have different field declarations than the + // oridinal type. The element of the `PairInfo` + // structure stores the correct field decl-ref to use + // as `ordinaryFieldDeclRef`. + + ordinaryVal = legalizeFieldAddress( + context, + ordinaryType, + pairVal->ordinaryVal, + pairElement->ordinaryFieldDeclRef); } if (pairElement->flags & PairInfo::kFlag_hasSpecial) { - specialVal = legalizeFieldAddress(context, specialType, pairVal->specialVal, legalFieldOperand); + specialVal = legalizeFieldAddress( + context, + specialType, + pairVal->specialVal, + fieldDeclRef); } return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); } @@ -1335,8 +513,27 @@ static LegalVal legalizeFieldAddress( } } +static LegalVal legalizeFieldAddress( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + LegalVal legalFieldOperand) +{ + // We don't expect any legalization to affect + // the "field" argument. + auto fieldOperand = legalFieldOperand.getSimple(); + assert(fieldOperand->op == kIROp_decl_ref); + auto fieldDeclRef = ((IRDeclRef*)fieldOperand)->declRef; + + return legalizeFieldAddress( + context, + type, + legalPtrOperand, + fieldDeclRef); +} + static LegalVal legalizeInst( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRInst* inst, LegalType type, LegalVal const* args) @@ -1370,7 +567,7 @@ RefPtr<VarLayout> findVarLayout(IRValue* value) } static LegalVal legalizeLocalVar( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRVar* irLocalVar) { // Legalize the type for the variable's value @@ -1424,7 +621,7 @@ static LegalVal legalizeLocalVar( } static LegalVal legalizeInst( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRInst* inst) { if (inst->op == kIROp_Var) @@ -1512,7 +709,7 @@ static void addParamType(IRFuncType * ftype, LegalType t) case LegalType::Flavor::pair: { auto pairInfo = t.getPair(); - addParamType(ftype, LegalType::simple(pairInfo->ordinaryType)); + addParamType(ftype, pairInfo->ordinaryType); addParamType(ftype, pairInfo->specialType); } break; @@ -1529,7 +726,7 @@ static void addParamType(IRFuncType * ftype, LegalType t) } static void legalizeFunc( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRFunc* irFunc) { // Overwrite the function's type with @@ -1606,61 +803,13 @@ static void legalizeFunc( } static LegalVal declareSimpleVar( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IROp op, Type* type, TypeLayout* typeLayout, LegalVarChain* varChain) { - RefPtr<VarLayout> varLayout; - if (typeLayout) - { - // We need to construct a layout for the new variable - // that reflects both the type we have given it, as - // well as all the offset information that has accumulated - // along the chain of parent variables. - - // TODO: this logic needs to propagate through semantics... - - varLayout = new VarLayout(); - varLayout->typeLayout = typeLayout; - - for (auto rr : typeLayout->resourceInfos) - { - auto resInfo = varLayout->findOrAddResourceInfo(rr.kind); - - for (auto vv = varChain; vv; vv = vv->next) - { - if (auto parentResInfo = vv->varLayout->FindResourceInfo(rr.kind)) - { - resInfo->index += parentResInfo->index; - resInfo->space += parentResInfo->space; - } - } - } - - // Some of the parent variables might actually contain offsets - // to the `space` or `set` of the field, and we need to apply - // those to all the nested resource infos. - for (auto vv = varChain; vv; vv = vv->next) - { - auto parentSpaceInfo = vv->varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace); - if (!parentSpaceInfo) - continue; - - for (auto& rr : varLayout->resourceInfos) - { - if (rr.kind == LayoutResourceKind::RegisterSpace) - { - rr.index += parentSpaceInfo->index; - } - else - { - rr.space += parentSpaceInfo->index; - } - } - } - } + RefPtr<VarLayout> varLayout = createVarLayout(varChain, typeLayout); DeclRef<VarDeclBase> varDeclRef; if (varChain) @@ -1733,39 +882,8 @@ static LegalVal declareSimpleVar( return legalVarVal; } -static RefPtr<TypeLayout> getDerefTypeLayout( - TypeLayout* typeLayout) -{ - if (!typeLayout) - return nullptr; - - if (auto parameterGroupTypeLayout = dynamic_cast<ParameterGroupTypeLayout*>(typeLayout)) - { - return parameterGroupTypeLayout->elementTypeLayout; - } - - return typeLayout; -} - -static RefPtr<VarLayout> getFieldLayout( - TypeLayout* typeLayout, - DeclRef<VarDeclBase> fieldDeclRef) -{ - if (!typeLayout) - return nullptr; - - if (auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout)) - { - RefPtr<VarLayout> fieldLayout; - if (structTypeLayout->mapVarToLayout.TryGetValue(fieldDeclRef.getDecl(), fieldLayout)) - return fieldLayout; - } - - return nullptr; -} - static LegalVal declareVars( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, @@ -1798,7 +916,7 @@ static LegalVal declareVars( case LegalType::Flavor::pair: { auto pairType = type.getPair(); - auto ordinaryVal = declareVars(context, op, LegalType::simple(pairType->ordinaryType), typeLayout, varChain); + auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain); auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain); return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo); } @@ -1852,7 +970,7 @@ static LegalVal declareVars( } static void legalizeGlobalVar( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRGlobalVar* irGlobalVar) { // Legalize the type for the variable's value @@ -1899,7 +1017,7 @@ static void legalizeGlobalVar( } static void legalizeGlobalValue( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRGlobalValue* irValue) { switch (irValue->op) @@ -1923,7 +1041,7 @@ static void legalizeGlobalValue( } static void legalizeTypes( - TypeLegalizationContext* context) + IRTypeLegalizationContext* context) { auto module = context->module; for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) @@ -1934,7 +1052,8 @@ static void legalizeTypes( void legalizeTypes( - IRModule* module) + TypeLegalizationContext* typeLegalizationContext, + IRModule* module) { auto session = module->session; @@ -1950,13 +1069,15 @@ void legalizeTypes( builder->sharedBuilder = sharedBuilder; - TypeLegalizationContext contextStorage; + IRTypeLegalizationContext contextStorage; auto context = &contextStorage; context->session = session; context->module = module; context->builder = builder; + context->typeLegalizationContext = typeLegalizationContext; + legalizeTypes(context); } diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 3a8aabd85..d35ffe2d1 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -3764,6 +3764,14 @@ namespace Slang String const& mangledName, IRGlobalValue* originalVal) { + // Check if we've already cloned this value, for the case where + // an original value has already been established. + IRValue* clonedVal = nullptr; + if( originalVal && context->getClonedValues().TryGetValue(originalVal, clonedVal) ) + { + return (IRGlobalValue*) clonedVal; + } + if(mangledName.Length() == 0) { // If there is no mangled name, then we assume this is a local symbol, @@ -3779,6 +3787,9 @@ namespace Slang RefPtr<IRSpecSymbol> sym; if( !context->getSymbols().TryGetValue(mangledName, sym) ) { + if(!originalVal) + return nullptr; + // This shouldn't happen! SLANG_UNEXPECTED("no matching values registered"); UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, nullptr)); @@ -3798,6 +3809,14 @@ namespace Slang bestVal = newVal; } + // Check if we've already cloned this value, for the case where + // we didn't have an original value (just a name), but we've + // now found a representative value. + if( !originalVal && context->getClonedValues().TryGetValue(bestVal, clonedVal) ) + { + return (IRGlobalValue*) clonedVal; + } + return cloneGlobalValueImpl(context, bestVal, sym); } diff --git a/source/slang/legalize-types.cpp b/source/slang/legalize-types.cpp new file mode 100644 index 000000000..510b9acd3 --- /dev/null +++ b/source/slang/legalize-types.cpp @@ -0,0 +1,1086 @@ +// legalize-types.cpp +#include "legalize-types.h" + +namespace Slang +{ + +LegalType LegalType::implicitDeref( + LegalType const& valueType) +{ + RefPtr<ImplicitDerefType> obj = new ImplicitDerefType(); + obj->valueType = valueType; + + LegalType result; + result.flavor = Flavor::implicitDeref; + result.obj = obj; + return result; +} + +LegalType LegalType::tuple( + RefPtr<TuplePseudoType> tupleType) +{ + LegalType result; + result.flavor = Flavor::tuple; + result.obj = tupleType; + return result; +} + +LegalType LegalType::pair( + RefPtr<PairPseudoType> pairType) +{ + LegalType result; + result.flavor = Flavor::pair; + result.obj = pairType; + return result; +} + +LegalType LegalType::pair( + LegalType const& ordinaryType, + LegalType const& specialType, + RefPtr<PairInfo> pairInfo) +{ + // Handle some special cases for when + // one or the other of the types isn't + // actually used. + + if (ordinaryType.flavor == LegalType::Flavor::none) + { + // There was nothing ordinary. + return specialType; + } + + if (specialType.flavor == LegalType::Flavor::none) + { + return ordinaryType; + } + + // There were both ordinary and special fields, + // and so we need to handle them here. + + RefPtr<PairPseudoType> obj = new PairPseudoType(); + obj->ordinaryType = ordinaryType; + obj->specialType = specialType; + obj->pairInfo = pairInfo; + return LegalType::pair(obj); +} + +// + +static bool isResourceType(Type* type) +{ + while (auto arrayType = type->As<ArrayExpressionType>()) + { + type = arrayType->baseType; + } + + if (auto resourceTypeBase = type->As<ResourceTypeBase>()) + { + return true; + } + else if (auto builtinGenericType = type->As<BuiltinGenericType>()) + { + return true; + } + else if (auto pointerLikeType = type->As<PointerLikeType>()) + { + return true; + } + else if (auto samplerType = type->As<SamplerStateType>()) + { + return true; + } + + // TODO: need more comprehensive coverage here + + return false; +} + +ModuleDecl* findModuleForDecl( + Decl* decl) +{ + for (auto dd = decl; dd; dd = dd->ParentDecl) + { + if (auto moduleDecl = dynamic_cast<ModuleDecl*>(dd)) + return moduleDecl; + } + return nullptr; +} + + +// Helper type for legalization of aggregate types +// that might need to be turned into tuple pseudo-types. +struct TupleTypeBuilder +{ + TypeLegalizationContext* context; + RefPtr<Type> type; + DeclRef<AggTypeDecl> typeDeclRef; + + struct OrdinaryElement + { + DeclRef<VarDeclBase> fieldDeclRef; + RefPtr<Type> type; + }; + + + List<OrdinaryElement> ordinaryElements; + List<TuplePseudoType::Element> specialElements; + + List<PairInfo::Element> pairElements; + + // Did we have any fields that forced us to change + // the actual type away from the declared type? + bool anyComplex = false; + + // Did we have any fields that actually required + // storage in the "special" part of things? + bool anySpecial = false; + + // Did we have any fields that actually used ordinary storage? + bool anyOrdinary = false; + + // Add a field to the (pseudo-)type we are building + void addField( + DeclRef<VarDeclBase> fieldDeclRef, + LegalType legalFieldType, + LegalType legalLeafType, + bool isResource) + { + LegalType ordinaryType; + LegalType specialType; + RefPtr<PairInfo> elementPairInfo; + switch (legalLeafType.flavor) + { + case LegalType::Flavor::simple: + { + // We need to add an actual field, but we need + // to check if it is a resource type to know + // whether it should go in the "ordinary" list or not. + if (!isResource) + { + ordinaryType = legalLeafType; + } + else + { + specialType = legalFieldType; + } + } + break; + + case LegalType::Flavor::implicitDeref: + { + // TODO: we may want to say that any use + // of `implicitDeref` puts the entire thing + // into the "special" category, rather than + // try to look under the hood... + + anyComplex = true; + + // We want to recursively add data + // based on the unwrapped type. + // + // Note: this assumes we can't have a tuple + // or a pair "under" an `implicitDeref`, so + // we'll need to ensure that elsewhere. + addField( + fieldDeclRef, + legalFieldType, + legalLeafType.getImplicitDeref()->valueType, + isResource); + return; + } + break; + + case LegalType::Flavor::pair: + { + // The field's type had both special and non-special parts + auto pairType = legalLeafType.getPair(); + ordinaryType = pairType->ordinaryType; + specialType = pairType->specialType; + elementPairInfo = pairType->pairInfo; + } + break; + + case LegalType::Flavor::tuple: + { + // A tuple always represents "special" data + specialType = legalFieldType; + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + break; + } + + + PairInfo::Element pairElement; + pairElement.flags = 0; + pairElement.fieldDeclRef = fieldDeclRef; + pairElement.fieldPairInfo = elementPairInfo; + + // We will always add a field to the "ordinary" + // side of things, even if it has no ordinary + // data, just to keep the list of fields aligned + // with the original type. + OrdinaryElement ordinaryElement; + ordinaryElement.fieldDeclRef = fieldDeclRef; + if (ordinaryType.flavor != LegalType::Flavor::none) + { + anyOrdinary = true; + pairElement.flags |= PairInfo::kFlag_hasOrdinary; + + LegalType ot = ordinaryType; + + // TODO: any cases we should "unwrap" here? + // E.g., `implicitDeref`? + + if(ot.flavor == LegalType::Flavor::simple) + { + ordinaryElement.type = ot.getSimple(); + } + else + { + SLANG_UNEXPECTED("unexpected ordinary field type"); + } + } + ordinaryElements.Add(ordinaryElement); + + if (specialType.flavor != LegalType::Flavor::none) + { + anySpecial = true; + anyComplex = true; + pairElement.flags |= PairInfo::kFlag_hasSpecial; + + TuplePseudoType::Element specialElement; + specialElement.fieldDeclRef = fieldDeclRef; + specialElement.type = specialType; + specialElements.Add(specialElement); + } + + pairElement.type = LegalType::pair(ordinaryType, specialType, elementPairInfo); + pairElements.Add(pairElement); + } + + // Add a field to the (pseudo-)type we are building + void addField( + DeclRef<VarDeclBase> fieldDeclRef) + { + // Skip `static` fields. + if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + return; + + auto fieldType = GetType(fieldDeclRef); + + bool isResourceField = isResourceType(fieldType); + + auto legalFieldType = legalizeType(context, fieldType); + addField( + fieldDeclRef, + legalFieldType, + legalFieldType, + isResourceField); + } + + LegalType getResult() + { + // If we didn't see anything "special" + // then we can use the type as-is. + // we can conceivably just use the type as-is + // + // TODO: this might be a good place to turn + // a reference to a generic `struct` type into + // a concrete non-generic type so that downstream + // codegen doesn't have to deal with generics... + // + // TODO: In fact, why not just fully replace + // all aggregate types here with some structural + // types defined in the IR? + if (!anyComplex) + { + return LegalType::simple(type); + } + + // If there were any "ordinary" fields along the way, + // then we need to collect them into a new `struct` type + // that represents these fields. + // + LegalType ordinaryType; + if (anyOrdinary) + { + // We are going to create a new `struct` type declaration that clones + // the fields we care about from the original `struct` type. Note that + // these fields may have different types from what they did before, + // because the fields themselves might have been legalized. + // + // Our new declaration will have the same name as the old one, so + // downstream code is going to need to be careful not to emit declarations + // for both of them. This should be okay, though, because the original + // type was illegal (that was the whole point) and so it shouldn't be + // allowed in the output anyway. + RefPtr<StructDecl> ordinaryStructDecl = new StructDecl(); + ordinaryStructDecl->loc = typeDeclRef.getDecl()->loc; + ordinaryStructDecl->nameAndLoc = typeDeclRef.getDecl()->nameAndLoc; + + addModifier(ordinaryStructDecl, new LegalizedModifier()); + + // We will do something a bit unsavory here, by setting the logical + // parent of the new `struct` type to be the same as the orignal type + // (All of this helps ensure it gets the same mangled name). + // + ordinaryStructDecl->ParentDecl = typeDeclRef.getDecl()->ParentDecl; + + if (context->mainModuleDecl) + { + // If the declaration we are lowering belongs to the AST-based + // module being lowered (rather than translated to IR), then we + // need to add any new declaration we create to that output. + + // If we are *not* outputting an IR module as well, then + // everything needs to wind up in a single AST module. + if (!context->irModule) + { + context->outputModuleDecl->Members.Add(ordinaryStructDecl); + } + else + { + // Otherwise, check if this declaration belongs to the main + // module (which is being lowered via the AST-to-AST pass), + // and add it to the output if needed. + // + // TODO: This won't work correctly if a type from the AST + // module is used to specialize a generic in the IR module, + // since the declaration would need to precede the specialized + // func... + auto parentModule = findModuleForDecl(typeDeclRef.getDecl()); + if (parentModule && (parentModule == context->mainModuleDecl)) + { + context->outputModuleDecl->Members.Add(ordinaryStructDecl); + } + } + } + + // For memory management reasons, we need to keep a reference to + // the declaration live, no matter what. + context->createdDecls.Add(ordinaryStructDecl); + + UInt elementCounter = 0; + for(auto ee : ordinaryElements) + { + UInt elementIndex = elementCounter++; + + // We will ensure that all the original fields are represented, + // although they may have different types (due to legalization). + // For fields that have *no* ordinary data, we will give them + // a dummy `void` type and rely on downstream passes to not + // actually emit declarations for those fields. + // + // (This helps keeps things simple because both the original + // and modified type will have the same number of fields, so + // we can continue to look up field layouts by index in the + // emit logic) + RefPtr<Type> fieldType = ee.type; + if(!fieldType) + fieldType = context->session->getVoidType(); + + // TODO: shallow clone of modifiers, etc. + + RefPtr<StructField> fieldDecl = new StructField(); + fieldDecl->loc = ee.fieldDeclRef.getDecl()->loc; + fieldDecl->nameAndLoc = ee.fieldDeclRef.getDecl()->nameAndLoc; + fieldDecl->type.type = fieldType; + + fieldDecl->ParentDecl = ordinaryStructDecl; + ordinaryStructDecl->Members.Add(fieldDecl); + + pairElements[elementIndex].ordinaryFieldDeclRef = makeDeclRef(fieldDecl.Ptr()); + + addModifier(fieldDecl, new LegalizedModifier()); + } + + RefPtr<Type> ordinaryStructType = DeclRefType::Create( + context->session, + makeDeclRef(ordinaryStructDecl.Ptr())); + + ordinaryType = LegalType::simple(ordinaryStructType); + } + + LegalType specialType; + if (anySpecial) + { + RefPtr<TuplePseudoType> specialTuple = new TuplePseudoType(); + specialTuple->elements = specialElements; + specialType = LegalType::tuple(specialTuple); + } + + RefPtr<PairInfo> pairInfo; + if (anyOrdinary && anySpecial) + { + pairInfo = new PairInfo(); + pairInfo->elements = pairElements; + } + + return LegalType::pair(ordinaryType, specialType, pairInfo); + } + +}; + +static RefPtr<Type> createBuiltinGenericType( + TypeLegalizationContext* context, + DeclRef<Decl> const& typeDeclRef, + RefPtr<Type> elementType) +{ + // We are going to take the type for the original + // decl-ref and construct a new one that uses + // our new element type as its parameter. + // + // TODO: we should have library code to make + // manipulations like this way easier. + + RefPtr<GenericSubstitution> oldGenericSubst = getGenericSubstitution( + typeDeclRef.substitutions); + SLANG_ASSERT(oldGenericSubst); + + RefPtr<GenericSubstitution> newGenericSubst = new GenericSubstitution(); + + newGenericSubst->outer = oldGenericSubst->outer; + newGenericSubst->genericDecl = oldGenericSubst->genericDecl; + newGenericSubst->args = oldGenericSubst->args; + newGenericSubst->args[0] = elementType; + + auto newDeclRef = DeclRef<Decl>( + typeDeclRef.getDecl(), + newGenericSubst); + + auto newType = DeclRefType::Create( + context->session, + newDeclRef); + + return newType; +} + +// Create a uniform buffer type with a given legalized +// element type. +static LegalType createLegalUniformBufferType( + TypeLegalizationContext* context, + DeclRef<Decl> const& typeDeclRef, + LegalType legalElementType) +{ + switch (legalElementType.flavor) + { + case LegalType::Flavor::simple: + { + // Easy case: we just have a simple element type, + // so we want to create a uniform buffer that wraps it. + return LegalType::simple(createBuiltinGenericType( + context, + typeDeclRef, + legalElementType.getSimple())); + } + break; + + case LegalType::Flavor::implicitDeref: + { + // This is actually an annoying case, because + // we are being asked to convert, e.g.,: + // + // cbuffer Foo { ParameterBlock<Bar> bar; } + // + // into the equivalent of: + // + // cbuffer Foo { Bar bar; } + // + // Which would really require a new `LegalType` that + // would reprerent a resource type with a modified + // element type. + // + // I'm going to attempt to hack this for now. + return LegalType::implicitDeref(createLegalUniformBufferType( + context, + typeDeclRef, + legalElementType.getImplicitDeref()->valueType)); + } + break; + + case LegalType::Flavor::pair: + { + // We assume that the "ordinary" part of things + // will get wrapped in a constant-buffer type, + // and the "special" part needs to be wrapped + // with an `implicitDeref`. + auto pairType = legalElementType.getPair(); + + auto ordinaryType = createLegalUniformBufferType( + context, + typeDeclRef, + pairType->ordinaryType); + auto specialType = LegalType::implicitDeref(pairType->specialType); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // if we have a tuple type, then it must be representing + // the fields that can't be stored in a buffer anyway, + // so we just need to wrap each of them in an `implicitDeref` + + auto elementPseudoTupleType = legalElementType.getTuple(); + + RefPtr<TuplePseudoType> bufferPseudoTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : elementPseudoTupleType->elements) + { + TuplePseudoType::Element newElement; + + newElement.fieldDeclRef = ee.fieldDeclRef; + newElement.type = LegalType::implicitDeref(ee.type); + + bufferPseudoTupleType->elements.Add(newElement); + } + + return LegalType::tuple(bufferPseudoTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + +static LegalType createLegalUniformBufferType( + TypeLegalizationContext* context, + UniformParameterGroupType* uniformBufferType, + LegalType legalElementType) +{ + return createLegalUniformBufferType( + context, + uniformBufferType->declRef, + legalElementType); +} + +// Create a pointer type with a given legalized value type. +static LegalType createLegalPtrType( + TypeLegalizationContext* context, + DeclRef<Decl> const& typeDeclRef, + LegalType legalValueType) +{ + switch (legalValueType.flavor) + { + case LegalType::Flavor::simple: + { + // Easy case: we just have a simple element type, + // so we want to create a uniform buffer that wraps it. + return LegalType::simple(createBuiltinGenericType( + context, + typeDeclRef, + legalValueType.getSimple())); + } + break; + + case LegalType::Flavor::implicitDeref: + { + // We are being asked to create a pointer type to something + // that is implicitly dereferenced, meaning we had: + // + // Ptr(PtrLink(T)) + // + // and now are being asked to make: + // + // Ptr(implicitDeref(LegalT)) + // + // So it seems like we can just create: + // + // implicitDeref(Ptr(LegalT)) + // + // and nobody should really be able to tell the difference, right? + return LegalType::implicitDeref(createLegalPtrType( + context, + typeDeclRef, + legalValueType.getImplicitDeref()->valueType)); + } + break; + + case LegalType::Flavor::pair: + { + // We just need to pointer-ify both sides of the pair. + auto pairType = legalValueType.getPair(); + + auto ordinaryType = createLegalPtrType( + context, + typeDeclRef, + pairType->ordinaryType); + auto specialType = createLegalPtrType( + context, + typeDeclRef, + pairType->specialType); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // Wrap each of the tuple elements up as a pointer. + auto valuePseudoTupleType = legalValueType.getTuple(); + + RefPtr<TuplePseudoType> ptrPseudoTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : valuePseudoTupleType->elements) + { + TuplePseudoType::Element newElement; + + newElement.fieldDeclRef = ee.fieldDeclRef; + newElement.type = createLegalPtrType( + context, + typeDeclRef, + ee.type); + + ptrPseudoTupleType->elements.Add(newElement); + } + + return LegalType::tuple(ptrPseudoTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + +struct LegalTypeWrapper +{ + virtual LegalType wrap(TypeLegalizationContext* context, Type* type) = 0; +}; + +struct ArrayLegalTypeWrapper : LegalTypeWrapper +{ + ArrayExpressionType* arrayType; + + LegalType wrap(TypeLegalizationContext* context, Type* type) + { + return LegalType::simple(context->session->getArrayType( + type, + arrayType->ArrayLength)); + } +}; + +struct BuiltinGenericLegalTypeWrapper : LegalTypeWrapper +{ + DeclRef<Decl> declRef; + + LegalType wrap(TypeLegalizationContext* context, Type* type) + { + return LegalType::simple(createBuiltinGenericType( + context, + declRef, + type)); + } +}; + + +struct ImplicitDerefLegalTypeWrapper : LegalTypeWrapper +{ + LegalType wrap(TypeLegalizationContext*, Type* type) + { + return LegalType::implicitDeref(LegalType::simple(type)); + } +}; + +static LegalType wrapLegalType( + TypeLegalizationContext* context, + LegalType legalType, + LegalTypeWrapper* ordinaryWrapper, + LegalTypeWrapper* specialWrapper) +{ + switch (legalType.flavor) + { + case LegalType::Flavor::simple: + { + return ordinaryWrapper->wrap(context, legalType.getSimple()); + } + break; + + case LegalType::Flavor::implicitDeref: + { + return LegalType::implicitDeref(wrapLegalType( + context, + legalType, + ordinaryWrapper, + specialWrapper)); + } + break; + + case LegalType::Flavor::pair: + { + // We just need to pointer-ify both sides of the pair. + auto pairType = legalType.getPair(); + + auto ordinaryType = wrapLegalType( + context, + pairType->ordinaryType, + ordinaryWrapper, + ordinaryWrapper); + auto specialType = wrapLegalType( + context, + pairType->specialType, + specialWrapper, + specialWrapper); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // Wrap each of the tuple elements up as a pointer. + auto tupleType = legalType.getTuple(); + + RefPtr<TuplePseudoType> resultTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : tupleType->elements) + { + TuplePseudoType::Element element; + + element.fieldDeclRef = ee.fieldDeclRef; + element.type = wrapLegalType( + context, + ee.type, + ordinaryWrapper, + specialWrapper); + + resultTupleType->elements.Add(element); + } + + return LegalType::tuple(resultTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + + +// Legalize a type, including any nested types +// that it transitively contains. +LegalType legalizeType( + TypeLegalizationContext* context, + Type* type) +{ + if (auto parameterBlockType = type->As<ParameterBlockType>()) + { + // We basically legalize the `ParameterBlock<T>` type + // over to `T`. In order to represent this preoperly, + // we need to be careful to wrap it up in a way that + // tells us to eliminate downstream deferences... + + auto legalElementType = legalizeType(context, + parameterBlockType->getElementType()); + return LegalType::implicitDeref(legalElementType); + } + else if (auto uniformBufferType = type->As<UniformParameterGroupType>()) + { + // We have a `ConstantBuffer<T>` or `TextureBuffer<T>` or + // other pointer-like type that represents uniform parameters. + // We need to pull any resource-type fields out of it, but + // leave the non-resource fields where they are. + + // Legalize the element type to see what we are working with. + auto legalElementType = legalizeType(context, + uniformBufferType->getElementType()); + + switch (legalElementType.flavor) + { + case LegalType::Flavor::simple: + return LegalType::simple(type); + + default: + return createLegalUniformBufferType( + context, + uniformBufferType, + legalElementType); + } + + } + else if (isResourceType(type)) + { + // We assume that any resource types not handled above + // are legal as-is. + return LegalType::simple(type); + } + else if (type->As<BasicExpressionType>()) + { + return LegalType::simple(type); + } + else if (type->As<VectorExpressionType>()) + { + return LegalType::simple(type); + } + else if (type->As<MatrixExpressionType>()) + { + return LegalType::simple(type); + } + else if (auto ptrType = type->As<PtrTypeBase>()) + { + auto legalValueType = legalizeType(context, ptrType->getValueType()); + return createLegalPtrType(context, ptrType->declRef, legalValueType); + } + else if (auto declRefType = type->As<DeclRefType>()) + { + auto declRef = declRefType->declRef; + + LegalType legalType; + if(context->mapDeclRefToLegalType.TryGetValue(declRef, legalType)) + return legalType; + + + if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + { + // Look at the (non-static) fields, and + // see if anything needs to be cleaned up. + // The things that need to be "cleaned up" for + // our purposes are: + // + // - Fields of resource type, or any other future + // type we run into that isn't allowed in + // aggregates for at least some targets + // + // - Fields with types that themselves had to + // get legalized. + // + // If we don't run into any of these, we + // can just use the type as-is. Hooray! + // + // Otherwise, we are effectively going to split + // the type apart and create a `TuplePseudoType`. + // Every field of the original type will be + // represented as an element of this pseudo-type. + // Each element will record its `LegalType`, + // and the original field that it was created from. + // An element will also track whether it contains + // any "ordinary" data, and if so, it will remember + // an element index in a real (AST-level, non-pseudo) + // `TupleType` that is used to bundle together + // such fields. + // + // Storing all the simple fields together like this + // obviously adds complexity to the legalization + // pass, but it has important benefits: + // + // - It avoids creating functions with a very large + // number of parameters (when passing a structure + // with many fields), which might confuse downstream + // compilers. + // + // - It avoids applying AOS->SOA conversion to fields + // that don't actually need it, which is basically + // required if we want type layout to work. + // + // - It ensures that we can actually construct a + // constant-buffer type that wraps a legalized + // aggregate type; the ordinary fields will get + // placed inside a new constant-buffer type, + // while the special ones will get left outside. + // + + TupleTypeBuilder builder; + builder.context = context; + builder.type = type; + builder.typeDeclRef = aggTypeDeclRef; + + + for (auto ff : getMembersOfType<StructField>(aggTypeDeclRef)) + { + builder.addField(ff); + } + + legalType = builder.getResult(); + context->mapDeclRefToLegalType.Add(declRef, legalType); + return legalType; + } + + // TODO: for other declaration-reference types, we really + // need to legalize the types used in substitutions, and + // signal an error if any of them turn out to be non-simple. + // + // The limited cases of types that can handle having non-simple + // types as generic arguments all need to be special-cased here. + // (For example, we can't handle `Texture2D<SomeStructWithTexturesInIt>`. + // + } + else if(auto arrayType = type->As<ArrayExpressionType>()) + { + auto legalElementType = legalizeType( + context, + arrayType->baseType); + + switch (legalElementType.flavor) + { + case LegalType::Flavor::simple: + // Element type didn't need to be legalized, so + // we can just use this type as-is. + return LegalType::simple(type); + + default: + { + ArrayLegalTypeWrapper wrapper; + wrapper.arrayType = arrayType; + + return wrapLegalType( + context, + legalElementType, + &wrapper, + &wrapper); + } + break; + } + + } + + return LegalType::simple(type); +} + +// + +RefPtr<TypeLayout> getDerefTypeLayout( + TypeLayout* typeLayout) +{ + if (!typeLayout) + return nullptr; + + if (auto parameterGroupTypeLayout = dynamic_cast<ParameterGroupTypeLayout*>(typeLayout)) + { + return parameterGroupTypeLayout->elementTypeLayout; + } + + return typeLayout; +} + +RefPtr<VarLayout> getFieldLayout( + TypeLayout* typeLayout, + DeclRef<VarDeclBase> fieldDeclRef) +{ + if (!typeLayout) + return nullptr; + + while(auto arrayTypeLayout = dynamic_cast<ArrayTypeLayout*>(typeLayout)) + { + typeLayout = arrayTypeLayout->elementTypeLayout; + } + + if (auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout)) + { + RefPtr<VarLayout> fieldLayout; + if (structTypeLayout->mapVarToLayout.TryGetValue(fieldDeclRef.getDecl(), fieldLayout)) + return fieldLayout; + } + + return nullptr; +} + +RefPtr<VarLayout> createVarLayout( + LegalVarChain* varChain, + TypeLayout* typeLayout) +{ + if (!typeLayout) + return nullptr; + + // We need to construct a layout for the new variable + // that reflects both the type we have given it, as + // well as all the offset information that has accumulated + // along the chain of parent variables. + + // TODO: this logic needs to propagate through semantics... + + RefPtr<VarLayout> varLayout = new VarLayout(); + varLayout->typeLayout = typeLayout; + + for (auto rr : typeLayout->resourceInfos) + { + auto resInfo = varLayout->findOrAddResourceInfo(rr.kind); + + for (auto vv = varChain; vv; vv = vv->next) + { + if (auto parentResInfo = vv->varLayout->FindResourceInfo(rr.kind)) + { + resInfo->index += parentResInfo->index; + resInfo->space += parentResInfo->space; + } + } + } + + // Some of the parent variables might actually contain offsets + // to the `space` or `set` of the field, and we need to apply + // those to all the nested resource infos. + for (auto vv = varChain; vv; vv = vv->next) + { + auto parentSpaceInfo = vv->varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace); + if (!parentSpaceInfo) + continue; + + for (auto& rr : varLayout->resourceInfos) + { + if (rr.kind == LayoutResourceKind::RegisterSpace) + { + rr.index += parentSpaceInfo->index; + } + else + { + rr.space += parentSpaceInfo->index; + } + } + } + + return varLayout; +} + +// + +// TODO(tfoley): The code captured here is the logic that used to be +// applied to decide whether or not to desugar aggregate types that +// contain resources. Right now the implementation will *always* legalize +// away such types (since the IR always does this), while the AST-to-AST +// pass would only do it if required (according to the tests below). +// +// For right now this is an academic distinction, since the only project +// using Slang right now enables this tansformation unconditionally, but +// we probably need to re-parent this code back into the `TypeLegalizationContext` +// somewhere. +#if 0 + +bool shouldDesugarTupleTypes = false; +if (getTarget() == CodeGenTarget::GLSL) +{ + // Always desugar this stuff for GLSL, since it doesn't + // support nesting of resources in structs. + // + // TODO: Need a way to make this more fine-grained to + // handle cases where a nested member might be allowed + // due to, e.g., bindless textures. + shouldDesugarTupleTypes = true; +} +else if( shared->compileRequest->compileFlags & SLANG_COMPILE_FLAG_SPLIT_MIXED_TYPES ) +{ + // If the user is directly asking us to do this transformation, + // then obviously we need to do it. + // + // TODO: The way this is defined here means it will even apply to user + // HLSL code (not just code written in Slang). We may want to + // reconsider that choice, and only split things that originated in Slang. + // + shouldDesugarTupleTypes = true; +} + +#endif + +} diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h new file mode 100644 index 000000000..36e4223b6 --- /dev/null +++ b/source/slang/legalize-types.h @@ -0,0 +1,296 @@ +// legalize-types.h +#ifndef SLANG_LEGALIZE_TYPES_H_INCLUDED +#define SLANG_LEGALIZE_TYPES_H_INCLUDED + +// This file and `legalize-types.cpp` implement the core +// logic for taking a `Type` as produced by the front-end, +// and turning it into a suitable representation for use +// on a particular back-end. +// +// The main work applies to aggregate (e.g., `struct`) types, +// since various targets have rules about what is and isn't +// allowed in an aggregate (or where aggregates are allowed +// to be used). +// +// We might completely replace an aggregate `Type` with a +// "pseudo-type" that is just the enumeration of its field +// types (sort of a tuple type) so that a variable declared +// with the original type should be transformed into a +// bunch of individual variables. +// +// Alternatively, we might replace an aggregate type, where +// only *some* of the fields are illegal with a combination +// of an aggregate (containing the legal/legalized fields), +// and some extra tuple-ified fields. + +#include "../core/basic.h" +#include "syntax.h" +#include "type-layout.h" + +namespace Slang +{ + +struct LegalTypeImpl : RefObject +{ +}; +struct ImplicitDerefType; +struct TuplePseudoType; +struct PairPseudoType; +struct PairInfo; + +struct LegalType +{ + enum class Flavor + { + // Nothing: a NULL type + none, + + // A simple type that can be represented directly as a `Type` + simple, + + // Logically, we have a pointer-like type, but we are + // going to represnet it as the pointed-to type + implicitDeref, + + // A compound type was broken apart into its constituent fields, + // so a tuple "pseduo-type" is being used to collect + // those fields together. + tuple, + + // A type has to get split into "ordinary" and "special" parts, + // each of which will be represented with its own `LegalType`. + pair, + }; + + Flavor flavor = Flavor::none; + RefPtr<RefObject> obj; + + static LegalType simple(Type* type) + { + LegalType result; + result.flavor = Flavor::simple; + result.obj = type; + return result; + } + + RefPtr<Type> getSimple() const + { + assert(flavor == Flavor::simple); + return obj.As<Type>(); + } + + static LegalType implicitDeref( + LegalType const& valueType); + + RefPtr<ImplicitDerefType> getImplicitDeref() const + { + assert(flavor == Flavor::implicitDeref); + return obj.As<ImplicitDerefType>(); + } + + static LegalType tuple( + RefPtr<TuplePseudoType> tupleType); + + RefPtr<TuplePseudoType> getTuple() const + { + assert(flavor == Flavor::tuple); + return obj.As<TuplePseudoType>(); + } + + static LegalType pair( + RefPtr<PairPseudoType> pairType); + + static LegalType pair( + LegalType const& ordinaryType, + LegalType const& specialType, + RefPtr<PairInfo> pairInfo); + + RefPtr<PairPseudoType> getPair() const + { + assert(flavor == Flavor::pair); + return obj.As<PairPseudoType>(); + } +}; + +// Represents the pseudo-type of a type that is pointer-like +// (and thus requires dereferencing, even if implicit), but +// was legalized to just use the type of the pointed-type value. +struct ImplicitDerefType : LegalTypeImpl +{ + LegalType valueType; +}; + +// Represents the pseudo-type for a compound type +// that had to be broken apart because it contained +// one or more fields of types that shouldn't be +// allowed in aggregates. +// +// A tuple pseduo-type will have an element for +// each field of the original type, that represents +// the legalization of that field's type. +// +// It optionally also contains an "ordinary" type +// that packs together any per-field data that +// itself has (or contains) an ordinary type. +struct TuplePseudoType : LegalTypeImpl +{ + // Represents one element of the tuple pseudo-type + struct Element + { + // The field that this element replaces + DeclRef<VarDeclBase> fieldDeclRef; + + // The legalized type of the element + LegalType type; + }; + + // All of the elements of the tuple pseduo-type. + List<Element> elements; +}; + +struct PairInfo : RefObject +{ + typedef unsigned int Flags; + enum + { + kFlag_hasOrdinary = 0x1, + kFlag_hasSpecial = 0x2, + kFlag_hasOrdinaryAndSpecial = kFlag_hasOrdinary | kFlag_hasSpecial, + }; + + struct Element + { + // The original field the element represents + DeclRef<Decl> fieldDeclRef; + + // The conceptual type of the field. + // If both the `hasOrdinary` and + // `hasSpecial` bits are set, then + // this is expected to be a + // `LegalType::Flavor::pair` + LegalType type; + + // Is the value represented on + // the ordinary side, the special + // side, or both? + Flags flags; + + // If the type of this element is + // itself a pair type (that is, + // it both `hasOrdinary` and `hasSpecial`) + // then this is the `PairInfo` for that + // pair type: + RefPtr<PairInfo> fieldPairInfo; + + // The actual field decl-ref that needs + // to be used for looking up this element + // in the ordinary type. + DeclRef<Decl> ordinaryFieldDeclRef; + }; + + // For a pair type or value, we need to track + // which fields are on which side(s). + List<Element> elements; + + Element* findElement(DeclRef<Decl> const& fieldDeclRef) + { + for (auto& ee : elements) + { + if(ee.fieldDeclRef.Equals(fieldDeclRef)) + return ⅇ + } + return nullptr; + } +}; + +struct PairPseudoType : LegalTypeImpl +{ + // Any field(s) with ordinary types will + // get captured here (usually as a `fieldRemap` + // type) + LegalType ordinaryType; + + // Any fields with "special" (not ordinary) + // types will get captured here (usually + // with a tuple). + LegalType specialType; + + // The `pairInfo` field helps to tell us which members + // of the original aggregate type appear on which side(s) + // of the new pair type. + RefPtr<PairInfo> pairInfo; +}; + +// + +RefPtr<TypeLayout> getDerefTypeLayout( + TypeLayout* typeLayout); + +RefPtr<VarLayout> getFieldLayout( + TypeLayout* typeLayout, + DeclRef<VarDeclBase> fieldDeclRef); + +// Represents the "chain" of declarations that +// were followed to get to a variable that we +// are now declaring as a leaf variable. +struct LegalVarChain +{ + LegalVarChain* next; + VarLayout* varLayout; +}; + +RefPtr<VarLayout> createVarLayout( + LegalVarChain* varChain, + TypeLayout* typeLayout); + +// + +struct TypeLegalizationContext +{ + /// The overall compilation session (used when + /// constructing types). + Session* session; + + // If the type we are legalizing comes from an + // AST module being lowered via AST-to-AST translation, + // then we want to add any new declaration we create + // to represent it to the appropriate output module. + // We store some fields here to enable that: + RefPtr<ModuleDecl> mainModuleDecl; + RefPtr<ModuleDecl> outputModuleDecl; + + // We also need to know whether the IR is involved + // at all, because if it is, then it will own certain + // declarations instead. + // + // We do this in a slightly silly way by storing a pointer + // to the IR module (if any), and assume that its presence + // or absence is the indicator we need. + IRModule* irModule = nullptr; + + /// A list to retain any AST objects created during type legalization. + List<RefPtr<Decl>> createdDecls; + + /// A mapping from declaration references to the resulting + /// legalized type. + /// + /// For declaration-reference types, this map can be used + /// to cache a legalization so that it will be re-used + /// for equivalent declaration references (and so avoid + /// emitting declarations of legalized `struct` types + /// multiple times). + Dictionary<DeclRef<Decl>, LegalType> mapDeclRefToLegalType; +}; + + +LegalType legalizeType( + TypeLegalizationContext* context, + Type* type); + +/// Try to find the module that (recursively) contains a given declaration. +ModuleDecl* findModuleForDecl( + Decl* decl); + +} + +#endif diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 5ec175668..1e43a4d31 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -947,14 +947,6 @@ LoweredTypeInfo lowerType( return visitor.dispatchType(type); } -#if 0 -struct LoweringVisitor - : ExprVisitor<LoweringVisitor, LoweredExpr> - , StmtVisitor<LoweringVisitor, void> - , DeclVisitor<LoweringVisitor, LoweredDecl> - , ValVisitor<LoweringVisitor, RefPtr<Val>, RefPtr<Type>> -#endif - LoweredValInfo createVar( IRGenContext* context, RefPtr<Type> type, @@ -982,6 +974,8 @@ void addArgs( case LoweredValInfo::Flavor::Simple: case LoweredValInfo::Flavor::Ptr: case LoweredValInfo::Flavor::SwizzledLValue: + case LoweredValInfo::Flavor::BoundSubscript: + case LoweredValInfo::Flavor::BoundMember: args.Add(getSimpleVal(context, argInfo)); break; @@ -2535,6 +2529,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> SLANG_UNIMPLEMENTED_X("decl catch-all"); } + LoweredValInfo visitImportDecl(ImportDecl* /*decl*/) + { + return LoweredValInfo(); + } + LoweredValInfo visitEmptyDecl(EmptyDecl* /*decl*/) { return LoweredValInfo(); diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index e2db1b456..721072b82 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -124,14 +124,6 @@ namespace Slang { emitQualifiedName(context, declRefType->declRef); } - else if (auto tupleType = dynamic_cast<FilteredTupleType*>(type)) - { - // TODO: this doesn't handle the possibility of multiple different - // filtered versions of the same type... - emitRaw(context, "t"); - emitType(context, tupleType->originalType); - emitRaw(context, "_"); - } else { SLANG_UNEXPECTED("unimplemented case in mangling"); diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h index 0f92fcd61..cd8f2524b 100644 --- a/source/slang/modifier-defs.h +++ b/source/slang/modifier-defs.h @@ -318,21 +318,9 @@ SYNTAX_CLASS(ComputedLayoutModifier, Modifier) FIELD(RefPtr<Layout>, layout) END_SYNTAX_CLASS() -// A modifier attached to types during lowering, to indicate that they -// are logically a "tuple" type -SYNTAX_CLASS(TupleTypeModifier, Modifier) - FIELD_INIT(AggTypeDecl*, decl, nullptr) - FIELD_INIT(bool, hasAnyNonTupleFields, false) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(TupleFieldModifier, Modifier) - FIELD_INIT(VarDeclBase*, decl, nullptr) - FIELD_INIT(bool, hasAnyNonTupleFields, false) - FIELD_INIT(bool, isNestedTuple, false) -END_SYNTAX_CLASS() SYNTAX_CLASS(TupleVarModifier, Modifier) - FIELD_INIT(TupleFieldModifier*, tupleField, nullptr) +// FIELD_INIT(TupleFieldModifier*, tupleField, nullptr) END_SYNTAX_CLASS() // A modifier to indicate that a constructor/initializer can be used @@ -342,3 +330,8 @@ SYNTAX_CLASS(ImplicitConversionModifier, Modifier) // The conversion cost, used to rank conversions FIELD(ConversionCost, cost) END_SYNTAX_CLASS() + +// A marker modifier used to indicate that a declaration was created as +// part of type legalization. +SIMPLE_SYNTAX_CLASS(LegalizedModifier, Modifier) + diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index d0e4e840e..9835640e0 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -179,6 +179,7 @@ <ClInclude Include="ir-inst-defs.h" /> <ClInclude Include="ir-insts.h" /> <ClInclude Include="ir.h" /> + <ClInclude Include="legalize-types.h" /> <ClInclude Include="lexer.h" /> <ClInclude Include="lookup.h" /> <ClInclude Include="lower-to-ir.h" /> @@ -217,6 +218,7 @@ <ClCompile Include="emit.cpp" /> <ClCompile Include="ir-legalize-types.cpp" /> <ClCompile Include="ir.cpp" /> + <ClCompile Include="legalize-types.cpp" /> <ClCompile Include="lexer.cpp" /> <ClCompile Include="lookup.cpp" /> <ClCompile Include="lower-to-ir.cpp" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index c93934fff..f207c6dcd 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -43,6 +43,7 @@ <ClInclude Include="vm.h" /> <ClInclude Include="mangle.h" /> <ClInclude Include="ast-legalize.h" /> + <ClInclude Include="legalize-types.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="check.cpp" /> @@ -72,6 +73,7 @@ <ClCompile Include="dxc-support.cpp" /> <ClCompile Include="ir-legalize-types.cpp" /> <ClCompile Include="ast-legalize.cpp" /> + <ClCompile Include="legalize-types.cpp" /> </ItemGroup> <ItemGroup> <CustomBuild Include="core.meta.slang" /> diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index f8237359d..dd08fa2b1 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -321,6 +321,7 @@ void Type::accept(IValVisitor* visitor, void* extra) IntVal* elementCount) { RefPtr<ArrayExpressionType> arrayType = new ArrayExpressionType(); + arrayType->setSession(this); arrayType->baseType = elementType; arrayType->ArrayLength = elementCount; return arrayType; @@ -865,8 +866,16 @@ void Type::accept(IValVisitor* visitor, void* extra) int NamedExpressionType::GetHashCode() { - SLANG_UNEXPECTED("unreachable"); - UNREACHABLE_RETURN(0); + // Type equality is based on comparing canonical types, + // so the hash code for a type needs to come from the + // canonical version of the type. This really means + // that `Type::GetHashCode()` should dispatch out to + // something like `Type::GetHashCodeImpl()` on the + // canonical version of a type, but it is less invasive + // for now (and hopefully equivalent) to just have any + // named types automaticlaly route hash-code requests + // to their canonical type. + return GetCanonicalType()->GetHashCode(); } // FuncType @@ -1875,118 +1884,4 @@ void Type::accept(IValVisitor* visitor, void* extra) insertGlobalGenericSubstitutions(newSubst, subst, ioDiff); return newSubst; } - - // FilteredTupleType - - String FilteredTupleType::ToString() - { - StringBuilder sb; - sb.append(originalType->ToString()); - sb.append("{"); - bool first = true; - for (auto ee : elements) - { - if (!ee.type) - continue; - - if (!first) sb.append(", "); - - sb.append(ee.fieldDeclRef.GetName()->text); - sb.append(":"); - sb.append(ee.type->ToString()); - - first = false; - } - sb.append("}"); - return sb.ProduceString(); - } - - RefPtr<Val> FilteredTupleType::SubstituteImpl(Substitutions* subst, int* ioDiff) - { - int diff = 0; - auto substOriginalType = originalType->SubstituteImpl(subst, &diff).As<Type>(); - - List<Element> substElements; - for (auto ee : elements) - { - Element substElement; - substElement.fieldDeclRef = ee.fieldDeclRef.SubstituteImpl(subst, &diff); - substElement.type = ee.type->SubstituteImpl(subst, &diff).As<Type>(); - substElements.Add(substElement); - } - - if (!diff) - return this; - - (*ioDiff)++; - RefPtr<FilteredTupleType> substType = new FilteredTupleType(); - substType->setSession(session); - substType->originalType = substOriginalType; - substType->elements = substElements; - return substType; - } - - bool FilteredTupleType::EqualsImpl(Type * type) - { - auto tupleType = type->As<FilteredTupleType>(); - if (!tupleType) - return false; - - if (!originalType->Equals(tupleType->originalType)) - return false; - - auto elementCount = elements.Count(); - if (tupleType->elements.Count() != elementCount) - return false; - - for (UInt ee = 0; ee < elementCount; ee++) - { - if (!elements[ee].type || !tupleType->elements[ee].type) - { - if (!elements[ee].type != !tupleType->elements[ee].type) - return false; - - continue; - } - - if (!elements[ee].fieldDeclRef.Equals(tupleType->elements[ee].fieldDeclRef)) - return false; - - if (!elements[ee].type->Equals(tupleType->elements[ee].type)) - return false; - } - return true; - } - - int FilteredTupleType::GetHashCode() - { - int hash = (int)(typeid(this).hash_code()); - hash = combineHash(hash, - originalType->GetHashCode()); - for (auto ee : elements) - { - hash = combineHash(hash, - ee.fieldDeclRef.GetHashCode()); - hash = combineHash(hash, - ee.type->GetHashCode()); - } - return hash; - } - - Type* FilteredTupleType::CreateCanonicalType() - { - RefPtr<FilteredTupleType> canTupleType = new FilteredTupleType(); - canTupleType->setSession(session); - canTupleType->originalType = originalType->GetCanonicalType(); - for (auto ee : elements) - { - Element element; - element.fieldDeclRef = ee.fieldDeclRef; - element.type = ee.type ? ee.type->GetCanonicalType() : nullptr; - - canTupleType->elements.Add(element); - } - getSession()->canonicalTypes.Add(canTupleType); - return canTupleType; - } } diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index 72bf6fe4c..da6e27d17 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -493,37 +493,3 @@ protected: virtual Type* CreateCanonicalType() override; ) END_SYNTAX_CLASS() - -// A type created to represent the result of filtering -// the fields of an aggregate type. -SYNTAX_CLASS(FilteredTupleType, Type) -RAW( - struct Element - { - // The original field this element represents - DeclRef<VarDeclBase> fieldDeclRef; - - // The type being used for the new field - RefPtr<Type> type; - }; -) - - FIELD(RefPtr<Type>, originalType); - FIELD(List<Element>, elements); - -RAW( - FilteredTupleType() - {} - - RefPtr<Type> getOriginalType() const { return originalType; } - List<Element> const& getElements() const { return elements; } - virtual String ToString() override; - -protected: - virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; - virtual bool EqualsImpl(Type * type) override; - virtual int GetHashCode() override; - virtual Type* CreateCanonicalType() override; -) - -END_SYNTAX_CLASS() diff --git a/tests/compute/rewriter-array-type.hlsl b/tests/compute/rewriter-array-type.hlsl new file mode 100644 index 000000000..38a2abc1e --- /dev/null +++ b/tests/compute/rewriter-array-type.hlsl @@ -0,0 +1,35 @@ +//TEST(compute):HLSL_COMPUTE:-xslang -no-checking -xslang -use-ir + +//TEST_INPUT:cbuffer(data=[16 0 0 0 32 0 0 0]):dxbinding(0),glbinding(0) +//TEST_INPUT:cbuffer(data=[256 0 0 0 512 0 0 0 768 0 0 0 1024 0 0 0]):dxbinding(1),glbinding(1) +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out +//TEST_INPUT:ubuffer(data=[90 91 92 93], stride=4):dxbinding(1),glbinding(1) +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):dxbinding(2),glbinding(2) + +// Test the case of user code with that uses the "rewriter" mode +// (`-no-checking` flag) and that uses a type declared in +// imported code (that will compile via IR). Also test +// the case where such a type requires legalization. + +import rewriter_types; + +RWStructuredBuffer<int> outputBuffer : register(u0); + +cbuffer C +{ + MyHelper myHelpers[2]; +} + +cbuffer D +{ + Other others[2]; +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal, myHelpers[1], others[tid]); + outputBuffer[tid] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/rewriter-array-type.hlsl.expected.txt b/tests/compute/rewriter-array-type.hlsl.expected.txt new file mode 100644 index 000000000..9bde1c082 --- /dev/null +++ b/tests/compute/rewriter-array-type.hlsl.expected.txt @@ -0,0 +1,4 @@ +120 +222 +324 +426 diff --git a/tests/compute/rewriter-types.hlsl b/tests/compute/rewriter-types.hlsl new file mode 100644 index 000000000..01da92997 --- /dev/null +++ b/tests/compute/rewriter-types.hlsl @@ -0,0 +1,34 @@ +//TEST(compute):HLSL_COMPUTE:-xslang -no-checking -xslang -use-ir + +//TEST_INPUT:cbuffer(data=[16]):dxbinding(0),glbinding(0) +//TEST_INPUT:cbuffer(data=[256]):dxbinding(1),glbinding(1) +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):dxbinding(1),glbinding(1) + +// Test the case of user code with that uses the "rewriter" mode +// (`-no-checking` flag) and that uses a type declared in +// imported code (that will compile via IR). Also test +// the case where such a type requires legalization. + +import rewriter_types; + +RWStructuredBuffer<int> outputBuffer : register(u0); + +cbuffer C +{ + MyHelper myHelper; +} + +cbuffer D +{ + Other other; +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal, myHelper, other); + outputBuffer[tid] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/rewriter-types.hlsl.expected.txt b/tests/compute/rewriter-types.hlsl.expected.txt new file mode 100644 index 000000000..b68ea4a4f --- /dev/null +++ b/tests/compute/rewriter-types.hlsl.expected.txt @@ -0,0 +1,4 @@ +110 +112 +114 +116 diff --git a/tests/compute/rewriter-types.slang b/tests/compute/rewriter-types.slang new file mode 100644 index 000000000..a627b7891 --- /dev/null +++ b/tests/compute/rewriter-types.slang @@ -0,0 +1,17 @@ +//TEST_IGNORE_FILE: + +struct MyHelper +{ + int val; + RWStructuredBuffer<int> buf; +}; + +struct Other +{ + int val; +}; + +int test(int inVal, MyHelper helper, Other other) +{ + return inVal + helper.val + helper.buf[inVal] + other.val; +} |
