diff options
Diffstat (limited to 'source/slang/check.cpp')
| -rw-r--r-- | source/slang/check.cpp | 7781 |
1 files changed, 3889 insertions, 3892 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 6ff8efe9e..b3e1baf79 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -7,3169 +7,3112 @@ namespace Slang { - namespace Compiler + bool IsNumeric(BaseType t) { - bool IsNumeric(BaseType t) - { - return t == BaseType::Int || t == BaseType::Float || t == BaseType::UInt; - } - - String TranslateHLSLTypeNames(String name) - { - if (name == "float2" || name == "half2") - return "vec2"; - else if (name == "float3" || name == "half3") - return "vec3"; - else if (name == "float4" || name == "half4") - return "vec4"; - else if (name == "half") - return "float"; - else if (name == "int2") - return "ivec2"; - else if (name == "int3") - return "ivec3"; - else if (name == "int4") - return "ivec4"; - else if (name == "uint2") - return "uvec2"; - else if (name == "uint3") - return "uvec3"; - else if (name == "uint4") - return "uvec4"; - else if (name == "float3x3" || name == "half3x3") - return "mat3"; - else if (name == "float4x4" || name == "half4x4") - return "mat4"; - else - return name; - } + return t == BaseType::Int || t == BaseType::Float || t == BaseType::UInt; + } - class SemanticsVisitor : public SyntaxVisitor - { - ProgramSyntaxNode * program = nullptr; - FunctionSyntaxNode * function = nullptr; - CompileOptions const* options = nullptr; - CompileRequest* request = nullptr; + String TranslateHLSLTypeNames(String name) + { + if (name == "float2" || name == "half2") + return "vec2"; + else if (name == "float3" || name == "half3") + return "vec3"; + else if (name == "float4" || name == "half4") + return "vec4"; + else if (name == "half") + return "float"; + else if (name == "int2") + return "ivec2"; + else if (name == "int3") + return "ivec3"; + else if (name == "int4") + return "ivec4"; + else if (name == "uint2") + return "uvec2"; + else if (name == "uint3") + return "uvec3"; + else if (name == "uint4") + return "uvec4"; + else if (name == "float3x3" || name == "half3x3") + return "mat3"; + else if (name == "float4x4" || name == "half4x4") + return "mat4"; + else + return name; + } - // lexical outer statements - List<StatementSyntaxNode*> outerStmts; - public: - SemanticsVisitor( - DiagnosticSink * pErr, - CompileOptions const& options, - CompileRequest* request) - : SyntaxVisitor(pErr) - , options(&options) - , request(request) - { - } + class SemanticsVisitor : public SyntaxVisitor + { + ProgramSyntaxNode * program = nullptr; + FunctionSyntaxNode * function = nullptr; + CompileOptions const* options = nullptr; + CompileRequest* request = nullptr; + + // lexical outer statements + List<StatementSyntaxNode*> outerStmts; + public: + SemanticsVisitor( + DiagnosticSink * pErr, + CompileOptions const& options, + CompileRequest* request) + : SyntaxVisitor(pErr) + , options(&options) + , request(request) + { + } - CompileOptions const& getOptions() { return *options; } + CompileOptions const& getOptions() { return *options; } - public: - // Translate Types - RefPtr<ExpressionType> typeResult; - RefPtr<ExpressionSyntaxNode> TranslateTypeNodeImpl(const RefPtr<ExpressionSyntaxNode> & node) - { - if (!node) return nullptr; - auto expr = node->Accept(this).As<ExpressionSyntaxNode>(); - expr = ExpectATypeRepr(expr); - return expr; - } - RefPtr<ExpressionType> ExtractTypeFromTypeRepr(const RefPtr<ExpressionSyntaxNode>& typeRepr) + public: + // Translate Types + RefPtr<ExpressionType> typeResult; + RefPtr<ExpressionSyntaxNode> TranslateTypeNodeImpl(const RefPtr<ExpressionSyntaxNode> & node) + { + if (!node) return nullptr; + auto expr = node->Accept(this).As<ExpressionSyntaxNode>(); + expr = ExpectATypeRepr(expr); + return expr; + } + RefPtr<ExpressionType> ExtractTypeFromTypeRepr(const RefPtr<ExpressionSyntaxNode>& typeRepr) + { + if (!typeRepr) return nullptr; + if (auto typeType = typeRepr->Type->As<TypeType>()) { - if (!typeRepr) return nullptr; - if (auto typeType = typeRepr->Type->As<TypeType>()) - { - return typeType->type; - } - return ExpressionType::Error; + return typeType->type; } - RefPtr<ExpressionType> TranslateTypeNode(const RefPtr<ExpressionSyntaxNode> & node) + return ExpressionType::Error; + } + RefPtr<ExpressionType> TranslateTypeNode(const RefPtr<ExpressionSyntaxNode> & node) + { + if (!node) return nullptr; + auto typeRepr = TranslateTypeNodeImpl(node); + return ExtractTypeFromTypeRepr(typeRepr); + } + TypeExp TranslateTypeNode(TypeExp const& typeExp) + { + // HACK(tfoley): It seems that in some cases we end up re-checking + // syntax that we've already checked. We need to root-cause that + // issue, but for now a quick fix in this case is to early + // exist if we've already got a type associated here: + if (typeExp.type) { - if (!node) return nullptr; - auto typeRepr = TranslateTypeNodeImpl(node); - return ExtractTypeFromTypeRepr(typeRepr); + return typeExp; } - TypeExp TranslateTypeNode(TypeExp const& typeExp) - { - // HACK(tfoley): It seems that in some cases we end up re-checking - // syntax that we've already checked. We need to root-cause that - // issue, but for now a quick fix in this case is to early - // exist if we've already got a type associated here: - if (typeExp.type) - { - return typeExp; - } - auto typeRepr = TranslateTypeNodeImpl(typeExp.exp); + auto typeRepr = TranslateTypeNodeImpl(typeExp.exp); - TypeExp result; - result.exp = typeRepr; - result.type = ExtractTypeFromTypeRepr(typeRepr); - return result; - } + TypeExp result; + result.exp = typeRepr; + result.type = ExtractTypeFromTypeRepr(typeRepr); + return result; + } - RefPtr<ExpressionSyntaxNode> ConstructDeclRefExpr( - DeclRef declRef, - RefPtr<ExpressionSyntaxNode> baseExpr, - RefPtr<ExpressionSyntaxNode> originalExpr) + RefPtr<ExpressionSyntaxNode> ConstructDeclRefExpr( + DeclRef declRef, + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + if (baseExpr) + { + auto expr = new MemberExpressionSyntaxNode(); + expr->Position = originalExpr->Position; + expr->BaseExpression = baseExpr; + expr->MemberName = declRef.GetName(); + expr->Type = GetTypeForDeclRef(declRef); + expr->declRef = declRef; + return expr; + } + else { - if (baseExpr) - { - auto expr = new MemberExpressionSyntaxNode(); - expr->Position = originalExpr->Position; - expr->BaseExpression = baseExpr; - expr->MemberName = declRef.GetName(); - expr->Type = GetTypeForDeclRef(declRef); - expr->declRef = declRef; - return expr; - } - else - { - auto expr = new VarExpressionSyntaxNode(); - expr->Position = originalExpr->Position; - expr->Variable = declRef.GetName(); - expr->Type = GetTypeForDeclRef(declRef); - expr->declRef = declRef; - return expr; - } + auto expr = new VarExpressionSyntaxNode(); + expr->Position = originalExpr->Position; + expr->Variable = declRef.GetName(); + expr->Type = GetTypeForDeclRef(declRef); + expr->declRef = declRef; + return expr; } + } - RefPtr<ExpressionSyntaxNode> ConstructDerefExpr( - RefPtr<ExpressionSyntaxNode> base, - RefPtr<ExpressionSyntaxNode> originalExpr) - { - auto ptrLikeType = base->Type->As<PointerLikeType>(); - assert(ptrLikeType); + RefPtr<ExpressionSyntaxNode> ConstructDerefExpr( + RefPtr<ExpressionSyntaxNode> base, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + auto ptrLikeType = base->Type->As<PointerLikeType>(); + assert(ptrLikeType); - auto derefExpr = new DerefExpr(); - derefExpr->Position = originalExpr->Position; - derefExpr->base = base; - derefExpr->Type = ptrLikeType->elementType; + auto derefExpr = new DerefExpr(); + derefExpr->Position = originalExpr->Position; + derefExpr->base = base; + derefExpr->Type = ptrLikeType->elementType; - // TODO(tfoley): handle l-value status here + // TODO(tfoley): handle l-value status here - return derefExpr; - } + return derefExpr; + } - RefPtr<ExpressionSyntaxNode> ConstructLookupResultExpr( - LookupResultItem const& item, - RefPtr<ExpressionSyntaxNode> baseExpr, - RefPtr<ExpressionSyntaxNode> originalExpr) + RefPtr<ExpressionSyntaxNode> ConstructLookupResultExpr( + LookupResultItem const& item, + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + // If we collected any breadcrumbs, then these represent + // additional segments of the lookup path that we need + // to expand here. + auto bb = baseExpr; + for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) { - // If we collected any breadcrumbs, then these represent - // additional segments of the lookup path that we need - // to expand here. - auto bb = baseExpr; - for (auto breadcrumb = item.breadcrumbs; breadcrumb; breadcrumb = breadcrumb->next) + switch (breadcrumb->kind) { - switch (breadcrumb->kind) - { - case LookupResultItem::Breadcrumb::Kind::Member: - bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, originalExpr); - break; - case LookupResultItem::Breadcrumb::Kind::Deref: - bb = ConstructDerefExpr(bb, originalExpr); - break; - default: - SLANG_UNREACHABLE("all cases handle"); - } + case LookupResultItem::Breadcrumb::Kind::Member: + bb = ConstructDeclRefExpr(breadcrumb->declRef, bb, originalExpr); + break; + case LookupResultItem::Breadcrumb::Kind::Deref: + bb = ConstructDerefExpr(bb, originalExpr); + break; + default: + SLANG_UNREACHABLE("all cases handle"); } - - return ConstructDeclRefExpr(item.declRef, bb, originalExpr); } - RefPtr<ExpressionSyntaxNode> createLookupResultExpr( - LookupResult const& lookupResult, - RefPtr<ExpressionSyntaxNode> baseExpr, - RefPtr<ExpressionSyntaxNode> originalExpr) + return ConstructDeclRefExpr(item.declRef, bb, originalExpr); + } + + RefPtr<ExpressionSyntaxNode> createLookupResultExpr( + LookupResult const& lookupResult, + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> originalExpr) + { + if (lookupResult.isOverloaded()) { - if (lookupResult.isOverloaded()) - { - auto overloadedExpr = new OverloadedExpr(); - overloadedExpr->Position = originalExpr->Position; - overloadedExpr->Type = ExpressionType::Overloaded; - overloadedExpr->base = baseExpr; - overloadedExpr->lookupResult2 = lookupResult; - return overloadedExpr; - } - else - { - return ConstructLookupResultExpr(lookupResult.item, baseExpr, originalExpr); - } + auto overloadedExpr = new OverloadedExpr(); + overloadedExpr->Position = originalExpr->Position; + overloadedExpr->Type = ExpressionType::Overloaded; + overloadedExpr->base = baseExpr; + overloadedExpr->lookupResult2 = lookupResult; + return overloadedExpr; } - - RefPtr<ExpressionSyntaxNode> ResolveOverloadedExpr(RefPtr<OverloadedExpr> overloadedExpr, LookupMask mask) + else { - auto lookupResult = overloadedExpr->lookupResult2; - assert(lookupResult.isValid() && lookupResult.isOverloaded()); - - // Take the lookup result we had, and refine it based on what is expected in context. - lookupResult = refineLookup(lookupResult, mask); - - if (!lookupResult.isValid()) - { - // If we didn't find any symbols after filtering, then just - // use the original and report errors that way - return overloadedExpr; - } - - if (lookupResult.isOverloaded()) - { - // We had an ambiguity anyway, so report it. - getSink()->diagnose(overloadedExpr, Diagnostics::ambiguousReference, lookupResult.items[0].declRef.GetName()); + return ConstructLookupResultExpr(lookupResult.item, baseExpr, originalExpr); + } + } - for(auto item : lookupResult.items) - { - String declString = getDeclSignatureString(item); - getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString); - } + RefPtr<ExpressionSyntaxNode> ResolveOverloadedExpr(RefPtr<OverloadedExpr> overloadedExpr, LookupMask mask) + { + auto lookupResult = overloadedExpr->lookupResult2; + assert(lookupResult.isValid() && lookupResult.isOverloaded()); - // TODO(tfoley): should we construct a new ErrorExpr here? - overloadedExpr->Type = ExpressionType::Error; - return overloadedExpr; - } + // Take the lookup result we had, and refine it based on what is expected in context. + lookupResult = refineLookup(lookupResult, mask); - // otherwise, we had a single decl and it was valid, hooray! - return ConstructLookupResultExpr(lookupResult.item, overloadedExpr->base, overloadedExpr); + if (!lookupResult.isValid()) + { + // If we didn't find any symbols after filtering, then just + // use the original and report errors that way + return overloadedExpr; } - RefPtr<ExpressionSyntaxNode> ExpectATypeRepr(RefPtr<ExpressionSyntaxNode> expr) + if (lookupResult.isOverloaded()) { - if (auto overloadedExpr = expr.As<OverloadedExpr>()) - { - expr = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); - } + // We had an ambiguity anyway, so report it. + getSink()->diagnose(overloadedExpr, Diagnostics::ambiguousReference, lookupResult.items[0].declRef.GetName()); - if (auto typeType = expr->Type.type->As<TypeType>()) + for(auto item : lookupResult.items) { - return expr; - } - else if (expr->Type.type->Equals(ExpressionType::Error)) - { - return expr; + String declString = getDeclSignatureString(item); + getSink()->diagnose(item.declRef, Diagnostics::overloadCandidate, declString); } - getSink()->diagnose(expr, Diagnostics::unimplemented, "expected a type"); - // TODO: construct some kind of `ErrorExpr`? - return expr; + // TODO(tfoley): should we construct a new ErrorExpr here? + overloadedExpr->Type = ExpressionType::Error; + return overloadedExpr; } - RefPtr<ExpressionType> ExpectAType(RefPtr<ExpressionSyntaxNode> expr) + // otherwise, we had a single decl and it was valid, hooray! + return ConstructLookupResultExpr(lookupResult.item, overloadedExpr->base, overloadedExpr); + } + + RefPtr<ExpressionSyntaxNode> ExpectATypeRepr(RefPtr<ExpressionSyntaxNode> expr) + { + if (auto overloadedExpr = expr.As<OverloadedExpr>()) { - auto typeRepr = ExpectATypeRepr(expr); - if (auto typeType = typeRepr->Type->As<TypeType>()) - { - return typeType->type; - } - return ExpressionType::Error; + expr = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); } - RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<ExpressionSyntaxNode> exp) + if (auto typeType = expr->Type.type->As<TypeType>()) { - return ExpectAType(exp); + return expr; } + else if (expr->Type.type->Equals(ExpressionType::Error)) + { + return expr; + } + + getSink()->diagnose(expr, Diagnostics::unimplemented, "expected a type"); + // TODO: construct some kind of `ErrorExpr`? + return expr; + } - RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<ExpressionSyntaxNode> exp) + RefPtr<ExpressionType> ExpectAType(RefPtr<ExpressionSyntaxNode> expr) + { + auto typeRepr = ExpectATypeRepr(expr); + if (auto typeType = typeRepr->Type->As<TypeType>()) { - return CheckIntegerConstantExpression(exp.Ptr()); + return typeType->type; } + return ExpressionType::Error; + } + + RefPtr<ExpressionType> ExtractGenericArgType(RefPtr<ExpressionSyntaxNode> exp) + { + return ExpectAType(exp); + } + + RefPtr<IntVal> ExtractGenericArgInteger(RefPtr<ExpressionSyntaxNode> exp) + { + return CheckIntegerConstantExpression(exp.Ptr()); + } - RefPtr<Val> ExtractGenericArgVal(RefPtr<ExpressionSyntaxNode> exp) + RefPtr<Val> ExtractGenericArgVal(RefPtr<ExpressionSyntaxNode> exp) + { + if (auto overloadedExpr = exp.As<OverloadedExpr>()) { - if (auto overloadedExpr = exp.As<OverloadedExpr>()) - { - // assume that if it is overloaded, we want a type - exp = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); - } + // assume that if it is overloaded, we want a type + exp = ResolveOverloadedExpr(overloadedExpr, LookupMask::Type); + } - if (auto typeType = exp->Type->As<TypeType>()) - { - return typeType->type; - } - else if (exp->Type->Equals(ExpressionType::Error)) - { - return exp->Type.type; - } - else - { - return ExtractGenericArgInteger(exp); - } + if (auto typeType = exp->Type->As<TypeType>()) + { + return typeType->type; + } + else if (exp->Type->Equals(ExpressionType::Error)) + { + return exp->Type.type; } + else + { + return ExtractGenericArgInteger(exp); + } + } - // Construct a type reprsenting the instantiation of - // the given generic declaration for the given arguments. - // The arguments should already be checked against - // the declaration. - RefPtr<ExpressionType> InstantiateGenericType( - GenericDeclRef genericDeclRef, - List<RefPtr<ExpressionSyntaxNode>> const& args) + // Construct a type reprsenting the instantiation of + // the given generic declaration for the given arguments. + // The arguments should already be checked against + // the declaration. + RefPtr<ExpressionType> InstantiateGenericType( + GenericDeclRef genericDeclRef, + List<RefPtr<ExpressionSyntaxNode>> const& args) + { + RefPtr<Substitutions> subst = new Substitutions(); + subst->genericDecl = genericDeclRef.GetDecl(); + subst->outer = genericDeclRef.substitutions; + + for (auto argExpr : args) { - RefPtr<Substitutions> subst = new Substitutions(); - subst->genericDecl = genericDeclRef.GetDecl(); - subst->outer = genericDeclRef.substitutions; + subst->args.Add(ExtractGenericArgVal(argExpr)); + } - for (auto argExpr : args) - { - subst->args.Add(ExtractGenericArgVal(argExpr)); - } + DeclRef innerDeclRef; + innerDeclRef.decl = genericDeclRef.GetInner(); + innerDeclRef.substitutions = subst; - DeclRef innerDeclRef; - innerDeclRef.decl = genericDeclRef.GetInner(); - innerDeclRef.substitutions = subst; + return DeclRefType::Create(innerDeclRef); + } - return DeclRefType::Create(innerDeclRef); + // Make sure a declaration has been checked, so we can refer to it. + // Note that this may lead to us recursively invoking checking, + // so this may not be the best way to handle things. + void EnsureDecl(RefPtr<Decl> decl, DeclCheckState state = DeclCheckState::CheckedHeader) + { + if (decl->IsChecked(state)) return; + if (decl->checkState == DeclCheckState::CheckingHeader) + { + // We tried to reference the same declaration while checking it! + throw "circularity"; } - // Make sure a declaration has been checked, so we can refer to it. - // Note that this may lead to us recursively invoking checking, - // so this may not be the best way to handle things. - void EnsureDecl(RefPtr<Decl> decl, DeclCheckState state = DeclCheckState::CheckedHeader) + if (DeclCheckState::CheckingHeader > decl->checkState) { - if (decl->IsChecked(state)) return; - if (decl->checkState == DeclCheckState::CheckingHeader) - { - // We tried to reference the same declaration while checking it! - throw "circularity"; - } - - if (DeclCheckState::CheckingHeader > decl->checkState) - { - decl->SetCheckState(DeclCheckState::CheckingHeader); - } + decl->SetCheckState(DeclCheckState::CheckingHeader); + } - // TODO: not all of the `Visit` cases are ready to - // handle this being called on-the-fly - decl->Accept(this); + // TODO: not all of the `Visit` cases are ready to + // handle this being called on-the-fly + decl->Accept(this); - decl->SetCheckState(DeclCheckState::Checked); - } + decl->SetCheckState(DeclCheckState::Checked); + } - void EnusreAllDeclsRec(RefPtr<Decl> decl) + void EnusreAllDeclsRec(RefPtr<Decl> decl) + { + EnsureDecl(decl, DeclCheckState::Checked); + if (auto containerDecl = decl.As<ContainerDecl>()) { - EnsureDecl(decl, DeclCheckState::Checked); - if (auto containerDecl = decl.As<ContainerDecl>()) + for (auto m : containerDecl->Members) { - for (auto m : containerDecl->Members) - { - EnusreAllDeclsRec(m); - } + EnusreAllDeclsRec(m); } } + } - // A "proper" type is one that can be used as the type of an expression. - // Put simply, it can be a concrete type like `int`, or a generic - // type that is applied to arguments, like `Texture2D<float4>`. - // The type `void` is also a proper type, since we can have expressions - // that return a `void` result (e.g., many function calls). - // - // A "non-proper" type is any type that can't actually have values. - // A simple example of this in C++ is `std::vector` - you can't have - // a value of this type. - // - // Part of what this function does is give errors if somebody tries - // to use a non-proper type as the type of a variable (or anything - // else that needs a proper type). - // - // The other thing it handles is the fact that HLSL lets you use - // the name of a non-proper type, and then have the compiler fill - // in the default values for its type arguments (e.g., a variable - // given type `Texture2D` will actually have type `Texture2D<float4>`). - bool CoerceToProperTypeImpl(TypeExp const& typeExp, RefPtr<ExpressionType>* outProperType) - { - ExpressionType* type = typeExp.type.Ptr(); - if (auto genericDeclRefType = type->As<GenericDeclRefType>()) - { - // We are using a reference to a generic declaration as a concrete - // type. This means we should substitute in any default parameter values - // if they are available. - // - // TODO(tfoley): A more expressive type system would substitute in - // "fresh" variables and then solve for their values... - // + // A "proper" type is one that can be used as the type of an expression. + // Put simply, it can be a concrete type like `int`, or a generic + // type that is applied to arguments, like `Texture2D<float4>`. + // The type `void` is also a proper type, since we can have expressions + // that return a `void` result (e.g., many function calls). + // + // A "non-proper" type is any type that can't actually have values. + // A simple example of this in C++ is `std::vector` - you can't have + // a value of this type. + // + // Part of what this function does is give errors if somebody tries + // to use a non-proper type as the type of a variable (or anything + // else that needs a proper type). + // + // The other thing it handles is the fact that HLSL lets you use + // the name of a non-proper type, and then have the compiler fill + // in the default values for its type arguments (e.g., a variable + // given type `Texture2D` will actually have type `Texture2D<float4>`). + bool CoerceToProperTypeImpl(TypeExp const& typeExp, RefPtr<ExpressionType>* outProperType) + { + ExpressionType* type = typeExp.type.Ptr(); + if (auto genericDeclRefType = type->As<GenericDeclRefType>()) + { + // We are using a reference to a generic declaration as a concrete + // type. This means we should substitute in any default parameter values + // if they are available. + // + // TODO(tfoley): A more expressive type system would substitute in + // "fresh" variables and then solve for their values... + // - auto genericDeclRef = genericDeclRefType->GetDeclRef(); - EnsureDecl(genericDeclRef.decl); - List<RefPtr<ExpressionSyntaxNode>> args; - for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) + auto genericDeclRef = genericDeclRefType->GetDeclRef(); + EnsureDecl(genericDeclRef.decl); + List<RefPtr<ExpressionSyntaxNode>> args; + for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) + { + if (auto typeParam = member.As<GenericTypeParamDecl>()) { - if (auto typeParam = member.As<GenericTypeParamDecl>()) + if (!typeParam->initType.exp) { - if (!typeParam->initType.exp) + if (outProperType) { - if (outProperType) - { - getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); - *outProperType = ExpressionType::Error; - } - return false; + getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); + *outProperType = ExpressionType::Error; } - - // TODO: this is one place where syntax should get cloned! - if(outProperType) - args.Add(typeParam->initType.exp); + return false; } - else if (auto valParam = member.As<GenericValueParamDecl>()) + + // TODO: this is one place where syntax should get cloned! + if(outProperType) + args.Add(typeParam->initType.exp); + } + else if (auto valParam = member.As<GenericValueParamDecl>()) + { + if (!valParam->Expr) { - if (!valParam->Expr) + if (outProperType) { - if (outProperType) - { - getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); - *outProperType = ExpressionType::Error; - } - return false; + getSink()->diagnose(typeExp.exp.Ptr(), Diagnostics::unimplemented, "can't fill in default for generic type parameter"); + *outProperType = ExpressionType::Error; } - - // TODO: this is one place where syntax should get cloned! - if(outProperType) - args.Add(valParam->Expr); - } - else - { - // ignore non-parameter members + return false; } - } - if (outProperType) + // TODO: this is one place where syntax should get cloned! + if(outProperType) + args.Add(valParam->Expr); + } + else { - *outProperType = InstantiateGenericType(genericDeclRef, args); + // ignore non-parameter members } - return true; } - else + + if (outProperType) { - // default case: we expect this to already be a proper type - if (outProperType) - { - *outProperType = type; - } - return true; + *outProperType = InstantiateGenericType(genericDeclRef, args); } + return true; + } + else + { + // default case: we expect this to already be a proper type + if (outProperType) + { + *outProperType = type; + } + return true; } + } - TypeExp CoerceToProperType(TypeExp const& typeExp) - { - TypeExp result = typeExp; - CoerceToProperTypeImpl(typeExp, &result.type); - return result; - } + TypeExp CoerceToProperType(TypeExp const& typeExp) + { + TypeExp result = typeExp; + CoerceToProperTypeImpl(typeExp, &result.type); + return result; + } - bool CanCoerceToProperType(TypeExp const& typeExp) - { - return CoerceToProperTypeImpl(typeExp, nullptr); - } + bool CanCoerceToProperType(TypeExp const& typeExp) + { + return CoerceToProperTypeImpl(typeExp, nullptr); + } - // Check a type, and coerce it to be proper - TypeExp CheckProperType(TypeExp typeExp) - { - return CoerceToProperType(TranslateTypeNode(typeExp)); - } + // Check a type, and coerce it to be proper + TypeExp CheckProperType(TypeExp typeExp) + { + return CoerceToProperType(TranslateTypeNode(typeExp)); + } - // For our purposes, a "usable" type is one that can be - // used to declare a function parameter, variable, etc. - // These turn out to be all the proper types except - // `void`. - // - // TODO(tfoley): consider just allowing `void` as a - // simple example of a "unit" type, and get rid of - // this check. - TypeExp CoerceToUsableType(TypeExp const& typeExp) - { - TypeExp result = CoerceToProperType(typeExp); - ExpressionType* type = result.type.Ptr(); - if (auto basicType = type->As<BasicExpressionType>()) + // For our purposes, a "usable" type is one that can be + // used to declare a function parameter, variable, etc. + // These turn out to be all the proper types except + // `void`. + // + // TODO(tfoley): consider just allowing `void` as a + // simple example of a "unit" type, and get rid of + // this check. + TypeExp CoerceToUsableType(TypeExp const& typeExp) + { + TypeExp result = CoerceToProperType(typeExp); + ExpressionType* type = result.type.Ptr(); + if (auto basicType = type->As<BasicExpressionType>()) + { + // TODO: `void` shouldn't be a basic type, to make this easier to avoid + if (basicType->BaseType == BaseType::Void) { - // TODO: `void` shouldn't be a basic type, to make this easier to avoid - if (basicType->BaseType == BaseType::Void) - { - // TODO(tfoley): pick the right diagnostic message - getSink()->diagnose(result.exp.Ptr(), Diagnostics::invalidTypeVoid); - result.type = ExpressionType::Error; - return result; - } + // TODO(tfoley): pick the right diagnostic message + getSink()->diagnose(result.exp.Ptr(), Diagnostics::invalidTypeVoid); + result.type = ExpressionType::Error; + return result; } - return result; } + return result; + } - // Check a type, and coerce it to be usable - TypeExp CheckUsableType(TypeExp typeExp) - { - return CoerceToUsableType(TranslateTypeNode(typeExp)); - } + // Check a type, and coerce it to be usable + TypeExp CheckUsableType(TypeExp typeExp) + { + return CoerceToUsableType(TranslateTypeNode(typeExp)); + } - RefPtr<ExpressionSyntaxNode> CheckTerm(RefPtr<ExpressionSyntaxNode> term) - { - if (!term) return nullptr; - return term->Accept(this).As<ExpressionSyntaxNode>(); - } + RefPtr<ExpressionSyntaxNode> CheckTerm(RefPtr<ExpressionSyntaxNode> term) + { + if (!term) return nullptr; + return term->Accept(this).As<ExpressionSyntaxNode>(); + } - RefPtr<ExpressionSyntaxNode> CreateErrorExpr(ExpressionSyntaxNode* expr) - { - expr->Type = ExpressionType::Error; - return expr; - } + RefPtr<ExpressionSyntaxNode> CreateErrorExpr(ExpressionSyntaxNode* expr) + { + expr->Type = ExpressionType::Error; + return expr; + } - bool IsErrorExpr(RefPtr<ExpressionSyntaxNode> expr) - { - // TODO: we may want other cases here... + bool IsErrorExpr(RefPtr<ExpressionSyntaxNode> expr) + { + // TODO: we may want other cases here... - if (expr->Type->Equals(ExpressionType::Error)) - return true; + if (expr->Type->Equals(ExpressionType::Error)) + return true; - return false; - } + return false; + } - // Capture the "base" expression in case this is a member reference - RefPtr<ExpressionSyntaxNode> GetBaseExpr(RefPtr<ExpressionSyntaxNode> expr) + // Capture the "base" expression in case this is a member reference + RefPtr<ExpressionSyntaxNode> GetBaseExpr(RefPtr<ExpressionSyntaxNode> expr) + { + if (auto memberExpr = expr.As<MemberExpressionSyntaxNode>()) { - if (auto memberExpr = expr.As<MemberExpressionSyntaxNode>()) - { - return memberExpr->BaseExpression; - } - else if(auto overloadedExpr = expr.As<OverloadedExpr>()) - { - return overloadedExpr->base; - } - return nullptr; + return memberExpr->BaseExpression; + } + else if(auto overloadedExpr = expr.As<OverloadedExpr>()) + { + return overloadedExpr->base; } + return nullptr; + } - public: + public: - typedef unsigned int ConversionCost; - enum : ConversionCost - { - // No conversion at all - kConversionCost_None = 0, + typedef unsigned int ConversionCost; + enum : ConversionCost + { + // No conversion at all + kConversionCost_None = 0, - // Conversions based on explicit sub-typing relationships are the cheapest - // - // TODO(tfoley): We will eventually need a discipline for ranking - // when two up-casts are comparable. - kConversionCost_CastToInterface = 50, + // Conversions based on explicit sub-typing relationships are the cheapest + // + // TODO(tfoley): We will eventually need a discipline for ranking + // when two up-casts are comparable. + kConversionCost_CastToInterface = 50, - // Conversion that is lossless and keeps the "kind" of the value the same - kConversionCost_RankPromotion = 100, + // Conversion that is lossless and keeps the "kind" of the value the same + kConversionCost_RankPromotion = 100, - // Conversions that are lossless, but change "kind" - kConversionCost_UnsignedToSignedPromotion = 200, + // Conversions that are lossless, but change "kind" + kConversionCost_UnsignedToSignedPromotion = 200, - // Conversion from signed->unsigned integer of same or greater size - kConversionCost_SignedToUnsignedConversion = 300, + // Conversion from signed->unsigned integer of same or greater size + kConversionCost_SignedToUnsignedConversion = 300, - // Cost of converting an integer to a floating-point type - kConversionCost_IntegerToFloatConversion = 400, + // Cost of converting an integer to a floating-point type + kConversionCost_IntegerToFloatConversion = 400, - // Catch-all for conversions that should be discouraged - // (i.e., that really shouldn't be made implicitly) - // - // TODO: make these conversions not be allowed implicitly in "Slang mode" - kConversionCost_GeneralConversion = 900, + // Catch-all for conversions that should be discouraged + // (i.e., that really shouldn't be made implicitly) + // + // TODO: make these conversions not be allowed implicitly in "Slang mode" + kConversionCost_GeneralConversion = 900, - // Additional conversion cost to add when promoting from a scalar to - // a vector (this will be added to the cost, if any, of converting - // the element type of the vector) - kConversionCost_ScalarToVector = 1, - }; + // Additional conversion cost to add when promoting from a scalar to + // a vector (this will be added to the cost, if any, of converting + // the element type of the vector) + kConversionCost_ScalarToVector = 1, + }; - enum BaseTypeConversionKind : uint8_t - { - kBaseTypeConversionKind_Signed, - kBaseTypeConversionKind_Unsigned, - kBaseTypeConversionKind_Float, - kBaseTypeConversionKind_Error, - }; + enum BaseTypeConversionKind : uint8_t + { + kBaseTypeConversionKind_Signed, + kBaseTypeConversionKind_Unsigned, + kBaseTypeConversionKind_Float, + kBaseTypeConversionKind_Error, + }; - enum BaseTypeConversionRank : uint8_t - { - kBaseTypeConversionRank_Bool, - kBaseTypeConversionRank_Int8, - kBaseTypeConversionRank_Int16, - kBaseTypeConversionRank_Int32, - kBaseTypeConversionRank_IntPtr, - kBaseTypeConversionRank_Int64, - kBaseTypeConversionRank_Error, - }; + enum BaseTypeConversionRank : uint8_t + { + kBaseTypeConversionRank_Bool, + kBaseTypeConversionRank_Int8, + kBaseTypeConversionRank_Int16, + kBaseTypeConversionRank_Int32, + kBaseTypeConversionRank_IntPtr, + kBaseTypeConversionRank_Int64, + kBaseTypeConversionRank_Error, + }; - struct BaseTypeConversionInfo - { - BaseTypeConversionKind kind; - BaseTypeConversionRank rank; - }; - static BaseTypeConversionInfo GetBaseTypeConversionInfo(BaseType baseType) + struct BaseTypeConversionInfo + { + BaseTypeConversionKind kind; + BaseTypeConversionRank rank; + }; + static BaseTypeConversionInfo GetBaseTypeConversionInfo(BaseType baseType) + { + switch (baseType) { - switch (baseType) - { - #define CASE(TAG, KIND, RANK) \ - case BaseType::TAG: { BaseTypeConversionInfo info = {kBaseTypeConversionKind_##KIND, kBaseTypeConversionRank_##RANK}; return info; } break + #define CASE(TAG, KIND, RANK) \ + case BaseType::TAG: { BaseTypeConversionInfo info = {kBaseTypeConversionKind_##KIND, kBaseTypeConversionRank_##RANK}; return info; } break - CASE(Bool, Unsigned, Bool); - CASE(Int, Signed, Int32); - CASE(UInt, Unsigned, Int32); - CASE(UInt64, Unsigned, Int64); - CASE(Float, Float, Int32); - CASE(Void, Error, Error); + CASE(Bool, Unsigned, Bool); + CASE(Int, Signed, Int32); + CASE(UInt, Unsigned, Int32); + CASE(UInt64, Unsigned, Int64); + CASE(Float, Float, Int32); + CASE(Void, Error, Error); - #undef CASE + #undef CASE - default: - break; - } - SLANG_UNREACHABLE("all cases handled"); + default: + break; } + SLANG_UNREACHABLE("all cases handled"); + } - bool ValuesAreEqual( - RefPtr<IntVal> left, - RefPtr<IntVal> right) - { - if(left == right) return true; + bool ValuesAreEqual( + RefPtr<IntVal> left, + RefPtr<IntVal> right) + { + if(left == right) return true; - if(auto leftConst = left.As<ConstantIntVal>()) + if(auto leftConst = left.As<ConstantIntVal>()) + { + if(auto rightConst = right.As<ConstantIntVal>()) { - if(auto rightConst = right.As<ConstantIntVal>()) - { - return leftConst->value == rightConst->value; - } + return leftConst->value == rightConst->value; } + } - if(auto leftVar = left.As<GenericParamIntVal>()) + if(auto leftVar = left.As<GenericParamIntVal>()) + { + if(auto rightVar = right.As<GenericParamIntVal>()) { - if(auto rightVar = right.As<GenericParamIntVal>()) - { - return leftVar->declRef.Equals(rightVar->declRef); - } + return leftVar->declRef.Equals(rightVar->declRef); } - - return false; } - // Central engine for implementing implicit coercion logic - bool TryCoerceImpl( - RefPtr<ExpressionType> toType, // the target type for conversion - RefPtr<ExpressionSyntaxNode>* outToExpr, // (optional) a place to stuff the target expression - RefPtr<ExpressionType> fromType, // the source type for the conversion - RefPtr<ExpressionSyntaxNode> fromExpr, // the source expression - ConversionCost* outCost) // (optional) a place to stuff the conversion cost + return false; + } + + // Central engine for implementing implicit coercion logic + bool TryCoerceImpl( + RefPtr<ExpressionType> toType, // the target type for conversion + RefPtr<ExpressionSyntaxNode>* outToExpr, // (optional) a place to stuff the target expression + RefPtr<ExpressionType> fromType, // the source type for the conversion + RefPtr<ExpressionSyntaxNode> fromExpr, // the source expression + ConversionCost* outCost) // (optional) a place to stuff the conversion cost + { + // Easy case: the types are equal + if (toType->Equals(fromType)) { - // Easy case: the types are equal - if (toType->Equals(fromType)) - { - if (outToExpr) - *outToExpr = fromExpr; - if (outCost) - *outCost = kConversionCost_None; - return true; - } + if (outToExpr) + *outToExpr = fromExpr; + if (outCost) + *outCost = kConversionCost_None; + return true; + } - // If either type is an error, then let things pass. - if (toType->As<ErrorType>() || fromType->As<ErrorType>()) - { - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) - *outCost = kConversionCost_None; - return true; - } + // If either type is an error, then let things pass. + if (toType->As<ErrorType>() || fromType->As<ErrorType>()) + { + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = kConversionCost_None; + return true; + } - // Coercion from an initializer list is allowed for many types - if( auto fromInitializerListExpr = fromExpr.As<InitializerListExpr>()) - { - auto argCount = fromInitializerListExpr->args.Count(); + // Coercion from an initializer list is allowed for many types + if( auto fromInitializerListExpr = fromExpr.As<InitializerListExpr>()) + { + auto argCount = fromInitializerListExpr->args.Count(); - // In the case where we need to build a reuslt expression, - // we will collect the new arguments here - List<RefPtr<ExpressionSyntaxNode>> coercedArgs; + // In the case where we need to build a reuslt expression, + // we will collect the new arguments here + List<RefPtr<ExpressionSyntaxNode>> coercedArgs; - if(auto toDeclRefType = toType->As<DeclRefType>()) + if(auto toDeclRefType = toType->As<DeclRefType>()) + { + auto toTypeDeclRef = toDeclRefType->declRef; + if(auto toStructDeclRef = toTypeDeclRef.As<StructDeclRef>()) { - auto toTypeDeclRef = toDeclRefType->declRef; - if(auto toStructDeclRef = toTypeDeclRef.As<StructDeclRef>()) - { - // Trying to initialize a `struct` type given an initializer list. - // We will go through the fields in order and try to match them - // up with initializer arguments. + // Trying to initialize a `struct` type given an initializer list. + // We will go through the fields in order and try to match them + // up with initializer arguments. - int argIndex = 0; - for(auto fieldDeclRef : toStructDeclRef.GetMembersOfType<FieldDeclRef>()) + int argIndex = 0; + for(auto fieldDeclRef : toStructDeclRef.GetMembersOfType<FieldDeclRef>()) + { + if(argIndex >= argCount) { - if(argIndex >= argCount) - { - // We've consumed all the arguments, so we should stop - break; - } - - auto arg = fromInitializerListExpr->args[argIndex++]; - - // - RefPtr<ExpressionSyntaxNode> coercedArg; - ConversionCost argCost; - - bool argResult = TryCoerceImpl( - fieldDeclRef.GetType(), - outToExpr ? &coercedArg : nullptr, - arg->Type, - arg, - outCost ? &argCost : nullptr); - - // No point in trying further if any argument fails - if(!argResult) - return false; - - // TODO(tfoley): what to do with cost? - // This only matters if/when we allow an initializer list as an argument to - // an overloaded call. - - if( outToExpr ) - { - coercedArgs.Add(coercedArg); - } + // We've consumed all the arguments, so we should stop + break; } - } - } - else if(auto toArrayType = toType->As<ArrayExpressionType>()) - { - // TODO(tfoley): If we can compute the size of the array statically, - // then we want to check that there aren't too many initializers present - auto toElementType = toArrayType->BaseType; + auto arg = fromInitializerListExpr->args[argIndex++]; - for(auto& arg : fromInitializerListExpr->args) - { + // RefPtr<ExpressionSyntaxNode> coercedArg; ConversionCost argCost; bool argResult = TryCoerceImpl( - toElementType, - outToExpr ? &coercedArg : nullptr, - arg->Type, - arg, - outCost ? &argCost : nullptr); + fieldDeclRef.GetType(), + outToExpr ? &coercedArg : nullptr, + arg->Type, + arg, + outCost ? &argCost : nullptr); // No point in trying further if any argument fails if(!argResult) return false; + // TODO(tfoley): what to do with cost? + // This only matters if/when we allow an initializer list as an argument to + // an overloaded call. + if( outToExpr ) { coercedArgs.Add(coercedArg); } } } - else - { - // By default, we don't allow a type to be initialized using - // an initializer list. - return false; - } + } + else if(auto toArrayType = toType->As<ArrayExpressionType>()) + { + // TODO(tfoley): If we can compute the size of the array statically, + // then we want to check that there aren't too many initializers present - // For now, coercion from an initializer list has no cost - if(outCost) + auto toElementType = toArrayType->BaseType; + + for(auto& arg : fromInitializerListExpr->args) { - *outCost = kConversionCost_None; + RefPtr<ExpressionSyntaxNode> coercedArg; + ConversionCost argCost; + + bool argResult = TryCoerceImpl( + toElementType, + outToExpr ? &coercedArg : nullptr, + arg->Type, + arg, + outCost ? &argCost : nullptr); + + // No point in trying further if any argument fails + if(!argResult) + return false; + + if( outToExpr ) + { + coercedArgs.Add(coercedArg); + } } + } + else + { + // By default, we don't allow a type to be initialized using + // an initializer list. + return false; + } - // We were able to coerce all the arguments given, and so - // we need to construct a suitable expression to remember the result - if(outToExpr) - { - auto toInitializerListExpr = new InitializerListExpr(); - toInitializerListExpr->Position = fromInitializerListExpr->Position; - toInitializerListExpr->Type = toType; - toInitializerListExpr->args = coercedArgs; + // For now, coercion from an initializer list has no cost + if(outCost) + { + *outCost = kConversionCost_None; + } + // We were able to coerce all the arguments given, and so + // we need to construct a suitable expression to remember the result + if(outToExpr) + { + auto toInitializerListExpr = new InitializerListExpr(); + toInitializerListExpr->Position = fromInitializerListExpr->Position; + toInitializerListExpr->Type = toType; + toInitializerListExpr->args = coercedArgs; - *outToExpr = toInitializerListExpr; - } - return true; + *outToExpr = toInitializerListExpr; } - // + return true; + } - if (auto toBasicType = toType->AsBasicType()) + // + + if (auto toBasicType = toType->AsBasicType()) + { + if (auto fromBasicType = fromType->AsBasicType()) { - if (auto fromBasicType = fromType->AsBasicType()) - { - // Conversions between base types are always allowed, - // and the only question is what the cost will be. + // Conversions between base types are always allowed, + // and the only question is what the cost will be. - auto toInfo = GetBaseTypeConversionInfo(toBasicType->BaseType); - auto fromInfo = GetBaseTypeConversionInfo(fromBasicType->BaseType); + auto toInfo = GetBaseTypeConversionInfo(toBasicType->BaseType); + auto fromInfo = GetBaseTypeConversionInfo(fromBasicType->BaseType); - // We expect identical types to have been dealt with already. - assert(toInfo.kind != fromInfo.kind || toInfo.rank != fromInfo.rank); + // We expect identical types to have been dealt with already. + assert(toInfo.kind != fromInfo.kind || toInfo.rank != fromInfo.rank); - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) + if (outCost) + { + // Conversions within the same kind are easist to handle + if (toInfo.kind == fromInfo.kind) { - // Conversions within the same kind are easist to handle - if (toInfo.kind == fromInfo.kind) - { - // If we are converting to a "larger" type, then - // we are doing a lossless promotion, and otherwise - // we are doing a demotion. - if( toInfo.rank > fromInfo.rank) - *outCost = kConversionCost_RankPromotion; - else - *outCost = kConversionCost_GeneralConversion; - } - // If we are converting from an unsigned integer type to - // a signed integer type that is guaranteed to be larger, - // then that is also a lossless promotion. - else if(toInfo.kind == kBaseTypeConversionKind_Signed - && fromInfo.kind == kBaseTypeConversionKind_Unsigned - && toInfo.rank > fromInfo.rank) - { - // TODO: probably need to weed out cases involving - // "pointer-sized" integers if these are treated - // as distinct from 32- and 64-bit types. - // E.g., there is no guarantee that conversion - // from 32-bit unsigned to pointer-sized signed - // is lossless, because pointers could be 32-bit, - // and the same applies for conversion from - // `uintptr` to `uint64`. - *outCost = kConversionCost_UnsignedToSignedPromotion; - } - // Conversion from signed to unsigned is always lossy, - // but it is preferred over conversions from unsigned - // to signed, for same-size types. - else if(toInfo.kind == kBaseTypeConversionKind_Unsigned - && fromInfo.kind == kBaseTypeConversionKind_Signed - && toInfo.rank >= fromInfo.rank) - { - *outCost = kConversionCost_SignedToUnsignedConversion; - } - // Conversion from an integer to a floating-point type - // is never considered a promotion (even when the value - // would fit in the available bits). - // If the destination type is at least 32 bits we consider - // this a reasonably good conversion, though. - else if (toInfo.kind == kBaseTypeConversionKind_Float - && toInfo.rank >= kBaseTypeConversionRank_Int32) - { - *outCost = kConversionCost_IntegerToFloatConversion; - } - // All other cases are considered as "general" conversions, - // where we don't consider any one conversion better than - // any others. + // If we are converting to a "larger" type, then + // we are doing a lossless promotion, and otherwise + // we are doing a demotion. + if( toInfo.rank > fromInfo.rank) + *outCost = kConversionCost_RankPromotion; else - { *outCost = kConversionCost_GeneralConversion; - } } - - return true; + // If we are converting from an unsigned integer type to + // a signed integer type that is guaranteed to be larger, + // then that is also a lossless promotion. + else if(toInfo.kind == kBaseTypeConversionKind_Signed + && fromInfo.kind == kBaseTypeConversionKind_Unsigned + && toInfo.rank > fromInfo.rank) + { + // TODO: probably need to weed out cases involving + // "pointer-sized" integers if these are treated + // as distinct from 32- and 64-bit types. + // E.g., there is no guarantee that conversion + // from 32-bit unsigned to pointer-sized signed + // is lossless, because pointers could be 32-bit, + // and the same applies for conversion from + // `uintptr` to `uint64`. + *outCost = kConversionCost_UnsignedToSignedPromotion; + } + // Conversion from signed to unsigned is always lossy, + // but it is preferred over conversions from unsigned + // to signed, for same-size types. + else if(toInfo.kind == kBaseTypeConversionKind_Unsigned + && fromInfo.kind == kBaseTypeConversionKind_Signed + && toInfo.rank >= fromInfo.rank) + { + *outCost = kConversionCost_SignedToUnsignedConversion; + } + // Conversion from an integer to a floating-point type + // is never considered a promotion (even when the value + // would fit in the available bits). + // If the destination type is at least 32 bits we consider + // this a reasonably good conversion, though. + else if (toInfo.kind == kBaseTypeConversionKind_Float + && toInfo.rank >= kBaseTypeConversionRank_Int32) + { + *outCost = kConversionCost_IntegerToFloatConversion; + } + // All other cases are considered as "general" conversions, + // where we don't consider any one conversion better than + // any others. + else + { + *outCost = kConversionCost_GeneralConversion; + } } + + return true; } + } - if (auto toVectorType = toType->AsVectorType()) + if (auto toVectorType = toType->AsVectorType()) + { + if (auto fromVectorType = fromType->AsVectorType()) { - if (auto fromVectorType = fromType->AsVectorType()) - { - // Conversion between vector types. + // Conversion between vector types. - // If element counts don't match, then bail: - if (!ValuesAreEqual(toVectorType->elementCount, fromVectorType->elementCount)) - return false; + // If element counts don't match, then bail: + if (!ValuesAreEqual(toVectorType->elementCount, fromVectorType->elementCount)) + return false; - // Otherwise, if we can convert the element types, we are golden - ConversionCost elementCost; - if (CanCoerce(toVectorType->elementType, fromVectorType->elementType, &elementCost)) - { - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) - *outCost = elementCost; - return true; - } + // Otherwise, if we can convert the element types, we are golden + ConversionCost elementCost; + if (CanCoerce(toVectorType->elementType, fromVectorType->elementType, &elementCost)) + { + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = elementCost; + return true; } - else if (auto fromScalarType = fromType->AsBasicType()) + } + else if (auto fromScalarType = fromType->AsBasicType()) + { + // Conversion from scalar to vector. + // Should allow as long as we can coerce the scalar to our element type. + ConversionCost elementCost; + if (CanCoerce(toVectorType->elementType, fromScalarType, &elementCost)) { - // Conversion from scalar to vector. - // Should allow as long as we can coerce the scalar to our element type. - ConversionCost elementCost; - if (CanCoerce(toVectorType->elementType, fromScalarType, &elementCost)) - { - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) - *outCost = elementCost + kConversionCost_ScalarToVector; - return true; - } + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = elementCost + kConversionCost_ScalarToVector; + return true; } } + } - if (auto toDeclRefType = toType->As<DeclRefType>()) + if (auto toDeclRefType = toType->As<DeclRefType>()) + { + auto toTypeDeclRef = toDeclRefType->declRef; + if (auto interfaceDeclRef = toTypeDeclRef.As<InterfaceDeclRef>()) { - auto toTypeDeclRef = toDeclRefType->declRef; - if (auto interfaceDeclRef = toTypeDeclRef.As<InterfaceDeclRef>()) + // Trying to convert to an interface type. + // + // We will allow this if the type conforms to the interface. + if (DoesTypeConformToInterface(fromType, interfaceDeclRef)) { - // Trying to convert to an interface type. - // - // We will allow this if the type conforms to the interface. - if (DoesTypeConformToInterface(fromType, interfaceDeclRef)) - { - if (outToExpr) - *outToExpr = CreateImplicitCastExpr(toType, fromExpr); - if (outCost) - *outCost = kConversionCost_CastToInterface; - return true; - } + if (outToExpr) + *outToExpr = CreateImplicitCastExpr(toType, fromExpr); + if (outCost) + *outCost = kConversionCost_CastToInterface; + return true; } } + } - // TODO: more cases! + // TODO: more cases! - return false; - } + return false; + } - // Check whether a type coercion is possible - bool CanCoerce( - RefPtr<ExpressionType> toType, // the target type for conversion - RefPtr<ExpressionType> fromType, // the source type for the conversion - ConversionCost* outCost = 0) // (optional) a place to stuff the conversion cost - { - return TryCoerceImpl( - toType, - nullptr, - fromType, - nullptr, - outCost); - } + // Check whether a type coercion is possible + bool CanCoerce( + RefPtr<ExpressionType> toType, // the target type for conversion + RefPtr<ExpressionType> fromType, // the source type for the conversion + ConversionCost* outCost = 0) // (optional) a place to stuff the conversion cost + { + return TryCoerceImpl( + toType, + nullptr, + fromType, + nullptr, + outCost); + } - RefPtr<ExpressionSyntaxNode> CreateImplicitCastExpr( - RefPtr<ExpressionType> toType, - RefPtr<ExpressionSyntaxNode> fromExpr) + RefPtr<ExpressionSyntaxNode> CreateImplicitCastExpr( + RefPtr<ExpressionType> toType, + RefPtr<ExpressionSyntaxNode> fromExpr) + { + auto castExpr = new TypeCastExpressionSyntaxNode(); + castExpr->Position = fromExpr->Position; + castExpr->TargetType.type = toType; + castExpr->Type = toType; + castExpr->Expression = fromExpr; + return castExpr; + } + + + // Perform type coercion, and emit errors if it isn't possible + RefPtr<ExpressionSyntaxNode> Coerce( + RefPtr<ExpressionType> toType, + RefPtr<ExpressionSyntaxNode> fromExpr) + { + // If semantic checking is being suppressed, then we might see + // expressions without a type, and we need to ignore them. + if( !fromExpr->Type.type ) { - auto castExpr = new TypeCastExpressionSyntaxNode(); - castExpr->Position = fromExpr->Position; - castExpr->TargetType.type = toType; - castExpr->Type = toType; - castExpr->Expression = fromExpr; - return castExpr; + if(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING ) + return fromExpr; } - - // Perform type coercion, and emit errors if it isn't possible - RefPtr<ExpressionSyntaxNode> Coerce( - RefPtr<ExpressionType> toType, - RefPtr<ExpressionSyntaxNode> fromExpr) + RefPtr<ExpressionSyntaxNode> expr; + if (!TryCoerceImpl( + toType, + &expr, + fromExpr->Type.Ptr(), + fromExpr.Ptr(), + nullptr)) { - // If semantic checking is being suppressed, then we might see - // expressions without a type, and we need to ignore them. - if( !fromExpr->Type.type ) + if(!(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING)) { - if(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING ) - return fromExpr; + getSink()->diagnose(fromExpr->Position, Diagnostics::typeMismatch, toType, fromExpr->Type); } - RefPtr<ExpressionSyntaxNode> expr; - if (!TryCoerceImpl( - toType, - &expr, - fromExpr->Type.Ptr(), - fromExpr.Ptr(), - nullptr)) - { - if(!(getOptions().flags & SLANG_COMPILE_FLAG_NO_CHECKING)) - { - getSink()->diagnose(fromExpr->Position, Diagnostics::typeMismatch, toType, fromExpr->Type); - } - - // Note(tfoley): We don't call `CreateErrorExpr` here, because that would - // clobber the type on `fromExpr`, and an invariant here is that coercion - // really shouldn't *change* the expression that is passed in, but should - // introduce new AST nodes to coerce its value to a different type... - return CreateImplicitCastExpr(ExpressionType::Error, fromExpr); - } - return expr; + // Note(tfoley): We don't call `CreateErrorExpr` here, because that would + // clobber the type on `fromExpr`, and an invariant here is that coercion + // really shouldn't *change* the expression that is passed in, but should + // introduce new AST nodes to coerce its value to a different type... + return CreateImplicitCastExpr(ExpressionType::Error, fromExpr); } + return expr; + } - void CheckVarDeclCommon(RefPtr<VarDeclBase> varDecl) - { - // Check the type, if one was given - TypeExp type = CheckUsableType(varDecl->Type); + void CheckVarDeclCommon(RefPtr<VarDeclBase> varDecl) + { + // Check the type, if one was given + TypeExp type = CheckUsableType(varDecl->Type); - // TODO: Additional validation rules on types should go here, - // but we need to deal with the fact that some cases might be - // allowed in one context (e.g., an unsized array parameter) - // but not in othters (e.g., an unsized array field in a struct). + // TODO: Additional validation rules on types should go here, + // but we need to deal with the fact that some cases might be + // allowed in one context (e.g., an unsized array parameter) + // but not in othters (e.g., an unsized array field in a struct). - // Check the initializers, if one was given - RefPtr<ExpressionSyntaxNode> initExpr = CheckTerm(varDecl->Expr); + // Check the initializers, if one was given + RefPtr<ExpressionSyntaxNode> initExpr = CheckTerm(varDecl->Expr); - // If a type was given, ... - if (type.Ptr()) + // If a type was given, ... + if (type.Ptr()) + { + // then coerce any initializer to the type + if (initExpr) { - // then coerce any initializer to the type - if (initExpr) - { - initExpr = Coerce(type, initExpr); - } + initExpr = Coerce(type, initExpr); + } + } + else + { + // TODO: infer a type from the initializers + if (!initExpr) + { + getSink()->diagnose(varDecl, Diagnostics::unimplemented, "variable declaration with no type must have initializer"); } else { - // TODO: infer a type from the initializers - if (!initExpr) - { - getSink()->diagnose(varDecl, Diagnostics::unimplemented, "variable declaration with no type must have initializer"); - } - else - { - getSink()->diagnose(varDecl, Diagnostics::unimplemented, "type inference for variable declaration"); - } + getSink()->diagnose(varDecl, Diagnostics::unimplemented, "type inference for variable declaration"); } - - varDecl->Type = type; - varDecl->Expr = initExpr; } - void CheckGenericConstraintDecl(GenericTypeConstraintDecl* decl) - { - // TODO: are there any other validations we can do at this point? - // - // There probably needs to be a kind of "occurs check" to make - // sure that the constraint actually applies to at least one - // of the parameters of the generic. + varDecl->Type = type; + varDecl->Expr = initExpr; + } - decl->sub = TranslateTypeNode(decl->sub); - decl->sup = TranslateTypeNode(decl->sup); - } + void CheckGenericConstraintDecl(GenericTypeConstraintDecl* decl) + { + // TODO: are there any other validations we can do at this point? + // + // There probably needs to be a kind of "occurs check" to make + // sure that the constraint actually applies to at least one + // of the parameters of the generic. - virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl* genericDecl) override + decl->sub = TranslateTypeNode(decl->sub); + decl->sup = TranslateTypeNode(decl->sup); + } + + virtual RefPtr<GenericDecl> VisitGenericDecl(GenericDecl* genericDecl) override + { + // check the parameters + for (auto m : genericDecl->Members) { - // check the parameters - for (auto m : genericDecl->Members) + if (auto typeParam = m.As<GenericTypeParamDecl>()) { - if (auto typeParam = m.As<GenericTypeParamDecl>()) - { - typeParam->initType = CheckProperType(typeParam->initType); - } - else if (auto valParam = m.As<GenericValueParamDecl>()) - { - // TODO: some real checking here... - CheckVarDeclCommon(valParam); - } - else if(auto constraint = m.As<GenericTypeConstraintDecl>()) - { - CheckGenericConstraintDecl(constraint.Ptr()); - } + typeParam->initType = CheckProperType(typeParam->initType); + } + else if (auto valParam = m.As<GenericValueParamDecl>()) + { + // TODO: some real checking here... + CheckVarDeclCommon(valParam); + } + else if(auto constraint = m.As<GenericTypeConstraintDecl>()) + { + CheckGenericConstraintDecl(constraint.Ptr()); } - - // check the nested declaration - // TODO: this needs to be done in an appropriate environment... - genericDecl->inner->Accept(this); - return genericDecl; } - virtual void visitInterfaceDecl(InterfaceDecl* decl) override - { - // TODO: do some actual checking of members here - } + // check the nested declaration + // TODO: this needs to be done in an appropriate environment... + genericDecl->inner->Accept(this); + return genericDecl; + } - virtual void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) override - { - // check the type being inherited from - auto base = inheritanceDecl->base; - base = TranslateTypeNode(base); - inheritanceDecl->base = base; + virtual void visitInterfaceDecl(InterfaceDecl* decl) override + { + // TODO: do some actual checking of members here + } + + virtual void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) override + { + // check the type being inherited from + auto base = inheritanceDecl->base; + base = TranslateTypeNode(base); + inheritanceDecl->base = base; - // For now we only allow inheritance from interfaces, so - // we will validate that the type expression names an interface + // For now we only allow inheritance from interfaces, so + // we will validate that the type expression names an interface - if(auto declRefType = base.type->As<DeclRefType>()) + if(auto declRefType = base.type->As<DeclRefType>()) + { + if(auto interfaceDeclRef = declRefType->declRef.As<InterfaceDeclRef>()) { - if(auto interfaceDeclRef = declRefType->declRef.As<InterfaceDeclRef>()) - { - return; - } + return; } - - // If type expression didn't name an interface, we'll emit an error here - // TODO: deal with the case of an error in the type expression (don't cascade) - getSink()->diagnose( base.exp, Diagnostics::expectedAnInterfaceGot, base.type); } - RefPtr<ConstantIntVal> checkConstantIntVal( - RefPtr<ExpressionSyntaxNode> expr) - { - // First type-check the expression as normal - expr = CheckExpr(expr); + // If type expression didn't name an interface, we'll emit an error here + // TODO: deal with the case of an error in the type expression (don't cascade) + getSink()->diagnose( base.exp, Diagnostics::expectedAnInterfaceGot, base.type); + } - auto intVal = CheckIntegerConstantExpression(expr.Ptr()); - if(!intVal) - return nullptr; + RefPtr<ConstantIntVal> checkConstantIntVal( + RefPtr<ExpressionSyntaxNode> expr) + { + // First type-check the expression as normal + expr = CheckExpr(expr); - auto constIntVal = intVal.As<ConstantIntVal>(); - if(!constIntVal) - { - getSink()->diagnose(expr->Position, Diagnostics::expectedIntegerConstantNotLiteral); - return nullptr; - } - return constIntVal; + auto intVal = CheckIntegerConstantExpression(expr.Ptr()); + if(!intVal) + return nullptr; + + auto constIntVal = intVal.As<ConstantIntVal>(); + if(!constIntVal) + { + getSink()->diagnose(expr->Position, Diagnostics::expectedIntegerConstantNotLiteral); + return nullptr; } + return constIntVal; + } - RefPtr<Modifier> checkModifier( - RefPtr<Modifier> m, - Decl* /*decl*/) + RefPtr<Modifier> checkModifier( + RefPtr<Modifier> m, + Decl* /*decl*/) + { + if(auto hlslUncheckedAttribute = m.As<HLSLUncheckedAttribute>()) { - if(auto hlslUncheckedAttribute = m.As<HLSLUncheckedAttribute>()) + // We have an HLSL `[name(arg,...)]` attribute, and we'd like + // to check that it is provides all the expected arguments + // + // For now we will do this in a completely ad hoc fashion, + // but it would be nice to have some generic routine to + // do the needed type checking/coercion. + if(hlslUncheckedAttribute->nameToken.Content == "numthreads") { - // We have an HLSL `[name(arg,...)]` attribute, and we'd like - // to check that it is provides all the expected arguments - // - // For now we will do this in a completely ad hoc fashion, - // but it would be nice to have some generic routine to - // do the needed type checking/coercion. - if(hlslUncheckedAttribute->nameToken.Content == "numthreads") - { - if(hlslUncheckedAttribute->args.Count() != 3) - return m; + if(hlslUncheckedAttribute->args.Count() != 3) + return m; - auto xVal = checkConstantIntVal(hlslUncheckedAttribute->args[0]); - auto yVal = checkConstantIntVal(hlslUncheckedAttribute->args[1]); - auto zVal = checkConstantIntVal(hlslUncheckedAttribute->args[2]); + auto xVal = checkConstantIntVal(hlslUncheckedAttribute->args[0]); + auto yVal = checkConstantIntVal(hlslUncheckedAttribute->args[1]); + auto zVal = checkConstantIntVal(hlslUncheckedAttribute->args[2]); - if(!xVal) return m; - if(!yVal) return m; - if(!zVal) return m; + if(!xVal) return m; + if(!yVal) return m; + if(!zVal) return m; - auto hlslNumThreadsAttribute = new HLSLNumThreadsAttribute(); + auto hlslNumThreadsAttribute = new HLSLNumThreadsAttribute(); - hlslNumThreadsAttribute->Position = hlslUncheckedAttribute->Position; - hlslNumThreadsAttribute->nameToken = hlslUncheckedAttribute->nameToken; - hlslNumThreadsAttribute->args = hlslUncheckedAttribute->args; - hlslNumThreadsAttribute->x = xVal->value; - hlslNumThreadsAttribute->y = yVal->value; - hlslNumThreadsAttribute->z = zVal->value; + hlslNumThreadsAttribute->Position = hlslUncheckedAttribute->Position; + hlslNumThreadsAttribute->nameToken = hlslUncheckedAttribute->nameToken; + hlslNumThreadsAttribute->args = hlslUncheckedAttribute->args; + hlslNumThreadsAttribute->x = xVal->value; + hlslNumThreadsAttribute->y = yVal->value; + hlslNumThreadsAttribute->z = zVal->value; - return hlslNumThreadsAttribute; - } + return hlslNumThreadsAttribute; } + } - // Default behavior is to leave things as they are, - // and assume that modifiers are mostly already checked. - // - // TODO: This would be a good place to validate that - // a modifier is actually valid for the thing it is - // being applied to, and potentially to check that - // it isn't in conflict with any other modifiers - // on the same declaration. + // Default behavior is to leave things as they are, + // and assume that modifiers are mostly already checked. + // + // TODO: This would be a good place to validate that + // a modifier is actually valid for the thing it is + // being applied to, and potentially to check that + // it isn't in conflict with any other modifiers + // on the same declaration. - return m; - } + return m; + } - void checkModifiers(Decl* decl) + void checkModifiers(Decl* decl) + { + // TODO(tfoley): need to make sure this only + // performs semantic checks on a `SharedModifier` once... + + // The process of checking a modifier may produce a new modifier in its place, + // so we will build up a new linked list of modifiers that will replace + // the old list. + RefPtr<Modifier> resultModifiers; + RefPtr<Modifier>* resultModifierLink = &resultModifiers; + + RefPtr<Modifier> modifier = decl->modifiers.first; + while(modifier) { - // TODO(tfoley): need to make sure this only - // performs semantic checks on a `SharedModifier` once... + // Because we are rewriting the list in place, we need to extract + // the next modifier here (not at the end of the loop). + auto next = modifier->next; - // The process of checking a modifier may produce a new modifier in its place, - // so we will build up a new linked list of modifiers that will replace - // the old list. - RefPtr<Modifier> resultModifiers; - RefPtr<Modifier>* resultModifierLink = &resultModifiers; + // We also go ahead and clobber the `next` field on the modifier + // itself, so that the default behavior of `checkModifier()` can + // be to return a single unlinked modifier. + modifier->next = nullptr; - RefPtr<Modifier> modifier = decl->modifiers.first; - while(modifier) + auto checkedModifier = checkModifier(modifier, decl); + if(checkedModifier) { - // Because we are rewriting the list in place, we need to extract - // the next modifier here (not at the end of the loop). - auto next = modifier->next; + // If checking gave us a modifier to add, then we + // had better add it. - // We also go ahead and clobber the `next` field on the modifier - // itself, so that the default behavior of `checkModifier()` can - // be to return a single unlinked modifier. - modifier->next = nullptr; + // Just in case `checkModifier` ever returns multiple + // modifiers, lets advance to the end of the list we + // are building. + while(*resultModifierLink) + resultModifierLink = &(*resultModifierLink)->next; - auto checkedModifier = checkModifier(modifier, decl); - if(checkedModifier) - { - // If checking gave us a modifier to add, then we - // had better add it. - - // Just in case `checkModifier` ever returns multiple - // modifiers, lets advance to the end of the list we - // are building. - while(*resultModifierLink) - resultModifierLink = &(*resultModifierLink)->next; - - // attach the new modifier at the end of the list, - // and now set the "link" to point to its `next` field - *resultModifierLink = checkedModifier; - resultModifierLink = &checkedModifier->next; - } - - // Move along to the next modifier - modifier = next; + // attach the new modifier at the end of the list, + // and now set the "link" to point to its `next` field + *resultModifierLink = checkedModifier; + resultModifierLink = &checkedModifier->next; } - // Whether we actually re-wrote anything or note, lets - // install the new list of modifiers on the declaration - decl->modifiers.first = resultModifiers; + // Move along to the next modifier + modifier = next; } - virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode * programNode) override + // Whether we actually re-wrote anything or note, lets + // install the new list of modifiers on the declaration + decl->modifiers.first = resultModifiers; + } + + virtual RefPtr<ProgramSyntaxNode> VisitProgram(ProgramSyntaxNode * programNode) override + { + // Try to register all the builtin decls + for (auto decl : programNode->Members) { - // Try to register all the builtin decls - for (auto decl : programNode->Members) + auto inner = decl; + if (auto genericDecl = decl.As<GenericDecl>()) { - auto inner = decl; - if (auto genericDecl = decl.As<GenericDecl>()) - { - inner = genericDecl->inner; - } - - if (auto builtinMod = inner->FindModifier<BuiltinTypeModifier>()) - { - RegisterBuiltinDecl(decl, builtinMod); - } - if (auto magicMod = inner->FindModifier<MagicTypeModifier>()) - { - RegisterMagicDecl(decl, magicMod); - } + inner = genericDecl->inner; } - // - - HashSet<String> funcNames; - this->program = programNode; - this->function = nullptr; - - for (auto & s : program->GetTypeDefs()) - VisitTypeDefDecl(s.Ptr()); - for (auto & s : program->GetStructs()) + if (auto builtinMod = inner->FindModifier<BuiltinTypeModifier>()) { - VisitStruct(s.Ptr()); + RegisterBuiltinDecl(decl, builtinMod); } - for (auto & s : program->GetClasses()) - { - VisitClass(s.Ptr()); - } - // HACK(tfoley): Visiting all generic declarations here, - // because otherwise they won't get visited. - for (auto & g : program->GetMembersOfType<GenericDecl>()) + if (auto magicMod = inner->FindModifier<MagicTypeModifier>()) { - VisitGenericDecl(g.Ptr()); + RegisterMagicDecl(decl, magicMod); } + } - for (auto & func : program->GetFunctions()) - { - if (!func->IsChecked(DeclCheckState::Checked)) - { - VisitFunctionDeclaration(func.Ptr()); - } - } - for (auto & func : program->GetFunctions()) - { - EnsureDecl(func); - } - - if (sink->GetErrorCount() != 0) - return programNode; - - // Force everything to be fully checked, just in case - // Note that we don't just call this on the program, - // because we'd end up recursing into this very code path... - for (auto d : programNode->Members) - { - EnusreAllDeclsRec(d); - } + // - // Do any semantic checking required on modifiers? - for (auto d : programNode->Members) - { - checkModifiers(d.Ptr()); - } + HashSet<String> funcNames; + this->program = programNode; + this->function = nullptr; - return programNode; + for (auto & s : program->GetTypeDefs()) + VisitTypeDefDecl(s.Ptr()); + for (auto & s : program->GetStructs()) + { + VisitStruct(s.Ptr()); } - - virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * classNode) override + for (auto & s : program->GetClasses()) { - if (classNode->IsChecked(DeclCheckState::Checked)) - return classNode; - classNode->SetCheckState(DeclCheckState::Checked); - - for (auto field : classNode->GetFields()) - { - field->Type = CheckUsableType(field->Type); - field->SetCheckState(DeclCheckState::Checked); - } - return classNode; + VisitClass(s.Ptr()); } - - virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * structNode) override + // HACK(tfoley): Visiting all generic declarations here, + // because otherwise they won't get visited. + for (auto & g : program->GetMembersOfType<GenericDecl>()) { - if (structNode->IsChecked(DeclCheckState::Checked)) - return structNode; - structNode->SetCheckState(DeclCheckState::Checked); + VisitGenericDecl(g.Ptr()); + } - for (auto field : structNode->GetFields()) + for (auto & func : program->GetFunctions()) + { + if (!func->IsChecked(DeclCheckState::Checked)) { - field->Type = CheckUsableType(field->Type); - field->SetCheckState(DeclCheckState::Checked); + VisitFunctionDeclaration(func.Ptr()); } - return structNode; } - - virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl) override + for (auto & func : program->GetFunctions()) { - if (decl->IsChecked(DeclCheckState::Checked)) return decl; - - decl->SetCheckState(DeclCheckState::CheckingHeader); - decl->Type = CheckProperType(decl->Type); - decl->SetCheckState(DeclCheckState::Checked); - return decl; + EnsureDecl(func); } - - virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode *functionNode) override + + if (sink->GetErrorCount() != 0) + return programNode; + + // Force everything to be fully checked, just in case + // Note that we don't just call this on the program, + // because we'd end up recursing into this very code path... + for (auto d : programNode->Members) { - if (functionNode->IsChecked(DeclCheckState::Checked)) - return functionNode; - - VisitFunctionDeclaration(functionNode); - functionNode->SetCheckState(DeclCheckState::Checked); - - if (!functionNode->IsExtern()) - { - this->function = functionNode; - if (functionNode->Body) - { - functionNode->Body->Accept(this); - } - this->function = nullptr; - } - return functionNode; + EnusreAllDeclsRec(d); } - // Check if two functions have the same signature for the purposes - // of overload resolution. - bool DoFunctionSignaturesMatch( - FunctionSyntaxNode* fst, - FunctionSyntaxNode* snd) + // Do any semantic checking required on modifiers? + for (auto d : programNode->Members) { - // TODO(tfoley): This function won't do anything sensible for generics, - // so we need to figure out a plan for that... + checkModifiers(d.Ptr()); + } - // TODO(tfoley): This copies the parameter array, which is bad for performance. - auto fstParams = fst->GetParameters().ToArray(); - auto sndParams = snd->GetParameters().ToArray(); + return programNode; + } - // If the functions have different numbers of parameters, then - // their signatures trivially don't match. - auto fstParamCount = fstParams.Count(); - auto sndParamCount = sndParams.Count(); - if (fstParamCount != sndParamCount) - return false; + virtual RefPtr<ClassSyntaxNode> VisitClass(ClassSyntaxNode * classNode) override + { + if (classNode->IsChecked(DeclCheckState::Checked)) + return classNode; + classNode->SetCheckState(DeclCheckState::Checked); - for (int ii = 0; ii < fstParamCount; ++ii) - { - auto fstParam = fstParams[ii]; - auto sndParam = sndParams[ii]; + for (auto field : classNode->GetFields()) + { + field->Type = CheckUsableType(field->Type); + field->SetCheckState(DeclCheckState::Checked); + } + return classNode; + } - // If a given parameter type doesn't match, then signatures don't match - if (!fstParam->Type.Equals(sndParam->Type)) - return false; + virtual RefPtr<StructSyntaxNode> VisitStruct(StructSyntaxNode * structNode) override + { + if (structNode->IsChecked(DeclCheckState::Checked)) + return structNode; + structNode->SetCheckState(DeclCheckState::Checked); - // If one parameter is `out` and the other isn't, then they don't match - // - // Note(tfoley): we don't consider `out` and `inout` as distinct here, - // because there is no way for overload resolution to pick between them. - if (fstParam->HasModifier<OutModifier>() != sndParam->HasModifier<OutModifier>()) - return false; - } + for (auto field : structNode->GetFields()) + { + field->Type = CheckUsableType(field->Type); + field->SetCheckState(DeclCheckState::Checked); + } + return structNode; + } - // Note(tfoley): return type doesn't enter into it, because we can't take - // calling context into account during overload resolution. + virtual RefPtr<TypeDefDecl> VisitTypeDefDecl(TypeDefDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return decl; - return true; - } + decl->SetCheckState(DeclCheckState::CheckingHeader); + decl->Type = CheckProperType(decl->Type); + decl->SetCheckState(DeclCheckState::Checked); + return decl; + } - void ValidateFunctionRedeclaration(FunctionSyntaxNode* funcDecl) - { - auto parentDecl = funcDecl->ParentDecl; - assert(parentDecl); - if (!parentDecl) return; + virtual RefPtr<FunctionSyntaxNode> VisitFunction(FunctionSyntaxNode *functionNode) override + { + if (functionNode->IsChecked(DeclCheckState::Checked)) + return functionNode; - // Look at previously-declared functions with the same name, - // in the same container - buildMemberDictionary(parentDecl); + VisitFunctionDeclaration(functionNode); + functionNode->SetCheckState(DeclCheckState::Checked); - for (auto prevDecl = funcDecl->nextInContainerWithSameName; prevDecl; prevDecl = prevDecl->nextInContainerWithSameName) + if (!functionNode->IsExtern()) + { + this->function = functionNode; + if (functionNode->Body) { - // Look through generics to the declaration underneath - auto prevGenericDecl = dynamic_cast<GenericDecl*>(prevDecl); - if (prevGenericDecl) - prevDecl = prevGenericDecl->inner.Ptr(); - - // We only care about previously-declared functions - // Note(tfoley): although we should really error out if the - // name is already in use for something else, like a variable... - auto prevFuncDecl = dynamic_cast<FunctionSyntaxNode*>(prevDecl); - if (!prevFuncDecl) - continue; + functionNode->Body->Accept(this); + } + this->function = nullptr; + } + return functionNode; + } - // If the parameter signatures don't match, then don't worry - if (!DoFunctionSignaturesMatch(funcDecl, prevFuncDecl)) - continue; + // Check if two functions have the same signature for the purposes + // of overload resolution. + bool DoFunctionSignaturesMatch( + FunctionSyntaxNode* fst, + FunctionSyntaxNode* snd) + { + // TODO(tfoley): This function won't do anything sensible for generics, + // so we need to figure out a plan for that... + + // TODO(tfoley): This copies the parameter array, which is bad for performance. + auto fstParams = fst->GetParameters().ToArray(); + auto sndParams = snd->GetParameters().ToArray(); + + // If the functions have different numbers of parameters, then + // their signatures trivially don't match. + auto fstParamCount = fstParams.Count(); + auto sndParamCount = sndParams.Count(); + if (fstParamCount != sndParamCount) + return false; - // If we get this far, then we've got two declarations in the same - // scope, with the same name and signature. - // - // They might just be redeclarations, which we would want to allow. + for (int ii = 0; ii < fstParamCount; ++ii) + { + auto fstParam = fstParams[ii]; + auto sndParam = sndParams[ii]; - // First, check if the return types match. - // TODO(tfolye): this code won't work for generics - if (!funcDecl->ReturnType.Equals(prevFuncDecl->ReturnType)) - { - // Bad dedeclaration - getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "redeclaration has a different return type"); + // If a given parameter type doesn't match, then signatures don't match + if (!fstParam->Type.Equals(sndParam->Type)) + return false; - // Don't bother emitting other errors at this point - break; - } + // If one parameter is `out` and the other isn't, then they don't match + // + // Note(tfoley): we don't consider `out` and `inout` as distinct here, + // because there is no way for overload resolution to pick between them. + if (fstParam->HasModifier<OutModifier>() != sndParam->HasModifier<OutModifier>()) + return false; + } - // TODO(tfoley): track the fact that there is redeclaration going on, - // so that we can detect it and react accordingly during overload resolution - // (e.g., by only considering one declaration as the canonical one...) + // Note(tfoley): return type doesn't enter into it, because we can't take + // calling context into account during overload resolution. - // If both have a body, then there is trouble - if (funcDecl->Body && prevFuncDecl->Body) - { - // Redefinition - getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "function redefinition"); + return true; + } - // Don't bother emitting other errors - break; - } + void ValidateFunctionRedeclaration(FunctionSyntaxNode* funcDecl) + { + auto parentDecl = funcDecl->ParentDecl; + assert(parentDecl); + if (!parentDecl) return; + + // Look at previously-declared functions with the same name, + // in the same container + buildMemberDictionary(parentDecl); + + for (auto prevDecl = funcDecl->nextInContainerWithSameName; prevDecl; prevDecl = prevDecl->nextInContainerWithSameName) + { + // Look through generics to the declaration underneath + auto prevGenericDecl = dynamic_cast<GenericDecl*>(prevDecl); + if (prevGenericDecl) + prevDecl = prevGenericDecl->inner.Ptr(); + + // We only care about previously-declared functions + // Note(tfoley): although we should really error out if the + // name is already in use for something else, like a variable... + auto prevFuncDecl = dynamic_cast<FunctionSyntaxNode*>(prevDecl); + if (!prevFuncDecl) + continue; + + // If the parameter signatures don't match, then don't worry + if (!DoFunctionSignaturesMatch(funcDecl, prevFuncDecl)) + continue; + + // If we get this far, then we've got two declarations in the same + // scope, with the same name and signature. + // + // They might just be redeclarations, which we would want to allow. + + // First, check if the return types match. + // TODO(tfolye): this code won't work for generics + if (!funcDecl->ReturnType.Equals(prevFuncDecl->ReturnType)) + { + // Bad dedeclaration + getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "redeclaration has a different return type"); - // TODO(tfoley): If both specific default argument expressions - // for the same value, then that is an error too... + // Don't bother emitting other errors at this point + break; } - } - void VisitFunctionDeclaration(FunctionSyntaxNode *functionNode) - { - if (functionNode->IsChecked(DeclCheckState::CheckedHeader)) return; - functionNode->SetCheckState(DeclCheckState::CheckingHeader); + // TODO(tfoley): track the fact that there is redeclaration going on, + // so that we can detect it and react accordingly during overload resolution + // (e.g., by only considering one declaration as the canonical one...) - this->function = functionNode; - auto returnType = CheckProperType(functionNode->ReturnType); - functionNode->ReturnType = returnType; - HashSet<String> paraNames; - for (auto & para : functionNode->GetParameters()) + // If both have a body, then there is trouble + if (funcDecl->Body && prevFuncDecl->Body) { - if (paraNames.Contains(para->Name.Content)) - getSink()->diagnose(para, Diagnostics::parameterAlreadyDefined, para->Name); - else - paraNames.Add(para->Name.Content); - para->Type = CheckUsableType(para->Type); - if (para->Type.Equals(ExpressionType::GetVoid())) - getSink()->diagnose(para, Diagnostics::parameterCannotBeVoid); + // Redefinition + getSink()->diagnose(funcDecl, Diagnostics::unimplemented, "function redefinition"); + + // Don't bother emitting other errors + break; } - this->function = NULL; - functionNode->SetCheckState(DeclCheckState::CheckedHeader); - // One last bit of validation: check if we are redeclaring an existing function - ValidateFunctionRedeclaration(functionNode); + // TODO(tfoley): If both specific default argument expressions + // for the same value, then that is an error too... } + } - virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode *stmt) override + void VisitFunctionDeclaration(FunctionSyntaxNode *functionNode) + { + if (functionNode->IsChecked(DeclCheckState::CheckedHeader)) return; + functionNode->SetCheckState(DeclCheckState::CheckingHeader); + + this->function = functionNode; + auto returnType = CheckProperType(functionNode->ReturnType); + functionNode->ReturnType = returnType; + HashSet<String> paraNames; + for (auto & para : functionNode->GetParameters()) { - for (auto & node : stmt->Statements) - { - node->Accept(this); - } - return stmt; + if (paraNames.Contains(para->Name.Content)) + getSink()->diagnose(para, Diagnostics::parameterAlreadyDefined, para->Name); + else + paraNames.Add(para->Name.Content); + para->Type = CheckUsableType(para->Type); + if (para->Type.Equals(ExpressionType::GetVoid())) + getSink()->diagnose(para, Diagnostics::parameterCannotBeVoid); } + this->function = NULL; + functionNode->SetCheckState(DeclCheckState::CheckedHeader); - template<typename T> - T* FindOuterStmt() + // One last bit of validation: check if we are redeclaring an existing function + ValidateFunctionRedeclaration(functionNode); + } + + virtual RefPtr<StatementSyntaxNode> VisitBlockStatement(BlockStatementSyntaxNode *stmt) override + { + for (auto & node : stmt->Statements) { - int outerStmtCount = outerStmts.Count(); - for (int ii = outerStmtCount - 1; ii >= 0; --ii) - { - auto outerStmt = outerStmts[ii]; - auto found = dynamic_cast<T*>(outerStmt); - if (found) - return found; - } - return nullptr; + node->Accept(this); } + return stmt; + } - virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode *stmt) override + template<typename T> + T* FindOuterStmt() + { + int outerStmtCount = outerStmts.Count(); + for (int ii = outerStmtCount - 1; ii >= 0; --ii) { - auto outer = FindOuterStmt<BreakableStmt>(); - if (!outer) - { - getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); - } - stmt->parentStmt = outer; - return stmt; + auto outerStmt = outerStmts[ii]; + auto found = dynamic_cast<T*>(outerStmt); + if (found) + return found; } - virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode *stmt) override + return nullptr; + } + + virtual RefPtr<StatementSyntaxNode> VisitBreakStatement(BreakStatementSyntaxNode *stmt) override + { + auto outer = FindOuterStmt<BreakableStmt>(); + if (!outer) { - auto outer = FindOuterStmt<LoopStmt>(); - if (!outer) - { - getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); - } - stmt->parentStmt = outer; - return stmt; + getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); } - - void PushOuterStmt(StatementSyntaxNode* stmt) + stmt->parentStmt = outer; + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitContinueStatement(ContinueStatementSyntaxNode *stmt) override + { + auto outer = FindOuterStmt<LoopStmt>(); + if (!outer) { - outerStmts.Add(stmt); + getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); } + stmt->parentStmt = outer; + return stmt; + } + + void PushOuterStmt(StatementSyntaxNode* stmt) + { + outerStmts.Add(stmt); + } - void PopOuterStmt(StatementSyntaxNode* /*stmt*/) + void PopOuterStmt(StatementSyntaxNode* /*stmt*/) + { + outerStmts.RemoveAt(outerStmts.Count() - 1); + } + + virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode *stmt) override + { + PushOuterStmt(stmt); + if (stmt->Predicate != NULL) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) { - outerStmts.RemoveAt(outerStmts.Count() - 1); + getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError); } + stmt->Statement->Accept(this); - virtual RefPtr<StatementSyntaxNode> VisitDoWhileStatement(DoWhileStatementSyntaxNode *stmt) override + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode *stmt) override + { + PushOuterStmt(stmt); + if (stmt->InitialStatement) { - PushOuterStmt(stmt); - if (stmt->Predicate != NULL) - stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) - { - getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError); - } - stmt->Statement->Accept(this); - - PopOuterStmt(stmt); - return stmt; + stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>(); } - virtual RefPtr<StatementSyntaxNode> VisitForStatement(ForStatementSyntaxNode *stmt) override + if (stmt->PredicateExpression) { - PushOuterStmt(stmt); - if (stmt->InitialStatement) - { - stmt->InitialStatement = stmt->InitialStatement->Accept(this).As<StatementSyntaxNode>(); - } - if (stmt->PredicateExpression) - { - stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->PredicateExpression->Type->Equals(ExpressionType::GetBool()) && - !stmt->PredicateExpression->Type->Equals(ExpressionType::GetInt()) && - !stmt->PredicateExpression->Type->Equals(ExpressionType::GetUInt())) - { - getSink()->diagnose(stmt->PredicateExpression.Ptr(), Diagnostics::forPredicateTypeError); - } - } - if (stmt->SideEffectExpression) + stmt->PredicateExpression = stmt->PredicateExpression->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->PredicateExpression->Type->Equals(ExpressionType::GetBool()) && + !stmt->PredicateExpression->Type->Equals(ExpressionType::GetInt()) && + !stmt->PredicateExpression->Type->Equals(ExpressionType::GetUInt())) { - stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>(); + getSink()->diagnose(stmt->PredicateExpression.Ptr(), Diagnostics::forPredicateTypeError); } - stmt->Statement->Accept(this); - - PopOuterStmt(stmt); - return stmt; } - virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt) override + if (stmt->SideEffectExpression) { - PushOuterStmt(stmt); - // TODO(tfoley): need to coerce condition to an integral type... - stmt->condition = CheckExpr(stmt->condition); - stmt->body->Accept(this); - PopOuterStmt(stmt); - return stmt; + stmt->SideEffectExpression = stmt->SideEffectExpression->Accept(this).As<ExpressionSyntaxNode>(); } - virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt) override - { - auto expr = CheckExpr(stmt->expr); - auto switchStmt = FindOuterStmt<SwitchStmt>(); + stmt->Statement->Accept(this); - if (!switchStmt) - { - getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); - } - else - { - // TODO: need to do some basic matching to ensure the type - // for the `case` is consistent with the type for the `switch`... - } - - stmt->expr = expr; - stmt->parentStmt = switchStmt; + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<SwitchStmt> VisitSwitchStmt(SwitchStmt* stmt) override + { + PushOuterStmt(stmt); + // TODO(tfoley): need to coerce condition to an integral type... + stmt->condition = CheckExpr(stmt->condition); + stmt->body->Accept(this); + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<CaseStmt> VisitCaseStmt(CaseStmt* stmt) override + { + auto expr = CheckExpr(stmt->expr); + auto switchStmt = FindOuterStmt<SwitchStmt>(); - return stmt; + if (!switchStmt) + { + getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); } - virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt) override + else { - auto switchStmt = FindOuterStmt<SwitchStmt>(); - if (!switchStmt) - { - getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); - } - stmt->parentStmt = switchStmt; - return stmt; + // TODO: need to do some basic matching to ensure the type + // for the `case` is consistent with the type for the `switch`... } - virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode *stmt) override + + stmt->expr = expr; + stmt->parentStmt = switchStmt; + + return stmt; + } + virtual RefPtr<DefaultStmt> VisitDefaultStmt(DefaultStmt* stmt) override + { + auto switchStmt = FindOuterStmt<SwitchStmt>(); + if (!switchStmt) { - auto condition = stmt->Predicate; - condition = CheckTerm(condition); - condition = Coerce(ExpressionType::GetBool(), condition); + getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); + } + stmt->parentStmt = switchStmt; + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitIfStatement(IfStatementSyntaxNode *stmt) override + { + auto condition = stmt->Predicate; + condition = CheckTerm(condition); + condition = Coerce(ExpressionType::GetBool(), condition); - stmt->Predicate = condition; + stmt->Predicate = condition; #if 0 - if (stmt->Predicate != NULL) - stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) - && (!stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetBool()))) - getSink()->diagnose(stmt, Diagnostics::ifPredicateTypeError); + if (stmt->Predicate != NULL) + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) + && (!stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetBool()))) + getSink()->diagnose(stmt, Diagnostics::ifPredicateTypeError); #endif - if (stmt->PositiveStatement != NULL) - stmt->PositiveStatement->Accept(this); + if (stmt->PositiveStatement != NULL) + stmt->PositiveStatement->Accept(this); - if (stmt->NegativeStatement != NULL) - stmt->NegativeStatement->Accept(this); - return stmt; + if (stmt->NegativeStatement != NULL) + stmt->NegativeStatement->Accept(this); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode *stmt) override + { + if (!stmt->Expression) + { + if (function && !function->ReturnType.Equals(ExpressionType::GetVoid())) + getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); } - virtual RefPtr<StatementSyntaxNode> VisitReturnStatement(ReturnStatementSyntaxNode *stmt) override + else { - if (!stmt->Expression) - { - if (function && !function->ReturnType.Equals(ExpressionType::GetVoid())) - getSink()->diagnose(stmt, Diagnostics::returnNeedsExpression); - } - else + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Expression->Type->Equals(ExpressionType::Error.Ptr())) { - stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->Expression->Type->Equals(ExpressionType::Error.Ptr())) + if (function) { - if (function) - { - stmt->Expression = Coerce(function->ReturnType, stmt->Expression); - } - else - { - // TODO(tfoley): this case currently gets triggered for member functions, - // which aren't being checked consistently (because of the whole symbol - // table idea getting in the way). + stmt->Expression = Coerce(function->ReturnType, stmt->Expression); + } + else + { + // TODO(tfoley): this case currently gets triggered for member functions, + // which aren't being checked consistently (because of the whole symbol + // table idea getting in the way). // getSink()->diagnose(stmt, Diagnostics::unimplemented, "case for return stmt"); - } } } - return stmt; } + return stmt; + } - int GetMinBound(RefPtr<IntVal> val) - { - if (auto constantVal = val.As<ConstantIntVal>()) - return constantVal->value; + int GetMinBound(RefPtr<IntVal> val) + { + if (auto constantVal = val.As<ConstantIntVal>()) + return constantVal->value; - // TODO(tfoley): Need to track intervals so that this isn't just a lie... - return 1; - } + // TODO(tfoley): Need to track intervals so that this isn't just a lie... + return 1; + } - void maybeInferArraySizeForVariable(Variable* varDecl) - { - // Not an array? - auto arrayType = varDecl->Type->AsArrayType(); - if (!arrayType) return; + void maybeInferArraySizeForVariable(Variable* varDecl) + { + // Not an array? + auto arrayType = varDecl->Type->AsArrayType(); + if (!arrayType) return; - // Explicit element count given? - auto elementCount = arrayType->ArrayLength; - if (elementCount) return; + // Explicit element count given? + auto elementCount = arrayType->ArrayLength; + if (elementCount) return; - // No initializer? - auto initExpr = varDecl->Expr; - if(!initExpr) return; + // No initializer? + auto initExpr = varDecl->Expr; + if(!initExpr) return; - // Is the initializer an initializer list? - if(auto initializerListExpr = initExpr.As<InitializerListExpr>()) - { - auto argCount = initializerListExpr->args.Count(); - elementCount = new ConstantIntVal(argCount); - } - // Is the type of the initializer an array type? - else if(auto arrayInitType = initExpr->Type->As<ArrayExpressionType>()) - { - elementCount = arrayInitType->ArrayLength; - } - else - { - // Nothing to do: we couldn't infer a size - return; - } + // Is the initializer an initializer list? + if(auto initializerListExpr = initExpr.As<InitializerListExpr>()) + { + auto argCount = initializerListExpr->args.Count(); + elementCount = new ConstantIntVal(argCount); + } + // Is the type of the initializer an array type? + else if(auto arrayInitType = initExpr->Type->As<ArrayExpressionType>()) + { + elementCount = arrayInitType->ArrayLength; + } + else + { + // Nothing to do: we couldn't infer a size + return; + } - // Create a new array type based on the size we found, - // and install it into our type. - auto newArrayType = new ArrayExpressionType(); - newArrayType->BaseType = arrayType->BaseType; - newArrayType->ArrayLength = elementCount; + // Create a new array type based on the size we found, + // and install it into our type. + auto newArrayType = new ArrayExpressionType(); + newArrayType->BaseType = arrayType->BaseType; + newArrayType->ArrayLength = elementCount; - // Okay we are good to go! - varDecl->Type.type = newArrayType; - } + // Okay we are good to go! + varDecl->Type.type = newArrayType; + } - void ValidateArraySizeForVariable(Variable* varDecl) - { - auto arrayType = varDecl->Type->AsArrayType(); - if (!arrayType) return; + void ValidateArraySizeForVariable(Variable* varDecl) + { + auto arrayType = varDecl->Type->AsArrayType(); + if (!arrayType) return; - auto elementCount = arrayType->ArrayLength; - if (!elementCount) - { - // Note(tfoley): For now we allow arrays of unspecified size - // everywhere, because some source languages (e.g., GLSL) - // allow them in specific cases. + auto elementCount = arrayType->ArrayLength; + if (!elementCount) + { + // Note(tfoley): For now we allow arrays of unspecified size + // everywhere, because some source languages (e.g., GLSL) + // allow them in specific cases. #if 0 - getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); #endif - return; - } - - // TODO(tfoley): How to handle the case where bound isn't known? - if (GetMinBound(elementCount) <= 0) - { - getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); - return; - } + return; } - virtual RefPtr<Variable> VisitDeclrVariable(Variable* varDecl) + // TODO(tfoley): How to handle the case where bound isn't known? + if (GetMinBound(elementCount) <= 0) { - TypeExp typeExp = CheckUsableType(varDecl->Type); + getSink()->diagnose(varDecl, Diagnostics::invalidArraySize); + return; + } + } + + virtual RefPtr<Variable> VisitDeclrVariable(Variable* varDecl) + { + TypeExp typeExp = CheckUsableType(varDecl->Type); #if 0 - if (typeExp.type->GetBindableResourceType() != BindableResourceType::NonBindable) + if (typeExp.type->GetBindableResourceType() != BindableResourceType::NonBindable) + { + // We don't want to allow bindable resource types as local variables (at least for now). + auto parentDecl = varDecl->ParentDecl; + if (auto parentScopeDecl = dynamic_cast<ScopeDecl*>(parentDecl)) { - // We don't want to allow bindable resource types as local variables (at least for now). - auto parentDecl = varDecl->ParentDecl; - if (auto parentScopeDecl = dynamic_cast<ScopeDecl*>(parentDecl)) - { - getSink()->diagnose(varDecl->Type, Diagnostics::invalidTypeForLocalVariable); - } + getSink()->diagnose(varDecl->Type, Diagnostics::invalidTypeForLocalVariable); } + } #endif - varDecl->Type = typeExp; - if (varDecl->Type.Equals(ExpressionType::GetVoid())) - getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); - - if(auto initExpr = varDecl->Expr) - { - initExpr = CheckTerm(initExpr); - varDecl->Expr = initExpr; - } - - // If this is an array variable, then we first want to give - // it a chance to infer an array size from its initializer - // - // TODO(tfoley): May need to extend this to handle the - // multi-dimensional case... - maybeInferArraySizeForVariable(varDecl); - // - // Next we want to make sure that the declared (or inferred) - // size for the array meets whatever language-specific - // constraints we want to enforce (e.g., disallow empty - // arrays in specific cases) - ValidateArraySizeForVariable(varDecl); + varDecl->Type = typeExp; + if (varDecl->Type.Equals(ExpressionType::GetVoid())) + getSink()->diagnose(varDecl, Diagnostics::invalidTypeVoid); + if(auto initExpr = varDecl->Expr) + { + initExpr = CheckTerm(initExpr); + varDecl->Expr = initExpr; + } - if(auto initExpr = varDecl->Expr) - { - // TODO(tfoley): should coercion of initializer lists be special-cased - // here, or handled as a general case for coercion? + // If this is an array variable, then we first want to give + // it a chance to infer an array size from its initializer + // + // TODO(tfoley): May need to extend this to handle the + // multi-dimensional case... + maybeInferArraySizeForVariable(varDecl); + // + // Next we want to make sure that the declared (or inferred) + // size for the array meets whatever language-specific + // constraints we want to enforce (e.g., disallow empty + // arrays in specific cases) + ValidateArraySizeForVariable(varDecl); - initExpr = Coerce(varDecl->Type, initExpr); - varDecl->Expr = initExpr; - } - varDecl->SetCheckState(DeclCheckState::Checked); + if(auto initExpr = varDecl->Expr) + { + // TODO(tfoley): should coercion of initializer lists be special-cased + // here, or handled as a general case for coercion? - return varDecl; + initExpr = Coerce(varDecl->Type, initExpr); + varDecl->Expr = initExpr; } - virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode *stmt) override - { - PushOuterStmt(stmt); - stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); - if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && - !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) - getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError2); + varDecl->SetCheckState(DeclCheckState::Checked); - stmt->Statement->Accept(this); - PopOuterStmt(stmt); - return stmt; - } - virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode *stmt) override - { - stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); - return stmt; - } - virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode *expr) override - { + return varDecl; + } + + virtual RefPtr<StatementSyntaxNode> VisitWhileStatement(WhileStatementSyntaxNode *stmt) override + { + PushOuterStmt(stmt); + stmt->Predicate = stmt->Predicate->Accept(this).As<ExpressionSyntaxNode>(); + if (!stmt->Predicate->Type->Equals(ExpressionType::GetError()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetInt()) && + !stmt->Predicate->Type->Equals(ExpressionType::GetBool())) + getSink()->diagnose(stmt, Diagnostics::whilePredicateTypeError2); + + stmt->Statement->Accept(this); + PopOuterStmt(stmt); + return stmt; + } + virtual RefPtr<StatementSyntaxNode> VisitExpressionStatement(ExpressionStatementSyntaxNode *stmt) override + { + stmt->Expression = stmt->Expression->Accept(this).As<ExpressionSyntaxNode>(); + return stmt; + } + virtual RefPtr<ExpressionSyntaxNode> VisitOperatorExpression(OperatorExpressionSyntaxNode *expr) override + { #if 0 - for (int i = 0; i < expr->Arguments.Count(); i++) - expr->Arguments[i] = expr->Arguments[i]->Accept(this).As<ExpressionSyntaxNode>(); - auto & leftType = expr->Arguments[0]->Type; - QualType rightType; - if (expr->Arguments.Count() == 2) - rightType = expr->Arguments[1]->Type; - RefPtr<ExpressionType> matchedType; - auto checkAssign = [&]() + for (int i = 0; i < expr->Arguments.Count(); i++) + expr->Arguments[i] = expr->Arguments[i]->Accept(this).As<ExpressionSyntaxNode>(); + auto & leftType = expr->Arguments[0]->Type; + QualType rightType; + if (expr->Arguments.Count() == 2) + rightType = expr->Arguments[1]->Type; + RefPtr<ExpressionType> matchedType; + auto checkAssign = [&]() + { + if (!leftType.IsLeftValue && + !leftType->Equals(ExpressionType::Error.Ptr())) + getSink()->diagnose(expr->Arguments[0].Ptr(), Diagnostics::assignNonLValue); + if (expr->Operator == Operator::AndAssign || + expr->Operator == Operator::OrAssign || + expr->Operator == Operator::XorAssign || + expr->Operator == Operator::LshAssign || + expr->Operator == Operator::RshAssign) { - if (!leftType.IsLeftValue && - !leftType->Equals(ExpressionType::Error.Ptr())) - getSink()->diagnose(expr->Arguments[0].Ptr(), Diagnostics::assignNonLValue); - if (expr->Operator == Operator::AndAssign || - expr->Operator == Operator::OrAssign || - expr->Operator == Operator::XorAssign || - expr->Operator == Operator::LshAssign || - expr->Operator == Operator::RshAssign) - { #if 0 - if (!(leftType->IsIntegral() && rightType->IsIntegral())) - { - // TODO(tfoley): This diagnostic shouldn't be handled here + if (!(leftType->IsIntegral() && rightType->IsIntegral())) + { + // TODO(tfoley): This diagnostic shouldn't be handled here // getSink()->diagnose(expr, Diagnostics::bitOperationNonIntegral); - } -#endif } - - // TODO(tfoley): Need to actual insert coercion here... - if(CanCoerce(leftType, expr->Type)) - expr->Type = leftType; - else - expr->Type = ExpressionType::Error; - }; -#if 0 - if (expr->Operator == Operator::Assign) - { - expr->Type = rightType; - checkAssign(); - } - else #endif - { - expr->FunctionExpr = CheckExpr(expr->FunctionExpr); - CheckInvokeExprWithCheckedOperands(expr); - if (expr->Operator > Operator::Assign) - checkAssign(); } - return expr; -#endif - // Treat operator application just like a function call - return VisitInvokeExpression(expr); + // TODO(tfoley): Need to actual insert coercion here... + if(CanCoerce(leftType, expr->Type)) + expr->Type = leftType; + else + expr->Type = ExpressionType::Error; + }; +#if 0 + if (expr->Operator == Operator::Assign) + { + expr->Type = rightType; + checkAssign(); } - virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode *expr) override + else +#endif { - switch (expr->ConstType) - { - case ConstantExpressionSyntaxNode::ConstantType::Int: - expr->Type = ExpressionType::GetInt(); - break; - case ConstantExpressionSyntaxNode::ConstantType::Bool: - expr->Type = ExpressionType::GetBool(); - break; - case ConstantExpressionSyntaxNode::ConstantType::Float: - expr->Type = ExpressionType::GetFloat(); - break; - default: - expr->Type = ExpressionType::Error; - throw "Invalid constant type."; - break; - } - return expr; + expr->FunctionExpr = CheckExpr(expr->FunctionExpr); + CheckInvokeExprWithCheckedOperands(expr); + if (expr->Operator > Operator::Assign) + checkAssign(); } + return expr; +#endif - IntVal* GetIntVal(ConstantExpressionSyntaxNode* expr) - { - // TODO(tfoley): don't keep allocating here! - return new ConstantIntVal(expr->IntValue); + // Treat operator application just like a function call + return VisitInvokeExpression(expr); + } + virtual RefPtr<ExpressionSyntaxNode> VisitConstantExpression(ConstantExpressionSyntaxNode *expr) override + { + switch (expr->ConstType) + { + case ConstantExpressionSyntaxNode::ConstantType::Int: + expr->Type = ExpressionType::GetInt(); + break; + case ConstantExpressionSyntaxNode::ConstantType::Bool: + expr->Type = ExpressionType::GetBool(); + break; + case ConstantExpressionSyntaxNode::ConstantType::Float: + expr->Type = ExpressionType::GetFloat(); + break; + default: + expr->Type = ExpressionType::Error; + throw "Invalid constant type."; + break; } + return expr; + } - RefPtr<IntVal> TryConstantFoldExpr( - InvokeExpressionSyntaxNode* invokeExpr) - { - // We need all the operands to the expression + IntVal* GetIntVal(ConstantExpressionSyntaxNode* expr) + { + // TODO(tfoley): don't keep allocating here! + return new ConstantIntVal(expr->IntValue); + } - // Check if the callee is an operation that is amenable to constant-folding. - // - // For right now we will look for calls to intrinsic functions, and then inspect - // their names (this is bad and slow). - auto funcDeclRefExpr = invokeExpr->FunctionExpr.As<DeclRefExpr>(); - if (!funcDeclRefExpr) return nullptr; - - auto funcDeclRef = funcDeclRefExpr->declRef; - auto intrinsicMod = funcDeclRef.GetDecl()->FindModifier<IntrinsicOpModifier>(); - if (!intrinsicMod) return nullptr; - - // Let's not constant-fold operations with more than a certain number of arguments, for simplicity - static const int kMaxArgs = 8; - if (invokeExpr->Arguments.Count() > kMaxArgs) - return nullptr; + RefPtr<IntVal> TryConstantFoldExpr( + InvokeExpressionSyntaxNode* invokeExpr) + { + // We need all the operands to the expression - // Before checking the operation name, let's look at the arguments - RefPtr<IntVal> argVals[kMaxArgs]; - int constArgVals[kMaxArgs]; - int argCount = 0; - bool allConst = true; - for (auto argExpr : invokeExpr->Arguments) - { - auto argVal = TryCheckIntegerConstantExpression(argExpr.Ptr()); - if (!argVal) - return nullptr; + // Check if the callee is an operation that is amenable to constant-folding. + // + // For right now we will look for calls to intrinsic functions, and then inspect + // their names (this is bad and slow). + auto funcDeclRefExpr = invokeExpr->FunctionExpr.As<DeclRefExpr>(); + if (!funcDeclRefExpr) return nullptr; + + auto funcDeclRef = funcDeclRefExpr->declRef; + auto intrinsicMod = funcDeclRef.GetDecl()->FindModifier<IntrinsicOpModifier>(); + if (!intrinsicMod) return nullptr; + + // Let's not constant-fold operations with more than a certain number of arguments, for simplicity + static const int kMaxArgs = 8; + if (invokeExpr->Arguments.Count() > kMaxArgs) + return nullptr; - argVals[argCount] = argVal; + // Before checking the operation name, let's look at the arguments + RefPtr<IntVal> argVals[kMaxArgs]; + int constArgVals[kMaxArgs]; + int argCount = 0; + bool allConst = true; + for (auto argExpr : invokeExpr->Arguments) + { + auto argVal = TryCheckIntegerConstantExpression(argExpr.Ptr()); + if (!argVal) + return nullptr; - if (auto constArgVal = argVal.As<ConstantIntVal>()) - { - constArgVals[argCount] = constArgVal->value; - } - else - { - allConst = false; - } - argCount++; - } + argVals[argCount] = argVal; - if (!allConst) + if (auto constArgVal = argVal.As<ConstantIntVal>()) { - // TODO(tfoley): We probably want to support a very limited number of operations - // on "constants" that aren't actually known, to be able to handle a generic - // that takes an integer `N` but then constructs a vector of size `N+1`. - // - // The hard part there is implementing the rules for value unification in the - // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd - // need inference to be smart enough to know that `2 + N` and `N + 2` are the - // same value, as are `N + M + 1 + 1` and `M + 2 + N`. - // - // For now we can just bail in this case. - return nullptr; + constArgVals[argCount] = constArgVal->value; + } + else + { + allConst = false; } + argCount++; + } - // At this point, all the operands had simple integer values, so we are golden. - int resultValue = 0; - auto opName = funcDeclRef.GetName(); + if (!allConst) + { + // TODO(tfoley): We probably want to support a very limited number of operations + // on "constants" that aren't actually known, to be able to handle a generic + // that takes an integer `N` but then constructs a vector of size `N+1`. + // + // The hard part there is implementing the rules for value unification in the + // presence of more complicated `IntVal` subclasses, like `SumIntVal`. You'd + // need inference to be smart enough to know that `2 + N` and `N + 2` are the + // same value, as are `N + M + 1 + 1` and `M + 2 + N`. + // + // For now we can just bail in this case. + return nullptr; + } + + // At this point, all the operands had simple integer values, so we are golden. + int resultValue = 0; + auto opName = funcDeclRef.GetName(); - // handle binary operators - if (opName == "-") + // handle binary operators + if (opName == "-") + { + if (argCount == 1) { - if (argCount == 1) - { - resultValue = -constArgVals[0]; - } - else if (argCount == 2) - { - resultValue = constArgVals[0] - constArgVals[1]; - } + resultValue = -constArgVals[0]; + } + else if (argCount == 2) + { + resultValue = constArgVals[0] - constArgVals[1]; } + } - // simple binary operators + // simple binary operators #define CASE(OP) \ - else if(opName == #OP) do { \ - if(argCount != 2) return nullptr; \ - resultValue = constArgVals[0] OP constArgVals[1]; \ - } while(0) + else if(opName == #OP) do { \ + if(argCount != 2) return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } while(0) - CASE(+); // TODO: this can also be unary... - CASE(*); + CASE(+); // TODO: this can also be unary... + CASE(*); #undef CASE - // binary operators with chance of divide-by-zero - // TODO: issue a suitable error in that case + // binary operators with chance of divide-by-zero + // TODO: issue a suitable error in that case #define CASE(OP) \ - else if(opName == #OP) do { \ - if(argCount != 2) return nullptr; \ - if(!constArgVals[1]) return nullptr; \ - resultValue = constArgVals[0] OP constArgVals[1]; \ - } while(0) - - CASE(/); - CASE(%); + else if(opName == #OP) do { \ + if(argCount != 2) return nullptr; \ + if(!constArgVals[1]) return nullptr; \ + resultValue = constArgVals[0] OP constArgVals[1]; \ + } while(0) + + CASE(/); + CASE(%); #undef CASE - // TODO(tfoley): more cases - else - { - return nullptr; - } + // TODO(tfoley): more cases + else + { + return nullptr; + } + + RefPtr<IntVal> result = new ConstantIntVal(resultValue); + return result; + } - RefPtr<IntVal> result = new ConstantIntVal(resultValue); - return result; + RefPtr<IntVal> TryConstantFoldExpr( + ExpressionSyntaxNode* expr) + { + // TODO(tfoley): more serious constant folding here + if (auto constExp = dynamic_cast<ConstantExpressionSyntaxNode*>(expr)) + { + return GetIntVal(constExp); } - RefPtr<IntVal> TryConstantFoldExpr( - ExpressionSyntaxNode* expr) + // it is possible that we are referring to a generic value param + if (auto declRefExpr = dynamic_cast<DeclRefExpr*>(expr)) { - // TODO(tfoley): more serious constant folding here - if (auto constExp = dynamic_cast<ConstantExpressionSyntaxNode*>(expr)) + auto declRef = declRefExpr->declRef; + + if (auto genericValParamRef = declRef.As<GenericValueParamDeclRef>()) { - return GetIntVal(constExp); + // TODO(tfoley): handle the case of non-`int` value parameters... + return new GenericParamIntVal(genericValParamRef); } - // it is possible that we are referring to a generic value param - if (auto declRefExpr = dynamic_cast<DeclRefExpr*>(expr)) + // We may also need to check for references to variables that + // are defined in a way that can be used as a constant expression: + if(auto varRef = declRef.As<VarDeclBaseRef>()) { - auto declRef = declRefExpr->declRef; - - if (auto genericValParamRef = declRef.As<GenericValueParamDeclRef>()) - { - // TODO(tfoley): handle the case of non-`int` value parameters... - return new GenericParamIntVal(genericValParamRef); - } + auto varDecl = varRef.GetDecl(); - // We may also need to check for references to variables that - // are defined in a way that can be used as a constant expression: - if(auto varRef = declRef.As<VarDeclBaseRef>()) + switch(sourceLanguage) { - auto varDecl = varRef.GetDecl(); - - switch(sourceLanguage) + case SourceLanguage::Slang: + case SourceLanguage::HLSL: + // HLSL: `static const` is used to mark compile-time constant expressions + if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) { - case SourceLanguage::Slang: - case SourceLanguage::HLSL: - // HLSL: `static const` is used to mark compile-time constant expressions - if(auto staticAttr = varDecl->FindModifier<HLSLStaticModifier>()) - { - if(auto constAttr = varDecl->FindModifier<ConstModifier>()) - { - // HLSL `static const` can be used as a constant expression - if(auto initExpr = varRef.getInitExpr()) - { - return TryConstantFoldExpr(initExpr.Ptr()); - } - } - } - break; - - case SourceLanguage::GLSL: - // GLSL: `const` indicates compile-time constant expression - // - // TODO(tfoley): The current logic here isn't robust against - // GLSL "specialization constants" - we will extract the - // initializer for a `const` variable and use it to extract - // a value, when we really should be using an opaque - // reference to the variable. if(auto constAttr = varDecl->FindModifier<ConstModifier>()) { - // We need to handle a "specialization constant" (with a `constant_id` layout modifier) - // differently from an ordinary compile-time constant. The latter can/should be reduced - // to a value, while the former should be kept as a symbolic reference - - if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>()) - { - // Retain the specialization constant as a symbolic reference - // - // TODO(tfoley): handle the case of non-`int` value parameters... - // - // TODO(tfoley): this is cloned from the case above that handles generic value parameters - return new GenericParamIntVal(varRef); - } - else if(auto initExpr = varRef.getInitExpr()) + // HLSL `static const` can be used as a constant expression + if(auto initExpr = varRef.getInitExpr()) { - // This is an ordinary constant, and not a specialization constant, so we - // can try to fold its value right now. return TryConstantFoldExpr(initExpr.Ptr()); } } - break; } + break; + case SourceLanguage::GLSL: + // GLSL: `const` indicates compile-time constant expression + // + // TODO(tfoley): The current logic here isn't robust against + // GLSL "specialization constants" - we will extract the + // initializer for a `const` variable and use it to extract + // a value, when we really should be using an opaque + // reference to the variable. + if(auto constAttr = varDecl->FindModifier<ConstModifier>()) + { + // We need to handle a "specialization constant" (with a `constant_id` layout modifier) + // differently from an ordinary compile-time constant. The latter can/should be reduced + // to a value, while the former should be kept as a symbolic reference + + if(auto constantIDModifier = varDecl->FindModifier<GLSLConstantIDLayoutModifier>()) + { + // Retain the specialization constant as a symbolic reference + // + // TODO(tfoley): handle the case of non-`int` value parameters... + // + // TODO(tfoley): this is cloned from the case above that handles generic value parameters + return new GenericParamIntVal(varRef); + } + else if(auto initExpr = varRef.getInitExpr()) + { + // This is an ordinary constant, and not a specialization constant, so we + // can try to fold its value right now. + return TryConstantFoldExpr(initExpr.Ptr()); + } + } + break; } - } - if (auto invokeExpr = dynamic_cast<InvokeExpressionSyntaxNode*>(expr)) - { - auto val = TryConstantFoldExpr(invokeExpr); - if (val) - return val; - } - else if(auto castExpr = dynamic_cast<TypeCastExpressionSyntaxNode*>(expr)) - { - auto val = TryConstantFoldExpr(castExpr->Expression.Ptr()); - if(val) - return val; } + } - return nullptr; + if (auto invokeExpr = dynamic_cast<InvokeExpressionSyntaxNode*>(expr)) + { + auto val = TryConstantFoldExpr(invokeExpr); + if (val) + return val; + } + else if(auto castExpr = dynamic_cast<TypeCastExpressionSyntaxNode*>(expr)) + { + auto val = TryConstantFoldExpr(castExpr->Expression.Ptr()); + if(val) + return val; } - // Try to check an integer constant expression, either returning the value, - // or NULL if the expression isn't recognized as a constant. - RefPtr<IntVal> TryCheckIntegerConstantExpression(ExpressionSyntaxNode* exp) + return nullptr; + } + + // Try to check an integer constant expression, either returning the value, + // or NULL if the expression isn't recognized as a constant. + RefPtr<IntVal> TryCheckIntegerConstantExpression(ExpressionSyntaxNode* exp) + { + if (!exp->Type.type->Equals(ExpressionType::GetInt())) { - if (!exp->Type.type->Equals(ExpressionType::GetInt())) - { - return nullptr; - } + return nullptr; + } - // Otherwise, we need to consider operations that we might be able to constant-fold... - return TryConstantFoldExpr(exp); - } + // Otherwise, we need to consider operations that we might be able to constant-fold... + return TryConstantFoldExpr(exp); + } - // Enforce that an expression resolves to an integer constant, and get its value - RefPtr<IntVal> CheckIntegerConstantExpression(ExpressionSyntaxNode* inExpr) + // Enforce that an expression resolves to an integer constant, and get its value + RefPtr<IntVal> CheckIntegerConstantExpression(ExpressionSyntaxNode* inExpr) + { + // First coerce the expression to the expected type + auto expr = Coerce(ExpressionType::GetInt(),inExpr); + auto result = TryCheckIntegerConstantExpression(expr.Ptr()); + if (!result) { - // First coerce the expression to the expected type - auto expr = Coerce(ExpressionType::GetInt(),inExpr); - auto result = TryCheckIntegerConstantExpression(expr.Ptr()); - if (!result) - { - getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); - } - return result; + getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); } + return result; + } - RefPtr<ExpressionSyntaxNode> CheckSimpleSubscriptExpr( - RefPtr<IndexExpressionSyntaxNode> subscriptExpr, - RefPtr<ExpressionType> elementType) - { - auto baseExpr = subscriptExpr->BaseExpression; - auto indexExpr = subscriptExpr->IndexExpression; + RefPtr<ExpressionSyntaxNode> CheckSimpleSubscriptExpr( + RefPtr<IndexExpressionSyntaxNode> subscriptExpr, + RefPtr<ExpressionType> elementType) + { + auto baseExpr = subscriptExpr->BaseExpression; + auto indexExpr = subscriptExpr->IndexExpression; - if (!indexExpr->Type->Equals(ExpressionType::GetInt()) && - !indexExpr->Type->Equals(ExpressionType::GetUInt())) - { - getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); - return CreateErrorExpr(subscriptExpr.Ptr()); - } + if (!indexExpr->Type->Equals(ExpressionType::GetInt()) && + !indexExpr->Type->Equals(ExpressionType::GetUInt())) + { + getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger); + return CreateErrorExpr(subscriptExpr.Ptr()); + } - subscriptExpr->Type = elementType; + subscriptExpr->Type = elementType; - // TODO(tfoley): need to be more careful about this stuff - subscriptExpr->Type.IsLeftValue = baseExpr->Type.IsLeftValue; + // TODO(tfoley): need to be more careful about this stuff + subscriptExpr->Type.IsLeftValue = baseExpr->Type.IsLeftValue; - return subscriptExpr; - } + return subscriptExpr; + } - // The way that we have designed out type system, pretyt much *every* - // type is a reference to some declaration in the standard library. - // That means that when we construct a new type on the fly, we need - // to make sure that it is wired up to reference the appropriate - // declaration, or else it won't compare as equal to other types - // that *do* reference the declaration. - // - // This function is used to construct a `vector<T,N>` type - // programmatically, so that it will work just like a type of - // that form constructed by the user. - RefPtr<VectorExpressionType> createVectorType( - RefPtr<ExpressionType> elementType, - RefPtr<IntVal> elementCount) - { - auto vectorGenericDecl = findMagicDecl("Vector").As<GenericDecl>(); - auto vectorTypeDecl = vectorGenericDecl->inner; + // The way that we have designed out type system, pretyt much *every* + // type is a reference to some declaration in the standard library. + // That means that when we construct a new type on the fly, we need + // to make sure that it is wired up to reference the appropriate + // declaration, or else it won't compare as equal to other types + // that *do* reference the declaration. + // + // This function is used to construct a `vector<T,N>` type + // programmatically, so that it will work just like a type of + // that form constructed by the user. + RefPtr<VectorExpressionType> createVectorType( + RefPtr<ExpressionType> elementType, + RefPtr<IntVal> elementCount) + { + auto vectorGenericDecl = findMagicDecl("Vector").As<GenericDecl>(); + auto vectorTypeDecl = vectorGenericDecl->inner; - auto substitutions = new Substitutions(); - substitutions->genericDecl = vectorGenericDecl.Ptr(); - substitutions->args.Add(elementType); - substitutions->args.Add(elementCount); + auto substitutions = new Substitutions(); + substitutions->genericDecl = vectorGenericDecl.Ptr(); + substitutions->args.Add(elementType); + substitutions->args.Add(elementCount); - auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions); + auto declRef = DeclRef(vectorTypeDecl.Ptr(), substitutions); + + return DeclRefType::Create(declRef)->As<VectorExpressionType>(); + } + + virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* subscriptExpr) override + { + auto baseExpr = subscriptExpr->BaseExpression; + baseExpr = CheckExpr(baseExpr); - return DeclRefType::Create(declRef)->As<VectorExpressionType>(); + RefPtr<ExpressionSyntaxNode> indexExpr = subscriptExpr->IndexExpression; + if (indexExpr) + { + indexExpr = CheckExpr(indexExpr); } - virtual RefPtr<ExpressionSyntaxNode> VisitIndexExpression(IndexExpressionSyntaxNode* subscriptExpr) override + subscriptExpr->BaseExpression = baseExpr; + subscriptExpr->IndexExpression = indexExpr; + + // If anything went wrong in the base expression, + // then just move along... + if (IsErrorExpr(baseExpr)) + return CreateErrorExpr(subscriptExpr); + + // Otherwise, we need to look at the type of the base expression, + // to figure out how subscripting should work. + auto baseType = baseExpr->Type.Ptr(); + if (auto baseTypeType = baseType->As<TypeType>()) { - auto baseExpr = subscriptExpr->BaseExpression; - baseExpr = CheckExpr(baseExpr); + // We are trying to "index" into a type, so we have an expression like `float[2]` + // which should be interpreted as resolving to an array type. - RefPtr<ExpressionSyntaxNode> indexExpr = subscriptExpr->IndexExpression; + RefPtr<IntVal> elementCount = nullptr; if (indexExpr) { - indexExpr = CheckExpr(indexExpr); + elementCount = CheckIntegerConstantExpression(indexExpr.Ptr()); } - subscriptExpr->BaseExpression = baseExpr; - subscriptExpr->IndexExpression = indexExpr; + auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); + auto arrayType = new ArrayExpressionType(); + arrayType->BaseType = elementType; + arrayType->ArrayLength = elementCount; - // If anything went wrong in the base expression, - // then just move along... - if (IsErrorExpr(baseExpr)) - return CreateErrorExpr(subscriptExpr); - - // Otherwise, we need to look at the type of the base expression, - // to figure out how subscripting should work. - auto baseType = baseExpr->Type.Ptr(); - if (auto baseTypeType = baseType->As<TypeType>()) - { - // We are trying to "index" into a type, so we have an expression like `float[2]` - // which should be interpreted as resolving to an array type. + typeResult = arrayType; + subscriptExpr->Type = new TypeType(arrayType); + return subscriptExpr; + } + else if (auto baseArrayType = baseType->As<ArrayExpressionType>()) + { + return CheckSimpleSubscriptExpr( + subscriptExpr, + baseArrayType->BaseType); + } + else if (auto vecType = baseType->As<VectorExpressionType>()) + { + return CheckSimpleSubscriptExpr( + subscriptExpr, + vecType->elementType); + } + else if (auto matType = baseType->As<MatrixExpressionType>()) + { + // TODO(tfoley): We shouldn't go and recompute + // row types over and over like this... :( + auto rowType = createVectorType( + matType->getElementType(), + matType->getColumnCount()); - RefPtr<IntVal> elementCount = nullptr; - if (indexExpr) - { - elementCount = CheckIntegerConstantExpression(indexExpr.Ptr()); - } + return CheckSimpleSubscriptExpr( + subscriptExpr, + rowType); + } - auto elementType = CoerceToUsableType(TypeExp(baseExpr, baseTypeType->type)); - auto arrayType = new ArrayExpressionType(); - arrayType->BaseType = elementType; - arrayType->ArrayLength = elementCount; + // Default behavior is to look at all available `__subscript` + // declarations on the type and try to call one of them. - typeResult = arrayType; - subscriptExpr->Type = new TypeType(arrayType); - return subscriptExpr; - } - else if (auto baseArrayType = baseType->As<ArrayExpressionType>()) - { - return CheckSimpleSubscriptExpr( - subscriptExpr, - baseArrayType->BaseType); - } - else if (auto vecType = baseType->As<VectorExpressionType>()) - { - return CheckSimpleSubscriptExpr( - subscriptExpr, - vecType->elementType); - } - else if (auto matType = baseType->As<MatrixExpressionType>()) + if (auto declRefType = baseType->AsDeclRefType()) + { + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) { - // TODO(tfoley): We shouldn't go and recompute - // row types over and over like this... :( - auto rowType = createVectorType( - matType->getElementType(), - matType->getColumnCount()); - - return CheckSimpleSubscriptExpr( - subscriptExpr, - rowType); - } + // Checking of the type must be complete before we can reference its members safely + EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); - // Default behavior is to look at all available `__subscript` - // declarations on the type and try to call one of them. - - if (auto declRefType = baseType->AsDeclRefType()) - { - if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) + // Note(tfoley): The name used for lookup here is a bit magical, since + // it must match what the parser installed in subscript declarations. + LookupResult lookupResult = LookUpLocal("operator[]", aggTypeDeclRef); + if (!lookupResult.isValid()) { - // Checking of the type must be complete before we can reference its members safely - EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); - - // Note(tfoley): The name used for lookup here is a bit magical, since - // it must match what the parser installed in subscript declarations. - LookupResult lookupResult = LookUpLocal("operator[]", aggTypeDeclRef); - if (!lookupResult.isValid()) - { - goto fail; - } + goto fail; + } - RefPtr<ExpressionSyntaxNode> subscriptFuncExpr = createLookupResultExpr( - lookupResult, subscriptExpr->BaseExpression, subscriptExpr); + RefPtr<ExpressionSyntaxNode> subscriptFuncExpr = createLookupResultExpr( + lookupResult, subscriptExpr->BaseExpression, subscriptExpr); - // Now that we know there is at least one subscript member, - // we will construct a reference to it and try to call it + // Now that we know there is at least one subscript member, + // we will construct a reference to it and try to call it - RefPtr<InvokeExpressionSyntaxNode> subscriptCallExpr = new InvokeExpressionSyntaxNode(); - subscriptCallExpr->Position = subscriptExpr->Position; - subscriptCallExpr->FunctionExpr = subscriptFuncExpr; + RefPtr<InvokeExpressionSyntaxNode> subscriptCallExpr = new InvokeExpressionSyntaxNode(); + subscriptCallExpr->Position = subscriptExpr->Position; + subscriptCallExpr->FunctionExpr = subscriptFuncExpr; - // TODO(tfoley): This path can support multiple arguments easily - subscriptCallExpr->Arguments.Add(subscriptExpr->IndexExpression); + // TODO(tfoley): This path can support multiple arguments easily + subscriptCallExpr->Arguments.Add(subscriptExpr->IndexExpression); - return CheckInvokeExprWithCheckedOperands(subscriptCallExpr.Ptr()); - } - } - - fail: - { - getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); - return CreateErrorExpr(subscriptExpr); + return CheckInvokeExprWithCheckedOperands(subscriptCallExpr.Ptr()); } } - bool MatchArguments(FunctionSyntaxNode * functionNode, List <RefPtr<ExpressionSyntaxNode>> &args) + fail: { - if (functionNode->GetParameters().Count() != args.Count()) - return false; - int i = 0; - for (auto param : functionNode->GetParameters()) - { - if (!param->Type.Equals(args[i]->Type.Ptr())) - return false; - i++; - } - return true; + getSink()->diagnose(subscriptExpr, Diagnostics::subscriptNonArray, baseType); + return CreateErrorExpr(subscriptExpr); } + } - // Coerce an expression to a specific type that it is expected to have in context - RefPtr<ExpressionSyntaxNode> CoerceExprToType( - RefPtr<ExpressionSyntaxNode> expr, - RefPtr<ExpressionType> type) + bool MatchArguments(FunctionSyntaxNode * functionNode, List <RefPtr<ExpressionSyntaxNode>> &args) + { + if (functionNode->GetParameters().Count() != args.Count()) + return false; + int i = 0; + for (auto param : functionNode->GetParameters()) { - // TODO(tfoley): clean this up so there is only one version... - return Coerce(type, expr); + if (!param->Type.Equals(args[i]->Type.Ptr())) + return false; + i++; } + return true; + } - // Resolve a call to a function, represented here - // by a symbol with a `FuncType` type. - RefPtr<ExpressionSyntaxNode> ResolveFunctionApp( - RefPtr<FuncType> funcType, - InvokeExpressionSyntaxNode* /*appExpr*/) - { - // TODO(tfoley): Actual checking logic needs to go here... + // Coerce an expression to a specific type that it is expected to have in context + RefPtr<ExpressionSyntaxNode> CoerceExprToType( + RefPtr<ExpressionSyntaxNode> expr, + RefPtr<ExpressionType> type) + { + // TODO(tfoley): clean this up so there is only one version... + return Coerce(type, expr); + } + + // Resolve a call to a function, represented here + // by a symbol with a `FuncType` type. + RefPtr<ExpressionSyntaxNode> ResolveFunctionApp( + RefPtr<FuncType> funcType, + InvokeExpressionSyntaxNode* /*appExpr*/) + { + // TODO(tfoley): Actual checking logic needs to go here... #if 0 - auto& args = appExpr->Arguments; - List<RefPtr<ParameterSyntaxNode>> params; - RefPtr<ExpressionType> resultType; - if (auto funcDeclRef = funcType->declRef) - { - EnsureDecl(funcDeclRef.GetDecl()); + auto& args = appExpr->Arguments; + List<RefPtr<ParameterSyntaxNode>> params; + RefPtr<ExpressionType> resultType; + if (auto funcDeclRef = funcType->declRef) + { + EnsureDecl(funcDeclRef.GetDecl()); - params = funcDeclRef->GetParameters().ToArray(); - resultType = funcDecl->ReturnType; - } - else if (auto funcSym = funcType->Func) - { - auto funcDecl = funcSym->SyntaxNode; - EnsureDecl(funcDecl); + params = funcDeclRef->GetParameters().ToArray(); + resultType = funcDecl->ReturnType; + } + else if (auto funcSym = funcType->Func) + { + auto funcDecl = funcSym->SyntaxNode; + EnsureDecl(funcDecl); - params = funcDecl->GetParameters().ToArray(); - resultType = funcDecl->ReturnType; - } - else if (auto componentFuncSym = funcType->Component) - { - auto componentFuncDecl = componentFuncSym->Implementations.First()->SyntaxNode; - params = componentFuncDecl->GetParameters().ToArray(); - resultType = componentFuncDecl->Type; - } + params = funcDecl->GetParameters().ToArray(); + resultType = funcDecl->ReturnType; + } + else if (auto componentFuncSym = funcType->Component) + { + auto componentFuncDecl = componentFuncSym->Implementations.First()->SyntaxNode; + params = componentFuncDecl->GetParameters().ToArray(); + resultType = componentFuncDecl->Type; + } - auto argCount = args.Count(); - auto paramCount = params.Count(); - if (argCount != paramCount) - { - getSink()->diagnose(appExpr, Diagnostics::unimplemented, "wrong number of arguments for call"); - appExpr->Type = ExpressionType::Error; - return appExpr; - } + auto argCount = args.Count(); + auto paramCount = params.Count(); + if (argCount != paramCount) + { + getSink()->diagnose(appExpr, Diagnostics::unimplemented, "wrong number of arguments for call"); + appExpr->Type = ExpressionType::Error; + return appExpr; + } - for (int ii = 0; ii < argCount; ++ii) - { - auto arg = args[ii]; - auto param = params[ii]; + for (int ii = 0; ii < argCount; ++ii) + { + auto arg = args[ii]; + auto param = params[ii]; - arg = CoerceExprToType(arg, param->Type); + arg = CoerceExprToType(arg, param->Type); - args[ii] = arg; - } + args[ii] = arg; + } - assert(resultType); - appExpr->Type = resultType; - return appExpr; + assert(resultType); + appExpr->Type = resultType; + return appExpr; #else - throw "unimplemented"; + throw "unimplemented"; #endif - } + } - // Resolve a constructor call, formed by apply a type to arguments - RefPtr<ExpressionSyntaxNode> ResolveConstructorApp( - RefPtr<ExpressionType> type, - InvokeExpressionSyntaxNode* appExpr) - { - // TODO(tfoley): Actual checking logic needs to go here... + // Resolve a constructor call, formed by apply a type to arguments + RefPtr<ExpressionSyntaxNode> ResolveConstructorApp( + RefPtr<ExpressionType> type, + InvokeExpressionSyntaxNode* appExpr) + { + // TODO(tfoley): Actual checking logic needs to go here... - appExpr->Type = type; - return appExpr; - } + appExpr->Type = type; + return appExpr; + } - // + // - virtual void VisitExtensionDecl(ExtensionDecl* decl) override - { - if (decl->IsChecked(DeclCheckState::Checked)) return; + virtual void VisitExtensionDecl(ExtensionDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return; - decl->SetCheckState(DeclCheckState::CheckingHeader); - decl->targetType = CheckProperType(decl->targetType); + decl->SetCheckState(DeclCheckState::CheckingHeader); + decl->targetType = CheckProperType(decl->targetType); - // TODO: need to check that the target type names a declaration... + // TODO: need to check that the target type names a declaration... - if (auto targetDeclRefType = decl->targetType->As<DeclRefType>()) - { - // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDeclRef>()) - { - auto aggTypeDecl = aggTypeDeclRef.GetDecl(); - decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; - aggTypeDecl->candidateExtensions = decl; - } - else - { - getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); - } - } - else if (decl->targetType->Equals(ExpressionType::Error)) + if (auto targetDeclRefType = decl->targetType->As<DeclRefType>()) + { + // Attach our extension to that type as a candidate... + if (auto aggTypeDeclRef = targetDeclRefType->declRef.As<AggTypeDeclRef>()) { - // there was an error, so ignore + auto aggTypeDecl = aggTypeDeclRef.GetDecl(); + decl->nextCandidateExtension = aggTypeDecl->candidateExtensions; + aggTypeDecl->candidateExtensions = decl; } else { getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); } - - decl->SetCheckState(DeclCheckState::CheckedHeader); - - // now check the members of the extension - for (auto m : decl->Members) - { - EnsureDecl(m); - } - - decl->SetCheckState(DeclCheckState::Checked); } - - virtual void VisitConstructorDecl(ConstructorDecl* decl) override + else if (decl->targetType->Equals(ExpressionType::Error)) { - if (decl->IsChecked(DeclCheckState::Checked)) return; - decl->SetCheckState(DeclCheckState::CheckingHeader); + // there was an error, so ignore + } + else + { + getSink()->diagnose(decl->targetType.exp, Diagnostics::unimplemented, "expected a nominal type here"); + } - for (auto& paramDecl : decl->GetParameters()) - { - paramDecl->Type = CheckUsableType(paramDecl->Type); - } - decl->SetCheckState(DeclCheckState::CheckedHeader); + decl->SetCheckState(DeclCheckState::CheckedHeader); - // TODO(tfoley): check body - decl->SetCheckState(DeclCheckState::Checked); + // now check the members of the extension + for (auto m : decl->Members) + { + EnsureDecl(m); } + decl->SetCheckState(DeclCheckState::Checked); + } - virtual void visitSubscriptDecl(SubscriptDecl* decl) override + virtual void VisitConstructorDecl(ConstructorDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return; + decl->SetCheckState(DeclCheckState::CheckingHeader); + + for (auto& paramDecl : decl->GetParameters()) { - if (decl->IsChecked(DeclCheckState::Checked)) return; - decl->SetCheckState(DeclCheckState::CheckingHeader); + paramDecl->Type = CheckUsableType(paramDecl->Type); + } + decl->SetCheckState(DeclCheckState::CheckedHeader); - for (auto& paramDecl : decl->GetParameters()) - { - paramDecl->Type = CheckUsableType(paramDecl->Type); - } + // TODO(tfoley): check body + decl->SetCheckState(DeclCheckState::Checked); + } - decl->ReturnType = CheckUsableType(decl->ReturnType); - decl->SetCheckState(DeclCheckState::CheckedHeader); + virtual void visitSubscriptDecl(SubscriptDecl* decl) override + { + if (decl->IsChecked(DeclCheckState::Checked)) return; + decl->SetCheckState(DeclCheckState::CheckingHeader); - decl->SetCheckState(DeclCheckState::Checked); + for (auto& paramDecl : decl->GetParameters()) + { + paramDecl->Type = CheckUsableType(paramDecl->Type); } - virtual void visitAccessorDecl(AccessorDecl* decl) override - { - // TODO: check the body! + decl->ReturnType = CheckUsableType(decl->ReturnType); - decl->SetCheckState(DeclCheckState::Checked); - } + decl->SetCheckState(DeclCheckState::CheckedHeader); + decl->SetCheckState(DeclCheckState::Checked); + } - // + virtual void visitAccessorDecl(AccessorDecl* decl) override + { + // TODO: check the body! - struct Constraint - { - Decl* decl; // the declaration of the thing being constraints - RefPtr<Val> val; // the value to which we are constraining it - bool satisfied = false; // Has this constraint been met? - }; + decl->SetCheckState(DeclCheckState::Checked); + } - // A collection of constraints that will need to be satisified (solved) - // in order for checking to suceed. - struct ConstraintSystem - { - List<Constraint> constraints; - }; - RefPtr<ExpressionType> TryJoinVectorAndScalarType( - RefPtr<VectorExpressionType> vectorType, - RefPtr<BasicExpressionType> scalarType) - { - // Join( vector<T,N>, S ) -> vetor<Join(T,S), N> - // - // That is, the join of a vector and a scalar type is - // a vector type with a joined element type. - auto joinElementType = TryJoinTypes( - vectorType->elementType, - scalarType); - if(!joinElementType) - return nullptr; + // - return createVectorType( - joinElementType, - vectorType->elementCount); - } + struct Constraint + { + Decl* decl; // the declaration of the thing being constraints + RefPtr<Val> val; // the value to which we are constraining it + bool satisfied = false; // Has this constraint been met? + }; + + // A collection of constraints that will need to be satisified (solved) + // in order for checking to suceed. + struct ConstraintSystem + { + List<Constraint> constraints; + }; + + RefPtr<ExpressionType> TryJoinVectorAndScalarType( + RefPtr<VectorExpressionType> vectorType, + RefPtr<BasicExpressionType> scalarType) + { + // Join( vector<T,N>, S ) -> vetor<Join(T,S), N> + // + // That is, the join of a vector and a scalar type is + // a vector type with a joined element type. + auto joinElementType = TryJoinTypes( + vectorType->elementType, + scalarType); + if(!joinElementType) + return nullptr; - bool DoesTypeConformToInterface( - RefPtr<ExpressionType> type, - InterfaceDeclRef interfaceDeclRef) + return createVectorType( + joinElementType, + vectorType->elementCount); + } + + bool DoesTypeConformToInterface( + RefPtr<ExpressionType> type, + InterfaceDeclRef interfaceDeclRef) + { + // for now look up a conformance member... + if(auto declRefType = type->As<DeclRefType>()) { - // for now look up a conformance member... - if(auto declRefType = type->As<DeclRefType>()) + if( auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>() ) { - if( auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>() ) + for( auto inheritanceDeclRef : aggTypeDeclRef.GetMembersOfType<InheritanceDeclRef>()) { - for( auto inheritanceDeclRef : aggTypeDeclRef.GetMembersOfType<InheritanceDeclRef>()) - { - EnsureDecl(inheritanceDeclRef.GetDecl()); + EnsureDecl(inheritanceDeclRef.GetDecl()); - auto inheritedDeclRefType = inheritanceDeclRef.getBaseType()->As<DeclRefType>(); - if (!inheritedDeclRefType) - continue; + auto inheritedDeclRefType = inheritanceDeclRef.getBaseType()->As<DeclRefType>(); + if (!inheritedDeclRefType) + continue; - if(interfaceDeclRef.Equals(inheritedDeclRefType->declRef)) - return true; - } + if(interfaceDeclRef.Equals(inheritedDeclRefType->declRef)) + return true; } } - - // default is failure - return false; } - RefPtr<ExpressionType> TryJoinTypeWithInterface( - RefPtr<ExpressionType> type, - InterfaceDeclRef interfaceDeclRef) - { - // The most basic test here should be: does the type declare conformance to the trait. - if(DoesTypeConformToInterface(type, interfaceDeclRef)) - return type; + // default is failure + return false; + } - // There is a more nuanced case if `type` is a builtin type, and we need to make it - // conform to a trait that some but not all builtin types support (the main problem - // here is when an operation wants an integer type, but one of our operands is a `float`. - // The HLSL rules will allow that, with implicit conversion, but our default join rules - // will end up picking `float` and we don't want that...). + RefPtr<ExpressionType> TryJoinTypeWithInterface( + RefPtr<ExpressionType> type, + InterfaceDeclRef interfaceDeclRef) + { + // The most basic test here should be: does the type declare conformance to the trait. + if(DoesTypeConformToInterface(type, interfaceDeclRef)) + return type; - // For now we don't handle the hard case and just bail - return nullptr; - } + // There is a more nuanced case if `type` is a builtin type, and we need to make it + // conform to a trait that some but not all builtin types support (the main problem + // here is when an operation wants an integer type, but one of our operands is a `float`. + // The HLSL rules will allow that, with implicit conversion, but our default join rules + // will end up picking `float` and we don't want that...). - // Try to compute the "join" between two types - RefPtr<ExpressionType> TryJoinTypes( - RefPtr<ExpressionType> left, - RefPtr<ExpressionType> right) - { - // Easy case: they are the same type! - if (left->Equals(right)) - return left; + // For now we don't handle the hard case and just bail + return nullptr; + } - // We can join two basic types by picking the "better" of the two - if (auto leftBasic = left->As<BasicExpressionType>()) - { - if (auto rightBasic = right->As<BasicExpressionType>()) - { - auto leftFlavor = leftBasic->BaseType; - auto rightFlavor = rightBasic->BaseType; + // Try to compute the "join" between two types + RefPtr<ExpressionType> TryJoinTypes( + RefPtr<ExpressionType> left, + RefPtr<ExpressionType> right) + { + // Easy case: they are the same type! + if (left->Equals(right)) + return left; - // TODO(tfoley): Need a special-case rule here that if - // either operand is of type `half`, then we promote - // to at least `float` + // We can join two basic types by picking the "better" of the two + if (auto leftBasic = left->As<BasicExpressionType>()) + { + if (auto rightBasic = right->As<BasicExpressionType>()) + { + auto leftFlavor = leftBasic->BaseType; + auto rightFlavor = rightBasic->BaseType; - // Return the one that had higher rank... - if (leftFlavor > rightFlavor) - return left; - else - { - assert(rightFlavor > leftFlavor); - return right; - } - } + // TODO(tfoley): Need a special-case rule here that if + // either operand is of type `half`, then we promote + // to at least `float` - // We can also join a vector and a scalar - if(auto rightVector = right->As<VectorExpressionType>()) + // Return the one that had higher rank... + if (leftFlavor > rightFlavor) + return left; + else { - return TryJoinVectorAndScalarType(rightVector, leftBasic); + assert(rightFlavor > leftFlavor); + return right; } } - // We can join two vector types by joining their element types - // (and also their sizes...) - if( auto leftVector = left->As<VectorExpressionType>()) + // We can also join a vector and a scalar + if(auto rightVector = right->As<VectorExpressionType>()) { - if(auto rightVector = right->As<VectorExpressionType>()) - { - // Check if the vector sizes match - if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr())) - return nullptr; - - // Try to join the element types - auto joinElementType = TryJoinTypes( - leftVector->elementType, - rightVector->elementType); - if(!joinElementType) - return nullptr; - - return createVectorType( - joinElementType, - leftVector->elementCount); - } - - // We can also join a vector and a scalar - if(auto rightBasic = right->As<BasicExpressionType>()) - { - return TryJoinVectorAndScalarType(leftVector, rightBasic); - } + return TryJoinVectorAndScalarType(rightVector, leftBasic); } + } - // HACK: trying to work trait types in here... - if(auto leftDeclRefType = left->As<DeclRefType>()) + // We can join two vector types by joining their element types + // (and also their sizes...) + if( auto leftVector = left->As<VectorExpressionType>()) + { + if(auto rightVector = right->As<VectorExpressionType>()) { - if( auto leftInterfaceRef = leftDeclRefType->declRef.As<InterfaceDeclRef>() ) - { - // - return TryJoinTypeWithInterface(right, leftInterfaceRef); - } + // Check if the vector sizes match + if(!leftVector->elementCount->EqualsVal(rightVector->elementCount.Ptr())) + return nullptr; + + // Try to join the element types + auto joinElementType = TryJoinTypes( + leftVector->elementType, + rightVector->elementType); + if(!joinElementType) + return nullptr; + + return createVectorType( + joinElementType, + leftVector->elementCount); } - if(auto rightDeclRefType = right->As<DeclRefType>()) + + // We can also join a vector and a scalar + if(auto rightBasic = right->As<BasicExpressionType>()) { - if( auto rightInterfaceRef = rightDeclRefType->declRef.As<InterfaceDeclRef>() ) - { - // - return TryJoinTypeWithInterface(left, rightInterfaceRef); - } + return TryJoinVectorAndScalarType(leftVector, rightBasic); } - - // TODO: all the cases for vectors apply to matrices too! - - // Default case is that we just fail. - return nullptr; } - // Try to solve a system of generic constraints. - // The `system` argument provides the constraints. - // The `varSubst` argument provides the list of constraint - // variables that were created for the system. - // - // Returns a new substitution representing the values that - // we solved for along the way. - RefPtr<Substitutions> TrySolveConstraintSystem( - ConstraintSystem* system, - GenericDeclRef genericDeclRef) + // HACK: trying to work trait types in here... + if(auto leftDeclRefType = left->As<DeclRefType>()) { - // For now the "solver" is going to be ridiculously simplistic. - - // The generic itself will have some constraints, so we need to try and solve those too - for( auto constraintDeclRef : genericDeclRef.GetMembersOfType<GenericTypeConstraintDeclRef>() ) + if( auto leftInterfaceRef = leftDeclRefType->declRef.As<InterfaceDeclRef>() ) { - if(!TryUnifyTypes(*system, constraintDeclRef.GetSub(), constraintDeclRef.GetSup())) - return nullptr; + // + return TryJoinTypeWithInterface(right, leftInterfaceRef); } + } + if(auto rightDeclRefType = right->As<DeclRefType>()) + { + if( auto rightInterfaceRef = rightDeclRefType->declRef.As<InterfaceDeclRef>() ) + { + // + return TryJoinTypeWithInterface(left, rightInterfaceRef); + } + } + + // TODO: all the cases for vectors apply to matrices too! + + // Default case is that we just fail. + return nullptr; + } + + // Try to solve a system of generic constraints. + // The `system` argument provides the constraints. + // The `varSubst` argument provides the list of constraint + // variables that were created for the system. + // + // Returns a new substitution representing the values that + // we solved for along the way. + RefPtr<Substitutions> TrySolveConstraintSystem( + ConstraintSystem* system, + GenericDeclRef genericDeclRef) + { + // For now the "solver" is going to be ridiculously simplistic. + + // The generic itself will have some constraints, so we need to try and solve those too + for( auto constraintDeclRef : genericDeclRef.GetMembersOfType<GenericTypeConstraintDeclRef>() ) + { + if(!TryUnifyTypes(*system, constraintDeclRef.GetSub(), constraintDeclRef.GetSup())) + return nullptr; + } - // We will loop over the generic parameters, and for - // each we will try to find a way to satisfy all - // the constraints for that parameter - List<RefPtr<Val>> args; - for (auto m : genericDeclRef.GetMembers()) + // We will loop over the generic parameters, and for + // each we will try to find a way to satisfy all + // the constraints for that parameter + List<RefPtr<Val>> args; + for (auto m : genericDeclRef.GetMembers()) + { + if (auto typeParam = m.As<GenericTypeParamDeclRef>()) { - if (auto typeParam = m.As<GenericTypeParamDeclRef>()) + RefPtr<ExpressionType> type = nullptr; + for (auto& c : system->constraints) { - RefPtr<ExpressionType> type = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != typeParam.GetDecl()) - continue; + if (c.decl != typeParam.GetDecl()) + continue; - auto cType = c.val.As<ExpressionType>(); - assert(cType.Ptr()); + auto cType = c.val.As<ExpressionType>(); + assert(cType.Ptr()); - if (!type) - { - type = cType; - } - else + if (!type) + { + type = cType; + } + else + { + auto joinType = TryJoinTypes(type, cType); + if (!joinType) { - auto joinType = TryJoinTypes(type, cType); - if (!joinType) - { - // failure! - return nullptr; - } - type = joinType; + // failure! + return nullptr; } - - c.satisfied = true; + type = joinType; } - if (!type) - { - // failure! - return nullptr; - } - args.Add(type); + c.satisfied = true; } - else if (auto valParam = m.As<GenericValueParamDeclRef>()) + + if (!type) { - // TODO(tfoley): maybe support more than integers some day? - // TODO(tfoley): figure out how this needs to interact with - // compile-time integers that aren't just constants... - RefPtr<IntVal> val = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != valParam.GetDecl()) - continue; + // failure! + return nullptr; + } + args.Add(type); + } + else if (auto valParam = m.As<GenericValueParamDeclRef>()) + { + // TODO(tfoley): maybe support more than integers some day? + // TODO(tfoley): figure out how this needs to interact with + // compile-time integers that aren't just constants... + RefPtr<IntVal> val = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != valParam.GetDecl()) + continue; - auto cVal = c.val.As<IntVal>(); - assert(cVal.Ptr()); + auto cVal = c.val.As<IntVal>(); + assert(cVal.Ptr()); - if (!val) - { - val = cVal; - } - else + if (!val) + { + val = cVal; + } + else + { + if(!val->EqualsVal(cVal.Ptr())) { - if(!val->EqualsVal(cVal.Ptr())) - { - // failure! - return nullptr; - } + // failure! + return nullptr; } - - c.satisfied = true; } - if (!val) - { - // failure! - return nullptr; - } - args.Add(val); + c.satisfied = true; } - else + + if (!val) { - // ignore anything that isn't a generic parameter + // failure! + return nullptr; } + args.Add(val); } + else + { + // ignore anything that isn't a generic parameter + } + } - // Make sure we haven't constructed any spurious constraints - // that we aren't able to satisfy: - for (auto c : system->constraints) + // Make sure we haven't constructed any spurious constraints + // that we aren't able to satisfy: + for (auto c : system->constraints) + { + if (!c.satisfied) { - if (!c.satisfied) - { - return nullptr; - } + return nullptr; } + } - // Consruct a reference to the extension with our constraint variables - // as the - RefPtr<Substitutions> solvedSubst = new Substitutions(); - solvedSubst->genericDecl = genericDeclRef.GetDecl(); - solvedSubst->outer = genericDeclRef.substitutions; - solvedSubst->args = args; + // Consruct a reference to the extension with our constraint variables + // as the + RefPtr<Substitutions> solvedSubst = new Substitutions(); + solvedSubst->genericDecl = genericDeclRef.GetDecl(); + solvedSubst->outer = genericDeclRef.substitutions; + solvedSubst->args = args; - return solvedSubst; + return solvedSubst; #if 0 - List<RefPtr<Val>> solvedArgs; - for (auto varArg : varSubst->args) + List<RefPtr<Val>> solvedArgs; + for (auto varArg : varSubst->args) + { + if (auto typeVar = dynamic_cast<ConstraintVarType*>(varArg.Ptr())) { - if (auto typeVar = dynamic_cast<ConstraintVarType*>(varArg.Ptr())) + RefPtr<ExpressionType> type = nullptr; + for (auto& c : system->constraints) { - RefPtr<ExpressionType> type = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != typeVar->declRef.GetDecl()) - continue; + if (c.decl != typeVar->declRef.GetDecl()) + continue; - auto cType = c.val.As<ExpressionType>(); - assert(cType.Ptr()); + auto cType = c.val.As<ExpressionType>(); + assert(cType.Ptr()); - if (!type) - { - type = cType; - } - else + if (!type) + { + type = cType; + } + else + { + if (!type->Equals(cType)) { - if (!type->Equals(cType)) - { - // failure! - return nullptr; - } + // failure! + return nullptr; } - - c.satisfied = true; } - if (!type) - { - // failure! - return nullptr; - } - solvedArgs.Add(type); + c.satisfied = true; } - else if (auto valueVar = dynamic_cast<ConstraintVarInt*>(varArg.Ptr())) + + if (!type) { - // TODO(tfoley): maybe support more than integers some day? - RefPtr<IntVal> val = nullptr; - for (auto& c : system->constraints) - { - if (c.decl != valueVar->declRef.GetDecl()) - continue; + // failure! + return nullptr; + } + solvedArgs.Add(type); + } + else if (auto valueVar = dynamic_cast<ConstraintVarInt*>(varArg.Ptr())) + { + // TODO(tfoley): maybe support more than integers some day? + RefPtr<IntVal> val = nullptr; + for (auto& c : system->constraints) + { + if (c.decl != valueVar->declRef.GetDecl()) + continue; - auto cVal = c.val.As<IntVal>(); - assert(cVal.Ptr()); + auto cVal = c.val.As<IntVal>(); + assert(cVal.Ptr()); - if (!val) + if (!val) + { + val = cVal; + } + else + { + if (val->value != cVal->value) { - val = cVal; + // failure! + return nullptr; } - else - { - if (val->value != cVal->value) - { - // failure! - return nullptr; - } - } - - c.satisfied = true; } - if (!val) - { - // failure! - return nullptr; - } - solvedArgs.Add(val); + c.satisfied = true; } - else + + if (!val) { - // ignore anything that isn't a generic parameter + // failure! + return nullptr; } + solvedArgs.Add(val); } + else + { + // ignore anything that isn't a generic parameter + } + } - // Make sure we haven't constructed any spurious constraints - // that we aren't able to satisfy: - for (auto c : system->constraints) + // Make sure we haven't constructed any spurious constraints + // that we aren't able to satisfy: + for (auto c : system->constraints) + { + if (!c.satisfied) { - if (!c.satisfied) - { - return nullptr; - } + return nullptr; } + } - RefPtr<Substitutions> newSubst = new Substitutions(); - newSubst->genericDecl = varSubst->genericDecl; - newSubst->outer = varSubst->outer; - newSubst->args = solvedArgs; - return newSubst; + RefPtr<Substitutions> newSubst = new Substitutions(); + newSubst->genericDecl = varSubst->genericDecl; + newSubst->outer = varSubst->outer; + newSubst->args = solvedArgs; + return newSubst; #endif - } + } - // + // - struct OverloadCandidate + struct OverloadCandidate + { + enum class Flavor { - enum class Flavor - { - Func, - Generic, - UnspecializedGeneric, - }; - Flavor flavor; + Func, + Generic, + UnspecializedGeneric, + }; + Flavor flavor; - enum class Status - { - GenericArgumentInferenceFailed, - Unchecked, - ArityChecked, - FixityChecked, - TypeChecked, - Appicable, - }; - Status status = Status::Unchecked; - - // Reference to the declaration being applied - LookupResultItem item; - - // The type of the result expression if this candidate is selected - RefPtr<ExpressionType> resultType; - - // A system for tracking constraints introduced on generic parameters - ConstraintSystem constraintSystem; - - // How much conversion cost should be considered for this overload, - // when ranking candidates. - ConversionCost conversionCostSum = kConversionCost_None; + enum class Status + { + GenericArgumentInferenceFailed, + Unchecked, + ArityChecked, + FixityChecked, + TypeChecked, + Appicable, }; + Status status = Status::Unchecked; + // Reference to the declaration being applied + LookupResultItem item; + // The type of the result expression if this candidate is selected + RefPtr<ExpressionType> resultType; - // State related to overload resolution for a call - // to an overloaded symbol - struct OverloadResolveContext - { - enum class Mode - { - // We are just checking if a candidate works or not - JustTrying, + // A system for tracking constraints introduced on generic parameters + ConstraintSystem constraintSystem; - // We want to actually update the AST for a chosen candidate - ForReal, - }; + // How much conversion cost should be considered for this overload, + // when ranking candidates. + ConversionCost conversionCostSum = kConversionCost_None; + }; - RefPtr<AppExprBase> appExpr; - RefPtr<ExpressionSyntaxNode> baseExpr; - // Are we still trying out candidates, or are we - // checking the chosen one for real? - Mode mode = Mode::JustTrying; - // We store one candidate directly, so that we don't - // need to do dynamic allocation on the list every time - OverloadCandidate bestCandidateStorage; - OverloadCandidate* bestCandidate = nullptr; + // State related to overload resolution for a call + // to an overloaded symbol + struct OverloadResolveContext + { + enum class Mode + { + // We are just checking if a candidate works or not + JustTrying, - // Full list of all candidates being considered, in the ambiguous case - List<OverloadCandidate> bestCandidates; + // We want to actually update the AST for a chosen candidate + ForReal, }; - struct ParamCounts - { - int required; - int allowed; - }; + RefPtr<AppExprBase> appExpr; + RefPtr<ExpressionSyntaxNode> baseExpr; - // count the number of parameters required/allowed for a callable - ParamCounts CountParameters(FilteredMemberRefList<ParamDeclRef> params) + // Are we still trying out candidates, or are we + // checking the chosen one for real? + Mode mode = Mode::JustTrying; + + // We store one candidate directly, so that we don't + // need to do dynamic allocation on the list every time + OverloadCandidate bestCandidateStorage; + OverloadCandidate* bestCandidate = nullptr; + + // Full list of all candidates being considered, in the ambiguous case + List<OverloadCandidate> bestCandidates; + }; + + struct ParamCounts + { + int required; + int allowed; + }; + + // count the number of parameters required/allowed for a callable + ParamCounts CountParameters(FilteredMemberRefList<ParamDeclRef> params) + { + ParamCounts counts = { 0, 0 }; + for (auto param : params) { - ParamCounts counts = { 0, 0 }; - for (auto param : params) - { - counts.allowed++; + counts.allowed++; - // No initializer means no default value - // - // TODO(tfoley): The logic here is currently broken in two ways: - // - // 1. We are assuming that once one parameter has a default, then all do. - // This can/should be validated earlier, so that we can assume it here. - // - // 2. We are not handling the possibility of multiple declarations for - // a single function, where we'd need to merge default parameters across - // all the declarations. - if (!param.GetDecl()->Expr) - { - counts.required++; - } + // No initializer means no default value + // + // TODO(tfoley): The logic here is currently broken in two ways: + // + // 1. We are assuming that once one parameter has a default, then all do. + // This can/should be validated earlier, so that we can assume it here. + // + // 2. We are not handling the possibility of multiple declarations for + // a single function, where we'd need to merge default parameters across + // all the declarations. + if (!param.GetDecl()->Expr) + { + counts.required++; } - return counts; } + return counts; + } - // count the number of parameters required/allowed for a generic - ParamCounts CountParameters(GenericDeclRef genericRef) + // count the number of parameters required/allowed for a generic + ParamCounts CountParameters(GenericDeclRef genericRef) + { + ParamCounts counts = { 0, 0 }; + for (auto m : genericRef.GetDecl()->Members) { - ParamCounts counts = { 0, 0 }; - for (auto m : genericRef.GetDecl()->Members) + if (auto typeParam = m.As<GenericTypeParamDecl>()) { - if (auto typeParam = m.As<GenericTypeParamDecl>()) + counts.allowed++; + if (!typeParam->initType.Ptr()) { - counts.allowed++; - if (!typeParam->initType.Ptr()) - { - counts.required++; - } + counts.required++; } - else if (auto valParam = m.As<GenericValueParamDecl>()) + } + else if (auto valParam = m.As<GenericValueParamDecl>()) + { + counts.allowed++; + if (!valParam->Expr) { - counts.allowed++; - if (!valParam->Expr) - { - counts.required++; - } + counts.required++; } } - return counts; } + return counts; + } - bool TryCheckOverloadCandidateArity( - OverloadResolveContext& context, - OverloadCandidate const& candidate) + bool TryCheckOverloadCandidateArity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) + { + int argCount = context.appExpr->Arguments.Count(); + ParamCounts paramCounts = { 0, 0 }; + switch (candidate.flavor) { - int argCount = context.appExpr->Arguments.Count(); - ParamCounts paramCounts = { 0, 0 }; - switch (candidate.flavor) - { - case OverloadCandidate::Flavor::Func: - paramCounts = CountParameters(candidate.item.declRef.As<CallableDeclRef>().GetParameters()); - break; + case OverloadCandidate::Flavor::Func: + paramCounts = CountParameters(candidate.item.declRef.As<CallableDeclRef>().GetParameters()); + break; - case OverloadCandidate::Flavor::Generic: - paramCounts = CountParameters(candidate.item.declRef.As<GenericDeclRef>()); - break; + case OverloadCandidate::Flavor::Generic: + paramCounts = CountParameters(candidate.item.declRef.As<GenericDeclRef>()); + break; - default: - assert(!"unexpected"); - break; - } + default: + assert(!"unexpected"); + break; + } - if (argCount >= paramCounts.required && argCount <= paramCounts.allowed) - return true; + if (argCount >= paramCounts.required && argCount <= paramCounts.allowed) + return true; - // Emit an error message if we are checking this call for real - if (context.mode != OverloadResolveContext::Mode::JustTrying) + // Emit an error message if we are checking this call for real + if (context.mode != OverloadResolveContext::Mode::JustTrying) + { + if (argCount < paramCounts.required) { - if (argCount < paramCounts.required) - { - getSink()->diagnose(context.appExpr, Diagnostics::notEnoughArguments, argCount, paramCounts.required); - } - else - { - assert(argCount > paramCounts.allowed); - getSink()->diagnose(context.appExpr, Diagnostics::tooManyArguments, argCount, paramCounts.allowed); - } + getSink()->diagnose(context.appExpr, Diagnostics::notEnoughArguments, argCount, paramCounts.required); + } + else + { + assert(argCount > paramCounts.allowed); + getSink()->diagnose(context.appExpr, Diagnostics::tooManyArguments, argCount, paramCounts.allowed); } - - return false; } - bool TryCheckOverloadCandidateFixity( - OverloadResolveContext& context, - OverloadCandidate const& candidate) - { - auto expr = context.appExpr; + return false; + } - auto decl = candidate.item.declRef.decl; + bool TryCheckOverloadCandidateFixity( + OverloadResolveContext& context, + OverloadCandidate const& candidate) + { + auto expr = context.appExpr; - if(auto prefixExpr = expr.As<PrefixExpr>()) - { - if(decl->HasModifier<PrefixModifier>()) - return true; + auto decl = candidate.item.declRef.decl; - if (context.mode != OverloadResolveContext::Mode::JustTrying) - { - getSink()->diagnose(context.appExpr, Diagnostics::expectedPrefixOperator); - getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); - } + if(auto prefixExpr = expr.As<PrefixExpr>()) + { + if(decl->HasModifier<PrefixModifier>()) + return true; - return false; - } - else if(auto postfixExpr = expr.As<PostfixExpr>()) + if (context.mode != OverloadResolveContext::Mode::JustTrying) { - if(decl->HasModifier<PostfixModifier>()) - return true; + getSink()->diagnose(context.appExpr, Diagnostics::expectedPrefixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); + } - if (context.mode != OverloadResolveContext::Mode::JustTrying) - { - getSink()->diagnose(context.appExpr, Diagnostics::expectedPostfixOperator); - getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); - } + return false; + } + else if(auto postfixExpr = expr.As<PostfixExpr>()) + { + if(decl->HasModifier<PostfixModifier>()) + return true; - return false; - } - else + if (context.mode != OverloadResolveContext::Mode::JustTrying) { - return true; + getSink()->diagnose(context.appExpr, Diagnostics::expectedPostfixOperator); + getSink()->diagnose(decl, Diagnostics::seeDefinitionOf, decl->getName()); } return false; } - - bool TryCheckGenericOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate) + else { - auto& args = context.appExpr->Arguments; + return true; + } + + return false; + } + + bool TryCheckGenericOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + auto& args = context.appExpr->Arguments; - auto genericDeclRef = candidate.item.declRef.As<GenericDeclRef>(); + auto genericDeclRef = candidate.item.declRef.As<GenericDeclRef>(); - int aa = 0; - for (auto memberRef : genericDeclRef.GetMembers()) + int aa = 0; + for (auto memberRef : genericDeclRef.GetMembers()) + { + if (auto typeParamRef = memberRef.As<GenericTypeParamDeclRef>()) { - if (auto typeParamRef = memberRef.As<GenericTypeParamDeclRef>()) - { - auto arg = args[aa++]; + auto arg = args[aa++]; - if (context.mode == OverloadResolveContext::Mode::JustTrying) - { - if (!CanCoerceToProperType(TypeExp(arg))) - { - return false; - } - } - else - { - TypeExp typeExp = CoerceToProperType(TypeExp(arg)); - } - } - else if (auto valParamRef = memberRef.As<GenericValueParamDeclRef>()) + if (context.mode == OverloadResolveContext::Mode::JustTrying) { - auto arg = args[aa++]; - - if (context.mode == OverloadResolveContext::Mode::JustTrying) + if (!CanCoerceToProperType(TypeExp(arg))) { - ConversionCost cost = kConversionCost_None; - if (!CanCoerce(valParamRef.GetType(), arg->Type, &cost)) - { - return false; - } - candidate.conversionCostSum += cost; - } - else - { - arg = Coerce(valParamRef.GetType(), arg); - auto val = ExtractGenericArgInteger(arg); + return false; } } else { - continue; + TypeExp typeExp = CoerceToProperType(TypeExp(arg)); } } - - return true; - } - - bool TryCheckOverloadCandidateTypes( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - auto& args = context.appExpr->Arguments; - int argCount = args.Count(); - - List<ParamDeclRef> params; - switch (candidate.flavor) + else if (auto valParamRef = memberRef.As<GenericValueParamDeclRef>()) { - case OverloadCandidate::Flavor::Func: - params = candidate.item.declRef.As<CallableDeclRef>().GetParameters().ToArray(); - break; - - case OverloadCandidate::Flavor::Generic: - return TryCheckGenericOverloadCandidateTypes(context, candidate); - - default: - assert(!"unexpected"); - break; - } - - // Note(tfoley): We might have fewer arguments than parameters in the - // case where one or more parameters had defaults. - assert(argCount <= params.Count()); - - for (int ii = 0; ii < argCount; ++ii) - { - auto& arg = args[ii]; - auto param = params[ii]; + auto arg = args[aa++]; if (context.mode == OverloadResolveContext::Mode::JustTrying) { ConversionCost cost = kConversionCost_None; - if (!CanCoerce(param.GetType(), arg->Type, &cost)) + if (!CanCoerce(valParamRef.GetType(), arg->Type, &cost)) { return false; } @@ -3177,531 +3120,532 @@ namespace Slang } else { - arg = Coerce(param.GetType(), arg); + arg = Coerce(valParamRef.GetType(), arg); + auto val = ExtractGenericArgInteger(arg); } } - return true; + else + { + continue; + } } - bool TryCheckOverloadCandidateDirections( - OverloadResolveContext& /*context*/, - OverloadCandidate const& /*candidate*/) - { - // TODO(tfoley): check `in` and `out` markers, as needed. - return true; - } + return true; + } - // Try to check an overload candidate, but bail out - // if any step fails - void TryCheckOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - if (!TryCheckOverloadCandidateArity(context, candidate)) - return; + bool TryCheckOverloadCandidateTypes( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + auto& args = context.appExpr->Arguments; + int argCount = args.Count(); - candidate.status = OverloadCandidate::Status::ArityChecked; - if (!TryCheckOverloadCandidateFixity(context, candidate)) - return; + List<ParamDeclRef> params; + switch (candidate.flavor) + { + case OverloadCandidate::Flavor::Func: + params = candidate.item.declRef.As<CallableDeclRef>().GetParameters().ToArray(); + break; - candidate.status = OverloadCandidate::Status::FixityChecked; - if (!TryCheckOverloadCandidateTypes(context, candidate)) - return; + case OverloadCandidate::Flavor::Generic: + return TryCheckGenericOverloadCandidateTypes(context, candidate); - candidate.status = OverloadCandidate::Status::TypeChecked; - if (!TryCheckOverloadCandidateDirections(context, candidate)) - return; - - candidate.status = OverloadCandidate::Status::Appicable; + default: + assert(!"unexpected"); + break; } - // Create the representation of a given generic applied to some arguments - RefPtr<ExpressionSyntaxNode> CreateGenericDeclRef( - RefPtr<ExpressionSyntaxNode> baseExpr, - RefPtr<AppExprBase> appExpr) + // Note(tfoley): We might have fewer arguments than parameters in the + // case where one or more parameters had defaults. + assert(argCount <= params.Count()); + + for (int ii = 0; ii < argCount; ++ii) { - auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>(); - if (!baseDeclRefExpr) + auto& arg = args[ii]; + auto param = params[ii]; + + if (context.mode == OverloadResolveContext::Mode::JustTrying) { - assert(!"unexpected"); - return CreateErrorExpr(appExpr.Ptr()); + ConversionCost cost = kConversionCost_None; + if (!CanCoerce(param.GetType(), arg->Type, &cost)) + { + return false; + } + candidate.conversionCostSum += cost; } - auto baseGenericRef = baseDeclRefExpr->declRef.As<GenericDeclRef>(); - if (!baseGenericRef) + else { - assert(!"unexpected"); - return CreateErrorExpr(appExpr.Ptr()); + arg = Coerce(param.GetType(), arg); } + } + return true; + } - RefPtr<Substitutions> subst = new Substitutions(); - subst->genericDecl = baseGenericRef.GetDecl(); - subst->outer = baseGenericRef.substitutions; + bool TryCheckOverloadCandidateDirections( + OverloadResolveContext& /*context*/, + OverloadCandidate const& /*candidate*/) + { + // TODO(tfoley): check `in` and `out` markers, as needed. + return true; + } - for (auto arg : appExpr->Arguments) - { - subst->args.Add(ExtractGenericArgVal(arg)); - } + // Try to check an overload candidate, but bail out + // if any step fails + void TryCheckOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + if (!TryCheckOverloadCandidateArity(context, candidate)) + return; - DeclRef innerDeclRef(baseGenericRef.GetInner(), subst); + candidate.status = OverloadCandidate::Status::ArityChecked; + if (!TryCheckOverloadCandidateFixity(context, candidate)) + return; - return ConstructDeclRefExpr( - innerDeclRef, - nullptr, - appExpr); + candidate.status = OverloadCandidate::Status::FixityChecked; + if (!TryCheckOverloadCandidateTypes(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::TypeChecked; + if (!TryCheckOverloadCandidateDirections(context, candidate)) + return; + + candidate.status = OverloadCandidate::Status::Appicable; + } + + // Create the representation of a given generic applied to some arguments + RefPtr<ExpressionSyntaxNode> CreateGenericDeclRef( + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<AppExprBase> appExpr) + { + auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>(); + if (!baseDeclRefExpr) + { + assert(!"unexpected"); + return CreateErrorExpr(appExpr.Ptr()); + } + auto baseGenericRef = baseDeclRefExpr->declRef.As<GenericDeclRef>(); + if (!baseGenericRef) + { + assert(!"unexpected"); + return CreateErrorExpr(appExpr.Ptr()); } - // Take an overload candidate that previously got through - // `TryCheckOverloadCandidate` above, and try to finish - // up the work and turn it into a real expression. - // - // If the candidate isn't actually applicable, this is - // where we'd start reporting the issue(s). - RefPtr<ExpressionSyntaxNode> CompleteOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - // special case for generic argument inference failure - if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed) - { - String callString = GetCallSignatureString(context.appExpr); - getSink()->diagnose( - context.appExpr, - Diagnostics::genericArgumentInferenceFailed, - callString); + RefPtr<Substitutions> subst = new Substitutions(); + subst->genericDecl = baseGenericRef.GetDecl(); + subst->outer = baseGenericRef.substitutions; - String declString = getDeclSignatureString(candidate.item); - getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString); - goto error; - } + for (auto arg : appExpr->Arguments) + { + subst->args.Add(ExtractGenericArgVal(arg)); + } - context.mode = OverloadResolveContext::Mode::ForReal; - context.appExpr->Type = ExpressionType::Error; + DeclRef innerDeclRef(baseGenericRef.GetInner(), subst); - if (!TryCheckOverloadCandidateArity(context, candidate)) - goto error; + return ConstructDeclRefExpr( + innerDeclRef, + nullptr, + appExpr); + } - if (!TryCheckOverloadCandidateFixity(context, candidate)) - goto error; + // Take an overload candidate that previously got through + // `TryCheckOverloadCandidate` above, and try to finish + // up the work and turn it into a real expression. + // + // If the candidate isn't actually applicable, this is + // where we'd start reporting the issue(s). + RefPtr<ExpressionSyntaxNode> CompleteOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // special case for generic argument inference failure + if (candidate.status == OverloadCandidate::Status::GenericArgumentInferenceFailed) + { + String callString = GetCallSignatureString(context.appExpr); + getSink()->diagnose( + context.appExpr, + Diagnostics::genericArgumentInferenceFailed, + callString); - if (!TryCheckOverloadCandidateTypes(context, candidate)) - goto error; + String declString = getDeclSignatureString(candidate.item); + getSink()->diagnose(candidate.item.declRef, Diagnostics::genericSignatureTried, declString); + goto error; + } - if (!TryCheckOverloadCandidateDirections(context, candidate)) - goto error; + context.mode = OverloadResolveContext::Mode::ForReal; + context.appExpr->Type = ExpressionType::Error; + if (!TryCheckOverloadCandidateArity(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateFixity(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateTypes(context, candidate)) + goto error; + + if (!TryCheckOverloadCandidateDirections(context, candidate)) + goto error; + + { + auto baseExpr = ConstructLookupResultExpr( + candidate.item, context.baseExpr, context.appExpr->FunctionExpr); + + switch(candidate.flavor) { - auto baseExpr = ConstructLookupResultExpr( - candidate.item, context.baseExpr, context.appExpr->FunctionExpr); + case OverloadCandidate::Flavor::Func: + context.appExpr->FunctionExpr = baseExpr; + context.appExpr->Type = candidate.resultType; - switch(candidate.flavor) + // A call may yield an l-value, and we should take a look at the candidate to be sure + if(auto subscriptDeclRef = candidate.item.declRef.As<SubscriptDeclRef>()) { - case OverloadCandidate::Flavor::Func: - context.appExpr->FunctionExpr = baseExpr; - context.appExpr->Type = candidate.resultType; - - // A call may yield an l-value, and we should take a look at the candidate to be sure - if(auto subscriptDeclRef = candidate.item.declRef.As<SubscriptDeclRef>()) + for(auto setter : subscriptDeclRef.GetDecl()->GetMembersOfType<SetterDecl>()) { - for(auto setter : subscriptDeclRef.GetDecl()->GetMembersOfType<SetterDecl>()) - { - context.appExpr->Type.IsLeftValue = true; - } + context.appExpr->Type.IsLeftValue = true; } + } - // TODO: there may be other cases that confer l-value-ness + // TODO: there may be other cases that confer l-value-ness - return context.appExpr; - break; + return context.appExpr; + break; - case OverloadCandidate::Flavor::Generic: - return CreateGenericDeclRef(baseExpr, context.appExpr); - break; + case OverloadCandidate::Flavor::Generic: + return CreateGenericDeclRef(baseExpr, context.appExpr); + break; - default: - assert(!"unexpected"); - break; - } + default: + assert(!"unexpected"); + break; } - - - error: - return CreateErrorExpr(context.appExpr.Ptr()); } - // Implement a comparison operation between overload candidates, - // so that the better candidate compares as less-than the other - int CompareOverloadCandidates( - OverloadCandidate* left, - OverloadCandidate* right) - { - // If one candidate got further along in validation, pick it - if (left->status != right->status) - return int(right->status) - int(left->status); - // If both candidates are applicable, then we need to compare - // the costs of their type conversion sequences - if(left->status == OverloadCandidate::Status::Appicable) - { - if (left->conversionCostSum != right->conversionCostSum) - return left->conversionCostSum - right->conversionCostSum; - } + error: + return CreateErrorExpr(context.appExpr.Ptr()); + } - return 0; - } + // Implement a comparison operation between overload candidates, + // so that the better candidate compares as less-than the other + int CompareOverloadCandidates( + OverloadCandidate* left, + OverloadCandidate* right) + { + // If one candidate got further along in validation, pick it + if (left->status != right->status) + return int(right->status) - int(left->status); - void AddOverloadCandidateInner( - OverloadResolveContext& context, - OverloadCandidate& candidate) + // If both candidates are applicable, then we need to compare + // the costs of their type conversion sequences + if(left->status == OverloadCandidate::Status::Appicable) { - // Filter our existing candidates, to remove any that are worse than our new one + if (left->conversionCostSum != right->conversionCostSum) + return left->conversionCostSum - right->conversionCostSum; + } - bool keepThisCandidate = true; // should this candidate be kept? + return 0; + } - if (context.bestCandidates.Count() != 0) - { - // We have multiple candidates right now, so filter them. - bool anyFiltered = false; - // Note that we are querying the list length on every iteration, - // because we might remove things. - for (int cc = 0; cc < context.bestCandidates.Count(); ++cc) - { - int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]); - if (cmp < 0) - { - // our new candidate is better! + void AddOverloadCandidateInner( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // Filter our existing candidates, to remove any that are worse than our new one - // remove it from the list (by swapping in a later one) - context.bestCandidates.FastRemoveAt(cc); - // and then reduce our index so that we re-visit the same index - --cc; + bool keepThisCandidate = true; // should this candidate be kept? - anyFiltered = true; - } - else if(cmp > 0) - { - // our candidate is worse! - keepThisCandidate = false; - } - } - // It should not be possible that we removed some existing candidate *and* - // chose not to keep this candidate (otherwise the better-ness relation - // isn't transitive). Therefore we confirm that we either chose to keep - // this candidate (in which case filtering is okay), or we didn't filter - // anything. - assert(keepThisCandidate || !anyFiltered); - } - else if(context.bestCandidate) + if (context.bestCandidates.Count() != 0) + { + // We have multiple candidates right now, so filter them. + bool anyFiltered = false; + // Note that we are querying the list length on every iteration, + // because we might remove things. + for (int cc = 0; cc < context.bestCandidates.Count(); ++cc) { - // There's only one candidate so far - int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate); - if(cmp < 0) + int cmp = CompareOverloadCandidates(&candidate, &context.bestCandidates[cc]); + if (cmp < 0) { // our new candidate is better! - context.bestCandidate = nullptr; + + // remove it from the list (by swapping in a later one) + context.bestCandidates.FastRemoveAt(cc); + // and then reduce our index so that we re-visit the same index + --cc; + + anyFiltered = true; } - else if (cmp > 0) + else if(cmp > 0) { // our candidate is worse! keepThisCandidate = false; } } - - // If our candidate isn't good enough, then drop it - if (!keepThisCandidate) - return; - - // Otherwise we want to keep the candidate - if (context.bestCandidates.Count() > 0) - { - // There were already multiple candidates, and we are adding one more - context.bestCandidates.Add(candidate); - } - else if (context.bestCandidate) + // It should not be possible that we removed some existing candidate *and* + // chose not to keep this candidate (otherwise the better-ness relation + // isn't transitive). Therefore we confirm that we either chose to keep + // this candidate (in which case filtering is okay), or we didn't filter + // anything. + assert(keepThisCandidate || !anyFiltered); + } + else if(context.bestCandidate) + { + // There's only one candidate so far + int cmp = CompareOverloadCandidates(&candidate, context.bestCandidate); + if(cmp < 0) { - // There was a unique best candidate, but now we are ambiguous - context.bestCandidates.Add(*context.bestCandidate); - context.bestCandidates.Add(candidate); + // our new candidate is better! context.bestCandidate = nullptr; } - else + else if (cmp > 0) { - // This is the only candidate worthe keeping track of right now - context.bestCandidateStorage = candidate; - context.bestCandidate = &context.bestCandidateStorage; + // our candidate is worse! + keepThisCandidate = false; } } - void AddOverloadCandidate( - OverloadResolveContext& context, - OverloadCandidate& candidate) - { - // Try the candidate out, to see if it is applicable at all. - TryCheckOverloadCandidate(context, candidate); + // If our candidate isn't good enough, then drop it + if (!keepThisCandidate) + return; - // Now (potentially) add it to the set of candidate overloads to consider. - AddOverloadCandidateInner(context, candidate); + // Otherwise we want to keep the candidate + if (context.bestCandidates.Count() > 0) + { + // There were already multiple candidates, and we are adding one more + context.bestCandidates.Add(candidate); } - - void AddFuncOverloadCandidate( - LookupResultItem item, - CallableDeclRef funcDeclRef, - OverloadResolveContext& context) + else if (context.bestCandidate) { - EnsureDecl(funcDeclRef.GetDecl()); - - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Func; - candidate.item = item; - candidate.resultType = funcDeclRef.GetResultType(); - - AddOverloadCandidate(context, candidate); + // There was a unique best candidate, but now we are ambiguous + context.bestCandidates.Add(*context.bestCandidate); + context.bestCandidates.Add(candidate); + context.bestCandidate = nullptr; } - - void AddFuncOverloadCandidate( - RefPtr<FuncType> /*funcType*/, - OverloadResolveContext& /*context*/) + else { -#if 0 - if (funcType->decl) - { - AddFuncOverloadCandidate(funcType->decl, context); - } - else if (funcType->Func) - { - AddFuncOverloadCandidate(funcType->Func->SyntaxNode, context); - } - else if (funcType->Component) - { - AddComponentFuncOverloadCandidate(funcType->Component, context); - } -#else - throw "unimplemented"; -#endif + // This is the only candidate worthe keeping track of right now + context.bestCandidateStorage = candidate; + context.bestCandidate = &context.bestCandidateStorage; } + } - void AddCtorOverloadCandidate( - LookupResultItem typeItem, - RefPtr<ExpressionType> type, - ConstructorDeclRef ctorDeclRef, - OverloadResolveContext& context) - { - EnsureDecl(ctorDeclRef.GetDecl()); + void AddOverloadCandidate( + OverloadResolveContext& context, + OverloadCandidate& candidate) + { + // Try the candidate out, to see if it is applicable at all. + TryCheckOverloadCandidate(context, candidate); - // `typeItem` refers to the type being constructed (the thing - // that was applied as a function) so we need to construct - // a `LookupResultItem` that refers to the constructor instead + // Now (potentially) add it to the set of candidate overloads to consider. + AddOverloadCandidateInner(context, candidate); + } - LookupResultItem ctorItem; - ctorItem.declRef = ctorDeclRef; - ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb(LookupResultItem::Breadcrumb::Kind::Member, typeItem.declRef, typeItem.breadcrumbs); + void AddFuncOverloadCandidate( + LookupResultItem item, + CallableDeclRef funcDeclRef, + OverloadResolveContext& context) + { + EnsureDecl(funcDeclRef.GetDecl()); - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Func; - candidate.item = ctorItem; - candidate.resultType = type; + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = item; + candidate.resultType = funcDeclRef.GetResultType(); - AddOverloadCandidate(context, candidate); - } + AddOverloadCandidate(context, candidate); + } - // If the given declaration has generic parameters, then - // return the corresponding `GenericDecl` that holds the - // parameters, etc. - GenericDecl* GetOuterGeneric(Decl* decl) + void AddFuncOverloadCandidate( + RefPtr<FuncType> /*funcType*/, + OverloadResolveContext& /*context*/) + { +#if 0 + if (funcType->decl) { - auto parentDecl = decl->ParentDecl; - if (!parentDecl) return nullptr; - auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl); - return parentGeneric; + AddFuncOverloadCandidate(funcType->decl, context); } - - // Try to find a unification for two values - bool TryUnifyVals( - ConstraintSystem& constraints, - RefPtr<Val> fst, - RefPtr<Val> snd) + else if (funcType->Func) { - // if both values are types, then unify types - if (auto fstType = fst.As<ExpressionType>()) - { - if (auto sndType = snd.As<ExpressionType>()) - { - return TryUnifyTypes(constraints, fstType, sndType); - } - } + AddFuncOverloadCandidate(funcType->Func->SyntaxNode, context); + } + else if (funcType->Component) + { + AddComponentFuncOverloadCandidate(funcType->Component, context); + } +#else + throw "unimplemented"; +#endif + } - // if both values are constant integers, then compare them - if (auto fstIntVal = fst.As<ConstantIntVal>()) - { - if (auto sndIntVal = snd.As<ConstantIntVal>()) - { - return fstIntVal->value == sndIntVal->value; - } - } + void AddCtorOverloadCandidate( + LookupResultItem typeItem, + RefPtr<ExpressionType> type, + ConstructorDeclRef ctorDeclRef, + OverloadResolveContext& context) + { + EnsureDecl(ctorDeclRef.GetDecl()); - // Check if both are integer values in general - if (auto fstInt = fst.As<IntVal>()) - { - if (auto sndInt = snd.As<IntVal>()) - { - auto fstParam = fstInt.As<GenericParamIntVal>(); - auto sndParam = sndInt.As<GenericParamIntVal>(); + // `typeItem` refers to the type being constructed (the thing + // that was applied as a function) so we need to construct + // a `LookupResultItem` that refers to the constructor instead - if (fstParam) - TryUnifyIntParam(constraints, fstParam->declRef, sndInt); - if (sndParam) - TryUnifyIntParam(constraints, sndParam->declRef, fstInt); + LookupResultItem ctorItem; + ctorItem.declRef = ctorDeclRef; + ctorItem.breadcrumbs = new LookupResultItem::Breadcrumb(LookupResultItem::Breadcrumb::Kind::Member, typeItem.declRef, typeItem.breadcrumbs); - if (fstParam || sndParam) - return true; - } - } + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Func; + candidate.item = ctorItem; + candidate.resultType = type; - throw "unimplemented"; + AddOverloadCandidate(context, candidate); + } - // default: fail - return false; - } + // If the given declaration has generic parameters, then + // return the corresponding `GenericDecl` that holds the + // parameters, etc. + GenericDecl* GetOuterGeneric(Decl* decl) + { + auto parentDecl = decl->ParentDecl; + if (!parentDecl) return nullptr; + auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl); + return parentGeneric; + } - bool TryUnifySubstitutions( - ConstraintSystem& constraints, - RefPtr<Substitutions> fst, - RefPtr<Substitutions> snd) + // Try to find a unification for two values + bool TryUnifyVals( + ConstraintSystem& constraints, + RefPtr<Val> fst, + RefPtr<Val> snd) + { + // if both values are types, then unify types + if (auto fstType = fst.As<ExpressionType>()) { - // They must both be NULL or non-NULL - if (!fst || !snd) - return fst == snd; - - // They must be specializing the same generic - if (fst->genericDecl != snd->genericDecl) - return false; - - // Their arguments must unify - assert(fst->args.Count() == snd->args.Count()); - int argCount = fst->args.Count(); - for (int aa = 0; aa < argCount; ++aa) + if (auto sndType = snd.As<ExpressionType>()) { - if (!TryUnifyVals(constraints, fst->args[aa], snd->args[aa])) - return false; + return TryUnifyTypes(constraints, fstType, sndType); } + } - // Their "base" specializations must unify - if (!TryUnifySubstitutions(constraints, fst->outer, snd->outer)) - return false; - - return true; + // if both values are constant integers, then compare them + if (auto fstIntVal = fst.As<ConstantIntVal>()) + { + if (auto sndIntVal = snd.As<ConstantIntVal>()) + { + return fstIntVal->value == sndIntVal->value; + } } - bool TryUnifyTypeParam( - ConstraintSystem& constraints, - RefPtr<GenericTypeParamDecl> typeParamDecl, - RefPtr<ExpressionType> type) + // Check if both are integer values in general + if (auto fstInt = fst.As<IntVal>()) { - // We want to constrain the given type parameter - // to equal the given type. - Constraint constraint; - constraint.decl = typeParamDecl.Ptr(); - constraint.val = type; + if (auto sndInt = snd.As<IntVal>()) + { + auto fstParam = fstInt.As<GenericParamIntVal>(); + auto sndParam = sndInt.As<GenericParamIntVal>(); - constraints.constraints.Add(constraint); + if (fstParam) + TryUnifyIntParam(constraints, fstParam->declRef, sndInt); + if (sndParam) + TryUnifyIntParam(constraints, sndParam->declRef, fstInt); - return true; + if (fstParam || sndParam) + return true; + } } - bool TryUnifyIntParam( - ConstraintSystem& constraints, - RefPtr<GenericValueParamDecl> paramDecl, - RefPtr<IntVal> val) - { - // We want to constrain the given parameter to equal the given value. - Constraint constraint; - constraint.decl = paramDecl.Ptr(); - constraint.val = val; + throw "unimplemented"; - constraints.constraints.Add(constraint); + // default: fail + return false; + } - return true; - } + bool TryUnifySubstitutions( + ConstraintSystem& constraints, + RefPtr<Substitutions> fst, + RefPtr<Substitutions> snd) + { + // They must both be NULL or non-NULL + if (!fst || !snd) + return fst == snd; + + // They must be specializing the same generic + if (fst->genericDecl != snd->genericDecl) + return false; - bool TryUnifyIntParam( - ConstraintSystem& constraints, - VarDeclBaseRef const& varRef, - RefPtr<IntVal> val) + // Their arguments must unify + assert(fst->args.Count() == snd->args.Count()); + int argCount = fst->args.Count(); + for (int aa = 0; aa < argCount; ++aa) { - if(auto genericValueParamRef = varRef.As<GenericValueParamDeclRef>()) - { - return TryUnifyIntParam(constraints, genericValueParamRef.GetDecl(), val); - } - else - { + if (!TryUnifyVals(constraints, fst->args[aa], snd->args[aa])) return false; - } } - bool TryUnifyTypesByStructuralMatch( - ConstraintSystem& constraints, - RefPtr<ExpressionType> fst, - RefPtr<ExpressionType> snd) - { - if (auto fstDeclRefType = fst->As<DeclRefType>()) - { - auto fstDeclRef = fstDeclRefType->declRef; + // Their "base" specializations must unify + if (!TryUnifySubstitutions(constraints, fst->outer, snd->outer)) + return false; + + return true; + } - if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) - return TryUnifyTypeParam(constraints, typeParamDecl, snd); + bool TryUnifyTypeParam( + ConstraintSystem& constraints, + RefPtr<GenericTypeParamDecl> typeParamDecl, + RefPtr<ExpressionType> type) + { + // We want to constrain the given type parameter + // to equal the given type. + Constraint constraint; + constraint.decl = typeParamDecl.Ptr(); + constraint.val = type; - if (auto sndDeclRefType = snd->As<DeclRefType>()) - { - auto sndDeclRef = sndDeclRefType->declRef; + constraints.constraints.Add(constraint); - if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl())) - return TryUnifyTypeParam(constraints, typeParamDecl, fst); + return true; + } - // can't be unified if they refer to differnt declarations. - if (fstDeclRef.GetDecl() != sndDeclRef.GetDecl()) return false; + bool TryUnifyIntParam( + ConstraintSystem& constraints, + RefPtr<GenericValueParamDecl> paramDecl, + RefPtr<IntVal> val) + { + // We want to constrain the given parameter to equal the given value. + Constraint constraint; + constraint.decl = paramDecl.Ptr(); + constraint.val = val; - // next we need to unify the substitutions applied - // to each decalration reference. - if (!TryUnifySubstitutions( - constraints, - fstDeclRef.substitutions, - sndDeclRef.substitutions)) - { - return false; - } + constraints.constraints.Add(constraint); - return true; - } - } + return true; + } + bool TryUnifyIntParam( + ConstraintSystem& constraints, + VarDeclBaseRef const& varRef, + RefPtr<IntVal> val) + { + if(auto genericValueParamRef = varRef.As<GenericValueParamDeclRef>()) + { + return TryUnifyIntParam(constraints, genericValueParamRef.GetDecl(), val); + } + else + { return false; } + } - bool TryUnifyTypes( - ConstraintSystem& constraints, - RefPtr<ExpressionType> fst, - RefPtr<ExpressionType> snd) + bool TryUnifyTypesByStructuralMatch( + ConstraintSystem& constraints, + RefPtr<ExpressionType> fst, + RefPtr<ExpressionType> snd) + { + if (auto fstDeclRefType = fst->As<DeclRefType>()) { - if (fst->Equals(snd)) return true; - - // An error type can unify with anything, just so we avoid cascading errors. + auto fstDeclRef = fstDeclRefType->declRef; - if (auto fstErrorType = fst->As<ErrorType>()) - return true; - - if (auto sndErrorType = snd->As<ErrorType>()) - return true; - - // A generic parameter type can unify with anything. - // TODO: there actually needs to be some kind of "occurs check" sort - // of thing here... - - if (auto fstDeclRefType = fst->As<DeclRefType>()) - { - auto fstDeclRef = fstDeclRefType->declRef; - - if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) - return TryUnifyTypeParam(constraints, typeParamDecl, snd); - } + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, snd); if (auto sndDeclRefType = snd->As<DeclRefType>()) { @@ -3709,1336 +3653,1389 @@ namespace Slang if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl())) return TryUnifyTypeParam(constraints, typeParamDecl, fst); - } - - // If we can unify the types structurally, then we are golden - if(TryUnifyTypesByStructuralMatch(constraints, fst, snd)) - return true; - // Now we need to consider cases where coercion might - // need to be applied. For now we can try to do this - // in a completely ad hoc fashion, but eventually we'd - // want to do it more formally. + // can't be unified if they refer to differnt declarations. + if (fstDeclRef.GetDecl() != sndDeclRef.GetDecl()) return false; - if(auto fstVectorType = fst->As<VectorExpressionType>()) - { - if(auto sndScalarType = snd->As<BasicExpressionType>()) + // next we need to unify the substitutions applied + // to each decalration reference. + if (!TryUnifySubstitutions( + constraints, + fstDeclRef.substitutions, + sndDeclRef.substitutions)) { - return TryUnifyTypes( - constraints, - fstVectorType->elementType, - sndScalarType); + return false; } - } - if(auto fstScalarType = fst->As<BasicExpressionType>()) - { - if(auto sndVectorType = snd->As<VectorExpressionType>()) - { - return TryUnifyTypes( - constraints, - fstScalarType, - sndVectorType->elementType); - } + return true; } + } - // TODO: the same thing for vectors... + return false; + } - return false; - } + bool TryUnifyTypes( + ConstraintSystem& constraints, + RefPtr<ExpressionType> fst, + RefPtr<ExpressionType> snd) + { + if (fst->Equals(snd)) return true; + + // An error type can unify with anything, just so we avoid cascading errors. - // Is the candidate extension declaration actually applicable to the given type - ExtensionDeclRef ApplyExtensionToType( - ExtensionDecl* extDecl, - RefPtr<ExpressionType> type) + if (auto fstErrorType = fst->As<ErrorType>()) + return true; + + if (auto sndErrorType = snd->As<ErrorType>()) + return true; + + // A generic parameter type can unify with anything. + // TODO: there actually needs to be some kind of "occurs check" sort + // of thing here... + + if (auto fstDeclRefType = fst->As<DeclRefType>()) { - if (auto extGenericDecl = GetOuterGeneric(extDecl)) - { - ConstraintSystem constraints; + auto fstDeclRef = fstDeclRefType->declRef; - if (!TryUnifyTypes(constraints, extDecl->targetType, type)) - return DeclRef().As<ExtensionDeclRef>(); + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(fstDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, snd); + } - auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef(extGenericDecl, nullptr).As<GenericDeclRef>()); - if (!constraintSubst) - { - return DeclRef().As<ExtensionDeclRef>(); - } + if (auto sndDeclRefType = snd->As<DeclRefType>()) + { + auto sndDeclRef = sndDeclRefType->declRef; - // Consruct a reference to the extension with our constraint variables - // set as they were found by solving the constraint system. - ExtensionDeclRef extDeclRef = DeclRef(extDecl, constraintSubst).As<ExtensionDeclRef>(); + if (auto typeParamDecl = dynamic_cast<GenericTypeParamDecl*>(sndDeclRef.GetDecl())) + return TryUnifyTypeParam(constraints, typeParamDecl, fst); + } - // We expect/require that the result of unification is such that - // the target types are now equal - assert(extDeclRef.GetTargetType()->Equals(type)); + // If we can unify the types structurally, then we are golden + if(TryUnifyTypesByStructuralMatch(constraints, fst, snd)) + return true; - return extDeclRef; - } - else + // Now we need to consider cases where coercion might + // need to be applied. For now we can try to do this + // in a completely ad hoc fashion, but eventually we'd + // want to do it more formally. + + if(auto fstVectorType = fst->As<VectorExpressionType>()) + { + if(auto sndScalarType = snd->As<BasicExpressionType>()) { - // The easy case is when the extension isn't generic: - // either it applies to the type or not. - if (!type->Equals(extDecl->targetType)) - return DeclRef().As<ExtensionDeclRef>(); - return DeclRef(extDecl, nullptr).As<ExtensionDeclRef>(); + return TryUnifyTypes( + constraints, + fstVectorType->elementType, + sndScalarType); } } - bool TryUnifyArgAndParamTypes( - ConstraintSystem& system, - RefPtr<ExpressionSyntaxNode> argExpr, - ParamDeclRef paramDeclRef) + if(auto fstScalarType = fst->As<BasicExpressionType>()) { - // TODO(tfoley): potentially need a bit more - // nuance in case where argument might be - // an overload group... - return TryUnifyTypes(system, argExpr->Type, paramDeclRef.GetType()); + if(auto sndVectorType = snd->As<VectorExpressionType>()) + { + return TryUnifyTypes( + constraints, + fstScalarType, + sndVectorType->elementType); + } } - // Take a generic declaration and try to specialize its parameters - // so that the resulting inner declaration can be applicable in - // a particular context... - DeclRef SpecializeGenericForOverload( - GenericDeclRef genericDeclRef, - OverloadResolveContext& context) + // TODO: the same thing for vectors... + + return false; + } + + // Is the candidate extension declaration actually applicable to the given type + ExtensionDeclRef ApplyExtensionToType( + ExtensionDecl* extDecl, + RefPtr<ExpressionType> type) + { + if (auto extGenericDecl = GetOuterGeneric(extDecl)) { ConstraintSystem constraints; - // Construct a reference to the inner declaration that has any generic - // parameter substitutions in place already, but *not* any substutions - // for the generic declaration we are currently trying to infer. - auto innerDecl = genericDeclRef.GetInner(); - DeclRef unspecializedInnerRef = DeclRef(innerDecl, genericDeclRef.substitutions); + if (!TryUnifyTypes(constraints, extDecl->targetType, type)) + return DeclRef().As<ExtensionDeclRef>(); - // Check what type of declaration we are dealing with, and then try - // to match it up with the arguments accordingly... - if (auto funcDeclRef = unspecializedInnerRef.As<CallableDeclRef>()) + auto constraintSubst = TrySolveConstraintSystem(&constraints, DeclRef(extGenericDecl, nullptr).As<GenericDeclRef>()); + if (!constraintSubst) { - auto& args = context.appExpr->Arguments; - auto params = funcDeclRef.GetParameters().ToArray(); + return DeclRef().As<ExtensionDeclRef>(); + } - int argCount = args.Count(); - int paramCount = params.Count(); + // Consruct a reference to the extension with our constraint variables + // set as they were found by solving the constraint system. + ExtensionDeclRef extDeclRef = DeclRef(extDecl, constraintSubst).As<ExtensionDeclRef>(); - // Bail out on mismatch. - // TODO(tfoley): need more nuance here - if (argCount != paramCount) - { - return DeclRef(nullptr, nullptr); - } + // We expect/require that the result of unification is such that + // the target types are now equal + assert(extDeclRef.GetTargetType()->Equals(type)); - for (int aa = 0; aa < argCount; ++aa) - { -#if 0 - if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa])) - return DeclRef(nullptr, nullptr); -#else - // The question here is whether failure to "unify" an argument - // and parameter should lead to immediate failure. - // - // The case that is interesting is if we want to unify, say: - // `vector<float,N>` and `vector<int,3>` - // - // It is clear that we should solve with `N = 3`, and then - // a later step may find that the resulting types aren't - // actually a match. - // - // A more refined approach to "unification" could of course - // see that `int` can convert to `float` and use that fact. - // (and indeed we already use something like this to unify - // `float` and `vector<T,3>`) - // - // So the question is then whether a mismatch during the - // unification step should be taken as an immediate failure... + return extDeclRef; + } + else + { + // The easy case is when the extension isn't generic: + // either it applies to the type or not. + if (!type->Equals(extDecl->targetType)) + return DeclRef().As<ExtensionDeclRef>(); + return DeclRef(extDecl, nullptr).As<ExtensionDeclRef>(); + } + } - TryUnifyArgAndParamTypes(constraints, args[aa], params[aa]); -#endif - } - } - else + bool TryUnifyArgAndParamTypes( + ConstraintSystem& system, + RefPtr<ExpressionSyntaxNode> argExpr, + ParamDeclRef paramDeclRef) + { + // TODO(tfoley): potentially need a bit more + // nuance in case where argument might be + // an overload group... + return TryUnifyTypes(system, argExpr->Type, paramDeclRef.GetType()); + } + + // Take a generic declaration and try to specialize its parameters + // so that the resulting inner declaration can be applicable in + // a particular context... + DeclRef SpecializeGenericForOverload( + GenericDeclRef genericDeclRef, + OverloadResolveContext& context) + { + ConstraintSystem constraints; + + // Construct a reference to the inner declaration that has any generic + // parameter substitutions in place already, but *not* any substutions + // for the generic declaration we are currently trying to infer. + auto innerDecl = genericDeclRef.GetInner(); + DeclRef unspecializedInnerRef = DeclRef(innerDecl, genericDeclRef.substitutions); + + // Check what type of declaration we are dealing with, and then try + // to match it up with the arguments accordingly... + if (auto funcDeclRef = unspecializedInnerRef.As<CallableDeclRef>()) + { + auto& args = context.appExpr->Arguments; + auto params = funcDeclRef.GetParameters().ToArray(); + + int argCount = args.Count(); + int paramCount = params.Count(); + + // Bail out on mismatch. + // TODO(tfoley): need more nuance here + if (argCount != paramCount) { - // TODO(tfoley): any other cases needed here? return DeclRef(nullptr, nullptr); } - auto constraintSubst = TrySolveConstraintSystem(&constraints, genericDeclRef); - if (!constraintSubst) + for (int aa = 0; aa < argCount; ++aa) { - // constraint solving failed - return DeclRef(nullptr, nullptr); +#if 0 + if (!TryUnifyArgAndParamTypes(constraints, args[aa], params[aa])) + return DeclRef(nullptr, nullptr); +#else + // The question here is whether failure to "unify" an argument + // and parameter should lead to immediate failure. + // + // The case that is interesting is if we want to unify, say: + // `vector<float,N>` and `vector<int,3>` + // + // It is clear that we should solve with `N = 3`, and then + // a later step may find that the resulting types aren't + // actually a match. + // + // A more refined approach to "unification" could of course + // see that `int` can convert to `float` and use that fact. + // (and indeed we already use something like this to unify + // `float` and `vector<T,3>`) + // + // So the question is then whether a mismatch during the + // unification step should be taken as an immediate failure... + + TryUnifyArgAndParamTypes(constraints, args[aa], params[aa]); +#endif } + } + else + { + // TODO(tfoley): any other cases needed here? + return DeclRef(nullptr, nullptr); + } + + auto constraintSubst = TrySolveConstraintSystem(&constraints, genericDeclRef); + if (!constraintSubst) + { + // constraint solving failed + return DeclRef(nullptr, nullptr); + } - // We can now construct a reference to the inner declaration using - // the solution to our constraints. - return DeclRef(innerDecl, constraintSubst); + // We can now construct a reference to the inner declaration using + // the solution to our constraints. + return DeclRef(innerDecl, constraintSubst); + } + + void AddAggTypeOverloadCandidates( + LookupResultItem typeItem, + RefPtr<ExpressionType> type, + AggTypeDeclRef aggTypeDeclRef, + OverloadResolveContext& context) + { + for (auto ctorDeclRef : aggTypeDeclRef.GetMembersOfType<ConstructorDeclRef>()) + { + // now work through this candidate... + AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context); } - void AddAggTypeOverloadCandidates( - LookupResultItem typeItem, - RefPtr<ExpressionType> type, - AggTypeDeclRef aggTypeDeclRef, - OverloadResolveContext& context) + // Now walk through any extensions we can find for this types + for (auto ext = aggTypeDeclRef.GetCandidateExtensions(); ext; ext = ext->nextCandidateExtension) { - for (auto ctorDeclRef : aggTypeDeclRef.GetMembersOfType<ConstructorDeclRef>()) + auto extDeclRef = ApplyExtensionToType(ext, type); + if (!extDeclRef) + continue; + + for (auto ctorDeclRef : extDeclRef.GetMembersOfType<ConstructorDeclRef>()) { + // TODO(tfoley): `typeItem` here should really reference the extension... + // now work through this candidate... AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context); } - // Now walk through any extensions we can find for this types - for (auto ext = aggTypeDeclRef.GetCandidateExtensions(); ext; ext = ext->nextCandidateExtension) + // Also check for generic constructors + for (auto genericDeclRef : extDeclRef.GetMembersOfType<GenericDeclRef>()) { - auto extDeclRef = ApplyExtensionToType(ext, type); - if (!extDeclRef) - continue; - - for (auto ctorDeclRef : extDeclRef.GetMembersOfType<ConstructorDeclRef>()) + if (auto ctorDecl = genericDeclRef.GetDecl()->inner.As<ConstructorDecl>()) { - // TODO(tfoley): `typeItem` here should really reference the extension... - - // now work through this candidate... - AddCtorOverloadCandidate(typeItem, type, ctorDeclRef, context); - } - - // Also check for generic constructors - for (auto genericDeclRef : extDeclRef.GetMembersOfType<GenericDeclRef>()) - { - if (auto ctorDecl = genericDeclRef.GetDecl()->inner.As<ConstructorDecl>()) - { - DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); - if (!innerRef) - continue; + DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); + if (!innerRef) + continue; - ConstructorDeclRef innerCtorRef = innerRef.As<ConstructorDeclRef>(); + ConstructorDeclRef innerCtorRef = innerRef.As<ConstructorDeclRef>(); - AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context); + AddCtorOverloadCandidate(typeItem, type, innerCtorRef, context); - // TODO(tfoley): need a way to do the solving step for the constraint system - } + // TODO(tfoley): need a way to do the solving step for the constraint system } } } + } - void AddTypeOverloadCandidates( - RefPtr<ExpressionType> type, - OverloadResolveContext& context) + void AddTypeOverloadCandidates( + RefPtr<ExpressionType> type, + OverloadResolveContext& context) + { + if (auto declRefType = type->As<DeclRefType>()) { - if (auto declRefType = type->As<DeclRefType>()) + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) { - if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) - { - AddAggTypeOverloadCandidates(LookupResultItem(aggTypeDeclRef), type, aggTypeDeclRef, context); - } + AddAggTypeOverloadCandidates(LookupResultItem(aggTypeDeclRef), type, aggTypeDeclRef, context); } } + } + + void AddDeclRefOverloadCandidates( + LookupResultItem item, + OverloadResolveContext& context) + { + auto declRef = item.declRef; - void AddDeclRefOverloadCandidates( - LookupResultItem item, - OverloadResolveContext& context) + if (auto funcDeclRef = item.declRef.As<CallableDeclRef>()) { - auto declRef = item.declRef; + AddFuncOverloadCandidate(item, funcDeclRef, context); + } + else if (auto aggTypeDeclRef = item.declRef.As<AggTypeDeclRef>()) + { + auto type = DeclRefType::Create(aggTypeDeclRef); + AddAggTypeOverloadCandidates(item, type, aggTypeDeclRef, context); + } + else if (auto genericDeclRef = item.declRef.As<GenericDeclRef>()) + { + // Try to infer generic arguments, based on the context + DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); - if (auto funcDeclRef = item.declRef.As<CallableDeclRef>()) - { - AddFuncOverloadCandidate(item, funcDeclRef, context); - } - else if (auto aggTypeDeclRef = item.declRef.As<AggTypeDeclRef>()) - { - auto type = DeclRefType::Create(aggTypeDeclRef); - AddAggTypeOverloadCandidates(item, type, aggTypeDeclRef, context); - } - else if (auto genericDeclRef = item.declRef.As<GenericDeclRef>()) + if (innerRef) { - // Try to infer generic arguments, based on the context - DeclRef innerRef = SpecializeGenericForOverload(genericDeclRef, context); + // If inference works, then we've now got a + // specialized declaration reference we can apply. - if (innerRef) - { - // If inference works, then we've now got a - // specialized declaration reference we can apply. - - LookupResultItem innerItem; - innerItem.breadcrumbs = item.breadcrumbs; - innerItem.declRef = innerRef; + LookupResultItem innerItem; + innerItem.breadcrumbs = item.breadcrumbs; + innerItem.declRef = innerRef; - AddDeclRefOverloadCandidates(innerItem, context); - } - else - { - // If inference failed, then we need to create - // a candidate that can be used to reflect that fact - // (so we can report a good error) - OverloadCandidate candidate; - candidate.item = item; - candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric; - candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed; - - AddOverloadCandidateInner(context, candidate); - } - } - else if( auto typeDefDeclRef = item.declRef.As<TypeDefDeclRef>() ) - { - AddTypeOverloadCandidates(typeDefDeclRef.GetType(), context); + AddDeclRefOverloadCandidates(innerItem, context); } else { - // TODO(tfoley): any other cases needed here? + // If inference failed, then we need to create + // a candidate that can be used to reflect that fact + // (so we can report a good error) + OverloadCandidate candidate; + candidate.item = item; + candidate.flavor = OverloadCandidate::Flavor::UnspecializedGeneric; + candidate.status = OverloadCandidate::Status::GenericArgumentInferenceFailed; + + AddOverloadCandidateInner(context, candidate); } } - - void AddOverloadCandidates( - RefPtr<ExpressionSyntaxNode> funcExpr, - OverloadResolveContext& context) + else if( auto typeDefDeclRef = item.declRef.As<TypeDefDeclRef>() ) + { + AddTypeOverloadCandidates(typeDefDeclRef.GetType(), context); + } + else { - auto funcExprType = funcExpr->Type; + // TODO(tfoley): any other cases needed here? + } + } - if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>()) - { - // The expression referenced a function declaration - AddDeclRefOverloadCandidates(LookupResultItem(funcDeclRefExpr->declRef), context); - } - else if (auto funcType = funcExprType->As<FuncType>()) - { - // TODO(tfoley): deprecate this path... - AddFuncOverloadCandidate(funcType, context); - } - else if (auto overloadedExpr = funcExpr.As<OverloadedExpr>()) - { - auto lookupResult = overloadedExpr->lookupResult2; - assert(lookupResult.isOverloaded()); - for(auto item : lookupResult.items) - { - AddDeclRefOverloadCandidates(item, context); - } - } - else if (auto typeType = funcExprType->As<TypeType>()) + void AddOverloadCandidates( + RefPtr<ExpressionSyntaxNode> funcExpr, + OverloadResolveContext& context) + { + auto funcExprType = funcExpr->Type; + + if (auto funcDeclRefExpr = funcExpr.As<DeclRefExpr>()) + { + // The expression referenced a function declaration + AddDeclRefOverloadCandidates(LookupResultItem(funcDeclRefExpr->declRef), context); + } + else if (auto funcType = funcExprType->As<FuncType>()) + { + // TODO(tfoley): deprecate this path... + AddFuncOverloadCandidate(funcType, context); + } + else if (auto overloadedExpr = funcExpr.As<OverloadedExpr>()) + { + auto lookupResult = overloadedExpr->lookupResult2; + assert(lookupResult.isOverloaded()); + for(auto item : lookupResult.items) { - // If none of the above cases matched, but we are - // looking at a type, then I suppose we have - // a constructor call on our hands. - // - // TODO(tfoley): are there any meaningful types left - // that aren't declaration references? - AddTypeOverloadCandidates(typeType->type, context); - return; + AddDeclRefOverloadCandidates(item, context); } } + else if (auto typeType = funcExprType->As<TypeType>()) + { + // If none of the above cases matched, but we are + // looking at a type, then I suppose we have + // a constructor call on our hands. + // + // TODO(tfoley): are there any meaningful types left + // that aren't declaration references? + AddTypeOverloadCandidates(typeType->type, context); + return; + } + } + + void formatType(StringBuilder& sb, RefPtr<ExpressionType> type) + { + sb << type->ToString(); + } + + void formatVal(StringBuilder& sb, RefPtr<Val> val) + { + sb << val->ToString(); + } - void formatType(StringBuilder& sb, RefPtr<ExpressionType> type) + void formatDeclPath(StringBuilder& sb, DeclRef declRef) + { + // Find the parent declaration + auto parentDeclRef = declRef.GetParent(); + + // If the immediate parent is a generic, then we probably + // want the declaration above that... + auto parentGenericDeclRef = parentDeclRef.As<GenericDeclRef>(); + if(parentGenericDeclRef) { - sb << type->ToString(); + parentDeclRef = parentGenericDeclRef.GetParent(); } - void formatVal(StringBuilder& sb, RefPtr<Val> val) + // Depending on what the parent is, we may want to format things specially + if(auto aggTypeDeclRef = parentDeclRef.As<AggTypeDeclRef>()) { - sb << val->ToString(); + formatDeclPath(sb, aggTypeDeclRef); + sb << "."; } - void formatDeclPath(StringBuilder& sb, DeclRef declRef) + sb << declRef.GetName(); + + // If the parent declaration is a generic, then we need to print out its + // signature + if( parentGenericDeclRef ) { - // Find the parent declaration - auto parentDeclRef = declRef.GetParent(); + assert(declRef.substitutions); + assert(declRef.substitutions->genericDecl == parentGenericDeclRef.GetDecl()); - // If the immediate parent is a generic, then we probably - // want the declaration above that... - auto parentGenericDeclRef = parentDeclRef.As<GenericDeclRef>(); - if(parentGenericDeclRef) + sb << "<"; + bool first = true; + for(auto arg : declRef.substitutions->args) { - parentDeclRef = parentGenericDeclRef.GetParent(); + if(!first) sb << ", "; + formatVal(sb, arg); + first = false; } + sb << ">"; + } + } - // Depending on what the parent is, we may want to format things specially - if(auto aggTypeDeclRef = parentDeclRef.As<AggTypeDeclRef>()) - { - formatDeclPath(sb, aggTypeDeclRef); - sb << "."; - } + void formatDeclParams(StringBuilder& sb, DeclRef declRef) + { + if (auto funcDeclRef = declRef.As<CallableDeclRef>()) + { - sb << declRef.GetName(); + // This is something callable, so we need to also print parameter types for overloading + sb << "("; - // If the parent declaration is a generic, then we need to print out its - // signature - if( parentGenericDeclRef ) + bool first = true; + for (auto paramDeclRef : funcDeclRef.GetParameters()) { - assert(declRef.substitutions); - assert(declRef.substitutions->genericDecl == parentGenericDeclRef.GetDecl()); + if (!first) sb << ", "; + + formatType(sb, paramDeclRef.GetType()); + + first = false; - sb << "<"; - bool first = true; - for(auto arg : declRef.substitutions->args) - { - if(!first) sb << ", "; - formatVal(sb, arg); - first = false; - } - sb << ">"; } - } - void formatDeclParams(StringBuilder& sb, DeclRef declRef) + sb << ")"; + } + else if(auto genericDeclRef = declRef.As<GenericDeclRef>()) { - if (auto funcDeclRef = declRef.As<CallableDeclRef>()) + sb << "<"; + bool first = true; + for (auto paramDeclRef : genericDeclRef.GetMembers()) { - - // This is something callable, so we need to also print parameter types for overloading - sb << "("; - - bool first = true; - for (auto paramDeclRef : funcDeclRef.GetParameters()) + if(auto genericTypeParam = paramDeclRef.As<GenericTypeParamDeclRef>()) { if (!first) sb << ", "; - - formatType(sb, paramDeclRef.GetType()); - first = false; + sb << genericTypeParam.GetName(); } - - sb << ")"; - } - else if(auto genericDeclRef = declRef.As<GenericDeclRef>()) - { - sb << "<"; - bool first = true; - for (auto paramDeclRef : genericDeclRef.GetMembers()) + else if(auto genericValParam = paramDeclRef.As<GenericValueParamDeclRef>()) { - if(auto genericTypeParam = paramDeclRef.As<GenericTypeParamDeclRef>()) - { - if (!first) sb << ", "; - first = false; - - sb << genericTypeParam.GetName(); - } - else if(auto genericValParam = paramDeclRef.As<GenericValueParamDeclRef>()) - { - if (!first) sb << ", "; - first = false; + if (!first) sb << ", "; + first = false; - formatType(sb, genericValParam.GetType()); - sb << " "; - sb << genericValParam.GetName(); - } - else - {} + formatType(sb, genericValParam.GetType()); + sb << " "; + sb << genericValParam.GetName(); } - sb << ">"; - - formatDeclParams(sb, DeclRef(genericDeclRef.GetInner(), genericDeclRef.substitutions)); - } - else - { + else + {} } - } + sb << ">"; - void formatDeclSignature(StringBuilder& sb, DeclRef declRef) + formatDeclParams(sb, DeclRef(genericDeclRef.GetInner(), genericDeclRef.substitutions)); + } + else { - formatDeclPath(sb, declRef); - formatDeclParams(sb, declRef); } + } + + void formatDeclSignature(StringBuilder& sb, DeclRef declRef) + { + formatDeclPath(sb, declRef); + formatDeclParams(sb, declRef); + } + + String getDeclSignatureString(DeclRef declRef) + { + StringBuilder sb; + formatDeclSignature(sb, declRef); + return sb.ProduceString(); + } + + String getDeclSignatureString(LookupResultItem item) + { + return getDeclSignatureString(item.declRef); + } - String getDeclSignatureString(DeclRef declRef) + String GetCallSignatureString(RefPtr<AppExprBase> expr) + { + StringBuilder argsListBuilder; + argsListBuilder << "("; + bool first = true; + for (auto a : expr->Arguments) { - StringBuilder sb; - formatDeclSignature(sb, declRef); - return sb.ProduceString(); + if (!first) argsListBuilder << ", "; + argsListBuilder << a->Type->ToString(); + first = false; } + argsListBuilder << ")"; + return argsListBuilder.ProduceString(); + } + - String getDeclSignatureString(LookupResultItem item) + RefPtr<ExpressionSyntaxNode> ResolveInvoke(InvokeExpressionSyntaxNode * expr) + { + // Look at the base expression for the call, and figure out how to invoke it. + auto funcExpr = expr->FunctionExpr; + auto funcExprType = funcExpr->Type; + + // If we are trying to apply an erroroneous expression, then just bail out now. + if(IsErrorExpr(funcExpr)) { - return getDeclSignatureString(item.declRef); + return CreateErrorExpr(expr); } - String GetCallSignatureString(RefPtr<AppExprBase> expr) + OverloadResolveContext context; + context.appExpr = expr; + if (auto funcMemberExpr = funcExpr.As<MemberExpressionSyntaxNode>()) { - StringBuilder argsListBuilder; - argsListBuilder << "("; - bool first = true; - for (auto a : expr->Arguments) - { - if (!first) argsListBuilder << ", "; - argsListBuilder << a->Type->ToString(); - first = false; - } - argsListBuilder << ")"; - return argsListBuilder.ProduceString(); + context.baseExpr = funcMemberExpr->BaseExpression; } - - - RefPtr<ExpressionSyntaxNode> ResolveInvoke(InvokeExpressionSyntaxNode * expr) + else if (auto funcOverloadExpr = funcExpr.As<OverloadedExpr>()) { - // Look at the base expression for the call, and figure out how to invoke it. - auto funcExpr = expr->FunctionExpr; - auto funcExprType = funcExpr->Type; - - // If we are trying to apply an erroroneous expression, then just bail out now. - if(IsErrorExpr(funcExpr)) - { - return CreateErrorExpr(expr); - } + context.baseExpr = funcOverloadExpr->base; + } + AddOverloadCandidates(funcExpr, context); - OverloadResolveContext context; - context.appExpr = expr; - if (auto funcMemberExpr = funcExpr.As<MemberExpressionSyntaxNode>()) - { - context.baseExpr = funcMemberExpr->BaseExpression; - } - else if (auto funcOverloadExpr = funcExpr.As<OverloadedExpr>()) - { - context.baseExpr = funcOverloadExpr->base; - } - AddOverloadCandidates(funcExpr, context); + if (context.bestCandidates.Count() > 0) + { + // Things were ambiguous. - if (context.bestCandidates.Count() > 0) + // It might be that things were only ambiguous because + // one of the argument expressions had an error, and + // so a bunch of candidates could match at that position. + // + // If any argument was an error, we skip out on printing + // another message, to avoid cascading errors. + for (auto arg : expr->Arguments) { - // Things were ambiguous. - - // It might be that things were only ambiguous because - // one of the argument expressions had an error, and - // so a bunch of candidates could match at that position. - // - // If any argument was an error, we skip out on printing - // another message, to avoid cascading errors. - for (auto arg : expr->Arguments) + if (IsErrorExpr(arg)) { - if (IsErrorExpr(arg)) - { - return CreateErrorExpr(expr); - } + return CreateErrorExpr(expr); } + } - String funcName; - if (auto baseVar = funcExpr.As<VarExpressionSyntaxNode>()) - funcName = baseVar->Variable; - else if(auto baseMemberRef = funcExpr.As<MemberExpressionSyntaxNode>()) - funcName = baseMemberRef->MemberName; + String funcName; + if (auto baseVar = funcExpr.As<VarExpressionSyntaxNode>()) + funcName = baseVar->Variable; + else if(auto baseMemberRef = funcExpr.As<MemberExpressionSyntaxNode>()) + funcName = baseMemberRef->MemberName; - String argsList = GetCallSignatureString(expr); + String argsList = GetCallSignatureString(expr); - if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) + if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) + { + // There were multple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. + if (funcName.Length() != 0) { - // There were multple equally-good candidates, but none actually usable. - // We will construct a diagnostic message to help out. - if (funcName.Length() != 0) - { - getSink()->diagnose(expr, Diagnostics::noApplicableOverloadForNameWithArgs, funcName, argsList); - } - else - { - getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList); - } + getSink()->diagnose(expr, Diagnostics::noApplicableOverloadForNameWithArgs, funcName, argsList); } else { - // There were multiple applicable candidates, so we need to report them. - - if (funcName.Length() != 0) - { - getSink()->diagnose(expr, Diagnostics::ambiguousOverloadForNameWithArgs, funcName, argsList); - } - else - { - getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList); - } + getSink()->diagnose(expr, Diagnostics::noApplicableWithArgs, argsList); } + } + else + { + // There were multiple applicable candidates, so we need to report them. - int candidateCount = context.bestCandidates.Count(); - int maxCandidatesToPrint = 10; // don't show too many candidates at once... - int candidateIndex = 0; - for (auto candidate : context.bestCandidates) + if (funcName.Length() != 0) { - String declString = getDeclSignatureString(candidate.item); - - declString = declString + "[" + String(candidate.conversionCostSum) + "]"; - - getSink()->diagnose(candidate.item.declRef, Diagnostics::overloadCandidate, declString); - - candidateIndex++; - if (candidateIndex == maxCandidatesToPrint) - break; + getSink()->diagnose(expr, Diagnostics::ambiguousOverloadForNameWithArgs, funcName, argsList); } - if (candidateIndex != candidateCount) + else { - getSink()->diagnose(expr, Diagnostics::moreOverloadCandidates, candidateCount - candidateIndex); + getSink()->diagnose(expr, Diagnostics::ambiguousOverloadWithArgs, argsList); } - - return CreateErrorExpr(expr); } - else if (context.bestCandidate) + + int candidateCount = context.bestCandidates.Count(); + int maxCandidatesToPrint = 10; // don't show too many candidates at once... + int candidateIndex = 0; + for (auto candidate : context.bestCandidates) { - // There was one best candidate, even if it might not have been - // applicable in the end. - // We will report errors for this one candidate, then, to give - // the user the most help we can. - return CompleteOverloadCandidate(context, *context.bestCandidate); + String declString = getDeclSignatureString(candidate.item); + + declString = declString + "[" + String(candidate.conversionCostSum) + "]"; + + getSink()->diagnose(candidate.item.declRef, Diagnostics::overloadCandidate, declString); + + candidateIndex++; + if (candidateIndex == maxCandidatesToPrint) + break; } - else + if (candidateIndex != candidateCount) { - // Nothing at all was found that we could even consider invoking - getSink()->diagnose(expr->FunctionExpr, Diagnostics::expectedFunction); - expr->Type = ExpressionType::Error; - return expr; + getSink()->diagnose(expr, Diagnostics::moreOverloadCandidates, candidateCount - candidateIndex); } + + return CreateErrorExpr(expr); + } + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + return CompleteOverloadCandidate(context, *context.bestCandidate); } + else + { + // Nothing at all was found that we could even consider invoking + getSink()->diagnose(expr->FunctionExpr, Diagnostics::expectedFunction); + expr->Type = ExpressionType::Error; + return expr; + } + } - void AddGenericOverloadCandidate( - LookupResultItem baseItem, - OverloadResolveContext& context) + void AddGenericOverloadCandidate( + LookupResultItem baseItem, + OverloadResolveContext& context) + { + if (auto genericDeclRef = baseItem.declRef.As<GenericDeclRef>()) { - if (auto genericDeclRef = baseItem.declRef.As<GenericDeclRef>()) - { - EnsureDecl(genericDeclRef.GetDecl()); + EnsureDecl(genericDeclRef.GetDecl()); - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Generic; - candidate.item = baseItem; - candidate.resultType = nullptr; + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Generic; + candidate.item = baseItem; + candidate.resultType = nullptr; - AddOverloadCandidate(context, candidate); - } + AddOverloadCandidate(context, candidate); } + } - void AddGenericOverloadCandidates( - RefPtr<ExpressionSyntaxNode> baseExpr, - OverloadResolveContext& context) + void AddGenericOverloadCandidates( + RefPtr<ExpressionSyntaxNode> baseExpr, + OverloadResolveContext& context) + { + if(auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>()) { - if(auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>()) - { - auto declRef = baseDeclRefExpr->declRef; - AddGenericOverloadCandidate(LookupResultItem(declRef), context); - } - else if (auto overloadedExpr = baseExpr.As<OverloadedExpr>()) - { - // We are referring to a bunch of declarations, each of which might be generic - LookupResult result; - for (auto item : overloadedExpr->lookupResult2.items) - { - AddGenericOverloadCandidate(item, context); - } - } - else + auto declRef = baseDeclRefExpr->declRef; + AddGenericOverloadCandidate(LookupResultItem(declRef), context); + } + else if (auto overloadedExpr = baseExpr.As<OverloadedExpr>()) + { + // We are referring to a bunch of declarations, each of which might be generic + LookupResult result; + for (auto item : overloadedExpr->lookupResult2.items) { - // any other cases? + AddGenericOverloadCandidate(item, context); } } - - RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr * genericAppExpr) override + else { - // We are applying a generic to arguments, but there might be multiple generic - // declarations with the same name, so this becomes a specialized case of - // overload resolution. + // any other cases? + } + } + RefPtr<ExpressionSyntaxNode> VisitGenericApp(GenericAppExpr * genericAppExpr) override + { + // We are applying a generic to arguments, but there might be multiple generic + // declarations with the same name, so this becomes a specialized case of + // overload resolution. - // Start by checking the base expression and arguments. - auto& baseExpr = genericAppExpr->FunctionExpr; - baseExpr = CheckTerm(baseExpr); - auto& args = genericAppExpr->Arguments; - for (auto& arg : args) - { - arg = CheckTerm(arg); - } - // If there was an error in the base expression, or in any of - // the arguments, then just bail. - if (IsErrorExpr(baseExpr)) + // Start by checking the base expression and arguments. + auto& baseExpr = genericAppExpr->FunctionExpr; + baseExpr = CheckTerm(baseExpr); + auto& args = genericAppExpr->Arguments; + for (auto& arg : args) + { + arg = CheckTerm(arg); + } + + // If there was an error in the base expression, or in any of + // the arguments, then just bail. + if (IsErrorExpr(baseExpr)) + { + return CreateErrorExpr(genericAppExpr); + } + for (auto argExpr : args) + { + if (IsErrorExpr(argExpr)) { return CreateErrorExpr(genericAppExpr); } - for (auto argExpr : args) - { - if (IsErrorExpr(argExpr)) - { - return CreateErrorExpr(genericAppExpr); - } - } + } - // Otherwise, let's start looking at how to find an overload... + // Otherwise, let's start looking at how to find an overload... - OverloadResolveContext context; - context.appExpr = genericAppExpr; - context.baseExpr = GetBaseExpr(baseExpr); + OverloadResolveContext context; + context.appExpr = genericAppExpr; + context.baseExpr = GetBaseExpr(baseExpr); - AddGenericOverloadCandidates(baseExpr, context); + AddGenericOverloadCandidates(baseExpr, context); - if (context.bestCandidates.Count() > 0) + if (context.bestCandidates.Count() > 0) + { + // Things were ambiguous. + if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) { - // Things were ambiguous. - if (context.bestCandidates[0].status != OverloadCandidate::Status::Appicable) - { - // There were multple equally-good candidates, but none actually usable. - // We will construct a diagnostic message to help out. + // There were multple equally-good candidates, but none actually usable. + // We will construct a diagnostic message to help out. - // TODO(tfoley): print a reasonable message here... + // TODO(tfoley): print a reasonable message here... - getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "no applicable generic"); + getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "no applicable generic"); - return CreateErrorExpr(genericAppExpr); - } - else - { - // There were multiple viable candidates, but that isn't an error: we just need - // to complete all of them and create an overloaded expression as a result. + return CreateErrorExpr(genericAppExpr); + } + else + { + // There were multiple viable candidates, but that isn't an error: we just need + // to complete all of them and create an overloaded expression as a result. - LookupResult result; - for (auto candidate : context.bestCandidates) - { - auto candidateExpr = CompleteOverloadCandidate(context, candidate); - } + LookupResult result; + for (auto candidate : context.bestCandidates) + { + auto candidateExpr = CompleteOverloadCandidate(context, candidate); + } - throw "what now?"; + throw "what now?"; // auto overloadedExpr = new OverloadedExpr(); // return overloadedExpr; - } - } - else if (context.bestCandidate) - { - // There was one best candidate, even if it might not have been - // applicable in the end. - // We will report errors for this one candidate, then, to give - // the user the most help we can. - return CompleteOverloadCandidate(context, *context.bestCandidate); - } - else - { - // Nothing at all was found that we could even consider invoking - getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "expected a generic"); - return CreateErrorExpr(genericAppExpr); } + } + else if (context.bestCandidate) + { + // There was one best candidate, even if it might not have been + // applicable in the end. + // We will report errors for this one candidate, then, to give + // the user the most help we can. + return CompleteOverloadCandidate(context, *context.bestCandidate); + } + else + { + // Nothing at all was found that we could even consider invoking + getSink()->diagnose(genericAppExpr, Diagnostics::unimplemented, "expected a generic"); + return CreateErrorExpr(genericAppExpr); + } #if TIMREMOVED - if (IsErrorExpr(base)) - { - return CreateErrorExpr(typeNode); - } - else if(auto baseDeclRefExpr = base.As<DeclRefExpr>()) - { - auto declRef = baseDeclRefExpr->declRef; + if (IsErrorExpr(base)) + { + return CreateErrorExpr(typeNode); + } + else if(auto baseDeclRefExpr = base.As<DeclRefExpr>()) + { + auto declRef = baseDeclRefExpr->declRef; - if (auto genericDeclRef = declRef.As<GenericDeclRef>()) + if (auto genericDeclRef = declRef.As<GenericDeclRef>()) + { + int argCount = typeNode->Args.Count(); + int argIndex = 0; + for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) { - int argCount = typeNode->Args.Count(); - int argIndex = 0; - for (RefPtr<Decl> member : genericDeclRef.GetDecl()->Members) + if (auto typeParam = member.As<GenericTypeParamDecl>()) { - if (auto typeParam = member.As<GenericTypeParamDecl>()) + if (argIndex == argCount) { - if (argIndex == argCount) - { - // Too few arguments! - - } + // Too few arguments! - // TODO: checking! } - else if (auto valParam = member.As<GenericValueParamDecl>()) - { - // TODO: checking - } - else - { - } + // TODO: checking! } - if (argIndex != argCount) + else if (auto valParam = member.As<GenericValueParamDecl>()) { - // Too many arguments! + // TODO: checking } + else + { - // Now instantiate the declaration given those arguments - auto type = InstantiateGenericType(genericDeclRef, args); - typeResult = type; - typeNode->Type = new TypeExpressionType(type); - return typeNode; + } } + if (argIndex != argCount) + { + // Too many arguments! + } + + // Now instantiate the declaration given those arguments + auto type = InstantiateGenericType(genericDeclRef, args); + typeResult = type; + typeNode->Type = new TypeExpressionType(type); + return typeNode; } - else if (auto overloadedExpr = base.As<OverloadedExpr>()) + } + else if (auto overloadedExpr = base.As<OverloadedExpr>()) + { + // We are referring to a bunch of declarations, each of which might be generic + LookupResult result; + for (auto item : overloadedExpr->lookupResult2.items) { - // We are referring to a bunch of declarations, each of which might be generic - LookupResult result; - for (auto item : overloadedExpr->lookupResult2.items) - { - auto applied = TryApplyGeneric(item, typeNode); - if (!applied) - continue; + auto applied = TryApplyGeneric(item, typeNode); + if (!applied) + continue; - AddToLookupResult(result, appliedItem); - } + AddToLookupResult(result, appliedItem); } + } - // TODO: correct diagnostic here! - getSink()->diagnose(typeNode, Diagnostics::expectedAGeneric, base->Type); - return CreateErrorExpr(typeNode); + // TODO: correct diagnostic here! + getSink()->diagnose(typeNode, Diagnostics::expectedAGeneric, base->Type); + return CreateErrorExpr(typeNode); #endif - } + } - RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* expr) override + RefPtr<ExpressionSyntaxNode> VisitSharedTypeExpr(SharedTypeExpr* expr) override + { + if (!expr->Type.Ptr()) { - if (!expr->Type.Ptr()) - { - expr->base = CheckProperType(expr->base); - expr->Type = expr->base.exp->Type; - } - return expr; + expr->base = CheckProperType(expr->base); + expr->Type = expr->base.exp->Type; } + return expr; + } - RefPtr<ExpressionSyntaxNode> CheckExpr(RefPtr<ExpressionSyntaxNode> expr) - { - return expr->Accept(this).As<ExpressionSyntaxNode>(); - } + RefPtr<ExpressionSyntaxNode> CheckExpr(RefPtr<ExpressionSyntaxNode> expr) + { + return expr->Accept(this).As<ExpressionSyntaxNode>(); + } - RefPtr<ExpressionSyntaxNode> CheckInvokeExprWithCheckedOperands(InvokeExpressionSyntaxNode *expr) - { + RefPtr<ExpressionSyntaxNode> CheckInvokeExprWithCheckedOperands(InvokeExpressionSyntaxNode *expr) + { - auto rs = ResolveInvoke(expr); - if (auto invoke = dynamic_cast<InvokeExpressionSyntaxNode*>(rs.Ptr())) + auto rs = ResolveInvoke(expr); + if (auto invoke = dynamic_cast<InvokeExpressionSyntaxNode*>(rs.Ptr())) + { + // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues + if(auto funcType = invoke->FunctionExpr->Type->As<FuncType>()) { - // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues - if(auto funcType = invoke->FunctionExpr->Type->As<FuncType>()) + List<RefPtr<ParameterSyntaxNode>> paramsStorage; + List<RefPtr<ParameterSyntaxNode>> * params = nullptr; + if (auto func = funcType->declRef.GetDecl()) { - List<RefPtr<ParameterSyntaxNode>> paramsStorage; - List<RefPtr<ParameterSyntaxNode>> * params = nullptr; - if (auto func = funcType->declRef.GetDecl()) - { - paramsStorage = func->GetParameters().ToArray(); - params = ¶msStorage; - } - if (params) + paramsStorage = func->GetParameters().ToArray(); + params = ¶msStorage; + } + if (params) + { + for (int i = 0; i < (*params).Count(); i++) { - for (int i = 0; i < (*params).Count(); i++) + if ((*params)[i]->HasModifier<OutModifier>()) { - if ((*params)[i]->HasModifier<OutModifier>()) + if (i < expr->Arguments.Count() && expr->Arguments[i]->Type->AsBasicType() && + !expr->Arguments[i]->Type.IsLeftValue) { - if (i < expr->Arguments.Count() && expr->Arguments[i]->Type->AsBasicType() && - !expr->Arguments[i]->Type.IsLeftValue) - { - getSink()->diagnose(expr->Arguments[i], Diagnostics::argumentExpectedLValue, (*params)[i]->Name); - } + getSink()->diagnose(expr->Arguments[i], Diagnostics::argumentExpectedLValue, (*params)[i]->Name); } } } } } - return rs; } + return rs; + } - virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode *expr) override + virtual RefPtr<ExpressionSyntaxNode> VisitInvokeExpression(InvokeExpressionSyntaxNode *expr) override + { + // check the base expression first + expr->FunctionExpr = CheckExpr(expr->FunctionExpr); + + // Next check the argument expressions + for (auto & arg : expr->Arguments) { - // check the base expression first - expr->FunctionExpr = CheckExpr(expr->FunctionExpr); + arg = CheckExpr(arg); + } - // Next check the argument expressions - for (auto & arg : expr->Arguments) - { - arg = CheckExpr(arg); - } + return CheckInvokeExprWithCheckedOperands(expr); + } - return CheckInvokeExprWithCheckedOperands(expr); - } + virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode *expr) override + { + // If we've already resolved this expression, don't try again. + if (expr->declRef) + return expr; + + expr->Type = ExpressionType::Error; - virtual RefPtr<ExpressionSyntaxNode> VisitVarExpression(VarExpressionSyntaxNode *expr) override + auto lookupResult = LookUp(expr->Variable, expr->scope); + if (lookupResult.isValid()) { - // If we've already resolved this expression, don't try again. - if (expr->declRef) - return expr; + return createLookupResultExpr( + lookupResult, + nullptr, + expr); + } - expr->Type = ExpressionType::Error; + getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->Variable); - auto lookupResult = LookUp(expr->Variable, expr->scope); - if (lookupResult.isValid()) - { - return createLookupResultExpr( - lookupResult, - nullptr, - expr); - } - - getSink()->diagnose(expr, Diagnostics::undefinedIdentifier2, expr->Variable); + return expr; + } + virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * expr) override + { + expr->Expression = expr->Expression->Accept(this).As<ExpressionSyntaxNode>(); + auto targetType = CheckProperType(expr->TargetType); + expr->TargetType = targetType; + // The way to perform casting depends on the types involved + if (expr->Expression->Type->Equals(ExpressionType::Error.Ptr())) + { + // If the expression being casted has an error type, then just silently succeed + expr->Type = targetType; return expr; } - virtual RefPtr<ExpressionSyntaxNode> VisitTypeCastExpression(TypeCastExpressionSyntaxNode * expr) override + else if (auto targetArithType = targetType->AsArithmeticType()) { - expr->Expression = expr->Expression->Accept(this).As<ExpressionSyntaxNode>(); - auto targetType = CheckProperType(expr->TargetType); - expr->TargetType = targetType; - - // The way to perform casting depends on the types involved - if (expr->Expression->Type->Equals(ExpressionType::Error.Ptr())) + if (auto exprArithType = expr->Expression->Type->AsArithmeticType()) { - // If the expression being casted has an error type, then just silently succeed - expr->Type = targetType; - return expr; - } - else if (auto targetArithType = targetType->AsArithmeticType()) - { - if (auto exprArithType = expr->Expression->Type->AsArithmeticType()) - { - // Both source and destination types are arithmetic, so we might - // have a valid cast - auto targetScalarType = targetArithType->GetScalarType(); - auto exprScalarType = exprArithType->GetScalarType(); + // Both source and destination types are arithmetic, so we might + // have a valid cast + auto targetScalarType = targetArithType->GetScalarType(); + auto exprScalarType = exprArithType->GetScalarType(); - if (!IsNumeric(exprScalarType->BaseType)) goto fail; - if (!IsNumeric(targetScalarType->BaseType)) goto fail; + if (!IsNumeric(exprScalarType->BaseType)) goto fail; + if (!IsNumeric(targetScalarType->BaseType)) goto fail; - // TODO(tfoley): this checking is incomplete here, and could - // lead to downstream compilation failures - expr->Type = targetType; - return expr; - } + // TODO(tfoley): this checking is incomplete here, and could + // lead to downstream compilation failures + expr->Type = targetType; + return expr; } - - fail: - // Default: in no other case succeds, then the cast failed and we emit a diagnostic. - getSink()->diagnose(expr, Diagnostics::invalidTypeCast, expr->Expression->Type, targetType->ToString()); - expr->Type = ExpressionType::Error; - return expr; } + + fail: + // Default: in no other case succeds, then the cast failed and we emit a diagnostic. + getSink()->diagnose(expr, Diagnostics::invalidTypeCast, expr->Expression->Type, targetType->ToString()); + expr->Type = ExpressionType::Error; + return expr; + } #if TIMREMOVED - virtual RefPtr<ExpressionSyntaxNode> VisitSelectExpression(SelectExpressionSyntaxNode * expr) override + virtual RefPtr<ExpressionSyntaxNode> VisitSelectExpression(SelectExpressionSyntaxNode * expr) override + { + auto selectorExpr = expr->SelectorExpr; + selectorExpr = CheckExpr(selectorExpr); + selectorExpr = Coerce(ExpressionType::GetBool(), selectorExpr); + expr->SelectorExpr = selectorExpr; + + // TODO(tfoley): We need a general purpose "join" on types for inferring + // generic argument types for builtins/intrinsics, so this should really + // be using the exact same logic... + // + expr->Expr0 = expr->Expr0->Accept(this).As<ExpressionSyntaxNode>(); + expr->Expr1 = expr->Expr1->Accept(this).As<ExpressionSyntaxNode>(); + if (!expr->Expr0->Type->Equals(expr->Expr1->Type.Ptr())) { - auto selectorExpr = expr->SelectorExpr; - selectorExpr = CheckExpr(selectorExpr); - selectorExpr = Coerce(ExpressionType::GetBool(), selectorExpr); - expr->SelectorExpr = selectorExpr; - - // TODO(tfoley): We need a general purpose "join" on types for inferring - // generic argument types for builtins/intrinsics, so this should really - // be using the exact same logic... - // - expr->Expr0 = expr->Expr0->Accept(this).As<ExpressionSyntaxNode>(); - expr->Expr1 = expr->Expr1->Accept(this).As<ExpressionSyntaxNode>(); - if (!expr->Expr0->Type->Equals(expr->Expr1->Type.Ptr())) - { - getSink()->diagnose(expr, Diagnostics::selectValuesTypeMismatch); - } - expr->Type = expr->Expr0->Type; - return expr; + getSink()->diagnose(expr, Diagnostics::selectValuesTypeMismatch); } + expr->Type = expr->Expr0->Type; + return expr; + } #endif - // Get the type to use when referencing a declaration - QualType GetTypeForDeclRef(DeclRef declRef) - { - return getTypeForDeclRef( - this, - getSink(), - declRef, - &typeResult); - } + // Get the type to use when referencing a declaration + QualType GetTypeForDeclRef(DeclRef declRef) + { + return getTypeForDeclRef( + this, + getSink(), + declRef, + &typeResult); + } - RefPtr<ExpressionSyntaxNode> MaybeDereference(RefPtr<ExpressionSyntaxNode> inExpr) + RefPtr<ExpressionSyntaxNode> MaybeDereference(RefPtr<ExpressionSyntaxNode> inExpr) + { + RefPtr<ExpressionSyntaxNode> expr = inExpr; + for (;;) { - RefPtr<ExpressionSyntaxNode> expr = inExpr; - for (;;) + auto& type = expr->Type; + if (auto pointerLikeType = type->As<PointerLikeType>()) { - auto& type = expr->Type; - if (auto pointerLikeType = type->As<PointerLikeType>()) - { - type = pointerLikeType->elementType; + type = pointerLikeType->elementType; - auto derefExpr = new DerefExpr(); - derefExpr->base = expr; - derefExpr->Type = pointerLikeType->elementType; + auto derefExpr = new DerefExpr(); + derefExpr->base = expr; + derefExpr->Type = pointerLikeType->elementType; - // TODO(tfoley): deal with l-value-ness here - - expr = derefExpr; - continue; - } + // TODO(tfoley): deal with l-value-ness here - // Default case: just use the expression as-is - return expr; + expr = derefExpr; + continue; } + + // Default case: just use the expression as-is + return expr; } + } - RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( - MemberExpressionSyntaxNode* memberRefExpr, - RefPtr<ExpressionType> baseElementType, - int baseElementCount) - { - RefPtr<SwizzleExpr> swizExpr = new SwizzleExpr(); - swizExpr->Position = memberRefExpr->Position; - swizExpr->base = memberRefExpr->BaseExpression; + RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( + MemberExpressionSyntaxNode* memberRefExpr, + RefPtr<ExpressionType> baseElementType, + int baseElementCount) + { + RefPtr<SwizzleExpr> swizExpr = new SwizzleExpr(); + swizExpr->Position = memberRefExpr->Position; + swizExpr->base = memberRefExpr->BaseExpression; - int limitElement = baseElementCount; + int limitElement = baseElementCount; - int elementIndices[4]; - int elementCount = 0; + int elementIndices[4]; + int elementCount = 0; - bool elementUsed[4] = { false, false, false, false }; - bool anyDuplicates = false; - bool anyError = false; + bool elementUsed[4] = { false, false, false, false }; + bool anyDuplicates = false; + bool anyError = false; - for (int i = 0; i < memberRefExpr->MemberName.Length(); i++) + for (int i = 0; i < memberRefExpr->MemberName.Length(); i++) + { + auto ch = memberRefExpr->MemberName[i]; + int elementIndex = -1; + switch (ch) { - auto ch = memberRefExpr->MemberName[i]; - int elementIndex = -1; - switch (ch) - { - case 'x': case 'r': elementIndex = 0; break; - case 'y': case 'g': elementIndex = 1; break; - case 'z': case 'b': elementIndex = 2; break; - case 'w': case 'a': elementIndex = 3; break; - default: - // An invalid character in the swizzle is an error - getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "invalid component name for swizzle"); - anyError = true; - continue; - } - - // TODO(tfoley): GLSL requires that all component names - // come from the same "family"... - - // Make sure the index is in range for the source type - if (elementIndex >= limitElement) - { - getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "swizzle component out of range for type"); - anyError = true; - continue; - } - - // Check if we've seen this index before - for (int ee = 0; ee < elementCount; ee++) - { - if (elementIndices[ee] == elementIndex) - anyDuplicates = true; - } - - // add to our list... - elementIndices[elementCount++] = elementIndex; + case 'x': case 'r': elementIndex = 0; break; + case 'y': case 'g': elementIndex = 1; break; + case 'z': case 'b': elementIndex = 2; break; + case 'w': case 'a': elementIndex = 3; break; + default: + // An invalid character in the swizzle is an error + getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "invalid component name for swizzle"); + anyError = true; + continue; } - for (int ee = 0; ee < elementCount; ++ee) - { - swizExpr->elementIndices[ee] = elementIndices[ee]; - } - swizExpr->elementCount = elementCount; + // TODO(tfoley): GLSL requires that all component names + // come from the same "family"... - if (anyError) + // Make sure the index is in range for the source type + if (elementIndex >= limitElement) { - swizExpr->Type = ExpressionType::Error; + getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "swizzle component out of range for type"); + anyError = true; + continue; } - else if (elementCount == 1) - { - // single-component swizzle produces a scalar - // - // Note(tfoley): the official HLSL rules seem to be that it produces - // a one-component vector, which is then implicitly convertible to - // a scalar, but that seems like it just adds complexity. - swizExpr->Type = baseElementType; - } - else + + // Check if we've seen this index before + for (int ee = 0; ee < elementCount; ee++) { - // TODO(tfoley): would be nice to "re-sugar" type - // here if the input type had a sugared name... - swizExpr->Type = createVectorType( - baseElementType, - new ConstantIntVal(elementCount)); + if (elementIndices[ee] == elementIndex) + anyDuplicates = true; } - // A swizzle can be used as an l-value as long as there - // were no duplicates in the list of components - swizExpr->Type.IsLeftValue = !anyDuplicates; + // add to our list... + elementIndices[elementCount++] = elementIndex; + } - return swizExpr; + for (int ee = 0; ee < elementCount; ++ee) + { + swizExpr->elementIndices[ee] = elementIndices[ee]; } + swizExpr->elementCount = elementCount; - RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( - MemberExpressionSyntaxNode* memberRefExpr, - RefPtr<ExpressionType> baseElementType, - RefPtr<IntVal> baseElementCount) + if (anyError) { - if (auto constantElementCount = baseElementCount.As<ConstantIntVal>()) - { - return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); - } - else - { - getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on vector of unknown size"); - return CreateErrorExpr(memberRefExpr); - } + swizExpr->Type = ExpressionType::Error; + } + else if (elementCount == 1) + { + // single-component swizzle produces a scalar + // + // Note(tfoley): the official HLSL rules seem to be that it produces + // a one-component vector, which is then implicitly convertible to + // a scalar, but that seems like it just adds complexity. + swizExpr->Type = baseElementType; + } + else + { + // TODO(tfoley): would be nice to "re-sugar" type + // here if the input type had a sugared name... + swizExpr->Type = createVectorType( + baseElementType, + new ConstantIntVal(elementCount)); } + // A swizzle can be used as an l-value as long as there + // were no duplicates in the list of components + swizExpr->Type.IsLeftValue = !anyDuplicates; - virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * expr) override + return swizExpr; + } + + RefPtr<ExpressionSyntaxNode> CheckSwizzleExpr( + MemberExpressionSyntaxNode* memberRefExpr, + RefPtr<ExpressionType> baseElementType, + RefPtr<IntVal> baseElementCount) + { + if (auto constantElementCount = baseElementCount.As<ConstantIntVal>()) { - expr->BaseExpression = CheckExpr(expr->BaseExpression); + return CheckSwizzleExpr(memberRefExpr, baseElementType, constantElementCount->value); + } + else + { + getSink()->diagnose(memberRefExpr, Diagnostics::unimplemented, "swizzle on vector of unknown size"); + return CreateErrorExpr(memberRefExpr); + } + } - expr->BaseExpression = MaybeDereference(expr->BaseExpression); - auto & baseType = expr->BaseExpression->Type; + virtual RefPtr<ExpressionSyntaxNode> VisitMemberExpression(MemberExpressionSyntaxNode * expr) override + { + expr->BaseExpression = CheckExpr(expr->BaseExpression); - // Note: Checking for vector types before declaration-reference types, - // because vectors are also declaration reference types... - if (auto baseVecType = baseType->AsVectorType()) - { - return CheckSwizzleExpr( - expr, - baseVecType->elementType, - baseVecType->elementCount); - } - else if(auto baseScalarType = baseType->AsBasicType()) - { - // Treat scalar like a 1-element vector when swizzling - return CheckSwizzleExpr( - expr, - baseScalarType, - 1); - } - else if (auto declRefType = baseType->AsDeclRefType()) + expr->BaseExpression = MaybeDereference(expr->BaseExpression); + + auto & baseType = expr->BaseExpression->Type; + + // Note: Checking for vector types before declaration-reference types, + // because vectors are also declaration reference types... + if (auto baseVecType = baseType->AsVectorType()) + { + return CheckSwizzleExpr( + expr, + baseVecType->elementType, + baseVecType->elementCount); + } + else if(auto baseScalarType = baseType->AsBasicType()) + { + // Treat scalar like a 1-element vector when swizzling + return CheckSwizzleExpr( + expr, + baseScalarType, + 1); + } + else if (auto declRefType = baseType->AsDeclRefType()) + { + if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) { - if (auto aggTypeDeclRef = declRefType->declRef.As<AggTypeDeclRef>()) - { - // Checking of the type must be complete before we can reference its members safely - EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); + // Checking of the type must be complete before we can reference its members safely + EnsureDecl(aggTypeDeclRef.GetDecl(), DeclCheckState::Checked); - LookupResult lookupResult = LookUpLocal(expr->MemberName, aggTypeDeclRef); - if (!lookupResult.isValid()) - { - goto fail; - } + LookupResult lookupResult = LookUpLocal(expr->MemberName, aggTypeDeclRef); + if (!lookupResult.isValid()) + { + goto fail; + } - return createLookupResultExpr( - lookupResult, - expr->BaseExpression, - expr); + return createLookupResultExpr( + lookupResult, + expr->BaseExpression, + expr); #if 0 - DeclRef memberDeclRef(lookupResult.decl, aggTypeDeclRef.substitutions); - return ConstructDeclRefExpr(memberDeclRef, expr->BaseExpression, expr); + DeclRef memberDeclRef(lookupResult.decl, aggTypeDeclRef.substitutions); + return ConstructDeclRefExpr(memberDeclRef, expr->BaseExpression, expr); #endif #if 0 - // TODO(tfoley): It is unfortunate that the lookup strategy - // here isn't unified with the ordinary `Scope` case. - // In particular, if we add support for "transparent" declarations, - // etc. here then we would need to add them in ordinary lookup - // as well. - - Decl* memberDecl = nullptr; // The first declaration we found, if any - Decl* secondDecl = nullptr; // Another declaration with the same name, if any - for (auto m : aggTypeDeclRef.GetMembers()) - { - if (m.GetName() != expr->MemberName) - continue; + // TODO(tfoley): It is unfortunate that the lookup strategy + // here isn't unified with the ordinary `Scope` case. + // In particular, if we add support for "transparent" declarations, + // etc. here then we would need to add them in ordinary lookup + // as well. - if (!memberDecl) - { - memberDecl = m.GetDecl(); - } - else - { - secondDecl = m.GetDecl(); - break; - } - } + Decl* memberDecl = nullptr; // The first declaration we found, if any + Decl* secondDecl = nullptr; // Another declaration with the same name, if any + for (auto m : aggTypeDeclRef.GetMembers()) + { + if (m.GetName() != expr->MemberName) + continue; - // If we didn't find any member, then we signal an error if (!memberDecl) { - expr->Type = ExpressionType::Error; - getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); - return expr; + memberDecl = m.GetDecl(); } - - // If we found only a single member, then we are fine - if (!secondDecl) + else { - // TODO: need to - DeclRef memberDeclRef(memberDecl, aggTypeDeclRef.substitutions); - - expr->declRef = memberDeclRef; - expr->Type = GetTypeForDeclRef(memberDeclRef); - - // When referencing a member variable, the result is an l-value - // if and only if the base expression was. - if (auto memberVarDecl = dynamic_cast<VarDeclBase*>(memberDecl)) - { - expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; - } - return expr; + secondDecl = m.GetDecl(); + break; } + } - // We found multiple members with the same name, and need - // to resolve the embiguity at some point... + // If we didn't find any member, then we signal an error + if (!memberDecl) + { expr->Type = ExpressionType::Error; - getSink()->diagnose(expr, Diagnostics::unimplemented, "ambiguous member reference"); + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); return expr; + } -#endif + // If we found only a single member, then we are fine + if (!secondDecl) + { + // TODO: need to + DeclRef memberDeclRef(memberDecl, aggTypeDeclRef.substitutions); -#if 0 + expr->declRef = memberDeclRef; + expr->Type = GetTypeForDeclRef(memberDeclRef); - StructField* field = structDecl->FindField(expr->MemberName); - if (!field) + // When referencing a member variable, the result is an l-value + // if and only if the base expression was. + if (auto memberVarDecl = dynamic_cast<VarDeclBase*>(memberDecl)) { - expr->Type = ExpressionType::Error; - getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; } - else - expr->Type = field->Type; - - // A reference to a struct member is an l-value if the reference to the struct - // value was also an l-value. - expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; return expr; -#endif } - // catch-all - fail: - getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + // We found multiple members with the same name, and need + // to resolve the embiguity at some point... expr->Type = ExpressionType::Error; + getSink()->diagnose(expr, Diagnostics::unimplemented, "ambiguous member reference"); return expr; - } - // All remaining cases assume we have a `BasicType` - else if (!baseType->AsBasicType()) - expr->Type = ExpressionType::Error; - else - expr->Type = ExpressionType::Error; - if (!baseType->Equals(ExpressionType::Error.Ptr()) && - expr->Type->Equals(ExpressionType::Error.Ptr())) - { - getSink()->diagnose(expr, Diagnostics::typeHasNoPublicMemberOfName, baseType, expr->MemberName); - } - return expr; - } - SemanticsVisitor & operator = (const SemanticsVisitor &) = delete; +#endif - // +#if 0 - virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) override - { - // When faced with an initializer list, we first just check the sub-expressions blindly. - // Actually making them conform to a desired type will wait for when we know the desired - // type based on context. + StructField* field = structDecl->FindField(expr->MemberName); + if (!field) + { + expr->Type = ExpressionType::Error; + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + } + else + expr->Type = field->Type; - for( auto& arg : expr->args ) - { - arg = CheckTerm(arg); + // A reference to a struct member is an l-value if the reference to the struct + // value was also an l-value. + expr->Type.IsLeftValue = expr->BaseExpression->Type.IsLeftValue; + return expr; +#endif } - expr->Type = ExpressionType::getInitializerListType(); - + // catch-all + fail: + getSink()->diagnose(expr, Diagnostics::noMemberOfNameInType, expr->MemberName, baseType); + expr->Type = ExpressionType::Error; return expr; } - - virtual void visitImportDecl(ImportDecl* decl) override + // All remaining cases assume we have a `BasicType` + else if (!baseType->AsBasicType()) + expr->Type = ExpressionType::Error; + else + expr->Type = ExpressionType::Error; + if (!baseType->Equals(ExpressionType::Error.Ptr()) && + expr->Type->Equals(ExpressionType::Error.Ptr())) { - // We need to look for a module with the specified name - // (whether it has already been loaded, or needs to - // be loaded), and then put its declarations into - // the current scope. + getSink()->diagnose(expr, Diagnostics::typeHasNoPublicMemberOfName, baseType, expr->MemberName); + } + return expr; + } + SemanticsVisitor & operator = (const SemanticsVisitor &) = delete; - auto name = decl->nameToken.Content; - auto scope = decl->scope; - // Try to load a module matching the name - auto importedModuleDecl = findOrImportModule(request, name, decl->nameToken.Position); + // - // If we didn't find a matching module, then bail out - if (!importedModuleDecl) - return; + virtual RefPtr<ExpressionSyntaxNode> visitInitializerListExpr(InitializerListExpr* expr) override + { + // When faced with an initializer list, we first just check the sub-expressions blindly. + // Actually making them conform to a desired type will wait for when we know the desired + // type based on context. - // Record the module that was imported, so that we can use - // it later during code generation. - decl->importedModuleDecl = importedModuleDecl; + for( auto& arg : expr->args ) + { + arg = CheckTerm(arg); + } - // Create a new sub-scope to wire the module - // into our lookup chain. - auto subScope = new Scope(); - subScope->containerDecl = importedModuleDecl.Ptr(); + expr->Type = ExpressionType::getInitializerListType(); - subScope->nextSibling = scope->nextSibling; - scope->nextSibling = subScope; - } - }; + return expr; + } - SyntaxVisitor* CreateSemanticsVisitor( - DiagnosticSink* err, - CompileOptions const& options, - CompileRequest* request) + virtual void visitImportDecl(ImportDecl* decl) override { - return new SemanticsVisitor(err, options, request); + // We need to look for a module with the specified name + // (whether it has already been loaded, or needs to + // be loaded), and then put its declarations into + // the current scope. + + auto name = decl->nameToken.Content; + auto scope = decl->scope; + + // Try to load a module matching the name + auto importedModuleDecl = findOrImportModule(request, name, decl->nameToken.Position); + + // If we didn't find a matching module, then bail out + if (!importedModuleDecl) + return; + + // Record the module that was imported, so that we can use + // it later during code generation. + decl->importedModuleDecl = importedModuleDecl; + + // Create a new sub-scope to wire the module + // into our lookup chain. + auto subScope = new Scope(); + subScope->containerDecl = importedModuleDecl.Ptr(); + + subScope->nextSibling = scope->nextSibling; + scope->nextSibling = subScope; } + }; - // + SyntaxVisitor* CreateSemanticsVisitor( + DiagnosticSink* err, + CompileOptions const& options, + CompileRequest* request) + { + return new SemanticsVisitor(err, options, request); + } - // Get the type to use when referencing a declaration - QualType getTypeForDeclRef( - SemanticsVisitor* sema, - DiagnosticSink* sink, - DeclRef declRef, - RefPtr<ExpressionType>* outTypeResult) - { - if( sema ) - { - sema->EnsureDecl(declRef.GetDecl()); - } + // - // We need to insert an appropriate type for the expression, based on - // what we found. - if (auto varDeclRef = declRef.As<VarDeclBaseRef>()) - { - QualType qualType; - qualType.type = varDeclRef.GetType(); - qualType.IsLeftValue = true; // TODO(tfoley): allow explicit `const` or `let` variables - return qualType; - } - else if (auto typeAliasDeclRef = declRef.As<TypeDefDeclRef>()) - { - auto type = new NamedExpressionType(typeAliasDeclRef); - *outTypeResult = type; - return new TypeType(type); - } - else if (auto aggTypeDeclRef = declRef.As<AggTypeDeclRef>()) - { - auto type = DeclRefType::Create(aggTypeDeclRef); - *outTypeResult = type; - return new TypeType(type); - } - else if (auto simpleTypeDeclRef = declRef.As<SimpleTypeDeclRef>()) - { - auto type = DeclRefType::Create(simpleTypeDeclRef); - *outTypeResult = type; - return new TypeType(type); - } - else if (auto genericDeclRef = declRef.As<GenericDeclRef>()) - { - auto type = new GenericDeclRefType(genericDeclRef); - *outTypeResult = type; - return new TypeType(type); - } - else if (auto funcDeclRef = declRef.As<CallableDeclRef>()) - { - auto type = new FuncType(); - type->declRef = funcDeclRef; - return type; - } + // Get the type to use when referencing a declaration + QualType getTypeForDeclRef( + SemanticsVisitor* sema, + DiagnosticSink* sink, + DeclRef declRef, + RefPtr<ExpressionType>* outTypeResult) + { + if( sema ) + { + sema->EnsureDecl(declRef.GetDecl()); + } - if( sink ) - { - sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration"); - } - return ExpressionType::Error; + // We need to insert an appropriate type for the expression, based on + // what we found. + if (auto varDeclRef = declRef.As<VarDeclBaseRef>()) + { + QualType qualType; + qualType.type = varDeclRef.GetType(); + qualType.IsLeftValue = true; // TODO(tfoley): allow explicit `const` or `let` variables + return qualType; + } + else if (auto typeAliasDeclRef = declRef.As<TypeDefDeclRef>()) + { + auto type = new NamedExpressionType(typeAliasDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto aggTypeDeclRef = declRef.As<AggTypeDeclRef>()) + { + auto type = DeclRefType::Create(aggTypeDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto simpleTypeDeclRef = declRef.As<SimpleTypeDeclRef>()) + { + auto type = DeclRefType::Create(simpleTypeDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto genericDeclRef = declRef.As<GenericDeclRef>()) + { + auto type = new GenericDeclRefType(genericDeclRef); + *outTypeResult = type; + return new TypeType(type); + } + else if (auto funcDeclRef = declRef.As<CallableDeclRef>()) + { + auto type = new FuncType(); + type->declRef = funcDeclRef; + return type; } - QualType getTypeForDeclRef( - DeclRef declRef) + if( sink ) { - RefPtr<ExpressionType> typeResult; - return getTypeForDeclRef(nullptr, nullptr, declRef, &typeResult); + sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration"); } + return ExpressionType::Error; + } + QualType getTypeForDeclRef( + DeclRef declRef) + { + RefPtr<ExpressionType> typeResult; + return getTypeForDeclRef(nullptr, nullptr, declRef, &typeResult); } -}
\ No newline at end of file + +} |
