diff options
Diffstat (limited to 'source')
35 files changed, 12378 insertions, 12499 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 6ff8efe9e..b3e1baf79 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -7,3169 +7,3112 @@ namespace Slang { - namespace Compiler + bool IsNumeric(BaseType t) { - bool IsNumeric(BaseType t) - { - return t == BaseType::Int || t == BaseType::Float || t == BaseType::UInt; - } - - String TranslateHLSLTypeNames(String name) - { - if (name == "float2" || name == "half2") - return "vec2"; - else if (name == "float3" || name == "half3") - return "vec3"; - else if (name == "float4" || name == "half4") - return "vec4"; - else if (name == "half") - return "float"; - else if (name == "int2") - return "ivec2"; - else if (name == "int3") - return "ivec3"; - else if (name == "int4") - return "ivec4"; - else if (name == "uint2") - return "uvec2"; - else if (name == "uint3") - return "uvec3"; - else if (name == "uint4") - return "uvec4"; - else if (name == "float3x3" || name == "half3x3") - return "mat3"; - else if (name == "float4x4" || name == "half4x4") - return "mat4"; - else - return name; - } + return t == BaseType::Int || t == BaseType::Float || t == BaseType::UInt; + } - class SemanticsVisitor : public SyntaxVisitor - { - ProgramSyntaxNode * program = nullptr; - FunctionSyntaxNode * function = nullptr; - CompileOptions const* options = nullptr; - CompileRequest* request = nullptr; + String TranslateHLSLTypeNames(String name) + { + if (name == "float2" || name == "half2") + return "vec2"; + else if (name == "float3" || name == "half3") + return "vec3"; + else if (name == "float4" || name == "half4") + return "vec4"; + else if (name == "half") + return "float"; + else if (name == "int2") + return "ivec2"; + else if (name == "int3") + return "ivec3"; + else if (name == "int4") + return "ivec4"; + else if (name == "uint2") + return "uvec2"; + else if (name == "uint3") + return "uvec3"; + else if (name == "uint4") + return "uvec4"; + else if (name == "float3x3" || name == "half3x3") + return "mat3"; + else if (name == "float4x4" || name == "half4x4") + return "mat4"; + else + return name; + } - // lexical outer statements - List<StatementSyntaxNode*> outerStmts; - public: - SemanticsVisitor( - DiagnosticSink * pErr, - CompileOptions const& options, - CompileRequest* request) - : SyntaxVisitor(pErr) - , options(&options) - , request(request) - { - } + class SemanticsVisitor : public SyntaxVisitor + { + ProgramSyntaxNode * program = nullptr; + FunctionSyntaxNode * function = nullptr; + CompileOptions const* options = nullptr; + CompileRequest* request = nullptr; + + // lexical outer statements + List<StatementSyntaxNode*> outerStmts; + public: + SemanticsVisitor( + DiagnosticSink * pErr, + CompileOptions const& options, + CompileRequest* request) + : SyntaxVisitor(pErr) + , options(&options) + , request(request) + { + } - CompileOptions const& getOptions() { return *options; } + CompileOptions const& getOptions() { return *options; } - public: - // Translate Types - RefPtr<ExpressionType> typeResult; - RefPtr<ExpressionSyntaxNode> TranslateTypeNodeImpl(const RefPtr<ExpressionSyntaxNode> & node) - { - if (!node) return nullptr; - auto expr = node->Accept(this).As<ExpressionSyntaxNode>(); - expr = ExpectATypeRepr(expr); - return expr; - } - RefPtr<ExpressionType> ExtractTypeFromTypeRepr(const RefPtr<ExpressionSyntaxNode>& typeRepr) + public: + // Translate Types + RefPtr<ExpressionType> typeResult; + RefPtr<ExpressionSyntaxNode> TranslateTypeNodeImpl(const RefPtr<ExpressionSyntaxNode> & node) + { + if (!node) return nullptr; + auto expr = node->Accept(this).As<ExpressionSyntaxNode>(); + expr = ExpectATypeRepr(expr); + return expr; + } + RefPtr<ExpressionType> ExtractTypeFromTypeRepr(const RefPtr<ExpressionSyntaxNode>& typeRepr) + { + if (!typeRepr) return nullptr; + if (auto typeType = typeRepr->Type->As<TypeType>()) { - if (!typeRepr) return nullptr; - if (auto typeType = typeRepr->Type->As<TypeType>()) - { - return typeType->type; - } - return ExpressionType::Error; + return typeType->type; } - RefPtr<ExpressionType> TranslateTypeNode(const RefPtr<ExpressionSyntaxNode> & node) + return ExpressionType::Error; + } + RefPtr<ExpressionType> TranslateTypeNode(const RefPtr<ExpressionSyntaxNode> & node) + { + if (!node) return nullptr; + auto typeRepr = TranslateTypeNodeImpl(node); + return ExtractTypeFromTypeRepr(typeRepr); + } + TypeExp TranslateTypeNode(TypeExp const& typeExp) + { + // HACK(tfoley): It seems that in some cases we end up re-checking + // syntax that we've already checked. We need to root-cause that + // issue, but for now a quick fix in this case is to early + // exist if we've already got a type associated here: + if (typeExp.type) { - if (!node) return nullptr; - auto typeRepr = TranslateTypeNodeImpl(node); - return ExtractTypeFromTypeRepr(typeRepr); + return typeExp; } - TypeExp TranslateTypeNode(TypeExp const& typeExp) - { - // HACK(tfoley): It seems that in some cases we end up re-checking - // syntax that we've already checked. We need to root-cause that - // issue, but for now a quick fix in this case is to early - // exist if we've already got a type associated here: - if (typeExp.type) - { - return typeExp; - } - auto typeRepr = TranslateTypeNodeImpl(typeExp.exp); + auto typeRepr = TranslateTypeNodeImpl(typeExp.exp); - TypeExp result; - result.exp = typeRepr; - result.type = ExtractTypeFromTypeRepr(typeRepr); - return result; - } + TypeExp result; + result.exp = typeRepr; + result.type = ExtractTypeFromTypeRepr(typeRepr); + return result; + } - RefPtr<ExpressionSyntaxNode> ConstructDeclRefExpr( - DeclRef declRef, - RefPtr<ExpressionSyntaxNode> baseExpr, - RefPtr<ExpressionSyntaxNode> originalExpr) + RefPtr<ExpressionSyntaxNode> ConstructDeclRefExpr( + DeclRef declRef, + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + if (baseExpr) + { + auto expr = new MemberExpressionSyntaxNode(); + expr->Position = originalExpr->Position; + expr->BaseExpression = baseExpr; + expr->MemberName = declRef.GetName(); + expr->Type = GetTypeForDeclRef(declRef); + expr->declRef = declRef; + return expr; + } + else { - if (baseExpr) - { - auto expr = new MemberExpressionSyntaxNode(); - expr->Position = originalExpr->Position; - expr->BaseExpression = baseExpr; - expr->MemberName = declRef.GetName(); - expr->Type = GetTypeForDeclRef(declRef); - expr->declRef = declRef; - return expr; - } - else - { - auto expr = new VarExpressionSyntaxNode(); - expr->Position = originalExpr->Position; - expr->Variable = declRef.GetName(); - expr->Type = GetTypeForDeclRef(declRef); - expr->declRef = declRef; - return expr; - } + auto expr = new VarExpressionSyntaxNode(); + expr->Position = originalExpr->Position; + expr->Variable = declRef.GetName(); + expr->Type = GetTypeForDeclRef(declRef); + expr->declRef = declRef; + return expr; } + } - RefPtr<ExpressionSyntaxNode> ConstructDerefExpr( - RefPtr<ExpressionSyntaxNode> base, - RefPtr<ExpressionSyntaxNode> originalExpr) - { - auto ptrLikeType = base->Type->As<PointerLikeType>(); - assert(ptrLikeType); + RefPtr<ExpressionSyntaxNode> ConstructDerefExpr( + RefPtr<ExpressionSyntaxNode> base, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + auto ptrLikeType = base->Type->As<PointerLikeType>(); + assert(ptrLikeType); - auto derefExpr = new DerefExpr(); - derefExpr->Position = originalExpr->Position; - derefExpr->base = base; - derefExpr->Type = ptrLikeType->elementType; + auto derefExpr = new DerefExpr(); + derefExpr->Position = originalExpr->Position; + derefExpr->base = base; + derefExpr->Type = ptrLikeType->elementType; - // TODO(tfoley): handle l-value status here + // TODO(tfoley): handle l-value status here - return derefExpr; - } + return derefExpr; + } - RefPtr<ExpressionSyntaxNode> ConstructLookupResultExpr( - LookupResultItem const& item, - RefPtr<ExpressionSyntaxNode> baseExpr, - RefPtr<ExpressionSyntaxNode> originalExpr) + RefPtr<ExpressionSyntaxNode> ConstructLookupResultExpr( + LookupResultItem const& item, + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + // If we collected any breadcrumbs, then these represent + // additional segments of the lookup path that we need + // to expand here. + auto bb = baseExpr; + for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) { - // If we collected any breadcrumbs, then these represent - // additional segments of the lookup path that we need - // to expand here. - auto bb = baseExpr; - for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) + switch (breadcrumb->kind) { - switch (breadcrumb->kind) - { - case LookupResultItem::Breadcrumb::Kind::Member: - bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, originalExpr); - break; - case LookupResultItem::Breadcrumb::Kind::Deref: - bb = ConstructDerefExpr(bb, originalExpr); - break; - default: - SLANG_UNREACHABLE("all cases handle"); - } + case LookupResultItem::Breadcrumb::Kind::Member: + bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, originalExpr); + break; + case LookupResultItem::Breadcrumb::Kind::Deref: + bb = ConstructDerefExpr(bb, originalExpr); + break; + default: + SLANG_UNREACHABLE("all cases handle"); } - - return ConstructDeclRefExpr(item.declRef, bb, originalExpr); } - RefPtr<ExpressionSyntaxNode> createLookupResultExpr( - LookupResult const& lookupResult, - RefPtr<ExpressionSyntaxNode> baseExpr, - RefPtr<ExpressionSyntaxNode> originalExpr) + return ConstructDeclRefExpr(item.declRef, bb, originalExpr); + } + + RefPtr<ExpressionSyntaxNode> createLookupResultExpr( + LookupResult const& lookupResult, + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + if (lookupResult.isOverloaded()) { - if (lookupResult.isOverloaded()) - { - auto overloadedExpr = new OverloadedExpr(); - overloadedExpr->Position = originalExpr->Position; - overloadedExpr->Type = ExpressionType::Overloaded; - overloadedExpr->base = baseExpr; - overloadedExpr->lookupResult2 = lookupResult; - return overloadedExpr; - } - else - { - return ConstructLookupResultExpr(lookupResult.item, baseExpr, originalExpr); - } + auto overloadedExpr = new OverloadedExpr(); + overloadedExpr->Position = originalExpr->Position; + overloadedExpr->Type = ExpressionType::Overloaded; + overloadedExpr->base = baseExpr; + overloadedExpr->lookupResult2 = lookupResult; + return overloadedExpr; } - - RefPtr<ExpressionSyntaxNode> ResolveOverloadedExpr(RefPtr<OverloadedExpr> overloadedExpr, LookupMask mask) + else { - auto lookupResult = overloadedExpr->lookupResult2; - assert(lookupResult.isValid() && lookupResult.isOverloaded()); - - // Take the lookup result we had, and refine it based on what is expected in context. - lookupResult = refineLookup(lookupResult, mask); - - if (!lookupResult.isValid()) - { - // If we didn't find any symbols after filtering, then just - // use the original and report errors that way - return overloadedExpr; - } - - if (lookupResult.isOverloaded()) - { - // We had an ambiguity anyway, so report it. - getSink()->diagnose(overloadedExpr, Diagnostics::ambiguousReference, lookupResult.items[0].declRef.GetName()); + return ConstructLookupResultExpr(lookupResult.item, baseExpr, originalExpr); + } + } - for(auto item : lookupResult.items) - { - String declString = getDeclSignatureString(item); - getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString); - } + RefPtr<ExpressionSyntaxNode> ResolveOverloadedExpr(RefPtr<OverloadedExpr> overloadedExpr, LookupMask mask) + { + auto lookupResult = overloadedExpr->lookupResult2; + assert(lookupResult.isValid() && lookupResult.isOverloaded()); - // TODO(tfoley): should we construct a new ErrorExpr here? - overloadedExpr->Type = ExpressionType::Error; - return overloadedExpr; - } + // Take the lookup result we had, and refine it based on what is expected in context. + lookupResult = refineLookup(lookupResult, mask); - // otherwise, we had a single decl and it was valid, hooray! - return ConstructLookupResultExpr(lookupResult.item, overloadedExpr->base, overloadedExpr); + if (!lookupResult.isValid()) + { + // If we didn't find any symbols after filtering, then just + // use the original and report errors that way + return overloadedExpr; } - RefPtr<ExpressionSyntaxNode> ExpectATypeRepr(RefPtr<ExpressionSyntaxNode> expr) + if (lookupResult.isOverloaded()) { - if (auto overloadedExpr = expr.As<OverloadedExpr>()) - { - expr = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); - } + // We had an ambiguity anyway, so report it. + getSink()->diagnose(overloadedExpr, Diagnostics::ambiguousReference, lookupResult.items[0].declRef.GetName()); - if (auto typeType = expr->Type.type->As<TypeType>()) + for(auto item : lookupResult.items) { - return expr; - } - else if (expr->Type.type->Equals(ExpressionType::Error)) - { - return expr; + String declString = getDeclSignatureString(item); + getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString); } - getSink()->diagnose(expr, Diagnostics::unimplemented, "expected a type"); - // TODO: construct some kind of `ErrorExpr`? - return expr; + // TODO(tfoley): should we construct a new ErrorExpr here? + overloadedExpr->Type = ExpressionType::Error; + return overloadedExpr; } - RefPtr<ExpressionType> ExpectAType(RefPtr<ExpressionSyntaxNode> expr) + // otherwise, we had a single decl and it was valid, hooray! + return ConstructLookupResultExpr(lookupResult.item, overloadedExpr->base, overloadedExpr); + } + + RefPtr<ExpressionSyntaxNode> ExpectATypeRepr(RefPtr<ExpressionSyntaxNode> expr) + { + if (auto overloadedExpr = expr.As<OverloadedExpr>()) { - auto typeRepr = ExpectATypeRepr(expr); - if (auto typeType = typeRepr->Type->As<TypeType>()) - { - return typeType->type; - } - return ExpressionType::Error; + expr = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); } - RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<ExpressionSyntaxNode> exp) + if (auto typeType = expr->Type.type->As<TypeType>()) { - return ExpectAType(exp); + return expr; } + else if (expr->Type.type->Equals(ExpressionType::Error)) + { + return expr; + } + + getSink()->diagnose(expr, Diagnostics::unimplemented, "expected a type"); + // TODO: construct some kind of `ErrorExpr`? + return expr; + } - RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<ExpressionSyntaxNode> exp) + RefPtr<ExpressionType> ExpectAType(RefPtr<ExpressionSyntaxNode> expr) + { + auto typeRepr = ExpectATypeRepr(expr); + if (auto typeType = typeRepr->Type->As<TypeType>()) { - return CheckIntegerConstantExpression(exp.Ptr()); + return typeType->type; } + return ExpressionType::Error; + } + + RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<ExpressionSyntaxNode> exp) + { + return ExpectAType(exp); + } + + RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<ExpressionSyntaxNode> exp) + { + return CheckIntegerConstantExpression(exp.Ptr()); + } - RefPtr<Val> ExtractGenericArgVal(RefPtr<ExpressionSyntaxNode> exp) + RefPtr<Val> ExtractGenericArgVal(RefPtr<ExpressionSyntaxNode> exp) + { + if (auto overloadedExpr = exp.As<OverloadedExpr>()) { - if (auto overloadedExpr = exp.As<OverloadedExpr>()) - { - // assume that if it is overloaded, we want a type - exp = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); - } + // assume that if it is overloaded, we want a type + exp = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); + } - if (auto typeType = exp->Type->As<TypeType>()) - { - return typeType->type; - } - else if (exp->Type->Equals(ExpressionType::Error)) - { - return exp->Type.type; - } - else - { - return ExtractGenericArgInteger(exp); - } + if (auto typeType = exp->Type->As<TypeType>()) + { + return typeType->type; + } + else if (exp->Type->Equals(ExpressionType::Error)) + { + return exp->Type.type; } + else + { + return ExtractGenericArgInteger(exp); + } + } - // Construct a type reprsenting the instantiation of - // the given generic declaration for the given arguments. - // The arguments should already be checked against - // the declaration. - RefPtr<ExpressionType> InstantiateGenericType( - GenericDeclRef genericDeclRef, - List<RefPtr<ExpressionSyntaxNode>> const& args) + // Construct a type reprsenting the instantiation of + // the given generic declaration for the given arguments. + // The arguments should already be checked against + // the declaration. + RefPtr<ExpressionType> InstantiateGenericType( + GenericDeclRef genericDeclRef, + List<RefPtr<ExpressionSyntaxNode>> const& args) + { + RefPtr<Substitutions> subst = new Substitutions(); + subst->genericDecl = genericDeclRef.GetDecl(); + subst->outer = genericDeclRef.substitutions; + + for (auto argExpr : args) { - RefPtr<Substitutions> subst = new Substitutions(); - subst->genericDecl = genericDeclRef.GetDecl(); - subst->outer = genericDeclRef.substitutions; + subst->args.Add(ExtractGenericArgVal(argExpr)); + } - for (auto argExpr : args) - { - subst->args.Add(ExtractGenericArgVal(argExpr)); - } + DeclRef innerDeclRef; + innerDeclRef.decl = genericDeclRef.GetInner(); + innerDeclRef.substitutions = subst; - DeclRef innerDeclRef; - innerDeclRef.decl = genericDeclRef.GetInner(); - innerDeclRef.substitutions = subst; + return DeclRefType::Create(innerDeclRef); + } - return DeclRefType::Create(innerDeclRef); + // Make sure a declaration has been checked, so we can refer to it. + // Note that this may lead to us recursively invoking checking, + // so this may not be the best way to handle things. + void EnsureDecl(RefPtr<Decl> decl, DeclCheckState state = DeclCheckState::CheckedHeader) + { + if (decl->IsChecked(state)) return; + if (decl->checkState == DeclCheckState::CheckingHeader) + { + // We tried to reference the same declaration while checking it! + throw "circularity"; } - // Make sure a declaration has been checked, so we can refer to it. - // Note that this may lead to us recursively invoking checking, - // so this may not be the best way to handle things. - void EnsureDecl(RefPtr<Decl> decl, DeclCheckState state = DeclCheckState::CheckedHeader) + if (DeclCheckState::CheckingHeader > decl->checkState) { - if (decl->IsChecked(state)) return; - if (decl->checkState == DeclCheckState::CheckingHeader) - { - // We tried to reference the same declaration while checking it! - throw "circularity"; - } - - if (DeclCheckState::CheckingHeader > decl->checkState) - { - decl->SetCheckState(DeclCheckState::CheckingHeader); - } + decl->SetCheckState(DeclCheckState::CheckingHeader); + } - // TODO: not all of the `Visit` cases are ready to - // handle this being called on-the-fly - decl->Accept(this); + // TODO: not all of the `Visit` cases are ready to + // handle this being called on-the-fly + decl->Accept(this); - decl->SetCheckState(DeclCheckState::Checked); - } + decl->SetCheckState(DeclCheckState::Checked); + } - void EnusreAllDeclsRec(RefPtr<Decl> decl) + void EnusreAllDeclsRec(RefPtr<Decl> decl) + { + EnsureDecl(decl, DeclCheckState::Checked); + if (auto containerDecl = decl.As<ContainerDecl>()) { - EnsureDecl(decl, DeclCheckState::Checked); - if (auto containerDecl = decl.As<ContainerDecl>()) + for (auto m : containerDecl->Members) { - for (auto m : containerDecl->Members) - { - EnusreAllDeclsRec(m); - } + EnusreAllDeclsRec(m); } } + } - // A "proper" type is one that can be used as the type of an expression. - // Put simply, it can be a concrete type like `int`, or a generic - // type that is applied to arguments, like `Texture2D<float4>`. - // The type `void` is also a proper type, since we can have expressions - // that return a `void` result (e.g., many function calls). - // - // A "non-proper" type is any type that can't actually have values. - // A simple example of this in C++ is `std::vector` - you can't have - // a value of this type. - // - // Part of what this function does is give errors if somebody tries - // to use a non-proper type as the type of a variable (or anything - // else that needs a proper type). - // - // The other thing it handles is the fact that HLSL lets you use - // the name of a non-proper type, and then have the compiler fill - // in the default values for its type arguments (e.g., a variable - // given type `Texture2D` will actually have type `Texture2D<float4>`). - bool CoerceToProperTypeImpl(TypeExp const& typeExp, RefPtr<ExpressionType>* outProperType) - { - ExpressionType* type = typeExp.type.Ptr(); - if (auto genericDeclRefType = type->As<GenericDeclRefType>()) - { - // We are using a reference to a generic declaration as a concrete - // type. This means we should substitute in any default parameter values - // if they are available. - // - // TODO(tfoley): A more expressive type system would substitute in - // "fresh" variables and then solve for their values... - // + // A "proper" type is one that can be used as the type of an expression. + // Put simply, it can be a concrete type like `int`, or a generic + // type that is applied to arguments, like `Texture2D<float4>`. + // The type `void` is also a proper type, since we can have expressions + // that return a `void` result (e.g., many function calls). + // + // A "non-proper" type is any type that can't actually have values. + // A simple example of this in C++ is `std::vector` - you can't have + // a value of this type. + // + // Part of what this function does is give errors if somebody tries + // to use a non-proper type as the type of a variable (or anything + // else that needs a proper type). + // + // The other thing it handles is the fact that HLSL lets you use + // the name of a non-proper type, and then have the compiler fill + // in the default values for its type arguments (e.g., a variable + // given type `Texture2D` will actually have type `Texture2D<float4>`). + bool CoerceToProperTypeImpl(TypeExp const& typeExp, RefPtr<ExpressionType>* outProperType) + { + ExpressionType* type = typeExp.type.Ptr(); + if (auto genericDeclRefType = type->As<GenericDeclRefType>()) + { + // We are using a reference to a generic declaration as a concrete + // type. This means we should substitute in any default parameter values + // if they are available. + // + // TODO(tfoley): A more expressive type system would substitute in + // "fresh" variables and then solve for their values... + // - auto genericDeclRef = genericDeclRefType->GetDeclRef(); - EnsureDecl(genericDeclRef.decl); - List<RefPtr<ExpressionSyntaxNode>> args; - for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) + auto genericDeclRef = genericDeclRefType->GetDeclRef(); + EnsureDecl(genericDeclRef.decl); + List<RefPtr<ExpressionSyntaxNode>> args; + for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) + { + if (auto typeParam = member.As<GenericTypeParamDecl>()) { - if (auto typeParam = member.As<GenericTypeParamDecl>()) + if (!typeParam->initType.exp) { - if (!typeParam->initType.exp) + if (outProperType) { - if (outProperType) - { - getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); - *outProperType = ExpressionType::Error; - } - return false; + getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); + *outProperType = ExpressionType::Error; } - - // TODO: this is one place where syntax should get cloned! - if(outProperType) - args.Add(typeParam->initType.exp); + return false; } - else if (auto valParam = member.As<GenericValueParamDecl>()) + + // TODO: this is one place where syntax should get cloned! + if(outProperType) + args.Add(typeParam->initType.exp); + } + else if (auto valParam = member.As<GenericValueParamDecl>()) + { + if (!valParam->Expr) { - if (!valParam->Expr) + if (outProperType) { - if (outProperType) - { - getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); - *outProperType = ExpressionType::Error; - } - return false; + getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); + *outProperType = ExpressionType::Error; } - - // TODO: this is one place where syntax should get cloned! - if(outProperType) - args.Add(valParam->Expr); - } - else - { - // ignore non-parameter members + return false; } - } - if (outProperType) + // TODO: this is one place where syntax should get cloned! + if(outProperType) + args.Add(valParam->Expr); + } + else { - *outProperType = InstantiateGenericType(genericDeclRef, args); + // ignore non-parameter members } - return true; } - else + + if (outProperType) { - // default case: we expect this to already be a proper type - if (outProperType) - { - *outProperType = type; - } - return true; + *outProperType = InstantiateGenericType(genericDeclRef, args); } + return true; + } + else + { + // default case: we expect this to already be a proper type + if (outProperType) + { + *outProperType = type; + } + return true; } + } - TypeExp CoerceToProperType(TypeExp const& typeExp) - { - TypeExp result = typeExp; - CoerceToProperTypeImpl(typeExp, &result.type); - return result; - } + TypeExp CoerceToProperType(TypeExp const& typeExp) + { + TypeExp result = typeExp; + CoerceToProperTypeImpl(typeExp, &result.type); + return result; + } - bool CanCoerceToProperType(TypeExp const& typeExp) - { - return CoerceToProperTypeImpl(typeExp, nullptr); - } + bool CanCoerceToProperType(TypeExp const& typeExp) + { + return CoerceToProperTypeImpl(typeExp, nullptr); + } - // Check a type, and coerce it to be proper - TypeExp CheckProperType(TypeExp typeExp) - { - return CoerceToProperType(TranslateTypeNode(typeExp)); - } + // Check a type, and coerce it to be proper + TypeExp CheckProperType(TypeExp typeExp) + { + return CoerceToProperType(TranslateTypeNode(typeExp)); + } - // For our purposes, a "usable" type is one that can be - // used to declare a function parameter, variable, etc. - // These turn out to be all the proper types except - // `void`. - // - // TODO(tfoley): consider just allowing `void` as a - // simple example of a "unit" type, and get rid of - // this check. - TypeExp CoerceToUsableType(TypeExp const& typeExp) - { - TypeExp result = CoerceToProperType(typeExp); - ExpressionType* type = result.type.Ptr(); - if (auto basicType = type->As<BasicExpressionType>()) + // For our purposes, a "usable" type is one that can be + // used to declare a function parameter, variable, etc. + // These turn out to be all the proper types except + // `void`. + // + // TODO(tfoley): consider just allowing `void` as a + // simple example of a "unit" type, and get rid of + // this check. + TypeExp CoerceToUsableType(TypeExp const& typeExp) + { + TypeExp result = CoerceToProperType(typeExp); + ExpressionType* type = result.type.Ptr(); + if (auto basicType = type->As<BasicExpressionType>()) + { + // TODO: `void` shouldn't be a basic type, to make this easier to avoid + if (basicType->BaseType == BaseType::Void) { - // TODO: `void` shouldn't be a basic type, to make this easier to avoid - if (basicType->BaseType == BaseType::Void) - { - // TODO(tfoley): pick the right diagnostic message - getSink()->diagnose(result.exp.Ptr(), Diagnostics::invalidTypeVoid); - result.type = ExpressionType::Error; - return result; - } + // TODO(tfoley): pick the right diagnostic message + getSink()->diagnose(result.exp.Ptr(), Diagnostics::invalidTypeVoid); + result.type = ExpressionType::Error; + return result; } - return result; } + return result; + } - // Check a type, and coerce it to be usable - TypeExp CheckUsableType(TypeExp typeExp) - { - return CoerceToUsableType(TranslateTypeNode(typeExp)); - } + // Check a type, and coerce it to be usable + TypeExp CheckUsableType(TypeExp typeExp) + { + return CoerceToUsableType(TranslateTypeNode(typeExp)); + } - RefPtr<ExpressionSyntaxNode> CheckTerm(RefPtr<ExpressionSyntaxNode> term) - { - if (!term) return nullptr; - return term->Accept(this).As<ExpressionSyntaxNode>(); - } + RefPtr<ExpressionSyntaxNode> CheckTerm(RefPtr<ExpressionSyntaxNode> term) + { + if (!term) return nullptr; + return term->Accept(this).As<ExpressionSyntaxNode>(); + } - RefPtr<ExpressionSyntaxNode> CreateErrorExpr(ExpressionSyntaxNode* expr) - { - expr->Type = ExpressionType::Error; - return expr; - } + RefPtr<ExpressionSyntaxNode> CreateErrorExpr(ExpressionSyntaxNode* expr) + { + expr->Type = ExpressionType::Error; + return expr; + } - bool IsErrorExpr(RefPtr<ExpressionSyntaxNode> expr) - { - // TODO: we may want other cases here... + bool IsErrorExpr(RefPtr<ExpressionSyntaxNode> expr) + { + // TODO: we may want other cases here... - if (expr->Type->Equals(ExpressionType::Error)) - return true; + if (expr->Type->Equals(ExpressionType::Error)) + return true; - return false; - } + return false; + } - // Capture the "base" expression in case this is a member reference - RefPtr<ExpressionSyntaxNode> GetBaseExpr(RefPtr<ExpressionSyntaxNode> expr) + // Capture the "base" expression in case this is a member reference + RefPtr<ExpressionSyntaxNode> GetBaseExpr(RefPtr<ExpressionSyntaxNode> expr) + { + if (auto memberExpr = expr.As<MemberExpressionSyntaxNode>()) { - if (auto memberExpr = expr.As<MemberExpressionSyntaxNode>()) - { - return memberExpr->BaseExpression; - } - else if(auto overloadedExpr = expr.As<OverloadedExpr>()) - { - return overloadedExpr->base; - } - return nullptr; + return memberExpr->BaseExpression; + } + else if(auto overloadedExpr = expr.As<OverloadedExpr>()) + { + return overloadedExpr->base; } + return nullptr; + } - public: + public: - typedef unsigned int ConversionCost; - enum : ConversionCost - { - // No conversion at all - kConversionCost_None = 0, + typedef unsigned int ConversionCost; + enum : ConversionCost + { + // No conversion at all + kConversionCost_None = 0, - // Conversions based on explicit sub-typing relationships are the cheapest - // - // TODO(tfoley): We will eventually need a discipline for ranking - // when two up-casts are comparable. - kConversionCost_CastToInterface = 50, + // Conversions based on explicit sub-typing relationships are the cheapest + // + // TODO(tfoley): We will eventually need a discipline for ranking + // when two up-casts are comparable. + kConversionCost_CastToInterface = 50, - // Conversion that is lossless and keeps the "kind" of the value the same - kConversionCost_RankPromotion = 100, + // Conversion that is lossless and keeps the "kind" of the value the same + kConversionCost_RankPromotion = 100, - // Conversions that are lossless, but change "kind" - kConversionCost_UnsignedToSignedPromotion = 200, + // Conversions that are lossless, but change "kind" + kConversionCost_UnsignedToSignedPromotion = 200, - // Conversion from signed->unsigned integer of same or greater size - kConversionCost_SignedToUnsignedConversion = 300, + // Conversion from signed->unsigned integer of same or greater size + kConversionCost_SignedToUnsignedConversion = 300, - // Cost of converting an integer to a floating-point type - kConversionCost_IntegerToFloatConversion = 400, + // Cost of converting an integer to a floating-point type + kConversionCost_IntegerToFloatConversion = 400, - // Catch-all for conversions that should be discouraged - // (i.e., that really shouldn't be made implicitly) - // - // TODO: make these conversions not be allowed implicitly in "Slang mode" - kConversionCost_GeneralConversion = 900, + // Catch-all for conversions that should be discouraged + // (i.e., that really shouldn't be made implicitly) + // + // TODO: make these conversions not be allowed implicitly in "Slang mode" + kConversionCost_GeneralConversion = 900, - // Additional conversion cost to add when promoting from a scalar to - // a vector (this will be added to the cost, if any, of converting - // the element type of the vector) - kConversionCost_ScalarToVector = 1, - }; + // Additional conversion cost to add when promoting from a scalar to + // a vector (this will be added to the cost, if any, of converting + // the element type of the vector) + kConversionCost_ScalarToVector = 1, + }; - enum BaseTypeConversionKind : uint8_t - { - kBaseTypeConversionKind_Signed, - kBaseTypeConversionKind_Unsigned, - kBaseTypeConversionKind_Float, - kBaseTypeConversionKind_Error, - }; + enum BaseTypeConversionKind : uint8_t + { + kBaseTypeConversionKind_Signed, + kBaseTypeConversionKind_Unsigned, + kBaseTypeConversionKind_Float, + kBaseTypeConversionKind_Error, + }; - enum BaseTypeConversionRank : uint8_t - { - kBaseTypeConversionRank_Bool, - kBaseTypeConversionRank_Int8, - kBaseTypeConversionRank_Int16, - kBaseTypeConversionRank_Int32, - kBaseTypeConversionRank_IntPtr, - kBaseTypeConversionRank_Int64, - kBaseTypeConversionRank_Error, - }; + enum BaseTypeConversionRank : uint8_t + { + kBaseTypeConversionRank_Bool, + kBaseTypeConversionRank_Int8, + kBaseTypeConversionRank_Int16, + kBaseTypeConversionRank_Int32, + kBaseTypeConversionRank_IntPtr, + kBaseTypeConversionRank_Int64, + kBaseTypeConversionRank_Error, + }; - struct BaseTypeConversionInfo - { - BaseTypeConversionKind kind; - BaseTypeConversionRank rank; - }; - static BaseTypeConversionInfo GetBaseTypeConversionInfo(BaseType baseType) + struct BaseTypeConversionInfo + { + BaseTypeConversionKind kind; + BaseTypeConversionRank rank; + }; + static BaseTypeConversionInfo GetBaseTypeConversionInfo(BaseType baseType) + { + switch (baseType) { - switch (baseType) - { - #define CASE(TAG, KIND, RANK) \ - case BaseType::TAG: { BaseTypeConversionInfo info = {kBaseTypeConversionKind_##KIND, kBaseTypeConversionRank_##RANK}; return info; } break + #define CASE(TAG, KIND, RANK) \ + case BaseType::TAG: { BaseTypeConversionInfo info = {kBaseTypeConversionKind_##KIND, kBaseTypeConversionRank_##RANK}; return info; } break - CASE(Bool, Unsigned, Bool); - CASE(Int, Signed, Int32); - CASE(UInt, Unsigned, Int32); - CASE(UInt64, Unsigned, Int64); - CASE(Float, Float, Int32); - CASE(Void, Error, Error); + CASE(Bool, Unsigned, Bool); + CASE(Int, Signed, Int32); + CASE(UInt, Unsigned, Int32); + CASE(UInt64, Unsigned, Int64); + CASE(Float, Float, Int32); + CASE(Void, Error, Error); - #undef CASE + #undef CASE - default: - break; - } - SLANG_UNREACHABLE("all cases handled"); + default: + break; } + SLANG_UNREACHABLE("all cases handled"); + } - bool ValuesAreEqual( - RefPtr<IntVal> left, - RefPtr<IntVal> right) - { - if(left == right) return true; + bool ValuesAreEqual( + RefPtr<IntVal> left, + RefPtr<IntVal> right) + { + if(left == right) return true; - if(auto leftConst = left.As<ConstantIntVal>()) + if(auto leftConst = left.As<ConstantIntVal>()) + { + if(auto rightConst = right.As<ConstantIntVal>()) { - if(auto rightConst = right.As<ConstantIntVal>()) - { - return leftConst->value == rightConst->value; - } + return leftConst->value == rightConst->value; } + } - if(auto leftVar = left.As<GenericParamIntVal>()) + if(auto leftVar = left.As<GenericParamIntVal>()) + { + if(auto rightVar = right.As<GenericParamIntVal>()) { - if(auto rightVar = right.As<GenericParamIntVal>()) - { - return leftVar->declRef.Equals(rightVar->declRef); - } + return leftVar->declRef.Equals(rightVar->declRef); } - - return false; } - // Central engine for implementing implicit coercion logic - bool TryCoerceImpl( - RefPtr<ExpressionType> toType, // the target type for conversion - RefPtr<ExpressionSyntaxNode>* outToExpr, // (optional) a place to stuff the target expression - RefPtr<ExpressionType> fromType, // the source type for the conversion - RefPtr<ExpressionSyntaxNode> fromExpr, // the source expression - ConversionCost* outCost) // (optional) a place to stuff the conversion cost + return false; + } + + // Central engine for implementing implicit coercion logic + bool TryCoerceImpl( + RefPtr<ExpressionType> toType, // the target type for conversion + RefPtr<ExpressionSyntaxNode>* outToExpr, // (optional) a place to stuff the target expression + RefPtr<ExpressionType> fromType, // the source type for the conversion + RefPtr<ExpressionSyntaxNode> fromExpr, // the source expression + ConversionCost* outCost) // (optional) a place to stuff the conversion cost + { + // Easy case: the types are equal + if (toType->Equals(fromType)) { - // Easy case: the types are equal - if (toType->Equals(fromType)) - { - if (outToExpr) - *outToExpr = fromExpr; - if (outCost) - *outCost = kConversionCost_None; - return true; - } + if (outToExpr) + *outToExpr = fromExpr; + if (outCost) + *outCost = kConversionCost_None; + return true; + } - // If either type is an error, then let things pass. - if (toType->As<ErrorType>() || fromType->As<ErrorType>()) - { - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) - *outCost = kConversionCost_None; - return true; - } + // If either type is an error, then let things pass. + if (toType->As<ErrorType>() || fromType->As<ErrorType>()) + { + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = kConversionCost_None; + return true; + } - // Coercion from an initializer list is allowed for many types - if( auto fromInitializerListExpr = fromExpr.As<InitializerListExpr>()) - { - auto argCount = fromInitializerListExpr->args.Count(); + // Coercion from an initializer list is allowed for many types + if( auto fromInitializerListExpr = fromExpr.As<InitializerListExpr>()) + { + auto argCount = fromInitializerListExpr->args.Count(); - // In the case where we need to build a reuslt expression, - // we will collect the new arguments here - List<RefPtr<ExpressionSyntaxNode>> coercedArgs; + // In the case where we need to build a reuslt expression, + // we will collect the new arguments here + List<RefPtr<ExpressionSyntaxNode>> coercedArgs; - if(auto toDeclRefType = toType->As<DeclRefType>()) + if(auto toDeclRefType = toType->As<DeclRefType>()) + { + auto toTypeDeclRef = toDeclRefType->declRef; + if(auto toStructDeclRef = toTypeDeclRef.As<StructDeclRef>()) { - auto toTypeDeclRef = toDeclRefType->declRef; - if(auto toStructDeclRef = toTypeDeclRef.As<StructDeclRef>()) - { - // Trying to initialize a `struct` type given an initializer list. - // We will go through the fields in order and try to match them - // up with initializer arguments. + // Trying to initialize a `struct` type given an initializer list. + // We will go through the fields in order and try to match them + // up with initializer arguments. - int argIndex = 0; - for(auto fieldDeclRef : toStructDeclRef.GetMembersOfType<FieldDeclRef>()) + int argIndex = 0; + for(auto fieldDeclRef : toStructDeclRef.GetMembersOfType<FieldDeclRef>()) + { + if(argIndex >= argCount) { - if(argIndex >= argCount) - { - // We've consumed all the arguments, so we should stop - break; - } - - auto arg = fromInitializerListExpr->args[argIndex++]; - - // - RefPtr<ExpressionSyntaxNode> coercedArg; - ConversionCost argCost; - - bool argResult = TryCoerceImpl( - fieldDeclRef.GetType(), - outToExpr ? &coercedArg : nullptr, - arg->Type, - arg, - outCost ? &argCost : nullptr); - - // No point in trying further if any argument fails - if(!argResult) - return false; - - // TODO(tfoley): what to do with cost? - // This only matters if/when we allow an initializer list as an argument to - // an overloaded call. - - if( outToExpr ) - { - coercedArgs.Add(coercedArg); - } + // We've consumed all the arguments, so we should stop + break; } - } - } - else if(auto toArrayType = toType->As<ArrayExpressionType>()) - { - // TODO(tfoley): If we can compute the size of the array statically, - // then we want to check that there aren't too many initializers present - auto toElementType = toArrayType->BaseType; + auto arg = fromInitializerListExpr->args[argIndex++]; - for(auto& arg : fromInitializerListExpr->args) - { + // RefPtr<ExpressionSyntaxNode> coercedArg; ConversionCost argCost; bool argResult = TryCoerceImpl( - toElementType, - outToExpr ? &coercedArg : nullptr, - arg->Type, - arg, - outCost ? &argCost : nullptr); + fieldDeclRef.GetType(), + outToExpr ? &coercedArg : nullptr, + arg->Type, + arg, + outCost ? &argCost : nullptr); // No point in trying further if any argument fails if(!argResult) return false; + // TODO(tfoley): what to do with cost? + // This only matters if/when we allow an initializer list as an argument to + // an overloaded call. + if( outToExpr ) { coercedArgs.Add(coercedArg); } } } - else - { - // By default, we don't allow a type to be initialized using - // an initializer list. - return false; - } + } + else if(auto toArrayType = toType->As<ArrayExpressionType>()) + { + // TODO(tfoley): If we can compute the size of the array statically, + // then we want to check that there aren't too many initializers present - // For now, coercion from an initializer list has no cost - if(outCost) + auto toElementType = toArrayType->BaseType; + + for(auto& arg : fromInitializerListExpr->args) { - *outCost = kConversionCost_None; + RefPtr<ExpressionSyntaxNode> coercedArg; + ConversionCost argCost; + + bool argResult = TryCoerceImpl( + toElementType, + outToExpr ? &coercedArg : nullptr, + arg->Type, + arg, + outCost ? &argCost : nullptr); + + // No point in trying further if any argument fails + if(!argResult) + return false; + + if( outToExpr ) + { + coercedArgs.Add(coercedArg); + } } + } + else + { + // By default, we don't allow a type to be initialized using + // an initializer list. + return false; + } - // We were able to coerce all the arguments given, and so - // we need to construct a suitable expression to remember the result - if(outToExpr) - { - auto toInitializerListExpr = new InitializerListExpr(); - toInitializerListExpr->Position = fromInitializerListExpr->Position; - toInitializerListExpr->Type = toType; - toInitializerListExpr->args = coercedArgs; + // For now, coercion from an initializer list has no cost + if(outCost) + { + *outCost = kConversionCost_None; + } + // We were able to coerce all the arguments given, and so + // we need to construct a suitable expression to remember the result + if(outToExpr) + { + auto toInitializerListExpr = new InitializerListExpr(); + toInitializerListExpr->Position = fromInitializerListExpr->Position; + toInitializerListExpr->Type = toType; + toInitializerListExpr->args = coercedArgs; - *outToExpr = toInitializerListExpr; - } - return true; + *outToExpr = toInitializerListExpr; } - // + return true; + } - if (auto toBasicType = toType->AsBasicType()) + // + + if (auto toBasicType = toType->AsBasicType()) + { + if (auto fromBasicType = fromType->AsBasicType()) { - if (auto fromBasicType = fromType->AsBasicType()) - { - // Conversions between base types are always allowed, - // and the only question is what the cost will be. + // Conversions between base types are always allowed, + // and the only question is what the cost will be. - auto toInfo = GetBaseTypeConversionInfo(toBasicType->BaseType); - auto fromInfo = GetBaseTypeConversionInfo(fromBasicType->BaseType); + auto toInfo = GetBaseTypeConversionInfo(toBasicType->BaseType); + auto fromInfo = GetBaseTypeConversionInfo(fromBasicType->BaseType); - // We expect identical types to have been dealt with already. - assert(toInfo.kind != fromInfo.kind || toInfo.rank != fromInfo.rank); + // We expect identical types to have been dealt with already. + assert(toInfo.kind != fromInfo.kind || toInfo.rank != fromInfo.rank); - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) + if (outCost) + { + // Conversions within the same kind are easist to handle + if (toInfo.kind == fromInfo.kind) { - // Conversions within the same kind are easist to handle - if (toInfo.kind == fromInfo.kind) - { - // If we are converting to a "larger" type, then - // we are doing a lossless promotion, and otherwise - // we are doing a demotion. - if( toInfo.rank > fromInfo.rank) - *outCost = kConversionCost_RankPromotion; - else - *outCost = kConversionCost_GeneralConversion; - } - // If we are converting from an unsigned integer type to - // a signed integer type that is guaranteed to be larger, - // then that is also a lossless promotion. - else if(toInfo.kind == kBaseTypeConversionKind_Signed - && fromInfo.kind == kBaseTypeConversionKind_Unsigned - && toInfo.rank > fromInfo.rank) - { - // TODO: probably need to weed out cases involving - // "pointer-sized" integers if these are treated - // as distinct from 32- and 64-bit types. - // E.g., there is no guarantee that conversion - // from 32-bit unsigned to pointer-sized signed - // is lossless, because pointers could be 32-bit, - // and the same applies for conversion from - // `uintptr` to `uint64`. - *outCost = kConversionCost_UnsignedToSignedPromotion; - } - // Conversion from signed to unsigned is always lossy, - // but it is preferred over conversions from unsigned - // to signed, for same-size types. - else if(toInfo.kind == kBaseTypeConversionKind_Unsigned - && fromInfo.kind == kBaseTypeConversionKind_Signed - && toInfo.rank >= fromInfo.rank) - { - *outCost = kConversionCost_SignedToUnsignedConversion; - } - // Conversion from an integer to a floating-point type - // is never considered a promotion (even when the value - // would fit in the available bits). - // If the destination type is at least 32 bits we consider - // this a reasonably good conversion, though. - else if (toInfo.kind == kBaseTypeConversionKind_Float - && toInfo.rank >= kBaseTypeConversionRank_Int32) - { - *outCost = kConversionCost_IntegerToFloatConversion; - } - // All other cases are considered as "general" conversions, - // where we don't consider any one conversion better than - // any others. + // If we are converting to a "larger" type, then + // we are doing a lossless promotion, and otherwise + // we are doing a demotion. + if( toInfo.rank > fromInfo.rank) + *outCost = kConversionCost_RankPromotion; else - { *outCost = kConversionCost_GeneralConversion; - } } - - return true; + // If we are converting from an unsigned integer type to + // a signed integer type that is guaranteed to be larger, + // then that is also a lossless promotion. + else if(toInfo.kind == kBaseTypeConversionKind_Signed + && fromInfo.kind == kBaseTypeConversionKind_Unsigned + && toInfo.rank > fromInfo.rank) + { + // TODO: probably need to weed out cases involving + // "pointer-sized" integers if these are treated + // as distinct from 32- and 64-bit types. + // E.g., there is no guarantee that conversion + // from 32-bit unsigned to pointer-sized signed + // is lossless, because pointers could be 32-bit, + // and the same applies for conversion from + // `uintptr` to `uint64`. + *outCost = kConversionCost_UnsignedToSignedPromotion; + } + // Conversion from signed to unsigned is always lossy, + // but it is preferred over conversions from unsigned + // to signed, for same-size types. + else if(toInfo.kind == kBaseTypeConversionKind_Unsigned + && fromInfo.kind == kBaseTypeConversionKind_Signed + && toInfo.rank >= fromInfo.rank) + { + *outCost = kConversionCost_SignedToUnsignedConversion; + } + // Conversion from an integer to a floating-point type + // is never considered a promotion (even when the value + // would fit in the available bits). + // If the destination type is at least 32 bits we consider + // this a reasonably good conversion, though. + else if (toInfo.kind == kBaseTypeConversionKind_Float + && toInfo.rank >= kBaseTypeConversionRank_Int32) + { + *outCost = kConversionCost_IntegerToFloatConversion; + } + // All other cases are considered as "general" conversions, + // where we don't consider any one conversion better than + // any others. + else + { + *outCost = kConversionCost_GeneralConversion; + } } + + return true; } + } - if (auto toVectorType = toType->AsVectorType()) + if (auto toVectorType = toType->AsVectorType()) + { + if (auto fromVectorType = fromType->AsVectorType()) { - if (auto fromVectorType = fromType->AsVectorType()) - { - // Conversion between vector types. + // Conversion between vector types. - // If element counts don't match, then bail: - if (!ValuesAreEqual(toVectorType->elementCount, fromVectorType->elementCount)) - return false; + // If element counts don't match, then bail: + if (!ValuesAreEqual(toVectorType->elementCount, fromVectorType->elementCount)) + return false; - // Otherwise, if we can convert the element types, we are golden - ConversionCost elementCost; - if (CanCoerce(toVectorType->elementType, fromVectorType->elementType, &elementCost)) - { - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) - *outCost = elementCost; - return true; - } + // Otherwise, if we can convert the element types, we are golden + ConversionCost elementCost; + if (CanCoerce(toVectorType->elementType, fromVectorType->elementType, &elementCost)) + { + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = elementCost; + return true; } - else if (auto fromScalarType = fromType->AsBasicType()) + } + else if (auto fromScalarType = fromType->AsBasicType()) + { + // Conversion from scalar to vector. + // Should allow as long as we can coerce the scalar to our element type. + ConversionCost elementCost; + if (CanCoerce(toVectorType->elementType, fromScalarType, &elementCost)) { - // Conversion from scalar to vector. - // Should allow as long as we can coerce the scalar to our element type. - ConversionCost elementCost; - if (CanCoerce(toVectorType->elementType, fromScalarType, &elementCost)) - { - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) - *outCost = elementCost + kConversionCost_ScalarToVector; - return true; - } + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = elementCost + kConversionCost_ScalarToVector; + return true; } } + } - if (auto toDeclRefType = toType->As<DeclRefType>()) + if (auto toDeclRefType = toType->As<DeclRefType>()) + { + auto toTypeDeclRef = toDeclRefType->declRef; + if (auto interfaceDeclRef = toTypeDeclRef.As<InterfaceDeclRef>()) { - auto toTypeDeclRef = toDeclRefType->declRef; - if (auto interfaceDeclRef = toTypeDeclRef.As<InterfaceDeclRef>()) + // Trying to convert to an interface type. + // + // We will allow this if the type conforms to the interface. + if (DoesTypeConformToInterface(fromType, interfaceDeclRef)) { - // Trying to convert to an interface type. - // - // We will allow this if the type conforms to the interface. - if (DoesTypeConformToInterface(fromType, interfaceDeclRef)) - { - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) - *outCost = kConversionCost_CastToInterface; - return true; - } + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = kConversionCost_CastToInterface; + return true; } } + } - // TODO: more cases! + // TODO: more cases! - return false; - } + return false; + } - // Check whether a type coercion is possible - bool CanCoerce( - RefPtr<ExpressionType> toType, // the target type for conversion - RefPtr<ExpressionType> fromType, // the source type for the conversion - ConversionCost* outCost = 0) // (optional) a place to stuff the conversion cost - { - return TryCoerceImpl( - toType, - nullptr, - fromType, - nullptr, - outCost); - } + // Check whether a type coercion is possible + bool CanCoerce( + RefPtr<ExpressionType> toType, // the target type for conversion + RefPtr<ExpressionType> fromType, // the source type for the conversion + ConversionCost* outCost = 0) // (optional) a place to stuff the conversion cost + { + return TryCoerceImpl( + toType, + nullptr, + fromType, + nullptr, + outCost); + } - RefPtr<ExpressionSyntaxNode> CreateImplicitCastExpr( - RefPtr<ExpressionType> toType, - RefPtr<ExpressionSyntaxNode> fromExpr) + RefPtr<ExpressionSyntaxNode> CreateImplicitCastExpr( + RefPtr<ExpressionType> toType, + RefPtr<ExpressionSyntaxNode> fromExpr) + { + auto castExpr = new TypeCastExpressionSyntaxNode(); + castExpr->Position = fromExpr->Position; + castExpr->TargetType.type = toType; + castExpr->Type = toType; + castExpr->Expression = fromExpr; + return castExpr; + } + + + // Perform type coercion, and emit errors if it isn't possible + RefPtr<ExpressionSyntaxNode> Coerce( + RefPtr<ExpressionType> toType, + RefPtr<ExpressionSyntaxNode> fromExpr) + { + // If semantic checking is being suppressed, then we might see + // expressions without a type, and we need to ignore them. + if( !fromExpr->Type.type ) { - auto castExpr = new TypeCastExpressionSyntaxNode(); - castExpr->Position = fromExpr->Position; - castExpr->TargetType.type = toType; - castExpr->Type = toType; - castExpr->Expression = fromExpr; - return castExpr; + if(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING ) + return fromExpr; } - - // Perform type coercion, and emit errors if it isn't possible - RefPtr<ExpressionSyntaxNode> Coerce( - RefPtr<ExpressionType> toType, - RefPtr<ExpressionSyntaxNode> fromExpr) + RefPtr<ExpressionSyntaxNode> expr; + if (!TryCoerceImpl( + toType, + &expr, + fromExpr->Type.Ptr(), + fromExpr.Ptr(), + nullptr)) { - // If semantic checking is being suppressed, then we might see - // expressions without a type, and we need to ignore them. - if( !fromExpr->Type.type ) + if(!(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING)) { - if(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING ) - return fromExpr; + getSink()->diagnose(fromExpr->Position, Diagnostics::typeMismatch, toType, fromExpr->Type); } - RefPtr<ExpressionSyntaxNode> expr; - if (!TryCoerceImpl( - toType, - &expr, - fromExpr->Type.Ptr(), - fromExpr.Ptr(), - nullptr)) - { - if(!(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING)) - { - getSink()->diagnose(fromExpr->Position, Diagnostics::typeMismatch, toType, fromExpr->Type); - } - - // Note(tfoley): We don't call `CreateErrorExpr` here, because that would - // clobber the type on `fromExpr`, and an invariant here is that coercion - // really shouldn't *change* the expression that is passed in, but should - // introduce new AST nodes to coerce its value to a different type... - return CreateImplicitCastExpr(ExpressionType::Error, fromExpr); - } - return expr; + // Note(tfoley): We don't call `CreateErrorExpr` here, because that would + // clobber the type on `fromExpr`, and an invariant here is that coercion + // really shouldn't *change* the expression that is passed in, but should + // introduce new AST nodes to coerce its value to a different type... + return CreateImplicitCastExpr(ExpressionType::Error, fromExpr); } + return expr; + } - void CheckVarDeclCommon(RefPtr<VarDeclBase> varDecl) - { - // Check the type, if one was given - TypeExp type = CheckUsableType(varDecl->Type); + void CheckVarDeclCommon(RefPtr<VarDeclBase> varDecl) + { + // Check the type, if one was given + TypeExp type = CheckUsableType(varDecl->Type); - // TODO: Additional validation rules on types should go here, - // but we need to deal with the fact that some cases might be - // allowed in one context (e.g., an unsized array parameter) - // but not in othters (e.g., an unsized array field in a struct). + // TODO: Additional validation rules on types should go here, + // but we need to deal with the fact that some cases might be + // allowed in one context (e.g., an unsized array parameter) + // but not in othters (e.g., an unsized array field in a struct). - // Check the initializers, if one was given - RefPtr<ExpressionSyntaxNode> initExpr = CheckTerm(varDecl->Expr); + // Check the initializers, if one was given + RefPtr<ExpressionSyntaxNode> initExpr = CheckTerm(varDecl->Expr); - // If a type was given, ... - if (type.Ptr()) + // If a type was given, ... + if (type.Ptr()) + { + // then coerce any initializer to the type + if (initExpr) { - // then coerce any initializer to the type - if (initExpr) - { - initExpr = Coerce(type, initExpr); - } + initExpr = Coerce(type, initExpr); + } + } + else + { + // TODO: infer a type from the initializers + if (!initExpr) + { + getSink()->diagnose(varDecl, Diagnostics::unimplemented, "variable declaration with no type must have initializer"); } else { - // TODO: infer a type from the initializers - if (!initExpr) - { - getSink()->diagnose(varDecl, Diagnostics::unimplemented, "variable declaration with no type must have initializer"); - } - else - { - getSink()->diagnose(varDecl, Diagnostics::unimplemented, "type inference for variable declaration"); - } + getSink()->diagnose(varDecl, Diagnostics::unimplemented, "type inference for variable declaration"); } - - varDecl->Type = type; - varDecl->Expr = initExpr; } - void CheckGenericConstraintDecl(GenericTypeConstraintDecl* decl) - { - // TODO: are there any other validations we can do at this point? - // - // There probably needs to be a kind of "occurs check" to make - // sure that the constraint actually applies to at least one - // of the parameters of the generic. + varDecl->Type = type; + varDecl->Expr = initExpr; + } - decl->sub = TranslateTypeNode(decl->sub); - decl->sup = TranslateTypeNode(decl->sup); - } + void CheckGenericConstraintDecl(GenericTypeConstraintDecl* decl) + { + // TODO: are there any other validations we can do at this point? + // + // There probably needs to be a kind of "occurs check" to make + // sure that the constraint actually applies to at least one + // of the parameters of the generic. - virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl* genericDecl) override + decl->sub = TranslateTypeNode(decl->sub); + decl->sup = TranslateTypeNode(decl->sup); + } + + virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl* genericDecl) override + { + // check the parameters + for (auto m : genericDecl->Members) { - // check the parameters - for (auto m : genericDecl->Members) + if (auto typeParam = m.As<GenericTypeParamDecl>()) { - if (auto typeParam = m.As<GenericTypeParamDecl>()) - { - typeParam->initType = CheckProperType(typeParam->initType); - } - else if (auto valParam = m.As<GenericValueParamDecl>()) - { - // TODO: some real checking here... - CheckVarDeclCommon(valParam); - } - else if(auto constraint = m.As<GenericTypeConstraintDecl>()) - { - CheckGenericConstraintDecl(constraint.Ptr()); - } + typeParam->initType = CheckProperType(typeParam->initType); + } + else if (auto valParam = m.As<GenericValueParamDecl>()) + { + // TODO: some real checking here... + CheckVarDeclCommon(valParam); + } + else if(auto constraint = m.As<GenericTypeConstraintDecl>()) + { + CheckGenericConstraintDecl(constraint.Ptr()); } - - // check the nested declaration - // TODO: this needs to be done in an appropriate environment... - genericDecl->inner->Accept(this); - return genericDecl; } - virtual void visitInterfaceDecl(InterfaceDecl* decl) override - { - // TODO: do some actual checking of members here - } + // check the nested declaration + // TODO: this needs to be done in an appropriate environment... + genericDecl->inner->Accept(this); + return genericDecl; + } - virtual void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) override - { - // check the type being inherited from - auto base = inheritanceDecl->base; - base = TranslateTypeNode(base); - inheritanceDecl->base = base; + virtual void visitInterfaceDecl(InterfaceDecl* decl) override + { + // TODO: do some actual checking of members here + } + + virtual void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) override + { + // check the type being inherited from + auto base = inheritanceDecl->base; + base = TranslateTypeNode(base); + inheritanceDecl->base = base; - // For now we only allow inheritance from interfaces, so - // we will validate that the type expression names an interface + // For now we only allow inheritance from interfaces, so + // we will validate that the type expression names an interface - if(auto declRefType = base.type->As<DeclRefType>()) + if(auto declRefType = base.type->As<DeclRefType>()) + { + if(auto interfaceDeclRef = declRefType->declRef.As<InterfaceDeclRef>()) { - if(auto interfaceDeclRef = declRefType->declRef.As<InterfaceDeclRef>()) - { - return; - } + return; } - - // If type expression didn't name an interface, we'll emit an error here - // TODO: deal with the case of an error in the type expression (don't cascade) - getSink()->diagnose( base.exp, Diagnostics::expectedAnInterfaceGot, base.type); } - RefPtr<ConstantIntVal> checkConstantIntVal( - RefPtr<ExpressionSyntaxNode> expr) - { - // First type-check the expression as normal - expr = CheckExpr(expr); + // If type expression didn't name an interface, we'll emit an error here + // TODO: deal with the case of an error in the type expression (don't cascade) + getSink()->diagnose( base.exp, Diagnostics::expectedAnInterfaceGot, base.type); + } - auto intVal = CheckIntegerConstantExpression(expr.Ptr()); - if(!intVal) - return nullptr; + RefPtr<ConstantIntVal> checkConstantIntVal( + RefPtr<ExpressionSyntaxNode> expr) + { + // First type-check the expression as normal + expr = CheckExpr(expr); - auto constIntVal = intVal.As<ConstantIntVal>(); - if(!constIntVal) - { - getSink()->diagnose(expr->Position, Diagnostics::expectedIntegerConstantNotLiteral); - return nullptr; - } - return constIntVal; + auto intVal = CheckIntegerConstantExpression(expr.Ptr()); + if(!intVal) + return nullptr; + + auto constIntVal = intVal.As<ConstantIntVal>(); + if(!constIntVal) + { + getSink()->diagnose(expr->Position, Diagnostics::expectedIntegerConstantNotLiteral); + return nullptr; } + return constIntVal; + } - RefPtr<Modifier> checkModifier( - RefPtr<Modifier> m, - Decl* /*decl*/) + RefPtr<Modifier> checkModifier( + RefPtr<Modifier> m, + Decl* /*decl*/) + { + if(auto hlslUncheckedAttribute = m.As<HLSLUncheckedAttribute>()) { - if(auto hlslUncheckedAttribute = m.As<HLSLUncheckedAttribute>()) + // We have an HLSL `[name(arg,...)]` attribute, and we'd like + // to check that it is provides all the expected arguments + // + // For now we will do this in a completely ad hoc fashion, + // but it would be nice to have some generic routine to + // do the needed type checking/coercion. + if(hlslUncheckedAttribute->nameToken.Content == "numthreads") { - // We have an HLSL `[name(arg,...)]` attribute, and we'd like - // to check that it is provides all the expected arguments - // - // For now we will do this in a completely ad hoc fashion, - // but it would be nice to have some generic routine to - // do the needed type checking/coercion. - if(hlslUncheckedAttribute->nameToken.Content == "numthreads") - { - if(hlslUncheckedAttribute->args.Count() != 3) - return m; + if(hlslUncheckedAttribute->args.Count() != 3) + return m; - auto xVal = checkConstantIntVal(hlslUncheckedAttribute->args[0]); - auto yVal = checkConstantIntVal(hlslUncheckedAttribute->args[1]); - auto zVal = checkConstantIntVal(hlslUncheckedAttribute->args[2]); + auto xVal = checkConstantIntVal(hlslUncheckedAttribute->args[0]); + auto yVal = checkConstantIntVal(hlslUncheckedAttribute->args[1]); + auto zVal = checkConstantIntVal(hlslUncheckedAttribute->args[2]); - if(!xVal) return m; - if(!yVal) return m; - if(!zVal) return m; + if(!xVal) return m; + if(!yVal) return m; + if(!zVal) return m; - auto hlslNumThreadsAttribute = new HLSLNumThreadsAttribute(); + auto hlslNumThreadsAttribute = new HLSLNumThreadsAttribute(); - hlslNumThreadsAttribute->Position = hlslUncheckedAttribute->Position; - hlslNumThreadsAttribute->nameToken = hlslUncheckedAttribute->nameToken; - hlslNumThreadsAttribute->args = hlslUncheckedAttribute->args; - hlslNumThreadsAttribute->x = xVal->value; - hlslNumThreadsAttribute->y = yVal->value; - hlslNumThreadsAttribute->z = zVal->value; + hlslNumThreadsAttribute->Position = hlslUncheckedAttribute->Position; + hlslNumThreadsAttribute->nameToken = hlslUncheckedAttribute->nameToken; + hlslNumThreadsAttribute->args = hlslUncheckedAttribute->args; + hlslNumThreadsAttribute->x = xVal->value; + hlslNumThreadsAttribute->y = yVal->value; + hlslNumThreadsAttribute->z = zVal->value; - return hlslNumThreadsAttribute; - } + return hlslNumThreadsAttribute; } + } - // Default behavior is to leave things as they are, - // and assume that modifiers are mostly already checked. - // - // TODO: This would be a good place to validate that - // a modifier is actually valid for the thing it is - // being applied to, and potentially to check that - // it isn't in conflict with any other modifiers - // on the same declaration. + // Default behavior is to leave things as they are, + // and assume that modifiers are mostly already checked. + // + // TODO: This would be a good place to validate that + // a modifier is actually valid for the thing it is + // being applied to, and potentially to check that + // it isn't in conflict with any other modifiers + // on the same declaration. - return m; - } + return m; + } - void checkModifiers(Decl* decl) + void checkModifiers(Decl* decl) + { + // TODO(tfoley): need to make sure this only + // performs semantic checks on a `SharedModifier` once... + + // The process of checking a modifier may produce a new modifier in its place, + // so we will build up a new linked list of modifiers that will replace + // the old list. + RefPtr<Modifier> resultModifiers; + RefPtr<Modifier>* resultModifierLink = &resultModifiers; + + RefPtr<Modifier> modifier = decl->modifiers.first; + while(modifier) { - // TODO(tfoley): need to make sure this only - // performs semantic checks on a `SharedModifier` once... + // Because we are rewriting the list in place, we need to extract + // the next modifier here (not at the end of the loop). + auto next = modifier->next; - // The process of checking a modifier may produce a new modifier in its place, - // so we will build up a new linked list of modifiers that will replace - // the old list. - RefPtr<Modifier> resultModifiers; - RefPtr<Modifier>* resultModifierLink = &resultModifiers; + // We also go ahead and clobber the `next` field on the modifier + // itself, so that the default behavior of `checkModifier()` can + // be to return a single unlinked modifier. + modifier->next = nullptr; - RefPtr<Modifier> modifier = decl->modifiers.first; - while(modifier) + auto checkedModifier = checkModifier(modifier, decl); + if(checkedModifier) { - // Because we are rewriting the list in place, we need to extract - // the next modifier here (not at the end of the loop). - auto next = modifier->next; + // If checking gave us a modifier to add, then we + // had better add it. - // We also go ahead and clobber the `next` field on the modifier - // itself, so that the default behavior of `checkModifier()` can - // be to return a single unlinked modifier. - modifier->next = nullptr; + // Just in case `checkModifier` ever returns multiple + // modifiers, lets advance to the end of the list we + // are building. + while(*resultModifierLink) + resultModifierLink = &(*resultModifierLink)->next; - auto checkedModifier = checkModifier(modifier, decl); - if(checkedModifier) - { - // If checking gave us a modifier to add, then we - // had better add it. - - // Just in case `checkModifier` ever returns multiple - // modifiers, lets advance to the end of the list we - // are building. - while(*resultModifierLink) - resultModifierLink = &(*resultModifierLink)->next; - - // attach the new modifier at the end of the list, - // and now set the "link" to point to its `next` field - *resultModifierLink = checkedModifier; - resultModifierLink = &checkedModifier->next; - } - - // Move along to the next modifier - modifier = next; + // attach the new modifier at the end of the list, + // and now set the "link" to point to its `next` field + *resultModifierLink = checkedModifier; + resultModifierLink = &checkedModifier->next; } - // Whether we actually re-wrote anything or note, lets - // install the new list of modifiers on the declaration - decl->modifiers.first = resultModifiers; + // Move along to the next modifier + modifier = next; } - virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode * programNode) override + // Whether we actually re-wrote anything or note, lets + // install the new list of modifiers on the declaration + decl->modifiers.first = resultModifiers; + } + + virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode * programNode) override + { + // Try to register all the builtin decls + for (auto decl : programNode->Members) { - // Try to register all the builtin decls - for (auto decl : programNode->Members) + auto inner = decl; + if (auto genericDecl = decl.As<GenericDecl>()) { - auto inner = decl; - if (auto genericDecl = decl.As<GenericDecl>()) - { - inner = genericDecl->inner; - } - - if (auto builtinMod = inner->FindModifier<BuiltinTypeModifier>()) - { - RegisterBuiltinDecl(decl, builtinMod); - } - if (auto magicMod = inner->FindModifier<MagicTypeModifier>()) - { - RegisterMagicDecl(decl, magicMod); - } + inner = genericDecl->inner; } - // - - HashSet<String> funcNames; - this->program = programNode; - this->function = nullptr; - - for (auto & s : program->GetTypeDefs()) - VisitTypeDefDecl(s.Ptr()); - for (auto & s : program->GetStructs()) + if (auto builtinMod = inner->FindModifier<BuiltinTypeModifier>()) { - VisitStruct(s.Ptr()); + RegisterBuiltinDecl(decl, builtinMod); } - for (auto & s : program->GetClasses()) - { - VisitClass(s.Ptr()); - } - // HACK(tfoley): Visiting all generic declarations here, - // because otherwise they won't get visited. - for (auto & g : program->GetMembersOfType<GenericDecl>()) + if (auto magicMod = inner->FindModifier<MagicTypeModifier>()) { - VisitGenericDecl(g.Ptr()); + RegisterMagicDecl(decl, magicMod); } + } - for (auto & func : program->GetFunctions()) - { - if (!func->IsChecked(DeclCheckState::Checked)) - { - VisitFunctionDeclaration(func.Ptr()); - } - } - for (auto & func : program->GetFunctions()) - { - EnsureDecl(func); - } - - if (sink->GetErrorCount() != 0) - return programNode; - - // Force everything to be fully checked, just in case - // Note that we don't just call this on the program, - // because we'd end up recursing into this very code path... - for (auto d : programNode->Members) - { - EnusreAllDeclsRec(d); - } + // - // Do any semantic checking required on modifiers? - for (auto d : programNode->Members) - { - checkModifiers(d.Ptr()); - } + HashSet<String> funcNames; + this->program = programNode; + this->function = nullptr; - return programNode; + for (auto & s : program->GetTypeDefs()) + VisitTypeDefDecl(s.Ptr()); + for (auto & s : program->GetStructs()) + { + VisitStruct(s.Ptr()); } - - virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * classNode) override + for (auto & s : program->GetClasses()) { - if (classNode->IsChecked(DeclCheckState::Checked)) - return classNode; - classNode->SetCheckState(DeclCheckState::Checked); - - for (auto field : classNode->GetFields()) - { - field->Type = CheckUsableType(field->Type); - field->SetCheckState(DeclCheckState::Checked); - } - return classNode; + VisitClass(s.Ptr()); } - - virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * structNode) override + // HACK(tfoley): Visiting all generic declarations here, + // because otherwise they won't get visited. + for (auto & g : program->GetMembersOfType<GenericDecl>()) { - if (structNode->IsChecked(DeclCheckState::Checked)) - return structNode; - structNode->SetCheckState(DeclCheckState::Checked); + VisitGenericDecl(g.Ptr()); + } - for (auto field : structNode->GetFields()) + for (auto & func : program->GetFunctions()) + { + if (!func->IsChecked(DeclCheckState::Checked)) { - field->Type = CheckUsableType(field->Type); - field->SetCheckState(DeclCheckState::Checked); + VisitFunctionDeclaration(func.Ptr()); } - return structNode; } - - virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl) override + for (auto & func : program->GetFunctions()) { - if (decl->IsChecked(DeclCheckState::Checked)) return decl; - - decl->SetCheckState(DeclCheckState::CheckingHeader); - decl->Type = CheckProperType(decl->Type); - decl->SetCheckState(DeclCheckState::Checked); - return decl; + EnsureDecl(func); } - - virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode *functionNode) override + + if (sink->GetErrorCount() != 0) + return programNode; + + // Force everything to be fully checked, just in case + // Note that we don't just call this on the program, + // because we'd end up recursing into this very code path... + for (auto d : programNode->Members) { - if (functionNode->IsChecked(DeclCheckState::Checked)) - return functionNode; - - VisitFunctionDeclaration(functionNode); - functionNode->SetCheckState(DeclCheckState::Checked); - - if (!functionNode->IsExtern()) - { - this->function = functionNode; - if (functionNode->Body) - { - functionNode->Body->Accept(this); - } - this->function = nullptr; - } - return functionNode; + EnusreAllDeclsRec(d); } - // Check if two functions have the same signature for the purposes - // of overload resolution. - bool DoFunctionSignaturesMatch( - FunctionSyntaxNode* fst, - FunctionSyntaxNode* snd) + // Do any semantic checking required on modifiers? + for (auto d : programNode->Members) { - // TODO(tfoley): This function won't do anything sensible for generics, - // so we need to figure out a plan for that... + checkModifiers(d.Ptr()); + } - // TODO(tfoley): This copies the parameter array, which is bad for performance. - auto fstParams = fst->GetParameters().ToArray(); - auto sndParams = snd->GetParameters().ToArray(); + return programNode; + } - // If the functions have different numbers of parameters, then - // their signatures trivially don't match. - auto fstParamCount = fstParams.Count(); - auto sndParamCount = sndParams.Count(); - if (fstParamCount != sndParamCount) - return false; + virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * classNode) override + { + if (classNode->IsChecked(DeclCheckState::Checked)) + return classNode; + classNode->SetCheckState(DeclCheckState::Checked); - for (int ii = 0; ii < fstParamCount; ++ii) - { - auto fstParam = fstParams[ii]; - auto sndParam = sndParams[ii]; + for (auto field : classNode->GetFields()) + { + field->Type = CheckUsableType(field->Type); + field->SetCheckState(DeclCheckState::Checked); + } + return classNode; + } - // If a given parameter type doesn't match, then signatures don't match - if (!fstParam->Type.Equals(sndParam->Type)) - return false; + virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * structNode) override + { + if (structNode->IsChecked(DeclCheckState::Checked)) + return structNode; + structNode->SetCheckState(DeclCheckState::Checked); - // If one parameter is `out` and the other isn't, then they don't match - // - // Note(tfoley): we don't consider `out` and `inout` as distinct here, - // because there is no way for overload resolution to pick between them. - if (fstParam->HasModifier<OutModifier>() != sndParam->HasModifier<OutModifier>()) - return false; - } + for (auto field : structNode->GetFields()) + { + field->Type = CheckUsableType(field->Type); + field->SetCheckState(DeclCheckState::Checked); + } + return structNode; + } - // Note(tfoley): return type doesn't enter into it, because we can't take - // calling context into account during overload resolution. + virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return decl; - return true; - } + decl->SetCheckState(DeclCheckState::CheckingHeader); + decl->Type = CheckProperType(decl->Type); + decl->SetCheckState(DeclCheckState::Checked); + return decl; + } - void ValidateFunctionRedeclaration(FunctionSyntaxNode* funcDecl) - { - auto parentDecl = funcDecl->ParentDecl; - assert(parentDecl); - if (!parentDecl) return; + virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode *functionNode) override + { + if (functionNode->IsChecked(DeclCheckState::Checked)) + return functionNode; - // Look at previously-declared functions with the same name, - // in the same container - buildMemberDictionary(parentDecl); + VisitFunctionDeclaration(functionNode); + functionNode->SetCheckState(DeclCheckState::Checked); - for (auto prevDecl = funcDecl->nextInContainerWithSameName; prevDecl; prevDecl = prevDecl->nextInContainerWithSameName) + if (!functionNode->IsExtern()) + { + this->function = functionNode; + if (functionNode->Body) { - // Look through generics to the declaration underneath - auto prevGenericDecl = dynamic_cast<GenericDecl*>(prevDecl); - if (prevGenericDecl) - prevDecl = prevGenericDecl->inner.Ptr(); - - // We only care about previously-declared functions - // Note(tfoley): although we should really error out if the - // name is already in use for something else, like a variable... - auto prevFuncDecl = dynamic_cast<FunctionSyntaxNode*>(prevDecl); - if (!prevFuncDecl) - continue; + functionNode->Body->Accept(this); + } + this->function = nullptr; + } + return functionNode; + } - // If the parameter signatures don't match, then don't worry - if (!DoFunctionSignaturesMatch(funcDecl, prevFuncDecl)) - continue; + // Check if two functions have the same signature for the purposes + // of overload resolution. + bool DoFunctionSignaturesMatch( + FunctionSyntaxNode* fst, + FunctionSyntaxNode* snd) + { + // TODO(tfoley): This function won't do anything sensible for generics, + // so we need to figure out a plan for that... + + // TODO(tfoley): This copies the parameter array, which is bad for performance. + auto fstParams = fst->GetParameters().ToArray(); + auto sndParams = snd->GetParameters().ToArray(); + + // If the functions have different numbers of parameters, then + // their signatures trivially don't match. + auto fstParamCount = fstParams.Count(); + auto sndParamCount = sndParams.Count(); + if (fstParamCount != sndParamCount) + return false; - // If we get this far, then we've got two declarations in the same - // scope, with the same name and signature. - // - // They might just be redeclarations, which we would want to allow. + for (int ii = 0; ii < fstParamCount; ++ii) + { + auto fstParam = fstParams[ii]; + auto sndParam = sndParams[ii]; - // First, check if the return types match. - // TODO(tfolye): this code won't work for generics - if (!funcDecl->ReturnType.Equals(prevFuncDecl->ReturnType)) - { - // Bad dedeclaration - getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "redeclaration has a different return type"); + // If a given parameter type doesn't match, then signatures don't match + if (!fstParam->Type.Equals(sndParam->Type)) + return false; - // Don't bother emitting other errors at this point - break; - } + // If one parameter is `out` and the other isn't, then they don't match + // + // Note(tfoley): we don't consider `out` and `inout` as distinct here, + // because there is no way for overload resolution to pick between them. + if (fstParam->HasModifier<OutModifier>() != sndParam->HasModifier<OutModifier>()) + return false; + } - // TODO(tfoley): track the fact that there is redeclaration going on, - // so that we can detect it and react accordingly during overload resolution - // (e.g., by only considering one declaration as the canonical one...) + // Note(tfoley): return type doesn't enter into it, because we can't take + // calling context into account during overload resolution. - // If both have a body, then there is trouble - if (funcDecl->Body && prevFuncDecl->Body) - { - // Redefinition - getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "function redefinition"); + return true; + } - // Don't bother emitting other errors - break; - } + void ValidateFunctionRedeclaration(FunctionSyntaxNode* funcDecl) + { + auto parentDecl = funcDecl->ParentDecl; + assert(parentDecl); + if (!parentDecl) return; + + // Look at previously-declared functions with the same name, + // in the same container + buildMemberDictionary(parentDecl); + + for (auto prevDecl = funcDecl->nextInContainerWithSameName; prevDecl; prevDecl = prevDecl->nextInContainerWithSameName) + { + // Look through generics to the declaration underneath + auto prevGenericDecl = dynamic_cast<GenericDecl*>(prevDecl); + if (prevGenericDecl) + prevDecl = prevGenericDecl->inner.Ptr(); + + // We only care about previously-declared functions + // Note(tfoley): although we should really error out if the + // name is already in use for something else, like a variable... + auto prevFuncDecl = dynamic_cast<FunctionSyntaxNode*>(prevDecl); + if (!prevFuncDecl) + continue; + + // If the parameter signatures don't match, then don't worry + if (!DoFunctionSignaturesMatch(funcDecl, prevFuncDecl)) + continue; + + // If we get this far, then we've got two declarations in the same + // scope, with the same name and signature. + // + // They might just be redeclarations, which we would want to allow. + + // First, check if the return types match. + // TODO(tfolye): this code won't work for generics + if (!funcDecl->ReturnType.Equals(prevFuncDecl->ReturnType)) + { + // Bad dedeclaration + getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "redeclaration has a different return type"); - // TODO(tfoley): If both specific default argument expressions - // for the same value, then that is an error too... + // Don't bother emitting other errors at this point + break; } - } - void VisitFunctionDeclaration(FunctionSyntaxNode *functionNode) - { - if (functionNode->IsChecked(DeclCheckState::CheckedHeader)) return; - functionNode->SetCheckState(DeclCheckState::CheckingHeader); + // TODO(tfoley): track the fact that there is redeclaration going on, + // so that we can detect it and react accordingly during overload resolution + // (e.g., by only considering one declaration as the canonical one...) - this->function = functionNode; - auto returnType = CheckProperType(functionNode->ReturnType); - functionNode->ReturnType = returnType; - HashSet<String> paraNames; - for (auto & para : functionNode->GetParameters()) + // If both have a body, then there is trouble + if (funcDecl->Body && prevFuncDecl->Body) { - if (paraNames.Contains(para->Name.Content)) - getSink()->diagnose(para, Diagnostics::parameterAlreadyDefined, para->Name); - else - paraNames.Add(para->Name.Content); - para->Type = CheckUsableType(para->Type); - if (para->Type.Equals(ExpressionType::GetVoid())) - getSink()->diagnose(para, Diagnostics::parameterCannotBeVoid); + // Redefinition + getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "function redefinition"); + + // Don't bother emitting other errors + break; } - this->function = NULL; - functionNode->SetCheckState(DeclCheckState::CheckedHeader); - // One last bit of validation: check if we are redeclaring an existing function - ValidateFunctionRedeclaration(functionNode); + // TODO(tfoley): If both specific default argument expressions + // for the same value, then that is an error too... } + } - virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode *stmt) override + void VisitFunctionDeclaration(FunctionSyntaxNode *functionNode) + { + if (functionNode->IsChecked(DeclCheckState::CheckedHeader)) return; + functionNode->SetCheckState(DeclCheckState::CheckingHeader); + + this->function = functionNode; + auto returnType = CheckProperType(functionNode->ReturnType); + functionNode->ReturnType = returnType; + HashSet<String> paraNames; + for (auto & para : functionNode->GetParameters()) { - for (auto & node : stmt->Statements) - { - node->Accept(this); - } - return stmt; + if (paraNames.Contains(para->Name.Content)) + getSink()->diagnose(para, Diagnostics::parameterAlreadyDefined, para->Name); + else + paraNames.Add(para->Name.Content); + para->Type = CheckUsableType(para->Type); + if (para->Type.Equals(ExpressionType::GetVoid())) + getSink()->diagnose(para, Diagnostics::parameterCannotBeVoid); } + this->function = NULL; + functionNode->SetCheckState(DeclCheckState::CheckedHeader); - template<typename T> - T* FindOuterStmt() + // One last bit of validation: check if we are redeclaring an existing function + ValidateFunctionRedeclaration(functionNode); + } + + virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode *stmt) override + { + for (auto & node : stmt->Statements) { - int outerStmtCount = outerStmts.Count(); - for (int ii = outerStmtCount - 1; ii >= 0; --ii) - { - auto outerStmt = outerStmts[ii]; - auto found = dynamic_cast<T*>(outerStmt); - if (found) - return found; - } - return nullptr; + node->Accept(this); } + return stmt; + } - virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode *stmt) override + template<typename T> + T* FindOuterStmt() + { + int outerStmtCount = outerStmts.Count(); + for (int ii = outerStmtCount - 1; ii >= 0; --ii) { - auto outer = FindOuterStmt<BreakableStmt>(); - if (!outer) - { - getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); - } - stmt->parentStmt = outer; - return stmt; + auto outerStmt = outerStmts[ii]; + auto found = dynamic_cast<T*>(outerStmt); + if (found) + return found; } - virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode *stmt) override + return nullptr; + } + + virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode *stmt) override + { + auto outer = FindOuterStmt<BreakableStmt>(); + if (!outer) { - auto outer = FindOuterStmt<LoopStmt>(); - if (!outer) - { - getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); - } - stmt->parentStmt = outer; - return stmt; + getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); } - - void PushOuterStmt(StatementSyntaxNode* stmt) + stmt->parentStmt = outer; + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode *stmt) override + { + auto outer = FindOuterStmt<LoopStmt>(); + if (!outer) { - outerStmts.Add(stmt); + getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); } + stmt->parentStmt = outer; + return stmt; + } + + void PushOuterStmt(StatementSyntaxNode* stmt) + { + outerStmts.Add(stmt); + } - void PopOuterStmt(StatementSyntaxNode* /*stmt*/) + void PopOuterStmt(StatementSyntaxNode* /*stmt*/) + { + outerStmts.RemoveAt(outerStmts.Count() - 1); + } + + virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode *stmt) override + { + PushOuterStmt(stmt); + if (stmt->Predicate != NULL) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) { - outerStmts.RemoveAt(outerStmts.Count() - 1); + getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError); } + stmt->Statement->Accept(this); - virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode *stmt) override + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode *stmt) override + { + PushOuterStmt(stmt); + if (stmt->InitialStatement) { - PushOuterStmt(stmt); - if (stmt->Predicate != NULL) - stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) - { - getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError); - } - stmt->Statement->Accept(this); - - PopOuterStmt(stmt); - return stmt; + stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>(); } - virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode *stmt) override + if (stmt->PredicateExpression) { - PushOuterStmt(stmt); - if (stmt->InitialStatement) - { - stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>(); - } - if (stmt->PredicateExpression) - { - stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->PredicateExpression->Type->Equals(ExpressionType::GetBool()) && - !stmt->PredicateExpression->Type->Equals(ExpressionType::GetInt()) && - !stmt->PredicateExpression->Type->Equals(ExpressionType::GetUInt())) - { - getSink()->diagnose(stmt->PredicateExpression.Ptr(), Diagnostics::forPredicateTypeError); - } - } - if (stmt->SideEffectExpression) + stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->PredicateExpression->Type->Equals(ExpressionType::GetBool()) && + !stmt->PredicateExpression->Type->Equals(ExpressionType::GetInt()) && + !stmt->PredicateExpression->Type->Equals(ExpressionType::GetUInt())) { - stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>(); + getSink()->diagnose(stmt->PredicateExpression.Ptr(), Diagnostics::forPredicateTypeError); } - stmt->Statement->Accept(this); - - PopOuterStmt(stmt); - return stmt; } - virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt) override + if (stmt->SideEffectExpression) { - PushOuterStmt(stmt); - // TODO(tfoley): need to coerce condition to an integral type... - stmt->condition = CheckExpr(stmt->condition); - stmt->body->Accept(this); - PopOuterStmt(stmt); - return stmt; + stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>(); } - virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt) override - { - auto expr = CheckExpr(stmt->expr); - auto switchStmt = FindOuterStmt<SwitchStmt>(); + stmt->Statement->Accept(this); - if (!switchStmt) - { - getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); - } - else - { - // TODO: need to do some basic matching to ensure the type - // for the `case` is consistent with the type for the `switch`... - } - - stmt->expr = expr; - stmt->parentStmt = switchStmt; + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt) override + { + PushOuterStmt(stmt); + // TODO(tfoley): need to coerce condition to an integral type... + stmt->condition = CheckExpr(stmt->condition); + stmt->body->Accept(this); + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt) override + { + auto expr = CheckExpr(stmt->expr); + auto switchStmt = FindOuterStmt<SwitchStmt>(); - return stmt; + if (!switchStmt) + { + getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); } - virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt) override + else { - auto switchStmt = FindOuterStmt<SwitchStmt>(); - if (!switchStmt) - { - getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); - } - stmt->parentStmt = switchStmt; - return stmt; + // TODO: need to do some basic matching to ensure the type + // for the `case` is consistent with the type for the `switch`... } - virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode *stmt) override + + stmt->expr = expr; + stmt->parentStmt = switchStmt; + + return stmt; + } + virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt) override + { + auto switchStmt = FindOuterStmt<SwitchStmt>(); + if (!switchStmt) { - auto condition = stmt->Predicate; - condition = CheckTerm(condition); - condition = Coerce(ExpressionType::GetBool(), condition); + getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); + } + stmt->parentStmt = switchStmt; + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode *stmt) override + { + auto condition = stmt->Predicate; + condition = CheckTerm(condition); + condition = Coerce(ExpressionType::GetBool(), condition); - stmt->Predicate = condition; + stmt->Predicate = condition; #if 0 - if (stmt->Predicate != NULL) - stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) - && (!stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetBool()))) - getSink()->diagnose(stmt, Diagnostics::ifPredicateTypeError); + if (stmt->Predicate != NULL) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) + && (!stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetBool()))) + getSink()->diagnose(stmt, Diagnostics::ifPredicateTypeError); #endif - if (stmt->PositiveStatement != NULL) - stmt->PositiveStatement->Accept(this); + if (stmt->PositiveStatement != NULL) + stmt->PositiveStatement->Accept(this); - if (stmt->NegativeStatement != NULL) - stmt->NegativeStatement->Accept(this); - return stmt; + if (stmt->NegativeStatement != NULL) + stmt->NegativeStatement->Accept(this); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode *stmt) override + { + if (!stmt->Expression) + { + if (function && !function->ReturnType.Equals(ExpressionType::GetVoid())) + getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); } - virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode *stmt) override + else { - if (!stmt->Expression) - { - if (function && !function->ReturnType.Equals(ExpressionType::GetVoid())) - getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); - } - else + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Expression->Type->Equals(ExpressionType::Error.Ptr())) { - stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->Expression->Type->Equals(ExpressionType::Error.Ptr())) + if (function) { - if (function) - { - stmt->Expression = Coerce(function->ReturnType, stmt->Expression); - } - else - { - // TODO(tfoley): this case currently gets triggered for member functions, - // which aren't being checked consistently (because of the whole symbol - // table idea getting in the way). + stmt->Expression = Coerce(function->ReturnType, stmt->Expression); + } + else + { + // TODO(tfoley): this case currently gets triggered for member functions, + // which aren't being checked consistently (because of the whole symbol + // table idea getting in the way). // getSink()->diagnose(stmt, Diagnostics::unimplemented, "case for return stmt"); - } } } - return stmt; } + return stmt; + } - int GetMinBound(RefPtr<IntVal> val) - { - if (auto constantVal = val.As<ConstantIntVal>()) - return constantVal->value; + int GetMinBound(RefPtr<IntVal> val) + { + if (auto constantVal = val.As<ConstantIntVal>()) + return constantVal->value; - // TODO(tfoley): Need to track intervals so that this isn't just a lie... - return 1; - } + // TODO(tfoley): Need to track intervals so that this isn't just a lie... + return 1; + } - void maybeInferArraySizeForVariable(Variable* varDecl) - { - // Not an array? - auto arrayType = varDecl->Type->AsArrayType(); - if (!arrayType) return; + void maybeInferArraySizeForVariable(Variable* varDecl) + { + // Not an array? + auto arrayType = varDecl->Type->AsArrayType(); + if (!arrayType) return; - // Explicit element count given? - auto elementCount = arrayType->ArrayLength; - if (elementCount) return; + // Explicit element count given? + auto elementCount = arrayType->ArrayLength; + if (elementCount) return; - // No initializer? - auto initExpr = varDecl->Expr; - if(!initExpr) return; + // No initializer? + auto initExpr = varDecl->Expr; + if(!initExpr) return; - // Is the initializer an initializer list? - if(auto initializerListExpr = initExpr.As<InitializerListExpr>()) - { - auto argCount = initializerListExpr->args.Count(); - elementCount = new ConstantIntVal(argCount); - } - // Is the type of the initializer an array type? - else if(auto arrayInitType = initExpr->Type->As<ArrayExpressionType>()) - { - elementCount = arrayInitType->ArrayLength; - } - else - { - // Nothing to do: we couldn't infer a size - return; - } + // Is the initializer an initializer list? + if(auto initializerListExpr = initExpr.As<InitializerListExpr>()) + { + auto argCount = initializerListExpr->args.Count(); + elementCount = new ConstantIntVal(argCount); + } + // Is the type of the initializer an array type? + else if(auto arrayInitType = initExpr->Type->As<ArrayExpressionType>()) + { + elementCount = arrayInitType->ArrayLength; + } + else + { + // Nothing to do: we couldn't infer a size + return; + } - // Create a new array type based on the size we found, - // and install it into our type. - auto newArrayType = new ArrayExpressionType(); - newArrayType->BaseType = arrayType->BaseType; - newArrayType->ArrayLength = elementCount; + // Create a new array type based on the size we found, + // and install it into our type. + auto newArrayType = new ArrayExpressionType(); + newArrayType->BaseType = arrayType->BaseType; + newArrayType->ArrayLength = elementCount; - // Okay we are good to go! - varDecl->Type.type = newArrayType; - } + // Okay we are good to go! + varDecl->Type.type = newArrayType; + } - void ValidateArraySizeForVariable(Variable* varDecl) - { - auto arrayType = varDecl->Type->AsArrayType(); - if (!arrayType) return; + void ValidateArraySizeForVariable(Variable* varDecl) + { + auto arrayType = varDecl->Type->AsArrayType(); + if (!arrayType) return; - auto elementCount = arrayType->ArrayLength; - if (!elementCount) - { - // Note(tfoley): For now we allow arrays of unspecified size - // everywhere, because some source languages (e.g., GLSL) - // allow them in specific cases. + auto elementCount = arrayType->ArrayLength; + if (!elementCount) + { + // Note(tfoley): For now we allow arrays of unspecified size + // everywhere, because some source languages (e.g., GLSL) + // allow them in specific cases. #if 0 - getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); #endif - return; - } - - // TODO(tfoley): How to handle the case where bound isn't known? - if (GetMinBound(elementCount) <= 0) - { - getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); - return; - } + return; } - virtual RefPtr<Variable> VisitDeclrVariable(Variable* varDecl) + // TODO(tfoley): How to handle the case where bound isn't known? + if (GetMinBound(elementCount) <= 0) { - TypeExp typeExp = CheckUsableType(varDecl->Type); + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); + return; + } + } + + virtual RefPtr<Variable> VisitDeclrVariable(Variable* varDecl) + { + TypeExp typeExp = CheckUsableType(varDecl->Type); #if 0 - if (typeExp.type->GetBindableResourceType() != BindableResourceType::NonBindable) + if (typeExp.type->GetBindableResourceType() != BindableResourceType::NonBindable) + { + // We don't want to allow bindable resource types as local variables (at least for now). + auto parentDecl = varDecl->ParentDecl; + if (auto parentScopeDecl = dynamic_cast<ScopeDecl*>(parentDecl)) { - // We don't want to allow bindable resource types as local variables (at least for now). - auto parentDecl = varDecl->ParentDecl; - if (auto parentScopeDecl = dynamic_cast<ScopeDecl*>(parentDecl)) - { - getSink()->diagnose(varDecl->Type, Diagnostics::invalidTypeForLocalVariable); - } + getSink()->diagnose(varDecl->Type, Diagnostics::invalidTypeForLocalVariable); } + } #endif - varDecl->Type = typeExp; - if (varDecl->Type.Equals(ExpressionType::GetVoid())) - getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); - - if(auto initExpr = varDecl->Expr) - { - initExpr = CheckTerm(initExpr); - varDecl->Expr = initExpr; - } - - // If this is an array variable, then we first want to give - // it a chance to infer an array size from its initializer - // - // TODO(tfoley): May need to extend this to handle the - // multi-dimensional case... - maybeInferArraySizeForVariable(varDecl); - // - // Next we want to make sure that the declared (or inferred) - // size for the array meets whatever language-specific - // constraints we want to enforce (e.g., disallow empty - // arrays in specific cases) - ValidateArraySizeForVariable(varDecl); + varDecl->Type = typeExp; + if (varDecl->Type.Equals(ExpressionType::GetVoid())) + getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); + if(auto initExpr = varDecl->Expr) + { + initExpr = CheckTerm(initExpr); + varDecl->Expr = initExpr; + } - if(auto initExpr = varDecl->Expr) - { - // TODO(tfoley): should coercion of initializer lists be special-cased - // here, or handled as a general case for coercion? + // If this is an array variable, then we first want to give + // it a chance to infer an array size from its initializer + // + // TODO(tfoley): May need to extend this to handle the + // multi-dimensional case... + maybeInferArraySizeForVariable(varDecl); + // + // Next we want to make sure that the declared (or inferred) + // size for the array meets whatever language-specific + // constraints we want to enforce (e.g., disallow empty + // arrays in specific cases) + ValidateArraySizeForVariable(varDecl); - initExpr = Coerce(varDecl->Type, initExpr); - varDecl->Expr = initExpr; - } - varDecl->SetCheckState(DeclCheckState::Checked); + if(auto initExpr = varDecl->Expr) + { + // TODO(tfoley): should coercion of initializer lists be special-cased + // here, or handled as a general case for coercion? - return varDecl; + initExpr = Coerce(varDecl->Type, initExpr); + varDecl->Expr = initExpr; } - virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode *stmt) override - { - PushOuterStmt(stmt); - stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) - getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError2); + varDecl->SetCheckState(DeclCheckState::Checked); - stmt->Statement->Accept(this); - PopOuterStmt(stmt); - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode *stmt) override - { - stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); - return stmt; - } - virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode *expr) override - { + return varDecl; + } + + virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode *stmt) override + { + PushOuterStmt(stmt); + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) + getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError2); + + stmt->Statement->Accept(this); + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode *stmt) override + { + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode *expr) override + { #if 0 - for (int i = 0; i < expr->Arguments.Count(); i++) - expr->Arguments[i] = expr->Arguments[i]->Accept(this).As<ExpressionSyntaxNode>(); - auto & leftType = expr->Arguments[0]->Type; - QualType rightType; - if (expr->Arguments.Count() == 2) - rightType = expr->Arguments[1]->Type; - RefPtr<ExpressionType> matchedType; - auto checkAssign = [&]() + for (int i = 0; i < expr->Arguments.Count(); i++) + expr->Arguments[i] = expr->Arguments[i]->Accept(this).As<ExpressionSyntaxNode>(); + auto & leftType = expr->Arguments[0]->Type; + QualType rightType; + if (expr->Arguments.Count() == 2) + rightType = expr->Arguments[1]->Type; + RefPtr<ExpressionType> matchedType; + auto checkAssign = [&]() + { + if (!leftType.IsLeftValue && + !leftType->Equals(ExpressionType::Error.Ptr())) + getSink()->diagnose(expr->Arguments[0].Ptr(), Diagnostics::assignNonLValue); + if (expr->Operator == Operator::AndAssign || + expr->Operator == Operator::OrAssign || + expr->Operator == Operator::XorAssign || + expr->Operator == Operator::LshAssign || + expr->Operator == Operator::RshAssign) { - if (!leftType.IsLeftValue && - !leftType->Equals(ExpressionType::Error.Ptr())) - getSink()->diagnose(expr->Arguments[0].Ptr(), Diagnostics::assignNonLValue); - if (expr->Operator == Operator::AndAssign || - expr->Operator == Operator::OrAssign || - expr->Operator == Operator::XorAssign || - expr->Operator == Operator::LshAssign || - expr->Operator == Operator::RshAssign) - { #if 0 - if (!(leftType->IsIntegral() && rightType->IsIntegral())) - { - // TODO(tfoley): This diagnostic shouldn't be handled here + if (!(leftType->IsIntegral() && rightType->IsIntegral())) + { + // TODO(tfoley): This diagnostic shouldn't be handled here // getSink()->diagnose(expr, Diagnostics::bitOperationNonIntegral); - } -#endif } - - // TODO(tfoley): Need to actual insert coercion here... - if(CanCoerce(leftType, expr->Type)) - expr->Type = leftType; - else - expr->Type = ExpressionType::Error; - }; -#if 0 - if (expr->Operator == Operator::Assign) - { - expr->Type = rightType; - checkAssign(); - } - else #endif - { - expr->FunctionExpr = CheckExpr(expr->FunctionExpr); - CheckInvokeExprWithCheckedOperands(expr); - if (expr->Operator > Operator::Assign) - checkAssign(); } - return expr; -#endif - // Treat operator application just like a function call - return VisitInvokeExpression(expr); + // TODO(tfoley): Need to actual insert coercion here... + if(CanCoerce(leftType, expr->Type)) + expr->Type = leftType; + else + expr->Type = ExpressionType::Error; + }; +#if 0 + if (expr->Operator == Operator::Assign) + { + expr->Type = rightType; + checkAssign(); } - virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode *expr) override + else +#endif { - switch (expr->ConstType) - { - case ConstantExpressionSyntaxNode::ConstantType::Int: - expr->Type = ExpressionType::GetInt(); - break; - case ConstantExpressionSyntaxNode::ConstantType::Bool: - expr->Type = ExpressionType::GetBool(); - break; - case ConstantExpressionSyntaxNode::ConstantType::Float: - expr->Type = ExpressionType::GetFloat(); - break; - default: - expr->Type = ExpressionType::Error; - throw "Invalid constant type."; - break; - } - return expr; + expr->FunctionExpr = CheckExpr(expr->FunctionExpr); + CheckInvokeExprWithCheckedOperands(expr); + if (expr->Operator > Operator::Assign) + checkAssign(); } + return expr; +#endif - IntVal* GetIntVal(ConstantExpressionSyntaxNode* expr) - { - // TODO(tfoley): don't keep allocating here! - return new ConstantIntVal(expr->IntValue); + // Treat operator application just like a function call + return VisitInvokeExpression(expr); + } + virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode *expr) override + { + switch (expr->ConstType) + { + case ConstantExpressionSyntaxNode::ConstantType::Int: + expr->Type = ExpressionType::GetInt(); + break; + case ConstantExpressionSyntaxNode::ConstantType::Bool: + expr->Type = ExpressionType::GetBool(); + break; + case ConstantExpressionSyntaxNode::ConstantType::Float: + expr->Type = ExpressionType::GetFloat(); + break; + default: + expr->Type = ExpressionType::Error; + throw "Invalid constant type."; + break; } + return expr; + } - RefPtr<IntVal> TryConstantFoldExpr( - InvokeExpressionSyntaxNode* invokeExpr) - { - // We need all the operands to the expression + IntVal* GetIntVal(ConstantExpressionSyntaxNode* expr) + { + // TODO(tfoley): don't keep allocating here! + return new ConstantIntVal(expr->IntValue); + } - // Check if the callee is an operation that is amenable to constant-folding. - // - // For right now we will look for calls to intrinsic functions, and then inspect - // their names (this is bad and slow). - auto funcDeclRefExpr = invokeExpr->FunctionExpr.As<DeclRefExpr>(); - if (!funcDeclRefExpr) return nullptr; - - auto funcDeclRef = funcDeclRefExpr->declRef; - auto intrinsicMod = funcDeclRef.GetDecl()->FindModifier<IntrinsicOpModifier>(); - if (!intrinsicMod) return nullptr; - - // Let's not constant-fold operations with more than a certain number of arguments, for simplicity - static const int kMaxArgs = 8; - if (invokeExpr->Arguments.Count() > kMaxArgs) - return nullptr; + RefPtr<IntVal> TryConstantFoldExpr( + InvokeExpressionSyntaxNode* invokeExpr) + { + // We need all the operands to the expression - // Before checking the operation name, let's look at the arguments - RefPtr<IntVal> argVals[kMaxArgs]; - int constArgVals[kMaxArgs]; - int argCount = 0; - bool allConst = true; - for (auto argExpr : invokeExpr->Arguments) - { - auto argVal = TryCheckIntegerConstantExpression(argExpr.Ptr()); - if (!argVal) - return nullptr; + // Check if the callee is an operation that is amenable to constant-folding. + // + // For right now we will look for calls to intrinsic functions, and then inspect + // their names (this is bad and slow). + auto funcDeclRefExpr = invokeExpr->FunctionExpr.As<DeclRefExpr>(); + if (!funcDeclRefExpr) return nullptr; + + auto funcDeclRef = funcDeclRefExpr->declRef; + auto intrinsicMod = funcDeclRef.GetDecl()->FindModifier<IntrinsicOpModifier>(); + if (!intrinsicMod) return nullptr; + + // Let's not constant-fold operations with more than a certain number of arguments, for simplicity + static const int kMaxArgs = 8; + if (invokeExpr->Arguments.Count() > kMaxArgs) + return nullptr; - argVals[argCount] = argVal; + // Before checking the operation name, let's look at the arguments + RefPtr<IntVal> argVals[kMaxArgs]; + int constArgVals[kMaxArgs]; + int argCount = 0; + bool allConst = true; + for (auto argExpr : invokeExpr->Arguments) + { + auto argVal = TryCheckIntegerConstantExpression(argExpr.Ptr()); + if (!argVal) + return nullptr; - if (auto constArgVal = argVal.As<ConstantIntVal>()) - { - constArgVals[argCount] = constArgVal->value; - } - else - { - allConst = false; - } - argCount++; - } + argVals[argCount] = argVal; - if (!allConst) + if (auto constArgVal = argVal.As<ConstantIntVal>()) { - // TODO(tfoley): We probably want to support a very limited number of operations - // on "constants" that aren't actually known, to be able to handle a generic - // that takes an integer `N` but then constructs a vector of size `N+1`. - // - // The hard part there is implementing the rules for value unification in the - // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd - // need inference to be smart enough to know that `2 + N` and `N + 2` are the - // same value, as are `N + M + 1 + 1` and `M + 2 + N`. - // - // For now we can just bail in this case. - return nullptr; + constArgVals[argCount] = constArgVal->value; + } + else + { + allConst = false; } + argCount++; + } - // At this point, all the operands had simple integer values, so we are golden. - int resultValue = 0; - auto opName = funcDeclRef.GetName(); + if (!allConst) + { + // TODO(tfoley): We probably want to support a very limited number of operations + // on "constants" that aren't actually known, to be able to handle a generic + // that takes an integer `N` but then constructs a vector of size `N+1`. + // + // The hard part there is implementing the rules for value unification in the + // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd + // need inference to be smart enough to know that `2 + N` and `N + 2` are the + // same value, as are `N + M + 1 + 1` and `M + 2 + N`. + // + // For now we can just bail in this case. + return nullptr; + } + + // At this point, all the operands had simple integer values, so we are golden. + int resultValue = 0; + auto opName = funcDeclRef.GetName(); - // handle binary operators - if (opName == "-") + // handle binary operators + if (opName == "-") + { + if (argCount == 1) { - if (argCount == 1) - { - resultValue = -constArgVals[0]; - } - else if (argCount == 2) - { - resultValue = constArgVals[0] - constArgVals[1]; - } + resultValue = -constArgVals[0]; + } + else if (argCount == 2) + { + resultValue = constArgVals[0] - constArgVals[1]; } + } - // simple binary operators + // simple binary operators #define CASE(OP) \ - else if(opName == #OP) do { \ - if(argCount != 2) return nullptr; \ - resultValue = constArgVals[0] OP constArgVals[1]; \ - } while(0) + else if(opName == #OP) do { \ + if(argCount != 2) return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } while(0) - CASE(+); // TODO: this can also be unary... - CASE(*); + CASE(+); // TODO: this can also be unary... + CASE(*); #undef CASE - // binary operators with chance of divide-by-zero - // TODO: issue a suitable error in that case + // binary operators with chance of divide-by-zero + // TODO: issue a suitable error in that case #define CASE(OP) \ - else if(opName == #OP) do { \ - if(argCount != 2) return nullptr; \ - if(!constArgVals[1]) return nullptr; \ - resultValue = constArgVals[0] OP constArgVals[1]; \ - } while(0) - - CASE(/); - CASE(%); + else if(opName == #OP) do { \ + if(argCount != 2) return nullptr; \ + if(!constArgVals[1]) return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } while(0) + + CASE(/); + CASE(%); #undef CASE - // TODO(tfoley): more cases - else - { - return nullptr; - } + // TODO(tfoley): more cases + else + { + return nullptr; + } + + RefPtr<IntVal> result = new ConstantIntVal(resultValue); + return result; + } - RefPtr<IntVal> result = new ConstantIntVal(resultValue); - return result; + RefPtr<IntVal> TryConstantFoldExpr( + ExpressionSyntaxNode* expr) + { + // TODO(tfoley): more serious constant folding here + if (auto constExp = dynamic_cast<ConstantExpressionSyntaxNode*>(expr)) + { + return GetIntVal(constExp); } - RefPtr<IntVal> TryConstantFoldExpr( - ExpressionSyntaxNode* expr) + // it is possible that we are referring to a generic value param + if (auto declRefExpr = dynamic_cast<DeclRefExpr*>(expr)) { - // TODO(tfoley): more serious constant folding here - if (auto constExp = dynamic_cast<ConstantExpressionSyntaxNode*>(expr)) + auto declRef = declRefExpr->declRef; + + if (auto genericValParamRef = declRef.As<GenericValueParamDeclRef>()) { - return GetIntVal(constExp); + // TODO(tfoley): handle the case of non-`int` value parameters... + return new GenericParamIntVal(genericValParamRef); } - // it is possible that we are referring to a generic value param - if (auto declRefExpr = dynamic_cast<DeclRefExpr*>(expr)) + // We may also need to check for references to variables that + // are defined in a way that can be used as a constant expression: + if(auto varRef = declRef.As<VarDeclBaseRef>()) { - auto declRef = declRefExpr->declRef; - - if (auto genericValParamRef = declRef.As<GenericValueParamDeclRef>()) - { - // TODO(tfoley): handle the case of non-`int` value parameters... - return new GenericParamIntVal(genericValParamRef); - } + auto varDecl = varRef.GetDecl(); - // We may also need to check for references to variables that - // are defined in a way that can be used as a constant expression: - if(auto varRef = declRef.As<VarDeclBaseRef>()) + switch(sourceLanguage) { - auto varDecl = varRef.GetDecl(); - - switch(sourceLanguage) + case SourceLanguage::Slang: + case SourceLanguage::HLSL: + // HLSL: `static const` is used to mark compile-time constant expressions + if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) { - case SourceLanguage::Slang: - case SourceLanguage::HLSL: - // HLSL: `static const` is used to mark compile-time constant expressions - if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) - { - if(auto constAttr = varDecl->FindModifier<ConstModifier>()) - { - // HLSL `static const` can be used as a constant expression - if(auto initExpr = varRef.getInitExpr()) - { - return TryConstantFoldExpr(initExpr.Ptr()); - } - } - } - break; - - case SourceLanguage::GLSL: - // GLSL: `const` indicates compile-time constant expression - // - // TODO(tfoley): The current logic here isn't robust against - // GLSL "specialization constants" - we will extract the - // initializer for a `const` variable and use it to extract - // a value, when we really should be using an opaque - // reference to the variable. if(auto constAttr = varDecl->FindModifier<ConstModifier>()) { - // We need to handle a "specialization constant" (with a `constant_id` layout modifier) - // differently from an ordinary compile-time constant. The latter can/should be reduced - // to a value, while the former should be kept as a symbolic reference - - if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>()) - { - // Retain the specialization constant as a symbolic reference - // - // TODO(tfoley): handle the case of non-`int` value parameters... - // - // TODO(tfoley): this is cloned from the case above that handles generic value parameters - return new GenericParamIntVal(varRef); - } - else if(auto initExpr = varRef.getInitExpr()) + // HLSL `static const` can be used as a constant expression + if(auto initExpr = varRef.getInitExpr()) { - // This is an ordinary constant, and not a specialization constant, so we - // can try to fold its value right now. return TryConstantFoldExpr(initExpr.Ptr()); } } - break; } + break; + case SourceLanguage::GLSL: + // GLSL: `const` indicates compile-time constant expression + // + // TODO(tfoley): The current logic here isn't robust against + // GLSL "specialization constants" - we will extract the + // initializer for a `const` variable and use it to extract + // a value, when we really should be using an opaque + // reference to the variable. + if(auto constAttr = varDecl->FindModifier<ConstModifier>()) + { + // We need to handle a "specialization constant" (with a `constant_id` layout modifier) + // differently from an ordinary compile-time constant. The latter can/should be reduced + // to a value, while the former should be kept as a symbolic reference + + if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>()) + { + // Retain the specialization constant as a symbolic reference + // + // TODO(tfoley): handle the case of non-`int` value parameters... + // + // TODO(tfoley): this is cloned from the case above that handles generic value parameters + return new GenericParamIntVal(varRef); + } + else if(auto initExpr = varRef.getInitExpr()) + { + // This is an ordinary constant, and not a specialization constant, so we + // can try to fold its value right now. + return TryConstantFoldExpr(initExpr.Ptr()); + } + } + break; } - } - if (auto invokeExpr = dynamic_cast<InvokeExpressionSyntaxNode*>(expr)) - { - auto val = TryConstantFoldExpr(invokeExpr); - if (val) - return val; - } - else if(auto castExpr = dynamic_cast<TypeCastExpressionSyntaxNode*>(expr)) - { - auto val = TryConstantFoldExpr(castExpr->Expression.Ptr()); - if(val) - return val; } + } - return nullptr; + if (auto invokeExpr = dynamic_cast<InvokeExpressionSyntaxNode*>(expr)) + { + auto val = TryConstantFoldExpr(invokeExpr); + if (val) + return val; + } + else if(auto castExpr = dynamic_cast<TypeCastExpressionSyntaxNode*>(expr)) + { + auto val = TryConstantFoldExpr(castExpr->Expression.Ptr()); + if(val) + return val; } - // Try to check an integer constant expression, either returning the value, - // or NULL if the expression isn't recognized as a constant. - RefPtr<IntVal> TryCheckIntegerConstantExpression(ExpressionSyntaxNode* exp) + return nullptr; + } + + // Try to check an integer constant expression, either returning the value, + // or NULL if the expression isn't recognized as a constant. + RefPtr<IntVal> TryCheckIntegerConstantExpression(ExpressionSyntaxNode* exp) + { + if (!exp->Type.type->Equals(ExpressionType::GetInt())) { - if (!exp->Type.type->Equals(ExpressionType::GetInt())) - { - return nullptr; - } + return nullptr; + } - // Otherwise, we need to consider operations that we might be able to constant-fold... - return TryConstantFoldExpr(exp); - } + // Otherwise, we need to consider operations that we might be able to constant-fold... + return TryConstantFoldExpr(exp); + } - // Enforce that an expression resolves to an integer constant, and get its value - RefPtr<IntVal> CheckIntegerConstantExpression(ExpressionSyntaxNode* inExpr) + // Enforce that an expression resolves to an integer constant, and get its value + RefPtr<IntVal> CheckIntegerConstantExpression(ExpressionSyntaxNode* inExpr) + { + // First coerce the expression to the expected type + auto expr = Coerce(ExpressionType::GetInt(),inExpr); + auto result = TryCheckIntegerConstantExpression(expr.Ptr()); + if (!result) { - // First coerce the expression to the expected type - auto expr = Coerce(ExpressionType::GetInt(),inExpr); - auto result = TryCheckIntegerConstantExpression(expr.Ptr()); - if (!result) - { - getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); - } - return result; + getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); } + return result; + } - RefPtr<ExpressionSyntaxNode> CheckSimpleSubscriptExpr( - RefPtr<IndexExpressionSyntaxNode> subscriptExpr, - RefPtr<ExpressionType> elementType) - { - auto baseExpr = subscriptExpr->BaseExpression; - auto indexExpr = subscriptExpr->IndexExpression; + RefPtr<ExpressionSyntaxNode> CheckSimpleSubscriptExpr( + RefPtr<IndexExpressionSyntaxNode> subscriptExpr, + RefPtr<ExpressionType> elementType) + { + auto baseExpr = subscriptExpr->BaseExpression; + auto indexExpr = subscriptExpr->IndexExpression; - if (!indexExpr->Type->Equals(ExpressionType::GetInt()) && - !indexExpr->Type->Equals(ExpressionType::GetUInt())) - { - getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); - return CreateErrorExpr(subscriptExpr.Ptr()); - } + if (!indexExpr->Type->Equals(ExpressionType::GetInt()) && + !indexExpr->Type->Equals(ExpressionType::GetUInt())) + { + getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); + return CreateErrorExpr(subscriptExpr.Ptr()); + } - subscriptExpr->Type = elementType; + subscriptExpr->Type = elementType; - // TODO(tfoley): need to be more careful about this stuff - subscriptExpr->Type.IsLeftValue = baseExpr->Type.IsLeftValue; + // TODO(tfoley): need to be more careful about this stuff + subscriptExpr->Type.IsLeftValue = baseExpr->Type.IsLeftValue; - return subscriptExpr; - } + return subscriptExpr; + } - // The way that we have designed out type system, pretyt much *every* - // type is a reference to some declaration in the standard library. - // That means that when we construct a new type on the fly, we need - // to make sure that it is wired up to reference the appropriate - // declaration, or else it won't compare as equal to other types - // that *do* reference the declaration. - // - // This function is used to construct a `vector<T,N>` type - // programmatically, so that it will work just like a type of - // that form constructed by the user. - RefPtr<VectorExpressionType> createVectorType( - RefPtr<ExpressionType> elementType, - RefPtr<IntVal> elementCount) - { - auto vectorGenericDecl = findMagicDecl("Vector").As<GenericDecl>(); - auto vectorTypeDecl = vectorGenericDecl->inner; + // The way that we have designed out type system, pretyt much *every* + // type is a reference to some declaration in the standard library. + // That means that when we construct a new type on the fly, we need + // to make sure that it is wired up to reference the appropriate + // declaration, or else it won't compare as equal to other types + // that *do* reference the declaration. + // + // This function is used to construct a `vector<T,N>` type + // programmatically, so that it will work just like a type of + // that form constructed by the user. + RefPtr<VectorExpressionType> createVectorType( + RefPtr<ExpressionType> elementType, + RefPtr<IntVal> elementCount) + { + auto vectorGenericDecl = findMagicDecl("Vector").As<GenericDecl>(); + auto vectorTypeDecl = vectorGenericDecl->inner; - auto substitutions = new Substitutions(); - substitutions->genericDecl = vectorGenericDecl.Ptr(); - substitutions->args.Add(elementType); - substitutions->args.Add(elementCount); + auto substitutions = new Substitutions(); + substitutions->genericDecl = vectorGenericDecl.Ptr(); + substitutions->args.Add(elementType); + substitutions->args.Add(elementCount); - auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions); + auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions); + + return DeclRefType::Create(declRef)->As<VectorExpressionType>(); + } + + virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* subscriptExpr) override + { + auto baseExpr = subscriptExpr->BaseExpression; + baseExpr = CheckExpr(baseExpr); - return DeclRefType::Create(declRef)->As<VectorExpressionType>(); + RefPtr<ExpressionSyntaxNode> indexExpr = subscriptExpr->IndexExpression; + if (indexExpr) + { + indexExpr = CheckExpr(indexExpr); } - virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* subscriptExpr) override + subscriptExpr->BaseExpression = baseExpr; + subscriptExpr->IndexExpression = indexExpr; + + // If anything went wrong in the base expression, + // then just move along... + if (IsErrorExpr(baseExpr)) + return CreateErrorExpr(subscriptExpr); + + // Otherwise, we need to look at the type of the base expression, + // to figure out how subscripting should work. + auto baseType = baseExpr->Type.Ptr(); + if (auto baseTypeType = baseType->As<TypeType>()) { - auto baseExpr = subscriptExpr->BaseExpression; - baseExpr = CheckExpr(baseExpr); + // We are trying to "index" into a type, so we have an expression like `float[2]` + // which should be interpreted as resolving to an array type. - RefPtr<ExpressionSyntaxNode> indexExpr = subscriptExpr->IndexExpression; + RefPtr<IntVal> elementCount = nullptr; if (indexExpr) { - indexExpr = CheckExpr(indexExpr); + elementCount = CheckIntegerConstantExpression(indexExpr.Ptr()); } - subscriptExpr->BaseExpression = baseExpr; - subscriptExpr->IndexExpression = indexExpr; + auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); + auto arrayType = new ArrayExpressionType(); + arrayType->BaseType = elementType; + arrayType->ArrayLength = elementCount; - // If anything went wrong in the base expression, - // then just move along... - if (IsErrorExpr(baseExpr)) - return CreateErrorExpr(subscriptExpr); - - // Otherwise, we need to look at the type of the base expression, - // to figure out how subscripting should work. - auto baseType = baseExpr->Type.Ptr(); - if (auto baseTypeType = baseType->As<TypeType>()) - { - // We are trying to "index" into a type, so we have an expression like `float[2]` - // which should be interpreted as resolving to an array type. + typeResult = arrayType; + subscriptExpr->Type = new TypeType(arrayType); + return subscriptExpr; + } + else if (auto baseArrayType = baseType->As<ArrayExpressionType>()) + { + return CheckSimpleSubscriptExpr( + subscriptExpr, + baseArrayType->BaseType); + } + else if (auto vecType = baseType->As<VectorExpressionType>()) + { + return CheckSimpleSubscriptExpr( + subscriptExpr, + vecType->elementType); + } + else if (auto matType = baseType->As<MatrixExpressionType>()) + { + // TODO(tfoley): We shouldn't go and recompute + // row types over and over like this... :( + auto rowType = createVectorType( + matType->getElementType(), + matType->getColumnCount()); - RefPtr<IntVal> elementCount = nullptr; - if (indexExpr) - { - elementCount = CheckIntegerConstantExpression(indexExpr.Ptr()); - } + return CheckSimpleSubscriptExpr( + subscriptExpr, + rowType); + } - auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); - auto arrayType = new ArrayExpressionType(); - arrayType->BaseType = elementType; - arrayType->ArrayLength = elementCount; + // Default behavior is to look at all available `__subscript` + // declarations on the type and try to call one of them. - typeResult = arrayType; - subscriptExpr->Type = new TypeType(arrayType); - return subscriptExpr; - } - else if (auto baseArrayType = baseType->As<ArrayExpressionType>()) - { - return CheckSimpleSubscriptExpr( - subscriptExpr, - baseArrayType->BaseType); - } - else if (auto vecType = baseType->As<VectorExpressionType>()) - { - return CheckSimpleSubscriptExpr( - subscriptExpr, - vecType->elementType); - } - else if (auto matType = baseType->As<MatrixExpressionType>()) + if (auto declRefType = baseType->AsDeclRefType()) + { + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) { - // TODO(tfoley): We shouldn't go and recompute - // row types over and over like this... :( - auto rowType = createVectorType( - matType->getElementType(), - matType->getColumnCount()); - - return CheckSimpleSubscriptExpr( - subscriptExpr, - rowType); - } + // Checking of the type must be complete before we can reference its members safely + EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); - // Default behavior is to look at all available `__subscript` - // declarations on the type and try to call one of them. - - if (auto declRefType = baseType->AsDeclRefType()) - { - if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) + // Note(tfoley): The name used for lookup here is a bit magical, since + // it must match what the parser installed in subscript declarations. + LookupResult lookupResult = LookUpLocal("operator[]", aggTypeDeclRef); + if (!lookupResult.isValid()) { - // Checking of the type must be complete before we can reference its members safely - EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); - - // Note(tfoley): The name used for lookup here is a bit magical, since - // it must match what the parser installed in subscript declarations. - LookupResult lookupResult = LookUpLocal("operator[]", aggTypeDeclRef); - if (!lookupResult.isValid()) - { - goto fail; - } + goto fail; + } - RefPtr<ExpressionSyntaxNode> subscriptFuncExpr = createLookupResultExpr( - lookupResult, subscriptExpr->BaseExpression, subscriptExpr); + RefPtr<ExpressionSyntaxNode> subscriptFuncExpr = createLookupResultExpr( + lookupResult, subscriptExpr->BaseExpression, subscriptExpr); - // Now that we know there is at least one subscript member, - // we will construct a reference to it and try to call it + // Now that we know there is at least one subscript member, + // we will construct a reference to it and try to call it - RefPtr<InvokeExpressionSyntaxNode> subscriptCallExpr = new InvokeExpressionSyntaxNode(); - subscriptCallExpr->Position = subscriptExpr->Position; - subscriptCallExpr->FunctionExpr = subscriptFuncExpr; + RefPtr<InvokeExpressionSyntaxNode> subscriptCallExpr = new InvokeExpressionSyntaxNode(); + subscriptCallExpr->Position = subscriptExpr->Position; + subscriptCallExpr->FunctionExpr = subscriptFuncExpr; - // TODO(tfoley): This path can support multiple arguments easily - subscriptCallExpr->Arguments.Add(subscriptExpr->IndexExpression); + // TODO(tfoley): This path can support multiple arguments easily + subscriptCallExpr->Arguments.Add(subscriptExpr->IndexExpression); - return CheckInvokeExprWithCheckedOperands(subscriptCallExpr.Ptr()); - } - } - - fail: - { - getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); - return CreateErrorExpr(subscriptExpr); + return CheckInvokeExprWithCheckedOperands(subscriptCallExpr.Ptr()); } } - bool MatchArguments(FunctionSyntaxNode * functionNode, List <RefPtr<ExpressionSyntaxNode>> &args) + fail: { - if (functionNode->GetParameters().Count() != args.Count()) - return false; - int i = 0; - for (auto param : functionNode->GetParameters()) - { - if (!param->Type.Equals(args[i]->Type.Ptr())) - return false; - i++; - } - return true; + getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); + return CreateErrorExpr(subscriptExpr); } + } - // Coerce an expression to a specific type that it is expected to have in context - RefPtr<ExpressionSyntaxNode> CoerceExprToType( - RefPtr<ExpressionSyntaxNode> expr, - RefPtr<ExpressionType> type) + bool MatchArguments(FunctionSyntaxNode * functionNode, List <RefPtr<ExpressionSyntaxNode>> &args) + { + if (functionNode->GetParameters().Count() != args.Count()) + return false; + int i = 0; + for (auto param : functionNode->GetParameters()) { - // TODO(tfoley): clean this up so there is only one version... - return Coerce(type, expr); + if (!param->Type.Equals(args[i]->Type.Ptr())) + return false; + i++; } + return true; + } - // Resolve a call to a function, represented here - // by a symbol with a `FuncType` type. - RefPtr<ExpressionSyntaxNode> ResolveFunctionApp( - RefPtr<FuncType> funcType, - InvokeExpressionSyntaxNode* /*appExpr*/) - { - // TODO(tfoley): Actual checking logic needs to go here... + // Coerce an expression to a specific type that it is expected to have in context + RefPtr<ExpressionSyntaxNode> CoerceExprToType( + RefPtr<ExpressionSyntaxNode> expr, + RefPtr<ExpressionType> type) + { + // TODO(tfoley): clean this up so there is only one version... + return Coerce(type, expr); + } + + // Resolve a call to a function, represented here + // by a symbol with a `FuncType` type. + RefPtr<ExpressionSyntaxNode> ResolveFunctionApp( + RefPtr<FuncType> funcType, + InvokeExpressionSyntaxNode* /*appExpr*/) + { + // TODO(tfoley): Actual checking logic needs to go here... #if 0 - auto& args = appExpr->Arguments; - List<RefPtr<ParameterSyntaxNode>> params; - RefPtr<ExpressionType> resultType; - if (auto funcDeclRef = funcType->declRef) - { - EnsureDecl(funcDeclRef.GetDecl()); + auto& args = appExpr->Arguments; + List<RefPtr<ParameterSyntaxNode>> params; + RefPtr<ExpressionType> resultType; + if (auto funcDeclRef = funcType->declRef) + { + EnsureDecl(funcDeclRef.GetDecl()); - params = funcDeclRef->GetParameters().ToArray(); - resultType = funcDecl->ReturnType; - } - else if (auto funcSym = funcType->Func) - { - auto funcDecl = funcSym->SyntaxNode; - EnsureDecl(funcDecl); + params = funcDeclRef->GetParameters().ToArray(); + resultType = funcDecl->ReturnType; + } + else if (auto funcSym = funcType->Func) + { + auto funcDecl = funcSym->SyntaxNode; + EnsureDecl(funcDecl); - params = funcDecl->GetParameters().ToArray(); - resultType = funcDecl->ReturnType; - } - else if (auto componentFuncSym = funcType->Component) - { - auto componentFuncDecl = componentFuncSym->Implementations.First()->SyntaxNode; - params = componentFuncDecl->GetParameters().ToArray(); - resultType = componentFuncDecl->Type; - } + params = funcDecl->GetParameters().ToArray(); + resultType = funcDecl->ReturnType; + } + else if (auto componentFuncSym = funcType->Component) + { + auto componentFuncDecl = componentFuncSym->Implementations.First()->SyntaxNode; + params = componentFuncDecl->GetParameters().ToArray(); + resultType = componentFuncDecl->Type; + } - auto argCount = args.Count(); - auto paramCount = params.Count(); - if (argCount != paramCount) - { - getSink()->diagnose(appExpr, Diagnostics::unimplemented, "wrong number of arguments for call"); - appExpr->Type = ExpressionType::Error; - return appExpr; - } + auto argCount = args.Count(); + auto paramCount = params.Count(); + if (argCount != paramCount) + { + getSink()->diagnose(appExpr, Diagnostics::unimplemented, "wrong number of arguments for call"); + appExpr->Type = ExpressionType::Error; + return appExpr; + } - for (int ii = 0; ii < argCount; ++ii) - { - auto arg = args[ii]; - auto param = params[ii]; + for (int ii = 0; ii < argCount; ++ii) + { + auto arg = args[ii]; + auto param = params[ii]; - arg = CoerceExprToType(arg, param->Type); + arg = CoerceExprToType(arg, param->Type); - args[ii] = arg; - } + args[ii] = arg; + } - assert(resultType); - appExpr->Type = resultType; - return appExpr; + assert(resultType); + appExpr->Type = resultType; + return appExpr; #else - throw "unimplemented"; + throw "unimplemented"; #endif - } + } - // Resolve a constructor call, formed by apply a type to arguments - RefPtr<ExpressionSyntaxNode> ResolveConstructorApp( - RefPtr<ExpressionType> type, - InvokeExpressionSyntaxNode* appExpr) - { - // TODO(tfoley): Actual checking logic needs to go here... + // Resolve a constructor call, formed by apply a type to arguments + RefPtr<ExpressionSyntaxNode> ResolveConstructorApp( + RefPtr<ExpressionType> type, + InvokeExpressionSyntaxNode* appExpr) + { + // TODO(tfoley): Actual checking logic needs to go here... - appExpr->Type = type; - return appExpr; - } + appExpr->Type = type; + return appExpr; + } - // + // - virtual void VisitExtensionDecl(ExtensionDecl* decl) override - { - if (decl->IsChecked(DeclCheckState::Checked)) return; + virtual void VisitExtensionDecl(ExtensionDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return; - decl->SetCheckState(DeclCheckState::CheckingHeader); - decl->targetType = CheckProperType(decl->targetType); + decl->SetCheckState(DeclCheckState::CheckingHeader); + decl->targetType = CheckProperType(decl->targetType); - // TODO: need to check that the target type names a declaration... + // TODO: need to check that the target type names a declaration... - if (auto targetDeclRefType = decl->targetType->As<DeclRefType>()) - { - // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDeclRef>()) - { - auto aggTypeDecl = aggTypeDeclRef.GetDecl(); - decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; - aggTypeDecl->candidateExtensions = decl; - } - else - { - getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); - } - } - else if (decl->targetType->Equals(ExpressionType::Error)) + if (auto targetDeclRefType = decl->targetType->As<DeclRefType>()) + { + // Attach our extension to that type as a candidate... + if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDeclRef>()) { - // there was an error, so ignore + auto aggTypeDecl = aggTypeDeclRef.GetDecl(); + decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; + aggTypeDecl->candidateExtensions = decl; } else { getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); } - - decl->SetCheckState(DeclCheckState::CheckedHeader); - - // now check the members of the extension - for (auto m : decl->Members) - { - EnsureDecl(m); - } - - decl->SetCheckState(DeclCheckState::Checked); } - - virtual void VisitConstructorDecl(ConstructorDecl* decl) override + else if (decl->targetType->Equals(ExpressionType::Error)) { - if (decl->IsChecked(DeclCheckState::Checked)) return; - decl->SetCheckState(DeclCheckState::CheckingHeader); + // there was an error, so ignore + } + else + { + getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); + } - for (auto& paramDecl : decl->GetParameters()) - { - paramDecl->Type = CheckUsableType(paramDecl->Type); - } - decl->SetCheckState(DeclCheckState::CheckedHeader); + decl->SetCheckState(DeclCheckState::CheckedHeader); - // TODO(tfoley): check body - decl->SetCheckState(DeclCheckState::Checked); + // now check the members of the extension + for (auto m : decl->Members) + { + EnsureDecl(m); } + decl->SetCheckState(DeclCheckState::Checked); + } - virtual void visitSubscriptDecl(SubscriptDecl* decl) override + virtual void VisitConstructorDecl(ConstructorDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return; + decl->SetCheckState(DeclCheckState::CheckingHeader); + + for (auto& paramDecl : decl->GetParameters()) { - if (decl->IsChecked(DeclCheckState::Checked)) return; - decl->SetCheckState(DeclCheckState::CheckingHeader); + paramDecl->Type = CheckUsableType(paramDecl->Type); + } + decl->SetCheckState(DeclCheckState::CheckedHeader); - for (auto& paramDecl : decl->GetParameters()) - { - paramDecl->Type = CheckUsableType(paramDecl->Type); - } + // TODO(tfoley): check body + decl->SetCheckState(DeclCheckState::Checked); + } - decl->ReturnType = CheckUsableType(decl->ReturnType); - decl->SetCheckState(DeclCheckState::CheckedHeader); + virtual void visitSubscriptDecl(SubscriptDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return; + decl->SetCheckState(DeclCheckState::CheckingHeader); - decl->SetCheckState(DeclCheckState::Checked); + for (auto& paramDecl : decl->GetParameters()) + { + paramDecl->Type = CheckUsableType(paramDecl->Type); } - virtual void visitAccessorDecl(AccessorDecl* decl) override - { - // TODO: check the body! + decl->ReturnType = CheckUsableType(decl->ReturnType); - decl->SetCheckState(DeclCheckState::Checked); - } + decl->SetCheckState(DeclCheckState::CheckedHeader); + decl->SetCheckState(DeclCheckState::Checked); + } - // + virtual void visitAccessorDecl(AccessorDecl* decl) override + { + // TODO: check the body! - struct Constraint - { - Decl* decl; // the declaration of the thing being constraints - RefPtr<Val> val; // the value to which we are constraining it - bool satisfied = false; // Has this constraint been met? - }; + decl->SetCheckState(DeclCheckState::Checked); + } - // A collection of constraints that will need to be satisified (solved) - // in order for checking to suceed. - struct ConstraintSystem - { - List<Constraint> constraints; - }; - RefPtr<ExpressionType> TryJoinVectorAndScalarType( - RefPtr<VectorExpressionType> vectorType, - RefPtr<BasicExpressionType> scalarType) - { - // Join( vector<T,N>, S ) -> vetor<Join(T,S), N> - // - // That is, the join of a vector and a scalar type is - // a vector type with a joined element type. - auto joinElementType = TryJoinTypes( - vectorType->elementType, - scalarType); - if(!joinElementType) - return nullptr; + // - return createVectorType( - joinElementType, - vectorType->elementCount); - } + struct Constraint + { + Decl* decl; // the declaration of the thing being constraints + RefPtr<Val> val; // the value to which we are constraining it + bool satisfied = false; // Has this constraint been met? + }; + + // A collection of constraints that will need to be satisified (solved) + // in order for checking to suceed. + struct ConstraintSystem + { + List<Constraint> constraints; + }; + + RefPtr<ExpressionType> TryJoinVectorAndScalarType( + RefPtr<VectorExpressionType> vectorType, + RefPtr<BasicExpressionType> scalarType) + { + // Join( vector<T,N>, S ) -> vetor<Join(T,S), N> + // + // That is, the join of a vector and a scalar type is + // a vector type with a joined element type. + auto joinElementType = TryJoinTypes( + vectorType->elementType, + scalarType); + if(!joinElementType) + return nullptr; - bool DoesTypeConformToInterface( - RefPtr<ExpressionType> type, - InterfaceDeclRef interfaceDeclRef) + return createVectorType( + joinElementType, + vectorType->elementCount); + } + + bool DoesTypeConformToInterface( + RefPtr<ExpressionType> type, + InterfaceDeclRef interfaceDeclRef) + { + // for now look up a conformance member... + if(auto declRefType = type->As<DeclRefType>()) { - // for now look up a conformance member... - if(auto declRefType = type->As<DeclRefType>()) + if( auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>() ) { - if( auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>() ) + for( auto inheritanceDeclRef : aggTypeDeclRef.GetMembersOfType<InheritanceDeclRef>()) { - for( auto inheritanceDeclRef : aggTypeDeclRef.GetMembersOfType<InheritanceDeclRef>()) - { - EnsureDecl(inheritanceDeclRef.GetDecl()); + EnsureDecl(inheritanceDeclRef.GetDecl()); - auto inheritedDeclRefType = inheritanceDeclRef.getBaseType()->As<DeclRefType>(); - if (!inheritedDeclRefType) - continue; + auto inheritedDeclRefType = inheritanceDeclRef.getBaseType()->As<DeclRefType>(); + if (!inheritedDeclRefType) + continue; - if(interfaceDeclRef.Equals(inheritedDeclRefType->declRef)) - return true; - } + if(interfaceDeclRef.Equals(inheritedDeclRefType->declRef)) + return true; } } - - // default is failure - return false; } - RefPtr<ExpressionType> TryJoinTypeWithInterface( - RefPtr<ExpressionType> type, - InterfaceDeclRef interfaceDeclRef) - { - // The most basic test here should be: does the type declare conformance to the trait. - if(DoesTypeConformToInterface(type, interfaceDeclRef)) - return type; + // default is failure + return false; + } - // There is a more nuanced case if `type` is a builtin type, and we need to make it - // conform to a trait that some but not all builtin types support (the main problem - // here is when an operation wants an integer type, but one of our operands is a `float`. - // The HLSL rules will allow that, with implicit conversion, but our default join rules - // will end up picking `float` and we don't want that...). + RefPtr<ExpressionType> TryJoinTypeWithInterface( + RefPtr<ExpressionType> type, + InterfaceDeclRef interfaceDeclRef) + { + // The most basic test here should be: does the type declare conformance to the trait. + if(DoesTypeConformToInterface(type, interfaceDeclRef)) + return type; - // For now we don't handle the hard case and just bail - return nullptr; - } + // There is a more nuanced case if `type` is a builtin type, and we need to make it + // conform to a trait that some but not all builtin types support (the main problem + // here is when an operation wants an integer type, but one of our operands is a `float`. + // The HLSL rules will allow that, with implicit conversion, but our default join rules + // will end up picking `float` and we don't want that...). - // Try to compute the "join" between two types - RefPtr<ExpressionType> TryJoinTypes( - RefPtr<ExpressionType> left, - RefPtr<ExpressionType> right) - { - // Easy case: they are the same type! - if (left->Equals(right)) - return left; + // For now we don't handle the hard case and just bail + return nullptr; + } - // We can join two basic types by picking the "better" of the two - if (auto leftBasic = left->As<BasicExpressionType>()) - { - if (auto rightBasic = right->As<BasicExpressionType>()) - { - auto leftFlavor = leftBasic->BaseType; - auto rightFlavor = rightBasic->BaseType; + // Try to compute the "join" between two types + RefPtr<ExpressionType> TryJoinTypes( + RefPtr<ExpressionType> left, + RefPtr<ExpressionType> right) + { + // Easy case: they are the same type! + if (left->Equals(right)) + return left; - // TODO(tfoley): Need a special-case rule here that if - // either operand is of type `half`, then we promote - // to at least `float` + // We can join two basic types by picking the "better" of the two + if (auto leftBasic = left->As<BasicExpressionType>()) + { + if (auto rightBasic = right->As<BasicExpressionType>()) + { + auto leftFlavor = leftBasic->BaseType; + auto rightFlavor = rightBasic->BaseType; - // Return the one that had higher rank... - if (leftFlavor > rightFlavor) - return left; - else - { - assert(rightFlavor > leftFlavor); - return right; - } - } + // TODO(tfoley): Need a special-case rule here that if + // either operand is of type `half`, then we promote + // to at least `float` - // We can also join a vector and a scalar - if(auto rightVector = right->As<VectorExpressionType>()) + // Return the one that had higher rank... + if (leftFlavor > rightFlavor) + return left; + else { - return TryJoinVectorAndScalarType(rightVector, leftBasic); + assert(rightFlavor > leftFlavor); + return right; } } - // We can join two vector types by joining their element types - // (and also their sizes...) - if( auto leftVector = left->As<VectorExpressionType>()) + // We can also join a vector and a scalar + if(auto rightVector = right->As<VectorExpressionType>()) { - if(auto rightVector = right->As<VectorExpressionType>()) - { - // Check if the vector sizes match - if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr())) - return nullptr; - - // Try to join the element types - auto joinElementType = TryJoinTypes( - leftVector->elementType, - rightVector->elementType); - if(!joinElementType) - return nullptr; - - return createVectorType( - joinElementType, - leftVector->elementCount); - } - - // We can also join a vector and a scalar - if(auto rightBasic = right->As<BasicExpressionType>()) - { - return TryJoinVectorAndScalarType(leftVector, rightBasic); - } + return TryJoinVectorAndScalarType(rightVector, leftBasic); } + } - // HACK: trying to work trait types in here... - if(auto leftDeclRefType = left->As<DeclRefType>()) + // We can join two vector types by joining their element types + // (and also their sizes...) + if( auto leftVector = left->As<VectorExpressionType>()) + { + if(auto rightVector = right->As<VectorExpressionType>()) { - if( auto leftInterfaceRef = leftDeclRefType->declRef.As<InterfaceDeclRef>() ) - { - // - return TryJoinTypeWithInterface(right, leftInterfaceRef); - } + // Check if the vector sizes match + if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr())) + return nullptr; + + // Try to join the element types + auto joinElementType = TryJoinTypes( + leftVector->elementType, + rightVector->elementType); + if(!joinElementType) + return nullptr; + + return createVectorType( + joinElementType, + leftVector->elementCount); } - if(auto rightDeclRefType = right->As<DeclRefType>()) + + // We can also join a vector and a scalar + if(auto rightBasic = right->As<BasicExpressionType>()) { - if( auto rightInterfaceRef = rightDeclRefType->declRef.As<InterfaceDeclRef>() ) - { - // - return TryJoinTypeWithInterface(left, rightInterfaceRef); - } + return TryJoinVectorAndScalarType(leftVector, rightBasic); } - - // TODO: all the cases for vectors apply to matrices too! - - // Default case is that we just fail. - return nullptr; } - // Try to solve a system of generic constraints. - // The `system` argument provides the constraints. - // The `varSubst` argument provides the list of constraint - // variables that were created for the system. - // - // Returns a new substitution representing the values that - // we solved for along the way. - RefPtr<Substitutions> TrySolveConstraintSystem( - ConstraintSystem* system, - GenericDeclRef genericDeclRef) + // HACK: trying to work trait types in here... + if(auto leftDeclRefType = left->As<DeclRefType>()) { - // For now the "solver" is going to be ridiculously simplistic. - - // The generic itself will have some constraints, so we need to try and solve those too - for( auto constraintDeclRef : genericDeclRef.GetMembersOfType<GenericTypeConstraintDeclRef>() ) + if( auto leftInterfaceRef = leftDeclRefType->declRef.As<InterfaceDeclRef>() ) { - if(!TryUnifyTypes(*system, constraintDeclRef.GetSub(), constraintDeclRef.GetSup())) - return nullptr; + // + return TryJoinTypeWithInterface(right, leftInterfaceRef); } + } + if(auto rightDeclRefType = right->As<DeclRefType>()) + { + if( auto rightInterfaceRef = rightDeclRefType->declRef.As<InterfaceDeclRef>() ) + { + // + return TryJoinTypeWithInterface(left, rightInterfaceRef); + } + } + + // TODO: all the cases for vectors apply to matrices too! + + // Default case is that we just fail. + return nullptr; + } + + // Try to solve a system of generic constraints. + // The `system` argument provides the constraints. + // The `varSubst` argument provides the list of constraint + // variables that were created for the system. + // + // Returns a new substitution representing the values that + // we solved for along the way. + RefPtr<Substitutions> TrySolveConstraintSystem( + ConstraintSystem* system, + GenericDeclRef genericDeclRef) + { + // For now the "solver" is going to be ridiculously simplistic. + + // The generic itself will have some constraints, so we need to try and solve those too + for( auto constraintDeclRef : genericDeclRef.GetMembersOfType<GenericTypeConstraintDeclRef>() ) + { + if(!TryUnifyTypes(*system, constraintDeclRef.GetSub(), constraintDeclRef.GetSup())) + return nullptr; + } - // We will loop over the generic parameters, and for - // each we will try to find a way to satisfy all - // the constraints for that parameter - List<RefPtr<Val>> args; - for (auto m : genericDeclRef.GetMembers()) + // We will loop over the generic parameters, and for + // each we will try to find a way to satisfy all + // the constraints for that parameter + List<RefPtr<Val>> args; + for (auto m : genericDeclRef.GetMembers()) + { + if (auto typeParam = m.As<GenericTypeParamDeclRef>()) { - if (auto typeParam = m.As<GenericTypeParamDeclRef>()) + RefPtr<ExpressionType> type = nullptr; + for (auto& c : system->constraints) { - RefPtr<ExpressionType> type = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != typeParam.GetDecl()) - continue; + if (c.decl != typeParam.GetDecl()) + continue; - auto cType = c.val.As<ExpressionType>(); - assert(cType.Ptr()); + auto cType = c.val.As<ExpressionType>(); + assert(cType.Ptr()); - if (!type) - { - type = cType; - } - else + if (!type) + { + type = cType; + } + else + { + auto joinType = TryJoinTypes(type, cType); + if (!joinType) { - auto joinType = TryJoinTypes(type, cType); - if (!joinType) - { - // failure! - return nullptr; - } - type = joinType; + // failure! + return nullptr; } - - c.satisfied = true; + type = joinType; } - if (!type) - { - // failure! - return nullptr; - } - args.Add(type); + c.satisfied = true; } - else if (auto valParam = m.As<GenericValueParamDeclRef>()) + + if (!type) { - // TODO(tfoley): maybe support more than integers some day? - // TODO(tfoley): figure out how this needs to interact with - // compile-time integers that aren't just constants... - RefPtr<IntVal> val = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != valParam.GetDecl()) - continue; + // failure! + return nullptr; + } + args.Add(type); + } + else if (auto valParam = m.As<GenericValueParamDeclRef>()) + { + // TODO(tfoley): maybe support more than integers some day? + // TODO(tfoley): figure out how this needs to interact with + // compile-time integers that aren't just constants... + RefPtr<IntVal> val = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != valParam.GetDecl()) + continue; - auto cVal = c.val.As<IntVal>(); - assert(cVal.Ptr()); + auto cVal = c.val.As<IntVal>(); + assert(cVal.Ptr()); - if (!val) - { - val = cVal; - } - else + if (!val) + { + val = cVal; + } + else + { + if(!val->EqualsVal(cVal.Ptr())) { - if(!val->EqualsVal(cVal.Ptr())) - { - // failure! - return nullptr; - } + // failure! + return nullptr; } - - c.satisfied = true; } - if (!val) - { - // failure! - return nullptr; - } - args.Add(val); + c.satisfied = true; } - else + + if (!val) { - // ignore anything that isn't a generic parameter + // failure! + return nullptr; } + args.Add(val); } + else + { + // ignore anything that isn't a generic parameter + } + } - // Make sure we haven't constructed any spurious constraints - // that we aren't able to satisfy: - for (auto c : system->constraints) + // Make sure we haven't constructed any spurious constraints + // that we aren't able to satisfy: + for (auto c : system->constraints) + { + if (!c.satisfied) { - if (!c.satisfied) - { - return nullptr; - } + return nullptr; } + } - // Consruct a reference to the extension with our constraint variables - // as the - RefPtr<Substitutions> solvedSubst = new Substitutions(); - solvedSubst->genericDecl = genericDeclRef.GetDecl(); - solvedSubst->outer = genericDeclRef.substitutions; - solvedSubst->args = args; + // Consruct a reference to the extension with our constraint variables + // as the + RefPtr<Substitutions> solvedSubst = new Substitutions(); + solvedSubst->genericDecl = genericDeclRef.GetDecl(); + solvedSubst->outer = genericDeclRef.substitutions; + solvedSubst->args = args; - return solvedSubst; + return solvedSubst; #if 0 - List<RefPtr<Val>> solvedArgs; - for (auto varArg : varSubst->args) + List<RefPtr<Val>> solvedArgs; + for (auto varArg : varSubst->args) + { + if (auto typeVar = dynamic_cast<ConstraintVarType*>(varArg.Ptr())) { - if (auto typeVar = dynamic_cast<ConstraintVarType*>(varArg.Ptr())) + RefPtr<ExpressionType> type = nullptr; + for (auto& c : system->constraints) { - RefPtr<ExpressionType> type = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != typeVar->declRef.GetDecl()) - continue; + if (c.decl != typeVar->declRef.GetDecl()) + continue; - auto cType = c.val.As<ExpressionType>(); - assert(cType.Ptr()); + auto cType = c.val.As<ExpressionType>(); + assert(cType.Ptr()); - if (!type) - { - type = cType; - } - else + if (!type) + { + type = cType; + } + else + { + if (!type->Equals(cType)) { - if (!type->Equals(cType)) - { - // failure! - return nullptr; - } + // failure! + return nullptr; } - - c.satisfied = true; } - if (!type) - { - // failure! - return nullptr; - } - solvedArgs.Add(type); + c.satisfied = true; } - else if (auto valueVar = dynamic_cast<ConstraintVarInt*>(varArg.Ptr())) + + if (!type) { - // TODO(tfoley): maybe support more than integers some day? - RefPtr<IntVal> val = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != valueVar->declRef.GetDecl()) - continue; + // failure! + return nullptr; + } + solvedArgs.Add(type); + } + else if (auto valueVar = dynamic_cast<ConstraintVarInt*>(varArg.Ptr())) + { + // TODO(tfoley): maybe support more than integers some day? + RefPtr<IntVal> val = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != valueVar->declRef.GetDecl()) + continue; - auto cVal = c.val.As<IntVal>(); - assert(cVal.Ptr()); + auto cVal = c.val.As<IntVal>(); + assert(cVal.Ptr()); - if (!val) + if (!val) + { + val = cVal; + } + else + { + if (val->value != cVal->value) { - val = cVal; + // failure! + return nullptr; } - else - { - if (val->value != cVal->value) - { - // failure! - return nullptr; - } - } - - c.satisfied = true; } - if (!val) - { - // failure! - return nullptr; - } - solvedArgs.Add(val); + c.satisfied = true; } - else + + if (!val) { - // ignore anything that isn't a generic parameter + // failure! + return nullptr; } + solvedArgs.Add(val); } + else + { + // ignore anything that isn't a generic parameter + } + } - // Make sure we haven't constructed any spurious constraints - // that we aren't able to satisfy: - for (auto c : system->constraints) + // Make sure we haven't constructed any spurious constraints + // that we aren't able to satisfy: + for (auto c : system->constraints) + { + if (!c.satisfied) { - if (!c.satisfied) - { - return nullptr; - } + return nullptr; } + } - RefPtr<Substitutions> newSubst = new Substitutions(); - newSubst->genericDecl = varSubst->genericDecl; - newSubst->outer = varSubst->outer; - newSubst->args = solvedArgs; - return newSubst; + RefPtr<Substitutions> newSubst = new Substitutions(); + newSubst->genericDecl = varSubst->genericDecl; + newSubst->outer = varSubst->outer; + newSubst->args = solvedArgs; + return newSubst; #endif - } + } - // + // - struct OverloadCandidate + struct OverloadCandidate + { + enum class Flavor { - enum class Flavor - { - Func, - Generic, - UnspecializedGeneric, - }; - Flavor flavor; + Func, + Generic, + UnspecializedGeneric, + }; + Flavor flavor; - enum class Status - { - GenericArgumentInferenceFailed, - Unchecked, - ArityChecked, - FixityChecked, - TypeChecked, - Appicable, - }; - Status status = Status::Unchecked; - - // Reference to the declaration being applied - LookupResultItem item; - - // The type of the result expression if this candidate is selected - RefPtr<ExpressionType> resultType; - - // A system for tracking constraints introduced on generic parameters - ConstraintSystem constraintSystem; - - // How much conversion cost should be considered for this overload, - // when ranking candidates. - ConversionCost conversionCostSum = kConversionCost_None; + enum class Status + { + GenericArgumentInferenceFailed, + Unchecked, + ArityChecked, + FixityChecked, + TypeChecked, + Appicable, }; + Status status = Status::Unchecked; + // Reference to the declaration being applied + LookupResultItem item; + // The type of the result expression if this candidate is selected + RefPtr<ExpressionType> resultType; - // State related to overload resolution for a call - // to an overloaded symbol - struct OverloadResolveContext - { - enum class Mode - { - // We are just checking if a candidate works or not - JustTrying, + // A system for tracking constraints introduced on generic parameters + ConstraintSystem constraintSystem; - // We want to actually update the AST for a chosen candidate - ForReal, - }; + // How much conversion cost should be considered for this overload, + // when ranking candidates. + ConversionCost conversionCostSum = kConversionCost_None; + }; - RefPtr<AppExprBase> appExpr; - RefPtr<ExpressionSyntaxNode> baseExpr; - // Are we still trying out candidates, or are we - // checking the chosen one for real? - Mode mode = Mode::JustTrying; - // We store one candidate directly, so that we don't - // need to do dynamic allocation on the list every time - OverloadCandidate bestCandidateStorage; - OverloadCandidate* bestCandidate = nullptr; + // State related to overload resolution for a call + // to an overloaded symbol + struct OverloadResolveContext + { + enum class Mode + { + // We are just checking if a candidate works or not + JustTrying, - // Full list of all candidates being considered, in the ambiguous case - List<OverloadCandidate> bestCandidates; + // We want to actually update the AST for a chosen candidate + ForReal, }; - struct ParamCounts - { - int required; - int allowed; - }; + RefPtr<AppExprBase> appExpr; + RefPtr<ExpressionSyntaxNode> baseExpr; - // count the number of parameters required/allowed for a callable - ParamCounts CountParameters(FilteredMemberRefList<ParamDeclRef> params) + // Are we still trying out candidates, or are we + // checking the chosen one for real? + Mode mode = Mode::JustTrying; + + // We store one candidate directly, so that we don't + // need to do dynamic allocation on the list every time + OverloadCandidate bestCandidateStorage; + OverloadCandidate* bestCandidate = nullptr; + + // Full list of all candidates being considered, in the ambiguous case + List<OverloadCandidate> bestCandidates; + }; + + struct ParamCounts + { + int required; + int allowed; + }; + + // count the number of parameters required/allowed for a callable + ParamCounts CountParameters(FilteredMemberRefList<ParamDeclRef> params) + { + ParamCounts counts = { 0, 0 }; + for (auto param : params) { - ParamCounts counts = { 0, 0 }; - for (auto param : params) - { - counts.allowed++; + counts.allowed++; - // No initializer means no default value - // - // TODO(tfoley): The logic here is currently broken in two ways: - // - // 1. We are assuming that once one parameter has a default, then all do. - // This can/should be validated earlier, so that we can assume it here. - // - // 2. We are not handling the possibility of multiple declarations for - // a single function, where we'd need to merge default parameters across - // all the declarations. - if (!param.GetDecl()->Expr) - { - counts.required++; - } + // No initializer means no default value + // + // TODO(tfoley): The logic here is currently broken in two ways: + // + // 1. We are assuming that once one parameter has a default, then all do. + // This can/should be validated earlier, so that we can assume it here. + // + // 2. We are not handling the possibility of multiple declarations for + // a single function, where we'd need to merge default parameters across + // all the declarations. + if (!param.GetDecl()->Expr) + { + counts.required++; } - return counts; } + return counts; + } - // count the number of parameters required/allowed for a generic - ParamCounts CountParameters(GenericDeclRef genericRef) + // count the number of parameters required/allowed for a generic + ParamCounts CountParameters(GenericDeclRef genericRef) + { + ParamCounts counts = { 0, 0 }; + for (auto m : genericRef.GetDecl()->Members) { - ParamCounts counts = { 0, 0 }; - for (auto m : genericRef.GetDecl()->Members) + if (auto typeParam = m.As<GenericTypeParamDecl>()) { - if (auto typeParam = m.As<GenericTypeParamDecl>()) + counts.allowed++; + if (!typeParam->initType.Ptr()) { - counts.allowed++; - if (!typeParam->initType.Ptr()) - { - counts.required++; - } + counts.required++; } - else if (auto valParam = m.As<GenericValueParamDecl>()) + } + else if (auto valParam = m.As<GenericValueParamDecl>()) + { + counts.allowed++; + if (!valParam->Expr) { - counts.allowed++; - if (!valParam->Expr) - { - counts.required++; - } + counts.required++; } } - return counts; } + return counts; + } - bool TryCheckOverloadCandidateArity( - OverloadResolveContext& context, - OverloadCandidate const& candidate) + bool TryCheckOverloadCandidateArity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) + { + int argCount = context.appExpr->Arguments.Count(); + ParamCounts paramCounts = { 0, 0 }; + switch (candidate.flavor) { - int argCount = context.appExpr->Arguments.Count(); - ParamCounts paramCounts = { 0, 0 }; - switch (candidate.flavor) - { - case OverloadCandidate::Flavor::Func: - paramCounts = CountParameters(candidate.item.declRef.As<CallableDeclRef>().GetParameters()); - break; + case OverloadCandidate::Flavor::Func: + paramCounts = CountParameters(candidate.item.declRef.As<CallableDeclRef>().GetParameters()); + break; - case OverloadCandidate::Flavor::Generic: - paramCounts = CountParameters(candidate.item.declRef.As<GenericDeclRef>()); - break; + case OverloadCandidate::Flavor::Generic: + paramCounts = CountParameters(candidate.item.declRef.As<GenericDeclRef>()); + break; - default: - assert(!"unexpected"); - break; - } + default: + assert(!"unexpected"); + break; + } - if (argCount >= paramCounts.required && argCount <= paramCounts.allowed) - return true; + if (argCount >= paramCounts.required && argCount <= paramCounts.allowed) + return true; - // Emit an error message if we are checking this call for real - if (context.mode != OverloadResolveContext::Mode::JustTrying) + // Emit an error message if we are checking this call for real + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + if (argCount < paramCounts.required) { - if (argCount < paramCounts.required) - { - getSink()->diagnose(context.appExpr, Diagnostics::notEnoughArguments, argCount, paramCounts.required); - } - else - { - assert(argCount > paramCounts.allowed); - getSink()->diagnose(context.appExpr, Diagnostics::tooManyArguments, argCount, paramCounts.allowed); - } + getSink()->diagnose(context.appExpr, Diagnostics::notEnoughArguments, argCount, paramCounts.required); + } + else + { + assert(argCount > paramCounts.allowed); + getSink()->diagnose(context.appExpr, Diagnostics::tooManyArguments, argCount, paramCounts.allowed); } - - return false; } - bool TryCheckOverloadCandidateFixity( - OverloadResolveContext& context, - OverloadCandidate const& candidate) - { - auto expr = context.appExpr; + return false; + } - auto decl = candidate.item.declRef.decl; + bool TryCheckOverloadCandidateFixity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) + { + auto expr = context.appExpr; - if(auto prefixExpr = expr.As<PrefixExpr>()) - { - if(decl->HasModifier<PrefixModifier>()) - return true; + auto decl = candidate.item.declRef.decl; - if (context.mode != OverloadResolveContext::Mode::JustTrying) - { - getSink()->diagnose(context.appExpr, Diagnostics::expectedPrefixOperator); - getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); - } + if(auto prefixExpr = expr.As<PrefixExpr>()) + { + if(decl->HasModifier<PrefixModifier>()) + return true; - return false; - } - else if(auto postfixExpr = expr.As<PostfixExpr>()) + if (context.mode != OverloadResolveContext::Mode::JustTrying) { - if(decl->HasModifier<PostfixModifier>()) - return true; + getSink()->diagnose(context.appExpr, Diagnostics::expectedPrefixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); + } - if (context.mode != OverloadResolveContext::Mode::JustTrying) - { - getSink()->diagnose(context.appExpr, Diagnostics::expectedPostfixOperator); - getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); - } + return false; + } + else if(auto postfixExpr = expr.As<PostfixExpr>()) + { + if(decl->HasModifier<PostfixModifier>()) + return true; - return false; - } - else + if (context.mode != OverloadResolveContext::Mode::JustTrying) { - return true; + getSink()->diagnose(context.appExpr, Diagnostics::expectedPostfixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); } return false; } - - bool TryCheckGenericOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate) + else { - auto& args = context.appExpr->Arguments; + return true; + } + + return false; + } + + bool TryCheckGenericOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + auto& args = context.appExpr->Arguments; - auto genericDeclRef = candidate.item.declRef.As<GenericDeclRef>(); + auto genericDeclRef = candidate.item.declRef.As<GenericDeclRef>(); - int aa = 0; - for (auto memberRef : genericDeclRef.GetMembers()) + int aa = 0; + for (auto memberRef : genericDeclRef.GetMembers()) + { + if (auto typeParamRef = memberRef.As<GenericTypeParamDeclRef>()) { - if (auto typeParamRef = memberRef.As<GenericTypeParamDeclRef>()) - { - auto arg = args[aa++]; + auto arg = args[aa++]; - if (context.mode == OverloadResolveContext::Mode::JustTrying) - { - if (!CanCoerceToProperType(TypeExp(arg))) - { - return false; - } - } - else - { - TypeExp typeExp = CoerceToProperType(TypeExp(arg)); - } - } - else if (auto valParamRef = memberRef.As<GenericValueParamDeclRef>()) + if (context.mode == OverloadResolveContext::Mode::JustTrying) { - auto arg = args[aa++]; - - if (context.mode == OverloadResolveContext::Mode::JustTrying) + if (!CanCoerceToProperType(TypeExp(arg))) { - ConversionCost cost = kConversionCost_None; - if (!CanCoerce(valParamRef.GetType(), arg->Type, &cost)) - { - return false; - } - candidate.conversionCostSum += cost; - } - else - { - arg = Coerce(valParamRef.GetType(), arg); - auto val = ExtractGenericArgInteger(arg); + return false; } } else { - continue; + TypeExp typeExp = CoerceToProperType(TypeExp(arg)); } } - - return true; - } - - bool TryCheckOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - auto& args = context.appExpr->Arguments; - int argCount = args.Count(); - - List<ParamDeclRef> params; - switch (candidate.flavor) + else if (auto valParamRef = memberRef.As<GenericValueParamDeclRef>()) { - case OverloadCandidate::Flavor::Func: - params = candidate.item.declRef.As<CallableDeclRef>().GetParameters().ToArray(); - break; - - case OverloadCandidate::Flavor::Generic: - return TryCheckGenericOverloadCandidateTypes(context, candidate); - - default: - assert(!"unexpected"); - break; - } - - // Note(tfoley): We might have fewer arguments than parameters in the - // case where one or more parameters had defaults. - assert(argCount <= params.Count()); - - for (int ii = 0; ii < argCount; ++ii) - { - auto& arg = args[ii]; - auto param = params[ii]; + auto arg = args[aa++]; if (context.mode == OverloadResolveContext::Mode::JustTrying) { ConversionCost cost = kConversionCost_None; - if (!CanCoerce(param.GetType(), arg->Type, &cost)) + if (!CanCoerce(valParamRef.GetType(), arg->Type, &cost)) { return false; } @@ -3177,531 +3120,532 @@ namespace Slang } else { - arg = Coerce(param.GetType(), arg); + arg = Coerce(valParamRef.GetType(), arg); + auto val = ExtractGenericArgInteger(arg); } } - return true; + else + { + continue; + } } - bool TryCheckOverloadCandidateDirections( - OverloadResolveContext& /*context*/, - OverloadCandidate const& /*candidate*/) - { - // TODO(tfoley): check `in` and `out` markers, as needed. - return true; - } + return true; + } - // Try to check an overload candidate, but bail out - // if any step fails - void TryCheckOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - if (!TryCheckOverloadCandidateArity(context, candidate)) - return; + bool TryCheckOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + auto& args = context.appExpr->Arguments; + int argCount = args.Count(); - candidate.status = OverloadCandidate::Status::ArityChecked; - if (!TryCheckOverloadCandidateFixity(context, candidate)) - return; + List<ParamDeclRef> params; + switch (candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + params = candidate.item.declRef.As<CallableDeclRef>().GetParameters().ToArray(); + break; - candidate.status = OverloadCandidate::Status::FixityChecked; - if (!TryCheckOverloadCandidateTypes(context, candidate)) - return; + case OverloadCandidate::Flavor::Generic: + return TryCheckGenericOverloadCandidateTypes(context, candidate); - candidate.status = OverloadCandidate::Status::TypeChecked; - if (!TryCheckOverloadCandidateDirections(context, candidate)) - return; - - candidate.status = OverloadCandidate::Status::Appicable; + default: + assert(!"unexpected"); + break; } - // Create the representation of a given generic applied to some arguments - RefPtr<ExpressionSyntaxNode> CreateGenericDeclRef( - RefPtr<ExpressionSyntaxNode> baseExpr, - RefPtr<AppExprBase> appExpr) + // Note(tfoley): We might have fewer arguments than parameters in the + // case where one or more parameters had defaults. + assert(argCount <= params.Count()); + + for (int ii = 0; ii < argCount; ++ii) { - auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>(); - if (!baseDeclRefExpr) + auto& arg = args[ii]; + auto param = params[ii]; + + if (context.mode == OverloadResolveContext::Mode::JustTrying) { - assert(!"unexpected"); - return CreateErrorExpr(appExpr.Ptr()); + ConversionCost cost = kConversionCost_None; + if (!CanCoerce(param.GetType(), arg->Type, &cost)) + { + return false; + } + candidate.conversionCostSum += cost; } - auto baseGenericRef = baseDeclRefExpr->declRef.As<GenericDeclRef>(); - if (!baseGenericRef) + else { - assert(!"unexpected"); - return CreateErrorExpr(appExpr.Ptr()); + arg = Coerce(param.GetType(), arg); } + } + return true; + } - RefPtr<Substitutions> subst = new Substitutions(); - subst->genericDecl = baseGenericRef.GetDecl(); - subst->outer = baseGenericRef.substitutions; + bool TryCheckOverloadCandidateDirections( + OverloadResolveContext& /*context*/, + OverloadCandidate const& /*candidate*/) + { + // TODO(tfoley): check `in` and `out` markers, as needed. + return true; + } - for (auto arg : appExpr->Arguments) - { - subst->args.Add(ExtractGenericArgVal(arg)); - } + // Try to check an overload candidate, but bail out + // if any step fails + void TryCheckOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + if (!TryCheckOverloadCandidateArity(context, candidate)) + return; - DeclRef innerDeclRef(baseGenericRef.GetInner(), subst); + candidate.status = OverloadCandidate::Status::ArityChecked; + if (!TryCheckOverloadCandidateFixity(context, candidate)) + return; - return ConstructDeclRefExpr( - innerDeclRef, - nullptr, - appExpr); + candidate.status = OverloadCandidate::Status::FixityChecked; + if (!TryCheckOverloadCandidateTypes(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::TypeChecked; + if (!TryCheckOverloadCandidateDirections(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::Appicable; + } + + // Create the representation of a given generic applied to some arguments + RefPtr<ExpressionSyntaxNode> CreateGenericDeclRef( + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<AppExprBase> appExpr) + { + auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>(); + if (!baseDeclRefExpr) + { + assert(!"unexpected"); + return CreateErrorExpr(appExpr.Ptr()); + } + auto baseGenericRef = baseDeclRefExpr->declRef.As<GenericDeclRef>(); + if (!baseGenericRef) + { + assert(!"unexpected"); + return CreateErrorExpr(appExpr.Ptr()); } - // Take an overload candidate that previously got through - // `TryCheckOverloadCandidate` above, and try to finish - // up the work and turn it into a real expression. - // - // If the candidate isn't actually applicable, this is - // where we'd start reporting the issue(s). - RefPtr<ExpressionSyntaxNode> CompleteOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - // special case for generic argument inference failure - if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed) - { - String callString = GetCallSignatureString(context.appExpr); - getSink()->diagnose( - context.appExpr, - Diagnostics::genericArgumentInferenceFailed, - callString); + RefPtr<Substitutions> subst = new Substitutions(); + subst->genericDecl = baseGenericRef.GetDecl(); + subst->outer = baseGenericRef.substitutions; - String declString = getDeclSignatureString(candidate.item); - getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString); - goto error; - } + for (auto arg : appExpr->Arguments) + { + subst->args.Add(ExtractGenericArgVal(arg)); + } - context.mode = OverloadResolveContext::Mode::ForReal; - context.appExpr->Type = ExpressionType::Error; + DeclRef innerDeclRef(baseGenericRef.GetInner(), subst); - if (!TryCheckOverloadCandidateArity(context, candidate)) - goto error; + return ConstructDeclRefExpr( + innerDeclRef, + nullptr, + appExpr); + } - if (!TryCheckOverloadCandidateFixity(context, candidate)) - goto error; + // Take an overload candidate that previously got through + // `TryCheckOverloadCandidate` above, and try to finish + // up the work and turn it into a real expression. + // + // If the candidate isn't actually applicable, this is + // where we'd start reporting the issue(s). + RefPtr<ExpressionSyntaxNode> CompleteOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // special case for generic argument inference failure + if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed) + { + String callString = GetCallSignatureString(context.appExpr); + getSink()->diagnose( + context.appExpr, + Diagnostics::genericArgumentInferenceFailed, + callString); - if (!TryCheckOverloadCandidateTypes(context, candidate)) - goto error; + String declString = getDeclSignatureString(candidate.item); + getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString); + goto error; + } - if (!TryCheckOverloadCandidateDirections(context, candidate)) - goto error; + context.mode = OverloadResolveContext::Mode::ForReal; + context.appExpr->Type = ExpressionType::Error; + if (!TryCheckOverloadCandidateArity(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateFixity(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateTypes(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateDirections(context, candidate)) + goto error; + + { + auto baseExpr = ConstructLookupResultExpr( + candidate.item, context.baseExpr, context.appExpr->FunctionExpr); + + switch(candidate.flavor) { - auto baseExpr = ConstructLookupResultExpr( - candidate.item, context.baseExpr, context.appExpr->FunctionExpr); + case OverloadCandidate::Flavor::Func: + context.appExpr->FunctionExpr = baseExpr; + context.appExpr->Type = candidate.resultType; - switch(candidate.flavor) + // A call may yield an l-value, and we should take a look at the candidate to be sure + if(auto subscriptDeclRef = candidate.item.declRef.As<SubscriptDeclRef>()) { - case OverloadCandidate::Flavor::Func: - context.appExpr->FunctionExpr = baseExpr; - context.appExpr->Type = candidate.resultType; - - // A call may yield an l-value, and we should take a look at the candidate to be sure - if(auto subscriptDeclRef = candidate.item.declRef.As<SubscriptDeclRef>()) + for(auto setter : subscriptDeclRef.GetDecl()->GetMembersOfType<SetterDecl>()) { - for(auto setter : subscriptDeclRef.GetDecl()->GetMembersOfType<SetterDecl>()) - { - context.appExpr->Type.IsLeftValue = true; - } + context.appExpr->Type.IsLeftValue = true; } + } - // TODO: there may be other cases that confer l-value-ness + // TODO: there may be other cases that confer l-value-ness - return context.appExpr; - break; + return context.appExpr; + break; - case OverloadCandidate::Flavor::Generic: - return CreateGenericDeclRef(baseExpr, context.appExpr); - break; + case OverloadCandidate::Flavor::Generic: + return CreateGenericDeclRef(baseExpr, context.appExpr); + break; - default: - assert(!"unexpected"); - break; - } + default: + assert(!"unexpected"); + break; } - - - error: - return CreateErrorExpr(context.appExpr.Ptr()); } - // Implement a comparison operation between overload candidates, - // so that the better candidate compares as less-than the other - int CompareOverloadCandidates( - OverloadCandidate* left, - OverloadCandidate* right) - { - // If one candidate got further along in validation, pick it - if (left->status != right->status) - return int(right->status) - int(left->status); - // If both candidates are applicable, then we need to compare - // the costs of their type conversion sequences - if(left->status == OverloadCandidate::Status::Appicable) - { - if (left->conversionCostSum != right->conversionCostSum) - return left->conversionCostSum - right->conversionCostSum; - } + error: + return CreateErrorExpr(context.appExpr.Ptr()); + } - return 0; - } + // Implement a comparison operation between overload candidates, + // so that the better candidate compares as less-than the other + int CompareOverloadCandidates( + OverloadCandidate* left, + OverloadCandidate* right) + { + // If one candidate got further along in validation, pick it + if (left->status != right->status) + return int(right->status) - int(left->status); - void AddOverloadCandidateInner( - OverloadResolveContext& context, - OverloadCandidate& candidate) + // If both candidates are applicable, then we need to compare + // the costs of their type conversion sequences + if(left->status == OverloadCandidate::Status::Appicable) { - // Filter our existing candidates, to remove any that are worse than our new one + if (left->conversionCostSum != right->conversionCostSum) + return left->conversionCostSum - right->conversionCostSum; + } - bool keepThisCandidate = true; // should this candidate be kept? + return 0; + } - if (context.bestCandidates.Count() != 0) - { - // We have multiple candidates right now, so filter them. - bool anyFiltered = false; - // Note that we are querying the list length on every iteration, - // because we might remove things. - for (int cc = 0; cc < context.bestCandidates.Count(); ++cc) - { - int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]); - if (cmp < 0) - { - // our new candidate is better! + void AddOverloadCandidateInner( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // Filter our existing candidates, to remove any that are worse than our new one - // remove it from the list (by swapping in a later one) - context.bestCandidates.FastRemoveAt(cc); - // and then reduce our index so that we re-visit the same index - --cc; + bool keepThisCandidate = true; // should this candidate be kept? - anyFiltered = true; - } - else if(cmp > 0) - { - // our candidate is worse! - keepThisCandidate = false; - } - } - // It should not be possible that we removed some existing candidate *and* - // chose not to keep this candidate (otherwise the better-ness relation - // isn't transitive). Therefore we confirm that we either chose to keep - // this candidate (in which case filtering is okay), or we didn't filter - // anything. - assert(keepThisCandidate || !anyFiltered); - } - else if(context.bestCandidate) + if (context.bestCandidates.Count() != 0) + { + // We have multiple candidates right now, so filter them. + bool anyFiltered = false; + // Note that we are querying the list length on every iteration, + // because we might remove things. + for (int cc = 0; cc < context.bestCandidates.Count(); ++cc) { - // There's only one candidate so far - int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate); - if(cmp < 0) + int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]); + if (cmp < 0) { // our new candidate is better! - context.bestCandidate = nullptr; + + // remove it from the list (by swapping in a later one) + context.bestCandidates.FastRemoveAt(cc); + // and then reduce our index so that we re-visit the same index + --cc; + + anyFiltered = true; } - else if (cmp > 0) + else if(cmp > 0) { // our candidate is worse! keepThisCandidate = false; } } - - // If our candidate isn't good enough, then drop it - if (!keepThisCandidate) - return; - - // Otherwise we want to keep the candidate - if (context.bestCandidates.Count() > 0) - { - // There were already multiple candidates, and we are adding one more - context.bestCandidates.Add(candidate); - } - else if (context.bestCandidate) + // It should not be possible that we removed some existing candidate *and* + // chose not to keep this candidate (otherwise the better-ness relation + // isn't transitive). Therefore we confirm that we either chose to keep + // this candidate (in which case filtering is okay), or we didn't filter + // anything. + assert(keepThisCandidate || !anyFiltered); + } + else if(context.bestCandidate) + { + // There's only one candidate so far + int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate); + if(cmp < 0) { - // There was a unique best candidate, but now we are ambiguous - context.bestCandidates.Add(*context.bestCandidate); - context.bestCandidates.Add(candidate); + // our new candidate is better! context.bestCandidate = nullptr; } - else + else if (cmp > 0) { - // This is the only candidate worthe keeping track of right now - context.bestCandidateStorage = candidate; - context.bestCandidate = &context.bestCandidateStorage; + // our candidate is worse! + keepThisCandidate = false; } } - void AddOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - // Try the candidate out, to see if it is applicable at all. - TryCheckOverloadCandidate(context, candidate); + // If our candidate isn't good enough, then drop it + if (!keepThisCandidate) + return; - // Now (potentially) add it to the set of candidate overloads to consider. - AddOverloadCandidateInner(context, candidate); + // Otherwise we want to keep the candidate + if (context.bestCandidates.Count() > 0) + { + // There were already multiple candidates, and we are adding one more + context.bestCandidates.Add(candidate); } - - void AddFuncOverloadCandidate( - LookupResultItem item, - CallableDeclRef funcDeclRef, - OverloadResolveContext& context) + else if (context.bestCandidate) { - EnsureDecl(funcDeclRef.GetDecl()); - - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Func; - candidate.item = item; - candidate.resultType = funcDeclRef.GetResultType(); - - AddOverloadCandidate(context, candidate); + // There was a unique best candidate, but now we are ambiguous + context.bestCandidates.Add(*context.bestCandidate); + context.bestCandidates.Add(candidate); + context.bestCandidate = nullptr; } - - void AddFuncOverloadCandidate( - RefPtr<FuncType> /*funcType*/, - OverloadResolveContext& /*context*/) + else { -#if 0 - if (funcType->decl) - { - AddFuncOverloadCandidate(funcType->decl, context); - } - else if (funcType->Func) - { - AddFuncOverloadCandidate(funcType->Func->SyntaxNode, context); - } - else if (funcType->Component) - { - AddComponentFuncOverloadCandidate(funcType->Component, context); - } -#else - throw "unimplemented"; -#endif + // This is the only candidate worthe keeping track of right now + context.bestCandidateStorage = candidate; + context.bestCandidate = &context.bestCandidateStorage; } + } - void AddCtorOverloadCandidate( - LookupResultItem typeItem, - RefPtr<ExpressionType> type, - ConstructorDeclRef ctorDeclRef, - OverloadResolveContext& context) - { - EnsureDecl(ctorDeclRef.GetDecl()); + void AddOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // Try the candidate out, to see if it is applicable at all. + TryCheckOverloadCandidate(context, candidate); - // `typeItem` refers to the type being constructed (the thing - // that was applied as a function) so we need to construct - // a `LookupResultItem` that refers to the constructor instead + // Now (potentially) add it to the set of candidate overloads to consider. + AddOverloadCandidateInner(context, candidate); + } - LookupResultItem ctorItem; - ctorItem.declRef = ctorDeclRef; - ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb(LookupResultItem::Breadcrumb::Kind::Member, typeItem.declRef, typeItem.breadcrumbs); + void AddFuncOverloadCandidate( + LookupResultItem item, + CallableDeclRef funcDeclRef, + OverloadResolveContext& context) + { + EnsureDecl(funcDeclRef.GetDecl()); - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Func; - candidate.item = ctorItem; - candidate.resultType = type; + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = item; + candidate.resultType = funcDeclRef.GetResultType(); - AddOverloadCandidate(context, candidate); - } + AddOverloadCandidate(context, candidate); + } - // If the given declaration has generic parameters, then - // return the corresponding `GenericDecl` that holds the - // parameters, etc. - GenericDecl* GetOuterGeneric(Decl* decl) + void AddFuncOverloadCandidate( + RefPtr<FuncType> /*funcType*/, + OverloadResolveContext& /*context*/) + { +#if 0 + if (funcType->decl) { - auto parentDecl = decl->ParentDecl; - if (!parentDecl) return nullptr; - auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl); - return parentGeneric; + AddFuncOverloadCandidate(funcType->decl, context); } - - // Try to find a unification for two values - bool TryUnifyVals( - ConstraintSystem& constraints, - RefPtr<Val> fst, - RefPtr<Val> snd) + else if (funcType->Func) { - // if both values are types, then unify types - if (auto fstType = fst.As<ExpressionType>()) - { - if (auto sndType = snd.As<ExpressionType>()) - { - return TryUnifyTypes(constraints, fstType, sndType); - } - } + AddFuncOverloadCandidate(funcType->Func->SyntaxNode, context); + } + else if (funcType->Component) + { + AddComponentFuncOverloadCandidate(funcType->Component, context); + } +#else + throw "unimplemented"; +#endif + } - // if both values are constant integers, then compare them - if (auto fstIntVal = fst.As<ConstantIntVal>()) - { - if (auto sndIntVal = snd.As<ConstantIntVal>()) - { - return fstIntVal->value == sndIntVal->value; - } - } + void AddCtorOverloadCandidate( + LookupResultItem typeItem, + RefPtr<ExpressionType> type, + ConstructorDeclRef ctorDeclRef, + OverloadResolveContext& context) + { + EnsureDecl(ctorDeclRef.GetDecl()); - // Check if both are integer values in general - if (auto fstInt = fst.As<IntVal>()) - { - if (auto sndInt = snd.As<IntVal>()) - { - auto fstParam = fstInt.As<GenericParamIntVal>(); - auto sndParam = sndInt.As<GenericParamIntVal>(); + // `typeItem` refers to the type being constructed (the thing + // that was applied as a function) so we need to construct + // a `LookupResultItem` that refers to the constructor instead - if (fstParam) - TryUnifyIntParam(constraints, fstParam->declRef, sndInt); - if (sndParam) - TryUnifyIntParam(constraints, sndParam->declRef, fstInt); + LookupResultItem ctorItem; + ctorItem.declRef = ctorDeclRef; + ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb(LookupResultItem::Breadcrumb::Kind::Member, typeItem.declRef, typeItem.breadcrumbs); - if (fstParam || sndParam) - return true; - } - } + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = ctorItem; + candidate.resultType = type; - throw "unimplemented"; + AddOverloadCandidate(context, candidate); + } - // default: fail - return false; - } + // If the given declaration has generic parameters, then + // return the corresponding `GenericDecl` that holds the + // parameters, etc. + GenericDecl* GetOuterGeneric(Decl* decl) + { + auto parentDecl = decl->ParentDecl; + if (!parentDecl) return nullptr; + auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl); + return parentGeneric; + } - bool TryUnifySubstitutions( - ConstraintSystem& constraints, - RefPtr<Substitutions> fst, - RefPtr<Substitutions> snd) + // Try to find a unification for two values + bool TryUnifyVals( + ConstraintSystem& constraints, + RefPtr<Val> fst, + RefPtr<Val> snd) + { + // if both values are types, then unify types + if (auto fstType = fst.As<ExpressionType>()) { - // They must both be NULL or non-NULL - if (!fst || !snd) - return fst == snd; - - // They must be specializing the same generic - if (fst->genericDecl != snd->genericDecl) - return false; - - // Their arguments must unify - assert(fst->args.Count() == snd->args.Count()); - int argCount = fst->args.Count(); - for (int aa = 0; aa < argCount; ++aa) + if (auto sndType = snd.As<ExpressionType>()) { - if (!TryUnifyVals(constraints, fst->args[aa], snd->args[aa])) - return false; + return TryUnifyTypes(constraints, fstType, sndType); } + } - // Their "base" specializations must unify - if (!TryUnifySubstitutions(constraints, fst->outer, snd->outer)) - return false; - - return true; + // if both values are constant integers, then compare them + if (auto fstIntVal = fst.As<ConstantIntVal>()) + { + if (auto sndIntVal = snd.As<ConstantIntVal>()) + { + return fstIntVal->value == sndIntVal->value; + } } - bool TryUnifyTypeParam( - ConstraintSystem& constraints, - RefPtr<GenericTypeParamDecl> typeParamDecl, - RefPtr<ExpressionType> type) + // Check if both are integer values in general + if (auto fstInt = fst.As<IntVal>()) { - // We want to constrain the given type parameter - // to equal the given type. - Constraint constraint; - constraint.decl = typeParamDecl.Ptr(); - constraint.val = type; + if (auto sndInt = snd.As<IntVal>()) + { + auto fstParam = fstInt.As<GenericParamIntVal>(); + auto sndParam = sndInt.As<GenericParamIntVal>(); - constraints.constraints.Add(constraint); + if (fstParam) + TryUnifyIntParam(constraints, fstParam->declRef, sndInt); + if (sndParam) + TryUnifyIntParam(constraints, sndParam->declRef, fstInt); - return true; + if (fstParam || sndParam) + return true; + } } - bool TryUnifyIntParam( - ConstraintSystem& constraints, - RefPtr<GenericValueParamDecl> paramDecl, - RefPtr<IntVal> val) - { - // We want to constrain the given parameter to equal the given value. - Constraint constraint; - constraint.decl = paramDecl.Ptr(); - constraint.val = val; + throw "unimplemented"; - constraints.constraints.Add(constraint); + // default: fail + return false; + } - return true; - } + bool TryUnifySubstitutions( + ConstraintSystem& constraints, + RefPtr<Substitutions> fst, + RefPtr<Substitutions> snd) + { + // They must both be NULL or non-NULL + if (!fst || !snd) + return fst == snd; + + // They must be specializing the same generic + if (fst->genericDecl != snd->genericDecl) + return false; - bool TryUnifyIntParam( - ConstraintSystem& constraints, - VarDeclBaseRef const& varRef, - RefPtr<IntVal> val) + // Their arguments must unify + assert(fst->args.Count() == snd->args.Count()); + int argCount = fst->args.Count(); + for (int aa = 0; aa < argCount; ++aa) { - if(auto genericValueParamRef = varRef.As<GenericValueParamDeclRef>()) - { - return TryUnifyIntParam(constraints, genericValueParamRef.GetDecl(), val); - } - else - { + if (!TryUnifyVals(constraints, fst->args[aa], snd->args[aa])) return false; - } } - bool TryUnifyTypesByStructuralMatch( - ConstraintSystem& constraints, - RefPtr<ExpressionType> fst, - RefPtr<ExpressionType> snd) - { - if (auto fstDeclRefType = fst->As<DeclRefType>()) - { - auto fstDeclRef = fstDeclRefType->declRef; + // Their "base" specializations must unify + if (!TryUnifySubstitutions(constraints, fst->outer, snd->outer)) + return false; + + return true; + } - if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) - return TryUnifyTypeParam(constraints, typeParamDecl, snd); + bool TryUnifyTypeParam( + ConstraintSystem& constraints, + RefPtr<GenericTypeParamDecl> typeParamDecl, + RefPtr<ExpressionType> type) + { + // We want to constrain the given type parameter + // to equal the given type. + Constraint constraint; + constraint.decl = typeParamDecl.Ptr(); + constraint.val = type; - if (auto sndDeclRefType = snd->As<DeclRefType>()) - { - auto sndDeclRef = sndDeclRefType->declRef; + constraints.constraints.Add(constraint); - if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl())) - return TryUnifyTypeParam(constraints, typeParamDecl, fst); + return true; + } - // can't be unified if they refer to differnt declarations. - if (fstDeclRef.GetDecl() != sndDeclRef.GetDecl()) return false; + bool TryUnifyIntParam( + ConstraintSystem& constraints, + RefPtr<GenericValueParamDecl> paramDecl, + RefPtr<IntVal> val) + { + // We want to constrain the given parameter to equal the given value. + Constraint constraint; + constraint.decl = paramDecl.Ptr(); + constraint.val = val; - // next we need to unify the substitutions applied - // to each decalration reference. - if (!TryUnifySubstitutions( - constraints, - fstDeclRef.substitutions, - sndDeclRef.substitutions)) - { - return false; - } + constraints.constraints.Add(constraint); - return true; - } - } + return true; + } + bool TryUnifyIntParam( + ConstraintSystem& constraints, + VarDeclBaseRef const& varRef, + RefPtr<IntVal> val) + { + if(auto genericValueParamRef = varRef.As<GenericValueParamDeclRef>()) + { + return TryUnifyIntParam(constraints, genericValueParamRef.GetDecl(), val); + } + else + { return false; } + } - bool TryUnifyTypes( - ConstraintSystem& constraints, - RefPtr<ExpressionType> fst, - RefPtr<ExpressionType> snd) + bool TryUnifyTypesByStructuralMatch( + ConstraintSystem& constraints, + RefPtr<ExpressionType> fst, + RefPtr<ExpressionType> snd) + { + if (auto fstDeclRefType = fst->As<DeclRefType>()) { - if (fst->Equals(snd)) return true; - - // An error type can unify with anything, just so we avoid cascading errors. + auto fstDeclRef = fstDeclRefType->declRef; - if (auto fstErrorType = fst->As<ErrorType>()) - return true; - - if (auto sndErrorType = snd->As<ErrorType>()) - return true; - - // A generic parameter type can unify with anything. - // TODO: there actually needs to be some kind of "occurs check" sort - // of thing here... - - if (auto fstDeclRefType = fst->As<DeclRefType>()) - { - auto fstDeclRef = fstDeclRefType->declRef; - - if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) - return TryUnifyTypeParam(constraints, typeParamDecl, snd); - } + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, snd); if (auto sndDeclRefType = snd->As<DeclRefType>()) { @@ -3709,1336 +3653,1389 @@ namespace Slang if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl())) return TryUnifyTypeParam(constraints, typeParamDecl, fst); - } - - // If we can unify the types structurally, then we are golden - if(TryUnifyTypesByStructuralMatch(constraints, fst, snd)) - return true; - // Now we need to consider cases where coercion might - // need to be applied. For now we can try to do this - // in a completely ad hoc fashion, but eventually we'd - // want to do it more formally. + // can't be unified if they refer to differnt declarations. + if (fstDeclRef.GetDecl() != sndDeclRef.GetDecl()) return false; - if(auto fstVectorType = fst->As<VectorExpressionType>()) - { - if(auto sndScalarType = snd->As<BasicExpressionType>()) + // next we need to unify the substitutions applied + // to each decalration reference. + if (!TryUnifySubstitutions( + constraints, + fstDeclRef.substitutions, + sndDeclRef.substitutions)) { - return TryUnifyTypes( - constraints, - fstVectorType->elementType, - sndScalarType); + return false; } - } - if(auto fstScalarType = fst->As<BasicExpressionType>()) - { - if(auto sndVectorType = snd->As<VectorExpressionType>()) - { - return TryUnifyTypes( - constraints, - fstScalarType, - sndVectorType->elementType); - } + return true; } + } - // TODO: the same thing for vectors... + return false; + } - return false; - } + bool TryUnifyTypes( + ConstraintSystem& constraints, + RefPtr<ExpressionType> fst, + RefPtr<ExpressionType> snd) + { + if (fst->Equals(snd)) return true; + + // An error type can unify with anything, just so we avoid cascading errors. - // Is the candidate extension declaration actually applicable to the given type - ExtensionDeclRef ApplyExtensionToType( - ExtensionDecl* extDecl, - RefPtr<ExpressionType> type) + if (auto fstErrorType = fst->As<ErrorType>()) + return true; + + if (auto sndErrorType = snd->As<ErrorType>()) + return true; + + // A generic parameter type can unify with anything. + // TODO: there actually needs to be some kind of "occurs check" sort + // of thing here... + + if (auto fstDeclRefType = fst->As<DeclRefType>()) { - if (auto extGenericDecl = GetOuterGeneric(extDecl)) - { - ConstraintSystem constraints; + auto fstDeclRef = fstDeclRefType->declRef; - if (!TryUnifyTypes(constraints, extDecl->targetType, type)) - return DeclRef().As<ExtensionDeclRef>(); + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, snd); + } - auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef(extGenericDecl, nullptr).As<GenericDeclRef>()); - if (!constraintSubst) - { - return DeclRef().As<ExtensionDeclRef>(); - } + if (auto sndDeclRefType = snd->As<DeclRefType>()) + { + auto sndDeclRef = sndDeclRefType->declRef; - // Consruct a reference to the extension with our constraint variables - // set as they were found by solving the constraint system. - ExtensionDeclRef extDeclRef = DeclRef(extDecl, constraintSubst).As<ExtensionDeclRef>(); + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, fst); + } - // We expect/require that the result of unification is such that - // the target types are now equal - assert(extDeclRef.GetTargetType()->Equals(type)); + // If we can unify the types structurally, then we are golden + if(TryUnifyTypesByStructuralMatch(constraints, fst, snd)) + return true; - return extDeclRef; - } - else + // Now we need to consider cases where coercion might + // need to be applied. For now we can try to do this + // in a completely ad hoc fashion, but eventually we'd + // want to do it more formally. + + if(auto fstVectorType = fst->As<VectorExpressionType>()) + { + if(auto sndScalarType = snd->As<BasicExpressionType>()) { - // The easy case is when the extension isn't generic: - // either it applies to the type or not. - if (!type->Equals(extDecl->targetType)) - return DeclRef().As<ExtensionDeclRef>(); - return DeclRef(extDecl, nullptr).As<ExtensionDeclRef>(); + return TryUnifyTypes( + constraints, + fstVectorType->elementType, + sndScalarType); } } - bool TryUnifyArgAndParamTypes( - ConstraintSystem& system, - RefPtr<ExpressionSyntaxNode> argExpr, - ParamDeclRef paramDeclRef) + if(auto fstScalarType = fst->As<BasicExpressionType>()) { - // TODO(tfoley): potentially need a bit more - // nuance in case where argument might be - // an overload group... - return TryUnifyTypes(system, argExpr->Type, paramDeclRef.GetType()); + if(auto sndVectorType = snd->As<VectorExpressionType>()) + { + return TryUnifyTypes( + constraints, + fstScalarType, + sndVectorType->elementType); + } } - // Take a generic declaration and try to specialize its parameters - // so that the resulting inner declaration can be applicable in - // a particular context... - DeclRef SpecializeGenericForOverload( - GenericDeclRef genericDeclRef, - OverloadResolveContext& context) + // TODO: the same thing for vectors... + + return false; + } + + // Is the candidate extension declaration actually applicable to the given type + ExtensionDeclRef ApplyExtensionToType( + ExtensionDecl* extDecl, + RefPtr<ExpressionType> type) + { + if (auto extGenericDecl = GetOuterGeneric(extDecl)) { ConstraintSystem constraints; - // Construct a reference to the inner declaration that has any generic - // parameter substitutions in place already, but *not* any substutions - // for the generic declaration we are currently trying to infer. - auto innerDecl = genericDeclRef.GetInner(); - DeclRef unspecializedInnerRef = DeclRef(innerDecl, genericDeclRef.substitutions); + if (!TryUnifyTypes(constraints, extDecl->targetType, type)) + return DeclRef().As<ExtensionDeclRef>(); - // Check what type of declaration we are dealing with, and then try - // to match it up with the arguments accordingly... - if (auto funcDeclRef = unspecializedInnerRef.As<CallableDeclRef>()) + auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef(extGenericDecl, nullptr).As<GenericDeclRef>()); + if (!constraintSubst) { - auto& args = context.appExpr->Arguments; - auto params = funcDeclRef.GetParameters().ToArray(); + return DeclRef().As<ExtensionDeclRef>(); + } - int argCount = args.Count(); - int paramCount = params.Count(); + // Consruct a reference to the extension with our constraint variables + // set as they were found by solving the constraint system. + ExtensionDeclRef extDeclRef = DeclRef(extDecl, constraintSubst).As<ExtensionDeclRef>(); - // Bail out on mismatch. - // TODO(tfoley): need more nuance here - if (argCount != paramCount) - { - return DeclRef(nullptr, nullptr); - } + // We expect/require that the result of unification is such that + // the target types are now equal + assert(extDeclRef.GetTargetType()->Equals(type)); - for (int aa = 0; aa < argCount; ++aa) - { -#if 0 - if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa])) - return DeclRef(nullptr, nullptr); -#else - // The question here is whether failure to "unify" an argument - // and parameter should lead to immediate failure. - // - // The case that is interesting is if we want to unify, say: - // `vector<float,N>` and `vector<int,3>` - // - // It is clear that we should solve with `N = 3`, and then - // a later step may find that the resulting types aren't - // actually a match. - // - // A more refined approach to "unification" could of course - // see that `int` can convert to `float` and use that fact. - // (and indeed we already use something like this to unify - // `float` and `vector<T,3>`) - // - // So the question is then whether a mismatch during the - // unification step should be taken as an immediate failure... + return extDeclRef; + } + else + { + // The easy case is when the extension isn't generic: + // either it applies to the type or not. + if (!type->Equals(extDecl->targetType)) + return DeclRef().As<ExtensionDeclRef>(); + return DeclRef(extDecl, nullptr).As<ExtensionDeclRef>(); + } + } - TryUnifyArgAndParamTypes(constraints, args[aa], params[aa]); -#endif - } - } - else + bool TryUnifyArgAndParamTypes( + ConstraintSystem& system, + RefPtr<ExpressionSyntaxNode> argExpr, + ParamDeclRef paramDeclRef) + { + // TODO(tfoley): potentially need a bit more + // nuance in case where argument might be + // an overload group... + return TryUnifyTypes(system, argExpr->Type, paramDeclRef.GetType()); + } + + // Take a generic declaration and try to specialize its parameters + // so that the resulting inner declaration can be applicable in + // a particular context... + DeclRef SpecializeGenericForOverload( + GenericDeclRef genericDeclRef, + OverloadResolveContext& context) + { + ConstraintSystem constraints; + + // Construct a reference to the inner declaration that has any generic + // parameter substitutions in place already, but *not* any substutions + // for the generic declaration we are currently trying to infer. + auto innerDecl = genericDeclRef.GetInner(); + DeclRef unspecializedInnerRef = DeclRef(innerDecl, genericDeclRef.substitutions); + + // Check what type of declaration we are dealing with, and then try + // to match it up with the arguments accordingly... + if (auto funcDeclRef = unspecializedInnerRef.As<CallableDeclRef>()) + { + auto& args = context.appExpr->Arguments; + auto params = funcDeclRef.GetParameters().ToArray(); + + int argCount = args.Count(); + int paramCount = params.Count(); + + // Bail out on mismatch. + // TODO(tfoley): need more nuance here + if (argCount != paramCount) { - // TODO(tfoley): any other cases needed here? return DeclRef(nullptr, nullptr); } - auto constraintSubst = TrySolveConstraintSystem(&constraints, genericDeclRef); - if (!constraintSubst) + for (int aa = 0; aa < argCount; ++aa) { - // constraint solving failed - return DeclRef(nullptr, nullptr); +#if 0 + if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa])) + return DeclRef(nullptr, nullptr); +#else + // The question here is whether failure to "unify" an argument + // and parameter should lead to immediate failure. + // + // The case that is interesting is if we want to unify, say: + // `vector<float,N>` and `vector<int,3>` + // + // It is clear that we should solve with `N = 3`, and then + // a later step may find that the resulting types aren't + // actually a match. + // + // A more refined approach to "unification" could of course + // see that `int` can convert to `float` and use that fact. + // (and indeed we already use something like this to unify + // `float` and `vector<T,3>`) + // + // So the question is then whether a mismatch during the + // unification step should be taken as an immediate failure... + + TryUnifyArgAndParamTypes(constraints, args[aa], params[aa]); +#endif } + } + else + { + // TODO(tfoley): any other cases needed here? + return DeclRef(nullptr, nullptr); + } + + auto constraintSubst = TrySolveConstraintSystem(&constraints, genericDeclRef); + if (!constraintSubst) + { + // constraint solving failed + return DeclRef(nullptr, nullptr); + } - // We can now construct a reference to the inner declaration using - // the solution to our constraints. - return DeclRef(innerDecl, constraintSubst); + // We can now construct a reference to the inner declaration using + // the solution to our constraints. + return DeclRef(innerDecl, constraintSubst); + } + + void AddAggTypeOverloadCandidates( + LookupResultItem typeItem, + RefPtr<ExpressionType> type, + AggTypeDeclRef aggTypeDeclRef, + OverloadResolveContext& context) + { + for (auto ctorDeclRef : aggTypeDeclRef.GetMembersOfType<ConstructorDeclRef>()) + { + // now work through this candidate... + AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context); } - void AddAggTypeOverloadCandidates( - LookupResultItem typeItem, - RefPtr<ExpressionType> type, - AggTypeDeclRef aggTypeDeclRef, - OverloadResolveContext& context) + // Now walk through any extensions we can find for this types + for (auto ext = aggTypeDeclRef.GetCandidateExtensions(); ext; ext = ext->nextCandidateExtension) { - for (auto ctorDeclRef : aggTypeDeclRef.GetMembersOfType<ConstructorDeclRef>()) + auto extDeclRef = ApplyExtensionToType(ext, type); + if (!extDeclRef) + continue; + + for (auto ctorDeclRef : extDeclRef.GetMembersOfType<ConstructorDeclRef>()) { + // TODO(tfoley): `typeItem` here should really reference the extension... + // now work through this candidate... AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context); } - // Now walk through any extensions we can find for this types - for (auto ext = aggTypeDeclRef.GetCandidateExtensions(); ext; ext = ext->nextCandidateExtension) + // Also check for generic constructors + for (auto genericDeclRef : extDeclRef.GetMembersOfType<GenericDeclRef>()) { - auto extDeclRef = ApplyExtensionToType(ext, type); - if (!extDeclRef) - continue; - - for (auto ctorDeclRef : extDeclRef.GetMembersOfType<ConstructorDeclRef>()) + if (auto ctorDecl = genericDeclRef.GetDecl()->inner.As<ConstructorDecl>()) { - // TODO(tfoley): `typeItem` here should really reference the extension... - - // now work through this candidate... - AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context); - } - - // Also check for generic constructors - for (auto genericDeclRef : extDeclRef.GetMembersOfType<GenericDeclRef>()) - { - if (auto ctorDecl = genericDeclRef.GetDecl()->inner.As<ConstructorDecl>()) - { - DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); - if (!innerRef) - continue; + DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); + if (!innerRef) + continue; - ConstructorDeclRef innerCtorRef = innerRef.As<ConstructorDeclRef>(); + ConstructorDeclRef innerCtorRef = innerRef.As<ConstructorDeclRef>(); - AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context); + AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context); - // TODO(tfoley): need a way to do the solving step for the constraint system - } + // TODO(tfoley): need a way to do the solving step for the constraint system } } } + } - void AddTypeOverloadCandidates( - RefPtr<ExpressionType> type, - OverloadResolveContext& context) + void AddTypeOverloadCandidates( + RefPtr<ExpressionType> type, + OverloadResolveContext& context) + { + if (auto declRefType = type->As<DeclRefType>()) { - if (auto declRefType = type->As<DeclRefType>()) + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) { - if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) - { - AddAggTypeOverloadCandidates(LookupResultItem(aggTypeDeclRef), type, aggTypeDeclRef, context); - } + AddAggTypeOverloadCandidates(LookupResultItem(aggTypeDeclRef), type, aggTypeDeclRef, context); } } + } + + void AddDeclRefOverloadCandidates( + LookupResultItem item, + OverloadResolveContext& context) + { + auto declRef = item.declRef; - void AddDeclRefOverloadCandidates( - LookupResultItem item, - OverloadResolveContext& context) + if (auto funcDeclRef = item.declRef.As<CallableDeclRef>()) { - auto declRef = item.declRef; + AddFuncOverloadCandidate(item, funcDeclRef, context); + } + else if (auto aggTypeDeclRef = item.declRef.As<AggTypeDeclRef>()) + { + auto type = DeclRefType::Create(aggTypeDeclRef); + AddAggTypeOverloadCandidates(item, type, aggTypeDeclRef, context); + } + else if (auto genericDeclRef = item.declRef.As<GenericDeclRef>()) + { + // Try to infer generic arguments, based on the context + DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); - if (auto funcDeclRef = item.declRef.As<CallableDeclRef>()) - { - AddFuncOverloadCandidate(item, funcDeclRef, context); - } - else if (auto aggTypeDeclRef = item.declRef.As<AggTypeDeclRef>()) - { - auto type = DeclRefType::Create(aggTypeDeclRef); - AddAggTypeOverloadCandidates(item, type, aggTypeDeclRef, context); - } - else if (auto genericDeclRef = item.declRef.As<GenericDeclRef>()) + if (innerRef) { - // Try to infer generic arguments, based on the context - DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); + // If inference works, then we've now got a + // specialized declaration reference we can apply. - if (innerRef) - { - // If inference works, then we've now got a - // specialized declaration reference we can apply. - - LookupResultItem innerItem; - innerItem.breadcrumbs = item.breadcrumbs; - innerItem.declRef = innerRef; + LookupResultItem innerItem; + innerItem.breadcrumbs = item.breadcrumbs; + innerItem.declRef = innerRef; - AddDeclRefOverloadCandidates(innerItem, context); - } - else - { - // If inference failed, then we need to create - // a candidate that can be used to reflect that fact - // (so we can report a good error) - OverloadCandidate candidate; - candidate.item = item; - candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric; - candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed; - - AddOverloadCandidateInner(context, candidate); - } - } - else if( auto typeDefDeclRef = item.declRef.As<TypeDefDeclRef>() ) - { - AddTypeOverloadCandidates(typeDefDeclRef.GetType(), context); + AddDeclRefOverloadCandidates(innerItem, context); } else { - // TODO(tfoley): any other cases needed here? + // If inference failed, then we need to create + // a candidate that can be used to reflect that fact + // (so we can report a good error) + OverloadCandidate candidate; + candidate.item = item; + candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric; + candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed; + + AddOverloadCandidateInner(context, candidate); } } - - void AddOverloadCandidates( - RefPtr<ExpressionSyntaxNode> funcExpr, - OverloadResolveContext& context) + else if( auto typeDefDeclRef = item.declRef.As<TypeDefDeclRef>() ) + { + AddTypeOverloadCandidates(typeDefDeclRef.GetType(), context); + } + else { - auto funcExprType = funcExpr->Type; + // TODO(tfoley): any other cases needed here? + } + } - if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>()) - { - // The expression referenced a function declaration - AddDeclRefOverloadCandidates(LookupResultItem(funcDeclRefExpr->declRef), context); - } - else if (auto funcType = funcExprType->As<FuncType>()) - { - // TODO(tfoley): deprecate this path... - AddFuncOverloadCandidate(funcType, context); - } - else if (auto overloadedExpr = funcExpr.As<OverloadedExpr>()) - { - auto lookupResult = overloadedExpr->lookupResult2; - assert(lookupResult.isOverloaded()); - for(auto item : lookupResult.items) - { - AddDeclRefOverloadCandidates(item, context); - } - } - else if (auto typeType = funcExprType->As<TypeType>()) + void AddOverloadCandidates( + RefPtr<ExpressionSyntaxNode> funcExpr, + OverloadResolveContext& context) + { + auto funcExprType = funcExpr->Type; + + if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>()) + { + // The expression referenced a function declaration + AddDeclRefOverloadCandidates(LookupResultItem(funcDeclRefExpr->declRef), context); + } + else if (auto funcType = funcExprType->As<FuncType>()) + { + // TODO(tfoley): deprecate this path... + AddFuncOverloadCandidate(funcType, context); + } + else if (auto overloadedExpr = funcExpr.As<OverloadedExpr>()) + { + auto lookupResult = overloadedExpr->lookupResult2; + assert(lookupResult.isOverloaded()); + for(auto item : lookupResult.items) { - // If none of the above cases matched, but we are - // looking at a type, then I suppose we have - // a constructor call on our hands. - // - // TODO(tfoley): are there any meaningful types left - // that aren't declaration references? - AddTypeOverloadCandidates(typeType->type, context); - return; + AddDeclRefOverloadCandidates(item, context); } } + else if (auto typeType = funcExprType->As<TypeType>()) + { + // If none of the above cases matched, but we are + // looking at a type, then I suppose we have + // a constructor call on our hands. + // + // TODO(tfoley): are there any meaningful types left + // that aren't declaration references? + AddTypeOverloadCandidates(typeType->type, context); + return; + } + } + + void formatType(StringBuilder& sb, RefPtr<ExpressionType> type) + { + sb << type->ToString(); + } + + void formatVal(StringBuilder& sb, RefPtr<Val> val) + { + sb << val->ToString(); + } - void formatType(StringBuilder& sb, RefPtr<ExpressionType> type) + void formatDeclPath(StringBuilder& sb, DeclRef declRef) + { + // Find the parent declaration + auto parentDeclRef = declRef.GetParent(); + + // If the immediate parent is a generic, then we probably + // want the declaration above that... + auto parentGenericDeclRef = parentDeclRef.As<GenericDeclRef>(); + if(parentGenericDeclRef) { - sb << type->ToString(); + parentDeclRef = parentGenericDeclRef.GetParent(); } - void formatVal(StringBuilder& sb, RefPtr<Val> val) + // Depending on what the parent is, we may want to format things specially + if(auto aggTypeDeclRef = parentDeclRef.As<AggTypeDeclRef>()) { - sb << val->ToString(); + formatDeclPath(sb, aggTypeDeclRef); + sb << "."; } - void formatDeclPath(StringBuilder& sb, DeclRef declRef) + sb << declRef.GetName(); + + // If the parent declaration is a generic, then we need to print out its + // signature + if( parentGenericDeclRef ) { - // Find the parent declaration - auto parentDeclRef = declRef.GetParent(); + assert(declRef.substitutions); + assert(declRef.substitutions->genericDecl == parentGenericDeclRef.GetDecl()); - // If the immediate parent is a generic, then we probably - // want the declaration above that... - auto parentGenericDeclRef = parentDeclRef.As<GenericDeclRef>(); - if(parentGenericDeclRef) + sb << "<"; + bool first = true; + for(auto arg : declRef.substitutions->args) { - parentDeclRef = parentGenericDeclRef.GetParent(); + if(!first) sb << ", "; + formatVal(sb, arg); + first = false; } + sb << ">"; + } + } - // Depending on what the parent is, we may want to format things specially - if(auto aggTypeDeclRef = parentDeclRef.As<AggTypeDeclRef>()) - { - formatDeclPath(sb, aggTypeDeclRef); - sb << "."; - } + void formatDeclParams(StringBuilder& sb, DeclRef declRef) + { + if (auto funcDeclRef = declRef.As<CallableDeclRef>()) + { - sb << declRef.GetName(); + // This is something callable, so we need to also print parameter types for overloading + sb << "("; - // If the parent declaration is a generic, then we need to print out its - // signature - if( parentGenericDeclRef ) + bool first = true; + for (auto paramDeclRef : funcDeclRef.GetParameters()) { - assert(declRef.substitutions); - assert(declRef.substitutions->genericDecl == parentGenericDeclRef.GetDecl()); + if (!first) sb << ", "; + + formatType(sb, paramDeclRef.GetType()); + + first = false; - sb << "<"; - bool first = true; - for(auto arg : declRef.substitutions->args) - { - if(!first) sb << ", "; - formatVal(sb, arg); - first = false; - } - sb << ">"; } - } - void formatDeclParams(StringBuilder& sb, DeclRef declRef) + sb << ")"; + } + else if(auto genericDeclRef = declRef.As<GenericDeclRef>()) { - if (auto funcDeclRef = declRef.As<CallableDeclRef>()) + sb << "<"; + bool first = true; + for (auto paramDeclRef : genericDeclRef.GetMembers()) { - - // This is something callable, so we need to also print parameter types for overloading - sb << "("; - - bool first = true; - for (auto paramDeclRef : funcDeclRef.GetParameters()) + if(auto genericTypeParam = paramDeclRef.As<GenericTypeParamDeclRef>()) { if (!first) sb << ", "; - - formatType(sb, paramDeclRef.GetType()); - first = false; + sb << genericTypeParam.GetName(); } - - sb << ")"; - } - else if(auto genericDeclRef = declRef.As<GenericDeclRef>()) - { - sb << "<"; - bool first = true; - for (auto paramDeclRef : genericDeclRef.GetMembers()) + else if(auto genericValParam = paramDeclRef.As<GenericValueParamDeclRef>()) { - if(auto genericTypeParam = paramDeclRef.As<GenericTypeParamDeclRef>()) - { - if (!first) sb << ", "; - first = false; - - sb << genericTypeParam.GetName(); - } - else if(auto genericValParam = paramDeclRef.As<GenericValueParamDeclRef>()) - { - if (!first) sb << ", "; - first = false; + if (!first) sb << ", "; + first = false; - formatType(sb, genericValParam.GetType()); - sb << " "; - sb << genericValParam.GetName(); - } - else - {} + formatType(sb, genericValParam.GetType()); + sb << " "; + sb << genericValParam.GetName(); } - sb << ">"; - - formatDeclParams(sb, DeclRef(genericDeclRef.GetInner(), genericDeclRef.substitutions)); - } - else - { + else + {} } - } + sb << ">"; - void formatDeclSignature(StringBuilder& sb, DeclRef declRef) + formatDeclParams(sb, DeclRef(genericDeclRef.GetInner(), genericDeclRef.substitutions)); + } + else { - formatDeclPath(sb, declRef); - formatDeclParams(sb, declRef); } + } + + void formatDeclSignature(StringBuilder& sb, DeclRef declRef) + { + formatDeclPath(sb, declRef); + formatDeclParams(sb, declRef); + } + + String getDeclSignatureString(DeclRef declRef) + { + StringBuilder sb; + formatDeclSignature(sb, declRef); + return sb.ProduceString(); + } + + String getDeclSignatureString(LookupResultItem item) + { + return getDeclSignatureString(item.declRef); + } - String getDeclSignatureString(DeclRef declRef) + String GetCallSignatureString(RefPtr<AppExprBase> expr) + { + StringBuilder argsListBuilder; + argsListBuilder << "("; + bool first = true; + for (auto a : expr->Arguments) { - StringBuilder sb; - formatDeclSignature(sb, declRef); - return sb.ProduceString(); + if (!first) argsListBuilder << ", "; + argsListBuilder << a->Type->ToString(); + first = false; } + argsListBuilder << ")"; + return argsListBuilder.ProduceString(); + } + - String getDeclSignatureString(LookupResultItem item) + RefPtr<ExpressionSyntaxNode> ResolveInvoke(InvokeExpressionSyntaxNode * expr) + { + // Look at the base expression for the call, and figure out how to invoke it. + auto funcExpr = expr->FunctionExpr; + auto funcExprType = funcExpr->Type; + + // If we are trying to apply an erroroneous expression, then just bail out now. + if(IsErrorExpr(funcExpr)) { - return getDeclSignatureString(item.declRef); + return CreateErrorExpr(expr); } - String GetCallSignatureString(RefPtr<AppExprBase> expr) + OverloadResolveContext context; + context.appExpr = expr; + if (auto funcMemberExpr = funcExpr.As<MemberExpressionSyntaxNode>()) { - StringBuilder argsListBuilder; - argsListBuilder << "("; - bool first = true; - for (auto a : expr->Arguments) - { - if (!first) argsListBuilder << ", "; - argsListBuilder << a->Type->ToString(); - first = false; - } - argsListBuilder << ")"; - return argsListBuilder.ProduceString(); + context.baseExpr = funcMemberExpr->BaseExpression; } - - - RefPtr<ExpressionSyntaxNode> ResolveInvoke(InvokeExpressionSyntaxNode * expr) + else if (auto funcOverloadExpr = funcExpr.As<OverloadedExpr>()) { - // Look at the base expression for the call, and figure out how to invoke it. - auto funcExpr = expr->FunctionExpr; - auto funcExprType = funcExpr->Type; - - // If we are trying to apply an erroroneous expression, then just bail out now. - if(IsErrorExpr(funcExpr)) - { - return CreateErrorExpr(expr); - } + context.baseExpr = funcOverloadExpr->base; + } + AddOverloadCandidates(funcExpr, context); - OverloadResolveContext context; - context.appExpr = expr; - if (auto funcMemberExpr = funcExpr.As<MemberExpressionSyntaxNode>()) - { - context.baseExpr = funcMemberExpr->BaseExpression; - } - else if (auto funcOverloadExpr = funcExpr.As<OverloadedExpr>()) - { - context.baseExpr = funcOverloadExpr->base; - } - AddOverloadCandidates(funcExpr, context); + if (context.bestCandidates.Count() > 0) + { + // Things were ambiguous. - if (context.bestCandidates.Count() > 0) + // It might be that things were only ambiguous because + // one of the argument expressions had an error, and + // so a bunch of candidates could match at that position. + // + // If any argument was an error, we skip out on printing + // another message, to avoid cascading errors. + for (auto arg : expr->Arguments) { - // Things were ambiguous. - - // It might be that things were only ambiguous because - // one of the argument expressions had an error, and - // so a bunch of candidates could match at that position. - // - // If any argument was an error, we skip out on printing - // another message, to avoid cascading errors. - for (auto arg : expr->Arguments) + if (IsErrorExpr(arg)) { - if (IsErrorExpr(arg)) - { - return CreateErrorExpr(expr); - } + return CreateErrorExpr(expr); } + } - String funcName; - if (auto baseVar = funcExpr.As<VarExpressionSyntaxNode>()) - funcName = baseVar->Variable; - else if(auto baseMemberRef = funcExpr.As<MemberExpressionSyntaxNode>()) - funcName = baseMemberRef->MemberName; + String funcName; + if (auto baseVar = funcExpr.As<VarExpressionSyntaxNode>()) + funcName = baseVar->Variable; + else if(auto baseMemberRef = funcExpr.As<MemberExpressionSyntaxNode>()) + funcName = baseMemberRef->MemberName; - String argsList = GetCallSignatureString(expr); + String argsList = GetCallSignatureString(expr); - if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) + if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) + { + // There were multple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. + if (funcName.Length() != 0) { - // There were multple equally-good candidates, but none actually usable. - // We will construct a diagnostic message to help out. - if (funcName.Length() != 0) - { - getSink()->diagnose(expr, Diagnostics::noApplicableOverloadForNameWithArgs, funcName, argsList); - } - else - { - getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList); - } + getSink()->diagnose(expr, Diagnostics::noApplicableOverloadForNameWithArgs, funcName, argsList); } else { - // There were multiple applicable candidates, so we need to report them. - - if (funcName.Length() != 0) - { - getSink()->diagnose(expr, Diagnostics::ambiguousOverloadForNameWithArgs, funcName, argsList); - } - else - { - getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList); - } + getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList); } + } + else + { + // There were multiple applicable candidates, so we need to report them. - int candidateCount = context.bestCandidates.Count(); - int maxCandidatesToPrint = 10; // don't show too many candidates at once... - int candidateIndex = 0; - for (auto candidate : context.bestCandidates) + if (funcName.Length() != 0) { - String declString = getDeclSignatureString(candidate.item); - - declString = declString + "[" + String(candidate.conversionCostSum) + "]"; - - getSink()->diagnose(candidate.item.declRef, Diagnostics::overloadCandidate, declString); - - candidateIndex++; - if (candidateIndex == maxCandidatesToPrint) - break; + getSink()->diagnose(expr, Diagnostics::ambiguousOverloadForNameWithArgs, funcName, argsList); } - if (candidateIndex != candidateCount) + else { - getSink()->diagnose(expr, Diagnostics::moreOverloadCandidates, candidateCount - candidateIndex); + getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList); } - - return CreateErrorExpr(expr); } - else if (context.bestCandidate) + + int candidateCount = context.bestCandidates.Count(); + int maxCandidatesToPrint = 10; // don't show too many candidates at once... + int candidateIndex = 0; + for (auto candidate : context.bestCandidates) { - // There was one best candidate, even if it might not have been - // applicable in the end. - // We will report errors for this one candidate, then, to give - // the user the most help we can. - return CompleteOverloadCandidate(context, *context.bestCandidate); + String declString = getDeclSignatureString(candidate.item); + + declString = declString + "[" + String(candidate.conversionCostSum) + "]"; + + getSink()->diagnose(candidate.item.declRef, Diagnostics::overloadCandidate, declString); + + candidateIndex++; + if (candidateIndex == maxCandidatesToPrint) + break; } - else + if (candidateIndex != candidateCount) { - // Nothing at all was found that we could even consider invoking - getSink()->diagnose(expr->FunctionExpr, Diagnostics::expectedFunction); - expr->Type = ExpressionType::Error; - return expr; + getSink()->diagnose(expr, Diagnostics::moreOverloadCandidates, candidateCount - candidateIndex); } + + return CreateErrorExpr(expr); + } + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + return CompleteOverloadCandidate(context, *context.bestCandidate); } + else + { + // Nothing at all was found that we could even consider invoking + getSink()->diagnose(expr->FunctionExpr, Diagnostics::expectedFunction); + expr->Type = ExpressionType::Error; + return expr; + } + } - void AddGenericOverloadCandidate( - LookupResultItem baseItem, - OverloadResolveContext& context) + void AddGenericOverloadCandidate( + LookupResultItem baseItem, + OverloadResolveContext& context) + { + if (auto genericDeclRef = baseItem.declRef.As<GenericDeclRef>()) { - if (auto genericDeclRef = baseItem.declRef.As<GenericDeclRef>()) - { - EnsureDecl(genericDeclRef.GetDecl()); + EnsureDecl(genericDeclRef.GetDecl()); - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Generic; - candidate.item = baseItem; - candidate.resultType = nullptr; + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Generic; + candidate.item = baseItem; + candidate.resultType = nullptr; - AddOverloadCandidate(context, candidate); - } + AddOverloadCandidate(context, candidate); } + } - void AddGenericOverloadCandidates( - RefPtr<ExpressionSyntaxNode> baseExpr, - OverloadResolveContext& context) + void AddGenericOverloadCandidates( + RefPtr<ExpressionSyntaxNode> baseExpr, + OverloadResolveContext& context) + { + if(auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>()) { - if(auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>()) - { - auto declRef = baseDeclRefExpr->declRef; - AddGenericOverloadCandidate(LookupResultItem(declRef), context); - } - else if (auto overloadedExpr = baseExpr.As<OverloadedExpr>()) - { - // We are referring to a bunch of declarations, each of which might be generic - LookupResult result; - for (auto item : overloadedExpr->lookupResult2.items) - { - AddGenericOverloadCandidate(item, context); - } - } - else + auto declRef = baseDeclRefExpr->declRef; + AddGenericOverloadCandidate(LookupResultItem(declRef), context); + } + else if (auto overloadedExpr = baseExpr.As<OverloadedExpr>()) + { + // We are referring to a bunch of declarations, each of which might be generic + LookupResult result; + for (auto item : overloadedExpr->lookupResult2.items) { - // any other cases? + AddGenericOverloadCandidate(item, context); } } - - RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr * genericAppExpr) override + else { - // We are applying a generic to arguments, but there might be multiple generic - // declarations with the same name, so this becomes a specialized case of - // overload resolution. + // any other cases? + } + } + RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr * genericAppExpr) override + { + // We are applying a generic to arguments, but there might be multiple generic + // declarations with the same name, so this becomes a specialized case of + // overload resolution. - // Start by checking the base expression and arguments. - auto& baseExpr = genericAppExpr->FunctionExpr; - baseExpr = CheckTerm(baseExpr); - auto& args = genericAppExpr->Arguments; - for (auto& arg : args) - { - arg = CheckTerm(arg); - } - // If there was an error in the base expression, or in any of - // the arguments, then just bail. - if (IsErrorExpr(baseExpr)) + // Start by checking the base expression and arguments. + auto& baseExpr = genericAppExpr->FunctionExpr; + baseExpr = CheckTerm(baseExpr); + auto& args = genericAppExpr->Arguments; + for (auto& arg : args) + { + arg = CheckTerm(arg); + } + + // If there was an error in the base expression, or in any of + // the arguments, then just bail. + if (IsErrorExpr(baseExpr)) + { + return CreateErrorExpr(genericAppExpr); + } + for (auto argExpr : args) + { + if (IsErrorExpr(argExpr)) { return CreateErrorExpr(genericAppExpr); } - for (auto argExpr : args) - { - if (IsErrorExpr(argExpr)) - { - return CreateErrorExpr(genericAppExpr); - } - } + } - // Otherwise, let's start looking at how to find an overload... + // Otherwise, let's start looking at how to find an overload... - OverloadResolveContext context; - context.appExpr = genericAppExpr; - context.baseExpr = GetBaseExpr(baseExpr); + OverloadResolveContext context; + context.appExpr = genericAppExpr; + context.baseExpr = GetBaseExpr(baseExpr); - AddGenericOverloadCandidates(baseExpr, context); + AddGenericOverloadCandidates(baseExpr, context); - if (context.bestCandidates.Count() > 0) + if (context.bestCandidates.Count() > 0) + { + // Things were ambiguous. + if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) { - // Things were ambiguous. - if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) - { - // There were multple equally-good candidates, but none actually usable. - // We will construct a diagnostic message to help out. + // There were multple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. - // TODO(tfoley): print a reasonable message here... + // TODO(tfoley): print a reasonable message here... - getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "no applicable generic"); + getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "no applicable generic"); - return CreateErrorExpr(genericAppExpr); - } - else - { - // There were multiple viable candidates, but that isn't an error: we just need - // to complete all of them and create an overloaded expression as a result. + return CreateErrorExpr(genericAppExpr); + } + else + { + // There were multiple viable candidates, but that isn't an error: we just need + // to complete all of them and create an overloaded expression as a result. - LookupResult result; - for (auto candidate : context.bestCandidates) - { - auto candidateExpr = CompleteOverloadCandidate(context, candidate); - } + LookupResult result; + for (auto candidate : context.bestCandidates) + { + auto candidateExpr = CompleteOverloadCandidate(context, candidate); + } - throw "what now?"; + throw "what now?"; // auto overloadedExpr = new OverloadedExpr(); // return overloadedExpr; - } - } - else if (context.bestCandidate) - { - // There was one best candidate, even if it might not have been - // applicable in the end. - // We will report errors for this one candidate, then, to give - // the user the most help we can. - return CompleteOverloadCandidate(context, *context.bestCandidate); - } - else - { - // Nothing at all was found that we could even consider invoking - getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "expected a generic"); - return CreateErrorExpr(genericAppExpr); } + } + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + return CompleteOverloadCandidate(context, *context.bestCandidate); + } + else + { + // Nothing at all was found that we could even consider invoking + getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "expected a generic"); + return CreateErrorExpr(genericAppExpr); + } #if TIMREMOVED - if (IsErrorExpr(base)) - { - return CreateErrorExpr(typeNode); - } - else if(auto baseDeclRefExpr = base.As<DeclRefExpr>()) - { - auto declRef = baseDeclRefExpr->declRef; + if (IsErrorExpr(base)) + { + return CreateErrorExpr(typeNode); + } + else if(auto baseDeclRefExpr = base.As<DeclRefExpr>()) + { + auto declRef = baseDeclRefExpr->declRef; - if (auto genericDeclRef = declRef.As<GenericDeclRef>()) + if (auto genericDeclRef = declRef.As<GenericDeclRef>()) + { + int argCount = typeNode->Args.Count(); + int argIndex = 0; + for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) { - int argCount = typeNode->Args.Count(); - int argIndex = 0; - for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) + if (auto typeParam = member.As<GenericTypeParamDecl>()) { - if (auto typeParam = member.As<GenericTypeParamDecl>()) + if (argIndex == argCount) { - if (argIndex == argCount) - { - // Too few arguments! - - } + // Too few arguments! - // TODO: checking! } - else if (auto valParam = member.As<GenericValueParamDecl>()) - { - // TODO: checking - } - else - { - } + // TODO: checking! } - if (argIndex != argCount) + else if (auto valParam = member.As<GenericValueParamDecl>()) { - // Too many arguments! + // TODO: checking } + else + { - // Now instantiate the declaration given those arguments - auto type = InstantiateGenericType(genericDeclRef, args); - typeResult = type; - typeNode->Type = new TypeExpressionType(type); - return typeNode; + } } + if (argIndex != argCount) + { + // Too many arguments! + } + + // Now instantiate the declaration given those arguments + auto type = InstantiateGenericType(genericDeclRef, args); + typeResult = type; + typeNode->Type = new TypeExpressionType(type); + return typeNode; } - else if (auto overloadedExpr = base.As<OverloadedExpr>()) + } + else if (auto overloadedExpr = base.As<OverloadedExpr>()) + { + // We are referring to a bunch of declarations, each of which might be generic + LookupResult result; + for (auto item : overloadedExpr->lookupResult2.items) { - // We are referring to a bunch of declarations, each of which might be generic - LookupResult result; - for (auto item : overloadedExpr->lookupResult2.items) - { - auto applied = TryApplyGeneric(item, typeNode); - if (!applied) - continue; + auto applied = TryApplyGeneric(item, typeNode); + if (!applied) + continue; - AddToLookupResult(result, appliedItem); - } + AddToLookupResult(result, appliedItem); } + } - // TODO: correct diagnostic here! - getSink()->diagnose(typeNode, Diagnostics::expectedAGeneric, base->Type); - return CreateErrorExpr(typeNode); + // TODO: correct diagnostic here! + getSink()->diagnose(typeNode, Diagnostics::expectedAGeneric, base->Type); + return CreateErrorExpr(typeNode); #endif - } + } - RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* expr) override + RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* expr) override + { + if (!expr->Type.Ptr()) { - if (!expr->Type.Ptr()) - { - expr->base = CheckProperType(expr->base); - expr->Type = expr->base.exp->Type; - } - return expr; + expr->base = CheckProperType(expr->base); + expr->Type = expr->base.exp->Type; } + return expr; + } - RefPtr<ExpressionSyntaxNode> CheckExpr(RefPtr<ExpressionSyntaxNode> expr) - { - return expr->Accept(this).As<ExpressionSyntaxNode>(); - } + RefPtr<ExpressionSyntaxNode> CheckExpr(RefPtr<ExpressionSyntaxNode> expr) + { + return expr->Accept(this).As<ExpressionSyntaxNode>(); + } - RefPtr<ExpressionSyntaxNode> CheckInvokeExprWithCheckedOperands(InvokeExpressionSyntaxNode *expr) - { + RefPtr<ExpressionSyntaxNode> CheckInvokeExprWithCheckedOperands(InvokeExpressionSyntaxNode *expr) + { - auto rs = ResolveInvoke(expr); - if (auto invoke = dynamic_cast<InvokeExpressionSyntaxNode*>(rs.Ptr())) + auto rs = ResolveInvoke(expr); + if (auto invoke = dynamic_cast<InvokeExpressionSyntaxNode*>(rs.Ptr())) + { + // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues + if(auto funcType = invoke->FunctionExpr->Type->As<FuncType>()) { - // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues - if(auto funcType = invoke->FunctionExpr->Type->As<FuncType>()) + List<RefPtr<ParameterSyntaxNode>> paramsStorage; + List<RefPtr<ParameterSyntaxNode>> * params = nullptr; + if (auto func = funcType->declRef.GetDecl()) { - List<RefPtr<ParameterSyntaxNode>> paramsStorage; - List<RefPtr<ParameterSyntaxNode>> * params = nullptr; - if (auto func = funcType->declRef.GetDecl()) - { - paramsStorage = func->GetParameters().ToArray(); - params = ¶msStorage; - } - if (params) + paramsStorage = func->GetParameters().ToArray(); + params = ¶msStorage; + } + if (params) + { + for (int i = 0; i < (*params).Count(); i++) { - for (int i = 0; i < (*params).Count(); i++) + if ((*params)[i]->HasModifier<OutModifier>()) { - if ((*params)[i]->HasModifier<OutModifier>()) + if (i < expr->Arguments.Count() && expr->Arguments[i]->Type->AsBasicType() && + !expr->Arguments[i]->Type.IsLeftValue) { - if (i < expr->Arguments.Count() && expr->Arguments[i]->Type->AsBasicType() && - !expr->Arguments[i]->Type.IsLeftValue) - { - getSink()->diagnose(expr->Arguments[i], Diagnostics::argumentExpectedLValue, (*params)[i]->Name); - } + getSink()->diagnose(expr->Arguments[i], Diagnostics::argumentExpectedLValue, (*params)[i]->Name); } } } } } - return rs; } + return rs; + } - virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode *expr) override + virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode *expr) override + { + // check the base expression first + expr->FunctionExpr = CheckExpr(expr->FunctionExpr); + + // Next check the argument expressions + for (auto & arg : expr->Arguments) { - // check the base expression first - expr->FunctionExpr = CheckExpr(expr->FunctionExpr); + arg = CheckExpr(arg); + } - // Next check the argument expressions - for (auto & arg : expr->Arguments) - { - arg = CheckExpr(arg); - } + return CheckInvokeExprWithCheckedOperands(expr); + } - return CheckInvokeExprWithCheckedOperands(expr); - } + virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode *expr) override + { + // If we've already resolved this expression, don't try again. + if (expr->declRef) + return expr; + + expr->Type = ExpressionType::Error; - virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode *expr) override + auto lookupResult = LookUp(expr->Variable, expr->scope); + if (lookupResult.isValid()) { - // If we've already resolved this expression, don't try again. - if (expr->declRef) - return expr; + return createLookupResultExpr( + lookupResult, + nullptr, + expr); + } - expr->Type = ExpressionType::Error; + getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->Variable); - auto lookupResult = LookUp(expr->Variable, expr->scope); - if (lookupResult.isValid()) - { - return createLookupResultExpr( - lookupResult, - nullptr, - expr); - } - - getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->Variable); + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * expr) override + { + expr->Expression = expr->Expression->Accept(this).As<ExpressionSyntaxNode>(); + auto targetType = CheckProperType(expr->TargetType); + expr->TargetType = targetType; + // The way to perform casting depends on the types involved + if (expr->Expression->Type->Equals(ExpressionType::Error.Ptr())) + { + // If the expression being casted has an error type, then just silently succeed + expr->Type = targetType; return expr; } - virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * expr) override + else if (auto targetArithType = targetType->AsArithmeticType()) { - expr->Expression = expr->Expression->Accept(this).As<ExpressionSyntaxNode>(); - auto targetType = CheckProperType(expr->TargetType); - expr->TargetType = targetType; - - // The way to perform casting depends on the types involved - if (expr->Expression->Type->Equals(ExpressionType::Error.Ptr())) + if (auto exprArithType = expr->Expression->Type->AsArithmeticType()) { - // If the expression being casted has an error type, then just silently succeed - expr->Type = targetType; - return expr; - } - else if (auto targetArithType = targetType->AsArithmeticType()) - { - if (auto exprArithType = expr->Expression->Type->AsArithmeticType()) - { - // Both source and destination types are arithmetic, so we might - // have a valid cast - auto targetScalarType = targetArithType->GetScalarType(); - auto exprScalarType = exprArithType->GetScalarType(); + // Both source and destination types are arithmetic, so we might + // have a valid cast + auto targetScalarType = targetArithType->GetScalarType(); + auto exprScalarType = exprArithType->GetScalarType(); - if (!IsNumeric(exprScalarType->BaseType)) goto fail; - if (!IsNumeric(targetScalarType->BaseType)) goto fail; + if (!IsNumeric(exprScalarType->BaseType)) goto fail; + if (!IsNumeric(targetScalarType->BaseType)) goto fail; - // TODO(tfoley): this checking is incomplete here, and could - // lead to downstream compilation failures - expr->Type = targetType; - return expr; - } + // TODO(tfoley): this checking is incomplete here, and could + // lead to downstream compilation failures + expr->Type = targetType; + return expr; } - - fail: - // Default: in no other case succeds, then the cast failed and we emit a diagnostic. - getSink()->diagnose(expr, Diagnostics::invalidTypeCast, expr->Expression->Type, targetType->ToString()); - expr->Type = ExpressionType::Error; - return expr; } + + fail: + // Default: in no other case succeds, then the cast failed and we emit a diagnostic. + getSink()->diagnose(expr, Diagnostics::invalidTypeCast, expr->Expression->Type, targetType->ToString()); + expr->Type = ExpressionType::Error; + return expr; + } #if TIMREMOVED - virtual RefPtr<ExpressionSyntaxNode> VisitSelectExpression(SelectExpressionSyntaxNode * expr) override + virtual RefPtr<ExpressionSyntaxNode> VisitSelectExpression(SelectExpressionSyntaxNode * expr) override + { + auto selectorExpr = expr->SelectorExpr; + selectorExpr = CheckExpr(selectorExpr); + selectorExpr = Coerce(ExpressionType::GetBool(), selectorExpr); + expr->SelectorExpr = selectorExpr; + + // TODO(tfoley): We need a general purpose "join" on types for inferring + // generic argument types for builtins/intrinsics, so this should really + // be using the exact same logic... + // + expr->Expr0 = expr->Expr0->Accept(this).As<ExpressionSyntaxNode>(); + expr->Expr1 = expr->Expr1->Accept(this).As<ExpressionSyntaxNode>(); + if (!expr->Expr0->Type->Equals(expr->Expr1->Type.Ptr())) { - auto selectorExpr = expr->SelectorExpr; - selectorExpr = CheckExpr(selectorExpr); - selectorExpr = Coerce(ExpressionType::GetBool(), selectorExpr); - expr->SelectorExpr = selectorExpr; - - // TODO(tfoley): We need a general purpose "join" on types for inferring - // generic argument types for builtins/intrinsics, so this should really - // be using the exact same logic... - // - expr->Expr0 = expr->Expr0->Accept(this).As<ExpressionSyntaxNode>(); - expr->Expr1 = expr->Expr1->Accept(this).As<ExpressionSyntaxNode>(); - if (!expr->Expr0->Type->Equals(expr->Expr1->Type.Ptr())) - { - getSink()->diagnose(expr, Diagnostics::selectValuesTypeMismatch); - } - expr->Type = expr->Expr0->Type; - return expr; + getSink()->diagnose(expr, Diagnostics::selectValuesTypeMismatch); } + expr->Type = expr->Expr0->Type; + return expr; + } #endif - // Get the type to use when referencing a declaration - QualType GetTypeForDeclRef(DeclRef declRef) - { - return getTypeForDeclRef( - this, - getSink(), - declRef, - &typeResult); - } + // Get the type to use when referencing a declaration + QualType GetTypeForDeclRef(DeclRef declRef) + { + return getTypeForDeclRef( + this, + getSink(), + declRef, + &typeResult); + } - RefPtr<ExpressionSyntaxNode> MaybeDereference(RefPtr<ExpressionSyntaxNode> inExpr) + RefPtr<ExpressionSyntaxNode> MaybeDereference(RefPtr<ExpressionSyntaxNode> inExpr) + { + RefPtr<ExpressionSyntaxNode> expr = inExpr; + for (;;) { - RefPtr<ExpressionSyntaxNode> expr = inExpr; - for (;;) + auto& type = expr->Type; + if (auto pointerLikeType = type->As<PointerLikeType>()) { - auto& type = expr->Type; - if (auto pointerLikeType = type->As<PointerLikeType>()) - { - type = pointerLikeType->elementType; + type = pointerLikeType->elementType; - auto derefExpr = new DerefExpr(); - derefExpr->base = expr; - derefExpr->Type = pointerLikeType->elementType; + auto derefExpr = new DerefExpr(); + derefExpr->base = expr; + derefExpr->Type = pointerLikeType->elementType; - // TODO(tfoley): deal with l-value-ness here - - expr = derefExpr; - continue; - } + // TODO(tfoley): deal with l-value-ness here - // Default case: just use the expression as-is - return expr; + expr = derefExpr; + continue; } + + // Default case: just use the expression as-is + return expr; } + } - RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( - MemberExpressionSyntaxNode* memberRefExpr, - RefPtr<ExpressionType> baseElementType, - int baseElementCount) - { - RefPtr<SwizzleExpr> swizExpr = new SwizzleExpr(); - swizExpr->Position = memberRefExpr->Position; - swizExpr->base = memberRefExpr->BaseExpression; + RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( + MemberExpressionSyntaxNode* memberRefExpr, + RefPtr<ExpressionType> baseElementType, + int baseElementCount) + { + RefPtr<SwizzleExpr> swizExpr = new SwizzleExpr(); + swizExpr->Position = memberRefExpr->Position; + swizExpr->base = memberRefExpr->BaseExpression; - int limitElement = baseElementCount; + int limitElement = baseElementCount; - int elementIndices[4]; - int elementCount = 0; + int elementIndices[4]; + int elementCount = 0; - bool elementUsed[4] = { false, false, false, false }; - bool anyDuplicates = false; - bool anyError = false; + bool elementUsed[4] = { false, false, false, false }; + bool anyDuplicates = false; + bool anyError = false; - for (int i = 0; i < memberRefExpr->MemberName.Length(); i++) + for (int i = 0; i < memberRefExpr->MemberName.Length(); i++) + { + auto ch = memberRefExpr->MemberName[i]; + int elementIndex = -1; + switch (ch) { - auto ch = memberRefExpr->MemberName[i]; - int elementIndex = -1; - switch (ch) - { - case 'x': case 'r': elementIndex = 0; break; - case 'y': case 'g': elementIndex = 1; break; - case 'z': case 'b': elementIndex = 2; break; - case 'w': case 'a': elementIndex = 3; break; - default: - // An invalid character in the swizzle is an error - getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "invalid component name for swizzle"); - anyError = true; - continue; - } - - // TODO(tfoley): GLSL requires that all component names - // come from the same "family"... - - // Make sure the index is in range for the source type - if (elementIndex >= limitElement) - { - getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "swizzle component out of range for type"); - anyError = true; - continue; - } - - // Check if we've seen this index before - for (int ee = 0; ee < elementCount; ee++) - { - if (elementIndices[ee] == elementIndex) - anyDuplicates = true; - } - - // add to our list... - elementIndices[elementCount++] = elementIndex; + case 'x': case 'r': elementIndex = 0; break; + case 'y': case 'g': elementIndex = 1; break; + case 'z': case 'b': elementIndex = 2; break; + case 'w': case 'a': elementIndex = 3; break; + default: + // An invalid character in the swizzle is an error + getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "invalid component name for swizzle"); + anyError = true; + continue; } - for (int ee = 0; ee < elementCount; ++ee) - { - swizExpr->elementIndices[ee] = elementIndices[ee]; - } - swizExpr->elementCount = elementCount; + // TODO(tfoley): GLSL requires that all component names + // come from the same "family"... - if (anyError) + // Make sure the index is in range for the source type + if (elementIndex >= limitElement) { - swizExpr->Type = ExpressionType::Error; + getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "swizzle component out of range for type"); + anyError = true; + continue; } - else if (elementCount == 1) - { - // single-component swizzle produces a scalar - // - // Note(tfoley): the official HLSL rules seem to be that it produces - // a one-component vector, which is then implicitly convertible to - // a scalar, but that seems like it just adds complexity. - swizExpr->Type = baseElementType; - } - else + + // Check if we've seen this index before + for (int ee = 0; ee < elementCount; ee++) { - // TODO(tfoley): would be nice to "re-sugar" type - // here if the input type had a sugared name... - swizExpr->Type = createVectorType( - baseElementType, - new ConstantIntVal(elementCount)); + if (elementIndices[ee] == elementIndex) + anyDuplicates = true; } - // A swizzle can be used as an l-value as long as there - // were no duplicates in the list of components - swizExpr->Type.IsLeftValue = !anyDuplicates; + // add to our list... + elementIndices[elementCount++] = elementIndex; + } - return swizExpr; + for (int ee = 0; ee < elementCount; ++ee) + { + swizExpr->elementIndices[ee] = elementIndices[ee]; } + swizExpr->elementCount = elementCount; - RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( - MemberExpressionSyntaxNode* memberRefExpr, - RefPtr<ExpressionType> baseElementType, - RefPtr<IntVal> baseElementCount) + if (anyError) { - if (auto constantElementCount = baseElementCount.As<ConstantIntVal>()) - { - return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); - } - else - { - getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on vector of unknown size"); - return CreateErrorExpr(memberRefExpr); - } + swizExpr->Type = ExpressionType::Error; + } + else if (elementCount == 1) + { + // single-component swizzle produces a scalar + // + // Note(tfoley): the official HLSL rules seem to be that it produces + // a one-component vector, which is then implicitly convertible to + // a scalar, but that seems like it just adds complexity. + swizExpr->Type = baseElementType; + } + else + { + // TODO(tfoley): would be nice to "re-sugar" type + // here if the input type had a sugared name... + swizExpr->Type = createVectorType( + baseElementType, + new ConstantIntVal(elementCount)); } + // A swizzle can be used as an l-value as long as there + // were no duplicates in the list of components + swizExpr->Type.IsLeftValue = !anyDuplicates; - virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * expr) override + return swizExpr; + } + + RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( + MemberExpressionSyntaxNode* memberRefExpr, + RefPtr<ExpressionType> baseElementType, + RefPtr<IntVal> baseElementCount) + { + if (auto constantElementCount = baseElementCount.As<ConstantIntVal>()) { - expr->BaseExpression = CheckExpr(expr->BaseExpression); + return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); + } + else + { + getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on vector of unknown size"); + return CreateErrorExpr(memberRefExpr); + } + } - expr->BaseExpression = MaybeDereference(expr->BaseExpression); - auto & baseType = expr->BaseExpression->Type; + virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * expr) override + { + expr->BaseExpression = CheckExpr(expr->BaseExpression); - // Note: Checking for vector types before declaration-reference types, - // because vectors are also declaration reference types... - if (auto baseVecType = baseType->AsVectorType()) - { - return CheckSwizzleExpr( - expr, - baseVecType->elementType, - baseVecType->elementCount); - } - else if(auto baseScalarType = baseType->AsBasicType()) - { - // Treat scalar like a 1-element vector when swizzling - return CheckSwizzleExpr( - expr, - baseScalarType, - 1); - } - else if (auto declRefType = baseType->AsDeclRefType()) + expr->BaseExpression = MaybeDereference(expr->BaseExpression); + + auto & baseType = expr->BaseExpression->Type; + + // Note: Checking for vector types before declaration-reference types, + // because vectors are also declaration reference types... + if (auto baseVecType = baseType->AsVectorType()) + { + return CheckSwizzleExpr( + expr, + baseVecType->elementType, + baseVecType->elementCount); + } + else if(auto baseScalarType = baseType->AsBasicType()) + { + // Treat scalar like a 1-element vector when swizzling + return CheckSwizzleExpr( + expr, + baseScalarType, + 1); + } + else if (auto declRefType = baseType->AsDeclRefType()) + { + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) { - if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) - { - // Checking of the type must be complete before we can reference its members safely - EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); + // Checking of the type must be complete before we can reference its members safely + EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); - LookupResult lookupResult = LookUpLocal(expr->MemberName, aggTypeDeclRef); - if (!lookupResult.isValid()) - { - goto fail; - } + LookupResult lookupResult = LookUpLocal(expr->MemberName, aggTypeDeclRef); + if (!lookupResult.isValid()) + { + goto fail; + } - return createLookupResultExpr( - lookupResult, - expr->BaseExpression, - expr); + return createLookupResultExpr( + lookupResult, + expr->BaseExpression, + expr); #if 0 - DeclRef memberDeclRef(lookupResult.decl, aggTypeDeclRef.substitutions); - return ConstructDeclRefExpr(memberDeclRef, expr->BaseExpression, expr); + DeclRef memberDeclRef(lookupResult.decl, aggTypeDeclRef.substitutions); + return ConstructDeclRefExpr(memberDeclRef, expr->BaseExpression, expr); #endif #if 0 - // TODO(tfoley): It is unfortunate that the lookup strategy - // here isn't unified with the ordinary `Scope` case. - // In particular, if we add support for "transparent" declarations, - // etc. here then we would need to add them in ordinary lookup - // as well. - - Decl* memberDecl = nullptr; // The first declaration we found, if any - Decl* secondDecl = nullptr; // Another declaration with the same name, if any - for (auto m : aggTypeDeclRef.GetMembers()) - { - if (m.GetName() != expr->MemberName) - continue; + // TODO(tfoley): It is unfortunate that the lookup strategy + // here isn't unified with the ordinary `Scope` case. + // In particular, if we add support for "transparent" declarations, + // etc. here then we would need to add them in ordinary lookup + // as well. - if (!memberDecl) - { - memberDecl = m.GetDecl(); - } - else - { - secondDecl = m.GetDecl(); - break; - } - } + Decl* memberDecl = nullptr; // The first declaration we found, if any + Decl* secondDecl = nullptr; // Another declaration with the same name, if any + for (auto m : aggTypeDeclRef.GetMembers()) + { + if (m.GetName() != expr->MemberName) + continue; - // If we didn't find any member, then we signal an error if (!memberDecl) { - expr->Type = ExpressionType::Error; - getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); - return expr; + memberDecl = m.GetDecl(); } - - // If we found only a single member, then we are fine - if (!secondDecl) + else { - // TODO: need to - DeclRef memberDeclRef(memberDecl, aggTypeDeclRef.substitutions); - - expr->declRef = memberDeclRef; - expr->Type = GetTypeForDeclRef(memberDeclRef); - - // When referencing a member variable, the result is an l-value - // if and only if the base expression was. - if (auto memberVarDecl = dynamic_cast<VarDeclBase*>(memberDecl)) - { - expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; - } - return expr; + secondDecl = m.GetDecl(); + break; } + } - // We found multiple members with the same name, and need - // to resolve the embiguity at some point... + // If we didn't find any member, then we signal an error + if (!memberDecl) + { expr->Type = ExpressionType::Error; - getSink()->diagnose(expr, Diagnostics::unimplemented, "ambiguous member reference"); + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); return expr; + } -#endif + // If we found only a single member, then we are fine + if (!secondDecl) + { + // TODO: need to + DeclRef memberDeclRef(memberDecl, aggTypeDeclRef.substitutions); -#if 0 + expr->declRef = memberDeclRef; + expr->Type = GetTypeForDeclRef(memberDeclRef); - StructField* field = structDecl->FindField(expr->MemberName); - if (!field) + // When referencing a member variable, the result is an l-value + // if and only if the base expression was. + if (auto memberVarDecl = dynamic_cast<VarDeclBase*>(memberDecl)) { - expr->Type = ExpressionType::Error; - getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; } - else - expr->Type = field->Type; - - // A reference to a struct member is an l-value if the reference to the struct - // value was also an l-value. - expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; return expr; -#endif } - // catch-all - fail: - getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + // We found multiple members with the same name, and need + // to resolve the embiguity at some point... expr->Type = ExpressionType::Error; + getSink()->diagnose(expr, Diagnostics::unimplemented, "ambiguous member reference"); return expr; - } - // All remaining cases assume we have a `BasicType` - else if (!baseType->AsBasicType()) - expr->Type = ExpressionType::Error; - else - expr->Type = ExpressionType::Error; - if (!baseType->Equals(ExpressionType::Error.Ptr()) && - expr->Type->Equals(ExpressionType::Error.Ptr())) - { - getSink()->diagnose(expr, Diagnostics::typeHasNoPublicMemberOfName, baseType, expr->MemberName); - } - return expr; - } - SemanticsVisitor & operator = (const SemanticsVisitor &) = delete; +#endif - // +#if 0 - virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) override - { - // When faced with an initializer list, we first just check the sub-expressions blindly. - // Actually making them conform to a desired type will wait for when we know the desired - // type based on context. + StructField* field = structDecl->FindField(expr->MemberName); + if (!field) + { + expr->Type = ExpressionType::Error; + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + } + else + expr->Type = field->Type; - for( auto& arg : expr->args ) - { - arg = CheckTerm(arg); + // A reference to a struct member is an l-value if the reference to the struct + // value was also an l-value. + expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; + return expr; +#endif } - expr->Type = ExpressionType::getInitializerListType(); - + // catch-all + fail: + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + expr->Type = ExpressionType::Error; return expr; } - - virtual void visitImportDecl(ImportDecl* decl) override + // All remaining cases assume we have a `BasicType` + else if (!baseType->AsBasicType()) + expr->Type = ExpressionType::Error; + else + expr->Type = ExpressionType::Error; + if (!baseType->Equals(ExpressionType::Error.Ptr()) && + expr->Type->Equals(ExpressionType::Error.Ptr())) { - // We need to look for a module with the specified name - // (whether it has already been loaded, or needs to - // be loaded), and then put its declarations into - // the current scope. + getSink()->diagnose(expr, Diagnostics::typeHasNoPublicMemberOfName, baseType, expr->MemberName); + } + return expr; + } + SemanticsVisitor & operator = (const SemanticsVisitor &) = delete; - auto name = decl->nameToken.Content; - auto scope = decl->scope; - // Try to load a module matching the name - auto importedModuleDecl = findOrImportModule(request, name, decl->nameToken.Position); + // - // If we didn't find a matching module, then bail out - if (!importedModuleDecl) - return; + virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) override + { + // When faced with an initializer list, we first just check the sub-expressions blindly. + // Actually making them conform to a desired type will wait for when we know the desired + // type based on context. - // Record the module that was imported, so that we can use - // it later during code generation. - decl->importedModuleDecl = importedModuleDecl; + for( auto& arg : expr->args ) + { + arg = CheckTerm(arg); + } - // Create a new sub-scope to wire the module - // into our lookup chain. - auto subScope = new Scope(); - subScope->containerDecl = importedModuleDecl.Ptr(); + expr->Type = ExpressionType::getInitializerListType(); - subScope->nextSibling = scope->nextSibling; - scope->nextSibling = subScope; - } - }; + return expr; + } - SyntaxVisitor* CreateSemanticsVisitor( - DiagnosticSink* err, - CompileOptions const& options, - CompileRequest* request) + virtual void visitImportDecl(ImportDecl* decl) override { - return new SemanticsVisitor(err, options, request); + // We need to look for a module with the specified name + // (whether it has already been loaded, or needs to + // be loaded), and then put its declarations into + // the current scope. + + auto name = decl->nameToken.Content; + auto scope = decl->scope; + + // Try to load a module matching the name + auto importedModuleDecl = findOrImportModule(request, name, decl->nameToken.Position); + + // If we didn't find a matching module, then bail out + if (!importedModuleDecl) + return; + + // Record the module that was imported, so that we can use + // it later during code generation. + decl->importedModuleDecl = importedModuleDecl; + + // Create a new sub-scope to wire the module + // into our lookup chain. + auto subScope = new Scope(); + subScope->containerDecl = importedModuleDecl.Ptr(); + + subScope->nextSibling = scope->nextSibling; + scope->nextSibling = subScope; } + }; - // + SyntaxVisitor* CreateSemanticsVisitor( + DiagnosticSink* err, + CompileOptions const& options, + CompileRequest* request) + { + return new SemanticsVisitor(err, options, request); + } - // Get the type to use when referencing a declaration - QualType getTypeForDeclRef( - SemanticsVisitor* sema, - DiagnosticSink* sink, - DeclRef declRef, - RefPtr<ExpressionType>* outTypeResult) - { - if( sema ) - { - sema->EnsureDecl(declRef.GetDecl()); - } + // - // We need to insert an appropriate type for the expression, based on - // what we found. - if (auto varDeclRef = declRef.As<VarDeclBaseRef>()) - { - QualType qualType; - qualType.type = varDeclRef.GetType(); - qualType.IsLeftValue = true; // TODO(tfoley): allow explicit `const` or `let` variables - return qualType; - } - else if (auto typeAliasDeclRef = declRef.As<TypeDefDeclRef>()) - { - auto type = new NamedExpressionType(typeAliasDeclRef); - *outTypeResult = type; - return new TypeType(type); - } - else if (auto aggTypeDeclRef = declRef.As<AggTypeDeclRef>()) - { - auto type = DeclRefType::Create(aggTypeDeclRef); - *outTypeResult = type; - return new TypeType(type); - } - else if (auto simpleTypeDeclRef = declRef.As<SimpleTypeDeclRef>()) - { - auto type = DeclRefType::Create(simpleTypeDeclRef); - *outTypeResult = type; - return new TypeType(type); - } - else if (auto genericDeclRef = declRef.As<GenericDeclRef>()) - { - auto type = new GenericDeclRefType(genericDeclRef); - *outTypeResult = type; - return new TypeType(type); - } - else if (auto funcDeclRef = declRef.As<CallableDeclRef>()) - { - auto type = new FuncType(); - type->declRef = funcDeclRef; - return type; - } + // Get the type to use when referencing a declaration + QualType getTypeForDeclRef( + SemanticsVisitor* sema, + DiagnosticSink* sink, + DeclRef declRef, + RefPtr<ExpressionType>* outTypeResult) + { + if( sema ) + { + sema->EnsureDecl(declRef.GetDecl()); + } - if( sink ) - { - sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration"); - } - return ExpressionType::Error; + // We need to insert an appropriate type for the expression, based on + // what we found. + if (auto varDeclRef = declRef.As<VarDeclBaseRef>()) + { + QualType qualType; + qualType.type = varDeclRef.GetType(); + qualType.IsLeftValue = true; // TODO(tfoley): allow explicit `const` or `let` variables + return qualType; + } + else if (auto typeAliasDeclRef = declRef.As<TypeDefDeclRef>()) + { + auto type = new NamedExpressionType(typeAliasDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto aggTypeDeclRef = declRef.As<AggTypeDeclRef>()) + { + auto type = DeclRefType::Create(aggTypeDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto simpleTypeDeclRef = declRef.As<SimpleTypeDeclRef>()) + { + auto type = DeclRefType::Create(simpleTypeDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto genericDeclRef = declRef.As<GenericDeclRef>()) + { + auto type = new GenericDeclRefType(genericDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto funcDeclRef = declRef.As<CallableDeclRef>()) + { + auto type = new FuncType(); + type->declRef = funcDeclRef; + return type; } - QualType getTypeForDeclRef( - DeclRef declRef) + if( sink ) { - RefPtr<ExpressionType> typeResult; - return getTypeForDeclRef(nullptr, nullptr, declRef, &typeResult); + sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration"); } + return ExpressionType::Error; + } + QualType getTypeForDeclRef( + DeclRef declRef) + { + RefPtr<ExpressionType> typeResult; + return getTypeForDeclRef(nullptr, nullptr, declRef, &typeResult); } -}
\ No newline at end of file + +} diff --git a/source/slang/compiled-program.h b/source/slang/compiled-program.h index 7a86dfe90..1766127b2 100644 --- a/source/slang/compiled-program.h +++ b/source/slang/compiled-program.h @@ -8,89 +8,42 @@ namespace Slang { - namespace Compiler + void IndentString(StringBuilder & sb, String src); + + struct EntryPointResult { -#if 0 - class ShaderMetaData - { - public: - CoreLib::String ShaderName; - CoreLib::EnumerableDictionary<CoreLib::String, CoreLib::RefPtr<ILModuleParameterSet>> ParameterSets; // bindingName->DescSet - }; + String outputSource; + }; - class StageSource - { - public: - String MainCode; - List<unsigned char> BinaryCode; - }; + struct TranslationUnitResult + { + String outputSource; + List<EntryPointResult> entryPoints; + }; - class CompiledShaderSource - { - public: - EnumerableDictionary<String, StageSource> Stages; - ShaderMetaData MetaData; - }; -#endif + class CompileResult + { + public: + DiagnosticSink* mSink = nullptr; - void IndentString(StringBuilder & sb, String src); + // Per-translation-unit results + List<TranslationUnitResult> translationUnits; - struct EntryPointResult + CompileResult() + {} + ~CompileResult() { - String outputSource; - }; - - struct TranslationUnitResult + } + DiagnosticSink * GetErrorWriter() { - String outputSource; - List<EntryPointResult> entryPoints; - }; - - class CompileResult + return mSink; + } + int GetErrorCount() { - public: - DiagnosticSink* mSink = nullptr; - -#if 0 - String ScheduleFile; - RefPtr<ILProgram> Program; - EnumerableDictionary<String, CompiledShaderSource> CompiledSource; // shader -> stage -> code -#endif - - // Per-translation-unit results - List<TranslationUnitResult> translationUnits; - -#if 0 - void PrintDiagnostics() - { - for (int i = 0; i < sink.diagnostics.Count(); i++) - { - fprintf(stderr, "%S(%d): %s %d: %S\n", - sink.diagnostics[i].Position.FileName.ToWString(), - sink.diagnostics[i].Position.Line, - getSeverityName(sink.diagnostics[i].severity), - sink.diagnostics[i].ErrorID, - sink.diagnostics[i].Message.ToWString()); - } - } -#endif - - CompileResult() - {} - ~CompileResult() - { - } - DiagnosticSink * GetErrorWriter() - { - return mSink; - } - int GetErrorCount() - { - return mSink->GetErrorCount(); - } - }; + return mSink->GetErrorCount(); + } + }; - } } #endif
\ No newline at end of file diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp index d02e5d10b..d8298c604 100644 --- a/source/slang/compiler.cpp +++ b/source/slang/compiler.cpp @@ -32,523 +32,519 @@ using namespace CoreLib::Basic; using namespace CoreLib::IO; -using namespace Slang::Compiler; namespace Slang { - namespace Compiler - { - // + // - Profile Profile::LookUp(char const* name) - { - #define PROFILE(TAG, NAME, STAGE, VERSION) if(strcmp(name, #NAME) == 0) return Profile::TAG; - #define PROFILE_ALIAS(TAG, NAME) if(strcmp(name, #NAME) == 0) return Profile::TAG; - #include "profile-defs.h" + Profile Profile::LookUp(char const* name) + { + #define PROFILE(TAG, NAME, STAGE, VERSION) if(strcmp(name, #NAME) == 0) return Profile::TAG; + #define PROFILE_ALIAS(TAG, NAME) if(strcmp(name, #NAME) == 0) return Profile::TAG; + #include "profile-defs.h" - return Profile::Unknown; - } + return Profile::Unknown; + } - // + // - String EmitHLSL(ExtraContext& context) + String EmitHLSL(ExtraContext& context) + { + if (context.getOptions().passThrough != PassThroughMode::None) { - if (context.getOptions().passThrough != PassThroughMode::None) - { - return context.sourceText; - } - else - { - // TODO(tfoley): probably need a way to customize the emit logic... - return emitProgram( - context.programSyntax.Ptr(), - context.programLayout, - CodeGenTarget::HLSL); - } + return context.sourceText; } + else + { + // TODO(tfoley): probably need a way to customize the emit logic... + return emitProgram( + context.programSyntax.Ptr(), + context.programLayout, + CodeGenTarget::HLSL); + } + } - String emitGLSLForEntryPoint(ExtraContext& context, EntryPointOption const& /*entryPoint*/) + String emitGLSLForEntryPoint(ExtraContext& context, EntryPointOption const& /*entryPoint*/) + { + if (context.getOptions().passThrough != PassThroughMode::None) { - if (context.getOptions().passThrough != PassThroughMode::None) - { - return context.sourceText; - } - else - { - // TODO(tfoley): probably need a way to customize the emit logic... - return emitProgram( - context.programSyntax.Ptr(), - context.programLayout, - CodeGenTarget::GLSL); - } + return context.sourceText; + } + else + { + // TODO(tfoley): probably need a way to customize the emit logic... + return emitProgram( + context.programSyntax.Ptr(), + context.programLayout, + CodeGenTarget::GLSL); } + } - char const* GetHLSLProfileName(Profile profile) + char const* GetHLSLProfileName(Profile profile) + { + switch(profile.raw) { - switch(profile.raw) - { - #define PROFILE(TAG, NAME, STAGE, VERSION) case Profile::TAG: return #NAME; - #include "profile-defs.h" + #define PROFILE(TAG, NAME, STAGE, VERSION) case Profile::TAG: return #NAME; + #include "profile-defs.h" - default: - // TODO: emit an error here! - return "unknown"; - } + default: + // TODO: emit an error here! + return "unknown"; } + } #ifdef _WIN32 - void* GetD3DCompilerDLL() - { - // TODO(tfoley): let user specify version of d3dcompiler DLL to use. - static HMODULE d3dCompiler = LoadLibraryA("d3dcompiler_47"); - // TODO(tfoley): handle case where we can't find it gracefully - assert(d3dCompiler); - return d3dCompiler; - } + void* GetD3DCompilerDLL() + { + // TODO(tfoley): let user specify version of d3dcompiler DLL to use. + static HMODULE d3dCompiler = LoadLibraryA("d3dcompiler_47"); + // TODO(tfoley): handle case where we can't find it gracefully + assert(d3dCompiler); + return d3dCompiler; + } - List<uint8_t> EmitDXBytecodeForEntryPoint( - ExtraContext& context, - EntryPointOption const& entryPoint) + List<uint8_t> EmitDXBytecodeForEntryPoint( + ExtraContext& context, + EntryPointOption const& entryPoint) + { + static pD3DCompile D3DCompile_ = nullptr; + if (!D3DCompile_) { - static pD3DCompile D3DCompile_ = nullptr; - if (!D3DCompile_) - { - HMODULE d3dCompiler = (HMODULE)GetD3DCompilerDLL(); - assert(d3dCompiler); + HMODULE d3dCompiler = (HMODULE)GetD3DCompilerDLL(); + assert(d3dCompiler); - D3DCompile_ = (pD3DCompile)GetProcAddress(d3dCompiler, "D3DCompile"); - assert(D3DCompile_); - } + D3DCompile_ = (pD3DCompile)GetProcAddress(d3dCompiler, "D3DCompile"); + assert(D3DCompile_); + } - // The HLSL compiler will try to "canonicalize" our input file path, - // and we don't want it to do that, because they it won't report - // the same locations on error messages that we would. - // - // To work around that, we prepend a custom `#line` directive. + // The HLSL compiler will try to "canonicalize" our input file path, + // and we don't want it to do that, because they it won't report + // the same locations on error messages that we would. + // + // To work around that, we prepend a custom `#line` directive. - String rawHlslCode = EmitHLSL(context); + String rawHlslCode = EmitHLSL(context); - StringBuilder hlslCodeBuilder; - hlslCodeBuilder << "#line 1 \""; - for(auto c : context.sourcePath) + StringBuilder hlslCodeBuilder; + hlslCodeBuilder << "#line 1 \""; + for(auto c : context.sourcePath) + { + char buffer[] = { c, 0 }; + switch(c) { - char buffer[] = { c, 0 }; - switch(c) - { - default: - hlslCodeBuilder << buffer; - break; + default: + hlslCodeBuilder << buffer; + break; - case '\\': - hlslCodeBuilder << "\\\\"; - } + case '\\': + hlslCodeBuilder << "\\\\"; } - hlslCodeBuilder << "\"\n"; - hlslCodeBuilder << rawHlslCode; - - auto hlslCode = hlslCodeBuilder.ProduceString(); - - ID3DBlob* codeBlob; - ID3DBlob* diagnosticsBlob; - HRESULT hr = D3DCompile_( - hlslCode.begin(), - hlslCode.Length(), - context.sourcePath.begin(), - nullptr, - nullptr, - entryPoint.name.begin(), - GetHLSLProfileName(entryPoint.profile), - 0, - 0, - &codeBlob, - &diagnosticsBlob); - List<uint8_t> data; - if (codeBlob) - { - data.AddRange((uint8_t const*)codeBlob->GetBufferPointer(), (int)codeBlob->GetBufferSize()); - codeBlob->Release(); - } - if (diagnosticsBlob) + } + hlslCodeBuilder << "\"\n"; + hlslCodeBuilder << rawHlslCode; + + auto hlslCode = hlslCodeBuilder.ProduceString(); + + ID3DBlob* codeBlob; + ID3DBlob* diagnosticsBlob; + HRESULT hr = D3DCompile_( + hlslCode.begin(), + hlslCode.Length(), + context.sourcePath.begin(), + nullptr, + nullptr, + entryPoint.name.begin(), + GetHLSLProfileName(entryPoint.profile), + 0, + 0, + &codeBlob, + &diagnosticsBlob); + List<uint8_t> data; + if (codeBlob) + { + data.AddRange((uint8_t const*)codeBlob->GetBufferPointer(), (int)codeBlob->GetBufferSize()); + codeBlob->Release(); + } + if (diagnosticsBlob) + { + // TODO(tfoley): need a better policy for how we translate diagnostics + // back into the Slang world (although we should always try to generate + // HLSL that doesn't produce any diagnostics...) + String diagnostics = (char const*) diagnosticsBlob->GetBufferPointer(); + fprintf(stderr, "%s", diagnostics.begin()); + OutputDebugStringA(diagnostics.begin()); + diagnosticsBlob->Release(); + } + if (FAILED(hr)) + { + // TODO(tfoley): What to do on failure? + } + return data; + } + + List<uint8_t> EmitDXBytecode( + ExtraContext& context) + { + if(context.getTranslationUnitOptions().entryPoints.Count() != 1) + { + if(context.getTranslationUnitOptions().entryPoints.Count() == 0) { - // TODO(tfoley): need a better policy for how we translate diagnostics - // back into the Slang world (although we should always try to generate - // HLSL that doesn't produce any diagnostics...) - String diagnostics = (char const*) diagnosticsBlob->GetBufferPointer(); - fprintf(stderr, "%s", diagnostics.begin()); - OutputDebugStringA(diagnostics.begin()); - diagnosticsBlob->Release(); + // TODO(tfoley): need to write diagnostics into this whole thing... + fprintf(stderr, "no entry point specified\n"); } - if (FAILED(hr)) + else { - // TODO(tfoley): What to do on failure? + fprintf(stderr, "multiple entry points specified\n"); } - return data; + return List<uint8_t>(); } - List<uint8_t> EmitDXBytecode( - ExtraContext& context) + return EmitDXBytecodeForEntryPoint(context, context.getTranslationUnitOptions().entryPoints[0]); + } + + String EmitDXBytecodeAssemblyForEntryPoint( + ExtraContext& context, + EntryPointOption const& entryPoint) + { + static pD3DDisassemble D3DDisassemble_ = nullptr; + if (!D3DDisassemble_) { - if(context.getTranslationUnitOptions().entryPoints.Count() != 1) - { - if(context.getTranslationUnitOptions().entryPoints.Count() == 0) - { - // TODO(tfoley): need to write diagnostics into this whole thing... - fprintf(stderr, "no entry point specified\n"); - } - else - { - fprintf(stderr, "multiple entry points specified\n"); - } - return List<uint8_t>(); - } + HMODULE d3dCompiler = (HMODULE)GetD3DCompilerDLL(); + assert(d3dCompiler); - return EmitDXBytecodeForEntryPoint(context, context.getTranslationUnitOptions().entryPoints[0]); + D3DDisassemble_ = (pD3DDisassemble)GetProcAddress(d3dCompiler, "D3DDisassemble"); + assert(D3DDisassemble_); } - String EmitDXBytecodeAssemblyForEntryPoint( - ExtraContext& context, - EntryPointOption const& entryPoint) + List<uint8_t> dxbc = EmitDXBytecodeForEntryPoint(context, entryPoint); + if (!dxbc.Count()) { - static pD3DDisassemble D3DDisassemble_ = nullptr; - if (!D3DDisassemble_) - { - HMODULE d3dCompiler = (HMODULE)GetD3DCompilerDLL(); - assert(d3dCompiler); - - D3DDisassemble_ = (pD3DDisassemble)GetProcAddress(d3dCompiler, "D3DDisassemble"); - assert(D3DDisassemble_); - } - - List<uint8_t> dxbc = EmitDXBytecodeForEntryPoint(context, entryPoint); - if (!dxbc.Count()) - { - return ""; - } + return ""; + } - ID3DBlob* codeBlob; - HRESULT hr = D3DDisassemble_( - &dxbc[0], - dxbc.Count(), - 0, - nullptr, - &codeBlob); + ID3DBlob* codeBlob; + HRESULT hr = D3DDisassemble_( + &dxbc[0], + dxbc.Count(), + 0, + nullptr, + &codeBlob); - String result; - if (codeBlob) - { - result = String((char const*) codeBlob->GetBufferPointer()); - codeBlob->Release(); - } - if (FAILED(hr)) - { - // TODO(tfoley): need to figure out what to diagnose here... - } - return result; + String result; + if (codeBlob) + { + result = String((char const*) codeBlob->GetBufferPointer()); + codeBlob->Release(); + } + if (FAILED(hr)) + { + // TODO(tfoley): need to figure out what to diagnose here... } + return result; + } - String EmitDXBytecodeAssembly( - ExtraContext& context) + String EmitDXBytecodeAssembly( + ExtraContext& context) + { + if(context.getTranslationUnitOptions().entryPoints.Count() == 0) { - if(context.getTranslationUnitOptions().entryPoints.Count() == 0) - { - // TODO(tfoley): need to write diagnostics into this whole thing... - fprintf(stderr, "no entry point specified\n"); - return ""; - } - - StringBuilder sb; - for (auto entryPoint : context.getTranslationUnitOptions().entryPoints) - { - sb << EmitDXBytecodeAssemblyForEntryPoint(context, entryPoint); - } - return sb.ProduceString(); + // TODO(tfoley): need to write diagnostics into this whole thing... + fprintf(stderr, "no entry point specified\n"); + return ""; } - - HMODULE getGLSLCompilerDLL() + StringBuilder sb; + for (auto entryPoint : context.getTranslationUnitOptions().entryPoints) { - // TODO(tfoley): let user specify version of glslang DLL to use. - static HMODULE glslCompiler = LoadLibraryA("glslang"); - // TODO(tfoley): handle case where we can't find it gracefully - assert(glslCompiler); - return glslCompiler; + sb << EmitDXBytecodeAssemblyForEntryPoint(context, entryPoint); } + return sb.ProduceString(); + } - String emitSPIRVAssemblyForEntryPoint( - ExtraContext& context, - EntryPointOption const& entryPoint) - { - String rawGLSL = emitGLSLForEntryPoint(context, entryPoint); + HMODULE getGLSLCompilerDLL() + { + // TODO(tfoley): let user specify version of glslang DLL to use. + static HMODULE glslCompiler = LoadLibraryA("glslang"); + // TODO(tfoley): handle case where we can't find it gracefully + assert(glslCompiler); + return glslCompiler; + } - static glslang_CompileFunc glslang_compile = nullptr; - if (!glslang_compile) - { - HMODULE glslCompiler = getGLSLCompilerDLL(); - assert(glslCompiler); - glslang_compile = (glslang_CompileFunc)GetProcAddress(glslCompiler, "glslang_compile"); - assert(glslang_compile); - } + String emitSPIRVAssemblyForEntryPoint( + ExtraContext& context, + EntryPointOption const& entryPoint) + { + String rawGLSL = emitGLSLForEntryPoint(context, entryPoint); - StringBuilder diagnosticBuilder; - StringBuilder outputBuilder; + static glslang_CompileFunc glslang_compile = nullptr; + if (!glslang_compile) + { + HMODULE glslCompiler = getGLSLCompilerDLL(); + assert(glslCompiler); - auto outputFunc = [](char const* text, void* userData) - { - *(StringBuilder*)userData << text; - }; + glslang_compile = (glslang_CompileFunc)GetProcAddress(glslCompiler, "glslang_compile"); + assert(glslang_compile); + } - glslang_CompileRequest request; - request.sourcePath = context.sourcePath.begin(); - request.sourceText = rawGLSL.begin(); - request.slangStage = (SlangStage) entryPoint.profile.GetStage(); + StringBuilder diagnosticBuilder; + StringBuilder outputBuilder; - request.diagnosticFunc = outputFunc; - request.diagnosticUserData = &diagnosticBuilder; + auto outputFunc = [](char const* text, void* userData) + { + *(StringBuilder*)userData << text; + }; - request.outputFunc = outputFunc; - request.outputUserData = &outputBuilder; + glslang_CompileRequest request; + request.sourcePath = context.sourcePath.begin(); + request.sourceText = rawGLSL.begin(); + request.slangStage = (SlangStage) entryPoint.profile.GetStage(); - int err = glslang_compile(&request); + request.diagnosticFunc = outputFunc; + request.diagnosticUserData = &diagnosticBuilder; - String diagnostics = diagnosticBuilder.ProduceString(); - String output = outputBuilder.ProduceString(); + request.outputFunc = outputFunc; + request.outputUserData = &outputBuilder; - if(err) - { - OutputDebugStringA(diagnostics.Buffer()); - fprintf(stderr, "%s", diagnostics.Buffer()); - exit(1); - } + int err = glslang_compile(&request); - return output; + String diagnostics = diagnosticBuilder.ProduceString(); + String output = outputBuilder.ProduceString(); + + if(err) + { + OutputDebugStringA(diagnostics.Buffer()); + fprintf(stderr, "%s", diagnostics.Buffer()); + exit(1); } + + return output; + } #endif - String emitSPIRVAssembly( - ExtraContext& context) + String emitSPIRVAssembly( + ExtraContext& context) + { + if(context.getTranslationUnitOptions().entryPoints.Count() == 0) { - if(context.getTranslationUnitOptions().entryPoints.Count() == 0) - { - // TODO(tfoley): need to write diagnostics into this whole thing... - fprintf(stderr, "no entry point specified\n"); - return ""; - } - - StringBuilder sb; - for (auto entryPoint : context.getTranslationUnitOptions().entryPoints) - { - sb << emitSPIRVAssemblyForEntryPoint(context, entryPoint); - } - return sb.ProduceString(); + // TODO(tfoley): need to write diagnostics into this whole thing... + fprintf(stderr, "no entry point specified\n"); + return ""; } - // Do emit logic for a single entry point - EntryPointResult emitEntryPoint(ExtraContext& context, EntryPointOption& entryPoint) + StringBuilder sb; + for (auto entryPoint : context.getTranslationUnitOptions().entryPoints) { - EntryPointResult result; + sb << emitSPIRVAssemblyForEntryPoint(context, entryPoint); + } + return sb.ProduceString(); + } + + // Do emit logic for a single entry point + EntryPointResult emitEntryPoint(ExtraContext& context, EntryPointOption& entryPoint) + { + EntryPointResult result; - switch (context.getOptions().Target) + switch (context.getOptions().Target) + { + case CodeGenTarget::GLSL: { - case CodeGenTarget::GLSL: - { - String code = emitGLSLForEntryPoint(context, entryPoint); - result.outputSource = code; - } - break; + String code = emitGLSLForEntryPoint(context, entryPoint); + result.outputSource = code; + } + break; - case CodeGenTarget::DXBytecode: - { - auto code = EmitDXBytecodeForEntryPoint(context, entryPoint); + case CodeGenTarget::DXBytecode: + { + auto code = EmitDXBytecodeForEntryPoint(context, entryPoint); - // TODO(tfoley): Need to figure out an appropriate interface - // for returning binary code, in addition to source. + // TODO(tfoley): Need to figure out an appropriate interface + // for returning binary code, in addition to source. #if 0 - if (context.compileResult) - { - StringBuilder sb; - sb.Append((char*) code.begin(), code.Count()); + if (context.compileResult) + { + StringBuilder sb; + sb.Append((char*) code.begin(), code.Count()); - String codeString = sb.ProduceString(); - result.outputSource = codeString; - } - else + String codeString = sb.ProduceString(); + result.outputSource = codeString; + } + else #endif + { + int col = 0; + for(auto ii : code) { - int col = 0; - for(auto ii : code) - { - if(col != 0) fputs(" ", stdout); - fprintf(stdout, "%02X", ii); - col++; - if(col == 8) - { - fputs("\n", stdout); - col = 0; - } - } - if(col != 0) + if(col != 0) fputs(" ", stdout); + fprintf(stdout, "%02X", ii); + col++; + if(col == 8) { fputs("\n", stdout); + col = 0; } } - return result; - } - break; - - case CodeGenTarget::DXBytecodeAssembly: - { - String code = EmitDXBytecodeAssemblyForEntryPoint(context, entryPoint); - result.outputSource = code; - } - break; - - case CodeGenTarget::SPIRVAssembly: - { - String code = emitSPIRVAssemblyForEntryPoint(context, entryPoint); - result.outputSource = code; + if(col != 0) + { + fputs("\n", stdout); + } } - break; - - // Note(tfoley): We currently hit this case when compiling the stdlib - case CodeGenTarget::Unknown: - break; + return result; + } + break; - default: - throw "unimplemented"; + case CodeGenTarget::DXBytecodeAssembly: + { + String code = EmitDXBytecodeAssemblyForEntryPoint(context, entryPoint); + result.outputSource = code; } + break; - return result; + case CodeGenTarget::SPIRVAssembly: + { + String code = emitSPIRVAssemblyForEntryPoint(context, entryPoint); + result.outputSource = code; + } + break; + // Note(tfoley): We currently hit this case when compiling the stdlib + case CodeGenTarget::Unknown: + break; + default: + throw "unimplemented"; } - TranslationUnitResult emitTranslationUnitEntryPoints(ExtraContext& context) - { - TranslationUnitResult result; + return result; - for (auto& entryPoint : context.getTranslationUnitOptions().entryPoints) - { - EntryPointResult entryPointResult = emitEntryPoint(context, entryPoint); - result.entryPoints.Add(entryPointResult); - } + } - // The result for the translation unit will just be the concatenation - // of the results for each entry point. This doesn't actually make - // much sense, but it is good enough for now. - StringBuilder sb; - for (auto& entryPointResult : result.entryPoints) - { - sb << entryPointResult.outputSource; - } + TranslationUnitResult emitTranslationUnitEntryPoints(ExtraContext& context) + { + TranslationUnitResult result; - result.outputSource = sb.ProduceString(); + for (auto& entryPoint : context.getTranslationUnitOptions().entryPoints) + { + EntryPointResult entryPointResult = emitEntryPoint(context, entryPoint); - return result; + result.entryPoints.Add(entryPointResult); } - // Do emit logic for an entire translation unit, which might - // have zero or more entry points - TranslationUnitResult emitTranslationUnit(ExtraContext& context) + // The result for the translation unit will just be the concatenation + // of the results for each entry point. This doesn't actually make + // much sense, but it is good enough for now. + StringBuilder sb; + for (auto& entryPointResult : result.entryPoints) { - // Most of our code generation targets will require us - // to proceed through one entry point at a time, but - // in some cases we can emit an entire translation unit - // in one go. + sb << entryPointResult.outputSource; + } - switch (context.getOptions().Target) - { - default: - // The default behavior is going to loop over all the entry - // points, and then collect an aggregate result. - return emitTranslationUnitEntryPoints(context); + result.outputSource = sb.ProduceString(); - case CodeGenTarget::HLSL: - // When targetting HLSL, we can emit the entire translation unit - // as a single HLSL program, and include all the entry points. - { + return result; + } - String hlsl = EmitHLSL(context); + // Do emit logic for an entire translation unit, which might + // have zero or more entry points + TranslationUnitResult emitTranslationUnit(ExtraContext& context) + { + // Most of our code generation targets will require us + // to proceed through one entry point at a time, but + // in some cases we can emit an entire translation unit + // in one go. - TranslationUnitResult result; - result.outputSource = hlsl; + switch (context.getOptions().Target) + { + default: + // The default behavior is going to loop over all the entry + // points, and then collect an aggregate result. + return emitTranslationUnitEntryPoints(context); + + case CodeGenTarget::HLSL: + // When targetting HLSL, we can emit the entire translation unit + // as a single HLSL program, and include all the entry points. + { - // Because the user might ask for per-entry-point source, - // we will just attach the same string as the result for - // each entry point. - for( auto& entryPoint : context.getTranslationUnitOptions().entryPoints ) - { - (void)entryPoint; + String hlsl = EmitHLSL(context); - EntryPointResult entryPointResult; - entryPointResult.outputSource = hlsl; - result.entryPoints.Add(entryPointResult); - } + TranslationUnitResult result; + result.outputSource = hlsl; - return result; + // Because the user might ask for per-entry-point source, + // we will just attach the same string as the result for + // each entry point. + for( auto& entryPoint : context.getTranslationUnitOptions().entryPoints ) + { + (void)entryPoint; + + EntryPointResult entryPointResult; + entryPointResult.outputSource = hlsl; + result.entryPoints.Add(entryPointResult); } - break; + + return result; } + break; } + } - TranslationUnitResult generateOutput(ExtraContext& context) - { - TranslationUnitResult result = emitTranslationUnit(context); - return result; - } + TranslationUnitResult generateOutput(ExtraContext& context) + { + TranslationUnitResult result = emitTranslationUnit(context); + return result; + } - void generateOutput( - ExtraContext& context, - CollectionOfTranslationUnits* collectionOfTranslationUnits) + void generateOutput( + ExtraContext& context, + CollectionOfTranslationUnits* collectionOfTranslationUnits) + { + switch (context.getOptions().Target) { - switch (context.getOptions().Target) + default: + // For most targets, we will do things per-translation-unit + for( auto translationUnit : collectionOfTranslationUnits->translationUnits ) { - default: - // For most targets, we will do things per-translation-unit - for( auto translationUnit : collectionOfTranslationUnits->translationUnits ) - { - ExtraContext innerContext = context; - innerContext.translationUnitOptions = &translationUnit.options; - innerContext.programSyntax = translationUnit.SyntaxNode; - innerContext.sourcePath = "slang"; // don't have this any more! - innerContext.sourceText = ""; - - TranslationUnitResult translationUnitResult = generateOutput(innerContext); - context.compileResult->translationUnits.Add(translationUnitResult); - } - break; + ExtraContext innerContext = context; + innerContext.translationUnitOptions = &translationUnit.options; + innerContext.programSyntax = translationUnit.SyntaxNode; + innerContext.sourcePath = "slang"; // don't have this any more! + innerContext.sourceText = ""; + + TranslationUnitResult translationUnitResult = generateOutput(innerContext); + context.compileResult->translationUnits.Add(translationUnitResult); + } + break; - case CodeGenTarget::ReflectionJSON: - { - String reflectionJSON = emitReflectionJSON(context.programLayout); + case CodeGenTarget::ReflectionJSON: + { + String reflectionJSON = emitReflectionJSON(context.programLayout); - // HACK(tfoley): just print it out since that is what people probably expect. - // TODO: need a way to control where output gets routed across all possible targets. - fprintf(stdout, "%s", reflectionJSON.begin()); - } - break; + // HACK(tfoley): just print it out since that is what people probably expect. + // TODO: need a way to control where output gets routed across all possible targets. + fprintf(stdout, "%s", reflectionJSON.begin()); } + break; } + } - TranslationUnitResult passThrough( - String const& sourceText, - String const& sourcePath, - const CompileOptions & options, - TranslationUnitOptions const& translationUnitOptions) - { - ExtraContext extra; - extra.options = &options; - extra.translationUnitOptions = &translationUnitOptions; - extra.sourcePath = sourcePath; - extra.sourceText = sourceText; - - return generateOutput(extra); - } + TranslationUnitResult passThrough( + String const& sourceText, + String const& sourcePath, + const CompileOptions & options, + TranslationUnitOptions const& translationUnitOptions) + { + ExtraContext extra; + extra.options = &options; + extra.translationUnitOptions = &translationUnitOptions; + extra.sourcePath = sourcePath; + extra.sourceText = sourceText; + return generateOutput(extra); } + } diff --git a/source/slang/compiler.h b/source/slang/compiler.h index 01db7874f..fe69a7bf8 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -13,180 +13,176 @@ namespace Slang { - namespace Compiler + struct IncludeHandler; + struct CompileRequest; + + enum class CompilerMode + { + ProduceLibrary, + ProduceShader, + GenerateChoice + }; + + enum class StageTarget + { + Unknown, + VertexShader, + HullShader, + DomainShader, + GeometryShader, + FragmentShader, + ComputeShader, + }; + + enum class CodeGenTarget + { + Unknown = SLANG_TARGET_UNKNOWN, + GLSL = SLANG_GLSL, + GLSL_Vulkan = SLANG_GLSL_VULKAN, + GLSL_Vulkan_OneDesc = SLANG_GLSL_VULKAN_ONE_DESC, + HLSL = SLANG_HLSL, + SPIRV = SLANG_SPIRV, + SPIRVAssembly = SLANG_SPIRV_ASM, + DXBytecode = SLANG_DXBC, + DXBytecodeAssembly = SLANG_DXBC_ASM, + ReflectionJSON = SLANG_REFLECTION_JSON, + }; + + // Describes an entry point that we've been requested to compile + struct EntryPointOption { - class ILConstOperand; - struct IncludeHandler; - struct CompileRequest; - - enum class CompilerMode - { - ProduceLibrary, - ProduceShader, - GenerateChoice - }; - - enum class StageTarget - { - Unknown, - VertexShader, - HullShader, - DomainShader, - GeometryShader, - FragmentShader, - ComputeShader, - }; - - enum class CodeGenTarget - { - Unknown = SLANG_TARGET_UNKNOWN, - GLSL = SLANG_GLSL, - GLSL_Vulkan = SLANG_GLSL_VULKAN, - GLSL_Vulkan_OneDesc = SLANG_GLSL_VULKAN_ONE_DESC, - HLSL = SLANG_HLSL, - SPIRV = SLANG_SPIRV, - SPIRVAssembly = SLANG_SPIRV_ASM, - DXBytecode = SLANG_DXBC, - DXBytecodeAssembly = SLANG_DXBC_ASM, - ReflectionJSON = SLANG_REFLECTION_JSON, - }; - - // Describes an entry point that we've been requested to compile - struct EntryPointOption - { - String name; - Profile profile; - }; - - enum class PassThroughMode : SlangPassThrough - { - None = SLANG_PASS_THROUGH_NONE, // don't pass through: use Slang compiler - HLSL = SLANG_PASS_THROUGH_FXC, // pass through HLSL to `D3DCompile` API + String name; + Profile profile; + }; + + enum class PassThroughMode : SlangPassThrough + { + None = SLANG_PASS_THROUGH_NONE, // don't pass through: use Slang compiler + HLSL = SLANG_PASS_THROUGH_FXC, // pass through HLSL to `D3DCompile` API // GLSL, // pass through GLSL to `glslang` library - }; - - // Represents a single source file (either an on-disk file, or a - // "virtual" file passed in as a string) - class SourceFile : public RefObject - { - public: - // The file path for a real file, or the nominal path for a virtual file - String path; - - // The actual contents of the file - String content; - }; - - // Options for a single translation unit being requested by the user - class TranslationUnitOptions - { - public: - SourceLanguage sourceLanguage = SourceLanguage::Unknown; - - // All entry points we've been asked to compile for this translation unit - List<EntryPointOption> entryPoints; - - // The source file(s) that will be compiled to form this translation unit - List<RefPtr<SourceFile> > sourceFiles; - - // Preprocessor definitions to use for this translation unit only - // (whereas the ones on `CompileOptions` will be shared) - Dictionary<String, String> preprocessorDefinitions; - }; - - class CompileOptions - { - public: - // What target language are we compiling to? - CodeGenTarget Target = CodeGenTarget::Unknown; - - // Directories to search for `#include` files or `import`ed modules - List<String> SearchDirectories; - - // Definitions to provide during preprocessing - Dictionary<String, String> preprocessorDefinitions; - - // Translation units we are being asked to compile - List<TranslationUnitOptions> translationUnits; - - // The code generation profile we've been asked to use. - Profile profile; - - // Should we just pass the input to another compiler? - PassThroughMode passThrough = PassThroughMode::None; - - // Flags supplied through the API - SlangCompileFlags flags = 0; - }; - - // This is the representation of a given translation unit - class CompileUnit - { - public: - TranslationUnitOptions options; - RefPtr<ProgramSyntaxNode> SyntaxNode; - }; - - // TODO: pick an appropriate name for this... - class CollectionOfTranslationUnits : public RefObject - { - public: - List<CompileUnit> translationUnits; - - // TODO: this is more output-oriented, but maybe okay to have here... - RefPtr<ProgramLayout> layout; - }; - - // Context information for code generation - struct ExtraContext - { - CompileOptions const* options = nullptr; - TranslationUnitOptions const* translationUnitOptions = nullptr; - - CompileResult* compileResult = nullptr; - - RefPtr<ProgramSyntaxNode> programSyntax; - ProgramLayout* programLayout; - - String sourceText; - String sourcePath; - - CompileOptions const& getOptions() { return *options; } - TranslationUnitOptions const& getTranslationUnitOptions() { return *translationUnitOptions; } - }; + }; -#if 0 + // Represents a single source file (either an on-disk file, or a + // "virtual" file passed in as a string) + class SourceFile : public RefObject + { + public: + // The file path for a real file, or the nominal path for a virtual file + String path; - class ShaderCompiler : public CoreLib::Basic::Object - { - public: - virtual void Compile( - CompileResult& result, - CollectionOfTranslationUnits* collectionOfTranslationUnits, - const CompileOptions& options, - CompileRequest* request) = 0; + // The actual contents of the file + String content; + }; - virtual TranslationUnitResult PassThrough( - String const& sourceText, - String const& sourcePath, - const CompileOptions & options, - TranslationUnitOptions const& translationUnitOptions) = 0; + // Options for a single translation unit being requested by the user + class TranslationUnitOptions + { + public: + SourceLanguage sourceLanguage = SourceLanguage::Unknown; - }; + // All entry points we've been asked to compile for this translation unit + List<EntryPointOption> entryPoints; - ShaderCompiler * CreateShaderCompiler(); -#endif + // The source file(s) that will be compiled to form this translation unit + List<RefPtr<SourceFile> > sourceFiles; + + // Preprocessor definitions to use for this translation unit only + // (whereas the ones on `CompileOptions` will be shared) + Dictionary<String, String> preprocessorDefinitions; + }; + + class CompileOptions + { + public: + // What target language are we compiling to? + CodeGenTarget Target = CodeGenTarget::Unknown; + + // Directories to search for `#include` files or `import`ed modules + List<String> SearchDirectories; + + // Definitions to provide during preprocessing + Dictionary<String, String> preprocessorDefinitions; + + // Translation units we are being asked to compile + List<TranslationUnitOptions> translationUnits; + + // The code generation profile we've been asked to use. + Profile profile; + + // Should we just pass the input to another compiler? + PassThroughMode passThrough = PassThroughMode::None; - TranslationUnitResult passThrough( + // Flags supplied through the API + SlangCompileFlags flags = 0; + }; + + // This is the representation of a given translation unit + class CompileUnit + { + public: + TranslationUnitOptions options; + RefPtr<ProgramSyntaxNode> SyntaxNode; + }; + + // TODO: pick an appropriate name for this... + class CollectionOfTranslationUnits : public RefObject + { + public: + List<CompileUnit> translationUnits; + + // TODO: this is more output-oriented, but maybe okay to have here... + RefPtr<ProgramLayout> layout; + }; + + // Context information for code generation + struct ExtraContext + { + CompileOptions const* options = nullptr; + TranslationUnitOptions const* translationUnitOptions = nullptr; + + CompileResult* compileResult = nullptr; + + RefPtr<ProgramSyntaxNode> programSyntax; + ProgramLayout* programLayout; + + String sourceText; + String sourcePath; + + CompileOptions const& getOptions() { return *options; } + TranslationUnitOptions const& getTranslationUnitOptions() { return *translationUnitOptions; } + }; + +#if 0 + + class ShaderCompiler : public CoreLib::Basic::Object + { + public: + virtual void Compile( + CompileResult& result, + CollectionOfTranslationUnits* collectionOfTranslationUnits, + const CompileOptions& options, + CompileRequest* request) = 0; + + virtual TranslationUnitResult PassThrough( String const& sourceText, String const& sourcePath, const CompileOptions & options, - TranslationUnitOptions const& translationUnitOptions); + TranslationUnitOptions const& translationUnitOptions) = 0; + + }; + + ShaderCompiler * CreateShaderCompiler(); +#endif + + TranslationUnitResult passThrough( + String const& sourceText, + String const& sourcePath, + const CompileOptions & options, + TranslationUnitOptions const& translationUnitOptions); - void generateOutput( - ExtraContext& context, - CollectionOfTranslationUnits* collectionOfTranslationUnits); - } + void generateOutput( + ExtraContext& context, + CollectionOfTranslationUnits* collectionOfTranslationUnits); } #endif
\ No newline at end of file diff --git a/source/slang/diagnostics.cpp b/source/slang/diagnostics.cpp index d8527466b..11f54e096 100644 --- a/source/slang/diagnostics.cpp +++ b/source/slang/diagnostics.cpp @@ -15,7 +15,6 @@ #endif namespace Slang { -namespace Compiler { void printDiagnosticArg(StringBuilder& sb, char const* str) { @@ -201,4 +200,4 @@ namespace Diagnostics } -}} // namespace Slang::Compiler +} // namespace Slang diff --git a/source/slang/diagnostics.h b/source/slang/diagnostics.h index c1559df5d..c4a9532ce 100644 --- a/source/slang/diagnostics.h +++ b/source/slang/diagnostics.h @@ -10,206 +10,203 @@ namespace Slang { - namespace Compiler - { - using namespace CoreLib::Basic; + using namespace CoreLib::Basic; - enum class Severity + enum class Severity + { + Note, + Warning, + Error, + Fatal, + Internal, + }; + + // TODO(tfoley): move this into a source file... + inline const char* getSeverityName(Severity severity) + { + switch (severity) { - Note, - Warning, - Error, - Fatal, - Internal, - }; + case Severity::Note: return "note"; + case Severity::Warning: return "warning"; + case Severity::Error: return "error"; + case Severity::Fatal: return "fatal error"; + case Severity::Internal: return "internal error"; + default: return "unknown error"; + } + } - // TODO(tfoley): move this into a source file... - inline const char* getSeverityName(Severity severity) + // A structure to be used in static data describing different + // diagnostic messages. + struct DiagnosticInfo + { + int id; + Severity severity; + char const* messageFormat; + }; + + class Diagnostic + { + public: + String Message; + CodePosition Position; + int ErrorID; + Severity severity; + + Diagnostic() { - switch (severity) - { - case Severity::Note: return "note"; - case Severity::Warning: return "warning"; - case Severity::Error: return "error"; - case Severity::Fatal: return "fatal error"; - case Severity::Internal: return "internal error"; - default: return "unknown error"; - } + ErrorID = -1; } - - // A structure to be used in static data describing different - // diagnostic messages. - struct DiagnosticInfo + Diagnostic( + const String & msg, + int id, + const CodePosition & pos, + Severity severity) + : severity(severity) { - int id; - Severity severity; - char const* messageFormat; - }; + Message = msg; + ErrorID = id; + Position = pos; + } + }; + + class Decl; + class Type; + class ExpressionType; + class ILType; + class StageAttribute; + struct TypeExp; + struct QualType; + + void printDiagnosticArg(StringBuilder& sb, char const* str); + void printDiagnosticArg(StringBuilder& sb, int val); + void printDiagnosticArg(StringBuilder& sb, CoreLib::Basic::String const& str); + void printDiagnosticArg(StringBuilder& sb, Decl* decl); + void printDiagnosticArg(StringBuilder& sb, Type* type); + void printDiagnosticArg(StringBuilder& sb, ExpressionType* type); + void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); + void printDiagnosticArg(StringBuilder& sb, QualType const& type); + void printDiagnosticArg(StringBuilder& sb, TokenType tokenType); + void printDiagnosticArg(StringBuilder& sb, Token const& token); + + template<typename T> + void printDiagnosticArg(StringBuilder& sb, RefPtr<T> ptr) + { + printDiagnosticArg(sb, ptr.Ptr()); + } + + inline CodePosition const& getDiagnosticPos(CodePosition const& pos) { return pos; } + + class SyntaxNode; + class ShaderClosure; + CodePosition const& getDiagnosticPos(SyntaxNode const* syntax); + CodePosition const& getDiagnosticPos(Token const& token); + CodePosition const& getDiagnosticPos(TypeExp const& typeExp); + + template<typename T> + CodePosition getDiagnosticPos(RefPtr<T> const& ptr) + { + return getDiagnosticPos(ptr.Ptr()); + } - class Diagnostic + struct DiagnosticArg + { + void* data; + void (*printFunc)(StringBuilder&, void*); + + template<typename T> + struct Helper { - public: - String Message; - CodePosition Position; - int ErrorID; - Severity severity; - - Diagnostic() - { - ErrorID = -1; - } - Diagnostic( - const String & msg, - int id, - const CodePosition & pos, - Severity severity) - : severity(severity) - { - Message = msg; - ErrorID = id; - Position = pos; - } + static void printFunc(StringBuilder& sb, void* data) { printDiagnosticArg(sb, *(T*)data); } }; - class Decl; - class Type; - class ExpressionType; - class ILType; - class StageAttribute; - struct TypeExp; - struct QualType; - - void printDiagnosticArg(StringBuilder& sb, char const* str); - void printDiagnosticArg(StringBuilder& sb, int val); - void printDiagnosticArg(StringBuilder& sb, CoreLib::Basic::String const& str); - void printDiagnosticArg(StringBuilder& sb, Decl* decl); - void printDiagnosticArg(StringBuilder& sb, Type* type); - void printDiagnosticArg(StringBuilder& sb, ExpressionType* type); - void printDiagnosticArg(StringBuilder& sb, TypeExp const& type); - void printDiagnosticArg(StringBuilder& sb, QualType const& type); - void printDiagnosticArg(StringBuilder& sb, TokenType tokenType); - void printDiagnosticArg(StringBuilder& sb, Token const& token); - template<typename T> - void printDiagnosticArg(StringBuilder& sb, RefPtr<T> ptr) + DiagnosticArg(T const& arg) + : data((void*)&arg) + , printFunc(&Helper<T>::printFunc) + {} + }; + + class DiagnosticSink + { + public: + StringBuilder outputBuffer; +// List<Diagnostic> diagnostics; + int errorCount = 0; + + SlangDiagnosticCallback callback = nullptr; + void* callbackUserData = nullptr; + +/* + void Error(int id, const String & msg, const CodePosition & pos) { - printDiagnosticArg(sb, ptr.Ptr()); + diagnostics.Add(Diagnostic(msg, id, pos, Severity::Error)); + errorCount++; } - inline CodePosition const& getDiagnosticPos(CodePosition const& pos) { return pos; } - - class SyntaxNode; - class ShaderClosure; - CodePosition const& getDiagnosticPos(SyntaxNode const* syntax); - CodePosition const& getDiagnosticPos(Token const& token); - CodePosition const& getDiagnosticPos(TypeExp const& typeExp); + void Warning(int id, const String & msg, const CodePosition & pos) + { + diagnostics.Add(Diagnostic(msg, id, pos, Severity::Warning)); + } +*/ + int GetErrorCount() { return errorCount; } - template<typename T> - CodePosition getDiagnosticPos(RefPtr<T> const& ptr) + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info) { - return getDiagnosticPos(ptr.Ptr()); + diagnoseImpl(pos, info, 0, NULL); } - struct DiagnosticArg + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0) { - void* data; - void (*printFunc)(StringBuilder&, void*); - - template<typename T> - struct Helper - { - static void printFunc(StringBuilder& sb, void* data) { printDiagnosticArg(sb, *(T*)data); } - }; - - template<typename T> - DiagnosticArg(T const& arg) - : data((void*)&arg) - , printFunc(&Helper<T>::printFunc) - {} - }; + DiagnosticArg const* args[] = { &arg0 }; + diagnoseImpl(pos, info, 1, args); + } - class DiagnosticSink + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1) { - public: - StringBuilder outputBuffer; -// List<Diagnostic> diagnostics; - int errorCount = 0; + DiagnosticArg const* args[] = { &arg0, &arg1 }; + diagnoseImpl(pos, info, 2, args); + } - SlangDiagnosticCallback callback = nullptr; - void* callbackUserData = nullptr; + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2) + { + DiagnosticArg const* args[] = { &arg0, &arg1, &arg2 }; + diagnoseImpl(pos, info, 3, args); + } -/* - void Error(int id, const String & msg, const CodePosition & pos) - { - diagnostics.Add(Diagnostic(msg, id, pos, Severity::Error)); - errorCount++; - } - - void Warning(int id, const String & msg, const CodePosition & pos) - { - diagnostics.Add(Diagnostic(msg, id, pos, Severity::Warning)); - } -*/ - int GetErrorCount() { return errorCount; } - - void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info) - { - diagnoseImpl(pos, info, 0, NULL); - } - - void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0) - { - DiagnosticArg const* args[] = { &arg0 }; - diagnoseImpl(pos, info, 1, args); - } - - void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1) - { - DiagnosticArg const* args[] = { &arg0, &arg1 }; - diagnoseImpl(pos, info, 2, args); - } - - void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2) - { - DiagnosticArg const* args[] = { &arg0, &arg1, &arg2 }; - diagnoseImpl(pos, info, 3, args); - } - - void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2, DiagnosticArg const& arg3) - { - DiagnosticArg const* args[] = { &arg0, &arg1, &arg2, &arg3 }; - diagnoseImpl(pos, info, 4, args); - } - - template<typename P, typename... Args> - void diagnose(P const& pos, DiagnosticInfo const& info, Args const&... args ) - { - diagnoseDispatch(getDiagnosticPos(pos), info, args...); - } - - void diagnoseImpl(CodePosition const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args); - }; + void diagnoseDispatch(CodePosition const& pos, DiagnosticInfo const& info, DiagnosticArg const& arg0, DiagnosticArg const& arg1, DiagnosticArg const& arg2, DiagnosticArg const& arg3) + { + DiagnosticArg const* args[] = { &arg0, &arg1, &arg2, &arg3 }; + diagnoseImpl(pos, info, 4, args); + } - namespace Diagnostics + template<typename P, typename... Args> + void diagnose(P const& pos, DiagnosticInfo const& info, Args const&... args ) { + diagnoseDispatch(getDiagnosticPos(pos), info, args...); + } + + void diagnoseImpl(CodePosition const& pos, DiagnosticInfo const& info, int argCount, DiagnosticArg const* const* args); + }; + + namespace Diagnostics + { #define DIAGNOSTIC(id, severity, name, messageFormat) extern const DiagnosticInfo name; #include "diagnostic-defs.h" - } } } #ifdef _DEBUG #define SLANG_INTERNAL_ERROR(sink, pos) \ - (sink)->diagnose(Slang::Compiler::CodePosition(__LINE__, 0, 0, __FILE__), Slang::Compiler::Diagnostics::internalCompilerError) + (sink)->diagnose(Slang::CodePosition(__LINE__, 0, 0, __FILE__), Slang::Diagnostics::internalCompilerError) #define SLANG_UNIMPLEMENTED(sink, pos, what) \ - (sink)->diagnose(Slang::Compiler::CodePosition(__LINE__, 0, 0, __FILE__), Slang::Compiler::Diagnostics::unimplemented, what) + (sink)->diagnose(Slang::CodePosition(__LINE__, 0, 0, __FILE__), Slang::Diagnostics::unimplemented, what) #define SLANG_UNREACHABLE(msg) do { assert(!"ureachable code:" msg); exit(1); } while(0) #else #define SLANG_INTERNAL_ERROR(sink, pos) \ - (sink)->diagnose(pos, Slang::Compiler::Diagnostics::internalCompilerError) + (sink)->diagnose(pos, Slang::Diagnostics::internalCompilerError) #define SLANG_UNIMPLEMENTED(sink, pos, what) \ - (sink)->diagnose(pos, Slang::Compiler::Diagnostics::unimplemented, what) + (sink)->diagnose(pos, Slang::Diagnostics::unimplemented, what) // TODO: find something that will perform better #define SLANG_UNREACHABLE(msg) exit(1) diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index f11757594..69bdaebf6 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -11,7 +11,7 @@ #pragma warning(disable:4996) #endif -namespace Slang { namespace Compiler { +namespace Slang { struct EmitContext { @@ -2748,4 +2748,4 @@ String emitProgram( } -}} // Slang::Compiler +} // namespace Slang diff --git a/source/slang/emit.h b/source/slang/emit.h index 05ea1550f..1cb8d2d81 100644 --- a/source/slang/emit.h +++ b/source/slang/emit.h @@ -8,17 +8,14 @@ namespace Slang { - namespace Compiler - { - using namespace CoreLib::Basic; + using namespace CoreLib::Basic; - class ProgramSyntaxNode; - class ProgramLayout; + class ProgramSyntaxNode; + class ProgramLayout; - String emitProgram( - ProgramSyntaxNode* program, - ProgramLayout* programLayout, - CodeGenTarget target); - } + String emitProgram( + ProgramSyntaxNode* program, + ProgramLayout* programLayout, + CodeGenTarget target); } #endif diff --git a/source/slang/lexer.cpp b/source/slang/lexer.cpp index 87b3eaf63..cb718b538 100644 --- a/source/slang/lexer.cpp +++ b/source/slang/lexer.cpp @@ -4,397 +4,402 @@ namespace Slang { - namespace Compiler + static Token GetEndOfFileToken() { - static Token GetEndOfFileToken() - { - return Token(TokenType::EndOfFile, "", 0, 0, 0, ""); - } + return Token(TokenType::EndOfFile, "", 0, 0, 0, ""); + } - Token* TokenList::begin() const - { - assert(mTokens.Count()); - return &mTokens[0]; - } + Token* TokenList::begin() const + { + assert(mTokens.Count()); + return &mTokens[0]; + } - Token* TokenList::end() const - { - assert(mTokens.Count()); - assert(mTokens[mTokens.Count()-1].Type == TokenType::EndOfFile); - return &mTokens[mTokens.Count() - 1]; - } + Token* TokenList::end() const + { + assert(mTokens.Count()); + assert(mTokens[mTokens.Count()-1].Type == TokenType::EndOfFile); + return &mTokens[mTokens.Count() - 1]; + } - TokenSpan::TokenSpan() - : mBegin(NULL) - , mEnd (NULL) - {} + TokenSpan::TokenSpan() + : mBegin(NULL) + , mEnd (NULL) + {} - TokenReader::TokenReader() - : mCursor(NULL) - , mEnd (NULL) - {} + TokenReader::TokenReader() + : mCursor(NULL) + , mEnd (NULL) + {} - Token TokenReader::PeekToken() const - { - if (!mCursor) - return GetEndOfFileToken(); + Token TokenReader::PeekToken() const + { + if (!mCursor) + return GetEndOfFileToken(); - Token token = *mCursor; - if (mCursor == mEnd) - token.Type = TokenType::EndOfFile; - return token; - } + Token token = *mCursor; + if (mCursor == mEnd) + token.Type = TokenType::EndOfFile; + return token; + } - TokenType TokenReader::PeekTokenType() const - { - if (mCursor == mEnd) - return TokenType::EndOfFile; - assert(mCursor); - return mCursor->Type; - } + TokenType TokenReader::PeekTokenType() const + { + if (mCursor == mEnd) + return TokenType::EndOfFile; + assert(mCursor); + return mCursor->Type; + } - CodePosition TokenReader::PeekLoc() const - { - if (!mCursor) - return CodePosition(); - assert(mCursor); - return mCursor->Position; - } + CodePosition TokenReader::PeekLoc() const + { + if (!mCursor) + return CodePosition(); + assert(mCursor); + return mCursor->Position; + } - Token TokenReader::AdvanceToken() - { - if (!mCursor) - return GetEndOfFileToken(); + Token TokenReader::AdvanceToken() + { + if (!mCursor) + return GetEndOfFileToken(); + + Token token = *mCursor; + if (mCursor == mEnd) + token.Type = TokenType::EndOfFile; + else + mCursor++; + return token; + } - Token token = *mCursor; - if (mCursor == mEnd) - token.Type = TokenType::EndOfFile; - else - mCursor++; - return token; - } + // Lexer - // Lexer + Lexer::Lexer( + String const& path, + String const& content, + DiagnosticSink* sink) + : path(path) + , content(content) + , sink(sink) + { + cursor = content.begin(); + end = content.end(); - Lexer::Lexer( - String const& path, - String const& content, - DiagnosticSink* sink) - : path(path) - , content(content) - , sink(sink) - { - cursor = content.begin(); - end = content.end(); + loc = CodePosition(1, 1, 0, path); + tokenFlags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + lexerFlags = 0; + } - loc = CodePosition(1, 1, 0, path); - tokenFlags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; - lexerFlags = 0; - } + Lexer::~Lexer() + { + } - Lexer::~Lexer() - { - } + enum { kEOF = -1 }; - enum { kEOF = -1 }; + // Get the next input byte, without any handling of + // escaped newlines, non-ASCII code points, source locations, etc. + static int peekRaw(Lexer* lexer) + { + // If we are at the end of the input, return a designated end-of-file value + if(lexer->cursor == lexer->end) + return kEOF; - // Get the next input byte, without any handling of - // escaped newlines, non-ASCII code points, source locations, etc. - static int peekRaw(Lexer* lexer) - { - // If we are at the end of the input, return a designated end-of-file value - if(lexer->cursor == lexer->end) - return kEOF; + // Otherwise, just look at the next byte + return *lexer->cursor; + } - // Otherwise, just look at the next byte - return *lexer->cursor; - } + // Read one input byte without any special handling (similar to `peekRaw`) + static int advanceRaw(Lexer* lexer) + { + // The logic here is basically the same as for `peekRaw()`, + // escape we advance `cursor` if we aren't at the end. - // Read one input byte without any special handling (similar to `peekRaw`) - static int advanceRaw(Lexer* lexer) - { - // The logic here is basically the same as for `peekRaw()`, - // escape we advance `cursor` if we aren't at the end. + if (lexer->cursor == lexer->end) + return kEOF; - if (lexer->cursor == lexer->end) - return kEOF; + return *lexer->cursor++; + } - return *lexer->cursor++; - } + // When the cursor is already at the first byte of an end-of-line sequence, + // consume one or two bytes that compose the sequence. + // + // Basically, a newline is one of: + // + // "\n" + // "\r" + // "\r\n" + // "\n\r" + // + // We always look for the longest match possible. + // + static void handleNewLineInner(Lexer* lexer, int c) + { + assert(c == '\n' || c == '\r'); - // When the cursor is already at the first byte of an end-of-line sequence, - // consume one or two bytes that compose the sequence. - // - // Basically, a newline is one of: - // - // "\n" - // "\r" - // "\r\n" - // "\n\r" - // - // We always look for the longest match possible. - // - static void handleNewLineInner(Lexer* lexer, int c) + int d = peekRaw(lexer); + if( (c ^ d) == ('\n' ^ '\r') ) { - assert(c == '\n' || c == '\r'); + advanceRaw(lexer); + } + + lexer->loc.Line++; + lexer->loc.Col = 1; + } - int d = peekRaw(lexer); - if( (c ^ d) == ('\n' ^ '\r') ) + // Look ahead one code point, dealing with complications like + // escaped newlines. + static int peek(Lexer* lexer) + { + // Look at the next raw byte, and decide what to do + int c = peekRaw(lexer); + + if(c == '\\') + { + // We might have a backslash-escaped newline. + // Look at the next byte (if any) to see. + // + // Note(tfoley): We are assuming a null-terminated input here, + // so that we can safely look at the next byte without issue. + int d = lexer->cursor[1]; + switch (d) { - advanceRaw(lexer); - } + case '\r': case '\n': + { + // The newline was escaped, so return the code point after *that* - lexer->loc.Line++; - lexer->loc.Col = 1; + int e = lexer->cursor[2]; + if ((d ^ e) == ('\r' ^ '\n')) + return lexer->cursor[3]; + return e; + } + + default: + break; + } } + // TODO: handle UTF-8 encoding for non-ASCII code points here + + // Default case is to just hand along the byte we read as an ASCII code point. + return c; + } - // Look ahead one code point, dealing with complications like - // escaped newlines. - static int peek(Lexer* lexer) + // Get the next code point from the input, and advance the cursor. + static int advance(Lexer* lexer) + { + // We are going to loop, but only as a way of handling + // escaped line endings. + for (;;) { + // If we are at the end of the input, then the task is easy. + if (lexer->cursor == lexer->end) + return kEOF; + // Look at the next raw byte, and decide what to do - int c = peekRaw(lexer); + int c = *lexer->cursor++; - if(c == '\\') + if (c == '\\') { // We might have a backslash-escaped newline. // Look at the next byte (if any) to see. // // Note(tfoley): We are assuming a null-terminated input here, // so that we can safely look at the next byte without issue. - int d = lexer->cursor[1]; + int d = *lexer->cursor; switch (d) { case '\r': case '\n': - { - // The newline was escaped, so return the code point after *that* + // handle the end-of-line for our source location tracking + lexer->cursor++; + handleNewLineInner(lexer, d); - int e = lexer->cursor[2]; - if ((d ^ e) == ('\r' ^ '\n')) - return lexer->cursor[3]; - return e; - } + // Now try again, looking at the character after the + // escaped nmewline. + continue; default: break; } } - // TODO: handle UTF-8 encoding for non-ASCII code points here - - // Default case is to just hand along the byte we read as an ASCII code point. - return c; - } - - // Get the next code point from the input, and advance the cursor. - static int advance(Lexer* lexer) - { - // We are going to loop, but only as a way of handling - // escaped line endings. - for (;;) - { - // If we are at the end of the input, then the task is easy. - if (lexer->cursor == lexer->end) - return kEOF; - - // Look at the next raw byte, and decide what to do - int c = *lexer->cursor++; - if (c == '\\') - { - // We might have a backslash-escaped newline. - // Look at the next byte (if any) to see. - // - // Note(tfoley): We are assuming a null-terminated input here, - // so that we can safely look at the next byte without issue. - int d = *lexer->cursor; - switch (d) - { - case '\r': case '\n': - // handle the end-of-line for our source location tracking - lexer->cursor++; - handleNewLineInner(lexer, d); - - // Now try again, looking at the character after the - // escaped nmewline. - continue; + // TODO: Need to handle non-ASCII code points. - default: - break; - } - } + // Default case is to advance by one location + // and return the raw byte we saw. - // TODO: Need to handle non-ASCII code points. + lexer->loc.Col++; + lexer->loc.Pos++; - // Default case is to advance by one location - // and return the raw byte we saw. - - lexer->loc.Col++; - lexer->loc.Pos++; - - return c; - } + return c; } + } - static void handleNewLine(Lexer* lexer) - { - int c = advance(lexer); - handleNewLineInner(lexer, c); - } + static void handleNewLine(Lexer* lexer) + { + int c = advance(lexer); + handleNewLineInner(lexer, c); + } - static void lexLineComment(Lexer* lexer) + static void lexLineComment(Lexer* lexer) + { + for(;;) { - for(;;) + switch(peek(lexer)) { - switch(peek(lexer)) - { - case '\n': case '\r': case kEOF: - return; + case '\n': case '\r': case kEOF: + return; - default: - advance(lexer); - continue; - } + default: + advance(lexer); + continue; } } + } - static void lexBlockComment(Lexer* lexer) + static void lexBlockComment(Lexer* lexer) + { + for(;;) { - for(;;) + switch(peek(lexer)) { - switch(peek(lexer)) - { - case kEOF: - // TODO(tfoley) diagnostic! - return; + case kEOF: + // TODO(tfoley) diagnostic! + return; - case '\n': case '\r': - handleNewLine(lexer); - continue; + case '\n': case '\r': + handleNewLine(lexer); + continue; - case '*': + case '*': + advance(lexer); + switch( peek(lexer) ) + { + case '/': advance(lexer); - switch( peek(lexer) ) - { - case '/': - advance(lexer); - return; - - default: - continue; - } + return; default: - advance(lexer); continue; } + + default: + advance(lexer); + continue; } } + } - static void lexHorizontalSpace(Lexer* lexer) + static void lexHorizontalSpace(Lexer* lexer) + { + for(;;) { - for(;;) + switch(peek(lexer)) { - switch(peek(lexer)) - { - case ' ': case '\t': - advance(lexer); - continue; + case ' ': case '\t': + advance(lexer); + continue; - default: - return; - } + default: + return; } } + } - static void lexIdentifier(Lexer* lexer) + static void lexIdentifier(Lexer* lexer) + { + for(;;) { - for(;;) + int c = peek(lexer); + if(('a' <= c ) && (c <= 'z') + || ('A' <= c) && (c <= 'Z') + || ('0' <= c) && (c <= '9') + || (c == '_')) { - int c = peek(lexer); - if(('a' <= c ) && (c <= 'z') - || ('A' <= c) && (c <= 'Z') - || ('0' <= c) && (c <= '9') - || (c == '_')) - { - advance(lexer); - continue; - } - - return; + advance(lexer); + continue; } + + return; } + } - static void lexDigits(Lexer* lexer, int base) + static void lexDigits(Lexer* lexer, int base) + { + for(;;) { - for(;;) - { - int c = peek(lexer); - - int digitVal = 0; - switch(c) - { - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - digitVal = c - '0'; - break; + int c = peek(lexer); - case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': - if(base <= 10) return; - digitVal = 10 + c - 'a'; - break; + int digitVal = 0; + switch(c) + { + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + digitVal = c - '0'; + break; - case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': - if(base <= 10) return; - digitVal = 10 + c - 'A'; - break; + case 'a': case 'b': case 'c': case 'd': case 'e': case 'f': + if(base <= 10) return; + digitVal = 10 + c - 'a'; + break; - default: - // Not more digits! - return; - } + case 'A': case 'B': case 'C': case 'D': case 'E': case 'F': + if(base <= 10) return; + digitVal = 10 + c - 'A'; + break; - if(digitVal >= base) - { - char buffer[] = { (char) c, 0 }; - lexer->sink->diagnose(lexer->loc, Diagnostics::invalidDigitForBase, buffer, base); - } + default: + // Not more digits! + return; + } - advance(lexer); + if(digitVal >= base) + { + char buffer[] = { (char) c, 0 }; + lexer->sink->diagnose(lexer->loc, Diagnostics::invalidDigitForBase, buffer, base); } + + advance(lexer); } + } - static TokenType maybeLexNumberSuffix(Lexer* lexer, TokenType tokenType) + static TokenType maybeLexNumberSuffix(Lexer* lexer, TokenType tokenType) + { + // First check for suffixes that + // indicate a floating-point number + switch(peek(lexer)) { - // First check for suffixes that - // indicate a floating-point number - switch(peek(lexer)) - { - case 'f': case 'F': - advance(lexer); - return TokenType::DoubleLiterial; + case 'f': case 'F': + advance(lexer); + return TokenType::DoubleLiterial; - default: - break; - } + default: + break; + } - // Once we've ruled out floating-point - // suffixes, we can check for the inter cases + // Once we've ruled out floating-point + // suffixes, we can check for the inter cases - // TODO: allow integer suffixes in any order... + // TODO: allow integer suffixes in any order... - // Leading `u` or `U` for unsigned - switch(peek(lexer)) - { - default: - break; + // Leading `u` or `U` for unsigned + switch(peek(lexer)) + { + default: + break; - case 'u': case 'U': - advance(lexer); - break; - } + case 'u': case 'U': + advance(lexer); + break; + } + + // Optional `l`, `L`, `ll`, or `LL` + switch(peek(lexer)) + { + default: + break; - // Optional `l`, `L`, `ll`, or `LL` + case 'l': case 'L': + advance(lexer); switch(peek(lexer)) { default: @@ -402,720 +407,712 @@ namespace Slang case 'l': case 'L': advance(lexer); - switch(peek(lexer)) - { - default: - break; - - case 'l': case 'L': - advance(lexer); - break; - } break; } + break; + } + + return tokenType; + } - return tokenType; + static bool maybeLexNumberExponent(Lexer* lexer, int base) + { + switch( peek(lexer) ) + { + default: + return false; + + case 'e': case 'E': + if(base != 10) return false; + advance(lexer); + break; + + case 'p': case 'P': + if(base != 16) return false; + advance(lexer); + break; } - static bool maybeLexNumberExponent(Lexer* lexer, int base) + // we saw an exponent marker, so we must + switch( peek(lexer) ) { - switch( peek(lexer) ) - { - default: - return false; + case '+': case '-': + advance(lexer); + break; + } - case 'e': case 'E': - if(base != 10) return false; - advance(lexer); - break; + // TODO(tfoley): it would be an error to not see digits here... - case 'p': case 'P': - if(base != 16) return false; - advance(lexer); - break; - } + lexDigits(lexer, 10); - // we saw an exponent marker, so we must - switch( peek(lexer) ) - { - case '+': case '-': - advance(lexer); - break; - } + return true; + } - // TODO(tfoley): it would be an error to not see digits here... + static TokenType lexNumberAfterDecimalPoint(Lexer* lexer, int base) + { + lexDigits(lexer, base); + maybeLexNumberExponent(lexer, base); + + return maybeLexNumberSuffix(lexer, TokenType::DoubleLiterial); + } + + static TokenType lexNumber(Lexer* lexer, int base) + { + // TODO(tfoley): Need to consider whehter to allow any kind of digit separator character. - lexDigits(lexer, 10); + TokenType tokenType = TokenType::IntLiterial; - return true; - } + // At the start of things, we just concern ourselves with digits + lexDigits(lexer, base); - static TokenType lexNumberAfterDecimalPoint(Lexer* lexer, int base) + if( peek(lexer) == '.' ) { + tokenType = TokenType::DoubleLiterial; + + advance(lexer); lexDigits(lexer, base); - maybeLexNumberExponent(lexer, base); - - return maybeLexNumberSuffix(lexer, TokenType::DoubleLiterial); } - static TokenType lexNumber(Lexer* lexer, int base) + if( maybeLexNumberExponent(lexer, base)) { - // TODO(tfoley): Need to consider whehter to allow any kind of digit separator character. - - TokenType tokenType = TokenType::IntLiterial; + tokenType = TokenType::DoubleLiterial; + } - // At the start of things, we just concern ourselves with digits - lexDigits(lexer, base); + maybeLexNumberSuffix(lexer, tokenType); + return tokenType; + } - if( peek(lexer) == '.' ) + static void lexStringLiteralBody(Lexer* lexer, char quote) + { + for(;;) + { + int c = peek(lexer); + if(c == quote) { - tokenType = TokenType::DoubleLiterial; - advance(lexer); - lexDigits(lexer, base); + return; } - if( maybeLexNumberExponent(lexer, base)) + switch(c) { - tokenType = TokenType::DoubleLiterial; - } + case kEOF: + lexer->sink->diagnose(lexer->loc, Diagnostics::endOfFileInLiteral); + return; - maybeLexNumberSuffix(lexer, tokenType); - return tokenType; - } + case '\n': case '\r': + lexer->sink->diagnose(lexer->loc, Diagnostics::newlineInLiteral); + return; - static void lexStringLiteralBody(Lexer* lexer, char quote) - { - for(;;) - { - int c = peek(lexer); - if(c == quote) + case '\\': + // Need to handle various escape sequence cases + advance(lexer); + switch(peek(lexer)) { + case '\'': + case '\"': + case '\\': + case '?': + case 'a': + case 'b': + case 'f': + case 'n': + case 'r': + case 't': + case 'v': advance(lexer); - return; - } - - switch(c) - { - case kEOF: - lexer->sink->diagnose(lexer->loc, Diagnostics::endOfFileInLiteral); - return; - - case '\n': case '\r': - lexer->sink->diagnose(lexer->loc, Diagnostics::newlineInLiteral); - return; + break; - case '\\': - // Need to handle various escape sequence cases + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': + // octal escape: up to 3 characters advance(lexer); - switch(peek(lexer)) + for(int ii = 0; ii < 3; ++ii) { - case '\'': - case '\"': - case '\\': - case '?': - case 'a': - case 'b': - case 'f': - case 'n': - case 'r': - case 't': - case 'v': - advance(lexer); - break; - - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': - // octal escape: up to 3 characters - advance(lexer); - for(int ii = 0; ii < 3; ++ii) + int d = peek(lexer); + if(('0' <= d) && (d <= '7')) { - int d = peek(lexer); - if(('0' <= d) && (d <= '7')) - { - advance(lexer); - continue; - } - else - { - break; - } + advance(lexer); + continue; } - break; - - case 'x': - // hexadecimal escape: any number of characters - advance(lexer); - for(;;) + else { - int d = peek(lexer); - if(('0' <= d) && (d <= '9') - || ('a' <= d) && (d <= 'f') - || ('A' <= d) && (d <= 'F')) - { - advance(lexer); - continue; - } - else - { - break; - } + break; } - break; - - // TODO: Unicode escape sequences - } break; - default: + case 'x': + // hexadecimal escape: any number of characters advance(lexer); - continue; + for(;;) + { + int d = peek(lexer); + if(('0' <= d) && (d <= '9') + || ('a' <= d) && (d <= 'f') + || ('A' <= d) && (d <= 'F')) + { + advance(lexer); + continue; + } + else + { + break; + } + } + break; + + // TODO: Unicode escape sequences + } + break; + + default: + advance(lexer); + continue; } } + } - String getStringLiteralTokenValue(Token const& token) - { - assert(token.Type == TokenType::StringLiterial - || token.Type == TokenType::CharLiterial); + String getStringLiteralTokenValue(Token const& token) + { + assert(token.Type == TokenType::StringLiterial + || token.Type == TokenType::CharLiterial); - char const* cursor = token.Content.begin(); - char const* end = token.Content.end(); + char const* cursor = token.Content.begin(); + char const* end = token.Content.end(); - auto quote = *cursor++; - assert(quote == '\'' || quote == '"'); + auto quote = *cursor++; + assert(quote == '\'' || quote == '"'); - StringBuilder valueBuilder; - for(;;) - { - assert(cursor != end); + StringBuilder valueBuilder; + for(;;) + { + assert(cursor != end); - auto c = *cursor++; + auto c = *cursor++; - // If we see a closing quote, then we are at the end of the string literal - if(c == quote) - { - assert(cursor == end); - return valueBuilder.ProduceString(); - } + // If we see a closing quote, then we are at the end of the string literal + if(c == quote) + { + assert(cursor == end); + return valueBuilder.ProduceString(); + } - // Charcters that don't being escape sequences are easy; - // just append them to the buffer and move on. - if(c != '\\') - { - valueBuilder.Append(c); - continue; - } + // Charcters that don't being escape sequences are easy; + // just append them to the buffer and move on. + if(c != '\\') + { + valueBuilder.Append(c); + continue; + } - // Now we look at another character to figure out the kind of - // escape sequence we are dealing with: + // Now we look at another character to figure out the kind of + // escape sequence we are dealing with: - int d = *cursor++; + int d = *cursor++; - switch(d) + switch(d) + { + // Simple characters that just needed to be escaped + case '\'': + case '\"': + case '\\': + case '?': + valueBuilder.Append(d); + continue; + + // Traditional escape sequences for special characters + case 'a': valueBuilder.Append('\a'); continue; + case 'b': valueBuilder.Append('\b'); continue; + case 'f': valueBuilder.Append('\f'); continue; + case 'n': valueBuilder.Append('\n'); continue; + case 'r': valueBuilder.Append('\r'); continue; + case 't': valueBuilder.Append('\t'); continue; + case 'v': valueBuilder.Append('\v'); continue; + + // Octal escape: up to 3 characterws + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': { - // Simple characters that just needed to be escaped - case '\'': - case '\"': - case '\\': - case '?': - valueBuilder.Append(d); - continue; - - // Traditional escape sequences for special characters - case 'a': valueBuilder.Append('\a'); continue; - case 'b': valueBuilder.Append('\b'); continue; - case 'f': valueBuilder.Append('\f'); continue; - case 'n': valueBuilder.Append('\n'); continue; - case 'r': valueBuilder.Append('\r'); continue; - case 't': valueBuilder.Append('\t'); continue; - case 'v': valueBuilder.Append('\v'); continue; - - // Octal escape: up to 3 characterws - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': + cursor--; + int value = 0; + for(int ii = 0; ii < 3; ++ii) { - cursor--; - int value = 0; - for(int ii = 0; ii < 3; ++ii) + d = *cursor; + if(('0' <= d) && (d <= '7')) { - d = *cursor; - if(('0' <= d) && (d <= '7')) - { - value = value*8 + (d - '0'); - - cursor++; - continue; - } - else - { - break; - } - } + value = value*8 + (d - '0'); - // TODO: add support for appending an arbitrary code point? - valueBuilder.Append((char) value); + cursor++; + continue; + } + else + { + break; + } } - continue; - // Hexadecimal escape: any number of characters - case 'x': + // TODO: add support for appending an arbitrary code point? + valueBuilder.Append((char) value); + } + continue; + + // Hexadecimal escape: any number of characters + case 'x': + { + cursor--; + int value = 0; + for(;;) { - cursor--; - int value = 0; - for(;;) + d = *cursor++; + int digitValue = 0; + if(('0' <= d) && (d <= '9')) + { + digitValue = d - '0'; + } + else if( ('a' <= d) && (d <= 'f') ) + { + digitValue = d - 'a'; + } + else if( ('A' <= d) && (d <= 'F') ) { - d = *cursor++; - int digitValue = 0; - if(('0' <= d) && (d <= '9')) - { - digitValue = d - '0'; - } - else if( ('a' <= d) && (d <= 'f') ) - { - digitValue = d - 'a'; - } - else if( ('A' <= d) && (d <= 'F') ) - { - digitValue = d - 'A'; - } - else - { - cursor--; - break; - } - - value = value*16 + digitValue; + digitValue = d - 'A'; + } + else + { + cursor--; + break; } - // TODO: add support for appending an arbitrary code point? - valueBuilder.Append((char) value); + value = value*16 + digitValue; } - continue; - - // TODO: Unicode escape sequences + // TODO: add support for appending an arbitrary code point? + valueBuilder.Append((char) value); } + continue; + + // TODO: Unicode escape sequences + } } + } - String getFileNameTokenValue(Token const& token) - { - // A file name usually doesn't process escape sequences - // (this is import on Windows, where `\\` is a valid - // path separator cahracter). + String getFileNameTokenValue(Token const& token) + { + // A file name usually doesn't process escape sequences + // (this is import on Windows, where `\\` is a valid + // path separator cahracter). - // Just trim off the first and last characters to remove the quotes - // (whether they were `""` or `<>`. - return token.Content.SubString(1, token.Content.Length()-2); - } + // Just trim off the first and last characters to remove the quotes + // (whether they were `""` or `<>`. + return token.Content.SubString(1, token.Content.Length()-2); + } - static TokenType lexTokenImpl(Lexer* lexer) + static TokenType lexTokenImpl(Lexer* lexer) + { + switch(peek(lexer)) { + default: + break; + + case kEOF: + if((lexer->lexerFlags & kLexerFlag_InDirective) != 0) + return TokenType::EndOfDirective; + return TokenType::EndOfFile; + + case '\r': case '\n': + if((lexer->lexerFlags & kLexerFlag_InDirective) != 0) + return TokenType::EndOfDirective; + handleNewLine(lexer); + return TokenType::NewLine; + + case ' ': case '\t': + lexHorizontalSpace(lexer); + return TokenType::WhiteSpace; + + case '.': + advance(lexer); switch(peek(lexer)) { - default: - break; + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + return lexNumberAfterDecimalPoint(lexer, 10); - case kEOF: - if((lexer->lexerFlags & kLexerFlag_InDirective) != 0) - return TokenType::EndOfDirective; - return TokenType::EndOfFile; + // TODO(tfoley): handle ellipsis (`...`) - case '\r': case '\n': - if((lexer->lexerFlags & kLexerFlag_InDirective) != 0) - return TokenType::EndOfDirective; - handleNewLine(lexer); - return TokenType::NewLine; + default: + return TokenType::Dot; + } - case ' ': case '\t': - lexHorizontalSpace(lexer); - return TokenType::WhiteSpace; + case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + return lexNumber(lexer, 10); - case '.': + case '0': + { + auto loc = lexer->loc; advance(lexer); switch(peek(lexer)) { - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - return lexNumberAfterDecimalPoint(lexer, 10); - - // TODO(tfoley): handle ellipsis (`...`) - default: - return TokenType::Dot; - } + return TokenType::IntLiterial; - case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - return lexNumber(lexer, 10); - - case '0': - { - auto loc = lexer->loc; + case '.': advance(lexer); - switch(peek(lexer)) - { - default: - return TokenType::IntLiterial; - - case '.': - advance(lexer); - return lexNumberAfterDecimalPoint(lexer, 10); + return lexNumberAfterDecimalPoint(lexer, 10); - case 'x': case 'X': - advance(lexer); - return lexNumber(lexer, 16); + case 'x': case 'X': + advance(lexer); + return lexNumber(lexer, 16); - case 'b': case 'B': - advance(lexer); - return lexNumber(lexer, 2); + case 'b': case 'B': + advance(lexer); + return lexNumber(lexer, 2); - case '0': case '1': case '2': case '3': case '4': - case '5': case '6': case '7': case '8': case '9': - lexer->sink->diagnose(loc, Diagnostics::octalLiteral); - return lexNumber(lexer, 8); - } + case '0': case '1': case '2': case '3': case '4': + case '5': case '6': case '7': case '8': case '9': + lexer->sink->diagnose(loc, Diagnostics::octalLiteral); + return lexNumber(lexer, 8); } + } - case 'a': case 'b': case 'c': case 'd': case 'e': - case 'f': case 'g': case 'h': case 'i': case 'j': - case 'k': case 'l': case 'm': case 'n': case 'o': - case 'p': case 'q': case 'r': case 's': case 't': - case 'u': case 'v': case 'w': case 'x': case 'y': - case 'z': - case 'A': case 'B': case 'C': case 'D': case 'E': - case 'F': case 'G': case 'H': case 'I': case 'J': - case 'K': case 'L': case 'M': case 'N': case 'O': - case 'P': case 'Q': case 'R': case 'S': case 'T': - case 'U': case 'V': case 'W': case 'X': case 'Y': - case 'Z': - case '_': - lexIdentifier(lexer); - return TokenType::Identifier; - - case '\"': - advance(lexer); - lexStringLiteralBody(lexer, '\"'); - return TokenType::StringLiterial; - - case '\'': - advance(lexer); - lexStringLiteralBody(lexer, '\''); - return TokenType::CharLiterial; - - case '+': - advance(lexer); - switch(peek(lexer)) - { - case '+': advance(lexer); return TokenType::OpInc; - case '=': advance(lexer); return TokenType::OpAddAssign; - default: - return TokenType::OpAdd; - } + case 'a': case 'b': case 'c': case 'd': case 'e': + case 'f': case 'g': case 'h': case 'i': case 'j': + case 'k': case 'l': case 'm': case 'n': case 'o': + case 'p': case 'q': case 'r': case 's': case 't': + case 'u': case 'v': case 'w': case 'x': case 'y': + case 'z': + case 'A': case 'B': case 'C': case 'D': case 'E': + case 'F': case 'G': case 'H': case 'I': case 'J': + case 'K': case 'L': case 'M': case 'N': case 'O': + case 'P': case 'Q': case 'R': case 'S': case 'T': + case 'U': case 'V': case 'W': case 'X': case 'Y': + case 'Z': + case '_': + lexIdentifier(lexer); + return TokenType::Identifier; + + case '\"': + advance(lexer); + lexStringLiteralBody(lexer, '\"'); + return TokenType::StringLiterial; + + case '\'': + advance(lexer); + lexStringLiteralBody(lexer, '\''); + return TokenType::CharLiterial; + + case '+': + advance(lexer); + switch(peek(lexer)) + { + case '+': advance(lexer); return TokenType::OpInc; + case '=': advance(lexer); return TokenType::OpAddAssign; + default: + return TokenType::OpAdd; + } - case '-': - advance(lexer); - switch(peek(lexer)) - { - case '-': advance(lexer); return TokenType::OpDec; - case '=': advance(lexer); return TokenType::OpSubAssign; - case '>': advance(lexer); return TokenType::RightArrow; - default: - return TokenType::OpSub; - } + case '-': + advance(lexer); + switch(peek(lexer)) + { + case '-': advance(lexer); return TokenType::OpDec; + case '=': advance(lexer); return TokenType::OpSubAssign; + case '>': advance(lexer); return TokenType::RightArrow; + default: + return TokenType::OpSub; + } - case '*': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpMulAssign; - default: - return TokenType::OpMul; - } + case '*': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpMulAssign; + default: + return TokenType::OpMul; + } - case '/': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpDivAssign; - case '/': advance(lexer); lexLineComment(lexer); return TokenType::LineComment; - case '*': advance(lexer); lexBlockComment(lexer); return TokenType::BlockComment; - default: - return TokenType::OpDiv; - } + case '/': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpDivAssign; + case '/': advance(lexer); lexLineComment(lexer); return TokenType::LineComment; + case '*': advance(lexer); lexBlockComment(lexer); return TokenType::BlockComment; + default: + return TokenType::OpDiv; + } - case '%': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpModAssign; - default: - return TokenType::OpMod; - } + case '%': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpModAssign; + default: + return TokenType::OpMod; + } - case '|': - advance(lexer); - switch(peek(lexer)) - { - case '|': advance(lexer); return TokenType::OpOr; - case '=': advance(lexer); return TokenType::OpOrAssign; - default: - return TokenType::OpBitOr; - } + case '|': + advance(lexer); + switch(peek(lexer)) + { + case '|': advance(lexer); return TokenType::OpOr; + case '=': advance(lexer); return TokenType::OpOrAssign; + default: + return TokenType::OpBitOr; + } - case '&': - advance(lexer); - switch(peek(lexer)) - { - case '&': advance(lexer); return TokenType::OpAnd; - case '=': advance(lexer); return TokenType::OpAndAssign; - default: - return TokenType::OpBitAnd; - } + case '&': + advance(lexer); + switch(peek(lexer)) + { + case '&': advance(lexer); return TokenType::OpAnd; + case '=': advance(lexer); return TokenType::OpAndAssign; + default: + return TokenType::OpBitAnd; + } - case '^': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpXorAssign; - default: - return TokenType::OpBitXor; - } + case '^': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpXorAssign; + default: + return TokenType::OpBitXor; + } + case '>': + advance(lexer); + switch(peek(lexer)) + { case '>': advance(lexer); switch(peek(lexer)) { - case '>': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpShrAssign; - default: return TokenType::OpRsh; - } - case '=': advance(lexer); return TokenType::OpGeq; - default: - return TokenType::OpGreater; + case '=': advance(lexer); return TokenType::OpShrAssign; + default: return TokenType::OpRsh; } + case '=': advance(lexer); return TokenType::OpGeq; + default: + return TokenType::OpGreater; + } + case '<': + advance(lexer); + switch(peek(lexer)) + { case '<': advance(lexer); switch(peek(lexer)) { - case '<': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpShlAssign; - default: return TokenType::OpLsh; - } - case '=': advance(lexer); return TokenType::OpLeq; - default: - return TokenType::OpLess; - } - - case '=': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpEql; - default: - return TokenType::OpAssign; + case '=': advance(lexer); return TokenType::OpShlAssign; + default: return TokenType::OpLsh; } + case '=': advance(lexer); return TokenType::OpLeq; + default: + return TokenType::OpLess; + } - case '!': - advance(lexer); - switch(peek(lexer)) - { - case '=': advance(lexer); return TokenType::OpNeq; - default: - return TokenType::OpNot; - } + case '=': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpEql; + default: + return TokenType::OpAssign; + } - case '#': - advance(lexer); - switch(peek(lexer)) - { - case '#': advance(lexer); return TokenType::PoundPound; - default: - return TokenType::Pound; - } + case '!': + advance(lexer); + switch(peek(lexer)) + { + case '=': advance(lexer); return TokenType::OpNeq; + default: + return TokenType::OpNot; + } - case '~': advance(lexer); return TokenType::OpBitNot; + case '#': + advance(lexer); + switch(peek(lexer)) + { + case '#': advance(lexer); return TokenType::PoundPound; + default: + return TokenType::Pound; + } - case ':': advance(lexer); return TokenType::Colon; - case ';': advance(lexer); return TokenType::Semicolon; - case ',': advance(lexer); return TokenType::Comma; + case '~': advance(lexer); return TokenType::OpBitNot; - case '{': advance(lexer); return TokenType::LBrace; - case '}': advance(lexer); return TokenType::RBrace; - case '[': advance(lexer); return TokenType::LBracket; - case ']': advance(lexer); return TokenType::RBracket; - case '(': advance(lexer); return TokenType::LParent; - case ')': advance(lexer); return TokenType::RParent; + case ':': advance(lexer); return TokenType::Colon; + case ';': advance(lexer); return TokenType::Semicolon; + case ',': advance(lexer); return TokenType::Comma; - case '?': advance(lexer); return TokenType::QuestionMark; - case '@': advance(lexer); return TokenType::At; - case '$': advance(lexer); return TokenType::Dollar; + case '{': advance(lexer); return TokenType::LBrace; + case '}': advance(lexer); return TokenType::RBrace; + case '[': advance(lexer); return TokenType::LBracket; + case ']': advance(lexer); return TokenType::RBracket; + case '(': advance(lexer); return TokenType::LParent; + case ')': advance(lexer); return TokenType::RParent; - } + case '?': advance(lexer); return TokenType::QuestionMark; + case '@': advance(lexer); return TokenType::At; + case '$': advance(lexer); return TokenType::Dollar; - // TODO(tfoley): If we ever wanted to support proper Unicode - // in identifiers, etc., then this would be the right place - // to perform a more expensive dispatch based on the actual - // code point (and not just the first byte). + } - { - // If none of the above cases matched, then we have an - // unexpected/invalid character. + // TODO(tfoley): If we ever wanted to support proper Unicode + // in identifiers, etc., then this would be the right place + // to perform a more expensive dispatch based on the actual + // code point (and not just the first byte). - auto loc = lexer->loc; - auto sink = lexer->sink; - int c = advance(lexer); - if(c >= 0x20 && c <= 0x7E) - { - char buffer[] = { (char) c, 0 }; - sink->diagnose(loc, Diagnostics::illegalCharacterPrint, buffer); - } - else - { - // Fallback: print as hexadecimal - sink->diagnose(loc, Diagnostics::illegalCharacterHex, String((unsigned char)c, 16)); - } + { + // If none of the above cases matched, then we have an + // unexpected/invalid character. - return TokenType::Invalid; + auto loc = lexer->loc; + auto sink = lexer->sink; + int c = advance(lexer); + if(c >= 0x20 && c <= 0x7E) + { + char buffer[] = { (char) c, 0 }; + sink->diagnose(loc, Diagnostics::illegalCharacterPrint, buffer); } - } - - Token Lexer::lexToken() - { - auto flags = this->tokenFlags; - for(;;) + else { - Token token; - token.Position = loc; - - char const* textBegin = cursor; + // Fallback: print as hexadecimal + sink->diagnose(loc, Diagnostics::illegalCharacterHex, String((unsigned char)c, 16)); + } - auto tokenType = lexTokenImpl(this); + return TokenType::Invalid; + } + } - // The low-level lexer produces tokens for things we want - // to ignore, such as white space, so we skip them here. - switch(tokenType) - { - case TokenType::Invalid: - flags = 0; - continue; + Token Lexer::lexToken() + { + auto flags = this->tokenFlags; + for(;;) + { + Token token; + token.Position = loc; - case TokenType::NewLine: - flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; - continue; + char const* textBegin = cursor; - case TokenType::WhiteSpace: - case TokenType::LineComment: - case TokenType::BlockComment: - flags |= TokenFlag::AfterWhitespace; - continue; + auto tokenType = lexTokenImpl(this); - // We don't want to skip the end-of-file token, but we *do* - // want to make sure it has appropriate flags to make our life easier - case TokenType::EndOfFile: - flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; - break; + // The low-level lexer produces tokens for things we want + // to ignore, such as white space, so we skip them here. + switch(tokenType) + { + case TokenType::Invalid: + flags = 0; + continue; + + case TokenType::NewLine: + flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + continue; + + case TokenType::WhiteSpace: + case TokenType::LineComment: + case TokenType::BlockComment: + flags |= TokenFlag::AfterWhitespace; + continue; + + // We don't want to skip the end-of-file token, but we *do* + // want to make sure it has appropriate flags to make our life easier + case TokenType::EndOfFile: + flags = TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace; + break; - // We will also do some book-keeping around preprocessor directives here: - // - // If we see a `#` at the start of a line, then we are entering a - // preprocessor directive. - case TokenType::Pound: - if((flags & TokenFlag::AtStartOfLine) != 0) - lexerFlags |= kLexerFlag_InDirective; - break; - // - // And if we saw an end-of-line during a directive, then we are - // now leaving that directive. - // - case TokenType::EndOfDirective: - lexerFlags &= ~kLexerFlag_InDirective; - break; + // We will also do some book-keeping around preprocessor directives here: + // + // If we see a `#` at the start of a line, then we are entering a + // preprocessor directive. + case TokenType::Pound: + if((flags & TokenFlag::AtStartOfLine) != 0) + lexerFlags |= kLexerFlag_InDirective; + break; + // + // And if we saw an end-of-line during a directive, then we are + // now leaving that directive. + // + case TokenType::EndOfDirective: + lexerFlags &= ~kLexerFlag_InDirective; + break; - default: - break; - } + default: + break; + } - token.Type = tokenType; + token.Type = tokenType; - char const* textEnd = cursor; + char const* textEnd = cursor; - // Note(tfoley): `StringBuilder::Append()` seems to crash when appending zero bytes - if(textEnd != textBegin) - { - StringBuilder valueBuilder; - valueBuilder.Append(textBegin, int(textEnd - textBegin)); - token.Content = valueBuilder.ProduceString(); - } + // Note(tfoley): `StringBuilder::Append()` seems to crash when appending zero bytes + if(textEnd != textBegin) + { + StringBuilder valueBuilder; + valueBuilder.Append(textBegin, int(textEnd - textBegin)); + token.Content = valueBuilder.ProduceString(); + } - token.flags = flags; + token.flags = flags; - this->tokenFlags = 0; + this->tokenFlags = 0; - return token; - } + return token; } + } - TokenList Lexer::lexAllTokens() + TokenList Lexer::lexAllTokens() + { + TokenList tokenList; + for(;;) { - TokenList tokenList; - for(;;) - { - Token token = lexToken(); - tokenList.mTokens.Add(token); + Token token = lexToken(); + tokenList.mTokens.Add(token); - if(token.Type == TokenType::EndOfFile) - return tokenList; - } + if(token.Type == TokenType::EndOfFile) + return tokenList; } + } #if 0 - TokenList Lexer::Parse(const String & fileName, const String & str, DiagnosticSink * sink) + TokenList Lexer::Parse(const String & fileName, const String & str, DiagnosticSink * sink) + { + TokenList tokenList; + tokenList.mTokens = TokenizeText(fileName, str, [&](TokenizeErrorType errType, CodePosition pos) { - TokenList tokenList; - tokenList.mTokens = TokenizeText(fileName, str, [&](TokenizeErrorType errType, CodePosition pos) + auto curChar = str[pos.Pos]; + switch (errType) { - auto curChar = str[pos.Pos]; - switch (errType) + case TokenizeErrorType::InvalidCharacter: + // Check if inside the ASCII "printable" range + if(curChar >= 0x20 && curChar <= 0x7E) { - case TokenizeErrorType::InvalidCharacter: - // Check if inside the ASCII "printable" range - if(curChar >= 0x20 && curChar <= 0x7E) - { - char buffer[] = { curChar, 0 }; - sink->diagnose(pos, Diagnostics::illegalCharacterPrint, buffer); - } - else - { - // Fallback: print as hexadecimal - sink->diagnose(pos, Diagnostics::illegalCharacterHex, String((unsigned char)curChar, 16)); - } - break; - case TokenizeErrorType::InvalidEscapeSequence: - sink->diagnose(pos, Diagnostics::illegalCharacterLiteral); - break; - default: - break; + char buffer[] = { curChar, 0 }; + sink->diagnose(pos, Diagnostics::illegalCharacterPrint, buffer); + } + else + { + // Fallback: print as hexadecimal + sink->diagnose(pos, Diagnostics::illegalCharacterHex, String((unsigned char)curChar, 16)); } - }); + break; + case TokenizeErrorType::InvalidEscapeSequence: + sink->diagnose(pos, Diagnostics::illegalCharacterLiteral); + break; + default: + break; + } + }); - // Add an end-of-file token so that we can reference it in diagnostic messages - tokenList.mTokens.Add(Token(TokenType::EndOfFile, "", 0, 0, 0, fileName, TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace)); - return tokenList; - } -#endif + // Add an end-of-file token so that we can reference it in diagnostic messages + tokenList.mTokens.Add(Token(TokenType::EndOfFile, "", 0, 0, 0, fileName, TokenFlag::AtStartOfLine | TokenFlag::AfterWhitespace)); + return tokenList; } +#endif }
\ No newline at end of file diff --git a/source/slang/lexer.h b/source/slang/lexer.h index d11e92d84..d599b3b7f 100644 --- a/source/slang/lexer.h +++ b/source/slang/lexer.h @@ -6,96 +6,93 @@ namespace Slang { - namespace Compiler + using namespace CoreLib::Basic; + + struct TokenList + { + Token* begin() const; + Token* end() const; + + List<Token> mTokens; + }; + + struct TokenSpan + { + TokenSpan(); + TokenSpan( + TokenList const& tokenList) + : mBegin(tokenList.begin()) + , mEnd (tokenList.end ()) + {} + + Token* begin() const { return mBegin; } + Token* end () const { return mEnd ; } + + int GetCount() { return (int)(mEnd - mBegin); } + + Token* mBegin; + Token* mEnd; + }; + + struct TokenReader { - using namespace CoreLib::Basic; - - struct TokenList - { - Token* begin() const; - Token* end() const; - - List<Token> mTokens; - }; - - struct TokenSpan - { - TokenSpan(); - TokenSpan( - TokenList const& tokenList) - : mBegin(tokenList.begin()) - , mEnd (tokenList.end ()) - {} - - Token* begin() const { return mBegin; } - Token* end () const { return mEnd ; } - - int GetCount() { return (int)(mEnd - mBegin); } - - Token* mBegin; - Token* mEnd; - }; - - struct TokenReader - { - TokenReader(); - explicit TokenReader(TokenSpan const& tokens) - : mCursor(tokens.begin()) - , mEnd (tokens.end ()) - {} - explicit TokenReader(TokenList const& tokens) - : mCursor(tokens.begin()) - , mEnd (tokens.end ()) - {} - - bool IsAtEnd() const { return mCursor == mEnd; } - Token PeekToken() const; - TokenType PeekTokenType() const; - CodePosition PeekLoc() const; - - Token AdvanceToken(); - - int GetCount() { return (int)(mEnd - mCursor); } - - Token* mCursor; - Token* mEnd; - }; - - typedef unsigned int LexerFlags; - enum - { - kLexerFlag_InDirective = 1 << 0, - kLexerFlag_ExpectFileName = 2 << 0, - }; - - struct Lexer - { - Lexer( - String const& path, - String const& content, - DiagnosticSink* sink); - - ~Lexer(); - - Token lexToken(); - - TokenList lexAllTokens(); - - String path; - String content; - DiagnosticSink* sink; - - char const* cursor; - char const* end; - CodePosition loc; - TokenFlags tokenFlags; - LexerFlags lexerFlags; - }; - - // Helper routines for extracting values from tokens - String getStringLiteralTokenValue(Token const& token); - String getFileNameTokenValue(Token const& token); - } + TokenReader(); + explicit TokenReader(TokenSpan const& tokens) + : mCursor(tokens.begin()) + , mEnd (tokens.end ()) + {} + explicit TokenReader(TokenList const& tokens) + : mCursor(tokens.begin()) + , mEnd (tokens.end ()) + {} + + bool IsAtEnd() const { return mCursor == mEnd; } + Token PeekToken() const; + TokenType PeekTokenType() const; + CodePosition PeekLoc() const; + + Token AdvanceToken(); + + int GetCount() { return (int)(mEnd - mCursor); } + + Token* mCursor; + Token* mEnd; + }; + + typedef unsigned int LexerFlags; + enum + { + kLexerFlag_InDirective = 1 << 0, + kLexerFlag_ExpectFileName = 2 << 0, + }; + + struct Lexer + { + Lexer( + String const& path, + String const& content, + DiagnosticSink* sink); + + ~Lexer(); + + Token lexToken(); + + TokenList lexAllTokens(); + + String path; + String content; + DiagnosticSink* sink; + + char const* cursor; + char const* end; + CodePosition loc; + TokenFlags tokenFlags; + LexerFlags lexerFlags; + }; + + // Helper routines for extracting values from tokens + String getStringLiteralTokenValue(Token const& token); + String getFileNameTokenValue(Token const& token); } #endif
\ No newline at end of file diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp index 9731b1c8a..89d07e223 100644 --- a/source/slang/lookup.cpp +++ b/source/slang/lookup.cpp @@ -2,7 +2,6 @@ #include "lookup.h" namespace Slang { -namespace Compiler { // @@ -308,4 +307,4 @@ LookupResult LookUpLocal(String const& name, ContainerDecl* containerDecl) } -}} +} diff --git a/source/slang/lookup.h b/source/slang/lookup.h index 25b62738f..bdf392aec 100644 --- a/source/slang/lookup.h +++ b/source/slang/lookup.h @@ -4,7 +4,6 @@ #include "Syntax.h" namespace Slang { -namespace Compiler { // Take an existing lookup result and refine it to only include // results that pass the given `LookupMask`. @@ -36,6 +35,6 @@ QualType getTypeForDeclRef( DeclRef declRef); -}} +} #endif
\ No newline at end of file diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 8bbb566af..1a1fa4335 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -10,7 +10,6 @@ #define SLANG_EXHAUSTIVE_SWITCH() default: assert(!"unexpected"); break; namespace Slang { -namespace Compiler { // Information on ranges of registers already claimed/used struct UsedRange @@ -1249,4 +1248,4 @@ void GenerateParameterBindings( program->layout = programLayout; } -}} +} diff --git a/source/slang/parameter-binding.h b/source/slang/parameter-binding.h index 8165f1b2e..2fff08090 100644 --- a/source/slang/parameter-binding.h +++ b/source/slang/parameter-binding.h @@ -7,7 +7,6 @@ #include "../../Slang.h" namespace Slang { -namespace Compiler { class CollectionOfTranslationUnits; @@ -27,6 +26,6 @@ class CollectionOfTranslationUnits; void GenerateParameterBindings( CollectionOfTranslationUnits* program); -}} +} #endif // SLANG_REFLECTION_H diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 9e76a68b9..df2986959 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -6,3149 +6,3146 @@ namespace Slang { - namespace Compiler - { - enum Precedence : int - { - Invalid = -1, - Comma, - Assignment, - TernaryConditional, - LogicalOr, - LogicalAnd, - BitOr, - BitXor, - BitAnd, - EqualityComparison, - RelationalComparison, - BitShift, - Additive, - Multiplicative, - Prefix, - Postfix, - }; - - // TODO: implement two pass parsing for file reference and struct type recognition - - class Parser + enum Precedence : int + { + Invalid = -1, + Comma, + Assignment, + TernaryConditional, + LogicalOr, + LogicalAnd, + BitOr, + BitXor, + BitAnd, + EqualityComparison, + RelationalComparison, + BitShift, + Additive, + Multiplicative, + Prefix, + Postfix, + }; + + // TODO: implement two pass parsing for file reference and struct type recognition + + class Parser + { + public: + CompileOptions& options; + int anonymousCounter = 0; + + RefPtr<Scope> outerScope; + RefPtr<Scope> currentScope; + + TokenReader tokenReader; + DiagnosticSink * sink; + String fileName; + int genericDepth = 0; + + // Is the parser in a "recovering" state? + // During recovery we don't emit additional errors, until we find + // a token that we expected, when we exit recovery. + bool isRecovering = false; + + void FillPosition(SyntaxNode * node) + { + node->Position = tokenReader.PeekLoc(); + } + void PushScope(ContainerDecl* containerDecl) + { + RefPtr<Scope> newScope = new Scope(); + newScope->containerDecl = containerDecl; + newScope->parent = currentScope; + + currentScope = newScope; + } + void PopScope() + { + currentScope = currentScope->parent; + } + Parser( + CompileOptions& options, + TokenSpan const& _tokens, + DiagnosticSink * sink, + String _fileName, + RefPtr<Scope> const& outerScope) + : options(options) + , tokenReader(_tokens) + , sink(sink) + , fileName(_fileName) + , outerScope(outerScope) + {} + RefPtr<ProgramSyntaxNode> Parse(); + + Token ReadToken(); + Token ReadToken(TokenType type); + Token ReadToken(const char * string); + bool LookAheadToken(TokenType type, int offset = 0); + bool LookAheadToken(const char * string, int offset = 0); + void parseSourceFile(ProgramSyntaxNode* program); + RefPtr<ProgramSyntaxNode> ParseProgram(); + RefPtr<StructSyntaxNode> ParseStruct(); + RefPtr<ClassSyntaxNode> ParseClass(); + RefPtr<StatementSyntaxNode> ParseStatement(); + RefPtr<StatementSyntaxNode> ParseBlockStatement(); + RefPtr<VarDeclrStatementSyntaxNode> ParseVarDeclrStatement(Modifiers modifiers); + RefPtr<IfStatementSyntaxNode> ParseIfStatement(); + RefPtr<ForStatementSyntaxNode> ParseForStatement(); + RefPtr<WhileStatementSyntaxNode> ParseWhileStatement(); + RefPtr<DoWhileStatementSyntaxNode> ParseDoWhileStatement(); + RefPtr<BreakStatementSyntaxNode> ParseBreakStatement(); + RefPtr<ContinueStatementSyntaxNode> ParseContinueStatement(); + RefPtr<ReturnStatementSyntaxNode> ParseReturnStatement(); + RefPtr<ExpressionStatementSyntaxNode> ParseExpressionStatement(); + RefPtr<ExpressionSyntaxNode> ParseExpression(Precedence level = Precedence::Comma); + + // Parse an expression that might be used in an initializer or argument context, so we should avoid operator-comma + inline RefPtr<ExpressionSyntaxNode> ParseInitExpr() { return ParseExpression(Precedence::Assignment); } + inline RefPtr<ExpressionSyntaxNode> ParseArgExpr() { return ParseExpression(Precedence::Assignment); } + + RefPtr<ExpressionSyntaxNode> ParseLeafExpression(); + RefPtr<ParameterSyntaxNode> ParseParameter(); + RefPtr<ExpressionSyntaxNode> ParseType(); + TypeExp ParseTypeExp(); + + Parser & operator = (const Parser &) = delete; + }; + + // Forward Declarations + + static void ParseDeclBody( + Parser* parser, + ContainerDecl* containerDecl, + TokenType closingToken); + + // Parse the `{}`-delimeted body of an aggregate type declaration + static void parseAggTypeDeclBody( + Parser* parser, + AggTypeDeclBase* decl); + + static RefPtr<Modifier> ParseOptSemantics( + Parser* parser); + + static void ParseOptSemantics( + Parser* parser, + Decl* decl); + + static RefPtr<DeclBase> ParseDecl( + Parser* parser, + ContainerDecl* containerDecl); + + static RefPtr<Decl> ParseSingleDecl( + Parser* parser, + ContainerDecl* containerDecl); + + // + + static void Unexpected( + Parser* parser) + { + // Don't emit "unexpected token" errors if we are in recovering mode + if (!parser->isRecovering) { - public: - CompileOptions& options; - int anonymousCounter = 0; + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedToken, + parser->tokenReader.PeekTokenType()); - RefPtr<Scope> outerScope; - RefPtr<Scope> currentScope; - - TokenReader tokenReader; - DiagnosticSink * sink; - String fileName; - int genericDepth = 0; - - // Is the parser in a "recovering" state? - // During recovery we don't emit additional errors, until we find - // a token that we expected, when we exit recovery. - bool isRecovering = false; - - void FillPosition(SyntaxNode * node) - { - node->Position = tokenReader.PeekLoc(); - } - void PushScope(ContainerDecl* containerDecl) - { - RefPtr<Scope> newScope = new Scope(); - newScope->containerDecl = containerDecl; - newScope->parent = currentScope; - - currentScope = newScope; - } - void PopScope() - { - currentScope = currentScope->parent; - } - Parser( - CompileOptions& options, - TokenSpan const& _tokens, - DiagnosticSink * sink, - String _fileName, - RefPtr<Scope> const& outerScope) - : options(options) - , tokenReader(_tokens) - , sink(sink) - , fileName(_fileName) - , outerScope(outerScope) - {} - RefPtr<ProgramSyntaxNode> Parse(); - - Token ReadToken(); - Token ReadToken(TokenType type); - Token ReadToken(const char * string); - bool LookAheadToken(TokenType type, int offset = 0); - bool LookAheadToken(const char * string, int offset = 0); - void parseSourceFile(ProgramSyntaxNode* program); - RefPtr<ProgramSyntaxNode> ParseProgram(); - RefPtr<StructSyntaxNode> ParseStruct(); - RefPtr<ClassSyntaxNode> ParseClass(); - RefPtr<StatementSyntaxNode> ParseStatement(); - RefPtr<StatementSyntaxNode> ParseBlockStatement(); - RefPtr<VarDeclrStatementSyntaxNode> ParseVarDeclrStatement(Modifiers modifiers); - RefPtr<IfStatementSyntaxNode> ParseIfStatement(); - RefPtr<ForStatementSyntaxNode> ParseForStatement(); - RefPtr<WhileStatementSyntaxNode> ParseWhileStatement(); - RefPtr<DoWhileStatementSyntaxNode> ParseDoWhileStatement(); - RefPtr<BreakStatementSyntaxNode> ParseBreakStatement(); - RefPtr<ContinueStatementSyntaxNode> ParseContinueStatement(); - RefPtr<ReturnStatementSyntaxNode> ParseReturnStatement(); - RefPtr<ExpressionStatementSyntaxNode> ParseExpressionStatement(); - RefPtr<ExpressionSyntaxNode> ParseExpression(Precedence level = Precedence::Comma); - - // Parse an expression that might be used in an initializer or argument context, so we should avoid operator-comma - inline RefPtr<ExpressionSyntaxNode> ParseInitExpr() { return ParseExpression(Precedence::Assignment); } - inline RefPtr<ExpressionSyntaxNode> ParseArgExpr() { return ParseExpression(Precedence::Assignment); } - - RefPtr<ExpressionSyntaxNode> ParseLeafExpression(); - RefPtr<ParameterSyntaxNode> ParseParameter(); - RefPtr<ExpressionSyntaxNode> ParseType(); - TypeExp ParseTypeExp(); - - Parser & operator = (const Parser &) = delete; - }; - - // Forward Declarations - - static void ParseDeclBody( - Parser* parser, - ContainerDecl* containerDecl, - TokenType closingToken); - - // Parse the `{}`-delimeted body of an aggregate type declaration - static void parseAggTypeDeclBody( - Parser* parser, - AggTypeDeclBase* decl); + // Switch into recovery mode, to suppress additional errors + parser->isRecovering = true; + } + } - static RefPtr<Modifier> ParseOptSemantics( - Parser* parser); + static void Unexpected( + Parser* parser, + char const* expected) + { + // Don't emit "unexpected token" errors if we are in recovering mode + if (!parser->isRecovering) + { + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenName, + parser->tokenReader.PeekTokenType(), + expected); - static void ParseOptSemantics( - Parser* parser, - Decl* decl); + // Switch into recovery mode, to suppress additional errors + parser->isRecovering = true; + } + } - static RefPtr<DeclBase> ParseDecl( - Parser* parser, - ContainerDecl* containerDecl); + static void Unexpected( + Parser* parser, + TokenType expected) + { + // Don't emit "unexpected token" errors if we are in recovering mode + if (!parser->isRecovering) + { + parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenType, + parser->tokenReader.PeekTokenType(), + expected); - static RefPtr<Decl> ParseSingleDecl( - Parser* parser, - ContainerDecl* containerDecl); + // Switch into recovery mode, to suppress additional errors + parser->isRecovering = true; + } + } - // + static TokenType SkipToMatchingToken(TokenReader* reader, TokenType tokenType); - static void Unexpected( - Parser* parser) + // Skip a singel balanced token, which is either a single token in + // the common case, or a matched pair of tokens for `()`, `[]`, and `{}` + static TokenType SkipBalancedToken( + TokenReader* reader) + { + TokenType tokenType = reader->AdvanceToken().Type; + switch (tokenType) { - // Don't emit "unexpected token" errors if we are in recovering mode - if (!parser->isRecovering) - { - parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedToken, - parser->tokenReader.PeekTokenType()); + default: + break; - // Switch into recovery mode, to suppress additional errors - parser->isRecovering = true; - } + case TokenType::LParent: tokenType = SkipToMatchingToken(reader, TokenType::RParent); break; + case TokenType::LBrace: tokenType = SkipToMatchingToken(reader, TokenType::RBrace); break; + case TokenType::LBracket: tokenType = SkipToMatchingToken(reader, TokenType::RBracket); break; } + return tokenType; + } - static void Unexpected( - Parser* parser, - char const* expected) + // Skip balanced + static TokenType SkipToMatchingToken( + TokenReader* reader, + TokenType tokenType) + { + for (;;) { - // Don't emit "unexpected token" errors if we are in recovering mode - if (!parser->isRecovering) + if (reader->IsAtEnd()) return TokenType::EndOfFile; + if (reader->PeekTokenType() == tokenType) { - parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenName, - parser->tokenReader.PeekTokenType(), - expected); - - // Switch into recovery mode, to suppress additional errors - parser->isRecovering = true; + reader->AdvanceToken(); + return tokenType; } + SkipBalancedToken(reader); } + } - static void Unexpected( - Parser* parser, - TokenType expected) + // Is the given token type one that is used to "close" a + // balanced construct. + static bool IsClosingToken(TokenType tokenType) + { + switch (tokenType) { - // Don't emit "unexpected token" errors if we are in recovering mode - if (!parser->isRecovering) - { - parser->sink->diagnose(parser->tokenReader.PeekLoc(), Diagnostics::unexpectedTokenExpectedTokenType, - parser->tokenReader.PeekTokenType(), - expected); + case TokenType::EndOfFile: + case TokenType::RBracket: + case TokenType::RParent: + case TokenType::RBrace: + return true; - // Switch into recovery mode, to suppress additional errors - parser->isRecovering = true; - } + default: + return false; } + } - static TokenType SkipToMatchingToken(TokenReader* reader, TokenType tokenType); - // Skip a singel balanced token, which is either a single token in - // the common case, or a matched pair of tokens for `()`, `[]`, and `{}` - static TokenType SkipBalancedToken( - TokenReader* reader) + // Expect an identifier token with the given content, and consume it. + Token Parser::ReadToken(const char* expected) + { + if (tokenReader.PeekTokenType() == TokenType::Identifier + && tokenReader.PeekToken().Content == expected) { - TokenType tokenType = reader->AdvanceToken().Type; - switch (tokenType) - { - default: - break; - - case TokenType::LParent: tokenType = SkipToMatchingToken(reader, TokenType::RParent); break; - case TokenType::LBrace: tokenType = SkipToMatchingToken(reader, TokenType::RBrace); break; - case TokenType::LBracket: tokenType = SkipToMatchingToken(reader, TokenType::RBracket); break; - } - return tokenType; + isRecovering = false; + return tokenReader.AdvanceToken(); } - // Skip balanced - static TokenType SkipToMatchingToken( - TokenReader* reader, - TokenType tokenType) + if (!isRecovering) + { + Unexpected(this, expected); + return tokenReader.PeekToken(); + } + else { + // Try to find a place to recover for (;;) { - if (reader->IsAtEnd()) return TokenType::EndOfFile; - if (reader->PeekTokenType() == tokenType) + // The token we expected? + // Then exit recovery mode and pretend like all is well. + if (tokenReader.PeekTokenType() == TokenType::Identifier + && tokenReader.PeekToken().Content == expected) { - reader->AdvanceToken(); - return tokenType; + isRecovering = false; + return tokenReader.AdvanceToken(); } - SkipBalancedToken(reader); - } - } - - // Is the given token type one that is used to "close" a - // balanced construct. - static bool IsClosingToken(TokenType tokenType) - { - switch (tokenType) - { - case TokenType::EndOfFile: - case TokenType::RBracket: - case TokenType::RParent: - case TokenType::RBrace: - return true; - - default: - return false; - } - } - // Expect an identifier token with the given content, and consume it. - Token Parser::ReadToken(const char* expected) - { - if (tokenReader.PeekTokenType() == TokenType::Identifier - && tokenReader.PeekToken().Content == expected) - { - isRecovering = false; - return tokenReader.AdvanceToken(); - } + // Don't skip past any "closing" tokens. + if (IsClosingToken(tokenReader.PeekTokenType())) + { + return tokenReader.PeekToken(); + } - if (!isRecovering) - { - Unexpected(this, expected); - return tokenReader.PeekToken(); + // Skip balanced tokens and try again. + SkipBalancedToken(&tokenReader); } - else - { - // Try to find a place to recover - for (;;) - { - // The token we expected? - // Then exit recovery mode and pretend like all is well. - if (tokenReader.PeekTokenType() == TokenType::Identifier - && tokenReader.PeekToken().Content == expected) - { - isRecovering = false; - return tokenReader.AdvanceToken(); - } + } + } + Token Parser::ReadToken() + { + return tokenReader.AdvanceToken(); + } - // Don't skip past any "closing" tokens. - if (IsClosingToken(tokenReader.PeekTokenType())) - { - return tokenReader.PeekToken(); - } + static bool TryRecover( + Parser* parser, + TokenType const* recoverBefore, + int recoverBeforeCount, + TokenType const* recoverAfter, + int recoverAfterCount) + { + if (!parser->isRecovering) + return true; - // Skip balanced tokens and try again. - SkipBalancedToken(&tokenReader); - } - } + // Determine if we are looking for a closing token at all... + bool lookingForClose = false; + for (int ii = 0; ii < recoverBeforeCount; ++ii) + { + if (IsClosingToken(recoverBefore[ii])) + lookingForClose = true; } - - Token Parser::ReadToken() + for (int ii = 0; ii < recoverAfterCount; ++ii) { - return tokenReader.AdvanceToken(); + if (IsClosingToken(recoverAfter[ii])) + lookingForClose = true; } - static bool TryRecover( - Parser* parser, - TokenType const* recoverBefore, - int recoverBeforeCount, - TokenType const* recoverAfter, - int recoverAfterCount) + TokenReader* tokenReader = &parser->tokenReader; + for (;;) { - if (!parser->isRecovering) - return true; + TokenType peek = tokenReader->PeekTokenType(); - // Determine if we are looking for a closing token at all... - bool lookingForClose = false; + // Is the next token in our recover-before set? + // If so, then we have recovered successfully! for (int ii = 0; ii < recoverBeforeCount; ++ii) { - if (IsClosingToken(recoverBefore[ii])) - lookingForClose = true; - } - for (int ii = 0; ii < recoverAfterCount; ++ii) - { - if (IsClosingToken(recoverAfter[ii])) - lookingForClose = true; + if (peek == recoverBefore[ii]) + { + parser->isRecovering = false; + return true; + } } - TokenReader* tokenReader = &parser->tokenReader; - for (;;) + // If we are looking at a token in our recover-after set, + // then consume it and recover + for (int ii = 0; ii < recoverAfterCount; ++ii) { - TokenType peek = tokenReader->PeekTokenType(); - - // Is the next token in our recover-before set? - // If so, then we have recovered successfully! - for (int ii = 0; ii < recoverBeforeCount; ++ii) + if (peek == recoverAfter[ii]) { - if (peek == recoverBefore[ii]) - { - parser->isRecovering = false; - return true; - } + tokenReader->AdvanceToken(); + parser->isRecovering = false; + return true; } + } - // If we are looking at a token in our recover-after set, - // then consume it and recover - for (int ii = 0; ii < recoverAfterCount; ++ii) - { - if (peek == recoverAfter[ii]) - { - tokenReader->AdvanceToken(); - parser->isRecovering = false; - return true; - } - } + // Don't try to skip past end of file + if (peek == TokenType::EndOfFile) + return false; - // Don't try to skip past end of file - if (peek == TokenType::EndOfFile) + switch (peek) + { + // Don't skip past simple "closing" tokens, *unless* + // we are looking for a closing token + case TokenType::RParent: + case TokenType::RBracket: + if (!lookingForClose) return false; + break; - switch (peek) - { - // Don't skip past simple "closing" tokens, *unless* - // we are looking for a closing token - case TokenType::RParent: - case TokenType::RBracket: - if (!lookingForClose) - return false; - break; - - // never skip a `}`, to avoid spurious errors - case TokenType::RBrace: - return false; - } + // never skip a `}`, to avoid spurious errors + case TokenType::RBrace: + return false; + } - // Skip balanced tokens and try again. - TokenType skipped = SkipBalancedToken(tokenReader); + // Skip balanced tokens and try again. + TokenType skipped = SkipBalancedToken(tokenReader); - // If we happened to find a matched pair of tokens, and - // the end of it was a token we were looking for, - // then recover here - for (int ii = 0; ii < recoverAfterCount; ++ii) + // If we happened to find a matched pair of tokens, and + // the end of it was a token we were looking for, + // then recover here + for (int ii = 0; ii < recoverAfterCount; ++ii) + { + if (skipped == recoverAfter[ii]) { - if (skipped == recoverAfter[ii]) - { - parser->isRecovering = false; - return true; - } + parser->isRecovering = false; + return true; } } } + } - static bool TryRecoverBefore( - Parser* parser, - TokenType before0) + static bool TryRecoverBefore( + Parser* parser, + TokenType before0) + { + TokenType recoverBefore[] = { before0 }; + return TryRecover(parser, recoverBefore, 1, nullptr, 0); + } + + // Default recovery strategy, to use inside `{}`-delimeted blocks. + static bool TryRecover( + Parser* parser) + { + TokenType recoverBefore[] = { TokenType::RBrace }; + TokenType recoverAfter[] = { TokenType::Semicolon }; + return TryRecover(parser, recoverBefore, 1, recoverAfter, 1); + } + + Token Parser::ReadToken(TokenType expected) + { + if (tokenReader.PeekTokenType() == expected) { - TokenType recoverBefore[] = { before0 }; - return TryRecover(parser, recoverBefore, 1, nullptr, 0); + isRecovering = false; + return tokenReader.AdvanceToken(); } - // Default recovery strategy, to use inside `{}`-delimeted blocks. - static bool TryRecover( - Parser* parser) + if (!isRecovering) { - TokenType recoverBefore[] = { TokenType::RBrace }; - TokenType recoverAfter[] = { TokenType::Semicolon }; - return TryRecover(parser, recoverBefore, 1, recoverAfter, 1); + Unexpected(this, expected); + return tokenReader.PeekToken(); } - - Token Parser::ReadToken(TokenType expected) + else { - if (tokenReader.PeekTokenType() == expected) + // Try to find a place to recover + if (TryRecoverBefore(this, expected)) { isRecovering = false; return tokenReader.AdvanceToken(); } - if (!isRecovering) - { - Unexpected(this, expected); - return tokenReader.PeekToken(); - } - else - { - // Try to find a place to recover - if (TryRecoverBefore(this, expected)) - { - isRecovering = false; - return tokenReader.AdvanceToken(); - } - - return tokenReader.PeekToken(); - } + return tokenReader.PeekToken(); } + } - bool Parser::LookAheadToken(const char * string, int offset) - { - TokenReader r = tokenReader; - for (int ii = 0; ii < offset; ++ii) - r.AdvanceToken(); + bool Parser::LookAheadToken(const char * string, int offset) + { + TokenReader r = tokenReader; + for (int ii = 0; ii < offset; ++ii) + r.AdvanceToken(); - return r.PeekTokenType() == TokenType::Identifier - && r.PeekToken().Content == string; - } + return r.PeekTokenType() == TokenType::Identifier + && r.PeekToken().Content == string; +} - bool Parser::LookAheadToken(TokenType type, int offset) - { - TokenReader r = tokenReader; - for (int ii = 0; ii < offset; ++ii) - r.AdvanceToken(); + bool Parser::LookAheadToken(TokenType type, int offset) + { + TokenReader r = tokenReader; + for (int ii = 0; ii < offset; ++ii) + r.AdvanceToken(); - return r.PeekTokenType() == type; - } + return r.PeekTokenType() == type; + } - // Consume a token and return true it if matches, otherwise false - bool AdvanceIf(Parser* parser, TokenType tokenType) + // Consume a token and return true it if matches, otherwise false + bool AdvanceIf(Parser* parser, TokenType tokenType) + { + if (parser->LookAheadToken(tokenType)) { - if (parser->LookAheadToken(tokenType)) - { - parser->ReadToken(); - return true; - } - return false; + parser->ReadToken(); + return true; } + return false; + } - // Consume a token and return true it if matches, otherwise false - bool AdvanceIf(Parser* parser, char const* text) + // Consume a token and return true it if matches, otherwise false + bool AdvanceIf(Parser* parser, char const* text) + { + if (parser->LookAheadToken(text)) { - if (parser->LookAheadToken(text)) - { - parser->ReadToken(); - return true; - } - return false; + parser->ReadToken(); + return true; } + return false; + } - // Consume a token and return true if it matches, otherwise check - // for end-of-file and expect that token (potentially producing - // an error) and return true to maintain forward progress. - // Otherwise return false. - bool AdvanceIfMatch(Parser* parser, TokenType tokenType) + // Consume a token and return true if it matches, otherwise check + // for end-of-file and expect that token (potentially producing + // an error) and return true to maintain forward progress. + // Otherwise return false. + bool AdvanceIfMatch(Parser* parser, TokenType tokenType) + { + // If we've run into a syntax error, but haven't recovered inside + // the block, then try to recover here. + if (parser->isRecovering) { - // If we've run into a syntax error, but haven't recovered inside - // the block, then try to recover here. - if (parser->isRecovering) - { - TryRecoverBefore(parser, tokenType); - } - if (AdvanceIf(parser, tokenType)) - return true; - if (parser->tokenReader.PeekTokenType() == TokenType::EndOfFile) - { - parser->ReadToken(tokenType); - return true; - } - return false; + TryRecoverBefore(parser, tokenType); } - - RefPtr<ProgramSyntaxNode> Parser::Parse() + if (AdvanceIf(parser, tokenType)) + return true; + if (parser->tokenReader.PeekTokenType() == TokenType::EndOfFile) { - return ParseProgram(); + parser->ReadToken(tokenType); + return true; } + return false; + } - RefPtr<TypeDefDecl> ParseTypeDef(Parser* parser) - { - // Consume the `typedef` keyword - parser->ReadToken("typedef"); + RefPtr<ProgramSyntaxNode> Parser::Parse() + { + return ParseProgram(); + } - // TODO(tfoley): parse an actual declarator - auto type = parser->ParseTypeExp(); + RefPtr<TypeDefDecl> ParseTypeDef(Parser* parser) + { + // Consume the `typedef` keyword + parser->ReadToken("typedef"); - auto nameToken = parser->ReadToken(TokenType::Identifier); + // TODO(tfoley): parse an actual declarator + auto type = parser->ParseTypeExp(); - RefPtr<TypeDefDecl> typeDefDecl = new TypeDefDecl(); - typeDefDecl->Name = nameToken; - typeDefDecl->Type = type; + auto nameToken = parser->ReadToken(TokenType::Identifier); - return typeDefDecl; - } + RefPtr<TypeDefDecl> typeDefDecl = new TypeDefDecl(); + typeDefDecl->Name = nameToken; + typeDefDecl->Type = type; - // Add a modifier to a list of modifiers being built - static void AddModifier(RefPtr<Modifier>** ioModifierLink, RefPtr<Modifier> modifier) - { - RefPtr<Modifier>*& modifierLink = *ioModifierLink; + return typeDefDecl; + } - while(*modifierLink) - modifierLink = &(*modifierLink)->next; + // Add a modifier to a list of modifiers being built + static void AddModifier(RefPtr<Modifier>** ioModifierLink, RefPtr<Modifier> modifier) + { + RefPtr<Modifier>*& modifierLink = *ioModifierLink; - *modifierLink = modifier; - modifierLink = &modifier->next; - } + while(*modifierLink) + modifierLink = &(*modifierLink)->next; - void addModifier( - RefPtr<ModifiableSyntaxNode> syntax, - RefPtr<Modifier> modifier) - { - auto modifierLink = &syntax->modifiers.first; - AddModifier(&modifierLink, modifier); - } + *modifierLink = modifier; + modifierLink = &modifier->next; + } + + void addModifier( + RefPtr<ModifiableSyntaxNode> syntax, + RefPtr<Modifier> modifier) + { + auto modifierLink = &syntax->modifiers.first; + AddModifier(&modifierLink, modifier); + } - // Parse HLSL-style `[name(arg, ...)]` style "attribute" modifiers - static void ParseSquareBracketAttributes(Parser* parser, RefPtr<Modifier>** ioModifierLink) + // Parse HLSL-style `[name(arg, ...)]` style "attribute" modifiers + static void ParseSquareBracketAttributes(Parser* parser, RefPtr<Modifier>** ioModifierLink) + { + parser->ReadToken(TokenType::LBracket); + for(;;) { - parser->ReadToken(TokenType::LBracket); - for(;;) + auto nameToken = parser->ReadToken(TokenType::Identifier); + RefPtr<HLSLUncheckedAttribute> modifier = new HLSLUncheckedAttribute(); + modifier->nameToken = nameToken; + + if (AdvanceIf(parser, TokenType::LParent)) { - auto nameToken = parser->ReadToken(TokenType::Identifier); - RefPtr<HLSLUncheckedAttribute> modifier = new HLSLUncheckedAttribute(); - modifier->nameToken = nameToken; + // HLSL-style `[name(arg0, ...)]` attribute - if (AdvanceIf(parser, TokenType::LParent)) + while (!AdvanceIfMatch(parser, TokenType::RParent)) { - // HLSL-style `[name(arg0, ...)]` attribute - - while (!AdvanceIfMatch(parser, TokenType::RParent)) + auto arg = parser->ParseArgExpr(); + if (arg) { - auto arg = parser->ParseArgExpr(); - if (arg) - { - modifier->args.Add(arg); - } + modifier->args.Add(arg); + } - if (AdvanceIfMatch(parser, TokenType::RParent)) - break; + if (AdvanceIfMatch(parser, TokenType::RParent)) + break; - parser->ReadToken(TokenType::Comma); - } + parser->ReadToken(TokenType::Comma); } - AddModifier(ioModifierLink, modifier); + } + AddModifier(ioModifierLink, modifier); - if (AdvanceIfMatch(parser, TokenType::RBracket)) - break; + if (AdvanceIfMatch(parser, TokenType::RBracket)) + break; - parser->ReadToken(TokenType::Comma); - } + parser->ReadToken(TokenType::Comma); } + } - static Modifiers ParseModifiers(Parser* parser) - { - Modifiers modifiers; - RefPtr<Modifier>* modifierLink = &modifiers.first; - for (;;) - { - CodePosition loc = parser->tokenReader.PeekLoc(); - - if (0) {} - - #define CASE(KEYWORD, TYPE) \ - else if(AdvanceIf(parser, #KEYWORD)) do { \ - RefPtr<TYPE> modifier = new TYPE(); \ - modifier->Position = loc; \ - AddModifier(&modifierLink, modifier); \ - } while(0) - - CASE(in, InModifier); - CASE(input, InputModifier); - CASE(out, OutModifier); - CASE(inout, InOutModifier); - CASE(const, ConstModifier); - CASE(instance, InstanceModifier); - CASE(__builtin, BuiltinModifier); - - CASE(inline, InlineModifier); - CASE(public, PublicModifier); - CASE(require, RequireModifier); - CASE(param, ParamModifier); - CASE(extern, ExternModifier); - - CASE(row_major, HLSLRowMajorLayoutModifier); - CASE(column_major, HLSLColumnMajorLayoutModifier); - - CASE(nointerpolation, HLSLNoInterpolationModifier); - CASE(linear, HLSLLinearModifier); - CASE(sample, HLSLSampleModifier); - CASE(centroid, HLSLCentroidModifier); - CASE(precise, HLSLPreciseModifier); - CASE(shared, HLSLEffectSharedModifier); - CASE(groupshared, HLSLGroupSharedModifier); - CASE(static, HLSLStaticModifier); - CASE(uniform, HLSLUniformModifier); - CASE(volatile, HLSLVolatileModifier); - - // Modifiers for geometry shader input - CASE(point, HLSLPointModifier); - CASE(line, HLSLLineModifier); - CASE(triangle, HLSLTriangleModifier); - CASE(lineadj, HLSLLineAdjModifier); - CASE(triangleadj, HLSLTriangleAdjModifier); - - // Modifiers for unary operator declarations - CASE(__prefix, PrefixModifier); - CASE(__postfix, PostfixModifier); - - #undef CASE - - else if (AdvanceIf(parser, "__intrinsic_op")) + static Modifiers ParseModifiers(Parser* parser) + { + Modifiers modifiers; + RefPtr<Modifier>* modifierLink = &modifiers.first; + for (;;) + { + CodePosition loc = parser->tokenReader.PeekLoc(); + + if (0) {} + + #define CASE(KEYWORD, TYPE) \ + else if(AdvanceIf(parser, #KEYWORD)) do { \ + RefPtr<TYPE> modifier = new TYPE(); \ + modifier->Position = loc; \ + AddModifier(&modifierLink, modifier); \ + } while(0) + + CASE(in, InModifier); + CASE(input, InputModifier); + CASE(out, OutModifier); + CASE(inout, InOutModifier); + CASE(const, ConstModifier); + CASE(instance, InstanceModifier); + CASE(__builtin, BuiltinModifier); + + CASE(inline, InlineModifier); + CASE(public, PublicModifier); + CASE(require, RequireModifier); + CASE(param, ParamModifier); + CASE(extern, ExternModifier); + + CASE(row_major, HLSLRowMajorLayoutModifier); + CASE(column_major, HLSLColumnMajorLayoutModifier); + + CASE(nointerpolation, HLSLNoInterpolationModifier); + CASE(linear, HLSLLinearModifier); + CASE(sample, HLSLSampleModifier); + CASE(centroid, HLSLCentroidModifier); + CASE(precise, HLSLPreciseModifier); + CASE(shared, HLSLEffectSharedModifier); + CASE(groupshared, HLSLGroupSharedModifier); + CASE(static, HLSLStaticModifier); + CASE(uniform, HLSLUniformModifier); + CASE(volatile, HLSLVolatileModifier); + + // Modifiers for geometry shader input + CASE(point, HLSLPointModifier); + CASE(line, HLSLLineModifier); + CASE(triangle, HLSLTriangleModifier); + CASE(lineadj, HLSLLineAdjModifier); + CASE(triangleadj, HLSLTriangleAdjModifier); + + // Modifiers for unary operator declarations + CASE(__prefix, PrefixModifier); + CASE(__postfix, PostfixModifier); + + #undef CASE + + else if (AdvanceIf(parser, "__intrinsic_op")) + { + auto modifier = new IntrinsicOpModifier(); + modifier->Position = loc; + + parser->ReadToken(TokenType::LParent); + if (parser->LookAheadToken(TokenType::IntLiterial)) + { + modifier->op = (IntrinsicOp)StringToInt(parser->ReadToken().Content); + } + else { - auto modifier = new IntrinsicOpModifier(); - modifier->Position = loc; + modifier->opToken = parser->ReadToken(TokenType::Identifier); + + modifier->op = findIntrinsicOp(modifier->opToken.Content.Buffer()); - parser->ReadToken(TokenType::LParent); - if (parser->LookAheadToken(TokenType::IntLiterial)) + if (modifier->op == IntrinsicOp::Unknown) { - modifier->op = (IntrinsicOp)StringToInt(parser->ReadToken().Content); + parser->sink->diagnose(loc, Diagnostics::unimplemented, "unknown intrinsic op"); } - else - { - modifier->opToken = parser->ReadToken(TokenType::Identifier); - - modifier->op = findIntrinsicOp(modifier->opToken.Content.Buffer()); + } - if (modifier->op == IntrinsicOp::Unknown) - { - parser->sink->diagnose(loc, Diagnostics::unimplemented, "unknown intrinsic op"); - } - } + parser->ReadToken(TokenType::RParent); - parser->ReadToken(TokenType::RParent); + AddModifier(&modifierLink, modifier); + } - AddModifier(&modifierLink, modifier); - } + else if (AdvanceIf(parser, "__intrinsic")) + { + auto modifier = new TargetIntrinsicModifier(); + modifier->Position = loc; - else if (AdvanceIf(parser, "__intrinsic")) + if (AdvanceIf(parser, TokenType::LParent)) { - auto modifier = new TargetIntrinsicModifier(); - modifier->Position = loc; + modifier->targetToken = parser->ReadToken(TokenType::Identifier); - if (AdvanceIf(parser, TokenType::LParent)) + if( AdvanceIf(parser, TokenType::Comma) ) { - modifier->targetToken = parser->ReadToken(TokenType::Identifier); - - if( AdvanceIf(parser, TokenType::Comma) ) + if( parser->LookAheadToken(TokenType::StringLiterial) ) { - if( parser->LookAheadToken(TokenType::StringLiterial) ) - { - modifier->definitionToken = parser->ReadToken(); - } - else - { - modifier->definitionToken = parser->ReadToken(TokenType::Identifier); - } + modifier->definitionToken = parser->ReadToken(); + } + else + { + modifier->definitionToken = parser->ReadToken(TokenType::Identifier); } - - parser->ReadToken(TokenType::RParent); } - AddModifier(&modifierLink, modifier); + parser->ReadToken(TokenType::RParent); } + AddModifier(&modifierLink, modifier); + } + - else if (AdvanceIf(parser, "layout")) + else if (AdvanceIf(parser, "layout")) + { + parser->ReadToken(TokenType::LParent); + while (!AdvanceIfMatch(parser, TokenType::RParent)) { - parser->ReadToken(TokenType::LParent); - while (!AdvanceIfMatch(parser, TokenType::RParent)) - { - auto nameToken = parser->ReadToken(TokenType::Identifier); + auto nameToken = parser->ReadToken(TokenType::Identifier); - RefPtr<GLSLLayoutModifier> modifier; + RefPtr<GLSLLayoutModifier> modifier; - // TODO: better handling of this choise (e.g., lookup in scope) - if(0) {} - #define CASE(KEYWORD, CLASS) \ - else if(nameToken.Content == #KEYWORD) modifier = new CLASS() + // TODO: better handling of this choise (e.g., lookup in scope) + if(0) {} + #define CASE(KEYWORD, CLASS) \ + else if(nameToken.Content == #KEYWORD) modifier = new CLASS() - CASE(constant_id, GLSLConstantIDLayoutModifier); - CASE(binding, GLSLBindingLayoutModifier); - CASE(set, GLSLSetLayoutModifier); - CASE(location, GLSLLocationLayoutModifier); + CASE(constant_id, GLSLConstantIDLayoutModifier); + CASE(binding, GLSLBindingLayoutModifier); + CASE(set, GLSLSetLayoutModifier); + CASE(location, GLSLLocationLayoutModifier); - #undef CASE - else - { - modifier = new GLSLUnparsedLayoutModifier(); - } + #undef CASE + else + { + modifier = new GLSLUnparsedLayoutModifier(); + } - modifier->nameToken = nameToken; + modifier->nameToken = nameToken; - if(AdvanceIf(parser, TokenType::OpAssign)) - { - modifier->valToken = parser->ReadToken(TokenType::IntLiterial); - } + if(AdvanceIf(parser, TokenType::OpAssign)) + { + modifier->valToken = parser->ReadToken(TokenType::IntLiterial); + } - AddModifier(&modifierLink, modifier); + AddModifier(&modifierLink, modifier); - if (AdvanceIf(parser, TokenType::RParent)) - break; - parser->ReadToken(TokenType::Comma); - } + if (AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); } - else if (parser->tokenReader.PeekTokenType() == TokenType::LBracket) + } + else if (parser->tokenReader.PeekTokenType() == TokenType::LBracket) + { + ParseSquareBracketAttributes(parser, &modifierLink); + } + else if (AdvanceIf(parser,"__builtin_type")) + { + RefPtr<BuiltinTypeModifier> modifier = new BuiltinTypeModifier(); + parser->ReadToken(TokenType::LParent); + modifier->tag = BaseType(StringToInt(parser->ReadToken(TokenType::IntLiterial).Content)); + parser->ReadToken(TokenType::RParent); + + AddModifier(&modifierLink, modifier); + } + else if (AdvanceIf(parser,"__magic_type")) + { + RefPtr<MagicTypeModifier> modifier = new MagicTypeModifier(); + parser->ReadToken(TokenType::LParent); + modifier->name = parser->ReadToken(TokenType::Identifier).Content; + if (AdvanceIf(parser, TokenType::Comma)) { - ParseSquareBracketAttributes(parser, &modifierLink); + modifier->tag = uint32_t(StringToInt(parser->ReadToken(TokenType::IntLiterial).Content)); } - else if (AdvanceIf(parser,"__builtin_type")) - { - RefPtr<BuiltinTypeModifier> modifier = new BuiltinTypeModifier(); - parser->ReadToken(TokenType::LParent); - modifier->tag = BaseType(StringToInt(parser->ReadToken(TokenType::IntLiterial).Content)); - parser->ReadToken(TokenType::RParent); + parser->ReadToken(TokenType::RParent); - AddModifier(&modifierLink, modifier); - } - else if (AdvanceIf(parser,"__magic_type")) - { - RefPtr<MagicTypeModifier> modifier = new MagicTypeModifier(); - parser->ReadToken(TokenType::LParent); - modifier->name = parser->ReadToken(TokenType::Identifier).Content; - if (AdvanceIf(parser, TokenType::Comma)) - { - modifier->tag = uint32_t(StringToInt(parser->ReadToken(TokenType::IntLiterial).Content)); - } - parser->ReadToken(TokenType::RParent); + AddModifier(&modifierLink, modifier); + } + else + { + // Fallback case if none of the above explicit cases matched. - AddModifier(&modifierLink, modifier); - } - else + // If we are looking at an identifier, then it may map to a + // modifier declaration visible in the current scope + if( parser->LookAheadToken(TokenType::Identifier) ) { - // Fallback case if none of the above explicit cases matched. + LookupResult lookupResult = LookUp( + parser->tokenReader.PeekToken().Content, + parser->currentScope); - // If we are looking at an identifier, then it may map to a - // modifier declaration visible in the current scope - if( parser->LookAheadToken(TokenType::Identifier) ) + if( lookupResult.isValid() && !lookupResult.isOverloaded() ) { - LookupResult lookupResult = LookUp( - parser->tokenReader.PeekToken().Content, - parser->currentScope); + auto& item = lookupResult.item; + auto decl = item.declRef.GetDecl(); - if( lookupResult.isValid() && !lookupResult.isOverloaded() ) + if( auto modifierDecl = dynamic_cast<ModifierDecl*>(decl) ) { - auto& item = lookupResult.item; - auto decl = item.declRef.GetDecl(); + // We found a declaration for some modifier syntax, + // so lets create an instance of the type it names + // here. - if( auto modifierDecl = dynamic_cast<ModifierDecl*>(decl) ) + auto syntax = createInstanceOfSyntaxClassByName(modifierDecl->classNameToken.Content); + auto modifier = dynamic_cast<Modifier*>(syntax); + + if( modifier ) + { + modifier->Position = parser->tokenReader.PeekLoc(); + modifier->nameToken = parser->ReadToken(TokenType::Identifier); + + AddModifier(&modifierLink, modifier); + continue; + } + else { - // We found a declaration for some modifier syntax, - // so lets create an instance of the type it names - // here. - - auto syntax = createInstanceOfSyntaxClassByName(modifierDecl->classNameToken.Content); - auto modifier = dynamic_cast<Modifier*>(syntax); - - if( modifier ) - { - modifier->Position = parser->tokenReader.PeekLoc(); - modifier->nameToken = parser->ReadToken(TokenType::Identifier); - - AddModifier(&modifierLink, modifier); - continue; - } - else - { - parser->ReadToken(TokenType::Identifier); - assert(!"unexpected"); - } + parser->ReadToken(TokenType::Identifier); + assert(!"unexpected"); } } } - - // Done with modifier list - return modifiers; } + + // Done with modifier list + return modifiers; } } + } - static RefPtr<Decl> parseImportDecl( - Parser* parser) - { - parser->ReadToken("__import"); + static RefPtr<Decl> parseImportDecl( + Parser* parser) + { + parser->ReadToken("__import"); - auto decl = new ImportDecl(); - decl->nameToken = parser->ReadToken(TokenType::Identifier); - decl->scope = parser->currentScope; + auto decl = new ImportDecl(); + decl->nameToken = parser->ReadToken(TokenType::Identifier); + decl->scope = parser->currentScope; - parser->ReadToken(TokenType::Semicolon); + parser->ReadToken(TokenType::Semicolon); - return decl; - } + return decl; + } - static Token ParseDeclName( - Parser* parser) + static Token ParseDeclName( + Parser* parser) + { + Token name; + if (AdvanceIf(parser, "operator")) { - Token name; - if (AdvanceIf(parser, "operator")) + name = parser->ReadToken(); + switch (name.Type) { - name = parser->ReadToken(); - switch (name.Type) - { - case TokenType::OpAdd: case TokenType::OpSub: case TokenType::OpMul: case TokenType::OpDiv: - case TokenType::OpMod: case TokenType::OpNot: case TokenType::OpBitNot: case TokenType::OpLsh: case TokenType::OpRsh: - case TokenType::OpEql: case TokenType::OpNeq: case TokenType::OpGreater: case TokenType::OpLess: case TokenType::OpGeq: - case TokenType::OpLeq: case TokenType::OpAnd: case TokenType::OpOr: case TokenType::OpBitXor: case TokenType::OpBitAnd: - case TokenType::OpBitOr: case TokenType::OpInc: case TokenType::OpDec: - case TokenType::OpAddAssign: - case TokenType::OpSubAssign: - case TokenType::OpMulAssign: - case TokenType::OpDivAssign: - case TokenType::OpModAssign: - case TokenType::OpShlAssign: - case TokenType::OpShrAssign: - case TokenType::OpOrAssign: - case TokenType::OpAndAssign: - case TokenType::OpXorAssign: - - // Note(tfoley): A bit of a hack: - case TokenType::Comma: - case TokenType::OpAssign: - break; + case TokenType::OpAdd: case TokenType::OpSub: case TokenType::OpMul: case TokenType::OpDiv: + case TokenType::OpMod: case TokenType::OpNot: case TokenType::OpBitNot: case TokenType::OpLsh: case TokenType::OpRsh: + case TokenType::OpEql: case TokenType::OpNeq: case TokenType::OpGreater: case TokenType::OpLess: case TokenType::OpGeq: + case TokenType::OpLeq: case TokenType::OpAnd: case TokenType::OpOr: case TokenType::OpBitXor: case TokenType::OpBitAnd: + case TokenType::OpBitOr: case TokenType::OpInc: case TokenType::OpDec: + case TokenType::OpAddAssign: + case TokenType::OpSubAssign: + case TokenType::OpMulAssign: + case TokenType::OpDivAssign: + case TokenType::OpModAssign: + case TokenType::OpShlAssign: + case TokenType::OpShrAssign: + case TokenType::OpOrAssign: + case TokenType::OpAndAssign: + case TokenType::OpXorAssign: - // Note(tfoley): Even more of a hack! - case TokenType::QuestionMark: - if (AdvanceIf(parser, TokenType::Colon)) - { - name.Content = name.Content + ":"; - break; - } + // Note(tfoley): A bit of a hack: + case TokenType::Comma: + case TokenType::OpAssign: + break; - default: - parser->sink->diagnose(name.Position, Diagnostics::invalidOperator, name.Content); + // Note(tfoley): Even more of a hack! + case TokenType::QuestionMark: + if (AdvanceIf(parser, TokenType::Colon)) + { + name.Content = name.Content + ":"; break; } + + default: + parser->sink->diagnose(name.Position, Diagnostics::invalidOperator, name.Content); + break; } - else - { - name = parser->ReadToken(TokenType::Identifier); - } - return name; } - - // A "declarator" as used in C-style languages - struct Declarator : RefObject + else { - // Different cases of declarator appear as "flavors" here - enum class Flavor - { - Name, - Pointer, - Array, - }; - Flavor flavor; - }; + name = parser->ReadToken(TokenType::Identifier); + } + return name; + } - // The most common case of declarator uses a simple name - struct NameDeclarator : Declarator + // A "declarator" as used in C-style languages + struct Declarator : RefObject + { + // Different cases of declarator appear as "flavors" here + enum class Flavor { - Token nameToken; + Name, + Pointer, + Array, }; + Flavor flavor; + }; - // A declarator that declares a pointer type - struct PointerDeclarator : Declarator - { - // location of the `*` token - CodePosition starLoc; + // The most common case of declarator uses a simple name + struct NameDeclarator : Declarator + { + Token nameToken; + }; - RefPtr<Declarator> inner; - }; + // A declarator that declares a pointer type + struct PointerDeclarator : Declarator + { + // location of the `*` token + CodePosition starLoc; - // A declarator that declares an array type - struct ArrayDeclarator : Declarator - { - RefPtr<Declarator> inner; + RefPtr<Declarator> inner; + }; - // location of the `[` token - CodePosition openBracketLoc; + // A declarator that declares an array type + struct ArrayDeclarator : Declarator + { + RefPtr<Declarator> inner; - // The expression that yields the element count, or NULL - RefPtr<ExpressionSyntaxNode> elementCountExpr; - }; + // location of the `[` token + CodePosition openBracketLoc; - // "Unwrapped" information about a declarator - struct DeclaratorInfo - { - RefPtr<ExpressionSyntaxNode> typeSpec; - Token nameToken; - RefPtr<Modifier> semantics; - RefPtr<ExpressionSyntaxNode> initializer; - }; + // The expression that yields the element count, or NULL + RefPtr<ExpressionSyntaxNode> elementCountExpr; + }; - // Add a member declaration to its container, and ensure that its - // parent link is set up correctly. - static void AddMember(RefPtr<ContainerDecl> container, RefPtr<Decl> member) + // "Unwrapped" information about a declarator + struct DeclaratorInfo + { + RefPtr<ExpressionSyntaxNode> typeSpec; + Token nameToken; + RefPtr<Modifier> semantics; + RefPtr<ExpressionSyntaxNode> initializer; + }; + + // Add a member declaration to its container, and ensure that its + // parent link is set up correctly. + static void AddMember(RefPtr<ContainerDecl> container, RefPtr<Decl> member) + { + if (container) { - if (container) - { - member->ParentDecl = container.Ptr(); - container->Members.Add(member); + member->ParentDecl = container.Ptr(); + container->Members.Add(member); - container->memberDictionaryIsValid = false; - } + container->memberDictionaryIsValid = false; } + } - static void AddMember(RefPtr<Scope> scope, RefPtr<Decl> member) + static void AddMember(RefPtr<Scope> scope, RefPtr<Decl> member) + { + if (scope) { - if (scope) - { - AddMember(scope->containerDecl, member); - } + AddMember(scope->containerDecl, member); } + } - static void parseParameterList( - Parser* parser, - RefPtr<CallableDecl> decl) - { - parser->ReadToken(TokenType::LParent); - - // Allow a declaration to use the keyword `void` for a parameter list, - // since that was required in ancient C, and continues to be supported - // in a bunc hof its derivatives even if it is a Bad Design Choice - // - // TODO: conditionalize this so we don't keep this around for "pure" - // Slang code - if( parser->LookAheadToken("void") && parser->LookAheadToken(TokenType::RParent, 1) ) - { - parser->ReadToken("void"); - parser->ReadToken(TokenType::RParent); - return; - } + static void parseParameterList( + Parser* parser, + RefPtr<CallableDecl> decl) + { + parser->ReadToken(TokenType::LParent); - while (!AdvanceIfMatch(parser, TokenType::RParent)) - { - AddMember(decl, parser->ParseParameter()); - if (AdvanceIf(parser, TokenType::RParent)) - break; - parser->ReadToken(TokenType::Comma); - } + // Allow a declaration to use the keyword `void` for a parameter list, + // since that was required in ancient C, and continues to be supported + // in a bunc hof its derivatives even if it is a Bad Design Choice + // + // TODO: conditionalize this so we don't keep this around for "pure" + // Slang code + if( parser->LookAheadToken("void") && parser->LookAheadToken(TokenType::RParent, 1) ) + { + parser->ReadToken("void"); + parser->ReadToken(TokenType::RParent); + return; } - static void ParseFuncDeclHeader( - Parser* parser, - DeclaratorInfo const& declaratorInfo, - RefPtr<FunctionSyntaxNode> decl) + while (!AdvanceIfMatch(parser, TokenType::RParent)) { - parser->PushScope(decl.Ptr()); + AddMember(decl, parser->ParseParameter()); + if (AdvanceIf(parser, TokenType::RParent)) + break; + parser->ReadToken(TokenType::Comma); + } + } - parser->FillPosition(decl.Ptr()); - decl->Position = declaratorInfo.nameToken.Position; + static void ParseFuncDeclHeader( + Parser* parser, + DeclaratorInfo const& declaratorInfo, + RefPtr<FunctionSyntaxNode> decl) + { + parser->PushScope(decl.Ptr()); - decl->Name = declaratorInfo.nameToken; - decl->ReturnType = TypeExp(declaratorInfo.typeSpec); - parseParameterList(parser, decl); - ParseOptSemantics(parser, decl.Ptr()); - } + parser->FillPosition(decl.Ptr()); + decl->Position = declaratorInfo.nameToken.Position; - static RefPtr<Decl> ParseFuncDecl( - Parser* parser, - ContainerDecl* /*containerDecl*/, - DeclaratorInfo const& declaratorInfo) - { - RefPtr<FunctionSyntaxNode> decl = new FunctionSyntaxNode(); - ParseFuncDeclHeader(parser, declaratorInfo, decl); + decl->Name = declaratorInfo.nameToken; + decl->ReturnType = TypeExp(declaratorInfo.typeSpec); + parseParameterList(parser, decl); + ParseOptSemantics(parser, decl.Ptr()); + } - if (AdvanceIf(parser, TokenType::Semicolon)) - { - // empty body - } - else - { - decl->Body = parser->ParseBlockStatement(); - } + static RefPtr<Decl> ParseFuncDecl( + Parser* parser, + ContainerDecl* /*containerDecl*/, + DeclaratorInfo const& declaratorInfo) + { + RefPtr<FunctionSyntaxNode> decl = new FunctionSyntaxNode(); + ParseFuncDeclHeader(parser, declaratorInfo, decl); - parser->PopScope(); - return decl; + if (AdvanceIf(parser, TokenType::Semicolon)) + { + // empty body } - - static RefPtr<VarDeclBase> CreateVarDeclForContext( - ContainerDecl* containerDecl ) + else { - if (dynamic_cast<StructSyntaxNode*>(containerDecl) || dynamic_cast<ClassSyntaxNode*>(containerDecl)) - { - return new StructField(); - } - else if (dynamic_cast<CallableDecl*>(containerDecl)) - { - return new ParameterSyntaxNode(); - } - else - { - return new Variable(); - } + decl->Body = parser->ParseBlockStatement(); } - // Add modifiers to the end of the modifier list for a declaration - void AddModifiers(Decl* decl, RefPtr<Modifier> modifiers) - { - if (!modifiers) - return; + parser->PopScope(); + return decl; + } - RefPtr<Modifier>* link = &decl->modifiers.first; - while (*link) - { - link = &(*link)->next; - } - *link = modifiers; + static RefPtr<VarDeclBase> CreateVarDeclForContext( + ContainerDecl* containerDecl ) + { + if (dynamic_cast<StructSyntaxNode*>(containerDecl) || dynamic_cast<ClassSyntaxNode*>(containerDecl)) + { + return new StructField(); + } + else if (dynamic_cast<CallableDecl*>(containerDecl)) + { + return new ParameterSyntaxNode(); } + else + { + return new Variable(); + } + } + // Add modifiers to the end of the modifier list for a declaration + void AddModifiers(Decl* decl, RefPtr<Modifier> modifiers) + { + if (!modifiers) + return; - static String GenerateName(Parser* /*parser*/, String const& base) + RefPtr<Modifier>* link = &decl->modifiers.first; + while (*link) { - // TODO: somehow mangle the name to avoid clashes - return base; + link = &(*link)->next; } + *link = modifiers; + } + + + static String GenerateName(Parser* /*parser*/, String const& base) + { + // TODO: somehow mangle the name to avoid clashes + return base; + } + + static String GenerateName(Parser* parser) + { + return GenerateName(parser, "_anonymous_" + String(parser->anonymousCounter++)); + } + + + // Set up a variable declaration based on what we saw in its declarator... + static void CompleteVarDecl( + Parser* parser, + RefPtr<VarDeclBase> decl, + DeclaratorInfo const& declaratorInfo) + { + parser->FillPosition(decl.Ptr()); - static String GenerateName(Parser* parser) + if( declaratorInfo.nameToken.Type == TokenType::Unknown ) { - return GenerateName(parser, "_anonymous_" + String(parser->anonymousCounter++)); + // HACK(tfoley): we always give a name, even if the declarator didn't include one... :( + decl->Name.Content = GenerateName(parser); } + else + { + decl->Position = declaratorInfo.nameToken.Position; + decl->Name = declaratorInfo.nameToken; + } + decl->Type = TypeExp(declaratorInfo.typeSpec); + AddModifiers(decl.Ptr(), declaratorInfo.semantics); - // Set up a variable declaration based on what we saw in its declarator... - static void CompleteVarDecl( - Parser* parser, - RefPtr<VarDeclBase> decl, - DeclaratorInfo const& declaratorInfo) - { - parser->FillPosition(decl.Ptr()); + decl->Expr = declaratorInfo.initializer; + } + + static RefPtr<Declarator> ParseDeclarator(Parser* parser); - if( declaratorInfo.nameToken.Type == TokenType::Unknown ) + static RefPtr<Declarator> ParseDirectAbstractDeclarator( + Parser* parser) + { + RefPtr<Declarator> declarator; + switch( parser->tokenReader.PeekTokenType() ) + { + case TokenType::Identifier: { - // HACK(tfoley): we always give a name, even if the declarator didn't include one... :( - decl->Name.Content = GenerateName(parser); + auto nameDeclarator = new NameDeclarator(); + nameDeclarator->flavor = Declarator::Flavor::Name; + nameDeclarator->nameToken = ParseDeclName(parser); + declarator = nameDeclarator; } - else + break; + + case TokenType::LParent: { - decl->Position = declaratorInfo.nameToken.Position; - decl->Name = declaratorInfo.nameToken; + // Note(tfoley): This is a point where disambiguation is required. + // We could be looking at an abstract declarator for a function-type + // parameter: + // + // void F( int(int) ); + // + // Or we could be looking at the use of parenthesese in an ordinary + // declarator: + // + // void (*f)(int); + // + // The difference really doesn't matter right now, but we err in + // the direction of assuming the second case. + parser->ReadToken(TokenType::LParent); + declarator = ParseDeclarator(parser); + parser->ReadToken(TokenType::RParent); } - decl->Type = TypeExp(declaratorInfo.typeSpec); - - AddModifiers(decl.Ptr(), declaratorInfo.semantics); + break; - decl->Expr = declaratorInfo.initializer; + default: + // an empty declarator is allowed + return nullptr; } - static RefPtr<Declarator> ParseDeclarator(Parser* parser); - - static RefPtr<Declarator> ParseDirectAbstractDeclarator( - Parser* parser) + // postifx additions + for( ;;) { - RefPtr<Declarator> declarator; switch( parser->tokenReader.PeekTokenType() ) { - case TokenType::Identifier: + case TokenType::LBracket: { - auto nameDeclarator = new NameDeclarator(); - nameDeclarator->flavor = Declarator::Flavor::Name; - nameDeclarator->nameToken = ParseDeclName(parser); - declarator = nameDeclarator; + auto arrayDeclarator = new ArrayDeclarator(); + arrayDeclarator->openBracketLoc = parser->tokenReader.PeekLoc(); + arrayDeclarator->flavor = Declarator::Flavor::Array; + arrayDeclarator->inner = declarator; + + parser->ReadToken(TokenType::LBracket); + if( parser->tokenReader.PeekTokenType() != TokenType::RBracket ) + { + arrayDeclarator->elementCountExpr = parser->ParseExpression(); + } + parser->ReadToken(TokenType::RBracket); + + declarator = arrayDeclarator; + continue; } - break; case TokenType::LParent: - { - // Note(tfoley): This is a point where disambiguation is required. - // We could be looking at an abstract declarator for a function-type - // parameter: - // - // void F( int(int) ); - // - // Or we could be looking at the use of parenthesese in an ordinary - // declarator: - // - // void (*f)(int); - // - // The difference really doesn't matter right now, but we err in - // the direction of assuming the second case. - parser->ReadToken(TokenType::LParent); - declarator = ParseDeclarator(parser); - parser->ReadToken(TokenType::RParent); - } break; default: - // an empty declarator is allowed - return nullptr; + break; } - // postifx additions - for( ;;) + break; + } + + return declarator; + } + + // Parse a declarator (or at least as much of one as we support) + static RefPtr<Declarator> ParseDeclarator( + Parser* parser) + { + if( parser->tokenReader.PeekTokenType() == TokenType::OpMul ) + { + auto ptrDeclarator = new PointerDeclarator(); + ptrDeclarator->starLoc = parser->tokenReader.PeekLoc(); + ptrDeclarator->flavor = Declarator::Flavor::Pointer; + + parser->ReadToken(TokenType::OpMul); + + // TODO(tfoley): allow qualifiers like `const` here? + + ptrDeclarator->inner = ParseDeclarator(parser); + return ptrDeclarator; + } + else + { + return ParseDirectAbstractDeclarator(parser); + } + } + + // A declarator plus optional semantics and initializer + struct InitDeclarator + { + RefPtr<Declarator> declarator; + RefPtr<Modifier> semantics; + RefPtr<ExpressionSyntaxNode> initializer; + }; + + // Parse a declarator plus optional semantics + static InitDeclarator ParseSemanticDeclarator( + Parser* parser) + { + InitDeclarator result; + result.declarator = ParseDeclarator(parser); + result.semantics = ParseOptSemantics(parser); + return result; + } + + // Parse a declarator plus optional semantics and initializer + static InitDeclarator ParseInitDeclarator( + Parser* parser) + { + InitDeclarator result = ParseSemanticDeclarator(parser); + if (AdvanceIf(parser, TokenType::OpAssign)) + { + result.initializer = parser->ParseInitExpr(); + } + return result; + } + + static void UnwrapDeclarator( + RefPtr<Declarator> declarator, + DeclaratorInfo* ioInfo) + { + while( declarator ) + { + switch(declarator->flavor) { - switch( parser->tokenReader.PeekTokenType() ) + case Declarator::Flavor::Name: { - case TokenType::LBracket: - { - auto arrayDeclarator = new ArrayDeclarator(); - arrayDeclarator->openBracketLoc = parser->tokenReader.PeekLoc(); - arrayDeclarator->flavor = Declarator::Flavor::Array; - arrayDeclarator->inner = declarator; + auto nameDeclarator = (NameDeclarator*) declarator.Ptr(); + ioInfo->nameToken = nameDeclarator->nameToken; + return; + } + break; - parser->ReadToken(TokenType::LBracket); - if( parser->tokenReader.PeekTokenType() != TokenType::RBracket ) - { - arrayDeclarator->elementCountExpr = parser->ParseExpression(); - } - parser->ReadToken(TokenType::RBracket); + case Declarator::Flavor::Pointer: + { + auto ptrDeclarator = (PointerDeclarator*) declarator.Ptr(); - declarator = arrayDeclarator; - continue; - } + // TODO(tfoley): we don't support pointers for now + // ioInfo->typeSpec = new PointerTypeExpr(ioInfo->typeSpec); - case TokenType::LParent: - break; + declarator = ptrDeclarator->inner; + } + break; - default: - break; + case Declarator::Flavor::Array: + { + // TODO(tfoley): we don't support pointers for now + auto arrayDeclarator = (ArrayDeclarator*) declarator.Ptr(); + + auto arrayTypeExpr = new IndexExpressionSyntaxNode(); + arrayTypeExpr->Position = arrayDeclarator->openBracketLoc; + arrayTypeExpr->BaseExpression = ioInfo->typeSpec; + arrayTypeExpr->IndexExpression = arrayDeclarator->elementCountExpr; + ioInfo->typeSpec = arrayTypeExpr; + + declarator = arrayDeclarator->inner; } + break; + default: + SLANG_UNREACHABLE("all cases handled"); break; } - - return declarator; } + } - // Parse a declarator (or at least as much of one as we support) - static RefPtr<Declarator> ParseDeclarator( - Parser* parser) - { - if( parser->tokenReader.PeekTokenType() == TokenType::OpMul ) - { - auto ptrDeclarator = new PointerDeclarator(); - ptrDeclarator->starLoc = parser->tokenReader.PeekLoc(); - ptrDeclarator->flavor = Declarator::Flavor::Pointer; + static void UnwrapDeclarator( + InitDeclarator const& initDeclarator, + DeclaratorInfo* ioInfo) + { + UnwrapDeclarator(initDeclarator.declarator, ioInfo); + ioInfo->semantics = initDeclarator.semantics; + ioInfo->initializer = initDeclarator.initializer; + } - parser->ReadToken(TokenType::OpMul); + // Either a single declaration, or a group of them + struct DeclGroupBuilder + { + CodePosition startPosition; + RefPtr<Decl> decl; + RefPtr<DeclGroup> group; - // TODO(tfoley): allow qualifiers like `const` here? + // Add a new declaration to the potential group + void addDecl( + RefPtr<Decl> newDecl) + { + assert(newDecl); + + if( decl ) + { + group = new DeclGroup(); + group->Position = startPosition; + group->decls.Add(decl); + decl = nullptr; + } - ptrDeclarator->inner = ParseDeclarator(parser); - return ptrDeclarator; + if( group ) + { + group->decls.Add(newDecl); } else { - return ParseDirectAbstractDeclarator(parser); + decl = newDecl; } } - // A declarator plus optional semantics and initializer - struct InitDeclarator + RefPtr<DeclBase> getResult() { - RefPtr<Declarator> declarator; - RefPtr<Modifier> semantics; - RefPtr<ExpressionSyntaxNode> initializer; - }; + if(group) return group; + return decl; + } + }; + + // Pares an argument to an application of a generic + RefPtr<ExpressionSyntaxNode> ParseGenericArg(Parser* parser) + { + return parser->ParseArgExpr(); + } + + // Create a type expression that will refer to the given declaration + static RefPtr<ExpressionSyntaxNode> + createDeclRefType(Parser* parser, RefPtr<Decl> decl) + { + // For now we just construct an expression that + // will look up the given declaration by name. + // + // TODO: do this better, e.g. by filling in the `declRef` field directly + + auto expr = new VarExpressionSyntaxNode(); + expr->scope = parser->currentScope.Ptr(); + expr->Position = decl->getNameToken().Position; + expr->Variable = decl->getName(); + return expr; + } - // Parse a declarator plus optional semantics - static InitDeclarator ParseSemanticDeclarator( - Parser* parser) + // Representation for a parsed type specifier, which might + // include a declaration (e.g., of a `struct` type) + struct TypeSpec + { + // If the type-spec declared something, then put it here + RefPtr<Decl> decl; + + // Put the resulting expression (which should evaluate to a type) here + RefPtr<ExpressionSyntaxNode> expr; + }; + + static TypeSpec + parseTypeSpec(Parser* parser) + { + TypeSpec typeSpec; + + // We may see a `struct` type specified here, and need to act accordingly + // + // TODO(tfoley): Handle the case where the user is just using `struct` + // as a way to name an existing struct "tag" (e.g., `struct Foo foo;`) + // + if( parser->LookAheadToken("struct") ) { - InitDeclarator result; - result.declarator = ParseDeclarator(parser); - result.semantics = ParseOptSemantics(parser); - return result; + auto decl = parser->ParseStruct(); + typeSpec.decl = decl; + typeSpec.expr = createDeclRefType(parser, decl); + return typeSpec; + } + else if( parser->LookAheadToken("class") ) + { + auto decl = parser->ParseClass(); + typeSpec.decl = decl; + typeSpec.expr = createDeclRefType(parser, decl); + return typeSpec; } - // Parse a declarator plus optional semantics and initializer - static InitDeclarator ParseInitDeclarator( - Parser* parser) + Token typeName = parser->ReadToken(TokenType::Identifier); + + auto basicType = new VarExpressionSyntaxNode(); + basicType->scope = parser->currentScope.Ptr(); + basicType->Position = typeName.Position; + basicType->Variable = typeName.Content; + + RefPtr<ExpressionSyntaxNode> typeExpr = basicType; + + if (parser->LookAheadToken(TokenType::OpLess)) { - InitDeclarator result = ParseSemanticDeclarator(parser); - if (AdvanceIf(parser, TokenType::OpAssign)) + RefPtr<GenericAppExpr> gtype = new GenericAppExpr(); + parser->FillPosition(gtype.Ptr()); // set up scope for lookup + gtype->Position = typeName.Position; + gtype->FunctionExpr = typeExpr; + parser->ReadToken(TokenType::OpLess); + parser->genericDepth++; + // For now assume all generics have at least one argument + gtype->Arguments.Add(ParseGenericArg(parser)); + while (AdvanceIf(parser, TokenType::Comma)) { - result.initializer = parser->ParseInitExpr(); + gtype->Arguments.Add(ParseGenericArg(parser)); } - return result; + parser->genericDepth--; + parser->ReadToken(TokenType::OpGreater); + typeExpr = gtype; } - static void UnwrapDeclarator( - RefPtr<Declarator> declarator, - DeclaratorInfo* ioInfo) - { - while( declarator ) - { - switch(declarator->flavor) - { - case Declarator::Flavor::Name: - { - auto nameDeclarator = (NameDeclarator*) declarator.Ptr(); - ioInfo->nameToken = nameDeclarator->nameToken; - return; - } - break; + typeSpec.expr = typeExpr; + return typeSpec; + } - case Declarator::Flavor::Pointer: - { - auto ptrDeclarator = (PointerDeclarator*) declarator.Ptr(); - // TODO(tfoley): we don't support pointers for now - // ioInfo->typeSpec = new PointerTypeExpr(ioInfo->typeSpec); + static RefPtr<DeclBase> ParseDeclaratorDecl( + Parser* parser, + ContainerDecl* containerDecl) + { + CodePosition startPosition = parser->tokenReader.PeekLoc(); - declarator = ptrDeclarator->inner; - } - break; + auto typeSpec = parseTypeSpec(parser); - case Declarator::Flavor::Array: - { - // TODO(tfoley): we don't support pointers for now - auto arrayDeclarator = (ArrayDeclarator*) declarator.Ptr(); + // We may need to build up multiple declarations in a group, + // but the common case will be when we have just a single + // declaration + DeclGroupBuilder declGroupBuilder; + declGroupBuilder.startPosition = startPosition; - auto arrayTypeExpr = new IndexExpressionSyntaxNode(); - arrayTypeExpr->Position = arrayDeclarator->openBracketLoc; - arrayTypeExpr->BaseExpression = ioInfo->typeSpec; - arrayTypeExpr->IndexExpression = arrayDeclarator->elementCountExpr; - ioInfo->typeSpec = arrayTypeExpr; + // The type specifier may include a declaration. E.g., + // it might declare a `struct` type. + if(typeSpec.decl) + declGroupBuilder.addDecl(typeSpec.decl); - declarator = arrayDeclarator->inner; - } - break; + if( AdvanceIf(parser, TokenType::Semicolon) ) + { + // No actual variable is being declared here, but + // that might not be an error. - default: - SLANG_UNREACHABLE("all cases handled"); - break; - } + auto result = declGroupBuilder.getResult(); + if( !result ) + { + parser->sink->diagnose(startPosition, Diagnostics::declarationDidntDeclareAnything); } + return result; } - static void UnwrapDeclarator( - InitDeclarator const& initDeclarator, - DeclaratorInfo* ioInfo) - { - UnwrapDeclarator(initDeclarator.declarator, ioInfo); - ioInfo->semantics = initDeclarator.semantics; - ioInfo->initializer = initDeclarator.initializer; - } - // Either a single declaration, or a group of them - struct DeclGroupBuilder - { - CodePosition startPosition; - RefPtr<Decl> decl; - RefPtr<DeclGroup> group; + InitDeclarator initDeclarator = ParseInitDeclarator(parser); - // Add a new declaration to the potential group - void addDecl( - RefPtr<Decl> newDecl) - { - assert(newDecl); - - if( decl ) - { - group = new DeclGroup(); - group->Position = startPosition; - group->decls.Add(decl); - decl = nullptr; - } + DeclaratorInfo declaratorInfo; + declaratorInfo.typeSpec = typeSpec.expr; - if( group ) - { - group->decls.Add(newDecl); - } - else - { - decl = newDecl; - } - } - RefPtr<DeclBase> getResult() - { - if(group) return group; - return decl; - } - }; + // Rather than parse function declarators properly for now, + // we'll just do a quick disambiguation here. This won't + // matter unless we actually decide to support function-type parameters, + // using C syntax. + // + if( parser->tokenReader.PeekTokenType() == TokenType::LParent - // Pares an argument to an application of a generic - RefPtr<ExpressionSyntaxNode> ParseGenericArg(Parser* parser) + // Only parse as a function if we didn't already see mutually-exclusive + // constructs when parsing the declarator. + && !initDeclarator.initializer + && !initDeclarator.semantics) { - return parser->ParseArgExpr(); + // Looks like a function, so parse it like one. + UnwrapDeclarator(initDeclarator, &declaratorInfo); + return ParseFuncDecl(parser, containerDecl, declaratorInfo); } - // Create a type expression that will refer to the given declaration - static RefPtr<ExpressionSyntaxNode> - createDeclRefType(Parser* parser, RefPtr<Decl> decl) + // Otherwise we are looking at a variable declaration, which could be one in a sequence... + + if( AdvanceIf(parser, TokenType::Semicolon) ) { - // For now we just construct an expression that - // will look up the given declaration by name. - // - // TODO: do this better, e.g. by filling in the `declRef` field directly + // easy case: we only had a single declaration! + UnwrapDeclarator(initDeclarator, &declaratorInfo); + RefPtr<VarDeclBase> firstDecl = CreateVarDeclForContext(containerDecl); + CompleteVarDecl(parser, firstDecl, declaratorInfo); - auto expr = new VarExpressionSyntaxNode(); - expr->scope = parser->currentScope.Ptr(); - expr->Position = decl->getNameToken().Position; - expr->Variable = decl->getName(); - return expr; + declGroupBuilder.addDecl(firstDecl); + return declGroupBuilder.getResult(); } - // Representation for a parsed type specifier, which might - // include a declaration (e.g., of a `struct` type) - struct TypeSpec - { - // If the type-spec declared something, then put it here - RefPtr<Decl> decl; + // Otherwise we have multiple declarations in a sequence, and these + // declarations need to somehow share both the type spec and modifiers. + // + // If there are any errors in the type specifier, we only want to hear + // about it once, so we need to share structure rather than just + // clone syntax. - // Put the resulting expression (which should evaluate to a type) here - RefPtr<ExpressionSyntaxNode> expr; - }; + auto sharedTypeSpec = new SharedTypeExpr(); + sharedTypeSpec->Position = typeSpec.expr->Position; + sharedTypeSpec->base = TypeExp(typeSpec.expr); - static TypeSpec - parseTypeSpec(Parser* parser) + for(;;) { - TypeSpec typeSpec; + declaratorInfo.typeSpec = sharedTypeSpec; + UnwrapDeclarator(initDeclarator, &declaratorInfo); - // We may see a `struct` type specified here, and need to act accordingly - // - // TODO(tfoley): Handle the case where the user is just using `struct` - // as a way to name an existing struct "tag" (e.g., `struct Foo foo;`) - // - if( parser->LookAheadToken("struct") ) - { - auto decl = parser->ParseStruct(); - typeSpec.decl = decl; - typeSpec.expr = createDeclRefType(parser, decl); - return typeSpec; - } - else if( parser->LookAheadToken("class") ) - { - auto decl = parser->ParseClass(); - typeSpec.decl = decl; - typeSpec.expr = createDeclRefType(parser, decl); - return typeSpec; - } + RefPtr<VarDeclBase> varDecl = CreateVarDeclForContext(containerDecl); + CompleteVarDecl(parser, varDecl, declaratorInfo); - Token typeName = parser->ReadToken(TokenType::Identifier); + declGroupBuilder.addDecl(varDecl); + + // end of the sequence? + if(AdvanceIf(parser, TokenType::Semicolon)) + return declGroupBuilder.getResult(); - auto basicType = new VarExpressionSyntaxNode(); - basicType->scope = parser->currentScope.Ptr(); - basicType->Position = typeName.Position; - basicType->Variable = typeName.Content; + // ad-hoc recovery, to avoid infinite loops + if( parser->isRecovering ) + { + parser->ReadToken(TokenType::Semicolon); + return declGroupBuilder.getResult(); + } - RefPtr<ExpressionSyntaxNode> typeExpr = basicType; + // Let's default to assuming that a missing `,` + // indicates the end of a declaration, + // where a `;` would be expected, and not + // a continuation of this declaration, where + // a `,` would be expected (this is tailoring + // the diagnostic message a bit). + // + // TODO: a more advanced heuristic here might + // look at whether the next token is on the + // same line, to predict whether `,` or `;` + // would be more likely... - if (parser->LookAheadToken(TokenType::OpLess)) + if (!AdvanceIf(parser, TokenType::Comma)) { - RefPtr<GenericAppExpr> gtype = new GenericAppExpr(); - parser->FillPosition(gtype.Ptr()); // set up scope for lookup - gtype->Position = typeName.Position; - gtype->FunctionExpr = typeExpr; - parser->ReadToken(TokenType::OpLess); - parser->genericDepth++; - // For now assume all generics have at least one argument - gtype->Arguments.Add(ParseGenericArg(parser)); - while (AdvanceIf(parser, TokenType::Comma)) - { - gtype->Arguments.Add(ParseGenericArg(parser)); - } - parser->genericDepth--; - parser->ReadToken(TokenType::OpGreater); - typeExpr = gtype; + parser->ReadToken(TokenType::Semicolon); + return declGroupBuilder.getResult(); } - typeSpec.expr = typeExpr; - return typeSpec; + // expect another variable declaration... + initDeclarator = ParseInitDeclarator(parser); } + } + // + // layout-semantic ::= (register | packoffset) '(' register-name component-mask? ')' + // register-name ::= identifier + // component-mask ::= '.' identifier + // + static void ParseHLSLLayoutSemantic( + Parser* parser, + HLSLLayoutSemantic* semantic) + { + semantic->name = parser->ReadToken(TokenType::Identifier); - static RefPtr<DeclBase> ParseDeclaratorDecl( - Parser* parser, - ContainerDecl* containerDecl) + parser->ReadToken(TokenType::LParent); + semantic->registerName = parser->ReadToken(TokenType::Identifier); + if (AdvanceIf(parser, TokenType::Dot)) { - CodePosition startPosition = parser->tokenReader.PeekLoc(); + semantic->componentMask = parser->ReadToken(TokenType::Identifier); + } + parser->ReadToken(TokenType::RParent); + } - auto typeSpec = parseTypeSpec(parser); + // + // semantic ::= identifier ( '(' args ')' )? + // + static RefPtr<Modifier> ParseSemantic( + Parser* parser) + { + if (parser->LookAheadToken("register")) + { + RefPtr<HLSLRegisterSemantic> semantic = new HLSLRegisterSemantic(); + ParseHLSLLayoutSemantic(parser, semantic.Ptr()); + return semantic; + } + else if (parser->LookAheadToken("packoffset")) + { + RefPtr<HLSLPackOffsetSemantic> semantic = new HLSLPackOffsetSemantic(); + ParseHLSLLayoutSemantic(parser, semantic.Ptr()); + return semantic; + } + else + { + RefPtr<HLSLSimpleSemantic> semantic = new HLSLSimpleSemantic(); + semantic->name = parser->ReadToken(TokenType::Identifier); + return semantic; + } + } - // We may need to build up multiple declarations in a group, - // but the common case will be when we have just a single - // declaration - DeclGroupBuilder declGroupBuilder; - declGroupBuilder.startPosition = startPosition; + // + // opt-semantics ::= (':' semantic)* + // + static RefPtr<Modifier> ParseOptSemantics( + Parser* parser) + { + if (!AdvanceIf(parser, TokenType::Colon)) + return nullptr; - // The type specifier may include a declaration. E.g., - // it might declare a `struct` type. - if(typeSpec.decl) - declGroupBuilder.addDecl(typeSpec.decl); + RefPtr<Modifier> result; + RefPtr<Modifier>* link = &result; + assert(!*link); - if( AdvanceIf(parser, TokenType::Semicolon) ) + for (;;) + { + RefPtr<Modifier> semantic = ParseSemantic(parser); + if (semantic) { - // No actual variable is being declared here, but - // that might not be an error. + *link = semantic; + link = &semantic->next; + } - auto result = declGroupBuilder.getResult(); - if( !result ) - { - parser->sink->diagnose(startPosition, Diagnostics::declarationDidntDeclareAnything); - } + switch (parser->tokenReader.PeekTokenType()) + { + case TokenType::LBrace: + case TokenType::Semicolon: + case TokenType::Comma: + case TokenType::RParent: + case TokenType::EndOfFile: return result; + + default: + break; } + parser->ReadToken(TokenType::Colon); + } - InitDeclarator initDeclarator = ParseInitDeclarator(parser); + } - DeclaratorInfo declaratorInfo; - declaratorInfo.typeSpec = typeSpec.expr; + static void ParseOptSemantics( + Parser* parser, + Decl* decl) + { + AddModifiers(decl, ParseOptSemantics(parser)); + } - // Rather than parse function declarators properly for now, - // we'll just do a quick disambiguation here. This won't - // matter unless we actually decide to support function-type parameters, - // using C syntax. - // - if( parser->tokenReader.PeekTokenType() == TokenType::LParent + static RefPtr<Decl> ParseHLSLBufferDecl( + Parser* parser) + { + // An HLSL declaration of a constant buffer like this: + // + // cbuffer Foo : register(b0) { int a; float b; }; + // + // is treated as syntax sugar for a type declaration + // and then a global variable declaration using that type: + // + // struct $anonymous { int a; float b; }; + // ConstantBuffer<$anonymous> Foo; + // + // where `$anonymous` is a fresh name, and the variable + // declaration is made to be "transparent" so that lookup + // will see through it to the members inside. - // Only parse as a function if we didn't already see mutually-exclusive - // constructs when parsing the declarator. - && !initDeclarator.initializer - && !initDeclarator.semantics) - { - // Looks like a function, so parse it like one. - UnwrapDeclarator(initDeclarator, &declaratorInfo); - return ParseFuncDecl(parser, containerDecl, declaratorInfo); - } + // We first look at the declaration keywrod to determine + // the type of buffer to declare: + String bufferWrapperTypeName; + CodePosition bufferWrapperTypeNamePos = parser->tokenReader.PeekLoc(); + if (AdvanceIf(parser, "cbuffer")) + { + bufferWrapperTypeName = "ConstantBuffer"; + } + else if (AdvanceIf(parser, "tbuffer")) + { + bufferWrapperTypeName = "TextureBuffer"; + } + else + { + Unexpected(parser); + } - // Otherwise we are looking at a variable declaration, which could be one in a sequence... + // We are going to represent each buffer as a pair of declarations. + // The first is a type declaration that holds all the members, while + // the second is a variable declaration that uses the buffer type. + RefPtr<StructSyntaxNode> bufferDataTypeDecl = new StructSyntaxNode(); + RefPtr<Variable> bufferVarDecl = new Variable(); - if( AdvanceIf(parser, TokenType::Semicolon) ) - { - // easy case: we only had a single declaration! - UnwrapDeclarator(initDeclarator, &declaratorInfo); - RefPtr<VarDeclBase> firstDecl = CreateVarDeclForContext(containerDecl); - CompleteVarDecl(parser, firstDecl, declaratorInfo); + // Both declarations will have a location that points to the name + parser->FillPosition(bufferDataTypeDecl.Ptr()); + parser->FillPosition(bufferVarDecl.Ptr()); - declGroupBuilder.addDecl(firstDecl); - return declGroupBuilder.getResult(); - } + auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); - // Otherwise we have multiple declarations in a sequence, and these - // declarations need to somehow share both the type spec and modifiers. - // - // If there are any errors in the type specifier, we only want to hear - // about it once, so we need to share structure rather than just - // clone syntax. + // Attach the reflection name to the block so we can use it + auto reflectionNameModifier = new ParameterBlockReflectionName(); + reflectionNameModifier->nameToken = reflectionNameToken; + addModifier(bufferVarDecl, reflectionNameModifier); - auto sharedTypeSpec = new SharedTypeExpr(); - sharedTypeSpec->Position = typeSpec.expr->Position; - sharedTypeSpec->base = TypeExp(typeSpec.expr); + // Both the buffer variable and its type need to have names generated + bufferVarDecl->Name.Content = GenerateName(parser, "SLANG_constantBuffer_" + reflectionNameToken.Content); + bufferDataTypeDecl->Name.Content = GenerateName(parser, "SLANG_ConstantBuffer_" + reflectionNameToken.Content); - for(;;) - { - declaratorInfo.typeSpec = sharedTypeSpec; - UnwrapDeclarator(initDeclarator, &declaratorInfo); + addModifier(bufferDataTypeDecl, new ImplicitParameterBlockElementTypeModifier()); + addModifier(bufferVarDecl, new ImplicitParameterBlockVariableModifier()); - RefPtr<VarDeclBase> varDecl = CreateVarDeclForContext(containerDecl); - CompleteVarDecl(parser, varDecl, declaratorInfo); + // TODO(tfoley): We end up constructing unchecked syntax here that + // is expected to type check into the right form, but it might be + // cleaner to have a more explicit desugaring pass where we parse + // these constructs directly into the AST and *then* desugar them. - declGroupBuilder.addDecl(varDecl); + // Construct a type expression to reference the buffer data type + auto bufferDataTypeExpr = new VarExpressionSyntaxNode(); + bufferDataTypeExpr->Position = bufferDataTypeDecl->Position; + bufferDataTypeExpr->Variable = bufferDataTypeDecl->Name.Content; + bufferDataTypeExpr->scope = parser->currentScope.Ptr(); - // end of the sequence? - if(AdvanceIf(parser, TokenType::Semicolon)) - return declGroupBuilder.getResult(); + // Construct a type exrpession to reference the type constructor + auto bufferWrapperTypeExpr = new VarExpressionSyntaxNode(); + bufferWrapperTypeExpr->Position = bufferWrapperTypeNamePos; + bufferWrapperTypeExpr->Variable = bufferWrapperTypeName; - // ad-hoc recovery, to avoid infinite loops - if( parser->isRecovering ) - { - parser->ReadToken(TokenType::Semicolon); - return declGroupBuilder.getResult(); - } + // Always need to look this up in the outer scope, + // so that it won't collide with, e.g., a local variable called `ConstantBuffer` + bufferWrapperTypeExpr->scope = parser->outerScope; - // Let's default to assuming that a missing `,` - // indicates the end of a declaration, - // where a `;` would be expected, and not - // a continuation of this declaration, where - // a `,` would be expected (this is tailoring - // the diagnostic message a bit). - // - // TODO: a more advanced heuristic here might - // look at whether the next token is on the - // same line, to predict whether `,` or `;` - // would be more likely... + // Construct a type expression that represents the type for the variable, + // which is the wrapper type applied to the data type + auto bufferVarTypeExpr = new GenericAppExpr(); + bufferVarTypeExpr->Position = bufferVarDecl->Position; + bufferVarTypeExpr->FunctionExpr = bufferWrapperTypeExpr; + bufferVarTypeExpr->Arguments.Add(bufferDataTypeExpr); - if (!AdvanceIf(parser, TokenType::Comma)) - { - parser->ReadToken(TokenType::Semicolon); - return declGroupBuilder.getResult(); - } + bufferVarDecl->Type.exp = bufferVarTypeExpr; - // expect another variable declaration... - initDeclarator = ParseInitDeclarator(parser); - } - } + // Any semantics applied to the bufer declaration are taken as applying + // to the variable instead. + ParseOptSemantics(parser, bufferVarDecl.Ptr()); - // - // layout-semantic ::= (register | packoffset) '(' register-name component-mask? ')' - // register-name ::= identifier - // component-mask ::= '.' identifier - // - static void ParseHLSLLayoutSemantic( - Parser* parser, - HLSLLayoutSemantic* semantic) - { - semantic->name = parser->ReadToken(TokenType::Identifier); + // The declarations in the body belong to the data type. + parseAggTypeDeclBody(parser, bufferDataTypeDecl.Ptr()); - parser->ReadToken(TokenType::LParent); - semantic->registerName = parser->ReadToken(TokenType::Identifier); - if (AdvanceIf(parser, TokenType::Dot)) - { - semantic->componentMask = parser->ReadToken(TokenType::Identifier); - } - parser->ReadToken(TokenType::RParent); - } + // All HLSL buffer declarations are "transparent" in that their + // members are implicitly made visible in the parent scope. + // We achieve this by applying the transparent modifier to the variable. + auto transparentModifier = new TransparentModifier(); + transparentModifier->next = bufferVarDecl->modifiers.first; + bufferVarDecl->modifiers.first = transparentModifier; + // Because we are constructing two declarations, we have a thorny + // issue that were are only supposed to return one. + // For now we handle this by adding the type declaration to + // the current scope manually, and then returning the variable + // declaration. // - // semantic ::= identifier ( '(' args ')' )? - // - static RefPtr<Modifier> ParseSemantic( - Parser* parser) + // Note: this means that any modifiers that have already been parsed + // will get attached to the variable declaration, not the type. + // There might be cases where we need to shuffle things around. + + AddMember(parser->currentScope, bufferDataTypeDecl); + + return bufferVarDecl; + } + + static void removeModifier( + Modifiers& modifiers, + RefPtr<Modifier> modifier) + { + RefPtr<Modifier>* link = &modifiers.first; + while (*link) { - if (parser->LookAheadToken("register")) - { - RefPtr<HLSLRegisterSemantic> semantic = new HLSLRegisterSemantic(); - ParseHLSLLayoutSemantic(parser, semantic.Ptr()); - return semantic; - } - else if (parser->LookAheadToken("packoffset")) + if (*link == modifier) { - RefPtr<HLSLPackOffsetSemantic> semantic = new HLSLPackOffsetSemantic(); - ParseHLSLLayoutSemantic(parser, semantic.Ptr()); - return semantic; - } - else - { - RefPtr<HLSLSimpleSemantic> semantic = new HLSLSimpleSemantic(); - semantic->name = parser->ReadToken(TokenType::Identifier); - return semantic; + *link = (*link)->next; + return; } + + link = &(*link)->next; } + } + static RefPtr<Decl> parseGLSLBlockDecl( + Parser* parser, + Modifiers& modifiers) + { + // An GLSL block like this: // - // opt-semantics ::= (':' semantic)* + // uniform Foo { int a; float b; } foo; // - static RefPtr<Modifier> ParseOptSemantics( - Parser* parser) - { - if (!AdvanceIf(parser, TokenType::Colon)) - return nullptr; - - RefPtr<Modifier> result; - RefPtr<Modifier>* link = &result; - assert(!*link); - - for (;;) - { - RefPtr<Modifier> semantic = ParseSemantic(parser); - if (semantic) - { - *link = semantic; - link = &semantic->next; - } + // is treated as syntax sugar for a type declaration + // and then a global variable declaration using that type: + // + // struct $anonymous { int a; float b; }; + // Block<$anonymous> foo; + // + // where `$anonymous` is a fresh name. + // + // If a "local name" like `foo` is not given, then + // we make the declaration "transparent" so that lookup + // will see through it to the members inside. - switch (parser->tokenReader.PeekTokenType()) - { - case TokenType::LBrace: - case TokenType::Semicolon: - case TokenType::Comma: - case TokenType::RParent: - case TokenType::EndOfFile: - return result; - default: - break; - } + CodePosition pos = parser->tokenReader.PeekLoc(); - parser->ReadToken(TokenType::Colon); - } + // The initial name before the `{` is only supposed + // to be made visible to reflection + auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); + // Look at the qualifiers present on the block to decide what kind + // of block we are looking at. Also *remove* those qualifiers so + // that they don't interfere with downstream work. + String blockWrapperTypeName; + if( auto uniformMod = modifiers.findModifier<HLSLUniformModifier>() ) + { + removeModifier(modifiers, uniformMod); + blockWrapperTypeName = "ConstantBuffer"; } - - - static void ParseOptSemantics( - Parser* parser, - Decl* decl) + else if( auto inMod = modifiers.findModifier<InModifier>() ) { - AddModifiers(decl, ParseOptSemantics(parser)); + removeModifier(modifiers, inMod); + blockWrapperTypeName = "__GLSLInputParameterBlock"; } - - static RefPtr<Decl> ParseHLSLBufferDecl( - Parser* parser) + else if( auto outMod = modifiers.findModifier<OutModifier>() ) { - // An HLSL declaration of a constant buffer like this: - // - // cbuffer Foo : register(b0) { int a; float b; }; - // - // is treated as syntax sugar for a type declaration - // and then a global variable declaration using that type: - // - // struct $anonymous { int a; float b; }; - // ConstantBuffer<$anonymous> Foo; - // - // where `$anonymous` is a fresh name, and the variable - // declaration is made to be "transparent" so that lookup - // will see through it to the members inside. - - // We first look at the declaration keywrod to determine - // the type of buffer to declare: - String bufferWrapperTypeName; - CodePosition bufferWrapperTypeNamePos = parser->tokenReader.PeekLoc(); - if (AdvanceIf(parser, "cbuffer")) - { - bufferWrapperTypeName = "ConstantBuffer"; - } - else if (AdvanceIf(parser, "tbuffer")) - { - bufferWrapperTypeName = "TextureBuffer"; - } - else - { - Unexpected(parser); - } - - // We are going to represent each buffer as a pair of declarations. - // The first is a type declaration that holds all the members, while - // the second is a variable declaration that uses the buffer type. - RefPtr<StructSyntaxNode> bufferDataTypeDecl = new StructSyntaxNode(); - RefPtr<Variable> bufferVarDecl = new Variable(); - - // Both declarations will have a location that points to the name - parser->FillPosition(bufferDataTypeDecl.Ptr()); - parser->FillPosition(bufferVarDecl.Ptr()); - - auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); - - // Attach the reflection name to the block so we can use it - auto reflectionNameModifier = new ParameterBlockReflectionName(); - reflectionNameModifier->nameToken = reflectionNameToken; - addModifier(bufferVarDecl, reflectionNameModifier); - - // Both the buffer variable and its type need to have names generated - bufferVarDecl->Name.Content = GenerateName(parser, "SLANG_constantBuffer_" + reflectionNameToken.Content); - bufferDataTypeDecl->Name.Content = GenerateName(parser, "SLANG_ConstantBuffer_" + reflectionNameToken.Content); - - addModifier(bufferDataTypeDecl, new ImplicitParameterBlockElementTypeModifier()); - addModifier(bufferVarDecl, new ImplicitParameterBlockVariableModifier()); - - // TODO(tfoley): We end up constructing unchecked syntax here that - // is expected to type check into the right form, but it might be - // cleaner to have a more explicit desugaring pass where we parse - // these constructs directly into the AST and *then* desugar them. - - // Construct a type expression to reference the buffer data type - auto bufferDataTypeExpr = new VarExpressionSyntaxNode(); - bufferDataTypeExpr->Position = bufferDataTypeDecl->Position; - bufferDataTypeExpr->Variable = bufferDataTypeDecl->Name.Content; - bufferDataTypeExpr->scope = parser->currentScope.Ptr(); - - // Construct a type exrpession to reference the type constructor - auto bufferWrapperTypeExpr = new VarExpressionSyntaxNode(); - bufferWrapperTypeExpr->Position = bufferWrapperTypeNamePos; - bufferWrapperTypeExpr->Variable = bufferWrapperTypeName; - - // Always need to look this up in the outer scope, - // so that it won't collide with, e.g., a local variable called `ConstantBuffer` - bufferWrapperTypeExpr->scope = parser->outerScope; - - // Construct a type expression that represents the type for the variable, - // which is the wrapper type applied to the data type - auto bufferVarTypeExpr = new GenericAppExpr(); - bufferVarTypeExpr->Position = bufferVarDecl->Position; - bufferVarTypeExpr->FunctionExpr = bufferWrapperTypeExpr; - bufferVarTypeExpr->Arguments.Add(bufferDataTypeExpr); - - bufferVarDecl->Type.exp = bufferVarTypeExpr; - - // Any semantics applied to the bufer declaration are taken as applying - // to the variable instead. - ParseOptSemantics(parser, bufferVarDecl.Ptr()); - - // The declarations in the body belong to the data type. - parseAggTypeDeclBody(parser, bufferDataTypeDecl.Ptr()); - - // All HLSL buffer declarations are "transparent" in that their - // members are implicitly made visible in the parent scope. - // We achieve this by applying the transparent modifier to the variable. - auto transparentModifier = new TransparentModifier(); - transparentModifier->next = bufferVarDecl->modifiers.first; - bufferVarDecl->modifiers.first = transparentModifier; - - // Because we are constructing two declarations, we have a thorny - // issue that were are only supposed to return one. - // For now we handle this by adding the type declaration to - // the current scope manually, and then returning the variable - // declaration. - // - // Note: this means that any modifiers that have already been parsed - // will get attached to the variable declaration, not the type. - // There might be cases where we need to shuffle things around. - - AddMember(parser->currentScope, bufferDataTypeDecl); - - return bufferVarDecl; + removeModifier(modifiers, outMod); + blockWrapperTypeName = "__GLSLOutputParameterBlock"; } - - static void removeModifier( - Modifiers& modifiers, - RefPtr<Modifier> modifier) + else if( auto bufferMod = modifiers.findModifier<GLSLBufferModifier>() ) { - RefPtr<Modifier>* link = &modifiers.first; - while (*link) - { - if (*link == modifier) - { - *link = (*link)->next; - return; - } - - link = &(*link)->next; - } + removeModifier(modifiers, bufferMod); + blockWrapperTypeName = "__GLSLShaderStorageBuffer"; } - - static RefPtr<Decl> parseGLSLBlockDecl( - Parser* parser, - Modifiers& modifiers) + else { - // An GLSL block like this: - // - // uniform Foo { int a; float b; } foo; - // - // is treated as syntax sugar for a type declaration - // and then a global variable declaration using that type: - // - // struct $anonymous { int a; float b; }; - // Block<$anonymous> foo; - // - // where `$anonymous` is a fresh name. - // - // If a "local name" like `foo` is not given, then - // we make the declaration "transparent" so that lookup - // will see through it to the members inside. + // Unknown case: just map to a constant buffer and hope for the best + blockWrapperTypeName = "ConstantBuffer"; + } + // We are going to represent each buffer as a pair of declarations. + // The first is a type declaration that holds all the members, while + // the second is a variable declaration that uses the buffer type. + RefPtr<StructSyntaxNode> blockDataTypeDecl = new StructSyntaxNode(); + RefPtr<Variable> blockVarDecl = new Variable(); - CodePosition pos = parser->tokenReader.PeekLoc(); + addModifier(blockDataTypeDecl, new ImplicitParameterBlockElementTypeModifier()); + addModifier(blockVarDecl, new ImplicitParameterBlockVariableModifier()); - // The initial name before the `{` is only supposed - // to be made visible to reflection - auto reflectionNameToken = parser->ReadToken(TokenType::Identifier); + // Attach the reflection name to the block so we can use it + auto reflectionNameModifier = new ParameterBlockReflectionName(); + reflectionNameModifier->nameToken = reflectionNameToken; + addModifier(blockVarDecl, reflectionNameModifier); - // Look at the qualifiers present on the block to decide what kind - // of block we are looking at. Also *remove* those qualifiers so - // that they don't interfere with downstream work. - String blockWrapperTypeName; - if( auto uniformMod = modifiers.findModifier<HLSLUniformModifier>() ) - { - removeModifier(modifiers, uniformMod); - blockWrapperTypeName = "ConstantBuffer"; - } - else if( auto inMod = modifiers.findModifier<InModifier>() ) - { - removeModifier(modifiers, inMod); - blockWrapperTypeName = "__GLSLInputParameterBlock"; - } - else if( auto outMod = modifiers.findModifier<OutModifier>() ) - { - removeModifier(modifiers, outMod); - blockWrapperTypeName = "__GLSLOutputParameterBlock"; - } - else if( auto bufferMod = modifiers.findModifier<GLSLBufferModifier>() ) - { - removeModifier(modifiers, bufferMod); - blockWrapperTypeName = "__GLSLShaderStorageBuffer"; - } - else - { - // Unknown case: just map to a constant buffer and hope for the best - blockWrapperTypeName = "ConstantBuffer"; - } + // Both declarations will have a location that points to the name + parser->FillPosition(blockDataTypeDecl.Ptr()); + parser->FillPosition(blockVarDecl.Ptr()); - // We are going to represent each buffer as a pair of declarations. - // The first is a type declaration that holds all the members, while - // the second is a variable declaration that uses the buffer type. - RefPtr<StructSyntaxNode> blockDataTypeDecl = new StructSyntaxNode(); - RefPtr<Variable> blockVarDecl = new Variable(); - - addModifier(blockDataTypeDecl, new ImplicitParameterBlockElementTypeModifier()); - addModifier(blockVarDecl, new ImplicitParameterBlockVariableModifier()); - - // Attach the reflection name to the block so we can use it - auto reflectionNameModifier = new ParameterBlockReflectionName(); - reflectionNameModifier->nameToken = reflectionNameToken; - addModifier(blockVarDecl, reflectionNameModifier); - - // Both declarations will have a location that points to the name - parser->FillPosition(blockDataTypeDecl.Ptr()); - parser->FillPosition(blockVarDecl.Ptr()); - - // Generate a unique name for the data type - blockDataTypeDecl->Name.Content = GenerateName(parser, "SLANG_ParameterBlock_" + reflectionNameToken.Content); - - // TODO(tfoley): We end up constructing unchecked syntax here that - // is expected to type check into the right form, but it might be - // cleaner to have a more explicit desugaring pass where we parse - // these constructs directly into the AST and *then* desugar them. - - // Construct a type expression to reference the buffer data type - auto blockDataTypeExpr = new VarExpressionSyntaxNode(); - blockDataTypeExpr->Position = blockDataTypeDecl->Position; - blockDataTypeExpr->Variable = blockDataTypeDecl->Name.Content; - blockDataTypeExpr->scope = parser->currentScope.Ptr(); - - // Construct a type exrpession to reference the type constructor - auto blockWrapperTypeExpr = new VarExpressionSyntaxNode(); - blockWrapperTypeExpr->Position = pos; - blockWrapperTypeExpr->Variable = blockWrapperTypeName; - // Always need to look this up in the outer scope, - // so that it won't collide with, e.g., a local variable called `ConstantBuffer` - blockWrapperTypeExpr->scope = parser->outerScope; - - // Construct a type expression that represents the type for the variable, - // which is the wrapper type applied to the data type - auto blockVarTypeExpr = new GenericAppExpr(); - blockVarTypeExpr->Position = blockVarDecl->Position; - blockVarTypeExpr->FunctionExpr = blockWrapperTypeExpr; - blockVarTypeExpr->Arguments.Add(blockDataTypeExpr); - - blockVarDecl->Type.exp = blockVarTypeExpr; - - // The declarations in the body belong to the data type. - parseAggTypeDeclBody(parser, blockDataTypeDecl.Ptr()); - - if( parser->LookAheadToken(TokenType::Identifier) ) - { - // The user gave an explicit name to the block, - // so we need to use that as our variable name - blockVarDecl->Name = parser->ReadToken(TokenType::Identifier); + // Generate a unique name for the data type + blockDataTypeDecl->Name.Content = GenerateName(parser, "SLANG_ParameterBlock_" + reflectionNameToken.Content); - // TODO: in this case we make actually have a more complex - // declarator, including `[]` brackets. - } - else - { - // synthesize a dummy name - blockVarDecl->Name.Content = GenerateName(parser, "SLANG_parameterBlock_" + reflectionNameToken.Content); - - // Otherwise we have a transparent declaration, similar - // to an HLSL `cbuffer` - auto transparentModifier = new TransparentModifier(); - transparentModifier->Position = pos; - addModifier(blockVarDecl, transparentModifier); - } + // TODO(tfoley): We end up constructing unchecked syntax here that + // is expected to type check into the right form, but it might be + // cleaner to have a more explicit desugaring pass where we parse + // these constructs directly into the AST and *then* desugar them. - // Expect a trailing `;` - parser->ReadToken(TokenType::Semicolon); + // Construct a type expression to reference the buffer data type + auto blockDataTypeExpr = new VarExpressionSyntaxNode(); + blockDataTypeExpr->Position = blockDataTypeDecl->Position; + blockDataTypeExpr->Variable = blockDataTypeDecl->Name.Content; + blockDataTypeExpr->scope = parser->currentScope.Ptr(); - // Because we are constructing two declarations, we have a thorny - // issue that were are only supposed to return one. - // For now we handle this by adding the type declaration to - // the current scope manually, and then returning the variable - // declaration. - // - // Note: this means that any modifiers that have already been parsed - // will get attached to the variable declaration, not the type. - // There might be cases where we need to shuffle things around. + // Construct a type exrpession to reference the type constructor + auto blockWrapperTypeExpr = new VarExpressionSyntaxNode(); + blockWrapperTypeExpr->Position = pos; + blockWrapperTypeExpr->Variable = blockWrapperTypeName; + // Always need to look this up in the outer scope, + // so that it won't collide with, e.g., a local variable called `ConstantBuffer` + blockWrapperTypeExpr->scope = parser->outerScope; + + // Construct a type expression that represents the type for the variable, + // which is the wrapper type applied to the data type + auto blockVarTypeExpr = new GenericAppExpr(); + blockVarTypeExpr->Position = blockVarDecl->Position; + blockVarTypeExpr->FunctionExpr = blockWrapperTypeExpr; + blockVarTypeExpr->Arguments.Add(blockDataTypeExpr); + + blockVarDecl->Type.exp = blockVarTypeExpr; + + // The declarations in the body belong to the data type. + parseAggTypeDeclBody(parser, blockDataTypeDecl.Ptr()); + + if( parser->LookAheadToken(TokenType::Identifier) ) + { + // The user gave an explicit name to the block, + // so we need to use that as our variable name + blockVarDecl->Name = parser->ReadToken(TokenType::Identifier); - AddMember(parser->currentScope, blockDataTypeDecl); + // TODO: in this case we make actually have a more complex + // declarator, including `[]` brackets. + } + else + { + // synthesize a dummy name + blockVarDecl->Name.Content = GenerateName(parser, "SLANG_parameterBlock_" + reflectionNameToken.Content); - return blockVarDecl; + // Otherwise we have a transparent declaration, similar + // to an HLSL `cbuffer` + auto transparentModifier = new TransparentModifier(); + transparentModifier->Position = pos; + addModifier(blockVarDecl, transparentModifier); } + // Expect a trailing `;` + parser->ReadToken(TokenType::Semicolon); + + // Because we are constructing two declarations, we have a thorny + // issue that were are only supposed to return one. + // For now we handle this by adding the type declaration to + // the current scope manually, and then returning the variable + // declaration. + // + // Note: this means that any modifiers that have already been parsed + // will get attached to the variable declaration, not the type. + // There might be cases where we need to shuffle things around. + + AddMember(parser->currentScope, blockDataTypeDecl); + + return blockVarDecl; + } + - static RefPtr<Decl> ParseGenericParamDecl( - Parser* parser, - RefPtr<GenericDecl> genericDecl) + static RefPtr<Decl> ParseGenericParamDecl( + Parser* parser, + RefPtr<GenericDecl> genericDecl) + { + // simple syntax to introduce a value parameter + if (AdvanceIf(parser, "let")) { - // simple syntax to introduce a value parameter - if (AdvanceIf(parser, "let")) + // default case is a type parameter + auto paramDecl = new GenericValueParamDecl(); + paramDecl->Name = parser->ReadToken(TokenType::Identifier); + if (AdvanceIf(parser, TokenType::Colon)) { - // default case is a type parameter - auto paramDecl = new GenericValueParamDecl(); - paramDecl->Name = parser->ReadToken(TokenType::Identifier); - if (AdvanceIf(parser, TokenType::Colon)) - { - paramDecl->Type = parser->ParseTypeExp(); - } - if (AdvanceIf(parser, TokenType::OpAssign)) - { - paramDecl->Expr = parser->ParseInitExpr(); - } - return paramDecl; + paramDecl->Type = parser->ParseTypeExp(); } - else + if (AdvanceIf(parser, TokenType::OpAssign)) { - // default case is a type parameter - auto paramDecl = new GenericTypeParamDecl(); - parser->FillPosition(paramDecl); - paramDecl->Name = parser->ReadToken(TokenType::Identifier); - if (AdvanceIf(parser, TokenType::Colon)) - { - // The user is apply a constraint to this type parameter... + paramDecl->Expr = parser->ParseInitExpr(); + } + return paramDecl; + } + else + { + // default case is a type parameter + auto paramDecl = new GenericTypeParamDecl(); + parser->FillPosition(paramDecl); + paramDecl->Name = parser->ReadToken(TokenType::Identifier); + if (AdvanceIf(parser, TokenType::Colon)) + { + // The user is apply a constraint to this type parameter... - auto paramConstraint = new GenericTypeConstraintDecl(); - parser->FillPosition(paramConstraint); + auto paramConstraint = new GenericTypeConstraintDecl(); + parser->FillPosition(paramConstraint); - auto paramType = DeclRefType::Create(DeclRef(paramDecl, nullptr)); + auto paramType = DeclRefType::Create(DeclRef(paramDecl, nullptr)); - auto paramTypeExpr = new SharedTypeExpr(); - paramTypeExpr->Position = paramDecl->Position; - paramTypeExpr->base.type = paramType; - paramTypeExpr->Type = new TypeType(paramType); + auto paramTypeExpr = new SharedTypeExpr(); + paramTypeExpr->Position = paramDecl->Position; + paramTypeExpr->base.type = paramType; + paramTypeExpr->Type = new TypeType(paramType); - paramConstraint->sub = TypeExp(paramTypeExpr); - paramConstraint->sup = parser->ParseTypeExp(); + paramConstraint->sub = TypeExp(paramTypeExpr); + paramConstraint->sup = parser->ParseTypeExp(); - AddMember(genericDecl, paramConstraint); + AddMember(genericDecl, paramConstraint); - } - if (AdvanceIf(parser, TokenType::OpAssign)) - { - paramDecl->initType = parser->ParseTypeExp(); - } - return paramDecl; } + if (AdvanceIf(parser, TokenType::OpAssign)) + { + paramDecl->initType = parser->ParseTypeExp(); + } + return paramDecl; } + } - static RefPtr<Decl> ParseGenericDecl( - Parser* parser) + static RefPtr<Decl> ParseGenericDecl( + Parser* parser) + { + RefPtr<GenericDecl> decl = new GenericDecl(); + parser->FillPosition(decl.Ptr()); + parser->PushScope(decl.Ptr()); + parser->ReadToken("__generic"); + parser->ReadToken(TokenType::OpLess); + parser->genericDepth++; + while (!parser->LookAheadToken(TokenType::OpGreater)) { - RefPtr<GenericDecl> decl = new GenericDecl(); - parser->FillPosition(decl.Ptr()); - parser->PushScope(decl.Ptr()); - parser->ReadToken("__generic"); - parser->ReadToken(TokenType::OpLess); - parser->genericDepth++; - while (!parser->LookAheadToken(TokenType::OpGreater)) - { - AddMember(decl, ParseGenericParamDecl(parser, decl)); + AddMember(decl, ParseGenericParamDecl(parser, decl)); if( parser->LookAheadToken(TokenType::OpGreater) ) break; parser->ReadToken(TokenType::Comma); - } - parser->genericDepth--; - parser->ReadToken(TokenType::OpGreater); + } + parser->genericDepth--; + parser->ReadToken(TokenType::OpGreater); - decl->inner = ParseSingleDecl(parser, decl.Ptr()); + decl->inner = ParseSingleDecl(parser, decl.Ptr()); - // A generic decl hijacks the name of the declaration - // it wraps, so that lookup can find it. - decl->Name = decl->inner->Name; + // A generic decl hijacks the name of the declaration + // it wraps, so that lookup can find it. + decl->Name = decl->inner->Name; - parser->PopScope(); - return decl; - } + parser->PopScope(); + return decl; + } - static RefPtr<ExtensionDecl> ParseExtensionDecl(Parser* parser) - { - RefPtr<ExtensionDecl> decl = new ExtensionDecl(); - parser->FillPosition(decl.Ptr()); - parser->ReadToken("__extension"); - decl->targetType = parser->ParseTypeExp(); + static RefPtr<ExtensionDecl> ParseExtensionDecl(Parser* parser) + { + RefPtr<ExtensionDecl> decl = new ExtensionDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("__extension"); + decl->targetType = parser->ParseTypeExp(); - parseAggTypeDeclBody(parser, decl.Ptr()); + parseAggTypeDeclBody(parser, decl.Ptr()); - return decl; - } + return decl; + } - static void parseOptionalInheritanceClause(Parser* parser, AggTypeDecl* decl) + static void parseOptionalInheritanceClause(Parser* parser, AggTypeDecl* decl) + { + if( AdvanceIf(parser, TokenType::Colon) ) { - if( AdvanceIf(parser, TokenType::Colon) ) + do { - do - { - auto base = parser->ParseTypeExp(); + auto base = parser->ParseTypeExp(); - auto inheritanceDecl = new InheritanceDecl(); - inheritanceDecl->Position = base.exp->Position; - inheritanceDecl->base = base; + auto inheritanceDecl = new InheritanceDecl(); + inheritanceDecl->Position = base.exp->Position; + inheritanceDecl->base = base; - AddMember(decl, inheritanceDecl); + AddMember(decl, inheritanceDecl); - } while( AdvanceIf(parser, TokenType::Comma) ); - } + } while( AdvanceIf(parser, TokenType::Comma) ); } + } - static RefPtr<InterfaceDecl> parseInterfaceDecl(Parser* parser) - { - RefPtr<InterfaceDecl> decl = new InterfaceDecl(); - parser->FillPosition(decl.Ptr()); - parser->ReadToken("interface"); - decl->Name = parser->ReadToken(TokenType::Identifier); + static RefPtr<InterfaceDecl> parseInterfaceDecl(Parser* parser) + { + RefPtr<InterfaceDecl> decl = new InterfaceDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("interface"); + decl->Name = parser->ReadToken(TokenType::Identifier); - parseOptionalInheritanceClause(parser, decl.Ptr()); + parseOptionalInheritanceClause(parser, decl.Ptr()); - parseAggTypeDeclBody(parser, decl.Ptr()); + parseAggTypeDeclBody(parser, decl.Ptr()); - return decl; - } + return decl; + } - static RefPtr<ConstructorDecl> ParseConstructorDecl(Parser* parser) - { - RefPtr<ConstructorDecl> decl = new ConstructorDecl(); - parser->FillPosition(decl.Ptr()); - parser->ReadToken("__init"); + static RefPtr<ConstructorDecl> ParseConstructorDecl(Parser* parser) + { + RefPtr<ConstructorDecl> decl = new ConstructorDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("__init"); - parseParameterList(parser, decl); + parseParameterList(parser, decl); - if( AdvanceIf(parser, TokenType::Semicolon) ) - { - // empty body - } - else - { - decl->Body = parser->ParseBlockStatement(); - } - return decl; + if( AdvanceIf(parser, TokenType::Semicolon) ) + { + // empty body } - - static RefPtr<AccessorDecl> parseAccessorDecl(Parser* parser) + else { - RefPtr<AccessorDecl> decl; - if( AdvanceIf(parser, "get") ) - { - decl = new GetterDecl(); - } - else if( AdvanceIf(parser, "set") ) - { - decl = new SetterDecl(); - } - else - { - Unexpected(parser); - return nullptr; - } - - if( parser->tokenReader.PeekTokenType() == TokenType::LBrace ) - { - decl->Body = parser->ParseBlockStatement(); - } - else - { - parser->ReadToken(TokenType::Semicolon); - } + decl->Body = parser->ParseBlockStatement(); + } + return decl; + } - return decl; + static RefPtr<AccessorDecl> parseAccessorDecl(Parser* parser) + { + RefPtr<AccessorDecl> decl; + if( AdvanceIf(parser, "get") ) + { + decl = new GetterDecl(); + } + else if( AdvanceIf(parser, "set") ) + { + decl = new SetterDecl(); + } + else + { + Unexpected(parser); + return nullptr; } - static RefPtr<SubscriptDecl> ParseSubscriptDecl(Parser* parser) + if( parser->tokenReader.PeekTokenType() == TokenType::LBrace ) { - RefPtr<SubscriptDecl> decl = new SubscriptDecl(); - parser->FillPosition(decl.Ptr()); - parser->ReadToken("__subscript"); + decl->Body = parser->ParseBlockStatement(); + } + else + { + parser->ReadToken(TokenType::Semicolon); + } - // TODO: the use of this name here is a bit magical... - decl->Name.Content = "operator[]"; + return decl; + } - parseParameterList(parser, decl); + static RefPtr<SubscriptDecl> ParseSubscriptDecl(Parser* parser) + { + RefPtr<SubscriptDecl> decl = new SubscriptDecl(); + parser->FillPosition(decl.Ptr()); + parser->ReadToken("__subscript"); - if( AdvanceIf(parser, TokenType::RightArrow) ) - { - decl->ReturnType = parser->ParseTypeExp(); - } + // TODO: the use of this name here is a bit magical... + decl->Name.Content = "operator[]"; - if( AdvanceIf(parser, TokenType::LBrace) ) - { - // We want to parse nested "accessor" declarations - while( !AdvanceIfMatch(parser, TokenType::RBrace) ) - { - auto accessor = parseAccessorDecl(parser); - AddMember(decl, accessor); - } - } - else - { - parser->ReadToken(TokenType::Semicolon); + parseParameterList(parser, decl); - // empty body should be treated like `{ get; }` + if( AdvanceIf(parser, TokenType::RightArrow) ) + { + decl->ReturnType = parser->ParseTypeExp(); + } + + if( AdvanceIf(parser, TokenType::LBrace) ) + { + // We want to parse nested "accessor" declarations + while( !AdvanceIfMatch(parser, TokenType::RBrace) ) + { + auto accessor = parseAccessorDecl(parser); + AddMember(decl, accessor); } + } + else + { + parser->ReadToken(TokenType::Semicolon); - return decl; + // empty body should be treated like `{ get; }` } - // Parse a declaration of a new modifier keyword - static RefPtr<ModifierDecl> parseModifierDecl(Parser* parser) - { - RefPtr<ModifierDecl> decl = new ModifierDecl(); + return decl; + } - // read the `__modifier` keyword - parser->ReadToken(TokenType::Identifier); + // Parse a declaration of a new modifier keyword + static RefPtr<ModifierDecl> parseModifierDecl(Parser* parser) + { + RefPtr<ModifierDecl> decl = new ModifierDecl(); - parser->ReadToken(TokenType::LParent); - decl->classNameToken = parser->ReadToken(TokenType::Identifier); - parser->ReadToken(TokenType::RParent); + // read the `__modifier` keyword + parser->ReadToken(TokenType::Identifier); - parser->FillPosition(decl.Ptr()); - decl->Name = parser->ReadToken(TokenType::Identifier); + parser->ReadToken(TokenType::LParent); + decl->classNameToken = parser->ReadToken(TokenType::Identifier); + parser->ReadToken(TokenType::RParent); - parser->ReadToken(TokenType::Semicolon); - return decl; - } + parser->FillPosition(decl.Ptr()); + decl->Name = parser->ReadToken(TokenType::Identifier); - // Finish up work on a declaration that was parsed - static void CompleteDecl( - Parser* /*parser*/, - RefPtr<Decl> decl, - ContainerDecl* containerDecl, - Modifiers modifiers) - { - // Add any modifiers we parsed before the declaration to the list - // of modifiers on the declaration itself. - AddModifiers(decl.Ptr(), modifiers.first); + parser->ReadToken(TokenType::Semicolon); + return decl; + } - // Make sure the decl is properly nested inside its lexical parent - if (containerDecl) - { - AddMember(containerDecl, decl); - } - } + // Finish up work on a declaration that was parsed + static void CompleteDecl( + Parser* /*parser*/, + RefPtr<Decl> decl, + ContainerDecl* containerDecl, + Modifiers modifiers) + { + // Add any modifiers we parsed before the declaration to the list + // of modifiers on the declaration itself. + AddModifiers(decl.Ptr(), modifiers.first); - static RefPtr<DeclBase> ParseDeclWithModifiers( - Parser* parser, - ContainerDecl* containerDecl, - Modifiers modifiers ) - { - RefPtr<DeclBase> decl; - - auto loc = parser->tokenReader.PeekLoc(); - - // TODO: actual dispatch! - if (parser->LookAheadToken("struct")) - decl = ParseDeclaratorDecl(parser, containerDecl); - else if (parser->LookAheadToken("class")) - decl = ParseDeclaratorDecl(parser, containerDecl); - else if (parser->LookAheadToken("typedef")) - decl = ParseTypeDef(parser); - else if (parser->LookAheadToken("cbuffer") || parser->LookAheadToken("tbuffer")) - decl = ParseHLSLBufferDecl(parser); - else if (parser->LookAheadToken("__generic")) - decl = ParseGenericDecl(parser); - else if (parser->LookAheadToken("__extension")) - decl = ParseExtensionDecl(parser); - else if (parser->LookAheadToken("__init")) - decl = ParseConstructorDecl(parser); - else if (parser->LookAheadToken("__subscript")) - decl = ParseSubscriptDecl(parser); - else if (parser->LookAheadToken("interface")) - decl = parseInterfaceDecl(parser); - else if(parser->LookAheadToken("__modifier")) - decl = parseModifierDecl(parser); - else if(parser->LookAheadToken("__import")) - decl = parseImportDecl(parser); - else if (AdvanceIf(parser, TokenType::Semicolon)) - { - decl = new EmptyDecl(); - decl->Position = loc; - } - // GLSL requires that we be able to parse "block" declarations, - // which look superficially similar to declarator declarations - else if( parser->LookAheadToken(TokenType::Identifier) - && parser->LookAheadToken(TokenType::LBrace, 1) ) - { - decl = parseGLSLBlockDecl(parser, modifiers); - } - else - { - // Default case: just parse a declarator-based declaration - decl = ParseDeclaratorDecl(parser, containerDecl); - } + // Make sure the decl is properly nested inside its lexical parent + if (containerDecl) + { + AddMember(containerDecl, decl); + } + } - if (decl) - { - if( auto dd = decl.As<Decl>() ) - { - CompleteDecl(parser, dd, containerDecl, modifiers); - } - else if(auto declGroup = decl.As<DeclGroup>()) - { - // We are going to add the same modifiers to *all* of these declarations, - // so we want to give later passes a way to detect which modifiers - // were shared, vs. which ones are specific to a single declaration. + static RefPtr<DeclBase> ParseDeclWithModifiers( + Parser* parser, + ContainerDecl* containerDecl, + Modifiers modifiers ) + { + RefPtr<DeclBase> decl; - auto sharedModifiers = new SharedModifiers(); - sharedModifiers->next = modifiers.first; - modifiers.first = sharedModifiers; + auto loc = parser->tokenReader.PeekLoc(); - for( auto subDecl : declGroup->decls ) - { - CompleteDecl(parser, subDecl, containerDecl, modifiers); - } - } - } - return decl; + // TODO: actual dispatch! + if (parser->LookAheadToken("struct")) + decl = ParseDeclaratorDecl(parser, containerDecl); + else if (parser->LookAheadToken("class")) + decl = ParseDeclaratorDecl(parser, containerDecl); + else if (parser->LookAheadToken("typedef")) + decl = ParseTypeDef(parser); + else if (parser->LookAheadToken("cbuffer") || parser->LookAheadToken("tbuffer")) + decl = ParseHLSLBufferDecl(parser); + else if (parser->LookAheadToken("__generic")) + decl = ParseGenericDecl(parser); + else if (parser->LookAheadToken("__extension")) + decl = ParseExtensionDecl(parser); + else if (parser->LookAheadToken("__init")) + decl = ParseConstructorDecl(parser); + else if (parser->LookAheadToken("__subscript")) + decl = ParseSubscriptDecl(parser); + else if (parser->LookAheadToken("interface")) + decl = parseInterfaceDecl(parser); + else if(parser->LookAheadToken("__modifier")) + decl = parseModifierDecl(parser); + else if(parser->LookAheadToken("__import")) + decl = parseImportDecl(parser); + else if (AdvanceIf(parser, TokenType::Semicolon)) + { + decl = new EmptyDecl(); + decl->Position = loc; } - - static RefPtr<DeclBase> ParseDecl( - Parser* parser, - ContainerDecl* containerDecl) + // GLSL requires that we be able to parse "block" declarations, + // which look superficially similar to declarator declarations + else if( parser->LookAheadToken(TokenType::Identifier) + && parser->LookAheadToken(TokenType::LBrace, 1) ) + { + decl = parseGLSLBlockDecl(parser, modifiers); + } + else { - Modifiers modifiers = ParseModifiers(parser); - return ParseDeclWithModifiers(parser, containerDecl, modifiers); + // Default case: just parse a declarator-based declaration + decl = ParseDeclaratorDecl(parser, containerDecl); } - static RefPtr<Decl> ParseSingleDecl( - Parser* parser, - ContainerDecl* containerDecl) + if (decl) { - auto declBase = ParseDecl(parser, containerDecl); - if(!declBase) - return nullptr; - if( auto decl = declBase.As<Decl>() ) + if( auto dd = decl.As<Decl>() ) { - return decl; + CompleteDecl(parser, dd, containerDecl, modifiers); } - else if( auto declGroup = declBase.As<DeclGroup>() ) + else if(auto declGroup = decl.As<DeclGroup>()) { - if( declGroup->decls.Count() == 1 ) + // We are going to add the same modifiers to *all* of these declarations, + // so we want to give later passes a way to detect which modifiers + // were shared, vs. which ones are specific to a single declaration. + + auto sharedModifiers = new SharedModifiers(); + sharedModifiers->next = modifiers.first; + modifiers.first = sharedModifiers; + + for( auto subDecl : declGroup->decls ) { - return declGroup->decls[0]; + CompleteDecl(parser, subDecl, containerDecl, modifiers); } } - - parser->sink->diagnose(declBase->Position, Diagnostics::unimplemented, "didn't expect multiple declarations here"); - return nullptr; } + return decl; + } + static RefPtr<DeclBase> ParseDecl( + Parser* parser, + ContainerDecl* containerDecl) + { + Modifiers modifiers = ParseModifiers(parser); + return ParseDeclWithModifiers(parser, containerDecl, modifiers); + } - // Parse a body consisting of declarations - static void ParseDeclBody( - Parser* parser, - ContainerDecl* containerDecl, - TokenType closingToken) + static RefPtr<Decl> ParseSingleDecl( + Parser* parser, + ContainerDecl* containerDecl) + { + auto declBase = ParseDecl(parser, containerDecl); + if(!declBase) + return nullptr; + if( auto decl = declBase.As<Decl>() ) { - while(!AdvanceIfMatch(parser, closingToken)) + return decl; + } + else if( auto declGroup = declBase.As<DeclGroup>() ) + { + if( declGroup->decls.Count() == 1 ) { - ParseDecl(parser, containerDecl); - TryRecover(parser); + return declGroup->decls[0]; } } - // Parse the `{}`-delimeted body of an aggregate type declaration - static void parseAggTypeDeclBody( - Parser* parser, - AggTypeDeclBase* decl) - { - // TODO: the scope used for the body might need to be - // slightly specialized to deal with the complexity - // of how `this` works. - // - // Alternatively, that complexity can be pushed down - // to semantic analysis so that it doesn't clutter - // things here. - parser->PushScope(decl); + parser->sink->diagnose(declBase->Position, Diagnostics::unimplemented, "didn't expect multiple declarations here"); + return nullptr; + } - parser->ReadToken(TokenType::LBrace); - ParseDeclBody(parser, decl, TokenType::RBrace); - parser->PopScope(); + // Parse a body consisting of declarations + static void ParseDeclBody( + Parser* parser, + ContainerDecl* containerDecl, + TokenType closingToken) + { + while(!AdvanceIfMatch(parser, closingToken)) + { + ParseDecl(parser, containerDecl); + TryRecover(parser); } + } + // Parse the `{}`-delimeted body of an aggregate type declaration + static void parseAggTypeDeclBody( + Parser* parser, + AggTypeDeclBase* decl) + { + // TODO: the scope used for the body might need to be + // slightly specialized to deal with the complexity + // of how `this` works. + // + // Alternatively, that complexity can be pushed down + // to semantic analysis so that it doesn't clutter + // things here. + parser->PushScope(decl); - void Parser::parseSourceFile(ProgramSyntaxNode* program) - { - if (outerScope) - { - currentScope = outerScope; - } + parser->ReadToken(TokenType::LBrace); + ParseDeclBody(parser, decl, TokenType::RBrace); - PushScope(program); - program->Position = CodePosition(0, 0, 0, fileName); - ParseDeclBody(this, program, TokenType::EndOfFile); - PopScope(); + parser->PopScope(); + } - assert(currentScope == outerScope); - currentScope = nullptr; - } - RefPtr<ProgramSyntaxNode> Parser::ParseProgram() + void Parser::parseSourceFile(ProgramSyntaxNode* program) + { + if (outerScope) { - RefPtr<ProgramSyntaxNode> program = new ProgramSyntaxNode(); + currentScope = outerScope; + } - parseSourceFile(program.Ptr()); + PushScope(program); + program->Position = CodePosition(0, 0, 0, fileName); + ParseDeclBody(this, program, TokenType::EndOfFile); + PopScope(); - return program; - } + assert(currentScope == outerScope); + currentScope = nullptr; + } - RefPtr<StructSyntaxNode> Parser::ParseStruct() - { - RefPtr<StructSyntaxNode> rs = new StructSyntaxNode(); - FillPosition(rs.Ptr()); - ReadToken("struct"); + RefPtr<ProgramSyntaxNode> Parser::ParseProgram() + { + RefPtr<ProgramSyntaxNode> program = new ProgramSyntaxNode(); - // TODO: support `struct` declaration without tag - rs->Name = ReadToken(TokenType::Identifier); + parseSourceFile(program.Ptr()); - // We allow for an inheritance clause on a `struct` - // so that it can conform to interfaces. - parseOptionalInheritanceClause(this, rs.Ptr()); + return program; + } - parseAggTypeDeclBody(this, rs.Ptr()); + RefPtr<StructSyntaxNode> Parser::ParseStruct() + { + RefPtr<StructSyntaxNode> rs = new StructSyntaxNode(); + FillPosition(rs.Ptr()); + ReadToken("struct"); - return rs; - } + // TODO: support `struct` declaration without tag + rs->Name = ReadToken(TokenType::Identifier); - RefPtr<ClassSyntaxNode> Parser::ParseClass() - { - RefPtr<ClassSyntaxNode> rs = new ClassSyntaxNode(); - FillPosition(rs.Ptr()); - ReadToken("class"); - rs->Name = ReadToken(TokenType::Identifier); - ReadToken(TokenType::LBrace); - parseOptionalInheritanceClause(this, rs.Ptr()); - parseAggTypeDeclBody(this, rs.Ptr()); - return rs; - } + // We allow for an inheritance clause on a `struct` + // so that it can conform to interfaces. + parseOptionalInheritanceClause(this, rs.Ptr()); - static RefPtr<StatementSyntaxNode> ParseSwitchStmt(Parser* parser) - { - RefPtr<SwitchStmt> stmt = new SwitchStmt(); - parser->FillPosition(stmt.Ptr()); - parser->ReadToken("switch"); - parser->ReadToken(TokenType::LParent); - stmt->condition = parser->ParseExpression(); - parser->ReadToken(TokenType::RParent); - stmt->body = parser->ParseBlockStatement(); - return stmt; - } + parseAggTypeDeclBody(this, rs.Ptr()); - static RefPtr<StatementSyntaxNode> ParseCaseStmt(Parser* parser) - { - RefPtr<CaseStmt> stmt = new CaseStmt(); - parser->FillPosition(stmt.Ptr()); - parser->ReadToken("case"); - stmt->expr = parser->ParseExpression(); - parser->ReadToken(TokenType::Colon); - return stmt; - } + return rs; + } - static RefPtr<StatementSyntaxNode> ParseDefaultStmt(Parser* parser) - { - RefPtr<DefaultStmt> stmt = new DefaultStmt(); - parser->FillPosition(stmt.Ptr()); - parser->ReadToken("default"); - parser->ReadToken(TokenType::Colon); - return stmt; - } + RefPtr<ClassSyntaxNode> Parser::ParseClass() + { + RefPtr<ClassSyntaxNode> rs = new ClassSyntaxNode(); + FillPosition(rs.Ptr()); + ReadToken("class"); + rs->Name = ReadToken(TokenType::Identifier); + ReadToken(TokenType::LBrace); + parseOptionalInheritanceClause(this, rs.Ptr()); + parseAggTypeDeclBody(this, rs.Ptr()); + return rs; + } - static bool peekTypeName(Parser* parser) - { - if(!parser->LookAheadToken(TokenType::Identifier)) - return false; + static RefPtr<StatementSyntaxNode> ParseSwitchStmt(Parser* parser) + { + RefPtr<SwitchStmt> stmt = new SwitchStmt(); + parser->FillPosition(stmt.Ptr()); + parser->ReadToken("switch"); + parser->ReadToken(TokenType::LParent); + stmt->condition = parser->ParseExpression(); + parser->ReadToken(TokenType::RParent); + stmt->body = parser->ParseBlockStatement(); + return stmt; + } - auto name = parser->tokenReader.PeekToken().Content; + static RefPtr<StatementSyntaxNode> ParseCaseStmt(Parser* parser) + { + RefPtr<CaseStmt> stmt = new CaseStmt(); + parser->FillPosition(stmt.Ptr()); + parser->ReadToken("case"); + stmt->expr = parser->ParseExpression(); + parser->ReadToken(TokenType::Colon); + return stmt; + } - auto lookupResult = LookUp(name, parser->currentScope); - if(!lookupResult.isValid() || lookupResult.isOverloaded()) - return false; + static RefPtr<StatementSyntaxNode> ParseDefaultStmt(Parser* parser) + { + RefPtr<DefaultStmt> stmt = new DefaultStmt(); + parser->FillPosition(stmt.Ptr()); + parser->ReadToken("default"); + parser->ReadToken(TokenType::Colon); + return stmt; + } - auto decl = lookupResult.item.declRef.GetDecl(); - if( auto typeDecl = dynamic_cast<AggTypeDecl*>(decl) ) - { - return true; - } - else if( auto typeVarDecl = dynamic_cast<SimpleTypeDecl*>(decl) ) - { - return true; - } - else - { - return false; - } - } + static bool peekTypeName(Parser* parser) + { + if(!parser->LookAheadToken(TokenType::Identifier)) + return false; - RefPtr<StatementSyntaxNode> Parser::ParseStatement() - { - auto modifiers = ParseModifiers(this); + auto name = parser->tokenReader.PeekToken().Content; - RefPtr<StatementSyntaxNode> statement; - if (LookAheadToken(TokenType::LBrace)) - statement = ParseBlockStatement(); - else if (peekTypeName(this)) - statement = ParseVarDeclrStatement(modifiers); - else if (LookAheadToken("if")) - statement = ParseIfStatement(); - else if (LookAheadToken("for")) - statement = ParseForStatement(); - else if (LookAheadToken("while")) - statement = ParseWhileStatement(); - else if (LookAheadToken("do")) - statement = ParseDoWhileStatement(); - else if (LookAheadToken("break")) - statement = ParseBreakStatement(); - else if (LookAheadToken("continue")) - statement = ParseContinueStatement(); - else if (LookAheadToken("return")) - statement = ParseReturnStatement(); - else if (LookAheadToken("discard")) - { - statement = new DiscardStatementSyntaxNode(); - FillPosition(statement.Ptr()); - ReadToken("discard"); - ReadToken(TokenType::Semicolon); - } - else if (LookAheadToken("switch")) - statement = ParseSwitchStmt(this); - else if (LookAheadToken("case")) - statement = ParseCaseStmt(this); - else if (LookAheadToken("default")) - statement = ParseDefaultStmt(this); - else if (LookAheadToken(TokenType::Identifier)) - { - // We might be looking at a local declaration, or an - // expression statement, and we need to figure out which. - // - // We'll solve this with backtracking for now. + auto lookupResult = LookUp(name, parser->currentScope); + if(!lookupResult.isValid() || lookupResult.isOverloaded()) + return false; - Token* startPos = tokenReader.mCursor; + auto decl = lookupResult.item.declRef.GetDecl(); + if( auto typeDecl = dynamic_cast<AggTypeDecl*>(decl) ) + { + return true; + } + else if( auto typeVarDecl = dynamic_cast<SimpleTypeDecl*>(decl) ) + { + return true; + } + else + { + return false; + } + } + + RefPtr<StatementSyntaxNode> Parser::ParseStatement() + { + auto modifiers = ParseModifiers(this); + + RefPtr<StatementSyntaxNode> statement; + if (LookAheadToken(TokenType::LBrace)) + statement = ParseBlockStatement(); + else if (peekTypeName(this)) + statement = ParseVarDeclrStatement(modifiers); + else if (LookAheadToken("if")) + statement = ParseIfStatement(); + else if (LookAheadToken("for")) + statement = ParseForStatement(); + else if (LookAheadToken("while")) + statement = ParseWhileStatement(); + else if (LookAheadToken("do")) + statement = ParseDoWhileStatement(); + else if (LookAheadToken("break")) + statement = ParseBreakStatement(); + else if (LookAheadToken("continue")) + statement = ParseContinueStatement(); + else if (LookAheadToken("return")) + statement = ParseReturnStatement(); + else if (LookAheadToken("discard")) + { + statement = new DiscardStatementSyntaxNode(); + FillPosition(statement.Ptr()); + ReadToken("discard"); + ReadToken(TokenType::Semicolon); + } + else if (LookAheadToken("switch")) + statement = ParseSwitchStmt(this); + else if (LookAheadToken("case")) + statement = ParseCaseStmt(this); + else if (LookAheadToken("default")) + statement = ParseDefaultStmt(this); + else if (LookAheadToken(TokenType::Identifier)) + { + // We might be looking at a local declaration, or an + // expression statement, and we need to figure out which. + // + // We'll solve this with backtracking for now. - // Try to parse a type (knowing that the type grammar is - // a subset of the expression grammar, and so this should - // always succeed). - RefPtr<ExpressionSyntaxNode> type = ParseType(); - // We don't actually care about the type, though, so - // don't retain it - type = nullptr; + Token* startPos = tokenReader.mCursor; - // If the next token after we parsed a type looks like - // we are going to declare a variable, then lets guess - // that this is a declaration. - // - // TODO(tfoley): this wouldn't be robust for more - // general kinds of declarators (notably pointer declarators), - // so we'll need to be careful about this. - if (LookAheadToken(TokenType::Identifier)) - { - // Reset the cursor and try to parse a declaration now. - // Note: the declaration will consume any modifiers - // that had been in place on the statement. - tokenReader.mCursor = startPos; - statement = ParseVarDeclrStatement(modifiers); - return statement; - } + // Try to parse a type (knowing that the type grammar is + // a subset of the expression grammar, and so this should + // always succeed). + RefPtr<ExpressionSyntaxNode> type = ParseType(); + // We don't actually care about the type, though, so + // don't retain it + type = nullptr; - // Fallback: reset and parse an expression + // If the next token after we parsed a type looks like + // we are going to declare a variable, then lets guess + // that this is a declaration. + // + // TODO(tfoley): this wouldn't be robust for more + // general kinds of declarators (notably pointer declarators), + // so we'll need to be careful about this. + if (LookAheadToken(TokenType::Identifier)) + { + // Reset the cursor and try to parse a declaration now. + // Note: the declaration will consume any modifiers + // that had been in place on the statement. tokenReader.mCursor = startPos; - statement = ParseExpressionStatement(); - } - else if (LookAheadToken(TokenType::Semicolon)) - { - statement = new EmptyStatementSyntaxNode(); - FillPosition(statement.Ptr()); - ReadToken(TokenType::Semicolon); - } - else - { - // Default case should always fall back to parsing an expression, - // and then let that detect any errors - statement = ParseExpressionStatement(); + statement = ParseVarDeclrStatement(modifiers); + return statement; } - if (statement) - { - // Install any modifiers onto the statement. - // Note: this path is bypassed in the case of a - // declaration statement, so we don't end up - // doubling up the modifiers. - statement->modifiers = modifiers; - } + // Fallback: reset and parse an expression + tokenReader.mCursor = startPos; + statement = ParseExpressionStatement(); + } + else if (LookAheadToken(TokenType::Semicolon)) + { + statement = new EmptyStatementSyntaxNode(); + FillPosition(statement.Ptr()); + ReadToken(TokenType::Semicolon); + } + else + { + // Default case should always fall back to parsing an expression, + // and then let that detect any errors + statement = ParseExpressionStatement(); + } - return statement; + if (statement) + { + // Install any modifiers onto the statement. + // Note: this path is bypassed in the case of a + // declaration statement, so we don't end up + // doubling up the modifiers. + statement->modifiers = modifiers; } - RefPtr<StatementSyntaxNode> Parser::ParseBlockStatement() + return statement; + } + + RefPtr<StatementSyntaxNode> Parser::ParseBlockStatement() + { + if( options.flags & SLANG_COMPILE_FLAG_NO_CHECKING ) { - if( options.flags & SLANG_COMPILE_FLAG_NO_CHECKING ) - { - // We have been asked to parse the input, but not attempt to understand it. + // We have been asked to parse the input, but not attempt to understand it. - // TODO: record start/end locations... + // TODO: record start/end locations... - List<Token> tokens; + List<Token> tokens; - ReadToken(TokenType::LBrace); + ReadToken(TokenType::LBrace); - int depth = 1; - for( ;;) + int depth = 1; + for( ;;) + { + switch( tokenReader.PeekTokenType() ) { - switch( tokenReader.PeekTokenType() ) - { - case TokenType::EndOfFile: - goto done; - - case TokenType::RBrace: - depth--; - if(depth == 0) - goto done; - break; + case TokenType::EndOfFile: + goto done; - case TokenType::LBrace: - depth++; - break; + case TokenType::RBrace: + depth--; + if(depth == 0) + goto done; + break; - default: - break; - } + case TokenType::LBrace: + depth++; + break; - auto token = tokenReader.AdvanceToken(); - tokens.Add(token); + default: + break; } - done: - ReadToken(TokenType::RBrace); - RefPtr<UnparsedStmt> unparsedStmt = new UnparsedStmt(); - unparsedStmt->tokens = tokens; - return unparsedStmt; + auto token = tokenReader.AdvanceToken(); + tokens.Add(token); } + done: + ReadToken(TokenType::RBrace); - - RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); - RefPtr<BlockStatementSyntaxNode> blockStatement = new BlockStatementSyntaxNode(); - blockStatement->scopeDecl = scopeDecl; - PushScope(scopeDecl.Ptr()); - ReadToken(TokenType::LBrace); - if(!tokenReader.IsAtEnd()) - { - FillPosition(blockStatement.Ptr()); - } - while (!AdvanceIfMatch(this, TokenType::RBrace)) - { - auto stmt = ParseStatement(); - if(stmt) - { - blockStatement->Statements.Add(stmt); - } - TryRecover(this); - } - PopScope(); - return blockStatement; + RefPtr<UnparsedStmt> unparsedStmt = new UnparsedStmt(); + unparsedStmt->tokens = tokens; + return unparsedStmt; } - RefPtr<VarDeclrStatementSyntaxNode> Parser::ParseVarDeclrStatement( - Modifiers modifiers) + + RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); + RefPtr<BlockStatementSyntaxNode> blockStatement = new BlockStatementSyntaxNode(); + blockStatement->scopeDecl = scopeDecl; + PushScope(scopeDecl.Ptr()); + ReadToken(TokenType::LBrace); + if(!tokenReader.IsAtEnd()) { - RefPtr<VarDeclrStatementSyntaxNode>varDeclrStatement = new VarDeclrStatementSyntaxNode(); - - FillPosition(varDeclrStatement.Ptr()); - auto decl = ParseDeclWithModifiers(this, currentScope->containerDecl, modifiers); - varDeclrStatement->decl = decl; - return varDeclrStatement; + FillPosition(blockStatement.Ptr()); } - - RefPtr<IfStatementSyntaxNode> Parser::ParseIfStatement() + while (!AdvanceIfMatch(this, TokenType::RBrace)) { - RefPtr<IfStatementSyntaxNode> ifStatement = new IfStatementSyntaxNode(); - FillPosition(ifStatement.Ptr()); - ReadToken("if"); - ReadToken(TokenType::LParent); - ifStatement->Predicate = ParseExpression(); - ReadToken(TokenType::RParent); - ifStatement->PositiveStatement = ParseStatement(); - if (LookAheadToken("else")) + auto stmt = ParseStatement(); + if(stmt) { - ReadToken("else"); - ifStatement->NegativeStatement = ParseStatement(); + blockStatement->Statements.Add(stmt); } - return ifStatement; + TryRecover(this); } + PopScope(); + return blockStatement; + } - RefPtr<ForStatementSyntaxNode> Parser::ParseForStatement() - { - RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); - RefPtr<ForStatementSyntaxNode> stmt = new ForStatementSyntaxNode(); - stmt->scopeDecl = scopeDecl; + RefPtr<VarDeclrStatementSyntaxNode> Parser::ParseVarDeclrStatement( + Modifiers modifiers) + { + RefPtr<VarDeclrStatementSyntaxNode>varDeclrStatement = new VarDeclrStatementSyntaxNode(); + + FillPosition(varDeclrStatement.Ptr()); + auto decl = ParseDeclWithModifiers(this, currentScope->containerDecl, modifiers); + varDeclrStatement->decl = decl; + return varDeclrStatement; + } + + RefPtr<IfStatementSyntaxNode> Parser::ParseIfStatement() + { + RefPtr<IfStatementSyntaxNode> ifStatement = new IfStatementSyntaxNode(); + FillPosition(ifStatement.Ptr()); + ReadToken("if"); + ReadToken(TokenType::LParent); + ifStatement->Predicate = ParseExpression(); + ReadToken(TokenType::RParent); + ifStatement->PositiveStatement = ParseStatement(); + if (LookAheadToken("else")) + { + ReadToken("else"); + ifStatement->NegativeStatement = ParseStatement(); + } + return ifStatement; + } + + RefPtr<ForStatementSyntaxNode> Parser::ParseForStatement() + { + RefPtr<ScopeDecl> scopeDecl = new ScopeDecl(); + RefPtr<ForStatementSyntaxNode> stmt = new ForStatementSyntaxNode(); + stmt->scopeDecl = scopeDecl; - // Note(tfoley): HLSL implements `for` with incorrect scoping. - // We need an option to turn on this behavior in a kind of "legacy" mode + // Note(tfoley): HLSL implements `for` with incorrect scoping. + // We need an option to turn on this behavior in a kind of "legacy" mode // PushScope(scopeDecl.Ptr()); - FillPosition(stmt.Ptr()); - ReadToken("for"); - ReadToken(TokenType::LParent); - if (peekTypeName(this)) + FillPosition(stmt.Ptr()); + ReadToken("for"); + ReadToken(TokenType::LParent); + if (peekTypeName(this)) + { + stmt->InitialStatement = ParseVarDeclrStatement(Modifiers()); + } + else + { + if (!LookAheadToken(TokenType::Semicolon)) { - stmt->InitialStatement = ParseVarDeclrStatement(Modifiers()); + stmt->InitialStatement = ParseExpressionStatement(); } else { - if (!LookAheadToken(TokenType::Semicolon)) - { - stmt->InitialStatement = ParseExpressionStatement(); - } - else - { - ReadToken(TokenType::Semicolon); - } + ReadToken(TokenType::Semicolon); } - if (!LookAheadToken(TokenType::Semicolon)) - stmt->PredicateExpression = ParseExpression(); - ReadToken(TokenType::Semicolon); - if (!LookAheadToken(TokenType::RParent)) - stmt->SideEffectExpression = ParseExpression(); - ReadToken(TokenType::RParent); - stmt->Statement = ParseStatement(); -// PopScope(); - return stmt; } + if (!LookAheadToken(TokenType::Semicolon)) + stmt->PredicateExpression = ParseExpression(); + ReadToken(TokenType::Semicolon); + if (!LookAheadToken(TokenType::RParent)) + stmt->SideEffectExpression = ParseExpression(); + ReadToken(TokenType::RParent); + stmt->Statement = ParseStatement(); +// PopScope(); + return stmt; + } - RefPtr<WhileStatementSyntaxNode> Parser::ParseWhileStatement() - { - RefPtr<WhileStatementSyntaxNode> whileStatement = new WhileStatementSyntaxNode(); - FillPosition(whileStatement.Ptr()); - ReadToken("while"); - ReadToken(TokenType::LParent); - whileStatement->Predicate = ParseExpression(); - ReadToken(TokenType::RParent); - whileStatement->Statement = ParseStatement(); - return whileStatement; - } + RefPtr<WhileStatementSyntaxNode> Parser::ParseWhileStatement() + { + RefPtr<WhileStatementSyntaxNode> whileStatement = new WhileStatementSyntaxNode(); + FillPosition(whileStatement.Ptr()); + ReadToken("while"); + ReadToken(TokenType::LParent); + whileStatement->Predicate = ParseExpression(); + ReadToken(TokenType::RParent); + whileStatement->Statement = ParseStatement(); + return whileStatement; + } - RefPtr<DoWhileStatementSyntaxNode> Parser::ParseDoWhileStatement() - { - RefPtr<DoWhileStatementSyntaxNode> doWhileStatement = new DoWhileStatementSyntaxNode(); - FillPosition(doWhileStatement.Ptr()); - ReadToken("do"); - doWhileStatement->Statement = ParseStatement(); - ReadToken("while"); - ReadToken(TokenType::LParent); - doWhileStatement->Predicate = ParseExpression(); - ReadToken(TokenType::RParent); - ReadToken(TokenType::Semicolon); - return doWhileStatement; - } + RefPtr<DoWhileStatementSyntaxNode> Parser::ParseDoWhileStatement() + { + RefPtr<DoWhileStatementSyntaxNode> doWhileStatement = new DoWhileStatementSyntaxNode(); + FillPosition(doWhileStatement.Ptr()); + ReadToken("do"); + doWhileStatement->Statement = ParseStatement(); + ReadToken("while"); + ReadToken(TokenType::LParent); + doWhileStatement->Predicate = ParseExpression(); + ReadToken(TokenType::RParent); + ReadToken(TokenType::Semicolon); + return doWhileStatement; + } - RefPtr<BreakStatementSyntaxNode> Parser::ParseBreakStatement() - { - RefPtr<BreakStatementSyntaxNode> breakStatement = new BreakStatementSyntaxNode(); - FillPosition(breakStatement.Ptr()); - ReadToken("break"); - ReadToken(TokenType::Semicolon); - return breakStatement; - } + RefPtr<BreakStatementSyntaxNode> Parser::ParseBreakStatement() + { + RefPtr<BreakStatementSyntaxNode> breakStatement = new BreakStatementSyntaxNode(); + FillPosition(breakStatement.Ptr()); + ReadToken("break"); + ReadToken(TokenType::Semicolon); + return breakStatement; + } - RefPtr<ContinueStatementSyntaxNode> Parser::ParseContinueStatement() - { - RefPtr<ContinueStatementSyntaxNode> continueStatement = new ContinueStatementSyntaxNode(); - FillPosition(continueStatement.Ptr()); - ReadToken("continue"); - ReadToken(TokenType::Semicolon); - return continueStatement; - } + RefPtr<ContinueStatementSyntaxNode> Parser::ParseContinueStatement() + { + RefPtr<ContinueStatementSyntaxNode> continueStatement = new ContinueStatementSyntaxNode(); + FillPosition(continueStatement.Ptr()); + ReadToken("continue"); + ReadToken(TokenType::Semicolon); + return continueStatement; + } - RefPtr<ReturnStatementSyntaxNode> Parser::ParseReturnStatement() - { - RefPtr<ReturnStatementSyntaxNode> returnStatement = new ReturnStatementSyntaxNode(); - FillPosition(returnStatement.Ptr()); - ReadToken("return"); - if (!LookAheadToken(TokenType::Semicolon)) - returnStatement->Expression = ParseExpression(); - ReadToken(TokenType::Semicolon); - return returnStatement; - } + RefPtr<ReturnStatementSyntaxNode> Parser::ParseReturnStatement() + { + RefPtr<ReturnStatementSyntaxNode> returnStatement = new ReturnStatementSyntaxNode(); + FillPosition(returnStatement.Ptr()); + ReadToken("return"); + if (!LookAheadToken(TokenType::Semicolon)) + returnStatement->Expression = ParseExpression(); + ReadToken(TokenType::Semicolon); + return returnStatement; + } - RefPtr<ExpressionStatementSyntaxNode> Parser::ParseExpressionStatement() - { - RefPtr<ExpressionStatementSyntaxNode> statement = new ExpressionStatementSyntaxNode(); + RefPtr<ExpressionStatementSyntaxNode> Parser::ParseExpressionStatement() + { + RefPtr<ExpressionStatementSyntaxNode> statement = new ExpressionStatementSyntaxNode(); - FillPosition(statement.Ptr()); - statement->Expression = ParseExpression(); + FillPosition(statement.Ptr()); + statement->Expression = ParseExpression(); - ReadToken(TokenType::Semicolon); - return statement; - } + ReadToken(TokenType::Semicolon); + return statement; + } - RefPtr<ParameterSyntaxNode> Parser::ParseParameter() - { - RefPtr<ParameterSyntaxNode> parameter = new ParameterSyntaxNode(); - parameter->modifiers = ParseModifiers(this); + RefPtr<ParameterSyntaxNode> Parser::ParseParameter() + { + RefPtr<ParameterSyntaxNode> parameter = new ParameterSyntaxNode(); + parameter->modifiers = ParseModifiers(this); - DeclaratorInfo declaratorInfo; - declaratorInfo.typeSpec = ParseType(); + DeclaratorInfo declaratorInfo; + declaratorInfo.typeSpec = ParseType(); - InitDeclarator initDeclarator = ParseInitDeclarator(this); - UnwrapDeclarator(initDeclarator, &declaratorInfo); + InitDeclarator initDeclarator = ParseInitDeclarator(this); + UnwrapDeclarator(initDeclarator, &declaratorInfo); - // Assume it is a variable-like declarator - CompleteVarDecl(this, parameter, declaratorInfo); - return parameter; - } + // Assume it is a variable-like declarator + CompleteVarDecl(this, parameter, declaratorInfo); + return parameter; + } - RefPtr<ExpressionSyntaxNode> Parser::ParseType() + RefPtr<ExpressionSyntaxNode> Parser::ParseType() + { + auto typeSpec = parseTypeSpec(this); + if( typeSpec.decl ) { - auto typeSpec = parseTypeSpec(this); - if( typeSpec.decl ) - { - AddMember(currentScope, typeSpec.decl); - } - auto typeExpr = typeSpec.expr; + AddMember(currentScope, typeSpec.decl); + } + auto typeExpr = typeSpec.expr; - while (LookAheadToken(TokenType::LBracket)) + while (LookAheadToken(TokenType::LBracket)) + { + RefPtr<IndexExpressionSyntaxNode> arrType = new IndexExpressionSyntaxNode(); + arrType->Position = typeExpr->Position; + arrType->BaseExpression = typeExpr; + ReadToken(TokenType::LBracket); + if (!LookAheadToken(TokenType::RBracket)) { - RefPtr<IndexExpressionSyntaxNode> arrType = new IndexExpressionSyntaxNode(); - arrType->Position = typeExpr->Position; - arrType->BaseExpression = typeExpr; - ReadToken(TokenType::LBracket); - if (!LookAheadToken(TokenType::RBracket)) - { - arrType->IndexExpression = ParseExpression(); - } - ReadToken(TokenType::RBracket); - typeExpr = arrType; + arrType->IndexExpression = ParseExpression(); } - - return typeExpr; + ReadToken(TokenType::RBracket); + typeExpr = arrType; } + return typeExpr; + } - TypeExp Parser::ParseTypeExp() - { - return TypeExp(ParseType()); - } - enum class Associativity - { - Left, Right - }; + TypeExp Parser::ParseTypeExp() + { + return TypeExp(ParseType()); + } + enum class Associativity + { + Left, Right + }; - Associativity GetAssociativityFromLevel(Precedence level) - { - if (level == Precedence::Assignment) - return Associativity::Right; - else - return Associativity::Left; - } + Associativity GetAssociativityFromLevel(Precedence level) + { + if (level == Precedence::Assignment) + return Associativity::Right; + else + return Associativity::Left; + } - Precedence GetOpLevel(Parser* parser, TokenType type) - { - switch(type) - { - case TokenType::Comma: - return Precedence::Comma; - case TokenType::OpAssign: - case TokenType::OpMulAssign: - case TokenType::OpDivAssign: - case TokenType::OpAddAssign: - case TokenType::OpSubAssign: - case TokenType::OpModAssign: - case TokenType::OpShlAssign: - case TokenType::OpShrAssign: - case TokenType::OpOrAssign: - case TokenType::OpAndAssign: - case TokenType::OpXorAssign: - return Precedence::Assignment; - case TokenType::OpOr: - return Precedence::LogicalOr; - case TokenType::OpAnd: - return Precedence::LogicalAnd; - case TokenType::OpBitOr: - return Precedence::BitOr; - case TokenType::OpBitXor: - return Precedence::BitXor; - case TokenType::OpBitAnd: - return Precedence::BitAnd; - case TokenType::OpEql: - case TokenType::OpNeq: - return Precedence::EqualityComparison; - case TokenType::OpGreater: - case TokenType::OpGeq: - // Don't allow these ops inside a generic argument - if (parser->genericDepth > 0) return Precedence::Invalid; - case TokenType::OpLeq: - case TokenType::OpLess: - return Precedence::RelationalComparison; - case TokenType::OpRsh: - // Don't allow this op inside a generic argument - if (parser->genericDepth > 0) return Precedence::Invalid; - case TokenType::OpLsh: - return Precedence::BitShift; - case TokenType::OpAdd: - case TokenType::OpSub: - return Precedence::Additive; - case TokenType::OpMul: - case TokenType::OpDiv: - case TokenType::OpMod: - return Precedence::Multiplicative; - default: - return Precedence::Invalid; - } + + Precedence GetOpLevel(Parser* parser, TokenType type) + { + switch(type) + { + case TokenType::Comma: + return Precedence::Comma; + case TokenType::OpAssign: + case TokenType::OpMulAssign: + case TokenType::OpDivAssign: + case TokenType::OpAddAssign: + case TokenType::OpSubAssign: + case TokenType::OpModAssign: + case TokenType::OpShlAssign: + case TokenType::OpShrAssign: + case TokenType::OpOrAssign: + case TokenType::OpAndAssign: + case TokenType::OpXorAssign: + return Precedence::Assignment; + case TokenType::OpOr: + return Precedence::LogicalOr; + case TokenType::OpAnd: + return Precedence::LogicalAnd; + case TokenType::OpBitOr: + return Precedence::BitOr; + case TokenType::OpBitXor: + return Precedence::BitXor; + case TokenType::OpBitAnd: + return Precedence::BitAnd; + case TokenType::OpEql: + case TokenType::OpNeq: + return Precedence::EqualityComparison; + case TokenType::OpGreater: + case TokenType::OpGeq: + // Don't allow these ops inside a generic argument + if (parser->genericDepth > 0) return Precedence::Invalid; + case TokenType::OpLeq: + case TokenType::OpLess: + return Precedence::RelationalComparison; + case TokenType::OpRsh: + // Don't allow this op inside a generic argument + if (parser->genericDepth > 0) return Precedence::Invalid; + case TokenType::OpLsh: + return Precedence::BitShift; + case TokenType::OpAdd: + case TokenType::OpSub: + return Precedence::Additive; + case TokenType::OpMul: + case TokenType::OpDiv: + case TokenType::OpMod: + return Precedence::Multiplicative; + default: + return Precedence::Invalid; } + } - Operator GetOpFromToken(Token & token) - { - switch(token.Type) - { - case TokenType::Comma: - return Operator::Sequence; - case TokenType::OpAssign: - return Operator::Assign; - case TokenType::OpAddAssign: - return Operator::AddAssign; - case TokenType::OpSubAssign: - return Operator::SubAssign; - case TokenType::OpMulAssign: - return Operator::MulAssign; - case TokenType::OpDivAssign: - return Operator::DivAssign; - case TokenType::OpModAssign: - return Operator::ModAssign; - case TokenType::OpShlAssign: - return Operator::LshAssign; - case TokenType::OpShrAssign: - return Operator::RshAssign; - case TokenType::OpOrAssign: - return Operator::OrAssign; - case TokenType::OpAndAssign: - return Operator::AddAssign; - case TokenType::OpXorAssign: - return Operator::XorAssign; - case TokenType::OpOr: - return Operator::Or; - case TokenType::OpAnd: - return Operator::And; - case TokenType::OpBitOr: - return Operator::BitOr; - case TokenType::OpBitXor: - return Operator::BitXor; - case TokenType::OpBitAnd: - return Operator::BitAnd; - case TokenType::OpEql: - return Operator::Eql; - case TokenType::OpNeq: - return Operator::Neq; - case TokenType::OpGeq: - return Operator::Geq; - case TokenType::OpLeq: - return Operator::Leq; - case TokenType::OpGreater: - return Operator::Greater; - case TokenType::OpLess: - return Operator::Less; - case TokenType::OpLsh: - return Operator::Lsh; - case TokenType::OpRsh: - return Operator::Rsh; - case TokenType::OpAdd: - return Operator::Add; - case TokenType::OpSub: - return Operator::Sub; - case TokenType::OpMul: - return Operator::Mul; - case TokenType::OpDiv: - return Operator::Div; - case TokenType::OpMod: - return Operator::Mod; - case TokenType::OpInc: - return Operator::PostInc; - case TokenType::OpDec: - return Operator::PostDec; - case TokenType::OpNot: - return Operator::Not; - case TokenType::OpBitNot: - return Operator::BitNot; - default: - throw "Illegal TokenType."; - } + Operator GetOpFromToken(Token & token) + { + switch(token.Type) + { + case TokenType::Comma: + return Operator::Sequence; + case TokenType::OpAssign: + return Operator::Assign; + case TokenType::OpAddAssign: + return Operator::AddAssign; + case TokenType::OpSubAssign: + return Operator::SubAssign; + case TokenType::OpMulAssign: + return Operator::MulAssign; + case TokenType::OpDivAssign: + return Operator::DivAssign; + case TokenType::OpModAssign: + return Operator::ModAssign; + case TokenType::OpShlAssign: + return Operator::LshAssign; + case TokenType::OpShrAssign: + return Operator::RshAssign; + case TokenType::OpOrAssign: + return Operator::OrAssign; + case TokenType::OpAndAssign: + return Operator::AddAssign; + case TokenType::OpXorAssign: + return Operator::XorAssign; + case TokenType::OpOr: + return Operator::Or; + case TokenType::OpAnd: + return Operator::And; + case TokenType::OpBitOr: + return Operator::BitOr; + case TokenType::OpBitXor: + return Operator::BitXor; + case TokenType::OpBitAnd: + return Operator::BitAnd; + case TokenType::OpEql: + return Operator::Eql; + case TokenType::OpNeq: + return Operator::Neq; + case TokenType::OpGeq: + return Operator::Geq; + case TokenType::OpLeq: + return Operator::Leq; + case TokenType::OpGreater: + return Operator::Greater; + case TokenType::OpLess: + return Operator::Less; + case TokenType::OpLsh: + return Operator::Lsh; + case TokenType::OpRsh: + return Operator::Rsh; + case TokenType::OpAdd: + return Operator::Add; + case TokenType::OpSub: + return Operator::Sub; + case TokenType::OpMul: + return Operator::Mul; + case TokenType::OpDiv: + return Operator::Div; + case TokenType::OpMod: + return Operator::Mod; + case TokenType::OpInc: + return Operator::PostInc; + case TokenType::OpDec: + return Operator::PostDec; + case TokenType::OpNot: + return Operator::Not; + case TokenType::OpBitNot: + return Operator::BitNot; + default: + throw "Illegal TokenType."; } + } - static RefPtr<ExpressionSyntaxNode> parseOperator(Parser* parser) + static RefPtr<ExpressionSyntaxNode> parseOperator(Parser* parser) + { + Token opToken; + switch(parser->tokenReader.PeekTokenType()) { - Token opToken; - switch(parser->tokenReader.PeekTokenType()) - { - case TokenType::QuestionMark: - opToken = parser->ReadToken(); - opToken.Content = "?:"; - break; + case TokenType::QuestionMark: + opToken = parser->ReadToken(); + opToken.Content = "?:"; + break; - default: - opToken = parser->ReadToken(); - break; - } + default: + opToken = parser->ReadToken(); + break; + } - auto opExpr = new VarExpressionSyntaxNode(); - opExpr->Variable = opToken.Content; - opExpr->scope = parser->currentScope; - opExpr->Position = opToken.Position; + auto opExpr = new VarExpressionSyntaxNode(); + opExpr->Variable = opToken.Content; + opExpr->scope = parser->currentScope; + opExpr->Position = opToken.Position; - return opExpr; + return opExpr; - } + } - RefPtr<ExpressionSyntaxNode> Parser::ParseExpression(Precedence level) + RefPtr<ExpressionSyntaxNode> Parser::ParseExpression(Precedence level) + { + if (level == Precedence::Prefix) + return ParseLeafExpression(); + if (level == Precedence::TernaryConditional) { - if (level == Precedence::Prefix) - return ParseLeafExpression(); - if (level == Precedence::TernaryConditional) + // parse select clause + auto condition = ParseExpression(Precedence(level + 1)); + if (LookAheadToken(TokenType::QuestionMark)) { - // parse select clause - auto condition = ParseExpression(Precedence(level + 1)); - if (LookAheadToken(TokenType::QuestionMark)) - { - RefPtr<SelectExpressionSyntaxNode> select = new SelectExpressionSyntaxNode(); - FillPosition(select.Ptr()); + RefPtr<SelectExpressionSyntaxNode> select = new SelectExpressionSyntaxNode(); + FillPosition(select.Ptr()); - select->Arguments.Add(condition); + select->Arguments.Add(condition); - select->FunctionExpr = parseOperator(this); + select->FunctionExpr = parseOperator(this); - select->Arguments.Add(ParseExpression(level)); - ReadToken(TokenType::Colon); - select->Arguments.Add(ParseExpression(level)); - return select; - } - else - return condition; + select->Arguments.Add(ParseExpression(level)); + ReadToken(TokenType::Colon); + select->Arguments.Add(ParseExpression(level)); + return select; } else + return condition; + } + else + { + if (GetAssociativityFromLevel(level) == Associativity::Left) { - if (GetAssociativityFromLevel(level) == Associativity::Left) + auto left = ParseExpression(Precedence(level + 1)); + while (GetOpLevel(this, tokenReader.PeekTokenType()) == level) { - auto left = ParseExpression(Precedence(level + 1)); - while (GetOpLevel(this, tokenReader.PeekTokenType()) == level) - { - RefPtr<OperatorExpressionSyntaxNode> tmp = new InfixExpr(); - tmp->FunctionExpr = parseOperator(this); + RefPtr<OperatorExpressionSyntaxNode> tmp = new InfixExpr(); + tmp->FunctionExpr = parseOperator(this); - tmp->Arguments.Add(left); - FillPosition(tmp.Ptr()); - tmp->Arguments.Add(ParseExpression(Precedence(level + 1))); - left = tmp; - } - return left; + tmp->Arguments.Add(left); + FillPosition(tmp.Ptr()); + tmp->Arguments.Add(ParseExpression(Precedence(level + 1))); + left = tmp; } - else + return left; + } + else + { + auto left = ParseExpression(Precedence(level + 1)); + if (GetOpLevel(this, tokenReader.PeekTokenType()) == level) { - auto left = ParseExpression(Precedence(level + 1)); - if (GetOpLevel(this, tokenReader.PeekTokenType()) == level) - { - RefPtr<OperatorExpressionSyntaxNode> tmp = new InfixExpr(); - tmp->Arguments.Add(left); - FillPosition(tmp.Ptr()); - tmp->FunctionExpr = parseOperator(this); - tmp->Arguments.Add(ParseExpression(level)); - left = tmp; - } - return left; + RefPtr<OperatorExpressionSyntaxNode> tmp = new InfixExpr(); + tmp->Arguments.Add(left); + FillPosition(tmp.Ptr()); + tmp->FunctionExpr = parseOperator(this); + tmp->Arguments.Add(ParseExpression(level)); + left = tmp; } + return left; } } + } + + RefPtr<ExpressionSyntaxNode> Parser::ParseLeafExpression() + { + RefPtr<ExpressionSyntaxNode> rs; + if (LookAheadToken(TokenType::OpInc) || + LookAheadToken(TokenType::OpDec) || + LookAheadToken(TokenType::OpNot) || + LookAheadToken(TokenType::OpBitNot) || + LookAheadToken(TokenType::OpSub)) + { + RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PrefixExpr(); + FillPosition(unaryExpr.Ptr()); + unaryExpr->FunctionExpr = parseOperator(this); + unaryExpr->Arguments.Add(ParseLeafExpression()); + rs = unaryExpr; + return rs; + } - RefPtr<ExpressionSyntaxNode> Parser::ParseLeafExpression() + if (LookAheadToken(TokenType::LParent)) { - RefPtr<ExpressionSyntaxNode> rs; - if (LookAheadToken(TokenType::OpInc) || - LookAheadToken(TokenType::OpDec) || - LookAheadToken(TokenType::OpNot) || - LookAheadToken(TokenType::OpBitNot) || - LookAheadToken(TokenType::OpSub)) + ReadToken(TokenType::LParent); + RefPtr<ExpressionSyntaxNode> expr; + if (peekTypeName(this) && LookAheadToken(TokenType::RParent, 1)) { - RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PrefixExpr(); - FillPosition(unaryExpr.Ptr()); - unaryExpr->FunctionExpr = parseOperator(this); - unaryExpr->Arguments.Add(ParseLeafExpression()); - rs = unaryExpr; - return rs; + RefPtr<TypeCastExpressionSyntaxNode> tcexpr = new TypeCastExpressionSyntaxNode(); + FillPosition(tcexpr.Ptr()); + tcexpr->TargetType = ParseTypeExp(); + ReadToken(TokenType::RParent); + tcexpr->Expression = ParseExpression(Precedence::Multiplicative); // Note(tfoley): need to double-check this + expr = tcexpr; } - - if (LookAheadToken(TokenType::LParent)) + else { - ReadToken(TokenType::LParent); - RefPtr<ExpressionSyntaxNode> expr; - if (peekTypeName(this) && LookAheadToken(TokenType::RParent, 1)) - { - RefPtr<TypeCastExpressionSyntaxNode> tcexpr = new TypeCastExpressionSyntaxNode(); - FillPosition(tcexpr.Ptr()); - tcexpr->TargetType = ParseTypeExp(); - ReadToken(TokenType::RParent); - tcexpr->Expression = ParseExpression(Precedence::Multiplicative); // Note(tfoley): need to double-check this - expr = tcexpr; - } - else - { - expr = ParseExpression(); - ReadToken(TokenType::RParent); - } - rs = expr; + expr = ParseExpression(); + ReadToken(TokenType::RParent); } - else if( LookAheadToken(TokenType::LBrace) ) - { - RefPtr<InitializerListExpr> initExpr = new InitializerListExpr(); - FillPosition(initExpr.Ptr()); + rs = expr; + } + else if( LookAheadToken(TokenType::LBrace) ) + { + RefPtr<InitializerListExpr> initExpr = new InitializerListExpr(); + FillPosition(initExpr.Ptr()); + + // Initializer list + ReadToken(TokenType::LBrace); - // Initializer list - ReadToken(TokenType::LBrace); + List<RefPtr<ExpressionSyntaxNode>> exprs; - List<RefPtr<ExpressionSyntaxNode>> exprs; + for(;;) + { + if(AdvanceIfMatch(this, TokenType::RBrace)) + break; - for(;;) + auto expr = ParseArgExpr(); + if( expr ) { - if(AdvanceIfMatch(this, TokenType::RBrace)) - break; + initExpr->args.Add(expr); + } - auto expr = ParseArgExpr(); - if( expr ) - { - initExpr->args.Add(expr); - } + if(AdvanceIfMatch(this, TokenType::RBrace)) + break; - if(AdvanceIfMatch(this, TokenType::RBrace)) - break; + ReadToken(TokenType::Comma); + } + rs = initExpr; + } - ReadToken(TokenType::Comma); - } - rs = initExpr; + else if (LookAheadToken(TokenType::IntLiterial) || + LookAheadToken(TokenType::DoubleLiterial)) + { + RefPtr<ConstantExpressionSyntaxNode> constExpr = new ConstantExpressionSyntaxNode(); + auto token = tokenReader.AdvanceToken(); + FillPosition(constExpr.Ptr()); + if (token.Type == TokenType::IntLiterial) + { + constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Int; + constExpr->IntValue = StringToInt(token.Content); + } + else if (token.Type == TokenType::DoubleLiterial) + { + constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Float; + constExpr->FloatValue = (FloatingPointLiteralValue) StringToDouble(token.Content); } + rs = constExpr; + } + else if (LookAheadToken("true") || LookAheadToken("false")) + { + RefPtr<ConstantExpressionSyntaxNode> constExpr = new ConstantExpressionSyntaxNode(); + auto token = tokenReader.AdvanceToken(); + FillPosition(constExpr.Ptr()); + constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Bool; + constExpr->IntValue = token.Content == "true" ? 1 : 0; + rs = constExpr; + } + else if (LookAheadToken(TokenType::Identifier)) + { + RefPtr<VarExpressionSyntaxNode> varExpr = new VarExpressionSyntaxNode(); + varExpr->scope = currentScope.Ptr(); + FillPosition(varExpr.Ptr()); + auto token = ReadToken(TokenType::Identifier); + varExpr->Variable = token.Content; + rs = varExpr; + } - else if (LookAheadToken(TokenType::IntLiterial) || - LookAheadToken(TokenType::DoubleLiterial)) + while (!tokenReader.IsAtEnd() && + (LookAheadToken(TokenType::OpInc) || + LookAheadToken(TokenType::OpDec) || + LookAheadToken(TokenType::Dot) || + LookAheadToken(TokenType::LBracket) || + LookAheadToken(TokenType::LParent))) + { + if (LookAheadToken(TokenType::OpInc)) { - RefPtr<ConstantExpressionSyntaxNode> constExpr = new ConstantExpressionSyntaxNode(); - auto token = tokenReader.AdvanceToken(); - FillPosition(constExpr.Ptr()); - if (token.Type == TokenType::IntLiterial) - { - constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Int; - constExpr->IntValue = StringToInt(token.Content); - } - else if (token.Type == TokenType::DoubleLiterial) - { - constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Float; - constExpr->FloatValue = (FloatingPointLiteralValue) StringToDouble(token.Content); - } - rs = constExpr; + RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PostfixExpr(); + FillPosition(unaryExpr.Ptr()); + unaryExpr->FunctionExpr = parseOperator(this); + unaryExpr->Arguments.Add(rs); + rs = unaryExpr; } - else if (LookAheadToken("true") || LookAheadToken("false")) + else if (LookAheadToken(TokenType::OpDec)) { - RefPtr<ConstantExpressionSyntaxNode> constExpr = new ConstantExpressionSyntaxNode(); - auto token = tokenReader.AdvanceToken(); - FillPosition(constExpr.Ptr()); - constExpr->ConstType = ConstantExpressionSyntaxNode::ConstantType::Bool; - constExpr->IntValue = token.Content == "true" ? 1 : 0; - rs = constExpr; + RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PostfixExpr(); + FillPosition(unaryExpr.Ptr()); + unaryExpr->FunctionExpr = parseOperator(this); + unaryExpr->Arguments.Add(rs); + rs = unaryExpr; } - else if (LookAheadToken(TokenType::Identifier)) + else if (LookAheadToken(TokenType::LBracket)) { - RefPtr<VarExpressionSyntaxNode> varExpr = new VarExpressionSyntaxNode(); - varExpr->scope = currentScope.Ptr(); - FillPosition(varExpr.Ptr()); - auto token = ReadToken(TokenType::Identifier); - varExpr->Variable = token.Content; - rs = varExpr; + RefPtr<IndexExpressionSyntaxNode> indexExpr = new IndexExpressionSyntaxNode(); + indexExpr->BaseExpression = rs; + FillPosition(indexExpr.Ptr()); + ReadToken(TokenType::LBracket); + indexExpr->IndexExpression = ParseExpression(); + ReadToken(TokenType::RBracket); + rs = indexExpr; } - - while (!tokenReader.IsAtEnd() && - (LookAheadToken(TokenType::OpInc) || - LookAheadToken(TokenType::OpDec) || - LookAheadToken(TokenType::Dot) || - LookAheadToken(TokenType::LBracket) || - LookAheadToken(TokenType::LParent))) + else if (LookAheadToken(TokenType::LParent)) { - if (LookAheadToken(TokenType::OpInc)) - { - RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PostfixExpr(); - FillPosition(unaryExpr.Ptr()); - unaryExpr->FunctionExpr = parseOperator(this); - unaryExpr->Arguments.Add(rs); - rs = unaryExpr; - } - else if (LookAheadToken(TokenType::OpDec)) - { - RefPtr<OperatorExpressionSyntaxNode> unaryExpr = new PostfixExpr(); - FillPosition(unaryExpr.Ptr()); - unaryExpr->FunctionExpr = parseOperator(this); - unaryExpr->Arguments.Add(rs); - rs = unaryExpr; - } - else if (LookAheadToken(TokenType::LBracket)) - { - RefPtr<IndexExpressionSyntaxNode> indexExpr = new IndexExpressionSyntaxNode(); - indexExpr->BaseExpression = rs; - FillPosition(indexExpr.Ptr()); - ReadToken(TokenType::LBracket); - indexExpr->IndexExpression = ParseExpression(); - ReadToken(TokenType::RBracket); - rs = indexExpr; - } - else if (LookAheadToken(TokenType::LParent)) + RefPtr<InvokeExpressionSyntaxNode> invokeExpr = new InvokeExpressionSyntaxNode(); + invokeExpr->FunctionExpr = rs; + FillPosition(invokeExpr.Ptr()); + ReadToken(TokenType::LParent); + while (!tokenReader.IsAtEnd()) { - RefPtr<InvokeExpressionSyntaxNode> invokeExpr = new InvokeExpressionSyntaxNode(); - invokeExpr->FunctionExpr = rs; - FillPosition(invokeExpr.Ptr()); - ReadToken(TokenType::LParent); - while (!tokenReader.IsAtEnd()) + if (!LookAheadToken(TokenType::RParent)) + invokeExpr->Arguments.Add(ParseArgExpr()); + else { - if (!LookAheadToken(TokenType::RParent)) - invokeExpr->Arguments.Add(ParseArgExpr()); - else - { - break; - } - if (!LookAheadToken(TokenType::Comma)) - break; - ReadToken(TokenType::Comma); + break; } - ReadToken(TokenType::RParent); - rs = invokeExpr; - } - else if (LookAheadToken(TokenType::Dot)) - { - RefPtr<MemberExpressionSyntaxNode> memberExpr = new MemberExpressionSyntaxNode(); - memberExpr->scope = currentScope.Ptr(); - FillPosition(memberExpr.Ptr()); - memberExpr->BaseExpression = rs; - ReadToken(TokenType::Dot); - memberExpr->MemberName = ReadToken(TokenType::Identifier).Content; - rs = memberExpr; + if (!LookAheadToken(TokenType::Comma)) + break; + ReadToken(TokenType::Comma); } + ReadToken(TokenType::RParent); + rs = invokeExpr; } - if (!rs) + else if (LookAheadToken(TokenType::Dot)) { - sink->diagnose(tokenReader.PeekLoc(), Diagnostics::syntaxError); + RefPtr<MemberExpressionSyntaxNode> memberExpr = new MemberExpressionSyntaxNode(); + memberExpr->scope = currentScope.Ptr(); + FillPosition(memberExpr.Ptr()); + memberExpr->BaseExpression = rs; + ReadToken(TokenType::Dot); + memberExpr->MemberName = ReadToken(TokenType::Identifier).Content; + rs = memberExpr; } - return rs; } - - // Parse a source file into an existing translation unit - void parseSourceFile( - ProgramSyntaxNode* translationUnitSyntax, - CompileOptions& options, - TokenSpan const& tokens, - DiagnosticSink* sink, - String const& fileName, - RefPtr<Scope> const&outerScope) + if (!rs) { - Parser parser(options, tokens, sink, fileName, outerScope); - return parser.parseSourceFile(translationUnitSyntax); + sink->diagnose(tokenReader.PeekLoc(), Diagnostics::syntaxError); } + return rs; + } + + // Parse a source file into an existing translation unit + void parseSourceFile( + ProgramSyntaxNode* translationUnitSyntax, + CompileOptions& options, + TokenSpan const& tokens, + DiagnosticSink* sink, + String const& fileName, + RefPtr<Scope> const&outerScope) + { + Parser parser(options, tokens, sink, fileName, outerScope); + return parser.parseSourceFile(translationUnitSyntax); } -}
\ No newline at end of file +} diff --git a/source/slang/parser.h b/source/slang/parser.h index 90af69158..cc7649e95 100644 --- a/source/slang/parser.h +++ b/source/slang/parser.h @@ -7,17 +7,14 @@ namespace Slang { - namespace Compiler - { - // Parse a source file into an existing translation unit - void parseSourceFile( - ProgramSyntaxNode* translationUnitSyntax, - CompileOptions& options, - TokenSpan const& tokens, - DiagnosticSink* sink, - String const& fileName, - RefPtr<Scope> const&outerScope); - } + // Parse a source file into an existing translation unit + void parseSourceFile( + ProgramSyntaxNode* translationUnitSyntax, + CompileOptions& options, + TokenSpan const& tokens, + DiagnosticSink* sink, + String const& fileName, + RefPtr<Scope> const&outerScope); } #endif
\ No newline at end of file diff --git a/source/slang/preprocessor.cpp b/source/slang/preprocessor.cpp index 60329c275..3bba307db 100644 --- a/source/slang/preprocessor.cpp +++ b/source/slang/preprocessor.cpp @@ -18,7 +18,7 @@ using namespace CoreLib; // idioms for using the preprocessor, found in shader code in the wild. -namespace Slang{ namespace Compiler { +namespace Slang{ // State of a preprocessor conditional, which can change when // we encounter directives like `#elif` or `#endif` @@ -2056,4 +2056,4 @@ TokenList preprocessSource( return tokens; } -}} +} diff --git a/source/slang/preprocessor.h b/source/slang/preprocessor.h index ab72f3f87..cb5a8baae 100644 --- a/source/slang/preprocessor.h +++ b/source/slang/preprocessor.h @@ -5,7 +5,7 @@ #include "../core/basic.h" #include "../slang/lexer.h" -namespace Slang{ namespace Compiler { +namespace Slang { class DiagnosticSink; class ProgramSyntaxNode; @@ -30,6 +30,6 @@ TokenList preprocessSource( CoreLib::Dictionary<CoreLib::String, CoreLib::String> defines, ProgramSyntaxNode* syntax); -}} +} // namespace Slang #endif diff --git a/source/slang/profile.cpp b/source/slang/profile.cpp index 923dc2841..4420a722a 100644 --- a/source/slang/profile.cpp +++ b/source/slang/profile.cpp @@ -1,10 +1,7 @@ // profile.cpp #include "Profile.h" - namespace Slang { -namespace Compiler { - ProfileFamily getProfileFamily(ProfileVersion version) { @@ -17,4 +14,4 @@ ProfileFamily getProfileFamily(ProfileVersion version) } } -}} +} diff --git a/source/slang/profile.h b/source/slang/profile.h index 31465c38c..f67207c9e 100644 --- a/source/slang/profile.h +++ b/source/slang/profile.h @@ -6,79 +6,76 @@ namespace Slang { - namespace Compiler + // Flavors of translation unit + enum class SourceLanguage : SlangSourceLanguage { - // Flavors of translation unit - enum class SourceLanguage : SlangSourceLanguage - { - Unknown = SLANG_SOURCE_LANGUAGE_UNKNOWN, // should not occur - Slang = SLANG_SOURCE_LANGUAGE_SLANG, - HLSL = SLANG_SOURCE_LANGUAGE_HLSL, - GLSL = SLANG_SOURCE_LANGUAGE_GLSL, + Unknown = SLANG_SOURCE_LANGUAGE_UNKNOWN, // should not occur + Slang = SLANG_SOURCE_LANGUAGE_SLANG, + HLSL = SLANG_SOURCE_LANGUAGE_HLSL, + GLSL = SLANG_SOURCE_LANGUAGE_GLSL, - // A separate PACKAGE of Slang code that has been imported - ImportedSlangCode, - }; + // A separate PACKAGE of Slang code that has been imported + ImportedSlangCode, + }; - // TODO(tfoley): This should merge with the above... - enum class Language - { - Unknown, + // TODO(tfoley): This should merge with the above... + enum class Language + { + Unknown, #define LANGUAGE(TAG, NAME) TAG, #include "profile-defs.h" - }; + }; - enum class ProfileFamily - { - Unknown, + enum class ProfileFamily + { + Unknown, #define PROFILE_FAMILY(TAG) TAG, #include "profile-defs.h" - }; + }; - enum class ProfileVersion - { - Unknown, + enum class ProfileVersion + { + Unknown, #define PROFILE_VERSION(TAG, FAMILY) TAG, #include "profile-defs.h" - }; + }; - enum class Stage : SlangStage - { - Unknown = SLANG_STAGE_NONE, + enum class Stage : SlangStage + { + Unknown = SLANG_STAGE_NONE, #define PROFILE_STAGE(TAG, NAME, VAL) TAG = VAL, #include "profile-defs.h" - }; + }; - ProfileFamily getProfileFamily(ProfileVersion version); + ProfileFamily getProfileFamily(ProfileVersion version); - struct Profile + struct Profile + { + typedef uint32_t RawVal; + enum : RawVal { - typedef uint32_t RawVal; - enum : RawVal - { - Unknown, + Unknown, #define PROFILE(TAG, NAME, STAGE, VERSION) TAG = (uint32_t(Stage::STAGE) << 16) | uint32_t(ProfileVersion::VERSION), #include "profile-defs.h" - }; + }; - Profile() {} - Profile(RawVal raw) - : raw(raw) - {} + Profile() {} + Profile(RawVal raw) + : raw(raw) + {} - bool operator==(Profile const& other) const { return raw == other.raw; } - bool operator!=(Profile const& other) const { return raw != other.raw; } + bool operator==(Profile const& other) const { return raw == other.raw; } + bool operator!=(Profile const& other) const { return raw != other.raw; } - Stage GetStage() const { return Stage((uint32_t(raw) >> 16) & 0xFFFF); } - ProfileVersion GetVersion() const { return ProfileVersion(uint32_t(raw) & 0xFFFF); } - ProfileFamily getFamily() const { return getProfileFamily(GetVersion()); } + Stage GetStage() const { return Stage((uint32_t(raw) >> 16) & 0xFFFF); } + ProfileVersion GetVersion() const { return ProfileVersion(uint32_t(raw) & 0xFFFF); } + ProfileFamily getFamily() const { return getProfileFamily(GetVersion()); } - static Profile LookUp(char const* name); + static Profile LookUp(char const* name); - RawVal raw = Unknown; - }; - } + RawVal raw = Unknown; + }; } #endif diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index 1a56a8c1f..220e5bcdf 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -9,7 +9,6 @@ // Implementation to back public-facing reflection API using namespace Slang; -using namespace Slang::Compiler; // Conversion routines to help with strongly-typed reflection API @@ -750,13 +749,6 @@ SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangRefl namespace Slang { -namespace Compiler { - - - - - - // Debug helper code: dump reflection data after generation @@ -1401,4 +1393,4 @@ String emitReflectionJSON( return writer.sb.ProduceString(); } -}} +} diff --git a/source/slang/reflection.h b/source/slang/reflection.h index 4d2c53084..627ca8382 100644 --- a/source/slang/reflection.h +++ b/source/slang/reflection.h @@ -16,8 +16,6 @@ typedef int64_t Int64; typedef uintptr_t UInt; typedef uint64_t UInt64; -namespace Compiler { - class ProgramLayout; class TypeLayout; @@ -34,6 +32,6 @@ UInt getReflectionFieldCount(ExpressionType* type); UInt getReflectionFieldByIndex(ExpressionType* type, UInt index); UInt getReflectionFieldByIndex(TypeLayout* typeLayout, UInt index); -}} +} #endif // SLANG_REFLECTION_H diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp index 40c391bf4..85c876105 100644 --- a/source/slang/slang-stdlib.cpp +++ b/source/slang/slang-stdlib.cpp @@ -951,903 +951,899 @@ using namespace CoreLib::Basic; namespace Slang { - namespace Compiler + static String stdlibPath; + + String getStdlibPath() { - static String stdlibPath; + if(stdlibPath.Length() != 0) + return stdlibPath; - String getStdlibPath() + StringBuilder pathBuilder; + for( auto cc = __FILE__; *cc; ++cc ) { - if(stdlibPath.Length() != 0) - return stdlibPath; - - StringBuilder pathBuilder; - for( auto cc = __FILE__; *cc; ++cc ) + switch( *cc ) { - switch( *cc ) - { - case '\n': - case '\t': - case '\\': - pathBuilder << "\\"; - default: - pathBuilder << *cc; - break; - } + case '\n': + case '\t': + case '\\': + pathBuilder << "\\"; + default: + pathBuilder << *cc; + break; } - stdlibPath = pathBuilder.ProduceString(); - - return stdlibPath; } + stdlibPath = pathBuilder.ProduceString(); - String SlangStdLib::code; + return stdlibPath; + } - enum - { - SINT_MASK = 1 << 0, - FLOAT_MASK = 1 << 1, - COMPARISON = 1 << 2, - BOOL_MASK = 1 << 3, - UINT_MASK = 1 << 4, - ASSIGNMENT = 1 << 5, - POSTFIX = 1 << 6, - - INT_MASK = SINT_MASK | UINT_MASK, - ARITHMETIC_MASK = INT_MASK | FLOAT_MASK, - LOGICAL_MASK = INT_MASK | BOOL_MASK, - ANY_MASK = INT_MASK | FLOAT_MASK | BOOL_MASK, - }; + String SlangStdLib::code; - String SlangStdLib::GetCode() - { - if (code.Length() > 0) - return code; - StringBuilder sb; + enum + { + SINT_MASK = 1 << 0, + FLOAT_MASK = 1 << 1, + COMPARISON = 1 << 2, + BOOL_MASK = 1 << 3, + UINT_MASK = 1 << 4, + ASSIGNMENT = 1 << 5, + POSTFIX = 1 << 6, + + INT_MASK = SINT_MASK | UINT_MASK, + ARITHMETIC_MASK = INT_MASK | FLOAT_MASK, + LOGICAL_MASK = INT_MASK | BOOL_MASK, + ANY_MASK = INT_MASK | FLOAT_MASK | BOOL_MASK, + }; + + String SlangStdLib::GetCode() + { + if (code.Length() > 0) + return code; + StringBuilder sb; - // generate operator overloads + // generate operator overloads - struct OpInfo { IntrinsicOp opCode; char const* opName; unsigned flags; }; + struct OpInfo { IntrinsicOp opCode; char const* opName; unsigned flags; }; - OpInfo unaryOps[] = { - { IntrinsicOp::Neg, "-", ARITHMETIC_MASK }, - { IntrinsicOp::Not, "!", ANY_MASK }, - { IntrinsicOp::Not, "~", INT_MASK }, - { IntrinsicOp::PreInc, "++", ARITHMETIC_MASK | ASSIGNMENT }, - { IntrinsicOp::PreDec, "--", ARITHMETIC_MASK | ASSIGNMENT }, - { IntrinsicOp::PostInc, "++", ARITHMETIC_MASK | ASSIGNMENT | POSTFIX }, - { IntrinsicOp::PostDec, "--", ARITHMETIC_MASK | ASSIGNMENT | POSTFIX }, - }; + OpInfo unaryOps[] = { + { IntrinsicOp::Neg, "-", ARITHMETIC_MASK }, + { IntrinsicOp::Not, "!", ANY_MASK }, + { IntrinsicOp::Not, "~", INT_MASK }, + { IntrinsicOp::PreInc, "++", ARITHMETIC_MASK | ASSIGNMENT }, + { IntrinsicOp::PreDec, "--", ARITHMETIC_MASK | ASSIGNMENT }, + { IntrinsicOp::PostInc, "++", ARITHMETIC_MASK | ASSIGNMENT | POSTFIX }, + { IntrinsicOp::PostDec, "--", ARITHMETIC_MASK | ASSIGNMENT | POSTFIX }, + }; - OpInfo binaryOps[] = { - { IntrinsicOp::Add, "+", ARITHMETIC_MASK }, - { IntrinsicOp::Sub, "-", ARITHMETIC_MASK }, - { IntrinsicOp::Mul, "*", ARITHMETIC_MASK }, - { IntrinsicOp::Div, "/", ARITHMETIC_MASK }, - { IntrinsicOp::Mod, "%", INT_MASK }, + OpInfo binaryOps[] = { + { IntrinsicOp::Add, "+", ARITHMETIC_MASK }, + { IntrinsicOp::Sub, "-", ARITHMETIC_MASK }, + { IntrinsicOp::Mul, "*", ARITHMETIC_MASK }, + { IntrinsicOp::Div, "/", ARITHMETIC_MASK }, + { IntrinsicOp::Mod, "%", INT_MASK }, - { IntrinsicOp::And, "&&", LOGICAL_MASK }, - { IntrinsicOp::Or, "||", LOGICAL_MASK }, + { IntrinsicOp::And, "&&", LOGICAL_MASK }, + { IntrinsicOp::Or, "||", LOGICAL_MASK }, - { IntrinsicOp::BitAnd, "&", LOGICAL_MASK }, - { IntrinsicOp::BitOr, "|", LOGICAL_MASK }, - { IntrinsicOp::BitXor, "^", LOGICAL_MASK }, + { IntrinsicOp::BitAnd, "&", LOGICAL_MASK }, + { IntrinsicOp::BitOr, "|", LOGICAL_MASK }, + { IntrinsicOp::BitXor, "^", LOGICAL_MASK }, - { IntrinsicOp::Lsh, "<<", INT_MASK }, - { IntrinsicOp::Rsh, ">>", INT_MASK }, + { IntrinsicOp::Lsh, "<<", INT_MASK }, + { IntrinsicOp::Rsh, ">>", INT_MASK }, - { IntrinsicOp::Eql, "==", ANY_MASK | COMPARISON }, - { IntrinsicOp::Neq, "!=", ANY_MASK | COMPARISON }, + { IntrinsicOp::Eql, "==", ANY_MASK | COMPARISON }, + { IntrinsicOp::Neq, "!=", ANY_MASK | COMPARISON }, - { IntrinsicOp::Greater, ">", ARITHMETIC_MASK | COMPARISON }, - { IntrinsicOp::Less, "<", ARITHMETIC_MASK | COMPARISON }, - { IntrinsicOp::Geq, ">=", ARITHMETIC_MASK | COMPARISON }, - { IntrinsicOp::Leq, "<=", ARITHMETIC_MASK | COMPARISON }, + { IntrinsicOp::Greater, ">", ARITHMETIC_MASK | COMPARISON }, + { IntrinsicOp::Less, "<", ARITHMETIC_MASK | COMPARISON }, + { IntrinsicOp::Geq, ">=", ARITHMETIC_MASK | COMPARISON }, + { IntrinsicOp::Leq, "<=", ARITHMETIC_MASK | COMPARISON }, - { IntrinsicOp::AddAssign, "+=", ASSIGNMENT | ARITHMETIC_MASK }, - { IntrinsicOp::SubAssign, "-=", ASSIGNMENT | ARITHMETIC_MASK }, - { IntrinsicOp::MulAssign, "*=", ASSIGNMENT | ARITHMETIC_MASK }, - { IntrinsicOp::DivAssign, "/=", ASSIGNMENT | ARITHMETIC_MASK }, - { IntrinsicOp::ModAssign, "%=", ASSIGNMENT | ARITHMETIC_MASK }, - { IntrinsicOp::AndAssign, "&=", ASSIGNMENT | LOGICAL_MASK }, - { IntrinsicOp::OrAssign, "|=", ASSIGNMENT | LOGICAL_MASK }, - { IntrinsicOp::XorAssign, "^=", ASSIGNMENT | LOGICAL_MASK }, - { IntrinsicOp::LshAssign, "<<=", ASSIGNMENT | INT_MASK }, - { IntrinsicOp::RshAssign, ">>=", ASSIGNMENT | INT_MASK }, + { IntrinsicOp::AddAssign, "+=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::SubAssign, "-=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::MulAssign, "*=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::DivAssign, "/=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::ModAssign, "%=", ASSIGNMENT | ARITHMETIC_MASK }, + { IntrinsicOp::AndAssign, "&=", ASSIGNMENT | LOGICAL_MASK }, + { IntrinsicOp::OrAssign, "|=", ASSIGNMENT | LOGICAL_MASK }, + { IntrinsicOp::XorAssign, "^=", ASSIGNMENT | LOGICAL_MASK }, + { IntrinsicOp::LshAssign, "<<=", ASSIGNMENT | INT_MASK }, + { IntrinsicOp::RshAssign, ">>=", ASSIGNMENT | INT_MASK }, - }; + }; - /* - String floatTypes[] = { "float", "float2", "float3", "float4" }; - String intTypes[] = { "int", "int2", "int3", "int4" }; - String uintTypes[] = { "uint", "uint2", "uint3", "uint4" }; - */ + /* + String floatTypes[] = { "float", "float2", "float3", "float4" }; + String intTypes[] = { "int", "int2", "int3", "int4" }; + String uintTypes[] = { "uint", "uint2", "uint3", "uint4" }; + */ - String path = getStdlibPath(); + String path = getStdlibPath(); #define EMIT_LINE_DIRECTIVE() sb << "#line " << (__LINE__+1) << " \"" << path << "\"\n" - // Generate declarations for all the base types - - static const struct { - char const* name; - BaseType tag; - unsigned flags; - } kBaseTypes[] = { - { "void", BaseType::Void, 0 }, - { "int", BaseType::Int, SINT_MASK }, - { "float", BaseType::Float, FLOAT_MASK }, - { "uint", BaseType::UInt, UINT_MASK }, - { "bool", BaseType::Bool, BOOL_MASK }, - { "uint64_t", BaseType::UInt64, UINT_MASK }, - }; - static const int kBaseTypeCount = sizeof(kBaseTypes) / sizeof(kBaseTypes[0]); - for (int tt = 0; tt < kBaseTypeCount; ++tt) - { - EMIT_LINE_DIRECTIVE(); - sb << "__builtin_type(" << int(kBaseTypes[tt].tag) << ") struct " << kBaseTypes[tt].name; - - // Declare interface conformances for this type + // Generate declarations for all the base types + + static const struct { + char const* name; + BaseType tag; + unsigned flags; + } kBaseTypes[] = { + { "void", BaseType::Void, 0 }, + { "int", BaseType::Int, SINT_MASK }, + { "float", BaseType::Float, FLOAT_MASK }, + { "uint", BaseType::UInt, UINT_MASK }, + { "bool", BaseType::Bool, BOOL_MASK }, + { "uint64_t", BaseType::UInt64, UINT_MASK }, + }; + static const int kBaseTypeCount = sizeof(kBaseTypes) / sizeof(kBaseTypes[0]); + for (int tt = 0; tt < kBaseTypeCount; ++tt) + { + EMIT_LINE_DIRECTIVE(); + sb << "__builtin_type(" << int(kBaseTypes[tt].tag) << ") struct " << kBaseTypes[tt].name; - sb << "\n : __BuiltinType\n"; + // Declare interface conformances for this type - switch( kBaseTypes[tt].tag ) - { - case BaseType::Float: - sb << "\n , __BuiltinFloatingPointType\n"; - sb << "\n , __BuiltinRealType\n"; - // fall through to: - case BaseType::Int: - sb << "\n , __BuiltinSignedArithmeticType\n"; - // fall through to: - case BaseType::UInt: - case BaseType::UInt64: - sb << "\n , __BuiltinArithmeticType\n"; - // fall through to: - case BaseType::Bool: - sb << "\n , __BuiltinType\n"; - break; - - default: - break; - } + sb << "\n : __BuiltinType\n"; - sb << "\n{\n"; + switch( kBaseTypes[tt].tag ) + { + case BaseType::Float: + sb << "\n , __BuiltinFloatingPointType\n"; + sb << "\n , __BuiltinRealType\n"; + // fall through to: + case BaseType::Int: + sb << "\n , __BuiltinSignedArithmeticType\n"; + // fall through to: + case BaseType::UInt: + case BaseType::UInt64: + sb << "\n , __BuiltinArithmeticType\n"; + // fall through to: + case BaseType::Bool: + sb << "\n , __BuiltinType\n"; + break; + + default: + break; + } + sb << "\n{\n"; - // Declare initializers to convert from various other types - for( int ss = 0; ss < kBaseTypeCount; ++ss ) - { - if( kBaseTypes[ss].tag == BaseType::Void ) - continue; - EMIT_LINE_DIRECTIVE(); - sb << "__init(" << kBaseTypes[ss].name << " value);\n"; - } + // Declare initializers to convert from various other types + for( int ss = 0; ss < kBaseTypeCount; ++ss ) + { + if( kBaseTypes[ss].tag == BaseType::Void ) + continue; - sb << "};\n"; + EMIT_LINE_DIRECTIVE(); + sb << "__init(" << kBaseTypes[ss].name << " value);\n"; } - // Declare ad hoc aliases for some types, just to get things compiling - // - // TODO(tfoley): At the very least, `double` should be treated as a distinct type. - sb << "typedef float double;\n"; - sb << "typedef float half;\n"; + sb << "};\n"; + } - // Declare vector and matrix types + // Declare ad hoc aliases for some types, just to get things compiling + // + // TODO(tfoley): At the very least, `double` should be treated as a distinct type. + sb << "typedef float double;\n"; + sb << "typedef float half;\n"; - sb << "__generic<T = float, let N : int = 4> __magic_type(Vector) struct vector\n{\n"; - sb << " __init(T value);\n"; // initialize from single scalar - sb << "};\n"; - sb << "__generic<T = float, let R : int = 4, let C : int = 4> __magic_type(Matrix) struct matrix {};\n"; + // Declare vector and matrix types - static const struct { - char const* name; - char const* glslPrefix; - } kTypes[] = - { - {"float", ""}, - {"int", "i"}, - {"uint", "u"}, - {"bool", "b"}, - }; - static const int kTypeCount = sizeof(kTypes) / sizeof(kTypes[0]); - - for (int tt = 0; tt < kTypeCount; ++tt) + sb << "__generic<T = float, let N : int = 4> __magic_type(Vector) struct vector\n{\n"; + sb << " __init(T value);\n"; // initialize from single scalar + sb << "};\n"; + sb << "__generic<T = float, let R : int = 4, let C : int = 4> __magic_type(Matrix) struct matrix {};\n"; + + static const struct { + char const* name; + char const* glslPrefix; + } kTypes[] = + { + {"float", ""}, + {"int", "i"}, + {"uint", "u"}, + {"bool", "b"}, + }; + static const int kTypeCount = sizeof(kTypes) / sizeof(kTypes[0]); + + for (int tt = 0; tt < kTypeCount; ++tt) + { + // Declare HLSL vector types + for (int ii = 1; ii <= 4; ++ii) { - // Declare HLSL vector types - for (int ii = 1; ii <= 4; ++ii) - { - sb << "typedef vector<" << kTypes[tt].name << "," << ii << "> " << kTypes[tt].name << ii << ";\n"; - } + sb << "typedef vector<" << kTypes[tt].name << "," << ii << "> " << kTypes[tt].name << ii << ";\n"; + } - // Declare HLSL matrix types - for (int rr = 2; rr <= 4; ++rr) - for (int cc = 2; cc <= 4; ++cc) - { - sb << "typedef matrix<" << kTypes[tt].name << "," << rr << "," << cc << "> " << kTypes[tt].name << rr << "x" << cc << ";\n"; - } + // Declare HLSL matrix types + for (int rr = 2; rr <= 4; ++rr) + for (int cc = 2; cc <= 4; ++cc) + { + sb << "typedef matrix<" << kTypes[tt].name << "," << rr << "," << cc << "> " << kTypes[tt].name << rr << "x" << cc << ";\n"; } + } - static const char* kComponentNames[]{ "x", "y", "z", "w" }; - static const char* kVectorNames[]{ "", "x", "xy", "xyz", "xyzw" }; + static const char* kComponentNames[]{ "x", "y", "z", "w" }; + static const char* kVectorNames[]{ "", "x", "xy", "xyz", "xyzw" }; - // Need to add constructors to the types above - for (int N = 2; N <= 4; ++N) + // Need to add constructors to the types above + for (int N = 2; N <= 4; ++N) + { + sb << "__generic<T> __extension vector<T, " << N << ">\n{\n"; + + // initialize from N scalars + sb << "__init("; + for (int ii = 0; ii < N; ++ii) { - sb << "__generic<T> __extension vector<T, " << N << ">\n{\n"; + if (ii != 0) sb << ", "; + sb << "T " << kComponentNames[ii]; + } + sb << ");\n"; - // initialize from N scalars - sb << "__init("; - for (int ii = 0; ii < N; ++ii) + // Initialize from an M-vector and then scalars + for (int M = 2; M < N; ++M) + { + sb << "__init(vector<T," << M << "> " << kVectorNames[M]; + for (int ii = M; ii < N; ++ii) { - if (ii != 0) sb << ", "; - sb << "T " << kComponentNames[ii]; + sb << ", T " << kComponentNames[ii]; } sb << ");\n"; + } - // Initialize from an M-vector and then scalars - for (int M = 2; M < N; ++M) - { - sb << "__init(vector<T," << M << "> " << kVectorNames[M]; - for (int ii = M; ii < N; ++ii) - { - sb << ", T " << kComponentNames[ii]; - } - sb << ");\n"; - } + // initialize from another vector of the same size + // + // TODO(tfoley): this overlaps with implicit conversions. + // We should look for a way that we can define implicit + // conversions directly in the stdlib instead... + sb << "__generic<U> __init(vector<U," << N << ">);\n"; - // initialize from another vector of the same size - // - // TODO(tfoley): this overlaps with implicit conversions. - // We should look for a way that we can define implicit - // conversions directly in the stdlib instead... - sb << "__generic<U> __init(vector<U," << N << ">);\n"; + sb << "}\n"; + } + + for( int R = 2; R <= 4; ++R ) + for( int C = 2; C <= 4; ++C ) + { + sb << "__generic<T> __extension matrix<T, " << R << "," << C << ">\n{\n"; - sb << "}\n"; + // initialize from R*C scalars + sb << "__init("; + for( int ii = 0; ii < R; ++ii ) + for( int jj = 0; jj < C; ++jj ) + { + if ((ii+jj) != 0) sb << ", "; + sb << "T m" << ii << jj; } + sb << ");\n"; - for( int R = 2; R <= 4; ++R ) - for( int C = 2; C <= 4; ++C ) + // Initialize from R C-vectors + sb << "__init("; + for (int ii = 0; ii < R; ++ii) { - sb << "__generic<T> __extension matrix<T, " << R << "," << C << ">\n{\n"; + if(ii != 0) sb << ", "; + sb << "vector<T," << C << "> row" << ii; + } + sb << ");\n"; - // initialize from R*C scalars - sb << "__init("; - for( int ii = 0; ii < R; ++ii ) - for( int jj = 0; jj < C; ++jj ) - { - if ((ii+jj) != 0) sb << ", "; - sb << "T m" << ii << jj; - } - sb << ");\n"; - // Initialize from R C-vectors - sb << "__init("; - for (int ii = 0; ii < R; ++ii) - { - if(ii != 0) sb << ", "; - sb << "vector<T," << C << "> row" << ii; - } - sb << ");\n"; + // initialize from another matrix of the same size + // + // TODO(tfoley): See comment about how this overlaps + // with implicit conversion, in the `vector` case above + sb << "__generic<U> __init(matrix<U," << R << ", " << C << ">);\n"; + sb << "}\n"; + } - // initialize from another matrix of the same size - // - // TODO(tfoley): See comment about how this overlaps - // with implicit conversion, in the `vector` case above - sb << "__generic<U> __init(matrix<U," << R << ", " << C << ">);\n"; - sb << "}\n"; - } + // Declare built-in texture and sampler types + + sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerState) << ") struct SamplerState {};"; + sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerComparisonState) << ") struct SamplerComparisonState {};"; + + // TODO(tfoley): Need to handle `RW*` variants of texture types as well... + static const struct { + char const* name; + TextureType::Shape baseShape; + int coordCount; + } kBaseTextureTypes[] = { + { "Texture1D", TextureType::Shape1D, 1 }, + { "Texture2D", TextureType::Shape2D, 2 }, + { "Texture3D", TextureType::Shape3D, 3 }, + { "TextureCube", TextureType::ShapeCube, 3 }, + }; + static const int kBaseTextureTypeCount = sizeof(kBaseTextureTypes) / sizeof(kBaseTextureTypes[0]); - // Declare built-in texture and sampler types - - sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerState) << ") struct SamplerState {};"; - sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerComparisonState) << ") struct SamplerComparisonState {};"; - - // TODO(tfoley): Need to handle `RW*` variants of texture types as well... - static const struct { - char const* name; - TextureType::Shape baseShape; - int coordCount; - } kBaseTextureTypes[] = { - { "Texture1D", TextureType::Shape1D, 1 }, - { "Texture2D", TextureType::Shape2D, 2 }, - { "Texture3D", TextureType::Shape3D, 3 }, - { "TextureCube", TextureType::ShapeCube, 3 }, - }; - static const int kBaseTextureTypeCount = sizeof(kBaseTextureTypes) / sizeof(kBaseTextureTypes[0]); - - - static const struct { - char const* name; - SlangResourceAccess access; - } kBaseTextureAccessLevels[] = { - { "", SLANG_RESOURCE_ACCESS_READ }, - { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE }, - { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED }, - }; - static const int kBaseTextureAccessLevelCount = sizeof(kBaseTextureAccessLevels) / sizeof(kBaseTextureAccessLevels[0]); - - for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) + static const struct { + char const* name; + SlangResourceAccess access; + } kBaseTextureAccessLevels[] = { + { "", SLANG_RESOURCE_ACCESS_READ }, + { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE }, + { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED }, + }; + static const int kBaseTextureAccessLevelCount = sizeof(kBaseTextureAccessLevels) / sizeof(kBaseTextureAccessLevels[0]); + + for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) + { + char const* name = kBaseTextureTypes[tt].name; + TextureType::Shape baseShape = kBaseTextureTypes[tt].baseShape; + + for (int isArray = 0; isArray < 2; ++isArray) { - char const* name = kBaseTextureTypes[tt].name; - TextureType::Shape baseShape = kBaseTextureTypes[tt].baseShape; + // Arrays of 3D textures aren't allowed + if (isArray && baseShape == TextureType::Shape3D) continue; - for (int isArray = 0; isArray < 2; ++isArray) + for (int isMultisample = 0; isMultisample < 2; ++isMultisample) + for (int accessLevel = 0; accessLevel < kBaseTextureAccessLevelCount; ++accessLevel) { - // Arrays of 3D textures aren't allowed - if (isArray && baseShape == TextureType::Shape3D) continue; + auto access = kBaseTextureAccessLevels[accessLevel].access; - for (int isMultisample = 0; isMultisample < 2; ++isMultisample) - for (int accessLevel = 0; accessLevel < kBaseTextureAccessLevelCount; ++accessLevel) - { - auto access = kBaseTextureAccessLevels[accessLevel].access; - - // TODO: any constraints to enforce on what gets to be multisampled? + // TODO: any constraints to enforce on what gets to be multisampled? - unsigned flavor = baseShape; - if (isArray) flavor |= TextureType::ArrayFlag; - if (isMultisample) flavor |= TextureType::MultisampleFlag; + unsigned flavor = baseShape; + if (isArray) flavor |= TextureType::ArrayFlag; + if (isMultisample) flavor |= TextureType::MultisampleFlag; // if (isShadow) flavor |= TextureType::ShadowFlag; - flavor |= (access << 8); + flavor |= (access << 8); - // emit a generic signature - // TODO: allow for multisample count to come in as well... - sb << "__generic<T = float4> "; + // emit a generic signature + // TODO: allow for multisample count to come in as well... + sb << "__generic<T = float4> "; - sb << "__magic_type(Texture," << int(flavor) << ") struct "; - sb << kBaseTextureAccessLevels[accessLevel].name; - sb << name; - if (isMultisample) sb << "MS"; - if (isArray) sb << "Array"; + sb << "__magic_type(Texture," << int(flavor) << ") struct "; + sb << kBaseTextureAccessLevels[accessLevel].name; + sb << name; + if (isMultisample) sb << "MS"; + if (isArray) sb << "Array"; // if (isShadow) sb << "Shadow"; - sb << "\n{"; + sb << "\n{"; - if( !isMultisample ) - { - sb << "float CalculateLevelOfDetail(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n"; + if( !isMultisample ) + { + sb << "float CalculateLevelOfDetail(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n"; - sb << "float CalculateLevelOfDetailUnclamped(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n"; + sb << "float CalculateLevelOfDetailUnclamped(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " location);\n"; - // TODO: `Gather` operation - // (tricky because it returns a 4-vector of the element type - // of the texture components...) - } + // TODO: `Gather` operation + // (tricky because it returns a 4-vector of the element type + // of the texture components...) + } + + // TODO: `GetDimensions` operations + + for(int isFloat = 0; isFloat < 2; ++isFloat) + for(int includeMipInfo = 0; includeMipInfo < 2; ++includeMipInfo) + { + char const* t = isFloat ? "out float " : "out UINT "; - // TODO: `GetDimensions` operations + sb << "void GetDimensions("; + if(includeMipInfo) + sb << "UINT mipLevel, "; - for(int isFloat = 0; isFloat < 2; ++isFloat) - for(int includeMipInfo = 0; includeMipInfo < 2; ++includeMipInfo) + switch(baseShape) { - char const* t = isFloat ? "out float " : "out UINT "; - - sb << "void GetDimensions("; - if(includeMipInfo) - sb << "UINT mipLevel, "; - - switch(baseShape) - { - case TextureType::Shape1D: - sb << t << "width"; - break; - - case TextureType::Shape2D: - case TextureType::ShapeCube: - sb << t << "width,"; - sb << t << "height"; - break; - - case TextureType::Shape3D: - sb << t << "width,"; - sb << t << "height,"; - sb << t << "depth"; - break; - - default: - assert(!"unexpected"); - break; - } - - if(isArray) - { - sb << ", " << t << "elements"; - } - - if(includeMipInfo) - sb << ", " << t << "numberOfLevels"; - - sb << ");\n"; + case TextureType::Shape1D: + sb << t << "width"; + break; + + case TextureType::Shape2D: + case TextureType::ShapeCube: + sb << t << "width,"; + sb << t << "height"; + break; + + case TextureType::Shape3D: + sb << t << "width,"; + sb << t << "height,"; + sb << t << "depth"; + break; + + default: + assert(!"unexpected"); + break; } - // `GetSamplePosition()` - if( isMultisample ) + if(isArray) { - sb << "float2 GetSamplePosition(int s);\n"; + sb << ", " << t << "elements"; } - // `Load()` + if(includeMipInfo) + sb << ", " << t << "numberOfLevels"; + + sb << ");\n"; + } + + // `GetSamplePosition()` + if( isMultisample ) + { + sb << "float2 GetSamplePosition(int s);\n"; + } + + // `Load()` + + if( kBaseTextureTypes[tt].coordCount + isArray < 4 ) + { + sb << "T Load("; + sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location);\n"; - if( kBaseTextureTypes[tt].coordCount + isArray < 4 ) + if( !isMultisample ) { sb << "T Load("; - sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location);\n"; - - if( !isMultisample ) - { - sb << "T Load("; - sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location, "; - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; - } - else - { - sb << "T Load("; - sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location, "; - sb << "int sampleIndex, "; - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; - } + sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; } - - if(baseShape != TextureType::ShapeCube) + else { - // subscript operator - sb << "__intrinsic __subscript(uint" << kBaseTextureTypes[tt].coordCount + isArray << " location) -> T;\n"; + sb << "T Load("; + sb << "int" << kBaseTextureTypes[tt].coordCount + isArray + 1 << " location, "; + sb << "int sampleIndex, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; } + } - if( !isMultisample ) - { - // `Sample()` + if(baseShape != TextureType::ShapeCube) + { + // subscript operator + sb << "__intrinsic __subscript(uint" << kBaseTextureTypes[tt].coordCount + isArray << " location) -> T;\n"; + } - sb << "T Sample(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location);\n"; + if( !isMultisample ) + { + // `Sample()` - if( baseShape != TextureType::ShapeCube ) - { - sb << "T Sample(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; - } + sb << "T Sample(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location);\n"; + if( baseShape != TextureType::ShapeCube ) + { sb << "T Sample(SamplerState s, "; sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - if( baseShape != TextureType::ShapeCube ) - { - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset, "; - } - sb << "float clamp);\n"; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } - sb << "T Sample(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - if( baseShape != TextureType::ShapeCube ) - { - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset, "; - } - sb << "float clamp, out uint status);\n"; + sb << "T Sample(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + if( baseShape != TextureType::ShapeCube ) + { + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset, "; + } + sb << "float clamp);\n"; + + sb << "T Sample(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + if( baseShape != TextureType::ShapeCube ) + { + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset, "; + } + sb << "float clamp, out uint status);\n"; - // `SampleBias()` + // `SampleBias()` + sb << "T SampleBias(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, float bias);\n"; + + if( baseShape != TextureType::ShapeCube ) + { sb << "T SampleBias(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, float bias);\n"; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, float bias, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } - if( baseShape != TextureType::ShapeCube ) - { - sb << "T SampleBias(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, float bias, "; - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; - } + // `SampleCmp()` and `SampleCmpLevelZero` + sb << "T SampleCmp(SamplerComparisonState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float compareValue"; + sb << ");\n"; - // `SampleCmp()` and `SampleCmpLevelZero` - sb << "T SampleCmp(SamplerComparisonState s, "; + sb << "T SampleCmpLevelZero(SamplerComparisonState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float compareValue"; + sb << ");\n"; + + if( baseShape != TextureType::ShapeCube ) + { + // Note(tfoley): MSDN seems confused, and claims that the `offset` + // parameter for `SampleCmp` is available for everything but 3D + // textures, while `Sample` and `SampleBias` are consistent in + // saying they only exclude `offset` for cube maps (which makes + // sense). I'm going to assume the documentation for `SampleCmp` + // is just wrong. + + sb << "T SampleCmp(SamplerState s, "; sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - sb << "float compareValue"; - sb << ");\n"; + sb << "float compareValue, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; - sb << "T SampleCmpLevelZero(SamplerComparisonState s, "; + sb << "T SampleCmpLevelZero(SamplerState s, "; sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - sb << "float compareValue"; - sb << ");\n"; - - if( baseShape != TextureType::ShapeCube ) - { - // Note(tfoley): MSDN seems confused, and claims that the `offset` - // parameter for `SampleCmp` is available for everything but 3D - // textures, while `Sample` and `SampleBias` are consistent in - // saying they only exclude `offset` for cube maps (which makes - // sense). I'm going to assume the documentation for `SampleCmp` - // is just wrong. - - sb << "T SampleCmp(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - sb << "float compareValue, "; - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; - - sb << "T SampleCmpLevelZero(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - sb << "float compareValue, "; - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; - } + sb << "float compareValue, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } + sb << "T SampleGrad(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " gradX, "; + sb << "float" << kBaseTextureTypes[tt].coordCount << " gradY"; + sb << ");\n"; + + if( baseShape != TextureType::ShapeCube ) + { sb << "T SampleGrad(SamplerState s, "; sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; sb << "float" << kBaseTextureTypes[tt].coordCount << " gradX, "; - sb << "float" << kBaseTextureTypes[tt].coordCount << " gradY"; - sb << ");\n"; + sb << "float" << kBaseTextureTypes[tt].coordCount << " gradY, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; + } - if( baseShape != TextureType::ShapeCube ) - { - sb << "T SampleGrad(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - sb << "float" << kBaseTextureTypes[tt].coordCount << " gradX, "; - sb << "float" << kBaseTextureTypes[tt].coordCount << " gradY, "; - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; - } + // `SampleLevel` - // `SampleLevel` + sb << "T SampleLevel(SamplerState s, "; + sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; + sb << "float level);\n"; + if( baseShape != TextureType::ShapeCube ) + { sb << "T SampleLevel(SamplerState s, "; sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - sb << "float level);\n"; - - if( baseShape != TextureType::ShapeCube ) - { - sb << "T SampleLevel(SamplerState s, "; - sb << "float" << kBaseTextureTypes[tt].coordCount + isArray << " location, "; - sb << "float level, "; - sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; - } + sb << "float level, "; + sb << "int" << kBaseTextureTypes[tt].coordCount << " offset);\n"; } - - sb << "\n};\n"; } + + sb << "\n};\n"; } } + } - // Declare additional built-in generic types + // Declare additional built-in generic types - sb << "__generic<T> __magic_type(ConstantBuffer) struct ConstantBuffer {};\n"; - sb << "__generic<T> __magic_type(TextureBuffer) struct TextureBuffer {};\n"; + sb << "__generic<T> __magic_type(ConstantBuffer) struct ConstantBuffer {};\n"; + sb << "__generic<T> __magic_type(TextureBuffer) struct TextureBuffer {};\n"; - sb << "__generic<T> __magic_type(PackedBuffer) struct PackedBuffer {};\n"; - sb << "__generic<T> __magic_type(Uniform) struct Uniform {};\n"; - sb << "__generic<T> __magic_type(Patch) struct Patch {};\n"; + sb << "__generic<T> __magic_type(PackedBuffer) struct PackedBuffer {};\n"; + sb << "__generic<T> __magic_type(Uniform) struct Uniform {};\n"; + sb << "__generic<T> __magic_type(Patch) struct Patch {};\n"; - // Stale declarations for GLSL inner-product builtins + // Stale declarations for GLSL inner-product builtins #if 0 - sb << "__intrinsic vec3 operator * (vec3, mat3);\n"; - sb << "__intrinsic vec3 operator * (mat3, vec3);\n"; + sb << "__intrinsic vec3 operator * (vec3, mat3);\n"; + sb << "__intrinsic vec3 operator * (mat3, vec3);\n"; - sb << "__intrinsic vec4 operator * (vec4, mat4);\n"; - sb << "__intrinsic vec4 operator * (mat4, vec4);\n"; + sb << "__intrinsic vec4 operator * (vec4, mat4);\n"; + sb << "__intrinsic vec4 operator * (mat4, vec4);\n"; - sb << "__intrinsic mat3 operator * (mat3, mat3);\n"; - sb << "__intrinsic mat4 operator * (mat4, mat4);\n"; + sb << "__intrinsic mat3 operator * (mat3, mat3);\n"; + sb << "__intrinsic mat4 operator * (mat4, mat4);\n"; #endif - for (auto op : unaryOps) + for (auto op : unaryOps) + { + for (auto type : kBaseTypes) { - for (auto type : kBaseTypes) - { - if ((type.flags & op.flags) == 0) - continue; + if ((type.flags & op.flags) == 0) + continue; - char const* fixity = (op.flags & POSTFIX) != 0 ? "__postfix " : "__prefix "; - char const* qual = (op.flags & ASSIGNMENT) != 0 ? "in out " : ""; + char const* fixity = (op.flags & POSTFIX) != 0 ? "__postfix " : "__prefix "; + char const* qual = (op.flags & ASSIGNMENT) != 0 ? "in out " : ""; - // scalar version - sb << fixity; - sb << "__intrinsic_op(" << int(op.opCode) << ") " << type.name << " operator" << op.opName << "(" << qual << type.name << " value);\n"; + // scalar version + sb << fixity; + sb << "__intrinsic_op(" << int(op.opCode) << ") " << type.name << " operator" << op.opName << "(" << qual << type.name << " value);\n"; - // vector version - sb << "__generic<let N : int> "; - sb << fixity; - sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << type.name << ",N> operator" << op.opName << "(" << qual << "vector<" << type.name << ",N> value);\n"; + // vector version + sb << "__generic<let N : int> "; + sb << fixity; + sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << type.name << ",N> operator" << op.opName << "(" << qual << "vector<" << type.name << ",N> value);\n"; - // matrix version - sb << "__generic<let N : int, let M : int> "; - sb << fixity; - sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << type.name << ",N,M> operator" << op.opName << "(" << qual << "matrix<" << type.name << ",N,M> value);\n"; - } + // matrix version + sb << "__generic<let N : int, let M : int> "; + sb << fixity; + sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << type.name << ",N,M> operator" << op.opName << "(" << qual << "matrix<" << type.name << ",N,M> value);\n"; } + } - for (auto op : binaryOps) + for (auto op : binaryOps) + { + for (auto type : kBaseTypes) { - for (auto type : kBaseTypes) - { - if ((type.flags & op.flags) == 0) - continue; + if ((type.flags & op.flags) == 0) + continue; - char const* leftType = type.name; - char const* rightType = leftType; - char const* resultType = leftType; + char const* leftType = type.name; + char const* rightType = leftType; + char const* resultType = leftType; - if (op.flags & COMPARISON) resultType = "bool"; + if (op.flags & COMPARISON) resultType = "bool"; - char const* leftQual = ""; - if(op.flags & ASSIGNMENT) leftQual = "in out "; + char const* leftQual = ""; + if(op.flags & ASSIGNMENT) leftQual = "in out "; - // TODO: handle `SHIFT` + // TODO: handle `SHIFT` - // scalar version - sb << "__intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << leftQual << leftType << " left, " << rightType << " right);\n"; + // scalar version + sb << "__intrinsic_op(" << int(op.opCode) << ") " << resultType << " operator" << op.opName << "(" << leftQual << leftType << " left, " << rightType << " right);\n"; - // vector version - sb << "__generic<let N : int> "; - sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(" << leftQual << "vector<" << leftType << ",N> left, vector<" << rightType << ",N> right);\n"; + // vector version + sb << "__generic<let N : int> "; + sb << "__intrinsic_op(" << int(op.opCode) << ") vector<" << resultType << ",N> operator" << op.opName << "(" << leftQual << "vector<" << leftType << ",N> left, vector<" << rightType << ",N> right);\n"; - // matrix version - sb << "__generic<let N : int, let M : int> "; - sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << leftQual << "matrix<" << leftType << ",N,M> left, matrix<" << rightType << ",N,M> right);\n"; - } + // matrix version + sb << "__generic<let N : int, let M : int> "; + sb << "__intrinsic_op(" << int(op.opCode) << ") matrix<" << resultType << ",N,M> operator" << op.opName << "(" << leftQual << "matrix<" << leftType << ",N,M> left, matrix<" << rightType << ",N,M> right);\n"; } + } #if 0 - for (auto op : intUnaryOps) + for (auto op : intUnaryOps) + { + String opName = GetOperatorFunctionName(op); + for (int i = 0; i < 4; i++) { - String opName = GetOperatorFunctionName(op); - for (int i = 0; i < 4; i++) + auto itype = intTypes[i]; + auto utype = uintTypes[i]; + for (int j = 0; j < 2; j++) { - auto itype = intTypes[i]; - auto utype = uintTypes[i]; - for (int j = 0; j < 2; j++) - { - auto retType = (op == Operator::Not) ? "bool" : j == 0 ? itype : utype; - sb << "__intrinsic " << retType << " operator " << opName << "(" << (j == 0 ? itype : utype) << ");\n"; - } + auto retType = (op == Operator::Not) ? "bool" : j == 0 ? itype : utype; + sb << "__intrinsic " << retType << " operator " << opName << "(" << (j == 0 ? itype : utype) << ");\n"; } } + } - for (auto op : floatUnaryOps) + for (auto op : floatUnaryOps) + { + String opName = GetOperatorFunctionName(op); + for (int i = 0; i < 4; i++) { - String opName = GetOperatorFunctionName(op); - for (int i = 0; i < 4; i++) - { - auto type = floatTypes[i]; - auto retType = (op == Operator::Not) ? "bool" : type; - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ");\n"; - } + auto type = floatTypes[i]; + auto retType = (op == Operator::Not) ? "bool" : type; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ");\n"; } + } - for (auto op : floatOps) + for (auto op : floatOps) + { + String opName = GetOperatorFunctionName(op); + for (int i = 0; i < 4; i++) { - String opName = GetOperatorFunctionName(op); - for (int i = 0; i < 4; i++) + auto type = floatTypes[i]; + auto itype = intTypes[i]; + auto utype = uintTypes[i]; + auto retType = ((op >= Operator::Eql && op <= Operator::Leq) || op == Operator::And || op == Operator::Or) ? "bool" : type; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << itype << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << itype << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << utype << ");\n"; + if (i > 0) { - auto type = floatTypes[i]; - auto itype = intTypes[i]; - auto utype = uintTypes[i]; - auto retType = ((op >= Operator::Eql && op <= Operator::Leq) || op == Operator::And || op == Operator::Or) ? "bool" : type; - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << type << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << itype << ", " << type << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << type << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << itype << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << utype << ");\n"; - if (i > 0) - { - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << floatTypes[0] << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << floatTypes[0] << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << floatTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << floatTypes[0] << ", " << type << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << intTypes[0] << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << intTypes[0] << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << intTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << intTypes[0] << ", " << type << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << uintTypes[0] << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << uintTypes[0] << ", " << type << ");\n"; - } + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << uintTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << uintTypes[0] << ", " << type << ");\n"; } } + } - for (auto op : intOps) + for (auto op : intOps) + { + String opName = GetOperatorFunctionName(op); + for (int i = 0; i < 4; i++) { - String opName = GetOperatorFunctionName(op); - for (int i = 0; i < 4; i++) + auto type = intTypes[i]; + auto utype = uintTypes[i]; + auto retType = ((op >= Operator::Eql && op <= Operator::Leq) || op == Operator::And || op == Operator::Or) ? "bool" : type; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << utype << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << utype << ");\n"; + if (i > 0) { - auto type = intTypes[i]; - auto utype = uintTypes[i]; - auto retType = ((op >= Operator::Eql && op <= Operator::Leq) || op == Operator::And || op == Operator::Or) ? "bool" : type; - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << type << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << type << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << utype << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << utype << ", " << utype << ");\n"; - if (i > 0) - { - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << intTypes[0] << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << intTypes[0] << ", " << type << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << intTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << intTypes[0] << ", " << type << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << uintTypes[0] << ");\n"; - sb << "__intrinsic " << retType << " operator " << opName << "(" << uintTypes[0] << ", " << type << ");\n"; - } + sb << "__intrinsic " << retType << " operator " << opName << "(" << type << ", " << uintTypes[0] << ");\n"; + sb << "__intrinsic " << retType << " operator " << opName << "(" << uintTypes[0] << ", " << type << ");\n"; } } + } #endif - // Output a suitable `#line` directive to point at our raw stdlib code above - sb << "\n#line " << kLibIncludeStringLine << " \"" << path << "\"\n"; - - int chunkCount = sizeof(LibIncludeStringChunks) / sizeof(LibIncludeStringChunks[0]); - for (int cc = 0; cc < chunkCount; ++cc) - { - sb << LibIncludeStringChunks[cc]; - } + // Output a suitable `#line` directive to point at our raw stdlib code above + sb << "\n#line " << kLibIncludeStringLine << " \"" << path << "\"\n"; - code = sb.ProduceString(); - return code; + int chunkCount = sizeof(LibIncludeStringChunks) / sizeof(LibIncludeStringChunks[0]); + for (int cc = 0; cc < chunkCount; ++cc) + { + sb << LibIncludeStringChunks[cc]; } + code = sb.ProduceString(); + return code; + } - // GLSL-specific library code - String glslLibraryCode; + // GLSL-specific library code - String getGLSLLibraryCode() - { - if(glslLibraryCode.Length() != 0) - return glslLibraryCode; + String glslLibraryCode; - String path = getStdlibPath(); + String getGLSLLibraryCode() + { + if(glslLibraryCode.Length() != 0) + return glslLibraryCode; - StringBuilder sb; + String path = getStdlibPath(); + + StringBuilder sb; #define RAW(TEXT) \ - EMIT_LINE_DIRECTIVE(); \ - sb << TEXT; +EMIT_LINE_DIRECTIVE(); \ +sb << TEXT; + + static const struct { + char const* name; + char const* glslPrefix; + } kTypes[] = + { + {"float", ""}, + {"int", "i"}, + {"uint", "u"}, + {"bool", "b"}, + }; + static const int kTypeCount = sizeof(kTypes) / sizeof(kTypes[0]); - static const struct { - char const* name; - char const* glslPrefix; - } kTypes[] = + for( int tt = 0; tt < kTypeCount; ++tt ) + { + // Declare GLSL aliases for HLSL types + for (int vv = 2; vv <= 4; ++vv) { - {"float", ""}, - {"int", "i"}, - {"uint", "u"}, - {"bool", "b"}, - }; - static const int kTypeCount = sizeof(kTypes) / sizeof(kTypes[0]); - - for( int tt = 0; tt < kTypeCount; ++tt ) + sb << "typedef " << kTypes[tt].name << vv << " " << kTypes[tt].glslPrefix << "vec" << vv << ";\n"; + sb << "typedef " << kTypes[tt].name << vv << "x" << vv << " " << kTypes[tt].glslPrefix << "mat" << vv << ";\n"; + } + for (int rr = 2; rr <= 4; ++rr) + for (int cc = 2; cc <= 4; ++cc) { - // Declare GLSL aliases for HLSL types - for (int vv = 2; vv <= 4; ++vv) - { - sb << "typedef " << kTypes[tt].name << vv << " " << kTypes[tt].glslPrefix << "vec" << vv << ";\n"; - sb << "typedef " << kTypes[tt].name << vv << "x" << vv << " " << kTypes[tt].glslPrefix << "mat" << vv << ";\n"; - } - for (int rr = 2; rr <= 4; ++rr) - for (int cc = 2; cc <= 4; ++cc) - { - sb << "typedef " << kTypes[tt].name << rr << "x" << cc << " " << kTypes[tt].glslPrefix << "mat" << rr << "x" << cc << ";\n"; - } + sb << "typedef " << kTypes[tt].name << rr << "x" << cc << " " << kTypes[tt].glslPrefix << "mat" << rr << "x" << cc << ";\n"; } + } - // TODO(tfoley): Need to handle `RW*` variants of texture types as well... - static const struct { - char const* name; - TextureType::Shape baseShape; - int coordCount; - } kBaseTextureTypes[] = { - { "1D", TextureType::Shape1D, 1 }, - { "2D", TextureType::Shape2D, 2 }, - { "3D", TextureType::Shape3D, 3 }, - { "Cube", TextureType::ShapeCube, 3 }, - }; - static const int kBaseTextureTypeCount = sizeof(kBaseTextureTypes) / sizeof(kBaseTextureTypes[0]); - - - static const struct { - char const* name; - SlangResourceAccess access; - } kBaseTextureAccessLevels[] = { - { "", SLANG_RESOURCE_ACCESS_READ }, - { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE }, - { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED }, - }; - static const int kBaseTextureAccessLevelCount = sizeof(kBaseTextureAccessLevels) / sizeof(kBaseTextureAccessLevels[0]); - - for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) + // TODO(tfoley): Need to handle `RW*` variants of texture types as well... + static const struct { + char const* name; + TextureType::Shape baseShape; + int coordCount; + } kBaseTextureTypes[] = { + { "1D", TextureType::Shape1D, 1 }, + { "2D", TextureType::Shape2D, 2 }, + { "3D", TextureType::Shape3D, 3 }, + { "Cube", TextureType::ShapeCube, 3 }, + }; + static const int kBaseTextureTypeCount = sizeof(kBaseTextureTypes) / sizeof(kBaseTextureTypes[0]); + + + static const struct { + char const* name; + SlangResourceAccess access; + } kBaseTextureAccessLevels[] = { + { "", SLANG_RESOURCE_ACCESS_READ }, + { "RW", SLANG_RESOURCE_ACCESS_READ_WRITE }, + { "RasterizerOrdered", SLANG_RESOURCE_ACCESS_RASTER_ORDERED }, + }; + static const int kBaseTextureAccessLevelCount = sizeof(kBaseTextureAccessLevels) / sizeof(kBaseTextureAccessLevels[0]); + + for (int tt = 0; tt < kBaseTextureTypeCount; ++tt) + { + char const* shapeName = kBaseTextureTypes[tt].name; + TextureType::Shape baseShape = kBaseTextureTypes[tt].baseShape; + + for (int isArray = 0; isArray < 2; ++isArray) { - char const* shapeName = kBaseTextureTypes[tt].name; - TextureType::Shape baseShape = kBaseTextureTypes[tt].baseShape; + // Arrays of 3D textures aren't allowed + if (isArray && baseShape == TextureType::Shape3D) continue; - for (int isArray = 0; isArray < 2; ++isArray) + for (int isMultisample = 0; isMultisample < 2; ++isMultisample) { - // Arrays of 3D textures aren't allowed - if (isArray && baseShape == TextureType::Shape3D) continue; + auto access = SLANG_RESOURCE_ACCESS_READ; - for (int isMultisample = 0; isMultisample < 2; ++isMultisample) - { - auto access = SLANG_RESOURCE_ACCESS_READ; - - // TODO: any constraints to enforce on what gets to be multisampled? + // TODO: any constraints to enforce on what gets to be multisampled? - unsigned flavor = baseShape; - if (isArray) flavor |= TextureType::ArrayFlag; - if (isMultisample) flavor |= TextureType::MultisampleFlag; + unsigned flavor = baseShape; + if (isArray) flavor |= TextureType::ArrayFlag; + if (isMultisample) flavor |= TextureType::MultisampleFlag; // if (isShadow) flavor |= TextureType::ShadowFlag; - flavor |= (access << 8); + flavor |= (access << 8); - StringBuilder nameBuilder; - nameBuilder << shapeName; - if (isMultisample) nameBuilder << "MS"; - if (isArray) nameBuilder << "Array"; - auto name = nameBuilder.ProduceString(); + StringBuilder nameBuilder; + nameBuilder << shapeName; + if (isMultisample) nameBuilder << "MS"; + if (isArray) nameBuilder << "Array"; + auto name = nameBuilder.ProduceString(); - sb << "__generic<T> "; - sb << "__magic_type(TextureSampler," << int(flavor) << ") struct "; - sb << "__sampler" << name; - sb << " {};\n"; + sb << "__generic<T> "; + sb << "__magic_type(TextureSampler," << int(flavor) << ") struct "; + sb << "__sampler" << name; + sb << " {};\n"; - sb << "__generic<T> "; - sb << "__magic_type(Texture," << int(flavor) << ") struct "; - sb << "__texture" << name; - sb << " {};\n"; + sb << "__generic<T> "; + sb << "__magic_type(Texture," << int(flavor) << ") struct "; + sb << "__texture" << name; + sb << " {};\n"; - sb << "__generic<T> "; - sb << "__magic_type(GLSLImageType," << int(flavor) << ") struct "; - sb << "__image" << name; - sb << " {};\n"; + sb << "__generic<T> "; + sb << "__magic_type(GLSLImageType," << int(flavor) << ") struct "; + sb << "__image" << name; + sb << " {};\n"; - // TODO(tfoley): flesh this out for all the available prefixes - static const struct - { - char const* prefix; - char const* elementType; - } kTextureElementTypes[] = { - { "", "vec4" }, - { "i", "ivec4" }, - { "u", "uvec4" }, - { nullptr, nullptr }, - }; - for( auto ee = kTextureElementTypes; ee->prefix; ++ee ) - { - sb << "typedef __sampler" << name << "<" << ee->elementType << "> " << ee->prefix << "sampler" << name << ";\n"; - sb << "typedef __texture" << name << "<" << ee->elementType << "> " << ee->prefix << "texture" << name << ";\n"; - sb << "typedef __image" << name << "<" << ee->elementType << "> " << ee->prefix << "image" << name << ";\n"; - } + // TODO(tfoley): flesh this out for all the available prefixes + static const struct + { + char const* prefix; + char const* elementType; + } kTextureElementTypes[] = { + { "", "vec4" }, + { "i", "ivec4" }, + { "u", "uvec4" }, + { nullptr, nullptr }, + }; + for( auto ee = kTextureElementTypes; ee->prefix; ++ee ) + { + sb << "typedef __sampler" << name << "<" << ee->elementType << "> " << ee->prefix << "sampler" << name << ";\n"; + sb << "typedef __texture" << name << "<" << ee->elementType << "> " << ee->prefix << "texture" << name << ";\n"; + sb << "typedef __image" << name << "<" << ee->elementType << "> " << ee->prefix << "image" << name << ";\n"; } } } + } - sb << "__generic<T> __magic_type(GLSLInputParameterBlockType) struct __GLSLInputParameterBlock {};\n"; - sb << "__generic<T> __magic_type(GLSLOutputParameterBlockType) struct __GLSLOutputParameterBlock {};\n"; - sb << "__generic<T> __magic_type(GLSLShaderStorageBufferType) struct __GLSLShaderStorageBuffer {};\n"; - - sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerState) << ") struct sampler {};"; + sb << "__generic<T> __magic_type(GLSLInputParameterBlockType) struct __GLSLInputParameterBlock {};\n"; + sb << "__generic<T> __magic_type(GLSLOutputParameterBlockType) struct __GLSLOutputParameterBlock {};\n"; + sb << "__generic<T> __magic_type(GLSLShaderStorageBufferType) struct __GLSLShaderStorageBuffer {};\n"; - sb << "__magic_type(GLSLInputAttachmentType) struct subpassInput {};"; + sb << "__magic_type(SamplerState," << int(SamplerStateType::Flavor::SamplerState) << ") struct sampler {};"; - // Define additional keywords - sb << "__modifier(GLSLBufferModifier) buffer;\n"; - sb << "__modifier(GLSLWriteOnlyModifier) writeonly;\n"; - sb << "__modifier(GLSLReadOnlyModifier) readonly;\n"; - sb << "__modifier(GLSLPatchModifier) patch;\n"; + sb << "__magic_type(GLSLInputAttachmentType) struct subpassInput {};"; - sb << "__modifier(SimpleModifier) flat;\n"; + // Define additional keywords + sb << "__modifier(GLSLBufferModifier) buffer;\n"; + sb << "__modifier(GLSLWriteOnlyModifier) writeonly;\n"; + sb << "__modifier(GLSLReadOnlyModifier) readonly;\n"; + sb << "__modifier(GLSLPatchModifier) patch;\n"; - glslLibraryCode = sb.ProduceString(); - return glslLibraryCode; - } + sb << "__modifier(SimpleModifier) flat;\n"; + glslLibraryCode = sb.ProduceString(); + return glslLibraryCode; + } - // - void SlangStdLib::Finalize() - { - code = nullptr; - stdlibPath = String(); - glslLibraryCode = String(); - } + // + void SlangStdLib::Finalize() + { + code = nullptr; + stdlibPath = String(); + glslLibraryCode = String(); } -} +} diff --git a/source/slang/slang-stdlib.h b/source/slang/slang-stdlib.h index 65c70ecb5..fd41565d1 100644 --- a/source/slang/slang-stdlib.h +++ b/source/slang/slang-stdlib.h @@ -5,19 +5,16 @@ namespace Slang { - namespace Compiler + class SlangStdLib { - class SlangStdLib - { - private: - static CoreLib::String code; - public: - static CoreLib::String GetCode(); - static void Finalize(); - }; + private: + static CoreLib::String code; + public: + static CoreLib::String GetCode(); + static void Finalize(); + }; - CoreLib::String getGLSLLibraryCode(); - } + CoreLib::String getGLSLLibraryCode(); } #endif
\ No newline at end of file diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 1de94bb14..9f64edc7f 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -19,10 +19,8 @@ using namespace CoreLib::Basic; using namespace CoreLib::IO; -using namespace Slang::Compiler; namespace Slang { -namespace Compiler { static void stdlibDiagnosticCallback( char const* message, @@ -603,16 +601,16 @@ void Session::addBuiltinSource( loadedModuleCode.Add(syntax); } -}} +} // implementation of C interface -#define SESSION(x) reinterpret_cast<Slang::Compiler::Session *>(x) -#define REQ(x) reinterpret_cast<Slang::Compiler::CompileRequest*>(x) +#define SESSION(x) reinterpret_cast<Slang::Session *>(x) +#define REQ(x) reinterpret_cast<Slang::CompileRequest*>(x) SLANG_API SlangSession* spCreateSession(const char * cacheDir) { - return reinterpret_cast<SlangSession *>(new Slang::Compiler::Session((cacheDir ? true : false), cacheDir)); + return reinterpret_cast<SlangSession *>(new Slang::Session((cacheDir ? true : false), cacheDir)); } SLANG_API void spDestroySession( @@ -642,7 +640,7 @@ SLANG_API SlangCompileRequest* spCreateCompileRequest( SlangSession* session) { auto s = SESSION(session); - auto req = new Slang::Compiler::CompileRequest(s); + auto req = new Slang::CompileRequest(s); return reinterpret_cast<SlangCompileRequest*>(req); } @@ -668,14 +666,14 @@ SLANG_API void spSetCodeGenTarget( SlangCompileRequest* request, int target) { - REQ(request)->Options.Target = (CodeGenTarget)target; + REQ(request)->Options.Target = (Slang::CodeGenTarget)target; } SLANG_API void spSetPassThrough( SlangCompileRequest* request, SlangPassThrough passThrough) { - REQ(request)->Options.passThrough = PassThroughMode(passThrough); + REQ(request)->Options.passThrough = Slang::PassThroughMode(passThrough); } SLANG_API void spSetDiagnosticCallback( @@ -723,7 +721,7 @@ SLANG_API int spAddTranslationUnit( auto req = REQ(request); return req->addTranslationUnit( - SourceLanguage(language), + Slang::SourceLanguage(language), name ? name : ""); } @@ -781,7 +779,7 @@ SLANG_API SlangProfileID spFindProfile( SlangSession* session, char const* name) { - return Profile::LookUp(name).raw; + return Slang::Profile::LookUp(name).raw; } SLANG_API int spAddTranslationUnitEntryPoint( @@ -800,7 +798,7 @@ SLANG_API int spAddTranslationUnitEntryPoint( return req->addTranslationUnitEntryPoint( translationUnitIndex, name, - Profile(Profile::RawVal(profile))); + Slang::Profile(Slang::Profile::RawVal(profile))); } diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index bfc1e7317..4f64baf6e 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -1,6 +1,6 @@ <?xml version="1.0" encoding="utf-8"?> <AutoVisualizer xmlns="http://schemas.microsoft.com/vstudio/debugger/natvis/2010"> - <Type Name="Slang::Compiler::CFGNode"> + <Type Name="Slang::CFGNode"> <DisplayString>{{CFG Basic Block}}</DisplayString> <Expand> <LinkedListItems> diff --git a/source/slang/source-loc.h b/source/slang/source-loc.h index dc353f402..123da8c16 100644 --- a/source/slang/source-loc.h +++ b/source/slang/source-loc.h @@ -5,7 +5,6 @@ #include "../core/basic.h" namespace Slang { -namespace Compiler { using namespace CoreLib::Basic; @@ -42,6 +41,6 @@ public: }; -}} +} // namespace Slang #endif diff --git a/source/slang/syntax-visitors.h b/source/slang/syntax-visitors.h index da1984f05..f83334155 100644 --- a/source/slang/syntax-visitors.h +++ b/source/slang/syntax-visitors.h @@ -7,31 +7,28 @@ namespace Slang { - namespace Compiler - { - class CompileOptions; - struct CompileRequest; - class ShaderCompiler; - class ShaderLinkInfo; - class ShaderSymbol; + class CompileOptions; + struct CompileRequest; + class ShaderCompiler; + class ShaderLinkInfo; + class ShaderSymbol; - SyntaxVisitor* CreateSemanticsVisitor( - DiagnosticSink* err, - CompileOptions const& options, - CompileRequest* request); + SyntaxVisitor* CreateSemanticsVisitor( + DiagnosticSink* err, + CompileOptions const& options, + CompileRequest* request); - // Look for a module that matches the given name: - // either one we've loaded already, or one we - // can find vai the search paths available to us. - // - // Needed by import declaration checking. - // - // TODO: need a better location to declare this. - RefPtr<ProgramSyntaxNode> findOrImportModule( - CompileRequest* request, - String const& name, - CodePosition const& loc); - } + // Look for a module that matches the given name: + // either one we've loaded already, or one we + // can find vai the search paths available to us. + // + // Needed by import declaration checking. + // + // TODO: need a better location to declare this. + RefPtr<ProgramSyntaxNode> findOrImportModule( + CompileRequest* request, + String const& name, + CodePosition const& loc); } #endif
\ No newline at end of file diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index e47c610c0..38323e2b3 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -5,1481 +5,1478 @@ namespace Slang { - namespace Compiler - { - // BasicExpressionType + // BasicExpressionType - bool BasicExpressionType::EqualsImpl(ExpressionType * type) - { - auto basicType = dynamic_cast<const BasicExpressionType*>(type); - if (basicType == nullptr) - return false; - return basicType->BaseType == BaseType; - } - - ExpressionType* BasicExpressionType::CreateCanonicalType() - { - // A basic type is already canonical, in our setup - return this; - } + bool BasicExpressionType::EqualsImpl(ExpressionType * type) + { + auto basicType = dynamic_cast<const BasicExpressionType*>(type); + if (basicType == nullptr) + return false; + return basicType->BaseType == BaseType; + } - CoreLib::Basic::String BasicExpressionType::ToString() - { - CoreLib::Basic::StringBuilder res; + ExpressionType* BasicExpressionType::CreateCanonicalType() + { + // A basic type is already canonical, in our setup + return this; + } - switch (BaseType) - { - case Compiler::BaseType::Int: - res.Append("int"); - break; - case Compiler::BaseType::UInt: - res.Append("uint"); - break; - case Compiler::BaseType::UInt64: - res.Append("uint64_t"); - break; - case Compiler::BaseType::Bool: - res.Append("bool"); - break; - case Compiler::BaseType::Float: - res.Append("float"); - break; - case Compiler::BaseType::Void: - res.Append("void"); - break; - default: - break; - } - return res.ProduceString(); - } + CoreLib::Basic::String BasicExpressionType::ToString() + { + CoreLib::Basic::StringBuilder res; + + switch (BaseType) + { + case Slang::BaseType::Int: + res.Append("int"); + break; + case Slang::BaseType::UInt: + res.Append("uint"); + break; + case Slang::BaseType::UInt64: + res.Append("uint64_t"); + break; + case Slang::BaseType::Bool: + res.Append("bool"); + break; + case Slang::BaseType::Float: + res.Append("float"); + break; + case Slang::BaseType::Void: + res.Append("void"); + break; + default: + break; + } + return res.ProduceString(); + } - RefPtr<SyntaxNode> ProgramSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitProgram(this); - } + RefPtr<SyntaxNode> ProgramSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitProgram(this); + } - RefPtr<SyntaxNode> FunctionSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitFunction(this); - } + RefPtr<SyntaxNode> FunctionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitFunction(this); + } - // + // - RefPtr<SyntaxNode> ScopeDecl::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitScopeDecl(this); - } + RefPtr<SyntaxNode> ScopeDecl::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitScopeDecl(this); + } - // + // - RefPtr<SyntaxNode> BlockStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitBlockStatement(this); - } + RefPtr<SyntaxNode> BlockStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitBlockStatement(this); + } - RefPtr<SyntaxNode> BreakStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitBreakStatement(this); - } + RefPtr<SyntaxNode> BreakStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitBreakStatement(this); + } - RefPtr<SyntaxNode> ContinueStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitContinueStatement(this); - } + RefPtr<SyntaxNode> ContinueStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitContinueStatement(this); + } - RefPtr<SyntaxNode> DoWhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitDoWhileStatement(this); - } + RefPtr<SyntaxNode> DoWhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitDoWhileStatement(this); + } - RefPtr<SyntaxNode> EmptyStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitEmptyStatement(this); - } + RefPtr<SyntaxNode> EmptyStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitEmptyStatement(this); + } - RefPtr<SyntaxNode> ForStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitForStatement(this); - } + RefPtr<SyntaxNode> ForStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitForStatement(this); + } - RefPtr<SyntaxNode> IfStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitIfStatement(this); - } + RefPtr<SyntaxNode> IfStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitIfStatement(this); + } - RefPtr<SyntaxNode> ReturnStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitReturnStatement(this); - } + RefPtr<SyntaxNode> ReturnStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitReturnStatement(this); + } - RefPtr<SyntaxNode> VarDeclrStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitVarDeclrStatement(this); - } + RefPtr<SyntaxNode> VarDeclrStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitVarDeclrStatement(this); + } - RefPtr<SyntaxNode> Variable::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitDeclrVariable(this); - } + RefPtr<SyntaxNode> Variable::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitDeclrVariable(this); + } - RefPtr<SyntaxNode> WhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitWhileStatement(this); - } + RefPtr<SyntaxNode> WhileStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitWhileStatement(this); + } - RefPtr<SyntaxNode> ExpressionStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitExpressionStatement(this); - } + RefPtr<SyntaxNode> ExpressionStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitExpressionStatement(this); + } - RefPtr<SyntaxNode> ConstantExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitConstantExpression(this); - } + RefPtr<SyntaxNode> ConstantExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitConstantExpression(this); + } - RefPtr<SyntaxNode> IndexExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitIndexExpression(this); - } - RefPtr<SyntaxNode> MemberExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitMemberExpression(this); - } + RefPtr<SyntaxNode> IndexExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitIndexExpression(this); + } + RefPtr<SyntaxNode> MemberExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitMemberExpression(this); + } - // SwizzleExpr + // SwizzleExpr - RefPtr<SyntaxNode> SwizzleExpr::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitSwizzleExpression(this); - } + RefPtr<SyntaxNode> SwizzleExpr::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitSwizzleExpression(this); + } - // DerefExpr + // DerefExpr - RefPtr<SyntaxNode> DerefExpr::Accept(SyntaxVisitor * /*visitor*/) - { - // throw "unimplemented"; - return this; - } + RefPtr<SyntaxNode> DerefExpr::Accept(SyntaxVisitor * /*visitor*/) + { + // throw "unimplemented"; + return this; + } - // + // - RefPtr<SyntaxNode> InvokeExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitInvokeExpression(this); - } + RefPtr<SyntaxNode> InvokeExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitInvokeExpression(this); + } - RefPtr<SyntaxNode> TypeCastExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitTypeCastExpression(this); - } + RefPtr<SyntaxNode> TypeCastExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitTypeCastExpression(this); + } - RefPtr<SyntaxNode> VarExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitVarExpression(this); - } + RefPtr<SyntaxNode> VarExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitVarExpression(this); + } - // OverloadedExpr + // OverloadedExpr - RefPtr<SyntaxNode> OverloadedExpr::Accept(SyntaxVisitor * /*visitor*/) - { + RefPtr<SyntaxNode> OverloadedExpr::Accept(SyntaxVisitor * /*visitor*/) + { // throw "unimplemented"; - return this; - } + return this; + } - // + // - RefPtr<SyntaxNode> ParameterSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitParameter(this); - } + RefPtr<SyntaxNode> ParameterSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitParameter(this); + } - // ImportDecl + // ImportDecl - RefPtr<SyntaxNode> ImportDecl::Accept(SyntaxVisitor * visitor) - { - visitor->visitImportDecl(this); - return this; - } + RefPtr<SyntaxNode> ImportDecl::Accept(SyntaxVisitor * visitor) + { + visitor->visitImportDecl(this); + return this; + } - // + // - RefPtr<SyntaxNode> StructField::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitStructField(this); - } - RefPtr<SyntaxNode> StructSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitStruct(this); - } - RefPtr<SyntaxNode> ClassSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitClass(this); - } - RefPtr<SyntaxNode> TypeDefDecl::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitTypeDefDecl(this); - } + RefPtr<SyntaxNode> StructField::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitStructField(this); + } + RefPtr<SyntaxNode> StructSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitStruct(this); + } + RefPtr<SyntaxNode> ClassSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitClass(this); + } + RefPtr<SyntaxNode> TypeDefDecl::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitTypeDefDecl(this); + } - RefPtr<SyntaxNode> DiscardStatementSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitDiscardStatement(this); - } + RefPtr<SyntaxNode> DiscardStatementSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitDiscardStatement(this); + } - // BasicExpressionType + // BasicExpressionType - BasicExpressionType* BasicExpressionType::GetScalarType() - { - return this; - } + BasicExpressionType* BasicExpressionType::GetScalarType() + { + return this; + } - // + // - bool ExpressionType::Equals(ExpressionType * type) - { - return GetCanonicalType()->EqualsImpl(type->GetCanonicalType()); - } + bool ExpressionType::Equals(ExpressionType * type) + { + return GetCanonicalType()->EqualsImpl(type->GetCanonicalType()); + } - bool ExpressionType::Equals(RefPtr<ExpressionType> type) - { - return Equals(type.Ptr()); - } + bool ExpressionType::Equals(RefPtr<ExpressionType> type) + { + return Equals(type.Ptr()); + } - bool ExpressionType::EqualsVal(Val* val) - { - if (auto type = dynamic_cast<ExpressionType*>(val)) - return const_cast<ExpressionType*>(this)->Equals(type); - return false; - } + bool ExpressionType::EqualsVal(Val* val) + { + if (auto type = dynamic_cast<ExpressionType*>(val)) + return const_cast<ExpressionType*>(this)->Equals(type); + return false; + } - NamedExpressionType* ExpressionType::AsNamedType() - { - return dynamic_cast<NamedExpressionType*>(this); - } + NamedExpressionType* ExpressionType::AsNamedType() + { + return dynamic_cast<NamedExpressionType*>(this); + } - RefPtr<Val> ExpressionType::SubstituteImpl(Substitutions* subst, int* ioDiff) - { - int diff = 0; - auto canSubst = GetCanonicalType()->SubstituteImpl(subst, &diff); + RefPtr<Val> ExpressionType::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + int diff = 0; + auto canSubst = GetCanonicalType()->SubstituteImpl(subst, &diff); - // If nothing changed, then don't drop any sugar that is applied - if (!diff) - return this; + // If nothing changed, then don't drop any sugar that is applied + if (!diff) + return this; - // If the canonical type changed, then we return a canonical type, - // rather than try to re-construct any amount of sugar - (*ioDiff)++; - return canSubst; - } + // If the canonical type changed, then we return a canonical type, + // rather than try to re-construct any amount of sugar + (*ioDiff)++; + return canSubst; + } - ExpressionType* ExpressionType::GetCanonicalType() + ExpressionType* ExpressionType::GetCanonicalType() + { + if (!this) return nullptr; + ExpressionType* et = const_cast<ExpressionType*>(this); + if (!et->canonicalType) { - if (!this) return nullptr; - ExpressionType* et = const_cast<ExpressionType*>(this); - if (!et->canonicalType) - { - // TODO(tfoley): worry about thread safety here? - et->canonicalType = et->CreateCanonicalType(); - assert(et->canonicalType); - } - return et->canonicalType; + // TODO(tfoley): worry about thread safety here? + et->canonicalType = et->CreateCanonicalType(); + assert(et->canonicalType); } + return et->canonicalType; + } - bool ExpressionType::IsTextureOrSampler() - { - return IsTexture() || IsSampler(); - } - bool ExpressionType::IsStruct() - { - auto declRefType = AsDeclRefType(); - if (!declRefType) return false; - auto structDeclRef = declRefType->declRef.As<StructDeclRef>(); - if (!structDeclRef) return false; - return true; - } + bool ExpressionType::IsTextureOrSampler() + { + return IsTexture() || IsSampler(); + } + bool ExpressionType::IsStruct() + { + auto declRefType = AsDeclRefType(); + if (!declRefType) return false; + auto structDeclRef = declRefType->declRef.As<StructDeclRef>(); + if (!structDeclRef) return false; + return true; + } - bool ExpressionType::IsClass() - { - auto declRefType = AsDeclRefType(); - if (!declRefType) return false; - auto classDeclRef = declRefType->declRef.As<ClassDeclRef>(); - if (!classDeclRef) return false; - return true; - } + bool ExpressionType::IsClass() + { + auto declRefType = AsDeclRefType(); + if (!declRefType) return false; + auto classDeclRef = declRefType->declRef.As<ClassDeclRef>(); + if (!classDeclRef) return false; + return true; + } #if 0 - RefPtr<ExpressionType> ExpressionType::Bool; - RefPtr<ExpressionType> ExpressionType::UInt; - RefPtr<ExpressionType> ExpressionType::Int; - RefPtr<ExpressionType> ExpressionType::Float; - RefPtr<ExpressionType> ExpressionType::Float2; - RefPtr<ExpressionType> ExpressionType::Void; + RefPtr<ExpressionType> ExpressionType::Bool; + RefPtr<ExpressionType> ExpressionType::UInt; + RefPtr<ExpressionType> ExpressionType::Int; + RefPtr<ExpressionType> ExpressionType::Float; + RefPtr<ExpressionType> ExpressionType::Float2; + RefPtr<ExpressionType> ExpressionType::Void; #endif - RefPtr<ExpressionType> ExpressionType::Error; - RefPtr<ExpressionType> ExpressionType::initializerListType; - RefPtr<ExpressionType> ExpressionType::Overloaded; + RefPtr<ExpressionType> ExpressionType::Error; + RefPtr<ExpressionType> ExpressionType::initializerListType; + RefPtr<ExpressionType> ExpressionType::Overloaded; - Dictionary<int, RefPtr<ExpressionType>> ExpressionType::sBuiltinTypes; - Dictionary<String, Decl*> ExpressionType::sMagicDecls; - List<RefPtr<ExpressionType>> ExpressionType::sCanonicalTypes; + Dictionary<int, RefPtr<ExpressionType>> ExpressionType::sBuiltinTypes; + Dictionary<String, Decl*> ExpressionType::sMagicDecls; + List<RefPtr<ExpressionType>> ExpressionType::sCanonicalTypes; - void ExpressionType::Init() - { - Error = new ErrorType(); - initializerListType = new InitializerListType(); - Overloaded = new OverloadGroupType(); - } - void ExpressionType::Finalize() - { - Error = nullptr; - initializerListType = nullptr; - Overloaded = nullptr; - // Note(tfoley): This seems to be just about the only way to clear out a List<T> - sCanonicalTypes = List<RefPtr<ExpressionType>>(); - sBuiltinTypes = Dictionary<int, RefPtr<ExpressionType>>(); - sMagicDecls = Dictionary<String, Decl*>(); - } - bool ArrayExpressionType::EqualsImpl(ExpressionType * type) - { - auto arrType = type->AsArrayType(); - if (!arrType) - return false; - return (ArrayLength == arrType->ArrayLength && BaseType->Equals(arrType->BaseType.Ptr())); - } - ExpressionType* ArrayExpressionType::CreateCanonicalType() - { - auto canonicalBaseType = BaseType->GetCanonicalType(); - auto canonicalArrayType = new ArrayExpressionType(); - sCanonicalTypes.Add(canonicalArrayType); - canonicalArrayType->BaseType = canonicalBaseType; - canonicalArrayType->ArrayLength = ArrayLength; - return canonicalArrayType; - } - int ArrayExpressionType::GetHashCode() - { - if (ArrayLength) - return (BaseType->GetHashCode() * 16777619) ^ ArrayLength->GetHashCode(); - else - return BaseType->GetHashCode(); - } - CoreLib::Basic::String ArrayExpressionType::ToString() - { - if (ArrayLength) - return BaseType->ToString() + "[" + ArrayLength->ToString() + "]"; - else - return BaseType->ToString() + "[]"; - } - RefPtr<SyntaxNode> GenericAppExpr::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitGenericApp(this); - } + void ExpressionType::Init() + { + Error = new ErrorType(); + initializerListType = new InitializerListType(); + Overloaded = new OverloadGroupType(); + } + void ExpressionType::Finalize() + { + Error = nullptr; + initializerListType = nullptr; + Overloaded = nullptr; + // Note(tfoley): This seems to be just about the only way to clear out a List<T> + sCanonicalTypes = List<RefPtr<ExpressionType>>(); + sBuiltinTypes = Dictionary<int, RefPtr<ExpressionType>>(); + sMagicDecls = Dictionary<String, Decl*>(); + } + bool ArrayExpressionType::EqualsImpl(ExpressionType * type) + { + auto arrType = type->AsArrayType(); + if (!arrType) + return false; + return (ArrayLength == arrType->ArrayLength && BaseType->Equals(arrType->BaseType.Ptr())); + } + ExpressionType* ArrayExpressionType::CreateCanonicalType() + { + auto canonicalBaseType = BaseType->GetCanonicalType(); + auto canonicalArrayType = new ArrayExpressionType(); + sCanonicalTypes.Add(canonicalArrayType); + canonicalArrayType->BaseType = canonicalBaseType; + canonicalArrayType->ArrayLength = ArrayLength; + return canonicalArrayType; + } + int ArrayExpressionType::GetHashCode() + { + if (ArrayLength) + return (BaseType->GetHashCode() * 16777619) ^ ArrayLength->GetHashCode(); + else + return BaseType->GetHashCode(); + } + CoreLib::Basic::String ArrayExpressionType::ToString() + { + if (ArrayLength) + return BaseType->ToString() + "[" + ArrayLength->ToString() + "]"; + else + return BaseType->ToString() + "[]"; + } + RefPtr<SyntaxNode> GenericAppExpr::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitGenericApp(this); + } - // DeclRefType + // DeclRefType - String DeclRefType::ToString() - { - return declRef.GetName(); - } + String DeclRefType::ToString() + { + return declRef.GetName(); + } - int DeclRefType::GetHashCode() - { - return (declRef.GetHashCode() * 16777619) ^ (int)(typeid(this).hash_code()); - } + int DeclRefType::GetHashCode() + { + return (declRef.GetHashCode() * 16777619) ^ (int)(typeid(this).hash_code()); + } - bool DeclRefType::EqualsImpl(ExpressionType * type) + bool DeclRefType::EqualsImpl(ExpressionType * type) + { + if (auto declRefType = type->AsDeclRefType()) { - if (auto declRefType = type->AsDeclRefType()) - { - return declRef.Equals(declRefType->declRef); - } - return false; + return declRef.Equals(declRefType->declRef); } + return false; + } - ExpressionType* DeclRefType::CreateCanonicalType() - { - // A declaration reference is already canonical - return this; - } + ExpressionType* DeclRefType::CreateCanonicalType() + { + // A declaration reference is already canonical + return this; + } - RefPtr<Val> DeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff) - { - if (!subst) return this; + RefPtr<Val> DeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + if (!subst) return this; - // the case we especially care about is when this type references a declaration - // of a generic parameter, since that is what we might be substituting... - if (auto genericTypeParamDecl = dynamic_cast<GenericTypeParamDecl*>(declRef.GetDecl())) + // the case we especially care about is when this type references a declaration + // of a generic parameter, since that is what we might be substituting... + if (auto genericTypeParamDecl = dynamic_cast<GenericTypeParamDecl*>(declRef.GetDecl())) + { + // search for a substitution that might apply to us + for (auto s = subst; s; s = s->outer.Ptr()) { - // search for a substitution that might apply to us - for (auto s = subst; s; s = s->outer.Ptr()) + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = s->genericDecl; + if (genericDecl != genericTypeParamDecl->ParentDecl) + continue; + + int index = 0; + for (auto m : genericDecl->Members) { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = s->genericDecl; - if (genericDecl != genericTypeParamDecl->ParentDecl) - continue; - - int index = 0; - for (auto m : genericDecl->Members) + if (m.Ptr() == genericTypeParamDecl) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return s->args[index]; + } + else if(auto typeParam = m.As<GenericTypeParamDecl>()) + { + index++; + } + else if(auto valParam = m.As<GenericValueParamDecl>()) + { + index++; + } + else { - if (m.Ptr() == genericTypeParamDecl) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return s->args[index]; - } - else if(auto typeParam = m.As<GenericTypeParamDecl>()) - { - index++; - } - else if(auto valParam = m.As<GenericValueParamDecl>()) - { - index++; - } - else - { - } } - } + } + } - int diff = 0; - DeclRef substDeclRef = declRef.SubstituteImpl(subst, &diff); + int diff = 0; + DeclRef substDeclRef = declRef.SubstituteImpl(subst, &diff); - if (!diff) - return this; + if (!diff) + return this; - // Make sure to record the difference! - *ioDiff += diff; + // Make sure to record the difference! + *ioDiff += diff; - // Re-construct the type in case we are using a specialized sub-class - return DeclRefType::Create(substDeclRef); - } + // Re-construct the type in case we are using a specialized sub-class + return DeclRefType::Create(substDeclRef); + } + + static RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<Val> val) + { + auto type = val.As<ExpressionType>(); + assert(type.Ptr()); + return type; + } + + static RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<Val> val) + { + auto intVal = val.As<IntVal>(); + assert(intVal.Ptr()); + return intVal; + } - static RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<Val> val) + // TODO: need to figure out how to unify this with the logic + // in the generic case... + DeclRefType* DeclRefType::Create(DeclRef declRef) + { + if (auto builtinMod = declRef.GetDecl()->FindModifier<BuiltinTypeModifier>()) { - auto type = val.As<ExpressionType>(); - assert(type.Ptr()); + auto type = new BasicExpressionType(builtinMod->tag); + type->declRef = declRef; return type; } - - static RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<Val> val) + else if (auto magicMod = declRef.GetDecl()->FindModifier<MagicTypeModifier>()) { - auto intVal = val.As<IntVal>(); - assert(intVal.Ptr()); - return intVal; - } + Substitutions* subst = declRef.substitutions.Ptr(); - // TODO: need to figure out how to unify this with the logic - // in the generic case... - DeclRefType* DeclRefType::Create(DeclRef declRef) - { - if (auto builtinMod = declRef.GetDecl()->FindModifier<BuiltinTypeModifier>()) + if (magicMod->name == "SamplerState") { - auto type = new BasicExpressionType(builtinMod->tag); + auto type = new SamplerStateType(); type->declRef = declRef; + type->flavor = SamplerStateType::Flavor(magicMod->tag); return type; } - else if (auto magicMod = declRef.GetDecl()->FindModifier<MagicTypeModifier>()) + else if (magicMod->name == "Vector") + { + assert(subst && subst->args.Count() == 2); + auto vecType = new VectorExpressionType(); + vecType->declRef = declRef; + vecType->elementType = ExtractGenericArgType(subst->args[0]); + vecType->elementCount = ExtractGenericArgInteger(subst->args[1]); + return vecType; + } + else if (magicMod->name == "Matrix") + { + assert(subst && subst->args.Count() == 3); + auto matType = new MatrixExpressionType(); + matType->declRef = declRef; + return matType; + } + else if (magicMod->name == "Texture") { - Substitutions* subst = declRef.substitutions.Ptr(); + assert(subst && subst->args.Count() >= 1); + auto textureType = new TextureType( + TextureType::Flavor(magicMod->tag), + ExtractGenericArgType(subst->args[0])); + textureType->declRef = declRef; + return textureType; + } + else if (magicMod->name == "TextureSampler") + { + assert(subst && subst->args.Count() >= 1); + auto textureType = new TextureSamplerType( + TextureType::Flavor(magicMod->tag), + ExtractGenericArgType(subst->args[0])); + textureType->declRef = declRef; + return textureType; + } + else if (magicMod->name == "GLSLImageType") + { + assert(subst && subst->args.Count() >= 1); + auto textureType = new GLSLImageType( + TextureType::Flavor(magicMod->tag), + ExtractGenericArgType(subst->args[0])); + textureType->declRef = declRef; + return textureType; + } - if (magicMod->name == "SamplerState") - { - auto type = new SamplerStateType(); - type->declRef = declRef; - type->flavor = SamplerStateType::Flavor(magicMod->tag); - return type; - } - else if (magicMod->name == "Vector") - { - assert(subst && subst->args.Count() == 2); - auto vecType = new VectorExpressionType(); - vecType->declRef = declRef; - vecType->elementType = ExtractGenericArgType(subst->args[0]); - vecType->elementCount = ExtractGenericArgInteger(subst->args[1]); - return vecType; - } - else if (magicMod->name == "Matrix") - { - assert(subst && subst->args.Count() == 3); - auto matType = new MatrixExpressionType(); - matType->declRef = declRef; - return matType; - } - else if (magicMod->name == "Texture") - { - assert(subst && subst->args.Count() >= 1); - auto textureType = new TextureType( - TextureType::Flavor(magicMod->tag), - ExtractGenericArgType(subst->args[0])); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->name == "TextureSampler") - { - assert(subst && subst->args.Count() >= 1); - auto textureType = new TextureSamplerType( - TextureType::Flavor(magicMod->tag), - ExtractGenericArgType(subst->args[0])); - textureType->declRef = declRef; - return textureType; - } - else if (magicMod->name == "GLSLImageType") - { - assert(subst && subst->args.Count() >= 1); - auto textureType = new GLSLImageType( - TextureType::Flavor(magicMod->tag), - ExtractGenericArgType(subst->args[0])); - textureType->declRef = declRef; - return textureType; + #define CASE(n,T) \ + else if(magicMod->name == #n) { \ + assert(subst && subst->args.Count() == 1); \ + auto type = new T(); \ + type->elementType = ExtractGenericArgType(subst->args[0]); \ + type->declRef = declRef; \ + return type; \ } - #define CASE(n,T) \ - else if(magicMod->name == #n) { \ - assert(subst && subst->args.Count() == 1); \ - auto type = new T(); \ - type->elementType = ExtractGenericArgType(subst->args[0]); \ - type->declRef = declRef; \ - return type; \ - } - - CASE(ConstantBuffer, ConstantBufferType) - CASE(TextureBuffer, TextureBufferType) - CASE(GLSLInputParameterBlockType, GLSLInputParameterBlockType) - CASE(GLSLOutputParameterBlockType, GLSLOutputParameterBlockType) - CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType) - - CASE(PackedBuffer, PackedBufferType) - CASE(Uniform, UniformBufferType) - CASE(Patch, PatchType) - - CASE(HLSLBufferType, HLSLBufferType) - CASE(HLSLStructuredBufferType, HLSLStructuredBufferType) - CASE(HLSLRWBufferType, HLSLRWBufferType) - CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType) - CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType) - CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType) - CASE(HLSLInputPatchType, HLSLInputPatchType) - CASE(HLSLOutputPatchType, HLSLOutputPatchType) - - CASE(HLSLPointStreamType, HLSLPointStreamType) - CASE(HLSLLineStreamType, HLSLPointStreamType) - CASE(HLSLTriangleStreamType, HLSLPointStreamType) - - #undef CASE - - // "magic" builtin types which have no generic parameters - #define CASE(n,T) \ - else if(magicMod->name == #n) { \ - auto type = new T(); \ - type->declRef = declRef; \ - return type; \ - } + CASE(ConstantBuffer, ConstantBufferType) + CASE(TextureBuffer, TextureBufferType) + CASE(GLSLInputParameterBlockType, GLSLInputParameterBlockType) + CASE(GLSLOutputParameterBlockType, GLSLOutputParameterBlockType) + CASE(GLSLShaderStorageBufferType, GLSLShaderStorageBufferType) + + CASE(PackedBuffer, PackedBufferType) + CASE(Uniform, UniformBufferType) + CASE(Patch, PatchType) + + CASE(HLSLBufferType, HLSLBufferType) + CASE(HLSLStructuredBufferType, HLSLStructuredBufferType) + CASE(HLSLRWBufferType, HLSLRWBufferType) + CASE(HLSLRWStructuredBufferType, HLSLRWStructuredBufferType) + CASE(HLSLAppendStructuredBufferType, HLSLAppendStructuredBufferType) + CASE(HLSLConsumeStructuredBufferType, HLSLConsumeStructuredBufferType) + CASE(HLSLInputPatchType, HLSLInputPatchType) + CASE(HLSLOutputPatchType, HLSLOutputPatchType) + + CASE(HLSLPointStreamType, HLSLPointStreamType) + CASE(HLSLLineStreamType, HLSLPointStreamType) + CASE(HLSLTriangleStreamType, HLSLPointStreamType) + + #undef CASE + + // "magic" builtin types which have no generic parameters + #define CASE(n,T) \ + else if(magicMod->name == #n) { \ + auto type = new T(); \ + type->declRef = declRef; \ + return type; \ + } - CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType) - CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType) - CASE(UntypedBufferResourceType, UntypedBufferResourceType) + CASE(HLSLByteAddressBufferType, HLSLByteAddressBufferType) + CASE(HLSLRWByteAddressBufferType, HLSLRWByteAddressBufferType) + CASE(UntypedBufferResourceType, UntypedBufferResourceType) - CASE(GLSLInputAttachmentType, GLSLInputAttachmentType) + CASE(GLSLInputAttachmentType, GLSLInputAttachmentType) - #undef CASE + #undef CASE - else - { - throw "unimplemented"; - } - } else { - return new DeclRefType(declRef); + throw "unimplemented"; } } - - // OverloadGroupType - - String OverloadGroupType::ToString() + else { - return "overload group"; + return new DeclRefType(declRef); } + } - bool OverloadGroupType::EqualsImpl(ExpressionType * /*type*/) - { - return false; - } + // OverloadGroupType - ExpressionType* OverloadGroupType::CreateCanonicalType() - { - return this; - } + String OverloadGroupType::ToString() + { + return "overload group"; + } - int OverloadGroupType::GetHashCode() - { - return (int)(int64_t)(void*)this; - } + bool OverloadGroupType::EqualsImpl(ExpressionType * /*type*/) + { + return false; + } - // InitializerListType + ExpressionType* OverloadGroupType::CreateCanonicalType() + { + return this; + } - String InitializerListType::ToString() - { - return "initializer list"; - } + int OverloadGroupType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } - bool InitializerListType::EqualsImpl(ExpressionType * /*type*/) - { - return false; - } + // InitializerListType - ExpressionType* InitializerListType::CreateCanonicalType() - { - return this; - } + String InitializerListType::ToString() + { + return "initializer list"; + } - int InitializerListType::GetHashCode() - { - return (int)(int64_t)(void*)this; - } + bool InitializerListType::EqualsImpl(ExpressionType * /*type*/) + { + return false; + } - // ErrorType + ExpressionType* InitializerListType::CreateCanonicalType() + { + return this; + } - String ErrorType::ToString() - { - return "error"; - } + int InitializerListType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } - bool ErrorType::EqualsImpl(ExpressionType* type) - { - if (auto errorType = type->As<ErrorType>()) - return true; - return false; - } + // ErrorType - ExpressionType* ErrorType::CreateCanonicalType() - { - return this; - } + String ErrorType::ToString() + { + return "error"; + } - int ErrorType::GetHashCode() - { - return (int)(int64_t)(void*)this; - } + bool ErrorType::EqualsImpl(ExpressionType* type) + { + if (auto errorType = type->As<ErrorType>()) + return true; + return false; + } + ExpressionType* ErrorType::CreateCanonicalType() + { + return this; + } - // NamedExpressionType + int ErrorType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } - String NamedExpressionType::ToString() - { - return declRef.GetName(); - } - bool NamedExpressionType::EqualsImpl(ExpressionType * /*type*/) - { - assert(!"unreachable"); - return false; - } + // NamedExpressionType - ExpressionType* NamedExpressionType::CreateCanonicalType() - { - return declRef.GetType()->GetCanonicalType(); - } + String NamedExpressionType::ToString() + { + return declRef.GetName(); + } - int NamedExpressionType::GetHashCode() - { - assert(!"unreachable"); - return 0; - } + bool NamedExpressionType::EqualsImpl(ExpressionType * /*type*/) + { + assert(!"unreachable"); + return false; + } - // FuncType + ExpressionType* NamedExpressionType::CreateCanonicalType() + { + return declRef.GetType()->GetCanonicalType(); + } - String FuncType::ToString() - { - // TODO: a better approach than this - if (declRef) - return declRef.GetName(); - else - return "/* unknown FuncType */"; - } + int NamedExpressionType::GetHashCode() + { + assert(!"unreachable"); + return 0; + } - bool FuncType::EqualsImpl(ExpressionType * type) - { - if (auto funcType = type->As<FuncType>()) - { - return declRef == funcType->declRef; - } - return false; - } + // FuncType - ExpressionType* FuncType::CreateCanonicalType() - { - return this; - } + String FuncType::ToString() + { + // TODO: a better approach than this + if (declRef) + return declRef.GetName(); + else + return "/* unknown FuncType */"; + } - int FuncType::GetHashCode() + bool FuncType::EqualsImpl(ExpressionType * type) + { + if (auto funcType = type->As<FuncType>()) { - return declRef.GetHashCode(); + return declRef == funcType->declRef; } + return false; + } - // TypeType + ExpressionType* FuncType::CreateCanonicalType() + { + return this; + } - String TypeType::ToString() - { - StringBuilder sb; - sb << "typeof(" << type->ToString() << ")"; - return sb.ProduceString(); - } + int FuncType::GetHashCode() + { + return declRef.GetHashCode(); + } - bool TypeType::EqualsImpl(ExpressionType * t) - { - if (auto typeType = t->As<TypeType>()) - { - return t->Equals(typeType->type); - } - return false; - } + // TypeType - ExpressionType* TypeType::CreateCanonicalType() - { - auto canType = new TypeType(type->GetCanonicalType()); - sCanonicalTypes.Add(canType); - return canType; - } + String TypeType::ToString() + { + StringBuilder sb; + sb << "typeof(" << type->ToString() << ")"; + return sb.ProduceString(); + } - int TypeType::GetHashCode() + bool TypeType::EqualsImpl(ExpressionType * t) + { + if (auto typeType = t->As<TypeType>()) { - assert(!"unreachable"); - return 0; + return t->Equals(typeType->type); } + return false; + } - // GenericDeclRefType + ExpressionType* TypeType::CreateCanonicalType() + { + auto canType = new TypeType(type->GetCanonicalType()); + sCanonicalTypes.Add(canType); + return canType; + } - String GenericDeclRefType::ToString() - { - // TODO: what is appropriate here? - return "<GenericDeclRef>"; - } + int TypeType::GetHashCode() + { + assert(!"unreachable"); + return 0; + } - bool GenericDeclRefType::EqualsImpl(ExpressionType * type) - { - if (auto genericDeclRefType = type->As<GenericDeclRefType>()) - { - return declRef.Equals(genericDeclRefType->declRef); - } - return false; - } + // GenericDeclRefType - int GenericDeclRefType::GetHashCode() - { - return declRef.GetHashCode(); - } + String GenericDeclRefType::ToString() + { + // TODO: what is appropriate here? + return "<GenericDeclRef>"; + } - ExpressionType* GenericDeclRefType::CreateCanonicalType() + bool GenericDeclRefType::EqualsImpl(ExpressionType * type) + { + if (auto genericDeclRefType = type->As<GenericDeclRefType>()) { - return this; + return declRef.Equals(genericDeclRefType->declRef); } + return false; + } - // ArithmeticExpressionType + int GenericDeclRefType::GetHashCode() + { + return declRef.GetHashCode(); + } - // VectorExpressionType + ExpressionType* GenericDeclRefType::CreateCanonicalType() + { + return this; + } - String VectorExpressionType::ToString() - { - StringBuilder sb; - sb << "vector<" << elementType->ToString() << "," << elementCount->ToString() << ">"; - return sb.ProduceString(); - } + // ArithmeticExpressionType - BasicExpressionType* VectorExpressionType::GetScalarType() - { - return elementType->AsBasicType(); - } + // VectorExpressionType - // MatrixExpressionType + String VectorExpressionType::ToString() + { + StringBuilder sb; + sb << "vector<" << elementType->ToString() << "," << elementCount->ToString() << ">"; + return sb.ProduceString(); + } - String MatrixExpressionType::ToString() - { - StringBuilder sb; - sb << "matrix<" << getElementType()->ToString() << "," << getRowCount()->ToString() << "," << getColumnCount()->ToString() << ">"; - return sb.ProduceString(); - } + BasicExpressionType* VectorExpressionType::GetScalarType() + { + return elementType->AsBasicType(); + } - BasicExpressionType* MatrixExpressionType::GetScalarType() - { - return getElementType()->AsBasicType(); - } + // MatrixExpressionType - ExpressionType* MatrixExpressionType::getElementType() - { - return this->declRef.substitutions->args[0].As<ExpressionType>().Ptr(); - } + String MatrixExpressionType::ToString() + { + StringBuilder sb; + sb << "matrix<" << getElementType()->ToString() << "," << getRowCount()->ToString() << "," << getColumnCount()->ToString() << ">"; + return sb.ProduceString(); + } - IntVal* MatrixExpressionType::getRowCount() - { - return this->declRef.substitutions->args[1].As<IntVal>().Ptr(); - } + BasicExpressionType* MatrixExpressionType::GetScalarType() + { + return getElementType()->AsBasicType(); + } - IntVal* MatrixExpressionType::getColumnCount() - { - return this->declRef.substitutions->args[2].As<IntVal>().Ptr(); - } + ExpressionType* MatrixExpressionType::getElementType() + { + return this->declRef.substitutions->args[0].As<ExpressionType>().Ptr(); + } + + IntVal* MatrixExpressionType::getRowCount() + { + return this->declRef.substitutions->args[1].As<IntVal>().Ptr(); + } + + IntVal* MatrixExpressionType::getColumnCount() + { + return this->declRef.substitutions->args[2].As<IntVal>().Ptr(); + } - // + // #if 0 - String GetOperatorFunctionName(Operator op) - { - switch (op) - { - case Operator::Add: - case Operator::AddAssign: - return "+"; - case Operator::Sub: - case Operator::SubAssign: - return "-"; - case Operator::Neg: - return "-"; - case Operator::Not: - return "!"; - case Operator::BitNot: - return "~"; - case Operator::PreInc: - case Operator::PostInc: - return "++"; - case Operator::PreDec: - case Operator::PostDec: - return "--"; - case Operator::Mul: - case Operator::MulAssign: - return "*"; - case Operator::Div: - case Operator::DivAssign: - return "/"; - case Operator::Mod: - case Operator::ModAssign: - return "%"; - case Operator::Lsh: - case Operator::LshAssign: - return "<<"; - case Operator::Rsh: - case Operator::RshAssign: - return ">>"; - case Operator::Eql: - return "=="; - case Operator::Neq: - return "!="; - case Operator::Greater: - return ">"; - case Operator::Less: - return "<"; - case Operator::Geq: - return ">="; - case Operator::Leq: - return "<="; - case Operator::BitAnd: - case Operator::AndAssign: - return "&"; - case Operator::BitXor: - case Operator::XorAssign: - return "^"; - case Operator::BitOr: - case Operator::OrAssign: - return "|"; - case Operator::And: - return "&&"; - case Operator::Or: - return "||"; - case Operator::Sequence: - return ","; - case Operator::Select: - return "?:"; - case Operator::Assign: - return "="; - default: - return ""; - } + String GetOperatorFunctionName(Operator op) + { + switch (op) + { + case Operator::Add: + case Operator::AddAssign: + return "+"; + case Operator::Sub: + case Operator::SubAssign: + return "-"; + case Operator::Neg: + return "-"; + case Operator::Not: + return "!"; + case Operator::BitNot: + return "~"; + case Operator::PreInc: + case Operator::PostInc: + return "++"; + case Operator::PreDec: + case Operator::PostDec: + return "--"; + case Operator::Mul: + case Operator::MulAssign: + return "*"; + case Operator::Div: + case Operator::DivAssign: + return "/"; + case Operator::Mod: + case Operator::ModAssign: + return "%"; + case Operator::Lsh: + case Operator::LshAssign: + return "<<"; + case Operator::Rsh: + case Operator::RshAssign: + return ">>"; + case Operator::Eql: + return "=="; + case Operator::Neq: + return "!="; + case Operator::Greater: + return ">"; + case Operator::Less: + return "<"; + case Operator::Geq: + return ">="; + case Operator::Leq: + return "<="; + case Operator::BitAnd: + case Operator::AndAssign: + return "&"; + case Operator::BitXor: + case Operator::XorAssign: + return "^"; + case Operator::BitOr: + case Operator::OrAssign: + return "|"; + case Operator::And: + return "&&"; + case Operator::Or: + return "||"; + case Operator::Sequence: + return ","; + case Operator::Select: + return "?:"; + case Operator::Assign: + return "="; + default: + return ""; } + } #endif - String OperatorToString(Operator op) - { - switch (op) - { - case Slang::Compiler::Operator::Neg: - return "-"; - case Slang::Compiler::Operator::Not: - return "!"; - case Slang::Compiler::Operator::PreInc: - return "++"; - case Slang::Compiler::Operator::PreDec: - return "--"; - case Slang::Compiler::Operator::PostInc: - return "++"; - case Slang::Compiler::Operator::PostDec: - return "--"; - case Slang::Compiler::Operator::Mul: - case Slang::Compiler::Operator::MulAssign: - return "*"; - case Slang::Compiler::Operator::Div: - case Slang::Compiler::Operator::DivAssign: - return "/"; - case Slang::Compiler::Operator::Mod: - case Slang::Compiler::Operator::ModAssign: - return "%"; - case Slang::Compiler::Operator::Add: - case Slang::Compiler::Operator::AddAssign: - return "+"; - case Slang::Compiler::Operator::Sub: - case Slang::Compiler::Operator::SubAssign: - return "-"; - case Slang::Compiler::Operator::Lsh: - case Slang::Compiler::Operator::LshAssign: - return "<<"; - case Slang::Compiler::Operator::Rsh: - case Slang::Compiler::Operator::RshAssign: - return ">>"; - case Slang::Compiler::Operator::Eql: - return "=="; - case Slang::Compiler::Operator::Neq: - return "!="; - case Slang::Compiler::Operator::Greater: - return ">"; - case Slang::Compiler::Operator::Less: - return "<"; - case Slang::Compiler::Operator::Geq: - return ">="; - case Slang::Compiler::Operator::Leq: - return "<="; - case Slang::Compiler::Operator::BitAnd: - case Slang::Compiler::Operator::AndAssign: - return "&"; - case Slang::Compiler::Operator::BitXor: - case Slang::Compiler::Operator::XorAssign: - return "^"; - case Slang::Compiler::Operator::BitOr: - case Slang::Compiler::Operator::OrAssign: - return "|"; - case Slang::Compiler::Operator::And: - return "&&"; - case Slang::Compiler::Operator::Or: - return "||"; - case Slang::Compiler::Operator::Assign: - return "="; - default: - return "ERROR"; - } + String OperatorToString(Operator op) + { + switch (op) + { + case Slang::Operator::Neg: + return "-"; + case Slang::Operator::Not: + return "!"; + case Slang::Operator::PreInc: + return "++"; + case Slang::Operator::PreDec: + return "--"; + case Slang::Operator::PostInc: + return "++"; + case Slang::Operator::PostDec: + return "--"; + case Slang::Operator::Mul: + case Slang::Operator::MulAssign: + return "*"; + case Slang::Operator::Div: + case Slang::Operator::DivAssign: + return "/"; + case Slang::Operator::Mod: + case Slang::Operator::ModAssign: + return "%"; + case Slang::Operator::Add: + case Slang::Operator::AddAssign: + return "+"; + case Slang::Operator::Sub: + case Slang::Operator::SubAssign: + return "-"; + case Slang::Operator::Lsh: + case Slang::Operator::LshAssign: + return "<<"; + case Slang::Operator::Rsh: + case Slang::Operator::RshAssign: + return ">>"; + case Slang::Operator::Eql: + return "=="; + case Slang::Operator::Neq: + return "!="; + case Slang::Operator::Greater: + return ">"; + case Slang::Operator::Less: + return "<"; + case Slang::Operator::Geq: + return ">="; + case Slang::Operator::Leq: + return "<="; + case Slang::Operator::BitAnd: + case Slang::Operator::AndAssign: + return "&"; + case Slang::Operator::BitXor: + case Slang::Operator::XorAssign: + return "^"; + case Slang::Operator::BitOr: + case Slang::Operator::OrAssign: + return "|"; + case Slang::Operator::And: + return "&&"; + case Slang::Operator::Or: + return "||"; + case Slang::Operator::Assign: + return "="; + default: + return "ERROR"; } + } - // TypeExp + // TypeExp - TypeExp TypeExp::Accept(SyntaxVisitor* visitor) - { - return visitor->VisitTypeExp(*this); - } + TypeExp TypeExp::Accept(SyntaxVisitor* visitor) + { + return visitor->VisitTypeExp(*this); + } - // BuiltinTypeModifier + // BuiltinTypeModifier - // MagicTypeModifier + // MagicTypeModifier - // GenericDecl + // GenericDecl - RefPtr<SyntaxNode> GenericDecl::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitGenericDecl(this); - } + RefPtr<SyntaxNode> GenericDecl::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitGenericDecl(this); + } - // GenericTypeParamDecl + // GenericTypeParamDecl - RefPtr<SyntaxNode> GenericTypeParamDecl::Accept(SyntaxVisitor * /*visitor*/) { - //throw "unimplemented"; - return this; - } + RefPtr<SyntaxNode> GenericTypeParamDecl::Accept(SyntaxVisitor * /*visitor*/) { + //throw "unimplemented"; + return this; + } - // GenericTypeConstraintDecl + // GenericTypeConstraintDecl - RefPtr<SyntaxNode> GenericTypeConstraintDecl::Accept(SyntaxVisitor * visitor) - { - return this; - } + RefPtr<SyntaxNode> GenericTypeConstraintDecl::Accept(SyntaxVisitor * visitor) + { + return this; + } - // GenericValueParamDecl + // GenericValueParamDecl - RefPtr<SyntaxNode> GenericValueParamDecl::Accept(SyntaxVisitor * /*visitor*/) { - //throw "unimplemented"; - return this; - } + RefPtr<SyntaxNode> GenericValueParamDecl::Accept(SyntaxVisitor * /*visitor*/) { + //throw "unimplemented"; + return this; + } - // GenericParamIntVal + // GenericParamIntVal - bool GenericParamIntVal::EqualsVal(Val* val) + bool GenericParamIntVal::EqualsVal(Val* val) + { + if (auto genericParamVal = dynamic_cast<GenericParamIntVal*>(val)) { - if (auto genericParamVal = dynamic_cast<GenericParamIntVal*>(val)) - { - return declRef.Equals(genericParamVal->declRef); - } - return false; + return declRef.Equals(genericParamVal->declRef); } + return false; + } - String GenericParamIntVal::ToString() - { - return declRef.GetName(); - } + String GenericParamIntVal::ToString() + { + return declRef.GetName(); + } - int GenericParamIntVal::GetHashCode() - { - return declRef.GetHashCode() ^ 0xFFFF; - } + int GenericParamIntVal::GetHashCode() + { + return declRef.GetHashCode() ^ 0xFFFF; + } - RefPtr<Val> GenericParamIntVal::SubstituteImpl(Substitutions* subst, int* ioDiff) + RefPtr<Val> GenericParamIntVal::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + // search for a substitution that might apply to us + for (auto s = subst; s; s = s->outer.Ptr()) { - // search for a substitution that might apply to us - for (auto s = subst; s; s = s->outer.Ptr()) - { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = s->genericDecl; - if (genericDecl != declRef.GetDecl()->ParentDecl) - continue; + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = s->genericDecl; + if (genericDecl != declRef.GetDecl()->ParentDecl) + continue; - int index = 0; - for (auto m : genericDecl->Members) + int index = 0; + for (auto m : genericDecl->Members) + { + if (m.Ptr() == declRef.GetDecl()) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return s->args[index]; + } + else if(auto typeParam = m.As<GenericTypeParamDecl>()) + { + index++; + } + else if(auto valParam = m.As<GenericValueParamDecl>()) + { + index++; + } + else { - if (m.Ptr() == declRef.GetDecl()) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return s->args[index]; - } - else if(auto typeParam = m.As<GenericTypeParamDecl>()) - { - index++; - } - else if(auto valParam = m.As<GenericValueParamDecl>()) - { - index++; - } - else - { - } } } - - // Nothing found: don't substittue. - return this; } - // ExtensionDecl - - RefPtr<SyntaxNode> ExtensionDecl::Accept(SyntaxVisitor * visitor) - { - visitor->VisitExtensionDecl(this); - return this; - } + // Nothing found: don't substittue. + return this; + } - // ConstructorDecl + // ExtensionDecl - RefPtr<SyntaxNode> ConstructorDecl::Accept(SyntaxVisitor * visitor) - { - visitor->VisitConstructorDecl(this); - return this; - } + RefPtr<SyntaxNode> ExtensionDecl::Accept(SyntaxVisitor * visitor) + { + visitor->VisitExtensionDecl(this); + return this; + } - // SubscriptDecl + // ConstructorDecl - RefPtr<SyntaxNode> SubscriptDecl::Accept(SyntaxVisitor * visitor) - { - visitor->visitSubscriptDecl(this); - return this; - } + RefPtr<SyntaxNode> ConstructorDecl::Accept(SyntaxVisitor * visitor) + { + visitor->VisitConstructorDecl(this); + return this; + } - // AccessorDecl + // SubscriptDecl - RefPtr<SyntaxNode> AccessorDecl::Accept(SyntaxVisitor * visitor) - { - visitor->visitAccessorDecl(this); - return this; - } + RefPtr<SyntaxNode> SubscriptDecl::Accept(SyntaxVisitor * visitor) + { + visitor->visitSubscriptDecl(this); + return this; + } - // Substitutions + // AccessorDecl - RefPtr<Substitutions> Substitutions::SubstituteImpl(Substitutions* subst, int* ioDiff) - { - if (!this) return nullptr; + RefPtr<SyntaxNode> AccessorDecl::Accept(SyntaxVisitor * visitor) + { + visitor->visitAccessorDecl(this); + return this; + } - int diff = 0; - auto outerSubst = outer->SubstituteImpl(subst, &diff); + // Substitutions - List<RefPtr<Val>> substArgs; - for (auto a : args) - { - substArgs.Add(a->SubstituteImpl(subst, &diff)); - } + RefPtr<Substitutions> Substitutions::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + if (!this) return nullptr; - if (!diff) return this; + int diff = 0; + auto outerSubst = outer->SubstituteImpl(subst, &diff); - (*ioDiff)++; - auto substSubst = new Substitutions(); - substSubst->genericDecl = genericDecl; - substSubst->args = substArgs; - return substSubst; + List<RefPtr<Val>> substArgs; + for (auto a : args) + { + substArgs.Add(a->SubstituteImpl(subst, &diff)); } - bool Substitutions::Equals(Substitutions* subst) - { - // both must be NULL, or non-NULL - if (!this || !subst) - return !this && !subst; + if (!diff) return this; - if (genericDecl != subst->genericDecl) - return false; + (*ioDiff)++; + auto substSubst = new Substitutions(); + substSubst->genericDecl = genericDecl; + substSubst->args = substArgs; + return substSubst; + } - int argCount = args.Count(); - assert(args.Count() == subst->args.Count()); - for (int aa = 0; aa < argCount; ++aa) - { - if (!args[aa]->EqualsVal(subst->args[aa].Ptr())) - return false; - } + bool Substitutions::Equals(Substitutions* subst) + { + // both must be NULL, or non-NULL + if (!this || !subst) + return !this && !subst; - if (!outer->Equals(subst->outer.Ptr())) - return false; + if (genericDecl != subst->genericDecl) + return false; - return true; + int argCount = args.Count(); + assert(args.Count() == subst->args.Count()); + for (int aa = 0; aa < argCount; ++aa) + { + if (!args[aa]->EqualsVal(subst->args[aa].Ptr())) + return false; } + if (!outer->Equals(subst->outer.Ptr())) + return false; - // DeclRef + return true; + } - RefPtr<ExpressionType> DeclRef::Substitute(RefPtr<ExpressionType> type) const - { - // No substitutions? Easy. - if (!substitutions) - return type; - // Otherwise we need to recurse on the type structure - // and apply substitutions where it makes sense + // DeclRef - return type->Substitute(substitutions.Ptr()).As<ExpressionType>(); - } + RefPtr<ExpressionType> DeclRef::Substitute(RefPtr<ExpressionType> type) const + { + // No substitutions? Easy. + if (!substitutions) + return type; - DeclRef DeclRef::Substitute(DeclRef declRef) const - { - if(!substitutions) - return declRef; + // Otherwise we need to recurse on the type structure + // and apply substitutions where it makes sense - int diff = 0; - return declRef.SubstituteImpl(substitutions.Ptr(), &diff); - } + return type->Substitute(substitutions.Ptr()).As<ExpressionType>(); + } - RefPtr<ExpressionSyntaxNode> DeclRef::Substitute(RefPtr<ExpressionSyntaxNode> expr) const - { - // No substitutions? Easy. - if (!substitutions) - return expr; + DeclRef DeclRef::Substitute(DeclRef declRef) const + { + if(!substitutions) + return declRef; - assert(!"unimplemented"); + int diff = 0; + return declRef.SubstituteImpl(substitutions.Ptr(), &diff); + } + RefPtr<ExpressionSyntaxNode> DeclRef::Substitute(RefPtr<ExpressionSyntaxNode> expr) const + { + // No substitutions? Easy. + if (!substitutions) return expr; - } + assert(!"unimplemented"); - DeclRef DeclRef::SubstituteImpl(Substitutions* subst, int* ioDiff) - { - if (!substitutions) return *this; + return expr; + } - int diff = 0; - RefPtr<Substitutions> substSubst = substitutions->SubstituteImpl(subst, &diff); - if (!diff) - return *this; + DeclRef DeclRef::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + if (!substitutions) return *this; - *ioDiff += diff; + int diff = 0; + RefPtr<Substitutions> substSubst = substitutions->SubstituteImpl(subst, &diff); - DeclRef substDeclRef; - substDeclRef.decl = decl; - substDeclRef.substitutions = substSubst; - return substDeclRef; - } + if (!diff) + return *this; + *ioDiff += diff; - // Check if this is an equivalent declaration reference to another - bool DeclRef::Equals(DeclRef const& declRef) const - { - if (decl != declRef.decl) - return false; + DeclRef substDeclRef; + substDeclRef.decl = decl; + substDeclRef.substitutions = substSubst; + return substDeclRef; + } - if (!substitutions->Equals(declRef.substitutions.Ptr())) - return false; - return true; - } + // Check if this is an equivalent declaration reference to another + bool DeclRef::Equals(DeclRef const& declRef) const + { + if (decl != declRef.decl) + return false; - // Convenience accessors for common properties of declarations - String const& DeclRef::GetName() const - { - return decl->Name.Content; - } + if (!substitutions->Equals(declRef.substitutions.Ptr())) + return false; - DeclRef DeclRef::GetParent() const - { - auto parentDecl = decl->ParentDecl; - if (auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl)) - { - // We need to strip away one layer of specialization - assert(substitutions); - return DeclRef(parentGeneric, substitutions->outer); - } - else - { - // If the parent isn't a generic, then it must - // use the same specializations as this declaration - return DeclRef(parentDecl, substitutions); - } + return true; + } - } + // Convenience accessors for common properties of declarations + String const& DeclRef::GetName() const + { + return decl->Name.Content; + } - int DeclRef::GetHashCode() const + DeclRef DeclRef::GetParent() const + { + auto parentDecl = decl->ParentDecl; + if (auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl)) { - auto rs = PointerHash<1>::GetHashCode(decl); - if (substitutions) - { - rs *= 16777619; - rs ^= substitutions->GetHashCode(); - } - return rs; + // We need to strip away one layer of specialization + assert(substitutions); + return DeclRef(parentGeneric, substitutions->outer); } - - // Val - - RefPtr<Val> Val::Substitute(Substitutions* subst) + else { - if (!this) return nullptr; - if (!subst) return this; - int diff = 0; - return SubstituteImpl(subst, &diff); + // If the parent isn't a generic, then it must + // use the same specializations as this declaration + return DeclRef(parentDecl, substitutions); } - RefPtr<Val> Val::SubstituteImpl(Substitutions* /*subst*/, int* /*ioDiff*/) + } + + int DeclRef::GetHashCode() const + { + auto rs = PointerHash<1>::GetHashCode(decl); + if (substitutions) { - // Default behavior is to not substitute at all - return this; + rs *= 16777619; + rs ^= substitutions->GetHashCode(); } + return rs; + } - // IntVal + // Val - int GetIntVal(RefPtr<IntVal> val) - { - if (auto constantVal = val.As<ConstantIntVal>()) - { - return constantVal->value; - } - assert(!"unexpected"); - return 0; - } + RefPtr<Val> Val::Substitute(Substitutions* subst) + { + if (!this) return nullptr; + if (!subst) return this; + int diff = 0; + return SubstituteImpl(subst, &diff); + } - // ConstantIntVal + RefPtr<Val> Val::SubstituteImpl(Substitutions* /*subst*/, int* /*ioDiff*/) + { + // Default behavior is to not substitute at all + return this; + } - bool ConstantIntVal::EqualsVal(Val* val) - { - if (auto intVal = dynamic_cast<ConstantIntVal*>(val)) - return value == intVal->value; - return false; - } + // IntVal - String ConstantIntVal::ToString() + int GetIntVal(RefPtr<IntVal> val) + { + if (auto constantVal = val.As<ConstantIntVal>()) { - return String(value); + return constantVal->value; } + assert(!"unexpected"); + return 0; + } - int ConstantIntVal::GetHashCode() - { - return value; - } + // ConstantIntVal - // SwitchStmt + bool ConstantIntVal::EqualsVal(Val* val) + { + if (auto intVal = dynamic_cast<ConstantIntVal*>(val)) + return value == intVal->value; + return false; + } - RefPtr<SyntaxNode> SwitchStmt::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitSwitchStmt(this); - } + String ConstantIntVal::ToString() + { + return String(value); + } - RefPtr<SyntaxNode> CaseStmt::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitCaseStmt(this); - } + int ConstantIntVal::GetHashCode() + { + return value; + } - RefPtr<SyntaxNode> DefaultStmt::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitDefaultStmt(this); - } + // SwitchStmt - // InterfaceDecl + RefPtr<SyntaxNode> SwitchStmt::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitSwitchStmt(this); + } - RefPtr<SyntaxNode> InterfaceDecl::Accept(SyntaxVisitor * visitor) - { - visitor->visitInterfaceDecl(this); - return this; - } + RefPtr<SyntaxNode> CaseStmt::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitCaseStmt(this); + } - // InheritanceDecl + RefPtr<SyntaxNode> DefaultStmt::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitDefaultStmt(this); + } - RefPtr<SyntaxNode> InheritanceDecl::Accept(SyntaxVisitor * visitor) - { - visitor->visitInheritanceDecl(this); - return this; - } + // InterfaceDecl - // SharedTypeExpr + RefPtr<SyntaxNode> InterfaceDecl::Accept(SyntaxVisitor * visitor) + { + visitor->visitInterfaceDecl(this); + return this; + } - RefPtr<SyntaxNode> SharedTypeExpr::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitSharedTypeExpr(this); - } + // InheritanceDecl + + RefPtr<SyntaxNode> InheritanceDecl::Accept(SyntaxVisitor * visitor) + { + visitor->visitInheritanceDecl(this); + return this; + } + + // SharedTypeExpr + + RefPtr<SyntaxNode> SharedTypeExpr::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitSharedTypeExpr(this); + } - // OperatorExpressionSyntaxNode + // OperatorExpressionSyntaxNode #if 0 - void OperatorExpressionSyntaxNode::SetOperator(RefPtr<Scope> scope, Slang::Compiler::Operator op) - { - this->Operator = op; - auto opExpr = new VarExpressionSyntaxNode(); - opExpr->Variable = GetOperatorFunctionName(Operator); - opExpr->scope = scope; - opExpr->Position = this->Position; - this->FunctionExpr = opExpr; - } + void OperatorExpressionSyntaxNode::SetOperator(RefPtr<Scope> scope, Slang::Operator op) + { + this->Operator = op; + auto opExpr = new VarExpressionSyntaxNode(); + opExpr->Variable = GetOperatorFunctionName(Operator); + opExpr->scope = scope; + opExpr->Position = this->Position; + this->FunctionExpr = opExpr; + } #endif - RefPtr<SyntaxNode> OperatorExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) - { - return visitor->VisitOperatorExpression(this); - } + RefPtr<SyntaxNode> OperatorExpressionSyntaxNode::Accept(SyntaxVisitor * visitor) + { + return visitor->VisitOperatorExpression(this); + } - // DeclGroup + // DeclGroup - RefPtr<SyntaxNode> DeclGroup::Accept(SyntaxVisitor * visitor) - { - visitor->VisitDeclGroup(this); - return this; - } + RefPtr<SyntaxNode> DeclGroup::Accept(SyntaxVisitor * visitor) + { + visitor->VisitDeclGroup(this); + return this; + } - // + // - void RegisterBuiltinDecl( - RefPtr<Decl> decl, - RefPtr<BuiltinTypeModifier> modifier) - { - auto type = DeclRefType::Create(DeclRef(decl.Ptr(), nullptr)); - ExpressionType::sBuiltinTypes[(int)modifier->tag] = type; - } + void RegisterBuiltinDecl( + RefPtr<Decl> decl, + RefPtr<BuiltinTypeModifier> modifier) + { + auto type = DeclRefType::Create(DeclRef(decl.Ptr(), nullptr)); + ExpressionType::sBuiltinTypes[(int)modifier->tag] = type; + } - void RegisterMagicDecl( - RefPtr<Decl> decl, - RefPtr<MagicTypeModifier> modifier) - { - ExpressionType::sMagicDecls[modifier->name] = decl.Ptr(); - } + void RegisterMagicDecl( + RefPtr<Decl> decl, + RefPtr<MagicTypeModifier> modifier) + { + ExpressionType::sMagicDecls[modifier->name] = decl.Ptr(); + } - RefPtr<Decl> findMagicDecl( - String const& name) - { - return ExpressionType::sMagicDecls[name].GetValue(); - } + RefPtr<Decl> findMagicDecl( + String const& name) + { + return ExpressionType::sMagicDecls[name].GetValue(); + } - ExpressionType* ExpressionType::GetBool() - { - return sBuiltinTypes[(int)BaseType::Bool].GetValue().Ptr(); - } + ExpressionType* ExpressionType::GetBool() + { + return sBuiltinTypes[(int)BaseType::Bool].GetValue().Ptr(); + } - ExpressionType* ExpressionType::GetFloat() - { - return sBuiltinTypes[(int)BaseType::Float].GetValue().Ptr(); - } + ExpressionType* ExpressionType::GetFloat() + { + return sBuiltinTypes[(int)BaseType::Float].GetValue().Ptr(); + } - ExpressionType* ExpressionType::GetInt() - { - return sBuiltinTypes[(int)BaseType::Int].GetValue().Ptr(); - } + ExpressionType* ExpressionType::GetInt() + { + return sBuiltinTypes[(int)BaseType::Int].GetValue().Ptr(); + } - ExpressionType* ExpressionType::GetUInt() - { - return sBuiltinTypes[(int)BaseType::UInt].GetValue().Ptr(); - } + ExpressionType* ExpressionType::GetUInt() + { + return sBuiltinTypes[(int)BaseType::UInt].GetValue().Ptr(); + } - ExpressionType* ExpressionType::GetVoid() - { - return sBuiltinTypes[(int)BaseType::Void].GetValue().Ptr(); - } + ExpressionType* ExpressionType::GetVoid() + { + return sBuiltinTypes[(int)BaseType::Void].GetValue().Ptr(); + } - ExpressionType* ExpressionType::getInitializerListType() - { - return initializerListType.Ptr(); - } + ExpressionType* ExpressionType::getInitializerListType() + { + return initializerListType.Ptr(); + } - ExpressionType* ExpressionType::GetError() - { - return ExpressionType::Error.Ptr(); - } + ExpressionType* ExpressionType::GetError() + { + return ExpressionType::Error.Ptr(); + } - // + // - RefPtr<SyntaxNode> UnparsedStmt::Accept(SyntaxVisitor * visitor) - { - return this; - } + RefPtr<SyntaxNode> UnparsedStmt::Accept(SyntaxVisitor * visitor) + { + return this; + } - // + // - RefPtr<SyntaxNode> InitializerListExpr::Accept(SyntaxVisitor * visitor) - { - return visitor->visitInitializerListExpr(this); - } + RefPtr<SyntaxNode> InitializerListExpr::Accept(SyntaxVisitor * visitor) + { + return visitor->visitInitializerListExpr(this); + } - // + // - RefPtr<SyntaxNode> ModifierDecl::Accept(SyntaxVisitor * visitor) - { - return this; - } + RefPtr<SyntaxNode> ModifierDecl::Accept(SyntaxVisitor * visitor) + { + return this; + } - // + // - RefPtr<SyntaxNode> EmptyDecl::Accept(SyntaxVisitor * visitor) - { - return this; - } + RefPtr<SyntaxNode> EmptyDecl::Accept(SyntaxVisitor * visitor) + { + return this; + } - // + // - SyntaxNodeBase* createInstanceOfSyntaxClassByName( - String const& name) - { - if(0) {} - #define CASE(NAME) \ - else if(name == #NAME) return new NAME() + SyntaxNodeBase* createInstanceOfSyntaxClassByName( + String const& name) + { + if(0) {} + #define CASE(NAME) \ + else if(name == #NAME) return new NAME() - CASE(GLSLBufferModifier); - CASE(GLSLWriteOnlyModifier); - CASE(GLSLReadOnlyModifier); - CASE(GLSLPatchModifier); - CASE(SimpleModifier); + CASE(GLSLBufferModifier); + CASE(GLSLWriteOnlyModifier); + CASE(GLSLReadOnlyModifier); + CASE(GLSLPatchModifier); + CASE(SimpleModifier); - #undef CASE - else - { - assert(!"unexpected"); - return nullptr; - } + #undef CASE + else + { + assert(!"unexpected"); + return nullptr; } + } - IntrinsicOp findIntrinsicOp(char const* name) - { - // TODO: need to make this faster by using a dictionary... + IntrinsicOp findIntrinsicOp(char const* name) + { + // TODO: need to make this faster by using a dictionary... - if (0) {} + if (0) {} #define INTRINSIC(NAME) else if(strcmp(name, #NAME) == 0) return IntrinsicOp::NAME; #include "intrinsic-defs.h" - return IntrinsicOp::Unknown; - } - + return IntrinsicOp::Unknown; } + }
\ No newline at end of file diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 642f4e99f..8d56cd28e 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -11,2797 +11,2794 @@ namespace Slang { - namespace Compiler - { - using namespace CoreLib::Basic; - class SyntaxVisitor; - class FunctionSyntaxNode; + using namespace CoreLib::Basic; + class SyntaxVisitor; + class FunctionSyntaxNode; - class SyntaxNodeBase : public RefObject - { - public: - CodePosition Position; - }; + class SyntaxNodeBase : public RefObject + { + public: + CodePosition Position; + }; - // - // Other modifiers may have more elaborate data, and so - // are represented as heap-allocated objects, in a linked - // list. - // - class Modifier : public SyntaxNodeBase - { - public: - // Next modifier in linked list of modifiers on same piece of syntax - RefPtr<Modifier> next; + // + // Other modifiers may have more elaborate data, and so + // are represented as heap-allocated objects, in a linked + // list. + // + class Modifier : public SyntaxNodeBase + { + public: + // Next modifier in linked list of modifiers on same piece of syntax + RefPtr<Modifier> next; - // The token that was used to name this modifier. - Token nameToken; - }; + // The token that was used to name this modifier. + Token nameToken; + }; #define SIMPLE_MODIFIER(NAME) \ - class NAME##Modifier : public Modifier {} - - SIMPLE_MODIFIER(Uniform); - SIMPLE_MODIFIER(In); - SIMPLE_MODIFIER(Out); - SIMPLE_MODIFIER(Const); - SIMPLE_MODIFIER(Instance); - SIMPLE_MODIFIER(Builtin); - SIMPLE_MODIFIER(Inline); - SIMPLE_MODIFIER(Public); - SIMPLE_MODIFIER(Require); - SIMPLE_MODIFIER(Param); - SIMPLE_MODIFIER(Extern); - SIMPLE_MODIFIER(Input); - SIMPLE_MODIFIER(Transparent); - SIMPLE_MODIFIER(FromStdLib); - SIMPLE_MODIFIER(Prefix); - SIMPLE_MODIFIER(Postfix); + class NAME##Modifier : public Modifier {} + + SIMPLE_MODIFIER(Uniform); + SIMPLE_MODIFIER(In); + SIMPLE_MODIFIER(Out); + SIMPLE_MODIFIER(Const); + SIMPLE_MODIFIER(Instance); + SIMPLE_MODIFIER(Builtin); + SIMPLE_MODIFIER(Inline); + SIMPLE_MODIFIER(Public); + SIMPLE_MODIFIER(Require); + SIMPLE_MODIFIER(Param); + SIMPLE_MODIFIER(Extern); + SIMPLE_MODIFIER(Input); + SIMPLE_MODIFIER(Transparent); + SIMPLE_MODIFIER(FromStdLib); + SIMPLE_MODIFIER(Prefix); + SIMPLE_MODIFIER(Postfix); #undef SIMPLE_MODIFIER - enum class IntrinsicOp - { - Unknown = 0, + enum class IntrinsicOp + { + Unknown = 0, #define INTRINSIC(NAME) NAME, #include "intrinsic-defs.h" - }; + }; - IntrinsicOp findIntrinsicOp(char const* name); + IntrinsicOp findIntrinsicOp(char const* name); - // Base class for modifiers that mark something as "intrinsic" - // and thus lacking a direct implementation in the language. - class IntrinsicModifierBase : public Modifier - { - }; - - // A modifier that marks something as one of a small set of - // truly intrinsic operations that the compiler knows about - // directly. - class IntrinsicOpModifier : public IntrinsicModifierBase - { - public: - // token that names the intrinsic op - Token opToken; + // Base class for modifiers that mark something as "intrinsic" + // and thus lacking a direct implementation in the language. + class IntrinsicModifierBase : public Modifier + { + }; - // The opcode for the intrinsic operation - IntrinsicOp op = IntrinsicOp::Unknown; - }; + // A modifier that marks something as one of a small set of + // truly intrinsic operations that the compiler knows about + // directly. + class IntrinsicOpModifier : public IntrinsicModifierBase + { + public: + // token that names the intrinsic op + Token opToken; - // A modifier that marks something as an intrinsic function, - // for some subset of targets. - class TargetIntrinsicModifier : public IntrinsicModifierBase - { - public: - // Token that names the target that the operation - // is an intrisic for. - Token targetToken; + // The opcode for the intrinsic operation + IntrinsicOp op = IntrinsicOp::Unknown; + }; - // A custom definition for the operation - Token definitionToken; - }; + // A modifier that marks something as an intrinsic function, + // for some subset of targets. + class TargetIntrinsicModifier : public IntrinsicModifierBase + { + public: + // Token that names the target that the operation + // is an intrisic for. + Token targetToken; + + // A custom definition for the operation + Token definitionToken; + }; + + + + class InOutModifier : public OutModifier {}; + + // This is a special sentinel modifier that gets added + // to the list when we have multiple variable declarations + // all sharing the same modifiers: + // + // static uniform int a : FOO, *b : register(x0); + // + // In this case both `a` and `b` share the syntax + // for part of their modifier list, but then have + // their own modifiers as well: + // + // a: SemanticModifier("FOO") --> SharedModifiers --> StaticModifier --> UniformModifier + // / + // b: RegisterModifier("x0") / + // + class SharedModifiers : public Modifier {}; + + // A GLSL `layout` modifier + // + // We use a distinct modifier for each key that + // appears within the `layout(...)` construct, + // and each key might have an optional value token. + // + // TODO: We probably want a notion of "modifier groups" + // so that we can recover good source location info + // for modifiers that were part of the same vs. + // different constructs. + class GLSLLayoutModifier : public Modifier + { + public: + // THe token used to introduce the modifier is stored + // as the `nameToken` field. + + // TODO: may want to accept a full expression here + Token valToken; + }; + + // We divide GLSL `layout` modifiers into those we have parsed + // (in the sense of having some notion of their semantics), and + // those we have not. + class GLSLParsedLayoutModifier : public GLSLLayoutModifier {}; + class GLSLUnparsedLayoutModifier : public GLSLLayoutModifier {}; + + // Specific cases for known GLSL `layout` modifiers that we need to work with + class GLSLConstantIDLayoutModifier : public GLSLParsedLayoutModifier {}; + class GLSLBindingLayoutModifier : public GLSLParsedLayoutModifier {}; + class GLSLSetLayoutModifier : public GLSLParsedLayoutModifier {}; + class GLSLLocationLayoutModifier : public GLSLParsedLayoutModifier {}; + + // A catch-all for single-keyword modifiers + class SimpleModifier : public Modifier {}; + + // Some GLSL-specific modifiers + class GLSLBufferModifier : public SimpleModifier {}; + class GLSLWriteOnlyModifier : public SimpleModifier {}; + class GLSLReadOnlyModifier : public SimpleModifier {}; + class GLSLPatchModifier : public SimpleModifier {}; + + // Indicates that this is a variable declaration that corresponds to + // a parameter block declaration in the source program. + class ImplicitParameterBlockVariableModifier : public Modifier {}; + + // Indicates that this is a type that corresponds to the element + // type of a parameter block declaration in the source program. + class ImplicitParameterBlockElementTypeModifier : public Modifier {}; + + // An HLSL semantic + class HLSLSemantic : public Modifier + { + public: + Token name; + }; + // An HLSL semantic that affects layout + class HLSLLayoutSemantic : public HLSLSemantic + { + public: + Token registerName; + Token componentMask; + }; - class InOutModifier : public OutModifier {}; + // An HLSL `register` semantic + class HLSLRegisterSemantic : public HLSLLayoutSemantic + { + }; - // This is a special sentinel modifier that gets added - // to the list when we have multiple variable declarations - // all sharing the same modifiers: - // - // static uniform int a : FOO, *b : register(x0); - // - // In this case both `a` and `b` share the syntax - // for part of their modifier list, but then have - // their own modifiers as well: - // - // a: SemanticModifier("FOO") --> SharedModifiers --> StaticModifier --> UniformModifier - // / - // b: RegisterModifier("x0") / - // - class SharedModifiers : public Modifier {}; + // TODO(tfoley): `packoffset` + class HLSLPackOffsetSemantic : public HLSLLayoutSemantic + { + }; - // A GLSL `layout` modifier - // - // We use a distinct modifier for each key that - // appears within the `layout(...)` construct, - // and each key might have an optional value token. - // - // TODO: We probably want a notion of "modifier groups" - // so that we can recover good source location info - // for modifiers that were part of the same vs. - // different constructs. - class GLSLLayoutModifier : public Modifier - { - public: - // THe token used to introduce the modifier is stored - // as the `nameToken` field. + // An HLSL semantic that just associated a declaration with a semantic name + class HLSLSimpleSemantic : public HLSLSemantic + { + }; - // TODO: may want to accept a full expression here - Token valToken; - }; + // GLSL - // We divide GLSL `layout` modifiers into those we have parsed - // (in the sense of having some notion of their semantics), and - // those we have not. - class GLSLParsedLayoutModifier : public GLSLLayoutModifier {}; - class GLSLUnparsedLayoutModifier : public GLSLLayoutModifier {}; + // Directives that came in via the preprocessor, but + // that we need to keep around for later steps + class GLSLPreprocessorDirective : public Modifier + { + }; - // Specific cases for known GLSL `layout` modifiers that we need to work with - class GLSLConstantIDLayoutModifier : public GLSLParsedLayoutModifier {}; - class GLSLBindingLayoutModifier : public GLSLParsedLayoutModifier {}; - class GLSLSetLayoutModifier : public GLSLParsedLayoutModifier {}; - class GLSLLocationLayoutModifier : public GLSLParsedLayoutModifier {}; + // A GLSL `#version` directive + class GLSLVersionDirective : public GLSLPreprocessorDirective + { + public: + // Token giving the version number to use + Token versionNumberToken; - // A catch-all for single-keyword modifiers - class SimpleModifier : public Modifier {}; + // Optional token giving the sub-profile to be used + Token glslProfileToken; + }; - // Some GLSL-specific modifiers - class GLSLBufferModifier : public SimpleModifier {}; - class GLSLWriteOnlyModifier : public SimpleModifier {}; - class GLSLReadOnlyModifier : public SimpleModifier {}; - class GLSLPatchModifier : public SimpleModifier {}; + // A GLSL `#extension` directive + class GLSLExtensionDirective : public GLSLPreprocessorDirective + { + public: + // Token giving the version number to use + Token extensionNameToken; - // Indicates that this is a variable declaration that corresponds to - // a parameter block declaration in the source program. - class ImplicitParameterBlockVariableModifier : public Modifier {}; + // Optional token giving the sub-profile to be used + Token dispositionToken; + }; - // Indicates that this is a type that corresponds to the element - // type of a parameter block declaration in the source program. - class ImplicitParameterBlockElementTypeModifier : public Modifier {}; + class ParameterBlockReflectionName : public Modifier + { + public: + Token nameToken; + }; - // An HLSL semantic - class HLSLSemantic : public Modifier + // Helper class for iterating over a list of heap-allocated modifiers + struct ModifierList + { + struct Iterator { - public: - Token name; - }; + Modifier* current; + Modifier* operator*() + { + return current; + } - // An HLSL semantic that affects layout - class HLSLLayoutSemantic : public HLSLSemantic - { - public: - Token registerName; - Token componentMask; - }; + void operator++() + { + current = current->next.Ptr(); + } - // An HLSL `register` semantic - class HLSLRegisterSemantic : public HLSLLayoutSemantic - { - }; + bool operator!=(Iterator other) + { + return current != other.current; + }; - // TODO(tfoley): `packoffset` - class HLSLPackOffsetSemantic : public HLSLLayoutSemantic - { - }; + Iterator() + : current(nullptr) + {} - // An HLSL semantic that just associated a declaration with a semantic name - class HLSLSimpleSemantic : public HLSLSemantic - { + Iterator(Modifier* modifier) + : current(modifier) + {} }; - // GLSL + ModifierList() + : modifiers(nullptr) + {} - // Directives that came in via the preprocessor, but - // that we need to keep around for later steps - class GLSLPreprocessorDirective : public Modifier - { - }; + ModifierList(Modifier* modifiers) + : modifiers(modifiers) + {} - // A GLSL `#version` directive - class GLSLVersionDirective : public GLSLPreprocessorDirective - { - public: - // Token giving the version number to use - Token versionNumberToken; + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } - // Optional token giving the sub-profile to be used - Token glslProfileToken; - }; + Modifier* modifiers; + }; - // A GLSL `#extension` directive - class GLSLExtensionDirective : public GLSLPreprocessorDirective + // Helper class for iterating over heap-allocated modifiers + // of a specific type. + template<typename T> + struct FilteredModifierList + { + struct Iterator { - public: - // Token giving the version number to use - Token extensionNameToken; + Modifier* current; - // Optional token giving the sub-profile to be used - Token dispositionToken; - }; + T* operator*() + { + return (T*)current; + } - class ParameterBlockReflectionName : public Modifier - { - public: - Token nameToken; - }; + void operator++() + { + current = Adjust(current->next.Ptr()); + } - // Helper class for iterating over a list of heap-allocated modifiers - struct ModifierList - { - struct Iterator + bool operator!=(Iterator other) { - Modifier* current; - - Modifier* operator*() - { - return current; - } - - void operator++() - { - current = current->next.Ptr(); - } - - bool operator!=(Iterator other) - { - return current != other.current; - }; - - Iterator() - : current(nullptr) - {} - - Iterator(Modifier* modifier) - : current(modifier) - {} + return current != other.current; }; - ModifierList() - : modifiers(nullptr) + Iterator() + : current(nullptr) {} - ModifierList(Modifier* modifiers) - : modifiers(modifiers) + Iterator(Modifier* modifier) + : current(modifier) {} - - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } - - Modifier* modifiers; }; - // Helper class for iterating over heap-allocated modifiers - // of a specific type. - template<typename T> - struct FilteredModifierList - { - struct Iterator - { - Modifier* current; - - T* operator*() - { - return (T*)current; - } - - void operator++() - { - current = Adjust(current->next.Ptr()); - } - - bool operator!=(Iterator other) - { - return current != other.current; - }; - - Iterator() - : current(nullptr) - {} - - Iterator(Modifier* modifier) - : current(modifier) - {} - }; + FilteredModifierList() + : modifiers(nullptr) + {} - FilteredModifierList() - : modifiers(nullptr) - {} - - FilteredModifierList(Modifier* modifiers) - : modifiers(Adjust(modifiers)) - {} + FilteredModifierList(Modifier* modifiers) + : modifiers(Adjust(modifiers)) + {} - Iterator begin() { return Iterator(modifiers); } - Iterator end() { return Iterator(nullptr); } + Iterator begin() { return Iterator(modifiers); } + Iterator end() { return Iterator(nullptr); } - static Modifier* Adjust(Modifier* modifier) + static Modifier* Adjust(Modifier* modifier) + { + Modifier* m = modifier; + for (;;) { - Modifier* m = modifier; - for (;;) - { - if (!m) return m; - if (dynamic_cast<T*>(m)) return m; - m = m->next.Ptr(); - } + if (!m) return m; + if (dynamic_cast<T*>(m)) return m; + m = m->next.Ptr(); } + } - Modifier* modifiers; - }; - - // A set of modifiers attached to a syntax node - struct Modifiers - { - // The first modifier in the linked list of heap-allocated modifiers - RefPtr<Modifier> first; + Modifier* modifiers; + }; - template<typename T> - FilteredModifierList<T> getModifiersOfType() { return FilteredModifierList<T>(first.Ptr()); } + // A set of modifiers attached to a syntax node + struct Modifiers + { + // The first modifier in the linked list of heap-allocated modifiers + RefPtr<Modifier> first; - // Find the first modifier of a given type, or return `nullptr` if none is found. - template<typename T> - T* findModifier() - { - return *getModifiersOfType<T>().begin(); - } + template<typename T> + FilteredModifierList<T> getModifiersOfType() { return FilteredModifierList<T>(first.Ptr()); } - template<typename T> - bool hasModifier() { return findModifier<T>() != nullptr; } + // Find the first modifier of a given type, or return `nullptr` if none is found. + template<typename T> + T* findModifier() + { + return *getModifiersOfType<T>().begin(); + } - FilteredModifierList<Modifier>::Iterator begin() { return FilteredModifierList<Modifier>::Iterator(first.Ptr()); } - FilteredModifierList<Modifier>::Iterator end() { return FilteredModifierList<Modifier>::Iterator(nullptr); } - }; + template<typename T> + bool hasModifier() { return findModifier<T>() != nullptr; } + FilteredModifierList<Modifier>::Iterator begin() { return FilteredModifierList<Modifier>::Iterator(first.Ptr()); } + FilteredModifierList<Modifier>::Iterator end() { return FilteredModifierList<Modifier>::Iterator(nullptr); } + }; - enum class BaseType - { - // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this - Void = 0, - Bool, - Int, - UInt, - UInt64, - Float, + enum class BaseType + { + // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this + + Void = 0, + Bool, + Int, + UInt, + UInt64, + Float, #if 0 - Texture2D = 48, - TextureCube = 49, - Texture2DArray = 50, - Texture2DShadow = 51, - TextureCubeShadow = 52, - Texture2DArrayShadow = 53, - Texture3D = 54, - SamplerState = 4096, SamplerComparisonState = 4097, - Error = 16384, + Texture2D = 48, + TextureCube = 49, + Texture2DArray = 50, + Texture2DShadow = 51, + TextureCubeShadow = 52, + Texture2DArrayShadow = 53, + Texture3D = 54, + SamplerState = 4096, SamplerComparisonState = 4097, + Error = 16384, #endif - }; - - class Decl; - class StructSyntaxNode; - class BasicExpressionType; - class ArrayExpressionType; - class TypeDefDecl; - class DeclRefType; - class NamedExpressionType; - class TypeType; - class GenericDeclRefType; - class VectorExpressionType; - class MatrixExpressionType; - class ArithmeticExpressionType; - class GenericDecl; - class Substitutions; - class TextureType; - class SamplerStateType; - - // A compile-time constant value (usually a type) - class Val : public RefObject - { - public: - // construct a new value by applying a set of parameter - // substitutions to this one - RefPtr<Val> Substitute(Substitutions* subst); - - // Lower-level interface for substition. Like the basic - // `Substitute` above, but also takes a by-reference - // integer parameter that should be incremented when - // returning a modified value (this can help the caller - // decide whether they need to do anything). - virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff); - - virtual bool EqualsVal(Val* val) = 0; - virtual String ToString() = 0; - virtual int GetHashCode() = 0; - bool operator == (const Val & v) - { - return EqualsVal(const_cast<Val*>(&v)); - } - }; + }; + + class Decl; + class StructSyntaxNode; + class BasicExpressionType; + class ArrayExpressionType; + class TypeDefDecl; + class DeclRefType; + class NamedExpressionType; + class TypeType; + class GenericDeclRefType; + class VectorExpressionType; + class MatrixExpressionType; + class ArithmeticExpressionType; + class GenericDecl; + class Substitutions; + class TextureType; + class SamplerStateType; + + // A compile-time constant value (usually a type) + class Val : public RefObject + { + public: + // construct a new value by applying a set of parameter + // substitutions to this one + RefPtr<Val> Substitute(Substitutions* subst); + + // Lower-level interface for substition. Like the basic + // `Substitute` above, but also takes a by-reference + // integer parameter that should be incremented when + // returning a modified value (this can help the caller + // decide whether they need to do anything). + virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff); + + virtual bool EqualsVal(Val* val) = 0; + virtual String ToString() = 0; + virtual int GetHashCode() = 0; + bool operator == (const Val & v) + { + return EqualsVal(const_cast<Val*>(&v)); + } + }; - // A compile-time integer (may not have a specific concrete value) - class IntVal : public Val - { - }; + // A compile-time integer (may not have a specific concrete value) + class IntVal : public Val + { + }; - // Try to extract a simple integer value from an `IntVal`. - // This fill assert-fail if the object doesn't represent a literal value. - int GetIntVal(RefPtr<IntVal> val); + // Try to extract a simple integer value from an `IntVal`. + // This fill assert-fail if the object doesn't represent a literal value. + int GetIntVal(RefPtr<IntVal> val); - // Trivial case of a value that is just a constant integer - class ConstantIntVal : public IntVal - { - public: - int value; + // Trivial case of a value that is just a constant integer + class ConstantIntVal : public IntVal + { + public: + int value; + + ConstantIntVal(int value) + : value(value) + {} + + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + }; + + // TODO(tfoley): classes for more general compile-time integers, + // including references to template parameters + + // A type, representing a classifier for some term in the AST. + // + // Types can include "sugar" in that they may refer to a + // `typedef` which gives them a good name when printed as + // part of diagnostic messages. + // + // In order to operation on types, though, we often want + // to look past any sugar, and operate on an underlying + // "canonical" type. The reprsentation caches a pointer to + // a canonical type on every type, so we can easily + // operate on the raw representation when needed. + class ExpressionType : public Val + { + public: + static RefPtr<ExpressionType> Error; + static RefPtr<ExpressionType> initializerListType; + static RefPtr<ExpressionType> Overloaded; - ConstantIntVal(int value) - : value(value) - {} + static Dictionary<int, RefPtr<ExpressionType>> sBuiltinTypes; + static Dictionary<String, Decl*> sMagicDecls; - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - }; + // Note: just exists to make sure we can clean up + // canonical types we create along the way + static List<RefPtr<ExpressionType>> sCanonicalTypes; - // TODO(tfoley): classes for more general compile-time integers, - // including references to template parameters - // A type, representing a classifier for some term in the AST. - // - // Types can include "sugar" in that they may refer to a - // `typedef` which gives them a good name when printed as - // part of diagnostic messages. - // - // In order to operation on types, though, we often want - // to look past any sugar, and operate on an underlying - // "canonical" type. The reprsentation caches a pointer to - // a canonical type on every type, so we can easily - // operate on the raw representation when needed. - class ExpressionType : public Val - { - public: - static RefPtr<ExpressionType> Error; - static RefPtr<ExpressionType> initializerListType; - static RefPtr<ExpressionType> Overloaded; - static Dictionary<int, RefPtr<ExpressionType>> sBuiltinTypes; - static Dictionary<String, Decl*> sMagicDecls; + static ExpressionType* GetBool(); + static ExpressionType* GetFloat(); + static ExpressionType* GetInt(); + static ExpressionType* GetUInt(); + static ExpressionType* GetVoid(); + static ExpressionType* getInitializerListType(); + static ExpressionType* GetError(); - // Note: just exists to make sure we can clean up - // canonical types we create along the way - static List<RefPtr<ExpressionType>> sCanonicalTypes; + public: + virtual String ToString() = 0; + bool Equals(ExpressionType * type); + bool Equals(RefPtr<ExpressionType> type); + bool IsVectorType() { return As<VectorExpressionType>() != nullptr; } + bool IsArray() { return As<ArrayExpressionType>() != nullptr; } - static ExpressionType* GetBool(); - static ExpressionType* GetFloat(); - static ExpressionType* GetInt(); - static ExpressionType* GetUInt(); - static ExpressionType* GetVoid(); - static ExpressionType* getInitializerListType(); - static ExpressionType* GetError(); + template<typename T> + T* As() + { + return dynamic_cast<T*>(GetCanonicalType()); + } - public: - virtual String ToString() = 0; + // Convenience/legacy wrappers for `As<>` + ArithmeticExpressionType * AsArithmeticType() { return As<ArithmeticExpressionType>(); } + BasicExpressionType * AsBasicType() { return As<BasicExpressionType>(); } + VectorExpressionType * AsVectorType() { return As<VectorExpressionType>(); } + MatrixExpressionType * AsMatrixType() { return As<MatrixExpressionType>(); } + ArrayExpressionType * AsArrayType() { return As<ArrayExpressionType>(); } - bool Equals(ExpressionType * type); - bool Equals(RefPtr<ExpressionType> type); + DeclRefType* AsDeclRefType() { return As<DeclRefType>(); } - bool IsVectorType() { return As<VectorExpressionType>() != nullptr; } - bool IsArray() { return As<ArrayExpressionType>() != nullptr; } + NamedExpressionType* AsNamedType(); - template<typename T> - T* As() - { - return dynamic_cast<T*>(GetCanonicalType()); - } + bool IsTextureOrSampler(); + bool IsTexture() { return As<TextureType>() != nullptr; } + bool IsSampler() { return As<SamplerStateType>() != nullptr; } + bool IsStruct(); + bool IsClass(); + static void Init(); + static void Finalize(); + ExpressionType* GetCanonicalType(); - // Convenience/legacy wrappers for `As<>` - ArithmeticExpressionType * AsArithmeticType() { return As<ArithmeticExpressionType>(); } - BasicExpressionType * AsBasicType() { return As<BasicExpressionType>(); } - VectorExpressionType * AsVectorType() { return As<VectorExpressionType>(); } - MatrixExpressionType * AsMatrixType() { return As<MatrixExpressionType>(); } - ArrayExpressionType * AsArrayType() { return As<ArrayExpressionType>(); } + virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; - DeclRefType* AsDeclRefType() { return As<DeclRefType>(); } + virtual bool EqualsVal(Val* val) override; + protected: + virtual bool EqualsImpl(ExpressionType * type) = 0; - NamedExpressionType* AsNamedType(); + virtual ExpressionType* CreateCanonicalType() = 0; + ExpressionType* canonicalType = nullptr; + }; - bool IsTextureOrSampler(); - bool IsTexture() { return As<TextureType>() != nullptr; } - bool IsSampler() { return As<SamplerStateType>() != nullptr; } - bool IsStruct(); - bool IsClass(); - static void Init(); - static void Finalize(); - ExpressionType* GetCanonicalType(); + // A substitution represents a binding of certain + // type-level variables to concrete argument values + class Substitutions : public RefObject + { + public: + // The generic declaration that defines the + // parametesr we are binding to arguments + GenericDecl* genericDecl; - virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + // The actual values of the arguments + List<RefPtr<Val>> args; - virtual bool EqualsVal(Val* val) override; - protected: - virtual bool EqualsImpl(ExpressionType * type) = 0; + // Any further substitutions, relating to outer generic declarations + RefPtr<Substitutions> outer; - virtual ExpressionType* CreateCanonicalType() = 0; - ExpressionType* canonicalType = nullptr; - }; + // Apply a set of substitutions to the bindings in this substitution + RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff); - // A substitution represents a binding of certain - // type-level variables to concrete argument values - class Substitutions : public RefObject + // Check if these are equivalent substitutiosn to another set + bool Equals(Substitutions* subst); + bool operator == (const Substitutions & subst) { - public: - // The generic declaration that defines the - // parametesr we are binding to arguments - GenericDecl* genericDecl; - - // The actual values of the arguments - List<RefPtr<Val>> args; - - // Any further substitutions, relating to outer generic declarations - RefPtr<Substitutions> outer; - - // Apply a set of substitutions to the bindings in this substitution - RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff); - - // Check if these are equivalent substitutiosn to another set - bool Equals(Substitutions* subst); - bool operator == (const Substitutions & subst) - { - return Equals(const_cast<Substitutions*>(&subst)); - } - int GetHashCode() const + return Equals(const_cast<Substitutions*>(&subst)); + } + int GetHashCode() const + { + int rs = 0; + for (auto && v : args) { - int rs = 0; - for (auto && v : args) - { - rs ^= v->GetHashCode(); - rs *= 16777619; - } - return rs; + rs ^= v->GetHashCode(); + rs *= 16777619; } - }; - - class SyntaxNode : public SyntaxNodeBase - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) = 0; - }; - - class ContainerDecl; - class SpecializeModifier; - - // Represents how much checking has been applied to a declaration. - enum class DeclCheckState : uint8_t - { - // The declaration has been parsed, but not checked - Unchecked, - - // We are in the process of checking the declaration "header" - // (those parts of the declaration needed in order to - // reference it) - CheckingHeader, + return rs; + } + }; - // We are done checking the declaration header. - CheckedHeader, + class SyntaxNode : public SyntaxNodeBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) = 0; + }; - // We have checked the declaration fully. - Checked, - }; + class ContainerDecl; + class SpecializeModifier; - // A syntax node which can have modifiers appled - class ModifiableSyntaxNode : public SyntaxNode - { - public: - Modifiers modifiers; + // Represents how much checking has been applied to a declaration. + enum class DeclCheckState : uint8_t + { + // The declaration has been parsed, but not checked + Unchecked, - template<typename T> - FilteredModifierList<T> GetModifiersOfType() { return FilteredModifierList<T>(modifiers.first.Ptr()); } + // We are in the process of checking the declaration "header" + // (those parts of the declaration needed in order to + // reference it) + CheckingHeader, - // Find the first modifier of a given type, or return `nullptr` if none is found. - template<typename T> - T* FindModifier() - { - return *GetModifiersOfType<T>().begin(); - } + // We are done checking the declaration header. + CheckedHeader, - template<typename T> - bool HasModifier() { return FindModifier<T>() != nullptr; } - }; + // We have checked the declaration fully. + Checked, + }; - void addModifier( - RefPtr<ModifiableSyntaxNode> syntax, - RefPtr<Modifier> modifier); + // A syntax node which can have modifiers appled + class ModifiableSyntaxNode : public SyntaxNode + { + public: + Modifiers modifiers; + template<typename T> + FilteredModifierList<T> GetModifiersOfType() { return FilteredModifierList<T>(modifiers.first.Ptr()); } - // An intermediate type to represent either a single declaration, or a group of declarations - class DeclBase : public ModifiableSyntaxNode + // Find the first modifier of a given type, or return `nullptr` if none is found. + template<typename T> + T* FindModifier() { - public: - }; + return *GetModifiersOfType<T>().begin(); + } - class Decl : public DeclBase - { - public: - ContainerDecl* ParentDecl; + template<typename T> + bool HasModifier() { return FindModifier<T>() != nullptr; } + }; - Token Name; - String const& getName() { return Name.Content; } - Token const& getNameToken() { return Name; } + void addModifier( + RefPtr<ModifiableSyntaxNode> syntax, + RefPtr<Modifier> modifier); - DeclCheckState checkState = DeclCheckState::Unchecked; + // An intermediate type to represent either a single declaration, or a group of declarations + class DeclBase : public ModifiableSyntaxNode + { + public: + }; - // The next declaration defined in the same container with the same name - Decl* nextInContainerWithSameName = nullptr; + class Decl : public DeclBase + { + public: + ContainerDecl* ParentDecl; - bool IsChecked(DeclCheckState state) { return checkState >= state; } - void SetCheckState(DeclCheckState state) - { - assert(state >= checkState); - checkState = state; - } - }; + Token Name; + String const& getName() { return Name.Content; } + Token const& getNameToken() { return Name; } - struct QualType - { - RefPtr<ExpressionType> type; - bool IsLeftValue; - QualType() - : IsLeftValue(false) - {} + DeclCheckState checkState = DeclCheckState::Unchecked; - QualType(RefPtr<ExpressionType> type) - : type(type) - , IsLeftValue(false) - {} + // The next declaration defined in the same container with the same name + Decl* nextInContainerWithSameName = nullptr; - QualType(ExpressionType* type) - : type(type) - , IsLeftValue(false) - {} + bool IsChecked(DeclCheckState state) { return checkState >= state; } + void SetCheckState(DeclCheckState state) + { + assert(state >= checkState); + checkState = state; + } + }; - void operator=(RefPtr<ExpressionType> t) - { - *this = QualType(t); - } + struct QualType + { + RefPtr<ExpressionType> type; + bool IsLeftValue; - void operator=(ExpressionType* t) - { - *this = QualType(t); - } + QualType() + : IsLeftValue(false) + {} - ExpressionType* Ptr() { return type.Ptr(); } + QualType(RefPtr<ExpressionType> type) + : type(type) + , IsLeftValue(false) + {} - operator RefPtr<ExpressionType>() { return type; } - RefPtr<ExpressionType> operator->() { return type; } - }; + QualType(ExpressionType* type) + : type(type) + , IsLeftValue(false) + {} - class ExpressionSyntaxNode : public SyntaxNode + void operator=(RefPtr<ExpressionType> t) { - public: - QualType Type; - ExpressionSyntaxNode() - {} - }; - - - + *this = QualType(t); + } - // A reference to a declaration, which may include - // substitutions for generic parameters. - struct DeclRef + void operator=(ExpressionType* t) { - typedef Decl DeclType; - - // The underlying declaration - Decl* decl = nullptr; - Decl* GetDecl() const { return decl; } + *this = QualType(t); + } - // Optionally, a chain of substititions to perform - RefPtr<Substitutions> substitutions; + ExpressionType* Ptr() { return type.Ptr(); } - DeclRef() - {} + operator RefPtr<ExpressionType>() { return type; } + RefPtr<ExpressionType> operator->() { return type; } + }; - DeclRef(Decl* decl, RefPtr<Substitutions> substitutions) - : decl(decl) - , substitutions(substitutions) - {} + class ExpressionSyntaxNode : public SyntaxNode + { + public: + QualType Type; + ExpressionSyntaxNode() + {} + }; - // Apply substitutions to a type or ddeclaration - RefPtr<ExpressionType> Substitute(RefPtr<ExpressionType> type) const; - DeclRef Substitute(DeclRef declRef) const; - // Apply substitutions to an expression - RefPtr<ExpressionSyntaxNode> Substitute(RefPtr<ExpressionSyntaxNode> expr) const; - // Apply substitutions to this declaration reference - DeclRef SubstituteImpl(Substitutions* subst, int* ioDiff); - // Check if this is an equivalent declaration reference to another - bool Equals(DeclRef const& declRef) const; - bool operator == (const DeclRef& other) const - { - return Equals(other); - } + // A reference to a declaration, which may include + // substitutions for generic parameters. + struct DeclRef + { + typedef Decl DeclType; - // Convenience accessors for common properties of declarations - String const& GetName() const; - DeclRef GetParent() const; + // The underlying declaration + Decl* decl = nullptr; + Decl* GetDecl() const { return decl; } - // "dynamic cast" to a more specific declaration reference type - template<typename T> - T As() const - { - T result; - result.decl = dynamic_cast<T::DeclType*>(decl); - result.substitutions = substitutions; - return result; - } + // Optionally, a chain of substititions to perform + RefPtr<Substitutions> substitutions; - // Implicit conversion mostly so we can use a `DeclRef` - // in a conditional context - operator Decl*() const - { - return decl; - } + DeclRef() + {} - int GetHashCode() const; - }; + DeclRef(Decl* decl, RefPtr<Substitutions> substitutions) + : decl(decl) + , substitutions(substitutions) + {} - // Helper macro for defining `DeclRef` subtypes - #define SLANG_DECLARE_DECL_REF(D) \ - typedef D DeclType; \ - D* GetDecl() const { return (D*) decl; } \ - /* */ + // Apply substitutions to a type or ddeclaration + RefPtr<ExpressionType> Substitute(RefPtr<ExpressionType> type) const; + DeclRef Substitute(DeclRef declRef) const; + // Apply substitutions to an expression + RefPtr<ExpressionSyntaxNode> Substitute(RefPtr<ExpressionSyntaxNode> expr) const; + // Apply substitutions to this declaration reference + DeclRef SubstituteImpl(Substitutions* subst, int* ioDiff); - // The type of a reference to an overloaded name - class OverloadGroupType : public ExpressionType + // Check if this is an equivalent declaration reference to another + bool Equals(DeclRef const& declRef) const; + bool operator == (const DeclRef& other) const { - public: - virtual String ToString() override; + return Equals(other); + } - protected: - virtual bool EqualsImpl(ExpressionType * type) override; - virtual ExpressionType* CreateCanonicalType() override; - virtual int GetHashCode() override; - }; + // Convenience accessors for common properties of declarations + String const& GetName() const; + DeclRef GetParent() const; - // The type of an initializer-list expression (before it has - // been coerced to some other type) - class InitializerListType : public ExpressionType + // "dynamic cast" to a more specific declaration reference type + template<typename T> + T As() const { - public: - virtual String ToString() override; - - protected: - virtual bool EqualsImpl(ExpressionType * type) override; - virtual ExpressionType* CreateCanonicalType() override; - virtual int GetHashCode() override; - }; + T result; + result.decl = dynamic_cast<T::DeclType*>(decl); + result.substitutions = substitutions; + return result; + } - // The type of an expression that was erroneous - class ErrorType : public ExpressionType + // Implicit conversion mostly so we can use a `DeclRef` + // in a conditional context + operator Decl*() const { - public: - virtual String ToString() override; - - protected: - virtual bool EqualsImpl(ExpressionType * type) override; - virtual ExpressionType* CreateCanonicalType() override; - virtual int GetHashCode() override; - }; + return decl; + } - // A type that takes the form of a reference to some declaration - class DeclRefType : public ExpressionType - { - public: - DeclRef declRef; + int GetHashCode() const; + }; - virtual String ToString() override; - virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + // Helper macro for defining `DeclRef` subtypes + #define SLANG_DECLARE_DECL_REF(D) \ + typedef D DeclType; \ + D* GetDecl() const { return (D*) decl; } \ + /* */ - static DeclRefType* Create(DeclRef declRef); - protected: - DeclRefType() - {} - DeclRefType(DeclRef declRef) - : declRef(declRef) - {} - virtual int GetHashCode() override; - virtual bool EqualsImpl(ExpressionType * type) override; - virtual ExpressionType* CreateCanonicalType() override; - }; - // Base class for types that can be used in arithmetic expressions - class ArithmeticExpressionType : public DeclRefType - { - public: - virtual BasicExpressionType* GetScalarType() = 0; - }; + // The type of a reference to an overloaded name + class OverloadGroupType : public ExpressionType + { + public: + virtual String ToString() override; + + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + // The type of an initializer-list expression (before it has + // been coerced to some other type) + class InitializerListType : public ExpressionType + { + public: + virtual String ToString() override; - class FunctionDeclBase; + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; - class BasicExpressionType : public ArithmeticExpressionType - { - public: - BaseType BaseType; + // The type of an expression that was erroneous + class ErrorType : public ExpressionType + { + public: + virtual String ToString() override; - BasicExpressionType() - { - BaseType = Compiler::BaseType::Int; - } - BasicExpressionType(Compiler::BaseType baseType) - { - BaseType = baseType; - } - virtual CoreLib::Basic::String ToString() override; - protected: - virtual BasicExpressionType* GetScalarType() override; - virtual bool EqualsImpl(ExpressionType * type) override; - virtual ExpressionType* CreateCanonicalType() override; - }; + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + // A type that takes the form of a reference to some declaration + class DeclRefType : public ExpressionType + { + public: + DeclRef declRef; + + virtual String ToString() override; + virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + + static DeclRefType* Create(DeclRef declRef); + + protected: + DeclRefType() + {} + DeclRefType(DeclRef declRef) + : declRef(declRef) + {} + virtual int GetHashCode() override; + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + }; + + // Base class for types that can be used in arithmetic expressions + class ArithmeticExpressionType : public DeclRefType + { + public: + virtual BasicExpressionType* GetScalarType() = 0; + }; - class TextureTypeBase : public DeclRefType - { - public: - // The type that results from fetching an element from this texture - RefPtr<ExpressionType> elementType; + class FunctionDeclBase; - // Bits representing the kind of texture type we are looking at - // (e.g., `Texture2DMS` vs. `TextureCubeArray`) - typedef uint16_t Flavor; - Flavor flavor; + class BasicExpressionType : public ArithmeticExpressionType + { + public: + BaseType BaseType; - enum - { - // Mask for the overall "shape" of the texture - ShapeMask = SLANG_RESOURCE_BASE_SHAPE_MASK, + BasicExpressionType() + { + BaseType = Slang::BaseType::Int; + } + BasicExpressionType(Slang::BaseType baseType) + { + BaseType = baseType; + } + virtual CoreLib::Basic::String ToString() override; + protected: + virtual BasicExpressionType* GetScalarType() override; + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + }; - // Flag for whether the shape has "array-ness" - ArrayFlag = SLANG_TEXTURE_ARRAY_FLAG, - // Whether or not the texture stores multiple samples per pixel - MultisampleFlag = SLANG_TEXTURE_MULTISAMPLE_FLAG, + class TextureTypeBase : public DeclRefType + { + public: + // The type that results from fetching an element from this texture + RefPtr<ExpressionType> elementType; - // Whether or not this is a shadow texture - // - // TODO(tfoley): is this even meaningful/used? - // ShadowFlag = 0x80, - }; + // Bits representing the kind of texture type we are looking at + // (e.g., `Texture2DMS` vs. `TextureCubeArray`) + typedef uint16_t Flavor; + Flavor flavor; - enum Shape : uint8_t - { - Shape1D = SLANG_TEXTURE_1D, - Shape2D = SLANG_TEXTURE_2D, - Shape3D = SLANG_TEXTURE_3D, - ShapeCube = SLANG_TEXTURE_CUBE, - - Shape1DArray = Shape1D | ArrayFlag, - Shape2DArray = Shape2D | ArrayFlag, - // No Shape3DArray - ShapeCubeArray = ShapeCube | ArrayFlag, - }; - + enum + { + // Mask for the overall "shape" of the texture + ShapeMask = SLANG_RESOURCE_BASE_SHAPE_MASK, - Shape GetBaseShape() const { return Shape(flavor & ShapeMask); } - bool isArray() const { return (flavor & ArrayFlag) != 0; } - bool isMultisample() const { return (flavor & MultisampleFlag) != 0; } -// bool isShadow() const { return (flavor & ShadowFlag) != 0; } + // Flag for whether the shape has "array-ness" + ArrayFlag = SLANG_TEXTURE_ARRAY_FLAG, - SlangResourceShape getShape() const { return flavor & 0xFF; } - SlangResourceAccess getAccess() const { return (flavor >> 8) & 0xFF; } + // Whether or not the texture stores multiple samples per pixel + MultisampleFlag = SLANG_TEXTURE_MULTISAMPLE_FLAG, - TextureTypeBase( - Flavor flavor, - RefPtr<ExpressionType> elementType) - : elementType(elementType) - , flavor(flavor) - {} + // Whether or not this is a shadow texture + // + // TODO(tfoley): is this even meaningful/used? + // ShadowFlag = 0x80, }; - class TextureType : public TextureTypeBase + enum Shape : uint8_t { - public: - TextureType( - Flavor flavor, - RefPtr<ExpressionType> elementType) - : TextureTypeBase(flavor, elementType) - {} - }; + Shape1D = SLANG_TEXTURE_1D, + Shape2D = SLANG_TEXTURE_2D, + Shape3D = SLANG_TEXTURE_3D, + ShapeCube = SLANG_TEXTURE_CUBE, - // This is a base type for texture/sampler pairs, - // as they exist in, e.g., GLSL - class TextureSamplerType : public TextureTypeBase - { - public: - TextureSamplerType( - Flavor flavor, - RefPtr<ExpressionType> elementType) - : TextureTypeBase(flavor, elementType) - {} + Shape1DArray = Shape1D | ArrayFlag, + Shape2DArray = Shape2D | ArrayFlag, + // No Shape3DArray + ShapeCubeArray = ShapeCube | ArrayFlag, }; + - // This is a base type for `image*` types, as they exist in GLSL - class GLSLImageType : public TextureTypeBase - { - public: - GLSLImageType( - Flavor flavor, - RefPtr<ExpressionType> elementType) - : TextureTypeBase(flavor, elementType) - {} - }; + Shape GetBaseShape() const { return Shape(flavor & ShapeMask); } + bool isArray() const { return (flavor & ArrayFlag) != 0; } + bool isMultisample() const { return (flavor & MultisampleFlag) != 0; } +// bool isShadow() const { return (flavor & ShadowFlag) != 0; } - class SamplerStateType : public DeclRefType - { - public: - // What flavor of sampler state is this - enum class Flavor : uint8_t - { - SamplerState, - SamplerComparisonState, - }; - Flavor flavor; - }; + SlangResourceShape getShape() const { return flavor & 0xFF; } + SlangResourceAccess getAccess() const { return (flavor >> 8) & 0xFF; } + + TextureTypeBase( + Flavor flavor, + RefPtr<ExpressionType> elementType) + : elementType(elementType) + , flavor(flavor) + {} + }; - // Other cases of generic types known to the compiler - class BuiltinGenericType : public DeclRefType + class TextureType : public TextureTypeBase + { + public: + TextureType( + Flavor flavor, + RefPtr<ExpressionType> elementType) + : TextureTypeBase(flavor, elementType) + {} + }; + + // This is a base type for texture/sampler pairs, + // as they exist in, e.g., GLSL + class TextureSamplerType : public TextureTypeBase + { + public: + TextureSamplerType( + Flavor flavor, + RefPtr<ExpressionType> elementType) + : TextureTypeBase(flavor, elementType) + {} + }; + + // This is a base type for `image*` types, as they exist in GLSL + class GLSLImageType : public TextureTypeBase + { + public: + GLSLImageType( + Flavor flavor, + RefPtr<ExpressionType> elementType) + : TextureTypeBase(flavor, elementType) + {} + }; + + class SamplerStateType : public DeclRefType + { + public: + // What flavor of sampler state is this + enum class Flavor : uint8_t { - public: - RefPtr<ExpressionType> elementType; + SamplerState, + SamplerComparisonState, }; + Flavor flavor; + }; - // Types that behave like pointers, in that they can be - // dereferenced (implicitly) to access members defined - // in the element type. - class PointerLikeType : public BuiltinGenericType - {}; + // Other cases of generic types known to the compiler + class BuiltinGenericType : public DeclRefType + { + public: + RefPtr<ExpressionType> elementType; + }; - // Generic types used in existing Slang code - // TODO(tfoley): check that these are actually working right... - class PatchType : public PointerLikeType {}; - class StorageBufferType : public BuiltinGenericType {}; - class UniformBufferType : public PointerLikeType {}; - class PackedBufferType : public BuiltinGenericType {}; + // Types that behave like pointers, in that they can be + // dereferenced (implicitly) to access members defined + // in the element type. + class PointerLikeType : public BuiltinGenericType + {}; - // HLSL buffer-type resources + // Generic types used in existing Slang code + // TODO(tfoley): check that these are actually working right... + class PatchType : public PointerLikeType {}; + class StorageBufferType : public BuiltinGenericType {}; + class UniformBufferType : public PointerLikeType {}; + class PackedBufferType : public BuiltinGenericType {}; - class HLSLBufferType : public BuiltinGenericType {}; - class HLSLRWBufferType : public BuiltinGenericType {}; - class HLSLStructuredBufferType : public BuiltinGenericType {}; - class HLSLRWStructuredBufferType : public BuiltinGenericType {}; + // HLSL buffer-type resources - class UntypedBufferResourceType : public DeclRefType {}; - class HLSLByteAddressBufferType : public UntypedBufferResourceType {}; - class HLSLRWByteAddressBufferType : public UntypedBufferResourceType {}; + class HLSLBufferType : public BuiltinGenericType {}; + class HLSLRWBufferType : public BuiltinGenericType {}; + class HLSLStructuredBufferType : public BuiltinGenericType {}; + class HLSLRWStructuredBufferType : public BuiltinGenericType {}; - class HLSLAppendStructuredBufferType : public BuiltinGenericType {}; - class HLSLConsumeStructuredBufferType : public BuiltinGenericType {}; + class UntypedBufferResourceType : public DeclRefType {}; + class HLSLByteAddressBufferType : public UntypedBufferResourceType {}; + class HLSLRWByteAddressBufferType : public UntypedBufferResourceType {}; - class HLSLInputPatchType : public BuiltinGenericType {}; - class HLSLOutputPatchType : public BuiltinGenericType {}; + class HLSLAppendStructuredBufferType : public BuiltinGenericType {}; + class HLSLConsumeStructuredBufferType : public BuiltinGenericType {}; - // HLSL geometry shader output stream types + class HLSLInputPatchType : public BuiltinGenericType {}; + class HLSLOutputPatchType : public BuiltinGenericType {}; - class HLSLStreamOutputType : public BuiltinGenericType {}; - class HLSLPointStreamType : public HLSLStreamOutputType {}; - class HLSLLineStreamType : public HLSLStreamOutputType {}; - class HLSLTriangleStreamType : public HLSLStreamOutputType {}; + // HLSL geometry shader output stream types - // - class GLSLInputAttachmentType : public DeclRefType {}; + class HLSLStreamOutputType : public BuiltinGenericType {}; + class HLSLPointStreamType : public HLSLStreamOutputType {}; + class HLSLLineStreamType : public HLSLStreamOutputType {}; + class HLSLTriangleStreamType : public HLSLStreamOutputType {}; - // Base class for types used when desugaring parameter block - // declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks. - class ParameterBlockType : public PointerLikeType {}; + // + class GLSLInputAttachmentType : public DeclRefType {}; - class UniformParameterBlockType : public ParameterBlockType {}; - class VaryingParameterBlockType : public ParameterBlockType {}; + // Base class for types used when desugaring parameter block + // declarations, includeing HLSL `cbuffer` or GLSL `uniform` blocks. + class ParameterBlockType : public PointerLikeType {}; - // Type for HLSL `cbuffer` declarations, and `ConstantBuffer<T>` - // ALso used for GLSL `uniform` blocks. - class ConstantBufferType : public UniformParameterBlockType {}; + class UniformParameterBlockType : public ParameterBlockType {}; + class VaryingParameterBlockType : public ParameterBlockType {}; - // Type for HLSL `tbuffer` declarations, and `TextureBuffer<T>` - class TextureBufferType : public UniformParameterBlockType {}; + // Type for HLSL `cbuffer` declarations, and `ConstantBuffer<T>` + // ALso used for GLSL `uniform` blocks. + class ConstantBufferType : public UniformParameterBlockType {}; - // Type for GLSL `in` and `out` blocks - class GLSLInputParameterBlockType : public VaryingParameterBlockType {}; - class GLSLOutputParameterBlockType : public VaryingParameterBlockType {}; + // Type for HLSL `tbuffer` declarations, and `TextureBuffer<T>` + class TextureBufferType : public UniformParameterBlockType {}; - // Type for GLLSL `buffer` blocks - class GLSLShaderStorageBufferType : public UniformParameterBlockType {}; + // Type for GLSL `in` and `out` blocks + class GLSLInputParameterBlockType : public VaryingParameterBlockType {}; + class GLSLOutputParameterBlockType : public VaryingParameterBlockType {}; - class ArrayExpressionType : public ExpressionType - { - public: - RefPtr<ExpressionType> BaseType; - RefPtr<IntVal> ArrayLength; - virtual CoreLib::Basic::String ToString() override; - protected: - virtual bool EqualsImpl(ExpressionType * type) override; - virtual ExpressionType* CreateCanonicalType() override; - virtual int GetHashCode() override; - }; + // Type for GLLSL `buffer` blocks + class GLSLShaderStorageBufferType : public UniformParameterBlockType {}; - // The "type" of an expression that resolves to a type. - // For example, in the expression `float(2)` the sub-expression, - // `float` would have the type `TypeType(float)`. - class TypeType : public ExpressionType - { - public: - TypeType(RefPtr<ExpressionType> type) - : type(type) - {} + class ArrayExpressionType : public ExpressionType + { + public: + RefPtr<ExpressionType> BaseType; + RefPtr<IntVal> ArrayLength; + virtual CoreLib::Basic::String ToString() override; + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + // The "type" of an expression that resolves to a type. + // For example, in the expression `float(2)` the sub-expression, + // `float` would have the type `TypeType(float)`. + class TypeType : public ExpressionType + { + public: + TypeType(RefPtr<ExpressionType> type) + : type(type) + {} - // The type that this is the type of... - RefPtr<ExpressionType> type; + // The type that this is the type of... + RefPtr<ExpressionType> type; - virtual String ToString() override; + virtual String ToString() override; - protected: - virtual bool EqualsImpl(ExpressionType * type) override; - virtual ExpressionType* CreateCanonicalType() override; - virtual int GetHashCode() override; - }; + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; - class GenericDecl; + class GenericDecl; - // A vector type, e.g., `vector<T,N>` - class VectorExpressionType : public ArithmeticExpressionType - { - public: + // A vector type, e.g., `vector<T,N>` + class VectorExpressionType : public ArithmeticExpressionType + { + public: #if 0 - VectorExpressionType( - RefPtr<ExpressionType> elementType, - RefPtr<IntVal> elementCount) - : elementType(elementType) - , elementCount(elementCount) - {} + VectorExpressionType( + RefPtr<ExpressionType> elementType, + RefPtr<IntVal> elementCount) + : elementType(elementType) + , elementCount(elementCount) + {} #endif - // The type of vector elements. - // As an invariant, this should be a basic type or an alias. - RefPtr<ExpressionType> elementType; + // The type of vector elements. + // As an invariant, this should be a basic type or an alias. + RefPtr<ExpressionType> elementType; - // The number of elements - RefPtr<IntVal> elementCount; + // The number of elements + RefPtr<IntVal> elementCount; - virtual String ToString() override; + virtual String ToString() override; - protected: - virtual BasicExpressionType* GetScalarType() override; - }; + protected: + virtual BasicExpressionType* GetScalarType() override; + }; - // A matrix type, e.g., `matrix<T,R,C>` - class MatrixExpressionType : public ArithmeticExpressionType - { - public: - // TODO: consider adding these back for convenience, - // with a way to initialize them on-demand from the - // real storage (which is in the `DeclRefType` + // A matrix type, e.g., `matrix<T,R,C>` + class MatrixExpressionType : public ArithmeticExpressionType + { + public: + // TODO: consider adding these back for convenience, + // with a way to initialize them on-demand from the + // real storage (which is in the `DeclRefType` #if 0 - // The type of vector elements. - // As an invariant, this should be a basic type or an alias. - RefPtr<ExpressionType> elementType; + // The type of vector elements. + // As an invariant, this should be a basic type or an alias. + RefPtr<ExpressionType> elementType; - // The type of the matrix rows - RefPtr<VectorExpressionType> rowType; + // The type of the matrix rows + RefPtr<VectorExpressionType> rowType; - // The number of rows and columns - RefPtr<IntVal> rowCount; - RefPtr<IntVal> colCount; + // The number of rows and columns + RefPtr<IntVal> rowCount; + RefPtr<IntVal> colCount; #endif - ExpressionType* getElementType(); - IntVal* getRowCount(); - IntVal* getColumnCount(); + ExpressionType* getElementType(); + IntVal* getRowCount(); + IntVal* getColumnCount(); - virtual String ToString() override; + virtual String ToString() override; - protected: - virtual BasicExpressionType* GetScalarType() override; - }; + protected: + virtual BasicExpressionType* GetScalarType() override; + }; - inline BaseType GetVectorBaseType(VectorExpressionType* vecType) { - return vecType->elementType->AsBasicType()->BaseType; - } + inline BaseType GetVectorBaseType(VectorExpressionType* vecType) { + return vecType->elementType->AsBasicType()->BaseType; + } - inline int GetVectorSize(VectorExpressionType* vecType) - { - auto constantVal = vecType->elementCount.As<ConstantIntVal>(); - if (constantVal) - return constantVal->value; - // TODO: what to do in this case? - return 0; - } + inline int GetVectorSize(VectorExpressionType* vecType) + { + auto constantVal = vecType->elementCount.As<ConstantIntVal>(); + if (constantVal) + return constantVal->value; + // TODO: what to do in this case? + return 0; + } - class Type - { - public: - RefPtr<ExpressionType> DataType; - // ContrainedWorlds: Implementation must be defined at at least one of of these worlds in order to satisfy global dependency - // FeasibleWorlds: The component can be computed at any of these worlds - EnumerableHashSet<String> ConstrainedWorlds, FeasibleWorlds; - EnumerableHashSet<String> PinnedWorlds; - }; + class Type + { + public: + RefPtr<ExpressionType> DataType; + // ContrainedWorlds: Implementation must be defined at at least one of of these worlds in order to satisfy global dependency + // FeasibleWorlds: The component can be computed at any of these worlds + EnumerableHashSet<String> ConstrainedWorlds, FeasibleWorlds; + EnumerableHashSet<String> PinnedWorlds; + }; - class ContainerDecl; + class ContainerDecl; - // A group of declarations that should be treated as a unit - class DeclGroup : public DeclBase - { - public: - List<RefPtr<Decl>> decls; + // A group of declarations that should be treated as a unit + class DeclGroup : public DeclBase + { + public: + List<RefPtr<Decl>> decls; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - template<typename T> - struct FilteredMemberList - { - typedef RefPtr<Decl> Element; + template<typename T> + struct FilteredMemberList + { + typedef RefPtr<Decl> Element; - FilteredMemberList() - : mBegin(NULL) - , mEnd(NULL) - {} + FilteredMemberList() + : mBegin(NULL) + , mEnd(NULL) + {} - explicit FilteredMemberList( - List<Element> const& list) - : mBegin(Adjust(list.begin(), list.end())) - , mEnd(list.end()) - {} + explicit FilteredMemberList( + List<Element> const& list) + : mBegin(Adjust(list.begin(), list.end())) + , mEnd(list.end()) + {} + + struct Iterator + { + Element* mCursor; + Element* mEnd; - struct Iterator + bool operator!=(Iterator const& other) { - Element* mCursor; - Element* mEnd; - - bool operator!=(Iterator const& other) - { - return mCursor != other.mCursor; - } - - void operator++() - { - mCursor = Adjust(mCursor + 1, mEnd); - } - - RefPtr<T>& operator*() - { - return *(RefPtr<T>*)mCursor; - } - }; + return mCursor != other.mCursor; + } - Iterator begin() + void operator++() { - Iterator iter = { mBegin, mEnd }; - return iter; + mCursor = Adjust(mCursor + 1, mEnd); } - Iterator end() + RefPtr<T>& operator*() { - Iterator iter = { mEnd, mEnd }; - return iter; + return *(RefPtr<T>*)mCursor; } + }; - static Element* Adjust(Element* cursor, Element* end) + Iterator begin() + { + Iterator iter = { mBegin, mEnd }; + return iter; + } + + Iterator end() + { + Iterator iter = { mEnd, mEnd }; + return iter; + } + + static Element* Adjust(Element* cursor, Element* end) + { + while (cursor != end) { - while (cursor != end) - { - if ((*cursor).As<T>()) - return cursor; - cursor++; - } - return cursor; + if ((*cursor).As<T>()) + return cursor; + cursor++; } + return cursor; + } - // TODO(tfoley): It is ugly to have these. - // We should probably fix the call sites instead. - RefPtr<T>& First() { return *begin(); } - int Count() + // TODO(tfoley): It is ugly to have these. + // We should probably fix the call sites instead. + RefPtr<T>& First() { return *begin(); } + int Count() + { + int count = 0; + for (auto iter : (*this)) { - int count = 0; - for (auto iter : (*this)) - { - (void)iter; - count++; - } - return count; + (void)iter; + count++; } + return count; + } - List<RefPtr<T>> ToArray() + List<RefPtr<T>> ToArray() + { + List<RefPtr<T>> result; + for (auto element : (*this)) { - List<RefPtr<T>> result; - for (auto element : (*this)) - { - result.Add(element); - } - return result; + result.Add(element); } + return result; + } - Element* mBegin; - Element* mEnd; - }; + Element* mBegin; + Element* mEnd; + }; - struct TransparentMemberInfo - { - // The declaration of the transparent member - Decl* decl; - }; + struct TransparentMemberInfo + { + // The declaration of the transparent member + Decl* decl; + }; + + // A "container" decl is a parent to other declarations + class ContainerDecl : public Decl + { + public: + List<RefPtr<Decl>> Members; - // A "container" decl is a parent to other declarations - class ContainerDecl : public Decl + template<typename T> + FilteredMemberList<T> GetMembersOfType() { - public: - List<RefPtr<Decl>> Members; + return FilteredMemberList<T>(Members); + } - template<typename T> - FilteredMemberList<T> GetMembersOfType() - { - return FilteredMemberList<T>(Members); - } + // Dictionary for looking up members by name. + // This is built on demand before performing lookup. + Dictionary<String, Decl*> memberDictionary; - // Dictionary for looking up members by name. - // This is built on demand before performing lookup. - Dictionary<String, Decl*> memberDictionary; + // Whether the `memberDictionary` is valid. + // Should be set to `false` if any members get added/remoed. + bool memberDictionaryIsValid = false; - // Whether the `memberDictionary` is valid. - // Should be set to `false` if any members get added/remoed. - bool memberDictionaryIsValid = false; + // A list of transparent members, to be used in lookup + // Note: this is only valid if `memberDictionaryIsValid` is true + List<TransparentMemberInfo> transparentMembers; + }; - // A list of transparent members, to be used in lookup - // Note: this is only valid if `memberDictionaryIsValid` is true - List<TransparentMemberInfo> transparentMembers; - }; + template<typename T> + struct FilteredMemberRefList + { + List<RefPtr<Decl>> const& decls; + RefPtr<Substitutions> substitutions; + + FilteredMemberRefList( + List<RefPtr<Decl>> const& decls, + RefPtr<Substitutions> substitutions) + : decls(decls) + , substitutions(substitutions) + {} + + int Count() const + { + int count = 0; + for (auto d : *this) + count++; + return count; + } - template<typename T> - struct FilteredMemberRefList + List<T> ToArray() const + { + List<T> result; + for (auto d : *this) + result.Add(d); + return result; + } + + struct Iterator { - List<RefPtr<Decl>> const& decls; - RefPtr<Substitutions> substitutions; + FilteredMemberRefList const* list; + RefPtr<Decl>* ptr; + RefPtr<Decl>* end; - FilteredMemberRefList( - List<RefPtr<Decl>> const& decls, - RefPtr<Substitutions> substitutions) - : decls(decls) - , substitutions(substitutions) + Iterator() : list(nullptr), ptr(nullptr) {} + Iterator( + FilteredMemberRefList const* list, + RefPtr<Decl>* ptr, + RefPtr<Decl>* end) + : list(list) + , ptr(ptr) + , end(end) {} - int Count() const + bool operator!=(Iterator other) { - int count = 0; - for (auto d : *this) - count++; - return count; + return ptr != other.ptr; } - List<T> ToArray() const + void operator++() { - List<T> result; - for (auto d : *this) - result.Add(d); - return result; + ptr = list->Adjust(ptr + 1, end); } - struct Iterator + T operator*() { - FilteredMemberRefList const* list; - RefPtr<Decl>* ptr; - RefPtr<Decl>* end; - - Iterator() : list(nullptr), ptr(nullptr) {} - Iterator( - FilteredMemberRefList const* list, - RefPtr<Decl>* ptr, - RefPtr<Decl>* end) - : list(list) - , ptr(ptr) - , end(end) - {} - - bool operator!=(Iterator other) - { - return ptr != other.ptr; - } - - void operator++() - { - ptr = list->Adjust(ptr + 1, end); - } - - T operator*() - { - return DeclRef(ptr->Ptr(), list->substitutions).As<T>(); - } - }; - - Iterator begin() const { return Iterator(this, Adjust(decls.begin(), decls.end()), decls.end()); } - Iterator end() const { return Iterator(this, decls.end(), decls.end()); } - - RefPtr<Decl>* Adjust(RefPtr<Decl>* ptr, RefPtr<Decl>* end) const - { - while (ptr != end) - { - DeclRef declRef(ptr->Ptr(), substitutions); - if (declRef.As<T>()) - return ptr; - ptr++; - } - return end; + return DeclRef(ptr->Ptr(), list->substitutions).As<T>(); } }; - struct ContainerDeclRef : DeclRef - { - SLANG_DECLARE_DECL_REF(ContainerDecl); - - FilteredMemberRefList<DeclRef> GetMembers() const - { - return FilteredMemberRefList<DeclRef>(GetDecl()->Members, substitutions); - } + Iterator begin() const { return Iterator(this, Adjust(decls.begin(), decls.end()), decls.end()); } + Iterator end() const { return Iterator(this, decls.end(), decls.end()); } - template<typename T> - FilteredMemberRefList<T> GetMembersOfType() const + RefPtr<Decl>* Adjust(RefPtr<Decl>* ptr, RefPtr<Decl>* end) const + { + while (ptr != end) { - return FilteredMemberRefList<T>(GetDecl()->Members, substitutions); + DeclRef declRef(ptr->Ptr(), substitutions); + if (declRef.As<T>()) + return ptr; + ptr++; } + return end; + } + }; - }; + struct ContainerDeclRef : DeclRef + { + SLANG_DECLARE_DECL_REF(ContainerDecl); - // - // Type Expressions - // + FilteredMemberRefList<DeclRef> GetMembers() const + { + return FilteredMemberRefList<DeclRef>(GetDecl()->Members, substitutions); + } - // A "type expression" is a term that we expect to resolve to a type during checking. - // We store both the original syntax and the resolved type here. - struct TypeExp + template<typename T> + FilteredMemberRefList<T> GetMembersOfType() const { - TypeExp() {} - TypeExp(TypeExp const& other) - : exp(other.exp) - , type(other.type) - {} - explicit TypeExp(RefPtr<ExpressionSyntaxNode> exp) - : exp(exp) - {} - TypeExp(RefPtr<ExpressionSyntaxNode> exp, RefPtr<ExpressionType> type) - : exp(exp) - , type(type) - {} + return FilteredMemberRefList<T>(GetDecl()->Members, substitutions); + } - RefPtr<ExpressionSyntaxNode> exp; - RefPtr<ExpressionType> type; + }; - bool Equals(ExpressionType* other) { - return type->Equals(other); - } - bool Equals(RefPtr<ExpressionType> other) { - return type->Equals(other.Ptr()); - } - ExpressionType* Ptr() { return type.Ptr(); } - operator RefPtr<ExpressionType>() - { - return type; - } - ExpressionType* operator->() { return Ptr(); } + // + // Type Expressions + // - TypeExp Accept(SyntaxVisitor* visitor); - }; + // A "type expression" is a term that we expect to resolve to a type during checking. + // We store both the original syntax and the resolved type here. + struct TypeExp + { + TypeExp() {} + TypeExp(TypeExp const& other) + : exp(other.exp) + , type(other.type) + {} + explicit TypeExp(RefPtr<ExpressionSyntaxNode> exp) + : exp(exp) + {} + TypeExp(RefPtr<ExpressionSyntaxNode> exp, RefPtr<ExpressionType> type) + : exp(exp) + , type(type) + {} + + RefPtr<ExpressionSyntaxNode> exp; + RefPtr<ExpressionType> type; + + bool Equals(ExpressionType* other) { + return type->Equals(other); + } + bool Equals(RefPtr<ExpressionType> other) { + return type->Equals(other.Ptr()); + } + ExpressionType* Ptr() { return type.Ptr(); } + operator RefPtr<ExpressionType>() + { + return type; + } + ExpressionType* operator->() { return Ptr(); } + TypeExp Accept(SyntaxVisitor* visitor); + }; - // - // Declarations - // - // Base class for all variable-like declarations - class VarDeclBase : public Decl - { - public: - // Type of the variable - TypeExp Type; + // + // Declarations + // - ExpressionType* getType() { return Type.type.Ptr(); } + // Base class for all variable-like declarations + class VarDeclBase : public Decl + { + public: + // Type of the variable + TypeExp Type; - // Initializer expression (optional) - RefPtr<ExpressionSyntaxNode> Expr; - }; + ExpressionType* getType() { return Type.type.Ptr(); } - struct VarDeclBaseRef : DeclRef - { - SLANG_DECLARE_DECL_REF(VarDeclBase); + // Initializer expression (optional) + RefPtr<ExpressionSyntaxNode> Expr; + }; - RefPtr<ExpressionType> GetType() const { return Substitute(GetDecl()->Type); } + struct VarDeclBaseRef : DeclRef + { + SLANG_DECLARE_DECL_REF(VarDeclBase); - RefPtr<ExpressionSyntaxNode> getInitExpr() const { return Substitute(GetDecl()->Expr); } - }; + RefPtr<ExpressionType> GetType() const { return Substitute(GetDecl()->Type); } - // A field of a `struct` type - class StructField : public VarDeclBase - { - public: - StructField() - {} - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + RefPtr<ExpressionSyntaxNode> getInitExpr() const { return Substitute(GetDecl()->Expr); } + }; - struct FieldDeclRef : VarDeclBaseRef - { - SLANG_DECLARE_DECL_REF(StructField) - }; + // A field of a `struct` type + class StructField : public VarDeclBase + { + public: + StructField() + {} + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // An `AggTypeDeclBase` captures the shared functionality - // between true aggregate type declarations and extension - // declarations: - // - // - Both can container members (they are `ContainerDecl`s) - // - Both can have declared bases - // - Both expose a `this` variable in their body - // - class AggTypeDeclBase : public ContainerDecl - { - public: - }; + struct FieldDeclRef : VarDeclBaseRef + { + SLANG_DECLARE_DECL_REF(StructField) + }; + + // An `AggTypeDeclBase` captures the shared functionality + // between true aggregate type declarations and extension + // declarations: + // + // - Both can container members (they are `ContainerDecl`s) + // - Both can have declared bases + // - Both expose a `this` variable in their body + // + class AggTypeDeclBase : public ContainerDecl + { + public: + }; - struct AggTypeDeclBaseRef : ContainerDeclRef - { - SLANG_DECLARE_DECL_REF(AggTypeDeclBase); - }; + struct AggTypeDeclBaseRef : ContainerDeclRef + { + SLANG_DECLARE_DECL_REF(AggTypeDeclBase); + }; - // An extension to apply to an existing type - class ExtensionDecl : public AggTypeDeclBase - { - public: - TypeExp targetType; + // An extension to apply to an existing type + class ExtensionDecl : public AggTypeDeclBase + { + public: + TypeExp targetType; - // next extension attached to the same nominal type - ExtensionDecl* nextCandidateExtension = nullptr; + // next extension attached to the same nominal type + ExtensionDecl* nextCandidateExtension = nullptr; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - struct ExtensionDeclRef : AggTypeDeclBaseRef - { - SLANG_DECLARE_DECL_REF(ExtensionDecl); + struct ExtensionDeclRef : AggTypeDeclBaseRef + { + SLANG_DECLARE_DECL_REF(ExtensionDecl); - RefPtr<ExpressionType> GetTargetType() const { return Substitute(GetDecl()->targetType); } - }; + RefPtr<ExpressionType> GetTargetType() const { return Substitute(GetDecl()->targetType); } + }; - // Declaration of a type that represents some sort of aggregate - class AggTypeDecl : public AggTypeDeclBase - { - public: + // Declaration of a type that represents some sort of aggregate + class AggTypeDecl : public AggTypeDeclBase + { + public: - // extensions that might apply to this declaration - ExtensionDecl* candidateExtensions = nullptr; - FilteredMemberList<StructField> GetFields() - { - return GetMembersOfType<StructField>(); - } - StructField* FindField(String name) + // extensions that might apply to this declaration + ExtensionDecl* candidateExtensions = nullptr; + FilteredMemberList<StructField> GetFields() + { + return GetMembersOfType<StructField>(); + } + StructField* FindField(String name) + { + for (auto field : GetFields()) { - for (auto field : GetFields()) - { - if (field->Name.Content == name) - return field.Ptr(); - } - return nullptr; + if (field->Name.Content == name) + return field.Ptr(); } - int FindFieldIndex(String name) + return nullptr; + } + int FindFieldIndex(String name) + { + int index = 0; + for (auto field : GetFields()) { - int index = 0; - for (auto field : GetFields()) - { - if (field->Name.Content == name) - return index; - index++; - } - return -1; + if (field->Name.Content == name) + return index; + index++; } - }; + return -1; + } + }; - struct AggTypeDeclRef : public AggTypeDeclBaseRef - { - SLANG_DECLARE_DECL_REF(AggTypeDecl); + struct AggTypeDeclRef : public AggTypeDeclBaseRef + { + SLANG_DECLARE_DECL_REF(AggTypeDecl); - ExtensionDecl* GetCandidateExtensions() const { return GetDecl()->candidateExtensions; } - }; - - class StructSyntaxNode : public AggTypeDecl - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + ExtensionDecl* GetCandidateExtensions() const { return GetDecl()->candidateExtensions; } + }; - struct StructDeclRef : public AggTypeDeclRef - { - SLANG_DECLARE_DECL_REF(StructSyntaxNode); + class StructSyntaxNode : public AggTypeDecl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - FilteredMemberRefList<FieldDeclRef> GetFields() const { return GetMembersOfType<FieldDeclRef>(); } - }; + struct StructDeclRef : public AggTypeDeclRef + { + SLANG_DECLARE_DECL_REF(StructSyntaxNode); - class ClassSyntaxNode : public AggTypeDecl - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + FilteredMemberRefList<FieldDeclRef> GetFields() const { return GetMembersOfType<FieldDeclRef>(); } + }; - struct ClassDeclRef : public AggTypeDeclRef - { - SLANG_DECLARE_DECL_REF(ClassSyntaxNode); + class ClassSyntaxNode : public AggTypeDecl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - FilteredMemberRefList<FieldDeclRef> GetFields() const { return GetMembersOfType<FieldDeclRef>(); } - }; + struct ClassDeclRef : public AggTypeDeclRef + { + SLANG_DECLARE_DECL_REF(ClassSyntaxNode); - // An interface which other types can conform to - class InterfaceDecl : public AggTypeDecl - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + FilteredMemberRefList<FieldDeclRef> GetFields() const { return GetMembersOfType<FieldDeclRef>(); } + }; - struct InterfaceDeclRef : public AggTypeDeclRef - { - SLANG_DECLARE_DECL_REF(InterfaceDecl); - }; + // An interface which other types can conform to + class InterfaceDecl : public AggTypeDecl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + struct InterfaceDeclRef : public AggTypeDeclRef + { + SLANG_DECLARE_DECL_REF(InterfaceDecl); + }; - // A kind of pseudo-member that represents an explicit - // or implicit inheritance relationship. - // - class InheritanceDecl : public Decl - { - public: - // The type expression as written - TypeExp base; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // A kind of pseudo-member that represents an explicit + // or implicit inheritance relationship. + // + class InheritanceDecl : public Decl + { + public: + // The type expression as written + TypeExp base; - struct InheritanceDeclRef : public DeclRef - { - SLANG_DECLARE_DECL_REF(InheritanceDecl); + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - RefPtr<ExpressionType> getBaseType() { return Substitute(GetDecl()->base.type); } - }; + struct InheritanceDeclRef : public DeclRef + { + SLANG_DECLARE_DECL_REF(InheritanceDecl); - // TODO: may eventually need sub-classes for explicit/direct vs. implicit/indirect inheritance + RefPtr<ExpressionType> getBaseType() { return Substitute(GetDecl()->base.type); } + }; + // TODO: may eventually need sub-classes for explicit/direct vs. implicit/indirect inheritance - // A declaration that represents a simple (non-aggregate) type - class SimpleTypeDecl : public Decl - { - }; - struct SimpleTypeDeclRef : DeclRef - { - SLANG_DECLARE_DECL_REF(SimpleTypeDecl) - }; + // A declaration that represents a simple (non-aggregate) type + class SimpleTypeDecl : public Decl + { + }; - // A `typedef` declaration - class TypeDefDecl : public SimpleTypeDecl - { - public: - TypeExp Type; + struct SimpleTypeDeclRef : DeclRef + { + SLANG_DECLARE_DECL_REF(SimpleTypeDecl) + }; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // A `typedef` declaration + class TypeDefDecl : public SimpleTypeDecl + { + public: + TypeExp Type; - struct TypeDefDeclRef : SimpleTypeDeclRef - { - SLANG_DECLARE_DECL_REF(TypeDefDecl); + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - RefPtr<ExpressionType> GetType() const { return Substitute(GetDecl()->Type); } - }; + struct TypeDefDeclRef : SimpleTypeDeclRef + { + SLANG_DECLARE_DECL_REF(TypeDefDecl); - // A type alias of some kind (e.g., via `typedef`) - class NamedExpressionType : public ExpressionType - { - public: - NamedExpressionType(TypeDefDeclRef declRef) - : declRef(declRef) - {} + RefPtr<ExpressionType> GetType() const { return Substitute(GetDecl()->Type); } + }; - TypeDefDeclRef declRef; + // A type alias of some kind (e.g., via `typedef`) + class NamedExpressionType : public ExpressionType + { + public: + NamedExpressionType(TypeDefDeclRef declRef) + : declRef(declRef) + {} - virtual String ToString() override; + TypeDefDeclRef declRef; - protected: - virtual bool EqualsImpl(ExpressionType * type) override; - virtual ExpressionType* CreateCanonicalType() override; - virtual int GetHashCode() override; - }; + virtual String ToString() override; + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; - class StatementSyntaxNode : public ModifiableSyntaxNode - { - public: - }; - // A scope for local declarations (e.g., as part of a statement) - class ScopeDecl : public ContainerDecl - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class StatementSyntaxNode : public ModifiableSyntaxNode + { + public: + }; - class ScopeStmt : public StatementSyntaxNode - { - public: - RefPtr<ScopeDecl> scopeDecl; - }; + // A scope for local declarations (e.g., as part of a statement) + class ScopeDecl : public ContainerDecl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class BlockStatementSyntaxNode : public ScopeStmt - { - public: - List<RefPtr<StatementSyntaxNode>> Statements; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class ScopeStmt : public StatementSyntaxNode + { + public: + RefPtr<ScopeDecl> scopeDecl; + }; - class UnparsedStmt : public StatementSyntaxNode - { - public: - // The tokens that were contained between `{` and `}` - List<Token> tokens; + class BlockStatementSyntaxNode : public ScopeStmt + { + public: + List<RefPtr<StatementSyntaxNode>> Statements; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class UnparsedStmt : public StatementSyntaxNode + { + public: + // The tokens that were contained between `{` and `}` + List<Token> tokens; - class ParameterSyntaxNode : public VarDeclBase - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - struct ParamDeclRef : VarDeclBaseRef - { - SLANG_DECLARE_DECL_REF(ParameterSyntaxNode); - }; + class ParameterSyntaxNode : public VarDeclBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // Base class for things that have parameter lists and can thus be applied to arguments ("called") - class CallableDecl : public ContainerDecl - { - public: - FilteredMemberList<ParameterSyntaxNode> GetParameters() - { - return GetMembersOfType<ParameterSyntaxNode>(); - } - TypeExp ReturnType; - }; + struct ParamDeclRef : VarDeclBaseRef + { + SLANG_DECLARE_DECL_REF(ParameterSyntaxNode); + }; - struct CallableDeclRef : ContainerDeclRef + // Base class for things that have parameter lists and can thus be applied to arguments ("called") + class CallableDecl : public ContainerDecl + { + public: + FilteredMemberList<ParameterSyntaxNode> GetParameters() { - SLANG_DECLARE_DECL_REF(CallableDecl); - - RefPtr<ExpressionType> GetResultType() const - { - return Substitute(GetDecl()->ReturnType.type.Ptr()); - } - - FilteredMemberRefList<ParamDeclRef> GetParameters() - { - return GetMembersOfType<ParamDeclRef>(); - } - }; + return GetMembersOfType<ParameterSyntaxNode>(); + } + TypeExp ReturnType; + }; - // Base class for callable things that may also have a body that is evaluated to produce their result - class FunctionDeclBase : public CallableDecl - { - public: - RefPtr<StatementSyntaxNode> Body; - }; + struct CallableDeclRef : ContainerDeclRef + { + SLANG_DECLARE_DECL_REF(CallableDecl); - struct FuncDeclBaseRef : CallableDeclRef + RefPtr<ExpressionType> GetResultType() const { - SLANG_DECLARE_DECL_REF(FunctionDeclBase); - }; + return Substitute(GetDecl()->ReturnType.type.Ptr()); + } - // Function types are currently used for references to symbols that name - // either ordinary functions, or "component functions." - // We do not directly store a representation of the type, and instead - // use a reference to the symbol to stand in for its logical type - class FuncType : public ExpressionType + FilteredMemberRefList<ParamDeclRef> GetParameters() { - public: - CallableDeclRef declRef; - - virtual String ToString() override; - protected: - virtual bool EqualsImpl(ExpressionType * type) override; - virtual ExpressionType* CreateCanonicalType() override; - virtual int GetHashCode() override; - }; + return GetMembersOfType<ParamDeclRef>(); + } + }; - // A constructor/initializer to create instances of a type - class ConstructorDecl : public FunctionDeclBase - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // Base class for callable things that may also have a body that is evaluated to produce their result + class FunctionDeclBase : public CallableDecl + { + public: + RefPtr<StatementSyntaxNode> Body; + }; - struct ConstructorDeclRef : FuncDeclBaseRef - { - SLANG_DECLARE_DECL_REF(ConstructorDecl); - }; + struct FuncDeclBaseRef : CallableDeclRef + { + SLANG_DECLARE_DECL_REF(FunctionDeclBase); + }; + + // Function types are currently used for references to symbols that name + // either ordinary functions, or "component functions." + // We do not directly store a representation of the type, and instead + // use a reference to the symbol to stand in for its logical type + class FuncType : public ExpressionType + { + public: + CallableDeclRef declRef; + + virtual String ToString() override; + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual ExpressionType* CreateCanonicalType() override; + virtual int GetHashCode() override; + }; + + // A constructor/initializer to create instances of a type + class ConstructorDecl : public FunctionDeclBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // A subscript operation used to index instances of a type - class SubscriptDecl : public CallableDecl - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + struct ConstructorDeclRef : FuncDeclBaseRef + { + SLANG_DECLARE_DECL_REF(ConstructorDecl); + }; - struct SubscriptDeclRef : CallableDeclRef - { - SLANG_DECLARE_DECL_REF(SubscriptDecl); - }; + // A subscript operation used to index instances of a type + class SubscriptDecl : public CallableDecl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // An "accessor" for a subscript or property - class AccessorDecl : public FunctionDeclBase - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + struct SubscriptDeclRef : CallableDeclRef + { + SLANG_DECLARE_DECL_REF(SubscriptDecl); + }; - class GetterDecl : public AccessorDecl - { - }; + // An "accessor" for a subscript or property + class AccessorDecl : public FunctionDeclBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class SetterDecl : public AccessorDecl - { - }; + class GetterDecl : public AccessorDecl + { + }; - // + class SetterDecl : public AccessorDecl + { + }; - class FunctionSyntaxNode : public FunctionDeclBase - { - public: - String InternalName; - bool IsInline() { return HasModifier<InlineModifier>(); } - bool IsExtern() { return HasModifier<ExternModifier>(); } - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - FunctionSyntaxNode() - { - } - }; + // - struct FuncDeclRef : FuncDeclBaseRef + class FunctionSyntaxNode : public FunctionDeclBase + { + public: + String InternalName; + bool IsInline() { return HasModifier<InlineModifier>(); } + bool IsExtern() { return HasModifier<ExternModifier>(); } + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + FunctionSyntaxNode() { - SLANG_DECLARE_DECL_REF(FunctionSyntaxNode); - }; + } + }; + struct FuncDeclRef : FuncDeclBaseRef + { + SLANG_DECLARE_DECL_REF(FunctionSyntaxNode); + }; - struct Scope : public RefObject - { - // The parent of this scope (where lookup should go if nothing is found locally) - RefPtr<Scope> parent; - // The next sibling of this scope (a peer for lookup) - RefPtr<Scope> nextSibling; + struct Scope : public RefObject + { + // The parent of this scope (where lookup should go if nothing is found locally) + RefPtr<Scope> parent; - // The container to use for lookup - // - // Note(tfoley): This is kept as an unowned pointer - // so that a scope can't keep parts of the AST alive, - // but the opposite it allowed. - ContainerDecl* containerDecl; - }; + // The next sibling of this scope (a peer for lookup) + RefPtr<Scope> nextSibling; - // Base class for expressions that will reference declarations - class DeclRefExpr : public ExpressionSyntaxNode - { - public: - // The scope in which to perform lookup - RefPtr<Scope> scope; + // The container to use for lookup + // + // Note(tfoley): This is kept as an unowned pointer + // so that a scope can't keep parts of the AST alive, + // but the opposite it allowed. + ContainerDecl* containerDecl; + }; + + // Base class for expressions that will reference declarations + class DeclRefExpr : public ExpressionSyntaxNode + { + public: + // The scope in which to perform lookup + RefPtr<Scope> scope; - // The declaration of the symbol being referenced - DeclRef declRef; - }; + // The declaration of the symbol being referenced + DeclRef declRef; + }; - class VarExpressionSyntaxNode : public DeclRefExpr - { - public: - // The name of the symbol being referenced - String Variable; + class VarExpressionSyntaxNode : public DeclRefExpr + { + public: + // The name of the symbol being referenced + String Variable; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // Masks to be applied when lookup up declarations - enum class LookupMask : uint8_t - { - Type = 0x1, - Function = 0x2, - Value = 0x4, + // Masks to be applied when lookup up declarations + enum class LookupMask : uint8_t + { + Type = 0x1, + Function = 0x2, + Value = 0x4, - All = Type | Function | Value, - }; + All = Type | Function | Value, + }; - // Represents one item found during lookup - struct LookupResultItem + // Represents one item found during lookup + struct LookupResultItem + { + // Sometimes lookup finds an item, but there were additional + // "hops" taken to reach it. We need to remember these steps + // so that if/when we consturct a full expression we generate + // appropriate AST nodes for all the steps. + // + // We build up a list of these "breadcrumbs" while doing + // lookup, and store them alongside each item found. + class Breadcrumb : public RefObject { - // Sometimes lookup finds an item, but there were additional - // "hops" taken to reach it. We need to remember these steps - // so that if/when we consturct a full expression we generate - // appropriate AST nodes for all the steps. - // - // We build up a list of these "breadcrumbs" while doing - // lookup, and store them alongside each item found. - class Breadcrumb : public RefObject + public: + enum class Kind { - public: - enum class Kind - { - Member, // A member was references - Deref, // A value with pointer(-like) type was dereferenced - }; - - Kind kind; - DeclRef declRef; - RefPtr<Breadcrumb> next; - - Breadcrumb(Kind kind, DeclRef declRef, RefPtr<Breadcrumb> next) - : kind(kind) - , declRef(declRef) - , next(next) - {} + Member, // A member was references + Deref, // A value with pointer(-like) type was dereferenced }; - // A properly-specialized reference to the declaration that was found. + Kind kind; DeclRef declRef; + RefPtr<Breadcrumb> next; - // Any breadcrumbs needed in order to turn that declaration - // reference into a well-formed expression. - // - // This is unused in the simple case where a declaration - // is being referenced directly (rather than through - // transparent members). - RefPtr<Breadcrumb> breadcrumbs; - - LookupResultItem() = default; - explicit LookupResultItem(DeclRef declRef) - : declRef(declRef) - {} - LookupResultItem(DeclRef declRef, RefPtr<Breadcrumb> breadcrumbs) - : declRef(declRef) - , breadcrumbs(breadcrumbs) + Breadcrumb(Kind kind, DeclRef declRef, RefPtr<Breadcrumb> next) + : kind(kind) + , declRef(declRef) + , next(next) {} }; + // A properly-specialized reference to the declaration that was found. + DeclRef declRef; - // Result of looking up a name in some lexical/semantic environment. - // Can be used to enumerate all the declarations matching that name, - // in the case where the result is overloaded. - struct LookupResult - { - // The one item that was found, in the smple case - LookupResultItem item; + // Any breadcrumbs needed in order to turn that declaration + // reference into a well-formed expression. + // + // This is unused in the simple case where a declaration + // is being referenced directly (rather than through + // transparent members). + RefPtr<Breadcrumb> breadcrumbs; + + LookupResultItem() = default; + explicit LookupResultItem(DeclRef declRef) + : declRef(declRef) + {} + LookupResultItem(DeclRef declRef, RefPtr<Breadcrumb> breadcrumbs) + : declRef(declRef) + , breadcrumbs(breadcrumbs) + {} + }; + + + // Result of looking up a name in some lexical/semantic environment. + // Can be used to enumerate all the declarations matching that name, + // in the case where the result is overloaded. + struct LookupResult + { + // The one item that was found, in the smple case + LookupResultItem item; - // All of the items that were found, in the complex case. - // Note: if there was no overloading, then this list isn't - // used at all, to avoid allocation. - List<LookupResultItem> items; + // All of the items that were found, in the complex case. + // Note: if there was no overloading, then this list isn't + // used at all, to avoid allocation. + List<LookupResultItem> items; - // Was at least one result found? - bool isValid() const { return item.declRef.GetDecl() != nullptr; } + // Was at least one result found? + bool isValid() const { return item.declRef.GetDecl() != nullptr; } - bool isOverloaded() const { return items.Count() > 1; } - }; + bool isOverloaded() const { return items.Count() > 1; } + }; - struct LookupRequest - { - RefPtr<Scope> scope = nullptr; - RefPtr<Scope> endScope = nullptr; + struct LookupRequest + { + RefPtr<Scope> scope = nullptr; + RefPtr<Scope> endScope = nullptr; - LookupMask mask = LookupMask::All; - }; + LookupMask mask = LookupMask::All; + }; - // An expression that references an overloaded set of declarations - // having the same name. - class OverloadedExpr : public ExpressionSyntaxNode - { - public: - // Optional: the base expression is this overloaded result - // arose from a member-reference expression. - RefPtr<ExpressionSyntaxNode> base; + // An expression that references an overloaded set of declarations + // having the same name. + class OverloadedExpr : public ExpressionSyntaxNode + { + public: + // Optional: the base expression is this overloaded result + // arose from a member-reference expression. + RefPtr<ExpressionSyntaxNode> base; - // The lookup result that was ambiguous - LookupResult lookupResult2; + // The lookup result that was ambiguous + LookupResult lookupResult2; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - typedef double FloatingPointLiteralValue; + typedef double FloatingPointLiteralValue; - class ConstantExpressionSyntaxNode : public ExpressionSyntaxNode + class ConstantExpressionSyntaxNode : public ExpressionSyntaxNode + { + public: + enum class ConstantType { - public: - enum class ConstantType - { - Int, Bool, Float - }; - ConstantType ConstType; - union - { - int IntValue; - FloatingPointLiteralValue FloatValue; - }; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + Int, Bool, Float }; - - enum class Operator - { - Neg, Not, BitNot, PreInc, PreDec, PostInc, PostDec, - Mul, Div, Mod, - Add, Sub, - Lsh, Rsh, - Eql, Neq, Greater, Less, Geq, Leq, - BitAnd, BitXor, BitOr, - And, - Or, - Sequence, - Select, - Assign = 200, AddAssign, SubAssign, MulAssign, DivAssign, ModAssign, - LshAssign, RshAssign, OrAssign, AndAssign, XorAssign, + ConstantType ConstType; + union + { + int IntValue; + FloatingPointLiteralValue FloatValue; }; - String GetOperatorFunctionName(Operator op); - String OperatorToString(Operator op); + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // An initializer list, e.g. `{ 1, 2, 3 }` - class InitializerListExpr : public ExpressionSyntaxNode - { - public: - List<RefPtr<ExpressionSyntaxNode>> args; + enum class Operator + { + Neg, Not, BitNot, PreInc, PreDec, PostInc, PostDec, + Mul, Div, Mod, + Add, Sub, + Lsh, Rsh, + Eql, Neq, Greater, Less, Geq, Leq, + BitAnd, BitXor, BitOr, + And, + Or, + Sequence, + Select, + Assign = 200, AddAssign, SubAssign, MulAssign, DivAssign, ModAssign, + LshAssign, RshAssign, OrAssign, AndAssign, XorAssign, + }; + String GetOperatorFunctionName(Operator op); + String OperatorToString(Operator op); + + // An initializer list, e.g. `{ 1, 2, 3 }` + class InitializerListExpr : public ExpressionSyntaxNode + { + public: + List<RefPtr<ExpressionSyntaxNode>> args; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // A base expression being applied to arguments: covers - // both ordinary `()` function calls and `<>` generic application - class AppExprBase : public ExpressionSyntaxNode - { - public: - RefPtr<ExpressionSyntaxNode> FunctionExpr; - List<RefPtr<ExpressionSyntaxNode>> Arguments; - }; + // A base expression being applied to arguments: covers + // both ordinary `()` function calls and `<>` generic application + class AppExprBase : public ExpressionSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> FunctionExpr; + List<RefPtr<ExpressionSyntaxNode>> Arguments; + }; - class InvokeExpressionSyntaxNode : public AppExprBase - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class InvokeExpressionSyntaxNode : public AppExprBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class OperatorExpressionSyntaxNode : public InvokeExpressionSyntaxNode - { - public: + class OperatorExpressionSyntaxNode : public InvokeExpressionSyntaxNode + { + public: // Operator Operator; -// void SetOperator(RefPtr<Scope> scope, Slang::Compiler::Operator op); - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; +// void SetOperator(RefPtr<Scope> scope, Slang::Operator op); + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class InfixExpr : public OperatorExpressionSyntaxNode {}; - class PrefixExpr : public OperatorExpressionSyntaxNode {}; - class PostfixExpr : public OperatorExpressionSyntaxNode {}; + class InfixExpr : public OperatorExpressionSyntaxNode {}; + class PrefixExpr : public OperatorExpressionSyntaxNode {}; + class PostfixExpr : public OperatorExpressionSyntaxNode {}; - class IndexExpressionSyntaxNode : public ExpressionSyntaxNode - { - public: - RefPtr<ExpressionSyntaxNode> BaseExpression; - RefPtr<ExpressionSyntaxNode> IndexExpression; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class IndexExpressionSyntaxNode : public ExpressionSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> BaseExpression; + RefPtr<ExpressionSyntaxNode> IndexExpression; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class MemberExpressionSyntaxNode : public DeclRefExpr - { - public: - RefPtr<ExpressionSyntaxNode> BaseExpression; - String MemberName; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class MemberExpressionSyntaxNode : public DeclRefExpr + { + public: + RefPtr<ExpressionSyntaxNode> BaseExpression; + String MemberName; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class SwizzleExpr : public ExpressionSyntaxNode - { - public: - RefPtr<ExpressionSyntaxNode> base; - int elementCount; - int elementIndices[4]; + class SwizzleExpr : public ExpressionSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> base; + int elementCount; + int elementIndices[4]; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // A dereference of a pointer or pointer-like type - class DerefExpr : public ExpressionSyntaxNode - { - public: - RefPtr<ExpressionSyntaxNode> base; + // A dereference of a pointer or pointer-like type + class DerefExpr : public ExpressionSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> base; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class TypeCastExpressionSyntaxNode : public ExpressionSyntaxNode - { - public: - TypeExp TargetType; - RefPtr<ExpressionSyntaxNode> Expression; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class TypeCastExpressionSyntaxNode : public ExpressionSyntaxNode + { + public: + TypeExp TargetType; + RefPtr<ExpressionSyntaxNode> Expression; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class SelectExpressionSyntaxNode : public OperatorExpressionSyntaxNode - { - public: - }; + class SelectExpressionSyntaxNode : public OperatorExpressionSyntaxNode + { + public: + }; - class EmptyStatementSyntaxNode : public StatementSyntaxNode - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class EmptyStatementSyntaxNode : public StatementSyntaxNode + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class DiscardStatementSyntaxNode : public StatementSyntaxNode - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class DiscardStatementSyntaxNode : public StatementSyntaxNode + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - struct Variable : public VarDeclBase - { - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + struct Variable : public VarDeclBase + { + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class VarDeclrStatementSyntaxNode : public StatementSyntaxNode + class VarDeclrStatementSyntaxNode : public StatementSyntaxNode + { + public: + RefPtr<DeclBase> decl; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // A "module" of code (essentiately, a single translation unit) + // that provides a scope for some number of declarations. + class ProgramSyntaxNode : public ContainerDecl + { + public: + // Access members of specific types + FilteredMemberList<FunctionSyntaxNode> GetFunctions() { - public: - RefPtr<DeclBase> decl; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + return GetMembersOfType<FunctionSyntaxNode>(); + } - // A "module" of code (essentiately, a single translation unit) - // that provides a scope for some number of declarations. - class ProgramSyntaxNode : public ContainerDecl + FilteredMemberList<ClassSyntaxNode> GetClasses() { - public: - // Access members of specific types - FilteredMemberList<FunctionSyntaxNode> GetFunctions() - { - return GetMembersOfType<FunctionSyntaxNode>(); - } + return GetMembersOfType<ClassSyntaxNode>(); + } + FilteredMemberList<StructSyntaxNode> GetStructs() + { + return GetMembersOfType<StructSyntaxNode>(); + } + FilteredMemberList<TypeDefDecl> GetTypeDefs() + { + return GetMembersOfType<TypeDefDecl>(); + } - FilteredMemberList<ClassSyntaxNode> GetClasses() - { - return GetMembersOfType<ClassSyntaxNode>(); - } - FilteredMemberList<StructSyntaxNode> GetStructs() - { - return GetMembersOfType<StructSyntaxNode>(); - } - FilteredMemberList<TypeDefDecl> GetTypeDefs() - { - return GetMembersOfType<TypeDefDecl>(); - } + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class ImportDecl : public Decl + { + public: + // The name of the module we are trying to import + Token nameToken; - class ImportDecl : public Decl - { - public: - // The name of the module we are trying to import - Token nameToken; + // The scope that we want to import into + RefPtr<Scope> scope; - // The scope that we want to import into - RefPtr<Scope> scope; + // The module that actually got imported + RefPtr<ProgramSyntaxNode> importedModuleDecl; - // The module that actually got imported - RefPtr<ProgramSyntaxNode> importedModuleDecl; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class IfStatementSyntaxNode : public StatementSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> Predicate; + RefPtr<StatementSyntaxNode> PositiveStatement; + RefPtr<StatementSyntaxNode> NegativeStatement; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // A statement that can be escaped with a `break` + class BreakableStmt : public ScopeStmt + {}; + + class SwitchStmt : public BreakableStmt + { + public: + RefPtr<ExpressionSyntaxNode> condition; + RefPtr<StatementSyntaxNode> body; - class IfStatementSyntaxNode : public StatementSyntaxNode - { - public: - RefPtr<ExpressionSyntaxNode> Predicate; - RefPtr<StatementSyntaxNode> PositiveStatement; - RefPtr<StatementSyntaxNode> NegativeStatement; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // A statement that can be escaped with a `break` - class BreakableStmt : public ScopeStmt - {}; + // A statement that is expected to appear lexically nested inside + // some other construct, and thus needs to keep track of the + // outer statement that it is associated with... + class ChildStmt : public StatementSyntaxNode + { + public: + StatementSyntaxNode* parentStmt = nullptr; + }; + + // a `case` or `default` statement inside a `switch` + // + // Note(tfoley): A correct AST for a C-like language would treat + // these as a labelled statement, and so they would contain a + // sub-statement. I'm leaving that out for now for simplicity. + class CaseStmtBase : public ChildStmt + { + public: + }; - class SwitchStmt : public BreakableStmt - { - public: - RefPtr<ExpressionSyntaxNode> condition; - RefPtr<StatementSyntaxNode> body; + // a `case` statement inside a `switch` + class CaseStmt : public CaseStmtBase + { + public: + RefPtr<ExpressionSyntaxNode> expr; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // A statement that is expected to appear lexically nested inside - // some other construct, and thus needs to keep track of the - // outer statement that it is associated with... - class ChildStmt : public StatementSyntaxNode - { - public: - StatementSyntaxNode* parentStmt = nullptr; - }; + // a `default` statement inside a `switch` + class DefaultStmt : public CaseStmtBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // a `case` or `default` statement inside a `switch` - // - // Note(tfoley): A correct AST for a C-like language would treat - // these as a labelled statement, and so they would contain a - // sub-statement. I'm leaving that out for now for simplicity. - class CaseStmtBase : public ChildStmt - { - public: - }; + // A statement that represents a loop, and can thus be escaped with a `continue` + class LoopStmt : public BreakableStmt + {}; - // a `case` statement inside a `switch` - class CaseStmt : public CaseStmtBase - { - public: - RefPtr<ExpressionSyntaxNode> expr; + class ForStatementSyntaxNode : public LoopStmt + { + public: + RefPtr<StatementSyntaxNode> InitialStatement; + RefPtr<ExpressionSyntaxNode> SideEffectExpression, PredicateExpression; + RefPtr<StatementSyntaxNode> Statement; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + class WhileStatementSyntaxNode : public LoopStmt + { + public: + RefPtr<ExpressionSyntaxNode> Predicate; + RefPtr<StatementSyntaxNode> Statement; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class DoWhileStatementSyntaxNode : public LoopStmt + { + public: + RefPtr<StatementSyntaxNode> Statement; + RefPtr<ExpressionSyntaxNode> Predicate; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // The case of child statements that do control flow relative + // to their parent statement. + class JumpStmt : public ChildStmt + { + public: + StatementSyntaxNode* parentStmt = nullptr; + }; - // a `default` statement inside a `switch` - class DefaultStmt : public CaseStmtBase - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class BreakStatementSyntaxNode : public JumpStmt + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // A statement that represents a loop, and can thus be escaped with a `continue` - class LoopStmt : public BreakableStmt - {}; + class ContinueStatementSyntaxNode : public JumpStmt + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class ForStatementSyntaxNode : public LoopStmt - { - public: - RefPtr<StatementSyntaxNode> InitialStatement; - RefPtr<ExpressionSyntaxNode> SideEffectExpression, PredicateExpression; - RefPtr<StatementSyntaxNode> Statement; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class ReturnStatementSyntaxNode : public StatementSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> Expression; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class WhileStatementSyntaxNode : public LoopStmt - { - public: - RefPtr<ExpressionSyntaxNode> Predicate; - RefPtr<StatementSyntaxNode> Statement; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + class ExpressionStatementSyntaxNode : public StatementSyntaxNode + { + public: + RefPtr<ExpressionSyntaxNode> Expression; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // Note(tfoley): Moved this further down in the file because it depends on + // `ExpressionSyntaxNode` and a forward reference just isn't good enough + // for `RefPtr`. + // + class GenericAppExpr : public AppExprBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class DoWhileStatementSyntaxNode : public LoopStmt - { - public: - RefPtr<StatementSyntaxNode> Statement; - RefPtr<ExpressionSyntaxNode> Predicate; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // An expression representing re-use of the syntax for a type in more + // than once conceptually-distinct declaration + class SharedTypeExpr : public ExpressionSyntaxNode + { + public: + // The underlying type expression that we want to share + TypeExp base; - // The case of child statements that do control flow relative - // to their parent statement. - class JumpStmt : public ChildStmt - { - public: - StatementSyntaxNode* parentStmt = nullptr; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - class BreakStatementSyntaxNode : public JumpStmt - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; - class ContinueStatementSyntaxNode : public JumpStmt - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // A modifier that indicates a built-in base type (e.g., `float`) + class BuiltinTypeModifier : public Modifier + { + public: + BaseType tag; + }; + + // A modifier that indicates a built-in type that isn't a base type (e.g., `vector`) + // + // TODO(tfoley): This deserves a better name than "magic" + class MagicTypeModifier : public Modifier + { + public: + String name; + uint32_t tag; + }; - class ReturnStatementSyntaxNode : public StatementSyntaxNode - { - public: - RefPtr<ExpressionSyntaxNode> Expression; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // Modifiers that affect the storage layout for matrices + class MatrixLayoutModifier : public Modifier {}; - class ExpressionStatementSyntaxNode : public StatementSyntaxNode - { - public: - RefPtr<ExpressionSyntaxNode> Expression; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // Modifiers that specify row- and column-major layout, respectively + class RowMajorLayoutModifier : public MatrixLayoutModifier {}; + class ColumnMajorLayoutModifier : public MatrixLayoutModifier {}; - // Note(tfoley): Moved this further down in the file because it depends on - // `ExpressionSyntaxNode` and a forward reference just isn't good enough - // for `RefPtr`. - // - class GenericAppExpr : public AppExprBase - { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // The HLSL flavor of those modifiers + class HLSLRowMajorLayoutModifier : public RowMajorLayoutModifier {}; + class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier {}; - // An expression representing re-use of the syntax for a type in more - // than once conceptually-distinct declaration - class SharedTypeExpr : public ExpressionSyntaxNode - { - public: - // The underlying type expression that we want to share - TypeExp base; + // The GLSL flavor of those modifiers + // + // Note(tfoley): The GLSL versions of these modifiers are "backwards" + // in the sense that when a GLSL programmer requests row-major layout, + // we actually interpret that as requesting column-major. This makes + // sense because we interpret matrix conventions backwards from how + // GLSL specifies them. + class GLSLRowMajorLayoutModifier : public ColumnMajorLayoutModifier {}; + class GLSLColumnMajorLayoutModifier : public RowMajorLayoutModifier {}; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // More HLSL Keyword + // HLSL `nointerpolation` modifier + class HLSLNoInterpolationModifier : public Modifier {}; - // A modifier that indicates a built-in base type (e.g., `float`) - class BuiltinTypeModifier : public Modifier - { - public: - BaseType tag; - }; + // HLSL `linear` modifier + class HLSLLinearModifier : public Modifier {}; - // A modifier that indicates a built-in type that isn't a base type (e.g., `vector`) - // - // TODO(tfoley): This deserves a better name than "magic" - class MagicTypeModifier : public Modifier - { - public: - String name; - uint32_t tag; - }; + // HLSL `sample` modifier + class HLSLSampleModifier : public Modifier {}; - // Modifiers that affect the storage layout for matrices - class MatrixLayoutModifier : public Modifier {}; + // HLSL `centroid` modifier + class HLSLCentroidModifier : public Modifier {}; - // Modifiers that specify row- and column-major layout, respectively - class RowMajorLayoutModifier : public MatrixLayoutModifier {}; - class ColumnMajorLayoutModifier : public MatrixLayoutModifier {}; + // HLSL `precise` modifier + class HLSLPreciseModifier : public Modifier {}; - // The HLSL flavor of those modifiers - class HLSLRowMajorLayoutModifier : public RowMajorLayoutModifier {}; - class HLSLColumnMajorLayoutModifier : public ColumnMajorLayoutModifier {}; + // HLSL `shared` modifier (which is used by the effect system, + // and shouldn't be confused with `groupshared`) + class HLSLEffectSharedModifier : public Modifier {}; - // The GLSL flavor of those modifiers - // - // Note(tfoley): The GLSL versions of these modifiers are "backwards" - // in the sense that when a GLSL programmer requests row-major layout, - // we actually interpret that as requesting column-major. This makes - // sense because we interpret matrix conventions backwards from how - // GLSL specifies them. - class GLSLRowMajorLayoutModifier : public ColumnMajorLayoutModifier {}; - class GLSLColumnMajorLayoutModifier : public RowMajorLayoutModifier {}; + // HLSL `groupshared` modifier + class HLSLGroupSharedModifier : public Modifier {}; - // More HLSL Keyword + // HLSL `static` modifier (probably doesn't need to be + // treated as HLSL-specific) + class HLSLStaticModifier : public Modifier {}; - // HLSL `nointerpolation` modifier - class HLSLNoInterpolationModifier : public Modifier {}; + // HLSL `uniform` modifier (distinct meaning from GLSL + // use of the keyword) + class HLSLUniformModifier : public Modifier {}; - // HLSL `linear` modifier - class HLSLLinearModifier : public Modifier {}; + // HLSL `volatile` modifier (ignored) + class HLSLVolatileModifier : public Modifier {}; - // HLSL `sample` modifier - class HLSLSampleModifier : public Modifier {}; + // An HLSL `[name(arg0, ...)]` style attribute. + class HLSLAttribute : public Modifier + { + public: + Token nameToken; + List<RefPtr<ExpressionSyntaxNode>> args; + }; + + // An HLSL `[name(...)]` attribute that hasn't undergone + // any semantic analysis. + // After analysis, this might be transformed into a more specific case. + class HLSLUncheckedAttribute : public HLSLAttribute + { + public: + }; - // HLSL `centroid` modifier - class HLSLCentroidModifier : public Modifier {}; + // An HLSL `[numthreads(x,y,z)]` attribute + class HLSLNumThreadsAttribute : public HLSLAttribute + { + public: + // The number of threads to use along each axis + int32_t x; + int32_t y; + int32_t z; + }; + + // HLSL modifiers for geometry shader input topology + class HLSLGeometryShaderInputPrimitiveTypeModifier : public Modifier {}; + class HLSLPointModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + class HLSLLineModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + class HLSLTriangleModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + class HLSLLineAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + class HLSLTriangleAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; + + // + + // A generic declaration, parameterized on types/values + class GenericDecl : public ContainerDecl + { + public: + // The decl that is genericized... + RefPtr<Decl> inner; - // HLSL `precise` modifier - class HLSLPreciseModifier : public Modifier {}; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - // HLSL `shared` modifier (which is used by the effect system, - // and shouldn't be confused with `groupshared`) - class HLSLEffectSharedModifier : public Modifier {}; + struct GenericDeclRef : ContainerDeclRef + { + SLANG_DECLARE_DECL_REF(GenericDecl); - // HLSL `groupshared` modifier - class HLSLGroupSharedModifier : public Modifier {}; + Decl* GetInner() const { return GetDecl()->inner.Ptr(); } + }; - // HLSL `static` modifier (probably doesn't need to be - // treated as HLSL-specific) - class HLSLStaticModifier : public Modifier {}; + // The "type" of an expression that names a generic declaration. + class GenericDeclRefType : public ExpressionType + { + public: + GenericDeclRefType(GenericDeclRef declRef) + : declRef(declRef) + {} - // HLSL `uniform` modifier (distinct meaning from GLSL - // use of the keyword) - class HLSLUniformModifier : public Modifier {}; + GenericDeclRef declRef; + GenericDeclRef const& GetDeclRef() const { return declRef; } - // HLSL `volatile` modifier (ignored) - class HLSLVolatileModifier : public Modifier {}; + virtual String ToString() override; - // An HLSL `[name(arg0, ...)]` style attribute. - class HLSLAttribute : public Modifier - { - public: - Token nameToken; - List<RefPtr<ExpressionSyntaxNode>> args; - }; - - // An HLSL `[name(...)]` attribute that hasn't undergone - // any semantic analysis. - // After analysis, this might be transformed into a more specific case. - class HLSLUncheckedAttribute : public HLSLAttribute - { - public: - }; + protected: + virtual bool EqualsImpl(ExpressionType * type) override; + virtual int GetHashCode() override; + virtual ExpressionType* CreateCanonicalType() override; + }; - // An HLSL `[numthreads(x,y,z)]` attribute - class HLSLNumThreadsAttribute : public HLSLAttribute - { - public: - // The number of threads to use along each axis - int32_t x; - int32_t y; - int32_t z; - }; - // HLSL modifiers for geometry shader input topology - class HLSLGeometryShaderInputPrimitiveTypeModifier : public Modifier {}; - class HLSLPointModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; - class HLSLLineModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; - class HLSLTriangleModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; - class HLSLLineAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; - class HLSLTriangleAdjModifier : public HLSLGeometryShaderInputPrimitiveTypeModifier {}; - // + class GenericTypeParamDecl : public SimpleTypeDecl + { + public: + // The bound for the type parameter represents a trait that any + // type used as this parameter must conform to +// TypeExp bound; - // A generic declaration, parameterized on types/values - class GenericDecl : public ContainerDecl - { - public: - // The decl that is genericized... - RefPtr<Decl> inner; + // The "initializer" for the parameter represents a default value + TypeExp initType; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - struct GenericDeclRef : ContainerDeclRef - { - SLANG_DECLARE_DECL_REF(GenericDecl); + struct GenericTypeParamDeclRef : SimpleTypeDeclRef + { + SLANG_DECLARE_DECL_REF(GenericTypeParamDecl); + }; - Decl* GetInner() const { return GetDecl()->inner.Ptr(); } - }; + // A constraint placed as part of a generic declaration + class GenericTypeConstraintDecl : public Decl + { + public: + // A type constraint like `T : U` is constraining `T` to be "below" `U` + // on a lattice of types. This may not be a subtyping relationship + // per se, but it makes sense to use that terminology here, so we + // think of these fields as the sub-type and sup-ertype, respectively. + TypeExp sub; + TypeExp sup; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + struct GenericTypeConstraintDeclRef : DeclRef + { + SLANG_DECLARE_DECL_REF(GenericTypeConstraintDecl); - // The "type" of an expression that names a generic declaration. - class GenericDeclRefType : public ExpressionType - { - public: - GenericDeclRefType(GenericDeclRef declRef) - : declRef(declRef) - {} + RefPtr<ExpressionType> GetSub() { return Substitute(GetDecl()->sub); } + RefPtr<ExpressionType> GetSup() { return Substitute(GetDecl()->sup); } + }; - GenericDeclRef declRef; - GenericDeclRef const& GetDeclRef() const { return declRef; } - virtual String ToString() override; + class GenericValueParamDecl : public VarDeclBase + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - protected: - virtual bool EqualsImpl(ExpressionType * type) override; - virtual int GetHashCode() override; - virtual ExpressionType* CreateCanonicalType() override; - }; + struct GenericValueParamDeclRef : VarDeclBaseRef + { + SLANG_DECLARE_DECL_REF(GenericValueParamDecl); + }; + // The logical "value" of a rererence to a generic value parameter + class GenericParamIntVal : public IntVal + { + public: + VarDeclBaseRef declRef; + GenericParamIntVal(VarDeclBaseRef declRef) + : declRef(declRef) + {} - class GenericTypeParamDecl : public SimpleTypeDecl - { - public: - // The bound for the type parameter represents a trait that any - // type used as this parameter must conform to -// TypeExp bound; + virtual bool EqualsVal(Val* val) override; + virtual String ToString() override; + virtual int GetHashCode() override; + virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + }; - // The "initializer" for the parameter represents a default value - TypeExp initType; + // Declaration of a user-defined modifier + class ModifierDecl : public Decl + { + public: + // The name of the C++ class to instantiate + // (this is a reference to a class in the compiler source code, + // and not the user's source code) + Token classNameToken; + + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; + + // An empty declaration (which might still have modifiers attached). + // + // An empty declaration is uncommon in HLSL, but + // in GLSL it is often used at the global scope + // to declare metadata that logically belongs + // to the entry point, e.g.: + // + // layout(local_size_x = 16) in; + // + class EmptyDecl : public Decl + { + public: + virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; + }; - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + // - struct GenericTypeParamDeclRef : SimpleTypeDeclRef - { - SLANG_DECLARE_DECL_REF(GenericTypeParamDecl); - }; + class SyntaxVisitor : public Object + { + protected: + DiagnosticSink * sink = nullptr; + DiagnosticSink* getSink() { return sink; } - // A constraint placed as part of a generic declaration - class GenericTypeConstraintDecl : public Decl + SourceLanguage sourceLanguage = SourceLanguage::Unknown; + public: + void setSourceLanguage(SourceLanguage language) { - public: - // A type constraint like `T : U` is constraining `T` to be "below" `U` - // on a lattice of types. This may not be a subtyping relationship - // per se, but it makes sense to use that terminology here, so we - // think of these fields as the sub-type and sup-ertype, respectively. - TypeExp sub; - TypeExp sup; - - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; + sourceLanguage = language; + } - struct GenericTypeConstraintDeclRef : DeclRef + SyntaxVisitor(DiagnosticSink * sink) + : sink(sink) + {} + virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode* program) { - SLANG_DECLARE_DECL_REF(GenericTypeConstraintDecl); - - RefPtr<ExpressionType> GetSub() { return Substitute(GetDecl()->sub); } - RefPtr<ExpressionType> GetSup() { return Substitute(GetDecl()->sup); } - }; + for (auto & m : program->Members) + m = m->Accept(this).As<Decl>(); + return program; + } + virtual void visitImportDecl(ImportDecl * decl) = 0; - class GenericValueParamDecl : public VarDeclBase + virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode* func) { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; - - struct GenericValueParamDeclRef : VarDeclBaseRef + func->ReturnType = func->ReturnType.Accept(this); + for (auto & member : func->Members) + member = member->Accept(this).As<Decl>(); + if (func->Body) + func->Body = func->Body->Accept(this).As<BlockStatementSyntaxNode>(); + return func; + } + virtual RefPtr<ScopeDecl> VisitScopeDecl(ScopeDecl* decl) { - SLANG_DECLARE_DECL_REF(GenericValueParamDecl); - }; - - // The logical "value" of a rererence to a generic value parameter - class GenericParamIntVal : public IntVal + // By default don't visit children, because they will always + // be encountered in the ordinary flow of the corresponding statement. + return decl; + } + virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * s) { - public: - VarDeclBaseRef declRef; - - GenericParamIntVal(VarDeclBaseRef declRef) - : declRef(declRef) - {} - - virtual bool EqualsVal(Val* val) override; - virtual String ToString() override; - virtual int GetHashCode() override; - virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; - }; - - // Declaration of a user-defined modifier - class ModifierDecl : public Decl + for (auto & f : s->Members) + f = f->Accept(this).As<Decl>(); + return s; + } + virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * s) { - public: - // The name of the C++ class to instantiate - // (this is a reference to a class in the compiler source code, - // and not the user's source code) - Token classNameToken; - - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; - - // An empty declaration (which might still have modifiers attached). - // - // An empty declaration is uncommon in HLSL, but - // in GLSL it is often used at the global scope - // to declare metadata that logically belongs - // to the entry point, e.g.: - // - // layout(local_size_x = 16) in; - // - class EmptyDecl : public Decl + for (auto & f : s->Members) + f = f->Accept(this).As<Decl>(); + return s; + } + virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl * decl) { - public: - virtual RefPtr<SyntaxNode> Accept(SyntaxVisitor * visitor) override; - }; - - // - - class SyntaxVisitor : public Object + for (auto & m : decl->Members) + m = m->Accept(this).As<Decl>(); + decl->inner = decl->inner->Accept(this).As<Decl>(); + return decl; + } + virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl) { - protected: - DiagnosticSink * sink = nullptr; - DiagnosticSink* getSink() { return sink; } - - SourceLanguage sourceLanguage = SourceLanguage::Unknown; - public: - void setSourceLanguage(SourceLanguage language) - { - sourceLanguage = language; - } - - SyntaxVisitor(DiagnosticSink * sink) - : sink(sink) - {} - virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode* program) - { - for (auto & m : program->Members) - m = m->Accept(this).As<Decl>(); - return program; - } - - virtual void visitImportDecl(ImportDecl * decl) = 0; - - virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode* func) - { - func->ReturnType = func->ReturnType.Accept(this); - for (auto & member : func->Members) - member = member->Accept(this).As<Decl>(); - if (func->Body) - func->Body = func->Body->Accept(this).As<BlockStatementSyntaxNode>(); - return func; - } - virtual RefPtr<ScopeDecl> VisitScopeDecl(ScopeDecl* decl) - { - // By default don't visit children, because they will always - // be encountered in the ordinary flow of the corresponding statement. - return decl; - } - virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * s) - { - for (auto & f : s->Members) - f = f->Accept(this).As<Decl>(); - return s; - } - virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * s) - { - for (auto & f : s->Members) - f = f->Accept(this).As<Decl>(); - return s; - } - virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl * decl) - { - for (auto & m : decl->Members) - m = m->Accept(this).As<Decl>(); - decl->inner = decl->inner->Accept(this).As<Decl>(); - return decl; - } - virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl) - { - decl->Type = decl->Type.Accept(this); - return decl; - } - virtual RefPtr<StatementSyntaxNode> VisitDiscardStatement(DiscardStatementSyntaxNode * stmt) - { - return stmt; - } - virtual RefPtr<StructField> VisitStructField(StructField * f) - { - f->Type = f->Type.Accept(this); - return f; - } - virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode* stmt) - { - for (auto & s : stmt->Statements) - s = s->Accept(this).As<StatementSyntaxNode>(); - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode* stmt) - { - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode* stmt) - { - return stmt; - } + decl->Type = decl->Type.Accept(this); + return decl; + } + virtual RefPtr<StatementSyntaxNode> VisitDiscardStatement(DiscardStatementSyntaxNode * stmt) + { + return stmt; + } + virtual RefPtr<StructField> VisitStructField(StructField * f) + { + f->Type = f->Type.Accept(this); + return f; + } + virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode* stmt) + { + for (auto & s : stmt->Statements) + s = s->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode* stmt) + { + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode* stmt) + { + return stmt; + } - virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode* stmt) - { - if (stmt->Predicate) - stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); - if (stmt->Statement) - stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>(); - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitEmptyStatement(EmptyStatementSyntaxNode* stmt) - { - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode* stmt) - { - if (stmt->InitialStatement) - stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>(); - if (stmt->PredicateExpression) - stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>(); - if (stmt->SideEffectExpression) - stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>(); - if (stmt->Statement) - stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>(); - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode* stmt) - { - if (stmt->Predicate) - stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); - if (stmt->PositiveStatement) - stmt->PositiveStatement = stmt->PositiveStatement->Accept(this).As<StatementSyntaxNode>(); - if (stmt->NegativeStatement) - stmt->NegativeStatement = stmt->NegativeStatement->Accept(this).As<StatementSyntaxNode>(); - return stmt; - } - virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt) - { - if (stmt->condition) - stmt->condition = stmt->condition->Accept(this).As<ExpressionSyntaxNode>(); - if (stmt->body) - stmt->body = stmt->body->Accept(this).As<BlockStatementSyntaxNode>(); - return stmt; - } - virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt) - { - if (stmt->expr) - stmt->expr = stmt->expr->Accept(this).As<ExpressionSyntaxNode>(); - return stmt; - } - virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt) - { - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode* stmt) - { - if (stmt->Expression) - stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitVarDeclrStatement(VarDeclrStatementSyntaxNode* stmt) - { - stmt->decl = stmt->decl->Accept(this).As<DeclBase>(); - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode* stmt) - { - if (stmt->Predicate) - stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); - if (stmt->Statement) - stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>(); - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode* stmt) - { - if (stmt->Expression) - stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); - return stmt; - } + virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode* stmt) + { + if (stmt->Predicate) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->Statement) + stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitEmptyStatement(EmptyStatementSyntaxNode* stmt) + { + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode* stmt) + { + if (stmt->InitialStatement) + stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>(); + if (stmt->PredicateExpression) + stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->SideEffectExpression) + stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->Statement) + stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode* stmt) + { + if (stmt->Predicate) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->PositiveStatement) + stmt->PositiveStatement = stmt->PositiveStatement->Accept(this).As<StatementSyntaxNode>(); + if (stmt->NegativeStatement) + stmt->NegativeStatement = stmt->NegativeStatement->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt) + { + if (stmt->condition) + stmt->condition = stmt->condition->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->body) + stmt->body = stmt->body->Accept(this).As<BlockStatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt) + { + if (stmt->expr) + stmt->expr = stmt->expr->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt) + { + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode* stmt) + { + if (stmt->Expression) + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitVarDeclrStatement(VarDeclrStatementSyntaxNode* stmt) + { + stmt->decl = stmt->decl->Accept(this).As<DeclBase>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode* stmt) + { + if (stmt->Predicate) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (stmt->Statement) + stmt->Statement = stmt->Statement->Accept(this).As<StatementSyntaxNode>(); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode* stmt) + { + if (stmt->Expression) + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } - virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode* expr) - { - for (auto && child : expr->Arguments) - child->Accept(this); - return expr; - } - virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode* expr) - { - return expr; - } - virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* expr) - { - if (expr->BaseExpression) - expr->BaseExpression = expr->BaseExpression->Accept(this).As<ExpressionSyntaxNode>(); - if (expr->IndexExpression) - expr->IndexExpression = expr->IndexExpression->Accept(this).As<ExpressionSyntaxNode>(); - return expr; - } - virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * stmt) - { - if (stmt->BaseExpression) - stmt->BaseExpression = stmt->BaseExpression->Accept(this).As<ExpressionSyntaxNode>(); - return stmt; - } - virtual RefPtr<ExpressionSyntaxNode> VisitSwizzleExpression(SwizzleExpr * expr) - { - if (expr->base) - expr->base->Accept(this); - return expr; - } - virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode* stmt) - { - stmt->FunctionExpr->Accept(this); - for (auto & arg : stmt->Arguments) - arg = arg->Accept(this).As<ExpressionSyntaxNode>(); - return stmt; - } - virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * stmt) - { - if (stmt->Expression) - stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); - return stmt->Expression; - } - virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode* expr) - { - return expr; - } + virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode* expr) + { + for (auto && child : expr->Arguments) + child->Accept(this); + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode* expr) + { + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* expr) + { + if (expr->BaseExpression) + expr->BaseExpression = expr->BaseExpression->Accept(this).As<ExpressionSyntaxNode>(); + if (expr->IndexExpression) + expr->IndexExpression = expr->IndexExpression->Accept(this).As<ExpressionSyntaxNode>(); + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * stmt) + { + if (stmt->BaseExpression) + stmt->BaseExpression = stmt->BaseExpression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<ExpressionSyntaxNode> VisitSwizzleExpression(SwizzleExpr * expr) + { + if (expr->base) + expr->base->Accept(this); + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode* stmt) + { + stmt->FunctionExpr->Accept(this); + for (auto & arg : stmt->Arguments) + arg = arg->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * stmt) + { + if (stmt->Expression) + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt->Expression; + } + virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode* expr) + { + return expr; + } - virtual RefPtr<ParameterSyntaxNode> VisitParameter(ParameterSyntaxNode* param) - { - return param; - } - virtual RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr* type) - { - return type; - } + virtual RefPtr<ParameterSyntaxNode> VisitParameter(ParameterSyntaxNode* param) + { + return param; + } + virtual RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr* type) + { + return type; + } - virtual RefPtr<Variable> VisitDeclrVariable(Variable* dclr) - { - if (dclr->Expr) - dclr->Expr = dclr->Expr->Accept(this).As<ExpressionSyntaxNode>(); - return dclr; - } + virtual RefPtr<Variable> VisitDeclrVariable(Variable* dclr) + { + if (dclr->Expr) + dclr->Expr = dclr->Expr->Accept(this).As<ExpressionSyntaxNode>(); + return dclr; + } - virtual TypeExp VisitTypeExp(TypeExp const& typeExp) + virtual TypeExp VisitTypeExp(TypeExp const& typeExp) + { + TypeExp result = typeExp; + result.exp = typeExp.exp->Accept(this).As<ExpressionSyntaxNode>(); + if (auto typeType = result.exp->Type.type.As<TypeType>()) { - TypeExp result = typeExp; - result.exp = typeExp.exp->Accept(this).As<ExpressionSyntaxNode>(); - if (auto typeType = result.exp->Type.type.As<TypeType>()) - { - result.type = typeType->type; - } - return result; + result.type = typeType->type; } + return result; + } - virtual void VisitExtensionDecl(ExtensionDecl* /*decl*/) - {} + virtual void VisitExtensionDecl(ExtensionDecl* /*decl*/) + {} - virtual void VisitConstructorDecl(ConstructorDecl* /*decl*/) - {} + virtual void VisitConstructorDecl(ConstructorDecl* /*decl*/) + {} - virtual void visitSubscriptDecl(SubscriptDecl* decl) = 0; + virtual void visitSubscriptDecl(SubscriptDecl* decl) = 0; - virtual void visitAccessorDecl(AccessorDecl* decl) = 0; + virtual void visitAccessorDecl(AccessorDecl* decl) = 0; - virtual void visitInterfaceDecl(InterfaceDecl* /*decl*/) = 0; + virtual void visitInterfaceDecl(InterfaceDecl* /*decl*/) = 0; - virtual void visitInheritanceDecl(InheritanceDecl* /*decl*/) = 0; + virtual void visitInheritanceDecl(InheritanceDecl* /*decl*/) = 0; - virtual RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* typeExpr) - { - return typeExpr; - } + virtual RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* typeExpr) + { + return typeExpr; + } - virtual void VisitDeclGroup(DeclGroup* declGroup) + virtual void VisitDeclGroup(DeclGroup* declGroup) + { + for (auto decl : declGroup->decls) { - for (auto decl : declGroup->decls) - { - decl->Accept(this); - } + decl->Accept(this); } + } - virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) = 0; - }; - - // Note(tfoley): These logically belong to `ExpressionType`, - // but order-of-declaration stuff makes that tricky - // - // TODO(tfoley): These should really belong to the compilation context! - // - void RegisterBuiltinDecl( - RefPtr<Decl> decl, - RefPtr<BuiltinTypeModifier> modifier); - void RegisterMagicDecl( - RefPtr<Decl> decl, - RefPtr<MagicTypeModifier> modifier); - - // Look up a magic declaration by its name - RefPtr<Decl> findMagicDecl( - String const& name); - - // Create an instance of a syntax class by name - SyntaxNodeBase* createInstanceOfSyntaxClassByName( - String const& name); - - } -} + virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) = 0; + }; + + // Note(tfoley): These logically belong to `ExpressionType`, + // but order-of-declaration stuff makes that tricky + // + // TODO(tfoley): These should really belong to the compilation context! + // + void RegisterBuiltinDecl( + RefPtr<Decl> decl, + RefPtr<BuiltinTypeModifier> modifier); + void RegisterMagicDecl( + RefPtr<Decl> decl, + RefPtr<MagicTypeModifier> modifier); + + // Look up a magic declaration by its name + RefPtr<Decl> findMagicDecl( + String const& name); + + // Create an instance of a syntax class by name + SyntaxNodeBase* createInstanceOfSyntaxClassByName( + String const& name); + +} // namespace Slang #endif
\ No newline at end of file diff --git a/source/slang/token.cpp b/source/slang/token.cpp index 436a3a740..ff2ded818 100644 --- a/source/slang/token.cpp +++ b/source/slang/token.cpp @@ -4,7 +4,6 @@ #include <assert.h> namespace Slang { -namespace Compiler { char const* TokenTypeToString(TokenType type) { @@ -19,4 +18,4 @@ char const* TokenTypeToString(TokenType type) } } -}} +} // namespace Slang diff --git a/source/slang/token.h b/source/slang/token.h index 00a55feb1..08eafccae 100644 --- a/source/slang/token.h +++ b/source/slang/token.h @@ -7,7 +7,6 @@ #include "source-loc.h" namespace Slang { -namespace Compiler { using namespace CoreLib::Basic; @@ -45,6 +44,6 @@ public: -}} +} // namespace Slang #endif diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp index 4e5c98ed5..8b390b718 100644 --- a/source/slang/type-layout.cpp +++ b/source/slang/type-layout.cpp @@ -6,7 +6,6 @@ #include <assert.h> namespace Slang { -namespace Compiler { size_t RoundToAlignment(size_t offset, size_t alignment) { @@ -1122,4 +1121,4 @@ SimpleLayoutInfo GetLayout(ExpressionType* type, LayoutRule rule) return GetLayout(type, rulesImpl); } -}} +} // namespace Slang diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index be54bbf53..df937c1d4 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -12,8 +12,6 @@ namespace Slang { typedef intptr_t Int; typedef uintptr_t UInt; -namespace Compiler { - // Forward declarations enum class BaseType; @@ -545,6 +543,6 @@ createStructuredBufferTypeLayout( // -}} +} #endif
\ No newline at end of file diff --git a/source/slangc/main.cpp b/source/slangc/main.cpp index b0ddc7ab6..39be1a1fd 100644 --- a/source/slangc/main.cpp +++ b/source/slangc/main.cpp @@ -149,7 +149,7 @@ struct OptionsParser #undef CASE #define CASE(EXT, LANG, PROFILE) \ - else if(path.EndsWith(EXT)) do { addInputForeignShaderPath(path, SLANG_SOURCE_LANGUAGE_##LANG, SlangProfileID(Slang::Compiler::Profile::PROFILE)); } while(0) + else if(path.EndsWith(EXT)) do { addInputForeignShaderPath(path, SLANG_SOURCE_LANGUAGE_##LANG, SlangProfileID(Slang::Profile::PROFILE)); } while(0) // TODO: need a way to pass along stage/profile and entry-point info for these cases... CASE(".vert", GLSL, GLSL_Vertex); CASE(".frag", GLSL, GLSL_Fragment); |
