diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/ast-legalize.cpp | 2495 | ||||
| -rw-r--r-- | source/slang/ast-legalize.h | 12 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 256 | ||||
| -rw-r--r-- | source/slang/ir-legalize-types.cpp | 1047 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 19 | ||||
| -rw-r--r-- | source/slang/legalize-types.cpp | 1086 | ||||
| -rw-r--r-- | source/slang/legalize-types.h | 296 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 15 | ||||
| -rw-r--r-- | source/slang/mangle.cpp | 8 | ||||
| -rw-r--r-- | source/slang/modifier-defs.h | 19 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 2 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 2 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 127 | ||||
| -rw-r--r-- | source/slang/type-defs.h | 34 |
14 files changed, 2838 insertions, 2580 deletions
diff --git a/source/slang/ast-legalize.cpp b/source/slang/ast-legalize.cpp index 1b045cbc9..eaceecc00 100644 --- a/source/slang/ast-legalize.cpp +++ b/source/slang/ast-legalize.cpp @@ -3,6 +3,7 @@ #include "emit.h" #include "ir-insts.h" +#include "legalize-types.h" #include "type-layout.h" #include "visitor.h" @@ -37,100 +38,10 @@ struct CloneVisitor // -// - -class TupleExpr; -class TupleVarDecl; -class VaryingTupleExpr; -class VaryingTupleVarDecl; - - -// The result of lowering a declaration will usually be a declaration, -// but it might also be a "tuple" declaration, in cases where we needed -// to sclarize (or partially scalarize) things to guarantee validity. -struct LoweredDecl -{ - enum class Flavor - { - Decl, // A single declaration (the default case) - Tuple, // A `TupleVarDecl` representing multiple decls - VaryingTuple, // A `VaryingTupleVarDecl` representing multiple decls - }; - - LoweredDecl() - : flavor(Flavor::Decl) - {} - - LoweredDecl(Decl* decl) - : value(decl) - , flavor(Flavor::Decl) - {} - - LoweredDecl(TupleVarDecl* decl) - : value((RefObject*) decl) - , flavor(Flavor::Tuple) - {} - - LoweredDecl(VaryingTupleVarDecl* decl) - : value((RefObject*) decl) - , flavor(Flavor::VaryingTuple) - {} - - Flavor getFlavor() const { return flavor; } - RefObject* getValue() const { return value; } - - Decl* getDecl() const - { - SLANG_ASSERT(getFlavor() == Flavor::Decl); - return (Decl*) value.Ptr(); - } - - TupleVarDecl* getTupleDecl() const - { - SLANG_ASSERT(getFlavor() == Flavor::Tuple); - return (TupleVarDecl*) value.Ptr(); - } - - VaryingTupleVarDecl* getVaryingTupleDecl() const - { - SLANG_ASSERT(getFlavor() == Flavor::VaryingTuple); - return (VaryingTupleVarDecl*) value.Ptr(); - } - - Decl* asDecl() const - { - return (getFlavor() == Flavor::Decl) ? getDecl() : nullptr; - } - - TupleVarDecl* asTupleDecl() const - { - return (getFlavor() == Flavor::Tuple) ? getTupleDecl() : nullptr; - } - - VaryingTupleVarDecl* asVaryingTupleDecl() const - { - return (getFlavor() == Flavor::VaryingTuple) ? getVaryingTupleDecl() : nullptr; - } - -private: - RefPtr<RefObject> value; - Flavor flavor; -}; - -struct LoweredDeclRef -{ -public: - LoweredDecl decl; - RefPtr<Substitutions> substitutions; - - LoweredDecl getDecl() { return decl; } - - template<typename T> - DeclRef<T> As() - { - return DeclRef<Decl>(decl.getDecl(), substitutions).As<T>(); - } -}; +// Forward-declare types used by `LegalExpr` +class ImplicitDerefPseudoExpr; +class TuplePseudoExpr; +class PairPseudoExpr; // @@ -152,7 +63,7 @@ struct StructuralTransformVisitorBase template<typename T> DeclRef<T> transformDeclField(DeclRef<T> const& decl) { - LoweredDeclRef declRef = visitor->translateDeclRef(decl); + DeclRef<Decl> declRef = visitor->translateDeclRef(decl); return declRef.As<T>(); } @@ -204,18 +115,6 @@ struct StructuralTransformVisitorBase } }; -#if 0 -template<typename V> -RefPtr<Stmt> structuralTransform( - Stmt* stmt, - V* visitor) -{ - StructuralTransformStmtVisitor<V> transformer; - transformer.visitor = visitor; - return transformer.dispatch(stmt); -} -#endif - template<typename V> struct StructuralTransformExprVisitor : StructuralTransformVisitorBase<V> @@ -267,71 +166,71 @@ RefPtr<Expr> structuralTransform( -// The result of lowering an exrpession will usually be just a single +// The result of legalizing an exrpession will usually be just a single // expression, but it might also be a "tuple" expression that encodes // multiple expressions. -struct LoweredExpr +struct LegalExpr { - enum class Flavor - { - Expr, - Tuple, - VaryingTuple, - }; + typedef LegalType::Flavor Flavor; - LoweredExpr() - : flavor(Flavor::Expr) + LegalExpr() + : flavor(Flavor::none) {} - LoweredExpr(Expr* expr) + LegalExpr(Expr* expr) : value(expr) - , flavor(Flavor::Expr) + , flavor(Flavor::simple) {} - LoweredExpr(TupleExpr* expr) + LegalExpr(TuplePseudoExpr* expr) : value((RefObject*) expr) - , flavor(Flavor::Tuple) + , flavor(Flavor::tuple) {} - LoweredExpr(VaryingTupleExpr* expr) + LegalExpr(PairPseudoExpr* expr) : value((RefObject*) expr) - , flavor(Flavor::VaryingTuple) + , flavor(Flavor::pair) {} - Flavor getFlavor() const { return flavor; } + LegalExpr(ImplicitDerefPseudoExpr* expr) + : value((RefObject*) expr) + , flavor(Flavor::implicitDeref) + {} - Expr* getExpr() const - { - assert(getFlavor() == Flavor::Expr); - return (Expr*)value.Ptr(); - } + Flavor getFlavor() const { return flavor; } - TupleExpr* getTupleExpr() const + Expr* getSimple() const { - assert(getFlavor() == Flavor::Tuple); - return (TupleExpr*)value.Ptr(); - } + switch (getFlavor()) + { + case Flavor::none: + return nullptr; + case Flavor::simple: + return (Expr*)value.Ptr(); - VaryingTupleExpr* getVaryingTupleExpr() const - { - assert(getFlavor() == Flavor::VaryingTuple); - return (VaryingTupleExpr*)value.Ptr(); + default: + assert(getFlavor() == Flavor::simple); + return nullptr; + } } - Expr* asExpr() const + TuplePseudoExpr* getTuple() const { - return (getFlavor() == Flavor::Expr) ? getExpr() : nullptr; + assert(getFlavor() == Flavor::tuple); + return (TuplePseudoExpr*)value.Ptr(); } - TupleExpr* asTuple() const + PairPseudoExpr* getPair() const { - return (getFlavor() == Flavor::Tuple) ? getTupleExpr() : nullptr; + assert(getFlavor() == Flavor::pair); + return (PairPseudoExpr*)value.Ptr(); } - VaryingTupleExpr* asVaryingTuple() const + ImplicitDerefPseudoExpr* getImplicitDeref() const { - return (getFlavor() == Flavor::VaryingTuple) ? getVaryingTupleExpr() : nullptr; + assert(getFlavor() == Flavor::implicitDeref); + return (ImplicitDerefPseudoExpr*)value.Ptr(); } // Allow use in boolean contexts @@ -345,84 +244,77 @@ private: Flavor flavor; }; -// Pseudo-syntax used during lowering -class PseudoVarDecl : public RefObject +struct LegalTypeExpr { -public: - NameLoc nameAndLoc; - SourceLoc loc; - TypeExp type; -}; + LegalType type; + RefPtr<Expr> expr; -class TupleVarDecl : public PseudoVarDecl -{ -public: - struct Element + LegalTypeExpr() + {} + + LegalTypeExpr(LegalType const& type) + : type(type) { - RefPtr<TupleVarModifier> tupleVarMod; - LoweredDecl decl; - }; + } - TupleTypeModifier* tupleType; - RefPtr<VarDeclBase> primaryDecl; - List<Element> tupleDecls; + LegalTypeExpr(TypeExp const& typeExpr) + { + type = LegalType::simple(typeExpr.type); + expr = typeExpr.exp; + } + + TypeExp getSimple() const + { + TypeExp result; + result.type = type.getSimple(); + result.exp = expr; + return result; + } }; class PseudoExpr : public RefObject { public: - SourceLoc loc; - QualType type; + SourceLoc loc; }; -// Pseudo-syntax used during lowering: -// represents an ordered list of expressions as a single unit -class TupleExpr : public PseudoExpr +class ImplicitDerefPseudoExpr : public PseudoExpr { public: - struct Element - { - DeclRef<VarDeclBase> tupleFieldDeclRef; - LoweredExpr expr; - }; - - // Optional reference to the "primary" value of the tuple, - // in the case of a tuple type with "orinary" fields - RefPtr<Expr> primaryExpr; - - // Additional fields to store values for any non-ordinary fields - // (or fields that aren't exclusively orginary) - List<Element> tupleElements; + LegalExpr valueExpr; }; -// Pseudo-syntax used during lowering -class VaryingTupleVarDecl : public PseudoVarDecl +class TuplePseudoExpr : public PseudoExpr { public: - LoweredExpr expr; + struct Element + { + LegalExpr expr; + DeclRef<VarDeclBase> fieldDeclRef; + }; + + List<Element> elements; }; -// Pseudo-syntax used during lowering: -// represents an ordered list of expressions as a single unit -class VaryingTupleExpr : public PseudoExpr +class PairPseudoExpr : public PseudoExpr { public: - struct Element - { - DeclRef<VarDeclBase> originalFieldDeclRef; - LoweredExpr expr; - }; + LegalExpr ordinary; + LegalExpr special; - List<Element> elements; + RefPtr<PairInfo> pairInfo; }; -static SourceLoc getPosition(LoweredExpr const& expr) +static SourceLoc getPosition(LegalExpr const& expr) { switch (expr.getFlavor()) { - case LoweredExpr::Flavor::Expr: return expr.getExpr() ->loc; - case LoweredExpr::Flavor::Tuple: return expr.getTupleExpr() ->loc; - case LoweredExpr::Flavor::VaryingTuple: return expr.getVaryingTupleExpr()->loc; + case LegalExpr::Flavor::none: return SourceLoc(); + case LegalExpr::Flavor::simple: return expr.getSimple() ->loc; + case LegalExpr::Flavor::tuple: return expr.getTuple() ->loc; + case LegalExpr::Flavor::pair: return expr.getPair() ->loc; + case LegalExpr::Flavor::implicitDeref: return expr.getImplicitDeref() ->loc; + default: SLANG_UNREACHABLE("all cases handled"); UNREACHABLE_RETURN(SourceLoc()); @@ -456,7 +348,8 @@ struct SharedLoweringContext RefPtr<ModuleDecl> loweredProgram; - Dictionary<Decl*, LoweredDecl> loweredDecls; + Dictionary<Decl*, RefPtr<Decl>> mapOriginalDeclToLowered; + Dictionary<Decl*, LegalExpr> mapOriginalDeclToExpr; Dictionary<RefObject*, Decl*> mapLoweredDeclToOriginal; // Work to be done at the very start and end of the entry point @@ -474,6 +367,9 @@ struct SharedLoweringContext // The actual result we want to return LoweredEntryPoint result; + + /// State to use when legalizing types. + TypeLegalizationContext* typeLegalizationContext; }; static void attachLayout( @@ -491,9 +387,9 @@ void requireGLSLVersion( ProfileVersion version); struct LoweringVisitor - : ExprVisitor<LoweringVisitor, LoweredExpr> + : ExprVisitor<LoweringVisitor, LegalExpr> , StmtVisitor<LoweringVisitor, void> - , DeclVisitor<LoweringVisitor, LoweredDecl> + , DeclVisitor<LoweringVisitor, RefPtr<Decl>> , ValVisitor<LoweringVisitor, RefPtr<Val>, RefPtr<Type>> { // @@ -512,6 +408,11 @@ struct LoweringVisitor // then this will point to that variable. RefPtr<Variable> resultVariable; + TypeLegalizationContext* getTypeLegalizationContext() + { + return shared->typeLegalizationContext; + } + Session* getSession() { return shared->compileRequest->mSession; @@ -708,22 +609,113 @@ struct LoweringVisitor // Types // - RefPtr<Type> lowerType( + RefPtr<Type> lowerTypeEx( Type* type) { if (!type) return nullptr; - return TypeVisitor::dispatch(type); + RefPtr<Type> loweredType = dispatchType(type); + return loweredType; + } + + LegalType lowerLegalType( + LegalType legalType) + { + switch(legalType.flavor) + { + case LegalType::Flavor::none: + return LegalType(); + + case LegalType::Flavor::simple: + return LegalType::simple( + lowerTypeEx(legalType.getSimple())); + + case LegalType::Flavor::tuple: + { + auto inputTuple = legalType.getTuple(); + RefPtr<TuplePseudoType> resultTuple = new TuplePseudoType(); + for(auto ee : inputTuple->elements) + { + TuplePseudoType::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.type = lowerLegalType(ee.type); + resultTuple->elements.Add(element); + } + return LegalType::tuple(resultTuple); + } + break; + + case LegalType::Flavor::pair: + { + auto inputPair = legalType.getPair(); + RefPtr<PairPseudoType> resultPair = new PairPseudoType(); + return LegalType::pair( + lowerLegalType(inputPair->ordinaryType), + lowerLegalType(inputPair->specialType), + inputPair->pairInfo); + } + break; + + case LegalType::Flavor::implicitDeref: + { + return LegalType::implicitDeref( + lowerLegalType(legalType.getImplicitDeref()->valueType)); + } + break; + + default: + SLANG_UNEXPECTED("uhandled type flavor"); + UNREACHABLE_RETURN(LegalType()); + } } - TypeExp lowerType( + LegalType lowerAndLegalizeType( + Type* type) + { + if (!type) return LegalType(); + + // We will first attempt to legalize the type, so that any parts of + // it that won't be allowed on the target get excised. Once we are + // done with that, we will do the "lowering" process of copying + // any needed bits of AST over to the new module. + LegalType legalType = legalizeType( + getTypeLegalizationContext(), + type); + + LegalType loweredType = lowerLegalType(legalType); + + return loweredType; + } + + TypeExp lowerTypeExprEx( TypeExp const& typeExp) { TypeExp result; - result.type = lowerType(typeExp.type); - result.exp = lowerExpr(typeExp.exp); + result.type = lowerTypeEx(typeExp.type); + result.exp = legalizeSimpleExpr(typeExp.exp); + return result; + } + + LegalTypeExpr lowerAndLegalizeTypeExpr( + TypeExp const& typeExp) + { + LegalTypeExpr result; + result.type = lowerAndLegalizeType(typeExp.type); + result.expr = legalizeSimpleExpr(typeExp.exp); return result; } + RefPtr<Type> lowerAndLegalizeSimpleType( + Type* type) + { + return lowerAndLegalizeType(type).getSimple(); + } + + TypeExp lowerAndlegalizeSimpleTypeExpr( + TypeExp const& typeExp) + { + return lowerAndLegalizeTypeExpr(typeExp).getSimple(); + } + RefPtr<Type> visitIRBasicBlockType(IRBasicBlockType* type) { return type; @@ -756,13 +748,10 @@ struct LoweringVisitor { RefPtr<FuncType> loweredType = new FuncType(); loweredType->setSession(getSession()); - loweredType->resultType = lowerType(type->resultType); + loweredType->resultType = lowerTypeEx(type->resultType); for (auto paramType : type->paramTypes) { - auto loweredParamType = lowerType(paramType); - - // TODO: it seems like this step needs to scalarize - // in the case where a parameter type is a tuple... + auto loweredParamType = lowerTypeEx(paramType); loweredType->paramTypes.Add(loweredParamType); } return loweredType; @@ -781,7 +770,7 @@ struct LoweringVisitor if (shared->target == CodeGenTarget::GLSL) { // GLSL does not support `typedef`, so we will lower it out of existence here - return lowerType(GetType(type->declRef)); + return lowerTypeEx(GetType(type->declRef)); } return getNamedType( @@ -789,30 +778,15 @@ struct LoweringVisitor translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>()); } - RefPtr<Type> visitFilteredTupleType(FilteredTupleType* type) - { - RefPtr<FilteredTupleType> loweredType = new FilteredTupleType(); - loweredType->setSession(type->getSession()); - loweredType->originalType = lowerType(type->originalType); - for (auto ee : type->elements) - { - FilteredTupleType::Element element; - element.fieldDeclRef = ee.fieldDeclRef; - element.type = lowerType(ee.type); - loweredType->elements.Add(element); - } - return loweredType; - } - RefPtr<Type> visitTypeType(TypeType* type) { - return getTypeType(lowerType(type->type)); + return getTypeType(lowerTypeEx(type->type)); } RefPtr<Type> visitArrayExpressionType(ArrayExpressionType* type) { RefPtr<ArrayExpressionType> loweredType = Slang::getArrayType( - lowerType(type->baseType), + lowerTypeEx(type->baseType), lowerVal(type->ArrayLength).As<IntVal>()); return loweredType; } @@ -820,7 +794,7 @@ struct LoweringVisitor RefPtr<Type> visitGroupSharedType(GroupSharedType* type) { return getSession()->getGroupSharedType( - lowerType(type->valueType)); + lowerTypeEx(type->valueType)); } RefPtr<Type> visitParameterBlockType(ParameterBlockType* type) @@ -832,14 +806,15 @@ struct LoweringVisitor // directly to its stated element type, and see how // that works. - return lowerType(type->getElementType()); + return lowerTypeEx(type->getElementType()); // return getSession()->getConstantBufferType( // lowerType(type->getElementType()); } RefPtr<Type> transformSyntaxField(Type* type) { - return lowerType(type); + // TODO: how to handle this... + return type; } RefPtr<Val> visitIRProxyVal(IRProxyVal* val) @@ -851,32 +826,33 @@ struct LoweringVisitor // Expressions // - LoweredExpr lowerExprOrTuple( + LegalExpr legalizeExpr( Expr* expr) { - if (!expr) return LoweredExpr(); + if (!expr) return LegalExpr(); return ExprVisitor::dispatch(expr); } - RefPtr<Expr> lowerExpr( + RefPtr<Expr> legalizeSimpleExpr( Expr* expr) { if (!expr) return nullptr; - auto result = lowerExprOrTuple(expr); - return maybeReifyTuple(result); + auto type = lowerAndLegalizeType(expr->type.type); + auto result = legalizeExpr(expr); + return maybeReifyTuple(result, type).getSimple(); } // catch-all - LoweredExpr visitExpr( + LegalExpr visitExpr( Expr* expr) { - return LoweredExpr(structuralTransform(expr, this)); + return LegalExpr(structuralTransform(expr, this)); } RefPtr<Expr> transformSyntaxField(Expr* expr) { - return lowerExpr(expr); + return legalizeSimpleExpr(expr); } void lowerExprCommon( @@ -884,16 +860,16 @@ struct LoweringVisitor Expr* expr) { loweredExpr->loc = expr->loc; - loweredExpr->type.type = lowerType(expr->type.type); + loweredExpr->type.type = lowerTypeEx(expr->type.type); } void lowerExprCommon( - LoweredExpr const& loweredExpr, - Expr* expr) + LegalExpr const& legalExpr, + Expr* expr) { - if (auto simpleExpr = loweredExpr.asExpr()) + if (legalExpr.getFlavor() == LegalExpr::Flavor::simple) { - lowerExprCommon(simpleExpr, expr); + lowerExprCommon(legalExpr.getSimple(), expr); } } @@ -925,95 +901,49 @@ struct LoweringVisitor return result; } - LoweredExpr createVarRef( - SourceLoc const& loc, - LoweredDecl const& decl) + LegalExpr createVarRef( + SourceLoc const& loc, + VarDeclBase* decl) { - switch (decl.getFlavor()) - { - case LoweredDecl::Flavor::Decl: - return LoweredExpr(createSimpleVarRef(loc, decl.getDecl()->As<VarDeclBase>())); - - case LoweredDecl::Flavor::Tuple: - return createTupleRef(loc, decl.getTupleDecl()); - - case LoweredDecl::Flavor::VaryingTuple: - return createVaryingTupleRef(loc, decl.getVaryingTupleDecl()); - - default: - SLANG_UNREACHABLE("all cases handled"); - UNREACHABLE_RETURN(LoweredExpr()); - } + return LegalExpr(createSimpleVarRef(loc, decl)); } - - LoweredExpr createTupleRef( - SourceLoc const& loc, - TupleVarDecl* decl) + RefPtr<Expr> createSimpleVarExpr( + VarExpr* expr, + DeclRef<Decl> const& declRef) { - RefPtr<TupleExpr> result = new TupleExpr(); - result->loc = loc; - result->type.type = decl->type.type; - - if (auto primaryDecl = decl->primaryDecl) - { - result->primaryExpr = createSimpleVarRef(loc, primaryDecl); - } - - for (auto declElem : decl->tupleDecls) + RefPtr<VarExpr> loweredExpr = new VarExpr(); + if (expr) { - auto tupleVarMod = declElem.tupleVarMod; - SLANG_RELEASE_ASSERT(tupleVarMod); - auto tupleFieldMod = tupleVarMod->tupleField; - SLANG_RELEASE_ASSERT(tupleFieldMod); - SLANG_RELEASE_ASSERT(tupleFieldMod->decl); - - TupleExpr::Element elem; - elem.tupleFieldDeclRef = makeDeclRef(tupleFieldMod->decl); - elem.expr = createVarRef(loc, declElem.decl); - result->tupleElements.Add(elem); + lowerExprCommon(loweredExpr, expr); } - - return LoweredExpr(result); - } - - LoweredExpr createVaryingTupleRef( - SourceLoc const& /*loc*/, - VaryingTupleVarDecl* decl) - { - return decl->expr; + loweredExpr->declRef = declRef; + loweredExpr->name = expr->name; + return loweredExpr; } - LoweredExpr visitVarExpr( + LegalExpr visitVarExpr( VarExpr* expr) { // If the expression didn't get resolved, we can leave it as-is if (!expr->declRef) return expr; + // Ensure that lowering has been applied to the declaration auto loweredDeclRef = translateDeclRef(expr->declRef); - auto loweredDecl = loweredDeclRef.getDecl(); - if (auto tupleVarDecl = loweredDecl.asTupleDecl()) - { - // If we are referencing a declaration that got tuple-ified, - // then we need to produce a tuple expression as well. - - return createTupleRef(expr->loc, tupleVarDecl); - } - else if (auto varyingTupleVarDecl = loweredDecl.asVaryingTupleDecl()) - { - return createVaryingTupleRef(expr->loc, varyingTupleVarDecl); - } + // Is there a value already registered for use when looking + // up this variable? + LegalExpr legalExpr; + if (this->shared->mapOriginalDeclToExpr.TryGetValue(expr->declRef.getDecl(), legalExpr)) + return legalExpr; - RefPtr<VarExpr> loweredExpr = new VarExpr(); - lowerExprCommon(loweredExpr, expr); - loweredExpr->declRef = loweredDeclRef.As<Decl>(); - loweredExpr->name = expr->name; - return LoweredExpr(loweredExpr); + return LegalExpr(createSimpleVarExpr( + expr, + loweredDeclRef)); } - LoweredExpr visitOverloadedExpr( + LegalExpr visitOverloadedExpr( OverloadedExpr* expr) { // The presence of an overloaded expression in the output @@ -1077,48 +1007,64 @@ struct LoweringVisitor return moveTemp(expr); } - LoweredExpr maybeMoveTemp( - LoweredExpr expr) + LegalExpr maybeMoveTemp( + LegalExpr expr) { - if (auto tupleExpr = expr.asTuple()) + switch (expr.getFlavor()) { - RefPtr<TupleExpr> resultExpr = new TupleExpr(); - resultExpr->loc = tupleExpr->loc; - resultExpr->type = tupleExpr->type; - if (tupleExpr->primaryExpr) + case LegalExpr::Flavor::none: + return LegalExpr(); + + case LegalExpr::Flavor::simple: + return LegalExpr(maybeMoveTemp(expr.getSimple())); + + case LegalExpr::Flavor::tuple: { - resultExpr->primaryExpr = maybeMoveTemp(tupleExpr->primaryExpr); + auto tupleExpr = expr.getTuple(); + RefPtr<TuplePseudoExpr> resultExpr = new TuplePseudoExpr(); + resultExpr->loc = tupleExpr->loc; + + for (auto ee : tupleExpr->elements) + { + TuplePseudoExpr::Element element; + element.expr = maybeMoveTemp(ee.expr); + element.fieldDeclRef = ee.fieldDeclRef; + resultExpr->elements.Add(element); + } + + return LegalExpr(resultExpr); } - for (auto ee : tupleExpr->tupleElements) + break; + + case LegalExpr::Flavor::pair: { - TupleExpr::Element elem; - elem.tupleFieldDeclRef = ee.tupleFieldDeclRef; - elem.expr = maybeMoveTemp(ee.expr); + auto pairExpr = expr.getPair(); + RefPtr<PairPseudoExpr> resultExpr = new PairPseudoExpr(); + resultExpr->loc = pairExpr->loc; + resultExpr->pairInfo = pairExpr->pairInfo; - resultExpr->tupleElements.Add(elem); + resultExpr->ordinary = maybeMoveTemp(pairExpr->ordinary); + resultExpr->special = maybeMoveTemp(pairExpr->special); + + return LegalExpr(resultExpr); } + break; - return LoweredExpr(resultExpr); - } - else if (auto varyingTupleExpr = expr.asVaryingTuple()) - { - RefPtr<VaryingTupleExpr> resultExpr = new VaryingTupleExpr(); - resultExpr->loc = varyingTupleExpr->loc; - resultExpr->type = varyingTupleExpr->type; - for (auto ee : varyingTupleExpr->elements) + case LegalExpr::Flavor::implicitDeref: { - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = ee.originalFieldDeclRef; - elem.expr = maybeMoveTemp(ee.expr); + auto implicitDerefExpr = expr.getImplicitDeref(); + RefPtr<ImplicitDerefPseudoExpr> resultExpr = new ImplicitDerefPseudoExpr(); + resultExpr->loc = implicitDerefExpr->loc; + + resultExpr->valueExpr = maybeMoveTemp(implicitDerefExpr->valueExpr); - resultExpr->elements.Add(elem); + return LegalExpr(resultExpr); } + break; - return LoweredExpr(resultExpr); - } - else - { - return LoweredExpr(maybeMoveTemp(expr.getExpr())); + default: + SLANG_UNEXPECTED("unhandled case"); + UNREACHABLE_RETURN(LegalExpr()); } } @@ -1133,8 +1079,8 @@ struct LoweringVisitor return expr; } - LoweredExpr ensureSimpleLValue( - LoweredExpr expr) + LegalExpr ensureSimpleLValue( + LegalExpr expr) { // TODO: actually implement this properly! @@ -1393,197 +1339,139 @@ struct LoweringVisitor } } - LoweredExpr createAssignExpr( - LoweredExpr leftExpr, - LoweredExpr rightExpr, + LegalExpr createAssignExpr( + LegalExpr leftExpr, + LegalExpr rightExpr, AssignMode mode = AssignMode::Default) { - auto leftTuple = leftExpr.asTuple(); - auto rightTuple = rightExpr.asTuple(); - if (leftTuple && rightTuple) + switch (leftExpr.getFlavor()) { - RefPtr<TupleExpr> resultTuple = new TupleExpr(); - resultTuple->type = leftTuple->type; - - if (leftTuple->primaryExpr) - { - SLANG_RELEASE_ASSERT(rightTuple->primaryExpr); - - resultTuple->primaryExpr = createSimpleAssignExpr( - leftTuple->primaryExpr, - rightTuple->primaryExpr, - mode); - } + case LegalExpr::Flavor::none: + return LegalExpr(); - auto elementCount = leftTuple->tupleElements.Count(); - SLANG_RELEASE_ASSERT(elementCount == rightTuple->tupleElements.Count()); - for (UInt ee = 0; ee < elementCount; ++ee) + case LegalExpr::Flavor::simple: + switch (rightExpr.getFlavor()) { - auto leftElement = leftTuple->tupleElements[ee]; - auto rightElement = rightTuple->tupleElements[ee]; + case LegalExpr::Flavor::simple: + return LegalExpr(createSimpleAssignExpr( + leftExpr.getSimple(), + rightExpr.getSimple(), + mode)); - TupleExpr::Element resultElement; - - resultElement.tupleFieldDeclRef = leftElement.tupleFieldDeclRef; - resultElement.expr = createAssignExpr( - leftElement.expr, - rightElement.expr, - mode); + case LegalExpr::Flavor::tuple: + { + auto rightTuple = rightExpr.getTuple(); + RefPtr<TuplePseudoExpr> resultTuple = new TuplePseudoExpr(); + for (auto ee : rightTuple->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = createAssignExpr( + extractField(leftExpr, ee.fieldDeclRef), + ee.expr, + mode); + + resultTuple->elements.Add(element); + } + return LegalExpr(resultTuple); + } + break; - resultTuple->tupleElements.Add(resultElement); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); } + break; - return LoweredExpr(resultTuple); - } - else - { - SLANG_RELEASE_ASSERT(!leftTuple && !rightTuple); - } - - auto leftVaryingTuple = leftExpr.asVaryingTuple(); - auto rightVaryingTuple = rightExpr.asVaryingTuple(); - - RefPtr<Expr> leftSimpleExpr = leftExpr.asExpr(); - RefPtr<Expr> rightSimpleExpr = rightExpr.asExpr(); - - if (leftVaryingTuple && rightVaryingTuple) - { - RefPtr<VaryingTupleExpr> resultTuple = new VaryingTupleExpr(); - resultTuple->type.type = leftVaryingTuple->type.type; - resultTuple->loc = leftVaryingTuple->loc; - - SLANG_RELEASE_ASSERT(resultTuple->type.type); - - UInt elementCount = leftVaryingTuple->elements.Count(); - SLANG_RELEASE_ASSERT(elementCount == rightVaryingTuple->elements.Count()); - - for (UInt ee = 0; ee < elementCount; ++ee) + case LegalExpr::Flavor::tuple: { - auto leftElem = leftVaryingTuple->elements[ee]; - auto rightElem = rightVaryingTuple->elements[ee]; - - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = leftElem.originalFieldDeclRef; - elem.expr = createAssignExpr( - leftElem.expr, - rightElem.expr, - mode); - } + rightExpr = maybeMoveTemp(rightExpr); - return LoweredExpr(resultTuple); - } - else if (leftVaryingTuple && rightSimpleExpr) - { - // Assigning from ordinary expression on RHS to tuple. - // This will naturally yield a tuple expression. - - RefPtr<VaryingTupleExpr> resultTuple = new VaryingTupleExpr(); - resultTuple->type.type = leftVaryingTuple->type.type; - resultTuple->loc = leftVaryingTuple->loc; - - SLANG_RELEASE_ASSERT(resultTuple->type.type); - - UInt elementCount = leftVaryingTuple->elements.Count(); - - // Move everything into temps if we can - rightSimpleExpr = maybeMoveTemp(rightSimpleExpr); - for (UInt ee = 0; ee < elementCount; ++ee) - { - auto& leftElem = leftVaryingTuple->elements[ee]; - leftElem.expr = ensureSimpleLValue(leftElem.expr); + auto leftTuple = leftExpr.getTuple(); + RefPtr<TuplePseudoExpr> resultTuple = new TuplePseudoExpr(); + resultTuple->loc = leftTuple->loc; + for (auto ee : leftTuple->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = createAssignExpr( + ee.expr, + extractField(rightExpr, ee.fieldDeclRef), + mode); + + resultTuple->elements.Add(element); + } + return LegalExpr(resultTuple); } + break; - // - - for (UInt ee = 0; ee < elementCount; ++ee) + case LegalExpr::Flavor::pair: { - auto leftElem = leftVaryingTuple->elements[ee]; - - - RefPtr<MemberExpr> rightElemExpr = new MemberExpr(); - rightElemExpr->loc = rightSimpleExpr->loc; - rightElemExpr->type.type = GetType(leftElem.originalFieldDeclRef); - rightElemExpr->declRef = leftElem.originalFieldDeclRef; - rightElemExpr->name = leftElem.originalFieldDeclRef.GetName(); - rightElemExpr->BaseExpression = rightSimpleExpr; - - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = leftElem.originalFieldDeclRef; - elem.expr = createAssignExpr( - leftElem.expr, - LoweredExpr(rightElemExpr), - mode); - - resultTuple->elements.Add(elem); - } - - return LoweredExpr(resultTuple); - } - else if (leftSimpleExpr && rightVaryingTuple) - { - // Pretty much the same as the above case, and we should - // probably try to share code eventually. - - - RefPtr<VaryingTupleExpr> resultTuple = new VaryingTupleExpr(); - resultTuple->type.type = leftSimpleExpr->type.type; - resultTuple->loc = leftSimpleExpr->loc; - - SLANG_RELEASE_ASSERT(resultTuple->type.type); - - UInt elementCount = rightVaryingTuple->elements.Count(); + auto leftPair = leftExpr.getPair(); + switch( rightExpr.getFlavor() ) + { + case LegalExpr::Flavor::pair: + { + auto rightPair = rightExpr.getPair(); + RefPtr<PairPseudoExpr> resultPair = new PairPseudoExpr(); + resultPair->loc = leftPair->loc; + resultPair->pairInfo = leftPair->pairInfo; + + resultPair->ordinary = createAssignExpr( + leftPair->ordinary, + rightPair->ordinary, + mode); + resultPair->special = createAssignExpr( + leftPair->special, + rightPair->special, + mode); + + return LegalExpr(resultPair); + } + break; - // Move everything into temps if we can - leftSimpleExpr = ensureSimpleLValue(leftSimpleExpr); - for (UInt ee = 0; ee < elementCount; ++ee) - { - auto& rightElem = rightVaryingTuple->elements[ee]; - rightElem.expr = maybeMoveTemp(rightElem.expr); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); + } } + break; - - for (UInt ee = 0; ee < elementCount; ++ee) + case LegalExpr::Flavor::implicitDeref: { - auto rightElem = rightVaryingTuple->elements[ee]; - - - RefPtr<MemberExpr> leftElemExpr = new MemberExpr(); - leftElemExpr->loc = leftSimpleExpr->loc; - leftElemExpr->type.type = GetType(rightElem.originalFieldDeclRef); - leftElemExpr->declRef = rightElem.originalFieldDeclRef; - leftElemExpr->name = rightElem.originalFieldDeclRef.GetName(); - leftElemExpr->BaseExpression = leftSimpleExpr; - - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = rightElem.originalFieldDeclRef; - elem.expr = createAssignExpr( - LoweredExpr(leftElemExpr), - rightElem.expr, - mode); + auto leftImplicitDeref = leftExpr.getImplicitDeref(); + switch(rightExpr.getFlavor()) + { + case LegalExpr::Flavor::implicitDeref: + { + auto rightImplicitDeref = rightExpr.getImplicitDeref(); + RefPtr<ImplicitDerefPseudoExpr> resultImplicitDeref = new ImplicitDerefPseudoExpr(); + resultImplicitDeref->loc = leftImplicitDeref->loc; + resultImplicitDeref->valueExpr = createAssignExpr( + leftImplicitDeref->valueExpr, + rightImplicitDeref->valueExpr, + mode); + + return LegalExpr(resultImplicitDeref); + } - resultTuple->elements.Add(elem); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); + } } - return LoweredExpr(resultTuple); - } - else if (leftSimpleExpr && rightSimpleExpr) - { - // Default case: no tuples of any kind... - - return LoweredExpr(createSimpleAssignExpr(leftSimpleExpr, rightSimpleExpr, mode)); - } - else - { - // Some case wasn't handled: diagnose! - SLANG_UNEXPECTED("bad combination of tuple types"); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); } } - LoweredExpr visitAssignExpr( + LegalExpr visitAssignExpr( AssignExpr* expr) { - auto leftExpr = lowerExprOrTuple(expr->left); - auto rightExpr = lowerExprOrTuple(expr->right); + auto leftExpr = legalizeExpr(expr->left); + auto rightExpr = legalizeExpr(expr->right); auto loweredExpr = createAssignExpr(leftExpr, rightExpr); lowerExprCommon(loweredExpr, expr); @@ -1614,10 +1502,77 @@ struct LoweringVisitor return loweredExpr; } - LoweredExpr createSubscriptExpr( - LoweredExpr baseExpr, + LegalExpr createSubscriptExpr( + LegalExpr baseExpr, RefPtr<Expr> indexExpr) { + switch (baseExpr.getFlavor()) + { + case LegalExpr::Flavor::none: + return LegalExpr(); + + case LegalExpr::Flavor::simple: + return LegalExpr(createSimpleSubscriptExpr( + baseExpr.getSimple(), + indexExpr)); + + case LegalExpr::Flavor::tuple: + { + indexExpr = maybeMoveTemp(indexExpr); + + auto baseTuple = baseExpr.getTuple(); + + auto resultTuple = new TuplePseudoExpr(); + resultTuple->loc = baseTuple->loc; + + for (auto ee : baseTuple->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = createSubscriptExpr( + ee.expr, + indexExpr); + + resultTuple->elements.Add(element); + } + + return LegalExpr(resultTuple); + } + break; + + case LegalExpr::Flavor::pair: + { + indexExpr = maybeMoveTemp(indexExpr); + + auto basePair = baseExpr.getPair(); + + RefPtr<PairPseudoExpr> resultPair = new PairPseudoExpr(); + resultPair->loc = basePair->loc; + + resultPair->ordinary = createSubscriptExpr(basePair->ordinary, indexExpr); + resultPair->special = createSubscriptExpr(basePair->special, indexExpr); + + return LegalExpr(resultPair); + } + + case LegalExpr::Flavor::implicitDeref: + { + auto baseImplicitDeref = baseExpr.getImplicitDeref(); + + RefPtr<ImplicitDerefPseudoExpr> resultImplicitDeref = new ImplicitDerefPseudoExpr(); + resultImplicitDeref->loc = baseImplicitDeref->loc; + + resultImplicitDeref->valueExpr = createSubscriptExpr(baseImplicitDeref->valueExpr, indexExpr); + + return LegalExpr(resultImplicitDeref); + } + + default: + SLANG_UNEXPECTED("unhandled case"); + UNREACHABLE_RETURN(LegalExpr()); + } + +#if 0 // TODO: This logic ends up duplicating the `indexExpr` // that was given, without worrying about any side // effects it might contain. That needs to be fixed. @@ -1670,71 +1625,30 @@ struct LoweringVisitor } else { - return LoweredExpr(createSimpleSubscriptExpr( + return LegalExpr(createSimpleSubscriptExpr( baseExpr.getExpr(), indexExpr)); } +#endif } - LoweredExpr visitIndexExpr( + LegalExpr visitIndexExpr( IndexExpr* subscriptExpr) { - auto baseExpr = lowerExprOrTuple(subscriptExpr->BaseExpression); - auto indexExpr = lowerExpr(subscriptExpr->IndexExpression); + auto baseExpr = legalizeExpr(subscriptExpr->BaseExpression); + auto indexExpr = legalizeSimpleExpr(subscriptExpr->IndexExpression); - // An attempt to subscript a tuple must be turned into a - // tuple of subscript expressions. - if (auto baseTuple = baseExpr.asTuple()) - { - return createSubscriptExpr(baseExpr, indexExpr); - } - else if (auto baseVaryingTuple = baseExpr.asVaryingTuple()) - { - return createSubscriptExpr(baseExpr, indexExpr); - } - else + if(baseExpr.getFlavor() == LegalExpr::Flavor::simple) { // Default case: just reconstrut a subscript expr RefPtr<IndexExpr> loweredExpr = new IndexExpr(); lowerExprCommon(loweredExpr, subscriptExpr); - loweredExpr->BaseExpression = baseExpr.getExpr(); + loweredExpr->BaseExpression = baseExpr.getSimple(); loweredExpr->IndexExpression = indexExpr; - return LoweredExpr(loweredExpr); - } - } - - RefPtr<Expr> maybeReifyTuple( - LoweredExpr expr) - { - if (auto tupleExpr = expr.asTuple()) - { - // TODO: need to diagnose - return tupleExpr->primaryExpr; - } - else if (auto varyingTupleExpr = expr.asVaryingTuple()) - { - // Need to pass an ordinary (non-tuple) expression of - // the corresponding type here. - - // TODO(tfoley): This won't work at all for an `out` or `inout` - // function argument, so we'll need to figure out a plan - // to handle that case... - - RefPtr<AggTypeCtorExpr> resultExpr = new AggTypeCtorExpr(); - resultExpr->type = varyingTupleExpr->type; - resultExpr->base.type = varyingTupleExpr->type.type; - SLANG_RELEASE_ASSERT(resultExpr->type.type); - - for (auto elem : varyingTupleExpr->elements) - { - addArgs(resultExpr, elem.expr); - } - - return resultExpr; + return LegalExpr(loweredExpr); } - // Default case: nothing special to this expression - return expr.getExpr(); + return createSubscriptExpr(baseExpr, indexExpr); } bool needGlslangBug988Workaround( @@ -1833,31 +1747,130 @@ struct LoweringVisitor callExpr->Arguments.Add(argExpr); } + // Take a legalized expression that might be represented as a tuple, + // and turn it back into a single ordinary expression of the given type. + // + // This is used in the case where we tuple-ified a value that has + // a legal type, but just isn't legal to use in a particular context. + static RefPtr<Expr> reifyTuple( + LegalExpr legalExpr, + RefPtr<Type> type) + { + if (legalExpr.getFlavor() == LegalExpr::Flavor::simple) + return legalExpr.getSimple(); + + if (auto declRefType = type->As<DeclRefType>()) + { + auto declRef = declRefType->declRef; + if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + { + // We want a single value of an aggregate type, which + // means we need to extract each of its fields from + // the expression. + + switch (legalExpr.getFlavor()) + { + case LegalExpr::Flavor::tuple: + { + auto tupleExpr = legalExpr.getTuple(); + + RefPtr<AggTypeCtorExpr> resultExpr = new AggTypeCtorExpr(); + resultExpr->type.type = type; + resultExpr->base.type = type; + SLANG_RELEASE_ASSERT(resultExpr->type.type); + + UInt fieldCounter = 0; + for (auto fieldDeclRef : getMembersOfType<StructField>(aggTypeDeclRef)) + { + if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + continue; + + UInt fieldIndex = fieldCounter++; + + resultExpr->Arguments.Add(reifyTuple( + tupleExpr->elements[fieldIndex].expr, + GetType(fieldDeclRef))); + } + + return resultExpr; + } + break; + } + + } + } + // TODO: need to handle array types here... + + SLANG_UNEXPECTED("unhandled case"); + UNREACHABLE_RETURN(legalExpr.getSimple()); + } + + static LegalExpr maybeReifyTuple( + LegalExpr legalExpr, + LegalType expectedType) + { + if (expectedType.flavor != LegalType::Flavor::simple) + return legalExpr; + + if (legalExpr.getFlavor() == LegalExpr::Flavor::simple) + return legalExpr; + + return LegalExpr(reifyTuple(legalExpr, expectedType.getSimple())); + } + void addArgs( ExprWithArgsBase* callExpr, - LoweredExpr argExpr) + LegalType argType, + LegalExpr argExpr) { - if (auto argTuple = argExpr.asTuple()) + argExpr = maybeReifyTuple(argExpr, argType); + + if (argExpr.getFlavor() != argType.flavor) + { + SLANG_UNEXPECTED("expression and type do not match"); + } + + switch (argExpr.getFlavor()) { - if (argTuple->primaryExpr) + case LegalExpr::Flavor::none: + break; + + case LegalExpr::Flavor::simple: + addArg(callExpr, argExpr.getSimple()); + break; + + case LegalExpr::Flavor::tuple: { - addArg(callExpr, argTuple->primaryExpr); + auto aa = argExpr.getTuple(); + auto at = argType.getTuple(); + auto elementCount = aa->elements.Count(); + for (UInt ee = 0; ee < elementCount; ++ee) + { + addArgs(callExpr, at->elements[ee].type, aa->elements[ee].expr); + } } - for (auto elem : argTuple->tupleElements) + break; + + case LegalExpr::Flavor::pair: { - addArgs(callExpr, elem.expr); + auto aa = argExpr.getPair(); + auto at = argType.getPair(); + addArgs(callExpr, at->ordinaryType, aa->ordinary); + addArgs(callExpr, at->specialType, aa->special); } - } - else if (auto varyingArgTuple = argExpr.asVaryingTuple()) - { - // Need to pass an ordinary (non-tuple) expression of - // the corresponding type here. + break; - callExpr->Arguments.Add(maybeReifyTuple(argExpr)); - } - else - { - addArg(callExpr, argExpr.getExpr()); + case LegalExpr::Flavor::implicitDeref: + { + auto aa = argExpr.getImplicitDeref(); + auto at = argType.getImplicitDeref(); + addArgs(callExpr, at->valueType, aa->valueExpr); + } + break; + + default: + SLANG_UNEXPECTED("unhandled case"); + break; } } @@ -1867,38 +1880,51 @@ struct LoweringVisitor { lowerExprCommon(loweredExpr, expr); - loweredExpr->FunctionExpr = lowerExpr(expr->FunctionExpr); + loweredExpr->FunctionExpr = legalizeSimpleExpr(expr->FunctionExpr); for (auto arg : expr->Arguments) { - auto loweredArg = lowerExprOrTuple(arg); - addArgs(loweredExpr, loweredArg); + auto argType = lowerAndLegalizeType(arg->type.type); + auto loweredArg = legalizeExpr(arg); + addArgs(loweredExpr, argType, loweredArg); } return loweredExpr; } - LoweredExpr visitInvokeExpr( + LegalExpr visitInvokeExpr( InvokeExpr* expr) { // Create a clone with the same class InvokeExpr* loweredExpr = (InvokeExpr*) expr->getClass().createInstance(); - return LoweredExpr(lowerCallExpr(loweredExpr, expr)); + return LegalExpr(lowerCallExpr(loweredExpr, expr)); } - LoweredExpr visitSelectExpr( + LegalExpr visitSelectExpr( SelectExpr* expr) { // TODO: A tuple needs to be special-cased here - return LoweredExpr(lowerCallExpr(new SelectExpr(), expr)); + return LegalExpr(lowerCallExpr(new SelectExpr(), expr)); } - LoweredExpr visitDerefExpr( + LegalExpr visitDerefExpr( DerefExpr* expr) { - auto loweredBase = lowerExprOrTuple(expr->base); + auto legalBase = legalizeExpr(expr->base); + if (legalBase.getFlavor() == LegalExpr::Flavor::simple) + { + // Default case is just to lower a dereference opertion + // into another dereference. + RefPtr<DerefExpr> loweredExpr = new DerefExpr(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->base = legalBase.getSimple(); + return LegalExpr(loweredExpr); + } + + return deref(legalBase); +#if 0 if (auto baseTuple = loweredBase.asTuple()) { // In the case of a tuple created for "resources in structs" reasons, @@ -1938,13 +1964,7 @@ struct LoweringVisitor // // TODO: implement this. } - - // Default case is just to lower a dereference opertion - // into another dereference. - RefPtr<DerefExpr> loweredExpr = new DerefExpr(); - lowerExprCommon(loweredExpr, expr); - loweredExpr->base = loweredBase.getExpr(); - return LoweredExpr(loweredExpr); +#endif } DiagnosticSink* getSink() @@ -1952,108 +1972,43 @@ struct LoweringVisitor return &shared->compileRequest->mSink; } - LoweredExpr visitStaticMemberExpr( + LegalExpr visitStaticMemberExpr( StaticMemberExpr* expr) { - auto loweredBase = lowerExprOrTuple(expr->BaseExpression); + auto loweredBase = legalizeExpr(expr->BaseExpression); auto loweredDeclRef = translateDeclRef(expr->declRef); // TODO: we should probably support type-type members here. RefPtr<StaticMemberExpr> loweredExpr = new StaticMemberExpr(); lowerExprCommon(loweredExpr, expr); - loweredExpr->BaseExpression = loweredBase.getExpr(); + loweredExpr->BaseExpression = loweredBase.getSimple(); loweredExpr->declRef = loweredDeclRef.As<Decl>(); loweredExpr->name = expr->name; - return LoweredExpr(loweredExpr); + return LegalExpr(loweredExpr); } - LoweredExpr visitMemberExpr( + LegalExpr visitMemberExpr( MemberExpr* expr) { assert(expr->BaseExpression); - auto loweredBase = lowerExprOrTuple(expr->BaseExpression); - assert(loweredBase); - - auto loweredDeclRef = translateDeclRef(expr->declRef); - - - // Are we extracting an element from a tuple? - if (auto baseTuple = loweredBase.asTuple()) - { - auto loweredFieldDecl = loweredDeclRef.As<Decl>().getDecl(); - auto tupleFieldMod = loweredFieldDecl->FindModifier<TupleFieldModifier>(); - if (tupleFieldMod) - { - // This field has a tuple part to it, so we need to search for it - - LoweredExpr tupleFieldExpr; - for (auto elem : baseTuple->tupleElements) - { - if (loweredFieldDecl == elem.tupleFieldDeclRef.getDecl()) - { - tupleFieldExpr = elem.expr; - break; - } - } - - if (!tupleFieldMod->hasAnyNonTupleFields) - { - // We need to have found something! - assert(tupleFieldExpr); - return tupleFieldExpr; - } - - auto tupleFieldTupleExpr = tupleFieldExpr.asTuple(); - SLANG_RELEASE_ASSERT(tupleFieldTupleExpr); - SLANG_RELEASE_ASSERT(!tupleFieldTupleExpr->primaryExpr); - - - RefPtr<MemberExpr> loweredPrimaryExpr = new MemberExpr(); - lowerExprCommon(loweredPrimaryExpr, expr); - loweredPrimaryExpr->BaseExpression = baseTuple->primaryExpr; - loweredPrimaryExpr->declRef = loweredDeclRef.As<Decl>(); - loweredPrimaryExpr->name = expr->name; - - assert(loweredPrimaryExpr->BaseExpression); + auto legalBase = legalizeExpr(expr->BaseExpression); + assert(legalBase); - tupleFieldTupleExpr->primaryExpr = loweredPrimaryExpr; - return tupleFieldTupleExpr; - } - - // If the field was a non-tuple field, then we can - // simply fall through to the ordinary case below. - loweredBase = LoweredExpr(baseTuple->primaryExpr); - assert(baseTuple->primaryExpr); - } - else if (auto baseVaryingTuple = loweredBase.asVaryingTuple()) + if (legalBase.getFlavor() == LegalExpr::Flavor::simple) { - // Search for the element corresponding to this field - for(auto elem : baseVaryingTuple->elements) - { - if (expr->declRef.getDecl() == elem.originalFieldDeclRef.getDecl()) - { - // We found the field! - assert(elem.expr); - return elem.expr; - } - } - - SLANG_DIAGNOSE_UNEXPECTED(getSink(), expr, "failed to find tuple field during lowering"); + // Default handling: + RefPtr<MemberExpr> loweredExpr = new MemberExpr(); + lowerExprCommon(loweredExpr, expr); + loweredExpr->BaseExpression = legalBase.getSimple(); + loweredExpr->declRef = translateDeclRef(expr->declRef); + loweredExpr->name = expr->name; + assert(loweredExpr->BaseExpression); + return LegalExpr(loweredExpr); } - // Default handling: - - RefPtr<MemberExpr> loweredExpr = new MemberExpr(); - lowerExprCommon(loweredExpr, expr); - loweredExpr->BaseExpression = loweredBase.getExpr(); - loweredExpr->declRef = loweredDeclRef.As<Decl>(); - loweredExpr->name = expr->name; - - assert(loweredExpr->BaseExpression); - - return LoweredExpr(loweredExpr); + return extractField(legalBase, expr->declRef.As<VarDeclBase>()); } // @@ -2123,18 +2078,18 @@ struct LoweringVisitor StmtVisitor::dispatch(stmt); } - LoweredDecl visitScopeDecl(ScopeDecl* decl) + RefPtr<Decl> visitScopeDecl(ScopeDecl* decl) { RefPtr<ScopeDecl> loweredDecl = new ScopeDecl(); lowerDeclCommon(loweredDecl, decl); - return LoweredDecl(loweredDecl); + return loweredDecl; } LoweringVisitor pushScope( RefPtr<ScopeStmt> loweredStmt, RefPtr<ScopeStmt> originalStmt) { - loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).getDecl()->As<ScopeDecl>(); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl)->As<ScopeDecl>(); LoweringVisitor subVisitor = *this; subVisitor.isBuildingStmt = true; @@ -2215,32 +2170,46 @@ struct LoweringVisitor } void addExprStmt( - LoweredExpr expr) + LegalExpr expr) { // Desugar tuples in statement position - if (auto tupleExpr = expr.asTuple()) + switch (expr.getFlavor()) { - if (tupleExpr->primaryExpr) + case LegalExpr::Flavor::none: + break; + + case LegalExpr::Flavor::simple: + addSimpleExprStmt(expr.getSimple()); + break; + + case LegalExpr::Flavor::tuple: { - addSimpleExprStmt(tupleExpr->primaryExpr); + auto tupleExpr = expr.getTuple(); + for (auto ee : tupleExpr->elements) + { + addExprStmt(ee.expr); + } } - for (auto ee : tupleExpr->tupleElements) + break; + + case LegalExpr::Flavor::pair: { - addExprStmt(ee.expr); + auto pairExpr = expr.getPair(); + addExprStmt(pairExpr->ordinary); + addExprStmt(pairExpr->special); } - return; - } - else if (auto varyingTupleExpr = expr.asVaryingTuple()) - { - for (auto ee : varyingTupleExpr->elements) + break; + + case LegalExpr::Flavor::implicitDeref: { - addExprStmt(ee.expr); + auto implicitDerefExpr = expr.getImplicitDeref(); + addExprStmt(implicitDerefExpr->valueExpr); } - return; - } - else - { - addSimpleExprStmt(expr.getExpr()); + break; + + default: + SLANG_UNEXPECTED("unhandled case"); + break; } } @@ -2266,7 +2235,7 @@ struct LoweringVisitor void visitExpressionStmt(ExpressionStmt* stmt) { - addExprStmt(lowerExprOrTuple(stmt->Expression)); + addExprStmt(legalizeExpr(stmt->Expression)); } void visitDeclStmt(DeclStmt* stmt) @@ -2297,7 +2266,7 @@ struct LoweringVisitor ScopeStmt* originalStmt) { lowerStmtFields(loweredStmt, originalStmt); - loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl).getDecl()->As<ScopeDecl>(); + loweredStmt->scopeDecl = translateDeclRef(originalStmt->scopeDecl)->As<ScopeDecl>(); } // Child statements reference their parent statement, @@ -2361,7 +2330,7 @@ struct LoweringVisitor RefPtr<CaseStmt> loweredStmt = new CaseStmt(); lowerChildStmtFields(loweredStmt, stmt); - loweredStmt->expr = lowerExpr(stmt->expr); + loweredStmt->expr = legalizeSimpleExpr(stmt->expr); addStmt(loweredStmt); } @@ -2371,7 +2340,7 @@ struct LoweringVisitor RefPtr<IfStmt> loweredStmt = new IfStmt(); lowerStmtFields(loweredStmt, stmt); - loweredStmt->Predicate = lowerExpr(stmt->Predicate); + loweredStmt->Predicate = legalizeSimpleExpr(stmt->Predicate); loweredStmt->PositiveStatement = lowerStmt(stmt->PositiveStatement); loweredStmt->NegativeStatement = lowerStmt(stmt->NegativeStatement); @@ -2385,7 +2354,7 @@ struct LoweringVisitor LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); - loweredStmt->condition = subVisitor.lowerExpr(stmt->condition); + loweredStmt->condition = subVisitor.legalizeSimpleExpr(stmt->condition); loweredStmt->body = subVisitor.lowerStmt(stmt->body); addStmt(loweredStmt); @@ -2400,8 +2369,8 @@ struct LoweringVisitor LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); loweredStmt->InitialStatement = subVisitor.lowerStmt(stmt->InitialStatement); - loweredStmt->SideEffectExpression = subVisitor.lowerExpr(stmt->SideEffectExpression); - loweredStmt->PredicateExpression = subVisitor.lowerExpr(stmt->PredicateExpression); + loweredStmt->SideEffectExpression = subVisitor.legalizeSimpleExpr(stmt->SideEffectExpression); + loweredStmt->PredicateExpression = subVisitor.legalizeSimpleExpr(stmt->PredicateExpression); loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement); addStmt(loweredStmt); @@ -2431,8 +2400,9 @@ struct LoweringVisitor return; auto varDecl = stmt->varDecl; + shared->mapOriginalDeclToLowered[varDecl] = nullptr; - auto varType = lowerType(varDecl->type); + auto varType = lowerTypeExprEx(varDecl->type); for (IntegerLiteralValue ii = rangeBeginVal; ii < rangeEndVal; ++ii) { @@ -2441,12 +2411,7 @@ struct LoweringVisitor constExpr->ConstType = ConstantExpr::ConstantType::Int; constExpr->integerValue = ii; - RefPtr<VaryingTupleVarDecl> loweredVarDecl = new VaryingTupleVarDecl(); - loweredVarDecl->loc = varDecl->loc; - loweredVarDecl->type = varType; - loweredVarDecl->expr = LoweredExpr(constExpr); - - shared->loweredDecls[varDecl] = LoweredDecl(loweredVarDecl); + shared->mapOriginalDeclToExpr[varDecl] = LegalExpr(constExpr); lowerStmtImpl(stmt->body); } @@ -2459,7 +2424,7 @@ struct LoweringVisitor LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); - loweredStmt->Predicate = subVisitor.lowerExpr(stmt->Predicate); + loweredStmt->Predicate = subVisitor.legalizeSimpleExpr(stmt->Predicate); loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement); addStmt(loweredStmt); @@ -2473,7 +2438,7 @@ struct LoweringVisitor LoweringVisitor subVisitor = pushScope(loweredStmt, stmt); loweredStmt->Statement = subVisitor.lowerStmt(stmt->Statement); - loweredStmt->Predicate = subVisitor.lowerExpr(stmt->Predicate); + loweredStmt->Predicate = subVisitor.legalizeSimpleExpr(stmt->Predicate); addStmt(loweredStmt); } @@ -2489,22 +2454,22 @@ struct LoweringVisitor } void assign( - LoweredExpr destExpr, - LoweredExpr srcExpr, + LegalExpr destExpr, + LegalExpr srcExpr, AssignMode mode = AssignMode::Default) { auto assignExpr = createAssignExpr(destExpr, srcExpr, mode); addExprStmt(assignExpr); } - void assign(VarDeclBase* varDecl, LoweredExpr expr) + void assign(VarDeclBase* varDecl, LegalExpr expr) { - assign(LoweredExpr(createVarRef(getPosition(expr), varDecl)), expr); + assign(LegalExpr(createVarRef(getPosition(expr), varDecl)), expr); } - void assign(LoweredExpr expr, VarDeclBase* varDecl) + void assign(LegalExpr expr, VarDeclBase* varDecl) { - assign(expr, LoweredExpr(createVarRef(getPosition(expr), varDecl))); + assign(expr, LegalExpr(createVarRef(getPosition(expr), varDecl))); } RefPtr<Expr> createTypeExpr( @@ -2537,20 +2502,20 @@ struct LoweringVisitor // where the types don't actually line up, because of // differences in how something is declared in HLSL vs. GLSL void assignWithFixups( - LoweredExpr destExpr, - LoweredExpr srcExpr) + LegalExpr destExpr, + LegalExpr srcExpr) { assign(destExpr, srcExpr, AssignMode::WithFixups); } - void assignWithFixups(VarDeclBase* varDecl, LoweredExpr expr) + void assignWithFixups(VarDeclBase* varDecl, LegalExpr expr) { - assignWithFixups(LoweredExpr(createVarRef(getPosition(expr), varDecl)), expr); + assignWithFixups(LegalExpr(createVarRef(getPosition(expr), varDecl)), expr); } - void assignWithFixups(LoweredExpr expr, VarDeclBase* varDecl) + void assignWithFixups(LegalExpr expr, VarDeclBase* varDecl) { - assignWithFixups(expr, LoweredExpr(createVarRef(getPosition(expr), varDecl))); + assignWithFixups(expr, LegalExpr(createVarRef(getPosition(expr), varDecl))); } void visitReturnStmt(ReturnStmt* stmt) @@ -2563,12 +2528,12 @@ struct LoweringVisitor if (resultVariable) { // Do it as an assignment - assign(resultVariable, lowerExprOrTuple(stmt->Expression)); + assign(resultVariable, legalizeExpr(stmt->Expression)); } else { // Simple case - loweredStmt->Expression = lowerExpr(stmt->Expression); + loweredStmt->Expression = legalizeSimpleExpr(stmt->Expression); } } @@ -2582,7 +2547,7 @@ struct LoweringVisitor RefPtr<Val> translateVal(Val* val) { if (auto type = dynamic_cast<Type*>(val)) - return lowerType(type); + return lowerTypeEx(type); if (auto litVal = dynamic_cast<ConstantIntVal*>(val)) return val; @@ -2597,7 +2562,7 @@ struct LoweringVisitor if (auto genSubst = dynamic_cast<GenericSubstitution*>(inSubstitutions)) { RefPtr<GenericSubstitution> result = new GenericSubstitution(); - result->genericDecl = translateDeclRef(genSubst->genericDecl).getDecl()->As<GenericDecl>(); + result->genericDecl = translateDeclRef(genSubst->genericDecl)->As<GenericDecl>(); for (auto arg : genSubst->args) { result->args.Add(translateVal(arg)); @@ -2622,38 +2587,26 @@ struct LoweringVisitor return decl; } - LoweredDeclRef translateDeclRef( + DeclRef<Decl> translateDeclRef( DeclRef<Decl> const& declRef) { - LoweredDeclRef result; + DeclRef<Decl> result; result.decl = translateDeclRefImpl(declRef); result.substitutions = translateSubstitutions(declRef.substitutions); return result; } - LoweredDecl translateDeclRef( + RefPtr<Decl> translateDeclRef( Decl* decl) { return translateDeclRefImpl(DeclRef<Decl>(decl, nullptr)); } - // Try to find the module that (recursively) contains a given declaration. - ModuleDecl* findModuleForDecl( - Decl* decl) - { - for (auto dd = decl; dd; dd = dd->ParentDecl) - { - if (auto moduleDecl = dynamic_cast<ModuleDecl*>(dd)) - return moduleDecl; - } - return nullptr; - } - - LoweredDecl translateDeclRefImpl( + RefPtr<Decl> translateDeclRefImpl( DeclRef<Decl> declRef) { Decl* decl = declRef.getDecl(); - if (!decl) return LoweredDecl(); + if (!decl) return nullptr; // We don't want to translate references to built-in declarations, // since they won't be subtituted anyway. @@ -2703,8 +2656,17 @@ struct LoweringVisitor } } - LoweredDecl loweredDecl; - if (shared->loweredDecls.TryGetValue(decl, loweredDecl)) + if (getModifiedDecl(decl)->HasModifier<LegalizedModifier>()) + { + // We are trying to translate a reference to a declaration + // that was created by the type legalization process. The + // target declaration should already be placed inside of + // the output module. + return decl; + } + + RefPtr<Decl> loweredDecl; + if (shared->mapOriginalDeclToLowered.TryGetValue(decl, loweredDecl)) return loweredDecl; // Time to force it @@ -2717,7 +2679,7 @@ struct LoweringVisitor return translateDeclRef(declRef).As<ContainerDecl>(); } - LoweredDecl lowerDeclBase( + RefPtr<Decl> lowerDeclBase( DeclBase* declBase) { if (Decl* decl = dynamic_cast<Decl*>(declBase)) @@ -2730,10 +2692,10 @@ struct LoweringVisitor } } - LoweredDecl lowerDecl( + RefPtr<Decl> lowerDecl( Decl* decl) { - LoweredDecl loweredDecl = DeclVisitor::dispatch(decl); + RefPtr<Decl> loweredDecl = DeclVisitor::dispatch(decl); return loweredDecl; } @@ -2768,11 +2730,10 @@ struct LoweringVisitor addMember(parentDecl, decl); } - void registerLoweredDecl(LoweredDecl loweredDecl, Decl* decl) + void registerLoweredDecl(RefPtr<Decl> loweredDecl, Decl* decl) { - shared->loweredDecls.Add(decl, loweredDecl); - - shared->mapLoweredDeclToOriginal.Add(loweredDecl.getValue(), decl); + shared->mapOriginalDeclToLowered.Add(decl, loweredDecl); + shared->mapLoweredDeclToOriginal.Add(loweredDecl.Ptr(), decl); } // If the name of the declarations collides with a reserved word @@ -2821,9 +2782,9 @@ struct LoweringVisitor { RefPtr<Decl> loweredParent; if (auto genericParentDecl = decl->ParentDecl->As<GenericDecl>()) - loweredParent = translateDeclRef(genericParentDecl->ParentDecl).getDecl(); + loweredParent = translateDeclRef(genericParentDecl->ParentDecl); else - loweredParent = translateDeclRef(decl->ParentDecl).getDecl(); + loweredParent = translateDeclRef(decl->ParentDecl); if (loweredParent) { auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>(); @@ -2874,89 +2835,92 @@ struct LoweringVisitor // Catch-all - LoweredDecl visitSyntaxDecl(SyntaxDecl*) + RefPtr<Decl> visitSyntaxDecl(SyntaxDecl*) { - return LoweredDecl(); + return nullptr; } - LoweredDecl visitGenericValueParamDecl(GenericValueParamDecl*) + RefPtr<Decl> visitGenericValueParamDecl(GenericValueParamDecl*) { SLANG_UNEXPECTED("generics should be lowered to specialized decls"); } - LoweredDecl visitGenericTypeParamDecl(GenericTypeParamDecl*) + RefPtr<Decl> visitGenericTypeParamDecl(GenericTypeParamDecl*) { SLANG_UNEXPECTED("generics should be lowered to specialized decls"); } - LoweredDecl visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*) + RefPtr<Decl> visitGenericTypeConstraintDecl(GenericTypeConstraintDecl*) { SLANG_UNEXPECTED("generics should be lowered to specialized decls"); } - LoweredDecl visitGenericDecl(GenericDecl*) + RefPtr<Decl> visitGenericDecl(GenericDecl*) { SLANG_UNEXPECTED("generics should be lowered to specialized decls"); } - LoweredDecl visitModuleDecl(ModuleDecl*) + RefPtr<Decl> visitModuleDecl(ModuleDecl*) { SLANG_UNEXPECTED("module decls should be lowered explicitly"); } - LoweredDecl visitSubscriptDecl(SubscriptDecl*) + RefPtr<Decl> visitSubscriptDecl(SubscriptDecl*) { // We don't expect to find direct references to a subscript // declaration, but rather to the underlying accessors - return LoweredDecl(); + return nullptr; } - LoweredDecl visitInheritanceDecl(InheritanceDecl*) + RefPtr<Decl> visitInheritanceDecl(InheritanceDecl*) { // We should deal with these explicitly, as part of lowering // the type that contains them. - return LoweredDecl(); + return nullptr; } - LoweredDecl visitExtensionDecl(ExtensionDecl*) + RefPtr<Decl> visitExtensionDecl(ExtensionDecl*) { // Extensions won't exist in the lowered code: their members // will turn into ordinary functions that get called explicitly - return LoweredDecl(); + return nullptr; } - LoweredDecl visitAssocTypeDecl(AssocTypeDecl * /*assocType*/) + RefPtr<Decl> visitAssocTypeDecl(AssocTypeDecl * /*assocType*/) { // not supported SLANG_UNREACHABLE("visitAssocTypeDecl in LowerVisitor"); - UNREACHABLE_RETURN(LoweredDecl()); + UNREACHABLE_RETURN(nullptr); } - LoweredDecl visitGlobalGenericParamDecl(GlobalGenericParamDecl * /*decl*/) + RefPtr<Decl> visitGlobalGenericParamDecl(GlobalGenericParamDecl * /*decl*/) { // not supported SLANG_UNREACHABLE("visitGlobalGenericParamDecl in LowerVisitor"); - UNREACHABLE_RETURN(LoweredDecl()); + UNREACHABLE_RETURN(nullptr); } - LoweredDecl visitTypeDefDecl(TypeDefDecl* decl) + RefPtr<Decl> visitTypeDefDecl(TypeDefDecl* decl) { if (shared->target == CodeGenTarget::GLSL) { // GLSL does not support `typedef`, so we will lower it out of existence here - return LoweredDecl(); + return nullptr; } RefPtr<TypeDefDecl> loweredDecl = new TypeDefDecl(); lowerDeclCommon(loweredDecl, decl); - loweredDecl->type = lowerType(decl->type); + // TODO: Need to handle the case where we `typedef` an aggregate + // type that needs to be legalized; in that case we should desugar + // the `typedef` out of existence. + loweredDecl->type = lowerTypeExprEx(decl->type); addMember(shared->loweredProgram, loweredDecl); - return LoweredDecl(loweredDecl); + return loweredDecl; } - LoweredDecl visitImportDecl(ImportDecl*) + RefPtr<Decl> visitImportDecl(ImportDecl*) { // We could unconditionally output the declarations in the // imported code, but this could cause problems if any @@ -2975,10 +2939,10 @@ struct LoweringVisitor // Don't actually include a representation of // the import declaration in the output - return LoweredDecl(); + return nullptr; } - LoweredDecl visitEmptyDecl(EmptyDecl* decl) + RefPtr<Decl> visitEmptyDecl(EmptyDecl* decl) { // Empty declarations are really only useful in GLSL, // where they are used to hold metadata that doesn't @@ -2992,20 +2956,7 @@ struct LoweringVisitor addDecl(loweredDecl); - return LoweredDecl(loweredDecl); - } - - TupleTypeModifier* isTupleType(Type* type) - { - if (auto declRefType = type->As<DeclRefType>()) - { - if (auto tupleTypeMod = declRefType->declRef.getDecl()->FindModifier<TupleTypeModifier>()) - { - return tupleTypeMod; - } - } - - return nullptr; + return loweredDecl; } Type* unwrapArray(Type* inType) @@ -3018,164 +2969,66 @@ struct LoweringVisitor return type; } - TupleTypeModifier* isTupleTypeOrArrayOfTupleType(Type* type) + RefPtr<Decl> visitAggTypeDecl(AggTypeDecl* decl) { - return isTupleType(unwrapArray(type)); - } - - bool isResourceType(Type* type) - { - while (auto arrayType = type->As<ArrayExpressionType>()) - { - type = arrayType->baseType; - } - - if (auto textureTypeBase = type->As<TextureTypeBase>()) - { - return true; - } - else if (auto samplerType = type->As<SamplerStateType>()) - { - return true; - } - - // TODO: need more comprehensive coverage here - - return false; - } - - LoweredDecl visitAggTypeDecl(AggTypeDecl* decl) - { - // We want to lower any aggregate type declaration - // to just a `struct` type that contains its fields. + // An aggregate type declaration might get "legalized away" + // and result in a new type declaration created by the + // type legalization logic. If that happens, we don't want + // the original type declaration to appear in the output. // - // Any non-field members (e.g., methods) will be - // lowered separately. + // If the result *doesn't* get legalized away, though, we + // need to try to reproduce this declaration as it originally + // appeared. - RefPtr<StructDecl> loweredDecl = new StructDecl(); - lowerDeclCommon(loweredDecl, decl); - - // We need to be ready to turn this type into a "tuple" type, - // if it has any members that can't normally be kept in a `struct` + // We start by creating a type to reference this declaration, + // and then we will try to legalize that. // - // We don't want to do this unconditionally, though, because - // then we'll end up changing the meaning of user code in - // languages like HLSL that support such nesting. + // Note: This logic shouldn't need to defend against generic + // types, since it won't get applied to Slang code that might + // include generics (just HLSL/GLSL code). + RefPtr<DeclRefType> declRefType = DeclRefType::Create( + getSession(), + makeDeclRef(decl)); + DeclRef<Decl> declRef = declRefType->declRef; - bool shouldDesugarTupleTypes = false; - if (getTarget() == CodeGenTarget::GLSL) - { - // Always desugar this stuff for GLSL, since it doesn't - // support nesting of resources in structs. - // - // TODO: Need a way to make this more fine-grained to - // handle cases where a nested member might be allowed - // due to, e.g., bindless textures. - shouldDesugarTupleTypes = true; - } - else if( shared->compileRequest->compileFlags & SLANG_COMPILE_FLAG_SPLIT_MIXED_TYPES ) + LegalType legalType = legalizeType(getTypeLegalizationContext(), declRefType); + if(legalType.flavor != LegalType::Flavor::simple) { - // If the user is directly asking us to do this transformation, - // then obviously we need to do it. + // Something happened to this type during legalization, so + // we don't want to let its declaration appear in the output. // - // TODO: The way this is defined here means it will even apply to user - // HLSL code (not just code written in Slang). We may want to - // reconsider that choice, and only split things that originated in Slang. + // However, we need to ensure that when declaration references + // that might reference this declaration get constructed (e.g., + // this might be the `T` in a `ConstantBuffer<T>`, we have something + // to stick in there. // - shouldDesugarTupleTypes = true; + // For now we'll use the original declaration and hope for the best. + return decl; } - bool isResultATupleType = false; - bool hasAnyNonTupleFields = false; + // if we get this far, then we want to produce an "equivalent" + // aggregate type declaration to what the user wrote. + + RefPtr<StructDecl> loweredDecl = new StructDecl(); + lowerDeclCommon(loweredDecl, decl); for (auto field : decl->getMembersOfType<VarDeclBase>()) { // We lower the field, which will involve lowering the field type - auto loweredField = translateDeclRef(field).getDecl()->As<VarDeclBase>(); + auto loweredField = translateDeclRef(field)->As<VarDeclBase>(); // Add the field to the result declaration addMember(loweredDecl, loweredField); - - // Don't consider any of the following desugaring logic, - // if we aren't supposed to be desugaring this type - if (!shouldDesugarTupleTypes) - { - hasAnyNonTupleFields = true; - continue; - } - - - // If the field is of a type that requires special handling, - // we need to make a note of it. - auto loweredFieldType = loweredField->type.type; - bool isTupleField = false; - bool fieldHasAnyNonTupleFields = false; - bool fieldHasTupleType = false; - if (auto fieldTupleTypeMod = isTupleTypeOrArrayOfTupleType(loweredFieldType)) - { - isTupleField = true; - fieldHasTupleType = true; - if (fieldTupleTypeMod->hasAnyNonTupleFields) - { - fieldHasAnyNonTupleFields = true; - hasAnyNonTupleFields = true; - } - } - else if (isResourceType(loweredFieldType)) - { - isTupleField = true; - } - else - { - hasAnyNonTupleFields = true; - } - - if (isTupleField) - { - isResultATupleType = true; - - RefPtr<TupleFieldModifier> tupleFieldMod = new TupleFieldModifier(); - tupleFieldMod->decl = loweredField; - tupleFieldMod->hasAnyNonTupleFields = fieldHasAnyNonTupleFields; - tupleFieldMod->isNestedTuple = fieldHasTupleType; - - addModifier(loweredField, tupleFieldMod); - } } - // An empty `struct` must be treated as a tuple type, - // in order to ensure that we don't mess up layout logic - // - // (Also, GLSL doesn't allow empty structs IIRC) - // - // Note: in this one case we are desugaring things even - // when targetting HLSL, just to keep things manageable. - if (!hasAnyNonTupleFields) - { - isResultATupleType = true; - } + // TODO: we should really be copying over *all* the members, + // in the case where this is a user-authored type. - if (isResultATupleType) - { - RefPtr<TupleTypeModifier> tupleTypeMod = new TupleTypeModifier(); - tupleTypeMod->decl = loweredDecl; - tupleTypeMod->hasAnyNonTupleFields = hasAnyNonTupleFields; - addModifier(loweredDecl, tupleTypeMod); - } + addMember( + shared->loweredProgram, + loweredDecl); - if (isResultATupleType && !hasAnyNonTupleFields) - { - // We don't want any pure-tuple types showing up in - // the output program, so we skip that here. - } - else - { - addMember( - shared->loweredProgram, - loweredDecl); - } - - return LoweredDecl(loweredDecl); + return loweredDecl; } RefPtr<VarDeclBase> lowerSimpleVarDeclCommon( @@ -3186,7 +3039,7 @@ struct LoweringVisitor lowerDeclCommon(loweredDecl, decl); loweredDecl->type = loweredType; - loweredDecl->initExpr = lowerExpr(decl->initExpr); + loweredDecl->initExpr = legalizeSimpleExpr(decl->initExpr); return loweredDecl; } @@ -3195,380 +3048,435 @@ struct LoweringVisitor RefPtr<VarDeclBase> loweredDecl, VarDeclBase* decl) { - auto loweredType = lowerType(decl->type); + auto loweredType = lowerTypeExprEx(decl->type); return lowerSimpleVarDeclCommon(loweredDecl, decl, loweredType); } - struct TupleTypeSecondaryVarArraySpec + RefPtr<StructTypeLayout> getBodyStructTypeLayout(RefPtr<TypeLayout> typeLayout) { - TupleTypeSecondaryVarArraySpec* next; - RefPtr<IntVal> elementCount; - }; + if (!typeLayout) + return nullptr; - struct TupleSecondaryVarInfo - { - // Parent tuple decl to add the secondary decl into - RefPtr<TupleVarDecl> tupleDecl; + while (auto parameterGroupTypeLayout = typeLayout.As<ParameterGroupTypeLayout>()) + { + typeLayout = parameterGroupTypeLayout->elementTypeLayout; + } - // Syntax class for declarations to create - SyntaxClass<VarDeclBase> varDeclClass; + while (auto arrayTypeLayout = typeLayout.As<ArrayTypeLayout>()) + { + typeLayout = arrayTypeLayout->elementTypeLayout; + } - // name "stem" to use for any actual variables we create - String name; + if (auto structTypeLayout = typeLayout.As<StructTypeLayout>()) + { + return structTypeLayout; + } - // The parent tuple type (or array thereof) we are scalarizing - RefPtr<Type> tupleType; + return nullptr; + } - // The actual declaration of the tuple type (which will give us the fields) - DeclRef<AggTypeDecl> tupleTypeDecl; + LegalExpr deref( + LegalExpr base) + { + switch (base.getFlavor()) + { + case LegalExpr::Flavor::none: + return LegalExpr(); - // An initializer expression to use for the tuple members - RefPtr<Expr> initExpr; + case LegalExpr::Flavor::simple: + { + auto simpleBase = base.getSimple(); - // The original layout given to the top-level variable - RefPtr<VarLayout> primaryVarLayout; + RefPtr<DerefExpr> resultExpr = new DerefExpr(); + // TODO: need to fill in a type here? + resultExpr->base = simpleBase; + return LegalExpr(resultExpr); + } + break; - // The computed layout of the tuple type itself - RefPtr<StructTypeLayout> tupleTypeLayout; + case LegalExpr::Flavor::tuple: + { + auto tupleExpr = base.getTuple(); + RefPtr<TuplePseudoExpr> resultExpr = new TuplePseudoExpr(); - TupleTypeSecondaryVarArraySpec* arraySpecs = nullptr; - }; + for (auto ee : tupleExpr->elements) + { + TuplePseudoExpr::Element element; + element.fieldDeclRef = ee.fieldDeclRef; + element.expr = deref(ee.expr); - void createTupleTypeSecondaryVarDecls( - TupleSecondaryVarInfo const& info) - { - if (auto arrayType = info.tupleType->As<ArrayExpressionType>()) - { - TupleTypeSecondaryVarArraySpec arraySpec; - arraySpec.next = info.arraySpecs; - arraySpec.elementCount = arrayType->ArrayLength; + resultExpr->elements.Add(element); + } - TupleSecondaryVarInfo subInfo = info; - subInfo.tupleType = arrayType->baseType; - subInfo.arraySpecs = &arraySpec; - createTupleTypeSecondaryVarDecls(subInfo); - return; - } + return LegalExpr(resultExpr); + } + break; - // Next, we need to go through the declarations in the aggregate - // type, and deal with all of those that should be tuple-ified. - for (auto dd : getMembersOfType<VarDeclBase>(info.tupleTypeDecl)) - { - if (dd.getDecl()->HasModifier<HLSLStaticModifier>()) - continue; + case LegalExpr::Flavor::pair: + { + auto basePair = base.getPair(); + RefPtr<PairPseudoExpr> resultPair = new PairPseudoExpr(); + resultPair->pairInfo = basePair->pairInfo; - auto tupleFieldMod = dd.getDecl()->FindModifier<TupleFieldModifier>(); - if (!tupleFieldMod) - continue; + resultPair->ordinary = deref(basePair->ordinary); + resultPair->special = deref(basePair->special); - // TODO: need to extract the initializer for this field - SLANG_RELEASE_ASSERT(!info.initExpr); - RefPtr<Expr> fieldInitExpr; + return LegalExpr(resultPair); + } - String fieldName = info.name + "_" + getText(dd.GetName()); + case LegalExpr::Flavor::implicitDeref: + { + auto implicitDerefExpr = base.getImplicitDeref(); + return implicitDerefExpr->valueExpr; + } + break; - auto fieldType = GetType(dd); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); + break; + } + } - Decl* originalFieldDecl; - shared->mapLoweredDeclToOriginal.TryGetValue(dd, originalFieldDecl); - SLANG_RELEASE_ASSERT(originalFieldDecl); + LegalExpr extractField( + LegalExpr base, + DeclRef<VarDeclBase> fieldDeclRef) + { + switch (base.getFlavor()) + { + case LegalExpr::Flavor::none: + return LegalExpr(); - RefPtr<VarLayout> fieldLayout; - if(info.tupleTypeLayout) + case LegalExpr::Flavor::simple: { - info.tupleTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout); + auto simpleBase = base.getSimple(); + + RefPtr<MemberExpr> resultExpr = new MemberExpr(); + resultExpr->BaseExpression = simpleBase; + resultExpr->type.type = GetType(fieldDeclRef); + resultExpr->declRef = translateDeclRef(fieldDeclRef.As<Decl>()); + resultExpr->name = fieldDeclRef.GetName(); + return LegalExpr(resultExpr); } - if (fieldLayout && info.primaryVarLayout) + break; + + case LegalExpr::Flavor::tuple: { - // The layout for a field may need to be adjusted - // based on a base offset stored in the primary - // variable. - // - // For example, if the primary variable was recoreded - // to start at descriptor-table slot N, then the - // field layout might say it uses slot k, but that - // needs to be understood relative to the parent, - // so we want slot N + k... and actuall N + k + 1, - // in the case where the parent itself took up - // space of that type... - - bool needsOffset = false; - for (auto rr : fieldLayout->resourceInfos) - { - if (auto parentInfo = info.primaryVarLayout->FindResourceInfo(rr.kind)) - { - if (parentInfo->index != 0 || parentInfo->space != 0) - { - needsOffset = true; - break; - } - } - } - if (needsOffset) + auto baseTuple = base.getTuple(); + for (auto ee : baseTuple->elements) { - RefPtr<VarLayout> newFieldLayout = new VarLayout(); - newFieldLayout->typeLayout = fieldLayout->typeLayout; - newFieldLayout->flags = fieldLayout->flags; - newFieldLayout->stage = fieldLayout->stage; - newFieldLayout->varDecl = fieldLayout->varDecl; - newFieldLayout->systemValueSemantic = fieldLayout->systemValueSemantic; - newFieldLayout->systemValueSemanticIndex = fieldLayout->systemValueSemanticIndex; - newFieldLayout->semanticName = fieldLayout->semanticName; - newFieldLayout->semanticIndex = fieldLayout->semanticIndex; - - for (auto resInfo : fieldLayout->resourceInfos) + if (ee.fieldDeclRef.Equals(fieldDeclRef)) { - auto newResInfo = newFieldLayout->findOrAddResourceInfo(resInfo.kind); - newResInfo->index = resInfo.index; - newResInfo->space = resInfo.space; - if (auto parentInfo = info.primaryVarLayout->FindResourceInfo(resInfo.kind)) - { - newResInfo->index += parentInfo->index; - newResInfo->space += parentInfo->space; - } + return ee.expr; } - - fieldLayout = newFieldLayout; } + SLANG_UNEXPECTED("failed to find tuple element"); } + break; - LoweredDecl fieldVarOrTupleDecl; - if (auto fieldTupleTypeMod = isTupleTypeOrArrayOfTupleType(fieldType)) + case LegalExpr::Flavor::pair: { - // If the field is itself a tuple, then recurse - RefPtr<TupleVarDecl> fieldTupleDecl = new TupleVarDecl(); - - TupleSecondaryVarInfo fieldInfo; - fieldInfo.tupleDecl = fieldTupleDecl; - fieldInfo.varDeclClass = info.varDeclClass; - fieldInfo.name = fieldName; - fieldInfo.tupleType = fieldType; - fieldInfo.tupleTypeDecl = makeDeclRef(fieldTupleTypeMod->decl); - fieldInfo.initExpr = fieldInitExpr; - fieldInfo.primaryVarLayout = fieldLayout; - fieldInfo.tupleTypeLayout = getBodyStructTypeLayout(fieldLayout ? fieldLayout->typeLayout : nullptr); - fieldInfo.arraySpecs = info.arraySpecs; - - fieldTupleDecl->tupleType = fieldTupleTypeMod; - createTupleTypeSecondaryVarDecls(fieldInfo); - - fieldVarOrTupleDecl = LoweredDecl(fieldTupleDecl); - } - else - { - // Otherwise the field has a simple type, and we just need to declare the variable here + auto basePair = base.getPair(); - RefPtr<Type> fieldVarType = fieldType; - for (auto aa = info.arraySpecs; aa; aa = aa->next) - { - RefPtr<ArrayExpressionType> arrayType = Slang::getArrayType( - fieldVarType, - aa->elementCount); + // Need to determine if this field is on the + // ordinary side, the special side, or both. - fieldVarType = arrayType; + auto pairInfo = basePair->pairInfo; + auto pairElement = pairInfo->findElement(fieldDeclRef); + if (!pairElement) + { + SLANG_UNEXPECTED("failed to find tuple element"); + UNREACHABLE_RETURN(LegalExpr()); } - RefPtr<VarDeclBase> fieldVarDecl = info.varDeclClass.createInstance(); - fieldVarDecl->nameAndLoc = NameLoc(getName(fieldName)); - fieldVarDecl->type.type = fieldVarType; - - addDecl(fieldVarDecl); - - if (fieldLayout) + if ((pairElement->flags & PairInfo::kFlag_hasOrdinaryAndSpecial) == PairInfo::kFlag_hasOrdinaryAndSpecial) { - RefPtr<ComputedLayoutModifier> layoutMod = new ComputedLayoutModifier(); - layoutMod->layout = fieldLayout; - addModifier(fieldVarDecl, layoutMod); + // we have both flags + LegalExpr ordinaryResult = extractField(basePair->ordinary, + pairElement->ordinaryFieldDeclRef.As<VarDeclBase>()); + LegalExpr specialResult = extractField(basePair->special, fieldDeclRef); + + RefPtr<PairPseudoExpr> resultPair = new PairPseudoExpr(); + resultPair->ordinary = ordinaryResult; + resultPair->special = specialResult; + resultPair->pairInfo = pairElement->type.getPair()->pairInfo; + return LegalExpr(resultPair); + } + else if(pairElement->flags & PairInfo::kFlag_hasOrdinary) + { + return extractField(basePair->ordinary, + pairElement->ordinaryFieldDeclRef.As<VarDeclBase>()); + } + else + { + SLANG_ASSERT(pairElement->flags & PairInfo::kFlag_hasSpecial); + return extractField(basePair->special, fieldDeclRef); } - - fieldVarOrTupleDecl = LoweredDecl(fieldVarDecl); } + break; - RefPtr<TupleVarModifier> fieldTupleVarMod = new TupleVarModifier(); - fieldTupleVarMod->tupleField = tupleFieldMod; + case LegalExpr::Flavor::implicitDeref: + { + auto baseImplicitDeref = base.getImplicitDeref(); - TupleVarDecl::Element elem; - elem.decl = fieldVarOrTupleDecl; - elem.tupleVarMod = fieldTupleVarMod; + RefPtr<ImplicitDerefPseudoExpr> resultImplicitDeref = new ImplicitDerefPseudoExpr(); + resultImplicitDeref->valueExpr = extractField( + baseImplicitDeref->valueExpr, + fieldDeclRef); + return LegalExpr(resultImplicitDeref); + } - info.tupleDecl->tupleDecls.Add(elem); + default: + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(LegalExpr()); + break; } } - LoweredDecl createTupleTypeVarDecls( - SyntaxClass<VarDeclBase> varDeclClass, - RefPtr<VarDeclBase> originalVarDecl, - String const& name, - RefPtr<Type> tupleType, - DeclRef<AggTypeDecl> tupleTypeDecl, - TupleTypeModifier* tupleTypeMod, - RefPtr<Expr> initExpr, - RefPtr<VarLayout> primaryVarLayout, - RefPtr<StructTypeLayout> tupleTypeLayout) + void attachLayoutModifier( + VarDeclBase* decl, + VarLayout* layout) { - // Not handling initializers just yet... - SLANG_RELEASE_ASSERT(!initExpr); + if (!layout) + return; - // We'll need a placeholder declaration to wrap the whole thing up: - RefPtr<TupleVarDecl> tupleDecl = new TupleVarDecl(); - tupleDecl->nameAndLoc = NameLoc(getName(name)); + RefPtr<ComputedLayoutModifier> mod = new ComputedLayoutModifier(); + mod->layout = layout; + addModifier(decl, mod); + } - // First, if the tuple type had any "ordinary" data, - // then we go ahead and create a declaration for that stuff - if (tupleTypeMod->hasAnyNonTupleFields) + RefPtr<VarDeclBase> declareSimpleVar( + VarDeclBase* decl, + SourceLoc const& loc, + Name* name, + SyntaxClass<VarDeclBase> loweredDeclClass, + VarLayout* varLayout, + RefPtr<Expr> initExpr, + TypeExp const& typeExpr) + { + RefPtr<VarDeclBase> loweredDecl = loweredDeclClass.createInstance(); + if (decl) { - RefPtr<VarDeclBase> primaryVarDecl = varDeclClass.createInstance(); - primaryVarDecl->nameAndLoc.name = getName(name); - primaryVarDecl->type.type = tupleType; + lowerDeclCommon(loweredDecl, decl); + } + loweredDecl->nameAndLoc.name = name; + loweredDecl->nameAndLoc.loc = loc; - primaryVarDecl->modifiers = shallowCloneModifiers(originalVarDecl->modifiers); + loweredDecl->type = typeExpr; + loweredDecl->initExpr = initExpr; - tupleDecl->primaryDecl = primaryVarDecl; + attachLayoutModifier(loweredDecl, varLayout); - if (primaryVarLayout) + addDecl(loweredDecl); + return loweredDecl; + } + + LegalExpr declareSimpleVar( + VarDeclBase* originalDecl, + LegalVarChain* varChain, + SourceLoc const& loc, + String const& name, + SyntaxClass<VarDeclBase> loweredDeclClass, + TypeLayout* typeLayout, + LegalExpr legalInit, + LegalTypeExpr const& legalTypeExpr) + { + RefPtr<VarLayout> varLayout = createVarLayout(varChain, typeLayout); + + RefPtr<VarDeclBase> varDecl = declareSimpleVar( + originalDecl, + loc, + getName(name), + loweredDeclClass, + varLayout, + legalInit.getSimple(), + legalTypeExpr.getSimple()); + + return createVarRef(loc, varDecl); + } + + LegalExpr declareVars( + VarDeclBase* originalDecl, + LegalVarChain* varChain, + SourceLoc const& loc, + String const& name, + SyntaxClass<VarDeclBase> loweredDeclClass, + TypeLayout* typeLayout, + LegalExpr legalInit, + LegalTypeExpr const& legalTypeExpr) + { + auto& legalType = legalTypeExpr.type; + switch (legalType.flavor) + { + case LegalType::Flavor::simple: { - RefPtr<ComputedLayoutModifier> layoutMod = new ComputedLayoutModifier(); - layoutMod->layout = primaryVarLayout; - addModifier(primaryVarDecl, layoutMod); - } + return declareSimpleVar( + originalDecl, + varChain, + loc, + name, + loweredDeclClass, + typeLayout, + legalInit, + legalTypeExpr); - addDecl(primaryVarDecl); - } + } + break; - TupleSecondaryVarInfo info; - info.tupleDecl = tupleDecl; - info.varDeclClass = varDeclClass; - info.name = name; - info.tupleType = tupleType; - info.tupleTypeDecl = tupleTypeDecl; - info.initExpr = initExpr; - info.primaryVarLayout = primaryVarLayout; - info.tupleTypeLayout = tupleTypeLayout; + case LegalType::Flavor::implicitDeref: + { + auto implicitDerefType = legalType.getImplicitDeref(); + + auto valueType = implicitDerefType->valueType; + auto valueTypeLayout = getDerefTypeLayout(typeLayout); + SLANG_ASSERT(valueTypeLayout || !typeLayout); + auto valueInit = deref(legalInit); + + LegalExpr valueExpr = declareVars( + originalDecl, + varChain, + loc, + name, + loweredDeclClass, + valueTypeLayout, + valueInit, + valueType); - createTupleTypeSecondaryVarDecls(info); + RefPtr<ImplicitDerefPseudoExpr> implicitDerefExpr = new ImplicitDerefPseudoExpr(); + implicitDerefExpr->valueExpr = valueExpr; + return LegalExpr(implicitDerefExpr); + } + break; - return LoweredDecl(tupleDecl); - } + case LegalType::Flavor::tuple: + { + auto tupleType = legalType.getTuple(); - RefPtr<StructTypeLayout> getBodyStructTypeLayout(RefPtr<TypeLayout> typeLayout) - { - if (!typeLayout) - return nullptr; + RefPtr<TuplePseudoExpr> tupleExpr = new TuplePseudoExpr(); - while (auto parameterGroupTypeLayout = typeLayout.As<ParameterGroupTypeLayout>()) - { - typeLayout = parameterGroupTypeLayout->elementTypeLayout; - } + for (auto ff : tupleType->elements) + { + RefPtr<VarLayout> fieldLayout = getFieldLayout( + typeLayout, + ff.fieldDeclRef); + RefPtr<TypeLayout> fieldTypeLayout = fieldLayout ? fieldLayout->typeLayout : nullptr; + SLANG_ASSERT(fieldLayout || !typeLayout); + LegalExpr fieldInit = extractField(legalInit, ff.fieldDeclRef); + + String fieldName = name + "_" + getText(ff.fieldDeclRef.GetName()); + + LegalVarChain fieldVarChain; + fieldVarChain.next = varChain; + fieldVarChain.varLayout = fieldLayout; + + LegalExpr fieldExpr = declareVars( + nullptr, + &fieldVarChain, + loc, + fieldName, + loweredDeclClass, + fieldTypeLayout, + fieldInit, + ff.type); + + TuplePseudoExpr::Element element; + element.expr = fieldExpr; + element.fieldDeclRef = ff.fieldDeclRef; + + tupleExpr->elements.Add(element); + } - while (auto arrayTypeLayout = typeLayout.As<ArrayTypeLayout>()) - { - typeLayout = arrayTypeLayout->elementTypeLayout; - } + return LegalExpr(tupleExpr); + } + break; - if (auto structTypeLayout = typeLayout.As<StructTypeLayout>()) - { - return structTypeLayout; - } + case LegalType::Flavor::pair: + { + auto pairType = legalType.getPair(); + RefPtr<PairPseudoExpr> pairExpr = new PairPseudoExpr(); + pairExpr->pairInfo = pairType->pairInfo; + pairExpr->loc = loc; + + pairExpr->ordinary = declareVars( + originalDecl, + varChain, + loc, + name, + loweredDeclClass, + typeLayout, + legalInit, + pairType->ordinaryType); + + pairExpr->special = declareVars( + originalDecl, + varChain, + loc, + name, + loweredDeclClass, + typeLayout, + legalInit, + pairType->specialType); - return nullptr; - } + return LegalExpr(pairExpr); + } + break; - LoweredDecl createTupleTypeVarDecls( - SyntaxClass<VarDeclBase> varDeclClass, - RefPtr<VarDeclBase> originalVarDecl, - String const& name, - RefPtr<Type> tupleType, - TupleTypeModifier* tupleTypeMod, - RefPtr<Expr> initExpr, - RefPtr<VarLayout> primaryVarLayout) - { - RefPtr<StructTypeLayout> tupleTypeLayout; - if (primaryVarLayout) - { - auto primaryTypeLayout = primaryVarLayout->typeLayout; - tupleTypeLayout = getBodyStructTypeLayout(primaryTypeLayout); + default: + SLANG_UNEXPECTED("unhandled legalized type flavor"); + UNREACHABLE_RETURN(LegalExpr()); + break; } - - return createTupleTypeVarDecls( - varDeclClass, - originalVarDecl, - name, - tupleType, - makeDeclRef(tupleTypeMod->decl), - tupleTypeMod, - initExpr, - primaryVarLayout, - tupleTypeLayout); } - LoweredDecl lowerVarDeclCommonInner( + void lowerVarDeclCommonInner( VarDeclBase* decl, SyntaxClass<VarDeclBase> loweredDeclClass) { - auto loweredType = lowerType(decl->type); - - if (auto tupleTypeMod = isTupleTypeOrArrayOfTupleType(loweredType)) - { - auto varLayout = tryToFindLayout(decl).As<VarLayout>(); + auto legalTypeExpr = lowerAndLegalizeTypeExpr(decl->type); - // The type for the variable is a "tuple type" - // so we need to go ahead and create multiple variables - // to represent it. + auto varLayout = tryToFindLayout(decl).As<VarLayout>(); - // If the variable had an initializer, we expect it - // to resolve to a tuple *value* - auto loweredInit = lowerExpr(decl->initExpr); - - // TODO: need to extract layout here and propagate it down! + // Note: we lower the initialization expression, if any, + // *before* we add the declaration to the current context (e.g., a statement being + // built), so that any operations inside the initialization expression that + // might need to inject statements/temporaries/whatever happen *before* + // the declaration of this variable. + auto legalInit = legalizeExpr(decl->initExpr); - auto tupleDecl = createTupleTypeVarDecls( - loweredDeclClass, + if (legalTypeExpr.type.flavor == LegalType::Flavor::simple) + { + declareSimpleVar( decl, - getText(decl->getName()), - loweredType.type, - tupleTypeMod, - loweredInit, - varLayout); - - shared->loweredDecls.Add(decl, tupleDecl); - return tupleDecl; + decl->nameAndLoc.loc, + decl->getName(), + loweredDeclClass, + varLayout, + legalInit.getSimple(), + legalTypeExpr.getSimple()); } - if (auto bufferType = loweredType->As<UniformParameterGroupType>()) + else { - auto varLayout = tryToFindLayout(decl).As<VarLayout>(); - - auto elementType = bufferType->elementType; - if (auto elementTupleTypeMod = isTupleTypeOrArrayOfTupleType(elementType)) - { - auto tupleDecl = createTupleTypeVarDecls( - loweredDeclClass, - decl, - getText(decl->getName()), - loweredType.type, - elementTupleTypeMod, - nullptr, - varLayout); - - shared->loweredDecls.Add(decl, tupleDecl); - return tupleDecl; - } - } + LegalVarChain varChain; + varChain.next = nullptr; + varChain.varLayout = varLayout; - RefPtr<VarDeclBase> loweredDecl = loweredDeclClass.createInstance(); - - // Note: we lower the declaration (including its initialization expression, if any) - // *before* we add the declaration to the current context (e.g., a statement being - // built), so that any operations inside the initialization expression that - // might need to inject statements/temporaries/whatever happen *before* - // the declaration of this variable. - auto result = lowerSimpleVarDeclCommon(loweredDecl, decl, loweredType); - addDecl(loweredDecl); + LegalExpr legalExpr = declareVars( + decl, + &varChain, + decl->nameAndLoc.loc, + getText(decl->getName()), + loweredDeclClass, + varLayout ? varLayout->typeLayout : nullptr, + legalInit, + legalTypeExpr); - return LoweredDecl(result); + shared->mapOriginalDeclToExpr.Add(decl, legalExpr); + shared->mapOriginalDeclToLowered.AddIfNotExists(decl, nullptr); + } } - LoweredDecl lowerVarDeclCommon( + void lowerVarDeclCommon( VarDeclBase* decl, SyntaxClass<VarDeclBase> loweredDeclClass) { @@ -3581,17 +3489,17 @@ struct LoweringVisitor if (auto parentModuleDecl = pp.As<ModuleDecl>()) { LoweringVisitor subVisitor = *this; - subVisitor.parentDecl = translateDeclRef(parentModuleDecl).getDecl()->As<ContainerDecl>(); + subVisitor.parentDecl = translateDeclRef(parentModuleDecl)->As<ContainerDecl>(); subVisitor.isBuildingStmt = false; - return subVisitor.lowerVarDeclCommonInner(decl, loweredDeclClass); + subVisitor.lowerVarDeclCommonInner(decl, loweredDeclClass); } // TODO: handle `static` function-scope variables else { // The default behavior is to lower into whatever // scope was already in places - return lowerVarDeclCommonInner(decl, loweredDeclClass); + lowerVarDeclCommonInner(decl, loweredDeclClass); } } @@ -3658,7 +3566,7 @@ struct LoweringVisitor return false; } - LoweredDecl visitVariable( + RefPtr<Decl> visitVariable( Variable* decl) { if (dynamic_cast<ModuleDecl*>(decl->ParentDecl)) @@ -3679,7 +3587,7 @@ struct LoweringVisitor // We can't easily support `in out` declarations with this approach SLANG_RELEASE_ASSERT(!(inRes && outRes)); - LoweredExpr loweredExpr; + LegalExpr loweredExpr; if (inRes) { loweredExpr = lowerShaderParameterToGLSLGLobals( @@ -3696,54 +3604,49 @@ struct LoweringVisitor VaryingParameterDirection::Output); } -// SLANG_RELEASE_ASSERT(loweredExpr); - auto loweredDecl = createVaryingTupleVarDecl( - decl, - loweredExpr); - - registerLoweredDecl(LoweredDecl(loweredDecl), decl); - return LoweredDecl(loweredDecl); + shared->mapOriginalDeclToExpr.Add(decl, loweredExpr); + shared->mapOriginalDeclToLowered.Add(decl, nullptr); + return nullptr; } } } - auto loweredDecl = lowerVarDeclCommon(decl, getClass<Variable>()); - if(!loweredDecl.getValue()) - return LoweredDecl(); + lowerVarDeclCommon(decl, getClass<Variable>()); - return loweredDecl; + return nullptr; } - LoweredDecl visitStructField( + RefPtr<Decl> visitStructField( StructField* decl) { - return LoweredDecl(lowerSimpleVarDeclCommon(new StructField(), decl)); + return lowerSimpleVarDeclCommon(new StructField(), decl); } - LoweredDecl visitParamDecl( + RefPtr<Decl> visitParamDecl( ParamDecl* decl) { - auto loweredDecl = lowerVarDeclCommon(decl, getClass<ParamDecl>()); - return loweredDecl; + lowerVarDeclCommon(decl, getClass<ParamDecl>()); + + return nullptr; } - LoweredDecl transformSyntaxField(DeclBase* decl) + RefPtr<Decl> transformSyntaxField(DeclBase* decl) { return lowerDeclBase(decl); } - LoweredDecl visitDeclGroup( + RefPtr<Decl> visitDeclGroup( DeclGroup* group) { for (auto decl : group->decls) { lowerDecl(decl); } - return LoweredDecl(); + return nullptr; } - LoweredDecl visitFunctionDeclBase( + RefPtr<Decl> visitFunctionDeclBase( FunctionDeclBase* decl) { // TODO: need to generate a name @@ -3766,7 +3669,7 @@ struct LoweringVisitor subVisitor.translateDeclRef(paramDecl); } - auto loweredReturnType = subVisitor.lowerType(decl->ReturnType); + auto loweredReturnType = subVisitor.lowerAndlegalizeSimpleTypeExpr(decl->ReturnType); loweredDecl->ReturnType = loweredReturnType; @@ -3776,7 +3679,7 @@ struct LoweringVisitor // even if it had been a member function when declared. addMember(shared->loweredProgram, loweredDecl); - return LoweredDecl(loweredDecl); + return loweredDecl; } // @@ -3969,7 +3872,7 @@ struct LoweringVisitor return Slang::getArrayType(elementType, getConstantIntVal(elementCount)); } - LoweredExpr lowerSimpleShaderParameterToGLSLGlobal( + LegalExpr lowerSimpleShaderParameterToGLSLGlobal( VaryingParameterInfo const& info, RefPtr<Type> varType, RefPtr<VarLayout> varLayout) @@ -4246,10 +4149,10 @@ struct LoweringVisitor globalVarExpr = globalVarRef; } - return LoweredExpr(globalVarExpr); + return LegalExpr(globalVarExpr); } - LoweredExpr lowerShaderParameterToGLSLGLobalsRec( + LegalExpr lowerShaderParameterToGLSLGLobalsRec( VaryingParameterInfo const& info, RefPtr<Type> varType, RefPtr<VarLayout> varLayout) @@ -4301,10 +4204,7 @@ struct LoweringVisitor // The shader parameter had a structured type, so we need // to destructure it into its constituent fields - RefPtr<VaryingTupleExpr> tupleExpr = new VaryingTupleExpr(); - tupleExpr->type.type = varType; - - SLANG_RELEASE_ASSERT(tupleExpr->type.type); + RefPtr<TuplePseudoExpr> tupleExpr = new TuplePseudoExpr(); for (auto fieldDeclRef : getMembersOfType<VarDeclBase>(aggTypeDeclRef)) { @@ -4337,15 +4237,15 @@ struct LoweringVisitor GetType(fieldDeclRef), fieldLayout); - VaryingTupleExpr::Element elem; - elem.originalFieldDeclRef = makeDeclRef(originalFieldDecl).As<VarDeclBase>(); + TuplePseudoExpr::Element elem; + elem.fieldDeclRef = makeDeclRef(originalFieldDecl).As<VarDeclBase>(); elem.expr = loweredFieldExpr; tupleExpr->elements.Add(elem); } // Okay, we are done with this parameter - return LoweredExpr(tupleExpr); + return LegalExpr(tupleExpr); } } @@ -4353,7 +4253,7 @@ struct LoweringVisitor return lowerSimpleShaderParameterToGLSLGlobal(info, varType, varLayout); } - LoweredExpr lowerShaderParameterToGLSLGLobals( + LegalExpr lowerShaderParameterToGLSLGLobals( RefPtr<VarDeclBase> originalVarDecl, RefPtr<VarLayout> paramLayout, VaryingParameterDirection direction) @@ -4383,47 +4283,16 @@ struct LoweringVisitor break; } - auto loweredType = lowerType(originalVarDecl->type); + auto loweredType = lowerAndLegalizeTypeExpr(originalVarDecl->type); auto loweredExpr = lowerShaderParameterToGLSLGLobalsRec( info, - loweredType.type, + loweredType.type.getSimple(), // TODO: handle non-simple? paramLayout); -#if 0 - RefPtr<VaryingTupleVarDecl> loweredDecl = createVaryingTupleVarDecl( - originalVarDecl, - loweredType, - loweredExpr); - - registerLoweredDecl(loweredDecl, originalVarDecl); - addDecl(loweredDecl); -#endif - return loweredExpr; } - RefPtr<VaryingTupleVarDecl> createVaryingTupleVarDecl( - RefPtr<VarDeclBase> originalVarDecl, - TypeExp const& loweredType, - LoweredExpr loweredExpr) - { - RefPtr<VaryingTupleVarDecl> loweredDecl = new VaryingTupleVarDecl(); - loweredDecl->nameAndLoc = originalVarDecl->nameAndLoc; - loweredDecl->type = loweredType; - loweredDecl->expr = loweredExpr; - - return loweredDecl; - } - - RefPtr<VaryingTupleVarDecl> createVaryingTupleVarDecl( - RefPtr<VarDeclBase> originalVarDecl, - LoweredExpr loweredExpr) - { - auto loweredType = lowerType(originalVarDecl->type); - return createVaryingTupleVarDecl(originalVarDecl, loweredType, loweredExpr); - } - struct EntryPointParamPair { RefPtr<ParamDecl> original; @@ -4436,7 +4305,7 @@ struct LoweringVisitor RefPtr<EntryPointLayout> entryPointLayout) { // First, loer the entry-point function as an ordinary function: - auto loweredEntryPointFunc = visitFunctionDeclBase(entryPointDecl).getDecl()->As<FunctionDeclBase>(); + auto loweredEntryPointFunc = visitFunctionDeclBase(entryPointDecl)->As<FunctionDeclBase>(); auto mainName = getName("main"); @@ -4476,7 +4345,7 @@ struct LoweringVisitor RefPtr<Variable> localVarDecl = new Variable(); localVarDecl->loc = paramDecl->loc; localVarDecl->nameAndLoc = paramDecl->getNameAndLoc(); - localVarDecl->type = lowerType(paramDecl->type); + localVarDecl->type = lowerAndlegalizeSimpleTypeExpr(paramDecl->type); ensureDeclHasAValidName(localVarDecl); @@ -4553,12 +4422,12 @@ struct LoweringVisitor if (resultVarDecl) { // Non-`void` return type, so we need to store it - subVisitor.assign(resultVarDecl, LoweredExpr(callExpr)); + subVisitor.assign(resultVarDecl, LegalExpr(callExpr)); } else { // `void` return type: just call it - subVisitor.addExprStmt(LoweredExpr(callExpr)); + subVisitor.addExprStmt(LegalExpr(callExpr)); } @@ -4594,8 +4463,8 @@ struct LoweringVisitor if (shared->requiresCopyGLPositionToPositionPerView) { subVisitor.assign( - LoweredExpr(createSimpleVarExpr("gl_PositionPerViewNV[0]")), - LoweredExpr(createSimpleVarExpr("gl_Position"))); + LegalExpr(createSimpleVarExpr("gl_PositionPerViewNV[0]")), + LegalExpr(createSimpleVarExpr("gl_Position"))); } bodyStmt->body = subVisitor.stmtBeingBuilt; @@ -4666,7 +4535,7 @@ struct LoweringVisitor { // Default case: lower an entry point just like any other function default: - return visitFunctionDeclBase(entryPointDecl).getDecl()->As<FuncDecl>(); + return visitFunctionDeclBase(entryPointDecl)->As<FuncDecl>(); // For Slang->GLSL translation, we need to lower things from HLSL-style // declarations over to GLSL conventions @@ -4719,11 +4588,12 @@ bool isRewriteRequest( LoweredEntryPoint lowerEntryPoint( - EntryPointRequest* entryPoint, - ProgramLayout* programLayout, - CodeGenTarget target, - ExtensionUsageTracker* extensionUsageTracker, - IRSpecializationState* irSpecializationState) + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + ExtensionUsageTracker* extensionUsageTracker, + IRSpecializationState* irSpecializationState, + TypeLegalizationContext* typeLegalizationContext) { SharedLoweringContext sharedContext; sharedContext.compileRequest = entryPoint->compileRequest; @@ -4732,6 +4602,7 @@ LoweredEntryPoint lowerEntryPoint( sharedContext.target = target; sharedContext.extensionUsageTracker = extensionUsageTracker; sharedContext.irSpecializationState = irSpecializationState; + sharedContext.typeLegalizationContext = typeLegalizationContext; auto translationUnit = entryPoint->getTranslationUnit(); sharedContext.mainModuleDecl = translationUnit->SyntaxNode; @@ -4742,6 +4613,10 @@ LoweredEntryPoint lowerEntryPoint( RefPtr<ModuleDecl> loweredProgram = new ModuleDecl(); sharedContext.loweredProgram = loweredProgram; + typeLegalizationContext->mainModuleDecl = sharedContext.mainModuleDecl; + typeLegalizationContext->outputModuleDecl = loweredProgram; + + LoweringVisitor visitor; visitor.shared = &sharedContext; visitor.parentDecl = loweredProgram; @@ -4753,7 +4628,7 @@ LoweredEntryPoint lowerEntryPoint( // of the existing translation unit declaration. visitor.registerLoweredDecl( - LoweredDecl(loweredProgram), + loweredProgram, translationUnit->SyntaxNode); // We also need to register the lowered program as the lowered version @@ -4761,9 +4636,9 @@ LoweredEntryPoint lowerEntryPoint( // a single module for code generation). for (auto rr : entryPoint->compileRequest->loadedModulesList) { - sharedContext.loweredDecls.Add( + sharedContext.mapOriginalDeclToLowered.Add( rr->moduleDecl, - LoweredDecl(loweredProgram)); + loweredProgram); } // We also want to remember the layout information for diff --git a/source/slang/ast-legalize.h b/source/slang/ast-legalize.h index 9046e8df0..23a150002 100644 --- a/source/slang/ast-legalize.h +++ b/source/slang/ast-legalize.h @@ -42,6 +42,7 @@ namespace Slang struct IRSpecializationState; class ProgramLayout; class TranslationUnitRequest; + struct TypeLegalizationContext; struct LoweredEntryPoint @@ -63,10 +64,11 @@ namespace Slang // Emit code for a single entry point, based on // the input translation unit. LoweredEntryPoint lowerEntryPoint( - EntryPointRequest* entryPoint, - ProgramLayout* programLayout, - CodeGenTarget target, - ExtensionUsageTracker* extensionUsageTracker, - IRSpecializationState* irSpecializationState); + EntryPointRequest* entryPoint, + ProgramLayout* programLayout, + CodeGenTarget target, + ExtensionUsageTracker* extensionUsageTracker, + IRSpecializationState* irSpecializationState, + TypeLegalizationContext* typeLegalizationContext); } #endif diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 4a084c714..b9aa3c027 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -3,6 +3,7 @@ #include "ast-legalize.h" #include "ir-insts.h" +#include "legalize-types.h" #include "lower-to-ir.h" #include "mangle.h" #include "name.h" @@ -1236,13 +1237,6 @@ struct EmitVisitor emitTypeImpl(type->valueType, arg.declarator); } - void visitFilteredTupleType(FilteredTupleType* type, TypeEmitArg const& arg) - { - auto declarator = arg.declarator; - emit(getMangledTypeName(type)); - EmitDeclarator(declarator); - } - void EmitType( RefPtr<Type> type, SourceLoc const& typeLoc, @@ -2952,7 +2946,7 @@ struct EmitVisitor Decl* decl = declRef.getDecl(); if(irDeclSet->Contains(decl)) { - emit(getMangledName(declRef)); + emit(getIRName(declRef)); return; } } @@ -3605,10 +3599,6 @@ struct EmitVisitor // The data type that describes where stuff in the constant buffer should go RefPtr<Type> dataType = parameterGroupType->elementType; - // We expect/require the data type to be a user-defined `struct` type - auto declRefType = dataType->As<DeclRefType>(); - SLANG_RELEASE_ASSERT(declRefType); - // We expect to always have layout information layout = maybeFetchLayout(varDecl, layout); SLANG_RELEASE_ASSERT(layout); @@ -3642,27 +3632,48 @@ struct EmitVisitor emitHLSLRegisterSemantic(*info); Emit("\n{\n"); - if (auto structRef = declRefType->declRef.As<StructDecl>()) - { - int fieldCounter = 0; - for (auto field : getMembersOfType<StructField>(structRef)) + // We expect the data type to be a user-defined `struct` type, + // but it might also be a "filtered" type that represents the + // case where only some fields of the original type are valid + // to appear inside of a `struct`. + if (auto declRefType = dataType->As<DeclRefType>()) + { + if (auto structRef = declRefType->declRef.As<StructDecl>()) { - int fieldIndex = fieldCounter++; + int fieldCounter = 0; - emitVarDeclHead(field); + for (auto field : getMembersOfType<StructField>(structRef)) + { + int fieldIndex = fieldCounter++; + + // Skip fields that have `void` type, since these represent + // declarations that got legalized out of existence. + if(GetType(field)->Equals(getSession()->getVoidType())) + continue; - RefPtr<VarLayout> fieldLayout = structTypeLayout->fields[fieldIndex]; - SLANG_RELEASE_ASSERT(fieldLayout->varDecl.GetName() == field.GetName()); + emitVarDeclHead(field); - // Emit explicit layout annotations for every field - emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); + RefPtr<VarLayout> fieldLayout = structTypeLayout->fields[fieldIndex]; + SLANG_RELEASE_ASSERT(fieldLayout->varDecl.GetName() == field.GetName()); - emitVarDeclInit(field); + // Emit explicit layout annotations for every field + emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); - Emit(";\n"); + emitVarDeclInit(field); + + Emit(";\n"); + } + } + else + { + SLANG_UNEXPECTED("unexpected element type for parameter group"); } } + else + { + SLANG_UNEXPECTED("unexpected element type for parameter group"); + } Emit("}\n"); } @@ -3773,10 +3784,6 @@ struct EmitVisitor // The data type that describes where stuff in the constant buffer should go RefPtr<Type> dataType = parameterGroupType->elementType; - // We expect/require the data type to be a user-defined `struct` type - auto declRefType = dataType->As<DeclRefType>(); - SLANG_RELEASE_ASSERT(declRefType); - // We expect the layout, if present, to be for a structured type... RefPtr<StructTypeLayout> structTypeLayout; if (layout) @@ -3827,26 +3834,54 @@ struct EmitVisitor } Emit("\n{\n"); - if (auto structRef = declRefType->declRef.As<StructDecl>()) + + // We expect the data type to be a user-defined `struct` type, + // but it might also be a "filtered" type that represents the + // case where only some fields of the original type are valid + // to appear inside of a `struct`. + if (auto declRefType = dataType->As<DeclRefType>()) { - for (auto field : getMembersOfType<StructField>(structRef)) + + if (auto structRef = declRefType->declRef.As<StructDecl>()) { - if (structTypeLayout) + int fieldCounter = 0; + for (auto field : getMembersOfType<StructField>(structRef)) { - RefPtr<VarLayout> fieldLayout; - structTypeLayout->mapVarToLayout.TryGetValue(field.getDecl(), fieldLayout); - // assert(fieldLayout); + int fieldIndex = fieldCounter++; + + // Skip fields that have `void` type, since these represent + // declarations that got legalized out of existence. + if(GetType(field)->Equals(getSession()->getVoidType())) + continue; + + if (structTypeLayout) + { + RefPtr<VarLayout> fieldLayout = structTypeLayout->fields[fieldIndex]; + // assert(fieldLayout); + + // TODO(tfoley): We may want to emit *some* of these, + // some of the time... + // emitGLSLLayoutQualifiers(fieldLayout); + } - // TODO(tfoley): We may want to emit *some* of these, - // some of the time... - // emitGLSLLayoutQualifiers(fieldLayout); - } - EmitVarDeclCommon(field); + EmitVarDeclCommon(field); - Emit(";\n"); + Emit(";\n"); + } + } + else + { + SLANG_UNEXPECTED("unexpected element type for parameter group"); } } + else + { + SLANG_UNEXPECTED("unexpected element type for parameter group"); + } + + + Emit("}"); if( varDecl->getNameLoc().isValid() ) @@ -3890,13 +3925,10 @@ struct EmitVisitor } } - // Skip fields that have been tuple-ified and don't contribute - // any fields of "ordinary" type. - if (auto tupleFieldMod = decl->FindModifier<TupleFieldModifier>()) - { - if (!tupleFieldMod->hasAnyNonTupleFields) - return; - } + // Skip fields that have `void` type, since these may be introduced + // as part of type leglaization. + if(decl->getType()->Equals(getSession()->getVoidType())) + return; RefPtr<VarLayout> layout = arg.layout; layout = maybeFetchLayout(decl, layout); @@ -4129,16 +4161,6 @@ emitDeclImpl(decl, nullptr); String getIRName(Decl* decl) { - // It is a bit ugly, but we need a deterministic way - // to get a name for things when emitting from the IR - // that won't conflict with any keywords, builtins, etc. - // in the target language. - // - // Eventually we should accomplish this by using - // mangled names everywhere, but that complicates things - // when we are also using direct comparison to fxc/glslang - // output for some of our tests. - // // TODO: need a flag to get rid of the step that adds // a prefix here, so that we can get "clean" output // when needed. @@ -4155,7 +4177,27 @@ emitDeclImpl(decl, nullptr); String getIRName(DeclRefBase const& declRef) { - return getIRName(declRef.decl); + // It is a bit ugly, but we need a deterministic way + // to get a name for things when emitting from the IR + // that won't conflict with any keywords, builtins, etc. + // in the target language. + // + // Eventually we should accomplish this by using + // mangled names everywhere, but that complicates things + // when we are also using direct comparison to fxc/glslang + // output for some of our tests. + // + + String name; + if (context->shared->entryPoint->compileRequest->compileFlags & SLANG_COMPILE_FLAG_NO_MANGLING) + { + name.append(getText(declRef.GetName())); + } + else + { + name.append(getMangledName(declRef)); + } + return name; } String getGLSLSystemValueName( @@ -6279,9 +6321,12 @@ emitDeclImpl(decl, nullptr); auto fieldLayout = structTypeLayout->fields[fieldIndex++]; + auto fieldType = GetType(ff); + if(fieldType->Equals(getSession()->getVoidType())) + continue; + emitIRVarModifiers(ctx, fieldLayout); - auto fieldType = GetType(ff); emitIRType(ctx, fieldType, getIRName(ff)); emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); @@ -6290,26 +6335,6 @@ emitDeclImpl(decl, nullptr); } } } - else if (auto filteredTupleType = elementType->As<FilteredTupleType>()) - { - auto structTypeLayout = typeLayout.As<StructTypeLayout>(); - assert(structTypeLayout); - - for (auto ee : filteredTupleType->elements) - { - RefPtr<VarLayout> fieldLayout; - structTypeLayout->mapVarToLayout.TryGetValue(ee.fieldDeclRef, fieldLayout); - - emitIRVarModifiers(ctx, fieldLayout); - - auto fieldType = ee.type; - emitIRType(ctx, fieldType, getIRName(ee.fieldDeclRef)); - - emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); - - emit(";\n"); - } - } else { emit("/* unexpected */"); @@ -6376,9 +6401,12 @@ emitDeclImpl(decl, nullptr); auto fieldLayout = structTypeLayout->fields[fieldIndex++]; + auto fieldType = GetType(ff); + if(fieldType->Equals(getSession()->getVoidType())) + continue; + emitIRVarModifiers(ctx, fieldLayout); - auto fieldType = GetType(ff); emitIRType(ctx, fieldType, getIRName(ff)); // emitHLSLParameterGroupFieldLayoutSemantics(layout, fieldLayout); @@ -6589,7 +6617,7 @@ emitDeclImpl(decl, nullptr); } Emit("struct "); - emit(declRef.GetName()); + EmitDeclRef(declRef); Emit("\n{\n"); for( auto ff : GetFields(declRef) ) { @@ -6597,6 +6625,11 @@ emitDeclImpl(decl, nullptr); continue; auto fieldType = GetType(ff); + + // Skip `void` fields that might have been created by legalization. + if(fieldType->Equals(getSession()->getVoidType())) + continue; + emitIRType(ctx, fieldType, getIRName(ff)); EmitSemantics(ff.getDecl()); @@ -6655,43 +6688,6 @@ emitDeclImpl(decl, nullptr); ensureStructDecl(ctx, structDeclRef); } } - else if (auto filteredTupleType = type->As<FilteredTupleType>()) - { - // First, ensure that the element types are ready: - for (auto ee : filteredTupleType->elements) - { - if (ee.type) - { - emitIRUsedType(ctx, ee.type); - } - } - - // Now, we want to ensure we've emitted a - // matching `struct` type declaration. - - String mangledName = getMangledTypeName(filteredTupleType); - if (!ctx->shared->irTupleTypes.Contains(mangledName)) - { - ctx->shared->irTupleTypes.Add(mangledName); - - // Emit the damn `struct` decl... - - Emit("struct "); - emit(mangledName); - Emit("\n{\n"); - for( auto ee : filteredTupleType->elements ) - { - if (!ee.type) - continue; - - emitIRType(ctx, ee.type, getIRName(ee.fieldDeclRef)); - - emit(";\n"); - } - Emit("};\n"); - - } - } else {} } @@ -6846,7 +6842,8 @@ StructTypeLayout* getGlobalStructLayout( } void legalizeTypes( - IRModule* module); + TypeLegalizationContext* context, + IRModule* module); String emitEntryPoint( EntryPointRequest* entryPoint, @@ -6937,6 +6934,9 @@ String emitEntryPoint( // Next we will check for case (2a): else if (!(translationUnit->compileRequest->compileFlags & SLANG_COMPILE_FLAG_USE_IR)) { + TypeLegalizationContext typeLegalizationContext; + typeLegalizationContext.session = entryPoint->compileRequest->mSession; + // This case means the user has opted out of using the IR (so we can't use the // cases below), but they either turned on semantic checking *or* imported some // Slang code, so they can't use the case above. @@ -6958,7 +6958,8 @@ String emitEntryPoint( programLayout, target, &sharedContext.extensionUsageTracker, - nullptr); + nullptr, + &typeLegalizationContext); sharedContext.program = lowered.program; // Note that we emit the main body code of the program *before* @@ -6976,6 +6977,9 @@ String emitEntryPoint( // are certain steps that need to be shared. else { + TypeLegalizationContext typeLegalizationContext; + typeLegalizationContext.session = entryPoint->compileRequest->mSession; + // We are going to create a fresh IR module that we will use to // clone any code needed by the user's entry point. IRSpecializationState* irSpecializationState = createIRSpecializationState( @@ -6985,6 +6989,8 @@ String emitEntryPoint( targetRequest); IRModule* irModule = getIRModule(irSpecializationState); + typeLegalizationContext.irModule = irModule; + LoweredEntryPoint lowered; if(translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) { @@ -7002,7 +7008,8 @@ String emitEntryPoint( programLayout, target, &sharedContext.extensionUsageTracker, - irSpecializationState); + irSpecializationState, + &typeLegalizationContext); } else { @@ -7044,7 +7051,9 @@ String emitEntryPoint( // we need to ensure that the code only uses types // that are legal on the chosen target. // - legalizeTypes(irModule); + legalizeTypes( + &typeLegalizationContext, + irModule); // Debugging output of legalization #if 0 @@ -7053,6 +7062,8 @@ String emitEntryPoint( fprintf(stderr, "###\n"); #endif + sharedContext.irDeclSetForAST = &lowered.irDecls; + // After all of the required optimization and legalization // passes have been performed, we can emit target code from // the IR module. @@ -7065,7 +7076,6 @@ String emitEntryPoint( // that we need to output, we'll do it now. if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING) { - sharedContext.irDeclSetForAST = &lowered.irDecls; visitor.EmitDeclsInContainer(lowered.program); } diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 7fad2ccc5..6e909106d 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -1,254 +1,22 @@ // ir-legalize-types.cpp -// This file implements a pass that takes IR -// that has been fully specialized (no more -// generics/interfaces needing to be specialized -// away) and replaces any types that can't actually -// be used as-is on the target. +// This file implements type legalization for the IR. +// It uses the core legalization logic in +// `legalize-types.{h,cpp}` to decide what to do with +// the types, while this file handles the actual +// rewriting of the IR to use the new types. // -// The particular case we are focused on is -// aggregate types (e.g., `struct` types) that -// contain resources (textures, samplers, etc.) -// or that mix resources and ordinary "uniform" -// data. +// This pass should only be applied to IR that has been +// fully specialized (no more generics/interfaces), so +// that the concrete type of everything is known. #include "ir.h" #include "ir-insts.h" +#include "legalize-types.h" namespace Slang { -struct LegalTypeImpl : RefObject -{ -}; -struct ImplicitDerefType; -struct TuplePseudoType; -struct PairPseudoType; -struct PairInfo; - -struct LegalType -{ - enum class Flavor - { - // Nothing: a NULL type - none, - - // A simple type that can be represented directly as a `Type` - simple, - - // Logically, we have a pointer-like type, but we are - // going to represnet it as the pointed-to type - implicitDeref, - - // A compound type was broken apart into its constituent fields, - // so a tuple "pseduo-type" is being used to collect - // those fields together. - tuple, - - // A type has to get split into "ordinary" and "special" parts, - // each of which will be represented with its own `LegalType`. - pair, - }; - - Flavor flavor = Flavor::none; - RefPtr<RefObject> obj; - - static LegalType simple(Type* type) - { - LegalType result; - result.flavor = Flavor::simple; - result.obj = type; - return result; - } - - RefPtr<Type> getSimple() - { - assert(flavor == Flavor::simple); - return obj.As<Type>(); - } - - static LegalType implicitDeref( - LegalType const& valueType); - - RefPtr<ImplicitDerefType> getImplicitDeref() - { - assert(flavor == Flavor::implicitDeref); - return obj.As<ImplicitDerefType>(); - } - - static LegalType tuple( - RefPtr<TuplePseudoType> tupleType); - - RefPtr<TuplePseudoType> getTuple() - { - assert(flavor == Flavor::tuple); - return obj.As<TuplePseudoType>(); - } - - static LegalType pair( - RefPtr<PairPseudoType> pairType); - - static LegalType pair( - RefPtr<Type> ordinaryType, - LegalType const& specialType, - RefPtr<PairInfo> pairInfo); - - RefPtr<PairPseudoType> getPair() - { - assert(flavor == Flavor::pair); - return obj.As<PairPseudoType>(); - } -}; - -struct ImplicitDerefType : LegalTypeImpl -{ - LegalType valueType; -}; - -LegalType LegalType::implicitDeref( - LegalType const& valueType) -{ - RefPtr<ImplicitDerefType> obj = new ImplicitDerefType(); - obj->valueType = valueType; - - LegalType result; - result.flavor = Flavor::implicitDeref; - result.obj = obj; - return result; -} - -// Represents the pseudo-type for a compound type -// that had to be broken apart because it contained -// one or more fields of types that shouldn't be -// allowed in aggregates. -// -// A tuple pseduo-type will have an element for -// each field of the original type, that represents -// the legalization of that field's type. -// -// It optionally also contains an "ordinary" type -// that packs together any per-field data that -// itself has (or contains) an ordinary type. -struct TuplePseudoType : LegalTypeImpl -{ - // Represents one element of the tuple pseudo-type - struct Element - { - // The field that this element replaces - DeclRef<VarDeclBase> fieldDeclRef; - - // The legalized type of the element - LegalType type; - }; - - // All of the elements of the tuple pseduo-type. - List<Element> elements; -}; - -LegalType LegalType::tuple( - RefPtr<TuplePseudoType> tupleType) -{ - LegalType result; - result.flavor = Flavor::tuple; - result.obj = tupleType; - return result; -} - -struct PairInfo : RefObject -{ - typedef unsigned int Flags; - enum - { - kFlag_hasOrdinary = 0x1, - kFlag_hasSpecial = 0x2, - }; - - struct Element - { - // The field the element represents - DeclRef<Decl> fieldDeclRef; - - // The conceptual type of the field. - // If both the `hasOrdinary` and - // `hasSpecial` bits are set, then - // this is expected to be a - // `LegalType::Flavor::pair` - LegalType type; - - // Is the value represented on - // the ordinary side, the special - // side, or both? - Flags flags; - }; - - // For a pair type or value, we need to track - // which fields are on which side(s). - List<Element> elements; - - Element* findElement(DeclRef<Decl> const& fieldDeclRef) - { - for (auto& ee : elements) - { - if(ee.fieldDeclRef.Equals(fieldDeclRef)) - return ⅇ - } - return nullptr; - } -}; - -struct PairPseudoType : LegalTypeImpl -{ - // Any field(s) with ordinary types will - // get captured here, as a completely - // standard AST-level type. - RefPtr<Type> ordinaryType; - - // Any fields with "special" (not ordinary) - // types will get captured here (usually - // with a tuple). - LegalType specialType; - - RefPtr<PairInfo> pairInfo; -}; - -LegalType LegalType::pair( - RefPtr<PairPseudoType> pairType) -{ - LegalType result; - result.flavor = Flavor::pair; - result.obj = pairType; - return result; -} - -LegalType LegalType::pair( - RefPtr<Type> ordinaryType, - LegalType const& specialType, - RefPtr<PairInfo> pairInfo) -{ - // Handle some special cases for when - // one or the other of the types isn't - // actually used. - - if (!ordinaryType) - { - // There was nothing ordinary. - return specialType; - } - - if (specialType.flavor == LegalType::Flavor::none) - { - return LegalType::simple(ordinaryType); - } - - // There were both ordinary and special fields, - // and so we need to handle them here. - - RefPtr<PairPseudoType> obj = new PairPseudoType(); - obj->ordinaryType = ordinaryType; - obj->specialType = specialType; - obj->pairInfo = pairInfo; - return LegalType::pair(obj); -} struct LegalValImpl : RefObject @@ -390,12 +158,15 @@ LegalVal LegalVal::getImplicitDeref() } -struct TypeLegalizationContext +struct IRTypeLegalizationContext { Session* session; IRModule* module; IRBuilder* builder; + /// Context to use for underlying (non-IR) type legalization. + TypeLegalizationContext* typeLegalizationContext; + // When inserting new globals, put them before this one. IRGlobalValue* insertBeforeGlobal = nullptr; @@ -411,633 +182,31 @@ struct TypeLegalizationContext }; static void registerLegalizedValue( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRValue* irValue, LegalVal const& legalVal) { context->mapValToLegalVal.Add(irValue, legalVal); } - -static bool isResourceType(Type* type) -{ - while (auto arrayType = type->As<ArrayExpressionType>()) - { - type = arrayType->baseType; - } - - if (auto textureTypeBase = type->As<TextureTypeBase>()) - { - return true; - } - else if (auto samplerType = type->As<SamplerStateType>()) - { - return true; - } - - // TODO: need more comprehensive coverage here - - return false; -} - -static LegalType legalizeType( - TypeLegalizationContext* context, - Type* type); - -// Helper type for legalization of aggregate types -// that might need to be turned into tuple pseudo-types. -struct TupleTypeBuilder -{ - TypeLegalizationContext* context; - RefPtr<Type> type; - - List<FilteredTupleType::Element> ordinaryElements; - List<TuplePseudoType::Element> specialElements; - - List<PairInfo::Element> pairElements; - - // Did we have any fields that forced us to change - // the actual type away from the declared type? - bool anyComplex = false; - - // Did we have any fields that actually required - // storage in the "special" part of things? - bool anySpecial = false; - - // Did we have any fields that actually used ordinary storage? - bool anyOrdinary = false; - - // Add a field to the (pseudo-)type we are building - void addField( - DeclRef<VarDeclBase> fieldDeclRef, - LegalType legalFieldType, - LegalType legalLeafType, - bool isResource) - { - RefPtr<Type> ordinaryType; - LegalType specialType; - RefPtr<PairInfo> elementPairInfo; - switch (legalLeafType.flavor) - { - case LegalType::Flavor::simple: - { - // We need to add an actual field, but we need - // to check if it is a resource type to know - // whether it should go in the "ordinary" list or not. - if (!isResource) - { - ordinaryType = legalLeafType.getSimple(); - } - else - { - specialType = legalFieldType; - } - } - break; - - case LegalType::Flavor::implicitDeref: - { - // TODO: we may want to say that any use - // of `implicitDeref` puts the entire thing - // into the "special" category, rather than - // try to look under the hood... - - anyComplex = true; - - // We want to recursively add data - // based on the unwrapped type. - // - // Note: this assumes we can't have a tuple - // or a pair "under" an `implicitDeref`, so - // we'll need to ensure that elsewhere. - addField( - fieldDeclRef, - legalFieldType, - legalLeafType.getImplicitDeref()->valueType, - isResource); - return; - } - break; - - case LegalType::Flavor::pair: - { - // The field's type had both special and non-special parts - auto pairType = legalLeafType.getPair(); - ordinaryType = pairType->ordinaryType; - specialType = pairType->specialType; - elementPairInfo = pairType->pairInfo; - } - break; - - case LegalType::Flavor::tuple: - { - // A tuple always represents "special" data - specialType = legalFieldType; - } - break; - - default: - SLANG_UNEXPECTED("unknown legal type flavor"); - break; - } - - - PairInfo::Element pairElement; - pairElement.flags = 0; - pairElement.fieldDeclRef = fieldDeclRef; - - if (ordinaryType) - { - anyOrdinary = true; - pairElement.flags |= PairInfo::kFlag_hasOrdinary; - - FilteredTupleType::Element ordinaryElement; - ordinaryElement.fieldDeclRef = fieldDeclRef; - ordinaryElement.type = ordinaryType; - ordinaryElements.Add(ordinaryElement); - } - - if (specialType.flavor != LegalType::Flavor::none) - { - anySpecial = true; - anyComplex = true; - pairElement.flags |= PairInfo::kFlag_hasSpecial; - - TuplePseudoType::Element specialElement; - specialElement.fieldDeclRef = fieldDeclRef; - specialElement.type = specialType; - specialElements.Add(specialElement); - } - - pairElement.type = LegalType::pair(ordinaryType, specialType, elementPairInfo); - pairElements.Add(pairElement); - } - - // Add a field to the (pseudo-)type we are building - void addField( - DeclRef<VarDeclBase> fieldDeclRef) - { - // Skip `static` fields. - if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) - return; - - auto fieldType = GetType(fieldDeclRef); - - bool isResourceField = isResourceType(fieldType); - - auto legalFieldType = legalizeType(context, fieldType); - addField( - fieldDeclRef, - legalFieldType, - legalFieldType, - isResourceField); - } - - LegalType getResult() - { - // If we didn't see anything "special" - // then we can use the type as-is. - // we can conceivably just use the type as-is - // - // TODO: this might be a good place to turn - // a reference to a generic `struct` type into - // a concrete non-generic type so that downstream - // codegen doesn't have to deal with generics... - // - // TODO: In fact, why not just fully replace - // all aggregate types here with some structural - // types defined in the IR? - if (!anyComplex) - { - return LegalType::simple(type); - } - - // If there were any "ordinary" fields along the way, - // then we need to collect them into a type to - // represent the ordinary part of things. - // - RefPtr<Type> ordinaryType; - if (anyOrdinary) - { - RefPtr<FilteredTupleType> ordinaryTypeImpl = new FilteredTupleType(); - ordinaryTypeImpl->setSession(context->session); - ordinaryTypeImpl->originalType = type; - ordinaryTypeImpl->elements = ordinaryElements; - ordinaryType = ordinaryTypeImpl; - } - - LegalType specialType; - if (anySpecial) - { - RefPtr<TuplePseudoType> specialTuple = new TuplePseudoType(); - specialTuple->elements = specialElements; - specialType = LegalType::tuple(specialTuple); - } - - RefPtr<PairInfo> pairInfo; - if (anyOrdinary && anySpecial) - { - pairInfo = new PairInfo(); - pairInfo->elements = pairElements; - } - - return LegalType::pair(ordinaryType, specialType, pairInfo); - } - -}; - -static RefPtr<Type> createBuiltinGenericType( - TypeLegalizationContext* context, - DeclRef<Decl> const& typeDeclRef, - RefPtr<Type> elementType) -{ - // We are going to take the type for the original - // decl-ref and construct a new one that uses - // our new element type as its parameter. - // - // TODO: we should have library code to make - // manipulations like this way easier. - - RefPtr<GenericSubstitution> oldGenericSubst = getGenericSubstitution( - typeDeclRef.substitutions); - SLANG_ASSERT(oldGenericSubst); - - RefPtr<GenericSubstitution> newGenericSubst = new GenericSubstitution(); - - newGenericSubst->outer = oldGenericSubst->outer; - newGenericSubst->genericDecl = oldGenericSubst->genericDecl; - newGenericSubst->args = oldGenericSubst->args; - newGenericSubst->args[0] = elementType; - - auto newDeclRef = DeclRef<Decl>( - typeDeclRef.getDecl(), - newGenericSubst); - - auto newType = DeclRefType::Create( - context->session, - newDeclRef); - - return newType; -} - -// Create a uniform buffer type with a given legalized -// element type. -static LegalType createLegalUniformBufferType( - TypeLegalizationContext* context, - DeclRef<Decl> const& typeDeclRef, - LegalType legalElementType) -{ - switch (legalElementType.flavor) - { - case LegalType::Flavor::simple: - { - // Easy case: we just have a simple element type, - // so we want to create a uniform buffer that wraps it. - return LegalType::simple(createBuiltinGenericType( - context, - typeDeclRef, - legalElementType.getSimple())); - } - break; - - case LegalType::Flavor::implicitDeref: - { - // This is actually an annoying case, because - // we are being asked to convert, e.g.,: - // - // cbuffer Foo { ParameterBlock<Bar> bar; } - // - // into the equivalent of: - // - // cbuffer Foo { Bar bar; } - // - // Which would really require a new `LegalType` that - // would reprerent a resource type with a modified - // element type. - // - // I'm going to attempt to hack this for now. - return LegalType::implicitDeref(createLegalUniformBufferType( - context, - typeDeclRef, - legalElementType.getImplicitDeref()->valueType)); - } - break; - - case LegalType::Flavor::pair: - { - // We assume that the "ordinary" part of things - // will get wrapped in a constant-buffer type, - // and the "special" part needs to be wrapped - // with an `implicitDeref`. - auto pairType = legalElementType.getPair(); - - auto ordinaryType = createBuiltinGenericType( - context, - typeDeclRef, - pairType->ordinaryType); - auto specialType = LegalType::implicitDeref(pairType->specialType); - - return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - // if we have a tuple type, then it must be representing - // the fields that can't be stored in a buffer anyway, - // so we just need to wrap each of them in an `implicitDeref` - - auto elementPseudoTupleType = legalElementType.getTuple(); - - RefPtr<TuplePseudoType> bufferPseudoTupleType = new TuplePseudoType(); - - // Wrap all the pseudo-tuple elements with `implicitDeref`, - // since they used to be inside a tuple, but aren't any more. - for (auto ee : elementPseudoTupleType->elements) - { - TuplePseudoType::Element newElement; - - newElement.fieldDeclRef = ee.fieldDeclRef; - newElement.type = LegalType::implicitDeref(ee.type); - - bufferPseudoTupleType->elements.Add(newElement); - } - - return LegalType::tuple(bufferPseudoTupleType); - } - break; - - default: - SLANG_UNEXPECTED("unknown legal type flavor"); - UNREACHABLE_RETURN(LegalType()); - break; - } -} - -static LegalType createLegalUniformBufferType( - TypeLegalizationContext* context, - UniformParameterGroupType* uniformBufferType, - LegalType legalElementType) -{ - return createLegalUniformBufferType( - context, - uniformBufferType->declRef, - legalElementType); -} - -// Create a pointer type with a given legalized value type. -static LegalType createLegalPtrType( - TypeLegalizationContext* context, - DeclRef<Decl> const& typeDeclRef, - LegalType legalValueType) -{ - switch (legalValueType.flavor) - { - case LegalType::Flavor::simple: - { - // Easy case: we just have a simple element type, - // so we want to create a uniform buffer that wraps it. - return LegalType::simple(createBuiltinGenericType( - context, - typeDeclRef, - legalValueType.getSimple())); - } - break; - - case LegalType::Flavor::implicitDeref: - { - // We are being asked to create a pointer type to something - // that is implicitly dereferenced, meaning we had: - // - // Ptr(PtrLink(T)) - // - // and now are being asked to make: - // - // Ptr(implicitDeref(LegalT)) - // - // So it seems like we can just create: - // - // implicitDeref(Ptr(LegalT)) - // - // and nobody should really be able to tell the difference, right? - return LegalType::implicitDeref(createLegalPtrType( - context, - typeDeclRef, - legalValueType.getImplicitDeref()->valueType)); - } - break; - - case LegalType::Flavor::pair: - { - // We just need to pointer-ify both sides of the pair. - auto pairType = legalValueType.getPair(); - - auto ordinaryType = createBuiltinGenericType( - context, - typeDeclRef, - pairType->ordinaryType); - auto specialType = createLegalPtrType( - context, - typeDeclRef, - pairType->specialType); - - return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); - } - - case LegalType::Flavor::tuple: - { - // Wrap each of the tuple elements up as a pointer. - auto valuePseudoTupleType = legalValueType.getTuple(); - - RefPtr<TuplePseudoType> ptrPseudoTupleType = new TuplePseudoType(); - - // Wrap all the pseudo-tuple elements with `implicitDeref`, - // since they used to be inside a tuple, but aren't any more. - for (auto ee : valuePseudoTupleType->elements) - { - TuplePseudoType::Element newElement; - - newElement.fieldDeclRef = ee.fieldDeclRef; - newElement.type = createLegalPtrType( - context, - typeDeclRef, - ee.type); - - ptrPseudoTupleType->elements.Add(newElement); - } - - return LegalType::tuple(ptrPseudoTupleType); - } - break; - - default: - SLANG_UNEXPECTED("unknown legal type flavor"); - UNREACHABLE_RETURN(LegalType()); - break; - } -} - -// Legalize a type, including any nested types -// that it transitively contains. -static LegalType legalizeType( - TypeLegalizationContext* context, - Type* type) -{ - if (auto parameterBlockType = type->As<ParameterBlockType>()) - { - // We basically legalize the `ParameterBlock<T>` type - // over to `T`. In order to represent this preoperly, - // we need to be careful to wrap it up in a way that - // tells us to eliminate downstream deferences... - - auto legalElementType = legalizeType(context, - parameterBlockType->getElementType()); - return LegalType::implicitDeref(legalElementType); - } - else if (auto uniformBufferType = type->As<UniformParameterGroupType>()) - { - // We have a `ConstantBuffer<T>` or `TextureBuffer<T>` or - // other pointer-like type that represents uniform parameters. - // We need to pull any resource-type fields out of it, but - // leave the non-resource fields where they are. - - // Legalize the element type to see what we are working with. - auto legalElementType = legalizeType(context, - uniformBufferType->getElementType()); - - switch (legalElementType.flavor) - { - case LegalType::Flavor::simple: - return LegalType::simple(type); - - default: - return createLegalUniformBufferType( - context, - uniformBufferType, - legalElementType); - } - - } - else if (isResourceType(type)) - { - // We assume that any resource types not handled above - // are legal as-is. - return LegalType::simple(type); - } - else if (type->As<BasicExpressionType>()) - { - return LegalType::simple(type); - } - else if (type->As<VectorExpressionType>()) - { - return LegalType::simple(type); - } - else if (type->As<MatrixExpressionType>()) - { - return LegalType::simple(type); - } - else if (auto ptrType = type->As<PtrTypeBase>()) - { - auto legalValueType = legalizeType(context, ptrType->getValueType()); - return createLegalPtrType(context, ptrType->declRef, legalValueType); - } - else if (auto declRefType = type->As<DeclRefType>()) - { - auto declRef = declRefType->declRef; - if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) - { - // Look at the (non-static) fields, and - // see if anything needs to be cleaned up. - // The things that need to be "cleaned up" for - // our purposes are: - // - // - Fields of resource type, or any other future - // type we run into that isn't allowed in - // aggregates for at least some targets - // - // - Fields with types that themselves had to - // get legalized. - // - // If we don't run into any of these, we - // can just use the type as-is. Hooray! - // - // Otherwise, we are effectively going to split - // the type apart and create a `TuplePseudoType`. - // Every field of the original type will be - // represented as an element of this pseudo-type. - // Each element will record its `LegalType`, - // and the original field that it was created from. - // An element will also track whether it contains - // any "ordinary" data, and if so, it will remember - // an element index in a real (AST-level, non-pseudo) - // `TupleType` that is used to bundle together - // such fields. - // - // Storing all the simple fields together like this - // obviously adds complexity to the legalization - // pass, but it has important benefits: - // - // - It avoids creating functions with a very large - // number of parameters (when passing a structure - // with many fields), which might confuse downstream - // compilers. - // - // - It avoids applying AOS->SOA conversion to fields - // that don't actually need it, which is basically - // required if we want type layout to work. - // - // - It ensures that we can actually construct a - // constant-buffer type that wraps a legalized - // aggregate type; the ordinary fields will get - // placed inside a new constant-buffer type, - // while the special ones will get left outside. - // - - TupleTypeBuilder builder; - builder.context = context; - builder.type = type; - - - for (auto ff : getMembersOfType<StructField>(aggTypeDeclRef)) - { - builder.addField(ff); - } - - return builder.getResult(); - } - - // TODO: for other declaration-reference types, we really - // need to legalize the types used in substitutions, and - // signal an error if any of them turn out to be non-simple. - // - // The limited cases of types that can handle having non-simple - // types as generic arguments all need to be special-cased here. - // (For example, we can't handle `Texture2D<SomeStructWithTexturesInIt>`. - // - } - - return LegalType::simple(type); -} - -// Represents the "chain" of declarations that -// were followed to get to a variable that we -// are now declaring as a leaf variable. -struct LegalVarChain -{ - LegalVarChain* next; - VarLayout* varLayout; -}; - static LegalVal declareVars( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, LegalVarChain* varChain); +static LegalType legalizeType( + IRTypeLegalizationContext* context, + Type* type) +{ + return legalizeType(context->typeLegalizationContext, type); +} + // Legalize a type, and then expect it to // result in a simple type. static RefPtr<Type> legalizeSimpleType( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, Type* type) { auto legalType = legalizeType(context, type); @@ -1056,7 +225,7 @@ static RefPtr<Type> legalizeSimpleType( // Take a value that is being used as an operand, // and turn it into the equivalent legalized value. static LegalVal legalizeOperand( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRValue* irValue) { LegalVal legalVal; @@ -1111,7 +280,7 @@ static void getArgumentValues( } static LegalVal legalizeCall( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRCall* callInst) { // TODO: implement legalization of non-simple return types @@ -1130,7 +299,7 @@ static LegalVal legalizeCall( } static LegalVal legalizeLoad( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, LegalVal legalPtrVal) { switch (legalPtrVal.flavor) @@ -1186,7 +355,7 @@ static LegalVal legalizeLoad( } static LegalVal legalizeStore( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, LegalVal legalPtrVal, LegalVal legalVal) { @@ -1241,19 +410,13 @@ static LegalVal legalizeStore( } static LegalVal legalizeFieldAddress( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, LegalType type, LegalVal legalPtrOperand, - LegalVal legalFieldOperand) + DeclRef<Decl> fieldDeclRef) { auto builder = context->builder; - // We don't expect any legalization to affect - // the "field" argument. - auto fieldOperand = legalFieldOperand.getSimple(); - assert(fieldOperand->op == kIROp_decl_ref); - auto fieldDeclRef = ((IRDeclRef*)fieldOperand)->declRef; - switch (legalPtrOperand.flavor) { case LegalVal::Flavor::simple: @@ -1261,7 +424,7 @@ static LegalVal legalizeFieldAddress( builder->emitFieldAddress( type.getSimple(), legalPtrOperand.getSimple(), - fieldOperand)); + builder->getDeclRefVal(fieldDeclRef))); case LegalVal::Flavor::pair: { @@ -1286,7 +449,7 @@ static LegalVal legalizeFieldAddress( { auto fieldPairType = type.getPair(); fieldPairInfo = fieldPairType->pairInfo; - ordinaryType = LegalType::simple(fieldPairType->ordinaryType); + ordinaryType = fieldPairType->ordinaryType; specialType = fieldPairType->specialType; } @@ -1295,12 +458,27 @@ static LegalVal legalizeFieldAddress( if (pairElement->flags & PairInfo::kFlag_hasOrdinary) { - ordinaryVal = legalizeFieldAddress(context, ordinaryType, pairVal->ordinaryVal, legalFieldOperand); + // Note: the ordinary side of the pair is expected + // to be a filtered `struct` type, and so it will + // have different field declarations than the + // oridinal type. The element of the `PairInfo` + // structure stores the correct field decl-ref to use + // as `ordinaryFieldDeclRef`. + + ordinaryVal = legalizeFieldAddress( + context, + ordinaryType, + pairVal->ordinaryVal, + pairElement->ordinaryFieldDeclRef); } if (pairElement->flags & PairInfo::kFlag_hasSpecial) { - specialVal = legalizeFieldAddress(context, specialType, pairVal->specialVal, legalFieldOperand); + specialVal = legalizeFieldAddress( + context, + specialType, + pairVal->specialVal, + fieldDeclRef); } return LegalVal::pair(ordinaryVal, specialVal, fieldPairInfo); } @@ -1335,8 +513,27 @@ static LegalVal legalizeFieldAddress( } } +static LegalVal legalizeFieldAddress( + IRTypeLegalizationContext* context, + LegalType type, + LegalVal legalPtrOperand, + LegalVal legalFieldOperand) +{ + // We don't expect any legalization to affect + // the "field" argument. + auto fieldOperand = legalFieldOperand.getSimple(); + assert(fieldOperand->op == kIROp_decl_ref); + auto fieldDeclRef = ((IRDeclRef*)fieldOperand)->declRef; + + return legalizeFieldAddress( + context, + type, + legalPtrOperand, + fieldDeclRef); +} + static LegalVal legalizeInst( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRInst* inst, LegalType type, LegalVal const* args) @@ -1370,7 +567,7 @@ RefPtr<VarLayout> findVarLayout(IRValue* value) } static LegalVal legalizeLocalVar( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRVar* irLocalVar) { // Legalize the type for the variable's value @@ -1424,7 +621,7 @@ static LegalVal legalizeLocalVar( } static LegalVal legalizeInst( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRInst* inst) { if (inst->op == kIROp_Var) @@ -1512,7 +709,7 @@ static void addParamType(IRFuncType * ftype, LegalType t) case LegalType::Flavor::pair: { auto pairInfo = t.getPair(); - addParamType(ftype, LegalType::simple(pairInfo->ordinaryType)); + addParamType(ftype, pairInfo->ordinaryType); addParamType(ftype, pairInfo->specialType); } break; @@ -1529,7 +726,7 @@ static void addParamType(IRFuncType * ftype, LegalType t) } static void legalizeFunc( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRFunc* irFunc) { // Overwrite the function's type with @@ -1606,61 +803,13 @@ static void legalizeFunc( } static LegalVal declareSimpleVar( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IROp op, Type* type, TypeLayout* typeLayout, LegalVarChain* varChain) { - RefPtr<VarLayout> varLayout; - if (typeLayout) - { - // We need to construct a layout for the new variable - // that reflects both the type we have given it, as - // well as all the offset information that has accumulated - // along the chain of parent variables. - - // TODO: this logic needs to propagate through semantics... - - varLayout = new VarLayout(); - varLayout->typeLayout = typeLayout; - - for (auto rr : typeLayout->resourceInfos) - { - auto resInfo = varLayout->findOrAddResourceInfo(rr.kind); - - for (auto vv = varChain; vv; vv = vv->next) - { - if (auto parentResInfo = vv->varLayout->FindResourceInfo(rr.kind)) - { - resInfo->index += parentResInfo->index; - resInfo->space += parentResInfo->space; - } - } - } - - // Some of the parent variables might actually contain offsets - // to the `space` or `set` of the field, and we need to apply - // those to all the nested resource infos. - for (auto vv = varChain; vv; vv = vv->next) - { - auto parentSpaceInfo = vv->varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace); - if (!parentSpaceInfo) - continue; - - for (auto& rr : varLayout->resourceInfos) - { - if (rr.kind == LayoutResourceKind::RegisterSpace) - { - rr.index += parentSpaceInfo->index; - } - else - { - rr.space += parentSpaceInfo->index; - } - } - } - } + RefPtr<VarLayout> varLayout = createVarLayout(varChain, typeLayout); DeclRef<VarDeclBase> varDeclRef; if (varChain) @@ -1733,39 +882,8 @@ static LegalVal declareSimpleVar( return legalVarVal; } -static RefPtr<TypeLayout> getDerefTypeLayout( - TypeLayout* typeLayout) -{ - if (!typeLayout) - return nullptr; - - if (auto parameterGroupTypeLayout = dynamic_cast<ParameterGroupTypeLayout*>(typeLayout)) - { - return parameterGroupTypeLayout->elementTypeLayout; - } - - return typeLayout; -} - -static RefPtr<VarLayout> getFieldLayout( - TypeLayout* typeLayout, - DeclRef<VarDeclBase> fieldDeclRef) -{ - if (!typeLayout) - return nullptr; - - if (auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout)) - { - RefPtr<VarLayout> fieldLayout; - if (structTypeLayout->mapVarToLayout.TryGetValue(fieldDeclRef.getDecl(), fieldLayout)) - return fieldLayout; - } - - return nullptr; -} - static LegalVal declareVars( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IROp op, LegalType type, TypeLayout* typeLayout, @@ -1798,7 +916,7 @@ static LegalVal declareVars( case LegalType::Flavor::pair: { auto pairType = type.getPair(); - auto ordinaryVal = declareVars(context, op, LegalType::simple(pairType->ordinaryType), typeLayout, varChain); + auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain); auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain); return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo); } @@ -1852,7 +970,7 @@ static LegalVal declareVars( } static void legalizeGlobalVar( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRGlobalVar* irGlobalVar) { // Legalize the type for the variable's value @@ -1899,7 +1017,7 @@ static void legalizeGlobalVar( } static void legalizeGlobalValue( - TypeLegalizationContext* context, + IRTypeLegalizationContext* context, IRGlobalValue* irValue) { switch (irValue->op) @@ -1923,7 +1041,7 @@ static void legalizeGlobalValue( } static void legalizeTypes( - TypeLegalizationContext* context) + IRTypeLegalizationContext* context) { auto module = context->module; for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) @@ -1934,7 +1052,8 @@ static void legalizeTypes( void legalizeTypes( - IRModule* module) + TypeLegalizationContext* typeLegalizationContext, + IRModule* module) { auto session = module->session; @@ -1950,13 +1069,15 @@ void legalizeTypes( builder->sharedBuilder = sharedBuilder; - TypeLegalizationContext contextStorage; + IRTypeLegalizationContext contextStorage; auto context = &contextStorage; context->session = session; context->module = module; context->builder = builder; + context->typeLegalizationContext = typeLegalizationContext; + legalizeTypes(context); } diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 3a8aabd85..d35ffe2d1 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -3764,6 +3764,14 @@ namespace Slang String const& mangledName, IRGlobalValue* originalVal) { + // Check if we've already cloned this value, for the case where + // an original value has already been established. + IRValue* clonedVal = nullptr; + if( originalVal && context->getClonedValues().TryGetValue(originalVal, clonedVal) ) + { + return (IRGlobalValue*) clonedVal; + } + if(mangledName.Length() == 0) { // If there is no mangled name, then we assume this is a local symbol, @@ -3779,6 +3787,9 @@ namespace Slang RefPtr<IRSpecSymbol> sym; if( !context->getSymbols().TryGetValue(mangledName, sym) ) { + if(!originalVal) + return nullptr; + // This shouldn't happen! SLANG_UNEXPECTED("no matching values registered"); UNREACHABLE_RETURN(cloneGlobalValueImpl(context, originalVal, nullptr)); @@ -3798,6 +3809,14 @@ namespace Slang bestVal = newVal; } + // Check if we've already cloned this value, for the case where + // we didn't have an original value (just a name), but we've + // now found a representative value. + if( !originalVal && context->getClonedValues().TryGetValue(bestVal, clonedVal) ) + { + return (IRGlobalValue*) clonedVal; + } + return cloneGlobalValueImpl(context, bestVal, sym); } diff --git a/source/slang/legalize-types.cpp b/source/slang/legalize-types.cpp new file mode 100644 index 000000000..510b9acd3 --- /dev/null +++ b/source/slang/legalize-types.cpp @@ -0,0 +1,1086 @@ +// legalize-types.cpp +#include "legalize-types.h" + +namespace Slang +{ + +LegalType LegalType::implicitDeref( + LegalType const& valueType) +{ + RefPtr<ImplicitDerefType> obj = new ImplicitDerefType(); + obj->valueType = valueType; + + LegalType result; + result.flavor = Flavor::implicitDeref; + result.obj = obj; + return result; +} + +LegalType LegalType::tuple( + RefPtr<TuplePseudoType> tupleType) +{ + LegalType result; + result.flavor = Flavor::tuple; + result.obj = tupleType; + return result; +} + +LegalType LegalType::pair( + RefPtr<PairPseudoType> pairType) +{ + LegalType result; + result.flavor = Flavor::pair; + result.obj = pairType; + return result; +} + +LegalType LegalType::pair( + LegalType const& ordinaryType, + LegalType const& specialType, + RefPtr<PairInfo> pairInfo) +{ + // Handle some special cases for when + // one or the other of the types isn't + // actually used. + + if (ordinaryType.flavor == LegalType::Flavor::none) + { + // There was nothing ordinary. + return specialType; + } + + if (specialType.flavor == LegalType::Flavor::none) + { + return ordinaryType; + } + + // There were both ordinary and special fields, + // and so we need to handle them here. + + RefPtr<PairPseudoType> obj = new PairPseudoType(); + obj->ordinaryType = ordinaryType; + obj->specialType = specialType; + obj->pairInfo = pairInfo; + return LegalType::pair(obj); +} + +// + +static bool isResourceType(Type* type) +{ + while (auto arrayType = type->As<ArrayExpressionType>()) + { + type = arrayType->baseType; + } + + if (auto resourceTypeBase = type->As<ResourceTypeBase>()) + { + return true; + } + else if (auto builtinGenericType = type->As<BuiltinGenericType>()) + { + return true; + } + else if (auto pointerLikeType = type->As<PointerLikeType>()) + { + return true; + } + else if (auto samplerType = type->As<SamplerStateType>()) + { + return true; + } + + // TODO: need more comprehensive coverage here + + return false; +} + +ModuleDecl* findModuleForDecl( + Decl* decl) +{ + for (auto dd = decl; dd; dd = dd->ParentDecl) + { + if (auto moduleDecl = dynamic_cast<ModuleDecl*>(dd)) + return moduleDecl; + } + return nullptr; +} + + +// Helper type for legalization of aggregate types +// that might need to be turned into tuple pseudo-types. +struct TupleTypeBuilder +{ + TypeLegalizationContext* context; + RefPtr<Type> type; + DeclRef<AggTypeDecl> typeDeclRef; + + struct OrdinaryElement + { + DeclRef<VarDeclBase> fieldDeclRef; + RefPtr<Type> type; + }; + + + List<OrdinaryElement> ordinaryElements; + List<TuplePseudoType::Element> specialElements; + + List<PairInfo::Element> pairElements; + + // Did we have any fields that forced us to change + // the actual type away from the declared type? + bool anyComplex = false; + + // Did we have any fields that actually required + // storage in the "special" part of things? + bool anySpecial = false; + + // Did we have any fields that actually used ordinary storage? + bool anyOrdinary = false; + + // Add a field to the (pseudo-)type we are building + void addField( + DeclRef<VarDeclBase> fieldDeclRef, + LegalType legalFieldType, + LegalType legalLeafType, + bool isResource) + { + LegalType ordinaryType; + LegalType specialType; + RefPtr<PairInfo> elementPairInfo; + switch (legalLeafType.flavor) + { + case LegalType::Flavor::simple: + { + // We need to add an actual field, but we need + // to check if it is a resource type to know + // whether it should go in the "ordinary" list or not. + if (!isResource) + { + ordinaryType = legalLeafType; + } + else + { + specialType = legalFieldType; + } + } + break; + + case LegalType::Flavor::implicitDeref: + { + // TODO: we may want to say that any use + // of `implicitDeref` puts the entire thing + // into the "special" category, rather than + // try to look under the hood... + + anyComplex = true; + + // We want to recursively add data + // based on the unwrapped type. + // + // Note: this assumes we can't have a tuple + // or a pair "under" an `implicitDeref`, so + // we'll need to ensure that elsewhere. + addField( + fieldDeclRef, + legalFieldType, + legalLeafType.getImplicitDeref()->valueType, + isResource); + return; + } + break; + + case LegalType::Flavor::pair: + { + // The field's type had both special and non-special parts + auto pairType = legalLeafType.getPair(); + ordinaryType = pairType->ordinaryType; + specialType = pairType->specialType; + elementPairInfo = pairType->pairInfo; + } + break; + + case LegalType::Flavor::tuple: + { + // A tuple always represents "special" data + specialType = legalFieldType; + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + break; + } + + + PairInfo::Element pairElement; + pairElement.flags = 0; + pairElement.fieldDeclRef = fieldDeclRef; + pairElement.fieldPairInfo = elementPairInfo; + + // We will always add a field to the "ordinary" + // side of things, even if it has no ordinary + // data, just to keep the list of fields aligned + // with the original type. + OrdinaryElement ordinaryElement; + ordinaryElement.fieldDeclRef = fieldDeclRef; + if (ordinaryType.flavor != LegalType::Flavor::none) + { + anyOrdinary = true; + pairElement.flags |= PairInfo::kFlag_hasOrdinary; + + LegalType ot = ordinaryType; + + // TODO: any cases we should "unwrap" here? + // E.g., `implicitDeref`? + + if(ot.flavor == LegalType::Flavor::simple) + { + ordinaryElement.type = ot.getSimple(); + } + else + { + SLANG_UNEXPECTED("unexpected ordinary field type"); + } + } + ordinaryElements.Add(ordinaryElement); + + if (specialType.flavor != LegalType::Flavor::none) + { + anySpecial = true; + anyComplex = true; + pairElement.flags |= PairInfo::kFlag_hasSpecial; + + TuplePseudoType::Element specialElement; + specialElement.fieldDeclRef = fieldDeclRef; + specialElement.type = specialType; + specialElements.Add(specialElement); + } + + pairElement.type = LegalType::pair(ordinaryType, specialType, elementPairInfo); + pairElements.Add(pairElement); + } + + // Add a field to the (pseudo-)type we are building + void addField( + DeclRef<VarDeclBase> fieldDeclRef) + { + // Skip `static` fields. + if (fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + return; + + auto fieldType = GetType(fieldDeclRef); + + bool isResourceField = isResourceType(fieldType); + + auto legalFieldType = legalizeType(context, fieldType); + addField( + fieldDeclRef, + legalFieldType, + legalFieldType, + isResourceField); + } + + LegalType getResult() + { + // If we didn't see anything "special" + // then we can use the type as-is. + // we can conceivably just use the type as-is + // + // TODO: this might be a good place to turn + // a reference to a generic `struct` type into + // a concrete non-generic type so that downstream + // codegen doesn't have to deal with generics... + // + // TODO: In fact, why not just fully replace + // all aggregate types here with some structural + // types defined in the IR? + if (!anyComplex) + { + return LegalType::simple(type); + } + + // If there were any "ordinary" fields along the way, + // then we need to collect them into a new `struct` type + // that represents these fields. + // + LegalType ordinaryType; + if (anyOrdinary) + { + // We are going to create a new `struct` type declaration that clones + // the fields we care about from the original `struct` type. Note that + // these fields may have different types from what they did before, + // because the fields themselves might have been legalized. + // + // Our new declaration will have the same name as the old one, so + // downstream code is going to need to be careful not to emit declarations + // for both of them. This should be okay, though, because the original + // type was illegal (that was the whole point) and so it shouldn't be + // allowed in the output anyway. + RefPtr<StructDecl> ordinaryStructDecl = new StructDecl(); + ordinaryStructDecl->loc = typeDeclRef.getDecl()->loc; + ordinaryStructDecl->nameAndLoc = typeDeclRef.getDecl()->nameAndLoc; + + addModifier(ordinaryStructDecl, new LegalizedModifier()); + + // We will do something a bit unsavory here, by setting the logical + // parent of the new `struct` type to be the same as the orignal type + // (All of this helps ensure it gets the same mangled name). + // + ordinaryStructDecl->ParentDecl = typeDeclRef.getDecl()->ParentDecl; + + if (context->mainModuleDecl) + { + // If the declaration we are lowering belongs to the AST-based + // module being lowered (rather than translated to IR), then we + // need to add any new declaration we create to that output. + + // If we are *not* outputting an IR module as well, then + // everything needs to wind up in a single AST module. + if (!context->irModule) + { + context->outputModuleDecl->Members.Add(ordinaryStructDecl); + } + else + { + // Otherwise, check if this declaration belongs to the main + // module (which is being lowered via the AST-to-AST pass), + // and add it to the output if needed. + // + // TODO: This won't work correctly if a type from the AST + // module is used to specialize a generic in the IR module, + // since the declaration would need to precede the specialized + // func... + auto parentModule = findModuleForDecl(typeDeclRef.getDecl()); + if (parentModule && (parentModule == context->mainModuleDecl)) + { + context->outputModuleDecl->Members.Add(ordinaryStructDecl); + } + } + } + + // For memory management reasons, we need to keep a reference to + // the declaration live, no matter what. + context->createdDecls.Add(ordinaryStructDecl); + + UInt elementCounter = 0; + for(auto ee : ordinaryElements) + { + UInt elementIndex = elementCounter++; + + // We will ensure that all the original fields are represented, + // although they may have different types (due to legalization). + // For fields that have *no* ordinary data, we will give them + // a dummy `void` type and rely on downstream passes to not + // actually emit declarations for those fields. + // + // (This helps keeps things simple because both the original + // and modified type will have the same number of fields, so + // we can continue to look up field layouts by index in the + // emit logic) + RefPtr<Type> fieldType = ee.type; + if(!fieldType) + fieldType = context->session->getVoidType(); + + // TODO: shallow clone of modifiers, etc. + + RefPtr<StructField> fieldDecl = new StructField(); + fieldDecl->loc = ee.fieldDeclRef.getDecl()->loc; + fieldDecl->nameAndLoc = ee.fieldDeclRef.getDecl()->nameAndLoc; + fieldDecl->type.type = fieldType; + + fieldDecl->ParentDecl = ordinaryStructDecl; + ordinaryStructDecl->Members.Add(fieldDecl); + + pairElements[elementIndex].ordinaryFieldDeclRef = makeDeclRef(fieldDecl.Ptr()); + + addModifier(fieldDecl, new LegalizedModifier()); + } + + RefPtr<Type> ordinaryStructType = DeclRefType::Create( + context->session, + makeDeclRef(ordinaryStructDecl.Ptr())); + + ordinaryType = LegalType::simple(ordinaryStructType); + } + + LegalType specialType; + if (anySpecial) + { + RefPtr<TuplePseudoType> specialTuple = new TuplePseudoType(); + specialTuple->elements = specialElements; + specialType = LegalType::tuple(specialTuple); + } + + RefPtr<PairInfo> pairInfo; + if (anyOrdinary && anySpecial) + { + pairInfo = new PairInfo(); + pairInfo->elements = pairElements; + } + + return LegalType::pair(ordinaryType, specialType, pairInfo); + } + +}; + +static RefPtr<Type> createBuiltinGenericType( + TypeLegalizationContext* context, + DeclRef<Decl> const& typeDeclRef, + RefPtr<Type> elementType) +{ + // We are going to take the type for the original + // decl-ref and construct a new one that uses + // our new element type as its parameter. + // + // TODO: we should have library code to make + // manipulations like this way easier. + + RefPtr<GenericSubstitution> oldGenericSubst = getGenericSubstitution( + typeDeclRef.substitutions); + SLANG_ASSERT(oldGenericSubst); + + RefPtr<GenericSubstitution> newGenericSubst = new GenericSubstitution(); + + newGenericSubst->outer = oldGenericSubst->outer; + newGenericSubst->genericDecl = oldGenericSubst->genericDecl; + newGenericSubst->args = oldGenericSubst->args; + newGenericSubst->args[0] = elementType; + + auto newDeclRef = DeclRef<Decl>( + typeDeclRef.getDecl(), + newGenericSubst); + + auto newType = DeclRefType::Create( + context->session, + newDeclRef); + + return newType; +} + +// Create a uniform buffer type with a given legalized +// element type. +static LegalType createLegalUniformBufferType( + TypeLegalizationContext* context, + DeclRef<Decl> const& typeDeclRef, + LegalType legalElementType) +{ + switch (legalElementType.flavor) + { + case LegalType::Flavor::simple: + { + // Easy case: we just have a simple element type, + // so we want to create a uniform buffer that wraps it. + return LegalType::simple(createBuiltinGenericType( + context, + typeDeclRef, + legalElementType.getSimple())); + } + break; + + case LegalType::Flavor::implicitDeref: + { + // This is actually an annoying case, because + // we are being asked to convert, e.g.,: + // + // cbuffer Foo { ParameterBlock<Bar> bar; } + // + // into the equivalent of: + // + // cbuffer Foo { Bar bar; } + // + // Which would really require a new `LegalType` that + // would reprerent a resource type with a modified + // element type. + // + // I'm going to attempt to hack this for now. + return LegalType::implicitDeref(createLegalUniformBufferType( + context, + typeDeclRef, + legalElementType.getImplicitDeref()->valueType)); + } + break; + + case LegalType::Flavor::pair: + { + // We assume that the "ordinary" part of things + // will get wrapped in a constant-buffer type, + // and the "special" part needs to be wrapped + // with an `implicitDeref`. + auto pairType = legalElementType.getPair(); + + auto ordinaryType = createLegalUniformBufferType( + context, + typeDeclRef, + pairType->ordinaryType); + auto specialType = LegalType::implicitDeref(pairType->specialType); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // if we have a tuple type, then it must be representing + // the fields that can't be stored in a buffer anyway, + // so we just need to wrap each of them in an `implicitDeref` + + auto elementPseudoTupleType = legalElementType.getTuple(); + + RefPtr<TuplePseudoType> bufferPseudoTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : elementPseudoTupleType->elements) + { + TuplePseudoType::Element newElement; + + newElement.fieldDeclRef = ee.fieldDeclRef; + newElement.type = LegalType::implicitDeref(ee.type); + + bufferPseudoTupleType->elements.Add(newElement); + } + + return LegalType::tuple(bufferPseudoTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + +static LegalType createLegalUniformBufferType( + TypeLegalizationContext* context, + UniformParameterGroupType* uniformBufferType, + LegalType legalElementType) +{ + return createLegalUniformBufferType( + context, + uniformBufferType->declRef, + legalElementType); +} + +// Create a pointer type with a given legalized value type. +static LegalType createLegalPtrType( + TypeLegalizationContext* context, + DeclRef<Decl> const& typeDeclRef, + LegalType legalValueType) +{ + switch (legalValueType.flavor) + { + case LegalType::Flavor::simple: + { + // Easy case: we just have a simple element type, + // so we want to create a uniform buffer that wraps it. + return LegalType::simple(createBuiltinGenericType( + context, + typeDeclRef, + legalValueType.getSimple())); + } + break; + + case LegalType::Flavor::implicitDeref: + { + // We are being asked to create a pointer type to something + // that is implicitly dereferenced, meaning we had: + // + // Ptr(PtrLink(T)) + // + // and now are being asked to make: + // + // Ptr(implicitDeref(LegalT)) + // + // So it seems like we can just create: + // + // implicitDeref(Ptr(LegalT)) + // + // and nobody should really be able to tell the difference, right? + return LegalType::implicitDeref(createLegalPtrType( + context, + typeDeclRef, + legalValueType.getImplicitDeref()->valueType)); + } + break; + + case LegalType::Flavor::pair: + { + // We just need to pointer-ify both sides of the pair. + auto pairType = legalValueType.getPair(); + + auto ordinaryType = createLegalPtrType( + context, + typeDeclRef, + pairType->ordinaryType); + auto specialType = createLegalPtrType( + context, + typeDeclRef, + pairType->specialType); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // Wrap each of the tuple elements up as a pointer. + auto valuePseudoTupleType = legalValueType.getTuple(); + + RefPtr<TuplePseudoType> ptrPseudoTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : valuePseudoTupleType->elements) + { + TuplePseudoType::Element newElement; + + newElement.fieldDeclRef = ee.fieldDeclRef; + newElement.type = createLegalPtrType( + context, + typeDeclRef, + ee.type); + + ptrPseudoTupleType->elements.Add(newElement); + } + + return LegalType::tuple(ptrPseudoTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + +struct LegalTypeWrapper +{ + virtual LegalType wrap(TypeLegalizationContext* context, Type* type) = 0; +}; + +struct ArrayLegalTypeWrapper : LegalTypeWrapper +{ + ArrayExpressionType* arrayType; + + LegalType wrap(TypeLegalizationContext* context, Type* type) + { + return LegalType::simple(context->session->getArrayType( + type, + arrayType->ArrayLength)); + } +}; + +struct BuiltinGenericLegalTypeWrapper : LegalTypeWrapper +{ + DeclRef<Decl> declRef; + + LegalType wrap(TypeLegalizationContext* context, Type* type) + { + return LegalType::simple(createBuiltinGenericType( + context, + declRef, + type)); + } +}; + + +struct ImplicitDerefLegalTypeWrapper : LegalTypeWrapper +{ + LegalType wrap(TypeLegalizationContext*, Type* type) + { + return LegalType::implicitDeref(LegalType::simple(type)); + } +}; + +static LegalType wrapLegalType( + TypeLegalizationContext* context, + LegalType legalType, + LegalTypeWrapper* ordinaryWrapper, + LegalTypeWrapper* specialWrapper) +{ + switch (legalType.flavor) + { + case LegalType::Flavor::simple: + { + return ordinaryWrapper->wrap(context, legalType.getSimple()); + } + break; + + case LegalType::Flavor::implicitDeref: + { + return LegalType::implicitDeref(wrapLegalType( + context, + legalType, + ordinaryWrapper, + specialWrapper)); + } + break; + + case LegalType::Flavor::pair: + { + // We just need to pointer-ify both sides of the pair. + auto pairType = legalType.getPair(); + + auto ordinaryType = wrapLegalType( + context, + pairType->ordinaryType, + ordinaryWrapper, + ordinaryWrapper); + auto specialType = wrapLegalType( + context, + pairType->specialType, + specialWrapper, + specialWrapper); + + return LegalType::pair(ordinaryType, specialType, pairType->pairInfo); + } + + case LegalType::Flavor::tuple: + { + // Wrap each of the tuple elements up as a pointer. + auto tupleType = legalType.getTuple(); + + RefPtr<TuplePseudoType> resultTupleType = new TuplePseudoType(); + + // Wrap all the pseudo-tuple elements with `implicitDeref`, + // since they used to be inside a tuple, but aren't any more. + for (auto ee : tupleType->elements) + { + TuplePseudoType::Element element; + + element.fieldDeclRef = ee.fieldDeclRef; + element.type = wrapLegalType( + context, + ee.type, + ordinaryWrapper, + specialWrapper); + + resultTupleType->elements.Add(element); + } + + return LegalType::tuple(resultTupleType); + } + break; + + default: + SLANG_UNEXPECTED("unknown legal type flavor"); + UNREACHABLE_RETURN(LegalType()); + break; + } +} + + +// Legalize a type, including any nested types +// that it transitively contains. +LegalType legalizeType( + TypeLegalizationContext* context, + Type* type) +{ + if (auto parameterBlockType = type->As<ParameterBlockType>()) + { + // We basically legalize the `ParameterBlock<T>` type + // over to `T`. In order to represent this preoperly, + // we need to be careful to wrap it up in a way that + // tells us to eliminate downstream deferences... + + auto legalElementType = legalizeType(context, + parameterBlockType->getElementType()); + return LegalType::implicitDeref(legalElementType); + } + else if (auto uniformBufferType = type->As<UniformParameterGroupType>()) + { + // We have a `ConstantBuffer<T>` or `TextureBuffer<T>` or + // other pointer-like type that represents uniform parameters. + // We need to pull any resource-type fields out of it, but + // leave the non-resource fields where they are. + + // Legalize the element type to see what we are working with. + auto legalElementType = legalizeType(context, + uniformBufferType->getElementType()); + + switch (legalElementType.flavor) + { + case LegalType::Flavor::simple: + return LegalType::simple(type); + + default: + return createLegalUniformBufferType( + context, + uniformBufferType, + legalElementType); + } + + } + else if (isResourceType(type)) + { + // We assume that any resource types not handled above + // are legal as-is. + return LegalType::simple(type); + } + else if (type->As<BasicExpressionType>()) + { + return LegalType::simple(type); + } + else if (type->As<VectorExpressionType>()) + { + return LegalType::simple(type); + } + else if (type->As<MatrixExpressionType>()) + { + return LegalType::simple(type); + } + else if (auto ptrType = type->As<PtrTypeBase>()) + { + auto legalValueType = legalizeType(context, ptrType->getValueType()); + return createLegalPtrType(context, ptrType->declRef, legalValueType); + } + else if (auto declRefType = type->As<DeclRefType>()) + { + auto declRef = declRefType->declRef; + + LegalType legalType; + if(context->mapDeclRefToLegalType.TryGetValue(declRef, legalType)) + return legalType; + + + if (auto aggTypeDeclRef = declRef.As<AggTypeDecl>()) + { + // Look at the (non-static) fields, and + // see if anything needs to be cleaned up. + // The things that need to be "cleaned up" for + // our purposes are: + // + // - Fields of resource type, or any other future + // type we run into that isn't allowed in + // aggregates for at least some targets + // + // - Fields with types that themselves had to + // get legalized. + // + // If we don't run into any of these, we + // can just use the type as-is. Hooray! + // + // Otherwise, we are effectively going to split + // the type apart and create a `TuplePseudoType`. + // Every field of the original type will be + // represented as an element of this pseudo-type. + // Each element will record its `LegalType`, + // and the original field that it was created from. + // An element will also track whether it contains + // any "ordinary" data, and if so, it will remember + // an element index in a real (AST-level, non-pseudo) + // `TupleType` that is used to bundle together + // such fields. + // + // Storing all the simple fields together like this + // obviously adds complexity to the legalization + // pass, but it has important benefits: + // + // - It avoids creating functions with a very large + // number of parameters (when passing a structure + // with many fields), which might confuse downstream + // compilers. + // + // - It avoids applying AOS->SOA conversion to fields + // that don't actually need it, which is basically + // required if we want type layout to work. + // + // - It ensures that we can actually construct a + // constant-buffer type that wraps a legalized + // aggregate type; the ordinary fields will get + // placed inside a new constant-buffer type, + // while the special ones will get left outside. + // + + TupleTypeBuilder builder; + builder.context = context; + builder.type = type; + builder.typeDeclRef = aggTypeDeclRef; + + + for (auto ff : getMembersOfType<StructField>(aggTypeDeclRef)) + { + builder.addField(ff); + } + + legalType = builder.getResult(); + context->mapDeclRefToLegalType.Add(declRef, legalType); + return legalType; + } + + // TODO: for other declaration-reference types, we really + // need to legalize the types used in substitutions, and + // signal an error if any of them turn out to be non-simple. + // + // The limited cases of types that can handle having non-simple + // types as generic arguments all need to be special-cased here. + // (For example, we can't handle `Texture2D<SomeStructWithTexturesInIt>`. + // + } + else if(auto arrayType = type->As<ArrayExpressionType>()) + { + auto legalElementType = legalizeType( + context, + arrayType->baseType); + + switch (legalElementType.flavor) + { + case LegalType::Flavor::simple: + // Element type didn't need to be legalized, so + // we can just use this type as-is. + return LegalType::simple(type); + + default: + { + ArrayLegalTypeWrapper wrapper; + wrapper.arrayType = arrayType; + + return wrapLegalType( + context, + legalElementType, + &wrapper, + &wrapper); + } + break; + } + + } + + return LegalType::simple(type); +} + +// + +RefPtr<TypeLayout> getDerefTypeLayout( + TypeLayout* typeLayout) +{ + if (!typeLayout) + return nullptr; + + if (auto parameterGroupTypeLayout = dynamic_cast<ParameterGroupTypeLayout*>(typeLayout)) + { + return parameterGroupTypeLayout->elementTypeLayout; + } + + return typeLayout; +} + +RefPtr<VarLayout> getFieldLayout( + TypeLayout* typeLayout, + DeclRef<VarDeclBase> fieldDeclRef) +{ + if (!typeLayout) + return nullptr; + + while(auto arrayTypeLayout = dynamic_cast<ArrayTypeLayout*>(typeLayout)) + { + typeLayout = arrayTypeLayout->elementTypeLayout; + } + + if (auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout)) + { + RefPtr<VarLayout> fieldLayout; + if (structTypeLayout->mapVarToLayout.TryGetValue(fieldDeclRef.getDecl(), fieldLayout)) + return fieldLayout; + } + + return nullptr; +} + +RefPtr<VarLayout> createVarLayout( + LegalVarChain* varChain, + TypeLayout* typeLayout) +{ + if (!typeLayout) + return nullptr; + + // We need to construct a layout for the new variable + // that reflects both the type we have given it, as + // well as all the offset information that has accumulated + // along the chain of parent variables. + + // TODO: this logic needs to propagate through semantics... + + RefPtr<VarLayout> varLayout = new VarLayout(); + varLayout->typeLayout = typeLayout; + + for (auto rr : typeLayout->resourceInfos) + { + auto resInfo = varLayout->findOrAddResourceInfo(rr.kind); + + for (auto vv = varChain; vv; vv = vv->next) + { + if (auto parentResInfo = vv->varLayout->FindResourceInfo(rr.kind)) + { + resInfo->index += parentResInfo->index; + resInfo->space += parentResInfo->space; + } + } + } + + // Some of the parent variables might actually contain offsets + // to the `space` or `set` of the field, and we need to apply + // those to all the nested resource infos. + for (auto vv = varChain; vv; vv = vv->next) + { + auto parentSpaceInfo = vv->varLayout->FindResourceInfo(LayoutResourceKind::RegisterSpace); + if (!parentSpaceInfo) + continue; + + for (auto& rr : varLayout->resourceInfos) + { + if (rr.kind == LayoutResourceKind::RegisterSpace) + { + rr.index += parentSpaceInfo->index; + } + else + { + rr.space += parentSpaceInfo->index; + } + } + } + + return varLayout; +} + +// + +// TODO(tfoley): The code captured here is the logic that used to be +// applied to decide whether or not to desugar aggregate types that +// contain resources. Right now the implementation will *always* legalize +// away such types (since the IR always does this), while the AST-to-AST +// pass would only do it if required (according to the tests below). +// +// For right now this is an academic distinction, since the only project +// using Slang right now enables this tansformation unconditionally, but +// we probably need to re-parent this code back into the `TypeLegalizationContext` +// somewhere. +#if 0 + +bool shouldDesugarTupleTypes = false; +if (getTarget() == CodeGenTarget::GLSL) +{ + // Always desugar this stuff for GLSL, since it doesn't + // support nesting of resources in structs. + // + // TODO: Need a way to make this more fine-grained to + // handle cases where a nested member might be allowed + // due to, e.g., bindless textures. + shouldDesugarTupleTypes = true; +} +else if( shared->compileRequest->compileFlags & SLANG_COMPILE_FLAG_SPLIT_MIXED_TYPES ) +{ + // If the user is directly asking us to do this transformation, + // then obviously we need to do it. + // + // TODO: The way this is defined here means it will even apply to user + // HLSL code (not just code written in Slang). We may want to + // reconsider that choice, and only split things that originated in Slang. + // + shouldDesugarTupleTypes = true; +} + +#endif + +} diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h new file mode 100644 index 000000000..36e4223b6 --- /dev/null +++ b/source/slang/legalize-types.h @@ -0,0 +1,296 @@ +// legalize-types.h +#ifndef SLANG_LEGALIZE_TYPES_H_INCLUDED +#define SLANG_LEGALIZE_TYPES_H_INCLUDED + +// This file and `legalize-types.cpp` implement the core +// logic for taking a `Type` as produced by the front-end, +// and turning it into a suitable representation for use +// on a particular back-end. +// +// The main work applies to aggregate (e.g., `struct`) types, +// since various targets have rules about what is and isn't +// allowed in an aggregate (or where aggregates are allowed +// to be used). +// +// We might completely replace an aggregate `Type` with a +// "pseudo-type" that is just the enumeration of its field +// types (sort of a tuple type) so that a variable declared +// with the original type should be transformed into a +// bunch of individual variables. +// +// Alternatively, we might replace an aggregate type, where +// only *some* of the fields are illegal with a combination +// of an aggregate (containing the legal/legalized fields), +// and some extra tuple-ified fields. + +#include "../core/basic.h" +#include "syntax.h" +#include "type-layout.h" + +namespace Slang +{ + +struct LegalTypeImpl : RefObject +{ +}; +struct ImplicitDerefType; +struct TuplePseudoType; +struct PairPseudoType; +struct PairInfo; + +struct LegalType +{ + enum class Flavor + { + // Nothing: a NULL type + none, + + // A simple type that can be represented directly as a `Type` + simple, + + // Logically, we have a pointer-like type, but we are + // going to represnet it as the pointed-to type + implicitDeref, + + // A compound type was broken apart into its constituent fields, + // so a tuple "pseduo-type" is being used to collect + // those fields together. + tuple, + + // A type has to get split into "ordinary" and "special" parts, + // each of which will be represented with its own `LegalType`. + pair, + }; + + Flavor flavor = Flavor::none; + RefPtr<RefObject> obj; + + static LegalType simple(Type* type) + { + LegalType result; + result.flavor = Flavor::simple; + result.obj = type; + return result; + } + + RefPtr<Type> getSimple() const + { + assert(flavor == Flavor::simple); + return obj.As<Type>(); + } + + static LegalType implicitDeref( + LegalType const& valueType); + + RefPtr<ImplicitDerefType> getImplicitDeref() const + { + assert(flavor == Flavor::implicitDeref); + return obj.As<ImplicitDerefType>(); + } + + static LegalType tuple( + RefPtr<TuplePseudoType> tupleType); + + RefPtr<TuplePseudoType> getTuple() const + { + assert(flavor == Flavor::tuple); + return obj.As<TuplePseudoType>(); + } + + static LegalType pair( + RefPtr<PairPseudoType> pairType); + + static LegalType pair( + LegalType const& ordinaryType, + LegalType const& specialType, + RefPtr<PairInfo> pairInfo); + + RefPtr<PairPseudoType> getPair() const + { + assert(flavor == Flavor::pair); + return obj.As<PairPseudoType>(); + } +}; + +// Represents the pseudo-type of a type that is pointer-like +// (and thus requires dereferencing, even if implicit), but +// was legalized to just use the type of the pointed-type value. +struct ImplicitDerefType : LegalTypeImpl +{ + LegalType valueType; +}; + +// Represents the pseudo-type for a compound type +// that had to be broken apart because it contained +// one or more fields of types that shouldn't be +// allowed in aggregates. +// +// A tuple pseduo-type will have an element for +// each field of the original type, that represents +// the legalization of that field's type. +// +// It optionally also contains an "ordinary" type +// that packs together any per-field data that +// itself has (or contains) an ordinary type. +struct TuplePseudoType : LegalTypeImpl +{ + // Represents one element of the tuple pseudo-type + struct Element + { + // The field that this element replaces + DeclRef<VarDeclBase> fieldDeclRef; + + // The legalized type of the element + LegalType type; + }; + + // All of the elements of the tuple pseduo-type. + List<Element> elements; +}; + +struct PairInfo : RefObject +{ + typedef unsigned int Flags; + enum + { + kFlag_hasOrdinary = 0x1, + kFlag_hasSpecial = 0x2, + kFlag_hasOrdinaryAndSpecial = kFlag_hasOrdinary | kFlag_hasSpecial, + }; + + struct Element + { + // The original field the element represents + DeclRef<Decl> fieldDeclRef; + + // The conceptual type of the field. + // If both the `hasOrdinary` and + // `hasSpecial` bits are set, then + // this is expected to be a + // `LegalType::Flavor::pair` + LegalType type; + + // Is the value represented on + // the ordinary side, the special + // side, or both? + Flags flags; + + // If the type of this element is + // itself a pair type (that is, + // it both `hasOrdinary` and `hasSpecial`) + // then this is the `PairInfo` for that + // pair type: + RefPtr<PairInfo> fieldPairInfo; + + // The actual field decl-ref that needs + // to be used for looking up this element + // in the ordinary type. + DeclRef<Decl> ordinaryFieldDeclRef; + }; + + // For a pair type or value, we need to track + // which fields are on which side(s). + List<Element> elements; + + Element* findElement(DeclRef<Decl> const& fieldDeclRef) + { + for (auto& ee : elements) + { + if(ee.fieldDeclRef.Equals(fieldDeclRef)) + return ⅇ + } + return nullptr; + } +}; + +struct PairPseudoType : LegalTypeImpl +{ + // Any field(s) with ordinary types will + // get captured here (usually as a `fieldRemap` + // type) + LegalType ordinaryType; + + // Any fields with "special" (not ordinary) + // types will get captured here (usually + // with a tuple). + LegalType specialType; + + // The `pairInfo` field helps to tell us which members + // of the original aggregate type appear on which side(s) + // of the new pair type. + RefPtr<PairInfo> pairInfo; +}; + +// + +RefPtr<TypeLayout> getDerefTypeLayout( + TypeLayout* typeLayout); + +RefPtr<VarLayout> getFieldLayout( + TypeLayout* typeLayout, + DeclRef<VarDeclBase> fieldDeclRef); + +// Represents the "chain" of declarations that +// were followed to get to a variable that we +// are now declaring as a leaf variable. +struct LegalVarChain +{ + LegalVarChain* next; + VarLayout* varLayout; +}; + +RefPtr<VarLayout> createVarLayout( + LegalVarChain* varChain, + TypeLayout* typeLayout); + +// + +struct TypeLegalizationContext +{ + /// The overall compilation session (used when + /// constructing types). + Session* session; + + // If the type we are legalizing comes from an + // AST module being lowered via AST-to-AST translation, + // then we want to add any new declaration we create + // to represent it to the appropriate output module. + // We store some fields here to enable that: + RefPtr<ModuleDecl> mainModuleDecl; + RefPtr<ModuleDecl> outputModuleDecl; + + // We also need to know whether the IR is involved + // at all, because if it is, then it will own certain + // declarations instead. + // + // We do this in a slightly silly way by storing a pointer + // to the IR module (if any), and assume that its presence + // or absence is the indicator we need. + IRModule* irModule = nullptr; + + /// A list to retain any AST objects created during type legalization. + List<RefPtr<Decl>> createdDecls; + + /// A mapping from declaration references to the resulting + /// legalized type. + /// + /// For declaration-reference types, this map can be used + /// to cache a legalization so that it will be re-used + /// for equivalent declaration references (and so avoid + /// emitting declarations of legalized `struct` types + /// multiple times). + Dictionary<DeclRef<Decl>, LegalType> mapDeclRefToLegalType; +}; + + +LegalType legalizeType( + TypeLegalizationContext* context, + Type* type); + +/// Try to find the module that (recursively) contains a given declaration. +ModuleDecl* findModuleForDecl( + Decl* decl); + +} + +#endif diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 5ec175668..1e43a4d31 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -947,14 +947,6 @@ LoweredTypeInfo lowerType( return visitor.dispatchType(type); } -#if 0 -struct LoweringVisitor - : ExprVisitor<LoweringVisitor, LoweredExpr> - , StmtVisitor<LoweringVisitor, void> - , DeclVisitor<LoweringVisitor, LoweredDecl> - , ValVisitor<LoweringVisitor, RefPtr<Val>, RefPtr<Type>> -#endif - LoweredValInfo createVar( IRGenContext* context, RefPtr<Type> type, @@ -982,6 +974,8 @@ void addArgs( case LoweredValInfo::Flavor::Simple: case LoweredValInfo::Flavor::Ptr: case LoweredValInfo::Flavor::SwizzledLValue: + case LoweredValInfo::Flavor::BoundSubscript: + case LoweredValInfo::Flavor::BoundMember: args.Add(getSimpleVal(context, argInfo)); break; @@ -2535,6 +2529,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> SLANG_UNIMPLEMENTED_X("decl catch-all"); } + LoweredValInfo visitImportDecl(ImportDecl* /*decl*/) + { + return LoweredValInfo(); + } + LoweredValInfo visitEmptyDecl(EmptyDecl* /*decl*/) { return LoweredValInfo(); diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index e2db1b456..721072b82 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -124,14 +124,6 @@ namespace Slang { emitQualifiedName(context, declRefType->declRef); } - else if (auto tupleType = dynamic_cast<FilteredTupleType*>(type)) - { - // TODO: this doesn't handle the possibility of multiple different - // filtered versions of the same type... - emitRaw(context, "t"); - emitType(context, tupleType->originalType); - emitRaw(context, "_"); - } else { SLANG_UNEXPECTED("unimplemented case in mangling"); diff --git a/source/slang/modifier-defs.h b/source/slang/modifier-defs.h index 0f92fcd61..cd8f2524b 100644 --- a/source/slang/modifier-defs.h +++ b/source/slang/modifier-defs.h @@ -318,21 +318,9 @@ SYNTAX_CLASS(ComputedLayoutModifier, Modifier) FIELD(RefPtr<Layout>, layout) END_SYNTAX_CLASS() -// A modifier attached to types during lowering, to indicate that they -// are logically a "tuple" type -SYNTAX_CLASS(TupleTypeModifier, Modifier) - FIELD_INIT(AggTypeDecl*, decl, nullptr) - FIELD_INIT(bool, hasAnyNonTupleFields, false) -END_SYNTAX_CLASS() - -SYNTAX_CLASS(TupleFieldModifier, Modifier) - FIELD_INIT(VarDeclBase*, decl, nullptr) - FIELD_INIT(bool, hasAnyNonTupleFields, false) - FIELD_INIT(bool, isNestedTuple, false) -END_SYNTAX_CLASS() SYNTAX_CLASS(TupleVarModifier, Modifier) - FIELD_INIT(TupleFieldModifier*, tupleField, nullptr) +// FIELD_INIT(TupleFieldModifier*, tupleField, nullptr) END_SYNTAX_CLASS() // A modifier to indicate that a constructor/initializer can be used @@ -342,3 +330,8 @@ SYNTAX_CLASS(ImplicitConversionModifier, Modifier) // The conversion cost, used to rank conversions FIELD(ConversionCost, cost) END_SYNTAX_CLASS() + +// A marker modifier used to indicate that a declaration was created as +// part of type legalization. +SIMPLE_SYNTAX_CLASS(LegalizedModifier, Modifier) + diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index d0e4e840e..9835640e0 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -179,6 +179,7 @@ <ClInclude Include="ir-inst-defs.h" /> <ClInclude Include="ir-insts.h" /> <ClInclude Include="ir.h" /> + <ClInclude Include="legalize-types.h" /> <ClInclude Include="lexer.h" /> <ClInclude Include="lookup.h" /> <ClInclude Include="lower-to-ir.h" /> @@ -217,6 +218,7 @@ <ClCompile Include="emit.cpp" /> <ClCompile Include="ir-legalize-types.cpp" /> <ClCompile Include="ir.cpp" /> + <ClCompile Include="legalize-types.cpp" /> <ClCompile Include="lexer.cpp" /> <ClCompile Include="lookup.cpp" /> <ClCompile Include="lower-to-ir.cpp" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index c93934fff..f207c6dcd 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -43,6 +43,7 @@ <ClInclude Include="vm.h" /> <ClInclude Include="mangle.h" /> <ClInclude Include="ast-legalize.h" /> + <ClInclude Include="legalize-types.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="check.cpp" /> @@ -72,6 +73,7 @@ <ClCompile Include="dxc-support.cpp" /> <ClCompile Include="ir-legalize-types.cpp" /> <ClCompile Include="ast-legalize.cpp" /> + <ClCompile Include="legalize-types.cpp" /> </ItemGroup> <ItemGroup> <CustomBuild Include="core.meta.slang" /> diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index f8237359d..dd08fa2b1 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -321,6 +321,7 @@ void Type::accept(IValVisitor* visitor, void* extra) IntVal* elementCount) { RefPtr<ArrayExpressionType> arrayType = new ArrayExpressionType(); + arrayType->setSession(this); arrayType->baseType = elementType; arrayType->ArrayLength = elementCount; return arrayType; @@ -865,8 +866,16 @@ void Type::accept(IValVisitor* visitor, void* extra) int NamedExpressionType::GetHashCode() { - SLANG_UNEXPECTED("unreachable"); - UNREACHABLE_RETURN(0); + // Type equality is based on comparing canonical types, + // so the hash code for a type needs to come from the + // canonical version of the type. This really means + // that `Type::GetHashCode()` should dispatch out to + // something like `Type::GetHashCodeImpl()` on the + // canonical version of a type, but it is less invasive + // for now (and hopefully equivalent) to just have any + // named types automaticlaly route hash-code requests + // to their canonical type. + return GetCanonicalType()->GetHashCode(); } // FuncType @@ -1875,118 +1884,4 @@ void Type::accept(IValVisitor* visitor, void* extra) insertGlobalGenericSubstitutions(newSubst, subst, ioDiff); return newSubst; } - - // FilteredTupleType - - String FilteredTupleType::ToString() - { - StringBuilder sb; - sb.append(originalType->ToString()); - sb.append("{"); - bool first = true; - for (auto ee : elements) - { - if (!ee.type) - continue; - - if (!first) sb.append(", "); - - sb.append(ee.fieldDeclRef.GetName()->text); - sb.append(":"); - sb.append(ee.type->ToString()); - - first = false; - } - sb.append("}"); - return sb.ProduceString(); - } - - RefPtr<Val> FilteredTupleType::SubstituteImpl(Substitutions* subst, int* ioDiff) - { - int diff = 0; - auto substOriginalType = originalType->SubstituteImpl(subst, &diff).As<Type>(); - - List<Element> substElements; - for (auto ee : elements) - { - Element substElement; - substElement.fieldDeclRef = ee.fieldDeclRef.SubstituteImpl(subst, &diff); - substElement.type = ee.type->SubstituteImpl(subst, &diff).As<Type>(); - substElements.Add(substElement); - } - - if (!diff) - return this; - - (*ioDiff)++; - RefPtr<FilteredTupleType> substType = new FilteredTupleType(); - substType->setSession(session); - substType->originalType = substOriginalType; - substType->elements = substElements; - return substType; - } - - bool FilteredTupleType::EqualsImpl(Type * type) - { - auto tupleType = type->As<FilteredTupleType>(); - if (!tupleType) - return false; - - if (!originalType->Equals(tupleType->originalType)) - return false; - - auto elementCount = elements.Count(); - if (tupleType->elements.Count() != elementCount) - return false; - - for (UInt ee = 0; ee < elementCount; ee++) - { - if (!elements[ee].type || !tupleType->elements[ee].type) - { - if (!elements[ee].type != !tupleType->elements[ee].type) - return false; - - continue; - } - - if (!elements[ee].fieldDeclRef.Equals(tupleType->elements[ee].fieldDeclRef)) - return false; - - if (!elements[ee].type->Equals(tupleType->elements[ee].type)) - return false; - } - return true; - } - - int FilteredTupleType::GetHashCode() - { - int hash = (int)(typeid(this).hash_code()); - hash = combineHash(hash, - originalType->GetHashCode()); - for (auto ee : elements) - { - hash = combineHash(hash, - ee.fieldDeclRef.GetHashCode()); - hash = combineHash(hash, - ee.type->GetHashCode()); - } - return hash; - } - - Type* FilteredTupleType::CreateCanonicalType() - { - RefPtr<FilteredTupleType> canTupleType = new FilteredTupleType(); - canTupleType->setSession(session); - canTupleType->originalType = originalType->GetCanonicalType(); - for (auto ee : elements) - { - Element element; - element.fieldDeclRef = ee.fieldDeclRef; - element.type = ee.type ? ee.type->GetCanonicalType() : nullptr; - - canTupleType->elements.Add(element); - } - getSession()->canonicalTypes.Add(canTupleType); - return canTupleType; - } } diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index 72bf6fe4c..da6e27d17 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -493,37 +493,3 @@ protected: virtual Type* CreateCanonicalType() override; ) END_SYNTAX_CLASS() - -// A type created to represent the result of filtering -// the fields of an aggregate type. -SYNTAX_CLASS(FilteredTupleType, Type) -RAW( - struct Element - { - // The original field this element represents - DeclRef<VarDeclBase> fieldDeclRef; - - // The type being used for the new field - RefPtr<Type> type; - }; -) - - FIELD(RefPtr<Type>, originalType); - FIELD(List<Element>, elements); - -RAW( - FilteredTupleType() - {} - - RefPtr<Type> getOriginalType() const { return originalType; } - List<Element> const& getElements() const { return elements; } - virtual String ToString() override; - -protected: - virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; - virtual bool EqualsImpl(Type * type) override; - virtual int GetHashCode() override; - virtual Type* CreateCanonicalType() override; -) - -END_SYNTAX_CLASS() |
